diff --git a/chirp_instantaneous_freq/filters.py b/chirp_instantaneous_freq/filters.py index 709b140..a0d4001 100644 --- a/chirp_instantaneous_freq/filters.py +++ b/chirp_instantaneous_freq/filters.py @@ -59,14 +59,14 @@ def instantaneous_frequency( def inst_freq(signal, fs): """ Computes the instantaneous frequency of a periodic signal using zero-crossings. - + Parameters: ----------- signal : array-like The input signal. fs : float The sampling frequency of the input signal. - + Returns: -------- freq : array-like @@ -74,29 +74,30 @@ def inst_freq(signal, fs): """ # Compute the sign of the signal sign = np.sign(signal) - + # Compute the crossings of the sign signal with a zero line crossings = np.where(np.diff(sign))[0] - + # Compute the time differences between zero crossings dt = np.diff(crossings) / fs - + # Compute the instantaneous frequency as the reciprocal of the time differences freq = 1 / dt - # Gaussian filter the signal + # Gaussian filter the signal freq = gaussian_filter1d(freq, 10) - + # Pad the frequency vector with zeros to match the length of the input signal freq = np.pad(freq, (0, len(signal) - len(freq))) - + return freq + def bandpass_filter( - signal: np.ndarray, - samplerate: float, - lowf: float, - highf: float, + signal: np.ndarray, + samplerate: float, + lowf: float, + highf: float, ) -> np.ndarray: """Bandpass filter a signal. @@ -150,9 +151,7 @@ def highpass_filter( def lowpass_filter( - signal: np.ndarray, - samplerate: float, - cutoff: float + signal: np.ndarray, samplerate: float, cutoff: float ) -> np.ndarray: """Lowpass filter a signal. @@ -176,10 +175,9 @@ def lowpass_filter( return filtered_signal -def envelope(signal: np.ndarray, - samplerate: float, - cutoff_frequency: float - ) -> np.ndarray: +def envelope( + signal: np.ndarray, samplerate: float, cutoff_frequency: float +) -> np.ndarray: """Calculate the envelope of a signal using a lowpass filter. Parameters diff --git a/chirp_instantaneous_freq/fish_signal.py b/chirp_instantaneous_freq/fish_signal.py index bf740b6..d7830bf 100644 --- a/chirp_instantaneous_freq/fish_signal.py +++ b/chirp_instantaneous_freq/fish_signal.py @@ -384,16 +384,14 @@ def chirps( frequency = eodf * np.ones(n) am = np.ones(n) - for time, width, size, kurtosis, contrast in zip(chirp_times, chirp_width, chirp_size, chirp_kurtosis, chirp_contrast): - + for time, width, size, kurtosis, contrast in zip( + chirp_times, chirp_width, chirp_size, chirp_kurtosis, chirp_contrast + ): # chirp frequency waveform: chirp_t = np.arange(-2.0 * width, 2.0 * width, 1.0 / samplerate) - chirp_sig = ( - 0.5 * width / (2.0 * np.log(10.0)) ** (0.5 / kurtosis) - ) + chirp_sig = 0.5 * width / (2.0 * np.log(10.0)) ** (0.5 / kurtosis) gauss = np.exp(-0.5 * ((chirp_t / chirp_sig) ** 2.0) ** kurtosis) - # add chirps on baseline eodf: index = int(time * samplerate) i0 = index - len(gauss) // 2 @@ -433,7 +431,7 @@ def rises( Sampling rate in Hertz. duration: float Duration of the generated data in seconds. - rise_times: list + rise_times: list Timestamp of each of the rises in seconds. rise_size: list Size of the respective rise (frequency increase above eodf) in Hertz. @@ -452,15 +450,12 @@ def rises( # baseline eod frequency: frequency = eodf * np.ones(n) - for time, size, riset, decayt in zip(rise_times, rise_size, rise_tau, decay_tau): - + for time, size, riset, decayt in zip( + rise_times, rise_size, rise_tau, decay_tau + ): # rise frequency waveform: rise_t = np.arange(0.0, 5.0 * decayt, 1.0 / samplerate) - rise = ( - size - * (1.0 - np.exp(-rise_t / riset)) - * np.exp(-rise_t / decayt) - ) + rise = size * (1.0 - np.exp(-rise_t / riset)) * np.exp(-rise_t / decayt) # add rises on baseline eodf: index = int(time * samplerate) @@ -472,13 +467,14 @@ def rises( frequency[index : index + len(rise)] += rise return frequency + class FishSignal: def __init__(self, samplerate, duration, eodf, nchirps, nrises): time = np.arange(0, duration, 1 / samplerate) chirp_times = np.random.uniform(0, duration, nchirps) rise_times = np.random.uniform(0, duration, nrises) - # pick random parameters for chirps + # pick random parameters for chirps chirp_size = np.random.uniform(60, 200, nchirps) chirp_width = np.random.uniform(0.01, 0.1, nchirps) chirp_kurtosis = np.random.uniform(1, 1, nchirps) @@ -534,7 +530,6 @@ class FishSignal: self.eodf = eodf def visualize(self): - spec, freqs, spectime = ps.spectrogram( data=self.signal, ratetime=self.samplerate, @@ -549,7 +544,12 @@ class FishSignal: ax1.set_xlabel("Time (s)") ax1.set_title("EOD signal") - ax2.imshow(ps.decibel(spec), origin='lower', aspect="auto", extent=[spectime[0], spectime[-1], freqs[0], freqs[-1]]) + ax2.imshow( + ps.decibel(spec), + origin="lower", + aspect="auto", + extent=[spectime[0], spectime[-1], freqs[0], freqs[-1]], + ) ax2.set_ylabel("Frequency (Hz)") ax2.set_xlabel("Time (s)") ax2.set_title("Spectrogram") diff --git a/chirp_instantaneous_freq/test_parameters.py b/chirp_instantaneous_freq/test_parameters.py index 9c4ab5f..bad1e45 100644 --- a/chirp_instantaneous_freq/test_parameters.py +++ b/chirp_instantaneous_freq/test_parameters.py @@ -1,4 +1,4 @@ -import numpy as np +import numpy as np import matplotlib.pyplot as plt from fish_signal import chirps, wavefish_eods from filters import bandpass_filter, instantaneous_frequency, inst_freq @@ -6,18 +6,18 @@ from IPython import embed def switch_test(test, defaultparams, testparams): - if test == 'width': - defaultparams['chirp_width'] = testparams['chirp_width'] - key = 'chirp_width' - elif test == 'size': - defaultparams['chirp_size'] = testparams['chirp_size'] - key = 'chirp_size' - elif test == 'kurtosis': - defaultparams['chirp_kurtosis'] = testparams['chirp_kurtosis'] - key = 'chirp_kurtosis' - elif test == 'contrast': - defaultparams['chirp_contrast'] = testparams['chirp_contrast'] - key = 'chirp_contrast' + if test == "width": + defaultparams["chirp_width"] = testparams["chirp_width"] + key = "chirp_width" + elif test == "size": + defaultparams["chirp_size"] = testparams["chirp_size"] + key = "chirp_size" + elif test == "kurtosis": + defaultparams["chirp_kurtosis"] = testparams["chirp_kurtosis"] + key = "chirp_kurtosis" + elif test == "contrast": + defaultparams["chirp_contrast"] = testparams["chirp_contrast"] + key = "chirp_contrast" else: raise ValueError("Test not recognized") @@ -29,31 +29,40 @@ def extract_dict(dict, index): def main(test1, test2, resolution=10): - - assert test1 in ['width', 'size', 'kurtosis', 'contrast'], "Test1 not recognized" - assert test2 in ['width', 'size', 'kurtosis', 'contrast'], "Test2 not recognized" - - # Define the parameters for the chirp simulations + assert test1 in [ + "width", + "size", + "kurtosis", + "contrast", + ], "Test1 not recognized" + assert test2 in [ + "width", + "size", + "kurtosis", + "contrast", + ], "Test2 not recognized" + + # Define the parameters for the chirp simulations ntest = resolution defaultparams = dict( - chirp_size = np.ones(ntest) * 100, - chirp_width = np.ones(ntest) * 0.1, - chirp_kurtosis = np.ones(ntest) * 1.0, - chirp_contrast = np.ones(ntest) * 0.5, + chirp_size=np.ones(ntest) * 100, + chirp_width=np.ones(ntest) * 0.1, + chirp_kurtosis=np.ones(ntest) * 1.0, + chirp_contrast=np.ones(ntest) * 0.5, ) testparams = dict( - chirp_width = np.linspace(0.01, 0.2, ntest), - chirp_size = np.linspace(50, 300, ntest), - chirp_kurtosis = np.linspace(0.5, 1.5, ntest), - chirp_contrast = np.linspace(0.01, 1.0, ntest), + chirp_width=np.linspace(0.01, 0.2, ntest), + chirp_size=np.linspace(50, 300, ntest), + chirp_kurtosis=np.linspace(0.5, 1.5, ntest), + chirp_contrast=np.linspace(0.01, 1.0, ntest), ) key1, chirp_params = switch_test(test1, defaultparams, testparams) key2, chirp_params = switch_test(test2, chirp_params, testparams) - # make the chirp trace + # make the chirp trace eodf = 500 samplerate = 20000 duration = 2 @@ -63,40 +72,60 @@ def main(test1, test2, resolution=10): tight_cutoffs = 10 distances = np.full((ntest, ntest), np.nan) - - fig, axs = plt.subplots(ntest, ntest, figsize = (10, 10), sharex = True, sharey = True) + + fig, axs = plt.subplots( + ntest, ntest, figsize=(10, 10), sharex=True, sharey=True + ) axs = axs.flatten() iter0 = 0 for iter1, test1_param in enumerate(chirp_params[key1]): for iter2, test2_param in enumerate(chirp_params[key2]): - # get the chirp parameters for the current test inner_chirp_params = extract_dict(chirp_params, iter2) inner_chirp_params[key1] = test1_param inner_chirp_params[key2] = test2_param # make the chirp trace for the current chirp parameters - sizes = np.ones(len(chirp_times)) * inner_chirp_params['chirp_size'] - widths = np.ones(len(chirp_times)) * inner_chirp_params['chirp_width'] - kurtosis = np.ones(len(chirp_times)) * inner_chirp_params['chirp_kurtosis'] - contrast = np.ones(len(chirp_times)) * inner_chirp_params['chirp_contrast'] + sizes = np.ones(len(chirp_times)) * inner_chirp_params["chirp_size"] + widths = ( + np.ones(len(chirp_times)) * inner_chirp_params["chirp_width"] + ) + kurtosis = ( + np.ones(len(chirp_times)) * inner_chirp_params["chirp_kurtosis"] + ) + contrast = ( + np.ones(len(chirp_times)) * inner_chirp_params["chirp_contrast"] + ) # make the chirp trace - chirp_trace, ampmod = chirps(eodf, samplerate, duration, chirp_times, sizes, widths, kurtosis, contrast) + chirp_trace, ampmod = chirps( + eodf, + samplerate, + duration, + chirp_times, + sizes, + widths, + kurtosis, + contrast, + ) signal = wavefish_eods( - fish="Alepto", - frequency=chirp_trace, - samplerate=samplerate, - duration=duration, - phase0=0.0, - noise_std=0.05 - ) + fish="Alepto", + frequency=chirp_trace, + samplerate=samplerate, + duration=duration, + phase0=0.0, + noise_std=0.05, + ) signal = signal * ampmod - # apply broadband filter - wide_signal = bandpass_filter(signal, samplerate, eodf - wide_cutoffs, eodf + wide_cutoffs) - tight_signal = bandpass_filter(signal, samplerate, eodf - tight_cutoffs, eodf + tight_cutoffs) + # apply broadband filter + wide_signal = bandpass_filter( + signal, samplerate, eodf - wide_cutoffs, eodf + wide_cutoffs + ) + tight_signal = bandpass_filter( + signal, samplerate, eodf - tight_cutoffs, eodf + tight_cutoffs + ) # get the instantaneous frequency wide_frequency = inst_freq(wide_signal, samplerate) @@ -111,8 +140,9 @@ def main(test1, test2, resolution=10): iter0 += 1 fig, ax = plt.subplots() - ax.imshow(distances, cmap = 'jet') + ax.imshow(distances, cmap="jet") plt.show() + if __name__ == "__main__": - main('width', 'size') + main("width", "size") diff --git a/code/analysis.py b/code/analysis.py index 787a53e..2e32671 100644 --- a/code/analysis.py +++ b/code/analysis.py @@ -10,73 +10,84 @@ from modules.filters import bandpass_filter def main(folder): - file = os.path.join(folder, 'traces-grid.raw') + file = os.path.join(folder, "traces-grid.raw") data = open_data(folder, 60.0, 0, channel=-1) - time = np.load(folder + 'times.npy', allow_pickle=True) - freq = np.load(folder + 'fund_v.npy', allow_pickle=True) - ident = np.load(folder + 'ident_v.npy', allow_pickle=True) - idx = np.load(folder + 'idx_v.npy', allow_pickle=True) + time = np.load(folder + "times.npy", allow_pickle=True) + freq = np.load(folder + "fund_v.npy", allow_pickle=True) + ident = np.load(folder + "ident_v.npy", allow_pickle=True) + idx = np.load(folder + "idx_v.npy", allow_pickle=True) - t0 = 3*60*60 + 6*60 + 43.5 + t0 = 3 * 60 * 60 + 6 * 60 + 43.5 dt = 60 - data_oi = data[t0 * data.samplerate: (t0+dt)*data.samplerate, :] + data_oi = data[t0 * data.samplerate : (t0 + dt) * data.samplerate, :] for i in [10]: # getting the spectogramm spec_power, spec_freqs, spec_times = spectrogram( - data_oi[:, i], ratetime=data.samplerate, freq_resolution=50, overlap_frac=0.0) - fig, ax = plt.subplots(figsize=(20/2.54, 12/2.54)) - ax.pcolormesh(spec_times, spec_freqs, decibel( - spec_power), vmin=-100, vmax=-50) + data_oi[:, i], + ratetime=data.samplerate, + freq_resolution=50, + overlap_frac=0.0, + ) + fig, ax = plt.subplots(figsize=(20 / 2.54, 12 / 2.54)) + ax.pcolormesh( + spec_times, spec_freqs, decibel(spec_power), vmin=-100, vmax=-50 + ) for track_id in np.unique(ident): # window_index for time array in time window - window_index = np.arange(len(idx))[(ident == track_id) & - (time[idx] >= t0) & - (time[idx] <= (t0+dt))] + window_index = np.arange(len(idx))[ + (ident == track_id) + & (time[idx] >= t0) + & (time[idx] <= (t0 + dt)) + ] freq_temp = freq[window_index] time_temp = time[idx[window_index]] - #mean_freq = np.mean(freq_temp) - #fdata = bandpass_filter(data_oi[:, track_id], data.samplerate, mean_freq-5, mean_freq+200) + # mean_freq = np.mean(freq_temp) + # fdata = bandpass_filter(data_oi[:, track_id], data.samplerate, mean_freq-5, mean_freq+200) ax.plot(time_temp - t0, freq_temp) ax.set_ylim(500, 1000) plt.show() # filter plot - id = 10. + id = 10.0 i = 10 - window_index = np.arange(len(idx))[(ident == id) & - (time[idx] >= t0) & - (time[idx] <= (t0+dt))] + window_index = np.arange(len(idx))[ + (ident == id) & (time[idx] >= t0) & (time[idx] <= (t0 + dt)) + ] freq_temp = freq[window_index] time_temp = time[idx[window_index]] mean_freq = np.mean(freq_temp) fdata = bandpass_filter( - data_oi[:, i], rate=data.samplerate, lowf=mean_freq-5, highf=mean_freq+200) + data_oi[:, i], + rate=data.samplerate, + lowf=mean_freq - 5, + highf=mean_freq + 200, + ) fig, ax = plt.subplots() - ax.plot(np.arange(len(fdata))/data.samplerate, fdata, marker='*') + ax.plot(np.arange(len(fdata)) / data.samplerate, fdata, marker="*") # plt.show() # freqency analyis of filtered data - time_fdata = np.arange(len(fdata))/data.samplerate + time_fdata = np.arange(len(fdata)) / data.samplerate roll_fdata = np.roll(fdata, shift=1) period_index = np.arange(len(fdata))[(roll_fdata < 0) & (fdata >= 0)] plt.plot(time_fdata, fdata) - plt.scatter(time_fdata[period_index], fdata[period_index], c='r') - plt.scatter(time_fdata[period_index-1], fdata[period_index-1], c='r') + plt.scatter(time_fdata[period_index], fdata[period_index], c="r") + plt.scatter(time_fdata[period_index - 1], fdata[period_index - 1], c="r") upper_bound = np.abs(fdata[period_index]) - lower_bound = np.abs(fdata[period_index-1]) + lower_bound = np.abs(fdata[period_index - 1]) upper_times = np.abs(time_fdata[period_index]) - lower_times = np.abs(time_fdata[period_index-1]) + lower_times = np.abs(time_fdata[period_index - 1]) - lower_ratio = lower_bound/(lower_bound+upper_bound) - upper_ratio = upper_bound/(lower_bound+upper_bound) + lower_ratio = lower_bound / (lower_bound + upper_bound) + upper_ratio = upper_bound / (lower_bound + upper_bound) - time_delta = upper_times-lower_times - true_zero = lower_times + time_delta*lower_ratio + time_delta = upper_times - lower_times + true_zero = lower_times + time_delta * lower_ratio plt.scatter(true_zero, np.zeros(len(true_zero))) @@ -84,7 +95,7 @@ def main(folder): inst_freq = 1 / np.diff(true_zero) filtered_inst_freq = gaussian_filter1d(inst_freq, 0.005) fig, ax = plt.subplots() - ax.plot(filtered_inst_freq, marker='.') + ax.plot(filtered_inst_freq, marker=".") # in 5 sekunden welcher fisch auf einer elektrode am embed() @@ -99,5 +110,7 @@ def main(folder): pass -if __name__ == '__main__': - main('/Users/acfw/Documents/uni_tuebingen/chirpdetection/gp_benda/data/2022-06-02-10_00/') +if __name__ == "__main__": + main( + "/Users/acfw/Documents/uni_tuebingen/chirpdetection/gp_benda/data/2022-06-02-10_00/" + ) diff --git a/code/band_pass_problem.py b/code/band_pass_problem.py index fc6a55e..f553ff2 100644 --- a/code/band_pass_problem.py +++ b/code/band_pass_problem.py @@ -12,25 +12,27 @@ from modules.filehandling import LoadData def main(folder): data = LoadData(folder) - t0 = 3*60*60 + 6*60 + 43.5 + t0 = 3 * 60 * 60 + 6 * 60 + 43.5 dt = 60 - data_oi = data.raw[t0 * data.raw_rate: (t0+dt)*data.raw_rate, :] - # good electrode - electrode = 10 + data_oi = data.raw[t0 * data.raw_rate : (t0 + dt) * data.raw_rate, :] + # good electrode + electrode = 10 data_oi = data_oi[:, electrode] - fig, axs = plt.subplots(2,1) - axs[0].plot( np.arange(data_oi.shape[0]) / data.raw_rate, data_oi) + fig, axs = plt.subplots(2, 1) + axs[0].plot(np.arange(data_oi.shape[0]) / data.raw_rate, data_oi) for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): rack_window_index = np.arange(len(data.idx))[ - (data.ident == track_id) & - (data.time[data.idx] >= t0) & - (data.time[data.idx] <= (t0+dt))] + (data.ident == track_id) + & (data.time[data.idx] >= t0) + & (data.time[data.idx] <= (t0 + dt)) + ] freq_fish = data.freq[rack_window_index] axs[1].plot(np.arange(freq_fish.shape[0]) / data.raw_rate, freq_fish) plt.show() - -if __name__ == '__main__': - main('/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/2022-06-02-10_00/') \ No newline at end of file +if __name__ == "__main__": + main( + "/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/2022-06-02-10_00/" + ) diff --git a/code/behavior.py b/code/behavior.py index 71c0926..4f16543 100644 --- a/code/behavior.py +++ b/code/behavior.py @@ -1,8 +1,8 @@ -import os -import os +import os +import os import numpy as np -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt from IPython import embed from pandas import read_csv @@ -11,51 +11,65 @@ from scipy.ndimage import gaussian_filter1d logger = makeLogger(__name__) + class Behavior: """Load behavior data from csv file as class attributes Attributes ---------- behavior: 0: chasing onset, 1: chasing offset, 2: physical contact - behavior_type: - behavioral_category: - comment_start: - comment_stop: - dataframe: pandas dataframe with all the data - duration_s: - media_file: - observation_date: - observation_id: - start_s: start time of the event in seconds - stop_s: stop time of the event in seconds - total_length: + behavior_type: + behavioral_category: + comment_start: + comment_stop: + dataframe: pandas dataframe with all the data + duration_s: + media_file: + observation_date: + observation_id: + start_s: start time of the event in seconds + stop_s: stop time of the event in seconds + total_length: """ def __init__(self, folder_path: str) -> None: - - - LED_on_time_BORIS = np.load(os.path.join(folder_path, 'LED_on_time.npy'), allow_pickle=True) - self.time = np.load(os.path.join(folder_path, "times.npy"), allow_pickle=True) - csv_filename = [f for f in os.listdir(folder_path) if f.endswith('.csv')][0] # check if there are more than one csv file + LED_on_time_BORIS = np.load( + os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True + ) + self.time = np.load( + os.path.join(folder_path, "times.npy"), allow_pickle=True + ) + csv_filename = [ + f for f in os.listdir(folder_path) if f.endswith(".csv") + ][ + 0 + ] # check if there are more than one csv file self.dataframe = read_csv(os.path.join(folder_path, csv_filename)) - self.chirps = np.load(os.path.join(folder_path, 'chirps.npy'), allow_pickle=True) - self.chirps_ids = np.load(os.path.join(folder_path, 'chirps_ids.npy'), allow_pickle=True) + self.chirps = np.load( + os.path.join(folder_path, "chirps.npy"), allow_pickle=True + ) + self.chirps_ids = np.load( + os.path.join(folder_path, "chirps_ids.npy"), allow_pickle=True + ) for k, key in enumerate(self.dataframe.keys()): - key = key.lower() - if ' ' in key: - key = key.replace(' ', '_') - if '(' in key: - key = key.replace('(', '') - key = key.replace(')', '') - setattr(self, key, np.array(self.dataframe[self.dataframe.keys()[k]])) - + key = key.lower() + if " " in key: + key = key.replace(" ", "_") + if "(" in key: + key = key.replace("(", "") + key = key.replace(")", "") + setattr( + self, key, np.array(self.dataframe[self.dataframe.keys()[k]]) + ) + last_LED_t_BORIS = LED_on_time_BORIS[-1] real_time_range = self.time[-1] - self.time[0] factor = 1.034141 shift = last_LED_t_BORIS - real_time_range * factor self.start_s = (self.start_s - shift) / factor self.stop_s = (self.stop_s - shift) / factor - + + """ 1 - chasing onset 2 - chasing offset @@ -83,77 +97,77 @@ temporal encpding needs to be corrected ... not exactly 25FPS. behavior = data['Behavior'] """ -def correct_chasing_events( - category: np.ndarray, - timestamps: np.ndarray - ) -> tuple[np.ndarray, np.ndarray]: - onset_ids = np.arange( - len(category))[category == 0] - offset_ids = np.arange( - len(category))[category == 1] +def correct_chasing_events( + category: np.ndarray, timestamps: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + onset_ids = np.arange(len(category))[category == 0] + offset_ids = np.arange(len(category))[category == 1] # Check whether on- or offset is longer and calculate length difference if len(onset_ids) > len(offset_ids): len_diff = len(onset_ids) - len(offset_ids) longer_array = onset_ids shorter_array = offset_ids - logger.info(f'Onsets are greater than offsets by {len_diff}') + logger.info(f"Onsets are greater than offsets by {len_diff}") elif len(onset_ids) < len(offset_ids): len_diff = len(offset_ids) - len(onset_ids) longer_array = offset_ids shorter_array = onset_ids - logger.info(f'Offsets are greater than offsets by {len_diff}') + logger.info(f"Offsets are greater than offsets by {len_diff}") elif len(onset_ids) == len(offset_ids): - logger.info('Chasing events are equal') + logger.info("Chasing events are equal") return category, timestamps # Correct the wrong chasing events; delete double events wrong_ids = [] - for i in range(len(longer_array)-(len_diff+1)): - if (shorter_array[i] > longer_array[i]) & (shorter_array[i] < longer_array[i+1]): + for i in range(len(longer_array) - (len_diff + 1)): + if (shorter_array[i] > longer_array[i]) & ( + shorter_array[i] < longer_array[i + 1] + ): pass else: wrong_ids.append(longer_array[i]) longer_array = np.delete(longer_array, i) - - category = np.delete( - category, wrong_ids) - timestamps = np.delete( - timestamps, wrong_ids) + + category = np.delete(category, wrong_ids) + timestamps = np.delete(timestamps, wrong_ids) return category, timestamps def event_triggered_chirps( - event: np.ndarray, - chirps:np.ndarray, + event: np.ndarray, + chirps: np.ndarray, time_before_event: int, - time_after_event: int - )-> tuple[np.ndarray, np.ndarray]: - - - event_chirps = [] # chirps that are in specified window around event - centered_chirps = [] # timestamps of chirps around event centered on the event timepoint + time_after_event: int, +) -> tuple[np.ndarray, np.ndarray]: + event_chirps = [] # chirps that are in specified window around event + centered_chirps = ( + [] + ) # timestamps of chirps around event centered on the event timepoint for event_timestamp in event: - start = event_timestamp - time_before_event # timepoint of window start - stop = event_timestamp + time_after_event # timepoint of window ending - chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)] # get chirps that are in a -5 to +5 sec window around event + start = event_timestamp - time_before_event # timepoint of window start + stop = event_timestamp + time_after_event # timepoint of window ending + chirps_around_event = [ + c for c in chirps if (c >= start) & (c <= stop) + ] # get chirps that are in a -5 to +5 sec window around event event_chirps.append(chirps_around_event) if len(chirps_around_event) == 0: continue - else: + else: centered_chirps.append(chirps_around_event - event_timestamp) - centered_chirps = np.concatenate(centered_chirps, axis=0) # convert list of arrays to one array for plotting + centered_chirps = np.concatenate( + centered_chirps, axis=0 + ) # convert list of arrays to one array for plotting return event_chirps, centered_chirps def main(datapath: str): - # behavior is pandas dataframe with all the data bh = Behavior(datapath) - + # chirps are not sorted in time (presumably due to prior groupings) # get and sort chirps and corresponding fish_ids of the chirps chirps = bh.chirps[np.argsort(bh.chirps)] @@ -172,10 +186,34 @@ def main(datapath: str): # First overview plot fig1, ax1 = plt.subplots() - ax1.scatter(chirps, np.ones_like(chirps), marker='*', color='royalblue', label='Chirps') - ax1.scatter(chasing_onset, np.ones_like(chasing_onset)*2, marker='.', color='forestgreen', label='Chasing onset') - ax1.scatter(chasing_offset, np.ones_like(chasing_offset)*2.5, marker='.', color='firebrick', label='Chasing offset') - ax1.scatter(physical_contact, np.ones_like(physical_contact)*3, marker='x', color='black', label='Physical contact') + ax1.scatter( + chirps, + np.ones_like(chirps), + marker="*", + color="royalblue", + label="Chirps", + ) + ax1.scatter( + chasing_onset, + np.ones_like(chasing_onset) * 2, + marker=".", + color="forestgreen", + label="Chasing onset", + ) + ax1.scatter( + chasing_offset, + np.ones_like(chasing_offset) * 2.5, + marker=".", + color="firebrick", + label="Chasing offset", + ) + ax1.scatter( + physical_contact, + np.ones_like(physical_contact) * 3, + marker="x", + color="black", + label="Physical contact", + ) plt.legend() # plt.show() plt.close() @@ -187,29 +225,40 @@ def main(datapath: str): # Evaluate how many chirps were emitted in specific time window around the chasing onset events # Iterate over chasing onsets (later over fish) - time_around_event = 5 # time window around the event in which chirps are counted, 5 = -5 to +5 sec around event + time_around_event = 5 # time window around the event in which chirps are counted, 5 = -5 to +5 sec around event #### Loop crashes at concatenate in function #### # for i in range(len(fish_ids)): # fish = fish_ids[i] # chirps = chirps[chirps_fish_ids == fish] # print(fish) - chasing_chirps, centered_chasing_chirps = event_triggered_chirps(chasing_onset, chirps, time_around_event, time_around_event) - physical_chirps, centered_physical_chirps = event_triggered_chirps(physical_contact, chirps, time_around_event, time_around_event) + chasing_chirps, centered_chasing_chirps = event_triggered_chirps( + chasing_onset, chirps, time_around_event, time_around_event + ) + physical_chirps, centered_physical_chirps = event_triggered_chirps( + physical_contact, chirps, time_around_event, time_around_event + ) # Kernel density estimation ??? # centered_chasing_chirps_convolved = gaussian_filter1d(centered_chasing_chirps, 5) - + # centered_chasing = chasing_onset[0] - chasing_onset[0] ## get the 0 timepoint for plotting; set one chasing event to 0 offsets = [0.5, 1] - fig4, ax4 = plt.subplots(figsize=(20 / 2.54, 12 / 2.54), constrained_layout=True) - ax4.eventplot(np.array([centered_chasing_chirps, centered_physical_chirps]), lineoffsets=offsets, linelengths=0.25, colors=['g', 'r']) - ax4.vlines(0, 0, 1.5, 'tab:grey', 'dashed', 'Timepoint of event') + fig4, ax4 = plt.subplots( + figsize=(20 / 2.54, 12 / 2.54), constrained_layout=True + ) + ax4.eventplot( + np.array([centered_chasing_chirps, centered_physical_chirps]), + lineoffsets=offsets, + linelengths=0.25, + colors=["g", "r"], + ) + ax4.vlines(0, 0, 1.5, "tab:grey", "dashed", "Timepoint of event") # ax4.plot(centered_chasing_chirps_convolved) ax4.set_yticks(offsets) - ax4.set_yticklabels(['Chasings', 'Physical \n contacts']) - ax4.set_xlabel('Time[s]') - ax4.set_ylabel('Type of event') + ax4.set_yticklabels(["Chasings", "Physical \n contacts"]) + ax4.set_xlabel("Time[s]") + ax4.set_ylabel("Type of event") plt.show() # Associate chirps to inidividual fish @@ -219,22 +268,21 @@ def main(datapath: str): ### Plots: # 1. All recordings, all fish, all chirps - # One CTC, one PTC + # One CTC, one PTC # 2. All recordings, only winners - # One CTC, one PTC + # One CTC, one PTC # 3. All recordings, all losers - # One CTC, one PTC + # One CTC, one PTC #### Chirp counts per fish general ##### fig2, ax2 = plt.subplots() - x = ['Fish1', 'Fish2'] + x = ["Fish1", "Fish2"] width = 0.35 ax2.bar(x, fish, width=width) - ax2.set_ylabel('Chirp count') + ax2.set_ylabel("Chirp count") # plt.show() plt.close() - ##### Count chirps emitted during chasing events and chirps emitted out of chasing events ##### chirps_in_chasings = [] for onset, offset in zip(chasing_onset, chasing_offset): @@ -251,23 +299,24 @@ def main(datapath: str): counts_chirps_chasings += 1 # chirps in chasing events - fig3 , ax3 = plt.subplots() - ax3.bar(['Chirps in chasing events', 'Chasing events without Chirps'], [counts_chirps_chasings, chasings_without_chirps], width=width) - plt.ylabel('Count') + fig3, ax3 = plt.subplots() + ax3.bar( + ["Chirps in chasing events", "Chasing events without Chirps"], + [counts_chirps_chasings, chasings_without_chirps], + width=width, + ) + plt.ylabel("Count") # plt.show() - plt.close() + plt.close() # comparison between chasing events with and without chirps - - embed() exit() - -if __name__ == '__main__': +if __name__ == "__main__": # Path to the data - datapath = '../data/mount_data/2020-05-13-10_00/' - datapath = '../data/mount_data/2020-05-13-10_00/' + datapath = "../data/mount_data/2020-05-13-10_00/" + datapath = "../data/mount_data/2020-05-13-10_00/" main(datapath) diff --git a/code/chirp_sim.py b/code/chirp_sim.py index 5433b36..d7023a5 100644 --- a/code/chirp_sim.py +++ b/code/chirp_sim.py @@ -8,30 +8,27 @@ from modules.datahandling import instantaneous_frequency from modules.simulations import create_chirp - # trying thunderfish fakefish chirp simulation --------------------------------- samplerate = 44100 freq, ampl = fakefish.chirps(eodf=500, chirp_contrast=0.2) -data = fakefish.wavefish_eods(fish='Alepto', frequency=freq, phase0=3, samplerate=samplerate) +data = fakefish.wavefish_eods( + fish="Alepto", frequency=freq, phase0=3, samplerate=samplerate +) # filter signal with bandpass_filter -data_filterd = bandpass_filter(data*ampl+1, samplerate, 0.01, 1.99) +data_filterd = bandpass_filter(data * ampl + 1, samplerate, 0.01, 1.99) embed() data_freq_time, data_freq = instantaneous_frequency(data, samplerate, 5) fig, ax = plt.subplots(4, 1, figsize=(20 / 2.54, 12 / 2.54), sharex=True) -ax[0].plot(np.arange(len(data))/samplerate, data*ampl) -#ax[0].scatter(true_zero, np.zeros_like(true_zero), color='red') -ax[1].plot(np.arange(len(data_filterd))/samplerate, data_filterd) -ax[2].plot(np.arange(len(freq))/samplerate, freq) +ax[0].plot(np.arange(len(data)) / samplerate, data * ampl) +# ax[0].scatter(true_zero, np.zeros_like(true_zero), color='red') +ax[1].plot(np.arange(len(data_filterd)) / samplerate, data_filterd) +ax[2].plot(np.arange(len(freq)) / samplerate, freq) ax[3].plot(data_freq_time, data_freq) plt.show() embed() - - - - diff --git a/code/chirpdetection.py b/code/chirpdetection.py index 95800df..937bde4 100755 --- a/code/chirpdetection.py +++ b/code/chirpdetection.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import matplotlib.gridspec as gr from scipy.signal import find_peaks from thunderfish.powerspectrum import spectrogram, decibel + # from sklearn.preprocessing import normalize from modules.filters import bandpass_filter, envelope, highpass_filter @@ -18,7 +19,7 @@ from modules.datahandling import ( purge_duplicates, group_timestamps, instantaneous_frequency, - instantaneous_frequency2, + instantaneous_frequency2, minmaxnorm, ) @@ -59,7 +60,6 @@ class ChirpPlotBuffer: frequency_peaks: np.ndarray def plot_buffer(self, chirps: np.ndarray, plot: str) -> None: - logger.debug("Starting plotting") # make data for plotting @@ -135,7 +135,6 @@ class ChirpPlotBuffer: ax0.set_ylim(np.min(self.frequency) - 100, np.max(self.frequency) + 200) for track_id in self.data.ids: - t0_track = self.t0_old - 5 dt_track = self.dt + 10 window_idx = np.arange(len(self.data.idx))[ @@ -176,10 +175,16 @@ class ChirpPlotBuffer: # ) ax0.axhline( - q50 - self.config.minimal_bandwidth / 2, color=ps.gblue1, lw=1, ls="dashed" + q50 - self.config.minimal_bandwidth / 2, + color=ps.gblue1, + lw=1, + ls="dashed", ) ax0.axhline( - q50 + self.config.minimal_bandwidth / 2, color=ps.gblue1, lw=1, ls="dashed" + q50 + self.config.minimal_bandwidth / 2, + color=ps.gblue1, + lw=1, + ls="dashed", ) ax0.axhline(search_lower, color=ps.gblue2, lw=1, ls="dashed") ax0.axhline(search_upper, color=ps.gblue2, lw=1, ls="dashed") @@ -205,7 +210,11 @@ class ChirpPlotBuffer: # plot waveform of filtered signal ax1.plot( - self.time, self.baseline * waveform_scaler, c=ps.gray, lw=lw, alpha=0.5 + self.time, + self.baseline * waveform_scaler, + c=ps.gray, + lw=lw, + alpha=0.5, ) ax1.plot( self.time, @@ -216,7 +225,13 @@ class ChirpPlotBuffer: ) # plot waveform of filtered search signal - ax2.plot(self.time, self.search * waveform_scaler, c=ps.gray, lw=lw, alpha=0.5) + ax2.plot( + self.time, + self.search * waveform_scaler, + c=ps.gray, + lw=lw, + alpha=0.5, + ) ax2.plot( self.time, self.search_envelope_unfiltered * waveform_scaler, @@ -238,9 +253,7 @@ class ChirpPlotBuffer: # ax4.plot( # self.time, self.baseline_envelope * waveform_scaler, c=ps.gblue1, lw=lw # ) - ax4.plot( - self.time, self.baseline_envelope, c=ps.gblue1, lw=lw - ) + ax4.plot(self.time, self.baseline_envelope, c=ps.gblue1, lw=lw) ax4.scatter( (self.time)[self.baseline_peaks], # (self.baseline_envelope * waveform_scaler)[self.baseline_peaks], @@ -269,7 +282,9 @@ class ChirpPlotBuffer: ) # plot filtered instantaneous frequency - ax6.plot(self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw) + ax6.plot( + self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw + ) ax6.scatter( self.frequency_time[self.frequency_peaks], self.frequency_filtered[self.frequency_peaks], @@ -303,7 +318,9 @@ class ChirpPlotBuffer: # ax7.spines.bottom.set_bounds((0, 5)) ax0.set_xlim(0, self.config.window) - plt.subplots_adjust(left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2) + plt.subplots_adjust( + left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2 + ) fig.align_labels() if plot == "show": @@ -408,7 +425,9 @@ def extract_frequency_bands( q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2 # filter baseline - filtered_baseline = bandpass_filter(raw_data, samplerate, lowf=q25, highf=q75) + filtered_baseline = bandpass_filter( + raw_data, samplerate, lowf=q25, highf=q75 + ) # filter search area filtered_search_freq = bandpass_filter( @@ -453,12 +472,14 @@ def window_median_all_track_ids( track_ids = [] for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): - # the window index combines the track id and the time window window_idx = np.arange(len(data.idx))[ (data.ident == track_id) & (data.time[data.idx] >= window_start_seconds) - & (data.time[data.idx] <= (window_start_seconds + window_duration_seconds)) + & ( + data.time[data.idx] + <= (window_start_seconds + window_duration_seconds) + ) ] if len(data.freq[window_idx]) > 0: @@ -595,15 +616,15 @@ def find_searchband( # iterate through theses tracks if check_track_ids.size != 0: - for j, check_track_id in enumerate(check_track_ids): - q25_temp = q25[percentiles_ids == check_track_id] q75_temp = q75[percentiles_ids == check_track_id] bool_lower[search_window > q25_temp - config.search_res] = False bool_upper[search_window < q75_temp + config.search_res] = False - search_window_bool[(bool_lower == False) & (bool_upper == False)] = False + search_window_bool[ + (bool_lower == False) & (bool_upper == False) + ] = False # find gaps in search window search_window_indices = np.arange(len(search_window)) @@ -622,7 +643,9 @@ def find_searchband( # if the first value is -1, the array starst with true, so a gap if nonzeros[0] == -1: stops = search_window_indices[search_window_gaps == -1] - starts = np.append(0, search_window_indices[search_window_gaps == 1]) + starts = np.append( + 0, search_window_indices[search_window_gaps == 1] + ) # if the last value is -1, the array ends with true, so a gap if nonzeros[-1] == 1: @@ -659,7 +682,6 @@ def find_searchband( def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: - assert plot in [ "save", "show", @@ -729,7 +751,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: multiwindow_ids = [] for st, window_start_index in enumerate(window_start_indices): - logger.info(f"Processing window {st+1} of {len(window_start_indices)}") window_start_seconds = window_start_index / data.raw_rate @@ -744,8 +765,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: ) # iterate through all fish - for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): - + for tr, track_id in enumerate( + np.unique(data.ident[~np.isnan(data.ident)]) + ): logger.debug(f"Processing track {tr} of {len(data.ids)}") # get index of track data in this time window @@ -773,16 +795,17 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: nanchecker = np.unique(np.isnan(current_powers)) if (len(nanchecker) == 1) and nanchecker[0] is True: logger.warning( - f"No powers available for track {track_id} window {st}," "skipping." + f"No powers available for track {track_id} window {st}," + "skipping." ) continue # find the strongest electrodes for the current fish in the current # window - best_electrode_index = np.argsort(np.nanmean(current_powers, axis=0))[ - -config.number_electrodes : - ] + best_electrode_index = np.argsort( + np.nanmean(current_powers, axis=0) + )[-config.number_electrodes :] # find a frequency above the baseline of the current fish in which # no other fish is active to search for chirps there @@ -802,9 +825,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: # iterate through electrodes for el, electrode_index in enumerate(best_electrode_index): - logger.debug( - f"Processing electrode {el+1} of " f"{len(best_electrode_index)}" + f"Processing electrode {el+1} of " + f"{len(best_electrode_index)}" ) # LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------ @@ -813,7 +836,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: current_raw_data = data.raw[ window_start_index:window_stop_index, electrode_index ] - current_raw_time = raw_time[window_start_index:window_stop_index] + current_raw_time = raw_time[ + window_start_index:window_stop_index + ] # EXTRACT FEATURES -------------------------------------------- @@ -839,8 +864,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: # because the instantaneous frequency is not reliable there amplitude_mask = mask_low_amplitudes( - baseline_envelope_unfiltered, - config.baseline_min_amplitude + baseline_envelope_unfiltered, config.baseline_min_amplitude ) # highpass filter baseline envelope to remove slower @@ -877,27 +901,30 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: # filtered baseline such as the one we are working with. baseline_frequency = instantaneous_frequency( - baselineband, - data.raw_rate, - config.baseline_frequency_smoothing + baselineband, + data.raw_rate, + config.baseline_frequency_smoothing, ) # Take the absolute of the instantaneous frequency to invert - # troughs into peaks. This is nessecary since the narrow + # troughs into peaks. This is nessecary since the narrow # pass band introduces these anomalies. Also substract by the # median to set it to 0. - + baseline_frequency_filtered = np.abs( baseline_frequency - np.median(baseline_frequency) ) - # check if there is at least one superthreshold peak on the - # instantaneous and exit the loop if not. This is used to - # prevent windows that do definetely not include a chirp - # to enter normalization, where small changes due to noise - # would be amplified + # check if there is at least one superthreshold peak on the + # instantaneous and exit the loop if not. This is used to + # prevent windows that do definetely not include a chirp + # to enter normalization, where small changes due to noise + # would be amplified - if not has_chirp(baseline_frequency_filtered[amplitude_mask], config.baseline_frequency_peakheight): + if not has_chirp( + baseline_frequency_filtered[amplitude_mask], + config.baseline_frequency_peakheight, + ): continue # CUT OFF OVERLAP --------------------------------------------- @@ -912,14 +939,20 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: current_raw_time = current_raw_time[no_edges] baselineband = baselineband[no_edges] - baseline_envelope_unfiltered = baseline_envelope_unfiltered[no_edges] + baseline_envelope_unfiltered = baseline_envelope_unfiltered[ + no_edges + ] searchband = searchband[no_edges] baseline_envelope = baseline_envelope[no_edges] - search_envelope_unfiltered = search_envelope_unfiltered[no_edges] + search_envelope_unfiltered = search_envelope_unfiltered[ + no_edges + ] search_envelope = search_envelope[no_edges] baseline_frequency = baseline_frequency[no_edges] - baseline_frequency_filtered = baseline_frequency_filtered[no_edges] + baseline_frequency_filtered = baseline_frequency_filtered[ + no_edges + ] baseline_frequency_time = current_raw_time # # get instantaneous frequency withoup edges @@ -960,13 +993,16 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: ) # detect peaks inst_freq_filtered frequency_peak_indices, _ = find_peaks( - baseline_frequency_filtered, prominence=config.frequency_prominence + baseline_frequency_filtered, + prominence=config.frequency_prominence, ) # DETECT CHIRPS IN SEARCH WINDOW ------------------------------ # get the peak timestamps from the peak indices - baseline_peak_timestamps = current_raw_time[baseline_peak_indices] + baseline_peak_timestamps = current_raw_time[ + baseline_peak_indices + ] search_peak_timestamps = current_raw_time[search_peak_indices] frequency_peak_timestamps = baseline_frequency_time[ @@ -1015,7 +1051,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: ) if chirp_detected or (debug != "elecrode"): - logger.debug("Detected chirp, ititialize buffer ...") # save data to Buffer @@ -1107,7 +1142,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: multiwindow_chirps_flat = [] multiwindow_ids_flat = [] for track_id in np.unique(multiwindow_ids): - # get chirps for this fish and flatten the list current_track_bool = np.asarray(multiwindow_ids) == track_id current_track_chirps = flatten( @@ -1116,7 +1150,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: # add flattened chirps to the list multiwindow_chirps_flat.extend(current_track_chirps) - multiwindow_ids_flat.extend(list(np.ones_like(current_track_chirps) * track_id)) + multiwindow_ids_flat.extend( + list(np.ones_like(current_track_chirps) * track_id) + ) # purge duplicates, i.e. chirps that are very close to each other # duplites arise due to overlapping windows diff --git a/code/chirpdetector_conf.yml b/code/chirpdetector_conf.yml index 371326e..bb4598b 100755 --- a/code/chirpdetector_conf.yml +++ b/code/chirpdetector_conf.yml @@ -1,37 +1,37 @@ # Path setup ------------------------------------------------------------------ -dataroot: "../data/" # path to data -outputdir: "../output/" # path to save plots to +dataroot: "../data/" # path to data +outputdir: "../output/" # path to save plots to # Rolling window parameters --------------------------------------------------- -window: 5 # rolling window length in seconds +window: 5 # rolling window length in seconds overlap: 1 # window overlap in seconds edge: 0.25 # window edge cufoffs to mitigate filter edge effects # Electrode iteration parameters ---------------------------------------------- -number_electrodes: 2 # number of electrodes to go over -minimum_electrodes: 1 # mimumun number of electrodes a chirp must be on +number_electrodes: 2 # number of electrodes to go over +minimum_electrodes: 1 # mimumun number of electrodes a chirp must be on # Feature extraction parameters ----------------------------------------------- -search_df_lower: 20 # start searching this far above the baseline -search_df_upper: 100 # stop searching this far above the baseline -search_res: 1 # search window resolution -default_search_freq: 60 # search here if no need for a search frequency -minimal_bandwidth: 10 # minimal bandpass filter width for baseline -search_bandwidth: 10 # minimal bandpass filter width for search frequency -baseline_frequency_smoothing: 3 # instantaneous frequency smoothing +search_df_lower: 20 # start searching this far above the baseline +search_df_upper: 100 # stop searching this far above the baseline +search_res: 1 # search window resolution +default_search_freq: 60 # search here if no need for a search frequency +minimal_bandwidth: 10 # minimal bandpass filter width for baseline +search_bandwidth: 10 # minimal bandpass filter width for search frequency +baseline_frequency_smoothing: 3 # instantaneous frequency smoothing # Feature processing parameters ----------------------------------------------- -baseline_frequency_peakheight: 5 # the min peak height of the baseline instfreq -baseline_min_amplitude: 0.0001 # the minimal value of the baseline envelope -baseline_envelope_cutoff: 25 # envelope estimation cutoff -baseline_envelope_bandpass_lowf: 2 # envelope badpass lower cutoff -baseline_envelope_bandpass_highf: 100 # envelope bandbass higher cutoff -search_envelope_cutoff: 10 # search envelope estimation cufoff +baseline_frequency_peakheight: 5 # the min peak height of the baseline instfreq +baseline_min_amplitude: 0.0001 # the minimal value of the baseline envelope +baseline_envelope_cutoff: 25 # envelope estimation cutoff +baseline_envelope_bandpass_lowf: 2 # envelope badpass lower cutoff +baseline_envelope_bandpass_highf: 100 # envelope bandbass higher cutoff +search_envelope_cutoff: 10 # search envelope estimation cufoff # Peak detecion parameters ---------------------------------------------------- # baseline_prominence: 0.00005 # peak prominence threshold for baseline envelope @@ -39,9 +39,8 @@ search_envelope_cutoff: 10 # search envelope estimation cufoff # frequency_prominence: 2 # peak prominence threshold for baseline freq baseline_prominence: 0.3 # peak prominence threshold for baseline envelope -search_prominence: 0.3 # peak prominence threshold for search envelope -frequency_prominence: 0.3 # peak prominence threshold for baseline freq +search_prominence: 0.3 # peak prominence threshold for search envelope +frequency_prominence: 0.3 # peak prominence threshold for baseline freq # Classify events as chirps if they are less than this time apart chirp_window_threshold: 0.02 - diff --git a/code/eventchirpsplots.py b/code/eventchirpsplots.py index 4ebaa66..5003a7d 100644 --- a/code/eventchirpsplots.py +++ b/code/eventchirpsplots.py @@ -35,28 +35,36 @@ class Behavior: """ def __init__(self, folder_path: str) -> None: - print(f'{folder_path}') - LED_on_time_BORIS = np.load(os.path.join( - folder_path, 'LED_on_time.npy'), allow_pickle=True) - self.time = np.load(os.path.join( - folder_path, "times.npy"), allow_pickle=True) - csv_filename = [f for f in os.listdir(folder_path) if f.endswith( - '.csv')][0] # check if there are more than one csv file + print(f"{folder_path}") + LED_on_time_BORIS = np.load( + os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True + ) + self.time = np.load( + os.path.join(folder_path, "times.npy"), allow_pickle=True + ) + csv_filename = [ + f for f in os.listdir(folder_path) if f.endswith(".csv") + ][ + 0 + ] # check if there are more than one csv file self.dataframe = read_csv(os.path.join(folder_path, csv_filename)) - self.chirps = np.load(os.path.join( - folder_path, 'chirps.npy'), allow_pickle=True) - self.chirps_ids = np.load(os.path.join( - folder_path, 'chirp_ids.npy'), allow_pickle=True) + self.chirps = np.load( + os.path.join(folder_path, "chirps.npy"), allow_pickle=True + ) + self.chirps_ids = np.load( + os.path.join(folder_path, "chirp_ids.npy"), allow_pickle=True + ) for k, key in enumerate(self.dataframe.keys()): key = key.lower() - if ' ' in key: - key = key.replace(' ', '_') - if '(' in key: - key = key.replace('(', '') - key = key.replace(')', '') - setattr(self, key, np.array( - self.dataframe[self.dataframe.keys()[k]])) + if " " in key: + key = key.replace(" ", "_") + if "(" in key: + key = key.replace("(", "") + key = key.replace(")", "") + setattr( + self, key, np.array(self.dataframe[self.dataframe.keys()[k]]) + ) last_LED_t_BORIS = LED_on_time_BORIS[-1] real_time_range = self.time[-1] - self.time[0] @@ -95,17 +103,14 @@ temporal encpding needs to be corrected ... not exactly 25FPS. def correct_chasing_events( - category: np.ndarray, - timestamps: np.ndarray + category: np.ndarray, timestamps: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: + onset_ids = np.arange(len(category))[category == 0] + offset_ids = np.arange(len(category))[category == 1] - onset_ids = np.arange( - len(category))[category == 0] - offset_ids = np.arange( - len(category))[category == 1] - - wrong_bh = np.arange(len(category))[ - category != 2][:-1][np.diff(category[category != 2]) == 0] + wrong_bh = np.arange(len(category))[category != 2][:-1][ + np.diff(category[category != 2]) == 0 + ] if onset_ids[0] > offset_ids[0]: offset_ids = np.delete(offset_ids, 0) help_index = offset_ids[0] @@ -117,12 +122,12 @@ def correct_chasing_events( # Check whether on- or offset is longer and calculate length difference if len(onset_ids) > len(offset_ids): len_diff = len(onset_ids) - len(offset_ids) - logger.info(f'Onsets are greater than offsets by {len_diff}') + logger.info(f"Onsets are greater than offsets by {len_diff}") elif len(onset_ids) < len(offset_ids): len_diff = len(offset_ids) - len(onset_ids) - logger.info(f'Offsets are greater than onsets by {len_diff}') + logger.info(f"Offsets are greater than onsets by {len_diff}") elif len(onset_ids) == len(offset_ids): - logger.info('Chasing events are equal') + logger.info("Chasing events are equal") return category, timestamps @@ -135,8 +140,7 @@ def event_triggered_chirps( dt: float, width: float, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - - event_chirps = [] # chirps that are in specified window around event + event_chirps = [] # chirps that are in specified window around event # timestamps of chirps around event centered on the event timepoint centered_chirps = [] @@ -159,16 +163,19 @@ def event_triggered_chirps( else: # convert list of arrays to one array for plotting centered_chirps = np.concatenate(centered_chirps, axis=0) - centered_chirps_convolved = (acausal_kde1d( - centered_chirps, time, width)) / len(event) + centered_chirps_convolved = ( + acausal_kde1d(centered_chirps, time, width) + ) / len(event) return event_chirps, centered_chirps, centered_chirps_convolved def main(datapath: str): - foldernames = [ - datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath + x)] + datapath + x + "/" + for x in os.listdir(datapath) + if os.path.isdir(datapath + x) + ] nrecording_chirps = [] nrecording_chirps_fish_ids = [] @@ -179,7 +186,7 @@ def main(datapath: str): # Iterate over all recordings and save chirp- and event-timestamps for folder in foldernames: # exclude folder with empty LED_on_time.npy - if folder == '../data/mount_data/2020-05-12-10_00/': + if folder == "../data/mount_data/2020-05-12-10_00/": continue bh = Behavior(folder) @@ -209,7 +216,7 @@ def main(datapath: str): time_before_event = 30 time_after_event = 60 dt = 0.01 - width = 1.5 # width of kernel for all recordings, currently gaussian kernel + width = 1.5 # width of kernel for all recordings, currently gaussian kernel recording_width = 2 # width of kernel for each recording time = np.arange(-time_before_event, time_after_event, dt) @@ -232,18 +239,47 @@ def main(datapath: str): physical_contacts = nrecording_physicals[i] # Chirps around chasing onsets - _, centered_chasing_onset_chirps, cc_chasing_onset_chirps = event_triggered_chirps( - chasing_onsets, chirps, time_before_event, time_after_event, dt, recording_width) + ( + _, + centered_chasing_onset_chirps, + cc_chasing_onset_chirps, + ) = event_triggered_chirps( + chasing_onsets, + chirps, + time_before_event, + time_after_event, + dt, + recording_width, + ) # Chirps around chasing offsets - _, centered_chasing_offset_chirps, cc_chasing_offset_chirps = event_triggered_chirps( - chasing_offsets, chirps, time_before_event, time_after_event, dt, recording_width) + ( + _, + centered_chasing_offset_chirps, + cc_chasing_offset_chirps, + ) = event_triggered_chirps( + chasing_offsets, + chirps, + time_before_event, + time_after_event, + dt, + recording_width, + ) # Chirps around physical contacts - _, centered_physical_chirps, cc_physical_chirps = event_triggered_chirps( - physical_contacts, chirps, time_before_event, time_after_event, dt, recording_width) + ( + _, + centered_physical_chirps, + cc_physical_chirps, + ) = event_triggered_chirps( + physical_contacts, + chirps, + time_before_event, + time_after_event, + dt, + recording_width, + ) nrecording_centered_onset_chirps.append(centered_chasing_onset_chirps) - nrecording_centered_offset_chirps.append( - centered_chasing_offset_chirps) + nrecording_centered_offset_chirps.append(centered_chasing_offset_chirps) nrecording_centered_physical_chirps.append(centered_physical_chirps) ## Shuffled chirps ## @@ -331,12 +367,13 @@ def main(datapath: str): # New bootstrapping approach for n in range(nbootstrapping): - diff_onset = np.diff( - np.sort(flatten(nrecording_centered_onset_chirps))) + diff_onset = np.diff(np.sort(flatten(nrecording_centered_onset_chirps))) diff_offset = np.diff( - np.sort(flatten(nrecording_centered_offset_chirps))) + np.sort(flatten(nrecording_centered_offset_chirps)) + ) diff_physical = np.diff( - np.sort(flatten(nrecording_centered_physical_chirps))) + np.sort(flatten(nrecording_centered_physical_chirps)) + ) np.random.shuffle(diff_onset) shuffled_onset = np.cumsum(diff_onset) @@ -345,9 +382,11 @@ def main(datapath: str): np.random.shuffle(diff_physical) shuffled_physical = np.cumsum(diff_physical) - kde_onset (acausal_kde1d(shuffled_onset, time, width))/(27*100) - kde_offset = (acausal_kde1d(shuffled_offset, time, width))/(27*100) - kde_physical = (acausal_kde1d(shuffled_physical, time, width))/(27*100) + kde_onset(acausal_kde1d(shuffled_onset, time, width)) / (27 * 100) + kde_offset = (acausal_kde1d(shuffled_offset, time, width)) / (27 * 100) + kde_physical = (acausal_kde1d(shuffled_physical, time, width)) / ( + 27 * 100 + ) bootstrap_onset.append(kde_onset) bootstrap_offset.append(kde_offset) @@ -355,11 +394,14 @@ def main(datapath: str): # New shuffle approach q5, q50, q95 onset_q5, onset_median, onset_q95 = np.percentile( - bootstrap_onset, [5, 50, 95], axis=0) + bootstrap_onset, [5, 50, 95], axis=0 + ) offset_q5, offset_median, offset_q95 = np.percentile( - bootstrap_offset, [5, 50, 95], axis=0) + bootstrap_offset, [5, 50, 95], axis=0 + ) physical_q5, physical_median, physical_q95 = np.percentile( - bootstrap_physical, [5, 50, 95], axis=0) + bootstrap_physical, [5, 50, 95], axis=0 + ) # vstack um 1. Dim zu cutten # nrecording_shuffled_convolved_onset_chirps = np.vstack(nrecording_shuffled_convolved_onset_chirps) @@ -378,45 +420,66 @@ def main(datapath: str): # Flatten event timestamps all_onsets = np.concatenate( - nrecording_chasing_onsets).ravel() # not centered + nrecording_chasing_onsets + ).ravel() # not centered all_offsets = np.concatenate( - nrecording_chasing_offsets).ravel() # not centered - all_physicals = np.concatenate( - nrecording_physicals).ravel() # not centered + nrecording_chasing_offsets + ).ravel() # not centered + all_physicals = np.concatenate(nrecording_physicals).ravel() # not centered # Flatten all chirps around events all_onset_chirps = np.concatenate( - nrecording_centered_onset_chirps).ravel() # centered + nrecording_centered_onset_chirps + ).ravel() # centered all_offset_chirps = np.concatenate( - nrecording_centered_offset_chirps).ravel() # centered + nrecording_centered_offset_chirps + ).ravel() # centered all_physical_chirps = np.concatenate( - nrecording_centered_physical_chirps).ravel() # centered + nrecording_centered_physical_chirps + ).ravel() # centered # Convolute all chirps # Divide by total number of each event over all recordings - all_onset_chirps_convolved = (acausal_kde1d( - all_onset_chirps, time, width)) / len(all_onsets) - all_offset_chirps_convolved = (acausal_kde1d( - all_offset_chirps, time, width)) / len(all_offsets) - all_physical_chirps_convolved = (acausal_kde1d( - all_physical_chirps, time, width)) / len(all_physicals) + all_onset_chirps_convolved = ( + acausal_kde1d(all_onset_chirps, time, width) + ) / len(all_onsets) + all_offset_chirps_convolved = ( + acausal_kde1d(all_offset_chirps, time, width) + ) / len(all_offsets) + all_physical_chirps_convolved = ( + acausal_kde1d(all_physical_chirps, time, width) + ) / len(all_physicals) # Plot all events with all shuffled - fig, ax = plt.subplots(1, 3, figsize=( - 28*ps.cm, 16*ps.cm, ), constrained_layout=True, sharey='all') + fig, ax = plt.subplots( + 1, + 3, + figsize=( + 28 * ps.cm, + 16 * ps.cm, + ), + constrained_layout=True, + sharey="all", + ) # offsets = np.arange(1,28,1) - ax[0].set_xlabel('Time[s]') + ax[0].set_xlabel("Time[s]") # Plot chasing onsets - ax[0].set_ylabel('Chirp rate [Hz]') + ax[0].set_ylabel("Chirp rate [Hz]") ax[0].plot(time, all_onset_chirps_convolved, color=ps.yellow, zorder=2) ax0 = ax[0].twinx() nrecording_centered_onset_chirps = np.asarray( - nrecording_centered_onset_chirps, dtype=object) - ax0.eventplot(np.array(nrecording_centered_onset_chirps), - linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1) - ax0.vlines(0, 0, 1.5, ps.white, 'dashed') - ax[0].set_zorder(ax0.get_zorder()+1) + nrecording_centered_onset_chirps, dtype=object + ) + ax0.eventplot( + np.array(nrecording_centered_onset_chirps), + linelengths=0.5, + colors=ps.gray, + alpha=0.25, + zorder=1, + ) + ax0.vlines(0, 0, 1.5, ps.white, "dashed") + ax[0].set_zorder(ax0.get_zorder() + 1) ax[0].patch.set_visible(False) ax0.set_yticklabels([]) ax0.set_yticks([]) @@ -426,15 +489,21 @@ def main(datapath: str): ax[0].plot(time, onset_median, color=ps.black) # Plot chasing offets - ax[1].set_xlabel('Time[s]') + ax[1].set_xlabel("Time[s]") ax[1].plot(time, all_offset_chirps_convolved, color=ps.orange, zorder=2) ax1 = ax[1].twinx() nrecording_centered_offset_chirps = np.asarray( - nrecording_centered_offset_chirps, dtype=object) - ax1.eventplot(np.array(nrecording_centered_offset_chirps), - linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1) - ax1.vlines(0, 0, 1.5, ps.white, 'dashed') - ax[1].set_zorder(ax1.get_zorder()+1) + nrecording_centered_offset_chirps, dtype=object + ) + ax1.eventplot( + np.array(nrecording_centered_offset_chirps), + linelengths=0.5, + colors=ps.gray, + alpha=0.25, + zorder=1, + ) + ax1.vlines(0, 0, 1.5, ps.white, "dashed") + ax[1].set_zorder(ax1.get_zorder() + 1) ax[1].patch.set_visible(False) ax1.set_yticklabels([]) ax1.set_yticks([]) @@ -444,24 +513,31 @@ def main(datapath: str): ax[1].plot(time, offset_median, color=ps.black) # Plot physical contacts - ax[2].set_xlabel('Time[s]') + ax[2].set_xlabel("Time[s]") ax[2].plot(time, all_physical_chirps_convolved, color=ps.maroon, zorder=2) ax2 = ax[2].twinx() nrecording_centered_physical_chirps = np.asarray( - nrecording_centered_physical_chirps, dtype=object) - ax2.eventplot(np.array(nrecording_centered_physical_chirps), - linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1) - ax2.vlines(0, 0, 1.5, ps.white, 'dashed') - ax[2].set_zorder(ax2.get_zorder()+1) + nrecording_centered_physical_chirps, dtype=object + ) + ax2.eventplot( + np.array(nrecording_centered_physical_chirps), + linelengths=0.5, + colors=ps.gray, + alpha=0.25, + zorder=1, + ) + ax2.vlines(0, 0, 1.5, ps.white, "dashed") + ax[2].set_zorder(ax2.get_zorder() + 1) ax[2].patch.set_visible(False) ax2.set_yticklabels([]) ax2.set_yticks([]) # ax[2].fill_between(time, shuffled_q5_physical, shuffled_q95_physical, color=ps.gray, alpha=0.5) # ax[2].plot(time, shuffled_median_physical, ps.black) - ax[2].fill_between(time, physical_q5, physical_q95, - color=ps.gray, alpha=0.5) + ax[2].fill_between( + time, physical_q5, physical_q95, color=ps.gray, alpha=0.5 + ) ax[2].plot(time, physical_median, ps.black) - fig.suptitle('All recordings') + fig.suptitle("All recordings") plt.show() plt.close() @@ -587,7 +663,7 @@ def main(datapath: str): #### Chirps around events, only losers, one recording #### -if __name__ == '__main__': +if __name__ == "__main__": # Path to the data - datapath = '../data/mount_data/' + datapath = "../data/mount_data/" main(datapath) diff --git a/code/extract_chirps.py b/code/extract_chirps.py index 77e3e8d..900f0a2 100644 --- a/code/extract_chirps.py +++ b/code/extract_chirps.py @@ -8,50 +8,51 @@ from IPython import embed def get_valid_datasets(dataroot): - - datasets = sorted([name for name in os.listdir(dataroot) if os.path.isdir( - os.path.join(dataroot, name))]) + datasets = sorted( + [ + name + for name in os.listdir(dataroot) + if os.path.isdir(os.path.join(dataroot, name)) + ] + ) valid_datasets = [] for dataset in datasets: - path = os.path.join(dataroot, dataset) - csv_name = '-'.join(dataset.split('-')[:3]) + '.csv' + csv_name = "-".join(dataset.split("-")[:3]) + ".csv" if os.path.exists(os.path.join(path, csv_name)) is False: continue - if os.path.exists(os.path.join(path, 'ident_v.npy')) is False: + if os.path.exists(os.path.join(path, "ident_v.npy")) is False: continue - ident = np.load(os.path.join(path, 'ident_v.npy')) + ident = np.load(os.path.join(path, "ident_v.npy")) number_of_fish = len(np.unique(ident[~np.isnan(ident)])) if number_of_fish != 2: continue valid_datasets.append(dataset) - datapaths = [os.path.join(dataroot, dataset) + - '/' for dataset in valid_datasets] + datapaths = [ + os.path.join(dataroot, dataset) + "/" for dataset in valid_datasets + ] return datapaths, valid_datasets def main(datapaths): - for path in datapaths: - chirpdetection(path, plot='show') - - -if __name__ == '__main__': + chirpdetection(path, plot="show") - dataroot = '../data/mount_data/' +if __name__ == "__main__": + dataroot = "../data/mount_data/" - datapaths, valid_datasets= get_valid_datasets(dataroot) + datapaths, valid_datasets = get_valid_datasets(dataroot) - recs = pd.DataFrame(columns=['recording'], data=valid_datasets) - recs.to_csv('../recs.csv', index=False) + recs = pd.DataFrame(columns=["recording"], data=valid_datasets) + recs.to_csv("../recs.csv", index=False) # datapaths = ['../data/mount_data/2020-03-25-10_00/'] main(datapaths) diff --git a/code/get_behaviour.py b/code/get_behaviour.py index 36311ca..3513c1b 100644 --- a/code/get_behaviour.py +++ b/code/get_behaviour.py @@ -1,4 +1,4 @@ -import os +import os from paramiko import SSHClient from scp import SCPClient from IPython import embed @@ -7,29 +7,41 @@ from pandas import read_csv ssh = SSHClient() ssh.load_system_host_keys() -ssh.connect(hostname='kraken', - username='efish', - password='fwNix4U', - ) +ssh.connect( + hostname="kraken", + username="efish", + password="fwNix4U", +) # SCPCLient takes a paramiko transport as its only argument scp = SCPClient(ssh.get_transport()) -data = read_csv('../recs.csv') -foldernames = data['recording'].values +data = read_csv("../recs.csv") +foldernames = data["recording"].values -directory = f'/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/mount_data/' +directory = f"/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/mount_data/" for foldername in foldernames: - - if not os.path.exists(directory+foldername): - os.makedirs(directory+foldername) - - files = [('-').join(foldername.split('-')[:3])+'.csv','chirp_ids.npy', 'chirps.npy', 'fund_v.npy', 'ident_v.npy', 'idx_v.npy', 'times.npy', 'spec.npy', 'LED_on_time.npy', 'sign_v.npy'] - + if not os.path.exists(directory + foldername): + os.makedirs(directory + foldername) + + files = [ + ("-").join(foldername.split("-")[:3]) + ".csv", + "chirp_ids.npy", + "chirps.npy", + "fund_v.npy", + "ident_v.npy", + "idx_v.npy", + "times.npy", + "spec.npy", + "LED_on_time.npy", + "sign_v.npy", + ] for f in files: - scp.get(f'/home/efish/behavior/2019_tube_competition/{foldername}/{f}', - directory+foldername) + scp.get( + f"/home/efish/behavior/2019_tube_competition/{foldername}/{f}", + directory + foldername, + ) scp.close() diff --git a/code/modules/behaviour_handling.py b/code/modules/behaviour_handling.py index 94a0ca1..a50d67a 100644 --- a/code/modules/behaviour_handling.py +++ b/code/modules/behaviour_handling.py @@ -30,12 +30,12 @@ class Behavior: """ def __init__(self, folder_path: str) -> None: - - LED_on_time_BORIS = np.load(os.path.join( - folder_path, 'LED_on_time.npy'), allow_pickle=True) + LED_on_time_BORIS = np.load( + os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True + ) csv_filename = os.path.split(folder_path[:-1])[-1] - csv_filename = '-'.join(csv_filename.split('-')[:-1]) + '.csv' + csv_filename = "-".join(csv_filename.split("-")[:-1]) + ".csv" # embed() # csv_filename = [f for f in os.listdir( @@ -43,31 +43,39 @@ class Behavior: # logger.info(f'CSV file: {csv_filename}') self.dataframe = read_csv(os.path.join(folder_path, csv_filename)) - self.chirps = np.load(os.path.join( - folder_path, 'chirps.npy'), allow_pickle=True) - self.chirps_ids = np.load(os.path.join( - folder_path, 'chirp_ids.npy'), allow_pickle=True) - - self.ident = np.load(os.path.join( - folder_path, 'ident_v.npy'), allow_pickle=True) - self.idx = np.load(os.path.join( - folder_path, 'idx_v.npy'), allow_pickle=True) - self.freq = np.load(os.path.join( - folder_path, 'fund_v.npy'), allow_pickle=True) - self.time = np.load(os.path.join( - folder_path, "times.npy"), allow_pickle=True) - self.spec = np.load(os.path.join( - folder_path, "spec.npy"), allow_pickle=True) + self.chirps = np.load( + os.path.join(folder_path, "chirps.npy"), allow_pickle=True + ) + self.chirps_ids = np.load( + os.path.join(folder_path, "chirp_ids.npy"), allow_pickle=True + ) + + self.ident = np.load( + os.path.join(folder_path, "ident_v.npy"), allow_pickle=True + ) + self.idx = np.load( + os.path.join(folder_path, "idx_v.npy"), allow_pickle=True + ) + self.freq = np.load( + os.path.join(folder_path, "fund_v.npy"), allow_pickle=True + ) + self.time = np.load( + os.path.join(folder_path, "times.npy"), allow_pickle=True + ) + self.spec = np.load( + os.path.join(folder_path, "spec.npy"), allow_pickle=True + ) for k, key in enumerate(self.dataframe.keys()): key = key.lower() - if ' ' in key: - key = key.replace(' ', '_') - if '(' in key: - key = key.replace('(', '') - key = key.replace(')', '') - setattr(self, key, np.array( - self.dataframe[self.dataframe.keys()[k]])) + if " " in key: + key = key.replace(" ", "_") + if "(" in key: + key = key.replace("(", "") + key = key.replace(")", "") + setattr( + self, key, np.array(self.dataframe[self.dataframe.keys()[k]]) + ) last_LED_t_BORIS = LED_on_time_BORIS[-1] real_time_range = self.time[-1] - self.time[0] @@ -78,22 +86,19 @@ class Behavior: def correct_chasing_events( - category: np.ndarray, - timestamps: np.ndarray + category: np.ndarray, timestamps: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: + onset_ids = np.arange(len(category))[category == 0] + offset_ids = np.arange(len(category))[category == 1] - onset_ids = np.arange( - len(category))[category == 0] - offset_ids = np.arange( - len(category))[category == 1] - - wrong_bh = np.arange(len(category))[ - category != 2][:-1][np.diff(category[category != 2]) == 0] + wrong_bh = np.arange(len(category))[category != 2][:-1][ + np.diff(category[category != 2]) == 0 + ] if category[category != 2][-1] == 0: wrong_bh = np.append( - wrong_bh, - np.arange(len(category))[category != 2][-1]) + wrong_bh, np.arange(len(category))[category != 2][-1] + ) if onset_ids[0] > offset_ids[0]: offset_ids = np.delete(offset_ids, 0) @@ -103,18 +108,16 @@ def correct_chasing_events( category = np.delete(category, wrong_bh) timestamps = np.delete(timestamps, wrong_bh) - new_onset_ids = np.arange( - len(category))[category == 0] - new_offset_ids = np.arange( - len(category))[category == 1] + new_onset_ids = np.arange(len(category))[category == 0] + new_offset_ids = np.arange(len(category))[category == 1] # Check whether on- or offset is longer and calculate length difference if len(new_onset_ids) > len(new_offset_ids): embed() - logger.warning('Onsets are greater than offsets') + logger.warning("Onsets are greater than offsets") elif len(new_onset_ids) < len(new_offset_ids): - logger.warning('Offsets are greater than onsets') + logger.warning("Offsets are greater than onsets") elif len(new_onset_ids) == len(new_offset_ids): # logger.info('Chasing events are equal') pass @@ -130,13 +133,11 @@ def center_chirps( # dt: float, # width: float, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - - event_chirps = [] # chirps that are in specified window around event + event_chirps = [] # chirps that are in specified window around event # timestamps of chirps around event centered on the event timepoint centered_chirps = [] for event_timestamp in events: - start = event_timestamp - time_before_event stop = event_timestamp + time_after_event chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)] @@ -152,7 +153,8 @@ def center_chirps( if len(centered_chirps) != len(event_chirps): raise ValueError( - 'Non centered chirps and centered chirps are not equal') + "Non centered chirps and centered chirps are not equal" + ) # time = np.arange(-time_before_event, time_after_event, dt) diff --git a/code/modules/datahandling.py b/code/modules/datahandling.py index 0a240ab..68e73cd 100644 --- a/code/modules/datahandling.py +++ b/code/modules/datahandling.py @@ -23,7 +23,9 @@ def minmaxnorm(data): return (data - np.min(data)) / (np.max(data) - np.min(data)) -def instantaneous_frequency2(signal: np.ndarray, fs: float, interpolation: str = 'linear') -> np.ndarray: +def instantaneous_frequency2( + signal: np.ndarray, fs: float, interpolation: str = "linear" +) -> np.ndarray: """ Compute the instantaneous frequency of a periodic signal using zero crossings and resample the frequency using linear or cubic interpolation to match the dimensions of the input array. @@ -55,10 +57,10 @@ def instantaneous_frequency2(signal: np.ndarray, fs: float, interpolation: str = orig_len = len(signal) freq = resample(freq, orig_len) - if interpolation == 'linear': + if interpolation == "linear": freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq) - elif interpolation == 'cubic': - freq = resample(freq, orig_len, window='cubic') + elif interpolation == "cubic": + freq = resample(freq, orig_len, window="cubic") return freq @@ -67,7 +69,7 @@ def instantaneous_frequency( signal: np.ndarray, samplerate: int, smoothing_window: int, - interpolation: str = 'linear', + interpolation: str = "linear", ) -> np.ndarray: """ Compute the instantaneous frequency of a signal that is approximately @@ -120,11 +122,10 @@ def instantaneous_frequency( orig_len = len(signal) freq = resample(instantaneous_frequency, orig_len) - if interpolation == 'linear': + if interpolation == "linear": freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq) - elif interpolation == 'cubic': - freq = resample(freq, orig_len, window='cubic') - + elif interpolation == "cubic": + freq = resample(freq, orig_len, window="cubic") return freq @@ -160,7 +161,6 @@ def purge_duplicates( group = [timestamps[0]] for i in range(1, len(timestamps)): - # check the difference between current timestamp and previous # timestamp is less than the threshold if timestamps[i] - timestamps[i - 1] < threshold: @@ -379,7 +379,6 @@ def acausal_kde1d(spikes, time, width): if __name__ == "__main__": - timestamps = [ [1.2, 1.5, 1.3], [], diff --git a/code/modules/filehandling.py b/code/modules/filehandling.py index c3c71f2..382a49d 100644 --- a/code/modules/filehandling.py +++ b/code/modules/filehandling.py @@ -35,7 +35,6 @@ class LoadData: """ def __init__(self, datapath: str) -> None: - # load raw data self.datapath = datapath self.file = os.path.join(datapath, "traces-grid1.raw") diff --git a/code/modules/filters.py b/code/modules/filters.py index e6d9896..06fe236 100644 --- a/code/modules/filters.py +++ b/code/modules/filters.py @@ -3,10 +3,10 @@ import numpy as np def bandpass_filter( - signal: np.ndarray, - samplerate: float, - lowf: float, - highf: float, + signal: np.ndarray, + samplerate: float, + lowf: float, + highf: float, ) -> np.ndarray: """Bandpass filter a signal. @@ -60,9 +60,7 @@ def highpass_filter( def lowpass_filter( - signal: np.ndarray, - samplerate: float, - cutoff: float + signal: np.ndarray, samplerate: float, cutoff: float ) -> np.ndarray: """Lowpass filter a signal. @@ -86,10 +84,9 @@ def lowpass_filter( return filtered_signal -def envelope(signal: np.ndarray, - samplerate: float, - cutoff_frequency: float - ) -> np.ndarray: +def envelope( + signal: np.ndarray, samplerate: float, cutoff_frequency: float +) -> np.ndarray: """Calculate the envelope of a signal using a lowpass filter. Parameters diff --git a/code/modules/logger.py b/code/modules/logger.py index 5dabf80..ed6d93e 100644 --- a/code/modules/logger.py +++ b/code/modules/logger.py @@ -2,12 +2,13 @@ import logging def makeLogger(name: str): - # create logger formats for file and terminal file_formatter = logging.Formatter( - "[ %(levelname)s ] ~ %(asctime)s ~ %(module)s.%(funcName)s: %(message)s") + "[ %(levelname)s ] ~ %(asctime)s ~ %(module)s.%(funcName)s: %(message)s" + ) console_formatter = logging.Formatter( - "[ %(levelname)s ] in %(module)s.%(funcName)s: %(message)s") + "[ %(levelname)s ] in %(module)s.%(funcName)s: %(message)s" + ) # create logging file if loglevel is debug file_handler = logging.FileHandler(f"gridtools_log.log", mode="w") @@ -29,7 +30,6 @@ def makeLogger(name: str): if __name__ == "__main__": - # initiate logger mylogger = makeLogger(__name__) diff --git a/code/modules/plotstyle.py b/code/modules/plotstyle.py index 22b14c6..43d12ac 100644 --- a/code/modules/plotstyle.py +++ b/code/modules/plotstyle.py @@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap def PlotStyle() -> None: class style: - # lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8) # units @@ -76,13 +75,15 @@ def PlotStyle() -> None: va="center", zorder=1000, bbox=dict( - boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1 + boxstyle=f"circle, pad={padding}", + fc="white", + ec="black", + lw=1, ), ) @classmethod def fade_cmap(cls, cmap): - my_cmap = cmap(np.arange(cmap.N)) my_cmap[:, -1] = np.linspace(0, 1, cmap.N) my_cmap = ListedColormap(my_cmap) @@ -295,7 +296,6 @@ def PlotStyle() -> None: if __name__ == "__main__": - s = PlotStyle() import matplotlib.cbook as cbook @@ -347,7 +347,8 @@ if __name__ == "__main__": for ax in axs: ax.yaxis.grid(True) ax.set_xticks( - [y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"] + [y + 1 for y in range(len(all_data))], + labels=["x1", "x2", "x3", "x4"], ) ax.set_xlabel("Four separate samples") ax.set_ylabel("Observed values") @@ -396,7 +397,10 @@ if __name__ == "__main__": grid = np.random.rand(4, 4) fig, axs = plt.subplots( - nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []} + nrows=3, + ncols=6, + figsize=(9, 6), + subplot_kw={"xticks": [], "yticks": []}, ) for ax, interp_method in zip(axs.flat, methods): diff --git a/code/modules/plotstyle1.py b/code/modules/plotstyle1.py index 32af4d2..237996b 100644 --- a/code/modules/plotstyle1.py +++ b/code/modules/plotstyle1.py @@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap def PlotStyle() -> None: class style: - # lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8) # units @@ -76,13 +75,15 @@ def PlotStyle() -> None: va="center", zorder=1000, bbox=dict( - boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1 + boxstyle=f"circle, pad={padding}", + fc="white", + ec="black", + lw=1, ), ) @classmethod def fade_cmap(cls, cmap): - my_cmap = cmap(np.arange(cmap.N)) my_cmap[:, -1] = np.linspace(0, 1, cmap.N) my_cmap = ListedColormap(my_cmap) @@ -295,7 +296,6 @@ def PlotStyle() -> None: if __name__ == "__main__": - s = PlotStyle() import matplotlib.cbook as cbook @@ -347,7 +347,8 @@ if __name__ == "__main__": for ax in axs: ax.yaxis.grid(True) ax.set_xticks( - [y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"] + [y + 1 for y in range(len(all_data))], + labels=["x1", "x2", "x3", "x4"], ) ax.set_xlabel("Four separate samples") ax.set_ylabel("Observed values") @@ -396,7 +397,10 @@ if __name__ == "__main__": grid = np.random.rand(4, 4) fig, axs = plt.subplots( - nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []} + nrows=3, + ncols=6, + figsize=(9, 6), + subplot_kw={"xticks": [], "yticks": []}, ) for ax, interp_method in zip(axs.flat, methods): diff --git a/code/modules/plotstyle_dark.py b/code/modules/plotstyle_dark.py index d5b9557..d767e24 100644 --- a/code/modules/plotstyle_dark.py +++ b/code/modules/plotstyle_dark.py @@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap def PlotStyle() -> None: class style: - # lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8) # units @@ -76,13 +75,15 @@ def PlotStyle() -> None: va="center", zorder=1000, bbox=dict( - boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1 + boxstyle=f"circle, pad={padding}", + fc="white", + ec="black", + lw=1, ), ) @classmethod def fade_cmap(cls, cmap): - my_cmap = cmap(np.arange(cmap.N)) my_cmap[:, -1] = np.linspace(0, 1, cmap.N) my_cmap = ListedColormap(my_cmap) @@ -295,7 +296,6 @@ def PlotStyle() -> None: if __name__ == "__main__": - s = PlotStyle() import matplotlib.cbook as cbook @@ -347,7 +347,8 @@ if __name__ == "__main__": for ax in axs: ax.yaxis.grid(True) ax.set_xticks( - [y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"] + [y + 1 for y in range(len(all_data))], + labels=["x1", "x2", "x3", "x4"], ) ax.set_xlabel("Four separate samples") ax.set_ylabel("Observed values") @@ -396,7 +397,10 @@ if __name__ == "__main__": grid = np.random.rand(4, 4) fig, axs = plt.subplots( - nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []} + nrows=3, + ncols=6, + figsize=(9, 6), + subplot_kw={"xticks": [], "yticks": []}, ) for ax, interp_method in zip(axs.flat, methods): diff --git a/code/modules/simulations.py b/code/modules/simulations.py index 473bac8..a074801 100644 --- a/code/modules/simulations.py +++ b/code/modules/simulations.py @@ -37,7 +37,7 @@ def create_chirp( ck = 0 csig = 0.5 * chirpduration / np.power(2.0 * np.log(10.0), 0.5 / kurtosis) - #csig = csig*-1 + # csig = csig*-1 for k, t in enumerate(time): a = 1.0 f = eodf diff --git a/code/plot_chirp_size.py b/code/plot_chirp_size.py index 95b2a95..1153ff5 100644 --- a/code/plot_chirp_size.py +++ b/code/plot_chirp_size.py @@ -16,26 +16,25 @@ logger = makeLogger(__name__) def get_chirp_winner_loser(folder_name, Behavior, order_meta_df): - - foldername = folder_name.split('/')[-2] - winner_row = order_meta_df[order_meta_df['recording'] == foldername] - winner = winner_row['winner'].values[0].astype(int) - winner_fish1 = winner_row['fish1'].values[0].astype(int) - winner_fish2 = winner_row['fish2'].values[0].astype(int) + foldername = folder_name.split("/")[-2] + winner_row = order_meta_df[order_meta_df["recording"] == foldername] + winner = winner_row["winner"].values[0].astype(int) + winner_fish1 = winner_row["fish1"].values[0].astype(int) + winner_fish2 = winner_row["fish2"].values[0].astype(int) if winner > 0: if winner == winner_fish1: - winner_fish_id = winner_row['rec_id1'].values[0] - loser_fish_id = winner_row['rec_id2'].values[0] + winner_fish_id = winner_row["rec_id1"].values[0] + loser_fish_id = winner_row["rec_id2"].values[0] elif winner == winner_fish2: - winner_fish_id = winner_row['rec_id2'].values[0] - loser_fish_id = winner_row['rec_id1'].values[0] + winner_fish_id = winner_row["rec_id2"].values[0] + loser_fish_id = winner_row["rec_id1"].values[0] chirp_winner = len( - Behavior.chirps[Behavior.chirps_ids == winner_fish_id]) - chirp_loser = len( - Behavior.chirps[Behavior.chirps_ids == loser_fish_id]) + Behavior.chirps[Behavior.chirps_ids == winner_fish_id] + ) + chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id]) return chirp_winner, chirp_loser else: @@ -43,24 +42,24 @@ def get_chirp_winner_loser(folder_name, Behavior, order_meta_df): def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df): - - foldername = folder_name.split('/')[-2] - folder_row = order_meta_df[order_meta_df['recording'] == foldername] - fish1 = folder_row['fish1'].values[0].astype(int) - fish2 = folder_row['fish2'].values[0].astype(int) - winner = folder_row['winner'].values[0].astype(int) - - groub = folder_row['group'].values[0].astype(int) - size_fish1_row = id_meta_df[(id_meta_df['group'] == groub) & ( - id_meta_df['fish'] == fish1)] - size_fish2_row = id_meta_df[(id_meta_df['group'] == groub) & ( - id_meta_df['fish'] == fish2)] - - size_winners = [size_fish1_row[col].values[0] - for col in ['l1', 'l2', 'l3']] + foldername = folder_name.split("/")[-2] + folder_row = order_meta_df[order_meta_df["recording"] == foldername] + fish1 = folder_row["fish1"].values[0].astype(int) + fish2 = folder_row["fish2"].values[0].astype(int) + winner = folder_row["winner"].values[0].astype(int) + + groub = folder_row["group"].values[0].astype(int) + size_fish1_row = id_meta_df[ + (id_meta_df["group"] == groub) & (id_meta_df["fish"] == fish1) + ] + size_fish2_row = id_meta_df[ + (id_meta_df["group"] == groub) & (id_meta_df["fish"] == fish2) + ] + + size_winners = [size_fish1_row[col].values[0] for col in ["l1", "l2", "l3"]] size_fish1 = np.nanmean(size_winners) - size_losers = [size_fish2_row[col].values[0] for col in ['l1', 'l2', 'l3']] + size_losers = [size_fish2_row[col].values[0] for col in ["l1", "l2", "l3"]] size_fish2 = np.nanmean(size_losers) if winner == fish1: @@ -75,8 +74,8 @@ def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df): size_diff_bigger = 0 size_diff_smaller = 0 - winner_fish_id = folder_row['rec_id1'].values[0] - loser_fish_id = folder_row['rec_id2'].values[0] + winner_fish_id = folder_row["rec_id1"].values[0] + loser_fish_id = folder_row["rec_id2"].values[0] elif winner == fish2: if size_fish2 > size_fish1: @@ -90,39 +89,39 @@ def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df): size_diff_bigger = 0 size_diff_smaller = 0 - winner_fish_id = folder_row['rec_id2'].values[0] - loser_fish_id = folder_row['rec_id1'].values[0] + winner_fish_id = folder_row["rec_id2"].values[0] + loser_fish_id = folder_row["rec_id1"].values[0] else: size_diff_bigger = np.nan size_diff_smaller = np.nan winner_fish_id = np.nan loser_fish_id = np.nan - return size_diff_bigger, size_diff_smaller, winner_fish_id, loser_fish_id + return ( + size_diff_bigger, + size_diff_smaller, + winner_fish_id, + loser_fish_id, + ) - chirp_winner = len( - Behavior.chirps[Behavior.chirps_ids == winner_fish_id]) - chirp_loser = len( - Behavior.chirps[Behavior.chirps_ids == loser_fish_id]) + chirp_winner = len(Behavior.chirps[Behavior.chirps_ids == winner_fish_id]) + chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id]) - return size_diff_bigger, chirp_winner, size_diff_smaller, chirp_loser + return size_diff_bigger, chirp_winner, size_diff_smaller, chirp_loser def get_chirp_freq(folder_name, Behavior, order_meta_df): + foldername = folder_name.split("/")[-2] + folder_row = order_meta_df[order_meta_df["recording"] == foldername] + fish1 = folder_row["fish1"].values[0].astype(int) + fish2 = folder_row["fish2"].values[0].astype(int) - foldername = folder_name.split('/')[-2] - folder_row = order_meta_df[order_meta_df['recording'] == foldername] - fish1 = folder_row['fish1'].values[0].astype(int) - fish2 = folder_row['fish2'].values[0].astype(int) + fish1_freq = folder_row["rec_id1"].values[0].astype(int) + fish2_freq = folder_row["rec_id2"].values[0].astype(int) - fish1_freq = folder_row['rec_id1'].values[0].astype(int) - fish2_freq = folder_row['rec_id2'].values[0].astype(int) - - chirp_freq_fish1 = np.nanmedian( - Behavior.freq[Behavior.ident == fish1_freq]) - chirp_freq_fish2 = np.nanmedian( - Behavior.freq[Behavior.ident == fish2_freq]) - winner = folder_row['winner'].values[0].astype(int) + chirp_freq_fish1 = np.nanmedian(Behavior.freq[Behavior.ident == fish1_freq]) + chirp_freq_fish2 = np.nanmedian(Behavior.freq[Behavior.ident == fish2_freq]) + winner = folder_row["winner"].values[0].astype(int) if winner == fish1: # if chirp_freq_fish1 > chirp_freq_fish2: @@ -138,9 +137,9 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df): # winner_fish_id = np.nan # loser_fish_id = np.nan - winner_fish_id = folder_row['rec_id1'].values[0] + winner_fish_id = folder_row["rec_id1"].values[0] winner_fish_freq = chirp_freq_fish1 - loser_fish_id = folder_row['rec_id2'].values[0] + loser_fish_id = folder_row["rec_id2"].values[0] loser_fish_freq = chirp_freq_fish2 elif winner == fish2: @@ -157,9 +156,9 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df): # winner_fish_id = np.nan # loser_fish_id = np.nan - winner_fish_id = folder_row['rec_id2'].values[0] + winner_fish_id = folder_row["rec_id2"].values[0] winner_fish_freq = chirp_freq_fish2 - loser_fish_id = folder_row['rec_id1'].values[0] + loser_fish_id = folder_row["rec_id1"].values[0] loser_fish_freq = chirp_freq_fish1 else: winner_fish_freq = np.nan @@ -168,25 +167,25 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df): loser_fish_id = np.nan return winner_fish_freq, winner_fish_id, loser_fish_freq, loser_fish_id - chirp_winner = len( - Behavior.chirps[Behavior.chirps_ids == winner_fish_id]) - chirp_loser = len( - Behavior.chirps[Behavior.chirps_ids == loser_fish_id]) + chirp_winner = len(Behavior.chirps[Behavior.chirps_ids == winner_fish_id]) + chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id]) return winner_fish_freq, chirp_winner, loser_fish_freq, chirp_loser def main(datapath: str): - foldernames = [ - datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)] + datapath + x + "/" + for x in os.listdir(datapath) + if os.path.isdir(datapath + x) + ] foldernames, _ = get_valid_datasets(datapath) - path_order_meta = ( - '/').join(foldernames[0].split('/')[:-2]) + '/order_meta.csv' + path_order_meta = ("/").join( + foldernames[0].split("/")[:-2] + ) + "/order_meta.csv" order_meta_df = read_csv(path_order_meta) - order_meta_df['recording'] = order_meta_df['recording'].str[1:-1] - path_id_meta = ( - '/').join(foldernames[0].split('/')[:-2]) + '/id_meta.csv' + order_meta_df["recording"] = order_meta_df["recording"].str[1:-1] + path_id_meta = ("/").join(foldernames[0].split("/")[:-2]) + "/id_meta.csv" id_meta_df = read_csv(path_id_meta) chirps_winner = [] @@ -202,10 +201,9 @@ def main(datapath: str): freq_chirps_winner = [] freq_chirps_loser = [] - for foldername in foldernames: # behabvior is pandas dataframe with all the data - if foldername == '../data/mount_data/2020-05-12-10_00/': + if foldername == "../data/mount_data/2020-05-12-10_00/": continue bh = Behavior(foldername) # chirps are not sorted in time (presumably due to prior groupings) @@ -217,15 +215,24 @@ def main(datapath: str): category, timestamps = correct_chasing_events(category, timestamps) winner_chirp, loser_chirp = get_chirp_winner_loser( - foldername, bh, order_meta_df) + foldername, bh, order_meta_df + ) chirps_winner.append(winner_chirp) chirps_loser.append(loser_chirp) - size_diff_bigger, chirp_winner, size_diff_smaller, chirp_loser = get_chirp_size( - foldername, bh, order_meta_df, id_meta_df) + ( + size_diff_bigger, + chirp_winner, + size_diff_smaller, + chirp_loser, + ) = get_chirp_size(foldername, bh, order_meta_df, id_meta_df) - freq_winner, chirp_freq_winner, freq_loser, chirp_freq_loser = get_chirp_freq( - foldername, bh, order_meta_df) + ( + freq_winner, + chirp_freq_winner, + freq_loser, + chirp_freq_loser, + ) = get_chirp_freq(foldername, bh, order_meta_df) freq_diffs_higher.append(freq_winner) freq_diffs_lower.append(freq_loser) @@ -242,82 +249,124 @@ def main(datapath: str): pearsonr(size_diffs_winner, size_chirps_winner) pearsonr(size_diffs_loser, size_chirps_loser) - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=( - 21*ps.cm, 7*ps.cm), width_ratios=[1, 0.8, 0.8], sharey=True) - plt.subplots_adjust(left=0.11, right=0.948, top=0.86, - wspace=0.343, bottom=0.198) + fig, (ax1, ax2, ax3) = plt.subplots( + 1, + 3, + figsize=(21 * ps.cm, 7 * ps.cm), + width_ratios=[1, 0.8, 0.8], + sharey=True, + ) + plt.subplots_adjust( + left=0.11, right=0.948, top=0.86, wspace=0.343, bottom=0.198 + ) scatterwinner = 1.15 scatterloser = 1.85 chirps_winner = np.asarray(chirps_winner)[~np.isnan(chirps_winner)] chirps_loser = np.asarray(chirps_loser)[~np.isnan(chirps_loser)] embed() exit() - freq_diffs_higher = np.asarray( - freq_diffs_higher)[~np.isnan(freq_diffs_higher)] - freq_diffs_lower = np.asarray(freq_diffs_lower)[ - ~np.isnan(freq_diffs_lower)] - freq_chirps_winner = np.asarray( - freq_chirps_winner)[~np.isnan(freq_chirps_winner)] - freq_chirps_loser = np.asarray( - freq_chirps_loser)[~np.isnan(freq_chirps_loser)] + freq_diffs_higher = np.asarray(freq_diffs_higher)[ + ~np.isnan(freq_diffs_higher) + ] + freq_diffs_lower = np.asarray(freq_diffs_lower)[~np.isnan(freq_diffs_lower)] + freq_chirps_winner = np.asarray(freq_chirps_winner)[ + ~np.isnan(freq_chirps_winner) + ] + freq_chirps_loser = np.asarray(freq_chirps_loser)[ + ~np.isnan(freq_chirps_loser) + ] stat = wilcoxon(chirps_winner, chirps_loser) print(stat) winner_color = ps.gblue2 loser_color = ps.gblue1 - bplot1 = ax1.boxplot(chirps_winner, positions=[ - 0.9], showfliers=False, patch_artist=True) - - bplot2 = ax1.boxplot(chirps_loser, positions=[ - 2.1], showfliers=False, patch_artist=True) - - ax1.scatter(np.ones(len(chirps_winner)) * - scatterwinner, chirps_winner, color=winner_color) - ax1.scatter(np.ones(len(chirps_loser)) * - scatterloser, chirps_loser, color=loser_color) - ax1.set_xticklabels(['Winner', 'Loser']) - - ax1.text(0.1, 0.85, f'n={len(chirps_loser)}', - transform=ax1.transAxes, color=ps.white) + bplot1 = ax1.boxplot( + chirps_winner, positions=[0.9], showfliers=False, patch_artist=True + ) + + bplot2 = ax1.boxplot( + chirps_loser, positions=[2.1], showfliers=False, patch_artist=True + ) + + ax1.scatter( + np.ones(len(chirps_winner)) * scatterwinner, + chirps_winner, + color=winner_color, + ) + ax1.scatter( + np.ones(len(chirps_loser)) * scatterloser, + chirps_loser, + color=loser_color, + ) + ax1.set_xticklabels(["Winner", "Loser"]) + + ax1.text( + 0.1, + 0.85, + f"n={len(chirps_loser)}", + transform=ax1.transAxes, + color=ps.white, + ) for w, l in zip(chirps_winner, chirps_loser): - ax1.plot([scatterwinner, scatterloser], [w, l], - color=ps.white, alpha=0.6, linewidth=1, zorder=-1) - ax1.set_ylabel('Chirp counts', color=ps.white) - ax1.set_xlabel('Competition outcome', color=ps.white) + ax1.plot( + [scatterwinner, scatterloser], + [w, l], + color=ps.white, + alpha=0.6, + linewidth=1, + zorder=-1, + ) + ax1.set_ylabel("Chirp counts", color=ps.white) + ax1.set_xlabel("Competition outcome", color=ps.white) ps.set_boxplot_color(bplot1, winner_color) ps.set_boxplot_color(bplot2, loser_color) - ax2.scatter(size_diffs_winner, size_chirps_winner, - color=winner_color, label='Winner') - ax2.scatter(size_diffs_loser, size_chirps_loser, - color=loser_color, label='Loser') - - ax2.text(0.05, 0.85, f'n={len(size_chirps_loser)}', - transform=ax2.transAxes, color=ps.white) - - ax2.set_xlabel('Size difference [cm]') + ax2.scatter( + size_diffs_winner, + size_chirps_winner, + color=winner_color, + label="Winner", + ) + ax2.scatter( + size_diffs_loser, size_chirps_loser, color=loser_color, label="Loser" + ) + + ax2.text( + 0.05, + 0.85, + f"n={len(size_chirps_loser)}", + transform=ax2.transAxes, + color=ps.white, + ) + + ax2.set_xlabel("Size difference [cm]") # ax2.set_xticks(np.arange(-10, 10.1, 2)) ax3.scatter(freq_diffs_higher, freq_chirps_winner, color=winner_color) ax3.scatter(freq_diffs_lower, freq_chirps_loser, color=loser_color) - ax3.text(0.1, 0.85, f'n={len(np.asarray(freq_chirps_winner)[~np.isnan(freq_chirps_loser)])}', - transform=ax3.transAxes, color=ps.white) + ax3.text( + 0.1, + 0.85, + f"n={len(np.asarray(freq_chirps_winner)[~np.isnan(freq_chirps_loser)])}", + transform=ax3.transAxes, + color=ps.white, + ) - ax3.set_xlabel('EODf [Hz]') + ax3.set_xlabel("EODf [Hz]") handles, labels = ax2.get_legend_handles_labels() - fig.legend(handles, labels, loc='upper center', - ncol=2, bbox_to_anchor=(0.5, 1.04)) + fig.legend( + handles, labels, loc="upper center", ncol=2, bbox_to_anchor=(0.5, 1.04) + ) # pearson r - plt.savefig('../poster/figs/chirps_winner_loser.pdf') + plt.savefig("../poster/figs/chirps_winner_loser.pdf") plt.show() -if __name__ == '__main__': - +if __name__ == "__main__": # Path to the data - datapath = '../data/mount_data/' + datapath = "../data/mount_data/" main(datapath) diff --git a/code/plot_chirps_in_chasing.py b/code/plot_chirps_in_chasing.py index ee43196..ef0e5a7 100644 --- a/code/plot_chirps_in_chasing.py +++ b/code/plot_chirps_in_chasing.py @@ -21,14 +21,16 @@ logger = makeLogger(__name__) def main(datapath: str): - foldernames = [ - datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)] + datapath + x + "/" + for x in os.listdir(datapath) + if os.path.isdir(datapath + x) + ] time_precents = [] chirps_percents = [] for foldername in foldernames: # behabvior is pandas dataframe with all the data - if foldername == '../data/mount_data/2020-05-12-10_00/': + if foldername == "../data/mount_data/2020-05-12-10_00/": continue bh = Behavior(foldername) @@ -46,50 +48,70 @@ def main(datapath: str): chirps_in_chasings = [] for onset, offset in zip(chasing_onset, chasing_offset): chirps_in_chasing = [ - c for c in bh.chirps if (c > onset) & (c < offset)] + c for c in bh.chirps if (c > onset) & (c < offset) + ] chirps_in_chasings.append(chirps_in_chasing) try: time_chasing = np.sum( - chasing_offset[chasing_offset < 3*60*60] - chasing_onset[chasing_onset < 3*60*60]) + chasing_offset[chasing_offset < 3 * 60 * 60] + - chasing_onset[chasing_onset < 3 * 60 * 60] + ) except: time_chasing = np.sum( - chasing_offset[chasing_offset < 3*60*60] - chasing_onset[chasing_onset < 3*60*60][:-1]) + chasing_offset[chasing_offset < 3 * 60 * 60] + - chasing_onset[chasing_onset < 3 * 60 * 60][:-1] + ) - time_chasing_percent = (time_chasing/(3*60*60))*100 + time_chasing_percent = (time_chasing / (3 * 60 * 60)) * 100 chirps_chasing = np.asarray(flatten(chirps_in_chasings)) - chirps_chasing_new = chirps_chasing[chirps_chasing < 3*60*60] - chirps_percent = (len(chirps_chasing_new) / - len(bh.chirps[bh.chirps < 3*60*60]))*100 + chirps_chasing_new = chirps_chasing[chirps_chasing < 3 * 60 * 60] + chirps_percent = ( + len(chirps_chasing_new) / len(bh.chirps[bh.chirps < 3 * 60 * 60]) + ) * 100 time_precents.append(time_chasing_percent) chirps_percents.append(chirps_percent) - fig, ax = plt.subplots(1, 1, figsize=(7*ps.cm, 7*ps.cm)) + fig, ax = plt.subplots(1, 1, figsize=(7 * ps.cm, 7 * ps.cm)) scatter_time = 1.20 scatter_chirps = 1.80 size = 10 - bplot1 = ax.boxplot([time_precents, chirps_percents], - showfliers=False, patch_artist=True) + bplot1 = ax.boxplot( + [time_precents, chirps_percents], showfliers=False, patch_artist=True + ) ps.set_boxplot_color(bplot1, ps.gray) - ax.set_xticklabels(['Time \nchasing', 'Chirps \nin chasing']) - ax.set_ylabel('Percent') - ax.scatter(np.ones(len(time_precents))*scatter_time, time_precents, - facecolor=ps.white, s=size) - ax.scatter(np.ones(len(chirps_percents))*scatter_chirps, chirps_percents, - facecolor=ps.white, s=size) + ax.set_xticklabels(["Time \nchasing", "Chirps \nin chasing"]) + ax.set_ylabel("Percent") + ax.scatter( + np.ones(len(time_precents)) * scatter_time, + time_precents, + facecolor=ps.white, + s=size, + ) + ax.scatter( + np.ones(len(chirps_percents)) * scatter_chirps, + chirps_percents, + facecolor=ps.white, + s=size, + ) for i in range(len(time_precents)): - ax.plot([scatter_time, scatter_chirps], [time_precents[i], - chirps_percents[i]], alpha=0.6, linewidth=1, color=ps.white) - - ax.text(0.1, 0.9, f'n={len(time_precents)}', transform=ax.transAxes) + ax.plot( + [scatter_time, scatter_chirps], + [time_precents[i], chirps_percents[i]], + alpha=0.6, + linewidth=1, + color=ps.white, + ) + + ax.text(0.1, 0.9, f"n={len(time_precents)}", transform=ax.transAxes) plt.subplots_adjust(left=0.221, bottom=0.186, right=0.97, top=0.967) - plt.savefig('../poster/figs/chirps_in_chasing.pdf') + plt.savefig("../poster/figs/chirps_in_chasing.pdf") plt.show() -if __name__ == '__main__': +if __name__ == "__main__": # Path to the data - datapath = '../data/mount_data/' + datapath = "../data/mount_data/" main(datapath) diff --git a/code/plot_event_timeline.py b/code/plot_event_timeline.py index cb75cd9..ab408ee 100644 --- a/code/plot_event_timeline.py +++ b/code/plot_event_timeline.py @@ -13,6 +13,7 @@ from modules.plotstyle import PlotStyle from modules.behaviour_handling import Behavior, correct_chasing_events from extract_chirps import get_valid_datasets + ps = PlotStyle() logger = makeLogger(__name__) @@ -20,13 +21,16 @@ logger = makeLogger(__name__) def main(datapath: str): foldernames = [ - datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)] + datapath + x + "/" + for x in os.listdir(datapath) + if os.path.isdir(datapath + x) + ] foldernames, _ = get_valid_datasets(datapath) for foldername in foldernames[3:4]: print(foldername) # foldername = foldernames[0] - if foldername == '../data/mount_data/2020-05-12-10_00/': + if foldername == "../data/mount_data/2020-05-12-10_00/": continue # behabvior is pandas dataframe with all the data bh = Behavior(foldername) @@ -52,18 +56,43 @@ def main(datapath: str): exit() fish1_color = ps.gblue2 fish2_color = ps.gblue1 - fig, ax = plt.subplots(5, 1, figsize=( - 21*ps.cm, 10*ps.cm), height_ratios=[0.5, 0.5, 0.5, 0.2, 6], sharex=True) + fig, ax = plt.subplots( + 5, + 1, + figsize=(21 * ps.cm, 10 * ps.cm), + height_ratios=[0.5, 0.5, 0.5, 0.2, 6], + sharex=True, + ) # marker size s = 80 - ax[0].scatter(physical_contact, np.ones( - len(physical_contact)), color=ps.gray, marker='|', s=s) - ax[1].scatter(chasing_onset, np.ones(len(chasing_onset)), - color=ps.gray, marker='|', s=s) - ax[2].scatter(fish1, np.ones(len(fish1))-0.25, - color=fish1_color, marker='|', s=s) - ax[2].scatter(fish2, np.zeros(len(fish2))+0.25, - color=fish2_color, marker='|', s=s) + ax[0].scatter( + physical_contact, + np.ones(len(physical_contact)), + color=ps.gray, + marker="|", + s=s, + ) + ax[1].scatter( + chasing_onset, + np.ones(len(chasing_onset)), + color=ps.gray, + marker="|", + s=s, + ) + ax[2].scatter( + fish1, + np.ones(len(fish1)) - 0.25, + color=fish1_color, + marker="|", + s=s, + ) + ax[2].scatter( + fish2, + np.zeros(len(fish2)) + 0.25, + color=fish2_color, + marker="|", + s=s, + ) freq_temp = bh.freq[bh.ident == fish1_id] time_temp = bh.time[bh.idx[bh.ident == fish1_id]] @@ -94,35 +123,38 @@ def main(datapath: str): ax[2].set_xticks([]) ps.hide_ax(ax[2]) - ax[4].axvspan(0, 3, 0, 5, facecolor='grey', alpha=0.5) + ax[4].axvspan(0, 3, 0, 5, facecolor="grey", alpha=0.5) ax[4].set_xticks(np.arange(0, 6.1, 0.5)) ps.hide_ax(ax[3]) labelpad = 30 fsize = 12 - ax[0].set_ylabel('Contact', rotation=0, - labelpad=labelpad, fontsize=fsize) + ax[0].set_ylabel( + "Contact", rotation=0, labelpad=labelpad, fontsize=fsize + ) ax[0].yaxis.set_label_coords(-0.062, -0.08) - ax[1].set_ylabel('Chasing', rotation=0, - labelpad=labelpad, fontsize=fsize) + ax[1].set_ylabel( + "Chasing", rotation=0, labelpad=labelpad, fontsize=fsize + ) ax[1].yaxis.set_label_coords(-0.06, -0.08) - ax[2].set_ylabel('Chirps', rotation=0, - labelpad=labelpad, fontsize=fsize) + ax[2].set_ylabel( + "Chirps", rotation=0, labelpad=labelpad, fontsize=fsize + ) ax[2].yaxis.set_label_coords(-0.07, -0.08) - ax[4].set_ylabel('EODf') + ax[4].set_ylabel("EODf") - ax[4].set_xlabel('Time [h]') + ax[4].set_xlabel("Time [h]") # ax[0].set_title(foldername.split('/')[-2]) # 2020-03-31-9_59 plt.subplots_adjust(left=0.158, right=0.987, top=0.918, bottom=0.136) - plt.savefig('../poster/figs/timeline.svg') + plt.savefig("../poster/figs/timeline.svg") plt.show() # plot chirps -if __name__ == '__main__': +if __name__ == "__main__": # Path to the data - datapath = '../data/mount_data/' + datapath = "../data/mount_data/" main(datapath) diff --git a/code/plot_introduction_specs.py b/code/plot_introduction_specs.py index d7e6f4a..0c8e2b4 100644 --- a/code/plot_introduction_specs.py +++ b/code/plot_introduction_specs.py @@ -11,7 +11,6 @@ ps = PlotStyle() def main(): - # Load data datapath = "../data/2022-06-02-10_00/" data = LoadData(datapath) @@ -24,26 +23,31 @@ def main(): timescaler = 1000 - raw = data.raw[window_start_index:window_start_index + - window_duration_index, 10] + raw = data.raw[ + window_start_index : window_start_index + window_duration_index, 10 + ] fig, (ax1, ax2) = plt.subplots( - 1, 2, figsize=(21 * ps.cm, 8*ps.cm), sharex=True, sharey=True) + 1, 2, figsize=(21 * ps.cm, 8 * ps.cm), sharex=True, sharey=True + ) # plot instantaneous frequency filtered1 = bandpass_filter( - signal=raw, lowf=750, highf=1200, samplerate=data.raw_rate) + signal=raw, lowf=750, highf=1200, samplerate=data.raw_rate + ) filtered2 = bandpass_filter( - signal=raw, lowf=550, highf=700, samplerate=data.raw_rate) + signal=raw, lowf=550, highf=700, samplerate=data.raw_rate + ) freqtime1, freq1 = instantaneous_frequency( - filtered1, data.raw_rate, smoothing_window=3) + filtered1, data.raw_rate, smoothing_window=3 + ) freqtime2, freq2 = instantaneous_frequency( - filtered2, data.raw_rate, smoothing_window=3) + filtered2, data.raw_rate, smoothing_window=3 + ) - ax1.plot(freqtime1*timescaler, freq1, color=ps.g, lw=2, label="Fish 1") - ax1.plot(freqtime2*timescaler, freq2, color=ps.gray, - lw=2, label="Fish 2") + ax1.plot(freqtime1 * timescaler, freq1, color=ps.g, lw=2, label="Fish 1") + ax1.plot(freqtime2 * timescaler, freq2, color=ps.gray, lw=2, label="Fish 2") # ax.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0) # # ps.hide_xax(ax1) @@ -62,8 +66,8 @@ def main(): ax1.imshow( decibel(spec_power[fmask, :]), extent=[ - spec_times[0]*timescaler, - spec_times[-1]*timescaler, + spec_times[0] * timescaler, + spec_times[-1] * timescaler, spec_freqs[fmask][0], spec_freqs[fmask][-1], ], @@ -87,8 +91,8 @@ def main(): ax2.imshow( decibel(spec_power[fmask, :]), extent=[ - spec_times[0]*timescaler, - spec_times[-1]*timescaler, + spec_times[0] * timescaler, + spec_times[-1] * timescaler, spec_freqs[fmask][0], spec_freqs[fmask][-1], ], @@ -98,9 +102,8 @@ def main(): alpha=1, ) # ps.hide_xax(ax3) - ax2.plot(freqtime1*timescaler, freq1, color=ps.g, lw=2, label="_") - ax2.plot(freqtime2*timescaler, freq2, color=ps.gray, - lw=2, label="_") + ax2.plot(freqtime1 * timescaler, freq1, color=ps.g, lw=2, label="_") + ax2.plot(freqtime2 * timescaler, freq2, color=ps.gray, lw=2, label="_") ax2.set_xlim(75, 200) ax1.set_ylim(400, 1200) @@ -109,15 +112,22 @@ def main(): fig.supylabel("Frequency [Hz]", fontsize=14) handles, labels = ax1.get_legend_handles_labels() - ax2.legend(handles, labels, bbox_to_anchor=(1.04, 1), loc="upper left", ncol=1,) + ax2.legend( + handles, + labels, + bbox_to_anchor=(1.04, 1), + loc="upper left", + ncol=1, + ) ps.letter_subplots(xoffset=[-0.27, -0.1], yoffset=1.05) - plt.subplots_adjust(left=0.12, right=0.85, top=0.89, - bottom=0.18, hspace=0.35) + plt.subplots_adjust( + left=0.12, right=0.85, top=0.89, bottom=0.18, hspace=0.35 + ) - plt.savefig('../poster/figs/introplot.pdf') + plt.savefig("../poster/figs/introplot.pdf") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/code/plot_kdes.py b/code/plot_kdes.py index 0f3082b..bc1bc98 100644 --- a/code/plot_kdes.py +++ b/code/plot_kdes.py @@ -1,7 +1,9 @@ - from modules.plotstyle import PlotStyle from modules.behaviour_handling import ( - Behavior, correct_chasing_events, center_chirps) + Behavior, + correct_chasing_events, + center_chirps, +) from modules.datahandling import flatten, causal_kde1d, acausal_kde1d from modules.logger import makeLogger from pandas import read_csv @@ -18,80 +20,93 @@ logger = makeLogger(__name__) ps = PlotStyle() -def bootstrap(data, nresamples, kde_time, kernel_width, event_times, time_before, time_after): - +def bootstrap( + data, + nresamples, + kde_time, + kernel_width, + event_times, + time_before, + time_after, +): bootstrapped_kdes = [] - data = data[data <= 3*60*60] # only night time + data = data[data <= 3 * 60 * 60] # only night time diff_data = np.diff(np.sort(data), prepend=0) # if len(data) != 0: # mean_chirprate = (len(data) - 1) / (data[-1] - data[0]) for i in tqdm(range(nresamples)): - np.random.shuffle(diff_data) bootstrapped_data = np.cumsum(diff_data) # bootstrapped_data = data + np.random.randn(len(data)) * 10 bootstrap_data_centered = center_chirps( - bootstrapped_data, event_times, time_before, time_after) + bootstrapped_data, event_times, time_before, time_after + ) bootstrapped_kde = acausal_kde1d( - bootstrap_data_centered, time=kde_time, width=kernel_width) + bootstrap_data_centered, time=kde_time, width=kernel_width + ) - bootstrapped_kde = list(np.asarray( - bootstrapped_kde) / len(event_times)) + bootstrapped_kde = list(np.asarray(bootstrapped_kde) / len(event_times)) bootstrapped_kdes.append(bootstrapped_kde) return bootstrapped_kdes -def jackknife(data, nresamples, subsetsize, kde_time, kernel_width, event_times, time_before, time_after): - +def jackknife( + data, + nresamples, + subsetsize, + kde_time, + kernel_width, + event_times, + time_before, + time_after, +): jackknife_kdes = [] - data = data[data <= 3*60*60] # only night time + data = data[data <= 3 * 60 * 60] # only night time subsetsize = int(len(data) * subsetsize) diff_data = np.diff(np.sort(data), prepend=0) for i in tqdm(range(nresamples)): - - jackknifed_data = np.random.choice( - diff_data, subsetsize, replace=False) + jackknifed_data = np.random.choice(diff_data, subsetsize, replace=False) jackknifed_data = np.cumsum(jackknifed_data) jackknifed_data_centered = center_chirps( - jackknifed_data, event_times, time_before, time_after) + jackknifed_data, event_times, time_before, time_after + ) jackknifed_kde = acausal_kde1d( - jackknifed_data_centered, time=kde_time, width=kernel_width) + jackknifed_data_centered, time=kde_time, width=kernel_width + ) - jackknifed_kde = list(np.asarray( - jackknifed_kde) / len(event_times)) + jackknifed_kde = list(np.asarray(jackknifed_kde) / len(event_times)) jackknife_kdes.append(jackknifed_kde) return jackknife_kdes def get_chirp_winner_loser(folder_name, Behavior, order_meta_df): - - foldername = folder_name.split('/')[-2] - winner_row = order_meta_df[order_meta_df['recording'] == foldername] - winner = winner_row['winner'].values[0].astype(int) - winner_fish1 = winner_row['fish1'].values[0].astype(int) - winner_fish2 = winner_row['fish2'].values[0].astype(int) + foldername = folder_name.split("/")[-2] + winner_row = order_meta_df[order_meta_df["recording"] == foldername] + winner = winner_row["winner"].values[0].astype(int) + winner_fish1 = winner_row["fish1"].values[0].astype(int) + winner_fish2 = winner_row["fish2"].values[0].astype(int) if winner > 0: if winner == winner_fish1: - winner_fish_id = winner_row['rec_id1'].values[0] - loser_fish_id = winner_row['rec_id2'].values[0] + winner_fish_id = winner_row["rec_id1"].values[0] + loser_fish_id = winner_row["rec_id2"].values[0] elif winner == winner_fish2: - winner_fish_id = winner_row['rec_id2'].values[0] - loser_fish_id = winner_row['rec_id1'].values[0] + winner_fish_id = winner_row["rec_id2"].values[0] + loser_fish_id = winner_row["rec_id1"].values[0] chirp_winner = Behavior.chirps[Behavior.chirps_ids == winner_fish_id] chirp_loser = Behavior.chirps[Behavior.chirps_ids == loser_fish_id] @@ -101,7 +116,6 @@ def get_chirp_winner_loser(folder_name, Behavior, order_meta_df): def main(dataroot): - foldernames, _ = np.asarray(get_valid_datasets(dataroot)) plot_all = True time_before = 90 @@ -111,10 +125,9 @@ def main(dataroot): kde_time = np.arange(-time_before, time_after, dt) nbootstraps = 50 - meta_path = ( - '/').join(foldernames[0].split('/')[:-2]) + '/order_meta.csv' + meta_path = ("/").join(foldernames[0].split("/")[:-2]) + "/order_meta.csv" meta = pd.read_csv(meta_path) - meta['recording'] = meta['recording'].str[1:-1] + meta["recording"] = meta["recording"].str[1:-1] winner_onsets = [] winner_offsets = [] @@ -143,24 +156,24 @@ def main(dataroot): # loser_onset_chirpcount = 0 # loser_offset_chirpcount = 0 # loser_physical_chirpcount = 0 - fig, ax = plt.subplots(1, 2, figsize=( - 14 * ps.cm, 7*ps.cm), sharey=True, sharex=True) + fig, ax = plt.subplots( + 1, 2, figsize=(14 * ps.cm, 7 * ps.cm), sharey=True, sharex=True + ) # Iterate over all recordings and save chirp- and event-timestamps good_recs = np.asarray([0, 15]) for i, folder in tqdm(enumerate(foldernames[good_recs])): - - foldername = folder.split('/')[-2] + foldername = folder.split("/")[-2] # logger.info('Loading data from folder: {}'.format(foldername)) - broken_folders = ['../data/mount_data/2020-05-12-10_00/'] + broken_folders = ["../data/mount_data/2020-05-12-10_00/"] if folder in broken_folders: continue bh = Behavior(folder) category, timestamps = correct_chasing_events(bh.behavior, bh.start_s) - category = category[timestamps < 3*60*60] # only night time - timestamps = timestamps[timestamps < 3*60*60] # only night time + category = category[timestamps < 3 * 60 * 60] # only night time + timestamps = timestamps[timestamps < 3 * 60 * 60] # only night time winner, loser = get_chirp_winner_loser(folder, bh, meta) if winner is None: @@ -168,27 +181,33 @@ def main(dataroot): # winner_count += len(winner) # loser_count += len(loser) - onsets = (timestamps[category == 0]) - offsets = (timestamps[category == 1]) - physicals = (timestamps[category == 2]) + onsets = timestamps[category == 0] + offsets = timestamps[category == 1] + physicals = timestamps[category == 2] onset_count += len(onsets) offset_count += len(offsets) physical_count += len(physicals) - winner_onsets.append(center_chirps( - winner, onsets, time_before, time_after)) - winner_offsets.append(center_chirps( - winner, offsets, time_before, time_after)) - winner_physicals.append(center_chirps( - winner, physicals, time_before, time_after)) - - loser_onsets.append(center_chirps( - loser, onsets, time_before, time_after)) - loser_offsets.append(center_chirps( - loser, offsets, time_before, time_after)) - loser_physicals.append(center_chirps( - loser, physicals, time_before, time_after)) + winner_onsets.append( + center_chirps(winner, onsets, time_before, time_after) + ) + winner_offsets.append( + center_chirps(winner, offsets, time_before, time_after) + ) + winner_physicals.append( + center_chirps(winner, physicals, time_before, time_after) + ) + + loser_onsets.append( + center_chirps(loser, onsets, time_before, time_after) + ) + loser_offsets.append( + center_chirps(loser, offsets, time_before, time_after) + ) + loser_physicals.append( + center_chirps(loser, physicals, time_before, time_after) + ) # winner_onset_chirpcount += len(winner_onsets[-1]) # winner_offset_chirpcount += len(winner_offsets[-1]) @@ -232,14 +251,17 @@ def main(dataroot): # event_times=onsets, # time_before=time_before, # time_after=time_after)) - loser_offsets_boot.append(bootstrap( - loser, - nresamples=nbootstraps, - kde_time=kde_time, - kernel_width=kernel_width, - event_times=offsets, - time_before=time_before, - time_after=time_after)) + loser_offsets_boot.append( + bootstrap( + loser, + nresamples=nbootstraps, + kde_time=kde_time, + kernel_width=kernel_width, + event_times=offsets, + time_before=time_before, + time_after=time_after, + ) + ) # loser_physicals_boot.append(bootstrap( # loser, # nresamples=nbootstraps, @@ -249,18 +271,17 @@ def main(dataroot): # time_before=time_before, # time_after=time_after)) -# loser_offsets_jackknife = jackknife( -# loser, -# nresamples=nbootstraps, -# subsetsize=0.9, -# kde_time=kde_time, -# kernel_width=kernel_width, -# event_times=offsets, -# time_before=time_before, -# time_after=time_after) + # loser_offsets_jackknife = jackknife( + # loser, + # nresamples=nbootstraps, + # subsetsize=0.9, + # kde_time=kde_time, + # kernel_width=kernel_width, + # event_times=offsets, + # time_before=time_before, + # time_after=time_after) if plot_all: - # winner_onsets_conv = acausal_kde1d( # winner_onsets[-1], kde_time, kernel_width) # winner_offsets_conv = acausal_kde1d( @@ -271,24 +292,35 @@ def main(dataroot): # loser_onsets_conv = acausal_kde1d( # loser_onsets[-1], kde_time, kernel_width) loser_offsets_conv = acausal_kde1d( - loser_offsets[-1], kde_time, kernel_width) + loser_offsets[-1], kde_time, kernel_width + ) # loser_physicals_conv = acausal_kde1d( # loser_physicals[-1], kde_time, kernel_width) - ax[i].plot(kde_time, loser_offsets_conv / - len(offsets), lw=2, zorder=100, c=ps.gblue1) + ax[i].plot( + kde_time, + loser_offsets_conv / len(offsets), + lw=2, + zorder=100, + c=ps.gblue1, + ) ax[i].fill_between( kde_time, np.percentile(loser_offsets_boot[-1], 1, axis=0), np.percentile(loser_offsets_boot[-1], 99, axis=0), - color='gray', - alpha=0.8) + color="gray", + alpha=0.8, + ) - ax[i].plot(kde_time, np.median(loser_offsets_boot[-1], axis=0), - color=ps.black, linewidth=2) + ax[i].plot( + kde_time, + np.median(loser_offsets_boot[-1], axis=0), + color=ps.black, + linewidth=2, + ) - ax[i].axvline(0, color=ps.gray, linestyle='--') + ax[i].axvline(0, color=ps.gray, linestyle="--") # ax[i].fill_between( # kde_time, @@ -300,8 +332,8 @@ def main(dataroot): # color=ps.white, linewidth=2) ax[i].set_xlim(-60, 60) - fig.supylabel('Chirp rate (a.u.)', fontsize=14) - fig.supxlabel('Time (s)', fontsize=14) + fig.supylabel("Chirp rate (a.u.)", fontsize=14) + fig.supxlabel("Time (s)", fontsize=14) # fig, ax = plt.subplots(2, 3, figsize=( # 21*ps.cm, 10*ps.cm), sharey=True, sharex=True) @@ -521,9 +553,9 @@ def main(dataroot): # color=ps.gray, # alpha=0.5) plt.subplots_adjust(bottom=0.21, top=0.93) - plt.savefig('../poster/figs/kde.pdf') + plt.savefig("../poster/figs/kde.pdf") plt.show() -if __name__ == '__main__': - main('../data/mount_data/') +if __name__ == "__main__": + main("../data/mount_data/")