diff --git a/data/generate_dataset.py b/data/generate_dataset.py index bdc6570..d3fabb3 100644 --- a/data/generate_dataset.py +++ b/data/generate_dataset.py @@ -127,9 +127,33 @@ def main(args): (rise_idx[id_idx] >= times_v_idx0) & (rise_idx[id_idx] <= times_v_idx1) & (rise_size[id_idx] >= 10)], dtype=int) + rise_size_oi = rise_idx[id_idx][(rise_idx[id_idx] >= times_v_idx0) & + (rise_idx[id_idx] <= times_v_idx1) & + (rise_size[id_idx] >= 10)] + ax.plot(times_v[rise_idx_oi], fish_freq[id_idx][rise_idx_oi], 'o', color='tab:red') if len(rise_idx_oi) > 0: + closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi])) + closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx] + + upper_freq_bound = closest_baseline_freq + rise_size_oi + lower_freq_bound = closest_baseline_freq + + left_time_bound = times_v[rise_idx_oi] + right_time_bound = np.zeros_like(left_time_bound) + + for enu, Ct_idx in enumerate(times_v[rise_idx_oi]): + Crise_size = rise_size_oi[enu] + Cblf = closest_baseline_freq[enu] + + rise_end_t = times_v[(idx_v > Ct_idx) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] + if len(rise_end_t) == 0: + right_time_bound[enu] = np.nan + else: + right_time_bound[enu] = rise_end_t[0] + + embed() quit() plt.show()