refactoring finished for now

This commit is contained in:
weygoldt 2023-01-20 13:56:26 +01:00
parent 9985686d53
commit ddf7bd545a
4 changed files with 401 additions and 278 deletions

525
code/chirpdetection.py Normal file → Executable file
View File

@ -1,20 +1,22 @@
from itertools import compress
from dataclasses import dataclass
from IPython import embed
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
from thunderfish.dataloader import DataLoader
from thunderfish.powerspectrum import spectrogram, decibel
from sklearn.preprocessing import normalize
from modules.filters import bandpass_filter, envelope, highpass_filter
from modules.filehandling import ConfLoader, LoadData, make_outputdir
from modules.datahandling import flatten, purge_duplicates, group_timestamps
from modules.plotstyle import PlotStyle
from modules.logger import makeLogger
from modules.datahandling import (
flatten,
purge_duplicates,
group_timestamps,
instantaneous_frequency,
)
logger = makeLogger(__name__)
@ -28,6 +30,7 @@ class PlotBuffer:
Buffer to save data that is created in the main detection loop
and plot it outside the detecion loop.
"""
config: ConfLoader
t0: float
dt: float
@ -85,8 +88,9 @@ class PlotBuffer:
plot_spectrogram(axs[0], data_oi, self.data.raw_rate, self.t0)
for chirp in chirps:
axs[0].scatter(chirp, np.median(self.frequency),
c=ps.black, marker="x")
axs[0].scatter(
chirp, np.median(self.frequency), c=ps.black, marker="x"
)
# plot waveform of filtered signal
axs[1].plot(self.time, self.baseline, c=ps.green)
@ -94,7 +98,7 @@ class PlotBuffer:
# plot waveform of filtered search signal
axs[2].plot(self.time, self.search)
# plot baseline instantaneos frequency
# plot baseline instantaneous frequency
axs[3].plot(self.frequency_time, self.frequency)
# plot filtered and rectified envelope
@ -145,7 +149,7 @@ class PlotBuffer:
def plot_spectrogram(
axis, signal: np.ndarray, samplerate: float, t0: float
axis, signal: np.ndarray, samplerate: float, window_start_seconds: float
) -> None:
"""
Plot a spectrogram of a signal.
@ -158,7 +162,7 @@ def plot_spectrogram(
Signal to plot the spectrogram from.
samplerate : float
Samplerate of the signal.
t0 : float
window_start_seconds : float
Start time of the signal.
"""
@ -172,73 +176,26 @@ def plot_spectrogram(
overlap_frac=0.5,
)
# axis.pcolormesh(
# spec_times + t0,
# spec_freqs,
# decibel(spec_power),
# )
axis.imshow(
decibel(spec_power),
extent=[spec_times[0] + t0, spec_times[-1] +
t0, spec_freqs[0], spec_freqs[-1]],
extent=[
spec_times[0] + window_start_seconds,
spec_times[-1] + window_start_seconds,
spec_freqs[0],
spec_freqs[-1],
],
aspect="auto",
origin="lower",
interpolation="gaussian",
)
def instantaneos_frequency(
signal: np.ndarray, samplerate: int
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the instantaneous frequency of a signal.
Parameters
----------
signal : np.ndarray
Signal to compute the instantaneous frequency from.
samplerate : int
Samplerate of the signal.
Returns
-------
tuple[np.ndarray, np.ndarray]
"""
# calculate instantaneos frequency with zero crossings
roll_signal = np.roll(signal, shift=1)
time_signal = np.arange(len(signal)) / samplerate
period_index = np.arange(len(signal))[(roll_signal < 0) & (signal >= 0)][
1:-1
]
upper_bound = np.abs(signal[period_index])
lower_bound = np.abs(signal[period_index - 1])
upper_time = np.abs(time_signal[period_index])
lower_time = np.abs(time_signal[period_index - 1])
# create ratio
lower_ratio = lower_bound / (lower_bound + upper_bound)
# appy to time delta
time_delta = upper_time - lower_time
true_zero = lower_time + lower_ratio * time_delta
# create new time array
inst_freq_time = true_zero[:-1] + 0.5 * np.diff(true_zero)
# compute frequency
inst_freq = gaussian_filter1d(1 / np.diff(true_zero), 5)
return inst_freq_time, inst_freq
def double_bandpass(
data: DataLoader,
def extract_frequency_bands(
raw_data: np.ndarray,
samplerate: int,
freqs: np.ndarray,
search_freq: float,
config: ConfLoader
baseline_track: np.ndarray,
searchband_center: float,
minimal_bandwidth: float,
) -> tuple[np.ndarray, np.ndarray]:
"""
Apply a bandpass filter to the baseline of a signal and a second bandpass
@ -246,14 +203,16 @@ def double_bandpass(
Parameters
----------
data : DataLoader
raw_data : np.ndarray
Data to apply the filter to.
samplerate : int
Samplerate of the signal.
freqs : np.ndarray
baseline_track : np.ndarray
Tracked fundamental frequencies of the signal.
search_freq : float
searchband_center: float
Frequency to search for above or below the baseline.
minimal_bandwidth : float
Minimal bandwidth of the filter.
Returns
-------
@ -261,28 +220,30 @@ def double_bandpass(
"""
# compute boundaries to filter baseline
q25, q50, q75 = np.percentile(freqs, [25, 50, 75])
q25, q50, q75 = np.percentile(baseline_track, [25, 50, 75])
# check if percentile delta is too small
if q75 - q25 < 5:
median = np.median(freqs)
q25, q75 = median - 2.5, median + 2.5
if q75 - q25 < 10:
q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2
# filter baseline
filtered_baseline = bandpass_filter(data, samplerate, lowf=q25, highf=q75)
filtered_baseline = bandpass_filter(
raw_data, samplerate, lowf=q25, highf=q75
)
# filter search area
filtered_search_freq = bandpass_filter(
data, samplerate,
lowf=search_freq + q50 - config.search_bandwidth / 2,
highf=search_freq + q50 + config.search_bandwidth / 2
raw_data,
samplerate,
lowf=searchband_center + q50 - minimal_bandwidth / 2,
highf=searchband_center + q50 + minimal_bandwidth / 2,
)
return filtered_baseline, filtered_search_freq
def freqmedian_allfish(
data: LoadData, t0: float, dt: float
def window_median_all_track_ids(
data: LoadData, window_start_seconds: float, window_duration_seconds: float
) -> tuple[float, list[int]]:
"""
Calculate the median frequency of all fish in a given time window.
@ -291,9 +252,9 @@ def freqmedian_allfish(
----------
data : LoadData
Data to calculate the median frequency from.
t0 : float
window_start_seconds : float
Start time of the window.
dt : float
window_duration_seconds : float
Duration of the window.
Returns
@ -308,8 +269,11 @@ def freqmedian_allfish(
for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
window_idx = np.arange(len(data.idx))[
(data.ident == track_id)
& (data.time[data.idx] >= t0)
& (data.time[data.idx] <= (t0 + dt))
& (data.time[data.idx] >= window_start_seconds)
& (
data.time[data.idx]
<= (window_start_seconds + window_duration_seconds)
)
]
if len(data.freq[window_idx]) > 0:
@ -323,7 +287,7 @@ def freqmedian_allfish(
return median_freq, track_ids
def find_search_freq(
def find_searchband(
freq_temp: np.ndarray,
median_ids: np.ndarray,
median_freq: np.ndarray,
@ -331,15 +295,16 @@ def find_search_freq(
data: LoadData,
) -> float:
"""
Find the search frequency for each fish by checking which fish EODs are
above the current EOD and finding a gap in them.
Find the search frequency band for each fish by checking which fish EODs
are above the current EOD and finding a gap in them.
Parameters
----------
freq_temp : np.ndarray
Current EOD frequency array / the current fish of interest.
median_ids : np.ndarray
Array of track IDs of the medians of all other fish in the current window.
Array of track IDs of the medians of all other fish in the current
window.
median_freq : np.ndarray
Array of median frequencies of all other fish in the current window.
config : ConfLoader
@ -421,7 +386,8 @@ def find_search_freq(
longest_search_window = search_windows[np.argmax(search_windows_lens)]
search_freq = (
longest_search_window[-1] - longest_search_window[0]) / 2
longest_search_window[-1] - longest_search_window[0]
) / 2
else:
search_freq = config.default_search_freq
@ -431,7 +397,11 @@ def find_search_freq(
def main(datapath: str, plot: str) -> None:
assert plot in ["save", "show", "false"]
assert plot in [
"save",
"show",
"false",
], "plot must be 'save', 'show' or 'false'"
# load raw file
data = LoadData(datapath)
@ -444,13 +414,15 @@ def main(datapath: str, plot: str) -> None:
window_overlap = config.overlap * data.raw_rate
window_edge = config.edge * data.raw_rate
# check if window duration is even
# check if window duration and window ovelap is even, otherwise the half
# of the duration or window overlap would return a float, thus an
# invalid index
if window_duration % 2 == 0:
window_duration = int(window_duration)
else:
raise ValueError("Window duration must be even.")
# check if window ovelap is even
if window_overlap % 2 == 0:
window_overlap = int(window_overlap)
else:
@ -460,16 +432,16 @@ def main(datapath: str, plot: str) -> None:
raw_time = np.arange(data.raw.shape[0]) / data.raw_rate
# good chirp times for data: 2022-06-02-10_00
t0 = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
dt = 60 * data.raw_rate
window_start_seconds = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
window_duration_seconds = 60 * data.raw_rate
# t0 = 0
# dt = data.raw.shape[0]
# generate starting points of rolling window
window_starts = np.arange(
t0,
t0 + dt,
window_start_indices = np.arange(
window_start_seconds,
window_start_seconds + window_duration_seconds,
window_duration - (window_overlap + 2 * window_edge),
dtype=int,
)
@ -478,19 +450,20 @@ def main(datapath: str, plot: str) -> None:
multiwindow_chirps = []
multiwindow_ids = []
for st, start_index in enumerate(window_starts):
for st, window_start_index in enumerate(window_start_indices):
logger.info(f"Processing window {st} of {len(window_starts)}")
logger.info(f"Processing window {st+1} of {len(window_start_indices)}")
# make t0 and dt
t0 = start_index / data.raw_rate
dt = window_duration / data.raw_rate
window_start_seconds = window_start_index / data.raw_rate
window_duration_seconds = window_duration / data.raw_rate
# set index window
stop_index = start_index + window_duration
window_stop_index = window_start_index + window_duration
# calucate median of fish frequencies in window
median_freq, median_ids = freqmedian_allfish(data, t0, dt)
median_freq, median_ids = window_median_all_track_ids(
data, window_start_seconds, window_duration_seconds
)
# iterate through all fish
for tr, track_id in enumerate(
@ -500,48 +473,57 @@ def main(datapath: str, plot: str) -> None:
logger.debug(f"Processing track {tr} of {len(data.ids)}")
# get index of track data in this time window
window_idx = np.arange(len(data.idx))[
track_window_index = np.arange(len(data.idx))[
(data.ident == track_id)
& (data.time[data.idx] >= t0)
& (data.time[data.idx] <= (t0 + dt))
& (data.time[data.idx] >= window_start_seconds)
& (
data.time[data.idx]
<= (window_start_seconds + window_duration_seconds)
)
]
# get tracked frequencies and their times
freq_temp = data.freq[window_idx]
powers_temp = data.powers[window_idx, :]
current_frequencies = data.freq[track_window_index]
current_powers = data.powers[track_window_index, :]
# approximate sampling rate to compute expected durations if there
# is data available for this time window for this fish id
track_samplerate = np.mean(1 / np.diff(data.time))
expected_duration = ((t0 + dt) - t0) * track_samplerate
expected_duration = (
(window_start_seconds + window_duration_seconds)
- window_start_seconds
) * track_samplerate
# check if tracked data available in this window
if len(freq_temp) < expected_duration * 0.5:
if len(current_frequencies) < expected_duration / 2:
logger.warning(
f"Track {track_id} has no data in window {st}, skipping."
)
continue
# check if there are powers available in this window
nanchecker = np.unique(np.isnan(powers_temp))
if (len(nanchecker) == 1) and nanchecker[0]:
nanchecker = np.unique(np.isnan(current_powers))
if (len(nanchecker) == 1) and nanchecker[0] is True:
logger.warning(
f"No powers available for track {track_id} window {st}, \
skipping."
f"No powers available for track {track_id} window {st},"
"skipping."
)
continue
# find the strongest electrodes for the current fish in the current
# window
best_electrodes = np.argsort(np.nanmean(powers_temp, axis=0))[
-config.number_electrodes:
]
best_electrode_index = np.argsort(
np.nanmean(current_powers, axis=0)
)[-config.number_electrodes:]
# find a frequency above the baseline of the current fish in which
# no other fish is active to search for chirps there
search_freq = find_search_freq(
search_frequency = find_searchband(
config=config,
freq_temp=freq_temp,
freq_temp=current_frequencies,
median_ids=median_ids,
data=data,
median_freq=median_freq,
@ -549,153 +531,219 @@ def main(datapath: str, plot: str) -> None:
# add all chirps that are detected on mulitple electrodes for one
# fish fish in one window to this list
multielectrode_chirps = []
# iterate through electrodes
for el, electrode in enumerate(best_electrodes):
for el, electrode_index in enumerate(best_electrode_index):
logger.debug(
f"Processing electrode {el} of {len(best_electrodes)}"
f"Processing electrode {el+1} of "
f"{len(best_electrode_index)}"
)
# LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------
# load region of interest of raw data file
data_oi = data.raw[start_index:stop_index, :]
time_oi = raw_time[start_index:stop_index]
current_raw_data = data.raw[
window_start_index:window_stop_index, electrode_index
]
current_raw_time = raw_time[
window_start_index:window_stop_index
]
# EXTRACT FEATURES --------------------------------------------
# filter baseline and above
baseline, search = double_bandpass(
data_oi[:, electrode],
data.raw_rate,
freq_temp,
search_freq,
config=config,
baselineband, searchband = extract_frequency_bands(
raw_data=current_raw_data,
samplerate=data.raw_rate,
baseline_track=current_frequencies,
searchband_center=search_frequency,
minimal_bandwidth=config.minimal_bandwidth,
)
# compute instantaneous frequency on narrow signal
baseline_freq_time, baseline_freq = instantaneos_frequency(
baseline, data.raw_rate
)
# compute envelope of baseline band to find dips
# in the baseline envelope
# compute envelopes
baseline_envelope_unfiltered = envelope(
baseline, data.raw_rate, config.envelope_cutoff
signal=baselineband,
samplerate=data.raw_rate,
cutoff_frequency=config.baseline_envelope_cutoff,
)
# highpass filter baseline envelope to remove slower
# fluctuations e.g. due to motion envelope
baseline_envelope = bandpass_filter(
signal=baseline_envelope_unfiltered,
samplerate=data.raw_rate,
lowf=config.baseline_envelope_bandpass_lowf,
highf=config.baseline_envelope_bandpass_highf,
)
# highbass filter introduced filter effects, i.e. oscillations
# around peaks. Compute the envelope of the highpass filtered
# and inverted baseline envelope to remove these oscillations
baseline_envelope = -baseline_envelope
baseline_envelope = envelope(
signal=baseline_envelope,
samplerate=data.raw_rate,
cutoff_frequency=config.baseline_envelope_envelope_cutoff,
)
# compute the envelope of the search band. Peaks in the search
# band envelope correspond to troughs in the baseline envelope
# during chirps
search_envelope = envelope(
search, data.raw_rate, config.envelope_cutoff
signal=searchband,
samplerate=data.raw_rate,
cutoff_frequency=config.search_envelope_cutoff,
)
# highpass filter envelopes
baseline_envelope = highpass_filter(
baseline_envelope_unfiltered,
data.raw_rate,
config.envelope_highpass_cutoff,
# compute instantaneous frequency of the baseline band to find
# anomalies during a chirp, i.e. a frequency jump upwards or
# sometimes downwards. We do not fully understand why the
# instantaneous frequency can also jump downwards during a
# chirp. This phenomenon is only observed on chirps on a narrow
# filtered baseline such as the one we are working with.
(
baseline_frequency_time,
baseline_frequency,
) = instantaneous_frequency(
signal=baselineband,
samplerate=data.raw_rate,
smoothing_window=config.baseline_frequency_smoothing,
)
# envelopes of filtered envelope of filtered baseline
baseline_envelope = envelope(
np.abs(baseline_envelope),
data.raw_rate,
config.envelope_envelope_cutoff,
# bandpass filter the instantaneous frequency to remove slow
# fluctuations. Just as with the baseline envelope, we then
# compute the envelope of the signal to remove the oscillations
# around the peaks
baseline_frequency_samplerate = np.mean(
np.diff(baseline_frequency_time)
)
baseline_frequency_filtered = np.abs(
baseline_frequency - np.median(baseline_frequency)
)
# bandpass filter the instantaneous frequency to put it to 0
inst_freq_filtered = bandpass_filter(
baseline_freq,
data.raw_rate,
lowf=config.instantaneous_lowf,
highf=config.instantaneous_highf,
baseline_frequency_filtered = highpass_filter(
signal=baseline_frequency_filtered,
samplerate=baseline_frequency_samplerate,
cutoff=config.baseline_frequency_highpass_cutoff,
)
baseline_frequency_filtered = envelope(
signal=-baseline_frequency_filtered,
samplerate=baseline_frequency_samplerate,
cutoff_frequency=config.baseline_frequency_envelope_cutoff,
)
# CUT OFF OVERLAP ---------------------------------------------
# overwrite raw time to valid region, i.e. cut off snippet at
# start and end of each window to remove filter effects
valid = np.arange(
# cut off snippet at start and end of each window to remove
# filter effects
# get arrays with raw samplerate without edges
no_edges = np.arange(
int(window_edge), len(baseline_envelope) - int(window_edge)
)
baseline_envelope_unfiltered = baseline_envelope_unfiltered[
valid
]
baseline_envelope = baseline_envelope[valid]
search_envelope = search_envelope[valid]
current_raw_time = current_raw_time[no_edges]
baselineband = baselineband[no_edges]
searchband = searchband[no_edges]
baseline_envelope = baseline_envelope[no_edges]
search_envelope = search_envelope[no_edges]
# get inst freq valid snippet
valid_t0 = int(window_edge) / data.raw_rate
valid_t1 = baseline_freq_time[-1] - (
# get instantaneous frequency withoup edges
no_edges_t0 = int(window_edge) / data.raw_rate
no_edges_t1 = baseline_frequency_time[-1] - (
int(window_edge) / data.raw_rate
)
no_edges = (baseline_frequency_time >= no_edges_t0) & (
baseline_frequency_time <= no_edges_t1
)
inst_freq_filtered = inst_freq_filtered[
(baseline_freq_time >= valid_t0)
& (baseline_freq_time <= valid_t1)
]
baseline_freq = baseline_freq[
(baseline_freq_time >= valid_t0)
& (baseline_freq_time <= valid_t1)
]
baseline_freq_time = (
baseline_freq_time[
(baseline_freq_time >= valid_t0)
& (baseline_freq_time <= valid_t1)
baseline_frequency_filtered = baseline_frequency_filtered[
no_edges
]
+ t0
baseline_frequency = baseline_frequency[no_edges]
baseline_frequency_time = (
baseline_frequency_time[no_edges] + window_start_seconds
)
time_oi = time_oi[valid]
baseline = baseline[valid]
search = search[valid]
# NORMALIZE ---------------------------------------------------
# normalize all three feature arrays to the same range to make
# peak detection simpler
baseline_envelope = normalize([baseline_envelope])[0]
search_envelope = normalize([search_envelope])[0]
inst_freq_filtered = normalize([np.abs(inst_freq_filtered)])[0]
baseline_frequency_filtered = normalize(
[baseline_frequency_filtered]
)[0]
# PEAK DETECTION ----------------------------------------------
prominence = config.prominence
# detect peaks baseline_enelope
baseline_peaks, _ = find_peaks(
baseline_envelope, prominence=prominence
baseline_peak_indices, _ = find_peaks(
baseline_envelope, prominence=config.prominence
)
# detect peaks search_envelope
search_peaks, _ = find_peaks(
search_envelope, prominence=prominence
search_peak_indices, _ = find_peaks(
search_envelope, prominence=config.prominence
)
# detect peaks inst_freq_filtered
inst_freq_peaks, _ = find_peaks(
inst_freq_filtered, prominence=prominence
frequency_peak_indices, _ = find_peaks(
baseline_frequency_filtered, prominence=config.prominence
)
# DETECT CHIRPS IN SEARCH WINDOW ------------------------------
# get the peak timestamps from the peak indices
baseline_ts = time_oi[baseline_peaks]
search_ts = time_oi[search_peaks]
freq_ts = baseline_freq_time[inst_freq_peaks]
baseline_peak_timestamps = current_raw_time[
baseline_peak_indices
]
search_peak_timestamps = current_raw_time[search_peak_indices]
frequency_peak_timestamps = baseline_frequency_time[
frequency_peak_indices
]
# check if one list is empty and if so, skip to the next
# electrode because a chirp cannot be detected if one is empty
if (
len(baseline_ts) == 0
or len(search_ts) == 0
or len(freq_ts) == 0
):
one_feature_empty = (
len(baseline_peak_timestamps) == 0
or len(search_peak_timestamps) == 0
or len(frequency_peak_timestamps) == 0
)
if one_feature_empty:
continue
# group peak across feature arrays but only if they
# occur in all 3 feature arrays
sublists = [
list(baseline_peak_timestamps),
list(search_peak_timestamps),
list(frequency_peak_timestamps),
]
singleelectrode_chirps = group_timestamps(
[list(baseline_ts), list(search_ts), list(freq_ts)],
3,
config.chirp_window_threshold,
sublists=sublists,
at_least_in=3,
difference_threshold=config.chirp_window_threshold,
)
# check it there are chirps detected after grouping, continue
# with the loop if not
if len(singleelectrode_chirps) == 0:
continue
@ -703,57 +751,62 @@ def main(datapath: str, plot: str) -> None:
multielectrode_chirps.append(singleelectrode_chirps)
# only initialize the plotting buffer if chirps are detected
if (
chirp_detected = (
(el == config.number_electrodes - 1)
& (len(singleelectrode_chirps) > 0)
& (plot in ["show", "save"])
):
)
if chirp_detected:
logger.debug("Detected chirp, ititialize buffer ...")
# save data to Buffer
buffer = PlotBuffer(
config=config,
t0=t0,
dt=dt,
electrode=electrode,
t0=window_start_seconds,
dt=window_duration_seconds,
electrode=electrode_index,
track_id=track_id,
data=data,
time=time_oi,
baseline=baseline,
time=current_raw_time,
baseline=baselineband,
baseline_envelope=baseline_envelope,
baseline_peaks=baseline_peaks,
search=search,
baseline_peaks=baseline_peak_indices,
search=searchband,
search_envelope=search_envelope,
search_peaks=search_peaks,
frequency_time=baseline_freq_time,
frequency=baseline_freq,
frequency_filtered=inst_freq_filtered,
frequency_peaks=inst_freq_peaks,
search_peaks=search_peak_indices,
frequency_time=baseline_frequency_time,
frequency=baseline_frequency,
frequency_filtered=baseline_frequency_filtered,
frequency_peaks=frequency_peak_indices,
)
logger.debug("Buffer initialized!")
logger.debug(
f"Processed all electrodes for fish {track_id} for this \
window, sorting chirps ..."
f"Processed all electrodes for fish {track_id} for this"
"window, sorting chirps ..."
)
# check if there are chirps detected in multiple electrodes and
# continue the loop if not
if len(multielectrode_chirps) == 0:
continue
# validate multielectrode chirps, i.e. check if they are
# detected in at least 'config.min_electrodes' electrodes
multielectrode_chirps_validated = group_timestamps(
multielectrode_chirps,
config.minimum_electrodes,
config.chirp_window_threshold
sublists=multielectrode_chirps,
at_least_in=config.minimum_electrodes,
difference_threshold=config.chirp_window_threshold,
)
# add validated chirps to the list that tracks chirps across there
# rolling time windows
multiwindow_chirps.append(multielectrode_chirps_validated)
multiwindow_ids.append(track_id)
@ -763,6 +816,7 @@ def main(datapath: str, plot: str) -> None:
)
# if chirps are detected and the plot flag is set, plot the
# chirps, otheswise try to delete the buffer if it exists
if len(multielectrode_chirps_validated) > 0:
try:
buffer.plot_buffer(multielectrode_chirps_validated, plot)
@ -776,27 +830,38 @@ def main(datapath: str, plot: str) -> None:
# flatten list of lists containing chirps and create
# an array of fish ids that correspond to the chirps
multiwindow_chirps_flat = []
multiwindow_ids_flat = []
for tr in np.unique(multiwindow_ids):
tr_index = np.asarray(multiwindow_ids) == tr
ts = flatten(list(compress(multiwindow_chirps, tr_index)))
multiwindow_chirps_flat.extend(ts)
multiwindow_ids_flat.extend(list(np.ones_like(ts) * tr))
for track_id in np.unique(multiwindow_ids):
# get chirps for this fish and flatten the list
current_track_bool = np.asarray(multiwindow_ids) == track_id
current_track_chirps = flatten(
list(compress(multiwindow_chirps, current_track_bool))
)
# add flattened chirps to the list
multiwindow_chirps_flat.extend(current_track_chirps)
multiwindow_ids_flat.extend(
list(np.ones_like(current_track_chirps) * track_id)
)
# purge duplicates, i.e. chirps that are very close to each other
# duplites arise due to overlapping windows
purged_chirps = []
purged_ids = []
for tr in np.unique(multiwindow_ids_flat):
for track_id in np.unique(multiwindow_ids_flat):
tr_chirps = np.asarray(multiwindow_chirps_flat)[
np.asarray(multiwindow_ids_flat) == tr]
np.asarray(multiwindow_ids_flat) == track_id
]
if len(tr_chirps) > 0:
tr_chirps_purged = purge_duplicates(
tr_chirps, config.chirp_window_threshold
)
purged_chirps.extend(list(tr_chirps_purged))
purged_ids.extend(list(np.ones_like(tr_chirps_purged) * tr))
purged_ids.extend(list(np.ones_like(tr_chirps_purged) * track_id))
# sort chirps by time
purged_chirps = np.asarray(purged_chirps)

View File

@ -1,3 +1,4 @@
# directory setup
dataroot: "../data/"
outputdir: "../output/"
@ -10,30 +11,26 @@ edge: 0.25
number_electrodes: 3
minimum_electrodes: 2
# Search window bandwidth
# Search window bandwidth and minimal baseline bandwidth
minimal_bandwidth: 10
# Cutoff frequency for envelope estimation by lowpass filter
envelope_cutoff: 25
# Instantaneous frequency smoothing usint a gaussian kernel of this width
baseline_frequency_smoothing: 5
# Cutoff frequency for envelope highpass filter
envelope_highpass_cutoff: 3
# Baseline processing parameters
baseline_envelope_cutoff: 25
baseline_envelope_bandpass_lowf: 4
baseline_envelope_bandpass_highf: 100
baseline_envelope_envelope_cutoff: 4
# Cutoff frequency for envelope of envelope
envelope_envelope_cutoff: 5
# search envelope processing parameters
search_envelope_cutoff: 5
# Instantaneous frequency bandpass filter cutoff frequencies
instantaneous_lowf: 15
instantaneous_highf: 8000
# Baseline envelope peak detection parameters
# baseline_prominence_percentile: 90
# Search envelope peak detection parameters
# search_prominence_percentile: 90
# Instantaneous frequency peak detection parameters
# instantaneous_prominence_percentile: 90
baseline_frequency_highpass_cutoff: 0.000005
baseline_frequency_envelope_cutoff: 0.000005
# peak detecion parameters
prominence: 0.005
# search freq parameter

View File

@ -1,5 +1,59 @@
import numpy as np
from typing import List, Any
from scipy.ndimage import gaussian_filter1d
def instantaneous_frequency(
signal: np.ndarray,
samplerate: int,
smoothing_window: int,
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the instantaneous frequency of a signal that is approximately
sinusoidal and symmetric around 0.
Parameters
----------
signal : np.ndarray
Signal to compute the instantaneous frequency from.
samplerate : int
Samplerate of the signal.
smoothing_window : int
Window size for the gaussian filter.
Returns
-------
tuple[np.ndarray, np.ndarray]
"""
# calculate instantaneous frequency with zero crossings
roll_signal = np.roll(signal, shift=1)
time_signal = np.arange(len(signal)) / samplerate
period_index = np.arange(len(signal))[(roll_signal < 0) & (signal >= 0)][
1:-1
]
upper_bound = np.abs(signal[period_index])
lower_bound = np.abs(signal[period_index - 1])
upper_time = np.abs(time_signal[period_index])
lower_time = np.abs(time_signal[period_index - 1])
# create ratio
lower_ratio = lower_bound / (lower_bound + upper_bound)
# appy to time delta
time_delta = upper_time - lower_time
true_zero = lower_time + lower_ratio * time_delta
# create new time array
instantaneous_frequency_time = true_zero[:-1] + 0.5 * np.diff(true_zero)
# compute frequency
instantaneous_frequency = gaussian_filter1d(
1 / np.diff(true_zero), smoothing_window
)
return instantaneous_frequency_time, instantaneous_frequency
def purge_duplicates(
@ -64,7 +118,7 @@ def purge_duplicates(
def group_timestamps(
sublists: List[List[float]], n: int, threshold: float
sublists: List[List[float]], at_least_in: int, difference_threshold: float
) -> List[float]:
"""
Groups timestamps that are less than `threshold` milliseconds apart from
@ -100,7 +154,7 @@ def group_timestamps(
# Group timestamps that are less than threshold milliseconds apart
for i in range(1, len(timestamps)):
if timestamps[i] - timestamps[i - 1] < threshold:
if timestamps[i] - timestamps[i - 1] < difference_threshold:
current_group.append(timestamps[i])
else:
groups.append(current_group)
@ -111,7 +165,7 @@ def group_timestamps(
# Retain only groups that contain at least n timestamps
final_groups = []
for group in groups:
if len(group) >= n:
if len(group) >= at_least_in:
final_groups.append(group)
# Calculate the mean of each group

View File

@ -3,8 +3,8 @@ import numpy as np
def bandpass_filter(
data: np.ndarray,
rate: float,
signal: np.ndarray,
samplerate: float,
lowf: float,
highf: float,
) -> np.ndarray:
@ -12,7 +12,7 @@ def bandpass_filter(
Parameters
----------
data : np.ndarray
signal : np.ndarray
The data to be filtered
rate : float
The sampling rate
@ -26,21 +26,22 @@ def bandpass_filter(
np.ndarray
The filtered data
"""
sos = butter(2, (lowf, highf), "bandpass", fs=rate, output="sos")
fdata = sosfiltfilt(sos, data)
return fdata
sos = butter(2, (lowf, highf), "bandpass", fs=samplerate, output="sos")
filtered_signal = sosfiltfilt(sos, signal)
return filtered_signal
def highpass_filter(
data: np.ndarray,
rate: float,
signal: np.ndarray,
samplerate: float,
cutoff: float,
) -> np.ndarray:
"""Highpass filter a signal.
Parameters
----------
data : np.ndarray
signal : np.ndarray
The data to be filtered
rate : float
The sampling rate
@ -52,14 +53,15 @@ def highpass_filter(
np.ndarray
The filtered data
"""
sos = butter(2, cutoff, "highpass", fs=rate, output="sos")
fdata = sosfiltfilt(sos, data)
return fdata
sos = butter(2, cutoff, "highpass", fs=samplerate, output="sos")
filtered_signal = sosfiltfilt(sos, signal)
return filtered_signal
def lowpass_filter(
data: np.ndarray,
rate: float,
signal: np.ndarray,
samplerate: float,
cutoff: float
) -> np.ndarray:
"""Lowpass filter a signal.
@ -78,21 +80,25 @@ def lowpass_filter(
np.ndarray
The filtered data
"""
sos = butter(2, cutoff, "lowpass", fs=rate, output="sos")
fdata = sosfiltfilt(sos, data)
return fdata
sos = butter(2, cutoff, "lowpass", fs=samplerate, output="sos")
filtered_signal = sosfiltfilt(sos, signal)
return filtered_signal
def envelope(data: np.ndarray, rate: float, freq: float) -> np.ndarray:
def envelope(signal: np.ndarray,
samplerate: float,
cutoff_frequency: float
) -> np.ndarray:
"""Calculate the envelope of a signal using a lowpass filter.
Parameters
----------
data : np.ndarray
signal : np.ndarray
The signal to calculate the envelope of
rate : float
samplingrate : float
The sampling rate of the signal
freq : float
cutoff_frequency : float
The cutoff frequency of the lowpass filter
Returns
@ -100,6 +106,7 @@ def envelope(data: np.ndarray, rate: float, freq: float) -> np.ndarray:
np.ndarray
The envelope of the signal
"""
sos = butter(2, freq, "lowpass", fs=rate, output="sos")
envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(data))
sos = butter(2, cutoff_frequency, "lowpass", fs=samplerate, output="sos")
envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(signal))
return envelope