diff --git a/code/chirpdetection.py b/code/chirpdetection.py old mode 100644 new mode 100755 index 200a6a4..2a48025 --- a/code/chirpdetection.py +++ b/code/chirpdetection.py @@ -2,19 +2,21 @@ from itertools import compress from dataclasses import dataclass import numpy as np -from IPython import embed import matplotlib.pyplot as plt from scipy.signal import find_peaks -from scipy.ndimage import gaussian_filter1d -from thunderfish.dataloader import DataLoader from thunderfish.powerspectrum import spectrogram, decibel from sklearn.preprocessing import normalize from modules.filters import bandpass_filter, envelope, highpass_filter from modules.filehandling import ConfLoader, LoadData, make_outputdir -from modules.datahandling import flatten, purge_duplicates, group_timestamps from modules.plotstyle import PlotStyle from modules.logger import makeLogger +from modules.datahandling import ( + flatten, + purge_duplicates, + group_timestamps, + instantaneous_frequency, +) logger = makeLogger(__name__) @@ -23,6 +25,12 @@ ps = PlotStyle() @dataclass class PlotBuffer: + + """ + Buffer to save data that is created in the main detection loop + and plot it outside the detecion loop. + """ + config: ConfLoader t0: float dt: float @@ -73,14 +81,16 @@ class PlotBuffer: figsize=(20 / 2.54, 12 / 2.54), constrained_layout=True, sharex=True, - sharey='row', + sharey="row", ) # plot spectrogram plot_spectrogram(axs[0], data_oi, self.data.raw_rate, self.t0) for chirp in chirps: - axs[0].scatter(chirp, np.median(self.frequency), c=ps.red) + axs[0].scatter( + chirp, np.median(self.frequency), c=ps.black, marker="x" + ) # plot waveform of filtered signal axs[1].plot(self.time, self.baseline, c=ps.green) @@ -88,7 +98,7 @@ class PlotBuffer: # plot waveform of filtered search signal axs[2].plot(self.time, self.search) - # plot baseline instantaneos frequency + # plot baseline instantaneous frequency axs[3].plot(self.frequency_time, self.frequency) # plot filtered and rectified envelope @@ -114,8 +124,9 @@ class PlotBuffer: self.frequency_filtered[self.frequency_peaks], c=ps.red, ) - axs[0].set_ylim(np.max(self.frequency)-200, - top=np.max(self.frequency)+200) + axs[0].set_ylim( + np.max(self.frequency) - 200, top=np.max(self.frequency) + 200 + ) axs[6].set_xlabel("Time [s]") axs[0].set_title("Spectrogram") axs[1].set_title("Fitered baseline") @@ -123,66 +134,23 @@ class PlotBuffer: axs[3].set_title("Fitered baseline instanenous frequency") axs[4].set_title("Filtered envelope of baseline envelope") axs[5].set_title("Search envelope") - axs[6].set_title( - "Filtered absolute instantaneous frequency") + axs[6].set_title("Filtered absolute instantaneous frequency") - if plot == 'show': + if plot == "show": plt.show() - elif plot == 'save': + elif plot == "save": make_outputdir(self.config.outputdir) - out = make_outputdir(self.config.outputdir + - self.data.datapath.split('/')[-2] + '/') + out = make_outputdir( + self.config.outputdir + self.data.datapath.split("/")[-2] + "/" + ) plt.savefig(f"{out}{self.track_id}_{self.t0}.pdf") plt.close() -def instantaneos_frequency( - signal: np.ndarray, samplerate: int -) -> tuple[np.ndarray, np.ndarray]: - """ - Compute the instantaneous frequency of a signal. - - Parameters - ---------- - signal : np.ndarray - Signal to compute the instantaneous frequency from. - samplerate : int - Samplerate of the signal. - - Returns - ------- - tuple[np.ndarray, np.ndarray] - - """ - # calculate instantaneos frequency with zero crossings - roll_signal = np.roll(signal, shift=1) - time_signal = np.arange(len(signal)) / samplerate - period_index = np.arange(len(signal))[( - roll_signal < 0) & (signal >= 0)][1:-1] - - upper_bound = np.abs(signal[period_index]) - lower_bound = np.abs(signal[period_index - 1]) - upper_time = np.abs(time_signal[period_index]) - lower_time = np.abs(time_signal[period_index - 1]) - - # create ratio - lower_ratio = lower_bound / (lower_bound + upper_bound) - - # appy to time delta - time_delta = upper_time - lower_time - true_zero = lower_time + lower_ratio * time_delta - - # create new time array - inst_freq_time = true_zero[:-1] + 0.5 * np.diff(true_zero) - - # compute frequency - inst_freq = gaussian_filter1d(1 / np.diff(true_zero), 5) - - return inst_freq_time, inst_freq - - -def plot_spectrogram(axis, signal: np.ndarray, samplerate: float, t0: float) -> None: +def plot_spectrogram( + axis, signal: np.ndarray, samplerate: float, window_start_seconds: float +) -> None: """ Plot a spectrogram of a signal. @@ -194,7 +162,7 @@ def plot_spectrogram(axis, signal: np.ndarray, samplerate: float, t0: float) -> Signal to plot the spectrogram from. samplerate : float Samplerate of the signal. - t0 : float + window_start_seconds : float Start time of the signal. """ @@ -204,21 +172,30 @@ def plot_spectrogram(axis, signal: np.ndarray, samplerate: float, t0: float) -> spec_power, spec_freqs, spec_times = spectrogram( signal, ratetime=samplerate, - freq_resolution=50, - overlap_frac=0.2, + freq_resolution=20, + overlap_frac=0.5, ) - axis.pcolormesh( - spec_times + t0, - spec_freqs, + axis.imshow( decibel(spec_power), + extent=[ + spec_times[0] + window_start_seconds, + spec_times[-1] + window_start_seconds, + spec_freqs[0], + spec_freqs[-1], + ], + aspect="auto", + origin="lower", + interpolation="gaussian", ) - axis.set_ylim(200, 1200) - -def double_bandpass( - data: DataLoader, samplerate: int, freqs: np.ndarray, search_freq: float +def extract_frequency_bands( + raw_data: np.ndarray, + samplerate: int, + baseline_track: np.ndarray, + searchband_center: float, + minimal_bandwidth: float, ) -> tuple[np.ndarray, np.ndarray]: """ Apply a bandpass filter to the baseline of a signal and a second bandpass @@ -226,14 +203,16 @@ def double_bandpass( Parameters ---------- - data : DataLoader + raw_data : np.ndarray Data to apply the filter to. samplerate : int Samplerate of the signal. - freqs : np.ndarray + baseline_track : np.ndarray Tracked fundamental frequencies of the signal. - search_freq : float + searchband_center: float Frequency to search for above or below the baseline. + minimal_bandwidth : float + Minimal bandwidth of the filter. Returns ------- @@ -241,25 +220,31 @@ def double_bandpass( """ # compute boundaries to filter baseline - q25, q75 = np.percentile(freqs, [25, 75]) + q25, q50, q75 = np.percentile(baseline_track, [25, 50, 75]) # check if percentile delta is too small - if q75 - q25 < 5: - median = np.median(freqs) - q25, q75 = median - 2.5, median + 2.5 + if q75 - q25 < 10: + q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2 # filter baseline - filtered_baseline = bandpass_filter(data, samplerate, lowf=q25, highf=q75) + filtered_baseline = bandpass_filter( + raw_data, samplerate, lowf=q25, highf=q75 + ) # filter search area filtered_search_freq = bandpass_filter( - data, samplerate, lowf=q25 + search_freq, highf=q75 + search_freq + raw_data, + samplerate, + lowf=searchband_center + q50 - minimal_bandwidth / 2, + highf=searchband_center + q50 + minimal_bandwidth / 2, ) - return (filtered_baseline, filtered_search_freq) + return filtered_baseline, filtered_search_freq -def freqmedian_allfish(data: LoadData, t0: float, dt: float) -> tuple[float, list[int]]: +def window_median_all_track_ids( + data: LoadData, window_start_seconds: float, window_duration_seconds: float +) -> tuple[float, list[int]]: """ Calculate the median frequency of all fish in a given time window. @@ -267,9 +252,9 @@ def freqmedian_allfish(data: LoadData, t0: float, dt: float) -> tuple[float, lis ---------- data : LoadData Data to calculate the median frequency from. - t0 : float + window_start_seconds : float Start time of the window. - dt : float + window_duration_seconds : float Duration of the window. Returns @@ -283,8 +268,12 @@ def freqmedian_allfish(data: LoadData, t0: float, dt: float) -> tuple[float, lis for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): window_idx = np.arange(len(data.idx))[ - (data.ident == track_id) & (data.time[data.idx] >= t0) & ( - data.time[data.idx] <= (t0 + dt)) + (data.ident == track_id) + & (data.time[data.idx] >= window_start_seconds) + & ( + data.time[data.idx] + <= (window_start_seconds + window_duration_seconds) + ) ] if len(data.freq[window_idx]) > 0: @@ -298,9 +287,121 @@ def freqmedian_allfish(data: LoadData, t0: float, dt: float) -> tuple[float, lis return median_freq, track_ids +def find_searchband( + freq_temp: np.ndarray, + median_ids: np.ndarray, + median_freq: np.ndarray, + config: ConfLoader, + data: LoadData, +) -> float: + """ + Find the search frequency band for each fish by checking which fish EODs + are above the current EOD and finding a gap in them. + + Parameters + ---------- + freq_temp : np.ndarray + Current EOD frequency array / the current fish of interest. + median_ids : np.ndarray + Array of track IDs of the medians of all other fish in the current + window. + median_freq : np.ndarray + Array of median frequencies of all other fish in the current window. + config : ConfLoader + Configuration file. + data : LoadData + Data to find the search frequency from. + + Returns + ------- + float + + """ + # frequency where second filter filters + search_window = np.arange( + np.median(freq_temp) + config.search_df_lower, + np.median(freq_temp) + config.search_df_upper, + config.search_res, + ) + + # search window in boolean + search_window_bool = np.ones(len(search_window), dtype=bool) + + # get tracks that fall into search window + check_track_ids = median_ids[ + (median_freq > search_window[0]) & (median_freq < search_window[-1]) + ] + + # iterate through theses tracks + if check_track_ids.size != 0: + + for j, check_track_id in enumerate(check_track_ids): + + q1, q2 = np.percentile( + data.freq[data.ident == check_track_id], + config.search_freq_percentiles, + ) + + search_window_bool[ + (search_window > q1) & (search_window < q2) + ] = False + + # find gaps in search window + search_window_indices = np.arange(len(search_window)) + + # get search window gaps + search_window_gaps = np.diff(search_window_bool, append=np.nan) + nonzeros = search_window_gaps[np.nonzero(search_window_gaps)[0]] + nonzeros = nonzeros[~np.isnan(nonzeros)] + + # if the first value is -1, the array starst with true, so a gap + if nonzeros[0] == -1: + stops = search_window_indices[search_window_gaps == -1] + starts = np.append( + 0, search_window_indices[search_window_gaps == 1] + ) + + # if the last value is -1, the array ends with true, so a gap + if nonzeros[-1] == 1: + stops = np.append( + search_window_indices[search_window_gaps == -1], + len(search_window) - 1, + ) + + # else it starts with false, so no gap + if nonzeros[0] == 1: + stops = search_window_indices[search_window_gaps == -1] + starts = search_window_indices[search_window_gaps == 1] + + # if the last value is -1, the array ends with true, so a gap + if nonzeros[-1] == 1: + stops = np.append( + search_window_indices[search_window_gaps == -1], + len(search_window), + ) + + # get the frequency ranges of the gaps + search_windows = [search_window[x:y] for x, y in zip(starts, stops)] + search_windows_lens = [len(x) for x in search_windows] + longest_search_window = search_windows[np.argmax(search_windows_lens)] + + search_freq = ( + longest_search_window[-1] - longest_search_window[0] + ) / 2 + + else: + search_freq = config.default_search_freq + + return search_freq + + def main(datapath: str, plot: str) -> None: - assert plot in ["save", "show", "false"] + assert plot in [ + "save", + "show", + "false", + ], "plot must be 'save', 'show' or 'false'" # load raw file data = LoadData(datapath) @@ -313,13 +414,15 @@ def main(datapath: str, plot: str) -> None: window_overlap = config.overlap * data.raw_rate window_edge = config.edge * data.raw_rate - # check if window duration is even + # check if window duration and window ovelap is even, otherwise the half + # of the duration or window overlap would return a float, thus an + # invalid index + if window_duration % 2 == 0: window_duration = int(window_duration) else: raise ValueError("Window duration must be even.") - # check if window ovelap is even if window_overlap % 2 == 0: window_overlap = int(window_overlap) else: @@ -328,339 +431,395 @@ def main(datapath: str, plot: str) -> None: # make time array for raw data raw_time = np.arange(data.raw.shape[0]) / data.raw_rate - # # good chirp times for data: 2022-06-02-10_00 - # t0 = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate - # dt = 60 * data.raw_rate + # good chirp times for data: 2022-06-02-10_00 + window_start_seconds = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate + window_duration_seconds = 60 * data.raw_rate - t0 = 0 - dt = data.raw.shape[0] + # t0 = 0 + # dt = data.raw.shape[0] # generate starting points of rolling window - window_starts = np.arange( - t0, - t0 + dt, + window_start_indices = np.arange( + window_start_seconds, + window_start_seconds + window_duration_seconds, window_duration - (window_overlap + 2 * window_edge), - dtype=int + dtype=int, ) # ititialize lists to store data - chirps = [] - fish_ids = [] + multiwindow_chirps = [] + multiwindow_ids = [] - for st, start_index in enumerate(window_starts): + for st, window_start_index in enumerate(window_start_indices): - logger.info(f"Processing window {st} of {len(window_starts)}") + logger.info(f"Processing window {st+1} of {len(window_start_indices)}") - # make t0 and dt - t0 = start_index / data.raw_rate - dt = window_duration / data.raw_rate + window_start_seconds = window_start_index / data.raw_rate + window_duration_seconds = window_duration / data.raw_rate # set index window - stop_index = start_index + window_duration + window_stop_index = window_start_index + window_duration # calucate median of fish frequencies in window - median_freq, median_ids = freqmedian_allfish(data, t0, dt) + median_freq, median_ids = window_median_all_track_ids( + data, window_start_seconds, window_duration_seconds + ) # iterate through all fish - for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): + for tr, track_id in enumerate( + np.unique(data.ident[~np.isnan(data.ident)]) + ): logger.debug(f"Processing track {tr} of {len(data.ids)}") # get index of track data in this time window - window_idx = np.arange(len(data.idx))[ - (data.ident == track_id) & (data.time[data.idx] >= t0) & ( - data.time[data.idx] <= (t0 + dt)) + track_window_index = np.arange(len(data.idx))[ + (data.ident == track_id) + & (data.time[data.idx] >= window_start_seconds) + & ( + data.time[data.idx] + <= (window_start_seconds + window_duration_seconds) + ) ] # get tracked frequencies and their times - freq_temp = data.freq[window_idx] - powers_temp = data.powers[window_idx, :] + current_frequencies = data.freq[track_window_index] + current_powers = data.powers[track_window_index, :] # approximate sampling rate to compute expected durations if there # is data available for this time window for this fish id + track_samplerate = np.mean(1 / np.diff(data.time)) - expected_duration = ((t0 + dt) - t0) * track_samplerate + expected_duration = ( + (window_start_seconds + window_duration_seconds) + - window_start_seconds + ) * track_samplerate # check if tracked data available in this window - if len(freq_temp) < expected_duration * 0.5: + if len(current_frequencies) < expected_duration / 2: logger.warning( - f"Track {track_id} has no data in window {st}, skipping.") + f"Track {track_id} has no data in window {st}, skipping." + ) continue # check if there are powers available in this window - nanchecker = np.unique(np.isnan(powers_temp)) - if (len(nanchecker) == 1) and nanchecker[0] == True: + nanchecker = np.unique(np.isnan(current_powers)) + if (len(nanchecker) == 1) and nanchecker[0] is True: logger.warning( - f"No powers available for track {track_id} window {st}, skipping.") + f"No powers available for track {track_id} window {st}," + "skipping." + ) continue - best_electrodes = np.argsort(np.nanmean( - powers_temp, axis=0))[-config.number_electrodes:] + # find the strongest electrodes for the current fish in the current + # window - # frequency where second filter filters - search_window = np.arange( - np.median(freq_temp)+config.search_df_lower, np.median( - freq_temp)+config.search_df_upper, config.search_res) + best_electrode_index = np.argsort( + np.nanmean(current_powers, axis=0) + )[-config.number_electrodes:] - # search window in boolean - search_window_bool = np.ones(len(search_window), dtype=bool) + # find a frequency above the baseline of the current fish in which + # no other fish is active to search for chirps there - # get tracks that fall into search window - check_track_ids = median_ids[(median_freq > search_window[0]) & ( - median_freq < search_window[-1])] + search_frequency = find_searchband( + config=config, + freq_temp=current_frequencies, + median_ids=median_ids, + data=data, + median_freq=median_freq, + ) - # iterate through theses tracks - if check_track_ids.size != 0: + # add all chirps that are detected on mulitple electrodes for one + # fish fish in one window to this list - for j, check_track_id in enumerate(check_track_ids): - - q1, q2 = np.percentile( - data.freq[data.ident == check_track_id], - config.search_freq_percentiles - ) - - search_window_bool[(search_window > q1) & ( - search_window < q2)] = False - - # find gaps in search window - search_window_indices = np.arange(len(search_window)) - - # get search window gaps - search_window_gaps = np.diff(search_window_bool, append=np.nan) - nonzeros = search_window_gaps[np.nonzero( - search_window_gaps)[0]] - nonzeros = nonzeros[~np.isnan(nonzeros)] - - # if the first value is -1, the array starst with true, so a gap - if nonzeros[0] == -1: - stops = search_window_indices[search_window_gaps == -1] - starts = np.append( - 0, search_window_indices[search_window_gaps == 1]) - - # if the last value is -1, the array ends with true, so a gap - if nonzeros[-1] == 1: - stops = np.append( - search_window_indices[search_window_gaps == -1], - len(search_window) - 1 - ) - - # else it starts with false, so no gap - if nonzeros[0] == 1: - stops = search_window_indices[search_window_gaps == -1] - starts = search_window_indices[search_window_gaps == 1] - - # if the last value is -1, the array ends with true, so a gap - if nonzeros[-1] == 1: - stops = np.append( - search_window_indices[search_window_gaps == -1], - len(search_window) - ) - - # get the frequency ranges of the gaps - search_windows = [search_window[x:y] - for x, y in zip(starts, stops)] - search_windows_lens = [len(x) for x in search_windows] - longest_search_window = search_windows[np.argmax( - search_windows_lens)] - - search_freq = ( - longest_search_window[1] - longest_search_window[0]) / 2 - - else: - search_freq = config.default_search_freq - - # ----------- chrips on the two best electrodes----------- - chirps_electrodes = [] + multielectrode_chirps = [] # iterate through electrodes - for el, electrode in enumerate(best_electrodes): + for el, electrode_index in enumerate(best_electrode_index): logger.debug( - f"Processing electrode {el} of {len(best_electrodes)}") + f"Processing electrode {el+1} of " + f"{len(best_electrode_index)}" + ) + + # LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------ # load region of interest of raw data file - data_oi = data.raw[start_index:stop_index, :] - time_oi = raw_time[start_index:stop_index] + current_raw_data = data.raw[ + window_start_index:window_stop_index, electrode_index + ] + current_raw_time = raw_time[ + window_start_index:window_stop_index + ] + + # EXTRACT FEATURES -------------------------------------------- # filter baseline and above - baseline, search = double_bandpass( - data_oi[:, electrode], - data.raw_rate, - freq_temp, - search_freq + baselineband, searchband = extract_frequency_bands( + raw_data=current_raw_data, + samplerate=data.raw_rate, + baseline_track=current_frequencies, + searchband_center=search_frequency, + minimal_bandwidth=config.minimal_bandwidth, ) - # compute instantaneous frequency on broad signal - broad_baseline = bandpass_filter( - data_oi[:, electrode], - data.raw_rate, - lowf=np.mean(freq_temp)-5, - highf=np.mean(freq_temp)+100 + # compute envelope of baseline band to find dips + # in the baseline envelope + + baseline_envelope_unfiltered = envelope( + signal=baselineband, + samplerate=data.raw_rate, + cutoff_frequency=config.baseline_envelope_cutoff, ) - # compute instantaneous frequency on narrow signal - baseline_freq_time, baseline_freq = instantaneos_frequency( - baseline, data.raw_rate + # highpass filter baseline envelope to remove slower + # fluctuations e.g. due to motion envelope + + baseline_envelope = bandpass_filter( + signal=baseline_envelope_unfiltered, + samplerate=data.raw_rate, + lowf=config.baseline_envelope_bandpass_lowf, + highf=config.baseline_envelope_bandpass_highf, ) - # compute envelopes - baseline_envelope_unfiltered = envelope( - baseline, data.raw_rate, config.envelope_cutoff) + # highbass filter introduced filter effects, i.e. oscillations + # around peaks. Compute the envelope of the highpass filtered + # and inverted baseline envelope to remove these oscillations + + baseline_envelope = -baseline_envelope + + baseline_envelope = envelope( + signal=baseline_envelope, + samplerate=data.raw_rate, + cutoff_frequency=config.baseline_envelope_envelope_cutoff, + ) + + # compute the envelope of the search band. Peaks in the search + # band envelope correspond to troughs in the baseline envelope + # during chirps + search_envelope = envelope( - search, data.raw_rate, config.envelope_cutoff) + signal=searchband, + samplerate=data.raw_rate, + cutoff_frequency=config.search_envelope_cutoff, + ) - # highpass filter envelopes - baseline_envelope = highpass_filter( - baseline_envelope_unfiltered, - data.raw_rate, - config.envelope_highpass_cutoff + # compute instantaneous frequency of the baseline band to find + # anomalies during a chirp, i.e. a frequency jump upwards or + # sometimes downwards. We do not fully understand why the + # instantaneous frequency can also jump downwards during a + # chirp. This phenomenon is only observed on chirps on a narrow + # filtered baseline such as the one we are working with. + + ( + baseline_frequency_time, + baseline_frequency, + ) = instantaneous_frequency( + signal=baselineband, + samplerate=data.raw_rate, + smoothing_window=config.baseline_frequency_smoothing, ) - # envelopes of filtered envelope of filtered baseline - baseline_envelope = envelope( - np.abs(baseline_envelope), - data.raw_rate, - config.envelope_envelope_cutoff + # bandpass filter the instantaneous frequency to remove slow + # fluctuations. Just as with the baseline envelope, we then + # compute the envelope of the signal to remove the oscillations + # around the peaks + + baseline_frequency_samplerate = np.mean( + np.diff(baseline_frequency_time) ) - # bandpass filter the instantaneous - inst_freq_filtered = bandpass_filter( - baseline_freq, - data.raw_rate, - lowf=config.instantaneous_lowf, - highf=config.instantaneous_highf + baseline_frequency_filtered = np.abs( + baseline_frequency - np.median(baseline_frequency) ) - # CUT OFF OVERLAP --------------------------------------------- + baseline_frequency_filtered = highpass_filter( + signal=baseline_frequency_filtered, + samplerate=baseline_frequency_samplerate, + cutoff=config.baseline_frequency_highpass_cutoff, + ) - # cut off first and last 0.5 * overlap at start and end - valid = np.arange( - int(window_edge), len(baseline_envelope) - - int(window_edge) + baseline_frequency_filtered = envelope( + signal=-baseline_frequency_filtered, + samplerate=baseline_frequency_samplerate, + cutoff_frequency=config.baseline_frequency_envelope_cutoff, ) - baseline_envelope_unfiltered = baseline_envelope_unfiltered[valid] - baseline_envelope = baseline_envelope[valid] - search_envelope = search_envelope[valid] - - # get inst freq valid snippet - valid_t0 = int(window_edge) / data.raw_rate - valid_t1 = baseline_freq_time[-1] - \ - (int(window_edge) / data.raw_rate) - - inst_freq_filtered = inst_freq_filtered[ - (baseline_freq_time >= valid_t0) & ( - baseline_freq_time <= valid_t1) - ] - baseline_freq = baseline_freq[ - (baseline_freq_time >= valid_t0) & ( - baseline_freq_time <= valid_t1) - ] + # CUT OFF OVERLAP --------------------------------------------- - baseline_freq_time = baseline_freq_time[ - (baseline_freq_time >= valid_t0) & ( - baseline_freq_time <= valid_t1) - ] + t0 + # cut off snippet at start and end of each window to remove + # filter effects - # overwrite raw time to valid region - time_oi = time_oi[valid] - baseline = baseline[valid] - broad_baseline = broad_baseline[valid] - search = search[valid] + # get arrays with raw samplerate without edges + no_edges = np.arange( + int(window_edge), len(baseline_envelope) - int(window_edge) + ) + current_raw_time = current_raw_time[no_edges] + baselineband = baselineband[no_edges] + searchband = searchband[no_edges] + baseline_envelope = baseline_envelope[no_edges] + search_envelope = search_envelope[no_edges] + + # get instantaneous frequency withoup edges + no_edges_t0 = int(window_edge) / data.raw_rate + no_edges_t1 = baseline_frequency_time[-1] - ( + int(window_edge) / data.raw_rate + ) + no_edges = (baseline_frequency_time >= no_edges_t0) & ( + baseline_frequency_time <= no_edges_t1 + ) + + baseline_frequency_filtered = baseline_frequency_filtered[ + no_edges + ] + baseline_frequency = baseline_frequency[no_edges] + baseline_frequency_time = ( + baseline_frequency_time[no_edges] + window_start_seconds + ) # NORMALIZE --------------------------------------------------- + # normalize all three feature arrays to the same range to make + # peak detection simpler + baseline_envelope = normalize([baseline_envelope])[0] search_envelope = normalize([search_envelope])[0] - inst_freq_filtered = normalize([np.abs(inst_freq_filtered)])[0] + baseline_frequency_filtered = normalize( + [baseline_frequency_filtered] + )[0] # PEAK DETECTION ---------------------------------------------- # detect peaks baseline_enelope - prominence = np.percentile( - baseline_envelope, config.baseline_prominence_percentile) - baseline_peaks, _ = find_peaks( - baseline_envelope, prominence=prominence) - + baseline_peak_indices, _ = find_peaks( + baseline_envelope, prominence=config.prominence + ) # detect peaks search_envelope - prominence = np.percentile( - search_envelope, config.search_prominence_percentile) - search_peaks, _ = find_peaks( - search_envelope, prominence=prominence) - - # detect peaks inst_freq_filtered - prominence = np.percentile( - inst_freq_filtered, - config.instantaneous_prominence_percentile + search_peak_indices, _ = find_peaks( + search_envelope, prominence=config.prominence ) - inst_freq_peaks, _ = find_peaks( - inst_freq_filtered, - prominence=prominence + # detect peaks inst_freq_filtered + frequency_peak_indices, _ = find_peaks( + baseline_frequency_filtered, prominence=config.prominence ) - # DETECT CHIRPS IN SEARCH WINDOW ------------------------------- + # DETECT CHIRPS IN SEARCH WINDOW ------------------------------ - baseline_ts = time_oi[baseline_peaks] - search_ts = time_oi[search_peaks] - freq_ts = baseline_freq_time[inst_freq_peaks] + # get the peak timestamps from the peak indices + baseline_peak_timestamps = current_raw_time[ + baseline_peak_indices + ] + search_peak_timestamps = current_raw_time[search_peak_indices] + frequency_peak_timestamps = baseline_frequency_time[ + frequency_peak_indices + ] - # check if one list is empty - if len(baseline_ts) == 0 or len(search_ts) == 0 or len(freq_ts) == 0: + # check if one list is empty and if so, skip to the next + # electrode because a chirp cannot be detected if one is empty + + one_feature_empty = ( + len(baseline_peak_timestamps) == 0 + or len(search_peak_timestamps) == 0 + or len(frequency_peak_timestamps) == 0 + ) + + if one_feature_empty: continue - current_chirps = group_timestamps( - [list(baseline_ts), list(search_ts), list(freq_ts)], 3, config.chirp_window_threshold) - # for checking if there are chirps on multiple electrodes - if len(current_chirps) == 0: + # group peak across feature arrays but only if they + # occur in all 3 feature arrays + + sublists = [ + list(baseline_peak_timestamps), + list(search_peak_timestamps), + list(frequency_peak_timestamps), + ] + + singleelectrode_chirps = group_timestamps( + sublists=sublists, + at_least_in=3, + difference_threshold=config.chirp_window_threshold, + ) + + # check it there are chirps detected after grouping, continue + # with the loop if not + + if len(singleelectrode_chirps) == 0: continue - chirps_electrodes.append(current_chirps) + # append chirps from this electrode to the multilectrode list + multielectrode_chirps.append(singleelectrode_chirps) + + # only initialize the plotting buffer if chirps are detected + chirp_detected = ( + (el == config.number_electrodes - 1) + & (len(singleelectrode_chirps) > 0) + & (plot in ["show", "save"]) + ) - if (el == config.number_electrodes - 1) & \ - (len(current_chirps) > 0) & \ - (plot in ["show", "save"]): + if chirp_detected: logger.debug("Detected chirp, ititialize buffer ...") # save data to Buffer buffer = PlotBuffer( config=config, - t0=t0, - dt=dt, - electrode=electrode, + t0=window_start_seconds, + dt=window_duration_seconds, + electrode=electrode_index, track_id=track_id, data=data, - time=time_oi, - baseline=baseline, + time=current_raw_time, + baseline=baselineband, baseline_envelope=baseline_envelope, - baseline_peaks=baseline_peaks, - search=search, + baseline_peaks=baseline_peak_indices, + search=searchband, search_envelope=search_envelope, - search_peaks=search_peaks, - frequency_time=baseline_freq_time, - frequency=baseline_freq, - frequency_filtered=inst_freq_filtered, - frequency_peaks=inst_freq_peaks, + search_peaks=search_peak_indices, + frequency_time=baseline_frequency_time, + frequency=baseline_frequency, + frequency_filtered=baseline_frequency_filtered, + frequency_peaks=frequency_peak_indices, ) logger.debug("Buffer initialized!") logger.debug( - f"Processed all electrodes for fish {track_id} for this window, sorting chirps ...") + f"Processed all electrodes for fish {track_id} for this" + "window, sorting chirps ..." + ) + + # check if there are chirps detected in multiple electrodes and + # continue the loop if not - if len(chirps_electrodes) == 0: + if len(multielectrode_chirps) == 0: continue - the_real_chirps = group_timestamps(chirps_electrodes, 2, 0.05) + # validate multielectrode chirps, i.e. check if they are + # detected in at least 'config.min_electrodes' electrodes - chirps.append(the_real_chirps) - fish_ids.append(track_id) + multielectrode_chirps_validated = group_timestamps( + sublists=multielectrode_chirps, + at_least_in=config.minimum_electrodes, + difference_threshold=config.chirp_window_threshold, + ) + + # add validated chirps to the list that tracks chirps across there + # rolling time windows + + multiwindow_chirps.append(multielectrode_chirps_validated) + multiwindow_ids.append(track_id) + + logger.debug( + "Found %d chirps, starting plotting ... " + % len(multielectrode_chirps_validated) + ) + # if chirps are detected and the plot flag is set, plot the + # chirps, otheswise try to delete the buffer if it exists - logger.debug('Found %d chirps, starting plotting ... ' % - len(the_real_chirps)) - if len(the_real_chirps) > 0: + if len(multielectrode_chirps_validated) > 0: try: - buffer.plot_buffer(the_real_chirps, plot) + buffer.plot_buffer(multielectrode_chirps_validated, plot) except NameError: pass else: @@ -669,29 +828,53 @@ def main(datapath: str, plot: str) -> None: except NameError: pass - chirps_new = [] - chirps_ids = [] - for tr in np.unique(fish_ids): - tr_index = np.asarray(fish_ids) == tr - ts = flatten(list(compress(chirps, tr_index))) - chirps_new.extend(ts) - chirps_ids.extend(list(np.ones_like(ts)*tr)) + # flatten list of lists containing chirps and create + # an array of fish ids that correspond to the chirps + + multiwindow_chirps_flat = [] + multiwindow_ids_flat = [] + for track_id in np.unique(multiwindow_ids): + + # get chirps for this fish and flatten the list + current_track_bool = np.asarray(multiwindow_ids) == track_id + current_track_chirps = flatten( + list(compress(multiwindow_chirps, current_track_bool)) + ) + + # add flattened chirps to the list + multiwindow_chirps_flat.extend(current_track_chirps) + multiwindow_ids_flat.extend( + list(np.ones_like(current_track_chirps) * track_id) + ) + + # purge duplicates, i.e. chirps that are very close to each other + # duplites arise due to overlapping windows - # purge duplicates purged_chirps = [] - purged_chirps_ids = [] - for tr in np.unique(fish_ids): - tr_chirps = np.asarray(chirps_new)[np.asarray(chirps_ids) == tr] + purged_ids = [] + for track_id in np.unique(multiwindow_ids_flat): + tr_chirps = np.asarray(multiwindow_chirps_flat)[ + np.asarray(multiwindow_ids_flat) == track_id + ] if len(tr_chirps) > 0: tr_chirps_purged = purge_duplicates( - tr_chirps, config.chirp_window_threshold) + tr_chirps, config.chirp_window_threshold + ) purged_chirps.extend(list(tr_chirps_purged)) - purged_chirps_ids.extend(list(np.ones_like(tr_chirps_purged)*tr)) + purged_ids.extend(list(np.ones_like(tr_chirps_purged) * track_id)) + + # sort chirps by time + purged_chirps = np.asarray(purged_chirps) + purged_ids = np.asarray(purged_ids) + purged_ids = purged_ids[np.argsort(purged_chirps)] + purged_chirps = purged_chirps[np.argsort(purged_chirps)] - np.save(datapath + 'chirps.npy', purged_chirps) - np.save(datapath + 'chirps_ids.npy', purged_chirps_ids) + # save them into the data directory + np.save(datapath + "chirps.npy", purged_chirps) + np.save(datapath + "chirp_ids.npy", purged_ids) if __name__ == "__main__": + # datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-05-13-10_00/" datapath = "../data/2022-06-02-10_00/" - main(datapath, plot="save") + main(datapath, plot="show") diff --git a/code/chirpdetector_conf.yml b/code/chirpdetector_conf.yml index 2c30fa7..0292fd6 100755 --- a/code/chirpdetector_conf.yml +++ b/code/chirpdetector_conf.yml @@ -1,3 +1,4 @@ +# directory setup dataroot: "../data/" outputdir: "../output/" @@ -8,41 +9,38 @@ edge: 0.25 # Number of electrodes to go over number_electrodes: 3 +minimum_electrodes: 2 -# Boundary for search frequency in Hz -search_boundary: 100 +# Search window bandwidth and minimal baseline bandwidth +minimal_bandwidth: 10 -# Cutoff frequency for envelope estimation by lowpass filter -envelope_cutoff: 25 +# Instantaneous frequency smoothing usint a gaussian kernel of this width +baseline_frequency_smoothing: 5 -# Cutoff frequency for envelope highpass filter -envelope_highpass_cutoff: 3 +# Baseline processing parameters +baseline_envelope_cutoff: 25 +baseline_envelope_bandpass_lowf: 4 +baseline_envelope_bandpass_highf: 100 +baseline_envelope_envelope_cutoff: 4 -# Cutoff frequency for envelope of envelope -envelope_envelope_cutoff: 5 +# search envelope processing parameters +search_envelope_cutoff: 5 # Instantaneous frequency bandpass filter cutoff frequencies -instantaneous_lowf: 15 -instantaneous_highf: 8000 +baseline_frequency_highpass_cutoff: 0.000005 +baseline_frequency_envelope_cutoff: 0.000005 -# Baseline envelope peak detection parameters -baseline_prominence_percentile: 90 - -# Search envelope peak detection parameters -search_prominence_percentile: 90 - -# Instantaneous frequency peak detection parameters -instantaneous_prominence_percentile: 90 +# peak detecion parameters +prominence: 0.005 # search freq parameter -search_df_lower: 25 +search_df_lower: 20 search_df_upper: 100 search_res: 1 -search_freq_percentiles: - - 5 - - 95 +search_bandwidth: 10 default_search_freq: 50 +# Classify events as chirps if they are less than this time apart chirp_window_threshold: 0.05 diff --git a/code/modules/datahandling.py b/code/modules/datahandling.py index 1de68d8..72e9caf 100644 --- a/code/modules/datahandling.py +++ b/code/modules/datahandling.py @@ -1,5 +1,59 @@ import numpy as np from typing import List, Any +from scipy.ndimage import gaussian_filter1d + + +def instantaneous_frequency( + signal: np.ndarray, + samplerate: int, + smoothing_window: int, +) -> tuple[np.ndarray, np.ndarray]: + """ + Compute the instantaneous frequency of a signal that is approximately + sinusoidal and symmetric around 0. + + Parameters + ---------- + signal : np.ndarray + Signal to compute the instantaneous frequency from. + samplerate : int + Samplerate of the signal. + smoothing_window : int + Window size for the gaussian filter. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + + """ + # calculate instantaneous frequency with zero crossings + roll_signal = np.roll(signal, shift=1) + time_signal = np.arange(len(signal)) / samplerate + period_index = np.arange(len(signal))[(roll_signal < 0) & (signal >= 0)][ + 1:-1 + ] + + upper_bound = np.abs(signal[period_index]) + lower_bound = np.abs(signal[period_index - 1]) + upper_time = np.abs(time_signal[period_index]) + lower_time = np.abs(time_signal[period_index - 1]) + + # create ratio + lower_ratio = lower_bound / (lower_bound + upper_bound) + + # appy to time delta + time_delta = upper_time - lower_time + true_zero = lower_time + lower_ratio * time_delta + + # create new time array + instantaneous_frequency_time = true_zero[:-1] + 0.5 * np.diff(true_zero) + + # compute frequency + instantaneous_frequency = gaussian_filter1d( + 1 / np.diff(true_zero), smoothing_window + ) + + return instantaneous_frequency_time, instantaneous_frequency def purge_duplicates( @@ -64,7 +118,7 @@ def purge_duplicates( def group_timestamps( - sublists: List[List[float]], n: int, threshold: float + sublists: List[List[float]], at_least_in: int, difference_threshold: float ) -> List[float]: """ Groups timestamps that are less than `threshold` milliseconds apart from @@ -100,7 +154,7 @@ def group_timestamps( # Group timestamps that are less than threshold milliseconds apart for i in range(1, len(timestamps)): - if timestamps[i] - timestamps[i - 1] < threshold: + if timestamps[i] - timestamps[i - 1] < difference_threshold: current_group.append(timestamps[i]) else: groups.append(current_group) @@ -111,7 +165,7 @@ def group_timestamps( # Retain only groups that contain at least n timestamps final_groups = [] for group in groups: - if len(group) >= n: + if len(group) >= at_least_in: final_groups.append(group) # Calculate the mean of each group diff --git a/code/modules/filters.py b/code/modules/filters.py index 5192cdc..e6d9896 100644 --- a/code/modules/filters.py +++ b/code/modules/filters.py @@ -3,8 +3,8 @@ import numpy as np def bandpass_filter( - data: np.ndarray, - rate: float, + signal: np.ndarray, + samplerate: float, lowf: float, highf: float, ) -> np.ndarray: @@ -12,7 +12,7 @@ def bandpass_filter( Parameters ---------- - data : np.ndarray + signal : np.ndarray The data to be filtered rate : float The sampling rate @@ -26,21 +26,22 @@ def bandpass_filter( np.ndarray The filtered data """ - sos = butter(2, (lowf, highf), "bandpass", fs=rate, output="sos") - fdata = sosfiltfilt(sos, data) - return fdata + sos = butter(2, (lowf, highf), "bandpass", fs=samplerate, output="sos") + filtered_signal = sosfiltfilt(sos, signal) + + return filtered_signal def highpass_filter( - data: np.ndarray, - rate: float, + signal: np.ndarray, + samplerate: float, cutoff: float, ) -> np.ndarray: """Highpass filter a signal. Parameters ---------- - data : np.ndarray + signal : np.ndarray The data to be filtered rate : float The sampling rate @@ -52,14 +53,15 @@ def highpass_filter( np.ndarray The filtered data """ - sos = butter(2, cutoff, "highpass", fs=rate, output="sos") - fdata = sosfiltfilt(sos, data) - return fdata + sos = butter(2, cutoff, "highpass", fs=samplerate, output="sos") + filtered_signal = sosfiltfilt(sos, signal) + + return filtered_signal def lowpass_filter( - data: np.ndarray, - rate: float, + signal: np.ndarray, + samplerate: float, cutoff: float ) -> np.ndarray: """Lowpass filter a signal. @@ -78,21 +80,25 @@ def lowpass_filter( np.ndarray The filtered data """ - sos = butter(2, cutoff, "lowpass", fs=rate, output="sos") - fdata = sosfiltfilt(sos, data) - return fdata + sos = butter(2, cutoff, "lowpass", fs=samplerate, output="sos") + filtered_signal = sosfiltfilt(sos, signal) + return filtered_signal -def envelope(data: np.ndarray, rate: float, freq: float) -> np.ndarray: + +def envelope(signal: np.ndarray, + samplerate: float, + cutoff_frequency: float + ) -> np.ndarray: """Calculate the envelope of a signal using a lowpass filter. Parameters ---------- - data : np.ndarray + signal : np.ndarray The signal to calculate the envelope of - rate : float + samplingrate : float The sampling rate of the signal - freq : float + cutoff_frequency : float The cutoff frequency of the lowpass filter Returns @@ -100,6 +106,7 @@ def envelope(data: np.ndarray, rate: float, freq: float) -> np.ndarray: np.ndarray The envelope of the signal """ - sos = butter(2, freq, "lowpass", fs=rate, output="sos") - envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(data)) + sos = butter(2, cutoff_frequency, "lowpass", fs=samplerate, output="sos") + envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(signal)) + return envelope