From 9422af9fb09de98c8018a0b112e0f47e758942df Mon Sep 17 00:00:00 2001 From: weygoldt <88969563+weygoldt@users.noreply.github.com> Date: Wed, 18 Jan 2023 12:42:41 +0100 Subject: [PATCH] purge duplicates --- code/chirpdetection.py | 38 +++++---- .../{timestamps.py => datahandling.py} | 82 ++++++++++++++----- 2 files changed, 81 insertions(+), 39 deletions(-) rename code/modules/{timestamps.py => datahandling.py} (64%) diff --git a/code/chirpdetection.py b/code/chirpdetection.py index 3272900..3aa9d2d 100644 --- a/code/chirpdetection.py +++ b/code/chirpdetection.py @@ -1,4 +1,4 @@ -import itertools +from itertools import combinations, compress import numpy as np from IPython import embed @@ -11,6 +11,7 @@ from sklearn.preprocessing import normalize from modules.filters import bandpass_filter, envelope, highpass_filter from modules.filehandling import ConfLoader, LoadData +from modules.datahandling import flatten, purge_duplicates from modules.plotstyle import PlotStyle @@ -517,7 +518,6 @@ def main(datapath: str) -> None: axs[6, el].set_title( "Filtered absolute instantaneous frequency") - # DETECT CHIRPS IN SEARCH WINDOW ------------------------------- baseline_ts = time_oi[baseline_peaks] @@ -528,10 +528,9 @@ def main(datapath: str) -> None: if len(baseline_ts) == 0 or len(search_ts) == 0 or len(freq_ts) == 0: continue - #current_chirps = group_timestamps_v2( + # current_chirps = group_timestamps_v2( # [list(baseline_ts), list(search_ts), list(freq_ts)], 3) - # get index for each feature baseline_idx = np.zeros_like(baseline_ts) search_idx = np.ones_like(search_ts) @@ -565,12 +564,11 @@ def main(datapath: str) -> None: current_chirps.append(np.mean(timestamps[cm])) electrodes_of_chirps.append(el) bool_timestamps[cm] = False - + # for checking if there are chirps on multiple electrodes chirps_electrodes.append(current_chirps) - for ct in current_chirps: axs[0, el].axvline(ct, color='r', lw=1) @@ -607,7 +605,7 @@ def main(datapath: str) -> None: index_vector = np.arange(len(sort_chirps_electrodes)) # make it more than only two electrodes for the search after chirps combinations_best_elctrodes = list( - itertools.combinations(range(3), 2)) + combinations(range(3), 2)) the_real_chirps = [] for chirp_index, seoc in enumerate(sort_chirps_electrodes): @@ -616,15 +614,14 @@ def main(datapath: str) -> None: cm = index_vector[(sort_chirps_electrodes >= seoc) & ( sort_chirps_electrodes <= seoc + config.chirp_window_threshold)] - chirps_unique = [] for combination in combinations_best_elctrodes: if set(combination).issubset(sort_electrodes[cm]): - chirps_unique.append(np.mean(sort_chirps_electrodes[cm])) + chirps_unique.append( + np.mean(sort_chirps_electrodes[cm])) the_real_chirps.append(np.mean(chirps_unique)) - """ if set([0,1]).issubset(sort_electrodes[cm]): the_real_chirps.append(np.mean(sort_chirps_electrodes[cm])) @@ -638,16 +635,14 @@ def main(datapath: str) -> None: bool_vector[cm] = False chirps.append(the_real_chirps) fish_ids.append(track_id) - for ct in the_real_chirps: axs[0, el].axvline(ct, color='b', lw=1) - + plt.close() - embed() fig, ax = plt.subplots() t0 = (3 * 60 * 60 + 6 * 60 + 43.5) - data_oi = data.raw[window_starts[0]:window_starts[-1]+ int(dt*data.raw_rate), 10] + data_oi = data.raw[window_starts[0]:window_starts[-1] + int(dt*data.raw_rate), 10] plot_spectrogram(ax, data_oi, data.raw_rate, t0) chirps_concat = np.concatenate(chirps) for ch in chirps_concat: @@ -655,15 +650,24 @@ def main(datapath: str) -> None: chirps_new = [] chirps_ids = [] - [chirps[x] for x in tr_index] for tr in np.unique(fish_ids): tr_index = np.asarray(fish_ids) == tr - - ts = list(np.ravel(chirps[fish_ids == int(tr)])) + ts = flatten(list(compress(chirps, tr_index))) chirps_new.extend(ts) chirps_ids.extend(list(np.ones_like(ts)*tr)) + # purge duplicates + purged_chirps = [] + purged_chirps_ids = [] + for tr in np.unique(fish_ids): + tr_chirps = np.asarray(chirps_new)[np.asarray(chirps_ids) == tr] + if len(tr_chirps) > 0: + tr_chirps_purged = purge_duplicates( + tr_chirps, config.chirp_window_threshold) + purged_chirps.extend(list(tr_chirps_purged)) + purged_chirps_ids.extend(list(np.ones_like(tr_chirps_purged)*tr)) + embed() if __name__ == "__main__": diff --git a/code/modules/timestamps.py b/code/modules/datahandling.py similarity index 64% rename from code/modules/timestamps.py rename to code/modules/datahandling.py index 80c3d5a..53778ff 100644 --- a/code/modules/timestamps.py +++ b/code/modules/datahandling.py @@ -1,11 +1,15 @@ import numpy as np -from typing import List +from typing import List, Union, Any -def purge_duplicates(timestamps: List[float], threshold: float = 0.5) -> List[float]: +def purge_duplicates( + timestamps: List[float], threshold: float = 0.5 +) -> List[float]: """ - Compute the mean of groups of timestamps that are closer to the previous or consecutive timestamp than the threshold, - and return all timestamps that are further apart from the previous or consecutive timestamp than the threshold in a single list. + Compute the mean of groups of timestamps that are closer to the previous + or consecutive timestamp than the threshold, and return all timestamps that + are further apart from the previous or consecutive timestamp than the + threshold in a single list. Parameters ---------- @@ -17,10 +21,12 @@ def purge_duplicates(timestamps: List[float], threshold: float = 0.5) -> List[fl Returns ------- List[float] - A list containing a list of timestamps that are further apart than the threshold - and a list of means of the groups of timestamps that are closer to the previous or consecutive timestamp than the threshold. + A list containing a list of timestamps that are further apart than + the threshold and a list of means of the groups of timestamps that + are closer to the previous or consecutive timestamp than the threshold. """ - # Initialize an empty list to store the groups of timestamps that are closer to the previous or consecutive timestamp than the threshold + # Initialize an empty list to store the groups of timestamps that are + # closer to the previous or consecutive timestamp than the threshold groups = [] # initialize the first group with the first timestamp @@ -28,8 +34,9 @@ def purge_duplicates(timestamps: List[float], threshold: float = 0.5) -> List[fl for i in range(1, len(timestamps)): - # check the difference between current timestamp and previous timestamp is less than the threshold - if timestamps[i] - timestamps[i-1] < threshold: + # check the difference between current timestamp and previous + # timestamp is less than the threshold + if timestamps[i] - timestamps[i - 1] < threshold: # add the current timestamp to the current group group.append(timestamps[i]) else: @@ -40,22 +47,28 @@ def purge_duplicates(timestamps: List[float], threshold: float = 0.5) -> List[fl # start a new group with the current timestamp group = [timestamps[i]] - # after iterating through all the timestamps, add the last group to the groups list + # after iterating through all the timestamps, add the last group to the + # groups list groups.append(group) - # get the mean of each group and only include the ones that have more than 1 timestamp + # get the mean of each group and only include the ones that have more + # than 1 timestamp means = [np.mean(group) for group in groups if len(group) > 1] - # get the timestamps that are outliers, i.e. the ones that are alone in a group + # get the timestamps that are outliers, i.e. the ones that are alone + # in a group outliers = [ts for group in groups for ts in group if len(group) == 1] # return the outliers and means in a single list return outliers + means -def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> List[float]: +def group_timestamps( + sublists: List[List[float]], n: int, threshold: float +) -> List[float]: """ - Groups timestamps that are less than `threshold` milliseconds apart from at least `n` other sublists. + Groups timestamps that are less than `threshold` milliseconds apart from + at least `n` other sublists. Returns a list of the mean of each group. If any of the sublists is empty, it will be ignored. @@ -64,9 +77,11 @@ def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> L sublists : List[List[float]] a list of sublists, each containing timestamps n : int - minimum number of sublists that a timestamp must be close to in order to be grouped + minimum number of sublists that a timestamp must be close to in order + to be grouped threshold : float - the maximum difference in milliseconds between timestamps to be considered a match + the maximum difference in milliseconds between timestamps to be + considered a match Returns ------- @@ -76,7 +91,8 @@ def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> L """ # Flatten the sublists and sort the timestamps timestamps = [ - timestamp for sublist in sublists if sublist for timestamp in sublist] + timestamp for sublist in sublists if sublist for timestamp in sublist + ] timestamps.sort() groups = [] @@ -84,7 +100,7 @@ def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> L # Group timestamps that are less than threshold milliseconds apart for i in range(1, len(timestamps)): - if timestamps[i] - timestamps[i-1] < threshold: + if timestamps[i] - timestamps[i - 1] < threshold: current_group.append(timestamps[i]) else: groups.append(current_group) @@ -104,10 +120,32 @@ def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> L return means +def flatten(list: List[List[Any]]) -> List: + """ + Flattens a list / array of lists. + + Parameters + ---------- + l : array or list of lists + The list to be flattened + + Returns + ------- + list + The flattened list + """ + return [item for sublist in list for item in sublist] + + if __name__ == "__main__": - timestamps = [[1.2, 1.5, 1.3], [], - [1.21, 1.51, 1.31], [1.19, 1.49, 1.29], [1.22, 1.52, 1.32], [1.2, 1.5, 1.3]] + timestamps = [ + [1.2, 1.5, 1.3], + [], + [1.21, 1.51, 1.31], + [1.19, 1.49, 1.29], + [1.22, 1.52, 1.32], + [1.2, 1.5, 1.3], + ] print(group_timestamps(timestamps, 2, 0.05)) - print(purge_duplicates( - [1, 2, 3, 4, 5, 6, 6.02, 7, 8, 8.02], 0.05)) + print(purge_duplicates([1, 2, 3, 4, 5, 6, 6.02, 7, 8, 8.02], 0.05))