From 282c846b05cbbfca07d2079c2651bdf90577bbe5 Mon Sep 17 00:00:00 2001
From: weygoldt <88969563+weygoldt@users.noreply.github.com>
Date: Tue, 11 Apr 2023 15:33:07 +0200
Subject: [PATCH] reformat

---
 chirp_instantaneous_freq/filters.py         |  36 ++-
 chirp_instantaneous_freq/fish_signal.py     |  34 +--
 chirp_instantaneous_freq/test_parameters.py | 122 +++++----
 code/analysis.py                            |  81 +++---
 code/band_pass_problem.py                   |  26 +-
 code/behavior.py                            | 237 ++++++++++-------
 code/chirp_sim.py                           |  19 +-
 code/chirpdetection.py                      | 138 ++++++----
 code/chirpdetector_conf.yml                 |  41 ++-
 code/eventchirpsplots.py                    | 266 ++++++++++++-------
 code/extract_chirps.py                      |  37 +--
 code/get_behaviour.py                       |  42 +--
 code/modules/behaviour_handling.py          |  94 +++----
 code/modules/datahandling.py                |  21 +-
 code/modules/filehandling.py                |   1 -
 code/modules/filters.py                     |  19 +-
 code/modules/logger.py                      |   8 +-
 code/modules/plotstyle.py                   |  16 +-
 code/modules/plotstyle1.py                  |  16 +-
 code/modules/plotstyle_dark.py              |  16 +-
 code/modules/simulations.py                 |   2 +-
 code/plot_chirp_size.py                     | 277 ++++++++++++--------
 code/plot_chirps_in_chasing.py              |  72 +++--
 code/plot_event_timeline.py                 |  80 ++++--
 code/plot_introduction_specs.py             |  56 ++--
 code/plot_kdes.py                           | 206 +++++++++------
 26 files changed, 1165 insertions(+), 798 deletions(-)

diff --git a/chirp_instantaneous_freq/filters.py b/chirp_instantaneous_freq/filters.py
index 709b140..a0d4001 100644
--- a/chirp_instantaneous_freq/filters.py
+++ b/chirp_instantaneous_freq/filters.py
@@ -59,14 +59,14 @@ def instantaneous_frequency(
 def inst_freq(signal, fs):
     """
     Computes the instantaneous frequency of a periodic signal using zero-crossings.
-    
+
     Parameters:
     -----------
     signal : array-like
         The input signal.
     fs : float
         The sampling frequency of the input signal.
-    
+
     Returns:
     --------
     freq : array-like
@@ -74,29 +74,30 @@ def inst_freq(signal, fs):
     """
     # Compute the sign of the signal
     sign = np.sign(signal)
-    
+
     # Compute the crossings of the sign signal with a zero line
     crossings = np.where(np.diff(sign))[0]
-    
+
     # Compute the time differences between zero crossings
     dt = np.diff(crossings) / fs
-    
+
     # Compute the instantaneous frequency as the reciprocal of the time differences
     freq = 1 / dt
 
-    # Gaussian filter the signal 
+    # Gaussian filter the signal
     freq = gaussian_filter1d(freq, 10)
-    
+
     # Pad the frequency vector with zeros to match the length of the input signal
     freq = np.pad(freq, (0, len(signal) - len(freq)))
-    
+
     return freq
 
+
 def bandpass_filter(
-        signal: np.ndarray,
-        samplerate: float,
-        lowf: float,
-        highf: float,
+    signal: np.ndarray,
+    samplerate: float,
+    lowf: float,
+    highf: float,
 ) -> np.ndarray:
     """Bandpass filter a signal.
 
@@ -150,9 +151,7 @@ def highpass_filter(
 
 
 def lowpass_filter(
-    signal: np.ndarray,
-    samplerate: float,
-    cutoff: float
+    signal: np.ndarray, samplerate: float, cutoff: float
 ) -> np.ndarray:
     """Lowpass filter a signal.
 
@@ -176,10 +175,9 @@ def lowpass_filter(
     return filtered_signal
 
 
-def envelope(signal: np.ndarray,
-             samplerate: float,
-             cutoff_frequency: float
-             ) -> np.ndarray:
+def envelope(
+    signal: np.ndarray, samplerate: float, cutoff_frequency: float
+) -> np.ndarray:
     """Calculate the envelope of a signal using a lowpass filter.
 
     Parameters
diff --git a/chirp_instantaneous_freq/fish_signal.py b/chirp_instantaneous_freq/fish_signal.py
index bf740b6..d7830bf 100644
--- a/chirp_instantaneous_freq/fish_signal.py
+++ b/chirp_instantaneous_freq/fish_signal.py
@@ -384,16 +384,14 @@ def chirps(
     frequency = eodf * np.ones(n)
     am = np.ones(n)
 
-    for time, width, size, kurtosis, contrast in zip(chirp_times, chirp_width, chirp_size, chirp_kurtosis, chirp_contrast):
-
+    for time, width, size, kurtosis, contrast in zip(
+        chirp_times, chirp_width, chirp_size, chirp_kurtosis, chirp_contrast
+    ):
         # chirp frequency waveform:
         chirp_t = np.arange(-2.0 * width, 2.0 * width, 1.0 / samplerate)
-        chirp_sig = (
-            0.5 * width / (2.0 * np.log(10.0)) ** (0.5 / kurtosis)
-        )
+        chirp_sig = 0.5 * width / (2.0 * np.log(10.0)) ** (0.5 / kurtosis)
         gauss = np.exp(-0.5 * ((chirp_t / chirp_sig) ** 2.0) ** kurtosis)
 
-
         # add chirps on baseline eodf:
         index = int(time * samplerate)
         i0 = index - len(gauss) // 2
@@ -433,7 +431,7 @@ def rises(
         Sampling rate in Hertz.
     duration: float
         Duration of the generated data in seconds.
-    rise_times: list 
+    rise_times: list
         Timestamp of each of the rises in seconds.
     rise_size: list
         Size of the respective rise (frequency increase above eodf) in Hertz.
@@ -452,15 +450,12 @@ def rises(
     # baseline eod frequency:
     frequency = eodf * np.ones(n)
 
-    for time, size, riset, decayt in zip(rise_times, rise_size, rise_tau, decay_tau):  
-
+    for time, size, riset, decayt in zip(
+        rise_times, rise_size, rise_tau, decay_tau
+    ):
         # rise frequency waveform:
         rise_t = np.arange(0.0, 5.0 * decayt, 1.0 / samplerate)
-        rise = (
-            size
-            * (1.0 - np.exp(-rise_t / riset))
-            * np.exp(-rise_t / decayt)
-        )
+        rise = size * (1.0 - np.exp(-rise_t / riset)) * np.exp(-rise_t / decayt)
 
         # add rises on baseline eodf:
         index = int(time * samplerate)
@@ -472,13 +467,14 @@ def rises(
             frequency[index : index + len(rise)] += rise
     return frequency
 
+
 class FishSignal:
     def __init__(self, samplerate, duration, eodf, nchirps, nrises):
         time = np.arange(0, duration, 1 / samplerate)
         chirp_times = np.random.uniform(0, duration, nchirps)
         rise_times = np.random.uniform(0, duration, nrises)
 
-        # pick random parameters for chirps 
+        # pick random parameters for chirps
         chirp_size = np.random.uniform(60, 200, nchirps)
         chirp_width = np.random.uniform(0.01, 0.1, nchirps)
         chirp_kurtosis = np.random.uniform(1, 1, nchirps)
@@ -534,7 +530,6 @@ class FishSignal:
         self.eodf = eodf
 
     def visualize(self):
-
         spec, freqs, spectime = ps.spectrogram(
             data=self.signal,
             ratetime=self.samplerate,
@@ -549,7 +544,12 @@ class FishSignal:
         ax1.set_xlabel("Time (s)")
         ax1.set_title("EOD signal")
 
-        ax2.imshow(ps.decibel(spec), origin='lower', aspect="auto", extent=[spectime[0], spectime[-1], freqs[0], freqs[-1]])
+        ax2.imshow(
+            ps.decibel(spec),
+            origin="lower",
+            aspect="auto",
+            extent=[spectime[0], spectime[-1], freqs[0], freqs[-1]],
+        )
         ax2.set_ylabel("Frequency (Hz)")
         ax2.set_xlabel("Time (s)")
         ax2.set_title("Spectrogram")
diff --git a/chirp_instantaneous_freq/test_parameters.py b/chirp_instantaneous_freq/test_parameters.py
index 9c4ab5f..bad1e45 100644
--- a/chirp_instantaneous_freq/test_parameters.py
+++ b/chirp_instantaneous_freq/test_parameters.py
@@ -1,4 +1,4 @@
-import numpy as np 
+import numpy as np
 import matplotlib.pyplot as plt
 from fish_signal import chirps, wavefish_eods
 from filters import bandpass_filter, instantaneous_frequency, inst_freq
@@ -6,18 +6,18 @@ from IPython import embed
 
 
 def switch_test(test, defaultparams, testparams):
-    if test == 'width':
-        defaultparams['chirp_width'] = testparams['chirp_width']
-        key = 'chirp_width'
-    elif test == 'size':
-        defaultparams['chirp_size'] = testparams['chirp_size']
-        key = 'chirp_size'
-    elif test == 'kurtosis':
-        defaultparams['chirp_kurtosis'] = testparams['chirp_kurtosis']
-        key = 'chirp_kurtosis'
-    elif test == 'contrast':
-        defaultparams['chirp_contrast'] = testparams['chirp_contrast']
-        key = 'chirp_contrast'
+    if test == "width":
+        defaultparams["chirp_width"] = testparams["chirp_width"]
+        key = "chirp_width"
+    elif test == "size":
+        defaultparams["chirp_size"] = testparams["chirp_size"]
+        key = "chirp_size"
+    elif test == "kurtosis":
+        defaultparams["chirp_kurtosis"] = testparams["chirp_kurtosis"]
+        key = "chirp_kurtosis"
+    elif test == "contrast":
+        defaultparams["chirp_contrast"] = testparams["chirp_contrast"]
+        key = "chirp_contrast"
     else:
         raise ValueError("Test not recognized")
 
@@ -29,31 +29,40 @@ def extract_dict(dict, index):
 
 
 def main(test1, test2, resolution=10):
+    assert test1 in [
+        "width",
+        "size",
+        "kurtosis",
+        "contrast",
+    ], "Test1 not recognized"
+    assert test2 in [
+        "width",
+        "size",
+        "kurtosis",
+        "contrast",
+    ], "Test2 not recognized"
 
-    assert test1 in ['width', 'size', 'kurtosis', 'contrast'], "Test1 not recognized"
-    assert test2 in ['width', 'size', 'kurtosis', 'contrast'], "Test2 not recognized"
-
-    # Define the parameters for the chirp simulations 
+    # Define the parameters for the chirp simulations
     ntest = resolution
 
     defaultparams = dict(
-        chirp_size = np.ones(ntest) * 100, 
-        chirp_width = np.ones(ntest) * 0.1, 
-        chirp_kurtosis = np.ones(ntest) * 1.0, 
-        chirp_contrast = np.ones(ntest) * 0.5, 
+        chirp_size=np.ones(ntest) * 100,
+        chirp_width=np.ones(ntest) * 0.1,
+        chirp_kurtosis=np.ones(ntest) * 1.0,
+        chirp_contrast=np.ones(ntest) * 0.5,
     )
 
     testparams = dict(
-        chirp_width = np.linspace(0.01, 0.2, ntest), 
-        chirp_size = np.linspace(50, 300, ntest), 
-        chirp_kurtosis = np.linspace(0.5, 1.5, ntest), 
-        chirp_contrast = np.linspace(0.01, 1.0, ntest), 
+        chirp_width=np.linspace(0.01, 0.2, ntest),
+        chirp_size=np.linspace(50, 300, ntest),
+        chirp_kurtosis=np.linspace(0.5, 1.5, ntest),
+        chirp_contrast=np.linspace(0.01, 1.0, ntest),
     )
 
     key1, chirp_params = switch_test(test1, defaultparams, testparams)
     key2, chirp_params = switch_test(test2, chirp_params, testparams)
 
-    # make the chirp trace 
+    # make the chirp trace
     eodf = 500
     samplerate = 20000
     duration = 2
@@ -63,40 +72,60 @@ def main(test1, test2, resolution=10):
     tight_cutoffs = 10
 
     distances = np.full((ntest, ntest), np.nan)
-    
-    fig, axs = plt.subplots(ntest, ntest, figsize = (10, 10), sharex = True, sharey = True)
+
+    fig, axs = plt.subplots(
+        ntest, ntest, figsize=(10, 10), sharex=True, sharey=True
+    )
     axs = axs.flatten()
 
     iter0 = 0
     for iter1, test1_param in enumerate(chirp_params[key1]):
         for iter2, test2_param in enumerate(chirp_params[key2]):
-
             # get the chirp parameters for the current test
             inner_chirp_params = extract_dict(chirp_params, iter2)
             inner_chirp_params[key1] = test1_param
             inner_chirp_params[key2] = test2_param
 
             # make the chirp trace for the current chirp parameters
-            sizes = np.ones(len(chirp_times)) * inner_chirp_params['chirp_size']
-            widths = np.ones(len(chirp_times)) * inner_chirp_params['chirp_width']
-            kurtosis = np.ones(len(chirp_times)) * inner_chirp_params['chirp_kurtosis']
-            contrast = np.ones(len(chirp_times)) * inner_chirp_params['chirp_contrast']
+            sizes = np.ones(len(chirp_times)) * inner_chirp_params["chirp_size"]
+            widths = (
+                np.ones(len(chirp_times)) * inner_chirp_params["chirp_width"]
+            )
+            kurtosis = (
+                np.ones(len(chirp_times)) * inner_chirp_params["chirp_kurtosis"]
+            )
+            contrast = (
+                np.ones(len(chirp_times)) * inner_chirp_params["chirp_contrast"]
+            )
 
             # make the chirp trace
-            chirp_trace, ampmod = chirps(eodf, samplerate, duration, chirp_times, sizes, widths, kurtosis, contrast)
+            chirp_trace, ampmod = chirps(
+                eodf,
+                samplerate,
+                duration,
+                chirp_times,
+                sizes,
+                widths,
+                kurtosis,
+                contrast,
+            )
             signal = wavefish_eods(
-                    fish="Alepto", 
-                    frequency=chirp_trace, 
-                    samplerate=samplerate, 
-                    duration=duration, 
-                    phase0=0.0, 
-                    noise_std=0.05
-            )           
+                fish="Alepto",
+                frequency=chirp_trace,
+                samplerate=samplerate,
+                duration=duration,
+                phase0=0.0,
+                noise_std=0.05,
+            )
             signal = signal * ampmod
 
-            # apply broadband filter 
-            wide_signal = bandpass_filter(signal, samplerate, eodf - wide_cutoffs, eodf + wide_cutoffs)
-            tight_signal = bandpass_filter(signal, samplerate, eodf - tight_cutoffs, eodf + tight_cutoffs)
+            # apply broadband filter
+            wide_signal = bandpass_filter(
+                signal, samplerate, eodf - wide_cutoffs, eodf + wide_cutoffs
+            )
+            tight_signal = bandpass_filter(
+                signal, samplerate, eodf - tight_cutoffs, eodf + tight_cutoffs
+            )
 
             # get the instantaneous frequency
             wide_frequency = inst_freq(wide_signal, samplerate)
@@ -111,8 +140,9 @@ def main(test1, test2, resolution=10):
             iter0 += 1
 
     fig, ax = plt.subplots()
-    ax.imshow(distances, cmap = 'jet')
+    ax.imshow(distances, cmap="jet")
     plt.show()
 
+
 if __name__ == "__main__":
-    main('width', 'size') 
+    main("width", "size")
diff --git a/code/analysis.py b/code/analysis.py
index 787a53e..2e32671 100644
--- a/code/analysis.py
+++ b/code/analysis.py
@@ -10,73 +10,84 @@ from modules.filters import bandpass_filter
 
 
 def main(folder):
-    file = os.path.join(folder, 'traces-grid.raw')
+    file = os.path.join(folder, "traces-grid.raw")
     data = open_data(folder, 60.0, 0, channel=-1)
-    time = np.load(folder + 'times.npy', allow_pickle=True)
-    freq = np.load(folder + 'fund_v.npy', allow_pickle=True)
-    ident = np.load(folder + 'ident_v.npy', allow_pickle=True)
-    idx = np.load(folder + 'idx_v.npy', allow_pickle=True)
+    time = np.load(folder + "times.npy", allow_pickle=True)
+    freq = np.load(folder + "fund_v.npy", allow_pickle=True)
+    ident = np.load(folder + "ident_v.npy", allow_pickle=True)
+    idx = np.load(folder + "idx_v.npy", allow_pickle=True)
 
-    t0 = 3*60*60 + 6*60 + 43.5
+    t0 = 3 * 60 * 60 + 6 * 60 + 43.5
     dt = 60
-    data_oi = data[t0 * data.samplerate: (t0+dt)*data.samplerate, :]
+    data_oi = data[t0 * data.samplerate : (t0 + dt) * data.samplerate, :]
 
     for i in [10]:
         # getting the spectogramm
         spec_power, spec_freqs, spec_times = spectrogram(
-            data_oi[:, i], ratetime=data.samplerate, freq_resolution=50, overlap_frac=0.0)
-        fig, ax = plt.subplots(figsize=(20/2.54, 12/2.54))
-        ax.pcolormesh(spec_times, spec_freqs, decibel(
-            spec_power), vmin=-100, vmax=-50)
+            data_oi[:, i],
+            ratetime=data.samplerate,
+            freq_resolution=50,
+            overlap_frac=0.0,
+        )
+        fig, ax = plt.subplots(figsize=(20 / 2.54, 12 / 2.54))
+        ax.pcolormesh(
+            spec_times, spec_freqs, decibel(spec_power), vmin=-100, vmax=-50
+        )
 
         for track_id in np.unique(ident):
             # window_index for time array in time window
-            window_index = np.arange(len(idx))[(ident == track_id) &
-                                               (time[idx] >= t0) &
-                                               (time[idx] <= (t0+dt))]
+            window_index = np.arange(len(idx))[
+                (ident == track_id)
+                & (time[idx] >= t0)
+                & (time[idx] <= (t0 + dt))
+            ]
             freq_temp = freq[window_index]
             time_temp = time[idx[window_index]]
-            #mean_freq = np.mean(freq_temp)
-            #fdata = bandpass_filter(data_oi[:, track_id], data.samplerate, mean_freq-5, mean_freq+200)
+            # mean_freq = np.mean(freq_temp)
+            # fdata = bandpass_filter(data_oi[:, track_id], data.samplerate, mean_freq-5, mean_freq+200)
             ax.plot(time_temp - t0, freq_temp)
 
     ax.set_ylim(500, 1000)
     plt.show()
     # filter plot
-    id = 10.
+    id = 10.0
     i = 10
-    window_index = np.arange(len(idx))[(ident == id) &
-                                       (time[idx] >= t0) &
-                                       (time[idx] <= (t0+dt))]
+    window_index = np.arange(len(idx))[
+        (ident == id) & (time[idx] >= t0) & (time[idx] <= (t0 + dt))
+    ]
     freq_temp = freq[window_index]
     time_temp = time[idx[window_index]]
     mean_freq = np.mean(freq_temp)
     fdata = bandpass_filter(
-        data_oi[:, i], rate=data.samplerate, lowf=mean_freq-5, highf=mean_freq+200)
+        data_oi[:, i],
+        rate=data.samplerate,
+        lowf=mean_freq - 5,
+        highf=mean_freq + 200,
+    )
     fig, ax = plt.subplots()
-    ax.plot(np.arange(len(fdata))/data.samplerate, fdata, marker='*')
+    ax.plot(np.arange(len(fdata)) / data.samplerate, fdata, marker="*")
     # plt.show()
     # freqency analyis of filtered data
 
-    time_fdata = np.arange(len(fdata))/data.samplerate
+    time_fdata = np.arange(len(fdata)) / data.samplerate
     roll_fdata = np.roll(fdata, shift=1)
     period_index = np.arange(len(fdata))[(roll_fdata < 0) & (fdata >= 0)]
 
     plt.plot(time_fdata, fdata)
-    plt.scatter(time_fdata[period_index], fdata[period_index], c='r')
-    plt.scatter(time_fdata[period_index-1], fdata[period_index-1], c='r')
+    plt.scatter(time_fdata[period_index], fdata[period_index], c="r")
+    plt.scatter(time_fdata[period_index - 1], fdata[period_index - 1], c="r")
 
     upper_bound = np.abs(fdata[period_index])
-    lower_bound = np.abs(fdata[period_index-1])
+    lower_bound = np.abs(fdata[period_index - 1])
 
     upper_times = np.abs(time_fdata[period_index])
-    lower_times = np.abs(time_fdata[period_index-1])
+    lower_times = np.abs(time_fdata[period_index - 1])
 
-    lower_ratio = lower_bound/(lower_bound+upper_bound)
-    upper_ratio = upper_bound/(lower_bound+upper_bound)
+    lower_ratio = lower_bound / (lower_bound + upper_bound)
+    upper_ratio = upper_bound / (lower_bound + upper_bound)
 
-    time_delta = upper_times-lower_times
-    true_zero = lower_times + time_delta*lower_ratio
+    time_delta = upper_times - lower_times
+    true_zero = lower_times + time_delta * lower_ratio
 
     plt.scatter(true_zero, np.zeros(len(true_zero)))
 
@@ -84,7 +95,7 @@ def main(folder):
     inst_freq = 1 / np.diff(true_zero)
     filtered_inst_freq = gaussian_filter1d(inst_freq, 0.005)
     fig, ax = plt.subplots()
-    ax.plot(filtered_inst_freq, marker='.')
+    ax.plot(filtered_inst_freq, marker=".")
     # in 5 sekunden welcher fisch auf einer elektrode am
 
     embed()
@@ -99,5 +110,7 @@ def main(folder):
     pass
 
 
-if __name__ == '__main__':
-    main('/Users/acfw/Documents/uni_tuebingen/chirpdetection/gp_benda/data/2022-06-02-10_00/')
+if __name__ == "__main__":
+    main(
+        "/Users/acfw/Documents/uni_tuebingen/chirpdetection/gp_benda/data/2022-06-02-10_00/"
+    )
diff --git a/code/band_pass_problem.py b/code/band_pass_problem.py
index fc6a55e..f553ff2 100644
--- a/code/band_pass_problem.py
+++ b/code/band_pass_problem.py
@@ -12,25 +12,27 @@ from modules.filehandling import LoadData
 def main(folder):
     data = LoadData(folder)
 
-    t0 = 3*60*60 + 6*60 + 43.5
+    t0 = 3 * 60 * 60 + 6 * 60 + 43.5
     dt = 60
-    data_oi = data.raw[t0 * data.raw_rate: (t0+dt)*data.raw_rate, :]
-    # good electrode 
-    electrode = 10 
+    data_oi = data.raw[t0 * data.raw_rate : (t0 + dt) * data.raw_rate, :]
+    # good electrode
+    electrode = 10
     data_oi = data_oi[:, electrode]
-    fig, axs = plt.subplots(2,1)
-    axs[0].plot( np.arange(data_oi.shape[0]) / data.raw_rate, data_oi)
+    fig, axs = plt.subplots(2, 1)
+    axs[0].plot(np.arange(data_oi.shape[0]) / data.raw_rate, data_oi)
     for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
         rack_window_index = np.arange(len(data.idx))[
-                (data.ident == track_id) &
-                                        (data.time[data.idx] >= t0) &
-                                        (data.time[data.idx] <= (t0+dt))]
+            (data.ident == track_id)
+            & (data.time[data.idx] >= t0)
+            & (data.time[data.idx] <= (t0 + dt))
+        ]
         freq_fish = data.freq[rack_window_index]
         axs[1].plot(np.arange(freq_fish.shape[0]) / data.raw_rate, freq_fish)
 
     plt.show()
 
 
-
-if __name__ == '__main__':
-    main('/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/2022-06-02-10_00/')
\ No newline at end of file
+if __name__ == "__main__":
+    main(
+        "/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/2022-06-02-10_00/"
+    )
diff --git a/code/behavior.py b/code/behavior.py
index 71c0926..4f16543 100644
--- a/code/behavior.py
+++ b/code/behavior.py
@@ -1,8 +1,8 @@
-import os 
-import os 
+import os
+import os
 
 import numpy as np
-import matplotlib.pyplot as plt 
+import matplotlib.pyplot as plt
 
 from IPython import embed
 from pandas import read_csv
@@ -11,51 +11,65 @@ from scipy.ndimage import gaussian_filter1d
 
 logger = makeLogger(__name__)
 
+
 class Behavior:
     """Load behavior data from csv file as class attributes
         Attributes
     ----------
     behavior: 0: chasing onset, 1: chasing offset, 2: physical contact
-    behavior_type:         
-    behavioral_category:   
-    comment_start:         
-    comment_stop:          
-    dataframe: pandas dataframe with all the data            
-    duration_s:             
-    media_file:            
-    observation_date:      
-    observation_id:        
-    start_s: start time of the event in seconds               
-    stop_s:  stop time of the event in seconds               
-    total_length:          
+    behavior_type:
+    behavioral_category:
+    comment_start:
+    comment_stop:
+    dataframe: pandas dataframe with all the data
+    duration_s:
+    media_file:
+    observation_date:
+    observation_id:
+    start_s: start time of the event in seconds
+    stop_s:  stop time of the event in seconds
+    total_length:
     """
 
     def __init__(self, folder_path: str) -> None:
-        
-
-        LED_on_time_BORIS = np.load(os.path.join(folder_path, 'LED_on_time.npy'), allow_pickle=True)
-        self.time = np.load(os.path.join(folder_path, "times.npy"), allow_pickle=True)
-        csv_filename = [f for f in os.listdir(folder_path) if f.endswith('.csv')][0] # check if there are more than one csv file
+        LED_on_time_BORIS = np.load(
+            os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True
+        )
+        self.time = np.load(
+            os.path.join(folder_path, "times.npy"), allow_pickle=True
+        )
+        csv_filename = [
+            f for f in os.listdir(folder_path) if f.endswith(".csv")
+        ][
+            0
+        ]  # check if there are more than one csv file
         self.dataframe = read_csv(os.path.join(folder_path, csv_filename))
-        self.chirps = np.load(os.path.join(folder_path, 'chirps.npy'), allow_pickle=True)
-        self.chirps_ids = np.load(os.path.join(folder_path, 'chirps_ids.npy'), allow_pickle=True)
+        self.chirps = np.load(
+            os.path.join(folder_path, "chirps.npy"), allow_pickle=True
+        )
+        self.chirps_ids = np.load(
+            os.path.join(folder_path, "chirps_ids.npy"), allow_pickle=True
+        )
 
         for k, key in enumerate(self.dataframe.keys()):
-            key = key.lower() 
-            if ' ' in key:
-                key = key.replace(' ', '_')
-                if '(' in key:
-                    key = key.replace('(', '')
-                    key = key.replace(')', '')
-            setattr(self, key, np.array(self.dataframe[self.dataframe.keys()[k]]))
-        
+            key = key.lower()
+            if " " in key:
+                key = key.replace(" ", "_")
+                if "(" in key:
+                    key = key.replace("(", "")
+                    key = key.replace(")", "")
+            setattr(
+                self, key, np.array(self.dataframe[self.dataframe.keys()[k]])
+            )
+
         last_LED_t_BORIS = LED_on_time_BORIS[-1]
         real_time_range = self.time[-1] - self.time[0]
         factor = 1.034141
         shift = last_LED_t_BORIS - real_time_range * factor
         self.start_s = (self.start_s - shift) / factor
         self.stop_s = (self.stop_s - shift) / factor
-  
+
+
 """
 1 - chasing onset
 2 - chasing offset
@@ -83,77 +97,77 @@ temporal encpding needs to be corrected ... not exactly 25FPS.
     behavior = data['Behavior']
 """
 
-def correct_chasing_events(
-    category: np.ndarray, 
-    timestamps: np.ndarray
-    ) -> tuple[np.ndarray, np.ndarray]:
 
-    onset_ids = np.arange(
-        len(category))[category == 0]
-    offset_ids = np.arange(
-        len(category))[category == 1]
+def correct_chasing_events(
+    category: np.ndarray, timestamps: np.ndarray
+) -> tuple[np.ndarray, np.ndarray]:
+    onset_ids = np.arange(len(category))[category == 0]
+    offset_ids = np.arange(len(category))[category == 1]
 
     # Check whether on- or offset is longer and calculate length difference
     if len(onset_ids) > len(offset_ids):
         len_diff = len(onset_ids) - len(offset_ids)
         longer_array = onset_ids
         shorter_array = offset_ids
-        logger.info(f'Onsets are greater than offsets by {len_diff}')
+        logger.info(f"Onsets are greater than offsets by {len_diff}")
     elif len(onset_ids) < len(offset_ids):
         len_diff = len(offset_ids) - len(onset_ids)
         longer_array = offset_ids
         shorter_array = onset_ids
-        logger.info(f'Offsets are greater than offsets by {len_diff}')
+        logger.info(f"Offsets are greater than offsets by {len_diff}")
     elif len(onset_ids) == len(offset_ids):
-        logger.info('Chasing events are equal')
+        logger.info("Chasing events are equal")
         return category, timestamps
 
     # Correct the wrong chasing events; delete double events
     wrong_ids = []
-    for i in range(len(longer_array)-(len_diff+1)):
-        if (shorter_array[i] > longer_array[i]) & (shorter_array[i] < longer_array[i+1]):
+    for i in range(len(longer_array) - (len_diff + 1)):
+        if (shorter_array[i] > longer_array[i]) & (
+            shorter_array[i] < longer_array[i + 1]
+        ):
             pass
         else:
             wrong_ids.append(longer_array[i])
             longer_array = np.delete(longer_array, i)
-        
-    category = np.delete(
-        category, wrong_ids)
-    timestamps = np.delete(
-        timestamps, wrong_ids)
+
+    category = np.delete(category, wrong_ids)
+    timestamps = np.delete(timestamps, wrong_ids)
     return category, timestamps
 
 
 def event_triggered_chirps(
-    event: np.ndarray, 
-    chirps:np.ndarray,
+    event: np.ndarray,
+    chirps: np.ndarray,
     time_before_event: int,
-    time_after_event: int
-    )-> tuple[np.ndarray, np.ndarray]:
-
-
-    event_chirps = []   # chirps that are in specified window around event
-    centered_chirps = []    # timestamps of chirps around event centered on the event timepoint
+    time_after_event: int,
+) -> tuple[np.ndarray, np.ndarray]:
+    event_chirps = []  # chirps that are in specified window around event
+    centered_chirps = (
+        []
+    )  # timestamps of chirps around event centered on the event timepoint
 
     for event_timestamp in event:
-        start = event_timestamp - time_before_event    # timepoint of window start
-        stop = event_timestamp + time_after_event    # timepoint of window ending
-        chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)]     # get chirps that are in a -5 to +5 sec window around event
+        start = event_timestamp - time_before_event  # timepoint of window start
+        stop = event_timestamp + time_after_event  # timepoint of window ending
+        chirps_around_event = [
+            c for c in chirps if (c >= start) & (c <= stop)
+        ]  # get chirps that are in a -5 to +5 sec window around event
         event_chirps.append(chirps_around_event)
         if len(chirps_around_event) == 0:
             continue
-        else: 
+        else:
             centered_chirps.append(chirps_around_event - event_timestamp)
-    centered_chirps = np.concatenate(centered_chirps, axis=0)   # convert list of arrays to one array for plotting
+    centered_chirps = np.concatenate(
+        centered_chirps, axis=0
+    )  # convert list of arrays to one array for plotting
 
     return event_chirps, centered_chirps
 
 
 def main(datapath: str):
-
     # behavior is pandas dataframe with all the data
     bh = Behavior(datapath)
-    
+
     # chirps are not sorted in time (presumably due to prior groupings)
     # get and sort chirps and corresponding fish_ids of the chirps
     chirps = bh.chirps[np.argsort(bh.chirps)]
@@ -172,10 +186,34 @@ def main(datapath: str):
 
     # First overview plot
     fig1, ax1 = plt.subplots()
-    ax1.scatter(chirps, np.ones_like(chirps), marker='*', color='royalblue', label='Chirps')
-    ax1.scatter(chasing_onset, np.ones_like(chasing_onset)*2, marker='.', color='forestgreen', label='Chasing onset')
-    ax1.scatter(chasing_offset, np.ones_like(chasing_offset)*2.5, marker='.', color='firebrick', label='Chasing offset')
-    ax1.scatter(physical_contact, np.ones_like(physical_contact)*3, marker='x', color='black', label='Physical contact')
+    ax1.scatter(
+        chirps,
+        np.ones_like(chirps),
+        marker="*",
+        color="royalblue",
+        label="Chirps",
+    )
+    ax1.scatter(
+        chasing_onset,
+        np.ones_like(chasing_onset) * 2,
+        marker=".",
+        color="forestgreen",
+        label="Chasing onset",
+    )
+    ax1.scatter(
+        chasing_offset,
+        np.ones_like(chasing_offset) * 2.5,
+        marker=".",
+        color="firebrick",
+        label="Chasing offset",
+    )
+    ax1.scatter(
+        physical_contact,
+        np.ones_like(physical_contact) * 3,
+        marker="x",
+        color="black",
+        label="Physical contact",
+    )
     plt.legend()
     # plt.show()
     plt.close()
@@ -187,29 +225,40 @@ def main(datapath: str):
     # Evaluate how many chirps were emitted in specific time window around the chasing onset events
 
     # Iterate over chasing onsets (later over fish)
-    time_around_event = 5    # time window around the event in which chirps are counted, 5 = -5 to +5 sec around event
+    time_around_event = 5  # time window around the event in which chirps are counted, 5 = -5 to +5 sec around event
     #### Loop crashes at concatenate in function ####
     # for i in range(len(fish_ids)):
     #     fish = fish_ids[i]
     #     chirps = chirps[chirps_fish_ids == fish]
     #     print(fish)
 
-    chasing_chirps, centered_chasing_chirps = event_triggered_chirps(chasing_onset, chirps, time_around_event, time_around_event)
-    physical_chirps, centered_physical_chirps = event_triggered_chirps(physical_contact, chirps, time_around_event, time_around_event)
+    chasing_chirps, centered_chasing_chirps = event_triggered_chirps(
+        chasing_onset, chirps, time_around_event, time_around_event
+    )
+    physical_chirps, centered_physical_chirps = event_triggered_chirps(
+        physical_contact, chirps, time_around_event, time_around_event
+    )
 
     # Kernel density estimation ???
     # centered_chasing_chirps_convolved = gaussian_filter1d(centered_chasing_chirps, 5)
-    
+
     # centered_chasing = chasing_onset[0] - chasing_onset[0]   ## get the 0 timepoint for plotting; set one chasing event to 0
     offsets = [0.5, 1]
-    fig4, ax4 = plt.subplots(figsize=(20 / 2.54, 12 / 2.54), constrained_layout=True)
-    ax4.eventplot(np.array([centered_chasing_chirps, centered_physical_chirps]), lineoffsets=offsets, linelengths=0.25, colors=['g', 'r'])
-    ax4.vlines(0, 0, 1.5, 'tab:grey', 'dashed', 'Timepoint of event')
+    fig4, ax4 = plt.subplots(
+        figsize=(20 / 2.54, 12 / 2.54), constrained_layout=True
+    )
+    ax4.eventplot(
+        np.array([centered_chasing_chirps, centered_physical_chirps]),
+        lineoffsets=offsets,
+        linelengths=0.25,
+        colors=["g", "r"],
+    )
+    ax4.vlines(0, 0, 1.5, "tab:grey", "dashed", "Timepoint of event")
     # ax4.plot(centered_chasing_chirps_convolved)
     ax4.set_yticks(offsets)
-    ax4.set_yticklabels(['Chasings', 'Physical \n contacts'])
-    ax4.set_xlabel('Time[s]')
-    ax4.set_ylabel('Type of event')
+    ax4.set_yticklabels(["Chasings", "Physical \n contacts"])
+    ax4.set_xlabel("Time[s]")
+    ax4.set_ylabel("Type of event")
     plt.show()
 
     # Associate chirps to inidividual fish
@@ -219,22 +268,21 @@ def main(datapath: str):
 
     ### Plots:
     # 1. All recordings, all fish, all chirps
-        # One CTC, one PTC
+    # One CTC, one PTC
     # 2. All recordings, only winners
-        # One CTC, one PTC
+    # One CTC, one PTC
     # 3. All recordings, all losers
-        # One CTC, one PTC
+    # One CTC, one PTC
 
     #### Chirp counts per fish general #####
     fig2, ax2 = plt.subplots()
-    x = ['Fish1', 'Fish2']
+    x = ["Fish1", "Fish2"]
     width = 0.35
     ax2.bar(x, fish, width=width)
-    ax2.set_ylabel('Chirp count')
+    ax2.set_ylabel("Chirp count")
     # plt.show()
     plt.close()
 
- 
     ##### Count chirps emitted during chasing events and chirps emitted out of chasing events #####
     chirps_in_chasings = []
     for onset, offset in zip(chasing_onset, chasing_offset):
@@ -251,23 +299,24 @@ def main(datapath: str):
             counts_chirps_chasings += 1
 
     # chirps in chasing events
-    fig3 , ax3 = plt.subplots()
-    ax3.bar(['Chirps in chasing events',  'Chasing events without Chirps'], [counts_chirps_chasings, chasings_without_chirps], width=width)
-    plt.ylabel('Count')
+    fig3, ax3 = plt.subplots()
+    ax3.bar(
+        ["Chirps in chasing events", "Chasing events without Chirps"],
+        [counts_chirps_chasings, chasings_without_chirps],
+        width=width,
+    )
+    plt.ylabel("Count")
     # plt.show()
-    plt.close()  
+    plt.close()
 
     # comparison between chasing events with and without chirps
 
-
-    
     embed()
     exit()
 
 
-
-if __name__ == '__main__':
+if __name__ == "__main__":
     # Path to the data
-    datapath = '../data/mount_data/2020-05-13-10_00/'
-    datapath = '../data/mount_data/2020-05-13-10_00/'
+    datapath = "../data/mount_data/2020-05-13-10_00/"
+    datapath = "../data/mount_data/2020-05-13-10_00/"
     main(datapath)
diff --git a/code/chirp_sim.py b/code/chirp_sim.py
index 5433b36..d7023a5 100644
--- a/code/chirp_sim.py
+++ b/code/chirp_sim.py
@@ -8,30 +8,27 @@ from modules.datahandling import instantaneous_frequency
 from modules.simulations import create_chirp
 
 
-
 # trying thunderfish fakefish chirp simulation ---------------------------------
 samplerate = 44100
 freq, ampl = fakefish.chirps(eodf=500, chirp_contrast=0.2)
-data = fakefish.wavefish_eods(fish='Alepto', frequency=freq, phase0=3, samplerate=samplerate)
+data = fakefish.wavefish_eods(
+    fish="Alepto", frequency=freq, phase0=3, samplerate=samplerate
+)
 
 # filter signal with bandpass_filter
-data_filterd = bandpass_filter(data*ampl+1, samplerate, 0.01, 1.99)
+data_filterd = bandpass_filter(data * ampl + 1, samplerate, 0.01, 1.99)
 embed()
 data_freq_time, data_freq = instantaneous_frequency(data, samplerate, 5)
 
 
 fig, ax = plt.subplots(4, 1, figsize=(20 / 2.54, 12 / 2.54), sharex=True)
 
-ax[0].plot(np.arange(len(data))/samplerate, data*ampl)
-#ax[0].scatter(true_zero, np.zeros_like(true_zero), color='red')
-ax[1].plot(np.arange(len(data_filterd))/samplerate, data_filterd)
-ax[2].plot(np.arange(len(freq))/samplerate, freq)
+ax[0].plot(np.arange(len(data)) / samplerate, data * ampl)
+# ax[0].scatter(true_zero, np.zeros_like(true_zero), color='red')
+ax[1].plot(np.arange(len(data_filterd)) / samplerate, data_filterd)
+ax[2].plot(np.arange(len(freq)) / samplerate, freq)
 ax[3].plot(data_freq_time, data_freq)
 
 
 plt.show()
 embed()
-
-
-
-
diff --git a/code/chirpdetection.py b/code/chirpdetection.py
index 95800df..937bde4 100755
--- a/code/chirpdetection.py
+++ b/code/chirpdetection.py
@@ -7,6 +7,7 @@ import matplotlib.pyplot as plt
 import matplotlib.gridspec as gr
 from scipy.signal import find_peaks
 from thunderfish.powerspectrum import spectrogram, decibel
+
 # from sklearn.preprocessing import normalize
 
 from modules.filters import bandpass_filter, envelope, highpass_filter
@@ -18,7 +19,7 @@ from modules.datahandling import (
     purge_duplicates,
     group_timestamps,
     instantaneous_frequency,
-    instantaneous_frequency2, 
+    instantaneous_frequency2,
     minmaxnorm,
 )
 
@@ -59,7 +60,6 @@ class ChirpPlotBuffer:
     frequency_peaks: np.ndarray
 
     def plot_buffer(self, chirps: np.ndarray, plot: str) -> None:
-
         logger.debug("Starting plotting")
 
         # make data for plotting
@@ -135,7 +135,6 @@ class ChirpPlotBuffer:
         ax0.set_ylim(np.min(self.frequency) - 100, np.max(self.frequency) + 200)
 
         for track_id in self.data.ids:
-
             t0_track = self.t0_old - 5
             dt_track = self.dt + 10
             window_idx = np.arange(len(self.data.idx))[
@@ -176,10 +175,16 @@ class ChirpPlotBuffer:
         # )
 
         ax0.axhline(
-            q50 - self.config.minimal_bandwidth / 2, color=ps.gblue1, lw=1, ls="dashed"
+            q50 - self.config.minimal_bandwidth / 2,
+            color=ps.gblue1,
+            lw=1,
+            ls="dashed",
         )
         ax0.axhline(
-            q50 + self.config.minimal_bandwidth / 2, color=ps.gblue1, lw=1, ls="dashed"
+            q50 + self.config.minimal_bandwidth / 2,
+            color=ps.gblue1,
+            lw=1,
+            ls="dashed",
         )
         ax0.axhline(search_lower, color=ps.gblue2, lw=1, ls="dashed")
         ax0.axhline(search_upper, color=ps.gblue2, lw=1, ls="dashed")
@@ -205,7 +210,11 @@ class ChirpPlotBuffer:
 
         # plot waveform of filtered signal
         ax1.plot(
-            self.time, self.baseline * waveform_scaler, c=ps.gray, lw=lw, alpha=0.5
+            self.time,
+            self.baseline * waveform_scaler,
+            c=ps.gray,
+            lw=lw,
+            alpha=0.5,
         )
         ax1.plot(
             self.time,
@@ -216,7 +225,13 @@ class ChirpPlotBuffer:
         )
 
         # plot waveform of filtered search signal
-        ax2.plot(self.time, self.search * waveform_scaler, c=ps.gray, lw=lw, alpha=0.5)
+        ax2.plot(
+            self.time,
+            self.search * waveform_scaler,
+            c=ps.gray,
+            lw=lw,
+            alpha=0.5,
+        )
         ax2.plot(
             self.time,
             self.search_envelope_unfiltered * waveform_scaler,
@@ -238,9 +253,7 @@ class ChirpPlotBuffer:
         # ax4.plot(
         #     self.time, self.baseline_envelope * waveform_scaler, c=ps.gblue1, lw=lw
         # )
-        ax4.plot(
-            self.time, self.baseline_envelope, c=ps.gblue1, lw=lw
-        )
+        ax4.plot(self.time, self.baseline_envelope, c=ps.gblue1, lw=lw)
         ax4.scatter(
             (self.time)[self.baseline_peaks],
             # (self.baseline_envelope * waveform_scaler)[self.baseline_peaks],
@@ -269,7 +282,9 @@ class ChirpPlotBuffer:
         )
 
         # plot filtered instantaneous frequency
-        ax6.plot(self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw)
+        ax6.plot(
+            self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw
+        )
         ax6.scatter(
             self.frequency_time[self.frequency_peaks],
             self.frequency_filtered[self.frequency_peaks],
@@ -303,7 +318,9 @@ class ChirpPlotBuffer:
         # ax7.spines.bottom.set_bounds((0, 5))
 
         ax0.set_xlim(0, self.config.window)
-        plt.subplots_adjust(left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2)
+        plt.subplots_adjust(
+            left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2
+        )
         fig.align_labels()
 
         if plot == "show":
@@ -408,7 +425,9 @@ def extract_frequency_bands(
         q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2
 
     # filter baseline
-    filtered_baseline = bandpass_filter(raw_data, samplerate, lowf=q25, highf=q75)
+    filtered_baseline = bandpass_filter(
+        raw_data, samplerate, lowf=q25, highf=q75
+    )
 
     # filter search area
     filtered_search_freq = bandpass_filter(
@@ -453,12 +472,14 @@ def window_median_all_track_ids(
     track_ids = []
 
     for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
-
         # the window index combines the track id and the time window
         window_idx = np.arange(len(data.idx))[
             (data.ident == track_id)
             & (data.time[data.idx] >= window_start_seconds)
-            & (data.time[data.idx] <= (window_start_seconds + window_duration_seconds))
+            & (
+                data.time[data.idx]
+                <= (window_start_seconds + window_duration_seconds)
+            )
         ]
 
         if len(data.freq[window_idx]) > 0:
@@ -595,15 +616,15 @@ def find_searchband(
 
     # iterate through theses tracks
     if check_track_ids.size != 0:
-
         for j, check_track_id in enumerate(check_track_ids):
-
             q25_temp = q25[percentiles_ids == check_track_id]
             q75_temp = q75[percentiles_ids == check_track_id]
 
             bool_lower[search_window > q25_temp - config.search_res] = False
             bool_upper[search_window < q75_temp + config.search_res] = False
-            search_window_bool[(bool_lower == False) & (bool_upper == False)] = False
+            search_window_bool[
+                (bool_lower == False) & (bool_upper == False)
+            ] = False
 
         # find gaps in search window
         search_window_indices = np.arange(len(search_window))
@@ -622,7 +643,9 @@ def find_searchband(
         # if the first value is -1, the array starst with true, so a gap
         if nonzeros[0] == -1:
             stops = search_window_indices[search_window_gaps == -1]
-            starts = np.append(0, search_window_indices[search_window_gaps == 1])
+            starts = np.append(
+                0, search_window_indices[search_window_gaps == 1]
+            )
 
             # if the last value is -1, the array ends with true, so a gap
             if nonzeros[-1] == 1:
@@ -659,7 +682,6 @@ def find_searchband(
 
 
 def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
-
     assert plot in [
         "save",
         "show",
@@ -729,7 +751,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
     multiwindow_ids = []
 
     for st, window_start_index in enumerate(window_start_indices):
-
         logger.info(f"Processing window {st+1} of {len(window_start_indices)}")
 
         window_start_seconds = window_start_index / data.raw_rate
@@ -744,8 +765,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
         )
 
         # iterate through all fish
-        for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
-
+        for tr, track_id in enumerate(
+            np.unique(data.ident[~np.isnan(data.ident)])
+        ):
             logger.debug(f"Processing track {tr} of {len(data.ids)}")
 
             # get index of track data in this time window
@@ -773,16 +795,17 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
             nanchecker = np.unique(np.isnan(current_powers))
             if (len(nanchecker) == 1) and nanchecker[0] is True:
                 logger.warning(
-                    f"No powers available for track {track_id} window {st}," "skipping."
+                    f"No powers available for track {track_id} window {st},"
+                    "skipping."
                 )
                 continue
 
             # find the strongest electrodes for the current fish in the current
             # window
 
-            best_electrode_index = np.argsort(np.nanmean(current_powers, axis=0))[
-                -config.number_electrodes :
-            ]
+            best_electrode_index = np.argsort(
+                np.nanmean(current_powers, axis=0)
+            )[-config.number_electrodes :]
 
             # find a frequency above the baseline of the current fish in which
             # no other fish is active to search for chirps there
@@ -802,9 +825,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
 
             # iterate through electrodes
             for el, electrode_index in enumerate(best_electrode_index):
-
                 logger.debug(
-                    f"Processing electrode {el+1} of " f"{len(best_electrode_index)}"
+                    f"Processing electrode {el+1} of "
+                    f"{len(best_electrode_index)}"
                 )
 
                 # LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------
@@ -813,7 +836,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
                 current_raw_data = data.raw[
                     window_start_index:window_stop_index, electrode_index
                 ]
-                current_raw_time = raw_time[window_start_index:window_stop_index]
+                current_raw_time = raw_time[
+                    window_start_index:window_stop_index
+                ]
 
                 # EXTRACT FEATURES --------------------------------------------
 
@@ -839,8 +864,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
                 # because the instantaneous frequency is not reliable there
 
                 amplitude_mask = mask_low_amplitudes(
-                        baseline_envelope_unfiltered,
-                        config.baseline_min_amplitude
+                    baseline_envelope_unfiltered, config.baseline_min_amplitude
                 )
 
                 # highpass filter baseline envelope to remove slower
@@ -877,27 +901,30 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
                 # filtered baseline such as the one we are working with.
 
                 baseline_frequency = instantaneous_frequency(
-                        baselineband, 
-                        data.raw_rate, 
-                        config.baseline_frequency_smoothing
+                    baselineband,
+                    data.raw_rate,
+                    config.baseline_frequency_smoothing,
                 )
 
                 # Take the absolute of the instantaneous frequency to invert
-                # troughs into peaks. This is nessecary since the narrow 
+                # troughs into peaks. This is nessecary since the narrow
                 # pass band introduces these anomalies. Also substract by the
                 # median to set it to 0.
-                
+
                 baseline_frequency_filtered = np.abs(
                     baseline_frequency - np.median(baseline_frequency)
                 )
 
-                # check if there is at least one superthreshold peak on the 
-                # instantaneous and exit the loop if not. This is used to 
-                # prevent windows that do definetely not include a chirp 
-                # to enter normalization, where small changes due to noise 
-                # would be amplified 
+                # check if there is at least one superthreshold peak on the
+                # instantaneous and exit the loop if not. This is used to
+                # prevent windows that do definetely not include a chirp
+                # to enter normalization, where small changes due to noise
+                # would be amplified
 
-                if not has_chirp(baseline_frequency_filtered[amplitude_mask], config.baseline_frequency_peakheight):
+                if not has_chirp(
+                    baseline_frequency_filtered[amplitude_mask],
+                    config.baseline_frequency_peakheight,
+                ):
                     continue
 
                 # CUT OFF OVERLAP ---------------------------------------------
@@ -912,14 +939,20 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
 
                 current_raw_time = current_raw_time[no_edges]
                 baselineband = baselineband[no_edges]
-                baseline_envelope_unfiltered = baseline_envelope_unfiltered[no_edges]
+                baseline_envelope_unfiltered = baseline_envelope_unfiltered[
+                    no_edges
+                ]
                 searchband = searchband[no_edges]
                 baseline_envelope = baseline_envelope[no_edges]
-                search_envelope_unfiltered = search_envelope_unfiltered[no_edges]
+                search_envelope_unfiltered = search_envelope_unfiltered[
+                    no_edges
+                ]
                 search_envelope = search_envelope[no_edges]
 
                 baseline_frequency = baseline_frequency[no_edges]
-                baseline_frequency_filtered = baseline_frequency_filtered[no_edges]
+                baseline_frequency_filtered = baseline_frequency_filtered[
+                    no_edges
+                ]
                 baseline_frequency_time = current_raw_time
 
                 # # get instantaneous frequency withoup edges
@@ -960,13 +993,16 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
                 )
                 # detect peaks inst_freq_filtered
                 frequency_peak_indices, _ = find_peaks(
-                    baseline_frequency_filtered, prominence=config.frequency_prominence
+                    baseline_frequency_filtered,
+                    prominence=config.frequency_prominence,
                 )
 
                 # DETECT CHIRPS IN SEARCH WINDOW ------------------------------
 
                 # get the peak timestamps from the peak indices
-                baseline_peak_timestamps = current_raw_time[baseline_peak_indices]
+                baseline_peak_timestamps = current_raw_time[
+                    baseline_peak_indices
+                ]
                 search_peak_timestamps = current_raw_time[search_peak_indices]
 
                 frequency_peak_timestamps = baseline_frequency_time[
@@ -1015,7 +1051,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
                 )
 
                 if chirp_detected or (debug != "elecrode"):
-
                     logger.debug("Detected chirp, ititialize buffer ...")
 
                     # save data to Buffer
@@ -1107,7 +1142,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
     multiwindow_chirps_flat = []
     multiwindow_ids_flat = []
     for track_id in np.unique(multiwindow_ids):
-
         # get chirps for this fish and flatten the list
         current_track_bool = np.asarray(multiwindow_ids) == track_id
         current_track_chirps = flatten(
@@ -1116,7 +1150,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
 
         # add flattened chirps to the list
         multiwindow_chirps_flat.extend(current_track_chirps)
-        multiwindow_ids_flat.extend(list(np.ones_like(current_track_chirps) * track_id))
+        multiwindow_ids_flat.extend(
+            list(np.ones_like(current_track_chirps) * track_id)
+        )
 
     # purge duplicates, i.e. chirps that are very close to each other
     # duplites arise due to overlapping windows
diff --git a/code/chirpdetector_conf.yml b/code/chirpdetector_conf.yml
index 371326e..bb4598b 100755
--- a/code/chirpdetector_conf.yml
+++ b/code/chirpdetector_conf.yml
@@ -1,37 +1,37 @@
 # Path setup ------------------------------------------------------------------
 
-dataroot: "../data/"      # path to data
-outputdir: "../output/"   # path to save plots to
+dataroot: "../data/"  # path to data
+outputdir: "../output/"  # path to save plots to
 
 # Rolling window parameters ---------------------------------------------------
 
-window: 5   # rolling window length in seconds
+window: 5  # rolling window length in seconds
 overlap: 1  # window overlap in seconds
 edge: 0.25  # window edge cufoffs to mitigate filter edge effects
 
 # Electrode iteration parameters ----------------------------------------------
 
-number_electrodes: 2    # number of electrodes to go over
-minimum_electrodes: 1   # mimumun number of electrodes a chirp must be on
+number_electrodes: 2  # number of electrodes to go over
+minimum_electrodes: 1  # mimumun number of electrodes a chirp must be on
 
 # Feature extraction parameters -----------------------------------------------
 
-search_df_lower: 20     # start searching this far above the baseline
-search_df_upper: 100    # stop searching this far above the baseline
-search_res: 1           # search window resolution
-default_search_freq: 60 # search here if no need for a search frequency
-minimal_bandwidth: 10   # minimal bandpass filter width for baseline
-search_bandwidth: 10    # minimal bandpass filter width for search frequency
-baseline_frequency_smoothing: 3 # instantaneous frequency smoothing
+search_df_lower: 20  # start searching this far above the baseline
+search_df_upper: 100  # stop searching this far above the baseline
+search_res: 1  # search window resolution
+default_search_freq: 60  # search here if no need for a search frequency
+minimal_bandwidth: 10  # minimal bandpass filter width for baseline
+search_bandwidth: 10  # minimal bandpass filter width for search frequency
+baseline_frequency_smoothing: 3  # instantaneous frequency smoothing
 
 # Feature processing parameters -----------------------------------------------
 
-baseline_frequency_peakheight: 5 # the min peak height of the baseline instfreq
-baseline_min_amplitude: 0.0001 # the minimal value of the baseline envelope
-baseline_envelope_cutoff: 25            # envelope estimation cutoff
-baseline_envelope_bandpass_lowf: 2      # envelope badpass lower cutoff
-baseline_envelope_bandpass_highf: 100   # envelope bandbass higher cutoff
-search_envelope_cutoff: 10              # search envelope estimation cufoff
+baseline_frequency_peakheight: 5  # the min peak height of the baseline instfreq
+baseline_min_amplitude: 0.0001  # the minimal value of the baseline envelope
+baseline_envelope_cutoff: 25  # envelope estimation cutoff
+baseline_envelope_bandpass_lowf: 2  # envelope badpass lower cutoff
+baseline_envelope_bandpass_highf: 100  # envelope bandbass higher cutoff
+search_envelope_cutoff: 10  # search envelope estimation cufoff
 
 # Peak detecion parameters ----------------------------------------------------
 # baseline_prominence: 0.00005  # peak prominence threshold for baseline envelope
@@ -39,9 +39,8 @@ search_envelope_cutoff: 10              # search envelope estimation cufoff
 # frequency_prominence: 2       # peak prominence threshold for baseline freq
 
 baseline_prominence: 0.3  # peak prominence threshold for baseline envelope
-search_prominence: 0.3   # peak prominence threshold for search envelope
-frequency_prominence: 0.3       # peak prominence threshold for baseline freq
+search_prominence: 0.3  # peak prominence threshold for search envelope
+frequency_prominence: 0.3  # peak prominence threshold for baseline freq
 
 # Classify events as chirps if they are less than this time apart
 chirp_window_threshold: 0.02
-
diff --git a/code/eventchirpsplots.py b/code/eventchirpsplots.py
index 4ebaa66..5003a7d 100644
--- a/code/eventchirpsplots.py
+++ b/code/eventchirpsplots.py
@@ -35,28 +35,36 @@ class Behavior:
     """
 
     def __init__(self, folder_path: str) -> None:
-        print(f'{folder_path}')
-        LED_on_time_BORIS = np.load(os.path.join(
-            folder_path, 'LED_on_time.npy'), allow_pickle=True)
-        self.time = np.load(os.path.join(
-            folder_path, "times.npy"), allow_pickle=True)
-        csv_filename = [f for f in os.listdir(folder_path) if f.endswith(
-            '.csv')][0]  # check if there are more than one csv file
+        print(f"{folder_path}")
+        LED_on_time_BORIS = np.load(
+            os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True
+        )
+        self.time = np.load(
+            os.path.join(folder_path, "times.npy"), allow_pickle=True
+        )
+        csv_filename = [
+            f for f in os.listdir(folder_path) if f.endswith(".csv")
+        ][
+            0
+        ]  # check if there are more than one csv file
         self.dataframe = read_csv(os.path.join(folder_path, csv_filename))
-        self.chirps = np.load(os.path.join(
-            folder_path, 'chirps.npy'), allow_pickle=True)
-        self.chirps_ids = np.load(os.path.join(
-            folder_path, 'chirp_ids.npy'), allow_pickle=True)
+        self.chirps = np.load(
+            os.path.join(folder_path, "chirps.npy"), allow_pickle=True
+        )
+        self.chirps_ids = np.load(
+            os.path.join(folder_path, "chirp_ids.npy"), allow_pickle=True
+        )
 
         for k, key in enumerate(self.dataframe.keys()):
             key = key.lower()
-            if ' ' in key:
-                key = key.replace(' ', '_')
-                if '(' in key:
-                    key = key.replace('(', '')
-                    key = key.replace(')', '')
-            setattr(self, key, np.array(
-                self.dataframe[self.dataframe.keys()[k]]))
+            if " " in key:
+                key = key.replace(" ", "_")
+                if "(" in key:
+                    key = key.replace("(", "")
+                    key = key.replace(")", "")
+            setattr(
+                self, key, np.array(self.dataframe[self.dataframe.keys()[k]])
+            )
 
         last_LED_t_BORIS = LED_on_time_BORIS[-1]
         real_time_range = self.time[-1] - self.time[0]
@@ -95,17 +103,14 @@ temporal encpding needs to be corrected ... not exactly 25FPS.
 
 
 def correct_chasing_events(
-    category: np.ndarray,
-    timestamps: np.ndarray
+    category: np.ndarray, timestamps: np.ndarray
 ) -> tuple[np.ndarray, np.ndarray]:
+    onset_ids = np.arange(len(category))[category == 0]
+    offset_ids = np.arange(len(category))[category == 1]
 
-    onset_ids = np.arange(
-        len(category))[category == 0]
-    offset_ids = np.arange(
-        len(category))[category == 1]
-
-    wrong_bh = np.arange(len(category))[
-        category != 2][:-1][np.diff(category[category != 2]) == 0]
+    wrong_bh = np.arange(len(category))[category != 2][:-1][
+        np.diff(category[category != 2]) == 0
+    ]
     if onset_ids[0] > offset_ids[0]:
         offset_ids = np.delete(offset_ids, 0)
         help_index = offset_ids[0]
@@ -117,12 +122,12 @@ def correct_chasing_events(
     # Check whether on- or offset is longer and calculate length difference
     if len(onset_ids) > len(offset_ids):
         len_diff = len(onset_ids) - len(offset_ids)
-        logger.info(f'Onsets are greater than offsets by {len_diff}')
+        logger.info(f"Onsets are greater than offsets by {len_diff}")
     elif len(onset_ids) < len(offset_ids):
         len_diff = len(offset_ids) - len(onset_ids)
-        logger.info(f'Offsets are greater than onsets by {len_diff}')
+        logger.info(f"Offsets are greater than onsets by {len_diff}")
     elif len(onset_ids) == len(offset_ids):
-        logger.info('Chasing events are equal')
+        logger.info("Chasing events are equal")
 
     return category, timestamps
 
@@ -135,8 +140,7 @@ def event_triggered_chirps(
     dt: float,
     width: float,
 ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
-
-    event_chirps = []   # chirps that are in specified window around event
+    event_chirps = []  # chirps that are in specified window around event
     # timestamps of chirps around event centered on the event timepoint
     centered_chirps = []
 
@@ -159,16 +163,19 @@ def event_triggered_chirps(
     else:
         # convert list of arrays to one array for plotting
         centered_chirps = np.concatenate(centered_chirps, axis=0)
-        centered_chirps_convolved = (acausal_kde1d(
-            centered_chirps, time, width)) / len(event)
+        centered_chirps_convolved = (
+            acausal_kde1d(centered_chirps, time, width)
+        ) / len(event)
 
     return event_chirps, centered_chirps, centered_chirps_convolved
 
 
 def main(datapath: str):
-
     foldernames = [
-        datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath + x)]
+        datapath + x + "/"
+        for x in os.listdir(datapath)
+        if os.path.isdir(datapath + x)
+    ]
 
     nrecording_chirps = []
     nrecording_chirps_fish_ids = []
@@ -179,7 +186,7 @@ def main(datapath: str):
     # Iterate over all recordings and save chirp- and event-timestamps
     for folder in foldernames:
         # exclude folder with empty LED_on_time.npy
-        if folder == '../data/mount_data/2020-05-12-10_00/':
+        if folder == "../data/mount_data/2020-05-12-10_00/":
             continue
 
         bh = Behavior(folder)
@@ -209,7 +216,7 @@ def main(datapath: str):
     time_before_event = 30
     time_after_event = 60
     dt = 0.01
-    width = 1.5   # width of kernel for all recordings, currently gaussian kernel
+    width = 1.5  # width of kernel for all recordings, currently gaussian kernel
     recording_width = 2  # width of kernel for each recording
     time = np.arange(-time_before_event, time_after_event, dt)
 
@@ -232,18 +239,47 @@ def main(datapath: str):
         physical_contacts = nrecording_physicals[i]
 
         # Chirps around chasing onsets
-        _, centered_chasing_onset_chirps, cc_chasing_onset_chirps = event_triggered_chirps(
-            chasing_onsets, chirps, time_before_event, time_after_event, dt, recording_width)
+        (
+            _,
+            centered_chasing_onset_chirps,
+            cc_chasing_onset_chirps,
+        ) = event_triggered_chirps(
+            chasing_onsets,
+            chirps,
+            time_before_event,
+            time_after_event,
+            dt,
+            recording_width,
+        )
         # Chirps around chasing offsets
-        _, centered_chasing_offset_chirps, cc_chasing_offset_chirps = event_triggered_chirps(
-            chasing_offsets, chirps, time_before_event, time_after_event, dt, recording_width)
+        (
+            _,
+            centered_chasing_offset_chirps,
+            cc_chasing_offset_chirps,
+        ) = event_triggered_chirps(
+            chasing_offsets,
+            chirps,
+            time_before_event,
+            time_after_event,
+            dt,
+            recording_width,
+        )
         # Chirps around physical contacts
-        _, centered_physical_chirps, cc_physical_chirps = event_triggered_chirps(
-            physical_contacts, chirps, time_before_event, time_after_event, dt, recording_width)
+        (
+            _,
+            centered_physical_chirps,
+            cc_physical_chirps,
+        ) = event_triggered_chirps(
+            physical_contacts,
+            chirps,
+            time_before_event,
+            time_after_event,
+            dt,
+            recording_width,
+        )
 
         nrecording_centered_onset_chirps.append(centered_chasing_onset_chirps)
-        nrecording_centered_offset_chirps.append(
-            centered_chasing_offset_chirps)
+        nrecording_centered_offset_chirps.append(centered_chasing_offset_chirps)
         nrecording_centered_physical_chirps.append(centered_physical_chirps)
 
         ## Shuffled chirps ##
@@ -331,12 +367,13 @@ def main(datapath: str):
 
     # New bootstrapping approach
     for n in range(nbootstrapping):
-        diff_onset = np.diff(
-            np.sort(flatten(nrecording_centered_onset_chirps)))
+        diff_onset = np.diff(np.sort(flatten(nrecording_centered_onset_chirps)))
         diff_offset = np.diff(
-            np.sort(flatten(nrecording_centered_offset_chirps)))
+            np.sort(flatten(nrecording_centered_offset_chirps))
+        )
         diff_physical = np.diff(
-            np.sort(flatten(nrecording_centered_physical_chirps)))
+            np.sort(flatten(nrecording_centered_physical_chirps))
+        )
 
         np.random.shuffle(diff_onset)
         shuffled_onset = np.cumsum(diff_onset)
@@ -345,9 +382,11 @@ def main(datapath: str):
         np.random.shuffle(diff_physical)
         shuffled_physical = np.cumsum(diff_physical)
 
-        kde_onset (acausal_kde1d(shuffled_onset, time, width))/(27*100)
-        kde_offset = (acausal_kde1d(shuffled_offset, time, width))/(27*100)
-        kde_physical = (acausal_kde1d(shuffled_physical, time, width))/(27*100)
+        kde_onset(acausal_kde1d(shuffled_onset, time, width)) / (27 * 100)
+        kde_offset = (acausal_kde1d(shuffled_offset, time, width)) / (27 * 100)
+        kde_physical = (acausal_kde1d(shuffled_physical, time, width)) / (
+            27 * 100
+        )
 
         bootstrap_onset.append(kde_onset)
         bootstrap_offset.append(kde_offset)
@@ -355,11 +394,14 @@ def main(datapath: str):
 
     # New shuffle approach q5, q50, q95
     onset_q5, onset_median, onset_q95 = np.percentile(
-        bootstrap_onset, [5, 50, 95], axis=0)
+        bootstrap_onset, [5, 50, 95], axis=0
+    )
     offset_q5, offset_median, offset_q95 = np.percentile(
-        bootstrap_offset, [5, 50, 95], axis=0)
+        bootstrap_offset, [5, 50, 95], axis=0
+    )
     physical_q5, physical_median, physical_q95 = np.percentile(
-        bootstrap_physical, [5, 50, 95], axis=0)
+        bootstrap_physical, [5, 50, 95], axis=0
+    )
 
     #  vstack um 1. Dim zu cutten
     # nrecording_shuffled_convolved_onset_chirps = np.vstack(nrecording_shuffled_convolved_onset_chirps)
@@ -378,45 +420,66 @@ def main(datapath: str):
 
     # Flatten event timestamps
     all_onsets = np.concatenate(
-        nrecording_chasing_onsets).ravel()  # not centered
+        nrecording_chasing_onsets
+    ).ravel()  # not centered
     all_offsets = np.concatenate(
-        nrecording_chasing_offsets).ravel()  # not centered
-    all_physicals = np.concatenate(
-        nrecording_physicals).ravel()  # not centered
+        nrecording_chasing_offsets
+    ).ravel()  # not centered
+    all_physicals = np.concatenate(nrecording_physicals).ravel()  # not centered
 
     # Flatten all chirps around events
     all_onset_chirps = np.concatenate(
-        nrecording_centered_onset_chirps).ravel()   # centered
+        nrecording_centered_onset_chirps
+    ).ravel()  # centered
     all_offset_chirps = np.concatenate(
-        nrecording_centered_offset_chirps).ravel()  # centered
+        nrecording_centered_offset_chirps
+    ).ravel()  # centered
     all_physical_chirps = np.concatenate(
-        nrecording_centered_physical_chirps).ravel()  # centered
+        nrecording_centered_physical_chirps
+    ).ravel()  # centered
 
     # Convolute all chirps
     # Divide by total number of each event over all recordings
-    all_onset_chirps_convolved = (acausal_kde1d(
-        all_onset_chirps, time, width)) / len(all_onsets)
-    all_offset_chirps_convolved = (acausal_kde1d(
-        all_offset_chirps, time, width)) / len(all_offsets)
-    all_physical_chirps_convolved = (acausal_kde1d(
-        all_physical_chirps, time, width)) / len(all_physicals)
+    all_onset_chirps_convolved = (
+        acausal_kde1d(all_onset_chirps, time, width)
+    ) / len(all_onsets)
+    all_offset_chirps_convolved = (
+        acausal_kde1d(all_offset_chirps, time, width)
+    ) / len(all_offsets)
+    all_physical_chirps_convolved = (
+        acausal_kde1d(all_physical_chirps, time, width)
+    ) / len(all_physicals)
 
     # Plot all events with all shuffled
-    fig, ax = plt.subplots(1, 3, figsize=(
-        28*ps.cm, 16*ps.cm, ), constrained_layout=True, sharey='all')
+    fig, ax = plt.subplots(
+        1,
+        3,
+        figsize=(
+            28 * ps.cm,
+            16 * ps.cm,
+        ),
+        constrained_layout=True,
+        sharey="all",
+    )
     # offsets = np.arange(1,28,1)
-    ax[0].set_xlabel('Time[s]')
+    ax[0].set_xlabel("Time[s]")
 
     # Plot chasing onsets
-    ax[0].set_ylabel('Chirp rate [Hz]')
+    ax[0].set_ylabel("Chirp rate [Hz]")
     ax[0].plot(time, all_onset_chirps_convolved, color=ps.yellow, zorder=2)
     ax0 = ax[0].twinx()
     nrecording_centered_onset_chirps = np.asarray(
-        nrecording_centered_onset_chirps, dtype=object)
-    ax0.eventplot(np.array(nrecording_centered_onset_chirps),
-                  linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1)
-    ax0.vlines(0, 0, 1.5, ps.white, 'dashed')
-    ax[0].set_zorder(ax0.get_zorder()+1)
+        nrecording_centered_onset_chirps, dtype=object
+    )
+    ax0.eventplot(
+        np.array(nrecording_centered_onset_chirps),
+        linelengths=0.5,
+        colors=ps.gray,
+        alpha=0.25,
+        zorder=1,
+    )
+    ax0.vlines(0, 0, 1.5, ps.white, "dashed")
+    ax[0].set_zorder(ax0.get_zorder() + 1)
     ax[0].patch.set_visible(False)
     ax0.set_yticklabels([])
     ax0.set_yticks([])
@@ -426,15 +489,21 @@ def main(datapath: str):
     ax[0].plot(time, onset_median, color=ps.black)
 
     # Plot chasing offets
-    ax[1].set_xlabel('Time[s]')
+    ax[1].set_xlabel("Time[s]")
     ax[1].plot(time, all_offset_chirps_convolved, color=ps.orange, zorder=2)
     ax1 = ax[1].twinx()
     nrecording_centered_offset_chirps = np.asarray(
-        nrecording_centered_offset_chirps, dtype=object)
-    ax1.eventplot(np.array(nrecording_centered_offset_chirps),
-                  linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1)
-    ax1.vlines(0, 0, 1.5, ps.white, 'dashed')
-    ax[1].set_zorder(ax1.get_zorder()+1)
+        nrecording_centered_offset_chirps, dtype=object
+    )
+    ax1.eventplot(
+        np.array(nrecording_centered_offset_chirps),
+        linelengths=0.5,
+        colors=ps.gray,
+        alpha=0.25,
+        zorder=1,
+    )
+    ax1.vlines(0, 0, 1.5, ps.white, "dashed")
+    ax[1].set_zorder(ax1.get_zorder() + 1)
     ax[1].patch.set_visible(False)
     ax1.set_yticklabels([])
     ax1.set_yticks([])
@@ -444,24 +513,31 @@ def main(datapath: str):
     ax[1].plot(time, offset_median, color=ps.black)
 
     # Plot physical contacts
-    ax[2].set_xlabel('Time[s]')
+    ax[2].set_xlabel("Time[s]")
     ax[2].plot(time, all_physical_chirps_convolved, color=ps.maroon, zorder=2)
     ax2 = ax[2].twinx()
     nrecording_centered_physical_chirps = np.asarray(
-        nrecording_centered_physical_chirps, dtype=object)
-    ax2.eventplot(np.array(nrecording_centered_physical_chirps),
-                  linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1)
-    ax2.vlines(0, 0, 1.5, ps.white, 'dashed')
-    ax[2].set_zorder(ax2.get_zorder()+1)
+        nrecording_centered_physical_chirps, dtype=object
+    )
+    ax2.eventplot(
+        np.array(nrecording_centered_physical_chirps),
+        linelengths=0.5,
+        colors=ps.gray,
+        alpha=0.25,
+        zorder=1,
+    )
+    ax2.vlines(0, 0, 1.5, ps.white, "dashed")
+    ax[2].set_zorder(ax2.get_zorder() + 1)
     ax[2].patch.set_visible(False)
     ax2.set_yticklabels([])
     ax2.set_yticks([])
     # ax[2].fill_between(time, shuffled_q5_physical, shuffled_q95_physical, color=ps.gray, alpha=0.5)
     # ax[2].plot(time, shuffled_median_physical, ps.black)
-    ax[2].fill_between(time, physical_q5, physical_q95,
-                       color=ps.gray, alpha=0.5)
+    ax[2].fill_between(
+        time, physical_q5, physical_q95, color=ps.gray, alpha=0.5
+    )
     ax[2].plot(time, physical_median, ps.black)
-    fig.suptitle('All recordings')
+    fig.suptitle("All recordings")
     plt.show()
     plt.close()
 
@@ -587,7 +663,7 @@ def main(datapath: str):
     #### Chirps around events, only losers, one recording ####
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     # Path to the data
-    datapath = '../data/mount_data/'
+    datapath = "../data/mount_data/"
     main(datapath)
diff --git a/code/extract_chirps.py b/code/extract_chirps.py
index 77e3e8d..900f0a2 100644
--- a/code/extract_chirps.py
+++ b/code/extract_chirps.py
@@ -8,50 +8,51 @@ from IPython import embed
 
 
 def get_valid_datasets(dataroot):
-
-    datasets = sorted([name for name in os.listdir(dataroot) if os.path.isdir(
-        os.path.join(dataroot, name))])
+    datasets = sorted(
+        [
+            name
+            for name in os.listdir(dataroot)
+            if os.path.isdir(os.path.join(dataroot, name))
+        ]
+    )
 
     valid_datasets = []
     for dataset in datasets:
-
         path = os.path.join(dataroot, dataset)
-        csv_name = '-'.join(dataset.split('-')[:3]) + '.csv'
+        csv_name = "-".join(dataset.split("-")[:3]) + ".csv"
 
         if os.path.exists(os.path.join(path, csv_name)) is False:
             continue
 
-        if os.path.exists(os.path.join(path, 'ident_v.npy')) is False:
+        if os.path.exists(os.path.join(path, "ident_v.npy")) is False:
             continue
 
-        ident = np.load(os.path.join(path, 'ident_v.npy'))
+        ident = np.load(os.path.join(path, "ident_v.npy"))
         number_of_fish = len(np.unique(ident[~np.isnan(ident)]))
         if number_of_fish != 2:
             continue
 
         valid_datasets.append(dataset)
 
-    datapaths = [os.path.join(dataroot, dataset) +
-                 '/' for dataset in valid_datasets]
+    datapaths = [
+        os.path.join(dataroot, dataset) + "/" for dataset in valid_datasets
+    ]
 
     return datapaths, valid_datasets
 
 
 def main(datapaths):
-
     for path in datapaths:
-        chirpdetection(path, plot='show')
+        chirpdetection(path, plot="show")
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
+    dataroot = "../data/mount_data/"
 
-    dataroot = '../data/mount_data/'
+    datapaths, valid_datasets = get_valid_datasets(dataroot)
 
-
-    datapaths, valid_datasets= get_valid_datasets(dataroot)
-
-    recs = pd.DataFrame(columns=['recording'], data=valid_datasets)
-    recs.to_csv('../recs.csv', index=False)
+    recs = pd.DataFrame(columns=["recording"], data=valid_datasets)
+    recs.to_csv("../recs.csv", index=False)
     # datapaths = ['../data/mount_data/2020-03-25-10_00/']
     main(datapaths)
 
diff --git a/code/get_behaviour.py b/code/get_behaviour.py
index 36311ca..3513c1b 100644
--- a/code/get_behaviour.py
+++ b/code/get_behaviour.py
@@ -1,4 +1,4 @@
-import os 
+import os
 from paramiko import SSHClient
 from scp import SCPClient
 from IPython import embed
@@ -7,29 +7,41 @@ from pandas import read_csv
 ssh = SSHClient()
 ssh.load_system_host_keys()
 
-ssh.connect(hostname='kraken',
-            username='efish',
-            password='fwNix4U',
-            )
+ssh.connect(
+    hostname="kraken",
+    username="efish",
+    password="fwNix4U",
+)
 
 
 # SCPCLient takes a paramiko transport as its only argument
 scp = SCPClient(ssh.get_transport())
 
-data = read_csv('../recs.csv')
-foldernames = data['recording'].values
+data = read_csv("../recs.csv")
+foldernames = data["recording"].values
 
-directory = f'/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/mount_data/'
+directory = f"/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/mount_data/"
 for foldername in foldernames:
+    if not os.path.exists(directory + foldername):
+        os.makedirs(directory + foldername)
 
-    if not os.path.exists(directory+foldername):
-        os.makedirs(directory+foldername)
-
-    files = [('-').join(foldername.split('-')[:3])+'.csv','chirp_ids.npy', 'chirps.npy', 'fund_v.npy', 'ident_v.npy', 'idx_v.npy', 'times.npy', 'spec.npy', 'LED_on_time.npy', 'sign_v.npy']
-
+    files = [
+        ("-").join(foldername.split("-")[:3]) + ".csv",
+        "chirp_ids.npy",
+        "chirps.npy",
+        "fund_v.npy",
+        "ident_v.npy",
+        "idx_v.npy",
+        "times.npy",
+        "spec.npy",
+        "LED_on_time.npy",
+        "sign_v.npy",
+    ]
 
     for f in files:
-        scp.get(f'/home/efish/behavior/2019_tube_competition/{foldername}/{f}',
-                directory+foldername)
+        scp.get(
+            f"/home/efish/behavior/2019_tube_competition/{foldername}/{f}",
+            directory + foldername,
+        )
 
 scp.close()
diff --git a/code/modules/behaviour_handling.py b/code/modules/behaviour_handling.py
index 94a0ca1..a50d67a 100644
--- a/code/modules/behaviour_handling.py
+++ b/code/modules/behaviour_handling.py
@@ -30,12 +30,12 @@ class Behavior:
     """
 
     def __init__(self, folder_path: str) -> None:
-
-        LED_on_time_BORIS = np.load(os.path.join(
-            folder_path, 'LED_on_time.npy'), allow_pickle=True)
+        LED_on_time_BORIS = np.load(
+            os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True
+        )
 
         csv_filename = os.path.split(folder_path[:-1])[-1]
-        csv_filename = '-'.join(csv_filename.split('-')[:-1]) + '.csv'
+        csv_filename = "-".join(csv_filename.split("-")[:-1]) + ".csv"
         # embed()
 
         # csv_filename = [f for f in os.listdir(
@@ -43,31 +43,39 @@ class Behavior:
         # logger.info(f'CSV file: {csv_filename}')
         self.dataframe = read_csv(os.path.join(folder_path, csv_filename))
 
-        self.chirps = np.load(os.path.join(
-            folder_path, 'chirps.npy'), allow_pickle=True)
-        self.chirps_ids = np.load(os.path.join(
-            folder_path, 'chirp_ids.npy'), allow_pickle=True)
+        self.chirps = np.load(
+            os.path.join(folder_path, "chirps.npy"), allow_pickle=True
+        )
+        self.chirps_ids = np.load(
+            os.path.join(folder_path, "chirp_ids.npy"), allow_pickle=True
+        )
 
-        self.ident = np.load(os.path.join(
-            folder_path, 'ident_v.npy'), allow_pickle=True)
-        self.idx = np.load(os.path.join(
-            folder_path, 'idx_v.npy'), allow_pickle=True)
-        self.freq = np.load(os.path.join(
-            folder_path, 'fund_v.npy'), allow_pickle=True)
-        self.time = np.load(os.path.join(
-            folder_path, "times.npy"), allow_pickle=True)
-        self.spec = np.load(os.path.join(
-            folder_path, "spec.npy"), allow_pickle=True)
+        self.ident = np.load(
+            os.path.join(folder_path, "ident_v.npy"), allow_pickle=True
+        )
+        self.idx = np.load(
+            os.path.join(folder_path, "idx_v.npy"), allow_pickle=True
+        )
+        self.freq = np.load(
+            os.path.join(folder_path, "fund_v.npy"), allow_pickle=True
+        )
+        self.time = np.load(
+            os.path.join(folder_path, "times.npy"), allow_pickle=True
+        )
+        self.spec = np.load(
+            os.path.join(folder_path, "spec.npy"), allow_pickle=True
+        )
 
         for k, key in enumerate(self.dataframe.keys()):
             key = key.lower()
-            if ' ' in key:
-                key = key.replace(' ', '_')
-                if '(' in key:
-                    key = key.replace('(', '')
-                    key = key.replace(')', '')
-            setattr(self, key, np.array(
-                self.dataframe[self.dataframe.keys()[k]]))
+            if " " in key:
+                key = key.replace(" ", "_")
+                if "(" in key:
+                    key = key.replace("(", "")
+                    key = key.replace(")", "")
+            setattr(
+                self, key, np.array(self.dataframe[self.dataframe.keys()[k]])
+            )
 
         last_LED_t_BORIS = LED_on_time_BORIS[-1]
         real_time_range = self.time[-1] - self.time[0]
@@ -78,22 +86,19 @@ class Behavior:
 
 
 def correct_chasing_events(
-    category: np.ndarray,
-    timestamps: np.ndarray
+    category: np.ndarray, timestamps: np.ndarray
 ) -> tuple[np.ndarray, np.ndarray]:
+    onset_ids = np.arange(len(category))[category == 0]
+    offset_ids = np.arange(len(category))[category == 1]
 
-    onset_ids = np.arange(
-        len(category))[category == 0]
-    offset_ids = np.arange(
-        len(category))[category == 1]
-
-    wrong_bh = np.arange(len(category))[
-        category != 2][:-1][np.diff(category[category != 2]) == 0]
+    wrong_bh = np.arange(len(category))[category != 2][:-1][
+        np.diff(category[category != 2]) == 0
+    ]
 
     if category[category != 2][-1] == 0:
         wrong_bh = np.append(
-            wrong_bh,
-            np.arange(len(category))[category != 2][-1])
+            wrong_bh, np.arange(len(category))[category != 2][-1]
+        )
 
     if onset_ids[0] > offset_ids[0]:
         offset_ids = np.delete(offset_ids, 0)
@@ -103,18 +108,16 @@ def correct_chasing_events(
     category = np.delete(category, wrong_bh)
     timestamps = np.delete(timestamps, wrong_bh)
 
-    new_onset_ids = np.arange(
-        len(category))[category == 0]
-    new_offset_ids = np.arange(
-        len(category))[category == 1]
+    new_onset_ids = np.arange(len(category))[category == 0]
+    new_offset_ids = np.arange(len(category))[category == 1]
 
     # Check whether on- or offset is longer and calculate length difference
 
     if len(new_onset_ids) > len(new_offset_ids):
         embed()
-        logger.warning('Onsets are greater than offsets')
+        logger.warning("Onsets are greater than offsets")
     elif len(new_onset_ids) < len(new_offset_ids):
-        logger.warning('Offsets are greater than onsets')
+        logger.warning("Offsets are greater than onsets")
     elif len(new_onset_ids) == len(new_offset_ids):
         # logger.info('Chasing events are equal')
         pass
@@ -130,13 +133,11 @@ def center_chirps(
     # dt: float,
     # width: float,
 ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
-
-    event_chirps = []   # chirps that are in specified window around event
+    event_chirps = []  # chirps that are in specified window around event
     # timestamps of chirps around event centered on the event timepoint
     centered_chirps = []
 
     for event_timestamp in events:
-
         start = event_timestamp - time_before_event
         stop = event_timestamp + time_after_event
         chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)]
@@ -152,7 +153,8 @@ def center_chirps(
 
     if len(centered_chirps) != len(event_chirps):
         raise ValueError(
-            'Non centered chirps and centered chirps are not equal')
+            "Non centered chirps and centered chirps are not equal"
+        )
 
     # time = np.arange(-time_before_event, time_after_event, dt)
 
diff --git a/code/modules/datahandling.py b/code/modules/datahandling.py
index 0a240ab..68e73cd 100644
--- a/code/modules/datahandling.py
+++ b/code/modules/datahandling.py
@@ -23,7 +23,9 @@ def minmaxnorm(data):
     return (data - np.min(data)) / (np.max(data) - np.min(data))
 
 
-def instantaneous_frequency2(signal: np.ndarray, fs: float, interpolation: str = 'linear') -> np.ndarray:
+def instantaneous_frequency2(
+    signal: np.ndarray, fs: float, interpolation: str = "linear"
+) -> np.ndarray:
     """
     Compute the instantaneous frequency of a periodic signal using zero crossings and resample the frequency using linear
     or cubic interpolation to match the dimensions of the input array.
@@ -55,10 +57,10 @@ def instantaneous_frequency2(signal: np.ndarray, fs: float, interpolation: str =
     orig_len = len(signal)
     freq = resample(freq, orig_len)
 
-    if interpolation == 'linear':
+    if interpolation == "linear":
         freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq)
-    elif interpolation == 'cubic':
-        freq = resample(freq, orig_len, window='cubic')
+    elif interpolation == "cubic":
+        freq = resample(freq, orig_len, window="cubic")
 
     return freq
 
@@ -67,7 +69,7 @@ def instantaneous_frequency(
     signal: np.ndarray,
     samplerate: int,
     smoothing_window: int,
-    interpolation: str = 'linear',
+    interpolation: str = "linear",
 ) -> np.ndarray:
     """
     Compute the instantaneous frequency of a signal that is approximately
@@ -120,11 +122,10 @@ def instantaneous_frequency(
     orig_len = len(signal)
     freq = resample(instantaneous_frequency, orig_len)
 
-    if interpolation == 'linear':
+    if interpolation == "linear":
         freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq)
-    elif interpolation == 'cubic':
-        freq = resample(freq, orig_len, window='cubic')
-
+    elif interpolation == "cubic":
+        freq = resample(freq, orig_len, window="cubic")
 
     return freq
 
@@ -160,7 +161,6 @@ def purge_duplicates(
     group = [timestamps[0]]
 
     for i in range(1, len(timestamps)):
-
         # check the difference between current timestamp and previous
         # timestamp is less than the threshold
         if timestamps[i] - timestamps[i - 1] < threshold:
@@ -379,7 +379,6 @@ def acausal_kde1d(spikes, time, width):
 
 
 if __name__ == "__main__":
-
     timestamps = [
         [1.2, 1.5, 1.3],
         [],
diff --git a/code/modules/filehandling.py b/code/modules/filehandling.py
index c3c71f2..382a49d 100644
--- a/code/modules/filehandling.py
+++ b/code/modules/filehandling.py
@@ -35,7 +35,6 @@ class LoadData:
     """
 
     def __init__(self, datapath: str) -> None:
-
         # load raw data
         self.datapath = datapath
         self.file = os.path.join(datapath, "traces-grid1.raw")
diff --git a/code/modules/filters.py b/code/modules/filters.py
index e6d9896..06fe236 100644
--- a/code/modules/filters.py
+++ b/code/modules/filters.py
@@ -3,10 +3,10 @@ import numpy as np
 
 
 def bandpass_filter(
-        signal: np.ndarray,
-        samplerate: float,
-        lowf: float,
-        highf: float,
+    signal: np.ndarray,
+    samplerate: float,
+    lowf: float,
+    highf: float,
 ) -> np.ndarray:
     """Bandpass filter a signal.
 
@@ -60,9 +60,7 @@ def highpass_filter(
 
 
 def lowpass_filter(
-    signal: np.ndarray,
-    samplerate: float,
-    cutoff: float
+    signal: np.ndarray, samplerate: float, cutoff: float
 ) -> np.ndarray:
     """Lowpass filter a signal.
 
@@ -86,10 +84,9 @@ def lowpass_filter(
     return filtered_signal
 
 
-def envelope(signal: np.ndarray,
-             samplerate: float,
-             cutoff_frequency: float
-             ) -> np.ndarray:
+def envelope(
+    signal: np.ndarray, samplerate: float, cutoff_frequency: float
+) -> np.ndarray:
     """Calculate the envelope of a signal using a lowpass filter.
 
     Parameters
diff --git a/code/modules/logger.py b/code/modules/logger.py
index 5dabf80..ed6d93e 100644
--- a/code/modules/logger.py
+++ b/code/modules/logger.py
@@ -2,12 +2,13 @@ import logging
 
 
 def makeLogger(name: str):
-
     # create logger formats for file and terminal
     file_formatter = logging.Formatter(
-        "[ %(levelname)s ] ~ %(asctime)s ~ %(module)s.%(funcName)s: %(message)s")
+        "[ %(levelname)s ] ~ %(asctime)s ~ %(module)s.%(funcName)s: %(message)s"
+    )
     console_formatter = logging.Formatter(
-        "[ %(levelname)s ] in %(module)s.%(funcName)s: %(message)s")
+        "[ %(levelname)s ] in %(module)s.%(funcName)s: %(message)s"
+    )
 
     # create logging file if loglevel is debug
     file_handler = logging.FileHandler(f"gridtools_log.log", mode="w")
@@ -29,7 +30,6 @@ def makeLogger(name: str):
 
 
 if __name__ == "__main__":
-
     # initiate logger
     mylogger = makeLogger(__name__)
 
diff --git a/code/modules/plotstyle.py b/code/modules/plotstyle.py
index 22b14c6..43d12ac 100644
--- a/code/modules/plotstyle.py
+++ b/code/modules/plotstyle.py
@@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap
 
 def PlotStyle() -> None:
     class style:
-
         # lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8)
 
         # units
@@ -76,13 +75,15 @@ def PlotStyle() -> None:
                 va="center",
                 zorder=1000,
                 bbox=dict(
-                    boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1
+                    boxstyle=f"circle, pad={padding}",
+                    fc="white",
+                    ec="black",
+                    lw=1,
                 ),
             )
 
         @classmethod
         def fade_cmap(cls, cmap):
-
             my_cmap = cmap(np.arange(cmap.N))
             my_cmap[:, -1] = np.linspace(0, 1, cmap.N)
             my_cmap = ListedColormap(my_cmap)
@@ -295,7 +296,6 @@ def PlotStyle() -> None:
 
 
 if __name__ == "__main__":
-
     s = PlotStyle()
 
     import matplotlib.cbook as cbook
@@ -347,7 +347,8 @@ if __name__ == "__main__":
     for ax in axs:
         ax.yaxis.grid(True)
         ax.set_xticks(
-            [y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"]
+            [y + 1 for y in range(len(all_data))],
+            labels=["x1", "x2", "x3", "x4"],
         )
         ax.set_xlabel("Four separate samples")
         ax.set_ylabel("Observed values")
@@ -396,7 +397,10 @@ if __name__ == "__main__":
     grid = np.random.rand(4, 4)
 
     fig, axs = plt.subplots(
-        nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []}
+        nrows=3,
+        ncols=6,
+        figsize=(9, 6),
+        subplot_kw={"xticks": [], "yticks": []},
     )
 
     for ax, interp_method in zip(axs.flat, methods):
diff --git a/code/modules/plotstyle1.py b/code/modules/plotstyle1.py
index 32af4d2..237996b 100644
--- a/code/modules/plotstyle1.py
+++ b/code/modules/plotstyle1.py
@@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap
 
 def PlotStyle() -> None:
     class style:
-
         # lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8)
 
         # units
@@ -76,13 +75,15 @@ def PlotStyle() -> None:
                 va="center",
                 zorder=1000,
                 bbox=dict(
-                    boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1
+                    boxstyle=f"circle, pad={padding}",
+                    fc="white",
+                    ec="black",
+                    lw=1,
                 ),
             )
 
         @classmethod
         def fade_cmap(cls, cmap):
-
             my_cmap = cmap(np.arange(cmap.N))
             my_cmap[:, -1] = np.linspace(0, 1, cmap.N)
             my_cmap = ListedColormap(my_cmap)
@@ -295,7 +296,6 @@ def PlotStyle() -> None:
 
 
 if __name__ == "__main__":
-
     s = PlotStyle()
 
     import matplotlib.cbook as cbook
@@ -347,7 +347,8 @@ if __name__ == "__main__":
     for ax in axs:
         ax.yaxis.grid(True)
         ax.set_xticks(
-            [y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"]
+            [y + 1 for y in range(len(all_data))],
+            labels=["x1", "x2", "x3", "x4"],
         )
         ax.set_xlabel("Four separate samples")
         ax.set_ylabel("Observed values")
@@ -396,7 +397,10 @@ if __name__ == "__main__":
     grid = np.random.rand(4, 4)
 
     fig, axs = plt.subplots(
-        nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []}
+        nrows=3,
+        ncols=6,
+        figsize=(9, 6),
+        subplot_kw={"xticks": [], "yticks": []},
     )
 
     for ax, interp_method in zip(axs.flat, methods):
diff --git a/code/modules/plotstyle_dark.py b/code/modules/plotstyle_dark.py
index d5b9557..d767e24 100644
--- a/code/modules/plotstyle_dark.py
+++ b/code/modules/plotstyle_dark.py
@@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap
 
 def PlotStyle() -> None:
     class style:
-
         # lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8)
 
         # units
@@ -76,13 +75,15 @@ def PlotStyle() -> None:
                 va="center",
                 zorder=1000,
                 bbox=dict(
-                    boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1
+                    boxstyle=f"circle, pad={padding}",
+                    fc="white",
+                    ec="black",
+                    lw=1,
                 ),
             )
 
         @classmethod
         def fade_cmap(cls, cmap):
-
             my_cmap = cmap(np.arange(cmap.N))
             my_cmap[:, -1] = np.linspace(0, 1, cmap.N)
             my_cmap = ListedColormap(my_cmap)
@@ -295,7 +296,6 @@ def PlotStyle() -> None:
 
 
 if __name__ == "__main__":
-
     s = PlotStyle()
 
     import matplotlib.cbook as cbook
@@ -347,7 +347,8 @@ if __name__ == "__main__":
     for ax in axs:
         ax.yaxis.grid(True)
         ax.set_xticks(
-            [y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"]
+            [y + 1 for y in range(len(all_data))],
+            labels=["x1", "x2", "x3", "x4"],
         )
         ax.set_xlabel("Four separate samples")
         ax.set_ylabel("Observed values")
@@ -396,7 +397,10 @@ if __name__ == "__main__":
     grid = np.random.rand(4, 4)
 
     fig, axs = plt.subplots(
-        nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []}
+        nrows=3,
+        ncols=6,
+        figsize=(9, 6),
+        subplot_kw={"xticks": [], "yticks": []},
     )
 
     for ax, interp_method in zip(axs.flat, methods):
diff --git a/code/modules/simulations.py b/code/modules/simulations.py
index 473bac8..a074801 100644
--- a/code/modules/simulations.py
+++ b/code/modules/simulations.py
@@ -37,7 +37,7 @@ def create_chirp(
 
     ck = 0
     csig = 0.5 * chirpduration / np.power(2.0 * np.log(10.0), 0.5 / kurtosis)
-    #csig = csig*-1
+    # csig = csig*-1
     for k, t in enumerate(time):
         a = 1.0
         f = eodf
diff --git a/code/plot_chirp_size.py b/code/plot_chirp_size.py
index 95b2a95..1153ff5 100644
--- a/code/plot_chirp_size.py
+++ b/code/plot_chirp_size.py
@@ -16,26 +16,25 @@ logger = makeLogger(__name__)
 
 
 def get_chirp_winner_loser(folder_name, Behavior, order_meta_df):
-
-    foldername = folder_name.split('/')[-2]
-    winner_row = order_meta_df[order_meta_df['recording'] == foldername]
-    winner = winner_row['winner'].values[0].astype(int)
-    winner_fish1 = winner_row['fish1'].values[0].astype(int)
-    winner_fish2 = winner_row['fish2'].values[0].astype(int)
+    foldername = folder_name.split("/")[-2]
+    winner_row = order_meta_df[order_meta_df["recording"] == foldername]
+    winner = winner_row["winner"].values[0].astype(int)
+    winner_fish1 = winner_row["fish1"].values[0].astype(int)
+    winner_fish2 = winner_row["fish2"].values[0].astype(int)
 
     if winner > 0:
         if winner == winner_fish1:
-            winner_fish_id = winner_row['rec_id1'].values[0]
-            loser_fish_id = winner_row['rec_id2'].values[0]
+            winner_fish_id = winner_row["rec_id1"].values[0]
+            loser_fish_id = winner_row["rec_id2"].values[0]
 
         elif winner == winner_fish2:
-            winner_fish_id = winner_row['rec_id2'].values[0]
-            loser_fish_id = winner_row['rec_id1'].values[0]
+            winner_fish_id = winner_row["rec_id2"].values[0]
+            loser_fish_id = winner_row["rec_id1"].values[0]
 
         chirp_winner = len(
-            Behavior.chirps[Behavior.chirps_ids == winner_fish_id])
-        chirp_loser = len(
-            Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
+            Behavior.chirps[Behavior.chirps_ids == winner_fish_id]
+        )
+        chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
 
         return chirp_winner, chirp_loser
     else:
@@ -43,24 +42,24 @@ def get_chirp_winner_loser(folder_name, Behavior, order_meta_df):
 
 
 def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df):
+    foldername = folder_name.split("/")[-2]
+    folder_row = order_meta_df[order_meta_df["recording"] == foldername]
+    fish1 = folder_row["fish1"].values[0].astype(int)
+    fish2 = folder_row["fish2"].values[0].astype(int)
+    winner = folder_row["winner"].values[0].astype(int)
 
-    foldername = folder_name.split('/')[-2]
-    folder_row = order_meta_df[order_meta_df['recording'] == foldername]
-    fish1 = folder_row['fish1'].values[0].astype(int)
-    fish2 = folder_row['fish2'].values[0].astype(int)
-    winner = folder_row['winner'].values[0].astype(int)
+    groub = folder_row["group"].values[0].astype(int)
+    size_fish1_row = id_meta_df[
+        (id_meta_df["group"] == groub) & (id_meta_df["fish"] == fish1)
+    ]
+    size_fish2_row = id_meta_df[
+        (id_meta_df["group"] == groub) & (id_meta_df["fish"] == fish2)
+    ]
 
-    groub = folder_row['group'].values[0].astype(int)
-    size_fish1_row = id_meta_df[(id_meta_df['group'] == groub) & (
-        id_meta_df['fish'] == fish1)]
-    size_fish2_row = id_meta_df[(id_meta_df['group'] == groub) & (
-        id_meta_df['fish'] == fish2)]
-
-    size_winners = [size_fish1_row[col].values[0]
-                    for col in ['l1', 'l2', 'l3']]
+    size_winners = [size_fish1_row[col].values[0] for col in ["l1", "l2", "l3"]]
     size_fish1 = np.nanmean(size_winners)
 
-    size_losers = [size_fish2_row[col].values[0] for col in ['l1', 'l2', 'l3']]
+    size_losers = [size_fish2_row[col].values[0] for col in ["l1", "l2", "l3"]]
     size_fish2 = np.nanmean(size_losers)
 
     if winner == fish1:
@@ -75,8 +74,8 @@ def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df):
             size_diff_bigger = 0
             size_diff_smaller = 0
 
-        winner_fish_id = folder_row['rec_id1'].values[0]
-        loser_fish_id = folder_row['rec_id2'].values[0]
+        winner_fish_id = folder_row["rec_id1"].values[0]
+        loser_fish_id = folder_row["rec_id2"].values[0]
 
     elif winner == fish2:
         if size_fish2 > size_fish1:
@@ -90,39 +89,39 @@ def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df):
             size_diff_bigger = 0
             size_diff_smaller = 0
 
-        winner_fish_id = folder_row['rec_id2'].values[0]
-        loser_fish_id = folder_row['rec_id1'].values[0]
+        winner_fish_id = folder_row["rec_id2"].values[0]
+        loser_fish_id = folder_row["rec_id1"].values[0]
     else:
         size_diff_bigger = np.nan
         size_diff_smaller = np.nan
         winner_fish_id = np.nan
         loser_fish_id = np.nan
 
-        return size_diff_bigger, size_diff_smaller, winner_fish_id, loser_fish_id
+        return (
+            size_diff_bigger,
+            size_diff_smaller,
+            winner_fish_id,
+            loser_fish_id,
+        )
 
-    chirp_winner = len(
-        Behavior.chirps[Behavior.chirps_ids == winner_fish_id])
-    chirp_loser = len(
-        Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
+    chirp_winner = len(Behavior.chirps[Behavior.chirps_ids == winner_fish_id])
+    chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
 
-    return size_diff_bigger, chirp_winner,  size_diff_smaller, chirp_loser
+    return size_diff_bigger, chirp_winner, size_diff_smaller, chirp_loser
 
 
 def get_chirp_freq(folder_name, Behavior, order_meta_df):
+    foldername = folder_name.split("/")[-2]
+    folder_row = order_meta_df[order_meta_df["recording"] == foldername]
+    fish1 = folder_row["fish1"].values[0].astype(int)
+    fish2 = folder_row["fish2"].values[0].astype(int)
 
-    foldername = folder_name.split('/')[-2]
-    folder_row = order_meta_df[order_meta_df['recording'] == foldername]
-    fish1 = folder_row['fish1'].values[0].astype(int)
-    fish2 = folder_row['fish2'].values[0].astype(int)
+    fish1_freq = folder_row["rec_id1"].values[0].astype(int)
+    fish2_freq = folder_row["rec_id2"].values[0].astype(int)
 
-    fish1_freq = folder_row['rec_id1'].values[0].astype(int)
-    fish2_freq = folder_row['rec_id2'].values[0].astype(int)
-
-    chirp_freq_fish1 = np.nanmedian(
-        Behavior.freq[Behavior.ident == fish1_freq])
-    chirp_freq_fish2 = np.nanmedian(
-        Behavior.freq[Behavior.ident == fish2_freq])
-    winner = folder_row['winner'].values[0].astype(int)
+    chirp_freq_fish1 = np.nanmedian(Behavior.freq[Behavior.ident == fish1_freq])
+    chirp_freq_fish2 = np.nanmedian(Behavior.freq[Behavior.ident == fish2_freq])
+    winner = folder_row["winner"].values[0].astype(int)
 
     if winner == fish1:
         # if chirp_freq_fish1 > chirp_freq_fish2:
@@ -138,9 +137,9 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df):
         #     winner_fish_id = np.nan
         #     loser_fish_id = np.nan
 
-        winner_fish_id = folder_row['rec_id1'].values[0]
+        winner_fish_id = folder_row["rec_id1"].values[0]
         winner_fish_freq = chirp_freq_fish1
-        loser_fish_id = folder_row['rec_id2'].values[0]
+        loser_fish_id = folder_row["rec_id2"].values[0]
         loser_fish_freq = chirp_freq_fish2
 
     elif winner == fish2:
@@ -157,9 +156,9 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df):
         #     winner_fish_id = np.nan
         #     loser_fish_id = np.nan
 
-        winner_fish_id = folder_row['rec_id2'].values[0]
+        winner_fish_id = folder_row["rec_id2"].values[0]
         winner_fish_freq = chirp_freq_fish2
-        loser_fish_id = folder_row['rec_id1'].values[0]
+        loser_fish_id = folder_row["rec_id1"].values[0]
         loser_fish_freq = chirp_freq_fish1
     else:
         winner_fish_freq = np.nan
@@ -168,25 +167,25 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df):
         loser_fish_id = np.nan
         return winner_fish_freq, winner_fish_id, loser_fish_freq, loser_fish_id
 
-    chirp_winner = len(
-        Behavior.chirps[Behavior.chirps_ids == winner_fish_id])
-    chirp_loser = len(
-        Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
+    chirp_winner = len(Behavior.chirps[Behavior.chirps_ids == winner_fish_id])
+    chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
 
     return winner_fish_freq, chirp_winner, loser_fish_freq, chirp_loser
 
 
 def main(datapath: str):
-
     foldernames = [
-        datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)]
+        datapath + x + "/"
+        for x in os.listdir(datapath)
+        if os.path.isdir(datapath + x)
+    ]
     foldernames, _ = get_valid_datasets(datapath)
-    path_order_meta = (
-        '/').join(foldernames[0].split('/')[:-2]) + '/order_meta.csv'
+    path_order_meta = ("/").join(
+        foldernames[0].split("/")[:-2]
+    ) + "/order_meta.csv"
     order_meta_df = read_csv(path_order_meta)
-    order_meta_df['recording'] = order_meta_df['recording'].str[1:-1]
-    path_id_meta = (
-        '/').join(foldernames[0].split('/')[:-2]) + '/id_meta.csv'
+    order_meta_df["recording"] = order_meta_df["recording"].str[1:-1]
+    path_id_meta = ("/").join(foldernames[0].split("/")[:-2]) + "/id_meta.csv"
     id_meta_df = read_csv(path_id_meta)
 
     chirps_winner = []
@@ -202,10 +201,9 @@ def main(datapath: str):
     freq_chirps_winner = []
     freq_chirps_loser = []
 
-
     for foldername in foldernames:
         # behabvior is pandas dataframe with all the data
-        if foldername == '../data/mount_data/2020-05-12-10_00/':
+        if foldername == "../data/mount_data/2020-05-12-10_00/":
             continue
         bh = Behavior(foldername)
         # chirps are not sorted in time (presumably due to prior groupings)
@@ -217,15 +215,24 @@ def main(datapath: str):
         category, timestamps = correct_chasing_events(category, timestamps)
 
         winner_chirp, loser_chirp = get_chirp_winner_loser(
-            foldername,  bh, order_meta_df)
+            foldername, bh, order_meta_df
+        )
         chirps_winner.append(winner_chirp)
         chirps_loser.append(loser_chirp)
 
-        size_diff_bigger, chirp_winner,  size_diff_smaller, chirp_loser = get_chirp_size(
-            foldername, bh, order_meta_df, id_meta_df)
+        (
+            size_diff_bigger,
+            chirp_winner,
+            size_diff_smaller,
+            chirp_loser,
+        ) = get_chirp_size(foldername, bh, order_meta_df, id_meta_df)
 
-        freq_winner, chirp_freq_winner, freq_loser, chirp_freq_loser = get_chirp_freq(
-            foldername, bh, order_meta_df)
+        (
+            freq_winner,
+            chirp_freq_winner,
+            freq_loser,
+            chirp_freq_loser,
+        ) = get_chirp_freq(foldername, bh, order_meta_df)
 
         freq_diffs_higher.append(freq_winner)
         freq_diffs_lower.append(freq_loser)
@@ -242,82 +249,124 @@ def main(datapath: str):
     pearsonr(size_diffs_winner, size_chirps_winner)
     pearsonr(size_diffs_loser, size_chirps_loser)
 
-    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(
-        21*ps.cm, 7*ps.cm), width_ratios=[1, 0.8, 0.8], sharey=True)
-    plt.subplots_adjust(left=0.11, right=0.948, top=0.86,
-                        wspace=0.343, bottom=0.198)
+    fig, (ax1, ax2, ax3) = plt.subplots(
+        1,
+        3,
+        figsize=(21 * ps.cm, 7 * ps.cm),
+        width_ratios=[1, 0.8, 0.8],
+        sharey=True,
+    )
+    plt.subplots_adjust(
+        left=0.11, right=0.948, top=0.86, wspace=0.343, bottom=0.198
+    )
     scatterwinner = 1.15
     scatterloser = 1.85
     chirps_winner = np.asarray(chirps_winner)[~np.isnan(chirps_winner)]
     chirps_loser = np.asarray(chirps_loser)[~np.isnan(chirps_loser)]
     embed()
     exit()
-    freq_diffs_higher = np.asarray(
-        freq_diffs_higher)[~np.isnan(freq_diffs_higher)]
-    freq_diffs_lower = np.asarray(freq_diffs_lower)[
-        ~np.isnan(freq_diffs_lower)]
-    freq_chirps_winner = np.asarray(
-        freq_chirps_winner)[~np.isnan(freq_chirps_winner)]
-    freq_chirps_loser = np.asarray(
-        freq_chirps_loser)[~np.isnan(freq_chirps_loser)]
+    freq_diffs_higher = np.asarray(freq_diffs_higher)[
+        ~np.isnan(freq_diffs_higher)
+    ]
+    freq_diffs_lower = np.asarray(freq_diffs_lower)[~np.isnan(freq_diffs_lower)]
+    freq_chirps_winner = np.asarray(freq_chirps_winner)[
+        ~np.isnan(freq_chirps_winner)
+    ]
+    freq_chirps_loser = np.asarray(freq_chirps_loser)[
+        ~np.isnan(freq_chirps_loser)
+    ]
 
     stat = wilcoxon(chirps_winner, chirps_loser)
     print(stat)
     winner_color = ps.gblue2
     loser_color = ps.gblue1
 
-    bplot1 = ax1.boxplot(chirps_winner, positions=[
-        0.9], showfliers=False, patch_artist=True)
+    bplot1 = ax1.boxplot(
+        chirps_winner, positions=[0.9], showfliers=False, patch_artist=True
+    )
 
-    bplot2 = ax1.boxplot(chirps_loser,  positions=[
-        2.1], showfliers=False, patch_artist=True)
+    bplot2 = ax1.boxplot(
+        chirps_loser, positions=[2.1], showfliers=False, patch_artist=True
+    )
 
-    ax1.scatter(np.ones(len(chirps_winner)) *
-                scatterwinner, chirps_winner, color=winner_color)
-    ax1.scatter(np.ones(len(chirps_loser)) *
-                scatterloser, chirps_loser, color=loser_color)
-    ax1.set_xticklabels(['Winner', 'Loser'])
+    ax1.scatter(
+        np.ones(len(chirps_winner)) * scatterwinner,
+        chirps_winner,
+        color=winner_color,
+    )
+    ax1.scatter(
+        np.ones(len(chirps_loser)) * scatterloser,
+        chirps_loser,
+        color=loser_color,
+    )
+    ax1.set_xticklabels(["Winner", "Loser"])
 
-    ax1.text(0.1, 0.85, f'n={len(chirps_loser)}',
-             transform=ax1.transAxes, color=ps.white)
+    ax1.text(
+        0.1,
+        0.85,
+        f"n={len(chirps_loser)}",
+        transform=ax1.transAxes,
+        color=ps.white,
+    )
 
     for w, l in zip(chirps_winner, chirps_loser):
-        ax1.plot([scatterwinner, scatterloser], [w, l],
-                 color=ps.white, alpha=0.6, linewidth=1, zorder=-1)
-    ax1.set_ylabel('Chirp counts', color=ps.white)
-    ax1.set_xlabel('Competition outcome',    color=ps.white)
+        ax1.plot(
+            [scatterwinner, scatterloser],
+            [w, l],
+            color=ps.white,
+            alpha=0.6,
+            linewidth=1,
+            zorder=-1,
+        )
+    ax1.set_ylabel("Chirp counts", color=ps.white)
+    ax1.set_xlabel("Competition outcome", color=ps.white)
 
     ps.set_boxplot_color(bplot1, winner_color)
     ps.set_boxplot_color(bplot2, loser_color)
 
-    ax2.scatter(size_diffs_winner, size_chirps_winner,
-                color=winner_color, label='Winner')
-    ax2.scatter(size_diffs_loser, size_chirps_loser,
-                color=loser_color, label='Loser')
+    ax2.scatter(
+        size_diffs_winner,
+        size_chirps_winner,
+        color=winner_color,
+        label="Winner",
+    )
+    ax2.scatter(
+        size_diffs_loser, size_chirps_loser, color=loser_color, label="Loser"
+    )
 
-    ax2.text(0.05, 0.85, f'n={len(size_chirps_loser)}',
-             transform=ax2.transAxes, color=ps.white)
+    ax2.text(
+        0.05,
+        0.85,
+        f"n={len(size_chirps_loser)}",
+        transform=ax2.transAxes,
+        color=ps.white,
+    )
 
-    ax2.set_xlabel('Size difference [cm]')
+    ax2.set_xlabel("Size difference [cm]")
     # ax2.set_xticks(np.arange(-10, 10.1, 2))
     ax3.scatter(freq_diffs_higher, freq_chirps_winner, color=winner_color)
     ax3.scatter(freq_diffs_lower, freq_chirps_loser, color=loser_color)
 
-    ax3.text(0.1, 0.85, f'n={len(np.asarray(freq_chirps_winner)[~np.isnan(freq_chirps_loser)])}',
-             transform=ax3.transAxes, color=ps.white)
+    ax3.text(
+        0.1,
+        0.85,
+        f"n={len(np.asarray(freq_chirps_winner)[~np.isnan(freq_chirps_loser)])}",
+        transform=ax3.transAxes,
+        color=ps.white,
+    )
 
-    ax3.set_xlabel('EODf [Hz]')
+    ax3.set_xlabel("EODf [Hz]")
     handles, labels = ax2.get_legend_handles_labels()
-    fig.legend(handles, labels, loc='upper center',
-               ncol=2, bbox_to_anchor=(0.5, 1.04))
+    fig.legend(
+        handles, labels, loc="upper center", ncol=2, bbox_to_anchor=(0.5, 1.04)
+    )
     # pearson r
-    plt.savefig('../poster/figs/chirps_winner_loser.pdf')
+    plt.savefig("../poster/figs/chirps_winner_loser.pdf")
     plt.show()
 
 
-if __name__ == '__main__':
-
+if __name__ == "__main__":
     # Path to the data
-    datapath = '../data/mount_data/'
+    datapath = "../data/mount_data/"
 
     main(datapath)
diff --git a/code/plot_chirps_in_chasing.py b/code/plot_chirps_in_chasing.py
index ee43196..ef0e5a7 100644
--- a/code/plot_chirps_in_chasing.py
+++ b/code/plot_chirps_in_chasing.py
@@ -21,14 +21,16 @@ logger = makeLogger(__name__)
 
 
 def main(datapath: str):
-
     foldernames = [
-        datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)]
+        datapath + x + "/"
+        for x in os.listdir(datapath)
+        if os.path.isdir(datapath + x)
+    ]
     time_precents = []
     chirps_percents = []
     for foldername in foldernames:
         # behabvior is pandas dataframe with all the data
-        if foldername == '../data/mount_data/2020-05-12-10_00/':
+        if foldername == "../data/mount_data/2020-05-12-10_00/":
             continue
         bh = Behavior(foldername)
 
@@ -46,50 +48,70 @@ def main(datapath: str):
         chirps_in_chasings = []
         for onset, offset in zip(chasing_onset, chasing_offset):
             chirps_in_chasing = [
-                c for c in bh.chirps if (c > onset) & (c < offset)]
+                c for c in bh.chirps if (c > onset) & (c < offset)
+            ]
             chirps_in_chasings.append(chirps_in_chasing)
 
         try:
             time_chasing = np.sum(
-                chasing_offset[chasing_offset < 3*60*60] - chasing_onset[chasing_onset < 3*60*60])
+                chasing_offset[chasing_offset < 3 * 60 * 60]
+                - chasing_onset[chasing_onset < 3 * 60 * 60]
+            )
         except:
             time_chasing = np.sum(
-                chasing_offset[chasing_offset < 3*60*60] - chasing_onset[chasing_onset < 3*60*60][:-1])
+                chasing_offset[chasing_offset < 3 * 60 * 60]
+                - chasing_onset[chasing_onset < 3 * 60 * 60][:-1]
+            )
 
-        time_chasing_percent = (time_chasing/(3*60*60))*100
+        time_chasing_percent = (time_chasing / (3 * 60 * 60)) * 100
         chirps_chasing = np.asarray(flatten(chirps_in_chasings))
-        chirps_chasing_new = chirps_chasing[chirps_chasing < 3*60*60]
-        chirps_percent = (len(chirps_chasing_new) /
-                          len(bh.chirps[bh.chirps < 3*60*60]))*100
+        chirps_chasing_new = chirps_chasing[chirps_chasing < 3 * 60 * 60]
+        chirps_percent = (
+            len(chirps_chasing_new) / len(bh.chirps[bh.chirps < 3 * 60 * 60])
+        ) * 100
 
         time_precents.append(time_chasing_percent)
         chirps_percents.append(chirps_percent)
 
-    fig, ax = plt.subplots(1, 1, figsize=(7*ps.cm, 7*ps.cm))
+    fig, ax = plt.subplots(1, 1, figsize=(7 * ps.cm, 7 * ps.cm))
     scatter_time = 1.20
     scatter_chirps = 1.80
     size = 10
-    bplot1 = ax.boxplot([time_precents, chirps_percents],
-                        showfliers=False, patch_artist=True)
+    bplot1 = ax.boxplot(
+        [time_precents, chirps_percents], showfliers=False, patch_artist=True
+    )
     ps.set_boxplot_color(bplot1, ps.gray)
-    ax.set_xticklabels(['Time \nchasing', 'Chirps \nin chasing'])
-    ax.set_ylabel('Percent')
-    ax.scatter(np.ones(len(time_precents))*scatter_time, time_precents,
-               facecolor=ps.white, s=size)
-    ax.scatter(np.ones(len(chirps_percents))*scatter_chirps, chirps_percents,
-               facecolor=ps.white, s=size)
+    ax.set_xticklabels(["Time \nchasing", "Chirps \nin chasing"])
+    ax.set_ylabel("Percent")
+    ax.scatter(
+        np.ones(len(time_precents)) * scatter_time,
+        time_precents,
+        facecolor=ps.white,
+        s=size,
+    )
+    ax.scatter(
+        np.ones(len(chirps_percents)) * scatter_chirps,
+        chirps_percents,
+        facecolor=ps.white,
+        s=size,
+    )
 
     for i in range(len(time_precents)):
-        ax.plot([scatter_time, scatter_chirps], [time_precents[i],
-                chirps_percents[i]], alpha=0.6, linewidth=1, color=ps.white)
+        ax.plot(
+            [scatter_time, scatter_chirps],
+            [time_precents[i], chirps_percents[i]],
+            alpha=0.6,
+            linewidth=1,
+            color=ps.white,
+        )
 
-    ax.text(0.1, 0.9, f'n={len(time_precents)}', transform=ax.transAxes)
+    ax.text(0.1, 0.9, f"n={len(time_precents)}", transform=ax.transAxes)
     plt.subplots_adjust(left=0.221, bottom=0.186, right=0.97, top=0.967)
-    plt.savefig('../poster/figs/chirps_in_chasing.pdf')
+    plt.savefig("../poster/figs/chirps_in_chasing.pdf")
     plt.show()
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     # Path to the data
-    datapath = '../data/mount_data/'
+    datapath = "../data/mount_data/"
     main(datapath)
diff --git a/code/plot_event_timeline.py b/code/plot_event_timeline.py
index cb75cd9..ab408ee 100644
--- a/code/plot_event_timeline.py
+++ b/code/plot_event_timeline.py
@@ -13,6 +13,7 @@ from modules.plotstyle import PlotStyle
 from modules.behaviour_handling import Behavior, correct_chasing_events
 
 from extract_chirps import get_valid_datasets
+
 ps = PlotStyle()
 
 logger = makeLogger(__name__)
@@ -20,13 +21,16 @@ logger = makeLogger(__name__)
 
 def main(datapath: str):
     foldernames = [
-        datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)]
+        datapath + x + "/"
+        for x in os.listdir(datapath)
+        if os.path.isdir(datapath + x)
+    ]
     foldernames, _ = get_valid_datasets(datapath)
 
     for foldername in foldernames[3:4]:
         print(foldername)
         # foldername = foldernames[0]
-        if foldername == '../data/mount_data/2020-05-12-10_00/':
+        if foldername == "../data/mount_data/2020-05-12-10_00/":
             continue
         # behabvior is pandas dataframe with all the data
         bh = Behavior(foldername)
@@ -52,18 +56,43 @@ def main(datapath: str):
         exit()
         fish1_color = ps.gblue2
         fish2_color = ps.gblue1
-        fig, ax = plt.subplots(5, 1, figsize=(
-            21*ps.cm, 10*ps.cm), height_ratios=[0.5, 0.5, 0.5, 0.2, 6], sharex=True)
+        fig, ax = plt.subplots(
+            5,
+            1,
+            figsize=(21 * ps.cm, 10 * ps.cm),
+            height_ratios=[0.5, 0.5, 0.5, 0.2, 6],
+            sharex=True,
+        )
         # marker size
         s = 80
-        ax[0].scatter(physical_contact, np.ones(
-            len(physical_contact)), color=ps.gray, marker='|', s=s)
-        ax[1].scatter(chasing_onset, np.ones(len(chasing_onset)),
-                      color=ps.gray, marker='|', s=s)
-        ax[2].scatter(fish1, np.ones(len(fish1))-0.25,
-                      color=fish1_color, marker='|', s=s)
-        ax[2].scatter(fish2, np.zeros(len(fish2))+0.25,
-                      color=fish2_color, marker='|', s=s)
+        ax[0].scatter(
+            physical_contact,
+            np.ones(len(physical_contact)),
+            color=ps.gray,
+            marker="|",
+            s=s,
+        )
+        ax[1].scatter(
+            chasing_onset,
+            np.ones(len(chasing_onset)),
+            color=ps.gray,
+            marker="|",
+            s=s,
+        )
+        ax[2].scatter(
+            fish1,
+            np.ones(len(fish1)) - 0.25,
+            color=fish1_color,
+            marker="|",
+            s=s,
+        )
+        ax[2].scatter(
+            fish2,
+            np.zeros(len(fish2)) + 0.25,
+            color=fish2_color,
+            marker="|",
+            s=s,
+        )
 
         freq_temp = bh.freq[bh.ident == fish1_id]
         time_temp = bh.time[bh.idx[bh.ident == fish1_id]]
@@ -94,35 +123,38 @@ def main(datapath: str):
         ax[2].set_xticks([])
         ps.hide_ax(ax[2])
 
-        ax[4].axvspan(0, 3, 0, 5, facecolor='grey', alpha=0.5)
+        ax[4].axvspan(0, 3, 0, 5, facecolor="grey", alpha=0.5)
         ax[4].set_xticks(np.arange(0, 6.1, 0.5))
         ps.hide_ax(ax[3])
 
         labelpad = 30
         fsize = 12
 
-        ax[0].set_ylabel('Contact', rotation=0,
-                         labelpad=labelpad, fontsize=fsize)
+        ax[0].set_ylabel(
+            "Contact", rotation=0, labelpad=labelpad, fontsize=fsize
+        )
         ax[0].yaxis.set_label_coords(-0.062, -0.08)
-        ax[1].set_ylabel('Chasing', rotation=0,
-                         labelpad=labelpad, fontsize=fsize)
+        ax[1].set_ylabel(
+            "Chasing", rotation=0, labelpad=labelpad, fontsize=fsize
+        )
         ax[1].yaxis.set_label_coords(-0.06, -0.08)
-        ax[2].set_ylabel('Chirps', rotation=0,
-                         labelpad=labelpad, fontsize=fsize)
+        ax[2].set_ylabel(
+            "Chirps", rotation=0, labelpad=labelpad, fontsize=fsize
+        )
         ax[2].yaxis.set_label_coords(-0.07, -0.08)
-        ax[4].set_ylabel('EODf')
+        ax[4].set_ylabel("EODf")
 
-        ax[4].set_xlabel('Time [h]')
+        ax[4].set_xlabel("Time [h]")
         # ax[0].set_title(foldername.split('/')[-2])
         # 2020-03-31-9_59
         plt.subplots_adjust(left=0.158, right=0.987, top=0.918, bottom=0.136)
-        plt.savefig('../poster/figs/timeline.svg')
+        plt.savefig("../poster/figs/timeline.svg")
         plt.show()
 
     # plot chirps
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     # Path to the data
-    datapath = '../data/mount_data/'
+    datapath = "../data/mount_data/"
     main(datapath)
diff --git a/code/plot_introduction_specs.py b/code/plot_introduction_specs.py
index d7e6f4a..0c8e2b4 100644
--- a/code/plot_introduction_specs.py
+++ b/code/plot_introduction_specs.py
@@ -11,7 +11,6 @@ ps = PlotStyle()
 
 
 def main():
-
     # Load data
     datapath = "../data/2022-06-02-10_00/"
     data = LoadData(datapath)
@@ -24,26 +23,31 @@ def main():
 
     timescaler = 1000
 
-    raw = data.raw[window_start_index:window_start_index +
-                   window_duration_index, 10]
+    raw = data.raw[
+        window_start_index : window_start_index + window_duration_index, 10
+    ]
 
     fig, (ax1, ax2) = plt.subplots(
-        1, 2, figsize=(21 * ps.cm, 8*ps.cm), sharex=True, sharey=True)
+        1, 2, figsize=(21 * ps.cm, 8 * ps.cm), sharex=True, sharey=True
+    )
 
     # plot instantaneous frequency
     filtered1 = bandpass_filter(
-        signal=raw, lowf=750, highf=1200, samplerate=data.raw_rate)
+        signal=raw, lowf=750, highf=1200, samplerate=data.raw_rate
+    )
     filtered2 = bandpass_filter(
-        signal=raw, lowf=550, highf=700, samplerate=data.raw_rate)
+        signal=raw, lowf=550, highf=700, samplerate=data.raw_rate
+    )
 
     freqtime1, freq1 = instantaneous_frequency(
-        filtered1, data.raw_rate, smoothing_window=3)
+        filtered1, data.raw_rate, smoothing_window=3
+    )
     freqtime2, freq2 = instantaneous_frequency(
-        filtered2, data.raw_rate, smoothing_window=3)
+        filtered2, data.raw_rate, smoothing_window=3
+    )
 
-    ax1.plot(freqtime1*timescaler, freq1, color=ps.g, lw=2, label="Fish 1") 
-    ax1.plot(freqtime2*timescaler, freq2, color=ps.gray,
-            lw=2, label="Fish 2")
+    ax1.plot(freqtime1 * timescaler, freq1, color=ps.g, lw=2, label="Fish 1")
+    ax1.plot(freqtime2 * timescaler, freq2, color=ps.gray, lw=2, label="Fish 2")
     # ax.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0)
     # # ps.hide_xax(ax1)
 
@@ -62,8 +66,8 @@ def main():
     ax1.imshow(
         decibel(spec_power[fmask, :]),
         extent=[
-            spec_times[0]*timescaler,
-            spec_times[-1]*timescaler,
+            spec_times[0] * timescaler,
+            spec_times[-1] * timescaler,
             spec_freqs[fmask][0],
             spec_freqs[fmask][-1],
         ],
@@ -87,8 +91,8 @@ def main():
     ax2.imshow(
         decibel(spec_power[fmask, :]),
         extent=[
-            spec_times[0]*timescaler,
-            spec_times[-1]*timescaler,
+            spec_times[0] * timescaler,
+            spec_times[-1] * timescaler,
             spec_freqs[fmask][0],
             spec_freqs[fmask][-1],
         ],
@@ -98,9 +102,8 @@ def main():
         alpha=1,
     )
     # ps.hide_xax(ax3)
-    ax2.plot(freqtime1*timescaler, freq1, color=ps.g, lw=2, label="_") 
-    ax2.plot(freqtime2*timescaler, freq2, color=ps.gray,
-            lw=2, label="_")
+    ax2.plot(freqtime1 * timescaler, freq1, color=ps.g, lw=2, label="_")
+    ax2.plot(freqtime2 * timescaler, freq2, color=ps.gray, lw=2, label="_")
 
     ax2.set_xlim(75, 200)
     ax1.set_ylim(400, 1200)
@@ -109,15 +112,22 @@ def main():
     fig.supylabel("Frequency [Hz]", fontsize=14)
 
     handles, labels = ax1.get_legend_handles_labels()
-    ax2.legend(handles, labels, bbox_to_anchor=(1.04, 1), loc="upper left", ncol=1,)
+    ax2.legend(
+        handles,
+        labels,
+        bbox_to_anchor=(1.04, 1),
+        loc="upper left",
+        ncol=1,
+    )
 
     ps.letter_subplots(xoffset=[-0.27, -0.1], yoffset=1.05)
 
-    plt.subplots_adjust(left=0.12, right=0.85, top=0.89,
-                        bottom=0.18, hspace=0.35)
+    plt.subplots_adjust(
+        left=0.12, right=0.85, top=0.89, bottom=0.18, hspace=0.35
+    )
 
-    plt.savefig('../poster/figs/introplot.pdf')
+    plt.savefig("../poster/figs/introplot.pdf")
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/code/plot_kdes.py b/code/plot_kdes.py
index 0f3082b..bc1bc98 100644
--- a/code/plot_kdes.py
+++ b/code/plot_kdes.py
@@ -1,7 +1,9 @@
-
 from modules.plotstyle import PlotStyle
 from modules.behaviour_handling import (
-    Behavior, correct_chasing_events, center_chirps)
+    Behavior,
+    correct_chasing_events,
+    center_chirps,
+)
 from modules.datahandling import flatten, causal_kde1d, acausal_kde1d
 from modules.logger import makeLogger
 from pandas import read_csv
@@ -18,80 +20,93 @@ logger = makeLogger(__name__)
 ps = PlotStyle()
 
 
-def bootstrap(data, nresamples, kde_time, kernel_width, event_times, time_before, time_after):
-
+def bootstrap(
+    data,
+    nresamples,
+    kde_time,
+    kernel_width,
+    event_times,
+    time_before,
+    time_after,
+):
     bootstrapped_kdes = []
-    data = data[data <= 3*60*60]  # only night time
+    data = data[data <= 3 * 60 * 60]  # only night time
 
     diff_data = np.diff(np.sort(data), prepend=0)
     # if len(data) != 0:
     #     mean_chirprate = (len(data) - 1) / (data[-1] - data[0])
 
     for i in tqdm(range(nresamples)):
-
         np.random.shuffle(diff_data)
 
         bootstrapped_data = np.cumsum(diff_data)
         # bootstrapped_data = data + np.random.randn(len(data)) * 10
 
         bootstrap_data_centered = center_chirps(
-            bootstrapped_data, event_times, time_before, time_after)
+            bootstrapped_data, event_times, time_before, time_after
+        )
 
         bootstrapped_kde = acausal_kde1d(
-            bootstrap_data_centered, time=kde_time, width=kernel_width)
+            bootstrap_data_centered, time=kde_time, width=kernel_width
+        )
 
-        bootstrapped_kde = list(np.asarray(
-            bootstrapped_kde) / len(event_times))
+        bootstrapped_kde = list(np.asarray(bootstrapped_kde) / len(event_times))
 
         bootstrapped_kdes.append(bootstrapped_kde)
 
     return bootstrapped_kdes
 
 
-def jackknife(data, nresamples, subsetsize, kde_time, kernel_width, event_times, time_before, time_after):
-
+def jackknife(
+    data,
+    nresamples,
+    subsetsize,
+    kde_time,
+    kernel_width,
+    event_times,
+    time_before,
+    time_after,
+):
     jackknife_kdes = []
-    data = data[data <= 3*60*60]  # only night time
+    data = data[data <= 3 * 60 * 60]  # only night time
     subsetsize = int(len(data) * subsetsize)
 
     diff_data = np.diff(np.sort(data), prepend=0)
 
     for i in tqdm(range(nresamples)):
-
-        jackknifed_data = np.random.choice(
-            diff_data, subsetsize, replace=False)
+        jackknifed_data = np.random.choice(diff_data, subsetsize, replace=False)
 
         jackknifed_data = np.cumsum(jackknifed_data)
 
         jackknifed_data_centered = center_chirps(
-            jackknifed_data, event_times, time_before, time_after)
+            jackknifed_data, event_times, time_before, time_after
+        )
 
         jackknifed_kde = acausal_kde1d(
-            jackknifed_data_centered, time=kde_time, width=kernel_width)
+            jackknifed_data_centered, time=kde_time, width=kernel_width
+        )
 
-        jackknifed_kde = list(np.asarray(
-            jackknifed_kde) / len(event_times))
+        jackknifed_kde = list(np.asarray(jackknifed_kde) / len(event_times))
 
         jackknife_kdes.append(jackknifed_kde)
     return jackknife_kdes
 
 
 def get_chirp_winner_loser(folder_name, Behavior, order_meta_df):
-
-    foldername = folder_name.split('/')[-2]
-    winner_row = order_meta_df[order_meta_df['recording'] == foldername]
-    winner = winner_row['winner'].values[0].astype(int)
-    winner_fish1 = winner_row['fish1'].values[0].astype(int)
-    winner_fish2 = winner_row['fish2'].values[0].astype(int)
+    foldername = folder_name.split("/")[-2]
+    winner_row = order_meta_df[order_meta_df["recording"] == foldername]
+    winner = winner_row["winner"].values[0].astype(int)
+    winner_fish1 = winner_row["fish1"].values[0].astype(int)
+    winner_fish2 = winner_row["fish2"].values[0].astype(int)
 
     if winner > 0:
         if winner == winner_fish1:
-            winner_fish_id = winner_row['rec_id1'].values[0]
-            loser_fish_id = winner_row['rec_id2'].values[0]
+            winner_fish_id = winner_row["rec_id1"].values[0]
+            loser_fish_id = winner_row["rec_id2"].values[0]
 
         elif winner == winner_fish2:
-            winner_fish_id = winner_row['rec_id2'].values[0]
-            loser_fish_id = winner_row['rec_id1'].values[0]
+            winner_fish_id = winner_row["rec_id2"].values[0]
+            loser_fish_id = winner_row["rec_id1"].values[0]
 
         chirp_winner = Behavior.chirps[Behavior.chirps_ids == winner_fish_id]
         chirp_loser = Behavior.chirps[Behavior.chirps_ids == loser_fish_id]
@@ -101,7 +116,6 @@ def get_chirp_winner_loser(folder_name, Behavior, order_meta_df):
 
 
 def main(dataroot):
-
     foldernames, _ = np.asarray(get_valid_datasets(dataroot))
     plot_all = True
     time_before = 90
@@ -111,10 +125,9 @@ def main(dataroot):
     kde_time = np.arange(-time_before, time_after, dt)
     nbootstraps = 50
 
-    meta_path = (
-        '/').join(foldernames[0].split('/')[:-2]) + '/order_meta.csv'
+    meta_path = ("/").join(foldernames[0].split("/")[:-2]) + "/order_meta.csv"
     meta = pd.read_csv(meta_path)
-    meta['recording'] = meta['recording'].str[1:-1]
+    meta["recording"] = meta["recording"].str[1:-1]
 
     winner_onsets = []
     winner_offsets = []
@@ -143,24 +156,24 @@ def main(dataroot):
     # loser_onset_chirpcount = 0
     # loser_offset_chirpcount = 0
     # loser_physical_chirpcount = 0
-    fig, ax = plt.subplots(1, 2, figsize=(
-        14 * ps.cm, 7*ps.cm), sharey=True, sharex=True)
+    fig, ax = plt.subplots(
+        1, 2, figsize=(14 * ps.cm, 7 * ps.cm), sharey=True, sharex=True
+    )
     # Iterate over all recordings and save chirp- and event-timestamps
     good_recs = np.asarray([0, 15])
     for i, folder in tqdm(enumerate(foldernames[good_recs])):
-
-        foldername = folder.split('/')[-2]
+        foldername = folder.split("/")[-2]
         # logger.info('Loading data from folder: {}'.format(foldername))
 
-        broken_folders = ['../data/mount_data/2020-05-12-10_00/']
+        broken_folders = ["../data/mount_data/2020-05-12-10_00/"]
         if folder in broken_folders:
             continue
 
         bh = Behavior(folder)
         category, timestamps = correct_chasing_events(bh.behavior, bh.start_s)
 
-        category = category[timestamps < 3*60*60]  # only night time
-        timestamps = timestamps[timestamps < 3*60*60]  # only night time
+        category = category[timestamps < 3 * 60 * 60]  # only night time
+        timestamps = timestamps[timestamps < 3 * 60 * 60]  # only night time
 
         winner, loser = get_chirp_winner_loser(folder, bh, meta)
         if winner is None:
@@ -168,27 +181,33 @@ def main(dataroot):
         # winner_count += len(winner)
         # loser_count += len(loser)
 
-        onsets = (timestamps[category == 0])
-        offsets = (timestamps[category == 1])
-        physicals = (timestamps[category == 2])
+        onsets = timestamps[category == 0]
+        offsets = timestamps[category == 1]
+        physicals = timestamps[category == 2]
 
         onset_count += len(onsets)
         offset_count += len(offsets)
         physical_count += len(physicals)
 
-        winner_onsets.append(center_chirps(
-            winner, onsets, time_before, time_after))
-        winner_offsets.append(center_chirps(
-            winner, offsets, time_before, time_after))
-        winner_physicals.append(center_chirps(
-            winner, physicals, time_before, time_after))
+        winner_onsets.append(
+            center_chirps(winner, onsets, time_before, time_after)
+        )
+        winner_offsets.append(
+            center_chirps(winner, offsets, time_before, time_after)
+        )
+        winner_physicals.append(
+            center_chirps(winner, physicals, time_before, time_after)
+        )
 
-        loser_onsets.append(center_chirps(
-            loser, onsets, time_before, time_after))
-        loser_offsets.append(center_chirps(
-            loser, offsets, time_before, time_after))
-        loser_physicals.append(center_chirps(
-            loser, physicals, time_before, time_after))
+        loser_onsets.append(
+            center_chirps(loser, onsets, time_before, time_after)
+        )
+        loser_offsets.append(
+            center_chirps(loser, offsets, time_before, time_after)
+        )
+        loser_physicals.append(
+            center_chirps(loser, physicals, time_before, time_after)
+        )
 
         # winner_onset_chirpcount += len(winner_onsets[-1])
         # winner_offset_chirpcount += len(winner_offsets[-1])
@@ -232,14 +251,17 @@ def main(dataroot):
         #     event_times=onsets,
         #     time_before=time_before,
         #     time_after=time_after))
-        loser_offsets_boot.append(bootstrap(
-            loser,
-            nresamples=nbootstraps,
-            kde_time=kde_time,
-            kernel_width=kernel_width,
-            event_times=offsets,
-            time_before=time_before,
-            time_after=time_after))
+        loser_offsets_boot.append(
+            bootstrap(
+                loser,
+                nresamples=nbootstraps,
+                kde_time=kde_time,
+                kernel_width=kernel_width,
+                event_times=offsets,
+                time_before=time_before,
+                time_after=time_after,
+            )
+        )
         # loser_physicals_boot.append(bootstrap(
         #     loser,
         #     nresamples=nbootstraps,
@@ -249,18 +271,17 @@ def main(dataroot):
         #     time_before=time_before,
         #     time_after=time_after))
 
-#         loser_offsets_jackknife = jackknife(
-#             loser,
-#             nresamples=nbootstraps,
-#             subsetsize=0.9,
-#             kde_time=kde_time,
-#             kernel_width=kernel_width,
-#             event_times=offsets,
-#             time_before=time_before,
-#             time_after=time_after)
+        #         loser_offsets_jackknife = jackknife(
+        #             loser,
+        #             nresamples=nbootstraps,
+        #             subsetsize=0.9,
+        #             kde_time=kde_time,
+        #             kernel_width=kernel_width,
+        #             event_times=offsets,
+        #             time_before=time_before,
+        #             time_after=time_after)
 
         if plot_all:
-
             # winner_onsets_conv = acausal_kde1d(
             #     winner_onsets[-1], kde_time, kernel_width)
             # winner_offsets_conv = acausal_kde1d(
@@ -271,24 +292,35 @@ def main(dataroot):
             # loser_onsets_conv = acausal_kde1d(
             #     loser_onsets[-1], kde_time, kernel_width)
             loser_offsets_conv = acausal_kde1d(
-                loser_offsets[-1], kde_time, kernel_width)
+                loser_offsets[-1], kde_time, kernel_width
+            )
             # loser_physicals_conv = acausal_kde1d(
             #     loser_physicals[-1], kde_time, kernel_width)
 
-            ax[i].plot(kde_time, loser_offsets_conv /
-                       len(offsets), lw=2, zorder=100, c=ps.gblue1)
+            ax[i].plot(
+                kde_time,
+                loser_offsets_conv / len(offsets),
+                lw=2,
+                zorder=100,
+                c=ps.gblue1,
+            )
 
             ax[i].fill_between(
                 kde_time,
                 np.percentile(loser_offsets_boot[-1], 1, axis=0),
                 np.percentile(loser_offsets_boot[-1], 99, axis=0),
-                color='gray',
-                alpha=0.8)
+                color="gray",
+                alpha=0.8,
+            )
 
-            ax[i].plot(kde_time, np.median(loser_offsets_boot[-1], axis=0),
-                       color=ps.black, linewidth=2)
+            ax[i].plot(
+                kde_time,
+                np.median(loser_offsets_boot[-1], axis=0),
+                color=ps.black,
+                linewidth=2,
+            )
 
-            ax[i].axvline(0, color=ps.gray, linestyle='--')
+            ax[i].axvline(0, color=ps.gray, linestyle="--")
 
             # ax[i].fill_between(
             #     kde_time,
@@ -300,8 +332,8 @@ def main(dataroot):
             #            color=ps.white, linewidth=2)
 
             ax[i].set_xlim(-60, 60)
-            fig.supylabel('Chirp rate (a.u.)', fontsize=14)
-            fig.supxlabel('Time (s)', fontsize=14)
+            fig.supylabel("Chirp rate (a.u.)", fontsize=14)
+            fig.supxlabel("Time (s)", fontsize=14)
 
             # fig, ax = plt.subplots(2, 3, figsize=(
             #     21*ps.cm, 10*ps.cm), sharey=True, sharex=True)
@@ -521,9 +553,9 @@ def main(dataroot):
     #                       color=ps.gray,
     #                       alpha=0.5)
     plt.subplots_adjust(bottom=0.21, top=0.93)
-    plt.savefig('../poster/figs/kde.pdf')
+    plt.savefig("../poster/figs/kde.pdf")
     plt.show()
 
 
-if __name__ == '__main__':
-    main('../data/mount_data/')
+if __name__ == "__main__":
+    main("../data/mount_data/")