new variable names

This commit is contained in:
weygoldt 2023-01-19 18:21:59 +01:00
parent f6326de0b7
commit 9985686d53
2 changed files with 355 additions and 236 deletions

View File

@ -1,8 +1,8 @@
from itertools import compress from itertools import compress
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np
from IPython import embed from IPython import embed
import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from scipy.signal import find_peaks from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d from scipy.ndimage import gaussian_filter1d
@ -23,6 +23,11 @@ ps = PlotStyle()
@dataclass @dataclass
class PlotBuffer: class PlotBuffer:
"""
Buffer to save data that is created in the main detection loop
and plot it outside the detecion loop.
"""
config: ConfLoader config: ConfLoader
t0: float t0: float
dt: float dt: float
@ -73,14 +78,15 @@ class PlotBuffer:
figsize=(20 / 2.54, 12 / 2.54), figsize=(20 / 2.54, 12 / 2.54),
constrained_layout=True, constrained_layout=True,
sharex=True, sharex=True,
sharey='row', sharey="row",
) )
# plot spectrogram # plot spectrogram
plot_spectrogram(axs[0], data_oi, self.data.raw_rate, self.t0) plot_spectrogram(axs[0], data_oi, self.data.raw_rate, self.t0)
for chirp in chirps: for chirp in chirps:
axs[0].scatter(chirp, np.median(self.frequency), c=ps.red) axs[0].scatter(chirp, np.median(self.frequency),
c=ps.black, marker="x")
# plot waveform of filtered signal # plot waveform of filtered signal
axs[1].plot(self.time, self.baseline, c=ps.green) axs[1].plot(self.time, self.baseline, c=ps.green)
@ -114,8 +120,9 @@ class PlotBuffer:
self.frequency_filtered[self.frequency_peaks], self.frequency_filtered[self.frequency_peaks],
c=ps.red, c=ps.red,
) )
axs[0].set_ylim(np.max(self.frequency)-200, axs[0].set_ylim(
top=np.max(self.frequency)+200) np.max(self.frequency) - 200, top=np.max(self.frequency) + 200
)
axs[6].set_xlabel("Time [s]") axs[6].set_xlabel("Time [s]")
axs[0].set_title("Spectrogram") axs[0].set_title("Spectrogram")
axs[1].set_title("Fitered baseline") axs[1].set_title("Fitered baseline")
@ -123,20 +130,63 @@ class PlotBuffer:
axs[3].set_title("Fitered baseline instanenous frequency") axs[3].set_title("Fitered baseline instanenous frequency")
axs[4].set_title("Filtered envelope of baseline envelope") axs[4].set_title("Filtered envelope of baseline envelope")
axs[5].set_title("Search envelope") axs[5].set_title("Search envelope")
axs[6].set_title( axs[6].set_title("Filtered absolute instantaneous frequency")
"Filtered absolute instantaneous frequency")
if plot == 'show': if plot == "show":
plt.show() plt.show()
elif plot == 'save': elif plot == "save":
make_outputdir(self.config.outputdir) make_outputdir(self.config.outputdir)
out = make_outputdir(self.config.outputdir + out = make_outputdir(
self.data.datapath.split('/')[-2] + '/') self.config.outputdir + self.data.datapath.split("/")[-2] + "/"
)
plt.savefig(f"{out}{self.track_id}_{self.t0}.pdf") plt.savefig(f"{out}{self.track_id}_{self.t0}.pdf")
plt.close() plt.close()
def plot_spectrogram(
axis, signal: np.ndarray, samplerate: float, t0: float
) -> None:
"""
Plot a spectrogram of a signal.
Parameters
----------
axis : matplotlib axis
Axis to plot the spectrogram on.
signal : np.ndarray
Signal to plot the spectrogram from.
samplerate : float
Samplerate of the signal.
t0 : float
Start time of the signal.
"""
logger.debug("Plotting spectrogram")
# compute spectrogram
spec_power, spec_freqs, spec_times = spectrogram(
signal,
ratetime=samplerate,
freq_resolution=20,
overlap_frac=0.5,
)
# axis.pcolormesh(
# spec_times + t0,
# spec_freqs,
# decibel(spec_power),
# )
axis.imshow(
decibel(spec_power),
extent=[spec_times[0] + t0, spec_times[-1] +
t0, spec_freqs[0], spec_freqs[-1]],
aspect="auto",
origin="lower",
interpolation="gaussian",
)
def instantaneos_frequency( def instantaneos_frequency(
signal: np.ndarray, samplerate: int signal: np.ndarray, samplerate: int
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
@ -158,8 +208,9 @@ def instantaneos_frequency(
# calculate instantaneos frequency with zero crossings # calculate instantaneos frequency with zero crossings
roll_signal = np.roll(signal, shift=1) roll_signal = np.roll(signal, shift=1)
time_signal = np.arange(len(signal)) / samplerate time_signal = np.arange(len(signal)) / samplerate
period_index = np.arange(len(signal))[( period_index = np.arange(len(signal))[(roll_signal < 0) & (signal >= 0)][
roll_signal < 0) & (signal >= 0)][1:-1] 1:-1
]
upper_bound = np.abs(signal[period_index]) upper_bound = np.abs(signal[period_index])
lower_bound = np.abs(signal[period_index - 1]) lower_bound = np.abs(signal[period_index - 1])
@ -182,43 +233,12 @@ def instantaneos_frequency(
return inst_freq_time, inst_freq return inst_freq_time, inst_freq
def plot_spectrogram(axis, signal: np.ndarray, samplerate: float, t0: float) -> None:
"""
Plot a spectrogram of a signal.
Parameters
----------
axis : matplotlib axis
Axis to plot the spectrogram on.
signal : np.ndarray
Signal to plot the spectrogram from.
samplerate : float
Samplerate of the signal.
t0 : float
Start time of the signal.
"""
logger.debug("Plotting spectrogram")
# compute spectrogram
spec_power, spec_freqs, spec_times = spectrogram(
signal,
ratetime=samplerate,
freq_resolution=50,
overlap_frac=0.2,
)
axis.pcolormesh(
spec_times + t0,
spec_freqs,
decibel(spec_power),
)
axis.set_ylim(200, 1200)
def double_bandpass( def double_bandpass(
data: DataLoader, samplerate: int, freqs: np.ndarray, search_freq: float data: DataLoader,
samplerate: int,
freqs: np.ndarray,
search_freq: float,
config: ConfLoader
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
""" """
Apply a bandpass filter to the baseline of a signal and a second bandpass Apply a bandpass filter to the baseline of a signal and a second bandpass
@ -241,7 +261,7 @@ def double_bandpass(
""" """
# compute boundaries to filter baseline # compute boundaries to filter baseline
q25, q75 = np.percentile(freqs, [25, 75]) q25, q50, q75 = np.percentile(freqs, [25, 50, 75])
# check if percentile delta is too small # check if percentile delta is too small
if q75 - q25 < 5: if q75 - q25 < 5:
@ -253,13 +273,17 @@ def double_bandpass(
# filter search area # filter search area
filtered_search_freq = bandpass_filter( filtered_search_freq = bandpass_filter(
data, samplerate, lowf=q25 + search_freq, highf=q75 + search_freq data, samplerate,
lowf=search_freq + q50 - config.search_bandwidth / 2,
highf=search_freq + q50 + config.search_bandwidth / 2
) )
return (filtered_baseline, filtered_search_freq) return filtered_baseline, filtered_search_freq
def freqmedian_allfish(data: LoadData, t0: float, dt: float) -> tuple[float, list[int]]: def freqmedian_allfish(
data: LoadData, t0: float, dt: float
) -> tuple[float, list[int]]:
""" """
Calculate the median frequency of all fish in a given time window. Calculate the median frequency of all fish in a given time window.
@ -283,8 +307,9 @@ def freqmedian_allfish(data: LoadData, t0: float, dt: float) -> tuple[float, lis
for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
window_idx = np.arange(len(data.idx))[ window_idx = np.arange(len(data.idx))[
(data.ident == track_id) & (data.time[data.idx] >= t0) & ( (data.ident == track_id)
data.time[data.idx] <= (t0 + dt)) & (data.time[data.idx] >= t0)
& (data.time[data.idx] <= (t0 + dt))
] ]
if len(data.freq[window_idx]) > 0: if len(data.freq[window_idx]) > 0:
@ -298,6 +323,112 @@ def freqmedian_allfish(data: LoadData, t0: float, dt: float) -> tuple[float, lis
return median_freq, track_ids return median_freq, track_ids
def find_search_freq(
freq_temp: np.ndarray,
median_ids: np.ndarray,
median_freq: np.ndarray,
config: ConfLoader,
data: LoadData,
) -> float:
"""
Find the search frequency for each fish by checking which fish EODs are
above the current EOD and finding a gap in them.
Parameters
----------
freq_temp : np.ndarray
Current EOD frequency array / the current fish of interest.
median_ids : np.ndarray
Array of track IDs of the medians of all other fish in the current window.
median_freq : np.ndarray
Array of median frequencies of all other fish in the current window.
config : ConfLoader
Configuration file.
data : LoadData
Data to find the search frequency from.
Returns
-------
float
"""
# frequency where second filter filters
search_window = np.arange(
np.median(freq_temp) + config.search_df_lower,
np.median(freq_temp) + config.search_df_upper,
config.search_res,
)
# search window in boolean
search_window_bool = np.ones(len(search_window), dtype=bool)
# get tracks that fall into search window
check_track_ids = median_ids[
(median_freq > search_window[0]) & (median_freq < search_window[-1])
]
# iterate through theses tracks
if check_track_ids.size != 0:
for j, check_track_id in enumerate(check_track_ids):
q1, q2 = np.percentile(
data.freq[data.ident == check_track_id],
config.search_freq_percentiles,
)
search_window_bool[
(search_window > q1) & (search_window < q2)
] = False
# find gaps in search window
search_window_indices = np.arange(len(search_window))
# get search window gaps
search_window_gaps = np.diff(search_window_bool, append=np.nan)
nonzeros = search_window_gaps[np.nonzero(search_window_gaps)[0]]
nonzeros = nonzeros[~np.isnan(nonzeros)]
# if the first value is -1, the array starst with true, so a gap
if nonzeros[0] == -1:
stops = search_window_indices[search_window_gaps == -1]
starts = np.append(
0, search_window_indices[search_window_gaps == 1]
)
# if the last value is -1, the array ends with true, so a gap
if nonzeros[-1] == 1:
stops = np.append(
search_window_indices[search_window_gaps == -1],
len(search_window) - 1,
)
# else it starts with false, so no gap
if nonzeros[0] == 1:
stops = search_window_indices[search_window_gaps == -1]
starts = search_window_indices[search_window_gaps == 1]
# if the last value is -1, the array ends with true, so a gap
if nonzeros[-1] == 1:
stops = np.append(
search_window_indices[search_window_gaps == -1],
len(search_window),
)
# get the frequency ranges of the gaps
search_windows = [search_window[x:y] for x, y in zip(starts, stops)]
search_windows_lens = [len(x) for x in search_windows]
longest_search_window = search_windows[np.argmax(search_windows_lens)]
search_freq = (
longest_search_window[-1] - longest_search_window[0]) / 2
else:
search_freq = config.default_search_freq
return search_freq
def main(datapath: str, plot: str) -> None: def main(datapath: str, plot: str) -> None:
assert plot in ["save", "show", "false"] assert plot in ["save", "show", "false"]
@ -328,24 +459,24 @@ def main(datapath: str, plot: str) -> None:
# make time array for raw data # make time array for raw data
raw_time = np.arange(data.raw.shape[0]) / data.raw_rate raw_time = np.arange(data.raw.shape[0]) / data.raw_rate
# # good chirp times for data: 2022-06-02-10_00 # good chirp times for data: 2022-06-02-10_00
# t0 = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate t0 = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
# dt = 60 * data.raw_rate dt = 60 * data.raw_rate
t0 = 0 # t0 = 0
dt = data.raw.shape[0] # dt = data.raw.shape[0]
# generate starting points of rolling window # generate starting points of rolling window
window_starts = np.arange( window_starts = np.arange(
t0, t0,
t0 + dt, t0 + dt,
window_duration - (window_overlap + 2 * window_edge), window_duration - (window_overlap + 2 * window_edge),
dtype=int dtype=int,
) )
# ititialize lists to store data # ititialize lists to store data
chirps = [] multiwindow_chirps = []
fish_ids = [] multiwindow_ids = []
for st, start_index in enumerate(window_starts): for st, start_index in enumerate(window_starts):
@ -362,14 +493,17 @@ def main(datapath: str, plot: str) -> None:
median_freq, median_ids = freqmedian_allfish(data, t0, dt) median_freq, median_ids = freqmedian_allfish(data, t0, dt)
# iterate through all fish # iterate through all fish
for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): for tr, track_id in enumerate(
np.unique(data.ident[~np.isnan(data.ident)])
):
logger.debug(f"Processing track {tr} of {len(data.ids)}") logger.debug(f"Processing track {tr} of {len(data.ids)}")
# get index of track data in this time window # get index of track data in this time window
window_idx = np.arange(len(data.idx))[ window_idx = np.arange(len(data.idx))[
(data.ident == track_id) & (data.time[data.idx] >= t0) & ( (data.ident == track_id)
data.time[data.idx] <= (t0 + dt)) & (data.time[data.idx] >= t0)
& (data.time[data.idx] <= (t0 + dt))
] ]
# get tracked frequencies and their times # get tracked frequencies and their times
@ -384,99 +518,45 @@ def main(datapath: str, plot: str) -> None:
# check if tracked data available in this window # check if tracked data available in this window
if len(freq_temp) < expected_duration * 0.5: if len(freq_temp) < expected_duration * 0.5:
logger.warning( logger.warning(
f"Track {track_id} has no data in window {st}, skipping.") f"Track {track_id} has no data in window {st}, skipping."
)
continue continue
# check if there are powers available in this window # check if there are powers available in this window
nanchecker = np.unique(np.isnan(powers_temp)) nanchecker = np.unique(np.isnan(powers_temp))
if (len(nanchecker) == 1) and nanchecker[0] == True: if (len(nanchecker) == 1) and nanchecker[0]:
logger.warning( logger.warning(
f"No powers available for track {track_id} window {st}, skipping.") f"No powers available for track {track_id} window {st}, \
continue skipping."
best_electrodes = np.argsort(np.nanmean(
powers_temp, axis=0))[-config.number_electrodes:]
# frequency where second filter filters
search_window = np.arange(
np.median(freq_temp)+config.search_df_lower, np.median(
freq_temp)+config.search_df_upper, config.search_res)
# search window in boolean
search_window_bool = np.ones(len(search_window), dtype=bool)
# get tracks that fall into search window
check_track_ids = median_ids[(median_freq > search_window[0]) & (
median_freq < search_window[-1])]
# iterate through theses tracks
if check_track_ids.size != 0:
for j, check_track_id in enumerate(check_track_ids):
q1, q2 = np.percentile(
data.freq[data.ident == check_track_id],
config.search_freq_percentiles
)
search_window_bool[(search_window > q1) & (
search_window < q2)] = False
# find gaps in search window
search_window_indices = np.arange(len(search_window))
# get search window gaps
search_window_gaps = np.diff(search_window_bool, append=np.nan)
nonzeros = search_window_gaps[np.nonzero(
search_window_gaps)[0]]
nonzeros = nonzeros[~np.isnan(nonzeros)]
# if the first value is -1, the array starst with true, so a gap
if nonzeros[0] == -1:
stops = search_window_indices[search_window_gaps == -1]
starts = np.append(
0, search_window_indices[search_window_gaps == 1])
# if the last value is -1, the array ends with true, so a gap
if nonzeros[-1] == 1:
stops = np.append(
search_window_indices[search_window_gaps == -1],
len(search_window) - 1
) )
continue
# else it starts with false, so no gap # find the strongest electrodes for the current fish in the current
if nonzeros[0] == 1: # window
stops = search_window_indices[search_window_gaps == -1] best_electrodes = np.argsort(np.nanmean(powers_temp, axis=0))[
starts = search_window_indices[search_window_gaps == 1] -config.number_electrodes:
]
# if the last value is -1, the array ends with true, so a gap # find a frequency above the baseline of the current fish in which
if nonzeros[-1] == 1: # no other fish is active to search for chirps there
stops = np.append( search_freq = find_search_freq(
search_window_indices[search_window_gaps == -1], config=config,
len(search_window) freq_temp=freq_temp,
median_ids=median_ids,
data=data,
median_freq=median_freq,
) )
# get the frequency ranges of the gaps # add all chirps that are detected on mulitple electrodes for one
search_windows = [search_window[x:y] # fish fish in one window to this list
for x, y in zip(starts, stops)] multielectrode_chirps = []
search_windows_lens = [len(x) for x in search_windows]
longest_search_window = search_windows[np.argmax(
search_windows_lens)]
search_freq = (
longest_search_window[1] - longest_search_window[0]) / 2
else:
search_freq = config.default_search_freq
# ----------- chrips on the two best electrodes-----------
chirps_electrodes = []
# iterate through electrodes # iterate through electrodes
for el, electrode in enumerate(best_electrodes): for el, electrode in enumerate(best_electrodes):
logger.debug( logger.debug(
f"Processing electrode {el} of {len(best_electrodes)}") f"Processing electrode {el} of {len(best_electrodes)}"
)
# load region of interest of raw data file # load region of interest of raw data file
data_oi = data.raw[start_index:stop_index, :] data_oi = data.raw[start_index:stop_index, :]
@ -487,15 +567,8 @@ def main(datapath: str, plot: str) -> None:
data_oi[:, electrode], data_oi[:, electrode],
data.raw_rate, data.raw_rate,
freq_temp, freq_temp,
search_freq search_freq,
) config=config,
# compute instantaneous frequency on broad signal
broad_baseline = bandpass_filter(
data_oi[:, electrode],
data.raw_rate,
lowf=np.mean(freq_temp)-5,
highf=np.mean(freq_temp)+100
) )
# compute instantaneous frequency on narrow signal # compute instantaneous frequency on narrow signal
@ -505,67 +578,73 @@ def main(datapath: str, plot: str) -> None:
# compute envelopes # compute envelopes
baseline_envelope_unfiltered = envelope( baseline_envelope_unfiltered = envelope(
baseline, data.raw_rate, config.envelope_cutoff) baseline, data.raw_rate, config.envelope_cutoff
)
search_envelope = envelope( search_envelope = envelope(
search, data.raw_rate, config.envelope_cutoff) search, data.raw_rate, config.envelope_cutoff
)
# highpass filter envelopes # highpass filter envelopes
baseline_envelope = highpass_filter( baseline_envelope = highpass_filter(
baseline_envelope_unfiltered, baseline_envelope_unfiltered,
data.raw_rate, data.raw_rate,
config.envelope_highpass_cutoff config.envelope_highpass_cutoff,
) )
# envelopes of filtered envelope of filtered baseline # envelopes of filtered envelope of filtered baseline
baseline_envelope = envelope( baseline_envelope = envelope(
np.abs(baseline_envelope), np.abs(baseline_envelope),
data.raw_rate, data.raw_rate,
config.envelope_envelope_cutoff config.envelope_envelope_cutoff,
) )
# bandpass filter the instantaneous # bandpass filter the instantaneous frequency to put it to 0
inst_freq_filtered = bandpass_filter( inst_freq_filtered = bandpass_filter(
baseline_freq, baseline_freq,
data.raw_rate, data.raw_rate,
lowf=config.instantaneous_lowf, lowf=config.instantaneous_lowf,
highf=config.instantaneous_highf highf=config.instantaneous_highf,
) )
# CUT OFF OVERLAP --------------------------------------------- # CUT OFF OVERLAP ---------------------------------------------
# cut off first and last 0.5 * overlap at start and end # overwrite raw time to valid region, i.e. cut off snippet at
# start and end of each window to remove filter effects
valid = np.arange( valid = np.arange(
int(window_edge), len(baseline_envelope) - int(window_edge), len(baseline_envelope) - int(window_edge)
int(window_edge)
) )
baseline_envelope_unfiltered = baseline_envelope_unfiltered[valid] baseline_envelope_unfiltered = baseline_envelope_unfiltered[
valid
]
baseline_envelope = baseline_envelope[valid] baseline_envelope = baseline_envelope[valid]
search_envelope = search_envelope[valid] search_envelope = search_envelope[valid]
# get inst freq valid snippet # get inst freq valid snippet
valid_t0 = int(window_edge) / data.raw_rate valid_t0 = int(window_edge) / data.raw_rate
valid_t1 = baseline_freq_time[-1] - \ valid_t1 = baseline_freq_time[-1] - (
(int(window_edge) / data.raw_rate) int(window_edge) / data.raw_rate
)
inst_freq_filtered = inst_freq_filtered[ inst_freq_filtered = inst_freq_filtered[
(baseline_freq_time >= valid_t0) & ( (baseline_freq_time >= valid_t0)
baseline_freq_time <= valid_t1) & (baseline_freq_time <= valid_t1)
] ]
baseline_freq = baseline_freq[ baseline_freq = baseline_freq[
(baseline_freq_time >= valid_t0) & ( (baseline_freq_time >= valid_t0)
baseline_freq_time <= valid_t1) & (baseline_freq_time <= valid_t1)
] ]
baseline_freq_time = baseline_freq_time[ baseline_freq_time = (
(baseline_freq_time >= valid_t0) & ( baseline_freq_time[
baseline_freq_time <= valid_t1) (baseline_freq_time >= valid_t0)
] + t0 & (baseline_freq_time <= valid_t1)
]
+ t0
)
# overwrite raw time to valid region
time_oi = time_oi[valid] time_oi = time_oi[valid]
baseline = baseline[valid] baseline = baseline[valid]
broad_baseline = broad_baseline[valid]
search = search[valid] search = search[valid]
# NORMALIZE --------------------------------------------------- # NORMALIZE ---------------------------------------------------
@ -576,49 +655,59 @@ def main(datapath: str, plot: str) -> None:
# PEAK DETECTION ---------------------------------------------- # PEAK DETECTION ----------------------------------------------
prominence = config.prominence
# detect peaks baseline_enelope # detect peaks baseline_enelope
prominence = np.percentile(
baseline_envelope, config.baseline_prominence_percentile)
baseline_peaks, _ = find_peaks( baseline_peaks, _ = find_peaks(
baseline_envelope, prominence=prominence) baseline_envelope, prominence=prominence
)
# detect peaks search_envelope # detect peaks search_envelope
prominence = np.percentile(
search_envelope, config.search_prominence_percentile)
search_peaks, _ = find_peaks( search_peaks, _ = find_peaks(
search_envelope, prominence=prominence) search_envelope, prominence=prominence
# detect peaks inst_freq_filtered
prominence = np.percentile(
inst_freq_filtered,
config.instantaneous_prominence_percentile
) )
# detect peaks inst_freq_filtered
inst_freq_peaks, _ = find_peaks( inst_freq_peaks, _ = find_peaks(
inst_freq_filtered, inst_freq_filtered, prominence=prominence
prominence=prominence
) )
# DETECT CHIRPS IN SEARCH WINDOW ------------------------------- # DETECT CHIRPS IN SEARCH WINDOW ------------------------------
# get the peak timestamps from the peak indices
baseline_ts = time_oi[baseline_peaks] baseline_ts = time_oi[baseline_peaks]
search_ts = time_oi[search_peaks] search_ts = time_oi[search_peaks]
freq_ts = baseline_freq_time[inst_freq_peaks] freq_ts = baseline_freq_time[inst_freq_peaks]
# check if one list is empty # check if one list is empty and if so, skip to the next
if len(baseline_ts) == 0 or len(search_ts) == 0 or len(freq_ts) == 0: # electrode because a chirp cannot be detected if one is empty
if (
len(baseline_ts) == 0
or len(search_ts) == 0
or len(freq_ts) == 0
):
continue continue
current_chirps = group_timestamps( # group peak across feature arrays but only if they
[list(baseline_ts), list(search_ts), list(freq_ts)], 3, config.chirp_window_threshold) # occur in all 3 feature arrays
# for checking if there are chirps on multiple electrodes singleelectrode_chirps = group_timestamps(
if len(current_chirps) == 0: [list(baseline_ts), list(search_ts), list(freq_ts)],
3,
config.chirp_window_threshold,
)
# check it there are chirps detected after grouping, continue
# with the loop if not
if len(singleelectrode_chirps) == 0:
continue continue
chirps_electrodes.append(current_chirps) # append chirps from this electrode to the multilectrode list
multielectrode_chirps.append(singleelectrode_chirps)
if (el == config.number_electrodes - 1) & \ # only initialize the plotting buffer if chirps are detected
(len(current_chirps) > 0) & \ if (
(plot in ["show", "save"]): (el == config.number_electrodes - 1)
& (len(singleelectrode_chirps) > 0)
& (plot in ["show", "save"])
):
logger.debug("Detected chirp, ititialize buffer ...") logger.debug("Detected chirp, ititialize buffer ...")
@ -646,21 +735,37 @@ def main(datapath: str, plot: str) -> None:
logger.debug("Buffer initialized!") logger.debug("Buffer initialized!")
logger.debug( logger.debug(
f"Processed all electrodes for fish {track_id} for this window, sorting chirps ...") f"Processed all electrodes for fish {track_id} for this \
window, sorting chirps ..."
)
if len(chirps_electrodes) == 0: # check if there are chirps detected in multiple electrodes and
# continue the loop if not
if len(multielectrode_chirps) == 0:
continue continue
the_real_chirps = group_timestamps(chirps_electrodes, 2, 0.05) # validate multielectrode chirps, i.e. check if they are
# detected in at least 'config.min_electrodes' electrodes
multielectrode_chirps_validated = group_timestamps(
multielectrode_chirps,
config.minimum_electrodes,
config.chirp_window_threshold
)
chirps.append(the_real_chirps) # add validated chirps to the list that tracks chirps across there
fish_ids.append(track_id) # rolling time windows
multiwindow_chirps.append(multielectrode_chirps_validated)
multiwindow_ids.append(track_id)
logger.debug('Found %d chirps, starting plotting ... ' % logger.debug(
len(the_real_chirps)) "Found %d chirps, starting plotting ... "
if len(the_real_chirps) > 0: % len(multielectrode_chirps_validated)
)
# if chirps are detected and the plot flag is set, plot the
# chirps, otheswise try to delete the buffer if it exists
if len(multielectrode_chirps_validated) > 0:
try: try:
buffer.plot_buffer(the_real_chirps, plot) buffer.plot_buffer(multielectrode_chirps_validated, plot)
except NameError: except NameError:
pass pass
else: else:
@ -669,29 +774,42 @@ def main(datapath: str, plot: str) -> None:
except NameError: except NameError:
pass pass
chirps_new = [] # flatten list of lists containing chirps and create
chirps_ids = [] # an array of fish ids that correspond to the chirps
for tr in np.unique(fish_ids): multiwindow_chirps_flat = []
tr_index = np.asarray(fish_ids) == tr multiwindow_ids_flat = []
ts = flatten(list(compress(chirps, tr_index))) for tr in np.unique(multiwindow_ids):
chirps_new.extend(ts) tr_index = np.asarray(multiwindow_ids) == tr
chirps_ids.extend(list(np.ones_like(ts)*tr)) ts = flatten(list(compress(multiwindow_chirps, tr_index)))
multiwindow_chirps_flat.extend(ts)
# purge duplicates multiwindow_ids_flat.extend(list(np.ones_like(ts) * tr))
# purge duplicates, i.e. chirps that are very close to each other
# duplites arise due to overlapping windows
purged_chirps = [] purged_chirps = []
purged_chirps_ids = [] purged_ids = []
for tr in np.unique(fish_ids): for tr in np.unique(multiwindow_ids_flat):
tr_chirps = np.asarray(chirps_new)[np.asarray(chirps_ids) == tr] tr_chirps = np.asarray(multiwindow_chirps_flat)[
np.asarray(multiwindow_ids_flat) == tr]
if len(tr_chirps) > 0: if len(tr_chirps) > 0:
tr_chirps_purged = purge_duplicates( tr_chirps_purged = purge_duplicates(
tr_chirps, config.chirp_window_threshold) tr_chirps, config.chirp_window_threshold
)
purged_chirps.extend(list(tr_chirps_purged)) purged_chirps.extend(list(tr_chirps_purged))
purged_chirps_ids.extend(list(np.ones_like(tr_chirps_purged)*tr)) purged_ids.extend(list(np.ones_like(tr_chirps_purged) * tr))
# sort chirps by time
purged_chirps = np.asarray(purged_chirps)
purged_ids = np.asarray(purged_ids)
purged_ids = purged_ids[np.argsort(purged_chirps)]
purged_chirps = purged_chirps[np.argsort(purged_chirps)]
np.save(datapath + 'chirps.npy', purged_chirps) # save them into the data directory
np.save(datapath + 'chirps_ids.npy', purged_chirps_ids) np.save(datapath + "chirps.npy", purged_chirps)
np.save(datapath + "chirp_ids.npy", purged_ids)
if __name__ == "__main__": if __name__ == "__main__":
# datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-05-13-10_00/"
datapath = "../data/2022-06-02-10_00/" datapath = "../data/2022-06-02-10_00/"
main(datapath, plot="save") main(datapath, plot="show")

View File

@ -8,9 +8,9 @@ edge: 0.25
# Number of electrodes to go over # Number of electrodes to go over
number_electrodes: 3 number_electrodes: 3
minimum_electrodes: 2
# Boundary for search frequency in Hz # Search window bandwidth
search_boundary: 100
# Cutoff frequency for envelope estimation by lowpass filter # Cutoff frequency for envelope estimation by lowpass filter
envelope_cutoff: 25 envelope_cutoff: 25
@ -26,23 +26,24 @@ instantaneous_lowf: 15
instantaneous_highf: 8000 instantaneous_highf: 8000
# Baseline envelope peak detection parameters # Baseline envelope peak detection parameters
baseline_prominence_percentile: 90 # baseline_prominence_percentile: 90
# Search envelope peak detection parameters # Search envelope peak detection parameters
search_prominence_percentile: 90 # search_prominence_percentile: 90
# Instantaneous frequency peak detection parameters # Instantaneous frequency peak detection parameters
instantaneous_prominence_percentile: 90 # instantaneous_prominence_percentile: 90
prominence: 0.005
# search freq parameter # search freq parameter
search_df_lower: 25 search_df_lower: 20
search_df_upper: 100 search_df_upper: 100
search_res: 1 search_res: 1
search_freq_percentiles: search_bandwidth: 10
- 5
- 95
default_search_freq: 50 default_search_freq: 50
# Classify events as chirps if they are less than this time apart
chirp_window_threshold: 0.05 chirp_window_threshold: 0.05