From ddf7bd545a849b32765e0c701d4f4852b71431f3 Mon Sep 17 00:00:00 2001 From: weygoldt <88969563+weygoldt@users.noreply.github.com> Date: Fri, 20 Jan 2023 13:56:26 +0100 Subject: [PATCH] refactoring finished for now --- code/chirpdetection.py | 533 ++++++++++++++++++++--------------- code/chirpdetector_conf.yml | 33 +-- code/modules/datahandling.py | 60 +++- code/modules/filters.py | 53 ++-- 4 files changed, 401 insertions(+), 278 deletions(-) mode change 100644 => 100755 code/chirpdetection.py diff --git a/code/chirpdetection.py b/code/chirpdetection.py old mode 100644 new mode 100755 index d6340d8..2a48025 --- a/code/chirpdetection.py +++ b/code/chirpdetection.py @@ -1,20 +1,22 @@ from itertools import compress from dataclasses import dataclass -from IPython import embed import numpy as np import matplotlib.pyplot as plt from scipy.signal import find_peaks -from scipy.ndimage import gaussian_filter1d -from thunderfish.dataloader import DataLoader from thunderfish.powerspectrum import spectrogram, decibel from sklearn.preprocessing import normalize from modules.filters import bandpass_filter, envelope, highpass_filter from modules.filehandling import ConfLoader, LoadData, make_outputdir -from modules.datahandling import flatten, purge_duplicates, group_timestamps from modules.plotstyle import PlotStyle from modules.logger import makeLogger +from modules.datahandling import ( + flatten, + purge_duplicates, + group_timestamps, + instantaneous_frequency, +) logger = makeLogger(__name__) @@ -28,6 +30,7 @@ class PlotBuffer: Buffer to save data that is created in the main detection loop and plot it outside the detecion loop. """ + config: ConfLoader t0: float dt: float @@ -85,8 +88,9 @@ class PlotBuffer: plot_spectrogram(axs[0], data_oi, self.data.raw_rate, self.t0) for chirp in chirps: - axs[0].scatter(chirp, np.median(self.frequency), - c=ps.black, marker="x") + axs[0].scatter( + chirp, np.median(self.frequency), c=ps.black, marker="x" + ) # plot waveform of filtered signal axs[1].plot(self.time, self.baseline, c=ps.green) @@ -94,7 +98,7 @@ class PlotBuffer: # plot waveform of filtered search signal axs[2].plot(self.time, self.search) - # plot baseline instantaneos frequency + # plot baseline instantaneous frequency axs[3].plot(self.frequency_time, self.frequency) # plot filtered and rectified envelope @@ -145,7 +149,7 @@ class PlotBuffer: def plot_spectrogram( - axis, signal: np.ndarray, samplerate: float, t0: float + axis, signal: np.ndarray, samplerate: float, window_start_seconds: float ) -> None: """ Plot a spectrogram of a signal. @@ -158,7 +162,7 @@ def plot_spectrogram( Signal to plot the spectrogram from. samplerate : float Samplerate of the signal. - t0 : float + window_start_seconds : float Start time of the signal. """ @@ -172,73 +176,26 @@ def plot_spectrogram( overlap_frac=0.5, ) - # axis.pcolormesh( - # spec_times + t0, - # spec_freqs, - # decibel(spec_power), - # ) axis.imshow( decibel(spec_power), - extent=[spec_times[0] + t0, spec_times[-1] + - t0, spec_freqs[0], spec_freqs[-1]], + extent=[ + spec_times[0] + window_start_seconds, + spec_times[-1] + window_start_seconds, + spec_freqs[0], + spec_freqs[-1], + ], aspect="auto", origin="lower", interpolation="gaussian", ) -def instantaneos_frequency( - signal: np.ndarray, samplerate: int -) -> tuple[np.ndarray, np.ndarray]: - """ - Compute the instantaneous frequency of a signal. - - Parameters - ---------- - signal : np.ndarray - Signal to compute the instantaneous frequency from. - samplerate : int - Samplerate of the signal. - - Returns - ------- - tuple[np.ndarray, np.ndarray] - - """ - # calculate instantaneos frequency with zero crossings - roll_signal = np.roll(signal, shift=1) - time_signal = np.arange(len(signal)) / samplerate - period_index = np.arange(len(signal))[(roll_signal < 0) & (signal >= 0)][ - 1:-1 - ] - - upper_bound = np.abs(signal[period_index]) - lower_bound = np.abs(signal[period_index - 1]) - upper_time = np.abs(time_signal[period_index]) - lower_time = np.abs(time_signal[period_index - 1]) - - # create ratio - lower_ratio = lower_bound / (lower_bound + upper_bound) - - # appy to time delta - time_delta = upper_time - lower_time - true_zero = lower_time + lower_ratio * time_delta - - # create new time array - inst_freq_time = true_zero[:-1] + 0.5 * np.diff(true_zero) - - # compute frequency - inst_freq = gaussian_filter1d(1 / np.diff(true_zero), 5) - - return inst_freq_time, inst_freq - - -def double_bandpass( - data: DataLoader, - samplerate: int, - freqs: np.ndarray, - search_freq: float, - config: ConfLoader +def extract_frequency_bands( + raw_data: np.ndarray, + samplerate: int, + baseline_track: np.ndarray, + searchband_center: float, + minimal_bandwidth: float, ) -> tuple[np.ndarray, np.ndarray]: """ Apply a bandpass filter to the baseline of a signal and a second bandpass @@ -246,14 +203,16 @@ def double_bandpass( Parameters ---------- - data : DataLoader + raw_data : np.ndarray Data to apply the filter to. samplerate : int Samplerate of the signal. - freqs : np.ndarray + baseline_track : np.ndarray Tracked fundamental frequencies of the signal. - search_freq : float + searchband_center: float Frequency to search for above or below the baseline. + minimal_bandwidth : float + Minimal bandwidth of the filter. Returns ------- @@ -261,28 +220,30 @@ def double_bandpass( """ # compute boundaries to filter baseline - q25, q50, q75 = np.percentile(freqs, [25, 50, 75]) + q25, q50, q75 = np.percentile(baseline_track, [25, 50, 75]) # check if percentile delta is too small - if q75 - q25 < 5: - median = np.median(freqs) - q25, q75 = median - 2.5, median + 2.5 + if q75 - q25 < 10: + q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2 # filter baseline - filtered_baseline = bandpass_filter(data, samplerate, lowf=q25, highf=q75) + filtered_baseline = bandpass_filter( + raw_data, samplerate, lowf=q25, highf=q75 + ) # filter search area filtered_search_freq = bandpass_filter( - data, samplerate, - lowf=search_freq + q50 - config.search_bandwidth / 2, - highf=search_freq + q50 + config.search_bandwidth / 2 + raw_data, + samplerate, + lowf=searchband_center + q50 - minimal_bandwidth / 2, + highf=searchband_center + q50 + minimal_bandwidth / 2, ) return filtered_baseline, filtered_search_freq -def freqmedian_allfish( - data: LoadData, t0: float, dt: float +def window_median_all_track_ids( + data: LoadData, window_start_seconds: float, window_duration_seconds: float ) -> tuple[float, list[int]]: """ Calculate the median frequency of all fish in a given time window. @@ -291,9 +252,9 @@ def freqmedian_allfish( ---------- data : LoadData Data to calculate the median frequency from. - t0 : float + window_start_seconds : float Start time of the window. - dt : float + window_duration_seconds : float Duration of the window. Returns @@ -308,8 +269,11 @@ def freqmedian_allfish( for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): window_idx = np.arange(len(data.idx))[ (data.ident == track_id) - & (data.time[data.idx] >= t0) - & (data.time[data.idx] <= (t0 + dt)) + & (data.time[data.idx] >= window_start_seconds) + & ( + data.time[data.idx] + <= (window_start_seconds + window_duration_seconds) + ) ] if len(data.freq[window_idx]) > 0: @@ -323,7 +287,7 @@ def freqmedian_allfish( return median_freq, track_ids -def find_search_freq( +def find_searchband( freq_temp: np.ndarray, median_ids: np.ndarray, median_freq: np.ndarray, @@ -331,15 +295,16 @@ def find_search_freq( data: LoadData, ) -> float: """ - Find the search frequency for each fish by checking which fish EODs are - above the current EOD and finding a gap in them. + Find the search frequency band for each fish by checking which fish EODs + are above the current EOD and finding a gap in them. Parameters ---------- freq_temp : np.ndarray Current EOD frequency array / the current fish of interest. median_ids : np.ndarray - Array of track IDs of the medians of all other fish in the current window. + Array of track IDs of the medians of all other fish in the current + window. median_freq : np.ndarray Array of median frequencies of all other fish in the current window. config : ConfLoader @@ -421,7 +386,8 @@ def find_search_freq( longest_search_window = search_windows[np.argmax(search_windows_lens)] search_freq = ( - longest_search_window[-1] - longest_search_window[0]) / 2 + longest_search_window[-1] - longest_search_window[0] + ) / 2 else: search_freq = config.default_search_freq @@ -431,7 +397,11 @@ def find_search_freq( def main(datapath: str, plot: str) -> None: - assert plot in ["save", "show", "false"] + assert plot in [ + "save", + "show", + "false", + ], "plot must be 'save', 'show' or 'false'" # load raw file data = LoadData(datapath) @@ -444,13 +414,15 @@ def main(datapath: str, plot: str) -> None: window_overlap = config.overlap * data.raw_rate window_edge = config.edge * data.raw_rate - # check if window duration is even + # check if window duration and window ovelap is even, otherwise the half + # of the duration or window overlap would return a float, thus an + # invalid index + if window_duration % 2 == 0: window_duration = int(window_duration) else: raise ValueError("Window duration must be even.") - # check if window ovelap is even if window_overlap % 2 == 0: window_overlap = int(window_overlap) else: @@ -460,16 +432,16 @@ def main(datapath: str, plot: str) -> None: raw_time = np.arange(data.raw.shape[0]) / data.raw_rate # good chirp times for data: 2022-06-02-10_00 - t0 = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate - dt = 60 * data.raw_rate + window_start_seconds = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate + window_duration_seconds = 60 * data.raw_rate -# t0 = 0 -# dt = data.raw.shape[0] + # t0 = 0 + # dt = data.raw.shape[0] # generate starting points of rolling window - window_starts = np.arange( - t0, - t0 + dt, + window_start_indices = np.arange( + window_start_seconds, + window_start_seconds + window_duration_seconds, window_duration - (window_overlap + 2 * window_edge), dtype=int, ) @@ -478,19 +450,20 @@ def main(datapath: str, plot: str) -> None: multiwindow_chirps = [] multiwindow_ids = [] - for st, start_index in enumerate(window_starts): + for st, window_start_index in enumerate(window_start_indices): - logger.info(f"Processing window {st} of {len(window_starts)}") + logger.info(f"Processing window {st+1} of {len(window_start_indices)}") - # make t0 and dt - t0 = start_index / data.raw_rate - dt = window_duration / data.raw_rate + window_start_seconds = window_start_index / data.raw_rate + window_duration_seconds = window_duration / data.raw_rate # set index window - stop_index = start_index + window_duration + window_stop_index = window_start_index + window_duration # calucate median of fish frequencies in window - median_freq, median_ids = freqmedian_allfish(data, t0, dt) + median_freq, median_ids = window_median_all_track_ids( + data, window_start_seconds, window_duration_seconds + ) # iterate through all fish for tr, track_id in enumerate( @@ -500,48 +473,57 @@ def main(datapath: str, plot: str) -> None: logger.debug(f"Processing track {tr} of {len(data.ids)}") # get index of track data in this time window - window_idx = np.arange(len(data.idx))[ + track_window_index = np.arange(len(data.idx))[ (data.ident == track_id) - & (data.time[data.idx] >= t0) - & (data.time[data.idx] <= (t0 + dt)) + & (data.time[data.idx] >= window_start_seconds) + & ( + data.time[data.idx] + <= (window_start_seconds + window_duration_seconds) + ) ] # get tracked frequencies and their times - freq_temp = data.freq[window_idx] - powers_temp = data.powers[window_idx, :] + current_frequencies = data.freq[track_window_index] + current_powers = data.powers[track_window_index, :] # approximate sampling rate to compute expected durations if there # is data available for this time window for this fish id + track_samplerate = np.mean(1 / np.diff(data.time)) - expected_duration = ((t0 + dt) - t0) * track_samplerate + expected_duration = ( + (window_start_seconds + window_duration_seconds) + - window_start_seconds + ) * track_samplerate # check if tracked data available in this window - if len(freq_temp) < expected_duration * 0.5: + if len(current_frequencies) < expected_duration / 2: logger.warning( f"Track {track_id} has no data in window {st}, skipping." ) continue # check if there are powers available in this window - nanchecker = np.unique(np.isnan(powers_temp)) - if (len(nanchecker) == 1) and nanchecker[0]: + nanchecker = np.unique(np.isnan(current_powers)) + if (len(nanchecker) == 1) and nanchecker[0] is True: logger.warning( - f"No powers available for track {track_id} window {st}, \ - skipping." + f"No powers available for track {track_id} window {st}," + "skipping." ) continue # find the strongest electrodes for the current fish in the current # window - best_electrodes = np.argsort(np.nanmean(powers_temp, axis=0))[ - -config.number_electrodes: - ] + + best_electrode_index = np.argsort( + np.nanmean(current_powers, axis=0) + )[-config.number_electrodes:] # find a frequency above the baseline of the current fish in which # no other fish is active to search for chirps there - search_freq = find_search_freq( + + search_frequency = find_searchband( config=config, - freq_temp=freq_temp, + freq_temp=current_frequencies, median_ids=median_ids, data=data, median_freq=median_freq, @@ -549,153 +531,219 @@ def main(datapath: str, plot: str) -> None: # add all chirps that are detected on mulitple electrodes for one # fish fish in one window to this list + multielectrode_chirps = [] # iterate through electrodes - for el, electrode in enumerate(best_electrodes): + for el, electrode_index in enumerate(best_electrode_index): logger.debug( - f"Processing electrode {el} of {len(best_electrodes)}" + f"Processing electrode {el+1} of " + f"{len(best_electrode_index)}" ) + # LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------ + # load region of interest of raw data file - data_oi = data.raw[start_index:stop_index, :] - time_oi = raw_time[start_index:stop_index] + current_raw_data = data.raw[ + window_start_index:window_stop_index, electrode_index + ] + current_raw_time = raw_time[ + window_start_index:window_stop_index + ] + + # EXTRACT FEATURES -------------------------------------------- # filter baseline and above - baseline, search = double_bandpass( - data_oi[:, electrode], - data.raw_rate, - freq_temp, - search_freq, - config=config, + baselineband, searchband = extract_frequency_bands( + raw_data=current_raw_data, + samplerate=data.raw_rate, + baseline_track=current_frequencies, + searchband_center=search_frequency, + minimal_bandwidth=config.minimal_bandwidth, ) - # compute instantaneous frequency on narrow signal - baseline_freq_time, baseline_freq = instantaneos_frequency( - baseline, data.raw_rate - ) + # compute envelope of baseline band to find dips + # in the baseline envelope - # compute envelopes baseline_envelope_unfiltered = envelope( - baseline, data.raw_rate, config.envelope_cutoff + signal=baselineband, + samplerate=data.raw_rate, + cutoff_frequency=config.baseline_envelope_cutoff, + ) + + # highpass filter baseline envelope to remove slower + # fluctuations e.g. due to motion envelope + + baseline_envelope = bandpass_filter( + signal=baseline_envelope_unfiltered, + samplerate=data.raw_rate, + lowf=config.baseline_envelope_bandpass_lowf, + highf=config.baseline_envelope_bandpass_highf, + ) + + # highbass filter introduced filter effects, i.e. oscillations + # around peaks. Compute the envelope of the highpass filtered + # and inverted baseline envelope to remove these oscillations + + baseline_envelope = -baseline_envelope + + baseline_envelope = envelope( + signal=baseline_envelope, + samplerate=data.raw_rate, + cutoff_frequency=config.baseline_envelope_envelope_cutoff, ) + + # compute the envelope of the search band. Peaks in the search + # band envelope correspond to troughs in the baseline envelope + # during chirps + search_envelope = envelope( - search, data.raw_rate, config.envelope_cutoff + signal=searchband, + samplerate=data.raw_rate, + cutoff_frequency=config.search_envelope_cutoff, ) - # highpass filter envelopes - baseline_envelope = highpass_filter( - baseline_envelope_unfiltered, - data.raw_rate, - config.envelope_highpass_cutoff, + # compute instantaneous frequency of the baseline band to find + # anomalies during a chirp, i.e. a frequency jump upwards or + # sometimes downwards. We do not fully understand why the + # instantaneous frequency can also jump downwards during a + # chirp. This phenomenon is only observed on chirps on a narrow + # filtered baseline such as the one we are working with. + + ( + baseline_frequency_time, + baseline_frequency, + ) = instantaneous_frequency( + signal=baselineband, + samplerate=data.raw_rate, + smoothing_window=config.baseline_frequency_smoothing, ) - # envelopes of filtered envelope of filtered baseline - baseline_envelope = envelope( - np.abs(baseline_envelope), - data.raw_rate, - config.envelope_envelope_cutoff, + # bandpass filter the instantaneous frequency to remove slow + # fluctuations. Just as with the baseline envelope, we then + # compute the envelope of the signal to remove the oscillations + # around the peaks + + baseline_frequency_samplerate = np.mean( + np.diff(baseline_frequency_time) + ) + + baseline_frequency_filtered = np.abs( + baseline_frequency - np.median(baseline_frequency) + ) + + baseline_frequency_filtered = highpass_filter( + signal=baseline_frequency_filtered, + samplerate=baseline_frequency_samplerate, + cutoff=config.baseline_frequency_highpass_cutoff, ) - # bandpass filter the instantaneous frequency to put it to 0 - inst_freq_filtered = bandpass_filter( - baseline_freq, - data.raw_rate, - lowf=config.instantaneous_lowf, - highf=config.instantaneous_highf, + baseline_frequency_filtered = envelope( + signal=-baseline_frequency_filtered, + samplerate=baseline_frequency_samplerate, + cutoff_frequency=config.baseline_frequency_envelope_cutoff, ) # CUT OFF OVERLAP --------------------------------------------- - # overwrite raw time to valid region, i.e. cut off snippet at - # start and end of each window to remove filter effects - valid = np.arange( + # cut off snippet at start and end of each window to remove + # filter effects + + # get arrays with raw samplerate without edges + no_edges = np.arange( int(window_edge), len(baseline_envelope) - int(window_edge) ) - baseline_envelope_unfiltered = baseline_envelope_unfiltered[ - valid - ] - baseline_envelope = baseline_envelope[valid] - search_envelope = search_envelope[valid] - - # get inst freq valid snippet - valid_t0 = int(window_edge) / data.raw_rate - valid_t1 = baseline_freq_time[-1] - ( + current_raw_time = current_raw_time[no_edges] + baselineband = baselineband[no_edges] + searchband = searchband[no_edges] + baseline_envelope = baseline_envelope[no_edges] + search_envelope = search_envelope[no_edges] + + # get instantaneous frequency withoup edges + no_edges_t0 = int(window_edge) / data.raw_rate + no_edges_t1 = baseline_frequency_time[-1] - ( int(window_edge) / data.raw_rate ) + no_edges = (baseline_frequency_time >= no_edges_t0) & ( + baseline_frequency_time <= no_edges_t1 + ) - inst_freq_filtered = inst_freq_filtered[ - (baseline_freq_time >= valid_t0) - & (baseline_freq_time <= valid_t1) - ] - - baseline_freq = baseline_freq[ - (baseline_freq_time >= valid_t0) - & (baseline_freq_time <= valid_t1) + baseline_frequency_filtered = baseline_frequency_filtered[ + no_edges ] - - baseline_freq_time = ( - baseline_freq_time[ - (baseline_freq_time >= valid_t0) - & (baseline_freq_time <= valid_t1) - ] - + t0 + baseline_frequency = baseline_frequency[no_edges] + baseline_frequency_time = ( + baseline_frequency_time[no_edges] + window_start_seconds ) - time_oi = time_oi[valid] - baseline = baseline[valid] - search = search[valid] - # NORMALIZE --------------------------------------------------- + # normalize all three feature arrays to the same range to make + # peak detection simpler + baseline_envelope = normalize([baseline_envelope])[0] search_envelope = normalize([search_envelope])[0] - inst_freq_filtered = normalize([np.abs(inst_freq_filtered)])[0] + baseline_frequency_filtered = normalize( + [baseline_frequency_filtered] + )[0] # PEAK DETECTION ---------------------------------------------- - prominence = config.prominence - # detect peaks baseline_enelope - baseline_peaks, _ = find_peaks( - baseline_envelope, prominence=prominence + baseline_peak_indices, _ = find_peaks( + baseline_envelope, prominence=config.prominence ) # detect peaks search_envelope - search_peaks, _ = find_peaks( - search_envelope, prominence=prominence + search_peak_indices, _ = find_peaks( + search_envelope, prominence=config.prominence ) # detect peaks inst_freq_filtered - inst_freq_peaks, _ = find_peaks( - inst_freq_filtered, prominence=prominence + frequency_peak_indices, _ = find_peaks( + baseline_frequency_filtered, prominence=config.prominence ) # DETECT CHIRPS IN SEARCH WINDOW ------------------------------ # get the peak timestamps from the peak indices - baseline_ts = time_oi[baseline_peaks] - search_ts = time_oi[search_peaks] - freq_ts = baseline_freq_time[inst_freq_peaks] + baseline_peak_timestamps = current_raw_time[ + baseline_peak_indices + ] + search_peak_timestamps = current_raw_time[search_peak_indices] + frequency_peak_timestamps = baseline_frequency_time[ + frequency_peak_indices + ] # check if one list is empty and if so, skip to the next # electrode because a chirp cannot be detected if one is empty - if ( - len(baseline_ts) == 0 - or len(search_ts) == 0 - or len(freq_ts) == 0 - ): + + one_feature_empty = ( + len(baseline_peak_timestamps) == 0 + or len(search_peak_timestamps) == 0 + or len(frequency_peak_timestamps) == 0 + ) + + if one_feature_empty: continue # group peak across feature arrays but only if they # occur in all 3 feature arrays + + sublists = [ + list(baseline_peak_timestamps), + list(search_peak_timestamps), + list(frequency_peak_timestamps), + ] + singleelectrode_chirps = group_timestamps( - [list(baseline_ts), list(search_ts), list(freq_ts)], - 3, - config.chirp_window_threshold, + sublists=sublists, + at_least_in=3, + difference_threshold=config.chirp_window_threshold, ) # check it there are chirps detected after grouping, continue # with the loop if not + if len(singleelectrode_chirps) == 0: continue @@ -703,57 +751,62 @@ def main(datapath: str, plot: str) -> None: multielectrode_chirps.append(singleelectrode_chirps) # only initialize the plotting buffer if chirps are detected - if ( + chirp_detected = ( (el == config.number_electrodes - 1) & (len(singleelectrode_chirps) > 0) & (plot in ["show", "save"]) - ): + ) + + if chirp_detected: logger.debug("Detected chirp, ititialize buffer ...") # save data to Buffer buffer = PlotBuffer( config=config, - t0=t0, - dt=dt, - electrode=electrode, + t0=window_start_seconds, + dt=window_duration_seconds, + electrode=electrode_index, track_id=track_id, data=data, - time=time_oi, - baseline=baseline, + time=current_raw_time, + baseline=baselineband, baseline_envelope=baseline_envelope, - baseline_peaks=baseline_peaks, - search=search, + baseline_peaks=baseline_peak_indices, + search=searchband, search_envelope=search_envelope, - search_peaks=search_peaks, - frequency_time=baseline_freq_time, - frequency=baseline_freq, - frequency_filtered=inst_freq_filtered, - frequency_peaks=inst_freq_peaks, + search_peaks=search_peak_indices, + frequency_time=baseline_frequency_time, + frequency=baseline_frequency, + frequency_filtered=baseline_frequency_filtered, + frequency_peaks=frequency_peak_indices, ) logger.debug("Buffer initialized!") logger.debug( - f"Processed all electrodes for fish {track_id} for this \ - window, sorting chirps ..." + f"Processed all electrodes for fish {track_id} for this" + "window, sorting chirps ..." ) # check if there are chirps detected in multiple electrodes and # continue the loop if not + if len(multielectrode_chirps) == 0: continue # validate multielectrode chirps, i.e. check if they are # detected in at least 'config.min_electrodes' electrodes + multielectrode_chirps_validated = group_timestamps( - multielectrode_chirps, - config.minimum_electrodes, - config.chirp_window_threshold + sublists=multielectrode_chirps, + at_least_in=config.minimum_electrodes, + difference_threshold=config.chirp_window_threshold, ) # add validated chirps to the list that tracks chirps across there # rolling time windows + multiwindow_chirps.append(multielectrode_chirps_validated) multiwindow_ids.append(track_id) @@ -763,6 +816,7 @@ def main(datapath: str, plot: str) -> None: ) # if chirps are detected and the plot flag is set, plot the # chirps, otheswise try to delete the buffer if it exists + if len(multielectrode_chirps_validated) > 0: try: buffer.plot_buffer(multielectrode_chirps_validated, plot) @@ -776,27 +830,38 @@ def main(datapath: str, plot: str) -> None: # flatten list of lists containing chirps and create # an array of fish ids that correspond to the chirps + multiwindow_chirps_flat = [] multiwindow_ids_flat = [] - for tr in np.unique(multiwindow_ids): - tr_index = np.asarray(multiwindow_ids) == tr - ts = flatten(list(compress(multiwindow_chirps, tr_index))) - multiwindow_chirps_flat.extend(ts) - multiwindow_ids_flat.extend(list(np.ones_like(ts) * tr)) + for track_id in np.unique(multiwindow_ids): + + # get chirps for this fish and flatten the list + current_track_bool = np.asarray(multiwindow_ids) == track_id + current_track_chirps = flatten( + list(compress(multiwindow_chirps, current_track_bool)) + ) + + # add flattened chirps to the list + multiwindow_chirps_flat.extend(current_track_chirps) + multiwindow_ids_flat.extend( + list(np.ones_like(current_track_chirps) * track_id) + ) # purge duplicates, i.e. chirps that are very close to each other # duplites arise due to overlapping windows + purged_chirps = [] purged_ids = [] - for tr in np.unique(multiwindow_ids_flat): + for track_id in np.unique(multiwindow_ids_flat): tr_chirps = np.asarray(multiwindow_chirps_flat)[ - np.asarray(multiwindow_ids_flat) == tr] + np.asarray(multiwindow_ids_flat) == track_id + ] if len(tr_chirps) > 0: tr_chirps_purged = purge_duplicates( tr_chirps, config.chirp_window_threshold ) purged_chirps.extend(list(tr_chirps_purged)) - purged_ids.extend(list(np.ones_like(tr_chirps_purged) * tr)) + purged_ids.extend(list(np.ones_like(tr_chirps_purged) * track_id)) # sort chirps by time purged_chirps = np.asarray(purged_chirps) diff --git a/code/chirpdetector_conf.yml b/code/chirpdetector_conf.yml index e12b904..0292fd6 100755 --- a/code/chirpdetector_conf.yml +++ b/code/chirpdetector_conf.yml @@ -1,3 +1,4 @@ +# directory setup dataroot: "../data/" outputdir: "../output/" @@ -10,30 +11,26 @@ edge: 0.25 number_electrodes: 3 minimum_electrodes: 2 -# Search window bandwidth +# Search window bandwidth and minimal baseline bandwidth +minimal_bandwidth: 10 -# Cutoff frequency for envelope estimation by lowpass filter -envelope_cutoff: 25 +# Instantaneous frequency smoothing usint a gaussian kernel of this width +baseline_frequency_smoothing: 5 -# Cutoff frequency for envelope highpass filter -envelope_highpass_cutoff: 3 +# Baseline processing parameters +baseline_envelope_cutoff: 25 +baseline_envelope_bandpass_lowf: 4 +baseline_envelope_bandpass_highf: 100 +baseline_envelope_envelope_cutoff: 4 -# Cutoff frequency for envelope of envelope -envelope_envelope_cutoff: 5 +# search envelope processing parameters +search_envelope_cutoff: 5 # Instantaneous frequency bandpass filter cutoff frequencies -instantaneous_lowf: 15 -instantaneous_highf: 8000 - -# Baseline envelope peak detection parameters -# baseline_prominence_percentile: 90 - -# Search envelope peak detection parameters -# search_prominence_percentile: 90 - -# Instantaneous frequency peak detection parameters -# instantaneous_prominence_percentile: 90 +baseline_frequency_highpass_cutoff: 0.000005 +baseline_frequency_envelope_cutoff: 0.000005 +# peak detecion parameters prominence: 0.005 # search freq parameter diff --git a/code/modules/datahandling.py b/code/modules/datahandling.py index 1de68d8..72e9caf 100644 --- a/code/modules/datahandling.py +++ b/code/modules/datahandling.py @@ -1,5 +1,59 @@ import numpy as np from typing import List, Any +from scipy.ndimage import gaussian_filter1d + + +def instantaneous_frequency( + signal: np.ndarray, + samplerate: int, + smoothing_window: int, +) -> tuple[np.ndarray, np.ndarray]: + """ + Compute the instantaneous frequency of a signal that is approximately + sinusoidal and symmetric around 0. + + Parameters + ---------- + signal : np.ndarray + Signal to compute the instantaneous frequency from. + samplerate : int + Samplerate of the signal. + smoothing_window : int + Window size for the gaussian filter. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + + """ + # calculate instantaneous frequency with zero crossings + roll_signal = np.roll(signal, shift=1) + time_signal = np.arange(len(signal)) / samplerate + period_index = np.arange(len(signal))[(roll_signal < 0) & (signal >= 0)][ + 1:-1 + ] + + upper_bound = np.abs(signal[period_index]) + lower_bound = np.abs(signal[period_index - 1]) + upper_time = np.abs(time_signal[period_index]) + lower_time = np.abs(time_signal[period_index - 1]) + + # create ratio + lower_ratio = lower_bound / (lower_bound + upper_bound) + + # appy to time delta + time_delta = upper_time - lower_time + true_zero = lower_time + lower_ratio * time_delta + + # create new time array + instantaneous_frequency_time = true_zero[:-1] + 0.5 * np.diff(true_zero) + + # compute frequency + instantaneous_frequency = gaussian_filter1d( + 1 / np.diff(true_zero), smoothing_window + ) + + return instantaneous_frequency_time, instantaneous_frequency def purge_duplicates( @@ -64,7 +118,7 @@ def purge_duplicates( def group_timestamps( - sublists: List[List[float]], n: int, threshold: float + sublists: List[List[float]], at_least_in: int, difference_threshold: float ) -> List[float]: """ Groups timestamps that are less than `threshold` milliseconds apart from @@ -100,7 +154,7 @@ def group_timestamps( # Group timestamps that are less than threshold milliseconds apart for i in range(1, len(timestamps)): - if timestamps[i] - timestamps[i - 1] < threshold: + if timestamps[i] - timestamps[i - 1] < difference_threshold: current_group.append(timestamps[i]) else: groups.append(current_group) @@ -111,7 +165,7 @@ def group_timestamps( # Retain only groups that contain at least n timestamps final_groups = [] for group in groups: - if len(group) >= n: + if len(group) >= at_least_in: final_groups.append(group) # Calculate the mean of each group diff --git a/code/modules/filters.py b/code/modules/filters.py index 5192cdc..e6d9896 100644 --- a/code/modules/filters.py +++ b/code/modules/filters.py @@ -3,8 +3,8 @@ import numpy as np def bandpass_filter( - data: np.ndarray, - rate: float, + signal: np.ndarray, + samplerate: float, lowf: float, highf: float, ) -> np.ndarray: @@ -12,7 +12,7 @@ def bandpass_filter( Parameters ---------- - data : np.ndarray + signal : np.ndarray The data to be filtered rate : float The sampling rate @@ -26,21 +26,22 @@ def bandpass_filter( np.ndarray The filtered data """ - sos = butter(2, (lowf, highf), "bandpass", fs=rate, output="sos") - fdata = sosfiltfilt(sos, data) - return fdata + sos = butter(2, (lowf, highf), "bandpass", fs=samplerate, output="sos") + filtered_signal = sosfiltfilt(sos, signal) + + return filtered_signal def highpass_filter( - data: np.ndarray, - rate: float, + signal: np.ndarray, + samplerate: float, cutoff: float, ) -> np.ndarray: """Highpass filter a signal. Parameters ---------- - data : np.ndarray + signal : np.ndarray The data to be filtered rate : float The sampling rate @@ -52,14 +53,15 @@ def highpass_filter( np.ndarray The filtered data """ - sos = butter(2, cutoff, "highpass", fs=rate, output="sos") - fdata = sosfiltfilt(sos, data) - return fdata + sos = butter(2, cutoff, "highpass", fs=samplerate, output="sos") + filtered_signal = sosfiltfilt(sos, signal) + + return filtered_signal def lowpass_filter( - data: np.ndarray, - rate: float, + signal: np.ndarray, + samplerate: float, cutoff: float ) -> np.ndarray: """Lowpass filter a signal. @@ -78,21 +80,25 @@ def lowpass_filter( np.ndarray The filtered data """ - sos = butter(2, cutoff, "lowpass", fs=rate, output="sos") - fdata = sosfiltfilt(sos, data) - return fdata + sos = butter(2, cutoff, "lowpass", fs=samplerate, output="sos") + filtered_signal = sosfiltfilt(sos, signal) + return filtered_signal -def envelope(data: np.ndarray, rate: float, freq: float) -> np.ndarray: + +def envelope(signal: np.ndarray, + samplerate: float, + cutoff_frequency: float + ) -> np.ndarray: """Calculate the envelope of a signal using a lowpass filter. Parameters ---------- - data : np.ndarray + signal : np.ndarray The signal to calculate the envelope of - rate : float + samplingrate : float The sampling rate of the signal - freq : float + cutoff_frequency : float The cutoff frequency of the lowpass filter Returns @@ -100,6 +106,7 @@ def envelope(data: np.ndarray, rate: float, freq: float) -> np.ndarray: np.ndarray The envelope of the signal """ - sos = butter(2, freq, "lowpass", fs=rate, output="sos") - envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(data)) + sos = butter(2, cutoff_frequency, "lowpass", fs=samplerate, output="sos") + envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(signal)) + return envelope