diff --git a/code/chirpdetection.py b/code/chirpdetection.py index 348851a..425664e 100755 --- a/code/chirpdetection.py +++ b/code/chirpdetection.py @@ -15,7 +15,6 @@ from modules.plotstyle import PlotStyle from modules.logger import makeLogger from modules.datahandling import ( flatten, - norm, purge_duplicates, group_timestamps, instantaneous_frequency, @@ -351,9 +350,16 @@ def extract_frequency_bands( def window_median_all_track_ids( data: LoadData, window_start_seconds: float, window_duration_seconds: float -) -> tuple[float, list[int]]: +) -> tuple[list[tuple[float, float, float]], list[int]]: """ - Calculate the median frequency of all fish in a given time window. + Calculate the median and quantiles of the frequency of all fish in a + given time window. + + Iterate over all track ids and calculate the 25, 50 and 75 percentile + in a given time window to pass this data to 'find_searchband' function, + which then determines whether other fish in the current window fall + within the searchband of the current fish and then determine the + gaps that are outside of the percentile ranges. Parameters ---------- @@ -366,14 +372,16 @@ def window_median_all_track_ids( Returns ------- - tuple[float, list[int]] + tuple[list[tuple[float, float, float]], list[int]] """ - median_freq = [] + frequency_percentiles = [] track_ids = [] for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): + + # the window index combines the track id and the time window window_idx = np.arange(len(data.idx))[ (data.ident == track_id) & (data.time[data.idx] >= window_start_seconds) @@ -384,20 +392,21 @@ def window_median_all_track_ids( ] if len(data.freq[window_idx]) > 0: - median_freq.append(np.median(data.freq[window_idx])) + frequency_percentiles.append( + np.percentile(data.freq[window_idx], [25, 50, 75])) track_ids.append(track_id) # convert to numpy array - median_freq = np.asarray(median_freq) + frequency_percentiles = np.asarray(frequency_percentiles) track_ids = np.asarray(track_ids) - return median_freq, track_ids + return frequency_percentiles, track_ids def find_searchband( - freq_temp: np.ndarray, - median_ids: np.ndarray, - median_freq: np.ndarray, + current_frequency: np.ndarray, + percentiles_ids: np.ndarray, + frequency_percentiles: np.ndarray, config: ConfLoader, data: LoadData, ) -> float: @@ -407,13 +416,13 @@ def find_searchband( Parameters ---------- - freq_temp : np.ndarray + current_frequency : np.ndarray Current EOD frequency array / the current fish of interest. - median_ids : np.ndarray + percentiles_ids : np.ndarray Array of track IDs of the medians of all other fish in the current window. - median_freq : np.ndarray - Array of median frequencies of all other fish in the current window. + frequency_percentiles : np.ndarray + Array of percentiles frequencies of all other fish in the current window. config : ConfLoader Configuration file. data : LoadData @@ -424,19 +433,27 @@ def find_searchband( float """ - # frequency where second filter filters + # frequency window where second filter filters is potentially allowed + # to filter. This is the search window, in which we want to find + # a gap in the other fish's EODs. + search_window = np.arange( - np.median(freq_temp) + config.search_df_lower, - np.median(freq_temp) + config.search_df_upper, + np.median(current_frequency) + config.search_df_lower, + np.median(current_frequency) + config.search_df_upper, config.search_res, ) # search window in boolean - search_window_bool = np.ones(len(search_window), dtype=bool) + search_window_bool = np.ones_like(len(search_window), dtype=bool) + + # make seperate arrays from the qartiles + q25 = np.asarray([i[0] for i in frequency_percentiles]) + q75 = np.asarray([i[2] for i in frequency_percentiles]) # get tracks that fall into search window - check_track_ids = median_ids[ - (median_freq > search_window[0]) & (median_freq < search_window[-1]) + check_track_ids = percentiles_ids[ + (q25 > search_window[0]) & ( + q75 < search_window[-1]) ] # iterate through theses tracks @@ -444,25 +461,26 @@ def find_searchband( for j, check_track_id in enumerate(check_track_ids): - q1, q2 = np.percentile( - data.freq[data.ident == check_track_id], [25, 75] - ) - print(q1, q2) + q25_temp = q25[percentiles_ids == check_track_id] + q75_temp = q75[percentiles_ids == check_track_id] + + print(q25_temp, q75_temp) search_window_bool[ - (search_window > q1) & (search_window < q2) + (search_window > q25_temp) & (search_window < q75_temp) ] = False # find gaps in search window search_window_indices = np.arange(len(search_window)) # get search window gaps + # taking the diff of a boolean array gives non zero values where the + # array changes from true to false or vice versa + search_window_gaps = np.diff(search_window_bool, append=np.nan) nonzeros = search_window_gaps[np.nonzero(search_window_gaps)[0]] nonzeros = nonzeros[~np.isnan(nonzeros)] - embed() - # if the first value is -1, the array starst with true, so a gap if nonzeros[0] == -1: stops = search_window_indices[search_window_gaps == -1] @@ -494,14 +512,16 @@ def find_searchband( search_windows_lens = [len(x) for x in search_windows] longest_search_window = search_windows[np.argmax(search_windows_lens)] + # the center of the search frequency band is then the center of + # the longest gap + search_freq = ( longest_search_window[-1] - longest_search_window[0] ) / 2 - else: - search_freq = config.default_search_freq + return search_freq - return search_freq + return config.default_search_freq def main(datapath: str, plot: str) -> None: @@ -637,10 +657,10 @@ def main(datapath: str, plot: str) -> None: search_frequency = find_searchband( config=config, - freq_temp=current_frequencies, - median_ids=median_ids, + current_frequency=current_frequencies, + percentiles_ids=median_ids, data=data, - median_freq=median_freq, + frequency_percentiles=median_freq, ) # add all chirps that are detected on mulitple electrodes for one @@ -1001,4 +1021,4 @@ if __name__ == "__main__": datapath = "../data/2022-06-02-10_00/" # datapath = "/home/weygoldt/Data/uni/efishdata/2016-colombia/fishgrid/2016-04-09-22_25/" # datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-03-13-10_00/" - main(datapath, plot="save") + main(datapath, plot="show")