refactoring finished for now

This commit is contained in:
weygoldt 2023-01-20 13:56:26 +01:00
parent 9985686d53
commit ddf7bd545a
4 changed files with 401 additions and 278 deletions

533
code/chirpdetection.py Normal file → Executable file
View File

@ -1,20 +1,22 @@
from itertools import compress from itertools import compress
from dataclasses import dataclass from dataclasses import dataclass
from IPython import embed
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from scipy.signal import find_peaks from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
from thunderfish.dataloader import DataLoader
from thunderfish.powerspectrum import spectrogram, decibel from thunderfish.powerspectrum import spectrogram, decibel
from sklearn.preprocessing import normalize from sklearn.preprocessing import normalize
from modules.filters import bandpass_filter, envelope, highpass_filter from modules.filters import bandpass_filter, envelope, highpass_filter
from modules.filehandling import ConfLoader, LoadData, make_outputdir from modules.filehandling import ConfLoader, LoadData, make_outputdir
from modules.datahandling import flatten, purge_duplicates, group_timestamps
from modules.plotstyle import PlotStyle from modules.plotstyle import PlotStyle
from modules.logger import makeLogger from modules.logger import makeLogger
from modules.datahandling import (
flatten,
purge_duplicates,
group_timestamps,
instantaneous_frequency,
)
logger = makeLogger(__name__) logger = makeLogger(__name__)
@ -28,6 +30,7 @@ class PlotBuffer:
Buffer to save data that is created in the main detection loop Buffer to save data that is created in the main detection loop
and plot it outside the detecion loop. and plot it outside the detecion loop.
""" """
config: ConfLoader config: ConfLoader
t0: float t0: float
dt: float dt: float
@ -85,8 +88,9 @@ class PlotBuffer:
plot_spectrogram(axs[0], data_oi, self.data.raw_rate, self.t0) plot_spectrogram(axs[0], data_oi, self.data.raw_rate, self.t0)
for chirp in chirps: for chirp in chirps:
axs[0].scatter(chirp, np.median(self.frequency), axs[0].scatter(
c=ps.black, marker="x") chirp, np.median(self.frequency), c=ps.black, marker="x"
)
# plot waveform of filtered signal # plot waveform of filtered signal
axs[1].plot(self.time, self.baseline, c=ps.green) axs[1].plot(self.time, self.baseline, c=ps.green)
@ -94,7 +98,7 @@ class PlotBuffer:
# plot waveform of filtered search signal # plot waveform of filtered search signal
axs[2].plot(self.time, self.search) axs[2].plot(self.time, self.search)
# plot baseline instantaneos frequency # plot baseline instantaneous frequency
axs[3].plot(self.frequency_time, self.frequency) axs[3].plot(self.frequency_time, self.frequency)
# plot filtered and rectified envelope # plot filtered and rectified envelope
@ -145,7 +149,7 @@ class PlotBuffer:
def plot_spectrogram( def plot_spectrogram(
axis, signal: np.ndarray, samplerate: float, t0: float axis, signal: np.ndarray, samplerate: float, window_start_seconds: float
) -> None: ) -> None:
""" """
Plot a spectrogram of a signal. Plot a spectrogram of a signal.
@ -158,7 +162,7 @@ def plot_spectrogram(
Signal to plot the spectrogram from. Signal to plot the spectrogram from.
samplerate : float samplerate : float
Samplerate of the signal. Samplerate of the signal.
t0 : float window_start_seconds : float
Start time of the signal. Start time of the signal.
""" """
@ -172,73 +176,26 @@ def plot_spectrogram(
overlap_frac=0.5, overlap_frac=0.5,
) )
# axis.pcolormesh(
# spec_times + t0,
# spec_freqs,
# decibel(spec_power),
# )
axis.imshow( axis.imshow(
decibel(spec_power), decibel(spec_power),
extent=[spec_times[0] + t0, spec_times[-1] + extent=[
t0, spec_freqs[0], spec_freqs[-1]], spec_times[0] + window_start_seconds,
spec_times[-1] + window_start_seconds,
spec_freqs[0],
spec_freqs[-1],
],
aspect="auto", aspect="auto",
origin="lower", origin="lower",
interpolation="gaussian", interpolation="gaussian",
) )
def instantaneos_frequency( def extract_frequency_bands(
signal: np.ndarray, samplerate: int raw_data: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]: samplerate: int,
""" baseline_track: np.ndarray,
Compute the instantaneous frequency of a signal. searchband_center: float,
minimal_bandwidth: float,
Parameters
----------
signal : np.ndarray
Signal to compute the instantaneous frequency from.
samplerate : int
Samplerate of the signal.
Returns
-------
tuple[np.ndarray, np.ndarray]
"""
# calculate instantaneos frequency with zero crossings
roll_signal = np.roll(signal, shift=1)
time_signal = np.arange(len(signal)) / samplerate
period_index = np.arange(len(signal))[(roll_signal < 0) & (signal >= 0)][
1:-1
]
upper_bound = np.abs(signal[period_index])
lower_bound = np.abs(signal[period_index - 1])
upper_time = np.abs(time_signal[period_index])
lower_time = np.abs(time_signal[period_index - 1])
# create ratio
lower_ratio = lower_bound / (lower_bound + upper_bound)
# appy to time delta
time_delta = upper_time - lower_time
true_zero = lower_time + lower_ratio * time_delta
# create new time array
inst_freq_time = true_zero[:-1] + 0.5 * np.diff(true_zero)
# compute frequency
inst_freq = gaussian_filter1d(1 / np.diff(true_zero), 5)
return inst_freq_time, inst_freq
def double_bandpass(
data: DataLoader,
samplerate: int,
freqs: np.ndarray,
search_freq: float,
config: ConfLoader
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
""" """
Apply a bandpass filter to the baseline of a signal and a second bandpass Apply a bandpass filter to the baseline of a signal and a second bandpass
@ -246,14 +203,16 @@ def double_bandpass(
Parameters Parameters
---------- ----------
data : DataLoader raw_data : np.ndarray
Data to apply the filter to. Data to apply the filter to.
samplerate : int samplerate : int
Samplerate of the signal. Samplerate of the signal.
freqs : np.ndarray baseline_track : np.ndarray
Tracked fundamental frequencies of the signal. Tracked fundamental frequencies of the signal.
search_freq : float searchband_center: float
Frequency to search for above or below the baseline. Frequency to search for above or below the baseline.
minimal_bandwidth : float
Minimal bandwidth of the filter.
Returns Returns
------- -------
@ -261,28 +220,30 @@ def double_bandpass(
""" """
# compute boundaries to filter baseline # compute boundaries to filter baseline
q25, q50, q75 = np.percentile(freqs, [25, 50, 75]) q25, q50, q75 = np.percentile(baseline_track, [25, 50, 75])
# check if percentile delta is too small # check if percentile delta is too small
if q75 - q25 < 5: if q75 - q25 < 10:
median = np.median(freqs) q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2
q25, q75 = median - 2.5, median + 2.5
# filter baseline # filter baseline
filtered_baseline = bandpass_filter(data, samplerate, lowf=q25, highf=q75) filtered_baseline = bandpass_filter(
raw_data, samplerate, lowf=q25, highf=q75
)
# filter search area # filter search area
filtered_search_freq = bandpass_filter( filtered_search_freq = bandpass_filter(
data, samplerate, raw_data,
lowf=search_freq + q50 - config.search_bandwidth / 2, samplerate,
highf=search_freq + q50 + config.search_bandwidth / 2 lowf=searchband_center + q50 - minimal_bandwidth / 2,
highf=searchband_center + q50 + minimal_bandwidth / 2,
) )
return filtered_baseline, filtered_search_freq return filtered_baseline, filtered_search_freq
def freqmedian_allfish( def window_median_all_track_ids(
data: LoadData, t0: float, dt: float data: LoadData, window_start_seconds: float, window_duration_seconds: float
) -> tuple[float, list[int]]: ) -> tuple[float, list[int]]:
""" """
Calculate the median frequency of all fish in a given time window. Calculate the median frequency of all fish in a given time window.
@ -291,9 +252,9 @@ def freqmedian_allfish(
---------- ----------
data : LoadData data : LoadData
Data to calculate the median frequency from. Data to calculate the median frequency from.
t0 : float window_start_seconds : float
Start time of the window. Start time of the window.
dt : float window_duration_seconds : float
Duration of the window. Duration of the window.
Returns Returns
@ -308,8 +269,11 @@ def freqmedian_allfish(
for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
window_idx = np.arange(len(data.idx))[ window_idx = np.arange(len(data.idx))[
(data.ident == track_id) (data.ident == track_id)
& (data.time[data.idx] >= t0) & (data.time[data.idx] >= window_start_seconds)
& (data.time[data.idx] <= (t0 + dt)) & (
data.time[data.idx]
<= (window_start_seconds + window_duration_seconds)
)
] ]
if len(data.freq[window_idx]) > 0: if len(data.freq[window_idx]) > 0:
@ -323,7 +287,7 @@ def freqmedian_allfish(
return median_freq, track_ids return median_freq, track_ids
def find_search_freq( def find_searchband(
freq_temp: np.ndarray, freq_temp: np.ndarray,
median_ids: np.ndarray, median_ids: np.ndarray,
median_freq: np.ndarray, median_freq: np.ndarray,
@ -331,15 +295,16 @@ def find_search_freq(
data: LoadData, data: LoadData,
) -> float: ) -> float:
""" """
Find the search frequency for each fish by checking which fish EODs are Find the search frequency band for each fish by checking which fish EODs
above the current EOD and finding a gap in them. are above the current EOD and finding a gap in them.
Parameters Parameters
---------- ----------
freq_temp : np.ndarray freq_temp : np.ndarray
Current EOD frequency array / the current fish of interest. Current EOD frequency array / the current fish of interest.
median_ids : np.ndarray median_ids : np.ndarray
Array of track IDs of the medians of all other fish in the current window. Array of track IDs of the medians of all other fish in the current
window.
median_freq : np.ndarray median_freq : np.ndarray
Array of median frequencies of all other fish in the current window. Array of median frequencies of all other fish in the current window.
config : ConfLoader config : ConfLoader
@ -421,7 +386,8 @@ def find_search_freq(
longest_search_window = search_windows[np.argmax(search_windows_lens)] longest_search_window = search_windows[np.argmax(search_windows_lens)]
search_freq = ( search_freq = (
longest_search_window[-1] - longest_search_window[0]) / 2 longest_search_window[-1] - longest_search_window[0]
) / 2
else: else:
search_freq = config.default_search_freq search_freq = config.default_search_freq
@ -431,7 +397,11 @@ def find_search_freq(
def main(datapath: str, plot: str) -> None: def main(datapath: str, plot: str) -> None:
assert plot in ["save", "show", "false"] assert plot in [
"save",
"show",
"false",
], "plot must be 'save', 'show' or 'false'"
# load raw file # load raw file
data = LoadData(datapath) data = LoadData(datapath)
@ -444,13 +414,15 @@ def main(datapath: str, plot: str) -> None:
window_overlap = config.overlap * data.raw_rate window_overlap = config.overlap * data.raw_rate
window_edge = config.edge * data.raw_rate window_edge = config.edge * data.raw_rate
# check if window duration is even # check if window duration and window ovelap is even, otherwise the half
# of the duration or window overlap would return a float, thus an
# invalid index
if window_duration % 2 == 0: if window_duration % 2 == 0:
window_duration = int(window_duration) window_duration = int(window_duration)
else: else:
raise ValueError("Window duration must be even.") raise ValueError("Window duration must be even.")
# check if window ovelap is even
if window_overlap % 2 == 0: if window_overlap % 2 == 0:
window_overlap = int(window_overlap) window_overlap = int(window_overlap)
else: else:
@ -460,16 +432,16 @@ def main(datapath: str, plot: str) -> None:
raw_time = np.arange(data.raw.shape[0]) / data.raw_rate raw_time = np.arange(data.raw.shape[0]) / data.raw_rate
# good chirp times for data: 2022-06-02-10_00 # good chirp times for data: 2022-06-02-10_00
t0 = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate window_start_seconds = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
dt = 60 * data.raw_rate window_duration_seconds = 60 * data.raw_rate
# t0 = 0 # t0 = 0
# dt = data.raw.shape[0] # dt = data.raw.shape[0]
# generate starting points of rolling window # generate starting points of rolling window
window_starts = np.arange( window_start_indices = np.arange(
t0, window_start_seconds,
t0 + dt, window_start_seconds + window_duration_seconds,
window_duration - (window_overlap + 2 * window_edge), window_duration - (window_overlap + 2 * window_edge),
dtype=int, dtype=int,
) )
@ -478,19 +450,20 @@ def main(datapath: str, plot: str) -> None:
multiwindow_chirps = [] multiwindow_chirps = []
multiwindow_ids = [] multiwindow_ids = []
for st, start_index in enumerate(window_starts): for st, window_start_index in enumerate(window_start_indices):
logger.info(f"Processing window {st} of {len(window_starts)}") logger.info(f"Processing window {st+1} of {len(window_start_indices)}")
# make t0 and dt window_start_seconds = window_start_index / data.raw_rate
t0 = start_index / data.raw_rate window_duration_seconds = window_duration / data.raw_rate
dt = window_duration / data.raw_rate
# set index window # set index window
stop_index = start_index + window_duration window_stop_index = window_start_index + window_duration
# calucate median of fish frequencies in window # calucate median of fish frequencies in window
median_freq, median_ids = freqmedian_allfish(data, t0, dt) median_freq, median_ids = window_median_all_track_ids(
data, window_start_seconds, window_duration_seconds
)
# iterate through all fish # iterate through all fish
for tr, track_id in enumerate( for tr, track_id in enumerate(
@ -500,48 +473,57 @@ def main(datapath: str, plot: str) -> None:
logger.debug(f"Processing track {tr} of {len(data.ids)}") logger.debug(f"Processing track {tr} of {len(data.ids)}")
# get index of track data in this time window # get index of track data in this time window
window_idx = np.arange(len(data.idx))[ track_window_index = np.arange(len(data.idx))[
(data.ident == track_id) (data.ident == track_id)
& (data.time[data.idx] >= t0) & (data.time[data.idx] >= window_start_seconds)
& (data.time[data.idx] <= (t0 + dt)) & (
data.time[data.idx]
<= (window_start_seconds + window_duration_seconds)
)
] ]
# get tracked frequencies and their times # get tracked frequencies and their times
freq_temp = data.freq[window_idx] current_frequencies = data.freq[track_window_index]
powers_temp = data.powers[window_idx, :] current_powers = data.powers[track_window_index, :]
# approximate sampling rate to compute expected durations if there # approximate sampling rate to compute expected durations if there
# is data available for this time window for this fish id # is data available for this time window for this fish id
track_samplerate = np.mean(1 / np.diff(data.time)) track_samplerate = np.mean(1 / np.diff(data.time))
expected_duration = ((t0 + dt) - t0) * track_samplerate expected_duration = (
(window_start_seconds + window_duration_seconds)
- window_start_seconds
) * track_samplerate
# check if tracked data available in this window # check if tracked data available in this window
if len(freq_temp) < expected_duration * 0.5: if len(current_frequencies) < expected_duration / 2:
logger.warning( logger.warning(
f"Track {track_id} has no data in window {st}, skipping." f"Track {track_id} has no data in window {st}, skipping."
) )
continue continue
# check if there are powers available in this window # check if there are powers available in this window
nanchecker = np.unique(np.isnan(powers_temp)) nanchecker = np.unique(np.isnan(current_powers))
if (len(nanchecker) == 1) and nanchecker[0]: if (len(nanchecker) == 1) and nanchecker[0] is True:
logger.warning( logger.warning(
f"No powers available for track {track_id} window {st}, \ f"No powers available for track {track_id} window {st},"
skipping." "skipping."
) )
continue continue
# find the strongest electrodes for the current fish in the current # find the strongest electrodes for the current fish in the current
# window # window
best_electrodes = np.argsort(np.nanmean(powers_temp, axis=0))[
-config.number_electrodes: best_electrode_index = np.argsort(
] np.nanmean(current_powers, axis=0)
)[-config.number_electrodes:]
# find a frequency above the baseline of the current fish in which # find a frequency above the baseline of the current fish in which
# no other fish is active to search for chirps there # no other fish is active to search for chirps there
search_freq = find_search_freq(
search_frequency = find_searchband(
config=config, config=config,
freq_temp=freq_temp, freq_temp=current_frequencies,
median_ids=median_ids, median_ids=median_ids,
data=data, data=data,
median_freq=median_freq, median_freq=median_freq,
@ -549,153 +531,219 @@ def main(datapath: str, plot: str) -> None:
# add all chirps that are detected on mulitple electrodes for one # add all chirps that are detected on mulitple electrodes for one
# fish fish in one window to this list # fish fish in one window to this list
multielectrode_chirps = [] multielectrode_chirps = []
# iterate through electrodes # iterate through electrodes
for el, electrode in enumerate(best_electrodes): for el, electrode_index in enumerate(best_electrode_index):
logger.debug( logger.debug(
f"Processing electrode {el} of {len(best_electrodes)}" f"Processing electrode {el+1} of "
f"{len(best_electrode_index)}"
) )
# LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------
# load region of interest of raw data file # load region of interest of raw data file
data_oi = data.raw[start_index:stop_index, :] current_raw_data = data.raw[
time_oi = raw_time[start_index:stop_index] window_start_index:window_stop_index, electrode_index
]
current_raw_time = raw_time[
window_start_index:window_stop_index
]
# EXTRACT FEATURES --------------------------------------------
# filter baseline and above # filter baseline and above
baseline, search = double_bandpass( baselineband, searchband = extract_frequency_bands(
data_oi[:, electrode], raw_data=current_raw_data,
data.raw_rate, samplerate=data.raw_rate,
freq_temp, baseline_track=current_frequencies,
search_freq, searchband_center=search_frequency,
config=config, minimal_bandwidth=config.minimal_bandwidth,
) )
# compute instantaneous frequency on narrow signal # compute envelope of baseline band to find dips
baseline_freq_time, baseline_freq = instantaneos_frequency( # in the baseline envelope
baseline, data.raw_rate
)
# compute envelopes
baseline_envelope_unfiltered = envelope( baseline_envelope_unfiltered = envelope(
baseline, data.raw_rate, config.envelope_cutoff signal=baselineband,
samplerate=data.raw_rate,
cutoff_frequency=config.baseline_envelope_cutoff,
)
# highpass filter baseline envelope to remove slower
# fluctuations e.g. due to motion envelope
baseline_envelope = bandpass_filter(
signal=baseline_envelope_unfiltered,
samplerate=data.raw_rate,
lowf=config.baseline_envelope_bandpass_lowf,
highf=config.baseline_envelope_bandpass_highf,
)
# highbass filter introduced filter effects, i.e. oscillations
# around peaks. Compute the envelope of the highpass filtered
# and inverted baseline envelope to remove these oscillations
baseline_envelope = -baseline_envelope
baseline_envelope = envelope(
signal=baseline_envelope,
samplerate=data.raw_rate,
cutoff_frequency=config.baseline_envelope_envelope_cutoff,
) )
# compute the envelope of the search band. Peaks in the search
# band envelope correspond to troughs in the baseline envelope
# during chirps
search_envelope = envelope( search_envelope = envelope(
search, data.raw_rate, config.envelope_cutoff signal=searchband,
samplerate=data.raw_rate,
cutoff_frequency=config.search_envelope_cutoff,
) )
# highpass filter envelopes # compute instantaneous frequency of the baseline band to find
baseline_envelope = highpass_filter( # anomalies during a chirp, i.e. a frequency jump upwards or
baseline_envelope_unfiltered, # sometimes downwards. We do not fully understand why the
data.raw_rate, # instantaneous frequency can also jump downwards during a
config.envelope_highpass_cutoff, # chirp. This phenomenon is only observed on chirps on a narrow
# filtered baseline such as the one we are working with.
(
baseline_frequency_time,
baseline_frequency,
) = instantaneous_frequency(
signal=baselineband,
samplerate=data.raw_rate,
smoothing_window=config.baseline_frequency_smoothing,
) )
# envelopes of filtered envelope of filtered baseline # bandpass filter the instantaneous frequency to remove slow
baseline_envelope = envelope( # fluctuations. Just as with the baseline envelope, we then
np.abs(baseline_envelope), # compute the envelope of the signal to remove the oscillations
data.raw_rate, # around the peaks
config.envelope_envelope_cutoff,
baseline_frequency_samplerate = np.mean(
np.diff(baseline_frequency_time)
)
baseline_frequency_filtered = np.abs(
baseline_frequency - np.median(baseline_frequency)
)
baseline_frequency_filtered = highpass_filter(
signal=baseline_frequency_filtered,
samplerate=baseline_frequency_samplerate,
cutoff=config.baseline_frequency_highpass_cutoff,
) )
# bandpass filter the instantaneous frequency to put it to 0 baseline_frequency_filtered = envelope(
inst_freq_filtered = bandpass_filter( signal=-baseline_frequency_filtered,
baseline_freq, samplerate=baseline_frequency_samplerate,
data.raw_rate, cutoff_frequency=config.baseline_frequency_envelope_cutoff,
lowf=config.instantaneous_lowf,
highf=config.instantaneous_highf,
) )
# CUT OFF OVERLAP --------------------------------------------- # CUT OFF OVERLAP ---------------------------------------------
# overwrite raw time to valid region, i.e. cut off snippet at # cut off snippet at start and end of each window to remove
# start and end of each window to remove filter effects # filter effects
valid = np.arange(
# get arrays with raw samplerate without edges
no_edges = np.arange(
int(window_edge), len(baseline_envelope) - int(window_edge) int(window_edge), len(baseline_envelope) - int(window_edge)
) )
baseline_envelope_unfiltered = baseline_envelope_unfiltered[ current_raw_time = current_raw_time[no_edges]
valid baselineband = baselineband[no_edges]
] searchband = searchband[no_edges]
baseline_envelope = baseline_envelope[valid] baseline_envelope = baseline_envelope[no_edges]
search_envelope = search_envelope[valid] search_envelope = search_envelope[no_edges]
# get inst freq valid snippet # get instantaneous frequency withoup edges
valid_t0 = int(window_edge) / data.raw_rate no_edges_t0 = int(window_edge) / data.raw_rate
valid_t1 = baseline_freq_time[-1] - ( no_edges_t1 = baseline_frequency_time[-1] - (
int(window_edge) / data.raw_rate int(window_edge) / data.raw_rate
) )
no_edges = (baseline_frequency_time >= no_edges_t0) & (
baseline_frequency_time <= no_edges_t1
)
inst_freq_filtered = inst_freq_filtered[ baseline_frequency_filtered = baseline_frequency_filtered[
(baseline_freq_time >= valid_t0) no_edges
& (baseline_freq_time <= valid_t1)
]
baseline_freq = baseline_freq[
(baseline_freq_time >= valid_t0)
& (baseline_freq_time <= valid_t1)
] ]
baseline_frequency = baseline_frequency[no_edges]
baseline_freq_time = ( baseline_frequency_time = (
baseline_freq_time[ baseline_frequency_time[no_edges] + window_start_seconds
(baseline_freq_time >= valid_t0)
& (baseline_freq_time <= valid_t1)
]
+ t0
) )
time_oi = time_oi[valid]
baseline = baseline[valid]
search = search[valid]
# NORMALIZE --------------------------------------------------- # NORMALIZE ---------------------------------------------------
# normalize all three feature arrays to the same range to make
# peak detection simpler
baseline_envelope = normalize([baseline_envelope])[0] baseline_envelope = normalize([baseline_envelope])[0]
search_envelope = normalize([search_envelope])[0] search_envelope = normalize([search_envelope])[0]
inst_freq_filtered = normalize([np.abs(inst_freq_filtered)])[0] baseline_frequency_filtered = normalize(
[baseline_frequency_filtered]
)[0]
# PEAK DETECTION ---------------------------------------------- # PEAK DETECTION ----------------------------------------------
prominence = config.prominence
# detect peaks baseline_enelope # detect peaks baseline_enelope
baseline_peaks, _ = find_peaks( baseline_peak_indices, _ = find_peaks(
baseline_envelope, prominence=prominence baseline_envelope, prominence=config.prominence
) )
# detect peaks search_envelope # detect peaks search_envelope
search_peaks, _ = find_peaks( search_peak_indices, _ = find_peaks(
search_envelope, prominence=prominence search_envelope, prominence=config.prominence
) )
# detect peaks inst_freq_filtered # detect peaks inst_freq_filtered
inst_freq_peaks, _ = find_peaks( frequency_peak_indices, _ = find_peaks(
inst_freq_filtered, prominence=prominence baseline_frequency_filtered, prominence=config.prominence
) )
# DETECT CHIRPS IN SEARCH WINDOW ------------------------------ # DETECT CHIRPS IN SEARCH WINDOW ------------------------------
# get the peak timestamps from the peak indices # get the peak timestamps from the peak indices
baseline_ts = time_oi[baseline_peaks] baseline_peak_timestamps = current_raw_time[
search_ts = time_oi[search_peaks] baseline_peak_indices
freq_ts = baseline_freq_time[inst_freq_peaks] ]
search_peak_timestamps = current_raw_time[search_peak_indices]
frequency_peak_timestamps = baseline_frequency_time[
frequency_peak_indices
]
# check if one list is empty and if so, skip to the next # check if one list is empty and if so, skip to the next
# electrode because a chirp cannot be detected if one is empty # electrode because a chirp cannot be detected if one is empty
if (
len(baseline_ts) == 0 one_feature_empty = (
or len(search_ts) == 0 len(baseline_peak_timestamps) == 0
or len(freq_ts) == 0 or len(search_peak_timestamps) == 0
): or len(frequency_peak_timestamps) == 0
)
if one_feature_empty:
continue continue
# group peak across feature arrays but only if they # group peak across feature arrays but only if they
# occur in all 3 feature arrays # occur in all 3 feature arrays
sublists = [
list(baseline_peak_timestamps),
list(search_peak_timestamps),
list(frequency_peak_timestamps),
]
singleelectrode_chirps = group_timestamps( singleelectrode_chirps = group_timestamps(
[list(baseline_ts), list(search_ts), list(freq_ts)], sublists=sublists,
3, at_least_in=3,
config.chirp_window_threshold, difference_threshold=config.chirp_window_threshold,
) )
# check it there are chirps detected after grouping, continue # check it there are chirps detected after grouping, continue
# with the loop if not # with the loop if not
if len(singleelectrode_chirps) == 0: if len(singleelectrode_chirps) == 0:
continue continue
@ -703,57 +751,62 @@ def main(datapath: str, plot: str) -> None:
multielectrode_chirps.append(singleelectrode_chirps) multielectrode_chirps.append(singleelectrode_chirps)
# only initialize the plotting buffer if chirps are detected # only initialize the plotting buffer if chirps are detected
if ( chirp_detected = (
(el == config.number_electrodes - 1) (el == config.number_electrodes - 1)
& (len(singleelectrode_chirps) > 0) & (len(singleelectrode_chirps) > 0)
& (plot in ["show", "save"]) & (plot in ["show", "save"])
): )
if chirp_detected:
logger.debug("Detected chirp, ititialize buffer ...") logger.debug("Detected chirp, ititialize buffer ...")
# save data to Buffer # save data to Buffer
buffer = PlotBuffer( buffer = PlotBuffer(
config=config, config=config,
t0=t0, t0=window_start_seconds,
dt=dt, dt=window_duration_seconds,
electrode=electrode, electrode=electrode_index,
track_id=track_id, track_id=track_id,
data=data, data=data,
time=time_oi, time=current_raw_time,
baseline=baseline, baseline=baselineband,
baseline_envelope=baseline_envelope, baseline_envelope=baseline_envelope,
baseline_peaks=baseline_peaks, baseline_peaks=baseline_peak_indices,
search=search, search=searchband,
search_envelope=search_envelope, search_envelope=search_envelope,
search_peaks=search_peaks, search_peaks=search_peak_indices,
frequency_time=baseline_freq_time, frequency_time=baseline_frequency_time,
frequency=baseline_freq, frequency=baseline_frequency,
frequency_filtered=inst_freq_filtered, frequency_filtered=baseline_frequency_filtered,
frequency_peaks=inst_freq_peaks, frequency_peaks=frequency_peak_indices,
) )
logger.debug("Buffer initialized!") logger.debug("Buffer initialized!")
logger.debug( logger.debug(
f"Processed all electrodes for fish {track_id} for this \ f"Processed all electrodes for fish {track_id} for this"
window, sorting chirps ..." "window, sorting chirps ..."
) )
# check if there are chirps detected in multiple electrodes and # check if there are chirps detected in multiple electrodes and
# continue the loop if not # continue the loop if not
if len(multielectrode_chirps) == 0: if len(multielectrode_chirps) == 0:
continue continue
# validate multielectrode chirps, i.e. check if they are # validate multielectrode chirps, i.e. check if they are
# detected in at least 'config.min_electrodes' electrodes # detected in at least 'config.min_electrodes' electrodes
multielectrode_chirps_validated = group_timestamps( multielectrode_chirps_validated = group_timestamps(
multielectrode_chirps, sublists=multielectrode_chirps,
config.minimum_electrodes, at_least_in=config.minimum_electrodes,
config.chirp_window_threshold difference_threshold=config.chirp_window_threshold,
) )
# add validated chirps to the list that tracks chirps across there # add validated chirps to the list that tracks chirps across there
# rolling time windows # rolling time windows
multiwindow_chirps.append(multielectrode_chirps_validated) multiwindow_chirps.append(multielectrode_chirps_validated)
multiwindow_ids.append(track_id) multiwindow_ids.append(track_id)
@ -763,6 +816,7 @@ def main(datapath: str, plot: str) -> None:
) )
# if chirps are detected and the plot flag is set, plot the # if chirps are detected and the plot flag is set, plot the
# chirps, otheswise try to delete the buffer if it exists # chirps, otheswise try to delete the buffer if it exists
if len(multielectrode_chirps_validated) > 0: if len(multielectrode_chirps_validated) > 0:
try: try:
buffer.plot_buffer(multielectrode_chirps_validated, plot) buffer.plot_buffer(multielectrode_chirps_validated, plot)
@ -776,27 +830,38 @@ def main(datapath: str, plot: str) -> None:
# flatten list of lists containing chirps and create # flatten list of lists containing chirps and create
# an array of fish ids that correspond to the chirps # an array of fish ids that correspond to the chirps
multiwindow_chirps_flat = [] multiwindow_chirps_flat = []
multiwindow_ids_flat = [] multiwindow_ids_flat = []
for tr in np.unique(multiwindow_ids): for track_id in np.unique(multiwindow_ids):
tr_index = np.asarray(multiwindow_ids) == tr
ts = flatten(list(compress(multiwindow_chirps, tr_index))) # get chirps for this fish and flatten the list
multiwindow_chirps_flat.extend(ts) current_track_bool = np.asarray(multiwindow_ids) == track_id
multiwindow_ids_flat.extend(list(np.ones_like(ts) * tr)) current_track_chirps = flatten(
list(compress(multiwindow_chirps, current_track_bool))
)
# add flattened chirps to the list
multiwindow_chirps_flat.extend(current_track_chirps)
multiwindow_ids_flat.extend(
list(np.ones_like(current_track_chirps) * track_id)
)
# purge duplicates, i.e. chirps that are very close to each other # purge duplicates, i.e. chirps that are very close to each other
# duplites arise due to overlapping windows # duplites arise due to overlapping windows
purged_chirps = [] purged_chirps = []
purged_ids = [] purged_ids = []
for tr in np.unique(multiwindow_ids_flat): for track_id in np.unique(multiwindow_ids_flat):
tr_chirps = np.asarray(multiwindow_chirps_flat)[ tr_chirps = np.asarray(multiwindow_chirps_flat)[
np.asarray(multiwindow_ids_flat) == tr] np.asarray(multiwindow_ids_flat) == track_id
]
if len(tr_chirps) > 0: if len(tr_chirps) > 0:
tr_chirps_purged = purge_duplicates( tr_chirps_purged = purge_duplicates(
tr_chirps, config.chirp_window_threshold tr_chirps, config.chirp_window_threshold
) )
purged_chirps.extend(list(tr_chirps_purged)) purged_chirps.extend(list(tr_chirps_purged))
purged_ids.extend(list(np.ones_like(tr_chirps_purged) * tr)) purged_ids.extend(list(np.ones_like(tr_chirps_purged) * track_id))
# sort chirps by time # sort chirps by time
purged_chirps = np.asarray(purged_chirps) purged_chirps = np.asarray(purged_chirps)

View File

@ -1,3 +1,4 @@
# directory setup
dataroot: "../data/" dataroot: "../data/"
outputdir: "../output/" outputdir: "../output/"
@ -10,30 +11,26 @@ edge: 0.25
number_electrodes: 3 number_electrodes: 3
minimum_electrodes: 2 minimum_electrodes: 2
# Search window bandwidth # Search window bandwidth and minimal baseline bandwidth
minimal_bandwidth: 10
# Cutoff frequency for envelope estimation by lowpass filter # Instantaneous frequency smoothing usint a gaussian kernel of this width
envelope_cutoff: 25 baseline_frequency_smoothing: 5
# Cutoff frequency for envelope highpass filter # Baseline processing parameters
envelope_highpass_cutoff: 3 baseline_envelope_cutoff: 25
baseline_envelope_bandpass_lowf: 4
baseline_envelope_bandpass_highf: 100
baseline_envelope_envelope_cutoff: 4
# Cutoff frequency for envelope of envelope # search envelope processing parameters
envelope_envelope_cutoff: 5 search_envelope_cutoff: 5
# Instantaneous frequency bandpass filter cutoff frequencies # Instantaneous frequency bandpass filter cutoff frequencies
instantaneous_lowf: 15 baseline_frequency_highpass_cutoff: 0.000005
instantaneous_highf: 8000 baseline_frequency_envelope_cutoff: 0.000005
# Baseline envelope peak detection parameters
# baseline_prominence_percentile: 90
# Search envelope peak detection parameters
# search_prominence_percentile: 90
# Instantaneous frequency peak detection parameters
# instantaneous_prominence_percentile: 90
# peak detecion parameters
prominence: 0.005 prominence: 0.005
# search freq parameter # search freq parameter

View File

@ -1,5 +1,59 @@
import numpy as np import numpy as np
from typing import List, Any from typing import List, Any
from scipy.ndimage import gaussian_filter1d
def instantaneous_frequency(
signal: np.ndarray,
samplerate: int,
smoothing_window: int,
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the instantaneous frequency of a signal that is approximately
sinusoidal and symmetric around 0.
Parameters
----------
signal : np.ndarray
Signal to compute the instantaneous frequency from.
samplerate : int
Samplerate of the signal.
smoothing_window : int
Window size for the gaussian filter.
Returns
-------
tuple[np.ndarray, np.ndarray]
"""
# calculate instantaneous frequency with zero crossings
roll_signal = np.roll(signal, shift=1)
time_signal = np.arange(len(signal)) / samplerate
period_index = np.arange(len(signal))[(roll_signal < 0) & (signal >= 0)][
1:-1
]
upper_bound = np.abs(signal[period_index])
lower_bound = np.abs(signal[period_index - 1])
upper_time = np.abs(time_signal[period_index])
lower_time = np.abs(time_signal[period_index - 1])
# create ratio
lower_ratio = lower_bound / (lower_bound + upper_bound)
# appy to time delta
time_delta = upper_time - lower_time
true_zero = lower_time + lower_ratio * time_delta
# create new time array
instantaneous_frequency_time = true_zero[:-1] + 0.5 * np.diff(true_zero)
# compute frequency
instantaneous_frequency = gaussian_filter1d(
1 / np.diff(true_zero), smoothing_window
)
return instantaneous_frequency_time, instantaneous_frequency
def purge_duplicates( def purge_duplicates(
@ -64,7 +118,7 @@ def purge_duplicates(
def group_timestamps( def group_timestamps(
sublists: List[List[float]], n: int, threshold: float sublists: List[List[float]], at_least_in: int, difference_threshold: float
) -> List[float]: ) -> List[float]:
""" """
Groups timestamps that are less than `threshold` milliseconds apart from Groups timestamps that are less than `threshold` milliseconds apart from
@ -100,7 +154,7 @@ def group_timestamps(
# Group timestamps that are less than threshold milliseconds apart # Group timestamps that are less than threshold milliseconds apart
for i in range(1, len(timestamps)): for i in range(1, len(timestamps)):
if timestamps[i] - timestamps[i - 1] < threshold: if timestamps[i] - timestamps[i - 1] < difference_threshold:
current_group.append(timestamps[i]) current_group.append(timestamps[i])
else: else:
groups.append(current_group) groups.append(current_group)
@ -111,7 +165,7 @@ def group_timestamps(
# Retain only groups that contain at least n timestamps # Retain only groups that contain at least n timestamps
final_groups = [] final_groups = []
for group in groups: for group in groups:
if len(group) >= n: if len(group) >= at_least_in:
final_groups.append(group) final_groups.append(group)
# Calculate the mean of each group # Calculate the mean of each group

View File

@ -3,8 +3,8 @@ import numpy as np
def bandpass_filter( def bandpass_filter(
data: np.ndarray, signal: np.ndarray,
rate: float, samplerate: float,
lowf: float, lowf: float,
highf: float, highf: float,
) -> np.ndarray: ) -> np.ndarray:
@ -12,7 +12,7 @@ def bandpass_filter(
Parameters Parameters
---------- ----------
data : np.ndarray signal : np.ndarray
The data to be filtered The data to be filtered
rate : float rate : float
The sampling rate The sampling rate
@ -26,21 +26,22 @@ def bandpass_filter(
np.ndarray np.ndarray
The filtered data The filtered data
""" """
sos = butter(2, (lowf, highf), "bandpass", fs=rate, output="sos") sos = butter(2, (lowf, highf), "bandpass", fs=samplerate, output="sos")
fdata = sosfiltfilt(sos, data) filtered_signal = sosfiltfilt(sos, signal)
return fdata
return filtered_signal
def highpass_filter( def highpass_filter(
data: np.ndarray, signal: np.ndarray,
rate: float, samplerate: float,
cutoff: float, cutoff: float,
) -> np.ndarray: ) -> np.ndarray:
"""Highpass filter a signal. """Highpass filter a signal.
Parameters Parameters
---------- ----------
data : np.ndarray signal : np.ndarray
The data to be filtered The data to be filtered
rate : float rate : float
The sampling rate The sampling rate
@ -52,14 +53,15 @@ def highpass_filter(
np.ndarray np.ndarray
The filtered data The filtered data
""" """
sos = butter(2, cutoff, "highpass", fs=rate, output="sos") sos = butter(2, cutoff, "highpass", fs=samplerate, output="sos")
fdata = sosfiltfilt(sos, data) filtered_signal = sosfiltfilt(sos, signal)
return fdata
return filtered_signal
def lowpass_filter( def lowpass_filter(
data: np.ndarray, signal: np.ndarray,
rate: float, samplerate: float,
cutoff: float cutoff: float
) -> np.ndarray: ) -> np.ndarray:
"""Lowpass filter a signal. """Lowpass filter a signal.
@ -78,21 +80,25 @@ def lowpass_filter(
np.ndarray np.ndarray
The filtered data The filtered data
""" """
sos = butter(2, cutoff, "lowpass", fs=rate, output="sos") sos = butter(2, cutoff, "lowpass", fs=samplerate, output="sos")
fdata = sosfiltfilt(sos, data) filtered_signal = sosfiltfilt(sos, signal)
return fdata
return filtered_signal
def envelope(data: np.ndarray, rate: float, freq: float) -> np.ndarray:
def envelope(signal: np.ndarray,
samplerate: float,
cutoff_frequency: float
) -> np.ndarray:
"""Calculate the envelope of a signal using a lowpass filter. """Calculate the envelope of a signal using a lowpass filter.
Parameters Parameters
---------- ----------
data : np.ndarray signal : np.ndarray
The signal to calculate the envelope of The signal to calculate the envelope of
rate : float samplingrate : float
The sampling rate of the signal The sampling rate of the signal
freq : float cutoff_frequency : float
The cutoff frequency of the lowpass filter The cutoff frequency of the lowpass filter
Returns Returns
@ -100,6 +106,7 @@ def envelope(data: np.ndarray, rate: float, freq: float) -> np.ndarray:
np.ndarray np.ndarray
The envelope of the signal The envelope of the signal
""" """
sos = butter(2, freq, "lowpass", fs=rate, output="sos") sos = butter(2, cutoff_frequency, "lowpass", fs=samplerate, output="sos")
envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(data)) envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(signal))
return envelope return envelope