diff --git a/code/chirpdetection.py b/code/chirpdetection.py index a10608e..083fc62 100644 --- a/code/chirpdetection.py +++ b/code/chirpdetection.py @@ -183,19 +183,6 @@ def main(datapath: str) -> None: chirps = [] fish_ids = [] - baseline_ts = [[[ - [] for el in range(config.number_electrodes)] - for tr in range(len(data.ids))] - for wi in range(nwindows)] - search_ts = [[[ - [] for el in range(config.number_electrodes)] - for tr in range(len(data.ids))] - for wi in range(nwindows)] - freq_ts = [[[ - [] for el in range(config.number_electrodes)] - for tr in range(len(data.ids))] - for wi in range(nwindows)] - for st, start_index in enumerate(window_starts[: nwindows]): # make t0 and dt @@ -208,7 +195,7 @@ def main(datapath: str) -> None: # calucate median of fish frequencies in window median_freq = [] track_ids = [] - for el, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): + for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): window_idx = np.arange(len(data.idx))[ (data.ident == track_id) & (data.time[data.idx] >= t0) & ( data.time[data.idx] <= (t0 + dt)) @@ -333,7 +320,7 @@ def main(datapath: str) -> None: # iterate through electrodes for el, electrode in enumerate(best_electrodes): - + print(el) # load region of interest of raw data file data_oi = data.raw[start_index:stop_index, :] time_oi = raw_time[start_index:stop_index] @@ -454,47 +441,6 @@ def main(datapath: str) -> None: prominence=prominence ) - # DETECT CHIRPS IN SEARCH WINDOW ------------------------------- - - baseline_ts = time_oi[baseline_peaks] - search_ts = time_oi[search_peaks] - freq_ts = baseline_freq_time[inst_freq_peaks] - - # check if one list is empty - if len(baseline_ts) == 0 or len(search_ts) == 0 or len(freq_ts) == 0: - continue - - # get index for each feature - baseline_idx = np.zeros_like(baseline_ts) - search_idx = np.ones_like(search_ts) - freq_idx = np.ones_like(freq_ts) * 2 - - timestamps_features = np.hstack( - [baseline_idx, search_idx, freq_idx]) - timestamps = np.hstack([baseline_ts, search_ts, freq_ts]) - - # sort timestamps - timestamps_idx = np.arange(len(timestamps)) - timestamps_features = timestamps_features[np.argsort(timestamps)] - timestamps = timestamps[np.argsort(timestamps)] - - # # get chirps - # diff = np.empty(timestamps.shape) - # diff[0] = np.inf # always retain the 1st element - # diff[1:] = np.diff(timestamps) - # mask = diff < config.chirp_window_threshold - # shared_peak_indices = timestamp_idx[mask] - - current_chirps = [] - for tt in timestamps: - cm = timestamps_idx[(timestamps >= tt) & ( - timestamps <= tt + config.chirp_window_threshold)] - if all([0, 1, 2]) in timestamps_features[cm]: - chirps.append(np.mean(timestamps[cm])) - current_chirps.append(np.mean(timestamps[cm])) - fish_ids.append(track_id) - - # # SAVE DATA --------------------------------------------------- # PLOT -------------------------------------------------------- @@ -503,9 +449,6 @@ def main(datapath: str) -> None: plot_spectrogram( axs[0, el], data_oi[:, electrode], data.raw_rate, t0) - for ct in current_chirps: - axs[0, el].axvline(ct, color='r', lw=1) - # plot baseline instantaneos frequency axs[1, el].plot(baseline_freq_time, baseline_freq - np.median(baseline_freq)) @@ -569,6 +512,68 @@ def main(datapath: str) -> None: axs[5, el].set_title("Search envelope") axs[6, el].set_title( "Filtered absolute instantaneous frequency") + print(el) + # DETECT CHIRPS IN SEARCH WINDOW ------------------------------- + + baseline_ts = time_oi[baseline_peaks] + search_ts = time_oi[search_peaks] + freq_ts = baseline_freq_time[inst_freq_peaks] + + # check if one list is empty + if len(baseline_ts) == 0 or len(search_ts) == 0 or len(freq_ts) == 0: + continue + + # get index for each feature + baseline_idx = np.zeros_like(baseline_ts) + search_idx = np.ones_like(search_ts) + freq_idx = np.ones_like(freq_ts) * 2 + + timestamps_features = np.hstack( + [baseline_idx, search_idx, freq_idx]) + timestamps = np.hstack([baseline_ts, search_ts, freq_ts]) + + # sort timestamps + timestamps_idx = np.arange(len(timestamps)) + timestamps_features = timestamps_features[np.argsort( + timestamps)] + timestamps = timestamps[np.argsort(timestamps)] + + # # get chirps + # diff = np.empty(timestamps.shape) + # diff[0] = np.inf # always retain the 1st element + # diff[1:] = np.diff(timestamps) + # mask = diff < config.chirp_window_threshold + # shared_peak_indices = timestamp_idx[mask] + + current_chirps = [] + for tt in timestamps: + cm = timestamps_idx[(timestamps >= tt) & ( + timestamps <= tt + config.chirp_window_threshold)] + if set([0, 1, 2]).issubset(timestamps_features[cm]): + chirps.append(np.mean(timestamps[cm])) + current_chirps.append(np.mean(timestamps[cm])) + fish_ids.append(track_id) + + for ct in current_chirps: + axs[0, el].axvline(ct, color='r', lw=1) + + axs[0, el].scatter( + baseline_freq_time[inst_freq_peaks], + np.ones_like(baseline_freq_time[inst_freq_peaks]) * 600, + c=ps.red, + ) + axs[0, el].scatter( + (time_oi)[search_peaks], + np.ones_like((time_oi)[search_peaks]) * 600, + c=ps.red, + ) + + axs[0, el].scatter( + (time_oi)[baseline_peaks], + np.ones_like((time_oi)[baseline_peaks]) * 600, + c=ps.red, + ) + plt.show()