diff --git a/event_time_analysis.py b/event_time_analysis.py index 4a7d7dc..6e519bb 100644 --- a/event_time_analysis.py +++ b/event_time_analysis.py @@ -173,8 +173,8 @@ def main(base_path): # ToDo: for chirp and rise analysis different datasets!!! trial_summary = pd.read_csv(os.path.join(base_path, 'trial_summary.csv'), index_col=0) chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0) - good_chirp_trial_idx = np.arange(len(chirp_notes))[chirp_notes['good'] == 1] - trial_summary = trial_summary[chirp_notes['good'] == 1] + trial_mask = chirp_notes['good'] == 1 + # trial_summary = trial_summary[chirp_notes['good'] == 1] all_rise_times_lose = [] all_rise_times_win = [] @@ -233,8 +233,12 @@ def main(base_path): all_rise_times_lose.append(rise_times[1]) all_rise_times_win.append(rise_times[0]) - all_chirp_times_lose.append(chirp_times[1]) - all_chirp_times_win.append(chirp_times[0]) + if trial_mask[index]: + all_chirp_times_lose.append(chirp_times[1]) + all_chirp_times_win.append(chirp_times[0]) + else: + all_chirp_times_lose.append(np.array([])) + all_chirp_times_win.append(np.array([])) win_sex.append(trial['sex_win']) lose_sex.append(trial['sex_lose']) @@ -258,7 +262,6 @@ def main(base_path): ############################################################################# - for all_event_t, event_name in zip([all_chirp_times_lose, all_chirp_times_win, all_rise_times_lose, all_rise_times_win], [r'chirps$_{lose}$', r'chirps$_{win}$', r'rises$_{lose}$', r'rises$_{win}$']): print('') @@ -285,6 +288,9 @@ def main(base_path): if len(ag_on_t) == 0: continue + if len(event_times) == 0: + continue + pre_chase_event_mask = np.zeros_like(event_times) chase_event_mask = np.zeros_like(event_times) end_chase_event_mask = np.zeros_like(event_times) diff --git a/event_time_correlations.py b/event_time_correlations.py index 2c25cee..dad1e19 100644 --- a/event_time_correlations.py +++ b/event_time_correlations.py @@ -244,7 +244,8 @@ def main(base_path): # ToDo: for chirp and rise analysis different datasets!!! trial_summary = pd.read_csv('trial_summary.csv', index_col=0) chirp_notes = pd.read_csv(os.path.join(base_path, 'chirp_notes.csv'), index_col=0) - trial_summary = trial_summary[chirp_notes['good'] == 1] + # trial_summary = trial_summary[chirp_notes['good'] == 1] + trial_mask = chirp_notes['good'] == 1 lose_chrips_centered_on_ag_off_t = [] lose_chrips_centered_on_ag_on_t = [] @@ -301,24 +302,36 @@ def main(base_path): chirp_ids = np.load(os.path.join(trial_path, 'chirp_ids_cnn.npy')) chirp_times = [chirp_t[chirp_ids == trial['win_ID']], chirp_t[chirp_ids == trial['lose_ID']]] - rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter] rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))] rise_times = [times[rise_idx_int[0]], times[rise_idx_int[1]]] ### collect for correlations #### # chirps - lose_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[1])) - lose_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[1])) - lose_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[1])) - lose_chrips_centered_on_win_rises.append(event_centered_times(rise_times[0], chirp_times[1])) - lose_chirp_count.append(len(chirp_times[1])) - - win_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[0])) - win_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[0])) - win_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[0])) - win_chrips_centered_on_lose_rises.append(event_centered_times(rise_times[1], chirp_times[0])) - win_chirp_count.append(len(chirp_times[0])) + if trial_mask[index]: + lose_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[1])) + lose_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[1])) + lose_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[1])) + lose_chrips_centered_on_win_rises.append(event_centered_times(rise_times[0], chirp_times[1])) + lose_chirp_count.append(len(chirp_times[1])) + + win_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[0])) + win_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[0])) + win_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[0])) + win_chrips_centered_on_lose_rises.append(event_centered_times(rise_times[1], chirp_times[0])) + win_chirp_count.append(len(chirp_times[0])) + else: + lose_chrips_centered_on_ag_off_t.append(np.array([])) + lose_chrips_centered_on_ag_on_t.append(np.array([])) + lose_chrips_centered_on_contact_t.append(np.array([])) + lose_chrips_centered_on_win_rises.append(np.array([])) + lose_chirp_count.append(np.nan) + + win_chrips_centered_on_ag_off_t.append(np.array([])) + win_chrips_centered_on_ag_on_t.append(np.array([])) + win_chrips_centered_on_contact_t.append(np.array([])) + win_chrips_centered_on_lose_rises.append(np.array([])) + win_chirp_count.append(np.nan) # rises lose_rises_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], rise_times[1])) @@ -335,6 +348,7 @@ def main(base_path): sex_win.append(trial['sex_win']) sex_lose.append(trial['sex_lose']) + sex_win = np.array(sex_win) sex_lose = np.array(sex_lose) # embed() @@ -380,10 +394,10 @@ def main(base_path): if sex_w == sex_win[i] and sex_l == sex_lose[i]: centered_times_pairing[-1].append(centered_times[i]) - event_counts_pairings = [np.sum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'm')]), - np.sum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'f')]), - np.sum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'm')]), - np.sum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'f')])] + event_counts_pairings = [np.nansum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'm')]), + np.nansum(np.array(event_counts)[(sex_win == 'm') & (sex_lose == 'f')]), + np.nansum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'm')]), + np.nansum(np.array(event_counts)[(sex_win == 'f') & (sex_lose == 'f')])] color = [male_color, female_color, male_color, female_color] linestyle = ['-', '--', '--', '-'] @@ -414,9 +428,12 @@ def main(base_path): ax[enu].plot(conv_t_numpy, jk_p50 / event_count_p / jack_pct, color=color[enu], alpha=1, lw=3, linestyle=linestyle[enu]) ax_m = ax[enu].twinx() + counter = 0 for enu2, centered_events in enumerate(centered_times_p): Cevents = centered_events[np.abs(centered_events) <= max_dt] - ax_m.plot(Cevents, np.ones(len(Cevents)) * enu2, '|', markersize=8, color='k', alpha=.1) + if len(Cevents) != 0: + ax_m.plot(Cevents, np.ones(len(Cevents)) * counter, '|', markersize=8, color='k', alpha=.1) + counter += 1 ax_m.set_yticks([]) ax[enu].set_xlim(-max_dt, max_dt) @@ -461,7 +478,10 @@ def main(base_path): ax_m = ax.twinx() for enu, centered_events in enumerate(centered_times): Cevents = centered_events[np.abs(centered_events) <= max_dt] - ax_m.plot(Cevents, np.ones(len(Cevents)) * enu, '|', markersize=8, color='k', alpha=.1) + if len(Cevents) != 0: + ax_m.plot(Cevents, np.ones(len(Cevents)) * counter, '|', markersize=8, color='k', alpha=.1) + counter += 1 + # ax_m.plot(Cevents, np.ones(len(Cevents)) * enu, '|', markersize=8, color='k', alpha=.1) ax_m.set_yticks([]) ax.set_xlabel('time [s]', fontsize=12)