purge duplicates

This commit is contained in:
weygoldt 2023-01-18 12:42:41 +01:00
parent ba3803bb56
commit 9422af9fb0
2 changed files with 81 additions and 39 deletions

View File

@ -1,4 +1,4 @@
import itertools from itertools import combinations, compress
import numpy as np import numpy as np
from IPython import embed from IPython import embed
@ -11,6 +11,7 @@ from sklearn.preprocessing import normalize
from modules.filters import bandpass_filter, envelope, highpass_filter from modules.filters import bandpass_filter, envelope, highpass_filter
from modules.filehandling import ConfLoader, LoadData from modules.filehandling import ConfLoader, LoadData
from modules.datahandling import flatten, purge_duplicates
from modules.plotstyle import PlotStyle from modules.plotstyle import PlotStyle
@ -517,7 +518,6 @@ def main(datapath: str) -> None:
axs[6, el].set_title( axs[6, el].set_title(
"Filtered absolute instantaneous frequency") "Filtered absolute instantaneous frequency")
# DETECT CHIRPS IN SEARCH WINDOW ------------------------------- # DETECT CHIRPS IN SEARCH WINDOW -------------------------------
baseline_ts = time_oi[baseline_peaks] baseline_ts = time_oi[baseline_peaks]
@ -528,10 +528,9 @@ def main(datapath: str) -> None:
if len(baseline_ts) == 0 or len(search_ts) == 0 or len(freq_ts) == 0: if len(baseline_ts) == 0 or len(search_ts) == 0 or len(freq_ts) == 0:
continue continue
#current_chirps = group_timestamps_v2( # current_chirps = group_timestamps_v2(
# [list(baseline_ts), list(search_ts), list(freq_ts)], 3) # [list(baseline_ts), list(search_ts), list(freq_ts)], 3)
# get index for each feature # get index for each feature
baseline_idx = np.zeros_like(baseline_ts) baseline_idx = np.zeros_like(baseline_ts)
search_idx = np.ones_like(search_ts) search_idx = np.ones_like(search_ts)
@ -565,12 +564,11 @@ def main(datapath: str) -> None:
current_chirps.append(np.mean(timestamps[cm])) current_chirps.append(np.mean(timestamps[cm]))
electrodes_of_chirps.append(el) electrodes_of_chirps.append(el)
bool_timestamps[cm] = False bool_timestamps[cm] = False
# for checking if there are chirps on multiple electrodes # for checking if there are chirps on multiple electrodes
chirps_electrodes.append(current_chirps) chirps_electrodes.append(current_chirps)
for ct in current_chirps: for ct in current_chirps:
axs[0, el].axvline(ct, color='r', lw=1) axs[0, el].axvline(ct, color='r', lw=1)
@ -607,7 +605,7 @@ def main(datapath: str) -> None:
index_vector = np.arange(len(sort_chirps_electrodes)) index_vector = np.arange(len(sort_chirps_electrodes))
# make it more than only two electrodes for the search after chirps # make it more than only two electrodes for the search after chirps
combinations_best_elctrodes = list( combinations_best_elctrodes = list(
itertools.combinations(range(3), 2)) combinations(range(3), 2))
the_real_chirps = [] the_real_chirps = []
for chirp_index, seoc in enumerate(sort_chirps_electrodes): for chirp_index, seoc in enumerate(sort_chirps_electrodes):
@ -616,15 +614,14 @@ def main(datapath: str) -> None:
cm = index_vector[(sort_chirps_electrodes >= seoc) & ( cm = index_vector[(sort_chirps_electrodes >= seoc) & (
sort_chirps_electrodes <= seoc + config.chirp_window_threshold)] sort_chirps_electrodes <= seoc + config.chirp_window_threshold)]
chirps_unique = [] chirps_unique = []
for combination in combinations_best_elctrodes: for combination in combinations_best_elctrodes:
if set(combination).issubset(sort_electrodes[cm]): if set(combination).issubset(sort_electrodes[cm]):
chirps_unique.append(np.mean(sort_chirps_electrodes[cm])) chirps_unique.append(
np.mean(sort_chirps_electrodes[cm]))
the_real_chirps.append(np.mean(chirps_unique)) the_real_chirps.append(np.mean(chirps_unique))
""" """
if set([0,1]).issubset(sort_electrodes[cm]): if set([0,1]).issubset(sort_electrodes[cm]):
the_real_chirps.append(np.mean(sort_chirps_electrodes[cm])) the_real_chirps.append(np.mean(sort_chirps_electrodes[cm]))
@ -638,16 +635,14 @@ def main(datapath: str) -> None:
bool_vector[cm] = False bool_vector[cm] = False
chirps.append(the_real_chirps) chirps.append(the_real_chirps)
fish_ids.append(track_id) fish_ids.append(track_id)
for ct in the_real_chirps: for ct in the_real_chirps:
axs[0, el].axvline(ct, color='b', lw=1) axs[0, el].axvline(ct, color='b', lw=1)
plt.close() plt.close()
embed()
fig, ax = plt.subplots() fig, ax = plt.subplots()
t0 = (3 * 60 * 60 + 6 * 60 + 43.5) t0 = (3 * 60 * 60 + 6 * 60 + 43.5)
data_oi = data.raw[window_starts[0]:window_starts[-1]+ int(dt*data.raw_rate), 10] data_oi = data.raw[window_starts[0]:window_starts[-1] + int(dt*data.raw_rate), 10]
plot_spectrogram(ax, data_oi, data.raw_rate, t0) plot_spectrogram(ax, data_oi, data.raw_rate, t0)
chirps_concat = np.concatenate(chirps) chirps_concat = np.concatenate(chirps)
for ch in chirps_concat: for ch in chirps_concat:
@ -655,15 +650,24 @@ def main(datapath: str) -> None:
chirps_new = [] chirps_new = []
chirps_ids = [] chirps_ids = []
[chirps[x] for x in tr_index]
for tr in np.unique(fish_ids): for tr in np.unique(fish_ids):
tr_index = np.asarray(fish_ids) == tr tr_index = np.asarray(fish_ids) == tr
ts = flatten(list(compress(chirps, tr_index)))
ts = list(np.ravel(chirps[fish_ids == int(tr)]))
chirps_new.extend(ts) chirps_new.extend(ts)
chirps_ids.extend(list(np.ones_like(ts)*tr)) chirps_ids.extend(list(np.ones_like(ts)*tr))
# purge duplicates
purged_chirps = []
purged_chirps_ids = []
for tr in np.unique(fish_ids):
tr_chirps = np.asarray(chirps_new)[np.asarray(chirps_ids) == tr]
if len(tr_chirps) > 0:
tr_chirps_purged = purge_duplicates(
tr_chirps, config.chirp_window_threshold)
purged_chirps.extend(list(tr_chirps_purged))
purged_chirps_ids.extend(list(np.ones_like(tr_chirps_purged)*tr))
embed()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,11 +1,15 @@
import numpy as np import numpy as np
from typing import List from typing import List, Union, Any
def purge_duplicates(timestamps: List[float], threshold: float = 0.5) -> List[float]: def purge_duplicates(
timestamps: List[float], threshold: float = 0.5
) -> List[float]:
""" """
Compute the mean of groups of timestamps that are closer to the previous or consecutive timestamp than the threshold, Compute the mean of groups of timestamps that are closer to the previous
and return all timestamps that are further apart from the previous or consecutive timestamp than the threshold in a single list. or consecutive timestamp than the threshold, and return all timestamps that
are further apart from the previous or consecutive timestamp than the
threshold in a single list.
Parameters Parameters
---------- ----------
@ -17,10 +21,12 @@ def purge_duplicates(timestamps: List[float], threshold: float = 0.5) -> List[fl
Returns Returns
------- -------
List[float] List[float]
A list containing a list of timestamps that are further apart than the threshold A list containing a list of timestamps that are further apart than
and a list of means of the groups of timestamps that are closer to the previous or consecutive timestamp than the threshold. the threshold and a list of means of the groups of timestamps that
are closer to the previous or consecutive timestamp than the threshold.
""" """
# Initialize an empty list to store the groups of timestamps that are closer to the previous or consecutive timestamp than the threshold # Initialize an empty list to store the groups of timestamps that are
# closer to the previous or consecutive timestamp than the threshold
groups = [] groups = []
# initialize the first group with the first timestamp # initialize the first group with the first timestamp
@ -28,8 +34,9 @@ def purge_duplicates(timestamps: List[float], threshold: float = 0.5) -> List[fl
for i in range(1, len(timestamps)): for i in range(1, len(timestamps)):
# check the difference between current timestamp and previous timestamp is less than the threshold # check the difference between current timestamp and previous
if timestamps[i] - timestamps[i-1] < threshold: # timestamp is less than the threshold
if timestamps[i] - timestamps[i - 1] < threshold:
# add the current timestamp to the current group # add the current timestamp to the current group
group.append(timestamps[i]) group.append(timestamps[i])
else: else:
@ -40,22 +47,28 @@ def purge_duplicates(timestamps: List[float], threshold: float = 0.5) -> List[fl
# start a new group with the current timestamp # start a new group with the current timestamp
group = [timestamps[i]] group = [timestamps[i]]
# after iterating through all the timestamps, add the last group to the groups list # after iterating through all the timestamps, add the last group to the
# groups list
groups.append(group) groups.append(group)
# get the mean of each group and only include the ones that have more than 1 timestamp # get the mean of each group and only include the ones that have more
# than 1 timestamp
means = [np.mean(group) for group in groups if len(group) > 1] means = [np.mean(group) for group in groups if len(group) > 1]
# get the timestamps that are outliers, i.e. the ones that are alone in a group # get the timestamps that are outliers, i.e. the ones that are alone
# in a group
outliers = [ts for group in groups for ts in group if len(group) == 1] outliers = [ts for group in groups for ts in group if len(group) == 1]
# return the outliers and means in a single list # return the outliers and means in a single list
return outliers + means return outliers + means
def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> List[float]: def group_timestamps(
sublists: List[List[float]], n: int, threshold: float
) -> List[float]:
""" """
Groups timestamps that are less than `threshold` milliseconds apart from at least `n` other sublists. Groups timestamps that are less than `threshold` milliseconds apart from
at least `n` other sublists.
Returns a list of the mean of each group. Returns a list of the mean of each group.
If any of the sublists is empty, it will be ignored. If any of the sublists is empty, it will be ignored.
@ -64,9 +77,11 @@ def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> L
sublists : List[List[float]] sublists : List[List[float]]
a list of sublists, each containing timestamps a list of sublists, each containing timestamps
n : int n : int
minimum number of sublists that a timestamp must be close to in order to be grouped minimum number of sublists that a timestamp must be close to in order
to be grouped
threshold : float threshold : float
the maximum difference in milliseconds between timestamps to be considered a match the maximum difference in milliseconds between timestamps to be
considered a match
Returns Returns
------- -------
@ -76,7 +91,8 @@ def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> L
""" """
# Flatten the sublists and sort the timestamps # Flatten the sublists and sort the timestamps
timestamps = [ timestamps = [
timestamp for sublist in sublists if sublist for timestamp in sublist] timestamp for sublist in sublists if sublist for timestamp in sublist
]
timestamps.sort() timestamps.sort()
groups = [] groups = []
@ -84,7 +100,7 @@ def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> L
# Group timestamps that are less than threshold milliseconds apart # Group timestamps that are less than threshold milliseconds apart
for i in range(1, len(timestamps)): for i in range(1, len(timestamps)):
if timestamps[i] - timestamps[i-1] < threshold: if timestamps[i] - timestamps[i - 1] < threshold:
current_group.append(timestamps[i]) current_group.append(timestamps[i])
else: else:
groups.append(current_group) groups.append(current_group)
@ -104,10 +120,32 @@ def group_timestamps(sublists: List[List[float]], n: int, threshold: float) -> L
return means return means
def flatten(list: List[List[Any]]) -> List:
"""
Flattens a list / array of lists.
Parameters
----------
l : array or list of lists
The list to be flattened
Returns
-------
list
The flattened list
"""
return [item for sublist in list for item in sublist]
if __name__ == "__main__": if __name__ == "__main__":
timestamps = [[1.2, 1.5, 1.3], [], timestamps = [
[1.21, 1.51, 1.31], [1.19, 1.49, 1.29], [1.22, 1.52, 1.32], [1.2, 1.5, 1.3]] [1.2, 1.5, 1.3],
[],
[1.21, 1.51, 1.31],
[1.19, 1.49, 1.29],
[1.22, 1.52, 1.32],
[1.2, 1.5, 1.3],
]
print(group_timestamps(timestamps, 2, 0.05)) print(group_timestamps(timestamps, 2, 0.05))
print(purge_duplicates( print(purge_duplicates([1, 2, 3, 4, 5, 6, 6.02, 7, 8, 8.02], 0.05))
[1, 2, 3, 4, 5, 6, 6.02, 7, 8, 8.02], 0.05))