include mean shelter power (which denfines winner) to plot.

This commit is contained in:
Till Raab 2023-05-19 08:15:47 +02:00
parent 922c938fb6
commit 83d5e3164b
2 changed files with 12 additions and 8 deletions

View File

@ -238,8 +238,11 @@ def main(data_folder=None):
if ~np.all(meta_id_in_uid): if ~np.all(meta_id_in_uid):
continue continue
ids = np.load(os.path.join(trial_path, 'analysis', 'ids.npy'))
sorter = -1 if win_id != ids[0] else 1
temp_t, temp = get_temperature(trial_path) temp_t, temp = get_temperature(trial_path)
baseline_freqs = np.load(os.path.join(trial_path, 'analysis', 'baseline_freqs.npy')) baseline_freqs = np.load(os.path.join(trial_path, 'analysis', 'baseline_freqs.npy'))[::sorter]
baseline_freq_times = np.load(os.path.join(trial_path, 'analysis', 'baseline_freq_times.npy')) baseline_freq_times = np.load(os.path.join(trial_path, 'analysis', 'baseline_freq_times.npy'))
q10_comp_freq, q10_vals = frequency_q10_compensation(baseline_freqs, baseline_freq_times, temp, temp_t, light_start_sec=light_start_sec) q10_comp_freq, q10_vals = frequency_q10_compensation(baseline_freqs, baseline_freq_times, temp, temp_t, light_start_sec=light_start_sec)
@ -252,9 +255,11 @@ def main(data_folder=None):
got_chirps = True got_chirps = True
chirp_times = [chirp_t[chirp_ids == win_id], chirp_t[chirp_ids == lose_id]] chirp_times = [chirp_t[chirp_ids == win_id], chirp_t[chirp_ids == lose_id]]
rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy')) rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter]
rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))] rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))]
############################################################################################################# #############################################################################################################
### physical behavior ### physical behavior
if video_eval: if video_eval:
@ -362,9 +367,6 @@ def main(data_folder=None):
plt.show() plt.show()
embed()
quit()
for g in pd.unique(trial_summary['group']): for g in pd.unique(trial_summary['group']):
fish_no = np.unique(np.concatenate((trial_summary['win_fish'][trial_summary['group'] == g], fish_no = np.unique(np.concatenate((trial_summary['win_fish'][trial_summary['group'] == g],

View File

@ -38,6 +38,8 @@ class Trial(object):
self.winner = None self.winner = None
self.loser = None self.loser = None
self.mean_shelter_power = None
if os.path.exists(os.path.join(self.base_path, self.folder, 'fund_v.npy')): if os.path.exists(os.path.join(self.base_path, self.folder, 'fund_v.npy')):
self.load() self.load()
@ -108,8 +110,8 @@ class Trial(object):
for enu, id in enumerate(self.ids): for enu, id in enumerate(self.ids):
shelter_power[enu] = self.fish_sign_interp[enu][day_idxs, -1] shelter_power[enu] = self.fish_sign_interp[enu][day_idxs, -1]
mean_shelter_power = np.nanmean(shelter_power, axis=1) self.mean_shelter_power = np.nanmean(shelter_power, axis=1)
self.winner = 1 if mean_shelter_power[1] > mean_shelter_power[0] else 0 self.winner = 1 if self.mean_shelter_power[1] > self.mean_shelter_power[0] else 0
self.loser = 0 if self.winner == 1 else 1 self.loser = 0 if self.winner == 1 else 1
def rise_detection(self, rise_th): def rise_detection(self, rise_th):
@ -187,7 +189,7 @@ class Trial(object):
for enu, id in enumerate(self.ids): for enu, id in enumerate(self.ids):
c = 'firebrick' if self.winner == enu else 'forestgreen' c = 'firebrick' if self.winner == enu else 'forestgreen'
ax.plot(self.times/3600, self.fish_freq[enu], marker='.', color=c, zorder=1) ax.plot(self.times/3600, self.fish_freq[enu], marker='.', color=c, zorder=1, label=f'{self.mean_shelter_power[enu]:.2f}dB')
ax.plot(self.times[np.isnan(self.fish_freq[enu])]/3600, self.fish_freq_interp[enu][np.isnan(self.fish_freq[enu])], '.', zorder=1, color=c, alpha=0.25) ax.plot(self.times[np.isnan(self.fish_freq[enu])]/3600, self.fish_freq_interp[enu][np.isnan(self.fish_freq[enu])], '.', zorder=1, color=c, alpha=0.25)
ax.plot(self.baseline_freq_times/3600, self.baseline_freqs[enu], '--', color='k', zorder=2) ax.plot(self.baseline_freq_times/3600, self.baseline_freqs[enu], '--', color='k', zorder=2)
ax.plot(self.baseline_freq_times/3600, self.pct95_freqs[enu], '--', color='k', zorder=2) ax.plot(self.baseline_freq_times/3600, self.pct95_freqs[enu], '--', color='k', zorder=2)