fix nan errors

This commit is contained in:
Jan Grewe 2021-06-11 13:02:07 +02:00
parent 7008a08c6d
commit bc24b900c5
2 changed files with 32 additions and 22 deletions

View File

@ -133,7 +133,8 @@ def estimate_boldness(visit_times, feeder_risks):
else: else:
y += 1 y += 1
score += y - intersection[x] score += y - intersection[x]
if len(sorted_risks) == 0:
return None, None
boldness = score / len(sorted_risks) boldness = score / len(sorted_risks)
return boldness, len(sorted_risks) return boldness, len(sorted_risks)
@ -155,6 +156,8 @@ def get_boldness_score(df, subject):
risks = get_feeder_risks(df, d) risks = get_feeder_risks(df, d)
boldness, count = estimate_boldness(visit_times, risks) boldness, count = estimate_boldness(visit_times, risks)
b = {"subject": subject, "day": d, "boldness": boldness, "total_visits": count} b = {"subject": subject, "day": d, "boldness": boldness, "total_visits": count}
if boldness is None:
continue
boldness_scores.append(b) boldness_scores.append(b)
return boldness_scores return boldness_scores
@ -205,8 +208,9 @@ if __name__ == "__main__":
scores = pd.DataFrame(all_boldness_scores) scores = pd.DataFrame(all_boldness_scores)
scores.to_csv("boldness_scores.csv", sep=";") scores.to_csv("boldness_scores.csv", sep=";")
for s in subjects: for s in subjects:
days = df.day[df.subject == s].unique()
print(s, np.mean(scores.boldness[scores.subject == s]), np.mean(scores.total_visits[scores.subject == s])) print(s, np.mean(scores.boldness[scores.subject == s]), np.mean(scores.total_visits[scores.subject == s]))
for day in days:
plot_boldness_analysis(df, s, "day_4") plot_boldness_analysis(df, s, day)
plot_boldness_analysis(df, s, "day_5") plot_boldness_analysis(df, s, day)

View File

@ -197,7 +197,7 @@ def plot_risk_per_day(data, subjects):
avg_feeder_risks_dark = avg_feeder_risks_dark[np.argsort(days)] avg_feeder_risks_dark = avg_feeder_risks_dark[np.argsort(days)]
days = days[np.argsort(days)] days = days[np.argsort(days)]
m, n, r, p, _ = st.linregress(days, trs) m, n, r, p, _ = st.linregress(days[~np.isnan(trs)], trs[~np.isnan(trs)])
ax1.plot(days, trs, marker=".", label=subject) ax1.plot(days, trs, marker=".", label=subject)
ax1.plot(days, m * days + n, label="r: %.2f, p:%.3f" % (r, p)) ax1.plot(days, m * days + n, label="r: %.2f, p:%.3f" % (r, p))
ax1.legend(fontsize=7) ax1.legend(fontsize=7)
@ -205,7 +205,7 @@ def plot_risk_per_day(data, subjects):
ax1.set_ylabel("total risk", fontsize=9) ax1.set_ylabel("total risk", fontsize=9)
ax2.plot(days, trs/num_visits, marker=".", label=subject) ax2.plot(days, trs/num_visits, marker=".", label=subject)
m, n, r, p, _ = st.linregress(days, trs/num_visits) m, n, r, p, _ = st.linregress(days[~np.isnan(trs/num_visits)], (trs/num_visits)[~np.isnan(trs/num_visits)])
ax2.plot(days, m * days + n, label="r: %.2f, p:%.3f" % (r, p)) ax2.plot(days, m * days + n, label="r: %.2f, p:%.3f" % (r, p))
ax2.set_ylim([0, 2]) ax2.set_ylim([0, 2])
ax2.set_ylabel("avg. risk per visit", fontsize=9) ax2.set_ylabel("avg. risk per visit", fontsize=9)
@ -213,8 +213,9 @@ def plot_risk_per_day(data, subjects):
ax2.legend(fontsize=7) ax2.legend(fontsize=7)
ax3.plot(days, trs/num_visits/avg_feeder_risks, marker=".", label="") ax3.plot(days, trs/num_visits/avg_feeder_risks, marker=".", label="")
m, n, r, p, _ = st.linregress(days, trs/num_visits/avg_feeder_risks) m, n, r, p, _ = st.linregress(days[~np.isnan(trs/num_visits/avg_feeder_risks)], (trs/num_visits/avg_feeder_risks)[~np.isnan(trs/num_visits/avg_feeder_risks)])
t, p = st.ttest_1samp(trs/num_visits/avg_feeder_risks - avg_feeder_risks_dark/avg_feeder_risks, 0) t, p = st.ttest_1samp((trs/num_visits/avg_feeder_risks - avg_feeder_risks_dark/avg_feeder_risks)[~np.isnan(trs/num_visits/avg_feeder_risks - avg_feeder_risks_dark/avg_feeder_risks)], 0)
print(subject, "ttest: p: %.3f," % p)
ax3.plot(days, m * days + n, label="r: %.2f, p:%.3f" % (r, p)) ax3.plot(days, m * days + n, label="r: %.2f, p:%.3f" % (r, p))
ax3.plot(days, avg_feeder_risks_dark/avg_feeder_risks, color="dodgerblue", ls="dashed", lw=0.5, label="avg. feeder risk dark") ax3.plot(days, avg_feeder_risks_dark/avg_feeder_risks, color="dodgerblue", ls="dashed", lw=0.5, label="avg. feeder risk dark")
ax3.hlines(1.0, min(days), max(days), ls="dashed", color="darkgreen", lw=0.5, zorder=0, label="avg. feeder risk") ax3.hlines(1.0, min(days), max(days), ls="dashed", color="darkgreen", lw=0.5, zorder=0, label="avg. feeder risk")
@ -246,8 +247,8 @@ def compare_risks_per_subject(data, subjects):
num_visits = num_visits[np.argsort(days)] num_visits = num_visits[np.argsort(days)]
days = days[np.argsort(days)] days = days[np.argsort(days)]
total_risks.append(trs) total_risks.append(trs[~np.isnan(trs)])
total_risks_rel.append(trs/num_visits) total_risks_rel.append((trs[num_visits > 0]/num_visits[num_visits > 0]))
fig = plt.figure(figsize=(5, 5)) fig = plt.figure(figsize=(5, 5))
ax1 = fig.add_subplot(2,1,1) ax1 = fig.add_subplot(2,1,1)
@ -260,7 +261,7 @@ def compare_risks_per_subject(data, subjects):
ax2.boxplot(total_risks_rel, showfliers=False) ax2.boxplot(total_risks_rel, showfliers=False)
ax2.set_xticklabels(subjects, fontsize=9, rotation=45) ax2.set_xticklabels(subjects, fontsize=9, rotation=45)
ax2.set_ylabel("risk per visit", fontsize=9) ax2.set_ylabel("risk per visit", fontsize=9)
ax2.set_ylim([0, 1]) ax2.set_ylim([0, 2.0])
fig.subplots_adjust(left=0.1, right=0.975, top=0.975, bottom=0.15, hspace=0.2) fig.subplots_adjust(left=0.1, right=0.975, top=0.975, bottom=0.15, hspace=0.2)
fig.savefig("total_risk_comparison.pdf") fig.savefig("total_risk_comparison.pdf")
@ -286,16 +287,22 @@ def plot_risk_over_time(data, subjects):
day = days[index] day = days[index]
day_times = times[index] day_times = times[index]
day_risks = risks[index] day_risks = risks[index]
m, n, _, _, _ = st.linregress(day_times, np.cumsum(day_risks)) print(subject, len(day_times), len(day_risks))
slopes.append(m) if len(day_risks) > 1:
m, n, _, _, _ = st.linregress(day_times, np.cumsum(day_risks))
slopes.append(m)
else:
slopes.append(np.nan)
l = ax1.plot(day_times, np.cumsum(day_risks), marker=".", lw=0.1, label=day) l = ax1.plot(day_times, np.cumsum(day_risks), marker=".", lw=0.1, label=day)
ax1.plot(day_times, m * day_times + n, lw=1, color=l[0].get_color()) ax1.plot(day_times, m * day_times + n, lw=1, color=l[0].get_color())
slopes = np.array(slopes)
all_slopes.append(slopes[~np.isnan(slopes)])
ax1.set_xlabel("time [s]") ax1.set_xlabel("time [s]")
ax1.set_ylabel("risk taken") ax1.set_ylabel("risk taken")
ax1.text(0.0, 1.05, subject, transform=ax1.transAxes, fontsize=9, fontweight="bold") ax1.text(0.0, 1.05, subject, transform=ax1.transAxes, fontsize=9, fontweight="bold")
ax2.scatter(day_numbers, slopes, marker="o", label="slopes") ax2.scatter(day_numbers, slopes, marker="o", label="slopes")
m, n, r, p, _ = st.linregress(day_numbers, slopes) m, n, r, p, _ = st.linregress(day_numbers[~np.isnan(slopes)], slopes[~np.isnan(slopes)])
ax2.plot(day_numbers, m * day_numbers + n, label="r: %.2f, p: %.3f" % (r, p)) ax2.plot(day_numbers, m * day_numbers + n, label="r: %.2f, p: %.3f" % (r, p))
ax2.set_xlabel("experimental day", fontsize=9) ax2.set_xlabel("experimental day", fontsize=9)
ax2.set_ylabel("slope [risk/s]", fontsize=9) ax2.set_ylabel("slope [risk/s]", fontsize=9)
@ -305,7 +312,6 @@ def plot_risk_over_time(data, subjects):
fig.savefig("risk_over_time_%s.pdf" % subject) fig.savefig("risk_over_time_%s.pdf" % subject)
plt.close() plt.close()
all_slopes.append(slopes)
fig = plt.figure(figsize=(3.5, 3.5)) fig = plt.figure(figsize=(3.5, 3.5))
ax = fig.add_subplot(111) ax = fig.add_subplot(111)
ax.boxplot(all_slopes, showfliers=False) ax.boxplot(all_slopes, showfliers=False)
@ -325,8 +331,8 @@ if __name__ == '__main__':
days = df.day.unique() days = df.day.unique()
subjects = df.subject.unique() subjects = df.subject.unique()
plot_trial_visit_times(df, df2, days, subjects) plot_trial_visit_times(df, df2, days, subjects)
#boldness = analyze_feeder_risk(df, subjects, days) boldness = analyze_feeder_risk(df, subjects, days)
#plot_risk_per_day(boldness, subjects) plot_risk_per_day(boldness, subjects)
#compare_risks_per_subject(boldness, subjects) compare_risks_per_subject(boldness, subjects)
#plot_risk_over_time(boldness, subjects) plot_risk_over_time(boldness, subjects)