From 12b1fdccae2da9691065e7043214563d6a99c57e Mon Sep 17 00:00:00 2001 From: weygoldt <88969563+weygoldt@users.noreply.github.com> Date: Tue, 11 Apr 2023 15:29:57 +0200 Subject: [PATCH 1/5] added chirp inst freq exporation --- chirp_instantaneous_freq/filters.py | 202 +++++++ chirp_instantaneous_freq/fish_signal.py | 557 ++++++++++++++++++++ chirp_instantaneous_freq/test_parameters.py | 118 +++++ requirements.txt | 2 - 4 files changed, 877 insertions(+), 2 deletions(-) create mode 100644 chirp_instantaneous_freq/filters.py create mode 100644 chirp_instantaneous_freq/fish_signal.py create mode 100644 chirp_instantaneous_freq/test_parameters.py diff --git a/chirp_instantaneous_freq/filters.py b/chirp_instantaneous_freq/filters.py new file mode 100644 index 0000000..709b140 --- /dev/null +++ b/chirp_instantaneous_freq/filters.py @@ -0,0 +1,202 @@ +from scipy.signal import butter, sosfiltfilt +from scipy.ndimage import gaussian_filter1d +import numpy as np + + +def instantaneous_frequency( + signal: np.ndarray, + samplerate: int, + smoothing_window: int, +) -> tuple[np.ndarray, np.ndarray]: + """ + Compute the instantaneous frequency of a signal that is approximately + sinusoidal and symmetric around 0. + + Parameters + ---------- + signal : np.ndarray + Signal to compute the instantaneous frequency from. + samplerate : int + Samplerate of the signal. + smoothing_window : int + Window size for the gaussian filter. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + + """ + # calculate instantaneous frequency with zero crossings + roll_signal = np.roll(signal, shift=1) + time_signal = np.arange(len(signal)) / samplerate + period_index = np.arange(len(signal))[(roll_signal < 0) & (signal >= 0)][ + 1:-1 + ] + + upper_bound = np.abs(signal[period_index]) + lower_bound = np.abs(signal[period_index - 1]) + upper_time = np.abs(time_signal[period_index]) + lower_time = np.abs(time_signal[period_index - 1]) + + # create ratio + lower_ratio = lower_bound / (lower_bound + upper_bound) + + # appy to time delta + time_delta = upper_time - lower_time + true_zero = lower_time + lower_ratio * time_delta + + # create new time array + instantaneous_frequency_time = true_zero[:-1] + 0.5 * np.diff(true_zero) + + # compute frequency + instantaneous_frequency = gaussian_filter1d( + 1 / np.diff(true_zero), smoothing_window + ) + + return instantaneous_frequency_time, instantaneous_frequency + + +def inst_freq(signal, fs): + """ + Computes the instantaneous frequency of a periodic signal using zero-crossings. + + Parameters: + ----------- + signal : array-like + The input signal. + fs : float + The sampling frequency of the input signal. + + Returns: + -------- + freq : array-like + The instantaneous frequency of the input signal. + """ + # Compute the sign of the signal + sign = np.sign(signal) + + # Compute the crossings of the sign signal with a zero line + crossings = np.where(np.diff(sign))[0] + + # Compute the time differences between zero crossings + dt = np.diff(crossings) / fs + + # Compute the instantaneous frequency as the reciprocal of the time differences + freq = 1 / dt + + # Gaussian filter the signal + freq = gaussian_filter1d(freq, 10) + + # Pad the frequency vector with zeros to match the length of the input signal + freq = np.pad(freq, (0, len(signal) - len(freq))) + + return freq + +def bandpass_filter( + signal: np.ndarray, + samplerate: float, + lowf: float, + highf: float, +) -> np.ndarray: + """Bandpass filter a signal. + + Parameters + ---------- + signal : np.ndarray + The data to be filtered + rate : float + The sampling rate + lowf : float + The low cutoff frequency + highf : float + The high cutoff frequency + + Returns + ------- + np.ndarray + The filtered data + """ + sos = butter(2, (lowf, highf), "bandpass", fs=samplerate, output="sos") + filtered_signal = sosfiltfilt(sos, signal) + + return filtered_signal + + +def highpass_filter( + signal: np.ndarray, + samplerate: float, + cutoff: float, +) -> np.ndarray: + """Highpass filter a signal. + + Parameters + ---------- + signal : np.ndarray + The data to be filtered + rate : float + The sampling rate + cutoff : float + The cutoff frequency + + Returns + ------- + np.ndarray + The filtered data + """ + sos = butter(2, cutoff, "highpass", fs=samplerate, output="sos") + filtered_signal = sosfiltfilt(sos, signal) + + return filtered_signal + + +def lowpass_filter( + signal: np.ndarray, + samplerate: float, + cutoff: float +) -> np.ndarray: + """Lowpass filter a signal. + + Parameters + ---------- + data : np.ndarray + The data to be filtered + rate : float + The sampling rate + cutoff : float + The cutoff frequency + + Returns + ------- + np.ndarray + The filtered data + """ + sos = butter(2, cutoff, "lowpass", fs=samplerate, output="sos") + filtered_signal = sosfiltfilt(sos, signal) + + return filtered_signal + + +def envelope(signal: np.ndarray, + samplerate: float, + cutoff_frequency: float + ) -> np.ndarray: + """Calculate the envelope of a signal using a lowpass filter. + + Parameters + ---------- + signal : np.ndarray + The signal to calculate the envelope of + samplingrate : float + The sampling rate of the signal + cutoff_frequency : float + The cutoff frequency of the lowpass filter + + Returns + ------- + np.ndarray + The envelope of the signal + """ + sos = butter(2, cutoff_frequency, "lowpass", fs=samplerate, output="sos") + envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(signal)) + + return envelope diff --git a/chirp_instantaneous_freq/fish_signal.py b/chirp_instantaneous_freq/fish_signal.py new file mode 100644 index 0000000..bf740b6 --- /dev/null +++ b/chirp_instantaneous_freq/fish_signal.py @@ -0,0 +1,557 @@ +import sys +from IPython import embed +import thunderfish.powerspectrum as ps +import numpy as np + +species_name = dict( + Sine="Sinewave", + Alepto="Apteronotus leptorhynchus", + Arostratus="Apteronotus rostratus", + Eigenmannia="Eigenmannia spec.", + Sternarchella="Sternarchella terminalis", + Sternopygus="Sternopygus dariensis", +) +"""Translate species ids used by wavefish_harmonics and pulsefish_eodpeaks to full species names. +""" + + +def abbrv_genus(name): + """Abbreviate genus in a species name. + + Parameters + ---------- + name: string + Full species name of the form 'Genus species subspecies' + + Returns + ------- + name: string + The species name with abbreviated genus, i.e. 'G. species subspecies' + """ + ns = name.split() + return ns[0][0] + ". " + " ".join(ns[1:]) + + +# Amplitudes and phases of various wavefish species: + +Sine_harmonics = dict(amplitudes=(1.0,), phases=(0.5 * np.pi,)) + +Apteronotus_leptorhynchus_harmonics = dict( + amplitudes=(0.90062, 0.15311, 0.072049, 0.012609, 0.011708), + phases=(1.3623, 2.3246, 0.9869, 2.6492, -2.6885), +) + +Apteronotus_rostratus_harmonics = dict( + amplitudes=( + 0.64707, + 0.43874, + 0.063592, + 0.07379, + 0.040199, + 0.023073, + 0.0097678, + ), + phases=(2.2988, 0.78876, -1.316, 2.2416, 2.0413, 1.1022, -2.0513), +) + +Eigenmannia_harmonics = dict( + amplitudes=(1.0087, 0.23201, 0.060524, 0.020175, 0.010087, 0.0080699), + phases=(1.3414, 1.3228, 2.9242, 2.8157, 2.6871, -2.8415), +) + +Sternarchella_terminalis_harmonics = dict( + amplitudes=( + 0.11457, + 0.4401, + 0.41055, + 0.20132, + 0.061364, + 0.011389, + 0.0057985, + ), + phases=(-2.7106, 2.4472, 1.6829, 0.79085, 0.119, -0.82355, -1.9956), +) + +Sternopygus_dariensis_harmonics = dict( + amplitudes=( + 0.98843, + 0.41228, + 0.047848, + 0.11048, + 0.022801, + 0.030706, + 0.019018, + ), + phases=(1.4153, 1.3141, 3.1062, -2.3961, -1.9524, 0.54321, 1.6844), +) + +wavefish_harmonics = dict( + Sine=Sine_harmonics, + Alepto=Apteronotus_leptorhynchus_harmonics, + Arostratus=Apteronotus_rostratus_harmonics, + Eigenmannia=Eigenmannia_harmonics, + Sternarchella=Sternarchella_terminalis_harmonics, + Sternopygus=Sternopygus_dariensis_harmonics, +) +"""Amplitudes and phases of EOD waveforms of various species of wave-type electric fish.""" + + +def wavefish_spectrum(fish): + """Amplitudes and phases of a wavefish EOD. + + Parameters + ---------- + fish: string, dict or tuple of lists/arrays + Specify relative amplitudes and phases of the fundamental and its harmonics. + If string then take amplitudes and phases from the `wavefish_harmonics` dictionary. + If dictionary then take amplitudes and phases from the 'amlitudes' and 'phases' keys. + If tuple then the first element is the list of amplitudes and + the second one the list of relative phases in radians. + + Returns + ------- + amplitudes: array of floats + Amplitudes of the fundamental and its harmonics. + phases: array of floats + Phases in radians of the fundamental and its harmonics. + + Raises + ------ + KeyError: + Unknown fish. + IndexError: + Amplitudes and phases differ in length. + """ + if isinstance(fish, (tuple, list)): + amplitudes = fish[0] + phases = fish[1] + elif isinstance(fish, dict): + amplitudes = fish["amplitudes"] + phases = fish["phases"] + else: + if fish not in wavefish_harmonics: + raise KeyError( + "unknown wavefish. Choose one of " + + ", ".join(wavefish_harmonics.keys()) + ) + amplitudes = wavefish_harmonics[fish]["amplitudes"] + phases = wavefish_harmonics[fish]["phases"] + if len(amplitudes) != len(phases): + raise IndexError("need exactly as many phases as amplitudes") + # remove NaNs: + for k in reversed(range(len(amplitudes))): + if np.isfinite(amplitudes[k]) or np.isfinite(phases[k]): + amplitudes = amplitudes[: k + 1] + phases = phases[: k + 1] + break + return amplitudes, phases + + +def wavefish_eods( + fish="Eigenmannia", + frequency=100.0, + samplerate=44100.0, + duration=1.0, + phase0=0.0, + noise_std=0.05, +): + """Simulate EOD waveform of a wave-type fish. + + The waveform is constructed by superimposing sinewaves of integral + multiples of the fundamental frequency - the fundamental and its + harmonics. The fundamental frequency of the EOD (EODf) is given by + `frequency`. With `fish` relative amplitudes and phases of the + fundamental and its harmonics are specified. + + The generated waveform is `duration` seconds long and is sampled with + `samplerate` Hertz. Gaussian white noise with a standard deviation of + `noise_std` is added to the generated waveform. + + Parameters + ---------- + fish: string, dict or tuple of lists/arrays + Specify relative amplitudes and phases of the fundamental and its harmonics. + If string then take amplitudes and phases from the `wavefish_harmonics` dictionary. + If dictionary then take amplitudes and phases from the 'amlitudes' and 'phases' keys. + If tuple then the first element is the list of amplitudes and + the second one the list of relative phases in radians. + frequency: float or array of floats + EOD frequency of the fish in Hertz. Either fixed number or array for + time-dependent frequencies. + samplerate: float + Sampling rate in Hertz. + duration: float + Duration of the generated data in seconds. Only used if frequency is scalar. + phase0: float + Phase offset of the EOD waveform in radians. + noise_std: float + Standard deviation of additive Gaussian white noise. + + Returns + ------- + data: array of floats + Generated data of a wave-type fish. + + Raises + ------ + KeyError: + Unknown fish. + IndexError: + Amplitudes and phases differ in length. + """ + # get relative amplitude and phases: + amplitudes, phases = wavefish_spectrum(fish) + # compute phase: + if np.isscalar(frequency): + phase = np.arange(0, duration, 1.0 / samplerate) + phase *= frequency + else: + phase = np.cumsum(frequency) / samplerate + # generate EOD: + data = np.zeros(len(phase)) + for har, (ampl, phi) in enumerate(zip(amplitudes, phases)): + if np.isfinite(ampl) and np.isfinite(phi): + data += ampl * np.sin( + 2 * np.pi * (har + 1) * phase + phi - (har + 1) * phase0 + ) + # add noise: + data += noise_std * np.random.randn(len(data)) + return data + + +def normalize_wavefish(fish, mode="peak"): + """Normalize amplitudes and phases of wave-type EOD waveform. + + The amplitudes and phases of the Fourier components are adjusted + such that the resulting EOD waveform has a peak-to-peak amplitude + of two and the peak of the waveform is at time zero (mode is set + to 'peak') or that the fundamental has an amplitude of one and a + phase of 0 (mode is set to 'zero'). + + Parameters + ---------- + fish: string, dict or tuple of lists/arrays + Specify relative amplitudes and phases of the fundamental and its harmonics. + If string then take amplitudes and phases from the `wavefish_harmonics` dictionary. + If dictionary then take amplitudes and phases from the 'amlitudes' and 'phases' keys. + If tuple then the first element is the list of amplitudes and + the second one the list of relative phases in radians. + mode: 'peak' or 'zero' + How to normalize amplitude and phases: + - 'peak': normalize waveform to a peak-to-peak amplitude of two + and shift it such that its peak is at time zero. + - 'zero': scale amplitude of fundamental to one and its phase to zero. + + Returns + ------- + amplitudes: array of floats + Adjusted amplitudes of the fundamental and its harmonics. + phases: array of floats + Adjusted phases in radians of the fundamental and its harmonics. + + """ + # get relative amplitude and phases: + amplitudes, phases = wavefish_spectrum(fish) + if mode == "zero": + newamplitudes = np.array(amplitudes) / amplitudes[0] + newphases = np.array( + [p + (k + 1) * (-phases[0]) for k, p in enumerate(phases)] + ) + newphases %= 2.0 * np.pi + newphases[newphases > np.pi] -= 2.0 * np.pi + return newamplitudes, newphases + else: + # generate waveform: + eodf = 100.0 + rate = 100000.0 + data = wavefish_eods(fish, eodf, rate, 2.0 / eodf, noise_std=0.0) + # normalize amplitudes: + ampl = 0.5 * (np.max(data) - np.min(data)) + newamplitudes = np.array(amplitudes) / ampl + # shift phases: + deltat = np.argmax(data[: int(rate / eodf)]) / rate + deltap = 2.0 * np.pi * deltat * eodf + newphases = np.array( + [p + (k + 1) * deltap for k, p in enumerate(phases)] + ) + newphases %= 2.0 * np.pi + newphases[newphases > np.pi] -= 2.0 * np.pi + # return: + return newamplitudes, newphases + + +def export_wavefish(fish, name="Unknown_harmonics", file=None): + """Serialize wavefish parameter to python code. + + Add output to the wavefish_harmonics dictionary! + + Parameters + ---------- + fish: string, dict or tuple of lists/arrays + Specify relative amplitudes and phases of the fundamental and its harmonics. + If string then take amplitudes and phases from the `wavefish_harmonics` dictionary. + If dictionary then take amplitudes and phases from the 'amlitudes' and 'phases' keys. + If tuple then the first element is the list of amplitudes and + the second one the list of relative phases in radians. + name: string + Name of the dictionary to be written. + file: string or file or None + File name or open file object where to write wavefish dictionary. + + Returns + ------- + fish: dict + Dictionary with amplitudes and phases. + """ + # get relative amplitude and phases: + amplitudes, phases = wavefish_spectrum(fish) + # write out dictionary: + if file is None: + file = sys.stdout + try: + file.write("") + closeit = False + except AttributeError: + file = open(file, "w") + closeit = True + n = 6 + file.write(name + " = \\\n") + file.write(" dict(amplitudes=(") + file.write(", ".join(["%.5g" % a for a in amplitudes[:n]])) + for k in range(n, len(amplitudes), n): + file.write(",\n") + file.write(" " * (9 + 12)) + file.write(", ".join(["%.5g" % a for a in amplitudes[k : k + n]])) + file.write("),\n") + file.write(" " * 9 + "phases=(") + file.write(", ".join(["%.5g" % p for p in phases[:n]])) + for k in range(n, len(phases), n): + file.write(",\n") + file.write(" " * (9 + 8)) + file.write(", ".join(["%.5g" % p for p in phases[k : k + n]])) + file.write("))\n") + if closeit: + file.close() + # return dictionary: + harmonics = dict(amplitudes=amplitudes, phases=phases) + return harmonics + + +def chirps( + eodf=100.0, + samplerate=44100.0, + duration=1.0, + chirp_times=[0.5], + chirp_size=[100.0], + chirp_width=[0.01], + chirp_kurtosis=[1.0], + chirp_contrast=[0.05], +): + """Simulate frequency trace with chirps. + + A chirp is modeled as a Gaussian frequency modulation. + The first chirp is placed at 0.5/chirp_freq. + + Parameters + ---------- + eodf: float + EOD frequency of the fish in Hertz. + samplerate: float + Sampling rate in Hertz. + duration: float + Duration of the generated data in seconds. + chirp_times: float + Timestamps of every single chirp in seconds. + chirp_size: list + Size of each chirp (maximum frequency increase above eodf) in Hertz. + chirp_width: list + Width of every single chirp at 10% height in seconds. + chirp_kurtosis: list: + Shape of every single chirp. =1: Gaussian, >1: more rectangular, <1: more peaked. + chirp_contrast: float + Maximum amplitude reduction of EOD during every respective chirp. + + Returns + ------- + frequency: array of floats + Generated frequency trace that can be passed on to wavefish_eods(). + amplitude: array of floats + Generated amplitude modulation that can be used to multiply the trace generated by + wavefish_eods(). + """ + # baseline eod frequency and amplitude modulation: + n = len(np.arange(0, duration, 1.0 / samplerate)) + frequency = eodf * np.ones(n) + am = np.ones(n) + + for time, width, size, kurtosis, contrast in zip(chirp_times, chirp_width, chirp_size, chirp_kurtosis, chirp_contrast): + + # chirp frequency waveform: + chirp_t = np.arange(-2.0 * width, 2.0 * width, 1.0 / samplerate) + chirp_sig = ( + 0.5 * width / (2.0 * np.log(10.0)) ** (0.5 / kurtosis) + ) + gauss = np.exp(-0.5 * ((chirp_t / chirp_sig) ** 2.0) ** kurtosis) + + + # add chirps on baseline eodf: + index = int(time * samplerate) + i0 = index - len(gauss) // 2 + i1 = i0 + len(gauss) + gi0 = 0 + gi1 = len(gauss) + if i0 < 0: + gi0 -= i0 + i0 = 0 + if i1 >= len(frequency): + gi1 -= i1 - len(frequency) + i1 = len(frequency) + frequency[i0:i1] += size * gauss[gi0:gi1] + am[i0:i1] -= contrast * gauss[gi0:gi1] + + return frequency, am + + +def rises( + eodf, + samplerate, + duration, + rise_times, + rise_size, + rise_tau, + decay_tau, +): + """Simulate frequency trace with rises. + + A rise is modeled as a double exponential frequency modulation. + + Parameters + ---------- + eodf: float + EOD frequency of the fish in Hertz. + samplerate: float + Sampling rate in Hertz. + duration: float + Duration of the generated data in seconds. + rise_times: list + Timestamp of each of the rises in seconds. + rise_size: list + Size of the respective rise (frequency increase above eodf) in Hertz. + rise_tau: list + Time constant of the frequency increase of the respective rise in seconds. + decay_tau: list + Time constant of the frequency decay of the respective rise in seconds. + + Returns + ------- + data: array of floats + Generate frequency trace that can be passed on to wavefish_eods(). + """ + n = len(np.arange(0, duration, 1.0 / samplerate)) + + # baseline eod frequency: + frequency = eodf * np.ones(n) + + for time, size, riset, decayt in zip(rise_times, rise_size, rise_tau, decay_tau): + + # rise frequency waveform: + rise_t = np.arange(0.0, 5.0 * decayt, 1.0 / samplerate) + rise = ( + size + * (1.0 - np.exp(-rise_t / riset)) + * np.exp(-rise_t / decayt) + ) + + # add rises on baseline eodf: + index = int(time * samplerate) + if index + len(rise) > len(frequency): + rise_index = len(frequency) - index + frequency[index : index + rise_index] += rise[:rise_index] + break + else: + frequency[index : index + len(rise)] += rise + return frequency + +class FishSignal: + def __init__(self, samplerate, duration, eodf, nchirps, nrises): + time = np.arange(0, duration, 1 / samplerate) + chirp_times = np.random.uniform(0, duration, nchirps) + rise_times = np.random.uniform(0, duration, nrises) + + # pick random parameters for chirps + chirp_size = np.random.uniform(60, 200, nchirps) + chirp_width = np.random.uniform(0.01, 0.1, nchirps) + chirp_kurtosis = np.random.uniform(1, 1, nchirps) + chirp_contrast = np.random.uniform(0.1, 0.5, nchirps) + + # pick random parameters for rises + rise_size = np.random.uniform(10, 100, nrises) + rise_tau = np.random.uniform(0.5, 1.5, nrises) + decay_tau = np.random.uniform(5, 15, nrises) + + # generate frequency trace with chirps + chirp_trace, chirp_amp = chirps( + eodf=0.0, + samplerate=samplerate, + duration=duration, + chirp_times=chirp_times, + chirp_size=chirp_size, + chirp_width=chirp_width, + chirp_kurtosis=chirp_kurtosis, + chirp_contrast=chirp_contrast, + ) + + # generate frequency trace with rises + rise_trace = rises( + eodf=0.0, + samplerate=samplerate, + duration=duration, + rise_times=rise_times, + rise_size=rise_size, + rise_tau=rise_tau, + decay_tau=decay_tau, + ) + + # combine traces to one + full_trace = chirp_trace + rise_trace + eodf + + # make the EOD from the frequency trace + fish = wavefish_eods( + fish="Alepto", + frequency=full_trace, + samplerate=samplerate, + duration=duration, + phase0=0.0, + noise_std=0.05, + ) + + signal = fish * chirp_amp + + self.signal = signal + self.trace = full_trace + self.time = time + self.samplerate = samplerate + self.eodf = eodf + + def visualize(self): + + spec, freqs, spectime = ps.spectrogram( + data=self.signal, + ratetime=self.samplerate, + freq_resolution=0.5, + overlap_frac=0.5, + ) + + fig, (ax1, ax2) = plt.subplots(2, 1, height_ratios=[1, 4], sharex=True) + + ax1.plot(self.time, self.signal) + ax1.set_ylabel("Amplitude") + ax1.set_xlabel("Time (s)") + ax1.set_title("EOD signal") + + ax2.imshow(ps.decibel(spec), origin='lower', aspect="auto", extent=[spectime[0], spectime[-1], freqs[0], freqs[-1]]) + ax2.set_ylabel("Frequency (Hz)") + ax2.set_xlabel("Time (s)") + ax2.set_title("Spectrogram") + ax2.set_ylim(0, 2000) + plt.show() diff --git a/chirp_instantaneous_freq/test_parameters.py b/chirp_instantaneous_freq/test_parameters.py new file mode 100644 index 0000000..9c4ab5f --- /dev/null +++ b/chirp_instantaneous_freq/test_parameters.py @@ -0,0 +1,118 @@ +import numpy as np +import matplotlib.pyplot as plt +from fish_signal import chirps, wavefish_eods +from filters import bandpass_filter, instantaneous_frequency, inst_freq +from IPython import embed + + +def switch_test(test, defaultparams, testparams): + if test == 'width': + defaultparams['chirp_width'] = testparams['chirp_width'] + key = 'chirp_width' + elif test == 'size': + defaultparams['chirp_size'] = testparams['chirp_size'] + key = 'chirp_size' + elif test == 'kurtosis': + defaultparams['chirp_kurtosis'] = testparams['chirp_kurtosis'] + key = 'chirp_kurtosis' + elif test == 'contrast': + defaultparams['chirp_contrast'] = testparams['chirp_contrast'] + key = 'chirp_contrast' + else: + raise ValueError("Test not recognized") + + return key, defaultparams + + +def extract_dict(dict, index): + return {key: value[index] for key, value in dict.items()} + + +def main(test1, test2, resolution=10): + + assert test1 in ['width', 'size', 'kurtosis', 'contrast'], "Test1 not recognized" + assert test2 in ['width', 'size', 'kurtosis', 'contrast'], "Test2 not recognized" + + # Define the parameters for the chirp simulations + ntest = resolution + + defaultparams = dict( + chirp_size = np.ones(ntest) * 100, + chirp_width = np.ones(ntest) * 0.1, + chirp_kurtosis = np.ones(ntest) * 1.0, + chirp_contrast = np.ones(ntest) * 0.5, + ) + + testparams = dict( + chirp_width = np.linspace(0.01, 0.2, ntest), + chirp_size = np.linspace(50, 300, ntest), + chirp_kurtosis = np.linspace(0.5, 1.5, ntest), + chirp_contrast = np.linspace(0.01, 1.0, ntest), + ) + + key1, chirp_params = switch_test(test1, defaultparams, testparams) + key2, chirp_params = switch_test(test2, chirp_params, testparams) + + # make the chirp trace + eodf = 500 + samplerate = 20000 + duration = 2 + chirp_times = [0.5, 1, 1.5] + + wide_cutoffs = 200 + tight_cutoffs = 10 + + distances = np.full((ntest, ntest), np.nan) + + fig, axs = plt.subplots(ntest, ntest, figsize = (10, 10), sharex = True, sharey = True) + axs = axs.flatten() + + iter0 = 0 + for iter1, test1_param in enumerate(chirp_params[key1]): + for iter2, test2_param in enumerate(chirp_params[key2]): + + # get the chirp parameters for the current test + inner_chirp_params = extract_dict(chirp_params, iter2) + inner_chirp_params[key1] = test1_param + inner_chirp_params[key2] = test2_param + + # make the chirp trace for the current chirp parameters + sizes = np.ones(len(chirp_times)) * inner_chirp_params['chirp_size'] + widths = np.ones(len(chirp_times)) * inner_chirp_params['chirp_width'] + kurtosis = np.ones(len(chirp_times)) * inner_chirp_params['chirp_kurtosis'] + contrast = np.ones(len(chirp_times)) * inner_chirp_params['chirp_contrast'] + + # make the chirp trace + chirp_trace, ampmod = chirps(eodf, samplerate, duration, chirp_times, sizes, widths, kurtosis, contrast) + signal = wavefish_eods( + fish="Alepto", + frequency=chirp_trace, + samplerate=samplerate, + duration=duration, + phase0=0.0, + noise_std=0.05 + ) + signal = signal * ampmod + + # apply broadband filter + wide_signal = bandpass_filter(signal, samplerate, eodf - wide_cutoffs, eodf + wide_cutoffs) + tight_signal = bandpass_filter(signal, samplerate, eodf - tight_cutoffs, eodf + tight_cutoffs) + + # get the instantaneous frequency + wide_frequency = inst_freq(wide_signal, samplerate) + tight_frequency = inst_freq(tight_signal, samplerate) + + bool_mask = wide_frequency != 0 + axs[iter0].plot(wide_frequency[bool_mask]) + axs[iter0].plot(tight_frequency[bool_mask]) + fig.supylabel(key1) + fig.supxlabel(key2) + + iter0 += 1 + + fig, ax = plt.subplots() + ax.imshow(distances, cmap = 'jet') + plt.show() + +if __name__ == "__main__": + main('width', 'size') diff --git a/requirements.txt b/requirements.txt index b6a16c7..00bb2e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -audioio==0.10.0 cmocean==3.0.3 cycler==0.11.0 ipython==8.12.0 @@ -9,5 +8,4 @@ paramiko==3.1.0 PyYAML==6.0 scipy==1.10.1 scp==0.14.5 -thunderfish==1.9.10 tqdm==4.65.0 From 282c846b05cbbfca07d2079c2651bdf90577bbe5 Mon Sep 17 00:00:00 2001 From: weygoldt <88969563+weygoldt@users.noreply.github.com> Date: Tue, 11 Apr 2023 15:33:07 +0200 Subject: [PATCH 2/5] reformat --- chirp_instantaneous_freq/filters.py | 36 ++- chirp_instantaneous_freq/fish_signal.py | 34 +-- chirp_instantaneous_freq/test_parameters.py | 124 +++++---- code/analysis.py | 81 +++--- code/band_pass_problem.py | 26 +- code/behavior.py | 237 +++++++++------- code/chirp_sim.py | 19 +- code/chirpdetection.py | 138 ++++++---- code/chirpdetector_conf.yml | 41 ++- code/eventchirpsplots.py | 266 +++++++++++------- code/extract_chirps.py | 37 +-- code/get_behaviour.py | 44 +-- code/modules/behaviour_handling.py | 96 +++---- code/modules/datahandling.py | 21 +- code/modules/filehandling.py | 1 - code/modules/filters.py | 19 +- code/modules/logger.py | 8 +- code/modules/plotstyle.py | 16 +- code/modules/plotstyle1.py | 16 +- code/modules/plotstyle_dark.py | 16 +- code/modules/simulations.py | 2 +- code/plot_chirp_size.py | 291 ++++++++++++-------- code/plot_chirps_in_chasing.py | 74 +++-- code/plot_event_timeline.py | 80 ++++-- code/plot_introduction_specs.py | 56 ++-- code/plot_kdes.py | 208 ++++++++------ 26 files changed, 1177 insertions(+), 810 deletions(-) diff --git a/chirp_instantaneous_freq/filters.py b/chirp_instantaneous_freq/filters.py index 709b140..a0d4001 100644 --- a/chirp_instantaneous_freq/filters.py +++ b/chirp_instantaneous_freq/filters.py @@ -59,14 +59,14 @@ def instantaneous_frequency( def inst_freq(signal, fs): """ Computes the instantaneous frequency of a periodic signal using zero-crossings. - + Parameters: ----------- signal : array-like The input signal. fs : float The sampling frequency of the input signal. - + Returns: -------- freq : array-like @@ -74,29 +74,30 @@ def inst_freq(signal, fs): """ # Compute the sign of the signal sign = np.sign(signal) - + # Compute the crossings of the sign signal with a zero line crossings = np.where(np.diff(sign))[0] - + # Compute the time differences between zero crossings dt = np.diff(crossings) / fs - + # Compute the instantaneous frequency as the reciprocal of the time differences freq = 1 / dt - # Gaussian filter the signal + # Gaussian filter the signal freq = gaussian_filter1d(freq, 10) - + # Pad the frequency vector with zeros to match the length of the input signal freq = np.pad(freq, (0, len(signal) - len(freq))) - + return freq + def bandpass_filter( - signal: np.ndarray, - samplerate: float, - lowf: float, - highf: float, + signal: np.ndarray, + samplerate: float, + lowf: float, + highf: float, ) -> np.ndarray: """Bandpass filter a signal. @@ -150,9 +151,7 @@ def highpass_filter( def lowpass_filter( - signal: np.ndarray, - samplerate: float, - cutoff: float + signal: np.ndarray, samplerate: float, cutoff: float ) -> np.ndarray: """Lowpass filter a signal. @@ -176,10 +175,9 @@ def lowpass_filter( return filtered_signal -def envelope(signal: np.ndarray, - samplerate: float, - cutoff_frequency: float - ) -> np.ndarray: +def envelope( + signal: np.ndarray, samplerate: float, cutoff_frequency: float +) -> np.ndarray: """Calculate the envelope of a signal using a lowpass filter. Parameters diff --git a/chirp_instantaneous_freq/fish_signal.py b/chirp_instantaneous_freq/fish_signal.py index bf740b6..d7830bf 100644 --- a/chirp_instantaneous_freq/fish_signal.py +++ b/chirp_instantaneous_freq/fish_signal.py @@ -384,16 +384,14 @@ def chirps( frequency = eodf * np.ones(n) am = np.ones(n) - for time, width, size, kurtosis, contrast in zip(chirp_times, chirp_width, chirp_size, chirp_kurtosis, chirp_contrast): - + for time, width, size, kurtosis, contrast in zip( + chirp_times, chirp_width, chirp_size, chirp_kurtosis, chirp_contrast + ): # chirp frequency waveform: chirp_t = np.arange(-2.0 * width, 2.0 * width, 1.0 / samplerate) - chirp_sig = ( - 0.5 * width / (2.0 * np.log(10.0)) ** (0.5 / kurtosis) - ) + chirp_sig = 0.5 * width / (2.0 * np.log(10.0)) ** (0.5 / kurtosis) gauss = np.exp(-0.5 * ((chirp_t / chirp_sig) ** 2.0) ** kurtosis) - # add chirps on baseline eodf: index = int(time * samplerate) i0 = index - len(gauss) // 2 @@ -433,7 +431,7 @@ def rises( Sampling rate in Hertz. duration: float Duration of the generated data in seconds. - rise_times: list + rise_times: list Timestamp of each of the rises in seconds. rise_size: list Size of the respective rise (frequency increase above eodf) in Hertz. @@ -452,15 +450,12 @@ def rises( # baseline eod frequency: frequency = eodf * np.ones(n) - for time, size, riset, decayt in zip(rise_times, rise_size, rise_tau, decay_tau): - + for time, size, riset, decayt in zip( + rise_times, rise_size, rise_tau, decay_tau + ): # rise frequency waveform: rise_t = np.arange(0.0, 5.0 * decayt, 1.0 / samplerate) - rise = ( - size - * (1.0 - np.exp(-rise_t / riset)) - * np.exp(-rise_t / decayt) - ) + rise = size * (1.0 - np.exp(-rise_t / riset)) * np.exp(-rise_t / decayt) # add rises on baseline eodf: index = int(time * samplerate) @@ -472,13 +467,14 @@ def rises( frequency[index : index + len(rise)] += rise return frequency + class FishSignal: def __init__(self, samplerate, duration, eodf, nchirps, nrises): time = np.arange(0, duration, 1 / samplerate) chirp_times = np.random.uniform(0, duration, nchirps) rise_times = np.random.uniform(0, duration, nrises) - # pick random parameters for chirps + # pick random parameters for chirps chirp_size = np.random.uniform(60, 200, nchirps) chirp_width = np.random.uniform(0.01, 0.1, nchirps) chirp_kurtosis = np.random.uniform(1, 1, nchirps) @@ -534,7 +530,6 @@ class FishSignal: self.eodf = eodf def visualize(self): - spec, freqs, spectime = ps.spectrogram( data=self.signal, ratetime=self.samplerate, @@ -549,7 +544,12 @@ class FishSignal: ax1.set_xlabel("Time (s)") ax1.set_title("EOD signal") - ax2.imshow(ps.decibel(spec), origin='lower', aspect="auto", extent=[spectime[0], spectime[-1], freqs[0], freqs[-1]]) + ax2.imshow( + ps.decibel(spec), + origin="lower", + aspect="auto", + extent=[spectime[0], spectime[-1], freqs[0], freqs[-1]], + ) ax2.set_ylabel("Frequency (Hz)") ax2.set_xlabel("Time (s)") ax2.set_title("Spectrogram") diff --git a/chirp_instantaneous_freq/test_parameters.py b/chirp_instantaneous_freq/test_parameters.py index 9c4ab5f..bad1e45 100644 --- a/chirp_instantaneous_freq/test_parameters.py +++ b/chirp_instantaneous_freq/test_parameters.py @@ -1,4 +1,4 @@ -import numpy as np +import numpy as np import matplotlib.pyplot as plt from fish_signal import chirps, wavefish_eods from filters import bandpass_filter, instantaneous_frequency, inst_freq @@ -6,18 +6,18 @@ from IPython import embed def switch_test(test, defaultparams, testparams): - if test == 'width': - defaultparams['chirp_width'] = testparams['chirp_width'] - key = 'chirp_width' - elif test == 'size': - defaultparams['chirp_size'] = testparams['chirp_size'] - key = 'chirp_size' - elif test == 'kurtosis': - defaultparams['chirp_kurtosis'] = testparams['chirp_kurtosis'] - key = 'chirp_kurtosis' - elif test == 'contrast': - defaultparams['chirp_contrast'] = testparams['chirp_contrast'] - key = 'chirp_contrast' + if test == "width": + defaultparams["chirp_width"] = testparams["chirp_width"] + key = "chirp_width" + elif test == "size": + defaultparams["chirp_size"] = testparams["chirp_size"] + key = "chirp_size" + elif test == "kurtosis": + defaultparams["chirp_kurtosis"] = testparams["chirp_kurtosis"] + key = "chirp_kurtosis" + elif test == "contrast": + defaultparams["chirp_contrast"] = testparams["chirp_contrast"] + key = "chirp_contrast" else: raise ValueError("Test not recognized") @@ -29,31 +29,40 @@ def extract_dict(dict, index): def main(test1, test2, resolution=10): - - assert test1 in ['width', 'size', 'kurtosis', 'contrast'], "Test1 not recognized" - assert test2 in ['width', 'size', 'kurtosis', 'contrast'], "Test2 not recognized" - - # Define the parameters for the chirp simulations + assert test1 in [ + "width", + "size", + "kurtosis", + "contrast", + ], "Test1 not recognized" + assert test2 in [ + "width", + "size", + "kurtosis", + "contrast", + ], "Test2 not recognized" + + # Define the parameters for the chirp simulations ntest = resolution defaultparams = dict( - chirp_size = np.ones(ntest) * 100, - chirp_width = np.ones(ntest) * 0.1, - chirp_kurtosis = np.ones(ntest) * 1.0, - chirp_contrast = np.ones(ntest) * 0.5, + chirp_size=np.ones(ntest) * 100, + chirp_width=np.ones(ntest) * 0.1, + chirp_kurtosis=np.ones(ntest) * 1.0, + chirp_contrast=np.ones(ntest) * 0.5, ) testparams = dict( - chirp_width = np.linspace(0.01, 0.2, ntest), - chirp_size = np.linspace(50, 300, ntest), - chirp_kurtosis = np.linspace(0.5, 1.5, ntest), - chirp_contrast = np.linspace(0.01, 1.0, ntest), + chirp_width=np.linspace(0.01, 0.2, ntest), + chirp_size=np.linspace(50, 300, ntest), + chirp_kurtosis=np.linspace(0.5, 1.5, ntest), + chirp_contrast=np.linspace(0.01, 1.0, ntest), ) key1, chirp_params = switch_test(test1, defaultparams, testparams) key2, chirp_params = switch_test(test2, chirp_params, testparams) - # make the chirp trace + # make the chirp trace eodf = 500 samplerate = 20000 duration = 2 @@ -63,40 +72,60 @@ def main(test1, test2, resolution=10): tight_cutoffs = 10 distances = np.full((ntest, ntest), np.nan) - - fig, axs = plt.subplots(ntest, ntest, figsize = (10, 10), sharex = True, sharey = True) + + fig, axs = plt.subplots( + ntest, ntest, figsize=(10, 10), sharex=True, sharey=True + ) axs = axs.flatten() iter0 = 0 for iter1, test1_param in enumerate(chirp_params[key1]): for iter2, test2_param in enumerate(chirp_params[key2]): - # get the chirp parameters for the current test inner_chirp_params = extract_dict(chirp_params, iter2) inner_chirp_params[key1] = test1_param inner_chirp_params[key2] = test2_param # make the chirp trace for the current chirp parameters - sizes = np.ones(len(chirp_times)) * inner_chirp_params['chirp_size'] - widths = np.ones(len(chirp_times)) * inner_chirp_params['chirp_width'] - kurtosis = np.ones(len(chirp_times)) * inner_chirp_params['chirp_kurtosis'] - contrast = np.ones(len(chirp_times)) * inner_chirp_params['chirp_contrast'] + sizes = np.ones(len(chirp_times)) * inner_chirp_params["chirp_size"] + widths = ( + np.ones(len(chirp_times)) * inner_chirp_params["chirp_width"] + ) + kurtosis = ( + np.ones(len(chirp_times)) * inner_chirp_params["chirp_kurtosis"] + ) + contrast = ( + np.ones(len(chirp_times)) * inner_chirp_params["chirp_contrast"] + ) # make the chirp trace - chirp_trace, ampmod = chirps(eodf, samplerate, duration, chirp_times, sizes, widths, kurtosis, contrast) + chirp_trace, ampmod = chirps( + eodf, + samplerate, + duration, + chirp_times, + sizes, + widths, + kurtosis, + contrast, + ) signal = wavefish_eods( - fish="Alepto", - frequency=chirp_trace, - samplerate=samplerate, - duration=duration, - phase0=0.0, - noise_std=0.05 - ) + fish="Alepto", + frequency=chirp_trace, + samplerate=samplerate, + duration=duration, + phase0=0.0, + noise_std=0.05, + ) signal = signal * ampmod - # apply broadband filter - wide_signal = bandpass_filter(signal, samplerate, eodf - wide_cutoffs, eodf + wide_cutoffs) - tight_signal = bandpass_filter(signal, samplerate, eodf - tight_cutoffs, eodf + tight_cutoffs) + # apply broadband filter + wide_signal = bandpass_filter( + signal, samplerate, eodf - wide_cutoffs, eodf + wide_cutoffs + ) + tight_signal = bandpass_filter( + signal, samplerate, eodf - tight_cutoffs, eodf + tight_cutoffs + ) # get the instantaneous frequency wide_frequency = inst_freq(wide_signal, samplerate) @@ -111,8 +140,9 @@ def main(test1, test2, resolution=10): iter0 += 1 fig, ax = plt.subplots() - ax.imshow(distances, cmap = 'jet') + ax.imshow(distances, cmap="jet") plt.show() + if __name__ == "__main__": - main('width', 'size') + main("width", "size") diff --git a/code/analysis.py b/code/analysis.py index 787a53e..2e32671 100644 --- a/code/analysis.py +++ b/code/analysis.py @@ -10,73 +10,84 @@ from modules.filters import bandpass_filter def main(folder): - file = os.path.join(folder, 'traces-grid.raw') + file = os.path.join(folder, "traces-grid.raw") data = open_data(folder, 60.0, 0, channel=-1) - time = np.load(folder + 'times.npy', allow_pickle=True) - freq = np.load(folder + 'fund_v.npy', allow_pickle=True) - ident = np.load(folder + 'ident_v.npy', allow_pickle=True) - idx = np.load(folder + 'idx_v.npy', allow_pickle=True) + time = np.load(folder + "times.npy", allow_pickle=True) + freq = np.load(folder + "fund_v.npy", allow_pickle=True) + ident = np.load(folder + "ident_v.npy", allow_pickle=True) + idx = np.load(folder + "idx_v.npy", allow_pickle=True) - t0 = 3*60*60 + 6*60 + 43.5 + t0 = 3 * 60 * 60 + 6 * 60 + 43.5 dt = 60 - data_oi = data[t0 * data.samplerate: (t0+dt)*data.samplerate, :] + data_oi = data[t0 * data.samplerate : (t0 + dt) * data.samplerate, :] for i in [10]: # getting the spectogramm spec_power, spec_freqs, spec_times = spectrogram( - data_oi[:, i], ratetime=data.samplerate, freq_resolution=50, overlap_frac=0.0) - fig, ax = plt.subplots(figsize=(20/2.54, 12/2.54)) - ax.pcolormesh(spec_times, spec_freqs, decibel( - spec_power), vmin=-100, vmax=-50) + data_oi[:, i], + ratetime=data.samplerate, + freq_resolution=50, + overlap_frac=0.0, + ) + fig, ax = plt.subplots(figsize=(20 / 2.54, 12 / 2.54)) + ax.pcolormesh( + spec_times, spec_freqs, decibel(spec_power), vmin=-100, vmax=-50 + ) for track_id in np.unique(ident): # window_index for time array in time window - window_index = np.arange(len(idx))[(ident == track_id) & - (time[idx] >= t0) & - (time[idx] <= (t0+dt))] + window_index = np.arange(len(idx))[ + (ident == track_id) + & (time[idx] >= t0) + & (time[idx] <= (t0 + dt)) + ] freq_temp = freq[window_index] time_temp = time[idx[window_index]] - #mean_freq = np.mean(freq_temp) - #fdata = bandpass_filter(data_oi[:, track_id], data.samplerate, mean_freq-5, mean_freq+200) + # mean_freq = np.mean(freq_temp) + # fdata = bandpass_filter(data_oi[:, track_id], data.samplerate, mean_freq-5, mean_freq+200) ax.plot(time_temp - t0, freq_temp) ax.set_ylim(500, 1000) plt.show() # filter plot - id = 10. + id = 10.0 i = 10 - window_index = np.arange(len(idx))[(ident == id) & - (time[idx] >= t0) & - (time[idx] <= (t0+dt))] + window_index = np.arange(len(idx))[ + (ident == id) & (time[idx] >= t0) & (time[idx] <= (t0 + dt)) + ] freq_temp = freq[window_index] time_temp = time[idx[window_index]] mean_freq = np.mean(freq_temp) fdata = bandpass_filter( - data_oi[:, i], rate=data.samplerate, lowf=mean_freq-5, highf=mean_freq+200) + data_oi[:, i], + rate=data.samplerate, + lowf=mean_freq - 5, + highf=mean_freq + 200, + ) fig, ax = plt.subplots() - ax.plot(np.arange(len(fdata))/data.samplerate, fdata, marker='*') + ax.plot(np.arange(len(fdata)) / data.samplerate, fdata, marker="*") # plt.show() # freqency analyis of filtered data - time_fdata = np.arange(len(fdata))/data.samplerate + time_fdata = np.arange(len(fdata)) / data.samplerate roll_fdata = np.roll(fdata, shift=1) period_index = np.arange(len(fdata))[(roll_fdata < 0) & (fdata >= 0)] plt.plot(time_fdata, fdata) - plt.scatter(time_fdata[period_index], fdata[period_index], c='r') - plt.scatter(time_fdata[period_index-1], fdata[period_index-1], c='r') + plt.scatter(time_fdata[period_index], fdata[period_index], c="r") + plt.scatter(time_fdata[period_index - 1], fdata[period_index - 1], c="r") upper_bound = np.abs(fdata[period_index]) - lower_bound = np.abs(fdata[period_index-1]) + lower_bound = np.abs(fdata[period_index - 1]) upper_times = np.abs(time_fdata[period_index]) - lower_times = np.abs(time_fdata[period_index-1]) + lower_times = np.abs(time_fdata[period_index - 1]) - lower_ratio = lower_bound/(lower_bound+upper_bound) - upper_ratio = upper_bound/(lower_bound+upper_bound) + lower_ratio = lower_bound / (lower_bound + upper_bound) + upper_ratio = upper_bound / (lower_bound + upper_bound) - time_delta = upper_times-lower_times - true_zero = lower_times + time_delta*lower_ratio + time_delta = upper_times - lower_times + true_zero = lower_times + time_delta * lower_ratio plt.scatter(true_zero, np.zeros(len(true_zero))) @@ -84,7 +95,7 @@ def main(folder): inst_freq = 1 / np.diff(true_zero) filtered_inst_freq = gaussian_filter1d(inst_freq, 0.005) fig, ax = plt.subplots() - ax.plot(filtered_inst_freq, marker='.') + ax.plot(filtered_inst_freq, marker=".") # in 5 sekunden welcher fisch auf einer elektrode am embed() @@ -99,5 +110,7 @@ def main(folder): pass -if __name__ == '__main__': - main('/Users/acfw/Documents/uni_tuebingen/chirpdetection/gp_benda/data/2022-06-02-10_00/') +if __name__ == "__main__": + main( + "/Users/acfw/Documents/uni_tuebingen/chirpdetection/gp_benda/data/2022-06-02-10_00/" + ) diff --git a/code/band_pass_problem.py b/code/band_pass_problem.py index fc6a55e..f553ff2 100644 --- a/code/band_pass_problem.py +++ b/code/band_pass_problem.py @@ -12,25 +12,27 @@ from modules.filehandling import LoadData def main(folder): data = LoadData(folder) - t0 = 3*60*60 + 6*60 + 43.5 + t0 = 3 * 60 * 60 + 6 * 60 + 43.5 dt = 60 - data_oi = data.raw[t0 * data.raw_rate: (t0+dt)*data.raw_rate, :] - # good electrode - electrode = 10 + data_oi = data.raw[t0 * data.raw_rate : (t0 + dt) * data.raw_rate, :] + # good electrode + electrode = 10 data_oi = data_oi[:, electrode] - fig, axs = plt.subplots(2,1) - axs[0].plot( np.arange(data_oi.shape[0]) / data.raw_rate, data_oi) + fig, axs = plt.subplots(2, 1) + axs[0].plot(np.arange(data_oi.shape[0]) / data.raw_rate, data_oi) for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): rack_window_index = np.arange(len(data.idx))[ - (data.ident == track_id) & - (data.time[data.idx] >= t0) & - (data.time[data.idx] <= (t0+dt))] + (data.ident == track_id) + & (data.time[data.idx] >= t0) + & (data.time[data.idx] <= (t0 + dt)) + ] freq_fish = data.freq[rack_window_index] axs[1].plot(np.arange(freq_fish.shape[0]) / data.raw_rate, freq_fish) plt.show() - -if __name__ == '__main__': - main('/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/2022-06-02-10_00/') \ No newline at end of file +if __name__ == "__main__": + main( + "/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/2022-06-02-10_00/" + ) diff --git a/code/behavior.py b/code/behavior.py index 71c0926..4f16543 100644 --- a/code/behavior.py +++ b/code/behavior.py @@ -1,8 +1,8 @@ -import os -import os +import os +import os import numpy as np -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt from IPython import embed from pandas import read_csv @@ -11,51 +11,65 @@ from scipy.ndimage import gaussian_filter1d logger = makeLogger(__name__) + class Behavior: """Load behavior data from csv file as class attributes Attributes ---------- behavior: 0: chasing onset, 1: chasing offset, 2: physical contact - behavior_type: - behavioral_category: - comment_start: - comment_stop: - dataframe: pandas dataframe with all the data - duration_s: - media_file: - observation_date: - observation_id: - start_s: start time of the event in seconds - stop_s: stop time of the event in seconds - total_length: + behavior_type: + behavioral_category: + comment_start: + comment_stop: + dataframe: pandas dataframe with all the data + duration_s: + media_file: + observation_date: + observation_id: + start_s: start time of the event in seconds + stop_s: stop time of the event in seconds + total_length: """ def __init__(self, folder_path: str) -> None: - - - LED_on_time_BORIS = np.load(os.path.join(folder_path, 'LED_on_time.npy'), allow_pickle=True) - self.time = np.load(os.path.join(folder_path, "times.npy"), allow_pickle=True) - csv_filename = [f for f in os.listdir(folder_path) if f.endswith('.csv')][0] # check if there are more than one csv file + LED_on_time_BORIS = np.load( + os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True + ) + self.time = np.load( + os.path.join(folder_path, "times.npy"), allow_pickle=True + ) + csv_filename = [ + f for f in os.listdir(folder_path) if f.endswith(".csv") + ][ + 0 + ] # check if there are more than one csv file self.dataframe = read_csv(os.path.join(folder_path, csv_filename)) - self.chirps = np.load(os.path.join(folder_path, 'chirps.npy'), allow_pickle=True) - self.chirps_ids = np.load(os.path.join(folder_path, 'chirps_ids.npy'), allow_pickle=True) + self.chirps = np.load( + os.path.join(folder_path, "chirps.npy"), allow_pickle=True + ) + self.chirps_ids = np.load( + os.path.join(folder_path, "chirps_ids.npy"), allow_pickle=True + ) for k, key in enumerate(self.dataframe.keys()): - key = key.lower() - if ' ' in key: - key = key.replace(' ', '_') - if '(' in key: - key = key.replace('(', '') - key = key.replace(')', '') - setattr(self, key, np.array(self.dataframe[self.dataframe.keys()[k]])) - + key = key.lower() + if " " in key: + key = key.replace(" ", "_") + if "(" in key: + key = key.replace("(", "") + key = key.replace(")", "") + setattr( + self, key, np.array(self.dataframe[self.dataframe.keys()[k]]) + ) + last_LED_t_BORIS = LED_on_time_BORIS[-1] real_time_range = self.time[-1] - self.time[0] factor = 1.034141 shift = last_LED_t_BORIS - real_time_range * factor self.start_s = (self.start_s - shift) / factor self.stop_s = (self.stop_s - shift) / factor - + + """ 1 - chasing onset 2 - chasing offset @@ -83,77 +97,77 @@ temporal encpding needs to be corrected ... not exactly 25FPS. behavior = data['Behavior'] """ -def correct_chasing_events( - category: np.ndarray, - timestamps: np.ndarray - ) -> tuple[np.ndarray, np.ndarray]: - onset_ids = np.arange( - len(category))[category == 0] - offset_ids = np.arange( - len(category))[category == 1] +def correct_chasing_events( + category: np.ndarray, timestamps: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + onset_ids = np.arange(len(category))[category == 0] + offset_ids = np.arange(len(category))[category == 1] # Check whether on- or offset is longer and calculate length difference if len(onset_ids) > len(offset_ids): len_diff = len(onset_ids) - len(offset_ids) longer_array = onset_ids shorter_array = offset_ids - logger.info(f'Onsets are greater than offsets by {len_diff}') + logger.info(f"Onsets are greater than offsets by {len_diff}") elif len(onset_ids) < len(offset_ids): len_diff = len(offset_ids) - len(onset_ids) longer_array = offset_ids shorter_array = onset_ids - logger.info(f'Offsets are greater than offsets by {len_diff}') + logger.info(f"Offsets are greater than offsets by {len_diff}") elif len(onset_ids) == len(offset_ids): - logger.info('Chasing events are equal') + logger.info("Chasing events are equal") return category, timestamps # Correct the wrong chasing events; delete double events wrong_ids = [] - for i in range(len(longer_array)-(len_diff+1)): - if (shorter_array[i] > longer_array[i]) & (shorter_array[i] < longer_array[i+1]): + for i in range(len(longer_array) - (len_diff + 1)): + if (shorter_array[i] > longer_array[i]) & ( + shorter_array[i] < longer_array[i + 1] + ): pass else: wrong_ids.append(longer_array[i]) longer_array = np.delete(longer_array, i) - - category = np.delete( - category, wrong_ids) - timestamps = np.delete( - timestamps, wrong_ids) + + category = np.delete(category, wrong_ids) + timestamps = np.delete(timestamps, wrong_ids) return category, timestamps def event_triggered_chirps( - event: np.ndarray, - chirps:np.ndarray, + event: np.ndarray, + chirps: np.ndarray, time_before_event: int, - time_after_event: int - )-> tuple[np.ndarray, np.ndarray]: - - - event_chirps = [] # chirps that are in specified window around event - centered_chirps = [] # timestamps of chirps around event centered on the event timepoint + time_after_event: int, +) -> tuple[np.ndarray, np.ndarray]: + event_chirps = [] # chirps that are in specified window around event + centered_chirps = ( + [] + ) # timestamps of chirps around event centered on the event timepoint for event_timestamp in event: - start = event_timestamp - time_before_event # timepoint of window start - stop = event_timestamp + time_after_event # timepoint of window ending - chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)] # get chirps that are in a -5 to +5 sec window around event + start = event_timestamp - time_before_event # timepoint of window start + stop = event_timestamp + time_after_event # timepoint of window ending + chirps_around_event = [ + c for c in chirps if (c >= start) & (c <= stop) + ] # get chirps that are in a -5 to +5 sec window around event event_chirps.append(chirps_around_event) if len(chirps_around_event) == 0: continue - else: + else: centered_chirps.append(chirps_around_event - event_timestamp) - centered_chirps = np.concatenate(centered_chirps, axis=0) # convert list of arrays to one array for plotting + centered_chirps = np.concatenate( + centered_chirps, axis=0 + ) # convert list of arrays to one array for plotting return event_chirps, centered_chirps def main(datapath: str): - # behavior is pandas dataframe with all the data bh = Behavior(datapath) - + # chirps are not sorted in time (presumably due to prior groupings) # get and sort chirps and corresponding fish_ids of the chirps chirps = bh.chirps[np.argsort(bh.chirps)] @@ -172,10 +186,34 @@ def main(datapath: str): # First overview plot fig1, ax1 = plt.subplots() - ax1.scatter(chirps, np.ones_like(chirps), marker='*', color='royalblue', label='Chirps') - ax1.scatter(chasing_onset, np.ones_like(chasing_onset)*2, marker='.', color='forestgreen', label='Chasing onset') - ax1.scatter(chasing_offset, np.ones_like(chasing_offset)*2.5, marker='.', color='firebrick', label='Chasing offset') - ax1.scatter(physical_contact, np.ones_like(physical_contact)*3, marker='x', color='black', label='Physical contact') + ax1.scatter( + chirps, + np.ones_like(chirps), + marker="*", + color="royalblue", + label="Chirps", + ) + ax1.scatter( + chasing_onset, + np.ones_like(chasing_onset) * 2, + marker=".", + color="forestgreen", + label="Chasing onset", + ) + ax1.scatter( + chasing_offset, + np.ones_like(chasing_offset) * 2.5, + marker=".", + color="firebrick", + label="Chasing offset", + ) + ax1.scatter( + physical_contact, + np.ones_like(physical_contact) * 3, + marker="x", + color="black", + label="Physical contact", + ) plt.legend() # plt.show() plt.close() @@ -187,29 +225,40 @@ def main(datapath: str): # Evaluate how many chirps were emitted in specific time window around the chasing onset events # Iterate over chasing onsets (later over fish) - time_around_event = 5 # time window around the event in which chirps are counted, 5 = -5 to +5 sec around event + time_around_event = 5 # time window around the event in which chirps are counted, 5 = -5 to +5 sec around event #### Loop crashes at concatenate in function #### # for i in range(len(fish_ids)): # fish = fish_ids[i] # chirps = chirps[chirps_fish_ids == fish] # print(fish) - chasing_chirps, centered_chasing_chirps = event_triggered_chirps(chasing_onset, chirps, time_around_event, time_around_event) - physical_chirps, centered_physical_chirps = event_triggered_chirps(physical_contact, chirps, time_around_event, time_around_event) + chasing_chirps, centered_chasing_chirps = event_triggered_chirps( + chasing_onset, chirps, time_around_event, time_around_event + ) + physical_chirps, centered_physical_chirps = event_triggered_chirps( + physical_contact, chirps, time_around_event, time_around_event + ) # Kernel density estimation ??? # centered_chasing_chirps_convolved = gaussian_filter1d(centered_chasing_chirps, 5) - + # centered_chasing = chasing_onset[0] - chasing_onset[0] ## get the 0 timepoint for plotting; set one chasing event to 0 offsets = [0.5, 1] - fig4, ax4 = plt.subplots(figsize=(20 / 2.54, 12 / 2.54), constrained_layout=True) - ax4.eventplot(np.array([centered_chasing_chirps, centered_physical_chirps]), lineoffsets=offsets, linelengths=0.25, colors=['g', 'r']) - ax4.vlines(0, 0, 1.5, 'tab:grey', 'dashed', 'Timepoint of event') + fig4, ax4 = plt.subplots( + figsize=(20 / 2.54, 12 / 2.54), constrained_layout=True + ) + ax4.eventplot( + np.array([centered_chasing_chirps, centered_physical_chirps]), + lineoffsets=offsets, + linelengths=0.25, + colors=["g", "r"], + ) + ax4.vlines(0, 0, 1.5, "tab:grey", "dashed", "Timepoint of event") # ax4.plot(centered_chasing_chirps_convolved) ax4.set_yticks(offsets) - ax4.set_yticklabels(['Chasings', 'Physical \n contacts']) - ax4.set_xlabel('Time[s]') - ax4.set_ylabel('Type of event') + ax4.set_yticklabels(["Chasings", "Physical \n contacts"]) + ax4.set_xlabel("Time[s]") + ax4.set_ylabel("Type of event") plt.show() # Associate chirps to inidividual fish @@ -219,22 +268,21 @@ def main(datapath: str): ### Plots: # 1. All recordings, all fish, all chirps - # One CTC, one PTC + # One CTC, one PTC # 2. All recordings, only winners - # One CTC, one PTC + # One CTC, one PTC # 3. All recordings, all losers - # One CTC, one PTC + # One CTC, one PTC #### Chirp counts per fish general ##### fig2, ax2 = plt.subplots() - x = ['Fish1', 'Fish2'] + x = ["Fish1", "Fish2"] width = 0.35 ax2.bar(x, fish, width=width) - ax2.set_ylabel('Chirp count') + ax2.set_ylabel("Chirp count") # plt.show() plt.close() - ##### Count chirps emitted during chasing events and chirps emitted out of chasing events ##### chirps_in_chasings = [] for onset, offset in zip(chasing_onset, chasing_offset): @@ -251,23 +299,24 @@ def main(datapath: str): counts_chirps_chasings += 1 # chirps in chasing events - fig3 , ax3 = plt.subplots() - ax3.bar(['Chirps in chasing events', 'Chasing events without Chirps'], [counts_chirps_chasings, chasings_without_chirps], width=width) - plt.ylabel('Count') + fig3, ax3 = plt.subplots() + ax3.bar( + ["Chirps in chasing events", "Chasing events without Chirps"], + [counts_chirps_chasings, chasings_without_chirps], + width=width, + ) + plt.ylabel("Count") # plt.show() - plt.close() + plt.close() # comparison between chasing events with and without chirps - - embed() exit() - -if __name__ == '__main__': +if __name__ == "__main__": # Path to the data - datapath = '../data/mount_data/2020-05-13-10_00/' - datapath = '../data/mount_data/2020-05-13-10_00/' + datapath = "../data/mount_data/2020-05-13-10_00/" + datapath = "../data/mount_data/2020-05-13-10_00/" main(datapath) diff --git a/code/chirp_sim.py b/code/chirp_sim.py index 5433b36..d7023a5 100644 --- a/code/chirp_sim.py +++ b/code/chirp_sim.py @@ -8,30 +8,27 @@ from modules.datahandling import instantaneous_frequency from modules.simulations import create_chirp - # trying thunderfish fakefish chirp simulation --------------------------------- samplerate = 44100 freq, ampl = fakefish.chirps(eodf=500, chirp_contrast=0.2) -data = fakefish.wavefish_eods(fish='Alepto', frequency=freq, phase0=3, samplerate=samplerate) +data = fakefish.wavefish_eods( + fish="Alepto", frequency=freq, phase0=3, samplerate=samplerate +) # filter signal with bandpass_filter -data_filterd = bandpass_filter(data*ampl+1, samplerate, 0.01, 1.99) +data_filterd = bandpass_filter(data * ampl + 1, samplerate, 0.01, 1.99) embed() data_freq_time, data_freq = instantaneous_frequency(data, samplerate, 5) fig, ax = plt.subplots(4, 1, figsize=(20 / 2.54, 12 / 2.54), sharex=True) -ax[0].plot(np.arange(len(data))/samplerate, data*ampl) -#ax[0].scatter(true_zero, np.zeros_like(true_zero), color='red') -ax[1].plot(np.arange(len(data_filterd))/samplerate, data_filterd) -ax[2].plot(np.arange(len(freq))/samplerate, freq) +ax[0].plot(np.arange(len(data)) / samplerate, data * ampl) +# ax[0].scatter(true_zero, np.zeros_like(true_zero), color='red') +ax[1].plot(np.arange(len(data_filterd)) / samplerate, data_filterd) +ax[2].plot(np.arange(len(freq)) / samplerate, freq) ax[3].plot(data_freq_time, data_freq) plt.show() embed() - - - - diff --git a/code/chirpdetection.py b/code/chirpdetection.py index 95800df..937bde4 100755 --- a/code/chirpdetection.py +++ b/code/chirpdetection.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import matplotlib.gridspec as gr from scipy.signal import find_peaks from thunderfish.powerspectrum import spectrogram, decibel + # from sklearn.preprocessing import normalize from modules.filters import bandpass_filter, envelope, highpass_filter @@ -18,7 +19,7 @@ from modules.datahandling import ( purge_duplicates, group_timestamps, instantaneous_frequency, - instantaneous_frequency2, + instantaneous_frequency2, minmaxnorm, ) @@ -59,7 +60,6 @@ class ChirpPlotBuffer: frequency_peaks: np.ndarray def plot_buffer(self, chirps: np.ndarray, plot: str) -> None: - logger.debug("Starting plotting") # make data for plotting @@ -135,7 +135,6 @@ class ChirpPlotBuffer: ax0.set_ylim(np.min(self.frequency) - 100, np.max(self.frequency) + 200) for track_id in self.data.ids: - t0_track = self.t0_old - 5 dt_track = self.dt + 10 window_idx = np.arange(len(self.data.idx))[ @@ -176,10 +175,16 @@ class ChirpPlotBuffer: # ) ax0.axhline( - q50 - self.config.minimal_bandwidth / 2, color=ps.gblue1, lw=1, ls="dashed" + q50 - self.config.minimal_bandwidth / 2, + color=ps.gblue1, + lw=1, + ls="dashed", ) ax0.axhline( - q50 + self.config.minimal_bandwidth / 2, color=ps.gblue1, lw=1, ls="dashed" + q50 + self.config.minimal_bandwidth / 2, + color=ps.gblue1, + lw=1, + ls="dashed", ) ax0.axhline(search_lower, color=ps.gblue2, lw=1, ls="dashed") ax0.axhline(search_upper, color=ps.gblue2, lw=1, ls="dashed") @@ -205,7 +210,11 @@ class ChirpPlotBuffer: # plot waveform of filtered signal ax1.plot( - self.time, self.baseline * waveform_scaler, c=ps.gray, lw=lw, alpha=0.5 + self.time, + self.baseline * waveform_scaler, + c=ps.gray, + lw=lw, + alpha=0.5, ) ax1.plot( self.time, @@ -216,7 +225,13 @@ class ChirpPlotBuffer: ) # plot waveform of filtered search signal - ax2.plot(self.time, self.search * waveform_scaler, c=ps.gray, lw=lw, alpha=0.5) + ax2.plot( + self.time, + self.search * waveform_scaler, + c=ps.gray, + lw=lw, + alpha=0.5, + ) ax2.plot( self.time, self.search_envelope_unfiltered * waveform_scaler, @@ -238,9 +253,7 @@ class ChirpPlotBuffer: # ax4.plot( # self.time, self.baseline_envelope * waveform_scaler, c=ps.gblue1, lw=lw # ) - ax4.plot( - self.time, self.baseline_envelope, c=ps.gblue1, lw=lw - ) + ax4.plot(self.time, self.baseline_envelope, c=ps.gblue1, lw=lw) ax4.scatter( (self.time)[self.baseline_peaks], # (self.baseline_envelope * waveform_scaler)[self.baseline_peaks], @@ -269,7 +282,9 @@ class ChirpPlotBuffer: ) # plot filtered instantaneous frequency - ax6.plot(self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw) + ax6.plot( + self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw + ) ax6.scatter( self.frequency_time[self.frequency_peaks], self.frequency_filtered[self.frequency_peaks], @@ -303,7 +318,9 @@ class ChirpPlotBuffer: # ax7.spines.bottom.set_bounds((0, 5)) ax0.set_xlim(0, self.config.window) - plt.subplots_adjust(left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2) + plt.subplots_adjust( + left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2 + ) fig.align_labels() if plot == "show": @@ -408,7 +425,9 @@ def extract_frequency_bands( q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2 # filter baseline - filtered_baseline = bandpass_filter(raw_data, samplerate, lowf=q25, highf=q75) + filtered_baseline = bandpass_filter( + raw_data, samplerate, lowf=q25, highf=q75 + ) # filter search area filtered_search_freq = bandpass_filter( @@ -453,12 +472,14 @@ def window_median_all_track_ids( track_ids = [] for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): - # the window index combines the track id and the time window window_idx = np.arange(len(data.idx))[ (data.ident == track_id) & (data.time[data.idx] >= window_start_seconds) - & (data.time[data.idx] <= (window_start_seconds + window_duration_seconds)) + & ( + data.time[data.idx] + <= (window_start_seconds + window_duration_seconds) + ) ] if len(data.freq[window_idx]) > 0: @@ -595,15 +616,15 @@ def find_searchband( # iterate through theses tracks if check_track_ids.size != 0: - for j, check_track_id in enumerate(check_track_ids): - q25_temp = q25[percentiles_ids == check_track_id] q75_temp = q75[percentiles_ids == check_track_id] bool_lower[search_window > q25_temp - config.search_res] = False bool_upper[search_window < q75_temp + config.search_res] = False - search_window_bool[(bool_lower == False) & (bool_upper == False)] = False + search_window_bool[ + (bool_lower == False) & (bool_upper == False) + ] = False # find gaps in search window search_window_indices = np.arange(len(search_window)) @@ -622,7 +643,9 @@ def find_searchband( # if the first value is -1, the array starst with true, so a gap if nonzeros[0] == -1: stops = search_window_indices[search_window_gaps == -1] - starts = np.append(0, search_window_indices[search_window_gaps == 1]) + starts = np.append( + 0, search_window_indices[search_window_gaps == 1] + ) # if the last value is -1, the array ends with true, so a gap if nonzeros[-1] == 1: @@ -659,7 +682,6 @@ def find_searchband( def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: - assert plot in [ "save", "show", @@ -729,7 +751,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: multiwindow_ids = [] for st, window_start_index in enumerate(window_start_indices): - logger.info(f"Processing window {st+1} of {len(window_start_indices)}") window_start_seconds = window_start_index / data.raw_rate @@ -744,8 +765,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: ) # iterate through all fish - for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): - + for tr, track_id in enumerate( + np.unique(data.ident[~np.isnan(data.ident)]) + ): logger.debug(f"Processing track {tr} of {len(data.ids)}") # get index of track data in this time window @@ -773,16 +795,17 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: nanchecker = np.unique(np.isnan(current_powers)) if (len(nanchecker) == 1) and nanchecker[0] is True: logger.warning( - f"No powers available for track {track_id} window {st}," "skipping." + f"No powers available for track {track_id} window {st}," + "skipping." ) continue # find the strongest electrodes for the current fish in the current # window - best_electrode_index = np.argsort(np.nanmean(current_powers, axis=0))[ - -config.number_electrodes : - ] + best_electrode_index = np.argsort( + np.nanmean(current_powers, axis=0) + )[-config.number_electrodes :] # find a frequency above the baseline of the current fish in which # no other fish is active to search for chirps there @@ -802,9 +825,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: # iterate through electrodes for el, electrode_index in enumerate(best_electrode_index): - logger.debug( - f"Processing electrode {el+1} of " f"{len(best_electrode_index)}" + f"Processing electrode {el+1} of " + f"{len(best_electrode_index)}" ) # LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------ @@ -813,7 +836,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: current_raw_data = data.raw[ window_start_index:window_stop_index, electrode_index ] - current_raw_time = raw_time[window_start_index:window_stop_index] + current_raw_time = raw_time[ + window_start_index:window_stop_index + ] # EXTRACT FEATURES -------------------------------------------- @@ -839,8 +864,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: # because the instantaneous frequency is not reliable there amplitude_mask = mask_low_amplitudes( - baseline_envelope_unfiltered, - config.baseline_min_amplitude + baseline_envelope_unfiltered, config.baseline_min_amplitude ) # highpass filter baseline envelope to remove slower @@ -877,27 +901,30 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: # filtered baseline such as the one we are working with. baseline_frequency = instantaneous_frequency( - baselineband, - data.raw_rate, - config.baseline_frequency_smoothing + baselineband, + data.raw_rate, + config.baseline_frequency_smoothing, ) # Take the absolute of the instantaneous frequency to invert - # troughs into peaks. This is nessecary since the narrow + # troughs into peaks. This is nessecary since the narrow # pass band introduces these anomalies. Also substract by the # median to set it to 0. - + baseline_frequency_filtered = np.abs( baseline_frequency - np.median(baseline_frequency) ) - # check if there is at least one superthreshold peak on the - # instantaneous and exit the loop if not. This is used to - # prevent windows that do definetely not include a chirp - # to enter normalization, where small changes due to noise - # would be amplified + # check if there is at least one superthreshold peak on the + # instantaneous and exit the loop if not. This is used to + # prevent windows that do definetely not include a chirp + # to enter normalization, where small changes due to noise + # would be amplified - if not has_chirp(baseline_frequency_filtered[amplitude_mask], config.baseline_frequency_peakheight): + if not has_chirp( + baseline_frequency_filtered[amplitude_mask], + config.baseline_frequency_peakheight, + ): continue # CUT OFF OVERLAP --------------------------------------------- @@ -912,14 +939,20 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: current_raw_time = current_raw_time[no_edges] baselineband = baselineband[no_edges] - baseline_envelope_unfiltered = baseline_envelope_unfiltered[no_edges] + baseline_envelope_unfiltered = baseline_envelope_unfiltered[ + no_edges + ] searchband = searchband[no_edges] baseline_envelope = baseline_envelope[no_edges] - search_envelope_unfiltered = search_envelope_unfiltered[no_edges] + search_envelope_unfiltered = search_envelope_unfiltered[ + no_edges + ] search_envelope = search_envelope[no_edges] baseline_frequency = baseline_frequency[no_edges] - baseline_frequency_filtered = baseline_frequency_filtered[no_edges] + baseline_frequency_filtered = baseline_frequency_filtered[ + no_edges + ] baseline_frequency_time = current_raw_time # # get instantaneous frequency withoup edges @@ -960,13 +993,16 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: ) # detect peaks inst_freq_filtered frequency_peak_indices, _ = find_peaks( - baseline_frequency_filtered, prominence=config.frequency_prominence + baseline_frequency_filtered, + prominence=config.frequency_prominence, ) # DETECT CHIRPS IN SEARCH WINDOW ------------------------------ # get the peak timestamps from the peak indices - baseline_peak_timestamps = current_raw_time[baseline_peak_indices] + baseline_peak_timestamps = current_raw_time[ + baseline_peak_indices + ] search_peak_timestamps = current_raw_time[search_peak_indices] frequency_peak_timestamps = baseline_frequency_time[ @@ -1015,7 +1051,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: ) if chirp_detected or (debug != "elecrode"): - logger.debug("Detected chirp, ititialize buffer ...") # save data to Buffer @@ -1107,7 +1142,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: multiwindow_chirps_flat = [] multiwindow_ids_flat = [] for track_id in np.unique(multiwindow_ids): - # get chirps for this fish and flatten the list current_track_bool = np.asarray(multiwindow_ids) == track_id current_track_chirps = flatten( @@ -1116,7 +1150,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: # add flattened chirps to the list multiwindow_chirps_flat.extend(current_track_chirps) - multiwindow_ids_flat.extend(list(np.ones_like(current_track_chirps) * track_id)) + multiwindow_ids_flat.extend( + list(np.ones_like(current_track_chirps) * track_id) + ) # purge duplicates, i.e. chirps that are very close to each other # duplites arise due to overlapping windows diff --git a/code/chirpdetector_conf.yml b/code/chirpdetector_conf.yml index 371326e..bb4598b 100755 --- a/code/chirpdetector_conf.yml +++ b/code/chirpdetector_conf.yml @@ -1,37 +1,37 @@ # Path setup ------------------------------------------------------------------ -dataroot: "../data/" # path to data -outputdir: "../output/" # path to save plots to +dataroot: "../data/" # path to data +outputdir: "../output/" # path to save plots to # Rolling window parameters --------------------------------------------------- -window: 5 # rolling window length in seconds +window: 5 # rolling window length in seconds overlap: 1 # window overlap in seconds edge: 0.25 # window edge cufoffs to mitigate filter edge effects # Electrode iteration parameters ---------------------------------------------- -number_electrodes: 2 # number of electrodes to go over -minimum_electrodes: 1 # mimumun number of electrodes a chirp must be on +number_electrodes: 2 # number of electrodes to go over +minimum_electrodes: 1 # mimumun number of electrodes a chirp must be on # Feature extraction parameters ----------------------------------------------- -search_df_lower: 20 # start searching this far above the baseline -search_df_upper: 100 # stop searching this far above the baseline -search_res: 1 # search window resolution -default_search_freq: 60 # search here if no need for a search frequency -minimal_bandwidth: 10 # minimal bandpass filter width for baseline -search_bandwidth: 10 # minimal bandpass filter width for search frequency -baseline_frequency_smoothing: 3 # instantaneous frequency smoothing +search_df_lower: 20 # start searching this far above the baseline +search_df_upper: 100 # stop searching this far above the baseline +search_res: 1 # search window resolution +default_search_freq: 60 # search here if no need for a search frequency +minimal_bandwidth: 10 # minimal bandpass filter width for baseline +search_bandwidth: 10 # minimal bandpass filter width for search frequency +baseline_frequency_smoothing: 3 # instantaneous frequency smoothing # Feature processing parameters ----------------------------------------------- -baseline_frequency_peakheight: 5 # the min peak height of the baseline instfreq -baseline_min_amplitude: 0.0001 # the minimal value of the baseline envelope -baseline_envelope_cutoff: 25 # envelope estimation cutoff -baseline_envelope_bandpass_lowf: 2 # envelope badpass lower cutoff -baseline_envelope_bandpass_highf: 100 # envelope bandbass higher cutoff -search_envelope_cutoff: 10 # search envelope estimation cufoff +baseline_frequency_peakheight: 5 # the min peak height of the baseline instfreq +baseline_min_amplitude: 0.0001 # the minimal value of the baseline envelope +baseline_envelope_cutoff: 25 # envelope estimation cutoff +baseline_envelope_bandpass_lowf: 2 # envelope badpass lower cutoff +baseline_envelope_bandpass_highf: 100 # envelope bandbass higher cutoff +search_envelope_cutoff: 10 # search envelope estimation cufoff # Peak detecion parameters ---------------------------------------------------- # baseline_prominence: 0.00005 # peak prominence threshold for baseline envelope @@ -39,9 +39,8 @@ search_envelope_cutoff: 10 # search envelope estimation cufoff # frequency_prominence: 2 # peak prominence threshold for baseline freq baseline_prominence: 0.3 # peak prominence threshold for baseline envelope -search_prominence: 0.3 # peak prominence threshold for search envelope -frequency_prominence: 0.3 # peak prominence threshold for baseline freq +search_prominence: 0.3 # peak prominence threshold for search envelope +frequency_prominence: 0.3 # peak prominence threshold for baseline freq # Classify events as chirps if they are less than this time apart chirp_window_threshold: 0.02 - diff --git a/code/eventchirpsplots.py b/code/eventchirpsplots.py index 4ebaa66..5003a7d 100644 --- a/code/eventchirpsplots.py +++ b/code/eventchirpsplots.py @@ -35,28 +35,36 @@ class Behavior: """ def __init__(self, folder_path: str) -> None: - print(f'{folder_path}') - LED_on_time_BORIS = np.load(os.path.join( - folder_path, 'LED_on_time.npy'), allow_pickle=True) - self.time = np.load(os.path.join( - folder_path, "times.npy"), allow_pickle=True) - csv_filename = [f for f in os.listdir(folder_path) if f.endswith( - '.csv')][0] # check if there are more than one csv file + print(f"{folder_path}") + LED_on_time_BORIS = np.load( + os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True + ) + self.time = np.load( + os.path.join(folder_path, "times.npy"), allow_pickle=True + ) + csv_filename = [ + f for f in os.listdir(folder_path) if f.endswith(".csv") + ][ + 0 + ] # check if there are more than one csv file self.dataframe = read_csv(os.path.join(folder_path, csv_filename)) - self.chirps = np.load(os.path.join( - folder_path, 'chirps.npy'), allow_pickle=True) - self.chirps_ids = np.load(os.path.join( - folder_path, 'chirp_ids.npy'), allow_pickle=True) + self.chirps = np.load( + os.path.join(folder_path, "chirps.npy"), allow_pickle=True + ) + self.chirps_ids = np.load( + os.path.join(folder_path, "chirp_ids.npy"), allow_pickle=True + ) for k, key in enumerate(self.dataframe.keys()): key = key.lower() - if ' ' in key: - key = key.replace(' ', '_') - if '(' in key: - key = key.replace('(', '') - key = key.replace(')', '') - setattr(self, key, np.array( - self.dataframe[self.dataframe.keys()[k]])) + if " " in key: + key = key.replace(" ", "_") + if "(" in key: + key = key.replace("(", "") + key = key.replace(")", "") + setattr( + self, key, np.array(self.dataframe[self.dataframe.keys()[k]]) + ) last_LED_t_BORIS = LED_on_time_BORIS[-1] real_time_range = self.time[-1] - self.time[0] @@ -95,17 +103,14 @@ temporal encpding needs to be corrected ... not exactly 25FPS. def correct_chasing_events( - category: np.ndarray, - timestamps: np.ndarray + category: np.ndarray, timestamps: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: + onset_ids = np.arange(len(category))[category == 0] + offset_ids = np.arange(len(category))[category == 1] - onset_ids = np.arange( - len(category))[category == 0] - offset_ids = np.arange( - len(category))[category == 1] - - wrong_bh = np.arange(len(category))[ - category != 2][:-1][np.diff(category[category != 2]) == 0] + wrong_bh = np.arange(len(category))[category != 2][:-1][ + np.diff(category[category != 2]) == 0 + ] if onset_ids[0] > offset_ids[0]: offset_ids = np.delete(offset_ids, 0) help_index = offset_ids[0] @@ -117,12 +122,12 @@ def correct_chasing_events( # Check whether on- or offset is longer and calculate length difference if len(onset_ids) > len(offset_ids): len_diff = len(onset_ids) - len(offset_ids) - logger.info(f'Onsets are greater than offsets by {len_diff}') + logger.info(f"Onsets are greater than offsets by {len_diff}") elif len(onset_ids) < len(offset_ids): len_diff = len(offset_ids) - len(onset_ids) - logger.info(f'Offsets are greater than onsets by {len_diff}') + logger.info(f"Offsets are greater than onsets by {len_diff}") elif len(onset_ids) == len(offset_ids): - logger.info('Chasing events are equal') + logger.info("Chasing events are equal") return category, timestamps @@ -135,8 +140,7 @@ def event_triggered_chirps( dt: float, width: float, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - - event_chirps = [] # chirps that are in specified window around event + event_chirps = [] # chirps that are in specified window around event # timestamps of chirps around event centered on the event timepoint centered_chirps = [] @@ -159,16 +163,19 @@ def event_triggered_chirps( else: # convert list of arrays to one array for plotting centered_chirps = np.concatenate(centered_chirps, axis=0) - centered_chirps_convolved = (acausal_kde1d( - centered_chirps, time, width)) / len(event) + centered_chirps_convolved = ( + acausal_kde1d(centered_chirps, time, width) + ) / len(event) return event_chirps, centered_chirps, centered_chirps_convolved def main(datapath: str): - foldernames = [ - datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath + x)] + datapath + x + "/" + for x in os.listdir(datapath) + if os.path.isdir(datapath + x) + ] nrecording_chirps = [] nrecording_chirps_fish_ids = [] @@ -179,7 +186,7 @@ def main(datapath: str): # Iterate over all recordings and save chirp- and event-timestamps for folder in foldernames: # exclude folder with empty LED_on_time.npy - if folder == '../data/mount_data/2020-05-12-10_00/': + if folder == "../data/mount_data/2020-05-12-10_00/": continue bh = Behavior(folder) @@ -209,7 +216,7 @@ def main(datapath: str): time_before_event = 30 time_after_event = 60 dt = 0.01 - width = 1.5 # width of kernel for all recordings, currently gaussian kernel + width = 1.5 # width of kernel for all recordings, currently gaussian kernel recording_width = 2 # width of kernel for each recording time = np.arange(-time_before_event, time_after_event, dt) @@ -232,18 +239,47 @@ def main(datapath: str): physical_contacts = nrecording_physicals[i] # Chirps around chasing onsets - _, centered_chasing_onset_chirps, cc_chasing_onset_chirps = event_triggered_chirps( - chasing_onsets, chirps, time_before_event, time_after_event, dt, recording_width) + ( + _, + centered_chasing_onset_chirps, + cc_chasing_onset_chirps, + ) = event_triggered_chirps( + chasing_onsets, + chirps, + time_before_event, + time_after_event, + dt, + recording_width, + ) # Chirps around chasing offsets - _, centered_chasing_offset_chirps, cc_chasing_offset_chirps = event_triggered_chirps( - chasing_offsets, chirps, time_before_event, time_after_event, dt, recording_width) + ( + _, + centered_chasing_offset_chirps, + cc_chasing_offset_chirps, + ) = event_triggered_chirps( + chasing_offsets, + chirps, + time_before_event, + time_after_event, + dt, + recording_width, + ) # Chirps around physical contacts - _, centered_physical_chirps, cc_physical_chirps = event_triggered_chirps( - physical_contacts, chirps, time_before_event, time_after_event, dt, recording_width) + ( + _, + centered_physical_chirps, + cc_physical_chirps, + ) = event_triggered_chirps( + physical_contacts, + chirps, + time_before_event, + time_after_event, + dt, + recording_width, + ) nrecording_centered_onset_chirps.append(centered_chasing_onset_chirps) - nrecording_centered_offset_chirps.append( - centered_chasing_offset_chirps) + nrecording_centered_offset_chirps.append(centered_chasing_offset_chirps) nrecording_centered_physical_chirps.append(centered_physical_chirps) ## Shuffled chirps ## @@ -331,12 +367,13 @@ def main(datapath: str): # New bootstrapping approach for n in range(nbootstrapping): - diff_onset = np.diff( - np.sort(flatten(nrecording_centered_onset_chirps))) + diff_onset = np.diff(np.sort(flatten(nrecording_centered_onset_chirps))) diff_offset = np.diff( - np.sort(flatten(nrecording_centered_offset_chirps))) + np.sort(flatten(nrecording_centered_offset_chirps)) + ) diff_physical = np.diff( - np.sort(flatten(nrecording_centered_physical_chirps))) + np.sort(flatten(nrecording_centered_physical_chirps)) + ) np.random.shuffle(diff_onset) shuffled_onset = np.cumsum(diff_onset) @@ -345,9 +382,11 @@ def main(datapath: str): np.random.shuffle(diff_physical) shuffled_physical = np.cumsum(diff_physical) - kde_onset (acausal_kde1d(shuffled_onset, time, width))/(27*100) - kde_offset = (acausal_kde1d(shuffled_offset, time, width))/(27*100) - kde_physical = (acausal_kde1d(shuffled_physical, time, width))/(27*100) + kde_onset(acausal_kde1d(shuffled_onset, time, width)) / (27 * 100) + kde_offset = (acausal_kde1d(shuffled_offset, time, width)) / (27 * 100) + kde_physical = (acausal_kde1d(shuffled_physical, time, width)) / ( + 27 * 100 + ) bootstrap_onset.append(kde_onset) bootstrap_offset.append(kde_offset) @@ -355,11 +394,14 @@ def main(datapath: str): # New shuffle approach q5, q50, q95 onset_q5, onset_median, onset_q95 = np.percentile( - bootstrap_onset, [5, 50, 95], axis=0) + bootstrap_onset, [5, 50, 95], axis=0 + ) offset_q5, offset_median, offset_q95 = np.percentile( - bootstrap_offset, [5, 50, 95], axis=0) + bootstrap_offset, [5, 50, 95], axis=0 + ) physical_q5, physical_median, physical_q95 = np.percentile( - bootstrap_physical, [5, 50, 95], axis=0) + bootstrap_physical, [5, 50, 95], axis=0 + ) # vstack um 1. Dim zu cutten # nrecording_shuffled_convolved_onset_chirps = np.vstack(nrecording_shuffled_convolved_onset_chirps) @@ -378,45 +420,66 @@ def main(datapath: str): # Flatten event timestamps all_onsets = np.concatenate( - nrecording_chasing_onsets).ravel() # not centered + nrecording_chasing_onsets + ).ravel() # not centered all_offsets = np.concatenate( - nrecording_chasing_offsets).ravel() # not centered - all_physicals = np.concatenate( - nrecording_physicals).ravel() # not centered + nrecording_chasing_offsets + ).ravel() # not centered + all_physicals = np.concatenate(nrecording_physicals).ravel() # not centered # Flatten all chirps around events all_onset_chirps = np.concatenate( - nrecording_centered_onset_chirps).ravel() # centered + nrecording_centered_onset_chirps + ).ravel() # centered all_offset_chirps = np.concatenate( - nrecording_centered_offset_chirps).ravel() # centered + nrecording_centered_offset_chirps + ).ravel() # centered all_physical_chirps = np.concatenate( - nrecording_centered_physical_chirps).ravel() # centered + nrecording_centered_physical_chirps + ).ravel() # centered # Convolute all chirps # Divide by total number of each event over all recordings - all_onset_chirps_convolved = (acausal_kde1d( - all_onset_chirps, time, width)) / len(all_onsets) - all_offset_chirps_convolved = (acausal_kde1d( - all_offset_chirps, time, width)) / len(all_offsets) - all_physical_chirps_convolved = (acausal_kde1d( - all_physical_chirps, time, width)) / len(all_physicals) + all_onset_chirps_convolved = ( + acausal_kde1d(all_onset_chirps, time, width) + ) / len(all_onsets) + all_offset_chirps_convolved = ( + acausal_kde1d(all_offset_chirps, time, width) + ) / len(all_offsets) + all_physical_chirps_convolved = ( + acausal_kde1d(all_physical_chirps, time, width) + ) / len(all_physicals) # Plot all events with all shuffled - fig, ax = plt.subplots(1, 3, figsize=( - 28*ps.cm, 16*ps.cm, ), constrained_layout=True, sharey='all') + fig, ax = plt.subplots( + 1, + 3, + figsize=( + 28 * ps.cm, + 16 * ps.cm, + ), + constrained_layout=True, + sharey="all", + ) # offsets = np.arange(1,28,1) - ax[0].set_xlabel('Time[s]') + ax[0].set_xlabel("Time[s]") # Plot chasing onsets - ax[0].set_ylabel('Chirp rate [Hz]') + ax[0].set_ylabel("Chirp rate [Hz]") ax[0].plot(time, all_onset_chirps_convolved, color=ps.yellow, zorder=2) ax0 = ax[0].twinx() nrecording_centered_onset_chirps = np.asarray( - nrecording_centered_onset_chirps, dtype=object) - ax0.eventplot(np.array(nrecording_centered_onset_chirps), - linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1) - ax0.vlines(0, 0, 1.5, ps.white, 'dashed') - ax[0].set_zorder(ax0.get_zorder()+1) + nrecording_centered_onset_chirps, dtype=object + ) + ax0.eventplot( + np.array(nrecording_centered_onset_chirps), + linelengths=0.5, + colors=ps.gray, + alpha=0.25, + zorder=1, + ) + ax0.vlines(0, 0, 1.5, ps.white, "dashed") + ax[0].set_zorder(ax0.get_zorder() + 1) ax[0].patch.set_visible(False) ax0.set_yticklabels([]) ax0.set_yticks([]) @@ -426,15 +489,21 @@ def main(datapath: str): ax[0].plot(time, onset_median, color=ps.black) # Plot chasing offets - ax[1].set_xlabel('Time[s]') + ax[1].set_xlabel("Time[s]") ax[1].plot(time, all_offset_chirps_convolved, color=ps.orange, zorder=2) ax1 = ax[1].twinx() nrecording_centered_offset_chirps = np.asarray( - nrecording_centered_offset_chirps, dtype=object) - ax1.eventplot(np.array(nrecording_centered_offset_chirps), - linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1) - ax1.vlines(0, 0, 1.5, ps.white, 'dashed') - ax[1].set_zorder(ax1.get_zorder()+1) + nrecording_centered_offset_chirps, dtype=object + ) + ax1.eventplot( + np.array(nrecording_centered_offset_chirps), + linelengths=0.5, + colors=ps.gray, + alpha=0.25, + zorder=1, + ) + ax1.vlines(0, 0, 1.5, ps.white, "dashed") + ax[1].set_zorder(ax1.get_zorder() + 1) ax[1].patch.set_visible(False) ax1.set_yticklabels([]) ax1.set_yticks([]) @@ -444,24 +513,31 @@ def main(datapath: str): ax[1].plot(time, offset_median, color=ps.black) # Plot physical contacts - ax[2].set_xlabel('Time[s]') + ax[2].set_xlabel("Time[s]") ax[2].plot(time, all_physical_chirps_convolved, color=ps.maroon, zorder=2) ax2 = ax[2].twinx() nrecording_centered_physical_chirps = np.asarray( - nrecording_centered_physical_chirps, dtype=object) - ax2.eventplot(np.array(nrecording_centered_physical_chirps), - linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1) - ax2.vlines(0, 0, 1.5, ps.white, 'dashed') - ax[2].set_zorder(ax2.get_zorder()+1) + nrecording_centered_physical_chirps, dtype=object + ) + ax2.eventplot( + np.array(nrecording_centered_physical_chirps), + linelengths=0.5, + colors=ps.gray, + alpha=0.25, + zorder=1, + ) + ax2.vlines(0, 0, 1.5, ps.white, "dashed") + ax[2].set_zorder(ax2.get_zorder() + 1) ax[2].patch.set_visible(False) ax2.set_yticklabels([]) ax2.set_yticks([]) # ax[2].fill_between(time, shuffled_q5_physical, shuffled_q95_physical, color=ps.gray, alpha=0.5) # ax[2].plot(time, shuffled_median_physical, ps.black) - ax[2].fill_between(time, physical_q5, physical_q95, - color=ps.gray, alpha=0.5) + ax[2].fill_between( + time, physical_q5, physical_q95, color=ps.gray, alpha=0.5 + ) ax[2].plot(time, physical_median, ps.black) - fig.suptitle('All recordings') + fig.suptitle("All recordings") plt.show() plt.close() @@ -587,7 +663,7 @@ def main(datapath: str): #### Chirps around events, only losers, one recording #### -if __name__ == '__main__': +if __name__ == "__main__": # Path to the data - datapath = '../data/mount_data/' + datapath = "../data/mount_data/" main(datapath) diff --git a/code/extract_chirps.py b/code/extract_chirps.py index 77e3e8d..900f0a2 100644 --- a/code/extract_chirps.py +++ b/code/extract_chirps.py @@ -8,50 +8,51 @@ from IPython import embed def get_valid_datasets(dataroot): - - datasets = sorted([name for name in os.listdir(dataroot) if os.path.isdir( - os.path.join(dataroot, name))]) + datasets = sorted( + [ + name + for name in os.listdir(dataroot) + if os.path.isdir(os.path.join(dataroot, name)) + ] + ) valid_datasets = [] for dataset in datasets: - path = os.path.join(dataroot, dataset) - csv_name = '-'.join(dataset.split('-')[:3]) + '.csv' + csv_name = "-".join(dataset.split("-")[:3]) + ".csv" if os.path.exists(os.path.join(path, csv_name)) is False: continue - if os.path.exists(os.path.join(path, 'ident_v.npy')) is False: + if os.path.exists(os.path.join(path, "ident_v.npy")) is False: continue - ident = np.load(os.path.join(path, 'ident_v.npy')) + ident = np.load(os.path.join(path, "ident_v.npy")) number_of_fish = len(np.unique(ident[~np.isnan(ident)])) if number_of_fish != 2: continue valid_datasets.append(dataset) - datapaths = [os.path.join(dataroot, dataset) + - '/' for dataset in valid_datasets] + datapaths = [ + os.path.join(dataroot, dataset) + "/" for dataset in valid_datasets + ] return datapaths, valid_datasets def main(datapaths): - for path in datapaths: - chirpdetection(path, plot='show') - - -if __name__ == '__main__': + chirpdetection(path, plot="show") - dataroot = '../data/mount_data/' +if __name__ == "__main__": + dataroot = "../data/mount_data/" - datapaths, valid_datasets= get_valid_datasets(dataroot) + datapaths, valid_datasets = get_valid_datasets(dataroot) - recs = pd.DataFrame(columns=['recording'], data=valid_datasets) - recs.to_csv('../recs.csv', index=False) + recs = pd.DataFrame(columns=["recording"], data=valid_datasets) + recs.to_csv("../recs.csv", index=False) # datapaths = ['../data/mount_data/2020-03-25-10_00/'] main(datapaths) diff --git a/code/get_behaviour.py b/code/get_behaviour.py index 36311ca..3513c1b 100644 --- a/code/get_behaviour.py +++ b/code/get_behaviour.py @@ -1,4 +1,4 @@ -import os +import os from paramiko import SSHClient from scp import SCPClient from IPython import embed @@ -7,29 +7,41 @@ from pandas import read_csv ssh = SSHClient() ssh.load_system_host_keys() -ssh.connect(hostname='kraken', - username='efish', - password='fwNix4U', - ) +ssh.connect( + hostname="kraken", + username="efish", + password="fwNix4U", +) # SCPCLient takes a paramiko transport as its only argument scp = SCPClient(ssh.get_transport()) -data = read_csv('../recs.csv') -foldernames = data['recording'].values +data = read_csv("../recs.csv") +foldernames = data["recording"].values -directory = f'/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/mount_data/' +directory = f"/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/mount_data/" for foldername in foldernames: - - if not os.path.exists(directory+foldername): - os.makedirs(directory+foldername) - - files = [('-').join(foldername.split('-')[:3])+'.csv','chirp_ids.npy', 'chirps.npy', 'fund_v.npy', 'ident_v.npy', 'idx_v.npy', 'times.npy', 'spec.npy', 'LED_on_time.npy', 'sign_v.npy'] - + if not os.path.exists(directory + foldername): + os.makedirs(directory + foldername) + + files = [ + ("-").join(foldername.split("-")[:3]) + ".csv", + "chirp_ids.npy", + "chirps.npy", + "fund_v.npy", + "ident_v.npy", + "idx_v.npy", + "times.npy", + "spec.npy", + "LED_on_time.npy", + "sign_v.npy", + ] for f in files: - scp.get(f'/home/efish/behavior/2019_tube_competition/{foldername}/{f}', - directory+foldername) + scp.get( + f"/home/efish/behavior/2019_tube_competition/{foldername}/{f}", + directory + foldername, + ) scp.close() diff --git a/code/modules/behaviour_handling.py b/code/modules/behaviour_handling.py index 94a0ca1..a50d67a 100644 --- a/code/modules/behaviour_handling.py +++ b/code/modules/behaviour_handling.py @@ -30,12 +30,12 @@ class Behavior: """ def __init__(self, folder_path: str) -> None: - - LED_on_time_BORIS = np.load(os.path.join( - folder_path, 'LED_on_time.npy'), allow_pickle=True) + LED_on_time_BORIS = np.load( + os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True + ) csv_filename = os.path.split(folder_path[:-1])[-1] - csv_filename = '-'.join(csv_filename.split('-')[:-1]) + '.csv' + csv_filename = "-".join(csv_filename.split("-")[:-1]) + ".csv" # embed() # csv_filename = [f for f in os.listdir( @@ -43,31 +43,39 @@ class Behavior: # logger.info(f'CSV file: {csv_filename}') self.dataframe = read_csv(os.path.join(folder_path, csv_filename)) - self.chirps = np.load(os.path.join( - folder_path, 'chirps.npy'), allow_pickle=True) - self.chirps_ids = np.load(os.path.join( - folder_path, 'chirp_ids.npy'), allow_pickle=True) - - self.ident = np.load(os.path.join( - folder_path, 'ident_v.npy'), allow_pickle=True) - self.idx = np.load(os.path.join( - folder_path, 'idx_v.npy'), allow_pickle=True) - self.freq = np.load(os.path.join( - folder_path, 'fund_v.npy'), allow_pickle=True) - self.time = np.load(os.path.join( - folder_path, "times.npy"), allow_pickle=True) - self.spec = np.load(os.path.join( - folder_path, "spec.npy"), allow_pickle=True) + self.chirps = np.load( + os.path.join(folder_path, "chirps.npy"), allow_pickle=True + ) + self.chirps_ids = np.load( + os.path.join(folder_path, "chirp_ids.npy"), allow_pickle=True + ) + + self.ident = np.load( + os.path.join(folder_path, "ident_v.npy"), allow_pickle=True + ) + self.idx = np.load( + os.path.join(folder_path, "idx_v.npy"), allow_pickle=True + ) + self.freq = np.load( + os.path.join(folder_path, "fund_v.npy"), allow_pickle=True + ) + self.time = np.load( + os.path.join(folder_path, "times.npy"), allow_pickle=True + ) + self.spec = np.load( + os.path.join(folder_path, "spec.npy"), allow_pickle=True + ) for k, key in enumerate(self.dataframe.keys()): key = key.lower() - if ' ' in key: - key = key.replace(' ', '_') - if '(' in key: - key = key.replace('(', '') - key = key.replace(')', '') - setattr(self, key, np.array( - self.dataframe[self.dataframe.keys()[k]])) + if " " in key: + key = key.replace(" ", "_") + if "(" in key: + key = key.replace("(", "") + key = key.replace(")", "") + setattr( + self, key, np.array(self.dataframe[self.dataframe.keys()[k]]) + ) last_LED_t_BORIS = LED_on_time_BORIS[-1] real_time_range = self.time[-1] - self.time[0] @@ -78,22 +86,19 @@ class Behavior: def correct_chasing_events( - category: np.ndarray, - timestamps: np.ndarray + category: np.ndarray, timestamps: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: + onset_ids = np.arange(len(category))[category == 0] + offset_ids = np.arange(len(category))[category == 1] - onset_ids = np.arange( - len(category))[category == 0] - offset_ids = np.arange( - len(category))[category == 1] - - wrong_bh = np.arange(len(category))[ - category != 2][:-1][np.diff(category[category != 2]) == 0] + wrong_bh = np.arange(len(category))[category != 2][:-1][ + np.diff(category[category != 2]) == 0 + ] if category[category != 2][-1] == 0: wrong_bh = np.append( - wrong_bh, - np.arange(len(category))[category != 2][-1]) + wrong_bh, np.arange(len(category))[category != 2][-1] + ) if onset_ids[0] > offset_ids[0]: offset_ids = np.delete(offset_ids, 0) @@ -103,18 +108,16 @@ def correct_chasing_events( category = np.delete(category, wrong_bh) timestamps = np.delete(timestamps, wrong_bh) - new_onset_ids = np.arange( - len(category))[category == 0] - new_offset_ids = np.arange( - len(category))[category == 1] + new_onset_ids = np.arange(len(category))[category == 0] + new_offset_ids = np.arange(len(category))[category == 1] # Check whether on- or offset is longer and calculate length difference if len(new_onset_ids) > len(new_offset_ids): embed() - logger.warning('Onsets are greater than offsets') + logger.warning("Onsets are greater than offsets") elif len(new_onset_ids) < len(new_offset_ids): - logger.warning('Offsets are greater than onsets') + logger.warning("Offsets are greater than onsets") elif len(new_onset_ids) == len(new_offset_ids): # logger.info('Chasing events are equal') pass @@ -130,13 +133,11 @@ def center_chirps( # dt: float, # width: float, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - - event_chirps = [] # chirps that are in specified window around event + event_chirps = [] # chirps that are in specified window around event # timestamps of chirps around event centered on the event timepoint centered_chirps = [] for event_timestamp in events: - start = event_timestamp - time_before_event stop = event_timestamp + time_after_event chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)] @@ -152,7 +153,8 @@ def center_chirps( if len(centered_chirps) != len(event_chirps): raise ValueError( - 'Non centered chirps and centered chirps are not equal') + "Non centered chirps and centered chirps are not equal" + ) # time = np.arange(-time_before_event, time_after_event, dt) diff --git a/code/modules/datahandling.py b/code/modules/datahandling.py index 0a240ab..68e73cd 100644 --- a/code/modules/datahandling.py +++ b/code/modules/datahandling.py @@ -23,7 +23,9 @@ def minmaxnorm(data): return (data - np.min(data)) / (np.max(data) - np.min(data)) -def instantaneous_frequency2(signal: np.ndarray, fs: float, interpolation: str = 'linear') -> np.ndarray: +def instantaneous_frequency2( + signal: np.ndarray, fs: float, interpolation: str = "linear" +) -> np.ndarray: """ Compute the instantaneous frequency of a periodic signal using zero crossings and resample the frequency using linear or cubic interpolation to match the dimensions of the input array. @@ -55,10 +57,10 @@ def instantaneous_frequency2(signal: np.ndarray, fs: float, interpolation: str = orig_len = len(signal) freq = resample(freq, orig_len) - if interpolation == 'linear': + if interpolation == "linear": freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq) - elif interpolation == 'cubic': - freq = resample(freq, orig_len, window='cubic') + elif interpolation == "cubic": + freq = resample(freq, orig_len, window="cubic") return freq @@ -67,7 +69,7 @@ def instantaneous_frequency( signal: np.ndarray, samplerate: int, smoothing_window: int, - interpolation: str = 'linear', + interpolation: str = "linear", ) -> np.ndarray: """ Compute the instantaneous frequency of a signal that is approximately @@ -120,11 +122,10 @@ def instantaneous_frequency( orig_len = len(signal) freq = resample(instantaneous_frequency, orig_len) - if interpolation == 'linear': + if interpolation == "linear": freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq) - elif interpolation == 'cubic': - freq = resample(freq, orig_len, window='cubic') - + elif interpolation == "cubic": + freq = resample(freq, orig_len, window="cubic") return freq @@ -160,7 +161,6 @@ def purge_duplicates( group = [timestamps[0]] for i in range(1, len(timestamps)): - # check the difference between current timestamp and previous # timestamp is less than the threshold if timestamps[i] - timestamps[i - 1] < threshold: @@ -379,7 +379,6 @@ def acausal_kde1d(spikes, time, width): if __name__ == "__main__": - timestamps = [ [1.2, 1.5, 1.3], [], diff --git a/code/modules/filehandling.py b/code/modules/filehandling.py index c3c71f2..382a49d 100644 --- a/code/modules/filehandling.py +++ b/code/modules/filehandling.py @@ -35,7 +35,6 @@ class LoadData: """ def __init__(self, datapath: str) -> None: - # load raw data self.datapath = datapath self.file = os.path.join(datapath, "traces-grid1.raw") diff --git a/code/modules/filters.py b/code/modules/filters.py index e6d9896..06fe236 100644 --- a/code/modules/filters.py +++ b/code/modules/filters.py @@ -3,10 +3,10 @@ import numpy as np def bandpass_filter( - signal: np.ndarray, - samplerate: float, - lowf: float, - highf: float, + signal: np.ndarray, + samplerate: float, + lowf: float, + highf: float, ) -> np.ndarray: """Bandpass filter a signal. @@ -60,9 +60,7 @@ def highpass_filter( def lowpass_filter( - signal: np.ndarray, - samplerate: float, - cutoff: float + signal: np.ndarray, samplerate: float, cutoff: float ) -> np.ndarray: """Lowpass filter a signal. @@ -86,10 +84,9 @@ def lowpass_filter( return filtered_signal -def envelope(signal: np.ndarray, - samplerate: float, - cutoff_frequency: float - ) -> np.ndarray: +def envelope( + signal: np.ndarray, samplerate: float, cutoff_frequency: float +) -> np.ndarray: """Calculate the envelope of a signal using a lowpass filter. Parameters diff --git a/code/modules/logger.py b/code/modules/logger.py index 5dabf80..ed6d93e 100644 --- a/code/modules/logger.py +++ b/code/modules/logger.py @@ -2,12 +2,13 @@ import logging def makeLogger(name: str): - # create logger formats for file and terminal file_formatter = logging.Formatter( - "[ %(levelname)s ] ~ %(asctime)s ~ %(module)s.%(funcName)s: %(message)s") + "[ %(levelname)s ] ~ %(asctime)s ~ %(module)s.%(funcName)s: %(message)s" + ) console_formatter = logging.Formatter( - "[ %(levelname)s ] in %(module)s.%(funcName)s: %(message)s") + "[ %(levelname)s ] in %(module)s.%(funcName)s: %(message)s" + ) # create logging file if loglevel is debug file_handler = logging.FileHandler(f"gridtools_log.log", mode="w") @@ -29,7 +30,6 @@ def makeLogger(name: str): if __name__ == "__main__": - # initiate logger mylogger = makeLogger(__name__) diff --git a/code/modules/plotstyle.py b/code/modules/plotstyle.py index 22b14c6..43d12ac 100644 --- a/code/modules/plotstyle.py +++ b/code/modules/plotstyle.py @@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap def PlotStyle() -> None: class style: - # lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8) # units @@ -76,13 +75,15 @@ def PlotStyle() -> None: va="center", zorder=1000, bbox=dict( - boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1 + boxstyle=f"circle, pad={padding}", + fc="white", + ec="black", + lw=1, ), ) @classmethod def fade_cmap(cls, cmap): - my_cmap = cmap(np.arange(cmap.N)) my_cmap[:, -1] = np.linspace(0, 1, cmap.N) my_cmap = ListedColormap(my_cmap) @@ -295,7 +296,6 @@ def PlotStyle() -> None: if __name__ == "__main__": - s = PlotStyle() import matplotlib.cbook as cbook @@ -347,7 +347,8 @@ if __name__ == "__main__": for ax in axs: ax.yaxis.grid(True) ax.set_xticks( - [y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"] + [y + 1 for y in range(len(all_data))], + labels=["x1", "x2", "x3", "x4"], ) ax.set_xlabel("Four separate samples") ax.set_ylabel("Observed values") @@ -396,7 +397,10 @@ if __name__ == "__main__": grid = np.random.rand(4, 4) fig, axs = plt.subplots( - nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []} + nrows=3, + ncols=6, + figsize=(9, 6), + subplot_kw={"xticks": [], "yticks": []}, ) for ax, interp_method in zip(axs.flat, methods): diff --git a/code/modules/plotstyle1.py b/code/modules/plotstyle1.py index 32af4d2..237996b 100644 --- a/code/modules/plotstyle1.py +++ b/code/modules/plotstyle1.py @@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap def PlotStyle() -> None: class style: - # lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8) # units @@ -76,13 +75,15 @@ def PlotStyle() -> None: va="center", zorder=1000, bbox=dict( - boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1 + boxstyle=f"circle, pad={padding}", + fc="white", + ec="black", + lw=1, ), ) @classmethod def fade_cmap(cls, cmap): - my_cmap = cmap(np.arange(cmap.N)) my_cmap[:, -1] = np.linspace(0, 1, cmap.N) my_cmap = ListedColormap(my_cmap) @@ -295,7 +296,6 @@ def PlotStyle() -> None: if __name__ == "__main__": - s = PlotStyle() import matplotlib.cbook as cbook @@ -347,7 +347,8 @@ if __name__ == "__main__": for ax in axs: ax.yaxis.grid(True) ax.set_xticks( - [y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"] + [y + 1 for y in range(len(all_data))], + labels=["x1", "x2", "x3", "x4"], ) ax.set_xlabel("Four separate samples") ax.set_ylabel("Observed values") @@ -396,7 +397,10 @@ if __name__ == "__main__": grid = np.random.rand(4, 4) fig, axs = plt.subplots( - nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []} + nrows=3, + ncols=6, + figsize=(9, 6), + subplot_kw={"xticks": [], "yticks": []}, ) for ax, interp_method in zip(axs.flat, methods): diff --git a/code/modules/plotstyle_dark.py b/code/modules/plotstyle_dark.py index d5b9557..d767e24 100644 --- a/code/modules/plotstyle_dark.py +++ b/code/modules/plotstyle_dark.py @@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap def PlotStyle() -> None: class style: - # lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8) # units @@ -76,13 +75,15 @@ def PlotStyle() -> None: va="center", zorder=1000, bbox=dict( - boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1 + boxstyle=f"circle, pad={padding}", + fc="white", + ec="black", + lw=1, ), ) @classmethod def fade_cmap(cls, cmap): - my_cmap = cmap(np.arange(cmap.N)) my_cmap[:, -1] = np.linspace(0, 1, cmap.N) my_cmap = ListedColormap(my_cmap) @@ -295,7 +296,6 @@ def PlotStyle() -> None: if __name__ == "__main__": - s = PlotStyle() import matplotlib.cbook as cbook @@ -347,7 +347,8 @@ if __name__ == "__main__": for ax in axs: ax.yaxis.grid(True) ax.set_xticks( - [y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"] + [y + 1 for y in range(len(all_data))], + labels=["x1", "x2", "x3", "x4"], ) ax.set_xlabel("Four separate samples") ax.set_ylabel("Observed values") @@ -396,7 +397,10 @@ if __name__ == "__main__": grid = np.random.rand(4, 4) fig, axs = plt.subplots( - nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []} + nrows=3, + ncols=6, + figsize=(9, 6), + subplot_kw={"xticks": [], "yticks": []}, ) for ax, interp_method in zip(axs.flat, methods): diff --git a/code/modules/simulations.py b/code/modules/simulations.py index 473bac8..a074801 100644 --- a/code/modules/simulations.py +++ b/code/modules/simulations.py @@ -37,7 +37,7 @@ def create_chirp( ck = 0 csig = 0.5 * chirpduration / np.power(2.0 * np.log(10.0), 0.5 / kurtosis) - #csig = csig*-1 + # csig = csig*-1 for k, t in enumerate(time): a = 1.0 f = eodf diff --git a/code/plot_chirp_size.py b/code/plot_chirp_size.py index 95b2a95..1153ff5 100644 --- a/code/plot_chirp_size.py +++ b/code/plot_chirp_size.py @@ -16,26 +16,25 @@ logger = makeLogger(__name__) def get_chirp_winner_loser(folder_name, Behavior, order_meta_df): - - foldername = folder_name.split('/')[-2] - winner_row = order_meta_df[order_meta_df['recording'] == foldername] - winner = winner_row['winner'].values[0].astype(int) - winner_fish1 = winner_row['fish1'].values[0].astype(int) - winner_fish2 = winner_row['fish2'].values[0].astype(int) + foldername = folder_name.split("/")[-2] + winner_row = order_meta_df[order_meta_df["recording"] == foldername] + winner = winner_row["winner"].values[0].astype(int) + winner_fish1 = winner_row["fish1"].values[0].astype(int) + winner_fish2 = winner_row["fish2"].values[0].astype(int) if winner > 0: if winner == winner_fish1: - winner_fish_id = winner_row['rec_id1'].values[0] - loser_fish_id = winner_row['rec_id2'].values[0] + winner_fish_id = winner_row["rec_id1"].values[0] + loser_fish_id = winner_row["rec_id2"].values[0] elif winner == winner_fish2: - winner_fish_id = winner_row['rec_id2'].values[0] - loser_fish_id = winner_row['rec_id1'].values[0] + winner_fish_id = winner_row["rec_id2"].values[0] + loser_fish_id = winner_row["rec_id1"].values[0] chirp_winner = len( - Behavior.chirps[Behavior.chirps_ids == winner_fish_id]) - chirp_loser = len( - Behavior.chirps[Behavior.chirps_ids == loser_fish_id]) + Behavior.chirps[Behavior.chirps_ids == winner_fish_id] + ) + chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id]) return chirp_winner, chirp_loser else: @@ -43,24 +42,24 @@ def get_chirp_winner_loser(folder_name, Behavior, order_meta_df): def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df): - - foldername = folder_name.split('/')[-2] - folder_row = order_meta_df[order_meta_df['recording'] == foldername] - fish1 = folder_row['fish1'].values[0].astype(int) - fish2 = folder_row['fish2'].values[0].astype(int) - winner = folder_row['winner'].values[0].astype(int) - - groub = folder_row['group'].values[0].astype(int) - size_fish1_row = id_meta_df[(id_meta_df['group'] == groub) & ( - id_meta_df['fish'] == fish1)] - size_fish2_row = id_meta_df[(id_meta_df['group'] == groub) & ( - id_meta_df['fish'] == fish2)] - - size_winners = [size_fish1_row[col].values[0] - for col in ['l1', 'l2', 'l3']] + foldername = folder_name.split("/")[-2] + folder_row = order_meta_df[order_meta_df["recording"] == foldername] + fish1 = folder_row["fish1"].values[0].astype(int) + fish2 = folder_row["fish2"].values[0].astype(int) + winner = folder_row["winner"].values[0].astype(int) + + groub = folder_row["group"].values[0].astype(int) + size_fish1_row = id_meta_df[ + (id_meta_df["group"] == groub) & (id_meta_df["fish"] == fish1) + ] + size_fish2_row = id_meta_df[ + (id_meta_df["group"] == groub) & (id_meta_df["fish"] == fish2) + ] + + size_winners = [size_fish1_row[col].values[0] for col in ["l1", "l2", "l3"]] size_fish1 = np.nanmean(size_winners) - size_losers = [size_fish2_row[col].values[0] for col in ['l1', 'l2', 'l3']] + size_losers = [size_fish2_row[col].values[0] for col in ["l1", "l2", "l3"]] size_fish2 = np.nanmean(size_losers) if winner == fish1: @@ -75,8 +74,8 @@ def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df): size_diff_bigger = 0 size_diff_smaller = 0 - winner_fish_id = folder_row['rec_id1'].values[0] - loser_fish_id = folder_row['rec_id2'].values[0] + winner_fish_id = folder_row["rec_id1"].values[0] + loser_fish_id = folder_row["rec_id2"].values[0] elif winner == fish2: if size_fish2 > size_fish1: @@ -90,39 +89,39 @@ def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df): size_diff_bigger = 0 size_diff_smaller = 0 - winner_fish_id = folder_row['rec_id2'].values[0] - loser_fish_id = folder_row['rec_id1'].values[0] + winner_fish_id = folder_row["rec_id2"].values[0] + loser_fish_id = folder_row["rec_id1"].values[0] else: size_diff_bigger = np.nan size_diff_smaller = np.nan winner_fish_id = np.nan loser_fish_id = np.nan - return size_diff_bigger, size_diff_smaller, winner_fish_id, loser_fish_id + return ( + size_diff_bigger, + size_diff_smaller, + winner_fish_id, + loser_fish_id, + ) - chirp_winner = len( - Behavior.chirps[Behavior.chirps_ids == winner_fish_id]) - chirp_loser = len( - Behavior.chirps[Behavior.chirps_ids == loser_fish_id]) + chirp_winner = len(Behavior.chirps[Behavior.chirps_ids == winner_fish_id]) + chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id]) - return size_diff_bigger, chirp_winner, size_diff_smaller, chirp_loser + return size_diff_bigger, chirp_winner, size_diff_smaller, chirp_loser def get_chirp_freq(folder_name, Behavior, order_meta_df): + foldername = folder_name.split("/")[-2] + folder_row = order_meta_df[order_meta_df["recording"] == foldername] + fish1 = folder_row["fish1"].values[0].astype(int) + fish2 = folder_row["fish2"].values[0].astype(int) - foldername = folder_name.split('/')[-2] - folder_row = order_meta_df[order_meta_df['recording'] == foldername] - fish1 = folder_row['fish1'].values[0].astype(int) - fish2 = folder_row['fish2'].values[0].astype(int) + fish1_freq = folder_row["rec_id1"].values[0].astype(int) + fish2_freq = folder_row["rec_id2"].values[0].astype(int) - fish1_freq = folder_row['rec_id1'].values[0].astype(int) - fish2_freq = folder_row['rec_id2'].values[0].astype(int) - - chirp_freq_fish1 = np.nanmedian( - Behavior.freq[Behavior.ident == fish1_freq]) - chirp_freq_fish2 = np.nanmedian( - Behavior.freq[Behavior.ident == fish2_freq]) - winner = folder_row['winner'].values[0].astype(int) + chirp_freq_fish1 = np.nanmedian(Behavior.freq[Behavior.ident == fish1_freq]) + chirp_freq_fish2 = np.nanmedian(Behavior.freq[Behavior.ident == fish2_freq]) + winner = folder_row["winner"].values[0].astype(int) if winner == fish1: # if chirp_freq_fish1 > chirp_freq_fish2: @@ -138,9 +137,9 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df): # winner_fish_id = np.nan # loser_fish_id = np.nan - winner_fish_id = folder_row['rec_id1'].values[0] + winner_fish_id = folder_row["rec_id1"].values[0] winner_fish_freq = chirp_freq_fish1 - loser_fish_id = folder_row['rec_id2'].values[0] + loser_fish_id = folder_row["rec_id2"].values[0] loser_fish_freq = chirp_freq_fish2 elif winner == fish2: @@ -157,9 +156,9 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df): # winner_fish_id = np.nan # loser_fish_id = np.nan - winner_fish_id = folder_row['rec_id2'].values[0] + winner_fish_id = folder_row["rec_id2"].values[0] winner_fish_freq = chirp_freq_fish2 - loser_fish_id = folder_row['rec_id1'].values[0] + loser_fish_id = folder_row["rec_id1"].values[0] loser_fish_freq = chirp_freq_fish1 else: winner_fish_freq = np.nan @@ -168,25 +167,25 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df): loser_fish_id = np.nan return winner_fish_freq, winner_fish_id, loser_fish_freq, loser_fish_id - chirp_winner = len( - Behavior.chirps[Behavior.chirps_ids == winner_fish_id]) - chirp_loser = len( - Behavior.chirps[Behavior.chirps_ids == loser_fish_id]) + chirp_winner = len(Behavior.chirps[Behavior.chirps_ids == winner_fish_id]) + chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id]) return winner_fish_freq, chirp_winner, loser_fish_freq, chirp_loser def main(datapath: str): - foldernames = [ - datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)] + datapath + x + "/" + for x in os.listdir(datapath) + if os.path.isdir(datapath + x) + ] foldernames, _ = get_valid_datasets(datapath) - path_order_meta = ( - '/').join(foldernames[0].split('/')[:-2]) + '/order_meta.csv' + path_order_meta = ("/").join( + foldernames[0].split("/")[:-2] + ) + "/order_meta.csv" order_meta_df = read_csv(path_order_meta) - order_meta_df['recording'] = order_meta_df['recording'].str[1:-1] - path_id_meta = ( - '/').join(foldernames[0].split('/')[:-2]) + '/id_meta.csv' + order_meta_df["recording"] = order_meta_df["recording"].str[1:-1] + path_id_meta = ("/").join(foldernames[0].split("/")[:-2]) + "/id_meta.csv" id_meta_df = read_csv(path_id_meta) chirps_winner = [] @@ -202,10 +201,9 @@ def main(datapath: str): freq_chirps_winner = [] freq_chirps_loser = [] - for foldername in foldernames: # behabvior is pandas dataframe with all the data - if foldername == '../data/mount_data/2020-05-12-10_00/': + if foldername == "../data/mount_data/2020-05-12-10_00/": continue bh = Behavior(foldername) # chirps are not sorted in time (presumably due to prior groupings) @@ -217,15 +215,24 @@ def main(datapath: str): category, timestamps = correct_chasing_events(category, timestamps) winner_chirp, loser_chirp = get_chirp_winner_loser( - foldername, bh, order_meta_df) + foldername, bh, order_meta_df + ) chirps_winner.append(winner_chirp) chirps_loser.append(loser_chirp) - size_diff_bigger, chirp_winner, size_diff_smaller, chirp_loser = get_chirp_size( - foldername, bh, order_meta_df, id_meta_df) + ( + size_diff_bigger, + chirp_winner, + size_diff_smaller, + chirp_loser, + ) = get_chirp_size(foldername, bh, order_meta_df, id_meta_df) - freq_winner, chirp_freq_winner, freq_loser, chirp_freq_loser = get_chirp_freq( - foldername, bh, order_meta_df) + ( + freq_winner, + chirp_freq_winner, + freq_loser, + chirp_freq_loser, + ) = get_chirp_freq(foldername, bh, order_meta_df) freq_diffs_higher.append(freq_winner) freq_diffs_lower.append(freq_loser) @@ -242,82 +249,124 @@ def main(datapath: str): pearsonr(size_diffs_winner, size_chirps_winner) pearsonr(size_diffs_loser, size_chirps_loser) - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=( - 21*ps.cm, 7*ps.cm), width_ratios=[1, 0.8, 0.8], sharey=True) - plt.subplots_adjust(left=0.11, right=0.948, top=0.86, - wspace=0.343, bottom=0.198) + fig, (ax1, ax2, ax3) = plt.subplots( + 1, + 3, + figsize=(21 * ps.cm, 7 * ps.cm), + width_ratios=[1, 0.8, 0.8], + sharey=True, + ) + plt.subplots_adjust( + left=0.11, right=0.948, top=0.86, wspace=0.343, bottom=0.198 + ) scatterwinner = 1.15 scatterloser = 1.85 chirps_winner = np.asarray(chirps_winner)[~np.isnan(chirps_winner)] chirps_loser = np.asarray(chirps_loser)[~np.isnan(chirps_loser)] embed() exit() - freq_diffs_higher = np.asarray( - freq_diffs_higher)[~np.isnan(freq_diffs_higher)] - freq_diffs_lower = np.asarray(freq_diffs_lower)[ - ~np.isnan(freq_diffs_lower)] - freq_chirps_winner = np.asarray( - freq_chirps_winner)[~np.isnan(freq_chirps_winner)] - freq_chirps_loser = np.asarray( - freq_chirps_loser)[~np.isnan(freq_chirps_loser)] + freq_diffs_higher = np.asarray(freq_diffs_higher)[ + ~np.isnan(freq_diffs_higher) + ] + freq_diffs_lower = np.asarray(freq_diffs_lower)[~np.isnan(freq_diffs_lower)] + freq_chirps_winner = np.asarray(freq_chirps_winner)[ + ~np.isnan(freq_chirps_winner) + ] + freq_chirps_loser = np.asarray(freq_chirps_loser)[ + ~np.isnan(freq_chirps_loser) + ] stat = wilcoxon(chirps_winner, chirps_loser) print(stat) winner_color = ps.gblue2 loser_color = ps.gblue1 - bplot1 = ax1.boxplot(chirps_winner, positions=[ - 0.9], showfliers=False, patch_artist=True) - - bplot2 = ax1.boxplot(chirps_loser, positions=[ - 2.1], showfliers=False, patch_artist=True) - - ax1.scatter(np.ones(len(chirps_winner)) * - scatterwinner, chirps_winner, color=winner_color) - ax1.scatter(np.ones(len(chirps_loser)) * - scatterloser, chirps_loser, color=loser_color) - ax1.set_xticklabels(['Winner', 'Loser']) - - ax1.text(0.1, 0.85, f'n={len(chirps_loser)}', - transform=ax1.transAxes, color=ps.white) + bplot1 = ax1.boxplot( + chirps_winner, positions=[0.9], showfliers=False, patch_artist=True + ) + + bplot2 = ax1.boxplot( + chirps_loser, positions=[2.1], showfliers=False, patch_artist=True + ) + + ax1.scatter( + np.ones(len(chirps_winner)) * scatterwinner, + chirps_winner, + color=winner_color, + ) + ax1.scatter( + np.ones(len(chirps_loser)) * scatterloser, + chirps_loser, + color=loser_color, + ) + ax1.set_xticklabels(["Winner", "Loser"]) + + ax1.text( + 0.1, + 0.85, + f"n={len(chirps_loser)}", + transform=ax1.transAxes, + color=ps.white, + ) for w, l in zip(chirps_winner, chirps_loser): - ax1.plot([scatterwinner, scatterloser], [w, l], - color=ps.white, alpha=0.6, linewidth=1, zorder=-1) - ax1.set_ylabel('Chirp counts', color=ps.white) - ax1.set_xlabel('Competition outcome', color=ps.white) + ax1.plot( + [scatterwinner, scatterloser], + [w, l], + color=ps.white, + alpha=0.6, + linewidth=1, + zorder=-1, + ) + ax1.set_ylabel("Chirp counts", color=ps.white) + ax1.set_xlabel("Competition outcome", color=ps.white) ps.set_boxplot_color(bplot1, winner_color) ps.set_boxplot_color(bplot2, loser_color) - ax2.scatter(size_diffs_winner, size_chirps_winner, - color=winner_color, label='Winner') - ax2.scatter(size_diffs_loser, size_chirps_loser, - color=loser_color, label='Loser') - - ax2.text(0.05, 0.85, f'n={len(size_chirps_loser)}', - transform=ax2.transAxes, color=ps.white) - - ax2.set_xlabel('Size difference [cm]') + ax2.scatter( + size_diffs_winner, + size_chirps_winner, + color=winner_color, + label="Winner", + ) + ax2.scatter( + size_diffs_loser, size_chirps_loser, color=loser_color, label="Loser" + ) + + ax2.text( + 0.05, + 0.85, + f"n={len(size_chirps_loser)}", + transform=ax2.transAxes, + color=ps.white, + ) + + ax2.set_xlabel("Size difference [cm]") # ax2.set_xticks(np.arange(-10, 10.1, 2)) ax3.scatter(freq_diffs_higher, freq_chirps_winner, color=winner_color) ax3.scatter(freq_diffs_lower, freq_chirps_loser, color=loser_color) - ax3.text(0.1, 0.85, f'n={len(np.asarray(freq_chirps_winner)[~np.isnan(freq_chirps_loser)])}', - transform=ax3.transAxes, color=ps.white) + ax3.text( + 0.1, + 0.85, + f"n={len(np.asarray(freq_chirps_winner)[~np.isnan(freq_chirps_loser)])}", + transform=ax3.transAxes, + color=ps.white, + ) - ax3.set_xlabel('EODf [Hz]') + ax3.set_xlabel("EODf [Hz]") handles, labels = ax2.get_legend_handles_labels() - fig.legend(handles, labels, loc='upper center', - ncol=2, bbox_to_anchor=(0.5, 1.04)) + fig.legend( + handles, labels, loc="upper center", ncol=2, bbox_to_anchor=(0.5, 1.04) + ) # pearson r - plt.savefig('../poster/figs/chirps_winner_loser.pdf') + plt.savefig("../poster/figs/chirps_winner_loser.pdf") plt.show() -if __name__ == '__main__': - +if __name__ == "__main__": # Path to the data - datapath = '../data/mount_data/' + datapath = "../data/mount_data/" main(datapath) diff --git a/code/plot_chirps_in_chasing.py b/code/plot_chirps_in_chasing.py index ee43196..ef0e5a7 100644 --- a/code/plot_chirps_in_chasing.py +++ b/code/plot_chirps_in_chasing.py @@ -21,14 +21,16 @@ logger = makeLogger(__name__) def main(datapath: str): - foldernames = [ - datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)] + datapath + x + "/" + for x in os.listdir(datapath) + if os.path.isdir(datapath + x) + ] time_precents = [] chirps_percents = [] for foldername in foldernames: # behabvior is pandas dataframe with all the data - if foldername == '../data/mount_data/2020-05-12-10_00/': + if foldername == "../data/mount_data/2020-05-12-10_00/": continue bh = Behavior(foldername) @@ -46,50 +48,70 @@ def main(datapath: str): chirps_in_chasings = [] for onset, offset in zip(chasing_onset, chasing_offset): chirps_in_chasing = [ - c for c in bh.chirps if (c > onset) & (c < offset)] + c for c in bh.chirps if (c > onset) & (c < offset) + ] chirps_in_chasings.append(chirps_in_chasing) try: time_chasing = np.sum( - chasing_offset[chasing_offset < 3*60*60] - chasing_onset[chasing_onset < 3*60*60]) + chasing_offset[chasing_offset < 3 * 60 * 60] + - chasing_onset[chasing_onset < 3 * 60 * 60] + ) except: time_chasing = np.sum( - chasing_offset[chasing_offset < 3*60*60] - chasing_onset[chasing_onset < 3*60*60][:-1]) + chasing_offset[chasing_offset < 3 * 60 * 60] + - chasing_onset[chasing_onset < 3 * 60 * 60][:-1] + ) - time_chasing_percent = (time_chasing/(3*60*60))*100 + time_chasing_percent = (time_chasing / (3 * 60 * 60)) * 100 chirps_chasing = np.asarray(flatten(chirps_in_chasings)) - chirps_chasing_new = chirps_chasing[chirps_chasing < 3*60*60] - chirps_percent = (len(chirps_chasing_new) / - len(bh.chirps[bh.chirps < 3*60*60]))*100 + chirps_chasing_new = chirps_chasing[chirps_chasing < 3 * 60 * 60] + chirps_percent = ( + len(chirps_chasing_new) / len(bh.chirps[bh.chirps < 3 * 60 * 60]) + ) * 100 time_precents.append(time_chasing_percent) chirps_percents.append(chirps_percent) - fig, ax = plt.subplots(1, 1, figsize=(7*ps.cm, 7*ps.cm)) + fig, ax = plt.subplots(1, 1, figsize=(7 * ps.cm, 7 * ps.cm)) scatter_time = 1.20 scatter_chirps = 1.80 size = 10 - bplot1 = ax.boxplot([time_precents, chirps_percents], - showfliers=False, patch_artist=True) + bplot1 = ax.boxplot( + [time_precents, chirps_percents], showfliers=False, patch_artist=True + ) ps.set_boxplot_color(bplot1, ps.gray) - ax.set_xticklabels(['Time \nchasing', 'Chirps \nin chasing']) - ax.set_ylabel('Percent') - ax.scatter(np.ones(len(time_precents))*scatter_time, time_precents, - facecolor=ps.white, s=size) - ax.scatter(np.ones(len(chirps_percents))*scatter_chirps, chirps_percents, - facecolor=ps.white, s=size) + ax.set_xticklabels(["Time \nchasing", "Chirps \nin chasing"]) + ax.set_ylabel("Percent") + ax.scatter( + np.ones(len(time_precents)) * scatter_time, + time_precents, + facecolor=ps.white, + s=size, + ) + ax.scatter( + np.ones(len(chirps_percents)) * scatter_chirps, + chirps_percents, + facecolor=ps.white, + s=size, + ) for i in range(len(time_precents)): - ax.plot([scatter_time, scatter_chirps], [time_precents[i], - chirps_percents[i]], alpha=0.6, linewidth=1, color=ps.white) - - ax.text(0.1, 0.9, f'n={len(time_precents)}', transform=ax.transAxes) + ax.plot( + [scatter_time, scatter_chirps], + [time_precents[i], chirps_percents[i]], + alpha=0.6, + linewidth=1, + color=ps.white, + ) + + ax.text(0.1, 0.9, f"n={len(time_precents)}", transform=ax.transAxes) plt.subplots_adjust(left=0.221, bottom=0.186, right=0.97, top=0.967) - plt.savefig('../poster/figs/chirps_in_chasing.pdf') + plt.savefig("../poster/figs/chirps_in_chasing.pdf") plt.show() -if __name__ == '__main__': +if __name__ == "__main__": # Path to the data - datapath = '../data/mount_data/' + datapath = "../data/mount_data/" main(datapath) diff --git a/code/plot_event_timeline.py b/code/plot_event_timeline.py index cb75cd9..ab408ee 100644 --- a/code/plot_event_timeline.py +++ b/code/plot_event_timeline.py @@ -13,6 +13,7 @@ from modules.plotstyle import PlotStyle from modules.behaviour_handling import Behavior, correct_chasing_events from extract_chirps import get_valid_datasets + ps = PlotStyle() logger = makeLogger(__name__) @@ -20,13 +21,16 @@ logger = makeLogger(__name__) def main(datapath: str): foldernames = [ - datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)] + datapath + x + "/" + for x in os.listdir(datapath) + if os.path.isdir(datapath + x) + ] foldernames, _ = get_valid_datasets(datapath) for foldername in foldernames[3:4]: print(foldername) # foldername = foldernames[0] - if foldername == '../data/mount_data/2020-05-12-10_00/': + if foldername == "../data/mount_data/2020-05-12-10_00/": continue # behabvior is pandas dataframe with all the data bh = Behavior(foldername) @@ -52,18 +56,43 @@ def main(datapath: str): exit() fish1_color = ps.gblue2 fish2_color = ps.gblue1 - fig, ax = plt.subplots(5, 1, figsize=( - 21*ps.cm, 10*ps.cm), height_ratios=[0.5, 0.5, 0.5, 0.2, 6], sharex=True) + fig, ax = plt.subplots( + 5, + 1, + figsize=(21 * ps.cm, 10 * ps.cm), + height_ratios=[0.5, 0.5, 0.5, 0.2, 6], + sharex=True, + ) # marker size s = 80 - ax[0].scatter(physical_contact, np.ones( - len(physical_contact)), color=ps.gray, marker='|', s=s) - ax[1].scatter(chasing_onset, np.ones(len(chasing_onset)), - color=ps.gray, marker='|', s=s) - ax[2].scatter(fish1, np.ones(len(fish1))-0.25, - color=fish1_color, marker='|', s=s) - ax[2].scatter(fish2, np.zeros(len(fish2))+0.25, - color=fish2_color, marker='|', s=s) + ax[0].scatter( + physical_contact, + np.ones(len(physical_contact)), + color=ps.gray, + marker="|", + s=s, + ) + ax[1].scatter( + chasing_onset, + np.ones(len(chasing_onset)), + color=ps.gray, + marker="|", + s=s, + ) + ax[2].scatter( + fish1, + np.ones(len(fish1)) - 0.25, + color=fish1_color, + marker="|", + s=s, + ) + ax[2].scatter( + fish2, + np.zeros(len(fish2)) + 0.25, + color=fish2_color, + marker="|", + s=s, + ) freq_temp = bh.freq[bh.ident == fish1_id] time_temp = bh.time[bh.idx[bh.ident == fish1_id]] @@ -94,35 +123,38 @@ def main(datapath: str): ax[2].set_xticks([]) ps.hide_ax(ax[2]) - ax[4].axvspan(0, 3, 0, 5, facecolor='grey', alpha=0.5) + ax[4].axvspan(0, 3, 0, 5, facecolor="grey", alpha=0.5) ax[4].set_xticks(np.arange(0, 6.1, 0.5)) ps.hide_ax(ax[3]) labelpad = 30 fsize = 12 - ax[0].set_ylabel('Contact', rotation=0, - labelpad=labelpad, fontsize=fsize) + ax[0].set_ylabel( + "Contact", rotation=0, labelpad=labelpad, fontsize=fsize + ) ax[0].yaxis.set_label_coords(-0.062, -0.08) - ax[1].set_ylabel('Chasing', rotation=0, - labelpad=labelpad, fontsize=fsize) + ax[1].set_ylabel( + "Chasing", rotation=0, labelpad=labelpad, fontsize=fsize + ) ax[1].yaxis.set_label_coords(-0.06, -0.08) - ax[2].set_ylabel('Chirps', rotation=0, - labelpad=labelpad, fontsize=fsize) + ax[2].set_ylabel( + "Chirps", rotation=0, labelpad=labelpad, fontsize=fsize + ) ax[2].yaxis.set_label_coords(-0.07, -0.08) - ax[4].set_ylabel('EODf') + ax[4].set_ylabel("EODf") - ax[4].set_xlabel('Time [h]') + ax[4].set_xlabel("Time [h]") # ax[0].set_title(foldername.split('/')[-2]) # 2020-03-31-9_59 plt.subplots_adjust(left=0.158, right=0.987, top=0.918, bottom=0.136) - plt.savefig('../poster/figs/timeline.svg') + plt.savefig("../poster/figs/timeline.svg") plt.show() # plot chirps -if __name__ == '__main__': +if __name__ == "__main__": # Path to the data - datapath = '../data/mount_data/' + datapath = "../data/mount_data/" main(datapath) diff --git a/code/plot_introduction_specs.py b/code/plot_introduction_specs.py index d7e6f4a..0c8e2b4 100644 --- a/code/plot_introduction_specs.py +++ b/code/plot_introduction_specs.py @@ -11,7 +11,6 @@ ps = PlotStyle() def main(): - # Load data datapath = "../data/2022-06-02-10_00/" data = LoadData(datapath) @@ -24,26 +23,31 @@ def main(): timescaler = 1000 - raw = data.raw[window_start_index:window_start_index + - window_duration_index, 10] + raw = data.raw[ + window_start_index : window_start_index + window_duration_index, 10 + ] fig, (ax1, ax2) = plt.subplots( - 1, 2, figsize=(21 * ps.cm, 8*ps.cm), sharex=True, sharey=True) + 1, 2, figsize=(21 * ps.cm, 8 * ps.cm), sharex=True, sharey=True + ) # plot instantaneous frequency filtered1 = bandpass_filter( - signal=raw, lowf=750, highf=1200, samplerate=data.raw_rate) + signal=raw, lowf=750, highf=1200, samplerate=data.raw_rate + ) filtered2 = bandpass_filter( - signal=raw, lowf=550, highf=700, samplerate=data.raw_rate) + signal=raw, lowf=550, highf=700, samplerate=data.raw_rate + ) freqtime1, freq1 = instantaneous_frequency( - filtered1, data.raw_rate, smoothing_window=3) + filtered1, data.raw_rate, smoothing_window=3 + ) freqtime2, freq2 = instantaneous_frequency( - filtered2, data.raw_rate, smoothing_window=3) + filtered2, data.raw_rate, smoothing_window=3 + ) - ax1.plot(freqtime1*timescaler, freq1, color=ps.g, lw=2, label="Fish 1") - ax1.plot(freqtime2*timescaler, freq2, color=ps.gray, - lw=2, label="Fish 2") + ax1.plot(freqtime1 * timescaler, freq1, color=ps.g, lw=2, label="Fish 1") + ax1.plot(freqtime2 * timescaler, freq2, color=ps.gray, lw=2, label="Fish 2") # ax.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0) # # ps.hide_xax(ax1) @@ -62,8 +66,8 @@ def main(): ax1.imshow( decibel(spec_power[fmask, :]), extent=[ - spec_times[0]*timescaler, - spec_times[-1]*timescaler, + spec_times[0] * timescaler, + spec_times[-1] * timescaler, spec_freqs[fmask][0], spec_freqs[fmask][-1], ], @@ -87,8 +91,8 @@ def main(): ax2.imshow( decibel(spec_power[fmask, :]), extent=[ - spec_times[0]*timescaler, - spec_times[-1]*timescaler, + spec_times[0] * timescaler, + spec_times[-1] * timescaler, spec_freqs[fmask][0], spec_freqs[fmask][-1], ], @@ -98,9 +102,8 @@ def main(): alpha=1, ) # ps.hide_xax(ax3) - ax2.plot(freqtime1*timescaler, freq1, color=ps.g, lw=2, label="_") - ax2.plot(freqtime2*timescaler, freq2, color=ps.gray, - lw=2, label="_") + ax2.plot(freqtime1 * timescaler, freq1, color=ps.g, lw=2, label="_") + ax2.plot(freqtime2 * timescaler, freq2, color=ps.gray, lw=2, label="_") ax2.set_xlim(75, 200) ax1.set_ylim(400, 1200) @@ -109,15 +112,22 @@ def main(): fig.supylabel("Frequency [Hz]", fontsize=14) handles, labels = ax1.get_legend_handles_labels() - ax2.legend(handles, labels, bbox_to_anchor=(1.04, 1), loc="upper left", ncol=1,) + ax2.legend( + handles, + labels, + bbox_to_anchor=(1.04, 1), + loc="upper left", + ncol=1, + ) ps.letter_subplots(xoffset=[-0.27, -0.1], yoffset=1.05) - plt.subplots_adjust(left=0.12, right=0.85, top=0.89, - bottom=0.18, hspace=0.35) + plt.subplots_adjust( + left=0.12, right=0.85, top=0.89, bottom=0.18, hspace=0.35 + ) - plt.savefig('../poster/figs/introplot.pdf') + plt.savefig("../poster/figs/introplot.pdf") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/code/plot_kdes.py b/code/plot_kdes.py index 0f3082b..bc1bc98 100644 --- a/code/plot_kdes.py +++ b/code/plot_kdes.py @@ -1,7 +1,9 @@ - from modules.plotstyle import PlotStyle from modules.behaviour_handling import ( - Behavior, correct_chasing_events, center_chirps) + Behavior, + correct_chasing_events, + center_chirps, +) from modules.datahandling import flatten, causal_kde1d, acausal_kde1d from modules.logger import makeLogger from pandas import read_csv @@ -18,80 +20,93 @@ logger = makeLogger(__name__) ps = PlotStyle() -def bootstrap(data, nresamples, kde_time, kernel_width, event_times, time_before, time_after): - +def bootstrap( + data, + nresamples, + kde_time, + kernel_width, + event_times, + time_before, + time_after, +): bootstrapped_kdes = [] - data = data[data <= 3*60*60] # only night time + data = data[data <= 3 * 60 * 60] # only night time diff_data = np.diff(np.sort(data), prepend=0) # if len(data) != 0: # mean_chirprate = (len(data) - 1) / (data[-1] - data[0]) for i in tqdm(range(nresamples)): - np.random.shuffle(diff_data) bootstrapped_data = np.cumsum(diff_data) # bootstrapped_data = data + np.random.randn(len(data)) * 10 bootstrap_data_centered = center_chirps( - bootstrapped_data, event_times, time_before, time_after) + bootstrapped_data, event_times, time_before, time_after + ) bootstrapped_kde = acausal_kde1d( - bootstrap_data_centered, time=kde_time, width=kernel_width) + bootstrap_data_centered, time=kde_time, width=kernel_width + ) - bootstrapped_kde = list(np.asarray( - bootstrapped_kde) / len(event_times)) + bootstrapped_kde = list(np.asarray(bootstrapped_kde) / len(event_times)) bootstrapped_kdes.append(bootstrapped_kde) return bootstrapped_kdes -def jackknife(data, nresamples, subsetsize, kde_time, kernel_width, event_times, time_before, time_after): - +def jackknife( + data, + nresamples, + subsetsize, + kde_time, + kernel_width, + event_times, + time_before, + time_after, +): jackknife_kdes = [] - data = data[data <= 3*60*60] # only night time + data = data[data <= 3 * 60 * 60] # only night time subsetsize = int(len(data) * subsetsize) diff_data = np.diff(np.sort(data), prepend=0) for i in tqdm(range(nresamples)): - - jackknifed_data = np.random.choice( - diff_data, subsetsize, replace=False) + jackknifed_data = np.random.choice(diff_data, subsetsize, replace=False) jackknifed_data = np.cumsum(jackknifed_data) jackknifed_data_centered = center_chirps( - jackknifed_data, event_times, time_before, time_after) + jackknifed_data, event_times, time_before, time_after + ) jackknifed_kde = acausal_kde1d( - jackknifed_data_centered, time=kde_time, width=kernel_width) + jackknifed_data_centered, time=kde_time, width=kernel_width + ) - jackknifed_kde = list(np.asarray( - jackknifed_kde) / len(event_times)) + jackknifed_kde = list(np.asarray(jackknifed_kde) / len(event_times)) jackknife_kdes.append(jackknifed_kde) return jackknife_kdes def get_chirp_winner_loser(folder_name, Behavior, order_meta_df): - - foldername = folder_name.split('/')[-2] - winner_row = order_meta_df[order_meta_df['recording'] == foldername] - winner = winner_row['winner'].values[0].astype(int) - winner_fish1 = winner_row['fish1'].values[0].astype(int) - winner_fish2 = winner_row['fish2'].values[0].astype(int) + foldername = folder_name.split("/")[-2] + winner_row = order_meta_df[order_meta_df["recording"] == foldername] + winner = winner_row["winner"].values[0].astype(int) + winner_fish1 = winner_row["fish1"].values[0].astype(int) + winner_fish2 = winner_row["fish2"].values[0].astype(int) if winner > 0: if winner == winner_fish1: - winner_fish_id = winner_row['rec_id1'].values[0] - loser_fish_id = winner_row['rec_id2'].values[0] + winner_fish_id = winner_row["rec_id1"].values[0] + loser_fish_id = winner_row["rec_id2"].values[0] elif winner == winner_fish2: - winner_fish_id = winner_row['rec_id2'].values[0] - loser_fish_id = winner_row['rec_id1'].values[0] + winner_fish_id = winner_row["rec_id2"].values[0] + loser_fish_id = winner_row["rec_id1"].values[0] chirp_winner = Behavior.chirps[Behavior.chirps_ids == winner_fish_id] chirp_loser = Behavior.chirps[Behavior.chirps_ids == loser_fish_id] @@ -101,7 +116,6 @@ def get_chirp_winner_loser(folder_name, Behavior, order_meta_df): def main(dataroot): - foldernames, _ = np.asarray(get_valid_datasets(dataroot)) plot_all = True time_before = 90 @@ -111,10 +125,9 @@ def main(dataroot): kde_time = np.arange(-time_before, time_after, dt) nbootstraps = 50 - meta_path = ( - '/').join(foldernames[0].split('/')[:-2]) + '/order_meta.csv' + meta_path = ("/").join(foldernames[0].split("/")[:-2]) + "/order_meta.csv" meta = pd.read_csv(meta_path) - meta['recording'] = meta['recording'].str[1:-1] + meta["recording"] = meta["recording"].str[1:-1] winner_onsets = [] winner_offsets = [] @@ -143,24 +156,24 @@ def main(dataroot): # loser_onset_chirpcount = 0 # loser_offset_chirpcount = 0 # loser_physical_chirpcount = 0 - fig, ax = plt.subplots(1, 2, figsize=( - 14 * ps.cm, 7*ps.cm), sharey=True, sharex=True) + fig, ax = plt.subplots( + 1, 2, figsize=(14 * ps.cm, 7 * ps.cm), sharey=True, sharex=True + ) # Iterate over all recordings and save chirp- and event-timestamps good_recs = np.asarray([0, 15]) for i, folder in tqdm(enumerate(foldernames[good_recs])): - - foldername = folder.split('/')[-2] + foldername = folder.split("/")[-2] # logger.info('Loading data from folder: {}'.format(foldername)) - broken_folders = ['../data/mount_data/2020-05-12-10_00/'] + broken_folders = ["../data/mount_data/2020-05-12-10_00/"] if folder in broken_folders: continue bh = Behavior(folder) category, timestamps = correct_chasing_events(bh.behavior, bh.start_s) - category = category[timestamps < 3*60*60] # only night time - timestamps = timestamps[timestamps < 3*60*60] # only night time + category = category[timestamps < 3 * 60 * 60] # only night time + timestamps = timestamps[timestamps < 3 * 60 * 60] # only night time winner, loser = get_chirp_winner_loser(folder, bh, meta) if winner is None: @@ -168,27 +181,33 @@ def main(dataroot): # winner_count += len(winner) # loser_count += len(loser) - onsets = (timestamps[category == 0]) - offsets = (timestamps[category == 1]) - physicals = (timestamps[category == 2]) + onsets = timestamps[category == 0] + offsets = timestamps[category == 1] + physicals = timestamps[category == 2] onset_count += len(onsets) offset_count += len(offsets) physical_count += len(physicals) - winner_onsets.append(center_chirps( - winner, onsets, time_before, time_after)) - winner_offsets.append(center_chirps( - winner, offsets, time_before, time_after)) - winner_physicals.append(center_chirps( - winner, physicals, time_before, time_after)) - - loser_onsets.append(center_chirps( - loser, onsets, time_before, time_after)) - loser_offsets.append(center_chirps( - loser, offsets, time_before, time_after)) - loser_physicals.append(center_chirps( - loser, physicals, time_before, time_after)) + winner_onsets.append( + center_chirps(winner, onsets, time_before, time_after) + ) + winner_offsets.append( + center_chirps(winner, offsets, time_before, time_after) + ) + winner_physicals.append( + center_chirps(winner, physicals, time_before, time_after) + ) + + loser_onsets.append( + center_chirps(loser, onsets, time_before, time_after) + ) + loser_offsets.append( + center_chirps(loser, offsets, time_before, time_after) + ) + loser_physicals.append( + center_chirps(loser, physicals, time_before, time_after) + ) # winner_onset_chirpcount += len(winner_onsets[-1]) # winner_offset_chirpcount += len(winner_offsets[-1]) @@ -232,14 +251,17 @@ def main(dataroot): # event_times=onsets, # time_before=time_before, # time_after=time_after)) - loser_offsets_boot.append(bootstrap( - loser, - nresamples=nbootstraps, - kde_time=kde_time, - kernel_width=kernel_width, - event_times=offsets, - time_before=time_before, - time_after=time_after)) + loser_offsets_boot.append( + bootstrap( + loser, + nresamples=nbootstraps, + kde_time=kde_time, + kernel_width=kernel_width, + event_times=offsets, + time_before=time_before, + time_after=time_after, + ) + ) # loser_physicals_boot.append(bootstrap( # loser, # nresamples=nbootstraps, @@ -249,18 +271,17 @@ def main(dataroot): # time_before=time_before, # time_after=time_after)) -# loser_offsets_jackknife = jackknife( -# loser, -# nresamples=nbootstraps, -# subsetsize=0.9, -# kde_time=kde_time, -# kernel_width=kernel_width, -# event_times=offsets, -# time_before=time_before, -# time_after=time_after) + # loser_offsets_jackknife = jackknife( + # loser, + # nresamples=nbootstraps, + # subsetsize=0.9, + # kde_time=kde_time, + # kernel_width=kernel_width, + # event_times=offsets, + # time_before=time_before, + # time_after=time_after) if plot_all: - # winner_onsets_conv = acausal_kde1d( # winner_onsets[-1], kde_time, kernel_width) # winner_offsets_conv = acausal_kde1d( @@ -271,24 +292,35 @@ def main(dataroot): # loser_onsets_conv = acausal_kde1d( # loser_onsets[-1], kde_time, kernel_width) loser_offsets_conv = acausal_kde1d( - loser_offsets[-1], kde_time, kernel_width) + loser_offsets[-1], kde_time, kernel_width + ) # loser_physicals_conv = acausal_kde1d( # loser_physicals[-1], kde_time, kernel_width) - ax[i].plot(kde_time, loser_offsets_conv / - len(offsets), lw=2, zorder=100, c=ps.gblue1) + ax[i].plot( + kde_time, + loser_offsets_conv / len(offsets), + lw=2, + zorder=100, + c=ps.gblue1, + ) ax[i].fill_between( kde_time, np.percentile(loser_offsets_boot[-1], 1, axis=0), np.percentile(loser_offsets_boot[-1], 99, axis=0), - color='gray', - alpha=0.8) + color="gray", + alpha=0.8, + ) - ax[i].plot(kde_time, np.median(loser_offsets_boot[-1], axis=0), - color=ps.black, linewidth=2) + ax[i].plot( + kde_time, + np.median(loser_offsets_boot[-1], axis=0), + color=ps.black, + linewidth=2, + ) - ax[i].axvline(0, color=ps.gray, linestyle='--') + ax[i].axvline(0, color=ps.gray, linestyle="--") # ax[i].fill_between( # kde_time, @@ -300,8 +332,8 @@ def main(dataroot): # color=ps.white, linewidth=2) ax[i].set_xlim(-60, 60) - fig.supylabel('Chirp rate (a.u.)', fontsize=14) - fig.supxlabel('Time (s)', fontsize=14) + fig.supylabel("Chirp rate (a.u.)", fontsize=14) + fig.supxlabel("Time (s)", fontsize=14) # fig, ax = plt.subplots(2, 3, figsize=( # 21*ps.cm, 10*ps.cm), sharey=True, sharex=True) @@ -521,9 +553,9 @@ def main(dataroot): # color=ps.gray, # alpha=0.5) plt.subplots_adjust(bottom=0.21, top=0.93) - plt.savefig('../poster/figs/kde.pdf') + plt.savefig("../poster/figs/kde.pdf") plt.show() -if __name__ == '__main__': - main('../data/mount_data/') +if __name__ == "__main__": + main("../data/mount_data/") From 95a256b517174826b4b22519aa43c39298ef4777 Mon Sep 17 00:00:00 2001 From: weygoldt <88969563+weygoldt@users.noreply.github.com> Date: Tue, 11 Apr 2023 15:44:15 +0200 Subject: [PATCH 3/5] something changed! --- code/chirpdetection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/code/chirpdetection.py b/code/chirpdetection.py index 937bde4..7afaa0c 100755 --- a/code/chirpdetection.py +++ b/code/chirpdetection.py @@ -24,7 +24,6 @@ from modules.datahandling import ( ) logger = makeLogger(__name__) - ps = PlotStyle() From c4f05d989123ae02da3eb406b433d40a5a950d5a Mon Sep 17 00:00:00 2001 From: weygoldt <88969563+weygoldt@users.noreply.github.com> Date: Thu, 13 Apr 2023 15:14:10 +0200 Subject: [PATCH 4/5] added notebook --- .../chirp_exploration.ipynb | 389 ++++++++++++++++++ chirp_instantaneous_freq/test_parameters.py | 14 +- 2 files changed, 397 insertions(+), 6 deletions(-) create mode 100644 chirp_instantaneous_freq/chirp_exploration.ipynb diff --git a/chirp_instantaneous_freq/chirp_exploration.ipynb b/chirp_instantaneous_freq/chirp_exploration.ipynb new file mode 100644 index 0000000..8c27cf3 --- /dev/null +++ b/chirp_instantaneous_freq/chirp_exploration.ipynb @@ -0,0 +1,389 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Why can the instantaneous frequency of a band-pass filtered chirp recording go down ...\n", + "... if a chirp is an up-modulation of the frequency? \n", + "\n", + "This is the question we try to answer in this notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "QApplication: invalid style override passed, ignoring it.\n", + " Available styles: Windows, Fusion\n" + ] + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import thunderfish.fakefish as ff \n", + "from filters import instantaneous_frequency, bandpass_filter\n", + "%matplotlib qt\n", + "\n", + "# parameters that stay the same\n", + "samplerate = 20000\n", + "duration = 0.2\n", + "chirp_freq = 5\n", + "smooth = 3" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "qt.qpa.wayland: Wayland does not support QWindow::requestActivate()\n" + ] + } + ], + "source": [ + "def make_chirp(eodf, size, width, kurtosis, contrast, phase0):\n", + "\n", + " chirp_trace, amp = ff.chirps(\n", + " eodf = eodf,\n", + " samplerate = samplerate,\n", + " duration = duration,\n", + " chirp_freq = chirp_freq,\n", + " chirp_size = size,\n", + " chirp_width = width,\n", + " chirp_kurtosis = kurtosis,\n", + " chirp_contrast = contrast,\n", + " )\n", + "\n", + " chirp = ff.wavefish_eods(\n", + " fish = 'Alepto',\n", + " frequency = chirp_trace,\n", + " samplerate = samplerate,\n", + " duration = duration,\n", + " phase0 = phase0,\n", + " noise_std = 0,\n", + " )\n", + "\n", + " chirp *= amp\n", + "\n", + " return chirp_trace, chirp\n", + "\n", + "def filtered_chirp(eodf, size, width, kurtosis, contrast, phase0):\n", + "\n", + " time = np.arange(0, duration, 1/samplerate)\n", + " chirp_trace, chirp = make_chirp(\n", + " eodf = eodf, \n", + " size = size, \n", + " width = width, \n", + " kurtosis = kurtosis, \n", + " contrast = contrast, \n", + " phase0 = phase0,\n", + " )\n", + "\n", + " # apply filters\n", + " narrow_filtered = bandpass_filter(chirp, samplerate, eodf-10, eodf+10)\n", + " narrow_freqtime, narrow_freq = instantaneous_frequency(narrow_filtered, samplerate, smooth)\n", + " broad_filtered = bandpass_filter(chirp, samplerate, eodf-300, eodf+300)\n", + " broad_freqtime, broad_freq = instantaneous_frequency(broad_filtered, samplerate, smooth)\n", + "\n", + " original = (time, chirp_trace, chirp)\n", + " broad = (broad_freqtime, broad_freq, broad_filtered)\n", + " narrow = (narrow_freqtime, narrow_freq, narrow_filtered)\n", + "\n", + " return original, broad, narrow\n", + "\n", + "def plot(original, broad, narrow, axs):\n", + "\n", + " axs[0].plot(original[0], original[1], label='chirp trace')\n", + " axs[0].plot(broad[0], broad[1], label='broad filtered')\n", + " axs[0].plot(narrow[0], narrow[1], label='narrow filtered')\n", + " axs[1].plot(original[0], original[2], label='unfiltered')\n", + " axs[1].plot(original[0], broad[2], label='broad filtered')\n", + " axs[1].plot(original[0], narrow[2], label='narrow filtered')\n", + "\n", + "original, broad, narrow = filtered_chirp(600, 100, 0.02, 1, 0.1, 0)\n", + "fig, axs = plt.subplots(2, 1, figsize=(10, 5), sharex=True)\n", + "plot(original, broad, narrow, axs)\n", + "fig.align_labels()\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chirp size\n", + "now that we have established an easy way to simulate and plot the chirps, lets change the chirp size and see how the narrow-filtered instantaneous frequency changes." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "size 10 Hz; Integral 0.117\n", + "size 30 Hz; Integral 0.35\n", + "size 50 Hz; Integral 0.584\n", + "size 70 Hz; Integral 0.818\n", + "size 90 Hz; Integral 1.051\n", + "size 110 Hz; Integral 1.285\n", + "size 130 Hz; Integral 1.518\n", + "size 150 Hz; Integral 1.752\n", + "size 170 Hz; Integral 1.986\n", + "size 190 Hz; Integral 2.219\n", + "size 210 Hz; Integral 2.453\n", + "size 230 Hz; Integral 2.687\n", + "size 250 Hz; Integral 2.92\n", + "size 270 Hz; Integral 3.154\n", + "size 290 Hz; Integral 3.387\n", + "size 310 Hz; Integral 3.621\n", + "size 330 Hz; Integral 3.855\n", + "size 350 Hz; Integral 4.088\n", + "size 370 Hz; Integral 4.322\n", + "size 390 Hz; Integral 4.555\n", + "size 410 Hz; Integral 4.789\n", + "size 430 Hz; Integral 5.023\n", + "size 450 Hz; Integral 5.256\n", + "size 470 Hz; Integral 5.49\n", + "size 490 Hz; Integral 5.724\n", + "size 510 Hz; Integral 5.957\n", + "size 530 Hz; Integral 6.191\n", + "size 550 Hz; Integral 6.424\n", + "size 570 Hz; Integral 6.658\n", + "size 590 Hz; Integral 6.892\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "qt.qpa.wayland: Wayland does not support QWindow::requestActivate()\n" + ] + } + ], + "source": [ + "sizes = np.arange(10, 600, 20)\n", + "fig, axs = plt.subplots(2, len(sizes), figsize=(10, 5), sharex=True, sharey='row')\n", + "integrals = []\n", + "\n", + "for i, size in enumerate(sizes):\n", + " original, broad, narrow = filtered_chirp(600, size, 0.02, 1, 0.1, 0)\n", + "\n", + " integral = np.sum(original[1]-600)/(20000)\n", + " integrals.append(integral)\n", + "\n", + " plot(original, broad, narrow, axs[:, i])\n", + " axs[:, i][0].set_xlim(0.06, 0.14)\n", + " axs[0, i].set_title(np.round(integral, 3))\n", + " print(f'size {size} Hz; Integral {np.round(integral,3)}')\n", + " \n", + "fig.legend(handles=axs[0,0].get_lines(), loc='upper center', ncol=3)\n", + "axs[0,0].set_ylabel('frequency [Hz]')\n", + "axs[1,0].set_ylabel('amplitude [a.u.]')\n", + "fig.supxlabel('time [s]')\n", + "fig.align_labels()\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chirp width" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "widths = np.arange(0.02, 0.08, 0.005)\n", + "fig, axs = plt.subplots(2, len(widths), figsize=(10, 5), sharex=True, sharey='row')\n", + "integrals = []\n", + "\n", + "for i, width in enumerate(widths):\n", + " if i > 9:\n", + " break\n", + "\n", + " original, broad, narrow = filtered_chirp(600, 100, width, 1, 0.1, 0)\n", + "\n", + " integral = np.sum(original[1]-600)/(20000)\n", + "\n", + " plot(original, broad, narrow, axs[:, i])\n", + " axs[:, i][0].set_xlim(0.06, 0.14)\n", + " axs[0, i].set_title(f'width {np.round(width, 2)} s')\n", + " print(f'width {width} s; Integral {np.round(integral, 3)}')\n", + " \n", + "fig.legend(handles=axs[0,0].get_lines(), loc='upper center', ncol=3)\n", + "axs[0,0].set_ylabel('frequency [Hz]')\n", + "axs[1,0].set_ylabel('amplitude [a.u.]')\n", + "fig.supxlabel('time [s]')\n", + "fig.align_labels()\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chirp kurtosis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "kurtosiss = np.arange(0, 20, 1.6)\n", + "fig, axs = plt.subplots(2, len(kurtosiss), figsize=(10, 5), sharex=True, sharey='row')\n", + "integrals = []\n", + "\n", + "for i, kurtosis in enumerate(kurtosiss):\n", + "\n", + " original, broad, narrow = filtered_chirp(600, 100, 0.02, kurtosis, 0.1, 0)\n", + "\n", + " integral = np.sum(original[1]-600)/(20000)\n", + "\n", + " plot(original, broad, narrow, axs[:, i])\n", + " axs[:, i][0].set_xlim(0.06, 0.14)\n", + " axs[0, i].set_title(f'kurt {np.round(kurtosis, 2)}')\n", + " print(f'kurt {kurtosis}; Integral {np.round(integral, 3)}')\n", + " \n", + "fig.legend(handles=axs[0,0].get_lines(), loc='upper center', ncol=3)\n", + "axs[0,0].set_ylabel('frequency [Hz]')\n", + "axs[1,0].set_ylabel('amplitude [a.u.]')\n", + "fig.supxlabel('time [s]')\n", + "fig.align_labels()\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chirp contrast" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "contrasts = np.arange(0.0, 1.1, 0.1)\n", + "fig, axs = plt.subplots(2, len(sizes), figsize=(10, 5), sharex=True, sharey='row')\n", + "integrals = []\n", + "\n", + "for i, contrast in enumerate(contrasts):\n", + " if i > 9:\n", + " break\n", + " original, broad, narrow = filtered_chirp(600, 100, 0.02, 1, contrast, 0)\n", + "\n", + " integral = np.trapz(original[2], original[0])\n", + " integrals.append(integral)\n", + "\n", + " plot(original, broad, narrow, axs[:, i])\n", + " axs[:, i][0].set_xlim(0.06, 0.14)\n", + " axs[0, i].set_title(f'contr {np.round(contrast, 2)}')\n", + " \n", + "fig.legend(handles=axs[0,0].get_lines(), loc='upper center', ncol=3)\n", + "axs[0,0].set_ylabel('frequency [Hz]')\n", + "axs[1,0].set_ylabel('amplitude [a.u.]')\n", + "fig.supxlabel('time [s]')\n", + "fig.align_labels()\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chirp phase " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "phases = np.arange(0.0, 2 * np.pi, 0.2)\n", + "fig, axs = plt.subplots(2, len(sizes), figsize=(10, 5), sharex=True, sharey='row')\n", + "integrals = []\n", + "for i, phase in enumerate(phases):\n", + " if i > 9:\n", + " break\n", + "\n", + " original, broad, narrow = filtered_chirp(600, 100, 0.02, 1, 0.1, phase)\n", + "\n", + " integral = np.trapz(original[2], original[0])\n", + " integrals.append(integral)\n", + "\n", + " plot(original, broad, narrow, axs[:, i])\n", + " axs[:, i][0].set_xlim(0.06, 0.14)\n", + " axs[0, i].set_title(f'phase {np.round(phase, 2)}')\n", + "\n", + " \n", + "fig.legend(handles=axs[0,0].get_lines(), loc='upper center', ncol=3)\n", + "axs[0,0].set_ylabel('frequency [Hz]')\n", + "axs[1,0].set_ylabel('amplitude [a.u.]')\n", + "fig.supxlabel('time [s]')\n", + "fig.align_labels()\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "These experiments show, that the narrow filtered instantaneous freuqency only switches its sign, when the integral of the instantaneous frequency (that was used to make the signal)\n", + "changes. Specifically, when the instantaneous frequency is 0.57, 1.57, 2.57 etc., the sign swithes. " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "chirpdetection", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/chirp_instantaneous_freq/test_parameters.py b/chirp_instantaneous_freq/test_parameters.py index bad1e45..7348496 100644 --- a/chirp_instantaneous_freq/test_parameters.py +++ b/chirp_instantaneous_freq/test_parameters.py @@ -1,7 +1,7 @@ -import numpy as np import matplotlib.pyplot as plt +import numpy as np +from filters import bandpass_filter, inst_freq, instantaneous_frequency from fish_signal import chirps, wavefish_eods -from filters import bandpass_filter, instantaneous_frequency, inst_freq from IPython import embed @@ -28,13 +28,14 @@ def extract_dict(dict, index): return {key: value[index] for key, value in dict.items()} -def main(test1, test2, resolution=10): +def test(test1, test2, resolution=10): assert test1 in [ "width", "size", "kurtosis", "contrast", ], "Test1 not recognized" + assert test2 in [ "width", "size", @@ -139,10 +140,11 @@ def main(test1, test2, resolution=10): iter0 += 1 - fig, ax = plt.subplots() - ax.imshow(distances, cmap="jet") plt.show() +def main(): + test("contrast", "kurtosis") + if __name__ == "__main__": - main("width", "size") + main() From 86d05bdb80cb8c036582d4f537bf48ccb6d694aa Mon Sep 17 00:00:00 2001 From: Patrick Weygoldt <88969563+weygoldt@users.noreply.github.com> Date: Fri, 14 Apr 2023 20:53:16 +0200 Subject: [PATCH 5/5] Update README.md --- README.md | 32 ++------------------------------ 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 8e79b27..d35ed7e 100644 --- a/README.md +++ b/README.md @@ -37,40 +37,12 @@ ## About The Project -[![Product Name Screen Shot][product-screenshot]](https://example.com) +Chirps are transient communication singals of many wave-type electric fish. Because they are so fast, detecting them when the recorded signal includes multiple individuals is hard. But to understand if, and what kind of information they transmit in a natural setting, analyzing chirps in multiple freely interacting individual is nessecary. This repository documents an approach to detect these signals on electrode grid recordings with many freely behaving individuals. -Here's a blank template to get started: To avoid retyping too much info. Do a search and replace with your text editor for the following: `github_username`, `repo_name`, `twitter_handle`, `linkedin_username`, `email_client`, `email`, `project_title`, `project_description` +The majority of the code and its tests were part of a lab rotation with the [Neuroethology](https://github.com/bendalab) at the University of Tuebingen. It also contains a [poster](poster_printed/main.pdf) and a more thorough [lab protocol](protocol/main.pdf).
- -## Getting Started - -This is an example of how you may give instructions on setting up your project locally. -To get a local copy up and running follow these simple example steps. - - - - -## Usage - -Use this space to show useful examples of how a project can be used. Additional screenshots, code examples and demos work well in this space. You may also link to more resources. - -_For more examples, please refer to the [Documentation](https://example.com)_ - - - - -## To do - -- [ ] Feature 1 -- [ ] Feature 2 -- [ ] Feature 3 - - [ ] Nested Feature - - - -