adapted code in a way that rise analysis include more recordings. chirp analysis skip those marked as bad by patrick

This commit is contained in:
Till Raab 2023-06-13 09:22:33 +02:00
parent 50c762ab8d
commit e2c284ea76
2 changed files with 50 additions and 24 deletions

View File

@ -173,8 +173,8 @@ def main(base_path):
# ToDo: for chirp and rise analysis different datasets!!!
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
good_chirp_trial_idx = np.arange(len(chirp_notes))[chirp_notes['good'] == 1]
trial_summary = trial_summary[chirp_notes['good'] == 1]
trial_mask = chirp_notes['good'] == 1
# trial_summary = trial_summary[chirp_notes['good'] == 1]
all_rise_times_lose = []
all_rise_times_win = []
@ -233,8 +233,12 @@ def main(base_path):
all_rise_times_lose.append(rise_times[1])
all_rise_times_win.append(rise_times[0])
all_chirp_times_lose.append(chirp_times[1])
all_chirp_times_win.append(chirp_times[0])
if trial_mask[index]:
all_chirp_times_lose.append(chirp_times[1])
all_chirp_times_win.append(chirp_times[0])
else:
all_chirp_times_lose.append(np.array([]))
all_chirp_times_win.append(np.array([]))
win_sex.append(trial['sex_win'])
lose_sex.append(trial['sex_lose'])
@ -258,7 +262,6 @@ def main(base_path):
#############################################################################
for all_event_t, event_name in zip([all_chirp_times_lose, all_chirp_times_win, all_rise_times_lose, all_rise_times_win],
[r'chirps$_{lose}$', r'chirps$_{win}$', r'rises$_{lose}$', r'rises$_{win}$']):
print('')
@ -285,6 +288,9 @@ def main(base_path):
if len(ag_on_t) == 0:
continue
if len(event_times) == 0:
continue
pre_chase_event_mask = np.zeros_like(event_times)
chase_event_mask = np.zeros_like(event_times)
end_chase_event_mask = np.zeros_like(event_times)

View File

@ -244,7 +244,8 @@ def main(base_path):
# ToDo: for chirp and rise analysis different datasets!!!
trial_summary = pd.read_csv('trial_summary.csv', index_col=0)
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
trial_summary = trial_summary[chirp_notes['good'] == 1]
# trial_summary = trial_summary[chirp_notes['good'] == 1]
trial_mask = chirp_notes['good'] == 1
lose_chrips_centered_on_ag_off_t = []
lose_chrips_centered_on_ag_on_t = []
@ -301,24 +302,36 @@ def main(base_path):
chirp_ids = np.load(os.path.join(trial_path, 'chirp_ids_cnn.npy'))
chirp_times = [chirp_t[chirp_ids == trial['win_ID']], chirp_t[chirp_ids == trial['lose_ID']]]
rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter]
rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))]
rise_times = [times[rise_idx_int[0]], times[rise_idx_int[1]]]
### collect for correlations ####
# chirps
lose_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[1]))
lose_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[1]))
lose_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[1]))
lose_chrips_centered_on_win_rises.append(event_centered_times(rise_times[0], chirp_times[1]))
lose_chirp_count.append(len(chirp_times[1]))
win_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[0]))
win_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[0]))
win_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[0]))
win_chrips_centered_on_lose_rises.append(event_centered_times(rise_times[1], chirp_times[0]))
win_chirp_count.append(len(chirp_times[0]))
if trial_mask[index]:
lose_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[1]))
lose_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[1]))
lose_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[1]))
lose_chrips_centered_on_win_rises.append(event_centered_times(rise_times[0], chirp_times[1]))
lose_chirp_count.append(len(chirp_times[1]))
win_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[0]))
win_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[0]))
win_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[0]))
win_chrips_centered_on_lose_rises.append(event_centered_times(rise_times[1], chirp_times[0]))
win_chirp_count.append(len(chirp_times[0]))
else:
lose_chrips_centered_on_ag_off_t.append(np.array([]))
lose_chrips_centered_on_ag_on_t.append(np.array([]))
lose_chrips_centered_on_contact_t.append(np.array([]))
lose_chrips_centered_on_win_rises.append(np.array([]))
lose_chirp_count.append(np.nan)
win_chrips_centered_on_ag_off_t.append(np.array([]))
win_chrips_centered_on_ag_on_t.append(np.array([]))
win_chrips_centered_on_contact_t.append(np.array([]))
win_chrips_centered_on_lose_rises.append(np.array([]))
win_chirp_count.append(np.nan)
# rises
lose_rises_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], rise_times[1]))
@ -335,6 +348,7 @@ def main(base_path):
sex_win.append(trial['sex_win'])
sex_lose.append(trial['sex_lose'])
sex_win = np.array(sex_win)
sex_lose = np.array(sex_lose)
# embed()
@ -380,10 +394,10 @@ def main(base_path):
if sex_w == sex_win[i] and sex_l == sex_lose[i]:
centered_times_pairing[-1].append(centered_times[i])
event_counts_pairings = [np.sum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'm')]),
np.sum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'f')]),
np.sum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'm')]),
np.sum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'f')])]
event_counts_pairings = [np.nansum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'm')]),
np.nansum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'f')]),
np.nansum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'm')]),
np.nansum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'f')])]
color = [male_color, female_color, male_color, female_color]
linestyle = ['-', '--', '--', '-']
@ -414,9 +428,12 @@ def main(base_path):
ax[enu].plot(conv_t_numpy, jk_p50 / event_count_p / jack_pct, color=color[enu], alpha=1, lw=3, linestyle=linestyle[enu])
ax_m = ax[enu].twinx()
counter = 0
for enu2, centered_events in enumerate(centered_times_p):
Cevents = centered_events[np.abs(centered_events) <= max_dt]
ax_m.plot(Cevents, np.ones(len(Cevents)) * enu2, '|', markersize=8, color='k', alpha=.1)
if len(Cevents) != 0:
ax_m.plot(Cevents, np.ones(len(Cevents)) * counter, '|', markersize=8, color='k', alpha=.1)
counter += 1
ax_m.set_yticks([])
ax[enu].set_xlim(-max_dt, max_dt)
@ -461,7 +478,10 @@ def main(base_path):
ax_m = ax.twinx()
for enu, centered_events in enumerate(centered_times):
Cevents = centered_events[np.abs(centered_events) <= max_dt]
ax_m.plot(Cevents, np.ones(len(Cevents)) * enu, '|', markersize=8, color='k', alpha=.1)
if len(Cevents) != 0:
ax_m.plot(Cevents, np.ones(len(Cevents)) * counter, '|', markersize=8, color='k', alpha=.1)
counter += 1
# ax_m.plot(Cevents, np.ones(len(Cevents)) * enu, '|', markersize=8, color='k', alpha=.1)
ax_m.set_yticks([])
ax.set_xlabel('time [s]', fontsize=12)