diff --git a/code/chirpdetection.py b/code/chirpdetection.py index ca92f40..900e251 100644 --- a/code/chirpdetection.py +++ b/code/chirpdetection.py @@ -10,41 +10,7 @@ from thunderfish.dataloader import DataLoader from thunderfish.powerspectrum import spectrogram, decibel from modules.filters import bandpass_filter, envelope, highpass_filter - - -class LoadData: - """ - - Attributes - ---------- - data : DataLoader object containing raw data - samplerate : sampling rate of raw data - time : array of time for tracked fundamental frequency - freq : array of fundamental frequency - idx : array of indices to access time array - ident : array of identifiers for each tracked fundamental frequency - ids : array of unique identifiers exluding NaNs - """ - - def __init__(self, datapath: str) -> None: - - # load raw data - self.file = os.path.join(datapath, "traces-grid1.raw") - self.data = DataLoader(self.file, 60.0, 0, channel=-1) - self.samplerate = self.data.samplerate - - # load wavetracker files - self.time = np.load(datapath + "times.npy", allow_pickle=True) - self.freq = np.load(datapath + "fund_v.npy", allow_pickle=True) - self.idx = np.load(datapath + "idx_v.npy", allow_pickle=True) - self.ident = np.load(datapath + "ident_v.npy", allow_pickle=True) - self.ids = np.unique(self.ident[~np.isnan(self.ident)]) - - def __repr__(self) -> str: - return f"LoadData({self.file})" - - def __str__(self) -> str: - return f"LoadData({self.file})" +from modules.filehandling import ConfLoader def instantaneos_frequency( @@ -179,9 +145,12 @@ def main(datapath: str) -> None: idx = np.load(datapath + "idx_v.npy", allow_pickle=True) ident = np.load(datapath + "ident_v.npy", allow_pickle=True) + # load config file + config = ConfLoader("chirpdetector_conf.yml") + # set time window # <------------------------ Iterate through windows here - window_duration = 5 * data.samplerate # 5 seconds window - window_overlap = 0.5 * data.samplerate # 0.5 seconds overlap + window_duration = config.window * data.samplerate + window_overlap = config.overlap * data.samplerate # check if window duration is even if window_duration % 2 == 0: @@ -196,6 +165,7 @@ def main(datapath: str) -> None: raise ValueError("Window overlap must be even.") raw_time = np.arange(data.shape[0]) / data.samplerate + # good chirp times for data: 2022-06-02-10_00 t0 = (3 * 60 * 60 + 6 * 60 + 43.5) * data.samplerate dt = 60 * data.samplerate @@ -242,7 +212,7 @@ def main(datapath: str) -> None: # get tracked frequencies and their times freq_temp = freq[window_index] powers_temp = powers[window_index, :] - time_temp = time[idx[window_index]] + # time_temp = time[idx[window_index]] track_samplerate = np.mean(1 / np.diff(time)) expected_duration = ((t0 + dt) - t0) * track_samplerate @@ -252,7 +222,6 @@ def main(datapath: str) -> None: # get best electrode electrode = np.argsort(np.nanmean(powers_temp, axis=0))[-1] - # <------------------------------------------ Iterate through electrodes # plot wavetracker tracks to spectrogram @@ -289,52 +258,68 @@ def main(datapath: str) -> None: lowf=np.mean(freq_temp)-5, highf=np.mean(freq_temp)+100 ) + # compute instantaneous frequency on narrow signal baseline_freq_time, baseline_freq = instantaneos_frequency( baseline, data.samplerate ) # compute envelopes - cutoff = 25 - baseline_envelope = envelope(baseline, data.samplerate, cutoff) - search_envelope = envelope(search, data.samplerate, cutoff) + baseline_envelope = envelope( + baseline, data.samplerate, config.envelope_cutoff) + search_envelope = envelope( + search, data.samplerate, config.envelope_cutoff) # highpass filter envelopes - cutoff = 5 baseline_envelope = highpass_filter( - baseline_envelope, data.samplerate, cutoff=cutoff + baseline_envelope, + data.samplerate, + config.envelope_highpass_cutoff ) + baseline_envelope = np.abs(baseline_envelope) # search_envelope = highpass_filter( - # search_envelope, data.samplerate, cutoff=cutoff) + # search_envelope, + # data.samplerate, + # config.envelope_highpass_cutoff + # ) # envelopes of filtered envelope of filtered baseline - # baseline_envelope = envelope( - # np.abs(baseline_envelope), data.samplerate, cutoff - # ) + baseline_envelope = envelope( + np.abs(baseline_envelope), + data.samplerate, + config.envelope_envelope_cutoff + ) - # search_envelope = bandpass_filter( - # search_envelope, data.samplerate, lowf=lowf, highf=highf) +# search_envelope = bandpass_filter( +# search_envelope, data.samplerate, lowf=lowf, highf=highf) # bandpass filter the instantaneous inst_freq_filtered = bandpass_filter( - baseline_freq, data.samplerate, lowf=15, highf=8000 + baseline_freq, + data.samplerate, + lowf=config.instantaneous_lowf, + highf=config.instantaneous_highf ) + # test taking the log of the envelopes + # baseline_envelope = np.log(baseline_envelope) + # search_envelope = np.log(search_envelope) + # CUT OFF OVERLAP ------------------------------------------------- # cut off first and last 0.5 * overlap at start and end valid = np.arange( - int(0.5 * window_overlap), len(baseline_envelope) - - int(0.5 * window_overlap) + int(window_overlap / 2), len(baseline_envelope) - + int(window_overlap / 2) ) baseline_envelope = baseline_envelope[valid] search_envelope = search_envelope[valid] # get inst freq valid snippet - valid_t0 = int(0.5 * window_overlap) / data.samplerate + valid_t0 = int(window_overlap / 2) / data.samplerate valid_t1 = baseline_freq_time[-1] - \ - (int(0.5 * window_overlap) / data.samplerate) + (int(window_overlap / 2) / data.samplerate) inst_freq_filtered = inst_freq_filtered[(baseline_freq_time >= valid_t0) & ( baseline_freq_time <= valid_t1)] @@ -354,24 +339,28 @@ def main(datapath: str) -> None: # PEAK DETECTION -------------------------------------------------- # detect peaks baseline_enelope - prominence = np.percentile(baseline_envelope, 90) + prominence = np.percentile( + baseline_envelope, config.baseline_prominence_percentile) baseline_peaks, _ = find_peaks( np.abs(baseline_envelope), prominence=prominence) # detect peaks search_envelope - prominence = np.percentile(search_envelope, 75) + prominence = np.percentile( + search_envelope, config.search_prominence_percentile) search_peaks, _ = find_peaks( search_envelope, prominence=prominence) # detect peaks inst_freq_filtered - prominence = 2 + prominence = np.percentile( + inst_freq_filtered, config.instantaneous_prominence_percentile) inst_freq_peaks, _ = find_peaks( np.abs(inst_freq_filtered), prominence=prominence) # PLOT ------------------------------------------------------------ # plot spectrogram - plot_spectrogram(axs[0, i], data_oi[:, electrode], data.samplerate, t0) + plot_spectrogram( + axs[0, i], data_oi[:, electrode], data.samplerate, t0) # plot baseline instantaneos frequency axs[1, i].plot(baseline_freq_time, baseline_freq - diff --git a/code/chirpdetector_conf.yml b/code/chirpdetector_conf.yml new file mode 100644 index 0000000..55a6d08 --- /dev/null +++ b/code/chirpdetector_conf.yml @@ -0,0 +1,32 @@ +# Duration and overlap of the analysis window in seconds +window: 5 +overlap: 0.5 + +# Number of electrodes to go over +electrodes: 3 + +# Boundary for search frequency in Hz +search_boundary: 100 + +# Cutoff frequency for envelope estimation by lowpass filter +envelope_cutoff: 25 + +# Cutoff frequency for envelope highpass filter +envelope_highpass_cutoff: 5 + +# Cutoff frequency for envelope of envelope +envelope_envelope_cutoff: 5 + +# Instantaneous frequency bandpass filter cutoff frequencies +instantaneous_lowf: 15 +instantaneous_highf: 8000 + +# Baseline envelope peak detection parameters +baseline_prominence_percentile: 90 + +# Search envelope peak detection parameters +search_prominence_percentile: 75 + +# Instantaneous frequency peak detection parameters +instantaneous_prominence_percentile: 90 + diff --git a/code/modules/filehandling.py b/code/modules/filehandling.py new file mode 100644 index 0000000..8a74fa4 --- /dev/null +++ b/code/modules/filehandling.py @@ -0,0 +1,54 @@ +import os + +import yaml +import numpy as np +from thunderfish.dataloader import DataLoader + + +class ConfLoader: + """ + Load configuration from yaml file as class attributes + """ + + def __init__(self, path: str) -> None: + with open(path) as file: + try: + conf = yaml.safe_load(file) + for key in conf: + setattr(self, key, conf[key]) + except yaml.YAMLError as error: + raise error + + +class LoadData: + """ + Attributes + ---------- + data : DataLoader object containing raw data + samplerate : sampling rate of raw data + time : array of time for tracked fundamental frequency + freq : array of fundamental frequency + idx : array of indices to access time array + ident : array of identifiers for each tracked fundamental frequency + ids : array of unique identifiers exluding NaNs + """ + + def __init__(self, datapath: str) -> None: + + # load raw data + self.file = os.path.join(datapath, "traces-grid1.raw") + self.data = DataLoader(self.file, 60.0, 0, channel=-1) + self.samplerate = self.data.samplerate + + # load wavetracker files + self.time = np.load(datapath + "times.npy", allow_pickle=True) + self.freq = np.load(datapath + "fund_v.npy", allow_pickle=True) + self.idx = np.load(datapath + "idx_v.npy", allow_pickle=True) + self.ident = np.load(datapath + "ident_v.npy", allow_pickle=True) + self.ids = np.unique(self.ident[~np.isnan(self.ident)]) + + def __repr__(self) -> str: + return f"LoadData({self.file})" + + def __str__(self) -> str: + return f"LoadData({self.file})"