1189 lines
40 KiB
Python
Executable File
1189 lines
40 KiB
Python
Executable File
from dataclasses import dataclass
|
|
from itertools import compress
|
|
|
|
import matplotlib.gridspec as gr
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from IPython import embed
|
|
from modules.datahandling import (
|
|
flatten,
|
|
group_timestamps,
|
|
instantaneous_frequency,
|
|
minmaxnorm,
|
|
purge_duplicates,
|
|
)
|
|
from modules.filehandling import ConfLoader, LoadData, make_outputdir
|
|
from modules.filters import bandpass_filter, envelope, highpass_filter
|
|
from modules.logger import makeLogger
|
|
from modules.plotstyle import PlotStyle
|
|
from scipy.signal import find_peaks
|
|
from thunderfish.powerspectrum import decibel, spectrogram
|
|
|
|
# from sklearn.preprocessing import normalize
|
|
|
|
|
|
logger = makeLogger(__name__)
|
|
|
|
ps = PlotStyle()
|
|
|
|
|
|
@dataclass
|
|
class ChirpPlotBuffer:
|
|
|
|
"""
|
|
Buffer to save data that is created in the main detection loop
|
|
and plot it outside the detecion loop.
|
|
"""
|
|
|
|
config: ConfLoader
|
|
t0: float
|
|
dt: float
|
|
track_id: float
|
|
electrode: int
|
|
data: LoadData
|
|
|
|
time: np.ndarray
|
|
baseline: np.ndarray
|
|
baseline_envelope_unfiltered: np.ndarray
|
|
baseline_envelope: np.ndarray
|
|
baseline_peaks: np.ndarray
|
|
search_frequency: float
|
|
search: np.ndarray
|
|
search_envelope_unfiltered: np.ndarray
|
|
search_envelope: np.ndarray
|
|
search_peaks: np.ndarray
|
|
|
|
frequency_time: np.ndarray
|
|
frequency: np.ndarray
|
|
frequency_filtered: np.ndarray
|
|
frequency_peaks: np.ndarray
|
|
|
|
def plot_buffer(self, chirps: np.ndarray, plot: str) -> None:
|
|
logger.debug("Starting plotting")
|
|
|
|
# make data for plotting
|
|
|
|
# get index of track data in this time window
|
|
window_idx = np.arange(len(self.data.idx))[
|
|
(self.data.ident == self.track_id)
|
|
& (self.data.time[self.data.idx] >= self.t0)
|
|
& (self.data.time[self.data.idx] <= (self.t0 + self.dt))
|
|
]
|
|
|
|
# get tracked frequencies and their times
|
|
freq_temp = self.data.freq[window_idx]
|
|
# time_temp = self.data.time[
|
|
# self.data.idx[self.data.ident == self.track_id]][
|
|
# (self.data.time >= self.t0)
|
|
# & (self.data.time <= (self.t0 + self.dt))
|
|
# ]
|
|
|
|
# remake the band we filtered in
|
|
q25, q50, q75 = np.percentile(freq_temp, [25, 50, 75])
|
|
search_upper, search_lower = (
|
|
q50 + self.search_frequency + self.config.minimal_bandwidth / 2,
|
|
q50 + self.search_frequency - self.config.minimal_bandwidth / 2,
|
|
)
|
|
print(search_upper, search_lower)
|
|
|
|
# get indices on raw data
|
|
start_idx = (self.t0 - 5) * self.data.raw_rate
|
|
window_duration = (self.dt + 10) * self.data.raw_rate
|
|
stop_idx = start_idx + window_duration
|
|
|
|
# get raw data
|
|
data_oi = self.data.raw[start_idx:stop_idx, self.electrode]
|
|
|
|
self.time = self.time - self.t0
|
|
self.frequency_time = self.frequency_time - self.t0
|
|
if len(chirps) > 0:
|
|
chirps = np.asarray(chirps) - self.t0
|
|
self.t0_old = self.t0
|
|
self.t0 = 0
|
|
|
|
fig = plt.figure(figsize=(14 * ps.cm, 18 * ps.cm))
|
|
|
|
gs0 = gr.GridSpec(3, 1, figure=fig, height_ratios=[1, 1, 1])
|
|
gs1 = gs0[0].subgridspec(1, 1)
|
|
gs2 = gs0[1].subgridspec(3, 1, hspace=0.4)
|
|
gs3 = gs0[2].subgridspec(3, 1, hspace=0.4)
|
|
# gs4 = gs0[5].subgridspec(1, 1)
|
|
|
|
ax6 = fig.add_subplot(gs3[2, 0])
|
|
ax0 = fig.add_subplot(gs1[0, 0], sharex=ax6)
|
|
ax1 = fig.add_subplot(gs2[0, 0], sharex=ax6)
|
|
ax2 = fig.add_subplot(gs2[1, 0], sharex=ax6)
|
|
ax3 = fig.add_subplot(gs2[2, 0], sharex=ax6)
|
|
ax4 = fig.add_subplot(gs3[0, 0], sharex=ax6)
|
|
ax5 = fig.add_subplot(gs3[1, 0], sharex=ax6)
|
|
# ax7 = fig.add_subplot(gs4[0, 0], sharex=ax0)
|
|
|
|
# ax_leg = fig.add_subplot(gs0[1, 0])
|
|
|
|
waveform_scaler = 1000
|
|
lw = 1.5
|
|
|
|
# plot spectrogram
|
|
_ = plot_spectrogram(
|
|
ax0,
|
|
data_oi,
|
|
self.data.raw_rate,
|
|
self.t0 - 5,
|
|
[np.min(self.frequency) - 300, np.max(self.frequency) + 300],
|
|
)
|
|
ax0.set_ylim(np.min(self.frequency) - 100, np.max(self.frequency) + 200)
|
|
|
|
for track_id in self.data.ids:
|
|
t0_track = self.t0_old - 5
|
|
dt_track = self.dt + 10
|
|
window_idx = np.arange(len(self.data.idx))[
|
|
(self.data.ident == track_id)
|
|
& (self.data.time[self.data.idx] >= t0_track)
|
|
& (self.data.time[self.data.idx] <= (t0_track + dt_track))
|
|
]
|
|
|
|
# get tracked frequencies and their times
|
|
f = self.data.freq[window_idx]
|
|
# t = self.data.time[
|
|
# self.data.idx[self.data.ident == self.track_id]]
|
|
# tmask = (t >= t0_track) & (t <= (t0_track + dt_track))
|
|
t = self.data.time[self.data.idx[window_idx]]
|
|
if track_id == self.track_id:
|
|
ax0.plot(t - self.t0_old, f, lw=lw, zorder=10, color=ps.gblue1)
|
|
else:
|
|
ax0.plot(t - self.t0_old, f, lw=lw, zorder=10, color=ps.black)
|
|
|
|
# ax0.fill_between(
|
|
# np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate),
|
|
# q50 - self.config.minimal_bandwidth / 2,
|
|
# q50 + self.config.minimal_bandwidth / 2,
|
|
# color=ps.gblue1,
|
|
# lw=1,
|
|
# ls="dashed",
|
|
# alpha=0.5,
|
|
# )
|
|
|
|
# ax0.fill_between(
|
|
# np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate),
|
|
# search_lower,
|
|
# search_upper,
|
|
# color=ps.gblue2,
|
|
# lw=1,
|
|
# ls="dashed",
|
|
# alpha=0.5,
|
|
# )
|
|
|
|
ax0.axhline(
|
|
q50 - self.config.minimal_bandwidth / 2,
|
|
color=ps.gblue1,
|
|
lw=1,
|
|
ls="dashed",
|
|
)
|
|
ax0.axhline(
|
|
q50 + self.config.minimal_bandwidth / 2,
|
|
color=ps.gblue1,
|
|
lw=1,
|
|
ls="dashed",
|
|
)
|
|
ax0.axhline(search_lower, color=ps.gblue2, lw=1, ls="dashed")
|
|
ax0.axhline(search_upper, color=ps.gblue2, lw=1, ls="dashed")
|
|
|
|
# ax0.axhline(q50, spec_times[0], spec_times[-1],
|
|
# color=ps.gblue1, lw=2, ls="dashed")
|
|
# ax0.axhline(q50 + self.search_frequency,
|
|
# spec_times[0], spec_times[-1],
|
|
# color=ps.gblue2, lw=2, ls="dashed")
|
|
|
|
if len(chirps) > 0:
|
|
for chirp in chirps:
|
|
ax0.scatter(
|
|
chirp,
|
|
np.median(self.frequency),
|
|
c=ps.red,
|
|
marker=".",
|
|
edgecolors=ps.black,
|
|
facecolors=ps.red,
|
|
zorder=10,
|
|
s=70,
|
|
)
|
|
|
|
# plot waveform of filtered signal
|
|
ax1.plot(
|
|
self.time,
|
|
self.baseline * waveform_scaler,
|
|
c=ps.gray,
|
|
lw=lw,
|
|
alpha=0.5,
|
|
)
|
|
ax1.plot(
|
|
self.time,
|
|
self.baseline_envelope_unfiltered * waveform_scaler,
|
|
c=ps.gblue1,
|
|
lw=lw,
|
|
label="baseline envelope",
|
|
)
|
|
|
|
# plot waveform of filtered search signal
|
|
ax2.plot(
|
|
self.time,
|
|
self.search * waveform_scaler,
|
|
c=ps.gray,
|
|
lw=lw,
|
|
alpha=0.5,
|
|
)
|
|
ax2.plot(
|
|
self.time,
|
|
self.search_envelope_unfiltered * waveform_scaler,
|
|
c=ps.gblue2,
|
|
lw=lw,
|
|
label="search envelope",
|
|
)
|
|
|
|
# plot baseline instantaneous frequency
|
|
ax3.plot(
|
|
self.frequency_time,
|
|
self.frequency,
|
|
c=ps.gblue3,
|
|
lw=lw,
|
|
label="baseline inst. freq.",
|
|
)
|
|
|
|
# plot filtered and rectified envelope
|
|
# ax4.plot(
|
|
# self.time, self.baseline_envelope * waveform_scaler, c=ps.gblue1, lw=lw
|
|
# )
|
|
ax4.plot(self.time, self.baseline_envelope, c=ps.gblue1, lw=lw)
|
|
ax4.scatter(
|
|
(self.time)[self.baseline_peaks],
|
|
# (self.baseline_envelope * waveform_scaler)[self.baseline_peaks],
|
|
(self.baseline_envelope)[self.baseline_peaks],
|
|
edgecolors=ps.black,
|
|
facecolors=ps.red,
|
|
zorder=10,
|
|
marker=".",
|
|
s=70,
|
|
# facecolors="none",
|
|
)
|
|
|
|
# plot envelope of search signal
|
|
# ax5.plot(self.time, self.search_envelope * waveform_scaler, c=ps.gblue2, lw=lw)
|
|
ax5.plot(self.time, self.search_envelope, c=ps.gblue2, lw=lw)
|
|
ax5.scatter(
|
|
(self.time)[self.search_peaks],
|
|
# (self.search_envelope * waveform_scaler)[self.search_peaks],
|
|
(self.search_envelope)[self.search_peaks],
|
|
edgecolors=ps.black,
|
|
facecolors=ps.red,
|
|
zorder=10,
|
|
marker=".",
|
|
s=70,
|
|
# facecolors="none",
|
|
)
|
|
|
|
# plot filtered instantaneous frequency
|
|
ax6.plot(
|
|
self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw
|
|
)
|
|
ax6.scatter(
|
|
self.frequency_time[self.frequency_peaks],
|
|
self.frequency_filtered[self.frequency_peaks],
|
|
edgecolors=ps.black,
|
|
facecolors=ps.red,
|
|
zorder=10,
|
|
marker=".",
|
|
s=70,
|
|
# facecolors="none",
|
|
)
|
|
|
|
ax0.set_ylabel("Frequency [Hz]")
|
|
ax1.set_ylabel(r"$\mu$V")
|
|
ax2.set_ylabel(r"$\mu$V")
|
|
ax3.set_ylabel("Hz")
|
|
ax4.set_ylabel(r"$\mu$V")
|
|
ax5.set_ylabel(r"$\mu$V")
|
|
ax6.set_ylabel("Hz")
|
|
ax6.set_xlabel("Time [s]")
|
|
|
|
plt.setp(ax0.get_xticklabels(), visible=False)
|
|
plt.setp(ax1.get_xticklabels(), visible=False)
|
|
plt.setp(ax2.get_xticklabels(), visible=False)
|
|
plt.setp(ax3.get_xticklabels(), visible=False)
|
|
plt.setp(ax4.get_xticklabels(), visible=False)
|
|
plt.setp(ax5.get_xticklabels(), visible=False)
|
|
|
|
# ps.letter_subplots([ax0, ax1, ax4], xoffset=-0.21)
|
|
|
|
# ax7.set_xticks(np.arange(0, 5.5, 1))
|
|
# ax7.spines.bottom.set_bounds((0, 5))
|
|
|
|
ax0.set_xlim(0, self.config.window)
|
|
plt.subplots_adjust(
|
|
left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2
|
|
)
|
|
fig.align_labels()
|
|
|
|
if plot == "show":
|
|
plt.show()
|
|
elif plot == "save":
|
|
make_outputdir(self.config.outputdir)
|
|
out = make_outputdir(
|
|
self.config.outputdir + self.data.datapath.split("/")[-2] + "/"
|
|
)
|
|
|
|
plt.savefig(f"{out}{self.track_id}_{self.t0_old}.pdf")
|
|
plt.savefig(f"{out}{self.track_id}_{self.t0_old}.svg")
|
|
plt.close()
|
|
|
|
|
|
def plot_spectrogram(
|
|
axis,
|
|
signal: np.ndarray,
|
|
samplerate: float,
|
|
window_start_seconds: float,
|
|
ylims: list[float],
|
|
) -> np.ndarray:
|
|
"""
|
|
Plot a spectrogram of a signal.
|
|
|
|
Parameters
|
|
----------
|
|
axis : matplotlib axis
|
|
Axis to plot the spectrogram on.
|
|
signal : np.ndarray
|
|
Signal to plot the spectrogram from.
|
|
samplerate : float
|
|
Samplerate of the signal.
|
|
window_start_seconds : float
|
|
Start time of the signal.
|
|
"""
|
|
|
|
logger.debug("Plotting spectrogram")
|
|
|
|
# compute spectrogram
|
|
spec_power, spec_freqs, spec_times = spectrogram(
|
|
signal,
|
|
ratetime=samplerate,
|
|
freq_resolution=10,
|
|
overlap_frac=0.5,
|
|
)
|
|
|
|
fmask = np.zeros(spec_freqs.shape, dtype=bool)
|
|
fmask[(spec_freqs > ylims[0]) & (spec_freqs < ylims[1])] = True
|
|
|
|
axis.imshow(
|
|
decibel(spec_power[fmask, :]),
|
|
extent=[
|
|
spec_times[0] + window_start_seconds,
|
|
spec_times[-1] + window_start_seconds,
|
|
spec_freqs[fmask][0],
|
|
spec_freqs[fmask][-1],
|
|
],
|
|
aspect="auto",
|
|
origin="lower",
|
|
interpolation="gaussian",
|
|
# alpha=0.6,
|
|
)
|
|
# axis.use_sticky_edges = False
|
|
return spec_times
|
|
|
|
|
|
def extract_frequency_bands(
|
|
raw_data: np.ndarray,
|
|
samplerate: int,
|
|
baseline_track: np.ndarray,
|
|
searchband_center: float,
|
|
minimal_bandwidth: float,
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Apply a bandpass filter to the baseline of a signal and a second bandpass
|
|
filter above or below the baseline, as specified by the search frequency.
|
|
|
|
Parameters
|
|
----------
|
|
raw_data : np.ndarray
|
|
Data to apply the filter to.
|
|
samplerate : int
|
|
Samplerate of the signal.
|
|
baseline_track : np.ndarray
|
|
Tracked fundamental frequencies of the signal.
|
|
searchband_center: float
|
|
Frequency to search for above or below the baseline.
|
|
minimal_bandwidth : float
|
|
Minimal bandwidth of the filter.
|
|
|
|
Returns
|
|
-------
|
|
tuple[np.ndarray, np.ndarray]
|
|
|
|
"""
|
|
# compute boundaries to filter baseline
|
|
q25, q50, q75 = np.percentile(baseline_track, [25, 50, 75])
|
|
|
|
# check if percentile delta is too small
|
|
if q75 - q25 < 10:
|
|
q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2
|
|
|
|
# filter baseline
|
|
filtered_baseline = bandpass_filter(
|
|
raw_data, samplerate, lowf=q25, highf=q75
|
|
)
|
|
|
|
# filter search area
|
|
filtered_search_freq = bandpass_filter(
|
|
raw_data,
|
|
samplerate,
|
|
lowf=searchband_center + q50 - minimal_bandwidth / 2,
|
|
highf=searchband_center + q50 + minimal_bandwidth / 2,
|
|
)
|
|
|
|
return filtered_baseline, filtered_search_freq
|
|
|
|
|
|
def window_median_all_track_ids(
|
|
data: LoadData, window_start_seconds: float, window_duration_seconds: float
|
|
) -> tuple[list[tuple[float, float, float]], list[int]]:
|
|
"""
|
|
Calculate the median and quantiles of the frequency of all fish in a
|
|
given time window.
|
|
|
|
Iterate over all track ids and calculate the 25, 50 and 75 percentile
|
|
in a given time window to pass this data to 'find_searchband' function,
|
|
which then determines whether other fish in the current window fall
|
|
within the searchband of the current fish and then determine the
|
|
gaps that are outside of the percentile ranges.
|
|
|
|
Parameters
|
|
----------
|
|
data : LoadData
|
|
Data to calculate the median frequency from.
|
|
window_start_seconds : float
|
|
Start time of the window.
|
|
window_duration_seconds : float
|
|
Duration of the window.
|
|
|
|
Returns
|
|
-------
|
|
tuple[list[tuple[float, float, float]], list[int]]
|
|
|
|
"""
|
|
|
|
frequency_percentiles = []
|
|
track_ids = []
|
|
|
|
for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
|
|
# the window index combines the track id and the time window
|
|
window_idx = np.arange(len(data.idx))[
|
|
(data.ident == track_id)
|
|
& (data.time[data.idx] >= window_start_seconds)
|
|
& (
|
|
data.time[data.idx]
|
|
<= (window_start_seconds + window_duration_seconds)
|
|
)
|
|
]
|
|
|
|
if len(data.freq[window_idx]) > 0:
|
|
frequency_percentiles.append(
|
|
np.percentile(data.freq[window_idx], [25, 50, 75])
|
|
)
|
|
track_ids.append(track_id)
|
|
|
|
# convert to numpy array
|
|
frequency_percentiles = np.asarray(frequency_percentiles)
|
|
track_ids = np.asarray(track_ids)
|
|
|
|
return frequency_percentiles, track_ids
|
|
|
|
|
|
def array_center(array: np.ndarray) -> float:
|
|
"""
|
|
Return the center value of an array.
|
|
If the array length is even, returns
|
|
the mean of the two center values.
|
|
|
|
Parameters
|
|
----------
|
|
array : np.ndarray
|
|
Array to calculate the center from.
|
|
|
|
Returns
|
|
-------
|
|
float
|
|
|
|
"""
|
|
if len(array) % 2 == 0:
|
|
return np.mean(array[int(len(array) / 2) - 1 : int(len(array) / 2) + 1])
|
|
else:
|
|
return array[int(len(array) / 2)]
|
|
|
|
|
|
def has_chirp(baseline_frequency: np.ndarray, peak_height: float) -> bool:
|
|
"""
|
|
Check if a fish has a chirp.
|
|
|
|
Parameters
|
|
----------
|
|
baseline_frequency : np.ndarray
|
|
Baseline frequency of the fish.
|
|
peak_height : float
|
|
Minimal peak height of a chirp on the instant. freq.
|
|
|
|
Returns
|
|
-------
|
|
bool: True if the fish has a chirp, False otherwise.
|
|
"""
|
|
peaks, _ = find_peaks(baseline_frequency, height=peak_height)
|
|
if len(peaks) > 0:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def mask_low_amplitudes(envelope, threshold):
|
|
"""
|
|
Mask low amplitudes in the envelope.
|
|
|
|
Parameters
|
|
----------
|
|
envelope : np.ndarray
|
|
Envelope of the signal.
|
|
threshold : float
|
|
Threshold to mask low amplitudes.
|
|
|
|
Returns
|
|
-------
|
|
np.ndarray
|
|
|
|
"""
|
|
mask = np.ones_like(envelope, dtype=bool)
|
|
mask[envelope < threshold] = False
|
|
return mask
|
|
|
|
|
|
def find_searchband(
|
|
current_frequency: np.ndarray,
|
|
percentiles_ids: np.ndarray,
|
|
frequency_percentiles: np.ndarray,
|
|
config: ConfLoader,
|
|
data: LoadData,
|
|
) -> float:
|
|
"""
|
|
Find the search frequency band for each fish by checking which fish EODs
|
|
are above the current EOD and finding a gap in them.
|
|
|
|
Parameters
|
|
----------
|
|
current_frequency : np.ndarray
|
|
Current EOD frequency array / the current fish of interest.
|
|
percentiles_ids : np.ndarray
|
|
Array of track IDs of the medians of all other fish in the current
|
|
window.
|
|
frequency_percentiles : np.ndarray
|
|
Array of percentiles frequencies of all other fish in the current window.
|
|
config : ConfLoader
|
|
Configuration file.
|
|
data : LoadData
|
|
Data to find the search frequency from.
|
|
|
|
Returns
|
|
-------
|
|
float
|
|
|
|
"""
|
|
# frequency window where second filter filters is potentially allowed
|
|
# to filter. This is the search window, in which we want to find
|
|
# a gap in the other fish's EODs.
|
|
current_median = np.median(current_frequency)
|
|
search_window = np.arange(
|
|
current_median + config.search_df_lower,
|
|
current_median + config.search_df_upper,
|
|
config.search_res,
|
|
)
|
|
|
|
# search window in boolean
|
|
bool_lower = np.ones_like(search_window, dtype=bool)
|
|
bool_upper = np.ones_like(search_window, dtype=bool)
|
|
search_window_bool = np.ones_like(search_window, dtype=bool)
|
|
|
|
# make seperate arrays from the qartiles
|
|
q25 = np.asarray([i[0] for i in frequency_percentiles])
|
|
q75 = np.asarray([i[2] for i in frequency_percentiles])
|
|
|
|
# get tracks that fall into search window
|
|
check_track_ids = percentiles_ids[
|
|
(q25 > current_median) & (q75 < search_window[-1])
|
|
]
|
|
|
|
# iterate through theses tracks
|
|
if check_track_ids.size != 0:
|
|
for j, check_track_id in enumerate(check_track_ids):
|
|
q25_temp = q25[percentiles_ids == check_track_id]
|
|
q75_temp = q75[percentiles_ids == check_track_id]
|
|
|
|
bool_lower[search_window > q25_temp - config.search_res] = False
|
|
bool_upper[search_window < q75_temp + config.search_res] = False
|
|
search_window_bool[
|
|
(bool_lower == False) & (bool_upper == False)
|
|
] = False
|
|
|
|
# find gaps in search window
|
|
search_window_indices = np.arange(len(search_window))
|
|
|
|
# get search window gaps
|
|
# taking the diff of a boolean array gives non zero values where the
|
|
# array changes from true to false or vice versa
|
|
|
|
search_window_gaps = np.diff(search_window_bool, append=np.nan)
|
|
nonzeros = search_window_gaps[np.nonzero(search_window_gaps)[0]]
|
|
nonzeros = nonzeros[~np.isnan(nonzeros)]
|
|
|
|
if len(nonzeros) == 0:
|
|
return config.default_search_freq
|
|
|
|
# if the first value is -1, the array starst with true, so a gap
|
|
if nonzeros[0] == -1:
|
|
stops = search_window_indices[search_window_gaps == -1]
|
|
starts = np.append(
|
|
0, search_window_indices[search_window_gaps == 1]
|
|
)
|
|
|
|
# if the last value is -1, the array ends with true, so a gap
|
|
if nonzeros[-1] == 1:
|
|
stops = np.append(
|
|
search_window_indices[search_window_gaps == -1],
|
|
len(search_window) - 1,
|
|
)
|
|
|
|
# else it starts with false, so no gap
|
|
if nonzeros[0] == 1:
|
|
stops = search_window_indices[search_window_gaps == -1]
|
|
starts = search_window_indices[search_window_gaps == 1]
|
|
|
|
# if the last value is -1, the array ends with true, so a gap
|
|
if nonzeros[-1] == 1:
|
|
stops = np.append(
|
|
search_window_indices[search_window_gaps == -1],
|
|
len(search_window),
|
|
)
|
|
|
|
# get the frequency ranges of the gaps
|
|
search_windows = [search_window[x:y] for x, y in zip(starts, stops)]
|
|
search_windows_lens = [len(x) for x in search_windows]
|
|
longest_search_window = search_windows[np.argmax(search_windows_lens)]
|
|
|
|
# the center of the search frequency band is then the center of
|
|
# the longest gap
|
|
|
|
search_freq = array_center(longest_search_window) - current_median
|
|
|
|
return search_freq
|
|
|
|
return config.default_search_freq
|
|
|
|
|
|
def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
|
|
assert plot in [
|
|
"save",
|
|
"show",
|
|
"false",
|
|
], "plot must be 'save', 'show' or 'false'"
|
|
|
|
assert debug in [
|
|
"false",
|
|
"electrode",
|
|
"fish",
|
|
], "debug must be 'false', 'electrode' or 'fish'"
|
|
|
|
if debug != "false":
|
|
assert plot == "show", "debug mode only runs when plot is 'show'"
|
|
|
|
# load raw file
|
|
print("datapath", datapath)
|
|
data = LoadData(datapath)
|
|
|
|
# load config file
|
|
config = ConfLoader("chirpdetector_conf.yml")
|
|
|
|
# set time window
|
|
window_duration = config.window * data.raw_rate
|
|
window_overlap = config.overlap * data.raw_rate
|
|
window_edge = config.edge * data.raw_rate
|
|
|
|
# check if window duration and window ovelap is even, otherwise the half
|
|
# of the duration or window overlap would return a float, thus an
|
|
# invalid index
|
|
|
|
if window_duration % 2 == 0:
|
|
window_duration = int(window_duration)
|
|
else:
|
|
raise ValueError("Window duration must be even.")
|
|
|
|
if window_overlap % 2 == 0:
|
|
window_overlap = int(window_overlap)
|
|
else:
|
|
raise ValueError("Window overlap must be even.")
|
|
|
|
# make time array for raw data
|
|
raw_time = np.arange(data.raw.shape[0]) / data.raw_rate
|
|
|
|
# good chirp times for data: 2022-06-02-10_00
|
|
# window_start_index = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
|
|
# window_duration_index = 60 * data.raw_rate
|
|
|
|
# t0 = 0
|
|
# dt = data.raw.shape[0]
|
|
# window_start_seconds = (23495 + ((28336-23495)/3)) * data.raw_rate
|
|
# window_duration_seconds = (28336 - 23495) * data.raw_rate
|
|
|
|
window_start_index = 0
|
|
window_duration_index = data.raw.shape[0]
|
|
|
|
# generate starting points of rolling window
|
|
window_start_indices = np.arange(
|
|
window_start_index,
|
|
window_start_index + window_duration_index,
|
|
window_duration - (window_overlap + 2 * window_edge),
|
|
dtype=int,
|
|
)
|
|
|
|
# ititialize lists to store data
|
|
multiwindow_chirps = []
|
|
multiwindow_ids = []
|
|
|
|
for st, window_start_index in enumerate(window_start_indices):
|
|
logger.info(f"Processing window {st+1} of {len(window_start_indices)}")
|
|
|
|
window_start_seconds = window_start_index / data.raw_rate
|
|
window_duration_seconds = window_duration / data.raw_rate
|
|
|
|
# set index window
|
|
window_stop_index = window_start_index + window_duration
|
|
|
|
# calucate median of fish frequencies in window
|
|
median_freq, median_ids = window_median_all_track_ids(
|
|
data, window_start_seconds, window_duration_seconds
|
|
)
|
|
|
|
# iterate through all fish
|
|
for tr, track_id in enumerate(
|
|
np.unique(data.ident[~np.isnan(data.ident)])
|
|
):
|
|
logger.debug(f"Processing track {tr} of {len(data.ids)}")
|
|
|
|
# get index of track data in this time window
|
|
track_window_index = np.arange(len(data.idx))[
|
|
(data.ident == track_id)
|
|
& (data.time[data.idx] >= window_start_seconds)
|
|
& (
|
|
data.time[data.idx]
|
|
<= (window_start_seconds + window_duration_seconds)
|
|
)
|
|
]
|
|
|
|
# get tracked frequencies and their times
|
|
current_frequencies = data.freq[track_window_index]
|
|
current_powers = data.powers[track_window_index, :]
|
|
|
|
# check if tracked data available in this window
|
|
if len(current_frequencies) < 3:
|
|
logger.warning(
|
|
f"Track {track_id} has no data in window {st}, skipping."
|
|
)
|
|
continue
|
|
|
|
# check if there are powers available in this window
|
|
nanchecker = np.unique(np.isnan(current_powers))
|
|
if (len(nanchecker) == 1) and nanchecker[0] is True:
|
|
logger.warning(
|
|
f"No powers available for track {track_id} window {st},"
|
|
"skipping."
|
|
)
|
|
continue
|
|
|
|
# find the strongest electrodes for the current fish in the current
|
|
# window
|
|
|
|
best_electrode_index = np.argsort(
|
|
np.nanmean(current_powers, axis=0)
|
|
)[-config.number_electrodes :]
|
|
|
|
# find a frequency above the baseline of the current fish in which
|
|
# no other fish is active to search for chirps there
|
|
|
|
search_frequency = find_searchband(
|
|
config=config,
|
|
current_frequency=current_frequencies,
|
|
percentiles_ids=median_ids,
|
|
data=data,
|
|
frequency_percentiles=median_freq,
|
|
)
|
|
|
|
# add all chirps that are detected on mulitple electrodes for one
|
|
# fish fish in one window to this list
|
|
|
|
multielectrode_chirps = []
|
|
|
|
# iterate through electrodes
|
|
for el, electrode_index in enumerate(best_electrode_index):
|
|
logger.debug(
|
|
f"Processing electrode {el+1} of "
|
|
f"{len(best_electrode_index)}"
|
|
)
|
|
|
|
# LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------
|
|
|
|
# load region of interest of raw data file
|
|
current_raw_data = data.raw[
|
|
window_start_index:window_stop_index, electrode_index
|
|
]
|
|
current_raw_time = raw_time[
|
|
window_start_index:window_stop_index
|
|
]
|
|
|
|
# EXTRACT FEATURES --------------------------------------------
|
|
|
|
# filter baseline and above
|
|
baselineband, searchband = extract_frequency_bands(
|
|
raw_data=current_raw_data,
|
|
samplerate=data.raw_rate,
|
|
baseline_track=current_frequencies,
|
|
searchband_center=search_frequency,
|
|
minimal_bandwidth=config.minimal_bandwidth,
|
|
)
|
|
|
|
# compute envelope of baseline band to find dips
|
|
# in the baseline envelope
|
|
|
|
baseline_envelope_unfiltered = envelope(
|
|
signal=baselineband,
|
|
samplerate=data.raw_rate,
|
|
cutoff_frequency=config.baseline_envelope_cutoff,
|
|
)
|
|
|
|
# create a mask that removes areas where amplitudes are very
|
|
# because the instantaneous frequency is not reliable there
|
|
|
|
amplitude_mask = mask_low_amplitudes(
|
|
baseline_envelope_unfiltered, config.baseline_min_amplitude
|
|
)
|
|
|
|
# highpass filter baseline envelope to remove slower
|
|
# fluctuations e.g. due to motion envelope
|
|
|
|
baseline_envelope = bandpass_filter(
|
|
signal=baseline_envelope_unfiltered,
|
|
samplerate=data.raw_rate,
|
|
lowf=config.baseline_envelope_bandpass_lowf,
|
|
highf=config.baseline_envelope_bandpass_highf,
|
|
)
|
|
|
|
# invert baseline envelope to find troughs in the baseline
|
|
|
|
baseline_envelope = -baseline_envelope
|
|
|
|
# compute the envelope of the search band. Peaks in the search
|
|
# band envelope correspond to troughs in the baseline envelope
|
|
# during chirps
|
|
|
|
search_envelope_unfiltered = envelope(
|
|
signal=searchband,
|
|
samplerate=data.raw_rate,
|
|
cutoff_frequency=config.search_envelope_cutoff,
|
|
)
|
|
|
|
search_envelope = search_envelope_unfiltered
|
|
|
|
# compute instantaneous frequency of the baseline band to find
|
|
# anomalies during a chirp, i.e. a frequency jump upwards or
|
|
# sometimes downwards. We do not fully understand why the
|
|
# instantaneous frequency can also jump downwards during a
|
|
# chirp. This phenomenon is only observed on chirps on a narrow
|
|
# filtered baseline such as the one we are working with.
|
|
|
|
baseline_frequency = instantaneous_frequency(
|
|
baselineband,
|
|
data.raw_rate,
|
|
config.baseline_frequency_smoothing,
|
|
)
|
|
|
|
# Take the absolute of the instantaneous frequency to invert
|
|
# troughs into peaks. This is nessecary since the narrow
|
|
# pass band introduces these anomalies. Also substract by the
|
|
# median to set it to 0.
|
|
|
|
baseline_frequency_filtered = np.abs(
|
|
baseline_frequency - np.median(baseline_frequency)
|
|
)
|
|
|
|
# check if there is at least one superthreshold peak on the
|
|
# instantaneous and exit the loop if not. This is used to
|
|
# prevent windows that do definetely not include a chirp
|
|
# to enter normalization, where small changes due to noise
|
|
# would be amplified
|
|
|
|
if not has_chirp(
|
|
baseline_frequency_filtered[amplitude_mask],
|
|
config.baseline_frequency_peakheight,
|
|
):
|
|
continue
|
|
|
|
# CUT OFF OVERLAP ---------------------------------------------
|
|
|
|
# cut off snippet at start and end of each window to remove
|
|
# filter effects
|
|
|
|
# get arrays with raw samplerate without edges
|
|
no_edges = np.arange(
|
|
int(window_edge), len(baseline_envelope) - int(window_edge)
|
|
)
|
|
|
|
current_raw_time = current_raw_time[no_edges]
|
|
baselineband = baselineband[no_edges]
|
|
baseline_envelope_unfiltered = baseline_envelope_unfiltered[
|
|
no_edges
|
|
]
|
|
searchband = searchband[no_edges]
|
|
baseline_envelope = baseline_envelope[no_edges]
|
|
search_envelope_unfiltered = search_envelope_unfiltered[
|
|
no_edges
|
|
]
|
|
search_envelope = search_envelope[no_edges]
|
|
|
|
baseline_frequency = baseline_frequency[no_edges]
|
|
baseline_frequency_filtered = baseline_frequency_filtered[
|
|
no_edges
|
|
]
|
|
baseline_frequency_time = current_raw_time
|
|
|
|
# # get instantaneous frequency withoup edges
|
|
# no_edges_t0 = int(window_edge) / data.raw_rate
|
|
# no_edges_t1 = baseline_frequency_time[-1] - (
|
|
# int(window_edge) / data.raw_rate
|
|
# )
|
|
# no_edges = (baseline_frequency_time >= no_edges_t0) & (
|
|
# baseline_frequency_time <= no_edges_t1
|
|
# )
|
|
|
|
# baseline_frequency_filtered = baseline_frequency_filtered[no_edges]
|
|
# baseline_frequency = baseline_frequency[no_edges]
|
|
# baseline_frequency_time = (
|
|
# baseline_frequency_time[no_edges] + window_start_seconds
|
|
# )
|
|
|
|
# NORMALIZE ---------------------------------------------------
|
|
|
|
# normalize all three feature arrays to the same range to make
|
|
# peak detection simpler
|
|
|
|
baseline_envelope = minmaxnorm([baseline_envelope])[0]
|
|
search_envelope = minmaxnorm([search_envelope])[0]
|
|
baseline_frequency_filtered = minmaxnorm(
|
|
[baseline_frequency_filtered]
|
|
)[0]
|
|
|
|
# PEAK DETECTION ----------------------------------------------
|
|
|
|
# detect peaks baseline_enelope
|
|
baseline_peak_indices, _ = find_peaks(
|
|
baseline_envelope, prominence=config.baseline_prominence
|
|
)
|
|
# detect peaks search_envelope
|
|
search_peak_indices, _ = find_peaks(
|
|
search_envelope, prominence=config.search_prominence
|
|
)
|
|
# detect peaks inst_freq_filtered
|
|
frequency_peak_indices, _ = find_peaks(
|
|
baseline_frequency_filtered,
|
|
prominence=config.frequency_prominence,
|
|
)
|
|
|
|
# DETECT CHIRPS IN SEARCH WINDOW ------------------------------
|
|
|
|
# get the peak timestamps from the peak indices
|
|
baseline_peak_timestamps = current_raw_time[
|
|
baseline_peak_indices
|
|
]
|
|
search_peak_timestamps = current_raw_time[search_peak_indices]
|
|
|
|
frequency_peak_timestamps = baseline_frequency_time[
|
|
frequency_peak_indices
|
|
]
|
|
|
|
# check if one list is empty and if so, skip to the next
|
|
# electrode because a chirp cannot be detected if one is empty
|
|
|
|
one_feature_empty = (
|
|
len(baseline_peak_timestamps) == 0
|
|
or len(search_peak_timestamps) == 0
|
|
or len(frequency_peak_timestamps) == 0
|
|
)
|
|
|
|
if one_feature_empty and (debug == "false"):
|
|
continue
|
|
|
|
# group peak across feature arrays but only if they
|
|
# occur in all 3 feature arrays
|
|
|
|
sublists = [
|
|
list(baseline_peak_timestamps),
|
|
list(search_peak_timestamps),
|
|
list(frequency_peak_timestamps),
|
|
]
|
|
|
|
singleelectrode_chirps = group_timestamps(
|
|
sublists=sublists,
|
|
at_least_in=3,
|
|
difference_threshold=config.chirp_window_threshold,
|
|
)
|
|
|
|
# check it there are chirps detected after grouping, continue
|
|
# with the loop if not
|
|
|
|
if (len(singleelectrode_chirps) == 0) and (debug == "false"):
|
|
continue
|
|
|
|
# append chirps from this electrode to the multilectrode list
|
|
multielectrode_chirps.append(singleelectrode_chirps)
|
|
|
|
# only initialize the plotting buffer if chirps are detected
|
|
chirp_detected = el == (config.number_electrodes - 1) & (
|
|
plot in ["show", "save"]
|
|
)
|
|
|
|
if chirp_detected or (debug != "elecrode"):
|
|
logger.debug("Detected chirp, ititialize buffer ...")
|
|
|
|
# save data to Buffer
|
|
buffer = ChirpPlotBuffer(
|
|
config=config,
|
|
t0=window_start_seconds,
|
|
dt=window_duration_seconds,
|
|
electrode=electrode_index,
|
|
track_id=track_id,
|
|
data=data,
|
|
time=current_raw_time,
|
|
baseline_envelope_unfiltered=baseline_envelope_unfiltered,
|
|
baseline=baselineband,
|
|
baseline_envelope=baseline_envelope,
|
|
baseline_peaks=baseline_peak_indices,
|
|
search_frequency=search_frequency,
|
|
search=searchband,
|
|
search_envelope_unfiltered=search_envelope_unfiltered,
|
|
search_envelope=search_envelope,
|
|
search_peaks=search_peak_indices,
|
|
frequency_time=baseline_frequency_time,
|
|
frequency=baseline_frequency,
|
|
frequency_filtered=baseline_frequency_filtered,
|
|
frequency_peaks=frequency_peak_indices,
|
|
)
|
|
|
|
logger.debug("Buffer initialized!")
|
|
|
|
if debug == "electrode":
|
|
logger.info(f"Plotting electrode {el} ...")
|
|
buffer.plot_buffer(chirps=singleelectrode_chirps, plot=plot)
|
|
|
|
logger.debug(
|
|
f"Processed all electrodes for fish {track_id} for this"
|
|
"window, sorting chirps ..."
|
|
)
|
|
|
|
# check if there are chirps detected in multiple electrodes and
|
|
# continue the loop if not
|
|
|
|
if (len(multielectrode_chirps) == 0) and (debug == "false"):
|
|
continue
|
|
|
|
# validate multielectrode chirps, i.e. check if they are
|
|
# detected in at least 'config.min_electrodes' electrodes
|
|
|
|
multielectrode_chirps_validated = group_timestamps(
|
|
sublists=multielectrode_chirps,
|
|
at_least_in=config.minimum_electrodes,
|
|
difference_threshold=config.chirp_window_threshold,
|
|
)
|
|
|
|
# add validated chirps to the list that tracks chirps across there
|
|
# rolling time windows
|
|
|
|
multiwindow_chirps.append(multielectrode_chirps_validated)
|
|
multiwindow_ids.append(track_id)
|
|
|
|
logger.info(
|
|
f"Found {len(multielectrode_chirps_validated)}"
|
|
f" chirps for fish {track_id} in this window!"
|
|
)
|
|
# if chirps are detected and the plot flag is set, plot the
|
|
# chirps, otheswise try to delete the buffer if it exists
|
|
|
|
if debug == "fish":
|
|
logger.info(f"Plotting fish {track_id} ...")
|
|
buffer.plot_buffer(multielectrode_chirps_validated, plot)
|
|
|
|
if (
|
|
(len(multielectrode_chirps_validated) > 0)
|
|
& (plot in ["show", "save"])
|
|
& (debug == "false")
|
|
):
|
|
try:
|
|
buffer.plot_buffer(multielectrode_chirps_validated, plot)
|
|
del buffer
|
|
except NameError:
|
|
pass
|
|
else:
|
|
try:
|
|
del buffer
|
|
except NameError:
|
|
pass
|
|
|
|
# flatten list of lists containing chirps and create
|
|
# an array of fish ids that correspond to the chirps
|
|
|
|
multiwindow_chirps_flat = []
|
|
multiwindow_ids_flat = []
|
|
for track_id in np.unique(multiwindow_ids):
|
|
# get chirps for this fish and flatten the list
|
|
current_track_bool = np.asarray(multiwindow_ids) == track_id
|
|
current_track_chirps = flatten(
|
|
list(compress(multiwindow_chirps, current_track_bool))
|
|
)
|
|
|
|
# add flattened chirps to the list
|
|
multiwindow_chirps_flat.extend(current_track_chirps)
|
|
multiwindow_ids_flat.extend(
|
|
list(np.ones_like(current_track_chirps) * track_id)
|
|
)
|
|
|
|
# purge duplicates, i.e. chirps that are very close to each other
|
|
# duplites arise due to overlapping windows
|
|
|
|
purged_chirps = []
|
|
purged_ids = []
|
|
for track_id in np.unique(multiwindow_ids_flat):
|
|
tr_chirps = np.asarray(multiwindow_chirps_flat)[
|
|
np.asarray(multiwindow_ids_flat) == track_id
|
|
]
|
|
if len(tr_chirps) > 0:
|
|
tr_chirps_purged = purge_duplicates(
|
|
tr_chirps, config.chirp_window_threshold
|
|
)
|
|
purged_chirps.extend(list(tr_chirps_purged))
|
|
purged_ids.extend(list(np.ones_like(tr_chirps_purged) * track_id))
|
|
|
|
# sort chirps by time
|
|
purged_chirps = np.asarray(purged_chirps)
|
|
purged_ids = np.asarray(purged_ids)
|
|
purged_ids = purged_ids[np.argsort(purged_chirps)]
|
|
purged_chirps = purged_chirps[np.argsort(purged_chirps)]
|
|
|
|
# save them into the data directory
|
|
np.save(datapath + "chirps.npy", purged_chirps)
|
|
np.save(datapath + "chirp_ids.npy", purged_ids)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-05-13-10_00/"
|
|
# datapath = "/home/weygoldt/Data/uni/efishdata/2016-colombia/fishgrid/2016-04-09-22_25/"
|
|
# datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-03-13-10_00/"
|
|
datapath = "../data/2022-06-02-10_00/"
|
|
chirpdetection(datapath, plot="show", debug="false")
|