adapted code in a way that rise analysis include more recordings. chirp analysis skip those marked as bad by patrick
This commit is contained in:
parent
50c762ab8d
commit
e2c284ea76
@ -173,8 +173,8 @@ def main(base_path):
|
|||||||
# ToDo: for chirp and rise analysis different datasets!!!
|
# ToDo: for chirp and rise analysis different datasets!!!
|
||||||
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
|
trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0)
|
||||||
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
|
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
|
||||||
good_chirp_trial_idx = np.arange(len(chirp_notes))[chirp_notes['good'] == 1]
|
trial_mask = chirp_notes['good'] == 1
|
||||||
trial_summary = trial_summary[chirp_notes['good'] == 1]
|
# trial_summary = trial_summary[chirp_notes['good'] == 1]
|
||||||
|
|
||||||
all_rise_times_lose = []
|
all_rise_times_lose = []
|
||||||
all_rise_times_win = []
|
all_rise_times_win = []
|
||||||
@ -233,8 +233,12 @@ def main(base_path):
|
|||||||
all_rise_times_lose.append(rise_times[1])
|
all_rise_times_lose.append(rise_times[1])
|
||||||
all_rise_times_win.append(rise_times[0])
|
all_rise_times_win.append(rise_times[0])
|
||||||
|
|
||||||
all_chirp_times_lose.append(chirp_times[1])
|
if trial_mask[index]:
|
||||||
all_chirp_times_win.append(chirp_times[0])
|
all_chirp_times_lose.append(chirp_times[1])
|
||||||
|
all_chirp_times_win.append(chirp_times[0])
|
||||||
|
else:
|
||||||
|
all_chirp_times_lose.append(np.array([]))
|
||||||
|
all_chirp_times_win.append(np.array([]))
|
||||||
|
|
||||||
win_sex.append(trial['sex_win'])
|
win_sex.append(trial['sex_win'])
|
||||||
lose_sex.append(trial['sex_lose'])
|
lose_sex.append(trial['sex_lose'])
|
||||||
@ -258,7 +262,6 @@ def main(base_path):
|
|||||||
|
|
||||||
|
|
||||||
#############################################################################
|
#############################################################################
|
||||||
|
|
||||||
for all_event_t, event_name in zip([all_chirp_times_lose, all_chirp_times_win, all_rise_times_lose, all_rise_times_win],
|
for all_event_t, event_name in zip([all_chirp_times_lose, all_chirp_times_win, all_rise_times_lose, all_rise_times_win],
|
||||||
[r'chirps$_{lose}$', r'chirps$_{win}$', r'rises$_{lose}$', r'rises$_{win}$']):
|
[r'chirps$_{lose}$', r'chirps$_{win}$', r'rises$_{lose}$', r'rises$_{win}$']):
|
||||||
print('')
|
print('')
|
||||||
@ -285,6 +288,9 @@ def main(base_path):
|
|||||||
if len(ag_on_t) == 0:
|
if len(ag_on_t) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if len(event_times) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
pre_chase_event_mask = np.zeros_like(event_times)
|
pre_chase_event_mask = np.zeros_like(event_times)
|
||||||
chase_event_mask = np.zeros_like(event_times)
|
chase_event_mask = np.zeros_like(event_times)
|
||||||
end_chase_event_mask = np.zeros_like(event_times)
|
end_chase_event_mask = np.zeros_like(event_times)
|
||||||
|
@ -244,7 +244,8 @@ def main(base_path):
|
|||||||
# ToDo: for chirp and rise analysis different datasets!!!
|
# ToDo: for chirp and rise analysis different datasets!!!
|
||||||
trial_summary = pd.read_csv('trial_summary.csv', index_col=0)
|
trial_summary = pd.read_csv('trial_summary.csv', index_col=0)
|
||||||
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
|
chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0)
|
||||||
trial_summary = trial_summary[chirp_notes['good'] == 1]
|
# trial_summary = trial_summary[chirp_notes['good'] == 1]
|
||||||
|
trial_mask = chirp_notes['good'] == 1
|
||||||
|
|
||||||
lose_chrips_centered_on_ag_off_t = []
|
lose_chrips_centered_on_ag_off_t = []
|
||||||
lose_chrips_centered_on_ag_on_t = []
|
lose_chrips_centered_on_ag_on_t = []
|
||||||
@ -301,24 +302,36 @@ def main(base_path):
|
|||||||
chirp_ids = np.load(os.path.join(trial_path, 'chirp_ids_cnn.npy'))
|
chirp_ids = np.load(os.path.join(trial_path, 'chirp_ids_cnn.npy'))
|
||||||
chirp_times = [chirp_t[chirp_ids == trial['win_ID']], chirp_t[chirp_ids == trial['lose_ID']]]
|
chirp_times = [chirp_t[chirp_ids == trial['win_ID']], chirp_t[chirp_ids == trial['lose_ID']]]
|
||||||
|
|
||||||
|
|
||||||
rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter]
|
rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter]
|
||||||
rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))]
|
rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))]
|
||||||
rise_times = [times[rise_idx_int[0]], times[rise_idx_int[1]]]
|
rise_times = [times[rise_idx_int[0]], times[rise_idx_int[1]]]
|
||||||
|
|
||||||
### collect for correlations ####
|
### collect for correlations ####
|
||||||
# chirps
|
# chirps
|
||||||
lose_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[1]))
|
if trial_mask[index]:
|
||||||
lose_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[1]))
|
lose_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[1]))
|
||||||
lose_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[1]))
|
lose_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[1]))
|
||||||
lose_chrips_centered_on_win_rises.append(event_centered_times(rise_times[0], chirp_times[1]))
|
lose_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[1]))
|
||||||
lose_chirp_count.append(len(chirp_times[1]))
|
lose_chrips_centered_on_win_rises.append(event_centered_times(rise_times[0], chirp_times[1]))
|
||||||
|
lose_chirp_count.append(len(chirp_times[1]))
|
||||||
win_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[0]))
|
|
||||||
win_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[0]))
|
win_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[0]))
|
||||||
win_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[0]))
|
win_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[0]))
|
||||||
win_chrips_centered_on_lose_rises.append(event_centered_times(rise_times[1], chirp_times[0]))
|
win_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[0]))
|
||||||
win_chirp_count.append(len(chirp_times[0]))
|
win_chrips_centered_on_lose_rises.append(event_centered_times(rise_times[1], chirp_times[0]))
|
||||||
|
win_chirp_count.append(len(chirp_times[0]))
|
||||||
|
else:
|
||||||
|
lose_chrips_centered_on_ag_off_t.append(np.array([]))
|
||||||
|
lose_chrips_centered_on_ag_on_t.append(np.array([]))
|
||||||
|
lose_chrips_centered_on_contact_t.append(np.array([]))
|
||||||
|
lose_chrips_centered_on_win_rises.append(np.array([]))
|
||||||
|
lose_chirp_count.append(np.nan)
|
||||||
|
|
||||||
|
win_chrips_centered_on_ag_off_t.append(np.array([]))
|
||||||
|
win_chrips_centered_on_ag_on_t.append(np.array([]))
|
||||||
|
win_chrips_centered_on_contact_t.append(np.array([]))
|
||||||
|
win_chrips_centered_on_lose_rises.append(np.array([]))
|
||||||
|
win_chirp_count.append(np.nan)
|
||||||
|
|
||||||
# rises
|
# rises
|
||||||
lose_rises_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], rise_times[1]))
|
lose_rises_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], rise_times[1]))
|
||||||
@ -335,6 +348,7 @@ def main(base_path):
|
|||||||
|
|
||||||
sex_win.append(trial['sex_win'])
|
sex_win.append(trial['sex_win'])
|
||||||
sex_lose.append(trial['sex_lose'])
|
sex_lose.append(trial['sex_lose'])
|
||||||
|
|
||||||
sex_win = np.array(sex_win)
|
sex_win = np.array(sex_win)
|
||||||
sex_lose = np.array(sex_lose)
|
sex_lose = np.array(sex_lose)
|
||||||
# embed()
|
# embed()
|
||||||
@ -380,10 +394,10 @@ def main(base_path):
|
|||||||
if sex_w == sex_win[i] and sex_l == sex_lose[i]:
|
if sex_w == sex_win[i] and sex_l == sex_lose[i]:
|
||||||
centered_times_pairing[-1].append(centered_times[i])
|
centered_times_pairing[-1].append(centered_times[i])
|
||||||
|
|
||||||
event_counts_pairings = [np.sum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'm')]),
|
event_counts_pairings = [np.nansum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'm')]),
|
||||||
np.sum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'f')]),
|
np.nansum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'f')]),
|
||||||
np.sum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'm')]),
|
np.nansum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'm')]),
|
||||||
np.sum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'f')])]
|
np.nansum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'f')])]
|
||||||
color = [male_color, female_color, male_color, female_color]
|
color = [male_color, female_color, male_color, female_color]
|
||||||
linestyle = ['-', '--', '--', '-']
|
linestyle = ['-', '--', '--', '-']
|
||||||
|
|
||||||
@ -414,9 +428,12 @@ def main(base_path):
|
|||||||
ax[enu].plot(conv_t_numpy, jk_p50 / event_count_p / jack_pct, color=color[enu], alpha=1, lw=3, linestyle=linestyle[enu])
|
ax[enu].plot(conv_t_numpy, jk_p50 / event_count_p / jack_pct, color=color[enu], alpha=1, lw=3, linestyle=linestyle[enu])
|
||||||
|
|
||||||
ax_m = ax[enu].twinx()
|
ax_m = ax[enu].twinx()
|
||||||
|
counter = 0
|
||||||
for enu2, centered_events in enumerate(centered_times_p):
|
for enu2, centered_events in enumerate(centered_times_p):
|
||||||
Cevents = centered_events[np.abs(centered_events) <= max_dt]
|
Cevents = centered_events[np.abs(centered_events) <= max_dt]
|
||||||
ax_m.plot(Cevents, np.ones(len(Cevents)) * enu2, '|', markersize=8, color='k', alpha=.1)
|
if len(Cevents) != 0:
|
||||||
|
ax_m.plot(Cevents, np.ones(len(Cevents)) * counter, '|', markersize=8, color='k', alpha=.1)
|
||||||
|
counter += 1
|
||||||
|
|
||||||
ax_m.set_yticks([])
|
ax_m.set_yticks([])
|
||||||
ax[enu].set_xlim(-max_dt, max_dt)
|
ax[enu].set_xlim(-max_dt, max_dt)
|
||||||
@ -461,7 +478,10 @@ def main(base_path):
|
|||||||
ax_m = ax.twinx()
|
ax_m = ax.twinx()
|
||||||
for enu, centered_events in enumerate(centered_times):
|
for enu, centered_events in enumerate(centered_times):
|
||||||
Cevents = centered_events[np.abs(centered_events) <= max_dt]
|
Cevents = centered_events[np.abs(centered_events) <= max_dt]
|
||||||
ax_m.plot(Cevents, np.ones(len(Cevents)) * enu, '|', markersize=8, color='k', alpha=.1)
|
if len(Cevents) != 0:
|
||||||
|
ax_m.plot(Cevents, np.ones(len(Cevents)) * counter, '|', markersize=8, color='k', alpha=.1)
|
||||||
|
counter += 1
|
||||||
|
# ax_m.plot(Cevents, np.ones(len(Cevents)) * enu, '|', markersize=8, color='k', alpha=.1)
|
||||||
|
|
||||||
ax_m.set_yticks([])
|
ax_m.set_yticks([])
|
||||||
ax.set_xlabel('time [s]', fontsize=12)
|
ax.set_xlabel('time [s]', fontsize=12)
|
||||||
|
Loading…
Reference in New Issue
Block a user