From c4ac372647fb5bb5808c7713b40fb37b082ebb21 Mon Sep 17 00:00:00 2001 From: weygoldt <88969563+weygoldt@users.noreply.github.com> Date: Tue, 23 May 2023 15:49:47 +0200 Subject: [PATCH] cleanup --- code/chirpdetection.py | 108 +++++++++++++++-------------------------- 1 file changed, 39 insertions(+), 69 deletions(-) diff --git a/code/chirpdetection.py b/code/chirpdetection.py index 121f6dc..161730d 100755 --- a/code/chirpdetection.py +++ b/code/chirpdetection.py @@ -87,12 +87,18 @@ class ChirpPlotBuffer: print(search_upper, search_lower) # get indices on raw data - start_idx = (self.t0 - 5) * self.data.raw_rate + start_idx = int((self.t0 - 5) * self.data.raw_rate) window_duration = (self.dt + 10) * self.data.raw_rate - stop_idx = start_idx + window_duration + stop_idx = int(start_idx + window_duration) + + if start_idx < 0: + start_idx = 0 # get raw data - data_oi = self.data.raw[start_idx:stop_idx, self.electrode] + try: + data_oi = self.data.raw[start_idx:stop_idx, self.electrode] + except: + embed() self.time = self.time - self.t0 self.frequency_time = self.frequency_time - self.t0 @@ -281,9 +287,7 @@ class ChirpPlotBuffer: ) # plot filtered instantaneous frequency - ax6.plot( - self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw - ) + ax6.plot(self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw) ax6.scatter( self.frequency_time[self.frequency_peaks], self.frequency_filtered[self.frequency_peaks], @@ -317,9 +321,7 @@ class ChirpPlotBuffer: # ax7.spines.bottom.set_bounds((0, 5)) ax0.set_xlim(0, self.config.window) - plt.subplots_adjust( - left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2 - ) + plt.subplots_adjust(left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2) fig.align_labels() if plot == "show": @@ -330,8 +332,9 @@ class ChirpPlotBuffer: self.config.outputdir + self.data.datapath.split("/")[-2] + "/" ) - plt.savefig(f"{out}{self.track_id}_{self.t0_old}.pdf") - plt.savefig(f"{out}{self.track_id}_{self.t0_old}.svg") + # plt.savefig(f"{out}{self.track_id}_{self.t0_old}.pdf") + # plt.savefig(f"{out}{self.track_id}_{self.t0_old}.svg") + plt.savefig(f"{out}{self.track_id}_{self.t0_old}.png") plt.close() @@ -424,9 +427,7 @@ def extract_frequency_bands( q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2 # filter baseline - filtered_baseline = bandpass_filter( - raw_data, samplerate, lowf=q25, highf=q75 - ) + filtered_baseline = bandpass_filter(raw_data, samplerate, lowf=q25, highf=q75) # filter search area filtered_search_freq = bandpass_filter( @@ -475,10 +476,7 @@ def window_median_all_track_ids( window_idx = np.arange(len(data.idx))[ (data.ident == track_id) & (data.time[data.idx] >= window_start_seconds) - & ( - data.time[data.idx] - <= (window_start_seconds + window_duration_seconds) - ) + & (data.time[data.idx] <= (window_start_seconds + window_duration_seconds)) ] if len(data.freq[window_idx]) > 0: @@ -609,9 +607,7 @@ def find_searchband( q75 = np.asarray([i[2] for i in frequency_percentiles]) # get tracks that fall into search window - check_track_ids = percentiles_ids[ - (q25 > current_median) & (q75 < search_window[-1]) - ] + check_track_ids = percentiles_ids[(q25 > current_median) & (q75 < search_window[-1])] # iterate through theses tracks if check_track_ids.size != 0: @@ -621,9 +617,7 @@ def find_searchband( bool_lower[search_window > q25_temp - config.search_res] = False bool_upper[search_window < q75_temp + config.search_res] = False - search_window_bool[ - (bool_lower == False) & (bool_upper == False) - ] = False + search_window_bool[(bool_lower == False) & (bool_upper == False)] = False # find gaps in search window search_window_indices = np.arange(len(search_window)) @@ -642,9 +636,7 @@ def find_searchband( # if the first value is -1, the array starst with true, so a gap if nonzeros[0] == -1: stops = search_window_indices[search_window_gaps == -1] - starts = np.append( - 0, search_window_indices[search_window_gaps == 1] - ) + starts = np.append(0, search_window_indices[search_window_gaps == 1]) # if the last value is -1, the array ends with true, so a gap if nonzeros[-1] == 1: @@ -749,8 +741,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: multiwindow_chirps = [] multiwindow_ids = [] - for st, window_start_index in enumerate(window_start_indices[1853:]): - + for st, window_start_index in enumerate(window_start_indices): logger.info(f"Processing window {st} of {len(window_start_indices)}") window_start_seconds = window_start_index / data.raw_rate @@ -765,9 +756,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: ) # iterate through all fish - for tr, track_id in enumerate( - np.unique(data.ident[~np.isnan(data.ident)]) - ): + for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): logger.debug(f"Processing track {tr} of {len(data.ids)}") # get index of track data in this time window @@ -786,26 +775,23 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: # check if tracked data available in this window if len(current_frequencies) < 3: - logger.warning( - f"Track {track_id} has no data in window {st}, skipping." - ) + logger.warning(f"Track {track_id} has no data in window {st}, skipping.") continue # check if there are powers available in this window nanchecker = np.unique(np.isnan(current_powers)) if (len(nanchecker) == 1) and nanchecker[0] is True: logger.warning( - f"No powers available for track {track_id} window {st}," - "skipping." + f"No powers available for track {track_id} window {st}," "skipping." ) continue # find the strongest electrodes for the current fish in the current # window - best_electrode_index = np.argsort( - np.nanmean(current_powers, axis=0) - )[-config.number_electrodes :] + best_electrode_index = np.argsort(np.nanmean(current_powers, axis=0))[ + -config.number_electrodes : + ] # find a frequency above the baseline of the current fish in which # no other fish is active to search for chirps there @@ -826,8 +812,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: # iterate through electrodes for el, electrode_index in enumerate(best_electrode_index): logger.debug( - f"Processing electrode {el+1} of " - f"{len(best_electrode_index)}" + f"Processing electrode {el+1} of " f"{len(best_electrode_index)}" ) # LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------ @@ -836,9 +821,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: current_raw_data = data.raw[ window_start_index:window_stop_index, electrode_index ] - current_raw_time = raw_time[ - window_start_index:window_stop_index - ] + current_raw_time = raw_time[window_start_index:window_stop_index] # EXTRACT FEATURES -------------------------------------------- @@ -924,7 +907,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: # if not has_chirp( # baseline_frequency_filtered[amplitude_mask], # config.baseline_frequency_peakheight, - # ): + # ): # logger.warning( # f"Amplitude to small for the chirp detection of track {track_id} window {st},") # continue @@ -941,20 +924,14 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: current_raw_time = current_raw_time[no_edges] baselineband = baselineband[no_edges] - baseline_envelope_unfiltered = baseline_envelope_unfiltered[ - no_edges - ] + baseline_envelope_unfiltered = baseline_envelope_unfiltered[no_edges] searchband = searchband[no_edges] baseline_envelope = baseline_envelope[no_edges] - search_envelope_unfiltered = search_envelope_unfiltered[ - no_edges - ] + search_envelope_unfiltered = search_envelope_unfiltered[no_edges] search_envelope = search_envelope[no_edges] baseline_frequency = baseline_frequency[no_edges] - baseline_frequency_filtered = baseline_frequency_filtered[ - no_edges - ] + baseline_frequency_filtered = baseline_frequency_filtered[no_edges] baseline_frequency_time = current_raw_time # # get instantaneous frequency withoup edges @@ -998,14 +975,11 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: baseline_frequency_filtered, prominence=config.frequency_prominence, ) - # DETECT CHIRPS IN SEARCH WINDOW ------------------------------ # get the peak timestamps from the peak indices - baseline_peak_timestamps = current_raw_time[ - baseline_peak_indices - ] + baseline_peak_timestamps = current_raw_time[baseline_peak_indices] search_peak_timestamps = current_raw_time[search_peak_indices] frequency_peak_timestamps = baseline_frequency_time[ @@ -1017,7 +991,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: one_feature_empty = ( len(baseline_peak_timestamps) == 0 or len(search_peak_timestamps) == 0 - #or len(frequency_peak_timestamps) == 0 + # or len(frequency_peak_timestamps) == 0 ) if one_feature_empty and (debug == "false"): @@ -1029,7 +1003,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: sublists = [ list(baseline_peak_timestamps), list(search_peak_timestamps), - #list(frequency_peak_timestamps), + # list(frequency_peak_timestamps), ] singleelectrode_chirps = group_timestamps( @@ -1038,7 +1012,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: difference_threshold=config.chirp_window_threshold, ) - # check it there are chirps detected after grouping, continue # with the loop if not @@ -1153,9 +1126,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: # add flattened chirps to the list multiwindow_chirps_flat.extend(current_track_chirps) - multiwindow_ids_flat.extend( - list(np.ones_like(current_track_chirps) * track_id) - ) + multiwindow_ids_flat.extend(list(np.ones_like(current_track_chirps) * track_id)) # purge duplicates, i.e. chirps that are very close to each other # duplites arise due to overlapping windows @@ -1167,9 +1138,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: np.asarray(multiwindow_ids_flat) == track_id ] if len(tr_chirps) > 0: - tr_chirps_purged = purge_duplicates( - tr_chirps, config.chirp_window_threshold - ) + tr_chirps_purged = purge_duplicates(tr_chirps, config.chirp_window_threshold) purged_chirps.extend(list(tr_chirps_purged)) purged_ids.extend(list(np.ones_like(tr_chirps_purged) * track_id)) @@ -1189,4 +1158,5 @@ if __name__ == "__main__": # datapath = "/home/weygoldt/Data/uni/efishdata/2016-colombia/fishgrid/2016-04-09-22_25/" # datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-03-13-10_00/" datapath = "../data/2022-06-02-10_00/" - chirpdetection(datapath, plot="show", debug="false") + datapath = "../../../local_data/randgrid/" + chirpdetection(datapath, plot="save", debug="false")