diff --git a/code/chirpdetection.py b/code/chirpdetection.py
index 3272900..3aa9d2d 100644
--- a/code/chirpdetection.py
+++ b/code/chirpdetection.py
@@ -1,4 +1,4 @@
-import itertools
+from itertools import combinations, compress
 
 import numpy as np
 from IPython import embed
@@ -11,6 +11,7 @@ from sklearn.preprocessing import normalize
 
 from modules.filters import bandpass_filter, envelope, highpass_filter
 from modules.filehandling import ConfLoader, LoadData
+from modules.datahandling import flatten, purge_duplicates
 from modules.plotstyle import PlotStyle
 
 
@@ -517,7 +518,6 @@ def main(datapath: str) -> None:
                 axs[6, el].set_title(
                     "Filtered absolute instantaneous frequency")
 
-
                 # DETECT CHIRPS IN SEARCH WINDOW -------------------------------
 
                 baseline_ts = time_oi[baseline_peaks]
@@ -528,10 +528,9 @@ def main(datapath: str) -> None:
                 if len(baseline_ts) == 0 or len(search_ts) == 0 or len(freq_ts) == 0:
                     continue
 
-                #current_chirps = group_timestamps_v2(
+                # current_chirps = group_timestamps_v2(
                 #    [list(baseline_ts), list(search_ts), list(freq_ts)], 3)
 
-               
                 # get index for each feature
                 baseline_idx = np.zeros_like(baseline_ts)
                 search_idx = np.ones_like(search_ts)
@@ -565,12 +564,11 @@ def main(datapath: str) -> None:
                         current_chirps.append(np.mean(timestamps[cm]))
                         electrodes_of_chirps.append(el)
                     bool_timestamps[cm] = False
-                
+
                 # for checking if there are chirps on multiple electrodes
 
                 chirps_electrodes.append(current_chirps)
 
-
                 for ct in current_chirps:
                     axs[0, el].axvline(ct, color='r', lw=1)
 
@@ -607,7 +605,7 @@ def main(datapath: str) -> None:
             index_vector = np.arange(len(sort_chirps_electrodes))
             # make it more than only two electrodes for the search after chirps
             combinations_best_elctrodes = list(
-                itertools.combinations(range(3), 2))
+                combinations(range(3), 2))
 
             the_real_chirps = []
             for chirp_index, seoc in enumerate(sort_chirps_electrodes):
@@ -616,15 +614,14 @@ def main(datapath: str) -> None:
                 cm = index_vector[(sort_chirps_electrodes >= seoc) & (
                     sort_chirps_electrodes <= seoc + config.chirp_window_threshold)]
 
-                
                 chirps_unique = []
                 for combination in combinations_best_elctrodes:
                     if set(combination).issubset(sort_electrodes[cm]):
-                        chirps_unique.append(np.mean(sort_chirps_electrodes[cm]))
+                        chirps_unique.append(
+                            np.mean(sort_chirps_electrodes[cm]))
 
                 the_real_chirps.append(np.mean(chirps_unique))
 
-
                 """
                 if set([0,1]).issubset(sort_electrodes[cm]):
                     the_real_chirps.append(np.mean(sort_chirps_electrodes[cm]))
@@ -638,16 +635,14 @@ def main(datapath: str) -> None:
                 bool_vector[cm] = False
             chirps.append(the_real_chirps)
             fish_ids.append(track_id)
-            
 
             for ct in the_real_chirps:
                 axs[0, el].axvline(ct, color='b', lw=1)
-            
+
     plt.close()
-    embed()
     fig, ax = plt.subplots()
     t0 = (3 * 60 * 60 + 6 * 60 + 43.5)
-    data_oi = data.raw[window_starts[0]:window_starts[-1]+ int(dt*data.raw_rate), 10]
+    data_oi = data.raw[window_starts[0]:window_starts[-1] + int(dt*data.raw_rate), 10]
     plot_spectrogram(ax, data_oi, data.raw_rate, t0)
     chirps_concat = np.concatenate(chirps)
     for ch in chirps_concat:
@@ -655,15 +650,24 @@ def main(datapath: str) -> None:
 
     chirps_new = []
     chirps_ids = []
-    [chirps[x] for x in tr_index]
     for tr in np.unique(fish_ids):
         tr_index = np.asarray(fish_ids) == tr
-
-        ts = list(np.ravel(chirps[fish_ids == int(tr)]))
+        ts = flatten(list(compress(chirps, tr_index)))
         chirps_new.extend(ts)
         chirps_ids.extend(list(np.ones_like(ts)*tr))
 
+    # purge duplicates
+    purged_chirps = []
+    purged_chirps_ids = []
+    for tr in np.unique(fish_ids):
+        tr_chirps = np.asarray(chirps_new)[np.asarray(chirps_ids) == tr]
+        if len(tr_chirps) > 0:
+            tr_chirps_purged = purge_duplicates(
+                tr_chirps, config.chirp_window_threshold)
+            purged_chirps.extend(list(tr_chirps_purged))
+            purged_chirps_ids.extend(list(np.ones_like(tr_chirps_purged)*tr))
 
+    embed()
 
 
 if __name__ == "__main__":
diff --git a/code/modules/timestamps.py b/code/modules/datahandling.py
similarity index 64%
rename from code/modules/timestamps.py
rename to code/modules/datahandling.py
index 80c3d5a..53778ff 100644
--- a/code/modules/timestamps.py
+++ b/code/modules/datahandling.py
@@ -1,11 +1,15 @@
 import numpy as np
-from typing import List
+from typing import List, Union, Any
 
 
-def purge_duplicates(timestamps: List[float], threshold: float = 0.5) -> List[float]:
+def purge_duplicates(
+    timestamps: List[float], threshold: float = 0.5
+) -> List[float]:
     """
-    Compute the mean of groups of timestamps that are closer to the previous or consecutive timestamp than the threshold,
-    and return all timestamps that are further apart from the previous or consecutive timestamp than the threshold in a single list.
+    Compute the mean of groups of timestamps that are closer to the previous
+    or consecutive timestamp than the threshold, and return all timestamps that
+    are further apart from the previous or consecutive timestamp than the
+    threshold in a single list.
 
     Parameters
     ----------
@@ -17,10 +21,12 @@ def purge_duplicates(timestamps: List[float], threshold: float = 0.5) -> List[fl
     Returns
     -------
     List[float]
-        A list containing a list of timestamps that are further apart than the threshold
-        and a list of means of the groups of timestamps that are closer to the previous or consecutive timestamp than the threshold.
+        A list containing a list of timestamps that are further apart than
+        the threshold and a list of means of the groups of timestamps that
+        are closer to the previous or consecutive timestamp than the threshold.
     """
-    # Initialize an empty list to store the groups of timestamps that are closer to the previous or consecutive timestamp than the threshold
+    # Initialize an empty list to store the groups of timestamps that are
+    # closer to the previous or consecutive timestamp than the threshold
     groups = []
 
     # initialize the first group with the first timestamp
@@ -28,8 +34,9 @@ def purge_duplicates(timestamps: List[float], threshold: float = 0.5) -> List[fl
 
     for i in range(1, len(timestamps)):
 
-        # check the difference between current timestamp and previous timestamp is less than the threshold
-        if timestamps[i] - timestamps[i-1] < threshold:
+        # check the difference between current timestamp and previous
+        # timestamp is less than the threshold
+        if timestamps[i] - timestamps[i - 1] < threshold:
             # add the current timestamp to the current group
             group.append(timestamps[i])
         else:
@@ -40,22 +47,28 @@ def purge_duplicates(timestamps: List[float], threshold: float = 0.5) -> List[fl
             # start a new group with the current timestamp
             group = [timestamps[i]]
 
-    # after iterating through all the timestamps, add the last group to the groups list
+    # after iterating through all the timestamps, add the last group to the
+    # groups list
     groups.append(group)
 
-    # get the mean of each group and only include the ones that have more than 1 timestamp
+    # get the mean of each group and only include the ones that have more
+    # than 1 timestamp
     means = [np.mean(group) for group in groups if len(group) > 1]
 
-    # get the timestamps that are outliers, i.e. the ones that are alone in a group
+    # get the timestamps that are outliers, i.e. the ones that are alone
+    # in a group
     outliers = [ts for group in groups for ts in group if len(group) == 1]
 
     # return the outliers and means in a single list
     return outliers + means
 
 
-def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> List[float]:
+def group_timestamps(
+    sublists: List[List[float]], n: int, threshold: float
+) -> List[float]:
     """
-    Groups timestamps that are less than `threshold` milliseconds apart from at least `n` other sublists.
+    Groups timestamps that are less than `threshold` milliseconds apart from
+    at least `n` other sublists.
     Returns a list of the mean of each group.
     If any of the sublists is empty, it will be ignored.
 
@@ -64,9 +77,11 @@ def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> L
     sublists : List[List[float]]
         a list of sublists, each containing timestamps
     n : int
-        minimum number of sublists that a timestamp must be close to in order to be grouped
+        minimum number of sublists that a timestamp must be close to in order
+        to be grouped
     threshold : float
-        the maximum difference in milliseconds between timestamps to be considered a match
+        the maximum difference in milliseconds between timestamps to be
+        considered a match
 
     Returns
     -------
@@ -76,7 +91,8 @@ def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> L
     """
     # Flatten the sublists and sort the timestamps
     timestamps = [
-        timestamp for sublist in sublists if sublist for timestamp in sublist]
+        timestamp for sublist in sublists if sublist for timestamp in sublist
+    ]
     timestamps.sort()
 
     groups = []
@@ -84,7 +100,7 @@ def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> L
 
     # Group timestamps that are less than threshold milliseconds apart
     for i in range(1, len(timestamps)):
-        if timestamps[i] - timestamps[i-1] < threshold:
+        if timestamps[i] - timestamps[i - 1] < threshold:
             current_group.append(timestamps[i])
         else:
             groups.append(current_group)
@@ -104,10 +120,32 @@ def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> L
     return means
 
 
+def flatten(list: List[List[Any]]) -> List:
+    """
+    Flattens a list / array of lists.
+
+    Parameters
+    ----------
+    l : array or list of lists
+        The list to be flattened
+
+    Returns
+    -------
+    list
+        The flattened list
+    """
+    return [item for sublist in list for item in sublist]
+
+
 if __name__ == "__main__":
 
-    timestamps = [[1.2, 1.5, 1.3], [],
-                  [1.21, 1.51, 1.31], [1.19, 1.49, 1.29], [1.22, 1.52, 1.32], [1.2, 1.5, 1.3]]
+    timestamps = [
+        [1.2, 1.5, 1.3],
+        [],
+        [1.21, 1.51, 1.31],
+        [1.19, 1.49, 1.29],
+        [1.22, 1.52, 1.32],
+        [1.2, 1.5, 1.3],
+    ]
     print(group_timestamps(timestamps, 2, 0.05))
-    print(purge_duplicates(
-        [1, 2, 3, 4, 5, 6, 6.02, 7, 8, 8.02], 0.05))
+    print(purge_duplicates([1, 2, 3, 4, 5, 6, 6.02, 7, 8, 8.02], 0.05))