From e3a5cadc0c625224999b121b24f5d00df3760fb6 Mon Sep 17 00:00:00 2001 From: Till Raab Date: Tue, 30 May 2023 15:03:48 +0200 Subject: [PATCH] works nice... adapt loop and clean up the code. --- event_time_analysis.py | 188 ++++++++++++++++++++++++++++++----------- 1 file changed, 140 insertions(+), 48 deletions(-) diff --git a/event_time_analysis.py b/event_time_analysis.py index 18c7c4a..376b496 100644 --- a/event_time_analysis.py +++ b/event_time_analysis.py @@ -87,7 +87,7 @@ def kde(event_dt, conv_t, kernal_w = 1, kernal_h = 0.2): return conv_array -def permulation_kde(event_dt, repetitions = 2000, max_dt = 60, max_mem_use_GB = 4, kernal_w = 1, kernal_h = 0.2): +def permulation_kde(event_dt, conv_t, repetitions = 2000, max_mem_use_GB = 4, kernal_w = 1, kernal_h = 0.2): def chunk_permutation(select_event_dt, conv_tt, n_chuck, max_jitter, kernal_w, kernal_h): # array.shape = (120, 100, 15486) = (len(conv_t), repetitions, len(event_dt)) # event_dt_perm = cp.tile(event_dt, (len(conv_t), repetitions, 1)) @@ -113,10 +113,10 @@ def permulation_kde(event_dt, repetitions = 2000, max_dt = 60, max_mem_use_GB = t0 = time.time() - max_jitter = 2*max_dt - select_event_dt = event_dt[np.abs(event_dt) <= max_dt + max_jitter*2] + max_jitter = float(2*cp.max(conv_t)) + select_event_dt = event_dt[np.abs(event_dt) <= float(cp.max(conv_t)) + max_jitter*2] - conv_t = cp.arange(-max_dt, max_dt, 1) + # conv_t = cp.arange(-max_dt, max_dt, 1) conv_tt = cp.reshape(conv_t, (len(conv_t), 1, 1)) chunk_size = int(np.floor(max_mem_use_GB / (select_event_dt.nbytes * conv_t.size / 1e9))) @@ -157,7 +157,7 @@ def permulation_kde(event_dt, repetitions = 2000, max_dt = 60, max_mem_use_GB = return chunk_collector -def jackknife_kde(event_dt, repetitions = 2000, max_dt = 60, max_mem_use_GB = 2, jack_pct = 0.9, kernal_w = 1, kernal_h = 0.2): +def jackknife_kde(event_dt, conv_t, repetitions = 2000, max_mem_use_GB = 2, jack_pct = 0.9, kernal_w = 1, kernal_h = 0.2): def chunk_jackknife(select_event_dt, conv_tt, n_chuck, jack_pct, kernal_w, kernal_h): event_dt_rep = cp.tile(select_event_dt, (n_chuck, 1)) idx = cp.random.rand(*event_dt_rep.shape).argsort(1)[:, :int(event_dt_rep.shape[-1]*jack_pct)] @@ -189,9 +189,9 @@ def jackknife_kde(event_dt, repetitions = 2000, max_dt = 60, max_mem_use_GB = 2, t0 = time.time() # max_jitter = 2*max_dt - select_event_dt = event_dt[np.abs(event_dt) <= max_dt * 2] + select_event_dt = event_dt[np.abs(event_dt) <= float(cp.max(conv_t)) * 2] - conv_t = cp.arange(-max_dt, max_dt, 1) + # conv_t = cp.arange(-max_dt, max_dt, 1) conv_tt = cp.reshape(conv_t, (len(conv_t), 1, 1)) chunk_size = int(np.floor(max_mem_use_GB / (select_event_dt.nbytes * jack_pct * conv_t.size / 1e9))) @@ -231,7 +231,24 @@ def main(base_path): trial_summary = pd.read_csv('trial_summary.csv', index_col=0) lose_chrips_centered_on_ag_off_t = [] + lose_chrips_centered_on_ag_on_t = [] + lose_chrips_centered_on_contact_t = [] + lose_chrips_centered_on_win_rises = [] lose_chirp_count = [] + + lose_rises_centered_on_ag_off_t = [] + lose_rises_centered_on_ag_on_t = [] + lose_rises_centered_on_contact_t = [] + lose_rises_centered_on_win_chirps = [] + lose_rises_count = [] + + + + win_chrips_centered_on_ag_off_t = [] + win_chrips_centered_on_ag_on_t = [] + win_chrips_centered_on_contact_t = [] + win_chirp_count = [] + for index, trial in tqdm(trial_summary.iterrows()): trial_path = os.path.join(base_path, trial['recording']) @@ -262,54 +279,129 @@ def main(base_path): rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))] rise_times = [times[rise_idx_int[0]], times[rise_idx_int[1]]] + ### collect for correlations #### lose_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[1])) + lose_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[1])) + lose_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[1])) + lose_chrips_centered_on_win_rises.append(event_centered_times(rise_times[0], chirp_times[1])) lose_chirp_count.append(len(chirp_times[1])) + lose_rises_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], rise_times[1])) + lose_rises_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], rise_times[1])) + lose_rises_centered_on_contact_t.append(event_centered_times(contact_t_GRID, rise_times[1])) + lose_rises_centered_on_win_chirps.append(event_centered_times(chirp_times[0], rise_times[1])) + lose_rises_count.append(len(rise_times[1])) - max_dt = 30 - conv_t = np.arange(-max_dt, max_dt, 1) - - kde_array = kde(np.hstack(lose_chrips_centered_on_ag_off_t), conv_t, kernal_w = 1, kernal_h = 1) - - boot_kde = permulation_kde(np.hstack(lose_chrips_centered_on_ag_off_t), max_dt=max_dt, kernal_w=1, kernal_h=1) + win_chrips_centered_on_ag_off_t.append(event_centered_times(ag_on_off_t_GRID[:, 1], chirp_times[0])) + win_chrips_centered_on_ag_on_t.append(event_centered_times(ag_on_off_t_GRID[:, 0], chirp_times[0])) + win_chrips_centered_on_contact_t.append(event_centered_times(contact_t_GRID, chirp_times[0])) + win_chirp_count.append(len(chirp_times[0])) + # embed() + # quit() + max_dt = 30 + conv_t_dt = 0.5 jack_pct = 0.9 - jk_kde = jackknife_kde(np.hstack(lose_chrips_centered_on_ag_off_t), max_dt=max_dt, jack_pct = jack_pct, kernal_w=1, kernal_h=1) - - perm_p1, perm_p50, perm_p99 = np.percentile(boot_kde, (1, 50, 99), axis=0) - jk_p1, jk_p50, jk_p99 = np.percentile(jk_kde, (1, 50, 99), axis=0) - - embed() - quit() - fig = plt.figure(figsize=(20/2.54, 12/2.54)) - gs = gridspec.GridSpec(1, 1, left=0.1, bottom=0.1, right=0.95, top=0.95) - ax = fig.add_subplot(gs[0, 0]) - ax.fill_between(conv_t, perm_p1/np.sum(lose_chirp_count), perm_p99/np.sum(lose_chirp_count), color='cornflowerblue', alpha=.8) - ax.plot(conv_t, perm_p50/np.sum(lose_chirp_count), color='dodgerblue', alpha=1, lw=3) - - ax.fill_between(conv_t, jk_p1/np.sum(lose_chirp_count)/jack_pct, jk_p99/np.sum(lose_chirp_count)/jack_pct, color='tab:red', alpha=.8) - ax.plot(conv_t, jk_p50/np.sum(lose_chirp_count)/jack_pct, color='firebrick', alpha=1, lw=3) - - ax_m = ax.twinx() - for enu, centered_events in enumerate(lose_chrips_centered_on_ag_off_t): - Cevents = centered_events[np.abs(centered_events) <= max_dt] - ax_m.plot(Cevents, np.ones(len(Cevents)) * enu, '|', markersize=8, color='k', alpha=.1) - - ax_m.set_yticks([]) - ax.set_xlabel('time [s]', fontsize=12) - ax.set_ylabel('event rate [Hz]', fontsize=12) - ax.set_xlim(-max_dt, max_dt) - ax.tick_params(labelsize=10) - - # for i in range(len(boot_kde)): - # ax.plot(conv_t, boot_kde[i] / np.sum(lose_chirp_count), color='tab:blue') - # - # for i in range(len(boot_kde)): - # ax.plot(conv_t, jk_kde[i] / np.sum(lose_chirp_count) / jack_pct, color='tab:red') - # ax.plot(conv_t, kde_array/np.sum(lose_chirp_count), color='k', lw=3) - plt.show() - pass + conv_t = cp.arange(-max_dt, max_dt, conv_t_dt) + # kde_array = kde(np.hstack(lose_chrips_centered_on_ag_off_t), conv_t, kernal_w = 1, kernal_h = 1) + + # for centered_times, event_counts, title in \ + # zip([lose_chrips_centered_on_ag_off_t, + # lose_chrips_centered_on_ag_on_t, + # lose_chrips_centered_on_contact_t, + # win_chrips_centered_on_ag_off_t, + # win_chrips_centered_on_ag_on_t, + # win_chrips_centered_on_contact_t, + # lose_rises_centered_on_ag_on_t, + # lose_chrips_centered_on_win_rises], + # + # [lose_chirp_count, + # lose_chirp_count, + # lose_chirp_count, + # win_chirp_count, + # win_chirp_count, + # win_chirp_count, + # lose_rises_count, + # lose_chirp_count], + # + # [r'chirp$_{lose}$ on chase$_{off}$', + # r'chirp$_{lose}$ on chase$_{on}$', + # r'chirp$_{lose}$ on contact', + # r'chirp$_{win}$ on chase$_{off}$', + # r'chirp$_{win}$ on chase$_{on}$', + # r'chirp$_{win}$ on contact', + # r'rise$_{lose}$ on chase$_{on}$', + # r'chirp$_{lose}$ on rise$_{win}$']): + for centered_times, event_counts, title in \ + [[lose_chrips_centered_on_ag_off_t, lose_chirp_count, r'chirp$_{lose}$ on chase$_{off}$'], + [lose_chrips_centered_on_ag_on_t, lose_chirp_count, r'chirp$_{lose}$ on chase$_{on}$']]: + + boot_kde = permulation_kde(np.hstack(centered_times), conv_t, kernal_w=1, kernal_h=1) + jk_kde = jackknife_kde(np.hstack(centered_times), conv_t, jack_pct=jack_pct, kernal_w=1, kernal_h=1) + + perm_p1, perm_p50, perm_p99 = np.percentile(boot_kde, (1, 50, 99), axis=0) + jk_p1, jk_p50, jk_p99 = np.percentile(jk_kde, (1, 50, 99), axis=0) + + + conv_t_numpy = cp.asnumpy(conv_t) + + fig = plt.figure(figsize=(20/2.54, 12/2.54)) + gs = gridspec.GridSpec(1, 1, left=0.1, bottom=0.1, right=0.95, top=0.95) + ax = fig.add_subplot(gs[0, 0]) + ax.fill_between(conv_t_numpy, perm_p1/np.sum(event_counts), perm_p99/np.sum(event_counts), color='cornflowerblue', alpha=.8) + ax.plot(conv_t_numpy, perm_p50/np.sum(event_counts), color='dodgerblue', alpha=1, lw=3) + + ax.fill_between(conv_t_numpy, jk_p1/np.sum(event_counts)/jack_pct, jk_p99/np.sum(event_counts)/jack_pct, color='tab:red', alpha=.8) + ax.plot(conv_t_numpy, jk_p50/np.sum(event_counts)/jack_pct, color='firebrick', alpha=1, lw=3) + + ax_m = ax.twinx() + for enu, centered_events in enumerate(centered_times): + Cevents = centered_events[np.abs(centered_events) <= max_dt] + ax_m.plot(Cevents, np.ones(len(Cevents)) * enu, '|', markersize=8, color='k', alpha=.1) + + ax_m.set_yticks([]) + ax.set_xlabel('time [s]', fontsize=12) + ax.set_ylabel('event rate [Hz]', fontsize=12) + ax.set_title(title) + ax.set_xlim(-max_dt, max_dt) + ax.tick_params(labelsize=10) + + plt.show() + + ################################## + # boot_kde = permulation_kde(np.hstack(lose_chrips_centered_on_ag_off_t), conv_t, kernal_w=1, kernal_h=1) + # jk_kde = jackknife_kde(np.hstack(lose_chrips_centered_on_ag_off_t), conv_t, jack_pct = jack_pct, kernal_w=1, kernal_h=1) + # + # perm_p1, perm_p50, perm_p99 = np.percentile(boot_kde, (1, 50, 99), axis=0) + # jk_p1, jk_p50, jk_p99 = np.percentile(jk_kde, (1, 50, 99), axis=0) + # + # ################################################################################################################# + # + # conv_t_numpy = cp.asnumpy(conv_t) + # + # fig = plt.figure(figsize=(20/2.54, 12/2.54)) + # gs = gridspec.GridSpec(1, 1, left=0.1, bottom=0.1, right=0.95, top=0.95) + # ax = fig.add_subplot(gs[0, 0]) + # ax.fill_between(conv_t_numpy, perm_p1/np.sum(lose_chirp_count), perm_p99/np.sum(lose_chirp_count), color='cornflowerblue', alpha=.8) + # ax.plot(conv_t_numpy, perm_p50/np.sum(lose_chirp_count), color='dodgerblue', alpha=1, lw=3) + # + # ax.fill_between(conv_t_numpy, jk_p1/np.sum(lose_chirp_count)/jack_pct, jk_p99/np.sum(lose_chirp_count)/jack_pct, color='tab:red', alpha=.8) + # ax.plot(conv_t_numpy, jk_p50/np.sum(lose_chirp_count)/jack_pct, color='firebrick', alpha=1, lw=3) + # + # ax_m = ax.twinx() + # for enu, centered_events in enumerate(lose_chrips_centered_on_ag_off_t): + # Cevents = centered_events[np.abs(centered_events) <= max_dt] + # ax_m.plot(Cevents, np.ones(len(Cevents)) * enu, '|', markersize=8, color='k', alpha=.1) + # + # ax_m.set_yticks([]) + # ax.set_xlabel('time [s]', fontsize=12) + # ax.set_ylabel('event rate [Hz]', fontsize=12) + # ax.set_xlim(-max_dt, max_dt) + # ax.tick_params(labelsize=10) + # + # plt.show() + # pass if __name__ == '__main__': main(sys.argv[1])