from itertools import compress from dataclasses import dataclass import numpy as np from IPython import embed import matplotlib.pyplot as plt import matplotlib.gridspec as gr from scipy.signal import find_peaks from thunderfish.powerspectrum import spectrogram, decibel # from sklearn.preprocessing import normalize from modules.filters import bandpass_filter, envelope, highpass_filter from modules.filehandling import ConfLoader, LoadData, make_outputdir from modules.plotstyle import PlotStyle from modules.logger import makeLogger from modules.datahandling import ( flatten, purge_duplicates, group_timestamps, instantaneous_frequency, minmaxnorm, ) logger = makeLogger(__name__) ps = PlotStyle() @dataclass class ChirpPlotBuffer: """ Buffer to save data that is created in the main detection loop and plot it outside the detecion loop. """ config: ConfLoader t0: float dt: float track_id: float electrode: int data: LoadData time: np.ndarray baseline: np.ndarray baseline_envelope_unfiltered: np.ndarray baseline_envelope: np.ndarray baseline_peaks: np.ndarray search_frequency: float search: np.ndarray search_envelope_unfiltered: np.ndarray search_envelope: np.ndarray search_peaks: np.ndarray frequency_time: np.ndarray frequency: np.ndarray frequency_filtered: np.ndarray frequency_peaks: np.ndarray def plot_buffer(self, chirps: np.ndarray, plot: str) -> None: logger.debug("Starting plotting") # make data for plotting # get index of track data in this time window window_idx = np.arange(len(self.data.idx))[ (self.data.ident == self.track_id) & (self.data.time[self.data.idx] >= self.t0) & (self.data.time[self.data.idx] <= (self.t0 + self.dt)) ] # get tracked frequencies and their times freq_temp = self.data.freq[window_idx] # time_temp = self.data.time[ # self.data.idx[self.data.ident == self.track_id]][ # (self.data.time >= self.t0) # & (self.data.time <= (self.t0 + self.dt)) # ] # remake the band we filtered in q25, q50, q75 = np.percentile(freq_temp, [25, 50, 75]) search_upper, search_lower = ( q50 + self.search_frequency + self.config.minimal_bandwidth / 2, q50 + self.search_frequency - self.config.minimal_bandwidth / 2, ) print(search_upper, search_lower) # get indices on raw data start_idx = (self.t0 - 5) * self.data.raw_rate window_duration = (self.dt + 10) * self.data.raw_rate stop_idx = start_idx + window_duration # get raw data data_oi = self.data.raw[start_idx:stop_idx, self.electrode] self.time = self.time - self.t0 self.frequency_time = self.frequency_time - self.t0 if len(chirps) > 0: chirps = np.asarray(chirps) - self.t0 self.t0_old = self.t0 self.t0 = 0 fig = plt.figure(figsize=(14 * ps.cm, 18 * ps.cm)) gs0 = gr.GridSpec(3, 1, figure=fig, height_ratios=[1, 1, 1]) gs1 = gs0[0].subgridspec(1, 1) gs2 = gs0[1].subgridspec(3, 1, hspace=0.4) gs3 = gs0[2].subgridspec(3, 1, hspace=0.4) # gs4 = gs0[5].subgridspec(1, 1) ax6 = fig.add_subplot(gs3[2, 0]) ax0 = fig.add_subplot(gs1[0, 0], sharex=ax6) ax1 = fig.add_subplot(gs2[0, 0], sharex=ax6) ax2 = fig.add_subplot(gs2[1, 0], sharex=ax6) ax3 = fig.add_subplot(gs2[2, 0], sharex=ax6) ax4 = fig.add_subplot(gs3[0, 0], sharex=ax6) ax5 = fig.add_subplot(gs3[1, 0], sharex=ax6) # ax7 = fig.add_subplot(gs4[0, 0], sharex=ax0) # ax_leg = fig.add_subplot(gs0[1, 0]) waveform_scaler = 1000 lw = 1.5 # plot spectrogram _ = plot_spectrogram( ax0, data_oi, self.data.raw_rate, self.t0 - 5, [np.min(self.frequency) - 300, np.max(self.frequency) + 300], ) ax0.set_ylim(np.min(self.frequency) - 100, np.max(self.frequency) + 200) for track_id in self.data.ids: t0_track = self.t0_old - 5 dt_track = self.dt + 10 window_idx = np.arange(len(self.data.idx))[ (self.data.ident == track_id) & (self.data.time[self.data.idx] >= t0_track) & (self.data.time[self.data.idx] <= (t0_track + dt_track)) ] # get tracked frequencies and their times f = self.data.freq[window_idx] # t = self.data.time[ # self.data.idx[self.data.ident == self.track_id]] # tmask = (t >= t0_track) & (t <= (t0_track + dt_track)) t = self.data.time[self.data.idx[window_idx]] if track_id == self.track_id: ax0.plot(t - self.t0_old, f, lw=lw, zorder=10, color=ps.gblue1) else: ax0.plot(t - self.t0_old, f, lw=lw, zorder=10, color=ps.black) # ax0.fill_between( # np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate), # q50 - self.config.minimal_bandwidth / 2, # q50 + self.config.minimal_bandwidth / 2, # color=ps.gblue1, # lw=1, # ls="dashed", # alpha=0.5, # ) # ax0.fill_between( # np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate), # search_lower, # search_upper, # color=ps.gblue2, # lw=1, # ls="dashed", # alpha=0.5, # ) ax0.axhline( q50 - self.config.minimal_bandwidth / 2, color=ps.gblue1, lw=1, ls="dashed" ) ax0.axhline( q50 + self.config.minimal_bandwidth / 2, color=ps.gblue1, lw=1, ls="dashed" ) ax0.axhline(search_lower, color=ps.gblue2, lw=1, ls="dashed") ax0.axhline(search_upper, color=ps.gblue2, lw=1, ls="dashed") # ax0.axhline(q50, spec_times[0], spec_times[-1], # color=ps.gblue1, lw=2, ls="dashed") # ax0.axhline(q50 + self.search_frequency, # spec_times[0], spec_times[-1], # color=ps.gblue2, lw=2, ls="dashed") if len(chirps) > 0: for chirp in chirps: ax0.scatter( chirp, np.median(self.frequency), c=ps.red, marker=".", edgecolors=ps.black, facecolors=ps.red, zorder=10, s=70, ) # plot waveform of filtered signal ax1.plot( self.time, self.baseline * waveform_scaler, c=ps.gray, lw=lw, alpha=0.5 ) ax1.plot( self.time, self.baseline_envelope_unfiltered * waveform_scaler, c=ps.gblue1, lw=lw, label="baseline envelope", ) # plot waveform of filtered search signal ax2.plot(self.time, self.search * waveform_scaler, c=ps.gray, lw=lw, alpha=0.5) ax2.plot( self.time, self.search_envelope_unfiltered * waveform_scaler, c=ps.gblue2, lw=lw, label="search envelope", ) # plot baseline instantaneous frequency ax3.plot( self.frequency_time, self.frequency, c=ps.gblue3, lw=lw, label="baseline inst. freq.", ) # plot filtered and rectified envelope # ax4.plot( # self.time, self.baseline_envelope * waveform_scaler, c=ps.gblue1, lw=lw # ) ax4.plot( self.time, self.baseline_envelope, c=ps.gblue1, lw=lw ) ax4.scatter( (self.time)[self.baseline_peaks], # (self.baseline_envelope * waveform_scaler)[self.baseline_peaks], (self.baseline_envelope)[self.baseline_peaks], edgecolors=ps.black, facecolors=ps.red, zorder=10, marker=".", s=70, # facecolors="none", ) # plot envelope of search signal # ax5.plot(self.time, self.search_envelope * waveform_scaler, c=ps.gblue2, lw=lw) ax5.plot(self.time, self.search_envelope, c=ps.gblue2, lw=lw) ax5.scatter( (self.time)[self.search_peaks], # (self.search_envelope * waveform_scaler)[self.search_peaks], (self.search_envelope)[self.search_peaks], edgecolors=ps.black, facecolors=ps.red, zorder=10, marker=".", s=70, # facecolors="none", ) # plot filtered instantaneous frequency ax6.plot(self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw) ax6.scatter( self.frequency_time[self.frequency_peaks], self.frequency_filtered[self.frequency_peaks], edgecolors=ps.black, facecolors=ps.red, zorder=10, marker=".", s=70, # facecolors="none", ) ax0.set_ylabel("Frequency [Hz]") ax1.set_ylabel(r"$\mu$V") ax2.set_ylabel(r"$\mu$V") ax3.set_ylabel("Hz") ax4.set_ylabel(r"$\mu$V") ax5.set_ylabel(r"$\mu$V") ax6.set_ylabel("Hz") ax6.set_xlabel("Time [s]") plt.setp(ax0.get_xticklabels(), visible=False) plt.setp(ax1.get_xticklabels(), visible=False) plt.setp(ax2.get_xticklabels(), visible=False) plt.setp(ax3.get_xticklabels(), visible=False) plt.setp(ax4.get_xticklabels(), visible=False) plt.setp(ax5.get_xticklabels(), visible=False) # ps.letter_subplots([ax0, ax1, ax4], xoffset=-0.21) # ax7.set_xticks(np.arange(0, 5.5, 1)) # ax7.spines.bottom.set_bounds((0, 5)) ax0.set_xlim(0, self.config.window) plt.subplots_adjust(left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2) fig.align_labels() if plot == "show": plt.show() elif plot == "save": make_outputdir(self.config.outputdir) out = make_outputdir( self.config.outputdir + self.data.datapath.split("/")[-2] + "/" ) plt.savefig(f"{out}{self.track_id}_{self.t0_old}.pdf") plt.savefig(f"{out}{self.track_id}_{self.t0_old}.svg") plt.close() def plot_spectrogram( axis, signal: np.ndarray, samplerate: float, window_start_seconds: float, ylims: list[float], ) -> np.ndarray: """ Plot a spectrogram of a signal. Parameters ---------- axis : matplotlib axis Axis to plot the spectrogram on. signal : np.ndarray Signal to plot the spectrogram from. samplerate : float Samplerate of the signal. window_start_seconds : float Start time of the signal. """ logger.debug("Plotting spectrogram") # compute spectrogram spec_power, spec_freqs, spec_times = spectrogram( signal, ratetime=samplerate, freq_resolution=10, overlap_frac=0.5, ) fmask = np.zeros(spec_freqs.shape, dtype=bool) fmask[(spec_freqs > ylims[0]) & (spec_freqs < ylims[1])] = True axis.imshow( decibel(spec_power[fmask, :]), extent=[ spec_times[0] + window_start_seconds, spec_times[-1] + window_start_seconds, spec_freqs[fmask][0], spec_freqs[fmask][-1], ], aspect="auto", origin="lower", interpolation="gaussian", # alpha=0.6, ) # axis.use_sticky_edges = False return spec_times def extract_frequency_bands( raw_data: np.ndarray, samplerate: int, baseline_track: np.ndarray, searchband_center: float, minimal_bandwidth: float, ) -> tuple[np.ndarray, np.ndarray]: """ Apply a bandpass filter to the baseline of a signal and a second bandpass filter above or below the baseline, as specified by the search frequency. Parameters ---------- raw_data : np.ndarray Data to apply the filter to. samplerate : int Samplerate of the signal. baseline_track : np.ndarray Tracked fundamental frequencies of the signal. searchband_center: float Frequency to search for above or below the baseline. minimal_bandwidth : float Minimal bandwidth of the filter. Returns ------- tuple[np.ndarray, np.ndarray] """ # compute boundaries to filter baseline q25, q50, q75 = np.percentile(baseline_track, [25, 50, 75]) # check if percentile delta is too small if q75 - q25 < 10: q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2 # filter baseline filtered_baseline = bandpass_filter(raw_data, samplerate, lowf=q25, highf=q75) # filter search area filtered_search_freq = bandpass_filter( raw_data, samplerate, lowf=searchband_center + q50 - minimal_bandwidth / 2, highf=searchband_center + q50 + minimal_bandwidth / 2, ) return filtered_baseline, filtered_search_freq def window_median_all_track_ids( data: LoadData, window_start_seconds: float, window_duration_seconds: float ) -> tuple[list[tuple[float, float, float]], list[int]]: """ Calculate the median and quantiles of the frequency of all fish in a given time window. Iterate over all track ids and calculate the 25, 50 and 75 percentile in a given time window to pass this data to 'find_searchband' function, which then determines whether other fish in the current window fall within the searchband of the current fish and then determine the gaps that are outside of the percentile ranges. Parameters ---------- data : LoadData Data to calculate the median frequency from. window_start_seconds : float Start time of the window. window_duration_seconds : float Duration of the window. Returns ------- tuple[list[tuple[float, float, float]], list[int]] """ frequency_percentiles = [] track_ids = [] for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): # the window index combines the track id and the time window window_idx = np.arange(len(data.idx))[ (data.ident == track_id) & (data.time[data.idx] >= window_start_seconds) & (data.time[data.idx] <= (window_start_seconds + window_duration_seconds)) ] if len(data.freq[window_idx]) > 0: frequency_percentiles.append( np.percentile(data.freq[window_idx], [25, 50, 75]) ) track_ids.append(track_id) # convert to numpy array frequency_percentiles = np.asarray(frequency_percentiles) track_ids = np.asarray(track_ids) return frequency_percentiles, track_ids def array_center(array: np.ndarray) -> float: """ Return the center value of an array. If the array length is even, returns the mean of the two center values. Parameters ---------- array : np.ndarray Array to calculate the center from. Returns ------- float """ if len(array) % 2 == 0: return np.mean(array[int(len(array) / 2) - 1 : int(len(array) / 2) + 1]) else: return array[int(len(array) / 2)] def has_chirp(baseline_frequency: np.ndarray, peak_height: float) -> bool: """ Check if a fish has a chirp. Parameters ---------- baseline_frequency : np.ndarray Baseline frequency of the fish. peak_height : float Minimal peak height of a chirp on the instant. freq. Returns ------- bool: True if the fish has a chirp, False otherwise. """ peaks, _ = find_peaks(baseline_frequency, height=peak_height) if len(peaks) > 0: return True else: return False def mask_low_amplitudes(envelope, threshold): """ Mask low amplitudes in the envelope. Parameters ---------- envelope : np.ndarray Envelope of the signal. threshold : float Threshold to mask low amplitudes. Returns ------- np.ndarray """ mask = np.ones_like(envelope, dtype=bool) mask[envelope < threshold] = False return mask def find_searchband( current_frequency: np.ndarray, percentiles_ids: np.ndarray, frequency_percentiles: np.ndarray, config: ConfLoader, data: LoadData, ) -> float: """ Find the search frequency band for each fish by checking which fish EODs are above the current EOD and finding a gap in them. Parameters ---------- current_frequency : np.ndarray Current EOD frequency array / the current fish of interest. percentiles_ids : np.ndarray Array of track IDs of the medians of all other fish in the current window. frequency_percentiles : np.ndarray Array of percentiles frequencies of all other fish in the current window. config : ConfLoader Configuration file. data : LoadData Data to find the search frequency from. Returns ------- float """ # frequency window where second filter filters is potentially allowed # to filter. This is the search window, in which we want to find # a gap in the other fish's EODs. current_median = np.median(current_frequency) search_window = np.arange( current_median + config.search_df_lower, current_median + config.search_df_upper, config.search_res, ) # search window in boolean bool_lower = np.ones_like(search_window, dtype=bool) bool_upper = np.ones_like(search_window, dtype=bool) search_window_bool = np.ones_like(search_window, dtype=bool) # make seperate arrays from the qartiles q25 = np.asarray([i[0] for i in frequency_percentiles]) q75 = np.asarray([i[2] for i in frequency_percentiles]) # get tracks that fall into search window check_track_ids = percentiles_ids[ (q25 > current_median) & (q75 < search_window[-1]) ] # iterate through theses tracks if check_track_ids.size != 0: for j, check_track_id in enumerate(check_track_ids): q25_temp = q25[percentiles_ids == check_track_id] q75_temp = q75[percentiles_ids == check_track_id] bool_lower[search_window > q25_temp - config.search_res] = False bool_upper[search_window < q75_temp + config.search_res] = False search_window_bool[(bool_lower == False) & (bool_upper == False)] = False # find gaps in search window search_window_indices = np.arange(len(search_window)) # get search window gaps # taking the diff of a boolean array gives non zero values where the # array changes from true to false or vice versa search_window_gaps = np.diff(search_window_bool, append=np.nan) nonzeros = search_window_gaps[np.nonzero(search_window_gaps)[0]] nonzeros = nonzeros[~np.isnan(nonzeros)] if len(nonzeros) == 0: return config.default_search_freq # if the first value is -1, the array starst with true, so a gap if nonzeros[0] == -1: stops = search_window_indices[search_window_gaps == -1] starts = np.append(0, search_window_indices[search_window_gaps == 1]) # if the last value is -1, the array ends with true, so a gap if nonzeros[-1] == 1: stops = np.append( search_window_indices[search_window_gaps == -1], len(search_window) - 1, ) # else it starts with false, so no gap if nonzeros[0] == 1: stops = search_window_indices[search_window_gaps == -1] starts = search_window_indices[search_window_gaps == 1] # if the last value is -1, the array ends with true, so a gap if nonzeros[-1] == 1: stops = np.append( search_window_indices[search_window_gaps == -1], len(search_window), ) # get the frequency ranges of the gaps search_windows = [search_window[x:y] for x, y in zip(starts, stops)] search_windows_lens = [len(x) for x in search_windows] longest_search_window = search_windows[np.argmax(search_windows_lens)] # the center of the search frequency band is then the center of # the longest gap search_freq = array_center(longest_search_window) - current_median return search_freq return config.default_search_freq def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: assert plot in [ "save", "show", "false", ], "plot must be 'save', 'show' or 'false'" assert debug in [ "false", "electrode", "fish", ], "debug must be 'false', 'electrode' or 'fish'" if debug != "false": assert plot == "show", "debug mode only runs when plot is 'show'" # load raw file print("datapath", datapath) data = LoadData(datapath) # load config file config = ConfLoader("chirpdetector_conf.yml") # set time window window_duration = config.window * data.raw_rate window_overlap = config.overlap * data.raw_rate window_edge = config.edge * data.raw_rate # check if window duration and window ovelap is even, otherwise the half # of the duration or window overlap would return a float, thus an # invalid index if window_duration % 2 == 0: window_duration = int(window_duration) else: raise ValueError("Window duration must be even.") if window_overlap % 2 == 0: window_overlap = int(window_overlap) else: raise ValueError("Window overlap must be even.") # make time array for raw data raw_time = np.arange(data.raw.shape[0]) / data.raw_rate # good chirp times for data: 2022-06-02-10_00 # window_start_index = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate # window_duration_index = 60 * data.raw_rate # t0 = 0 # dt = data.raw.shape[0] # window_start_seconds = (23495 + ((28336-23495)/3)) * data.raw_rate # window_duration_seconds = (28336 - 23495) * data.raw_rate window_start_index = 0 window_duration_index = data.raw.shape[0] # generate starting points of rolling window window_start_indices = np.arange( window_start_index, window_start_index + window_duration_index, window_duration - (window_overlap + 2 * window_edge), dtype=int, ) # ititialize lists to store data multiwindow_chirps = [] multiwindow_ids = [] for st, window_start_index in enumerate(window_start_indices): logger.info(f"Processing window {st+1} of {len(window_start_indices)}") window_start_seconds = window_start_index / data.raw_rate window_duration_seconds = window_duration / data.raw_rate # set index window window_stop_index = window_start_index + window_duration # calucate median of fish frequencies in window median_freq, median_ids = window_median_all_track_ids( data, window_start_seconds, window_duration_seconds ) # iterate through all fish for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): logger.debug(f"Processing track {tr} of {len(data.ids)}") # get index of track data in this time window track_window_index = np.arange(len(data.idx))[ (data.ident == track_id) & (data.time[data.idx] >= window_start_seconds) & ( data.time[data.idx] <= (window_start_seconds + window_duration_seconds) ) ] # get tracked frequencies and their times current_frequencies = data.freq[track_window_index] current_powers = data.powers[track_window_index, :] # check if tracked data available in this window if len(current_frequencies) < 3: logger.warning( f"Track {track_id} has no data in window {st}, skipping." ) continue # check if there are powers available in this window nanchecker = np.unique(np.isnan(current_powers)) if (len(nanchecker) == 1) and nanchecker[0] is True: logger.warning( f"No powers available for track {track_id} window {st}," "skipping." ) continue # find the strongest electrodes for the current fish in the current # window best_electrode_index = np.argsort(np.nanmean(current_powers, axis=0))[ -config.number_electrodes : ] # find a frequency above the baseline of the current fish in which # no other fish is active to search for chirps there search_frequency = find_searchband( config=config, current_frequency=current_frequencies, percentiles_ids=median_ids, data=data, frequency_percentiles=median_freq, ) # add all chirps that are detected on mulitple electrodes for one # fish fish in one window to this list multielectrode_chirps = [] # iterate through electrodes for el, electrode_index in enumerate(best_electrode_index): logger.debug( f"Processing electrode {el+1} of " f"{len(best_electrode_index)}" ) # LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------ # load region of interest of raw data file current_raw_data = data.raw[ window_start_index:window_stop_index, electrode_index ] current_raw_time = raw_time[window_start_index:window_stop_index] # EXTRACT FEATURES -------------------------------------------- # filter baseline and above baselineband, searchband = extract_frequency_bands( raw_data=current_raw_data, samplerate=data.raw_rate, baseline_track=current_frequencies, searchband_center=search_frequency, minimal_bandwidth=config.minimal_bandwidth, ) # compute envelope of baseline band to find dips # in the baseline envelope baseline_envelope_unfiltered = envelope( signal=baselineband, samplerate=data.raw_rate, cutoff_frequency=config.baseline_envelope_cutoff, ) # create a mask that removes areas where amplitudes are very # because the instantaneous frequency is not reliable there amplitude_mask = mask_low_amplitudes( baseline_envelope_unfiltered, config.baseline_min_amplitude ) # highpass filter baseline envelope to remove slower # fluctuations e.g. due to motion envelope baseline_envelope = bandpass_filter( signal=baseline_envelope_unfiltered, samplerate=data.raw_rate, lowf=config.baseline_envelope_bandpass_lowf, highf=config.baseline_envelope_bandpass_highf, ) # invert baseline envelope to find troughs in the baseline baseline_envelope = -baseline_envelope # compute the envelope of the search band. Peaks in the search # band envelope correspond to troughs in the baseline envelope # during chirps search_envelope_unfiltered = envelope( signal=searchband, samplerate=data.raw_rate, cutoff_frequency=config.search_envelope_cutoff, ) search_envelope = search_envelope_unfiltered # compute instantaneous frequency of the baseline band to find # anomalies during a chirp, i.e. a frequency jump upwards or # sometimes downwards. We do not fully understand why the # instantaneous frequency can also jump downwards during a # chirp. This phenomenon is only observed on chirps on a narrow # filtered baseline such as the one we are working with. baseline_frequency = instantaneous_frequency( baselineband, data.raw_rate, config.baseline_frequency_smoothing ) # Take the absolute of the instantaneous frequency to invert # troughs into peaks. This is nessecary since the narrow # pass band introduces these anomalies. Also substract by the # median to set it to 0. baseline_frequency_filtered = np.abs( baseline_frequency - np.median(baseline_frequency) ) # check if there is at least one superthreshold peak on the # instantaneous and exit the loop if not. This is used to # prevent windows that do definetely not include a chirp # to enter normalization, where small changes due to noise # would be amplified if not has_chirp(baseline_frequency_filtered[amplitude_mask], config.baseline_frequency_peakheight): continue # CUT OFF OVERLAP --------------------------------------------- # cut off snippet at start and end of each window to remove # filter effects # get arrays with raw samplerate without edges no_edges = np.arange( int(window_edge), len(baseline_envelope) - int(window_edge) ) current_raw_time = current_raw_time[no_edges] baselineband = baselineband[no_edges] baseline_envelope_unfiltered = baseline_envelope_unfiltered[no_edges] searchband = searchband[no_edges] baseline_envelope = baseline_envelope[no_edges] search_envelope_unfiltered = search_envelope_unfiltered[no_edges] search_envelope = search_envelope[no_edges] baseline_frequency = baseline_frequency[no_edges] baseline_frequency_filtered = baseline_frequency_filtered[no_edges] baseline_frequency_time = current_raw_time # # get instantaneous frequency withoup edges # no_edges_t0 = int(window_edge) / data.raw_rate # no_edges_t1 = baseline_frequency_time[-1] - ( # int(window_edge) / data.raw_rate # ) # no_edges = (baseline_frequency_time >= no_edges_t0) & ( # baseline_frequency_time <= no_edges_t1 # ) # baseline_frequency_filtered = baseline_frequency_filtered[no_edges] # baseline_frequency = baseline_frequency[no_edges] # baseline_frequency_time = ( # baseline_frequency_time[no_edges] + window_start_seconds # ) # NORMALIZE --------------------------------------------------- # normalize all three feature arrays to the same range to make # peak detection simpler baseline_envelope = minmaxnorm([baseline_envelope])[0] search_envelope = minmaxnorm([search_envelope])[0] baseline_frequency_filtered = minmaxnorm( [baseline_frequency_filtered] )[0] # PEAK DETECTION ---------------------------------------------- # detect peaks baseline_enelope baseline_peak_indices, _ = find_peaks( baseline_envelope, prominence=config.baseline_prominence ) # detect peaks search_envelope search_peak_indices, _ = find_peaks( search_envelope, prominence=config.search_prominence ) # detect peaks inst_freq_filtered frequency_peak_indices, _ = find_peaks( baseline_frequency_filtered, prominence=config.frequency_prominence ) # DETECT CHIRPS IN SEARCH WINDOW ------------------------------ # get the peak timestamps from the peak indices baseline_peak_timestamps = current_raw_time[baseline_peak_indices] search_peak_timestamps = current_raw_time[search_peak_indices] frequency_peak_timestamps = baseline_frequency_time[ frequency_peak_indices ] # check if one list is empty and if so, skip to the next # electrode because a chirp cannot be detected if one is empty one_feature_empty = ( len(baseline_peak_timestamps) == 0 or len(search_peak_timestamps) == 0 or len(frequency_peak_timestamps) == 0 ) if one_feature_empty and (debug == "false"): continue # group peak across feature arrays but only if they # occur in all 3 feature arrays sublists = [ list(baseline_peak_timestamps), list(search_peak_timestamps), list(frequency_peak_timestamps), ] singleelectrode_chirps = group_timestamps( sublists=sublists, at_least_in=3, difference_threshold=config.chirp_window_threshold, ) # check it there are chirps detected after grouping, continue # with the loop if not if (len(singleelectrode_chirps) == 0) and (debug == "false"): continue # append chirps from this electrode to the multilectrode list multielectrode_chirps.append(singleelectrode_chirps) # only initialize the plotting buffer if chirps are detected chirp_detected = el == (config.number_electrodes - 1) & ( plot in ["show", "save"] ) if chirp_detected or (debug != "elecrode"): logger.debug("Detected chirp, ititialize buffer ...") # save data to Buffer buffer = ChirpPlotBuffer( config=config, t0=window_start_seconds, dt=window_duration_seconds, electrode=electrode_index, track_id=track_id, data=data, time=current_raw_time, baseline_envelope_unfiltered=baseline_envelope_unfiltered, baseline=baselineband, baseline_envelope=baseline_envelope, baseline_peaks=baseline_peak_indices, search_frequency=search_frequency, search=searchband, search_envelope_unfiltered=search_envelope_unfiltered, search_envelope=search_envelope, search_peaks=search_peak_indices, frequency_time=baseline_frequency_time, frequency=baseline_frequency, frequency_filtered=baseline_frequency_filtered, frequency_peaks=frequency_peak_indices, ) logger.debug("Buffer initialized!") if debug == "electrode": logger.info(f"Plotting electrode {el} ...") buffer.plot_buffer(chirps=singleelectrode_chirps, plot=plot) logger.debug( f"Processed all electrodes for fish {track_id} for this" "window, sorting chirps ..." ) # check if there are chirps detected in multiple electrodes and # continue the loop if not if (len(multielectrode_chirps) == 0) and (debug == "false"): continue # validate multielectrode chirps, i.e. check if they are # detected in at least 'config.min_electrodes' electrodes multielectrode_chirps_validated = group_timestamps( sublists=multielectrode_chirps, at_least_in=config.minimum_electrodes, difference_threshold=config.chirp_window_threshold, ) # add validated chirps to the list that tracks chirps across there # rolling time windows multiwindow_chirps.append(multielectrode_chirps_validated) multiwindow_ids.append(track_id) logger.info( f"Found {len(multielectrode_chirps_validated)}" f" chirps for fish {track_id} in this window!" ) # if chirps are detected and the plot flag is set, plot the # chirps, otheswise try to delete the buffer if it exists if debug == "fish": logger.info(f"Plotting fish {track_id} ...") buffer.plot_buffer(multielectrode_chirps_validated, plot) if ( (len(multielectrode_chirps_validated) > 0) & (plot in ["show", "save"]) & (debug == "false") ): try: buffer.plot_buffer(multielectrode_chirps_validated, plot) del buffer except NameError: pass else: try: del buffer except NameError: pass # flatten list of lists containing chirps and create # an array of fish ids that correspond to the chirps multiwindow_chirps_flat = [] multiwindow_ids_flat = [] for track_id in np.unique(multiwindow_ids): # get chirps for this fish and flatten the list current_track_bool = np.asarray(multiwindow_ids) == track_id current_track_chirps = flatten( list(compress(multiwindow_chirps, current_track_bool)) ) # add flattened chirps to the list multiwindow_chirps_flat.extend(current_track_chirps) multiwindow_ids_flat.extend(list(np.ones_like(current_track_chirps) * track_id)) # purge duplicates, i.e. chirps that are very close to each other # duplites arise due to overlapping windows purged_chirps = [] purged_ids = [] for track_id in np.unique(multiwindow_ids_flat): tr_chirps = np.asarray(multiwindow_chirps_flat)[ np.asarray(multiwindow_ids_flat) == track_id ] if len(tr_chirps) > 0: tr_chirps_purged = purge_duplicates( tr_chirps, config.chirp_window_threshold ) purged_chirps.extend(list(tr_chirps_purged)) purged_ids.extend(list(np.ones_like(tr_chirps_purged) * track_id)) # sort chirps by time purged_chirps = np.asarray(purged_chirps) purged_ids = np.asarray(purged_ids) purged_ids = purged_ids[np.argsort(purged_chirps)] purged_chirps = purged_chirps[np.argsort(purged_chirps)] # save them into the data directory np.save(datapath + "chirps.npy", purged_chirps) np.save(datapath + "chirp_ids.npy", purged_ids) if __name__ == "__main__": # datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-05-13-10_00/" # datapath = "/home/weygoldt/Data/uni/efishdata/2016-colombia/fishgrid/2016-04-09-22_25/" # datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-03-13-10_00/" datapath = "../data/2022-06-02-10_00/" chirpdetection(datapath, plot="show", debug="false")