GP2023_chirp_detection/code/chirpdetection.py
2023-05-23 15:49:47 +02:00

1163 lines
40 KiB
Python
Executable File

from dataclasses import dataclass
from itertools import compress
import matplotlib.gridspec as gr
import matplotlib.pyplot as plt
import numpy as np
from IPython import embed
from modules.datahandling import (
flatten,
group_timestamps,
instantaneous_frequency,
minmaxnorm,
purge_duplicates,
)
from modules.filehandling import ConfLoader, LoadData, make_outputdir
from modules.filters import bandpass_filter, envelope, highpass_filter
from modules.logger import makeLogger
from modules.plotstyle import PlotStyle
from scipy.signal import find_peaks
from thunderfish.powerspectrum import decibel, spectrogram
# from sklearn.preprocessing import normalize
logger = makeLogger(__name__)
ps = PlotStyle()
@dataclass
class ChirpPlotBuffer:
"""
Buffer to save data that is created in the main detection loop
and plot it outside the detecion loop.
"""
config: ConfLoader
t0: float
dt: float
track_id: float
electrode: int
data: LoadData
time: np.ndarray
baseline: np.ndarray
baseline_envelope_unfiltered: np.ndarray
baseline_envelope: np.ndarray
baseline_peaks: np.ndarray
search_frequency: float
search: np.ndarray
search_envelope_unfiltered: np.ndarray
search_envelope: np.ndarray
search_peaks: np.ndarray
frequency_time: np.ndarray
frequency: np.ndarray
frequency_filtered: np.ndarray
frequency_peaks: np.ndarray
def plot_buffer(self, chirps: np.ndarray, plot: str) -> None:
logger.debug("Starting plotting")
# make data for plotting
# get index of track data in this time window
window_idx = np.arange(len(self.data.idx))[
(self.data.ident == self.track_id)
& (self.data.time[self.data.idx] >= self.t0)
& (self.data.time[self.data.idx] <= (self.t0 + self.dt))
]
# get tracked frequencies and their times
freq_temp = self.data.freq[window_idx]
# time_temp = self.data.time[
# self.data.idx[self.data.ident == self.track_id]][
# (self.data.time >= self.t0)
# & (self.data.time <= (self.t0 + self.dt))
# ]
# remake the band we filtered in
q25, q50, q75 = np.percentile(freq_temp, [25, 50, 75])
search_upper, search_lower = (
q50 + self.search_frequency + self.config.minimal_bandwidth / 2,
q50 + self.search_frequency - self.config.minimal_bandwidth / 2,
)
print(search_upper, search_lower)
# get indices on raw data
start_idx = int((self.t0 - 5) * self.data.raw_rate)
window_duration = (self.dt + 10) * self.data.raw_rate
stop_idx = int(start_idx + window_duration)
if start_idx < 0:
start_idx = 0
# get raw data
try:
data_oi = self.data.raw[start_idx:stop_idx, self.electrode]
except:
embed()
self.time = self.time - self.t0
self.frequency_time = self.frequency_time - self.t0
if len(chirps) > 0:
chirps = np.asarray(chirps) - self.t0
self.t0_old = self.t0
self.t0 = 0
fig = plt.figure(figsize=(14 * ps.cm, 18 * ps.cm))
gs0 = gr.GridSpec(3, 1, figure=fig, height_ratios=[1, 1, 1])
gs1 = gs0[0].subgridspec(1, 1)
gs2 = gs0[1].subgridspec(3, 1, hspace=0.4)
gs3 = gs0[2].subgridspec(3, 1, hspace=0.4)
# gs4 = gs0[5].subgridspec(1, 1)
ax6 = fig.add_subplot(gs3[2, 0])
ax0 = fig.add_subplot(gs1[0, 0], sharex=ax6)
ax1 = fig.add_subplot(gs2[0, 0], sharex=ax6)
ax2 = fig.add_subplot(gs2[1, 0], sharex=ax6)
ax3 = fig.add_subplot(gs2[2, 0], sharex=ax6)
ax4 = fig.add_subplot(gs3[0, 0], sharex=ax6)
ax5 = fig.add_subplot(gs3[1, 0], sharex=ax6)
# ax7 = fig.add_subplot(gs4[0, 0], sharex=ax0)
# ax_leg = fig.add_subplot(gs0[1, 0])
waveform_scaler = 1000
lw = 1.5
# plot spectrogram
_ = plot_spectrogram(
ax0,
data_oi,
self.data.raw_rate,
self.t0 - 5,
[np.min(self.frequency) - 300, np.max(self.frequency) + 300],
)
ax0.set_ylim(np.min(self.frequency) - 100, np.max(self.frequency) + 200)
for track_id in self.data.ids:
t0_track = self.t0_old - 5
dt_track = self.dt + 10
window_idx = np.arange(len(self.data.idx))[
(self.data.ident == track_id)
& (self.data.time[self.data.idx] >= t0_track)
& (self.data.time[self.data.idx] <= (t0_track + dt_track))
]
# get tracked frequencies and their times
f = self.data.freq[window_idx]
# t = self.data.time[
# self.data.idx[self.data.ident == self.track_id]]
# tmask = (t >= t0_track) & (t <= (t0_track + dt_track))
t = self.data.time[self.data.idx[window_idx]]
if track_id == self.track_id:
ax0.plot(t - self.t0_old, f, lw=lw, zorder=10, color=ps.gblue1)
else:
ax0.plot(t - self.t0_old, f, lw=lw, zorder=10, color=ps.black)
# ax0.fill_between(
# np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate),
# q50 - self.config.minimal_bandwidth / 2,
# q50 + self.config.minimal_bandwidth / 2,
# color=ps.gblue1,
# lw=1,
# ls="dashed",
# alpha=0.5,
# )
# ax0.fill_between(
# np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate),
# search_lower,
# search_upper,
# color=ps.gblue2,
# lw=1,
# ls="dashed",
# alpha=0.5,
# )
ax0.axhline(
q50 - self.config.minimal_bandwidth / 2,
color=ps.gblue1,
lw=1,
ls="dashed",
)
ax0.axhline(
q50 + self.config.minimal_bandwidth / 2,
color=ps.gblue1,
lw=1,
ls="dashed",
)
ax0.axhline(search_lower, color=ps.gblue2, lw=1, ls="dashed")
ax0.axhline(search_upper, color=ps.gblue2, lw=1, ls="dashed")
# ax0.axhline(q50, spec_times[0], spec_times[-1],
# color=ps.gblue1, lw=2, ls="dashed")
# ax0.axhline(q50 + self.search_frequency,
# spec_times[0], spec_times[-1],
# color=ps.gblue2, lw=2, ls="dashed")
if len(chirps) > 0:
for chirp in chirps:
ax0.scatter(
chirp,
np.median(self.frequency),
c=ps.red,
marker=".",
edgecolors=ps.black,
facecolors=ps.red,
zorder=10,
s=70,
)
# plot waveform of filtered signal
ax1.plot(
self.time,
self.baseline * waveform_scaler,
c=ps.gray,
lw=lw,
alpha=0.5,
)
ax1.plot(
self.time,
self.baseline_envelope_unfiltered * waveform_scaler,
c=ps.gblue1,
lw=lw,
label="baseline envelope",
)
# plot waveform of filtered search signal
ax2.plot(
self.time,
self.search * waveform_scaler,
c=ps.gray,
lw=lw,
alpha=0.5,
)
ax2.plot(
self.time,
self.search_envelope_unfiltered * waveform_scaler,
c=ps.gblue2,
lw=lw,
label="search envelope",
)
# plot baseline instantaneous frequency
ax3.plot(
self.frequency_time,
self.frequency,
c=ps.gblue3,
lw=lw,
label="baseline inst. freq.",
)
# plot filtered and rectified envelope
# ax4.plot(
# self.time, self.baseline_envelope * waveform_scaler, c=ps.gblue1, lw=lw
# )
ax4.plot(self.time, self.baseline_envelope, c=ps.gblue1, lw=lw)
ax4.scatter(
(self.time)[self.baseline_peaks],
# (self.baseline_envelope * waveform_scaler)[self.baseline_peaks],
(self.baseline_envelope)[self.baseline_peaks],
edgecolors=ps.black,
facecolors=ps.red,
zorder=10,
marker=".",
s=70,
# facecolors="none",
)
# plot envelope of search signal
# ax5.plot(self.time, self.search_envelope * waveform_scaler, c=ps.gblue2, lw=lw)
ax5.plot(self.time, self.search_envelope, c=ps.gblue2, lw=lw)
ax5.scatter(
(self.time)[self.search_peaks],
# (self.search_envelope * waveform_scaler)[self.search_peaks],
(self.search_envelope)[self.search_peaks],
edgecolors=ps.black,
facecolors=ps.red,
zorder=10,
marker=".",
s=70,
# facecolors="none",
)
# plot filtered instantaneous frequency
ax6.plot(self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw)
ax6.scatter(
self.frequency_time[self.frequency_peaks],
self.frequency_filtered[self.frequency_peaks],
edgecolors=ps.black,
facecolors=ps.red,
zorder=10,
marker=".",
s=70,
# facecolors="none",
)
ax0.set_ylabel("Frequency [Hz]")
ax1.set_ylabel(r"$\mu$V")
ax2.set_ylabel(r"$\mu$V")
ax3.set_ylabel("Hz")
ax4.set_ylabel(r"$\mu$V")
ax5.set_ylabel(r"$\mu$V")
ax6.set_ylabel("Hz")
ax6.set_xlabel("Time [s]")
plt.setp(ax0.get_xticklabels(), visible=False)
plt.setp(ax1.get_xticklabels(), visible=False)
plt.setp(ax2.get_xticklabels(), visible=False)
plt.setp(ax3.get_xticklabels(), visible=False)
plt.setp(ax4.get_xticklabels(), visible=False)
plt.setp(ax5.get_xticklabels(), visible=False)
# ps.letter_subplots([ax0, ax1, ax4], xoffset=-0.21)
# ax7.set_xticks(np.arange(0, 5.5, 1))
# ax7.spines.bottom.set_bounds((0, 5))
ax0.set_xlim(0, self.config.window)
plt.subplots_adjust(left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2)
fig.align_labels()
if plot == "show":
plt.show()
elif plot == "save":
make_outputdir(self.config.outputdir)
out = make_outputdir(
self.config.outputdir + self.data.datapath.split("/")[-2] + "/"
)
# plt.savefig(f"{out}{self.track_id}_{self.t0_old}.pdf")
# plt.savefig(f"{out}{self.track_id}_{self.t0_old}.svg")
plt.savefig(f"{out}{self.track_id}_{self.t0_old}.png")
plt.close()
def plot_spectrogram(
axis,
signal: np.ndarray,
samplerate: float,
window_start_seconds: float,
ylims: list[float],
) -> np.ndarray:
"""
Plot a spectrogram of a signal.
Parameters
----------
axis : matplotlib axis
Axis to plot the spectrogram on.
signal : np.ndarray
Signal to plot the spectrogram from.
samplerate : float
Samplerate of the signal.
window_start_seconds : float
Start time of the signal.
"""
logger.debug("Plotting spectrogram")
# compute spectrogram
spec_power, spec_freqs, spec_times = spectrogram(
signal,
ratetime=samplerate,
freq_resolution=10,
overlap_frac=0.5,
)
fmask = np.zeros(spec_freqs.shape, dtype=bool)
fmask[(spec_freqs > ylims[0]) & (spec_freqs < ylims[1])] = True
axis.imshow(
decibel(spec_power[fmask, :]),
extent=[
spec_times[0] + window_start_seconds,
spec_times[-1] + window_start_seconds,
spec_freqs[fmask][0],
spec_freqs[fmask][-1],
],
aspect="auto",
origin="lower",
interpolation="gaussian",
# alpha=0.6,
)
# axis.use_sticky_edges = False
return spec_times
def extract_frequency_bands(
raw_data: np.ndarray,
samplerate: int,
baseline_track: np.ndarray,
searchband_center: float,
minimal_bandwidth: float,
) -> tuple[np.ndarray, np.ndarray]:
"""
Apply a bandpass filter to the baseline of a signal and a second bandpass
filter above or below the baseline, as specified by the search frequency.
Parameters
----------
raw_data : np.ndarray
Data to apply the filter to.
samplerate : int
Samplerate of the signal.
baseline_track : np.ndarray
Tracked fundamental frequencies of the signal.
searchband_center: float
Frequency to search for above or below the baseline.
minimal_bandwidth : float
Minimal bandwidth of the filter.
Returns
-------
tuple[np.ndarray, np.ndarray]
"""
# compute boundaries to filter baseline
q25, q50, q75 = np.percentile(baseline_track, [25, 50, 75])
# check if percentile delta is too small
if q75 - q25 < 10:
q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2
# filter baseline
filtered_baseline = bandpass_filter(raw_data, samplerate, lowf=q25, highf=q75)
# filter search area
filtered_search_freq = bandpass_filter(
raw_data,
samplerate,
lowf=searchband_center + q50 - minimal_bandwidth / 2,
highf=searchband_center + q50 + minimal_bandwidth / 2,
)
return filtered_baseline, filtered_search_freq
def window_median_all_track_ids(
data: LoadData, window_start_seconds: float, window_duration_seconds: float
) -> tuple[list[tuple[float, float, float]], list[int]]:
"""
Calculate the median and quantiles of the frequency of all fish in a
given time window.
Iterate over all track ids and calculate the 25, 50 and 75 percentile
in a given time window to pass this data to 'find_searchband' function,
which then determines whether other fish in the current window fall
within the searchband of the current fish and then determine the
gaps that are outside of the percentile ranges.
Parameters
----------
data : LoadData
Data to calculate the median frequency from.
window_start_seconds : float
Start time of the window.
window_duration_seconds : float
Duration of the window.
Returns
-------
tuple[list[tuple[float, float, float]], list[int]]
"""
frequency_percentiles = []
track_ids = []
for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
# the window index combines the track id and the time window
window_idx = np.arange(len(data.idx))[
(data.ident == track_id)
& (data.time[data.idx] >= window_start_seconds)
& (data.time[data.idx] <= (window_start_seconds + window_duration_seconds))
]
if len(data.freq[window_idx]) > 0:
frequency_percentiles.append(
np.percentile(data.freq[window_idx], [25, 50, 75])
)
track_ids.append(track_id)
# convert to numpy array
frequency_percentiles = np.asarray(frequency_percentiles)
track_ids = np.asarray(track_ids)
return frequency_percentiles, track_ids
def array_center(array: np.ndarray) -> float:
"""
Return the center value of an array.
If the array length is even, returns
the mean of the two center values.
Parameters
----------
array : np.ndarray
Array to calculate the center from.
Returns
-------
float
"""
if len(array) % 2 == 0:
return np.mean(array[int(len(array) / 2) - 1 : int(len(array) / 2) + 1])
else:
return array[int(len(array) / 2)]
def has_chirp(baseline_frequency: np.ndarray, peak_height: float) -> bool:
"""
Check if a fish has a chirp.
Parameters
----------
baseline_frequency : np.ndarray
Baseline frequency of the fish.
peak_height : float
Minimal peak height of a chirp on the instant. freq.
Returns
-------
bool: True if the fish has a chirp, False otherwise.
"""
peaks, _ = find_peaks(baseline_frequency, height=peak_height)
if len(peaks) > 0:
return True
else:
return False
def mask_low_amplitudes(envelope, threshold):
"""
Mask low amplitudes in the envelope.
Parameters
----------
envelope : np.ndarray
Envelope of the signal.
threshold : float
Threshold to mask low amplitudes.
Returns
-------
np.ndarray
"""
mask = np.ones_like(envelope, dtype=bool)
mask[envelope < threshold] = False
return mask
def find_searchband(
current_frequency: np.ndarray,
percentiles_ids: np.ndarray,
frequency_percentiles: np.ndarray,
config: ConfLoader,
data: LoadData,
) -> float:
"""
Find the search frequency band for each fish by checking which fish EODs
are above the current EOD and finding a gap in them.
Parameters
----------
current_frequency : np.ndarray
Current EOD frequency array / the current fish of interest.
percentiles_ids : np.ndarray
Array of track IDs of the medians of all other fish in the current
window.
frequency_percentiles : np.ndarray
Array of percentiles frequencies of all other fish in the current window.
config : ConfLoader
Configuration file.
data : LoadData
Data to find the search frequency from.
Returns
-------
float
"""
# frequency window where second filter filters is potentially allowed
# to filter. This is the search window, in which we want to find
# a gap in the other fish's EODs.
current_median = np.median(current_frequency)
search_window = np.arange(
current_median + config.search_df_lower,
current_median + config.search_df_upper,
config.search_res,
)
# search window in boolean
bool_lower = np.ones_like(search_window, dtype=bool)
bool_upper = np.ones_like(search_window, dtype=bool)
search_window_bool = np.ones_like(search_window, dtype=bool)
# make seperate arrays from the qartiles
q25 = np.asarray([i[0] for i in frequency_percentiles])
q75 = np.asarray([i[2] for i in frequency_percentiles])
# get tracks that fall into search window
check_track_ids = percentiles_ids[(q25 > current_median) & (q75 < search_window[-1])]
# iterate through theses tracks
if check_track_ids.size != 0:
for j, check_track_id in enumerate(check_track_ids):
q25_temp = q25[percentiles_ids == check_track_id]
q75_temp = q75[percentiles_ids == check_track_id]
bool_lower[search_window > q25_temp - config.search_res] = False
bool_upper[search_window < q75_temp + config.search_res] = False
search_window_bool[(bool_lower == False) & (bool_upper == False)] = False
# find gaps in search window
search_window_indices = np.arange(len(search_window))
# get search window gaps
# taking the diff of a boolean array gives non zero values where the
# array changes from true to false or vice versa
search_window_gaps = np.diff(search_window_bool, append=np.nan)
nonzeros = search_window_gaps[np.nonzero(search_window_gaps)[0]]
nonzeros = nonzeros[~np.isnan(nonzeros)]
if len(nonzeros) == 0:
return config.default_search_freq
# if the first value is -1, the array starst with true, so a gap
if nonzeros[0] == -1:
stops = search_window_indices[search_window_gaps == -1]
starts = np.append(0, search_window_indices[search_window_gaps == 1])
# if the last value is -1, the array ends with true, so a gap
if nonzeros[-1] == 1:
stops = np.append(
search_window_indices[search_window_gaps == -1],
len(search_window) - 1,
)
# else it starts with false, so no gap
if nonzeros[0] == 1:
stops = search_window_indices[search_window_gaps == -1]
starts = search_window_indices[search_window_gaps == 1]
# if the last value is -1, the array ends with true, so a gap
if nonzeros[-1] == 1:
stops = np.append(
search_window_indices[search_window_gaps == -1],
len(search_window),
)
# get the frequency ranges of the gaps
search_windows = [search_window[x:y] for x, y in zip(starts, stops)]
search_windows_lens = [len(x) for x in search_windows]
longest_search_window = search_windows[np.argmax(search_windows_lens)]
# the center of the search frequency band is then the center of
# the longest gap
search_freq = array_center(longest_search_window) - current_median
return search_freq
return config.default_search_freq
def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
assert plot in [
"save",
"show",
"false",
], "plot must be 'save', 'show' or 'false'"
assert debug in [
"false",
"electrode",
"fish",
], "debug must be 'false', 'electrode' or 'fish'"
if debug != "false":
assert plot == "show", "debug mode only runs when plot is 'show'"
# load raw file
print("datapath", datapath)
data = LoadData(datapath)
# load config file
config = ConfLoader("chirpdetector_conf.yml")
# set time window
window_duration = config.window * data.raw_rate
window_overlap = config.overlap * data.raw_rate
window_edge = config.edge * data.raw_rate
# check if window duration and window ovelap is even, otherwise the half
# of the duration or window overlap would return a float, thus an
# invalid index
if window_duration % 2 == 0:
window_duration = int(window_duration)
else:
raise ValueError("Window duration must be even.")
if window_overlap % 2 == 0:
window_overlap = int(window_overlap)
else:
raise ValueError("Window overlap must be even.")
# make time array for raw data
raw_time = np.arange(data.raw.shape[0]) / data.raw_rate
# good chirp times for data: 2022-06-02-10_00
# window_start_index = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
# window_duration_index = 60 * data.raw_rate
# t0 = 0
# dt = data.raw.shape[0]
# window_start_seconds = (23495 + ((28336-23495)/3)) * data.raw_rate
# window_duration_seconds = (28336 - 23495) * data.raw_rate
window_start_index = 0
window_duration_index = data.raw.shape[0]
# generate starting points of rolling window
window_start_indices = np.arange(
window_start_index,
window_start_index + window_duration_index,
window_duration - (window_overlap + 2 * window_edge),
dtype=int,
)
# ititialize lists to store data
multiwindow_chirps = []
multiwindow_ids = []
for st, window_start_index in enumerate(window_start_indices):
logger.info(f"Processing window {st} of {len(window_start_indices)}")
window_start_seconds = window_start_index / data.raw_rate
window_duration_seconds = window_duration / data.raw_rate
# set index window
window_stop_index = window_start_index + window_duration
# calucate median of fish frequencies in window
median_freq, median_ids = window_median_all_track_ids(
data, window_start_seconds, window_duration_seconds
)
# iterate through all fish
for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
logger.debug(f"Processing track {tr} of {len(data.ids)}")
# get index of track data in this time window
track_window_index = np.arange(len(data.idx))[
(data.ident == track_id)
& (data.time[data.idx] >= window_start_seconds)
& (
data.time[data.idx]
<= (window_start_seconds + window_duration_seconds)
)
]
# get tracked frequencies and their times
current_frequencies = data.freq[track_window_index]
current_powers = data.powers[track_window_index, :]
# check if tracked data available in this window
if len(current_frequencies) < 3:
logger.warning(f"Track {track_id} has no data in window {st}, skipping.")
continue
# check if there are powers available in this window
nanchecker = np.unique(np.isnan(current_powers))
if (len(nanchecker) == 1) and nanchecker[0] is True:
logger.warning(
f"No powers available for track {track_id} window {st}," "skipping."
)
continue
# find the strongest electrodes for the current fish in the current
# window
best_electrode_index = np.argsort(np.nanmean(current_powers, axis=0))[
-config.number_electrodes :
]
# find a frequency above the baseline of the current fish in which
# no other fish is active to search for chirps there
search_frequency = find_searchband(
config=config,
current_frequency=current_frequencies,
percentiles_ids=median_ids,
data=data,
frequency_percentiles=median_freq,
)
# add all chirps that are detected on mulitple electrodes for one
# fish fish in one window to this list
multielectrode_chirps = []
# iterate through electrodes
for el, electrode_index in enumerate(best_electrode_index):
logger.debug(
f"Processing electrode {el+1} of " f"{len(best_electrode_index)}"
)
# LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------
# load region of interest of raw data file
current_raw_data = data.raw[
window_start_index:window_stop_index, electrode_index
]
current_raw_time = raw_time[window_start_index:window_stop_index]
# EXTRACT FEATURES --------------------------------------------
# filter baseline and above
baselineband, searchband = extract_frequency_bands(
raw_data=current_raw_data,
samplerate=data.raw_rate,
baseline_track=current_frequencies,
searchband_center=search_frequency,
minimal_bandwidth=config.minimal_bandwidth,
)
# compute envelope of baseline band to find dips
# in the baseline envelope
baseline_envelope_unfiltered = envelope(
signal=baselineband,
samplerate=data.raw_rate,
cutoff_frequency=config.baseline_envelope_cutoff,
)
# create a mask that removes areas where amplitudes are very
# because the instantaneous frequency is not reliable there
amplitude_mask = mask_low_amplitudes(
baseline_envelope_unfiltered, config.baseline_min_amplitude
)
# highpass filter baseline envelope to remove slower
# fluctuations e.g. due to motion envelope
baseline_envelope = bandpass_filter(
signal=baseline_envelope_unfiltered,
samplerate=data.raw_rate,
lowf=config.baseline_envelope_bandpass_lowf,
highf=config.baseline_envelope_bandpass_highf,
)
# invert baseline envelope to find troughs in the baseline
baseline_envelope = -baseline_envelope
# compute the envelope of the search band. Peaks in the search
# band envelope correspond to troughs in the baseline envelope
# during chirps
search_envelope_unfiltered = envelope(
signal=searchband,
samplerate=data.raw_rate,
cutoff_frequency=config.search_envelope_cutoff,
)
search_envelope = search_envelope_unfiltered
# compute instantaneous frequency of the baseline band to find
# anomalies during a chirp, i.e. a frequency jump upwards or
# sometimes downwards. We do not fully understand why the
# instantaneous frequency can also jump downwards during a
# chirp. This phenomenon is only observed on chirps on a narrow
# filtered baseline such as the one we are working with.
baseline_frequency = instantaneous_frequency(
baselineband,
data.raw_rate,
config.baseline_frequency_smoothing,
)
# Take the absolute of the instantaneous frequency to invert
# troughs into peaks. This is nessecary since the narrow
# pass band introduces these anomalies. Also substract by the
# median to set it to 0.
baseline_frequency_filtered = np.abs(
baseline_frequency - np.median(baseline_frequency)
)
# # check if there is at least one superthreshold peak on the
# # instantaneous and exit the loop if not. This is used to
# # prevent windows that do definetely not include a chirp
# # to enter normalization, where small changes due to noise
# # would be amplified
# if not has_chirp(
# baseline_frequency_filtered[amplitude_mask],
# config.baseline_frequency_peakheight,
# ):
# logger.warning(
# f"Amplitude to small for the chirp detection of track {track_id} window {st},")
# continue
# CUT OFF OVERLAP ---------------------------------------------
# cut off snippet at start and end of each window to remove
# filter effects
# get arrays with raw samplerate without edges
no_edges = np.arange(
int(window_edge), len(baseline_envelope) - int(window_edge)
)
current_raw_time = current_raw_time[no_edges]
baselineband = baselineband[no_edges]
baseline_envelope_unfiltered = baseline_envelope_unfiltered[no_edges]
searchband = searchband[no_edges]
baseline_envelope = baseline_envelope[no_edges]
search_envelope_unfiltered = search_envelope_unfiltered[no_edges]
search_envelope = search_envelope[no_edges]
baseline_frequency = baseline_frequency[no_edges]
baseline_frequency_filtered = baseline_frequency_filtered[no_edges]
baseline_frequency_time = current_raw_time
# # get instantaneous frequency withoup edges
# no_edges_t0 = int(window_edge) / data.raw_rate
# no_edges_t1 = baseline_frequency_time[-1] - (
# int(window_edge) / data.raw_rate
# )
# no_edges = (baseline_frequency_time >= no_edges_t0) & (
# baseline_frequency_time <= no_edges_t1
# )
# baseline_frequency_filtered = baseline_frequency_filtered[no_edges]
# baseline_frequency = baseline_frequency[no_edges]
# baseline_frequency_time = (
# baseline_frequency_time[no_edges] + window_start_seconds
# )
# NORMALIZE ---------------------------------------------------
# normalize all three feature arrays to the same range to make
# peak detection simpler
# baseline_envelope = minmaxnorm([baseline_envelope])[0]
# search_envelope = minmaxnorm([search_envelope])[0]
# baseline_frequency_filtered = minmaxnorm(
# [baseline_frequency_filtered]
# )[0]
# PEAK DETECTION ----------------------------------------------
# detect peaks baseline_enelope
baseline_peak_indices, _ = find_peaks(
baseline_envelope, prominence=config.baseline_prominence
)
# detect peaks search_envelope
search_peak_indices, _ = find_peaks(
search_envelope, prominence=config.search_prominence
)
# detect peaks inst_freq_filtered
frequency_peak_indices, _ = find_peaks(
baseline_frequency_filtered,
prominence=config.frequency_prominence,
)
# DETECT CHIRPS IN SEARCH WINDOW ------------------------------
# get the peak timestamps from the peak indices
baseline_peak_timestamps = current_raw_time[baseline_peak_indices]
search_peak_timestamps = current_raw_time[search_peak_indices]
frequency_peak_timestamps = baseline_frequency_time[
frequency_peak_indices
]
# check if one list is empty and if so, skip to the next
# electrode because a chirp cannot be detected if one is empty
one_feature_empty = (
len(baseline_peak_timestamps) == 0
or len(search_peak_timestamps) == 0
# or len(frequency_peak_timestamps) == 0
)
if one_feature_empty and (debug == "false"):
continue
# group peak across feature arrays but only if they
# occur in all 3 feature arrays
sublists = [
list(baseline_peak_timestamps),
list(search_peak_timestamps),
# list(frequency_peak_timestamps),
]
singleelectrode_chirps = group_timestamps(
sublists=sublists,
at_least_in=2,
difference_threshold=config.chirp_window_threshold,
)
# check it there are chirps detected after grouping, continue
# with the loop if not
if (len(singleelectrode_chirps) == 0) and (debug == "false"):
continue
# append chirps from this electrode to the multilectrode list
multielectrode_chirps.append(singleelectrode_chirps)
# only initialize the plotting buffer if chirps are detected
chirp_detected = el == (config.number_electrodes - 1) & (
plot in ["show", "save"]
)
if chirp_detected or (debug != "elecrode"):
logger.debug("Detected chirp, ititialize buffer ...")
# save data to Buffer
buffer = ChirpPlotBuffer(
config=config,
t0=window_start_seconds,
dt=window_duration_seconds,
electrode=electrode_index,
track_id=track_id,
data=data,
time=current_raw_time,
baseline_envelope_unfiltered=baseline_envelope_unfiltered,
baseline=baselineband,
baseline_envelope=baseline_envelope,
baseline_peaks=baseline_peak_indices,
search_frequency=search_frequency,
search=searchband,
search_envelope_unfiltered=search_envelope_unfiltered,
search_envelope=search_envelope,
search_peaks=search_peak_indices,
frequency_time=baseline_frequency_time,
frequency=baseline_frequency,
frequency_filtered=baseline_frequency_filtered,
frequency_peaks=frequency_peak_indices,
)
logger.debug("Buffer initialized!")
if debug == "electrode":
logger.info(f"Plotting electrode {el} ...")
buffer.plot_buffer(chirps=singleelectrode_chirps, plot=plot)
logger.debug(
f"Processed all electrodes for fish {track_id} for this"
"window, sorting chirps ..."
)
# check if there are chirps detected in multiple electrodes and
# continue the loop if not
if (len(multielectrode_chirps) == 0) and (debug == "false"):
continue
# validate multielectrode chirps, i.e. check if they are
# detected in at least 'config.min_electrodes' electrodes
multielectrode_chirps_validated = group_timestamps(
sublists=multielectrode_chirps,
at_least_in=config.minimum_electrodes,
difference_threshold=config.chirp_window_threshold,
)
# add validated chirps to the list that tracks chirps across there
# rolling time windows
multiwindow_chirps.append(multielectrode_chirps_validated)
multiwindow_ids.append(track_id)
logger.info(
f"Found {len(multielectrode_chirps_validated)}"
f" chirps for fish {track_id} in this window!"
)
# if chirps are detected and the plot flag is set, plot the
# chirps, otheswise try to delete the buffer if it exists
if debug == "fish":
logger.info(f"Plotting fish {track_id} ...")
buffer.plot_buffer(multielectrode_chirps_validated, plot)
if (
(len(multielectrode_chirps_validated) > 0)
& (plot in ["show", "save"])
& (debug == "false")
):
try:
buffer.plot_buffer(multielectrode_chirps_validated, plot)
del buffer
except NameError:
pass
else:
try:
del buffer
except NameError:
pass
# flatten list of lists containing chirps and create
# an array of fish ids that correspond to the chirps
multiwindow_chirps_flat = []
multiwindow_ids_flat = []
for track_id in np.unique(multiwindow_ids):
# get chirps for this fish and flatten the list
current_track_bool = np.asarray(multiwindow_ids) == track_id
current_track_chirps = flatten(
list(compress(multiwindow_chirps, current_track_bool))
)
# add flattened chirps to the list
multiwindow_chirps_flat.extend(current_track_chirps)
multiwindow_ids_flat.extend(list(np.ones_like(current_track_chirps) * track_id))
# purge duplicates, i.e. chirps that are very close to each other
# duplites arise due to overlapping windows
purged_chirps = []
purged_ids = []
for track_id in np.unique(multiwindow_ids_flat):
tr_chirps = np.asarray(multiwindow_chirps_flat)[
np.asarray(multiwindow_ids_flat) == track_id
]
if len(tr_chirps) > 0:
tr_chirps_purged = purge_duplicates(tr_chirps, config.chirp_window_threshold)
purged_chirps.extend(list(tr_chirps_purged))
purged_ids.extend(list(np.ones_like(tr_chirps_purged) * track_id))
# sort chirps by time
purged_chirps = np.asarray(purged_chirps)
purged_ids = np.asarray(purged_ids)
purged_ids = purged_ids[np.argsort(purged_chirps)]
purged_chirps = purged_chirps[np.argsort(purged_chirps)]
# save them into the data directory
np.save(datapath + "chirp_times_gp.npy", purged_chirps)
np.save(datapath + "chirp_ids_gp.npy", purged_ids)
if __name__ == "__main__":
# datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-05-13-10_00/"
# datapath = "/home/weygoldt/Data/uni/efishdata/2016-colombia/fishgrid/2016-04-09-22_25/"
# datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-03-13-10_00/"
datapath = "../data/2022-06-02-10_00/"
datapath = "../../../local_data/randgrid/"
chirpdetection(datapath, plot="save", debug="false")