Merge branch 'master' of github.com:weygoldt/grid-chirpdetection

This commit is contained in:
weygoldt 2023-04-14 21:02:11 +02:00
commit c9fe3bd912
No known key found for this signature in database
28 changed files with 2372 additions and 758 deletions

View File

@ -37,40 +37,12 @@
## About The Project ## About The Project
[![Product Name Screen Shot][product-screenshot]](https://example.com) Chirps are transient communication singals of many wave-type electric fish. Because they are so fast, detecting them when the recorded signal includes multiple individuals is hard. But to understand if, and what kind of information they transmit in a natural setting, analyzing chirps in multiple freely interacting individual is nessecary. This repository documents an approach to detect these signals on electrode grid recordings with many freely behaving individuals.
Here's a blank template to get started: To avoid retyping too much info. Do a search and replace with your text editor for the following: `github_username`, `repo_name`, `twitter_handle`, `linkedin_username`, `email_client`, `email`, `project_title`, `project_description` The majority of the code and its tests were part of a lab rotation with the [Neuroethology](https://github.com/bendalab) at the University of Tuebingen. It also contains a [poster](poster_printed/main.pdf) and a more thorough [lab protocol](protocol/main.pdf).
<p align="right">(<a href="#readme-top">back to top</a>)</p> <p align="right">(<a href="#readme-top">back to top</a>)</p>
## Getting Started
This is an example of how you may give instructions on setting up your project locally.
To get a local copy up and running follow these simple example steps.
<p align="right">(<a href="#readme-top">back to top</a>)</p>
<!-- USAGE EXAMPLES -->
## Usage
Use this space to show useful examples of how a project can be used. Additional screenshots, code examples and demos work well in this space. You may also link to more resources.
_For more examples, please refer to the [Documentation](https://example.com)_
<p align="right">(<a href="#readme-top">back to top</a>)</p>
## To do
- [ ] Feature 1
- [ ] Feature 2
- [ ] Feature 3
- [ ] Nested Feature
<p align="right">(<a href="#readme-top">back to top</a>)</p>
<!-- # Chirp detection - GP2023 --> <!-- # Chirp detection - GP2023 -->
<!-- ## Git-Repository and commands --> <!-- ## Git-Repository and commands -->

View File

@ -0,0 +1,389 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Why can the instantaneous frequency of a band-pass filtered chirp recording go down ...\n",
"... if a chirp is an up-modulation of the frequency? \n",
"\n",
"This is the question we try to answer in this notebook"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"QApplication: invalid style override passed, ignoring it.\n",
" Available styles: Windows, Fusion\n"
]
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import thunderfish.fakefish as ff \n",
"from filters import instantaneous_frequency, bandpass_filter\n",
"%matplotlib qt\n",
"\n",
"# parameters that stay the same\n",
"samplerate = 20000\n",
"duration = 0.2\n",
"chirp_freq = 5\n",
"smooth = 3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"qt.qpa.wayland: Wayland does not support QWindow::requestActivate()\n"
]
}
],
"source": [
"def make_chirp(eodf, size, width, kurtosis, contrast, phase0):\n",
"\n",
" chirp_trace, amp = ff.chirps(\n",
" eodf = eodf,\n",
" samplerate = samplerate,\n",
" duration = duration,\n",
" chirp_freq = chirp_freq,\n",
" chirp_size = size,\n",
" chirp_width = width,\n",
" chirp_kurtosis = kurtosis,\n",
" chirp_contrast = contrast,\n",
" )\n",
"\n",
" chirp = ff.wavefish_eods(\n",
" fish = 'Alepto',\n",
" frequency = chirp_trace,\n",
" samplerate = samplerate,\n",
" duration = duration,\n",
" phase0 = phase0,\n",
" noise_std = 0,\n",
" )\n",
"\n",
" chirp *= amp\n",
"\n",
" return chirp_trace, chirp\n",
"\n",
"def filtered_chirp(eodf, size, width, kurtosis, contrast, phase0):\n",
"\n",
" time = np.arange(0, duration, 1/samplerate)\n",
" chirp_trace, chirp = make_chirp(\n",
" eodf = eodf, \n",
" size = size, \n",
" width = width, \n",
" kurtosis = kurtosis, \n",
" contrast = contrast, \n",
" phase0 = phase0,\n",
" )\n",
"\n",
" # apply filters\n",
" narrow_filtered = bandpass_filter(chirp, samplerate, eodf-10, eodf+10)\n",
" narrow_freqtime, narrow_freq = instantaneous_frequency(narrow_filtered, samplerate, smooth)\n",
" broad_filtered = bandpass_filter(chirp, samplerate, eodf-300, eodf+300)\n",
" broad_freqtime, broad_freq = instantaneous_frequency(broad_filtered, samplerate, smooth)\n",
"\n",
" original = (time, chirp_trace, chirp)\n",
" broad = (broad_freqtime, broad_freq, broad_filtered)\n",
" narrow = (narrow_freqtime, narrow_freq, narrow_filtered)\n",
"\n",
" return original, broad, narrow\n",
"\n",
"def plot(original, broad, narrow, axs):\n",
"\n",
" axs[0].plot(original[0], original[1], label='chirp trace')\n",
" axs[0].plot(broad[0], broad[1], label='broad filtered')\n",
" axs[0].plot(narrow[0], narrow[1], label='narrow filtered')\n",
" axs[1].plot(original[0], original[2], label='unfiltered')\n",
" axs[1].plot(original[0], broad[2], label='broad filtered')\n",
" axs[1].plot(original[0], narrow[2], label='narrow filtered')\n",
"\n",
"original, broad, narrow = filtered_chirp(600, 100, 0.02, 1, 0.1, 0)\n",
"fig, axs = plt.subplots(2, 1, figsize=(10, 5), sharex=True)\n",
"plot(original, broad, narrow, axs)\n",
"fig.align_labels()\n",
"plt.show()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chirp size\n",
"now that we have established an easy way to simulate and plot the chirps, lets change the chirp size and see how the narrow-filtered instantaneous frequency changes."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"size 10 Hz; Integral 0.117\n",
"size 30 Hz; Integral 0.35\n",
"size 50 Hz; Integral 0.584\n",
"size 70 Hz; Integral 0.818\n",
"size 90 Hz; Integral 1.051\n",
"size 110 Hz; Integral 1.285\n",
"size 130 Hz; Integral 1.518\n",
"size 150 Hz; Integral 1.752\n",
"size 170 Hz; Integral 1.986\n",
"size 190 Hz; Integral 2.219\n",
"size 210 Hz; Integral 2.453\n",
"size 230 Hz; Integral 2.687\n",
"size 250 Hz; Integral 2.92\n",
"size 270 Hz; Integral 3.154\n",
"size 290 Hz; Integral 3.387\n",
"size 310 Hz; Integral 3.621\n",
"size 330 Hz; Integral 3.855\n",
"size 350 Hz; Integral 4.088\n",
"size 370 Hz; Integral 4.322\n",
"size 390 Hz; Integral 4.555\n",
"size 410 Hz; Integral 4.789\n",
"size 430 Hz; Integral 5.023\n",
"size 450 Hz; Integral 5.256\n",
"size 470 Hz; Integral 5.49\n",
"size 490 Hz; Integral 5.724\n",
"size 510 Hz; Integral 5.957\n",
"size 530 Hz; Integral 6.191\n",
"size 550 Hz; Integral 6.424\n",
"size 570 Hz; Integral 6.658\n",
"size 590 Hz; Integral 6.892\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"qt.qpa.wayland: Wayland does not support QWindow::requestActivate()\n"
]
}
],
"source": [
"sizes = np.arange(10, 600, 20)\n",
"fig, axs = plt.subplots(2, len(sizes), figsize=(10, 5), sharex=True, sharey='row')\n",
"integrals = []\n",
"\n",
"for i, size in enumerate(sizes):\n",
" original, broad, narrow = filtered_chirp(600, size, 0.02, 1, 0.1, 0)\n",
"\n",
" integral = np.sum(original[1]-600)/(20000)\n",
" integrals.append(integral)\n",
"\n",
" plot(original, broad, narrow, axs[:, i])\n",
" axs[:, i][0].set_xlim(0.06, 0.14)\n",
" axs[0, i].set_title(np.round(integral, 3))\n",
" print(f'size {size} Hz; Integral {np.round(integral,3)}')\n",
" \n",
"fig.legend(handles=axs[0,0].get_lines(), loc='upper center', ncol=3)\n",
"axs[0,0].set_ylabel('frequency [Hz]')\n",
"axs[1,0].set_ylabel('amplitude [a.u.]')\n",
"fig.supxlabel('time [s]')\n",
"fig.align_labels()\n",
"plt.show()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chirp width"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"widths = np.arange(0.02, 0.08, 0.005)\n",
"fig, axs = plt.subplots(2, len(widths), figsize=(10, 5), sharex=True, sharey='row')\n",
"integrals = []\n",
"\n",
"for i, width in enumerate(widths):\n",
" if i > 9:\n",
" break\n",
"\n",
" original, broad, narrow = filtered_chirp(600, 100, width, 1, 0.1, 0)\n",
"\n",
" integral = np.sum(original[1]-600)/(20000)\n",
"\n",
" plot(original, broad, narrow, axs[:, i])\n",
" axs[:, i][0].set_xlim(0.06, 0.14)\n",
" axs[0, i].set_title(f'width {np.round(width, 2)} s')\n",
" print(f'width {width} s; Integral {np.round(integral, 3)}')\n",
" \n",
"fig.legend(handles=axs[0,0].get_lines(), loc='upper center', ncol=3)\n",
"axs[0,0].set_ylabel('frequency [Hz]')\n",
"axs[1,0].set_ylabel('amplitude [a.u.]')\n",
"fig.supxlabel('time [s]')\n",
"fig.align_labels()\n",
"plt.show()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chirp kurtosis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"kurtosiss = np.arange(0, 20, 1.6)\n",
"fig, axs = plt.subplots(2, len(kurtosiss), figsize=(10, 5), sharex=True, sharey='row')\n",
"integrals = []\n",
"\n",
"for i, kurtosis in enumerate(kurtosiss):\n",
"\n",
" original, broad, narrow = filtered_chirp(600, 100, 0.02, kurtosis, 0.1, 0)\n",
"\n",
" integral = np.sum(original[1]-600)/(20000)\n",
"\n",
" plot(original, broad, narrow, axs[:, i])\n",
" axs[:, i][0].set_xlim(0.06, 0.14)\n",
" axs[0, i].set_title(f'kurt {np.round(kurtosis, 2)}')\n",
" print(f'kurt {kurtosis}; Integral {np.round(integral, 3)}')\n",
" \n",
"fig.legend(handles=axs[0,0].get_lines(), loc='upper center', ncol=3)\n",
"axs[0,0].set_ylabel('frequency [Hz]')\n",
"axs[1,0].set_ylabel('amplitude [a.u.]')\n",
"fig.supxlabel('time [s]')\n",
"fig.align_labels()\n",
"plt.show()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chirp contrast"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"contrasts = np.arange(0.0, 1.1, 0.1)\n",
"fig, axs = plt.subplots(2, len(sizes), figsize=(10, 5), sharex=True, sharey='row')\n",
"integrals = []\n",
"\n",
"for i, contrast in enumerate(contrasts):\n",
" if i > 9:\n",
" break\n",
" original, broad, narrow = filtered_chirp(600, 100, 0.02, 1, contrast, 0)\n",
"\n",
" integral = np.trapz(original[2], original[0])\n",
" integrals.append(integral)\n",
"\n",
" plot(original, broad, narrow, axs[:, i])\n",
" axs[:, i][0].set_xlim(0.06, 0.14)\n",
" axs[0, i].set_title(f'contr {np.round(contrast, 2)}')\n",
" \n",
"fig.legend(handles=axs[0,0].get_lines(), loc='upper center', ncol=3)\n",
"axs[0,0].set_ylabel('frequency [Hz]')\n",
"axs[1,0].set_ylabel('amplitude [a.u.]')\n",
"fig.supxlabel('time [s]')\n",
"fig.align_labels()\n",
"plt.show()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chirp phase "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"phases = np.arange(0.0, 2 * np.pi, 0.2)\n",
"fig, axs = plt.subplots(2, len(sizes), figsize=(10, 5), sharex=True, sharey='row')\n",
"integrals = []\n",
"for i, phase in enumerate(phases):\n",
" if i > 9:\n",
" break\n",
"\n",
" original, broad, narrow = filtered_chirp(600, 100, 0.02, 1, 0.1, phase)\n",
"\n",
" integral = np.trapz(original[2], original[0])\n",
" integrals.append(integral)\n",
"\n",
" plot(original, broad, narrow, axs[:, i])\n",
" axs[:, i][0].set_xlim(0.06, 0.14)\n",
" axs[0, i].set_title(f'phase {np.round(phase, 2)}')\n",
"\n",
" \n",
"fig.legend(handles=axs[0,0].get_lines(), loc='upper center', ncol=3)\n",
"axs[0,0].set_ylabel('frequency [Hz]')\n",
"axs[1,0].set_ylabel('amplitude [a.u.]')\n",
"fig.supxlabel('time [s]')\n",
"fig.align_labels()\n",
"plt.show()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"These experiments show, that the narrow filtered instantaneous freuqency only switches its sign, when the integral of the instantaneous frequency (that was used to make the signal)\n",
"changes. Specifically, when the instantaneous frequency is 0.57, 1.57, 2.57 etc., the sign swithes. "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "chirpdetection",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,200 @@
from scipy.signal import butter, sosfiltfilt
from scipy.ndimage import gaussian_filter1d
import numpy as np
def instantaneous_frequency(
signal: np.ndarray,
samplerate: int,
smoothing_window: int,
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the instantaneous frequency of a signal that is approximately
sinusoidal and symmetric around 0.
Parameters
----------
signal : np.ndarray
Signal to compute the instantaneous frequency from.
samplerate : int
Samplerate of the signal.
smoothing_window : int
Window size for the gaussian filter.
Returns
-------
tuple[np.ndarray, np.ndarray]
"""
# calculate instantaneous frequency with zero crossings
roll_signal = np.roll(signal, shift=1)
time_signal = np.arange(len(signal)) / samplerate
period_index = np.arange(len(signal))[(roll_signal < 0) & (signal >= 0)][
1:-1
]
upper_bound = np.abs(signal[period_index])
lower_bound = np.abs(signal[period_index - 1])
upper_time = np.abs(time_signal[period_index])
lower_time = np.abs(time_signal[period_index - 1])
# create ratio
lower_ratio = lower_bound / (lower_bound + upper_bound)
# appy to time delta
time_delta = upper_time - lower_time
true_zero = lower_time + lower_ratio * time_delta
# create new time array
instantaneous_frequency_time = true_zero[:-1] + 0.5 * np.diff(true_zero)
# compute frequency
instantaneous_frequency = gaussian_filter1d(
1 / np.diff(true_zero), smoothing_window
)
return instantaneous_frequency_time, instantaneous_frequency
def inst_freq(signal, fs):
"""
Computes the instantaneous frequency of a periodic signal using zero-crossings.
Parameters:
-----------
signal : array-like
The input signal.
fs : float
The sampling frequency of the input signal.
Returns:
--------
freq : array-like
The instantaneous frequency of the input signal.
"""
# Compute the sign of the signal
sign = np.sign(signal)
# Compute the crossings of the sign signal with a zero line
crossings = np.where(np.diff(sign))[0]
# Compute the time differences between zero crossings
dt = np.diff(crossings) / fs
# Compute the instantaneous frequency as the reciprocal of the time differences
freq = 1 / dt
# Gaussian filter the signal
freq = gaussian_filter1d(freq, 10)
# Pad the frequency vector with zeros to match the length of the input signal
freq = np.pad(freq, (0, len(signal) - len(freq)))
return freq
def bandpass_filter(
signal: np.ndarray,
samplerate: float,
lowf: float,
highf: float,
) -> np.ndarray:
"""Bandpass filter a signal.
Parameters
----------
signal : np.ndarray
The data to be filtered
rate : float
The sampling rate
lowf : float
The low cutoff frequency
highf : float
The high cutoff frequency
Returns
-------
np.ndarray
The filtered data
"""
sos = butter(2, (lowf, highf), "bandpass", fs=samplerate, output="sos")
filtered_signal = sosfiltfilt(sos, signal)
return filtered_signal
def highpass_filter(
signal: np.ndarray,
samplerate: float,
cutoff: float,
) -> np.ndarray:
"""Highpass filter a signal.
Parameters
----------
signal : np.ndarray
The data to be filtered
rate : float
The sampling rate
cutoff : float
The cutoff frequency
Returns
-------
np.ndarray
The filtered data
"""
sos = butter(2, cutoff, "highpass", fs=samplerate, output="sos")
filtered_signal = sosfiltfilt(sos, signal)
return filtered_signal
def lowpass_filter(
signal: np.ndarray, samplerate: float, cutoff: float
) -> np.ndarray:
"""Lowpass filter a signal.
Parameters
----------
data : np.ndarray
The data to be filtered
rate : float
The sampling rate
cutoff : float
The cutoff frequency
Returns
-------
np.ndarray
The filtered data
"""
sos = butter(2, cutoff, "lowpass", fs=samplerate, output="sos")
filtered_signal = sosfiltfilt(sos, signal)
return filtered_signal
def envelope(
signal: np.ndarray, samplerate: float, cutoff_frequency: float
) -> np.ndarray:
"""Calculate the envelope of a signal using a lowpass filter.
Parameters
----------
signal : np.ndarray
The signal to calculate the envelope of
samplingrate : float
The sampling rate of the signal
cutoff_frequency : float
The cutoff frequency of the lowpass filter
Returns
-------
np.ndarray
The envelope of the signal
"""
sos = butter(2, cutoff_frequency, "lowpass", fs=samplerate, output="sos")
envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(signal))
return envelope

View File

@ -0,0 +1,557 @@
import sys
from IPython import embed
import thunderfish.powerspectrum as ps
import numpy as np
species_name = dict(
Sine="Sinewave",
Alepto="Apteronotus leptorhynchus",
Arostratus="Apteronotus rostratus",
Eigenmannia="Eigenmannia spec.",
Sternarchella="Sternarchella terminalis",
Sternopygus="Sternopygus dariensis",
)
"""Translate species ids used by wavefish_harmonics and pulsefish_eodpeaks to full species names.
"""
def abbrv_genus(name):
"""Abbreviate genus in a species name.
Parameters
----------
name: string
Full species name of the form 'Genus species subspecies'
Returns
-------
name: string
The species name with abbreviated genus, i.e. 'G. species subspecies'
"""
ns = name.split()
return ns[0][0] + ". " + " ".join(ns[1:])
# Amplitudes and phases of various wavefish species:
Sine_harmonics = dict(amplitudes=(1.0,), phases=(0.5 * np.pi,))
Apteronotus_leptorhynchus_harmonics = dict(
amplitudes=(0.90062, 0.15311, 0.072049, 0.012609, 0.011708),
phases=(1.3623, 2.3246, 0.9869, 2.6492, -2.6885),
)
Apteronotus_rostratus_harmonics = dict(
amplitudes=(
0.64707,
0.43874,
0.063592,
0.07379,
0.040199,
0.023073,
0.0097678,
),
phases=(2.2988, 0.78876, -1.316, 2.2416, 2.0413, 1.1022, -2.0513),
)
Eigenmannia_harmonics = dict(
amplitudes=(1.0087, 0.23201, 0.060524, 0.020175, 0.010087, 0.0080699),
phases=(1.3414, 1.3228, 2.9242, 2.8157, 2.6871, -2.8415),
)
Sternarchella_terminalis_harmonics = dict(
amplitudes=(
0.11457,
0.4401,
0.41055,
0.20132,
0.061364,
0.011389,
0.0057985,
),
phases=(-2.7106, 2.4472, 1.6829, 0.79085, 0.119, -0.82355, -1.9956),
)
Sternopygus_dariensis_harmonics = dict(
amplitudes=(
0.98843,
0.41228,
0.047848,
0.11048,
0.022801,
0.030706,
0.019018,
),
phases=(1.4153, 1.3141, 3.1062, -2.3961, -1.9524, 0.54321, 1.6844),
)
wavefish_harmonics = dict(
Sine=Sine_harmonics,
Alepto=Apteronotus_leptorhynchus_harmonics,
Arostratus=Apteronotus_rostratus_harmonics,
Eigenmannia=Eigenmannia_harmonics,
Sternarchella=Sternarchella_terminalis_harmonics,
Sternopygus=Sternopygus_dariensis_harmonics,
)
"""Amplitudes and phases of EOD waveforms of various species of wave-type electric fish."""
def wavefish_spectrum(fish):
"""Amplitudes and phases of a wavefish EOD.
Parameters
----------
fish: string, dict or tuple of lists/arrays
Specify relative amplitudes and phases of the fundamental and its harmonics.
If string then take amplitudes and phases from the `wavefish_harmonics` dictionary.
If dictionary then take amplitudes and phases from the 'amlitudes' and 'phases' keys.
If tuple then the first element is the list of amplitudes and
the second one the list of relative phases in radians.
Returns
-------
amplitudes: array of floats
Amplitudes of the fundamental and its harmonics.
phases: array of floats
Phases in radians of the fundamental and its harmonics.
Raises
------
KeyError:
Unknown fish.
IndexError:
Amplitudes and phases differ in length.
"""
if isinstance(fish, (tuple, list)):
amplitudes = fish[0]
phases = fish[1]
elif isinstance(fish, dict):
amplitudes = fish["amplitudes"]
phases = fish["phases"]
else:
if fish not in wavefish_harmonics:
raise KeyError(
"unknown wavefish. Choose one of "
+ ", ".join(wavefish_harmonics.keys())
)
amplitudes = wavefish_harmonics[fish]["amplitudes"]
phases = wavefish_harmonics[fish]["phases"]
if len(amplitudes) != len(phases):
raise IndexError("need exactly as many phases as amplitudes")
# remove NaNs:
for k in reversed(range(len(amplitudes))):
if np.isfinite(amplitudes[k]) or np.isfinite(phases[k]):
amplitudes = amplitudes[: k + 1]
phases = phases[: k + 1]
break
return amplitudes, phases
def wavefish_eods(
fish="Eigenmannia",
frequency=100.0,
samplerate=44100.0,
duration=1.0,
phase0=0.0,
noise_std=0.05,
):
"""Simulate EOD waveform of a wave-type fish.
The waveform is constructed by superimposing sinewaves of integral
multiples of the fundamental frequency - the fundamental and its
harmonics. The fundamental frequency of the EOD (EODf) is given by
`frequency`. With `fish` relative amplitudes and phases of the
fundamental and its harmonics are specified.
The generated waveform is `duration` seconds long and is sampled with
`samplerate` Hertz. Gaussian white noise with a standard deviation of
`noise_std` is added to the generated waveform.
Parameters
----------
fish: string, dict or tuple of lists/arrays
Specify relative amplitudes and phases of the fundamental and its harmonics.
If string then take amplitudes and phases from the `wavefish_harmonics` dictionary.
If dictionary then take amplitudes and phases from the 'amlitudes' and 'phases' keys.
If tuple then the first element is the list of amplitudes and
the second one the list of relative phases in radians.
frequency: float or array of floats
EOD frequency of the fish in Hertz. Either fixed number or array for
time-dependent frequencies.
samplerate: float
Sampling rate in Hertz.
duration: float
Duration of the generated data in seconds. Only used if frequency is scalar.
phase0: float
Phase offset of the EOD waveform in radians.
noise_std: float
Standard deviation of additive Gaussian white noise.
Returns
-------
data: array of floats
Generated data of a wave-type fish.
Raises
------
KeyError:
Unknown fish.
IndexError:
Amplitudes and phases differ in length.
"""
# get relative amplitude and phases:
amplitudes, phases = wavefish_spectrum(fish)
# compute phase:
if np.isscalar(frequency):
phase = np.arange(0, duration, 1.0 / samplerate)
phase *= frequency
else:
phase = np.cumsum(frequency) / samplerate
# generate EOD:
data = np.zeros(len(phase))
for har, (ampl, phi) in enumerate(zip(amplitudes, phases)):
if np.isfinite(ampl) and np.isfinite(phi):
data += ampl * np.sin(
2 * np.pi * (har + 1) * phase + phi - (har + 1) * phase0
)
# add noise:
data += noise_std * np.random.randn(len(data))
return data
def normalize_wavefish(fish, mode="peak"):
"""Normalize amplitudes and phases of wave-type EOD waveform.
The amplitudes and phases of the Fourier components are adjusted
such that the resulting EOD waveform has a peak-to-peak amplitude
of two and the peak of the waveform is at time zero (mode is set
to 'peak') or that the fundamental has an amplitude of one and a
phase of 0 (mode is set to 'zero').
Parameters
----------
fish: string, dict or tuple of lists/arrays
Specify relative amplitudes and phases of the fundamental and its harmonics.
If string then take amplitudes and phases from the `wavefish_harmonics` dictionary.
If dictionary then take amplitudes and phases from the 'amlitudes' and 'phases' keys.
If tuple then the first element is the list of amplitudes and
the second one the list of relative phases in radians.
mode: 'peak' or 'zero'
How to normalize amplitude and phases:
- 'peak': normalize waveform to a peak-to-peak amplitude of two
and shift it such that its peak is at time zero.
- 'zero': scale amplitude of fundamental to one and its phase to zero.
Returns
-------
amplitudes: array of floats
Adjusted amplitudes of the fundamental and its harmonics.
phases: array of floats
Adjusted phases in radians of the fundamental and its harmonics.
"""
# get relative amplitude and phases:
amplitudes, phases = wavefish_spectrum(fish)
if mode == "zero":
newamplitudes = np.array(amplitudes) / amplitudes[0]
newphases = np.array(
[p + (k + 1) * (-phases[0]) for k, p in enumerate(phases)]
)
newphases %= 2.0 * np.pi
newphases[newphases > np.pi] -= 2.0 * np.pi
return newamplitudes, newphases
else:
# generate waveform:
eodf = 100.0
rate = 100000.0
data = wavefish_eods(fish, eodf, rate, 2.0 / eodf, noise_std=0.0)
# normalize amplitudes:
ampl = 0.5 * (np.max(data) - np.min(data))
newamplitudes = np.array(amplitudes) / ampl
# shift phases:
deltat = np.argmax(data[: int(rate / eodf)]) / rate
deltap = 2.0 * np.pi * deltat * eodf
newphases = np.array(
[p + (k + 1) * deltap for k, p in enumerate(phases)]
)
newphases %= 2.0 * np.pi
newphases[newphases > np.pi] -= 2.0 * np.pi
# return:
return newamplitudes, newphases
def export_wavefish(fish, name="Unknown_harmonics", file=None):
"""Serialize wavefish parameter to python code.
Add output to the wavefish_harmonics dictionary!
Parameters
----------
fish: string, dict or tuple of lists/arrays
Specify relative amplitudes and phases of the fundamental and its harmonics.
If string then take amplitudes and phases from the `wavefish_harmonics` dictionary.
If dictionary then take amplitudes and phases from the 'amlitudes' and 'phases' keys.
If tuple then the first element is the list of amplitudes and
the second one the list of relative phases in radians.
name: string
Name of the dictionary to be written.
file: string or file or None
File name or open file object where to write wavefish dictionary.
Returns
-------
fish: dict
Dictionary with amplitudes and phases.
"""
# get relative amplitude and phases:
amplitudes, phases = wavefish_spectrum(fish)
# write out dictionary:
if file is None:
file = sys.stdout
try:
file.write("")
closeit = False
except AttributeError:
file = open(file, "w")
closeit = True
n = 6
file.write(name + " = \\\n")
file.write(" dict(amplitudes=(")
file.write(", ".join(["%.5g" % a for a in amplitudes[:n]]))
for k in range(n, len(amplitudes), n):
file.write(",\n")
file.write(" " * (9 + 12))
file.write(", ".join(["%.5g" % a for a in amplitudes[k : k + n]]))
file.write("),\n")
file.write(" " * 9 + "phases=(")
file.write(", ".join(["%.5g" % p for p in phases[:n]]))
for k in range(n, len(phases), n):
file.write(",\n")
file.write(" " * (9 + 8))
file.write(", ".join(["%.5g" % p for p in phases[k : k + n]]))
file.write("))\n")
if closeit:
file.close()
# return dictionary:
harmonics = dict(amplitudes=amplitudes, phases=phases)
return harmonics
def chirps(
eodf=100.0,
samplerate=44100.0,
duration=1.0,
chirp_times=[0.5],
chirp_size=[100.0],
chirp_width=[0.01],
chirp_kurtosis=[1.0],
chirp_contrast=[0.05],
):
"""Simulate frequency trace with chirps.
A chirp is modeled as a Gaussian frequency modulation.
The first chirp is placed at 0.5/chirp_freq.
Parameters
----------
eodf: float
EOD frequency of the fish in Hertz.
samplerate: float
Sampling rate in Hertz.
duration: float
Duration of the generated data in seconds.
chirp_times: float
Timestamps of every single chirp in seconds.
chirp_size: list
Size of each chirp (maximum frequency increase above eodf) in Hertz.
chirp_width: list
Width of every single chirp at 10% height in seconds.
chirp_kurtosis: list:
Shape of every single chirp. =1: Gaussian, >1: more rectangular, <1: more peaked.
chirp_contrast: float
Maximum amplitude reduction of EOD during every respective chirp.
Returns
-------
frequency: array of floats
Generated frequency trace that can be passed on to wavefish_eods().
amplitude: array of floats
Generated amplitude modulation that can be used to multiply the trace generated by
wavefish_eods().
"""
# baseline eod frequency and amplitude modulation:
n = len(np.arange(0, duration, 1.0 / samplerate))
frequency = eodf * np.ones(n)
am = np.ones(n)
for time, width, size, kurtosis, contrast in zip(
chirp_times, chirp_width, chirp_size, chirp_kurtosis, chirp_contrast
):
# chirp frequency waveform:
chirp_t = np.arange(-2.0 * width, 2.0 * width, 1.0 / samplerate)
chirp_sig = 0.5 * width / (2.0 * np.log(10.0)) ** (0.5 / kurtosis)
gauss = np.exp(-0.5 * ((chirp_t / chirp_sig) ** 2.0) ** kurtosis)
# add chirps on baseline eodf:
index = int(time * samplerate)
i0 = index - len(gauss) // 2
i1 = i0 + len(gauss)
gi0 = 0
gi1 = len(gauss)
if i0 < 0:
gi0 -= i0
i0 = 0
if i1 >= len(frequency):
gi1 -= i1 - len(frequency)
i1 = len(frequency)
frequency[i0:i1] += size * gauss[gi0:gi1]
am[i0:i1] -= contrast * gauss[gi0:gi1]
return frequency, am
def rises(
eodf,
samplerate,
duration,
rise_times,
rise_size,
rise_tau,
decay_tau,
):
"""Simulate frequency trace with rises.
A rise is modeled as a double exponential frequency modulation.
Parameters
----------
eodf: float
EOD frequency of the fish in Hertz.
samplerate: float
Sampling rate in Hertz.
duration: float
Duration of the generated data in seconds.
rise_times: list
Timestamp of each of the rises in seconds.
rise_size: list
Size of the respective rise (frequency increase above eodf) in Hertz.
rise_tau: list
Time constant of the frequency increase of the respective rise in seconds.
decay_tau: list
Time constant of the frequency decay of the respective rise in seconds.
Returns
-------
data: array of floats
Generate frequency trace that can be passed on to wavefish_eods().
"""
n = len(np.arange(0, duration, 1.0 / samplerate))
# baseline eod frequency:
frequency = eodf * np.ones(n)
for time, size, riset, decayt in zip(
rise_times, rise_size, rise_tau, decay_tau
):
# rise frequency waveform:
rise_t = np.arange(0.0, 5.0 * decayt, 1.0 / samplerate)
rise = size * (1.0 - np.exp(-rise_t / riset)) * np.exp(-rise_t / decayt)
# add rises on baseline eodf:
index = int(time * samplerate)
if index + len(rise) > len(frequency):
rise_index = len(frequency) - index
frequency[index : index + rise_index] += rise[:rise_index]
break
else:
frequency[index : index + len(rise)] += rise
return frequency
class FishSignal:
def __init__(self, samplerate, duration, eodf, nchirps, nrises):
time = np.arange(0, duration, 1 / samplerate)
chirp_times = np.random.uniform(0, duration, nchirps)
rise_times = np.random.uniform(0, duration, nrises)
# pick random parameters for chirps
chirp_size = np.random.uniform(60, 200, nchirps)
chirp_width = np.random.uniform(0.01, 0.1, nchirps)
chirp_kurtosis = np.random.uniform(1, 1, nchirps)
chirp_contrast = np.random.uniform(0.1, 0.5, nchirps)
# pick random parameters for rises
rise_size = np.random.uniform(10, 100, nrises)
rise_tau = np.random.uniform(0.5, 1.5, nrises)
decay_tau = np.random.uniform(5, 15, nrises)
# generate frequency trace with chirps
chirp_trace, chirp_amp = chirps(
eodf=0.0,
samplerate=samplerate,
duration=duration,
chirp_times=chirp_times,
chirp_size=chirp_size,
chirp_width=chirp_width,
chirp_kurtosis=chirp_kurtosis,
chirp_contrast=chirp_contrast,
)
# generate frequency trace with rises
rise_trace = rises(
eodf=0.0,
samplerate=samplerate,
duration=duration,
rise_times=rise_times,
rise_size=rise_size,
rise_tau=rise_tau,
decay_tau=decay_tau,
)
# combine traces to one
full_trace = chirp_trace + rise_trace + eodf
# make the EOD from the frequency trace
fish = wavefish_eods(
fish="Alepto",
frequency=full_trace,
samplerate=samplerate,
duration=duration,
phase0=0.0,
noise_std=0.05,
)
signal = fish * chirp_amp
self.signal = signal
self.trace = full_trace
self.time = time
self.samplerate = samplerate
self.eodf = eodf
def visualize(self):
spec, freqs, spectime = ps.spectrogram(
data=self.signal,
ratetime=self.samplerate,
freq_resolution=0.5,
overlap_frac=0.5,
)
fig, (ax1, ax2) = plt.subplots(2, 1, height_ratios=[1, 4], sharex=True)
ax1.plot(self.time, self.signal)
ax1.set_ylabel("Amplitude")
ax1.set_xlabel("Time (s)")
ax1.set_title("EOD signal")
ax2.imshow(
ps.decibel(spec),
origin="lower",
aspect="auto",
extent=[spectime[0], spectime[-1], freqs[0], freqs[-1]],
)
ax2.set_ylabel("Frequency (Hz)")
ax2.set_xlabel("Time (s)")
ax2.set_title("Spectrogram")
ax2.set_ylim(0, 2000)
plt.show()

View File

@ -0,0 +1,150 @@
import matplotlib.pyplot as plt
import numpy as np
from filters import bandpass_filter, inst_freq, instantaneous_frequency
from fish_signal import chirps, wavefish_eods
from IPython import embed
def switch_test(test, defaultparams, testparams):
if test == "width":
defaultparams["chirp_width"] = testparams["chirp_width"]
key = "chirp_width"
elif test == "size":
defaultparams["chirp_size"] = testparams["chirp_size"]
key = "chirp_size"
elif test == "kurtosis":
defaultparams["chirp_kurtosis"] = testparams["chirp_kurtosis"]
key = "chirp_kurtosis"
elif test == "contrast":
defaultparams["chirp_contrast"] = testparams["chirp_contrast"]
key = "chirp_contrast"
else:
raise ValueError("Test not recognized")
return key, defaultparams
def extract_dict(dict, index):
return {key: value[index] for key, value in dict.items()}
def test(test1, test2, resolution=10):
assert test1 in [
"width",
"size",
"kurtosis",
"contrast",
], "Test1 not recognized"
assert test2 in [
"width",
"size",
"kurtosis",
"contrast",
], "Test2 not recognized"
# Define the parameters for the chirp simulations
ntest = resolution
defaultparams = dict(
chirp_size=np.ones(ntest) * 100,
chirp_width=np.ones(ntest) * 0.1,
chirp_kurtosis=np.ones(ntest) * 1.0,
chirp_contrast=np.ones(ntest) * 0.5,
)
testparams = dict(
chirp_width=np.linspace(0.01, 0.2, ntest),
chirp_size=np.linspace(50, 300, ntest),
chirp_kurtosis=np.linspace(0.5, 1.5, ntest),
chirp_contrast=np.linspace(0.01, 1.0, ntest),
)
key1, chirp_params = switch_test(test1, defaultparams, testparams)
key2, chirp_params = switch_test(test2, chirp_params, testparams)
# make the chirp trace
eodf = 500
samplerate = 20000
duration = 2
chirp_times = [0.5, 1, 1.5]
wide_cutoffs = 200
tight_cutoffs = 10
distances = np.full((ntest, ntest), np.nan)
fig, axs = plt.subplots(
ntest, ntest, figsize=(10, 10), sharex=True, sharey=True
)
axs = axs.flatten()
iter0 = 0
for iter1, test1_param in enumerate(chirp_params[key1]):
for iter2, test2_param in enumerate(chirp_params[key2]):
# get the chirp parameters for the current test
inner_chirp_params = extract_dict(chirp_params, iter2)
inner_chirp_params[key1] = test1_param
inner_chirp_params[key2] = test2_param
# make the chirp trace for the current chirp parameters
sizes = np.ones(len(chirp_times)) * inner_chirp_params["chirp_size"]
widths = (
np.ones(len(chirp_times)) * inner_chirp_params["chirp_width"]
)
kurtosis = (
np.ones(len(chirp_times)) * inner_chirp_params["chirp_kurtosis"]
)
contrast = (
np.ones(len(chirp_times)) * inner_chirp_params["chirp_contrast"]
)
# make the chirp trace
chirp_trace, ampmod = chirps(
eodf,
samplerate,
duration,
chirp_times,
sizes,
widths,
kurtosis,
contrast,
)
signal = wavefish_eods(
fish="Alepto",
frequency=chirp_trace,
samplerate=samplerate,
duration=duration,
phase0=0.0,
noise_std=0.05,
)
signal = signal * ampmod
# apply broadband filter
wide_signal = bandpass_filter(
signal, samplerate, eodf - wide_cutoffs, eodf + wide_cutoffs
)
tight_signal = bandpass_filter(
signal, samplerate, eodf - tight_cutoffs, eodf + tight_cutoffs
)
# get the instantaneous frequency
wide_frequency = inst_freq(wide_signal, samplerate)
tight_frequency = inst_freq(tight_signal, samplerate)
bool_mask = wide_frequency != 0
axs[iter0].plot(wide_frequency[bool_mask])
axs[iter0].plot(tight_frequency[bool_mask])
fig.supylabel(key1)
fig.supxlabel(key2)
iter0 += 1
plt.show()
def main():
test("contrast", "kurtosis")
if __name__ == "__main__":
main()

View File

@ -10,73 +10,84 @@ from modules.filters import bandpass_filter
def main(folder): def main(folder):
file = os.path.join(folder, 'traces-grid.raw') file = os.path.join(folder, "traces-grid.raw")
data = open_data(folder, 60.0, 0, channel=-1) data = open_data(folder, 60.0, 0, channel=-1)
time = np.load(folder + 'times.npy', allow_pickle=True) time = np.load(folder + "times.npy", allow_pickle=True)
freq = np.load(folder + 'fund_v.npy', allow_pickle=True) freq = np.load(folder + "fund_v.npy", allow_pickle=True)
ident = np.load(folder + 'ident_v.npy', allow_pickle=True) ident = np.load(folder + "ident_v.npy", allow_pickle=True)
idx = np.load(folder + 'idx_v.npy', allow_pickle=True) idx = np.load(folder + "idx_v.npy", allow_pickle=True)
t0 = 3*60*60 + 6*60 + 43.5 t0 = 3 * 60 * 60 + 6 * 60 + 43.5
dt = 60 dt = 60
data_oi = data[t0 * data.samplerate: (t0+dt)*data.samplerate, :] data_oi = data[t0 * data.samplerate : (t0 + dt) * data.samplerate, :]
for i in [10]: for i in [10]:
# getting the spectogramm # getting the spectogramm
spec_power, spec_freqs, spec_times = spectrogram( spec_power, spec_freqs, spec_times = spectrogram(
data_oi[:, i], ratetime=data.samplerate, freq_resolution=50, overlap_frac=0.0) data_oi[:, i],
fig, ax = plt.subplots(figsize=(20/2.54, 12/2.54)) ratetime=data.samplerate,
ax.pcolormesh(spec_times, spec_freqs, decibel( freq_resolution=50,
spec_power), vmin=-100, vmax=-50) overlap_frac=0.0,
)
fig, ax = plt.subplots(figsize=(20 / 2.54, 12 / 2.54))
ax.pcolormesh(
spec_times, spec_freqs, decibel(spec_power), vmin=-100, vmax=-50
)
for track_id in np.unique(ident): for track_id in np.unique(ident):
# window_index for time array in time window # window_index for time array in time window
window_index = np.arange(len(idx))[(ident == track_id) & window_index = np.arange(len(idx))[
(time[idx] >= t0) & (ident == track_id)
(time[idx] <= (t0+dt))] & (time[idx] >= t0)
& (time[idx] <= (t0 + dt))
]
freq_temp = freq[window_index] freq_temp = freq[window_index]
time_temp = time[idx[window_index]] time_temp = time[idx[window_index]]
#mean_freq = np.mean(freq_temp) # mean_freq = np.mean(freq_temp)
#fdata = bandpass_filter(data_oi[:, track_id], data.samplerate, mean_freq-5, mean_freq+200) # fdata = bandpass_filter(data_oi[:, track_id], data.samplerate, mean_freq-5, mean_freq+200)
ax.plot(time_temp - t0, freq_temp) ax.plot(time_temp - t0, freq_temp)
ax.set_ylim(500, 1000) ax.set_ylim(500, 1000)
plt.show() plt.show()
# filter plot # filter plot
id = 10. id = 10.0
i = 10 i = 10
window_index = np.arange(len(idx))[(ident == id) & window_index = np.arange(len(idx))[
(time[idx] >= t0) & (ident == id) & (time[idx] >= t0) & (time[idx] <= (t0 + dt))
(time[idx] <= (t0+dt))] ]
freq_temp = freq[window_index] freq_temp = freq[window_index]
time_temp = time[idx[window_index]] time_temp = time[idx[window_index]]
mean_freq = np.mean(freq_temp) mean_freq = np.mean(freq_temp)
fdata = bandpass_filter( fdata = bandpass_filter(
data_oi[:, i], rate=data.samplerate, lowf=mean_freq-5, highf=mean_freq+200) data_oi[:, i],
rate=data.samplerate,
lowf=mean_freq - 5,
highf=mean_freq + 200,
)
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.plot(np.arange(len(fdata))/data.samplerate, fdata, marker='*') ax.plot(np.arange(len(fdata)) / data.samplerate, fdata, marker="*")
# plt.show() # plt.show()
# freqency analyis of filtered data # freqency analyis of filtered data
time_fdata = np.arange(len(fdata))/data.samplerate time_fdata = np.arange(len(fdata)) / data.samplerate
roll_fdata = np.roll(fdata, shift=1) roll_fdata = np.roll(fdata, shift=1)
period_index = np.arange(len(fdata))[(roll_fdata < 0) & (fdata >= 0)] period_index = np.arange(len(fdata))[(roll_fdata < 0) & (fdata >= 0)]
plt.plot(time_fdata, fdata) plt.plot(time_fdata, fdata)
plt.scatter(time_fdata[period_index], fdata[period_index], c='r') plt.scatter(time_fdata[period_index], fdata[period_index], c="r")
plt.scatter(time_fdata[period_index-1], fdata[period_index-1], c='r') plt.scatter(time_fdata[period_index - 1], fdata[period_index - 1], c="r")
upper_bound = np.abs(fdata[period_index]) upper_bound = np.abs(fdata[period_index])
lower_bound = np.abs(fdata[period_index-1]) lower_bound = np.abs(fdata[period_index - 1])
upper_times = np.abs(time_fdata[period_index]) upper_times = np.abs(time_fdata[period_index])
lower_times = np.abs(time_fdata[period_index-1]) lower_times = np.abs(time_fdata[period_index - 1])
lower_ratio = lower_bound/(lower_bound+upper_bound) lower_ratio = lower_bound / (lower_bound + upper_bound)
upper_ratio = upper_bound/(lower_bound+upper_bound) upper_ratio = upper_bound / (lower_bound + upper_bound)
time_delta = upper_times-lower_times time_delta = upper_times - lower_times
true_zero = lower_times + time_delta*lower_ratio true_zero = lower_times + time_delta * lower_ratio
plt.scatter(true_zero, np.zeros(len(true_zero))) plt.scatter(true_zero, np.zeros(len(true_zero)))
@ -84,7 +95,7 @@ def main(folder):
inst_freq = 1 / np.diff(true_zero) inst_freq = 1 / np.diff(true_zero)
filtered_inst_freq = gaussian_filter1d(inst_freq, 0.005) filtered_inst_freq = gaussian_filter1d(inst_freq, 0.005)
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.plot(filtered_inst_freq, marker='.') ax.plot(filtered_inst_freq, marker=".")
# in 5 sekunden welcher fisch auf einer elektrode am # in 5 sekunden welcher fisch auf einer elektrode am
embed() embed()
@ -99,5 +110,7 @@ def main(folder):
pass pass
if __name__ == '__main__': if __name__ == "__main__":
main('/Users/acfw/Documents/uni_tuebingen/chirpdetection/gp_benda/data/2022-06-02-10_00/') main(
"/Users/acfw/Documents/uni_tuebingen/chirpdetection/gp_benda/data/2022-06-02-10_00/"
)

View File

@ -12,25 +12,27 @@ from modules.filehandling import LoadData
def main(folder): def main(folder):
data = LoadData(folder) data = LoadData(folder)
t0 = 3*60*60 + 6*60 + 43.5 t0 = 3 * 60 * 60 + 6 * 60 + 43.5
dt = 60 dt = 60
data_oi = data.raw[t0 * data.raw_rate: (t0+dt)*data.raw_rate, :] data_oi = data.raw[t0 * data.raw_rate : (t0 + dt) * data.raw_rate, :]
# good electrode # good electrode
electrode = 10 electrode = 10
data_oi = data_oi[:, electrode] data_oi = data_oi[:, electrode]
fig, axs = plt.subplots(2,1) fig, axs = plt.subplots(2, 1)
axs[0].plot( np.arange(data_oi.shape[0]) / data.raw_rate, data_oi) axs[0].plot(np.arange(data_oi.shape[0]) / data.raw_rate, data_oi)
for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
rack_window_index = np.arange(len(data.idx))[ rack_window_index = np.arange(len(data.idx))[
(data.ident == track_id) & (data.ident == track_id)
(data.time[data.idx] >= t0) & & (data.time[data.idx] >= t0)
(data.time[data.idx] <= (t0+dt))] & (data.time[data.idx] <= (t0 + dt))
]
freq_fish = data.freq[rack_window_index] freq_fish = data.freq[rack_window_index]
axs[1].plot(np.arange(freq_fish.shape[0]) / data.raw_rate, freq_fish) axs[1].plot(np.arange(freq_fish.shape[0]) / data.raw_rate, freq_fish)
plt.show() plt.show()
if __name__ == "__main__":
if __name__ == '__main__': main(
main('/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/2022-06-02-10_00/') "/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/2022-06-02-10_00/"
)

View File

@ -11,6 +11,7 @@ from scipy.ndimage import gaussian_filter1d
logger = makeLogger(__name__) logger = makeLogger(__name__)
class Behavior: class Behavior:
"""Load behavior data from csv file as class attributes """Load behavior data from csv file as class attributes
Attributes Attributes
@ -31,23 +32,35 @@ class Behavior:
""" """
def __init__(self, folder_path: str) -> None: def __init__(self, folder_path: str) -> None:
LED_on_time_BORIS = np.load(
os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True
LED_on_time_BORIS = np.load(os.path.join(folder_path, 'LED_on_time.npy'), allow_pickle=True) )
self.time = np.load(os.path.join(folder_path, "times.npy"), allow_pickle=True) self.time = np.load(
csv_filename = [f for f in os.listdir(folder_path) if f.endswith('.csv')][0] # check if there are more than one csv file os.path.join(folder_path, "times.npy"), allow_pickle=True
)
csv_filename = [
f for f in os.listdir(folder_path) if f.endswith(".csv")
][
0
] # check if there are more than one csv file
self.dataframe = read_csv(os.path.join(folder_path, csv_filename)) self.dataframe = read_csv(os.path.join(folder_path, csv_filename))
self.chirps = np.load(os.path.join(folder_path, 'chirps.npy'), allow_pickle=True) self.chirps = np.load(
self.chirps_ids = np.load(os.path.join(folder_path, 'chirps_ids.npy'), allow_pickle=True) os.path.join(folder_path, "chirps.npy"), allow_pickle=True
)
self.chirps_ids = np.load(
os.path.join(folder_path, "chirps_ids.npy"), allow_pickle=True
)
for k, key in enumerate(self.dataframe.keys()): for k, key in enumerate(self.dataframe.keys()):
key = key.lower() key = key.lower()
if ' ' in key: if " " in key:
key = key.replace(' ', '_') key = key.replace(" ", "_")
if '(' in key: if "(" in key:
key = key.replace('(', '') key = key.replace("(", "")
key = key.replace(')', '') key = key.replace(")", "")
setattr(self, key, np.array(self.dataframe[self.dataframe.keys()[k]])) setattr(
self, key, np.array(self.dataframe[self.dataframe.keys()[k]])
)
last_LED_t_BORIS = LED_on_time_BORIS[-1] last_LED_t_BORIS = LED_on_time_BORIS[-1]
real_time_range = self.time[-1] - self.time[0] real_time_range = self.time[-1] - self.time[0]
@ -56,6 +69,7 @@ class Behavior:
self.start_s = (self.start_s - shift) / factor self.start_s = (self.start_s - shift) / factor
self.stop_s = (self.stop_s - shift) / factor self.stop_s = (self.stop_s - shift) / factor
""" """
1 - chasing onset 1 - chasing onset
2 - chasing offset 2 - chasing offset
@ -83,74 +97,74 @@ temporal encpding needs to be corrected ... not exactly 25FPS.
behavior = data['Behavior'] behavior = data['Behavior']
""" """
def correct_chasing_events(
category: np.ndarray,
timestamps: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
onset_ids = np.arange( def correct_chasing_events(
len(category))[category == 0] category: np.ndarray, timestamps: np.ndarray
offset_ids = np.arange( ) -> tuple[np.ndarray, np.ndarray]:
len(category))[category == 1] onset_ids = np.arange(len(category))[category == 0]
offset_ids = np.arange(len(category))[category == 1]
# Check whether on- or offset is longer and calculate length difference # Check whether on- or offset is longer and calculate length difference
if len(onset_ids) > len(offset_ids): if len(onset_ids) > len(offset_ids):
len_diff = len(onset_ids) - len(offset_ids) len_diff = len(onset_ids) - len(offset_ids)
longer_array = onset_ids longer_array = onset_ids
shorter_array = offset_ids shorter_array = offset_ids
logger.info(f'Onsets are greater than offsets by {len_diff}') logger.info(f"Onsets are greater than offsets by {len_diff}")
elif len(onset_ids) < len(offset_ids): elif len(onset_ids) < len(offset_ids):
len_diff = len(offset_ids) - len(onset_ids) len_diff = len(offset_ids) - len(onset_ids)
longer_array = offset_ids longer_array = offset_ids
shorter_array = onset_ids shorter_array = onset_ids
logger.info(f'Offsets are greater than offsets by {len_diff}') logger.info(f"Offsets are greater than offsets by {len_diff}")
elif len(onset_ids) == len(offset_ids): elif len(onset_ids) == len(offset_ids):
logger.info('Chasing events are equal') logger.info("Chasing events are equal")
return category, timestamps return category, timestamps
# Correct the wrong chasing events; delete double events # Correct the wrong chasing events; delete double events
wrong_ids = [] wrong_ids = []
for i in range(len(longer_array)-(len_diff+1)): for i in range(len(longer_array) - (len_diff + 1)):
if (shorter_array[i] > longer_array[i]) & (shorter_array[i] < longer_array[i+1]): if (shorter_array[i] > longer_array[i]) & (
shorter_array[i] < longer_array[i + 1]
):
pass pass
else: else:
wrong_ids.append(longer_array[i]) wrong_ids.append(longer_array[i])
longer_array = np.delete(longer_array, i) longer_array = np.delete(longer_array, i)
category = np.delete( category = np.delete(category, wrong_ids)
category, wrong_ids) timestamps = np.delete(timestamps, wrong_ids)
timestamps = np.delete(
timestamps, wrong_ids)
return category, timestamps return category, timestamps
def event_triggered_chirps( def event_triggered_chirps(
event: np.ndarray, event: np.ndarray,
chirps:np.ndarray, chirps: np.ndarray,
time_before_event: int, time_before_event: int,
time_after_event: int time_after_event: int,
)-> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
event_chirps = [] # chirps that are in specified window around event event_chirps = [] # chirps that are in specified window around event
centered_chirps = [] # timestamps of chirps around event centered on the event timepoint centered_chirps = (
[]
) # timestamps of chirps around event centered on the event timepoint
for event_timestamp in event: for event_timestamp in event:
start = event_timestamp - time_before_event # timepoint of window start start = event_timestamp - time_before_event # timepoint of window start
stop = event_timestamp + time_after_event # timepoint of window ending stop = event_timestamp + time_after_event # timepoint of window ending
chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)] # get chirps that are in a -5 to +5 sec window around event chirps_around_event = [
c for c in chirps if (c >= start) & (c <= stop)
] # get chirps that are in a -5 to +5 sec window around event
event_chirps.append(chirps_around_event) event_chirps.append(chirps_around_event)
if len(chirps_around_event) == 0: if len(chirps_around_event) == 0:
continue continue
else: else:
centered_chirps.append(chirps_around_event - event_timestamp) centered_chirps.append(chirps_around_event - event_timestamp)
centered_chirps = np.concatenate(centered_chirps, axis=0) # convert list of arrays to one array for plotting centered_chirps = np.concatenate(
centered_chirps, axis=0
) # convert list of arrays to one array for plotting
return event_chirps, centered_chirps return event_chirps, centered_chirps
def main(datapath: str): def main(datapath: str):
# behavior is pandas dataframe with all the data # behavior is pandas dataframe with all the data
bh = Behavior(datapath) bh = Behavior(datapath)
@ -172,10 +186,34 @@ def main(datapath: str):
# First overview plot # First overview plot
fig1, ax1 = plt.subplots() fig1, ax1 = plt.subplots()
ax1.scatter(chirps, np.ones_like(chirps), marker='*', color='royalblue', label='Chirps') ax1.scatter(
ax1.scatter(chasing_onset, np.ones_like(chasing_onset)*2, marker='.', color='forestgreen', label='Chasing onset') chirps,
ax1.scatter(chasing_offset, np.ones_like(chasing_offset)*2.5, marker='.', color='firebrick', label='Chasing offset') np.ones_like(chirps),
ax1.scatter(physical_contact, np.ones_like(physical_contact)*3, marker='x', color='black', label='Physical contact') marker="*",
color="royalblue",
label="Chirps",
)
ax1.scatter(
chasing_onset,
np.ones_like(chasing_onset) * 2,
marker=".",
color="forestgreen",
label="Chasing onset",
)
ax1.scatter(
chasing_offset,
np.ones_like(chasing_offset) * 2.5,
marker=".",
color="firebrick",
label="Chasing offset",
)
ax1.scatter(
physical_contact,
np.ones_like(physical_contact) * 3,
marker="x",
color="black",
label="Physical contact",
)
plt.legend() plt.legend()
# plt.show() # plt.show()
plt.close() plt.close()
@ -194,22 +232,33 @@ def main(datapath: str):
# chirps = chirps[chirps_fish_ids == fish] # chirps = chirps[chirps_fish_ids == fish]
# print(fish) # print(fish)
chasing_chirps, centered_chasing_chirps = event_triggered_chirps(chasing_onset, chirps, time_around_event, time_around_event) chasing_chirps, centered_chasing_chirps = event_triggered_chirps(
physical_chirps, centered_physical_chirps = event_triggered_chirps(physical_contact, chirps, time_around_event, time_around_event) chasing_onset, chirps, time_around_event, time_around_event
)
physical_chirps, centered_physical_chirps = event_triggered_chirps(
physical_contact, chirps, time_around_event, time_around_event
)
# Kernel density estimation ??? # Kernel density estimation ???
# centered_chasing_chirps_convolved = gaussian_filter1d(centered_chasing_chirps, 5) # centered_chasing_chirps_convolved = gaussian_filter1d(centered_chasing_chirps, 5)
# centered_chasing = chasing_onset[0] - chasing_onset[0] ## get the 0 timepoint for plotting; set one chasing event to 0 # centered_chasing = chasing_onset[0] - chasing_onset[0] ## get the 0 timepoint for plotting; set one chasing event to 0
offsets = [0.5, 1] offsets = [0.5, 1]
fig4, ax4 = plt.subplots(figsize=(20 / 2.54, 12 / 2.54), constrained_layout=True) fig4, ax4 = plt.subplots(
ax4.eventplot(np.array([centered_chasing_chirps, centered_physical_chirps]), lineoffsets=offsets, linelengths=0.25, colors=['g', 'r']) figsize=(20 / 2.54, 12 / 2.54), constrained_layout=True
ax4.vlines(0, 0, 1.5, 'tab:grey', 'dashed', 'Timepoint of event') )
ax4.eventplot(
np.array([centered_chasing_chirps, centered_physical_chirps]),
lineoffsets=offsets,
linelengths=0.25,
colors=["g", "r"],
)
ax4.vlines(0, 0, 1.5, "tab:grey", "dashed", "Timepoint of event")
# ax4.plot(centered_chasing_chirps_convolved) # ax4.plot(centered_chasing_chirps_convolved)
ax4.set_yticks(offsets) ax4.set_yticks(offsets)
ax4.set_yticklabels(['Chasings', 'Physical \n contacts']) ax4.set_yticklabels(["Chasings", "Physical \n contacts"])
ax4.set_xlabel('Time[s]') ax4.set_xlabel("Time[s]")
ax4.set_ylabel('Type of event') ax4.set_ylabel("Type of event")
plt.show() plt.show()
# Associate chirps to inidividual fish # Associate chirps to inidividual fish
@ -227,14 +276,13 @@ def main(datapath: str):
#### Chirp counts per fish general ##### #### Chirp counts per fish general #####
fig2, ax2 = plt.subplots() fig2, ax2 = plt.subplots()
x = ['Fish1', 'Fish2'] x = ["Fish1", "Fish2"]
width = 0.35 width = 0.35
ax2.bar(x, fish, width=width) ax2.bar(x, fish, width=width)
ax2.set_ylabel('Chirp count') ax2.set_ylabel("Chirp count")
# plt.show() # plt.show()
plt.close() plt.close()
##### Count chirps emitted during chasing events and chirps emitted out of chasing events ##### ##### Count chirps emitted during chasing events and chirps emitted out of chasing events #####
chirps_in_chasings = [] chirps_in_chasings = []
for onset, offset in zip(chasing_onset, chasing_offset): for onset, offset in zip(chasing_onset, chasing_offset):
@ -251,23 +299,24 @@ def main(datapath: str):
counts_chirps_chasings += 1 counts_chirps_chasings += 1
# chirps in chasing events # chirps in chasing events
fig3 , ax3 = plt.subplots() fig3, ax3 = plt.subplots()
ax3.bar(['Chirps in chasing events', 'Chasing events without Chirps'], [counts_chirps_chasings, chasings_without_chirps], width=width) ax3.bar(
plt.ylabel('Count') ["Chirps in chasing events", "Chasing events without Chirps"],
[counts_chirps_chasings, chasings_without_chirps],
width=width,
)
plt.ylabel("Count")
# plt.show() # plt.show()
plt.close() plt.close()
# comparison between chasing events with and without chirps # comparison between chasing events with and without chirps
embed() embed()
exit() exit()
if __name__ == "__main__":
if __name__ == '__main__':
# Path to the data # Path to the data
datapath = '../data/mount_data/2020-05-13-10_00/' datapath = "../data/mount_data/2020-05-13-10_00/"
datapath = '../data/mount_data/2020-05-13-10_00/' datapath = "../data/mount_data/2020-05-13-10_00/"
main(datapath) main(datapath)

View File

@ -8,30 +8,27 @@ from modules.datahandling import instantaneous_frequency
from modules.simulations import create_chirp from modules.simulations import create_chirp
# trying thunderfish fakefish chirp simulation --------------------------------- # trying thunderfish fakefish chirp simulation ---------------------------------
samplerate = 44100 samplerate = 44100
freq, ampl = fakefish.chirps(eodf=500, chirp_contrast=0.2) freq, ampl = fakefish.chirps(eodf=500, chirp_contrast=0.2)
data = fakefish.wavefish_eods(fish='Alepto', frequency=freq, phase0=3, samplerate=samplerate) data = fakefish.wavefish_eods(
fish="Alepto", frequency=freq, phase0=3, samplerate=samplerate
)
# filter signal with bandpass_filter # filter signal with bandpass_filter
data_filterd = bandpass_filter(data*ampl+1, samplerate, 0.01, 1.99) data_filterd = bandpass_filter(data * ampl + 1, samplerate, 0.01, 1.99)
embed() embed()
data_freq_time, data_freq = instantaneous_frequency(data, samplerate, 5) data_freq_time, data_freq = instantaneous_frequency(data, samplerate, 5)
fig, ax = plt.subplots(4, 1, figsize=(20 / 2.54, 12 / 2.54), sharex=True) fig, ax = plt.subplots(4, 1, figsize=(20 / 2.54, 12 / 2.54), sharex=True)
ax[0].plot(np.arange(len(data))/samplerate, data*ampl) ax[0].plot(np.arange(len(data)) / samplerate, data * ampl)
#ax[0].scatter(true_zero, np.zeros_like(true_zero), color='red') # ax[0].scatter(true_zero, np.zeros_like(true_zero), color='red')
ax[1].plot(np.arange(len(data_filterd))/samplerate, data_filterd) ax[1].plot(np.arange(len(data_filterd)) / samplerate, data_filterd)
ax[2].plot(np.arange(len(freq))/samplerate, freq) ax[2].plot(np.arange(len(freq)) / samplerate, freq)
ax[3].plot(data_freq_time, data_freq) ax[3].plot(data_freq_time, data_freq)
plt.show() plt.show()
embed() embed()

View File

@ -1,6 +1,8 @@
from itertools import compress
from dataclasses import dataclass from dataclasses import dataclass
from itertools import compress
import matplotlib.gridspec as gr
import matplotlib.pyplot as plt
import numpy as np import numpy as np
from IPython import embed from IPython import embed
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -15,11 +17,17 @@ from modules.plotstyle import PlotStyle
from modules.logger import makeLogger from modules.logger import makeLogger
from modules.datahandling import ( from modules.datahandling import (
flatten, flatten,
purge_duplicates,
group_timestamps, group_timestamps,
instantaneous_frequency, instantaneous_frequency,
minmaxnorm, minmaxnorm,
purge_duplicates,
) )
from modules.filehandling import ConfLoader, LoadData, make_outputdir
from modules.filters import bandpass_filter, envelope, highpass_filter
from modules.logger import makeLogger
from modules.plotstyle import PlotStyle
from scipy.signal import find_peaks
from thunderfish.powerspectrum import decibel, spectrogram
logger = makeLogger(__name__) logger = makeLogger(__name__)
@ -58,7 +66,6 @@ class ChirpPlotBuffer:
frequency_peaks: np.ndarray frequency_peaks: np.ndarray
def plot_buffer(self, chirps: np.ndarray, plot: str) -> None: def plot_buffer(self, chirps: np.ndarray, plot: str) -> None:
logger.debug("Starting plotting") logger.debug("Starting plotting")
# make data for plotting # make data for plotting
@ -134,7 +141,6 @@ class ChirpPlotBuffer:
ax0.set_ylim(np.min(self.frequency) - 100, np.max(self.frequency) + 200) ax0.set_ylim(np.min(self.frequency) - 100, np.max(self.frequency) + 200)
for track_id in self.data.ids: for track_id in self.data.ids:
t0_track = self.t0_old - 5 t0_track = self.t0_old - 5
dt_track = self.dt + 10 dt_track = self.dt + 10
window_idx = np.arange(len(self.data.idx))[ window_idx = np.arange(len(self.data.idx))[
@ -175,10 +181,16 @@ class ChirpPlotBuffer:
# ) # )
ax0.axhline( ax0.axhline(
q50 - self.config.minimal_bandwidth / 2, color=ps.gblue1, lw=1, ls="dashed" q50 - self.config.minimal_bandwidth / 2,
color=ps.gblue1,
lw=1,
ls="dashed",
) )
ax0.axhline( ax0.axhline(
q50 + self.config.minimal_bandwidth / 2, color=ps.gblue1, lw=1, ls="dashed" q50 + self.config.minimal_bandwidth / 2,
color=ps.gblue1,
lw=1,
ls="dashed",
) )
ax0.axhline(search_lower, color=ps.gblue2, lw=1, ls="dashed") ax0.axhline(search_lower, color=ps.gblue2, lw=1, ls="dashed")
ax0.axhline(search_upper, color=ps.gblue2, lw=1, ls="dashed") ax0.axhline(search_upper, color=ps.gblue2, lw=1, ls="dashed")
@ -204,7 +216,11 @@ class ChirpPlotBuffer:
# plot waveform of filtered signal # plot waveform of filtered signal
ax1.plot( ax1.plot(
self.time, self.baseline * waveform_scaler, c=ps.gray, lw=lw, alpha=0.5 self.time,
self.baseline * waveform_scaler,
c=ps.gray,
lw=lw,
alpha=0.5,
) )
ax1.plot( ax1.plot(
self.time, self.time,
@ -215,7 +231,13 @@ class ChirpPlotBuffer:
) )
# plot waveform of filtered search signal # plot waveform of filtered search signal
ax2.plot(self.time, self.search * waveform_scaler, c=ps.gray, lw=lw, alpha=0.5) ax2.plot(
self.time,
self.search * waveform_scaler,
c=ps.gray,
lw=lw,
alpha=0.5,
)
ax2.plot( ax2.plot(
self.time, self.time,
self.search_envelope_unfiltered * waveform_scaler, self.search_envelope_unfiltered * waveform_scaler,
@ -237,9 +259,7 @@ class ChirpPlotBuffer:
# ax4.plot( # ax4.plot(
# self.time, self.baseline_envelope * waveform_scaler, c=ps.gblue1, lw=lw # self.time, self.baseline_envelope * waveform_scaler, c=ps.gblue1, lw=lw
# ) # )
ax4.plot( ax4.plot(self.time, self.baseline_envelope, c=ps.gblue1, lw=lw)
self.time, self.baseline_envelope, c=ps.gblue1, lw=lw
)
ax4.scatter( ax4.scatter(
(self.time)[self.baseline_peaks], (self.time)[self.baseline_peaks],
# (self.baseline_envelope * waveform_scaler)[self.baseline_peaks], # (self.baseline_envelope * waveform_scaler)[self.baseline_peaks],
@ -268,7 +288,9 @@ class ChirpPlotBuffer:
) )
# plot filtered instantaneous frequency # plot filtered instantaneous frequency
ax6.plot(self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw) ax6.plot(
self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw
)
ax6.scatter( ax6.scatter(
self.frequency_time[self.frequency_peaks], self.frequency_time[self.frequency_peaks],
self.frequency_filtered[self.frequency_peaks], self.frequency_filtered[self.frequency_peaks],
@ -302,7 +324,9 @@ class ChirpPlotBuffer:
# ax7.spines.bottom.set_bounds((0, 5)) # ax7.spines.bottom.set_bounds((0, 5))
ax0.set_xlim(0, self.config.window) ax0.set_xlim(0, self.config.window)
plt.subplots_adjust(left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2) plt.subplots_adjust(
left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2
)
fig.align_labels() fig.align_labels()
if plot == "show": if plot == "show":
@ -407,7 +431,9 @@ def extract_frequency_bands(
q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2 q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2
# filter baseline # filter baseline
filtered_baseline = bandpass_filter(raw_data, samplerate, lowf=q25, highf=q75) filtered_baseline = bandpass_filter(
raw_data, samplerate, lowf=q25, highf=q75
)
# filter search area # filter search area
filtered_search_freq = bandpass_filter( filtered_search_freq = bandpass_filter(
@ -452,12 +478,14 @@ def window_median_all_track_ids(
track_ids = [] track_ids = []
for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
# the window index combines the track id and the time window # the window index combines the track id and the time window
window_idx = np.arange(len(data.idx))[ window_idx = np.arange(len(data.idx))[
(data.ident == track_id) (data.ident == track_id)
& (data.time[data.idx] >= window_start_seconds) & (data.time[data.idx] >= window_start_seconds)
& (data.time[data.idx] <= (window_start_seconds + window_duration_seconds)) & (
data.time[data.idx]
<= (window_start_seconds + window_duration_seconds)
)
] ]
if len(data.freq[window_idx]) > 0: if len(data.freq[window_idx]) > 0:
@ -594,15 +622,15 @@ def find_searchband(
# iterate through theses tracks # iterate through theses tracks
if check_track_ids.size != 0: if check_track_ids.size != 0:
for j, check_track_id in enumerate(check_track_ids): for j, check_track_id in enumerate(check_track_ids):
q25_temp = q25[percentiles_ids == check_track_id] q25_temp = q25[percentiles_ids == check_track_id]
q75_temp = q75[percentiles_ids == check_track_id] q75_temp = q75[percentiles_ids == check_track_id]
bool_lower[search_window > q25_temp - config.search_res] = False bool_lower[search_window > q25_temp - config.search_res] = False
bool_upper[search_window < q75_temp + config.search_res] = False bool_upper[search_window < q75_temp + config.search_res] = False
search_window_bool[(bool_lower == False) & (bool_upper == False)] = False search_window_bool[
(bool_lower == False) & (bool_upper == False)
] = False
# find gaps in search window # find gaps in search window
search_window_indices = np.arange(len(search_window)) search_window_indices = np.arange(len(search_window))
@ -621,7 +649,9 @@ def find_searchband(
# if the first value is -1, the array starst with true, so a gap # if the first value is -1, the array starst with true, so a gap
if nonzeros[0] == -1: if nonzeros[0] == -1:
stops = search_window_indices[search_window_gaps == -1] stops = search_window_indices[search_window_gaps == -1]
starts = np.append(0, search_window_indices[search_window_gaps == 1]) starts = np.append(
0, search_window_indices[search_window_gaps == 1]
)
# if the last value is -1, the array ends with true, so a gap # if the last value is -1, the array ends with true, so a gap
if nonzeros[-1] == 1: if nonzeros[-1] == 1:
@ -658,7 +688,6 @@ def find_searchband(
def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None: def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
assert plot in [ assert plot in [
"save", "save",
"show", "show",
@ -728,7 +757,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
multiwindow_ids = [] multiwindow_ids = []
for st, window_start_index in enumerate(window_start_indices): for st, window_start_index in enumerate(window_start_indices):
logger.info(f"Processing window {st+1} of {len(window_start_indices)}") logger.info(f"Processing window {st+1} of {len(window_start_indices)}")
window_start_seconds = window_start_index / data.raw_rate window_start_seconds = window_start_index / data.raw_rate
@ -743,8 +771,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
) )
# iterate through all fish # iterate through all fish
for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): for tr, track_id in enumerate(
np.unique(data.ident[~np.isnan(data.ident)])
):
logger.debug(f"Processing track {tr} of {len(data.ids)}") logger.debug(f"Processing track {tr} of {len(data.ids)}")
# get index of track data in this time window # get index of track data in this time window
@ -772,16 +801,17 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
nanchecker = np.unique(np.isnan(current_powers)) nanchecker = np.unique(np.isnan(current_powers))
if (len(nanchecker) == 1) and nanchecker[0] is True: if (len(nanchecker) == 1) and nanchecker[0] is True:
logger.warning( logger.warning(
f"No powers available for track {track_id} window {st}," "skipping." f"No powers available for track {track_id} window {st},"
"skipping."
) )
continue continue
# find the strongest electrodes for the current fish in the current # find the strongest electrodes for the current fish in the current
# window # window
best_electrode_index = np.argsort(np.nanmean(current_powers, axis=0))[ best_electrode_index = np.argsort(
-config.number_electrodes : np.nanmean(current_powers, axis=0)
] )[-config.number_electrodes :]
# find a frequency above the baseline of the current fish in which # find a frequency above the baseline of the current fish in which
# no other fish is active to search for chirps there # no other fish is active to search for chirps there
@ -801,9 +831,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
# iterate through electrodes # iterate through electrodes
for el, electrode_index in enumerate(best_electrode_index): for el, electrode_index in enumerate(best_electrode_index):
logger.debug( logger.debug(
f"Processing electrode {el+1} of " f"{len(best_electrode_index)}" f"Processing electrode {el+1} of "
f"{len(best_electrode_index)}"
) )
# LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------ # LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------
@ -812,7 +842,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
current_raw_data = data.raw[ current_raw_data = data.raw[
window_start_index:window_stop_index, electrode_index window_start_index:window_stop_index, electrode_index
] ]
current_raw_time = raw_time[window_start_index:window_stop_index] current_raw_time = raw_time[
window_start_index:window_stop_index
]
# EXTRACT FEATURES -------------------------------------------- # EXTRACT FEATURES --------------------------------------------
@ -838,8 +870,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
# because the instantaneous frequency is not reliable there # because the instantaneous frequency is not reliable there
amplitude_mask = mask_low_amplitudes( amplitude_mask = mask_low_amplitudes(
baseline_envelope_unfiltered, baseline_envelope_unfiltered, config.baseline_min_amplitude
config.baseline_min_amplitude
) )
# highpass filter baseline envelope to remove slower # highpass filter baseline envelope to remove slower
@ -878,7 +909,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
baseline_frequency = instantaneous_frequency( baseline_frequency = instantaneous_frequency(
baselineband, baselineband,
data.raw_rate, data.raw_rate,
config.baseline_frequency_smoothing config.baseline_frequency_smoothing,
) )
# Take the absolute of the instantaneous frequency to invert # Take the absolute of the instantaneous frequency to invert
@ -896,7 +927,10 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
# to enter normalization, where small changes due to noise # to enter normalization, where small changes due to noise
# would be amplified # would be amplified
if not has_chirp(baseline_frequency_filtered[amplitude_mask], config.baseline_frequency_peakheight): if not has_chirp(
baseline_frequency_filtered[amplitude_mask],
config.baseline_frequency_peakheight,
):
continue continue
# CUT OFF OVERLAP --------------------------------------------- # CUT OFF OVERLAP ---------------------------------------------
@ -911,14 +945,20 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
current_raw_time = current_raw_time[no_edges] current_raw_time = current_raw_time[no_edges]
baselineband = baselineband[no_edges] baselineband = baselineband[no_edges]
baseline_envelope_unfiltered = baseline_envelope_unfiltered[no_edges] baseline_envelope_unfiltered = baseline_envelope_unfiltered[
no_edges
]
searchband = searchband[no_edges] searchband = searchband[no_edges]
baseline_envelope = baseline_envelope[no_edges] baseline_envelope = baseline_envelope[no_edges]
search_envelope_unfiltered = search_envelope_unfiltered[no_edges] search_envelope_unfiltered = search_envelope_unfiltered[
no_edges
]
search_envelope = search_envelope[no_edges] search_envelope = search_envelope[no_edges]
baseline_frequency = baseline_frequency[no_edges] baseline_frequency = baseline_frequency[no_edges]
baseline_frequency_filtered = baseline_frequency_filtered[no_edges] baseline_frequency_filtered = baseline_frequency_filtered[
no_edges
]
baseline_frequency_time = current_raw_time baseline_frequency_time = current_raw_time
# # get instantaneous frequency withoup edges # # get instantaneous frequency withoup edges
@ -959,13 +999,16 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
) )
# detect peaks inst_freq_filtered # detect peaks inst_freq_filtered
frequency_peak_indices, _ = find_peaks( frequency_peak_indices, _ = find_peaks(
baseline_frequency_filtered, prominence=config.frequency_prominence baseline_frequency_filtered,
prominence=config.frequency_prominence,
) )
# DETECT CHIRPS IN SEARCH WINDOW ------------------------------ # DETECT CHIRPS IN SEARCH WINDOW ------------------------------
# get the peak timestamps from the peak indices # get the peak timestamps from the peak indices
baseline_peak_timestamps = current_raw_time[baseline_peak_indices] baseline_peak_timestamps = current_raw_time[
baseline_peak_indices
]
search_peak_timestamps = current_raw_time[search_peak_indices] search_peak_timestamps = current_raw_time[search_peak_indices]
frequency_peak_timestamps = baseline_frequency_time[ frequency_peak_timestamps = baseline_frequency_time[
@ -1014,7 +1057,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
) )
if chirp_detected or (debug != "elecrode"): if chirp_detected or (debug != "elecrode"):
logger.debug("Detected chirp, ititialize buffer ...") logger.debug("Detected chirp, ititialize buffer ...")
# save data to Buffer # save data to Buffer
@ -1106,7 +1148,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
multiwindow_chirps_flat = [] multiwindow_chirps_flat = []
multiwindow_ids_flat = [] multiwindow_ids_flat = []
for track_id in np.unique(multiwindow_ids): for track_id in np.unique(multiwindow_ids):
# get chirps for this fish and flatten the list # get chirps for this fish and flatten the list
current_track_bool = np.asarray(multiwindow_ids) == track_id current_track_bool = np.asarray(multiwindow_ids) == track_id
current_track_chirps = flatten( current_track_chirps = flatten(
@ -1115,7 +1156,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
# add flattened chirps to the list # add flattened chirps to the list
multiwindow_chirps_flat.extend(current_track_chirps) multiwindow_chirps_flat.extend(current_track_chirps)
multiwindow_ids_flat.extend(list(np.ones_like(current_track_chirps) * track_id)) multiwindow_ids_flat.extend(
list(np.ones_like(current_track_chirps) * track_id)
)
# purge duplicates, i.e. chirps that are very close to each other # purge duplicates, i.e. chirps that are very close to each other
# duplites arise due to overlapping windows # duplites arise due to overlapping windows

View File

@ -44,4 +44,3 @@ frequency_prominence: 0.3 # peak prominence threshold for baseline freq
# Classify events as chirps if they are less than this time apart # Classify events as chirps if they are less than this time apart
chirp_window_threshold: 0.02 chirp_window_threshold: 0.02

View File

@ -35,28 +35,36 @@ class Behavior:
""" """
def __init__(self, folder_path: str) -> None: def __init__(self, folder_path: str) -> None:
print(f'{folder_path}') print(f"{folder_path}")
LED_on_time_BORIS = np.load(os.path.join( LED_on_time_BORIS = np.load(
folder_path, 'LED_on_time.npy'), allow_pickle=True) os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True
self.time = np.load(os.path.join( )
folder_path, "times.npy"), allow_pickle=True) self.time = np.load(
csv_filename = [f for f in os.listdir(folder_path) if f.endswith( os.path.join(folder_path, "times.npy"), allow_pickle=True
'.csv')][0] # check if there are more than one csv file )
csv_filename = [
f for f in os.listdir(folder_path) if f.endswith(".csv")
][
0
] # check if there are more than one csv file
self.dataframe = read_csv(os.path.join(folder_path, csv_filename)) self.dataframe = read_csv(os.path.join(folder_path, csv_filename))
self.chirps = np.load(os.path.join( self.chirps = np.load(
folder_path, 'chirps.npy'), allow_pickle=True) os.path.join(folder_path, "chirps.npy"), allow_pickle=True
self.chirps_ids = np.load(os.path.join( )
folder_path, 'chirp_ids.npy'), allow_pickle=True) self.chirps_ids = np.load(
os.path.join(folder_path, "chirp_ids.npy"), allow_pickle=True
)
for k, key in enumerate(self.dataframe.keys()): for k, key in enumerate(self.dataframe.keys()):
key = key.lower() key = key.lower()
if ' ' in key: if " " in key:
key = key.replace(' ', '_') key = key.replace(" ", "_")
if '(' in key: if "(" in key:
key = key.replace('(', '') key = key.replace("(", "")
key = key.replace(')', '') key = key.replace(")", "")
setattr(self, key, np.array( setattr(
self.dataframe[self.dataframe.keys()[k]])) self, key, np.array(self.dataframe[self.dataframe.keys()[k]])
)
last_LED_t_BORIS = LED_on_time_BORIS[-1] last_LED_t_BORIS = LED_on_time_BORIS[-1]
real_time_range = self.time[-1] - self.time[0] real_time_range = self.time[-1] - self.time[0]
@ -95,17 +103,14 @@ temporal encpding needs to be corrected ... not exactly 25FPS.
def correct_chasing_events( def correct_chasing_events(
category: np.ndarray, category: np.ndarray, timestamps: np.ndarray
timestamps: np.ndarray
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
onset_ids = np.arange(len(category))[category == 0]
offset_ids = np.arange(len(category))[category == 1]
onset_ids = np.arange( wrong_bh = np.arange(len(category))[category != 2][:-1][
len(category))[category == 0] np.diff(category[category != 2]) == 0
offset_ids = np.arange( ]
len(category))[category == 1]
wrong_bh = np.arange(len(category))[
category != 2][:-1][np.diff(category[category != 2]) == 0]
if onset_ids[0] > offset_ids[0]: if onset_ids[0] > offset_ids[0]:
offset_ids = np.delete(offset_ids, 0) offset_ids = np.delete(offset_ids, 0)
help_index = offset_ids[0] help_index = offset_ids[0]
@ -117,12 +122,12 @@ def correct_chasing_events(
# Check whether on- or offset is longer and calculate length difference # Check whether on- or offset is longer and calculate length difference
if len(onset_ids) > len(offset_ids): if len(onset_ids) > len(offset_ids):
len_diff = len(onset_ids) - len(offset_ids) len_diff = len(onset_ids) - len(offset_ids)
logger.info(f'Onsets are greater than offsets by {len_diff}') logger.info(f"Onsets are greater than offsets by {len_diff}")
elif len(onset_ids) < len(offset_ids): elif len(onset_ids) < len(offset_ids):
len_diff = len(offset_ids) - len(onset_ids) len_diff = len(offset_ids) - len(onset_ids)
logger.info(f'Offsets are greater than onsets by {len_diff}') logger.info(f"Offsets are greater than onsets by {len_diff}")
elif len(onset_ids) == len(offset_ids): elif len(onset_ids) == len(offset_ids):
logger.info('Chasing events are equal') logger.info("Chasing events are equal")
return category, timestamps return category, timestamps
@ -135,7 +140,6 @@ def event_triggered_chirps(
dt: float, dt: float,
width: float, width: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
event_chirps = [] # chirps that are in specified window around event event_chirps = [] # chirps that are in specified window around event
# timestamps of chirps around event centered on the event timepoint # timestamps of chirps around event centered on the event timepoint
centered_chirps = [] centered_chirps = []
@ -159,16 +163,19 @@ def event_triggered_chirps(
else: else:
# convert list of arrays to one array for plotting # convert list of arrays to one array for plotting
centered_chirps = np.concatenate(centered_chirps, axis=0) centered_chirps = np.concatenate(centered_chirps, axis=0)
centered_chirps_convolved = (acausal_kde1d( centered_chirps_convolved = (
centered_chirps, time, width)) / len(event) acausal_kde1d(centered_chirps, time, width)
) / len(event)
return event_chirps, centered_chirps, centered_chirps_convolved return event_chirps, centered_chirps, centered_chirps_convolved
def main(datapath: str): def main(datapath: str):
foldernames = [ foldernames = [
datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath + x)] datapath + x + "/"
for x in os.listdir(datapath)
if os.path.isdir(datapath + x)
]
nrecording_chirps = [] nrecording_chirps = []
nrecording_chirps_fish_ids = [] nrecording_chirps_fish_ids = []
@ -179,7 +186,7 @@ def main(datapath: str):
# Iterate over all recordings and save chirp- and event-timestamps # Iterate over all recordings and save chirp- and event-timestamps
for folder in foldernames: for folder in foldernames:
# exclude folder with empty LED_on_time.npy # exclude folder with empty LED_on_time.npy
if folder == '../data/mount_data/2020-05-12-10_00/': if folder == "../data/mount_data/2020-05-12-10_00/":
continue continue
bh = Behavior(folder) bh = Behavior(folder)
@ -232,18 +239,47 @@ def main(datapath: str):
physical_contacts = nrecording_physicals[i] physical_contacts = nrecording_physicals[i]
# Chirps around chasing onsets # Chirps around chasing onsets
_, centered_chasing_onset_chirps, cc_chasing_onset_chirps = event_triggered_chirps( (
chasing_onsets, chirps, time_before_event, time_after_event, dt, recording_width) _,
centered_chasing_onset_chirps,
cc_chasing_onset_chirps,
) = event_triggered_chirps(
chasing_onsets,
chirps,
time_before_event,
time_after_event,
dt,
recording_width,
)
# Chirps around chasing offsets # Chirps around chasing offsets
_, centered_chasing_offset_chirps, cc_chasing_offset_chirps = event_triggered_chirps( (
chasing_offsets, chirps, time_before_event, time_after_event, dt, recording_width) _,
centered_chasing_offset_chirps,
cc_chasing_offset_chirps,
) = event_triggered_chirps(
chasing_offsets,
chirps,
time_before_event,
time_after_event,
dt,
recording_width,
)
# Chirps around physical contacts # Chirps around physical contacts
_, centered_physical_chirps, cc_physical_chirps = event_triggered_chirps( (
physical_contacts, chirps, time_before_event, time_after_event, dt, recording_width) _,
centered_physical_chirps,
cc_physical_chirps,
) = event_triggered_chirps(
physical_contacts,
chirps,
time_before_event,
time_after_event,
dt,
recording_width,
)
nrecording_centered_onset_chirps.append(centered_chasing_onset_chirps) nrecording_centered_onset_chirps.append(centered_chasing_onset_chirps)
nrecording_centered_offset_chirps.append( nrecording_centered_offset_chirps.append(centered_chasing_offset_chirps)
centered_chasing_offset_chirps)
nrecording_centered_physical_chirps.append(centered_physical_chirps) nrecording_centered_physical_chirps.append(centered_physical_chirps)
## Shuffled chirps ## ## Shuffled chirps ##
@ -331,12 +367,13 @@ def main(datapath: str):
# New bootstrapping approach # New bootstrapping approach
for n in range(nbootstrapping): for n in range(nbootstrapping):
diff_onset = np.diff( diff_onset = np.diff(np.sort(flatten(nrecording_centered_onset_chirps)))
np.sort(flatten(nrecording_centered_onset_chirps)))
diff_offset = np.diff( diff_offset = np.diff(
np.sort(flatten(nrecording_centered_offset_chirps))) np.sort(flatten(nrecording_centered_offset_chirps))
)
diff_physical = np.diff( diff_physical = np.diff(
np.sort(flatten(nrecording_centered_physical_chirps))) np.sort(flatten(nrecording_centered_physical_chirps))
)
np.random.shuffle(diff_onset) np.random.shuffle(diff_onset)
shuffled_onset = np.cumsum(diff_onset) shuffled_onset = np.cumsum(diff_onset)
@ -345,9 +382,11 @@ def main(datapath: str):
np.random.shuffle(diff_physical) np.random.shuffle(diff_physical)
shuffled_physical = np.cumsum(diff_physical) shuffled_physical = np.cumsum(diff_physical)
kde_onset (acausal_kde1d(shuffled_onset, time, width))/(27*100) kde_onset(acausal_kde1d(shuffled_onset, time, width)) / (27 * 100)
kde_offset = (acausal_kde1d(shuffled_offset, time, width))/(27*100) kde_offset = (acausal_kde1d(shuffled_offset, time, width)) / (27 * 100)
kde_physical = (acausal_kde1d(shuffled_physical, time, width))/(27*100) kde_physical = (acausal_kde1d(shuffled_physical, time, width)) / (
27 * 100
)
bootstrap_onset.append(kde_onset) bootstrap_onset.append(kde_onset)
bootstrap_offset.append(kde_offset) bootstrap_offset.append(kde_offset)
@ -355,11 +394,14 @@ def main(datapath: str):
# New shuffle approach q5, q50, q95 # New shuffle approach q5, q50, q95
onset_q5, onset_median, onset_q95 = np.percentile( onset_q5, onset_median, onset_q95 = np.percentile(
bootstrap_onset, [5, 50, 95], axis=0) bootstrap_onset, [5, 50, 95], axis=0
)
offset_q5, offset_median, offset_q95 = np.percentile( offset_q5, offset_median, offset_q95 = np.percentile(
bootstrap_offset, [5, 50, 95], axis=0) bootstrap_offset, [5, 50, 95], axis=0
)
physical_q5, physical_median, physical_q95 = np.percentile( physical_q5, physical_median, physical_q95 = np.percentile(
bootstrap_physical, [5, 50, 95], axis=0) bootstrap_physical, [5, 50, 95], axis=0
)
# vstack um 1. Dim zu cutten # vstack um 1. Dim zu cutten
# nrecording_shuffled_convolved_onset_chirps = np.vstack(nrecording_shuffled_convolved_onset_chirps) # nrecording_shuffled_convolved_onset_chirps = np.vstack(nrecording_shuffled_convolved_onset_chirps)
@ -378,45 +420,66 @@ def main(datapath: str):
# Flatten event timestamps # Flatten event timestamps
all_onsets = np.concatenate( all_onsets = np.concatenate(
nrecording_chasing_onsets).ravel() # not centered nrecording_chasing_onsets
).ravel() # not centered
all_offsets = np.concatenate( all_offsets = np.concatenate(
nrecording_chasing_offsets).ravel() # not centered nrecording_chasing_offsets
all_physicals = np.concatenate( ).ravel() # not centered
nrecording_physicals).ravel() # not centered all_physicals = np.concatenate(nrecording_physicals).ravel() # not centered
# Flatten all chirps around events # Flatten all chirps around events
all_onset_chirps = np.concatenate( all_onset_chirps = np.concatenate(
nrecording_centered_onset_chirps).ravel() # centered nrecording_centered_onset_chirps
).ravel() # centered
all_offset_chirps = np.concatenate( all_offset_chirps = np.concatenate(
nrecording_centered_offset_chirps).ravel() # centered nrecording_centered_offset_chirps
).ravel() # centered
all_physical_chirps = np.concatenate( all_physical_chirps = np.concatenate(
nrecording_centered_physical_chirps).ravel() # centered nrecording_centered_physical_chirps
).ravel() # centered
# Convolute all chirps # Convolute all chirps
# Divide by total number of each event over all recordings # Divide by total number of each event over all recordings
all_onset_chirps_convolved = (acausal_kde1d( all_onset_chirps_convolved = (
all_onset_chirps, time, width)) / len(all_onsets) acausal_kde1d(all_onset_chirps, time, width)
all_offset_chirps_convolved = (acausal_kde1d( ) / len(all_onsets)
all_offset_chirps, time, width)) / len(all_offsets) all_offset_chirps_convolved = (
all_physical_chirps_convolved = (acausal_kde1d( acausal_kde1d(all_offset_chirps, time, width)
all_physical_chirps, time, width)) / len(all_physicals) ) / len(all_offsets)
all_physical_chirps_convolved = (
acausal_kde1d(all_physical_chirps, time, width)
) / len(all_physicals)
# Plot all events with all shuffled # Plot all events with all shuffled
fig, ax = plt.subplots(1, 3, figsize=( fig, ax = plt.subplots(
28*ps.cm, 16*ps.cm, ), constrained_layout=True, sharey='all') 1,
3,
figsize=(
28 * ps.cm,
16 * ps.cm,
),
constrained_layout=True,
sharey="all",
)
# offsets = np.arange(1,28,1) # offsets = np.arange(1,28,1)
ax[0].set_xlabel('Time[s]') ax[0].set_xlabel("Time[s]")
# Plot chasing onsets # Plot chasing onsets
ax[0].set_ylabel('Chirp rate [Hz]') ax[0].set_ylabel("Chirp rate [Hz]")
ax[0].plot(time, all_onset_chirps_convolved, color=ps.yellow, zorder=2) ax[0].plot(time, all_onset_chirps_convolved, color=ps.yellow, zorder=2)
ax0 = ax[0].twinx() ax0 = ax[0].twinx()
nrecording_centered_onset_chirps = np.asarray( nrecording_centered_onset_chirps = np.asarray(
nrecording_centered_onset_chirps, dtype=object) nrecording_centered_onset_chirps, dtype=object
ax0.eventplot(np.array(nrecording_centered_onset_chirps), )
linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1) ax0.eventplot(
ax0.vlines(0, 0, 1.5, ps.white, 'dashed') np.array(nrecording_centered_onset_chirps),
ax[0].set_zorder(ax0.get_zorder()+1) linelengths=0.5,
colors=ps.gray,
alpha=0.25,
zorder=1,
)
ax0.vlines(0, 0, 1.5, ps.white, "dashed")
ax[0].set_zorder(ax0.get_zorder() + 1)
ax[0].patch.set_visible(False) ax[0].patch.set_visible(False)
ax0.set_yticklabels([]) ax0.set_yticklabels([])
ax0.set_yticks([]) ax0.set_yticks([])
@ -426,15 +489,21 @@ def main(datapath: str):
ax[0].plot(time, onset_median, color=ps.black) ax[0].plot(time, onset_median, color=ps.black)
# Plot chasing offets # Plot chasing offets
ax[1].set_xlabel('Time[s]') ax[1].set_xlabel("Time[s]")
ax[1].plot(time, all_offset_chirps_convolved, color=ps.orange, zorder=2) ax[1].plot(time, all_offset_chirps_convolved, color=ps.orange, zorder=2)
ax1 = ax[1].twinx() ax1 = ax[1].twinx()
nrecording_centered_offset_chirps = np.asarray( nrecording_centered_offset_chirps = np.asarray(
nrecording_centered_offset_chirps, dtype=object) nrecording_centered_offset_chirps, dtype=object
ax1.eventplot(np.array(nrecording_centered_offset_chirps), )
linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1) ax1.eventplot(
ax1.vlines(0, 0, 1.5, ps.white, 'dashed') np.array(nrecording_centered_offset_chirps),
ax[1].set_zorder(ax1.get_zorder()+1) linelengths=0.5,
colors=ps.gray,
alpha=0.25,
zorder=1,
)
ax1.vlines(0, 0, 1.5, ps.white, "dashed")
ax[1].set_zorder(ax1.get_zorder() + 1)
ax[1].patch.set_visible(False) ax[1].patch.set_visible(False)
ax1.set_yticklabels([]) ax1.set_yticklabels([])
ax1.set_yticks([]) ax1.set_yticks([])
@ -444,24 +513,31 @@ def main(datapath: str):
ax[1].plot(time, offset_median, color=ps.black) ax[1].plot(time, offset_median, color=ps.black)
# Plot physical contacts # Plot physical contacts
ax[2].set_xlabel('Time[s]') ax[2].set_xlabel("Time[s]")
ax[2].plot(time, all_physical_chirps_convolved, color=ps.maroon, zorder=2) ax[2].plot(time, all_physical_chirps_convolved, color=ps.maroon, zorder=2)
ax2 = ax[2].twinx() ax2 = ax[2].twinx()
nrecording_centered_physical_chirps = np.asarray( nrecording_centered_physical_chirps = np.asarray(
nrecording_centered_physical_chirps, dtype=object) nrecording_centered_physical_chirps, dtype=object
ax2.eventplot(np.array(nrecording_centered_physical_chirps), )
linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1) ax2.eventplot(
ax2.vlines(0, 0, 1.5, ps.white, 'dashed') np.array(nrecording_centered_physical_chirps),
ax[2].set_zorder(ax2.get_zorder()+1) linelengths=0.5,
colors=ps.gray,
alpha=0.25,
zorder=1,
)
ax2.vlines(0, 0, 1.5, ps.white, "dashed")
ax[2].set_zorder(ax2.get_zorder() + 1)
ax[2].patch.set_visible(False) ax[2].patch.set_visible(False)
ax2.set_yticklabels([]) ax2.set_yticklabels([])
ax2.set_yticks([]) ax2.set_yticks([])
# ax[2].fill_between(time, shuffled_q5_physical, shuffled_q95_physical, color=ps.gray, alpha=0.5) # ax[2].fill_between(time, shuffled_q5_physical, shuffled_q95_physical, color=ps.gray, alpha=0.5)
# ax[2].plot(time, shuffled_median_physical, ps.black) # ax[2].plot(time, shuffled_median_physical, ps.black)
ax[2].fill_between(time, physical_q5, physical_q95, ax[2].fill_between(
color=ps.gray, alpha=0.5) time, physical_q5, physical_q95, color=ps.gray, alpha=0.5
)
ax[2].plot(time, physical_median, ps.black) ax[2].plot(time, physical_median, ps.black)
fig.suptitle('All recordings') fig.suptitle("All recordings")
plt.show() plt.show()
plt.close() plt.close()
@ -587,7 +663,7 @@ def main(datapath: str):
#### Chirps around events, only losers, one recording #### #### Chirps around events, only losers, one recording ####
if __name__ == '__main__': if __name__ == "__main__":
# Path to the data # Path to the data
datapath = '../data/mount_data/' datapath = "../data/mount_data/"
main(datapath) main(datapath)

View File

@ -8,50 +8,51 @@ from IPython import embed
def get_valid_datasets(dataroot): def get_valid_datasets(dataroot):
datasets = sorted(
datasets = sorted([name for name in os.listdir(dataroot) if os.path.isdir( [
os.path.join(dataroot, name))]) name
for name in os.listdir(dataroot)
if os.path.isdir(os.path.join(dataroot, name))
]
)
valid_datasets = [] valid_datasets = []
for dataset in datasets: for dataset in datasets:
path = os.path.join(dataroot, dataset) path = os.path.join(dataroot, dataset)
csv_name = '-'.join(dataset.split('-')[:3]) + '.csv' csv_name = "-".join(dataset.split("-")[:3]) + ".csv"
if os.path.exists(os.path.join(path, csv_name)) is False: if os.path.exists(os.path.join(path, csv_name)) is False:
continue continue
if os.path.exists(os.path.join(path, 'ident_v.npy')) is False: if os.path.exists(os.path.join(path, "ident_v.npy")) is False:
continue continue
ident = np.load(os.path.join(path, 'ident_v.npy')) ident = np.load(os.path.join(path, "ident_v.npy"))
number_of_fish = len(np.unique(ident[~np.isnan(ident)])) number_of_fish = len(np.unique(ident[~np.isnan(ident)]))
if number_of_fish != 2: if number_of_fish != 2:
continue continue
valid_datasets.append(dataset) valid_datasets.append(dataset)
datapaths = [os.path.join(dataroot, dataset) + datapaths = [
'/' for dataset in valid_datasets] os.path.join(dataroot, dataset) + "/" for dataset in valid_datasets
]
return datapaths, valid_datasets return datapaths, valid_datasets
def main(datapaths): def main(datapaths):
for path in datapaths: for path in datapaths:
chirpdetection(path, plot='show') chirpdetection(path, plot="show")
if __name__ == '__main__':
dataroot = '../data/mount_data/'
if __name__ == "__main__":
dataroot = "../data/mount_data/"
datapaths, valid_datasets= get_valid_datasets(dataroot) datapaths, valid_datasets = get_valid_datasets(dataroot)
recs = pd.DataFrame(columns=['recording'], data=valid_datasets) recs = pd.DataFrame(columns=["recording"], data=valid_datasets)
recs.to_csv('../recs.csv', index=False) recs.to_csv("../recs.csv", index=False)
# datapaths = ['../data/mount_data/2020-03-25-10_00/'] # datapaths = ['../data/mount_data/2020-03-25-10_00/']
main(datapaths) main(datapaths)

View File

@ -7,29 +7,41 @@ from pandas import read_csv
ssh = SSHClient() ssh = SSHClient()
ssh.load_system_host_keys() ssh.load_system_host_keys()
ssh.connect(hostname='kraken', ssh.connect(
username='efish', hostname="kraken",
password='fwNix4U', username="efish",
) password="fwNix4U",
)
# SCPCLient takes a paramiko transport as its only argument # SCPCLient takes a paramiko transport as its only argument
scp = SCPClient(ssh.get_transport()) scp = SCPClient(ssh.get_transport())
data = read_csv('../recs.csv') data = read_csv("../recs.csv")
foldernames = data['recording'].values foldernames = data["recording"].values
directory = f'/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/mount_data/' directory = f"/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/mount_data/"
for foldername in foldernames: for foldername in foldernames:
if not os.path.exists(directory + foldername):
if not os.path.exists(directory+foldername): os.makedirs(directory + foldername)
os.makedirs(directory+foldername)
files = [
files = [('-').join(foldername.split('-')[:3])+'.csv','chirp_ids.npy', 'chirps.npy', 'fund_v.npy', 'ident_v.npy', 'idx_v.npy', 'times.npy', 'spec.npy', 'LED_on_time.npy', 'sign_v.npy'] ("-").join(foldername.split("-")[:3]) + ".csv",
"chirp_ids.npy",
"chirps.npy",
"fund_v.npy",
"ident_v.npy",
"idx_v.npy",
"times.npy",
"spec.npy",
"LED_on_time.npy",
"sign_v.npy",
]
for f in files: for f in files:
scp.get(f'/home/efish/behavior/2019_tube_competition/{foldername}/{f}', scp.get(
directory+foldername) f"/home/efish/behavior/2019_tube_competition/{foldername}/{f}",
directory + foldername,
)
scp.close() scp.close()

View File

@ -30,12 +30,12 @@ class Behavior:
""" """
def __init__(self, folder_path: str) -> None: def __init__(self, folder_path: str) -> None:
LED_on_time_BORIS = np.load(
LED_on_time_BORIS = np.load(os.path.join( os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True
folder_path, 'LED_on_time.npy'), allow_pickle=True) )
csv_filename = os.path.split(folder_path[:-1])[-1] csv_filename = os.path.split(folder_path[:-1])[-1]
csv_filename = '-'.join(csv_filename.split('-')[:-1]) + '.csv' csv_filename = "-".join(csv_filename.split("-")[:-1]) + ".csv"
# embed() # embed()
# csv_filename = [f for f in os.listdir( # csv_filename = [f for f in os.listdir(
@ -43,31 +43,39 @@ class Behavior:
# logger.info(f'CSV file: {csv_filename}') # logger.info(f'CSV file: {csv_filename}')
self.dataframe = read_csv(os.path.join(folder_path, csv_filename)) self.dataframe = read_csv(os.path.join(folder_path, csv_filename))
self.chirps = np.load(os.path.join( self.chirps = np.load(
folder_path, 'chirps.npy'), allow_pickle=True) os.path.join(folder_path, "chirps.npy"), allow_pickle=True
self.chirps_ids = np.load(os.path.join( )
folder_path, 'chirp_ids.npy'), allow_pickle=True) self.chirps_ids = np.load(
os.path.join(folder_path, "chirp_ids.npy"), allow_pickle=True
self.ident = np.load(os.path.join( )
folder_path, 'ident_v.npy'), allow_pickle=True)
self.idx = np.load(os.path.join( self.ident = np.load(
folder_path, 'idx_v.npy'), allow_pickle=True) os.path.join(folder_path, "ident_v.npy"), allow_pickle=True
self.freq = np.load(os.path.join( )
folder_path, 'fund_v.npy'), allow_pickle=True) self.idx = np.load(
self.time = np.load(os.path.join( os.path.join(folder_path, "idx_v.npy"), allow_pickle=True
folder_path, "times.npy"), allow_pickle=True) )
self.spec = np.load(os.path.join( self.freq = np.load(
folder_path, "spec.npy"), allow_pickle=True) os.path.join(folder_path, "fund_v.npy"), allow_pickle=True
)
self.time = np.load(
os.path.join(folder_path, "times.npy"), allow_pickle=True
)
self.spec = np.load(
os.path.join(folder_path, "spec.npy"), allow_pickle=True
)
for k, key in enumerate(self.dataframe.keys()): for k, key in enumerate(self.dataframe.keys()):
key = key.lower() key = key.lower()
if ' ' in key: if " " in key:
key = key.replace(' ', '_') key = key.replace(" ", "_")
if '(' in key: if "(" in key:
key = key.replace('(', '') key = key.replace("(", "")
key = key.replace(')', '') key = key.replace(")", "")
setattr(self, key, np.array( setattr(
self.dataframe[self.dataframe.keys()[k]])) self, key, np.array(self.dataframe[self.dataframe.keys()[k]])
)
last_LED_t_BORIS = LED_on_time_BORIS[-1] last_LED_t_BORIS = LED_on_time_BORIS[-1]
real_time_range = self.time[-1] - self.time[0] real_time_range = self.time[-1] - self.time[0]
@ -78,22 +86,19 @@ class Behavior:
def correct_chasing_events( def correct_chasing_events(
category: np.ndarray, category: np.ndarray, timestamps: np.ndarray
timestamps: np.ndarray
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
onset_ids = np.arange(len(category))[category == 0]
offset_ids = np.arange(len(category))[category == 1]
onset_ids = np.arange( wrong_bh = np.arange(len(category))[category != 2][:-1][
len(category))[category == 0] np.diff(category[category != 2]) == 0
offset_ids = np.arange( ]
len(category))[category == 1]
wrong_bh = np.arange(len(category))[
category != 2][:-1][np.diff(category[category != 2]) == 0]
if category[category != 2][-1] == 0: if category[category != 2][-1] == 0:
wrong_bh = np.append( wrong_bh = np.append(
wrong_bh, wrong_bh, np.arange(len(category))[category != 2][-1]
np.arange(len(category))[category != 2][-1]) )
if onset_ids[0] > offset_ids[0]: if onset_ids[0] > offset_ids[0]:
offset_ids = np.delete(offset_ids, 0) offset_ids = np.delete(offset_ids, 0)
@ -103,18 +108,16 @@ def correct_chasing_events(
category = np.delete(category, wrong_bh) category = np.delete(category, wrong_bh)
timestamps = np.delete(timestamps, wrong_bh) timestamps = np.delete(timestamps, wrong_bh)
new_onset_ids = np.arange( new_onset_ids = np.arange(len(category))[category == 0]
len(category))[category == 0] new_offset_ids = np.arange(len(category))[category == 1]
new_offset_ids = np.arange(
len(category))[category == 1]
# Check whether on- or offset is longer and calculate length difference # Check whether on- or offset is longer and calculate length difference
if len(new_onset_ids) > len(new_offset_ids): if len(new_onset_ids) > len(new_offset_ids):
embed() embed()
logger.warning('Onsets are greater than offsets') logger.warning("Onsets are greater than offsets")
elif len(new_onset_ids) < len(new_offset_ids): elif len(new_onset_ids) < len(new_offset_ids):
logger.warning('Offsets are greater than onsets') logger.warning("Offsets are greater than onsets")
elif len(new_onset_ids) == len(new_offset_ids): elif len(new_onset_ids) == len(new_offset_ids):
# logger.info('Chasing events are equal') # logger.info('Chasing events are equal')
pass pass
@ -130,13 +133,11 @@ def center_chirps(
# dt: float, # dt: float,
# width: float, # width: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
event_chirps = [] # chirps that are in specified window around event event_chirps = [] # chirps that are in specified window around event
# timestamps of chirps around event centered on the event timepoint # timestamps of chirps around event centered on the event timepoint
centered_chirps = [] centered_chirps = []
for event_timestamp in events: for event_timestamp in events:
start = event_timestamp - time_before_event start = event_timestamp - time_before_event
stop = event_timestamp + time_after_event stop = event_timestamp + time_after_event
chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)] chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)]
@ -152,7 +153,8 @@ def center_chirps(
if len(centered_chirps) != len(event_chirps): if len(centered_chirps) != len(event_chirps):
raise ValueError( raise ValueError(
'Non centered chirps and centered chirps are not equal') "Non centered chirps and centered chirps are not equal"
)
# time = np.arange(-time_before_event, time_after_event, dt) # time = np.arange(-time_before_event, time_after_event, dt)

View File

@ -23,7 +23,9 @@ def minmaxnorm(data):
return (data - np.min(data)) / (np.max(data) - np.min(data)) return (data - np.min(data)) / (np.max(data) - np.min(data))
def instantaneous_frequency2(signal: np.ndarray, fs: float, interpolation: str = 'linear') -> np.ndarray: def instantaneous_frequency2(
signal: np.ndarray, fs: float, interpolation: str = "linear"
) -> np.ndarray:
""" """
Compute the instantaneous frequency of a periodic signal using zero crossings and resample the frequency using linear Compute the instantaneous frequency of a periodic signal using zero crossings and resample the frequency using linear
or cubic interpolation to match the dimensions of the input array. or cubic interpolation to match the dimensions of the input array.
@ -55,10 +57,10 @@ def instantaneous_frequency2(signal: np.ndarray, fs: float, interpolation: str =
orig_len = len(signal) orig_len = len(signal)
freq = resample(freq, orig_len) freq = resample(freq, orig_len)
if interpolation == 'linear': if interpolation == "linear":
freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq) freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq)
elif interpolation == 'cubic': elif interpolation == "cubic":
freq = resample(freq, orig_len, window='cubic') freq = resample(freq, orig_len, window="cubic")
return freq return freq
@ -67,7 +69,7 @@ def instantaneous_frequency(
signal: np.ndarray, signal: np.ndarray,
samplerate: int, samplerate: int,
smoothing_window: int, smoothing_window: int,
interpolation: str = 'linear', interpolation: str = "linear",
) -> np.ndarray: ) -> np.ndarray:
""" """
Compute the instantaneous frequency of a signal that is approximately Compute the instantaneous frequency of a signal that is approximately
@ -120,11 +122,10 @@ def instantaneous_frequency(
orig_len = len(signal) orig_len = len(signal)
freq = resample(instantaneous_frequency, orig_len) freq = resample(instantaneous_frequency, orig_len)
if interpolation == 'linear': if interpolation == "linear":
freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq) freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq)
elif interpolation == 'cubic': elif interpolation == "cubic":
freq = resample(freq, orig_len, window='cubic') freq = resample(freq, orig_len, window="cubic")
return freq return freq
@ -160,7 +161,6 @@ def purge_duplicates(
group = [timestamps[0]] group = [timestamps[0]]
for i in range(1, len(timestamps)): for i in range(1, len(timestamps)):
# check the difference between current timestamp and previous # check the difference between current timestamp and previous
# timestamp is less than the threshold # timestamp is less than the threshold
if timestamps[i] - timestamps[i - 1] < threshold: if timestamps[i] - timestamps[i - 1] < threshold:
@ -379,7 +379,6 @@ def acausal_kde1d(spikes, time, width):
if __name__ == "__main__": if __name__ == "__main__":
timestamps = [ timestamps = [
[1.2, 1.5, 1.3], [1.2, 1.5, 1.3],
[], [],

View File

@ -35,7 +35,6 @@ class LoadData:
""" """
def __init__(self, datapath: str) -> None: def __init__(self, datapath: str) -> None:
# load raw data # load raw data
self.datapath = datapath self.datapath = datapath
self.file = os.path.join(datapath, "traces-grid1.raw") self.file = os.path.join(datapath, "traces-grid1.raw")

View File

@ -60,9 +60,7 @@ def highpass_filter(
def lowpass_filter( def lowpass_filter(
signal: np.ndarray, signal: np.ndarray, samplerate: float, cutoff: float
samplerate: float,
cutoff: float
) -> np.ndarray: ) -> np.ndarray:
"""Lowpass filter a signal. """Lowpass filter a signal.
@ -86,10 +84,9 @@ def lowpass_filter(
return filtered_signal return filtered_signal
def envelope(signal: np.ndarray, def envelope(
samplerate: float, signal: np.ndarray, samplerate: float, cutoff_frequency: float
cutoff_frequency: float ) -> np.ndarray:
) -> np.ndarray:
"""Calculate the envelope of a signal using a lowpass filter. """Calculate the envelope of a signal using a lowpass filter.
Parameters Parameters

View File

@ -2,12 +2,13 @@ import logging
def makeLogger(name: str): def makeLogger(name: str):
# create logger formats for file and terminal # create logger formats for file and terminal
file_formatter = logging.Formatter( file_formatter = logging.Formatter(
"[ %(levelname)s ] ~ %(asctime)s ~ %(module)s.%(funcName)s: %(message)s") "[ %(levelname)s ] ~ %(asctime)s ~ %(module)s.%(funcName)s: %(message)s"
)
console_formatter = logging.Formatter( console_formatter = logging.Formatter(
"[ %(levelname)s ] in %(module)s.%(funcName)s: %(message)s") "[ %(levelname)s ] in %(module)s.%(funcName)s: %(message)s"
)
# create logging file if loglevel is debug # create logging file if loglevel is debug
file_handler = logging.FileHandler(f"gridtools_log.log", mode="w") file_handler = logging.FileHandler(f"gridtools_log.log", mode="w")
@ -29,7 +30,6 @@ def makeLogger(name: str):
if __name__ == "__main__": if __name__ == "__main__":
# initiate logger # initiate logger
mylogger = makeLogger(__name__) mylogger = makeLogger(__name__)

View File

@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap
def PlotStyle() -> None: def PlotStyle() -> None:
class style: class style:
# lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8) # lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8)
# units # units
@ -76,13 +75,15 @@ def PlotStyle() -> None:
va="center", va="center",
zorder=1000, zorder=1000,
bbox=dict( bbox=dict(
boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1 boxstyle=f"circle, pad={padding}",
fc="white",
ec="black",
lw=1,
), ),
) )
@classmethod @classmethod
def fade_cmap(cls, cmap): def fade_cmap(cls, cmap):
my_cmap = cmap(np.arange(cmap.N)) my_cmap = cmap(np.arange(cmap.N))
my_cmap[:, -1] = np.linspace(0, 1, cmap.N) my_cmap[:, -1] = np.linspace(0, 1, cmap.N)
my_cmap = ListedColormap(my_cmap) my_cmap = ListedColormap(my_cmap)
@ -295,7 +296,6 @@ def PlotStyle() -> None:
if __name__ == "__main__": if __name__ == "__main__":
s = PlotStyle() s = PlotStyle()
import matplotlib.cbook as cbook import matplotlib.cbook as cbook
@ -347,7 +347,8 @@ if __name__ == "__main__":
for ax in axs: for ax in axs:
ax.yaxis.grid(True) ax.yaxis.grid(True)
ax.set_xticks( ax.set_xticks(
[y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"] [y + 1 for y in range(len(all_data))],
labels=["x1", "x2", "x3", "x4"],
) )
ax.set_xlabel("Four separate samples") ax.set_xlabel("Four separate samples")
ax.set_ylabel("Observed values") ax.set_ylabel("Observed values")
@ -396,7 +397,10 @@ if __name__ == "__main__":
grid = np.random.rand(4, 4) grid = np.random.rand(4, 4)
fig, axs = plt.subplots( fig, axs = plt.subplots(
nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []} nrows=3,
ncols=6,
figsize=(9, 6),
subplot_kw={"xticks": [], "yticks": []},
) )
for ax, interp_method in zip(axs.flat, methods): for ax, interp_method in zip(axs.flat, methods):

View File

@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap
def PlotStyle() -> None: def PlotStyle() -> None:
class style: class style:
# lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8) # lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8)
# units # units
@ -76,13 +75,15 @@ def PlotStyle() -> None:
va="center", va="center",
zorder=1000, zorder=1000,
bbox=dict( bbox=dict(
boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1 boxstyle=f"circle, pad={padding}",
fc="white",
ec="black",
lw=1,
), ),
) )
@classmethod @classmethod
def fade_cmap(cls, cmap): def fade_cmap(cls, cmap):
my_cmap = cmap(np.arange(cmap.N)) my_cmap = cmap(np.arange(cmap.N))
my_cmap[:, -1] = np.linspace(0, 1, cmap.N) my_cmap[:, -1] = np.linspace(0, 1, cmap.N)
my_cmap = ListedColormap(my_cmap) my_cmap = ListedColormap(my_cmap)
@ -295,7 +296,6 @@ def PlotStyle() -> None:
if __name__ == "__main__": if __name__ == "__main__":
s = PlotStyle() s = PlotStyle()
import matplotlib.cbook as cbook import matplotlib.cbook as cbook
@ -347,7 +347,8 @@ if __name__ == "__main__":
for ax in axs: for ax in axs:
ax.yaxis.grid(True) ax.yaxis.grid(True)
ax.set_xticks( ax.set_xticks(
[y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"] [y + 1 for y in range(len(all_data))],
labels=["x1", "x2", "x3", "x4"],
) )
ax.set_xlabel("Four separate samples") ax.set_xlabel("Four separate samples")
ax.set_ylabel("Observed values") ax.set_ylabel("Observed values")
@ -396,7 +397,10 @@ if __name__ == "__main__":
grid = np.random.rand(4, 4) grid = np.random.rand(4, 4)
fig, axs = plt.subplots( fig, axs = plt.subplots(
nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []} nrows=3,
ncols=6,
figsize=(9, 6),
subplot_kw={"xticks": [], "yticks": []},
) )
for ax, interp_method in zip(axs.flat, methods): for ax, interp_method in zip(axs.flat, methods):

View File

@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap
def PlotStyle() -> None: def PlotStyle() -> None:
class style: class style:
# lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8) # lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8)
# units # units
@ -76,13 +75,15 @@ def PlotStyle() -> None:
va="center", va="center",
zorder=1000, zorder=1000,
bbox=dict( bbox=dict(
boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1 boxstyle=f"circle, pad={padding}",
fc="white",
ec="black",
lw=1,
), ),
) )
@classmethod @classmethod
def fade_cmap(cls, cmap): def fade_cmap(cls, cmap):
my_cmap = cmap(np.arange(cmap.N)) my_cmap = cmap(np.arange(cmap.N))
my_cmap[:, -1] = np.linspace(0, 1, cmap.N) my_cmap[:, -1] = np.linspace(0, 1, cmap.N)
my_cmap = ListedColormap(my_cmap) my_cmap = ListedColormap(my_cmap)
@ -295,7 +296,6 @@ def PlotStyle() -> None:
if __name__ == "__main__": if __name__ == "__main__":
s = PlotStyle() s = PlotStyle()
import matplotlib.cbook as cbook import matplotlib.cbook as cbook
@ -347,7 +347,8 @@ if __name__ == "__main__":
for ax in axs: for ax in axs:
ax.yaxis.grid(True) ax.yaxis.grid(True)
ax.set_xticks( ax.set_xticks(
[y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"] [y + 1 for y in range(len(all_data))],
labels=["x1", "x2", "x3", "x4"],
) )
ax.set_xlabel("Four separate samples") ax.set_xlabel("Four separate samples")
ax.set_ylabel("Observed values") ax.set_ylabel("Observed values")
@ -396,7 +397,10 @@ if __name__ == "__main__":
grid = np.random.rand(4, 4) grid = np.random.rand(4, 4)
fig, axs = plt.subplots( fig, axs = plt.subplots(
nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []} nrows=3,
ncols=6,
figsize=(9, 6),
subplot_kw={"xticks": [], "yticks": []},
) )
for ax, interp_method in zip(axs.flat, methods): for ax, interp_method in zip(axs.flat, methods):

View File

@ -37,7 +37,7 @@ def create_chirp(
ck = 0 ck = 0
csig = 0.5 * chirpduration / np.power(2.0 * np.log(10.0), 0.5 / kurtosis) csig = 0.5 * chirpduration / np.power(2.0 * np.log(10.0), 0.5 / kurtosis)
#csig = csig*-1 # csig = csig*-1
for k, t in enumerate(time): for k, t in enumerate(time):
a = 1.0 a = 1.0
f = eodf f = eodf

View File

@ -16,26 +16,25 @@ logger = makeLogger(__name__)
def get_chirp_winner_loser(folder_name, Behavior, order_meta_df): def get_chirp_winner_loser(folder_name, Behavior, order_meta_df):
foldername = folder_name.split("/")[-2]
foldername = folder_name.split('/')[-2] winner_row = order_meta_df[order_meta_df["recording"] == foldername]
winner_row = order_meta_df[order_meta_df['recording'] == foldername] winner = winner_row["winner"].values[0].astype(int)
winner = winner_row['winner'].values[0].astype(int) winner_fish1 = winner_row["fish1"].values[0].astype(int)
winner_fish1 = winner_row['fish1'].values[0].astype(int) winner_fish2 = winner_row["fish2"].values[0].astype(int)
winner_fish2 = winner_row['fish2'].values[0].astype(int)
if winner > 0: if winner > 0:
if winner == winner_fish1: if winner == winner_fish1:
winner_fish_id = winner_row['rec_id1'].values[0] winner_fish_id = winner_row["rec_id1"].values[0]
loser_fish_id = winner_row['rec_id2'].values[0] loser_fish_id = winner_row["rec_id2"].values[0]
elif winner == winner_fish2: elif winner == winner_fish2:
winner_fish_id = winner_row['rec_id2'].values[0] winner_fish_id = winner_row["rec_id2"].values[0]
loser_fish_id = winner_row['rec_id1'].values[0] loser_fish_id = winner_row["rec_id1"].values[0]
chirp_winner = len( chirp_winner = len(
Behavior.chirps[Behavior.chirps_ids == winner_fish_id]) Behavior.chirps[Behavior.chirps_ids == winner_fish_id]
chirp_loser = len( )
Behavior.chirps[Behavior.chirps_ids == loser_fish_id]) chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
return chirp_winner, chirp_loser return chirp_winner, chirp_loser
else: else:
@ -43,24 +42,24 @@ def get_chirp_winner_loser(folder_name, Behavior, order_meta_df):
def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df): def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df):
foldername = folder_name.split("/")[-2]
foldername = folder_name.split('/')[-2] folder_row = order_meta_df[order_meta_df["recording"] == foldername]
folder_row = order_meta_df[order_meta_df['recording'] == foldername] fish1 = folder_row["fish1"].values[0].astype(int)
fish1 = folder_row['fish1'].values[0].astype(int) fish2 = folder_row["fish2"].values[0].astype(int)
fish2 = folder_row['fish2'].values[0].astype(int) winner = folder_row["winner"].values[0].astype(int)
winner = folder_row['winner'].values[0].astype(int)
groub = folder_row["group"].values[0].astype(int)
groub = folder_row['group'].values[0].astype(int) size_fish1_row = id_meta_df[
size_fish1_row = id_meta_df[(id_meta_df['group'] == groub) & ( (id_meta_df["group"] == groub) & (id_meta_df["fish"] == fish1)
id_meta_df['fish'] == fish1)] ]
size_fish2_row = id_meta_df[(id_meta_df['group'] == groub) & ( size_fish2_row = id_meta_df[
id_meta_df['fish'] == fish2)] (id_meta_df["group"] == groub) & (id_meta_df["fish"] == fish2)
]
size_winners = [size_fish1_row[col].values[0]
for col in ['l1', 'l2', 'l3']] size_winners = [size_fish1_row[col].values[0] for col in ["l1", "l2", "l3"]]
size_fish1 = np.nanmean(size_winners) size_fish1 = np.nanmean(size_winners)
size_losers = [size_fish2_row[col].values[0] for col in ['l1', 'l2', 'l3']] size_losers = [size_fish2_row[col].values[0] for col in ["l1", "l2", "l3"]]
size_fish2 = np.nanmean(size_losers) size_fish2 = np.nanmean(size_losers)
if winner == fish1: if winner == fish1:
@ -75,8 +74,8 @@ def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df):
size_diff_bigger = 0 size_diff_bigger = 0
size_diff_smaller = 0 size_diff_smaller = 0
winner_fish_id = folder_row['rec_id1'].values[0] winner_fish_id = folder_row["rec_id1"].values[0]
loser_fish_id = folder_row['rec_id2'].values[0] loser_fish_id = folder_row["rec_id2"].values[0]
elif winner == fish2: elif winner == fish2:
if size_fish2 > size_fish1: if size_fish2 > size_fish1:
@ -90,39 +89,39 @@ def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df):
size_diff_bigger = 0 size_diff_bigger = 0
size_diff_smaller = 0 size_diff_smaller = 0
winner_fish_id = folder_row['rec_id2'].values[0] winner_fish_id = folder_row["rec_id2"].values[0]
loser_fish_id = folder_row['rec_id1'].values[0] loser_fish_id = folder_row["rec_id1"].values[0]
else: else:
size_diff_bigger = np.nan size_diff_bigger = np.nan
size_diff_smaller = np.nan size_diff_smaller = np.nan
winner_fish_id = np.nan winner_fish_id = np.nan
loser_fish_id = np.nan loser_fish_id = np.nan
return size_diff_bigger, size_diff_smaller, winner_fish_id, loser_fish_id return (
size_diff_bigger,
size_diff_smaller,
winner_fish_id,
loser_fish_id,
)
chirp_winner = len( chirp_winner = len(Behavior.chirps[Behavior.chirps_ids == winner_fish_id])
Behavior.chirps[Behavior.chirps_ids == winner_fish_id]) chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
chirp_loser = len(
Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
return size_diff_bigger, chirp_winner, size_diff_smaller, chirp_loser return size_diff_bigger, chirp_winner, size_diff_smaller, chirp_loser
def get_chirp_freq(folder_name, Behavior, order_meta_df): def get_chirp_freq(folder_name, Behavior, order_meta_df):
foldername = folder_name.split("/")[-2]
folder_row = order_meta_df[order_meta_df["recording"] == foldername]
fish1 = folder_row["fish1"].values[0].astype(int)
fish2 = folder_row["fish2"].values[0].astype(int)
foldername = folder_name.split('/')[-2] fish1_freq = folder_row["rec_id1"].values[0].astype(int)
folder_row = order_meta_df[order_meta_df['recording'] == foldername] fish2_freq = folder_row["rec_id2"].values[0].astype(int)
fish1 = folder_row['fish1'].values[0].astype(int)
fish2 = folder_row['fish2'].values[0].astype(int)
fish1_freq = folder_row['rec_id1'].values[0].astype(int)
fish2_freq = folder_row['rec_id2'].values[0].astype(int)
chirp_freq_fish1 = np.nanmedian( chirp_freq_fish1 = np.nanmedian(Behavior.freq[Behavior.ident == fish1_freq])
Behavior.freq[Behavior.ident == fish1_freq]) chirp_freq_fish2 = np.nanmedian(Behavior.freq[Behavior.ident == fish2_freq])
chirp_freq_fish2 = np.nanmedian( winner = folder_row["winner"].values[0].astype(int)
Behavior.freq[Behavior.ident == fish2_freq])
winner = folder_row['winner'].values[0].astype(int)
if winner == fish1: if winner == fish1:
# if chirp_freq_fish1 > chirp_freq_fish2: # if chirp_freq_fish1 > chirp_freq_fish2:
@ -138,9 +137,9 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df):
# winner_fish_id = np.nan # winner_fish_id = np.nan
# loser_fish_id = np.nan # loser_fish_id = np.nan
winner_fish_id = folder_row['rec_id1'].values[0] winner_fish_id = folder_row["rec_id1"].values[0]
winner_fish_freq = chirp_freq_fish1 winner_fish_freq = chirp_freq_fish1
loser_fish_id = folder_row['rec_id2'].values[0] loser_fish_id = folder_row["rec_id2"].values[0]
loser_fish_freq = chirp_freq_fish2 loser_fish_freq = chirp_freq_fish2
elif winner == fish2: elif winner == fish2:
@ -157,9 +156,9 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df):
# winner_fish_id = np.nan # winner_fish_id = np.nan
# loser_fish_id = np.nan # loser_fish_id = np.nan
winner_fish_id = folder_row['rec_id2'].values[0] winner_fish_id = folder_row["rec_id2"].values[0]
winner_fish_freq = chirp_freq_fish2 winner_fish_freq = chirp_freq_fish2
loser_fish_id = folder_row['rec_id1'].values[0] loser_fish_id = folder_row["rec_id1"].values[0]
loser_fish_freq = chirp_freq_fish1 loser_fish_freq = chirp_freq_fish1
else: else:
winner_fish_freq = np.nan winner_fish_freq = np.nan
@ -168,25 +167,25 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df):
loser_fish_id = np.nan loser_fish_id = np.nan
return winner_fish_freq, winner_fish_id, loser_fish_freq, loser_fish_id return winner_fish_freq, winner_fish_id, loser_fish_freq, loser_fish_id
chirp_winner = len( chirp_winner = len(Behavior.chirps[Behavior.chirps_ids == winner_fish_id])
Behavior.chirps[Behavior.chirps_ids == winner_fish_id]) chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
chirp_loser = len(
Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
return winner_fish_freq, chirp_winner, loser_fish_freq, chirp_loser return winner_fish_freq, chirp_winner, loser_fish_freq, chirp_loser
def main(datapath: str): def main(datapath: str):
foldernames = [ foldernames = [
datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)] datapath + x + "/"
for x in os.listdir(datapath)
if os.path.isdir(datapath + x)
]
foldernames, _ = get_valid_datasets(datapath) foldernames, _ = get_valid_datasets(datapath)
path_order_meta = ( path_order_meta = ("/").join(
'/').join(foldernames[0].split('/')[:-2]) + '/order_meta.csv' foldernames[0].split("/")[:-2]
) + "/order_meta.csv"
order_meta_df = read_csv(path_order_meta) order_meta_df = read_csv(path_order_meta)
order_meta_df['recording'] = order_meta_df['recording'].str[1:-1] order_meta_df["recording"] = order_meta_df["recording"].str[1:-1]
path_id_meta = ( path_id_meta = ("/").join(foldernames[0].split("/")[:-2]) + "/id_meta.csv"
'/').join(foldernames[0].split('/')[:-2]) + '/id_meta.csv'
id_meta_df = read_csv(path_id_meta) id_meta_df = read_csv(path_id_meta)
chirps_winner = [] chirps_winner = []
@ -202,10 +201,9 @@ def main(datapath: str):
freq_chirps_winner = [] freq_chirps_winner = []
freq_chirps_loser = [] freq_chirps_loser = []
for foldername in foldernames: for foldername in foldernames:
# behabvior is pandas dataframe with all the data # behabvior is pandas dataframe with all the data
if foldername == '../data/mount_data/2020-05-12-10_00/': if foldername == "../data/mount_data/2020-05-12-10_00/":
continue continue
bh = Behavior(foldername) bh = Behavior(foldername)
# chirps are not sorted in time (presumably due to prior groupings) # chirps are not sorted in time (presumably due to prior groupings)
@ -217,15 +215,24 @@ def main(datapath: str):
category, timestamps = correct_chasing_events(category, timestamps) category, timestamps = correct_chasing_events(category, timestamps)
winner_chirp, loser_chirp = get_chirp_winner_loser( winner_chirp, loser_chirp = get_chirp_winner_loser(
foldername, bh, order_meta_df) foldername, bh, order_meta_df
)
chirps_winner.append(winner_chirp) chirps_winner.append(winner_chirp)
chirps_loser.append(loser_chirp) chirps_loser.append(loser_chirp)
size_diff_bigger, chirp_winner, size_diff_smaller, chirp_loser = get_chirp_size( (
foldername, bh, order_meta_df, id_meta_df) size_diff_bigger,
chirp_winner,
size_diff_smaller,
chirp_loser,
) = get_chirp_size(foldername, bh, order_meta_df, id_meta_df)
freq_winner, chirp_freq_winner, freq_loser, chirp_freq_loser = get_chirp_freq( (
foldername, bh, order_meta_df) freq_winner,
chirp_freq_winner,
freq_loser,
chirp_freq_loser,
) = get_chirp_freq(foldername, bh, order_meta_df)
freq_diffs_higher.append(freq_winner) freq_diffs_higher.append(freq_winner)
freq_diffs_lower.append(freq_loser) freq_diffs_lower.append(freq_loser)
@ -242,82 +249,124 @@ def main(datapath: str):
pearsonr(size_diffs_winner, size_chirps_winner) pearsonr(size_diffs_winner, size_chirps_winner)
pearsonr(size_diffs_loser, size_chirps_loser) pearsonr(size_diffs_loser, size_chirps_loser)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=( fig, (ax1, ax2, ax3) = plt.subplots(
21*ps.cm, 7*ps.cm), width_ratios=[1, 0.8, 0.8], sharey=True) 1,
plt.subplots_adjust(left=0.11, right=0.948, top=0.86, 3,
wspace=0.343, bottom=0.198) figsize=(21 * ps.cm, 7 * ps.cm),
width_ratios=[1, 0.8, 0.8],
sharey=True,
)
plt.subplots_adjust(
left=0.11, right=0.948, top=0.86, wspace=0.343, bottom=0.198
)
scatterwinner = 1.15 scatterwinner = 1.15
scatterloser = 1.85 scatterloser = 1.85
chirps_winner = np.asarray(chirps_winner)[~np.isnan(chirps_winner)] chirps_winner = np.asarray(chirps_winner)[~np.isnan(chirps_winner)]
chirps_loser = np.asarray(chirps_loser)[~np.isnan(chirps_loser)] chirps_loser = np.asarray(chirps_loser)[~np.isnan(chirps_loser)]
embed() embed()
exit() exit()
freq_diffs_higher = np.asarray( freq_diffs_higher = np.asarray(freq_diffs_higher)[
freq_diffs_higher)[~np.isnan(freq_diffs_higher)] ~np.isnan(freq_diffs_higher)
freq_diffs_lower = np.asarray(freq_diffs_lower)[ ]
~np.isnan(freq_diffs_lower)] freq_diffs_lower = np.asarray(freq_diffs_lower)[~np.isnan(freq_diffs_lower)]
freq_chirps_winner = np.asarray( freq_chirps_winner = np.asarray(freq_chirps_winner)[
freq_chirps_winner)[~np.isnan(freq_chirps_winner)] ~np.isnan(freq_chirps_winner)
freq_chirps_loser = np.asarray( ]
freq_chirps_loser)[~np.isnan(freq_chirps_loser)] freq_chirps_loser = np.asarray(freq_chirps_loser)[
~np.isnan(freq_chirps_loser)
]
stat = wilcoxon(chirps_winner, chirps_loser) stat = wilcoxon(chirps_winner, chirps_loser)
print(stat) print(stat)
winner_color = ps.gblue2 winner_color = ps.gblue2
loser_color = ps.gblue1 loser_color = ps.gblue1
bplot1 = ax1.boxplot(chirps_winner, positions=[ bplot1 = ax1.boxplot(
0.9], showfliers=False, patch_artist=True) chirps_winner, positions=[0.9], showfliers=False, patch_artist=True
)
bplot2 = ax1.boxplot(chirps_loser, positions=[
2.1], showfliers=False, patch_artist=True) bplot2 = ax1.boxplot(
chirps_loser, positions=[2.1], showfliers=False, patch_artist=True
ax1.scatter(np.ones(len(chirps_winner)) * )
scatterwinner, chirps_winner, color=winner_color)
ax1.scatter(np.ones(len(chirps_loser)) * ax1.scatter(
scatterloser, chirps_loser, color=loser_color) np.ones(len(chirps_winner)) * scatterwinner,
ax1.set_xticklabels(['Winner', 'Loser']) chirps_winner,
color=winner_color,
ax1.text(0.1, 0.85, f'n={len(chirps_loser)}', )
transform=ax1.transAxes, color=ps.white) ax1.scatter(
np.ones(len(chirps_loser)) * scatterloser,
chirps_loser,
color=loser_color,
)
ax1.set_xticklabels(["Winner", "Loser"])
ax1.text(
0.1,
0.85,
f"n={len(chirps_loser)}",
transform=ax1.transAxes,
color=ps.white,
)
for w, l in zip(chirps_winner, chirps_loser): for w, l in zip(chirps_winner, chirps_loser):
ax1.plot([scatterwinner, scatterloser], [w, l], ax1.plot(
color=ps.white, alpha=0.6, linewidth=1, zorder=-1) [scatterwinner, scatterloser],
ax1.set_ylabel('Chirp counts', color=ps.white) [w, l],
ax1.set_xlabel('Competition outcome', color=ps.white) color=ps.white,
alpha=0.6,
linewidth=1,
zorder=-1,
)
ax1.set_ylabel("Chirp counts", color=ps.white)
ax1.set_xlabel("Competition outcome", color=ps.white)
ps.set_boxplot_color(bplot1, winner_color) ps.set_boxplot_color(bplot1, winner_color)
ps.set_boxplot_color(bplot2, loser_color) ps.set_boxplot_color(bplot2, loser_color)
ax2.scatter(size_diffs_winner, size_chirps_winner, ax2.scatter(
color=winner_color, label='Winner') size_diffs_winner,
ax2.scatter(size_diffs_loser, size_chirps_loser, size_chirps_winner,
color=loser_color, label='Loser') color=winner_color,
label="Winner",
ax2.text(0.05, 0.85, f'n={len(size_chirps_loser)}', )
transform=ax2.transAxes, color=ps.white) ax2.scatter(
size_diffs_loser, size_chirps_loser, color=loser_color, label="Loser"
ax2.set_xlabel('Size difference [cm]') )
ax2.text(
0.05,
0.85,
f"n={len(size_chirps_loser)}",
transform=ax2.transAxes,
color=ps.white,
)
ax2.set_xlabel("Size difference [cm]")
# ax2.set_xticks(np.arange(-10, 10.1, 2)) # ax2.set_xticks(np.arange(-10, 10.1, 2))
ax3.scatter(freq_diffs_higher, freq_chirps_winner, color=winner_color) ax3.scatter(freq_diffs_higher, freq_chirps_winner, color=winner_color)
ax3.scatter(freq_diffs_lower, freq_chirps_loser, color=loser_color) ax3.scatter(freq_diffs_lower, freq_chirps_loser, color=loser_color)
ax3.text(0.1, 0.85, f'n={len(np.asarray(freq_chirps_winner)[~np.isnan(freq_chirps_loser)])}', ax3.text(
transform=ax3.transAxes, color=ps.white) 0.1,
0.85,
f"n={len(np.asarray(freq_chirps_winner)[~np.isnan(freq_chirps_loser)])}",
transform=ax3.transAxes,
color=ps.white,
)
ax3.set_xlabel('EODf [Hz]') ax3.set_xlabel("EODf [Hz]")
handles, labels = ax2.get_legend_handles_labels() handles, labels = ax2.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', fig.legend(
ncol=2, bbox_to_anchor=(0.5, 1.04)) handles, labels, loc="upper center", ncol=2, bbox_to_anchor=(0.5, 1.04)
)
# pearson r # pearson r
plt.savefig('../poster/figs/chirps_winner_loser.pdf') plt.savefig("../poster/figs/chirps_winner_loser.pdf")
plt.show() plt.show()
if __name__ == '__main__': if __name__ == "__main__":
# Path to the data # Path to the data
datapath = '../data/mount_data/' datapath = "../data/mount_data/"
main(datapath) main(datapath)

View File

@ -21,14 +21,16 @@ logger = makeLogger(__name__)
def main(datapath: str): def main(datapath: str):
foldernames = [ foldernames = [
datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)] datapath + x + "/"
for x in os.listdir(datapath)
if os.path.isdir(datapath + x)
]
time_precents = [] time_precents = []
chirps_percents = [] chirps_percents = []
for foldername in foldernames: for foldername in foldernames:
# behabvior is pandas dataframe with all the data # behabvior is pandas dataframe with all the data
if foldername == '../data/mount_data/2020-05-12-10_00/': if foldername == "../data/mount_data/2020-05-12-10_00/":
continue continue
bh = Behavior(foldername) bh = Behavior(foldername)
@ -46,50 +48,70 @@ def main(datapath: str):
chirps_in_chasings = [] chirps_in_chasings = []
for onset, offset in zip(chasing_onset, chasing_offset): for onset, offset in zip(chasing_onset, chasing_offset):
chirps_in_chasing = [ chirps_in_chasing = [
c for c in bh.chirps if (c > onset) & (c < offset)] c for c in bh.chirps if (c > onset) & (c < offset)
]
chirps_in_chasings.append(chirps_in_chasing) chirps_in_chasings.append(chirps_in_chasing)
try: try:
time_chasing = np.sum( time_chasing = np.sum(
chasing_offset[chasing_offset < 3*60*60] - chasing_onset[chasing_onset < 3*60*60]) chasing_offset[chasing_offset < 3 * 60 * 60]
- chasing_onset[chasing_onset < 3 * 60 * 60]
)
except: except:
time_chasing = np.sum( time_chasing = np.sum(
chasing_offset[chasing_offset < 3*60*60] - chasing_onset[chasing_onset < 3*60*60][:-1]) chasing_offset[chasing_offset < 3 * 60 * 60]
- chasing_onset[chasing_onset < 3 * 60 * 60][:-1]
)
time_chasing_percent = (time_chasing/(3*60*60))*100 time_chasing_percent = (time_chasing / (3 * 60 * 60)) * 100
chirps_chasing = np.asarray(flatten(chirps_in_chasings)) chirps_chasing = np.asarray(flatten(chirps_in_chasings))
chirps_chasing_new = chirps_chasing[chirps_chasing < 3*60*60] chirps_chasing_new = chirps_chasing[chirps_chasing < 3 * 60 * 60]
chirps_percent = (len(chirps_chasing_new) / chirps_percent = (
len(bh.chirps[bh.chirps < 3*60*60]))*100 len(chirps_chasing_new) / len(bh.chirps[bh.chirps < 3 * 60 * 60])
) * 100
time_precents.append(time_chasing_percent) time_precents.append(time_chasing_percent)
chirps_percents.append(chirps_percent) chirps_percents.append(chirps_percent)
fig, ax = plt.subplots(1, 1, figsize=(7*ps.cm, 7*ps.cm)) fig, ax = plt.subplots(1, 1, figsize=(7 * ps.cm, 7 * ps.cm))
scatter_time = 1.20 scatter_time = 1.20
scatter_chirps = 1.80 scatter_chirps = 1.80
size = 10 size = 10
bplot1 = ax.boxplot([time_precents, chirps_percents], bplot1 = ax.boxplot(
showfliers=False, patch_artist=True) [time_precents, chirps_percents], showfliers=False, patch_artist=True
)
ps.set_boxplot_color(bplot1, ps.gray) ps.set_boxplot_color(bplot1, ps.gray)
ax.set_xticklabels(['Time \nchasing', 'Chirps \nin chasing']) ax.set_xticklabels(["Time \nchasing", "Chirps \nin chasing"])
ax.set_ylabel('Percent') ax.set_ylabel("Percent")
ax.scatter(np.ones(len(time_precents))*scatter_time, time_precents, ax.scatter(
facecolor=ps.white, s=size) np.ones(len(time_precents)) * scatter_time,
ax.scatter(np.ones(len(chirps_percents))*scatter_chirps, chirps_percents, time_precents,
facecolor=ps.white, s=size) facecolor=ps.white,
s=size,
)
ax.scatter(
np.ones(len(chirps_percents)) * scatter_chirps,
chirps_percents,
facecolor=ps.white,
s=size,
)
for i in range(len(time_precents)): for i in range(len(time_precents)):
ax.plot([scatter_time, scatter_chirps], [time_precents[i], ax.plot(
chirps_percents[i]], alpha=0.6, linewidth=1, color=ps.white) [scatter_time, scatter_chirps],
[time_precents[i], chirps_percents[i]],
ax.text(0.1, 0.9, f'n={len(time_precents)}', transform=ax.transAxes) alpha=0.6,
linewidth=1,
color=ps.white,
)
ax.text(0.1, 0.9, f"n={len(time_precents)}", transform=ax.transAxes)
plt.subplots_adjust(left=0.221, bottom=0.186, right=0.97, top=0.967) plt.subplots_adjust(left=0.221, bottom=0.186, right=0.97, top=0.967)
plt.savefig('../poster/figs/chirps_in_chasing.pdf') plt.savefig("../poster/figs/chirps_in_chasing.pdf")
plt.show() plt.show()
if __name__ == '__main__': if __name__ == "__main__":
# Path to the data # Path to the data
datapath = '../data/mount_data/' datapath = "../data/mount_data/"
main(datapath) main(datapath)

View File

@ -13,6 +13,7 @@ from modules.plotstyle import PlotStyle
from modules.behaviour_handling import Behavior, correct_chasing_events from modules.behaviour_handling import Behavior, correct_chasing_events
from extract_chirps import get_valid_datasets from extract_chirps import get_valid_datasets
ps = PlotStyle() ps = PlotStyle()
logger = makeLogger(__name__) logger = makeLogger(__name__)
@ -20,13 +21,16 @@ logger = makeLogger(__name__)
def main(datapath: str): def main(datapath: str):
foldernames = [ foldernames = [
datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)] datapath + x + "/"
for x in os.listdir(datapath)
if os.path.isdir(datapath + x)
]
foldernames, _ = get_valid_datasets(datapath) foldernames, _ = get_valid_datasets(datapath)
for foldername in foldernames[3:4]: for foldername in foldernames[3:4]:
print(foldername) print(foldername)
# foldername = foldernames[0] # foldername = foldernames[0]
if foldername == '../data/mount_data/2020-05-12-10_00/': if foldername == "../data/mount_data/2020-05-12-10_00/":
continue continue
# behabvior is pandas dataframe with all the data # behabvior is pandas dataframe with all the data
bh = Behavior(foldername) bh = Behavior(foldername)
@ -52,18 +56,43 @@ def main(datapath: str):
exit() exit()
fish1_color = ps.gblue2 fish1_color = ps.gblue2
fish2_color = ps.gblue1 fish2_color = ps.gblue1
fig, ax = plt.subplots(5, 1, figsize=( fig, ax = plt.subplots(
21*ps.cm, 10*ps.cm), height_ratios=[0.5, 0.5, 0.5, 0.2, 6], sharex=True) 5,
1,
figsize=(21 * ps.cm, 10 * ps.cm),
height_ratios=[0.5, 0.5, 0.5, 0.2, 6],
sharex=True,
)
# marker size # marker size
s = 80 s = 80
ax[0].scatter(physical_contact, np.ones( ax[0].scatter(
len(physical_contact)), color=ps.gray, marker='|', s=s) physical_contact,
ax[1].scatter(chasing_onset, np.ones(len(chasing_onset)), np.ones(len(physical_contact)),
color=ps.gray, marker='|', s=s) color=ps.gray,
ax[2].scatter(fish1, np.ones(len(fish1))-0.25, marker="|",
color=fish1_color, marker='|', s=s) s=s,
ax[2].scatter(fish2, np.zeros(len(fish2))+0.25, )
color=fish2_color, marker='|', s=s) ax[1].scatter(
chasing_onset,
np.ones(len(chasing_onset)),
color=ps.gray,
marker="|",
s=s,
)
ax[2].scatter(
fish1,
np.ones(len(fish1)) - 0.25,
color=fish1_color,
marker="|",
s=s,
)
ax[2].scatter(
fish2,
np.zeros(len(fish2)) + 0.25,
color=fish2_color,
marker="|",
s=s,
)
freq_temp = bh.freq[bh.ident == fish1_id] freq_temp = bh.freq[bh.ident == fish1_id]
time_temp = bh.time[bh.idx[bh.ident == fish1_id]] time_temp = bh.time[bh.idx[bh.ident == fish1_id]]
@ -94,35 +123,38 @@ def main(datapath: str):
ax[2].set_xticks([]) ax[2].set_xticks([])
ps.hide_ax(ax[2]) ps.hide_ax(ax[2])
ax[4].axvspan(0, 3, 0, 5, facecolor='grey', alpha=0.5) ax[4].axvspan(0, 3, 0, 5, facecolor="grey", alpha=0.5)
ax[4].set_xticks(np.arange(0, 6.1, 0.5)) ax[4].set_xticks(np.arange(0, 6.1, 0.5))
ps.hide_ax(ax[3]) ps.hide_ax(ax[3])
labelpad = 30 labelpad = 30
fsize = 12 fsize = 12
ax[0].set_ylabel('Contact', rotation=0, ax[0].set_ylabel(
labelpad=labelpad, fontsize=fsize) "Contact", rotation=0, labelpad=labelpad, fontsize=fsize
)
ax[0].yaxis.set_label_coords(-0.062, -0.08) ax[0].yaxis.set_label_coords(-0.062, -0.08)
ax[1].set_ylabel('Chasing', rotation=0, ax[1].set_ylabel(
labelpad=labelpad, fontsize=fsize) "Chasing", rotation=0, labelpad=labelpad, fontsize=fsize
)
ax[1].yaxis.set_label_coords(-0.06, -0.08) ax[1].yaxis.set_label_coords(-0.06, -0.08)
ax[2].set_ylabel('Chirps', rotation=0, ax[2].set_ylabel(
labelpad=labelpad, fontsize=fsize) "Chirps", rotation=0, labelpad=labelpad, fontsize=fsize
)
ax[2].yaxis.set_label_coords(-0.07, -0.08) ax[2].yaxis.set_label_coords(-0.07, -0.08)
ax[4].set_ylabel('EODf') ax[4].set_ylabel("EODf")
ax[4].set_xlabel('Time [h]') ax[4].set_xlabel("Time [h]")
# ax[0].set_title(foldername.split('/')[-2]) # ax[0].set_title(foldername.split('/')[-2])
# 2020-03-31-9_59 # 2020-03-31-9_59
plt.subplots_adjust(left=0.158, right=0.987, top=0.918, bottom=0.136) plt.subplots_adjust(left=0.158, right=0.987, top=0.918, bottom=0.136)
plt.savefig('../poster/figs/timeline.svg') plt.savefig("../poster/figs/timeline.svg")
plt.show() plt.show()
# plot chirps # plot chirps
if __name__ == '__main__': if __name__ == "__main__":
# Path to the data # Path to the data
datapath = '../data/mount_data/' datapath = "../data/mount_data/"
main(datapath) main(datapath)

View File

@ -11,7 +11,6 @@ ps = PlotStyle()
def main(): def main():
# Load data # Load data
datapath = "../data/2022-06-02-10_00/" datapath = "../data/2022-06-02-10_00/"
data = LoadData(datapath) data = LoadData(datapath)
@ -24,26 +23,31 @@ def main():
timescaler = 1000 timescaler = 1000
raw = data.raw[window_start_index:window_start_index + raw = data.raw[
window_duration_index, 10] window_start_index : window_start_index + window_duration_index, 10
]
fig, (ax1, ax2) = plt.subplots( fig, (ax1, ax2) = plt.subplots(
1, 2, figsize=(21 * ps.cm, 8*ps.cm), sharex=True, sharey=True) 1, 2, figsize=(21 * ps.cm, 8 * ps.cm), sharex=True, sharey=True
)
# plot instantaneous frequency # plot instantaneous frequency
filtered1 = bandpass_filter( filtered1 = bandpass_filter(
signal=raw, lowf=750, highf=1200, samplerate=data.raw_rate) signal=raw, lowf=750, highf=1200, samplerate=data.raw_rate
)
filtered2 = bandpass_filter( filtered2 = bandpass_filter(
signal=raw, lowf=550, highf=700, samplerate=data.raw_rate) signal=raw, lowf=550, highf=700, samplerate=data.raw_rate
)
freqtime1, freq1 = instantaneous_frequency( freqtime1, freq1 = instantaneous_frequency(
filtered1, data.raw_rate, smoothing_window=3) filtered1, data.raw_rate, smoothing_window=3
)
freqtime2, freq2 = instantaneous_frequency( freqtime2, freq2 = instantaneous_frequency(
filtered2, data.raw_rate, smoothing_window=3) filtered2, data.raw_rate, smoothing_window=3
)
ax1.plot(freqtime1*timescaler, freq1, color=ps.g, lw=2, label="Fish 1") ax1.plot(freqtime1 * timescaler, freq1, color=ps.g, lw=2, label="Fish 1")
ax1.plot(freqtime2*timescaler, freq2, color=ps.gray, ax1.plot(freqtime2 * timescaler, freq2, color=ps.gray, lw=2, label="Fish 2")
lw=2, label="Fish 2")
# ax.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0) # ax.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0)
# # ps.hide_xax(ax1) # # ps.hide_xax(ax1)
@ -62,8 +66,8 @@ def main():
ax1.imshow( ax1.imshow(
decibel(spec_power[fmask, :]), decibel(spec_power[fmask, :]),
extent=[ extent=[
spec_times[0]*timescaler, spec_times[0] * timescaler,
spec_times[-1]*timescaler, spec_times[-1] * timescaler,
spec_freqs[fmask][0], spec_freqs[fmask][0],
spec_freqs[fmask][-1], spec_freqs[fmask][-1],
], ],
@ -87,8 +91,8 @@ def main():
ax2.imshow( ax2.imshow(
decibel(spec_power[fmask, :]), decibel(spec_power[fmask, :]),
extent=[ extent=[
spec_times[0]*timescaler, spec_times[0] * timescaler,
spec_times[-1]*timescaler, spec_times[-1] * timescaler,
spec_freqs[fmask][0], spec_freqs[fmask][0],
spec_freqs[fmask][-1], spec_freqs[fmask][-1],
], ],
@ -98,9 +102,8 @@ def main():
alpha=1, alpha=1,
) )
# ps.hide_xax(ax3) # ps.hide_xax(ax3)
ax2.plot(freqtime1*timescaler, freq1, color=ps.g, lw=2, label="_") ax2.plot(freqtime1 * timescaler, freq1, color=ps.g, lw=2, label="_")
ax2.plot(freqtime2*timescaler, freq2, color=ps.gray, ax2.plot(freqtime2 * timescaler, freq2, color=ps.gray, lw=2, label="_")
lw=2, label="_")
ax2.set_xlim(75, 200) ax2.set_xlim(75, 200)
ax1.set_ylim(400, 1200) ax1.set_ylim(400, 1200)
@ -109,15 +112,22 @@ def main():
fig.supylabel("Frequency [Hz]", fontsize=14) fig.supylabel("Frequency [Hz]", fontsize=14)
handles, labels = ax1.get_legend_handles_labels() handles, labels = ax1.get_legend_handles_labels()
ax2.legend(handles, labels, bbox_to_anchor=(1.04, 1), loc="upper left", ncol=1,) ax2.legend(
handles,
labels,
bbox_to_anchor=(1.04, 1),
loc="upper left",
ncol=1,
)
ps.letter_subplots(xoffset=[-0.27, -0.1], yoffset=1.05) ps.letter_subplots(xoffset=[-0.27, -0.1], yoffset=1.05)
plt.subplots_adjust(left=0.12, right=0.85, top=0.89, plt.subplots_adjust(
bottom=0.18, hspace=0.35) left=0.12, right=0.85, top=0.89, bottom=0.18, hspace=0.35
)
plt.savefig('../poster/figs/introplot.pdf') plt.savefig("../poster/figs/introplot.pdf")
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -1,7 +1,9 @@
from modules.plotstyle import PlotStyle from modules.plotstyle import PlotStyle
from modules.behaviour_handling import ( from modules.behaviour_handling import (
Behavior, correct_chasing_events, center_chirps) Behavior,
correct_chasing_events,
center_chirps,
)
from modules.datahandling import flatten, causal_kde1d, acausal_kde1d from modules.datahandling import flatten, causal_kde1d, acausal_kde1d
from modules.logger import makeLogger from modules.logger import makeLogger
from pandas import read_csv from pandas import read_csv
@ -18,80 +20,93 @@ logger = makeLogger(__name__)
ps = PlotStyle() ps = PlotStyle()
def bootstrap(data, nresamples, kde_time, kernel_width, event_times, time_before, time_after): def bootstrap(
data,
nresamples,
kde_time,
kernel_width,
event_times,
time_before,
time_after,
):
bootstrapped_kdes = [] bootstrapped_kdes = []
data = data[data <= 3*60*60] # only night time data = data[data <= 3 * 60 * 60] # only night time
diff_data = np.diff(np.sort(data), prepend=0) diff_data = np.diff(np.sort(data), prepend=0)
# if len(data) != 0: # if len(data) != 0:
# mean_chirprate = (len(data) - 1) / (data[-1] - data[0]) # mean_chirprate = (len(data) - 1) / (data[-1] - data[0])
for i in tqdm(range(nresamples)): for i in tqdm(range(nresamples)):
np.random.shuffle(diff_data) np.random.shuffle(diff_data)
bootstrapped_data = np.cumsum(diff_data) bootstrapped_data = np.cumsum(diff_data)
# bootstrapped_data = data + np.random.randn(len(data)) * 10 # bootstrapped_data = data + np.random.randn(len(data)) * 10
bootstrap_data_centered = center_chirps( bootstrap_data_centered = center_chirps(
bootstrapped_data, event_times, time_before, time_after) bootstrapped_data, event_times, time_before, time_after
)
bootstrapped_kde = acausal_kde1d( bootstrapped_kde = acausal_kde1d(
bootstrap_data_centered, time=kde_time, width=kernel_width) bootstrap_data_centered, time=kde_time, width=kernel_width
)
bootstrapped_kde = list(np.asarray( bootstrapped_kde = list(np.asarray(bootstrapped_kde) / len(event_times))
bootstrapped_kde) / len(event_times))
bootstrapped_kdes.append(bootstrapped_kde) bootstrapped_kdes.append(bootstrapped_kde)
return bootstrapped_kdes return bootstrapped_kdes
def jackknife(data, nresamples, subsetsize, kde_time, kernel_width, event_times, time_before, time_after): def jackknife(
data,
nresamples,
subsetsize,
kde_time,
kernel_width,
event_times,
time_before,
time_after,
):
jackknife_kdes = [] jackknife_kdes = []
data = data[data <= 3*60*60] # only night time data = data[data <= 3 * 60 * 60] # only night time
subsetsize = int(len(data) * subsetsize) subsetsize = int(len(data) * subsetsize)
diff_data = np.diff(np.sort(data), prepend=0) diff_data = np.diff(np.sort(data), prepend=0)
for i in tqdm(range(nresamples)): for i in tqdm(range(nresamples)):
jackknifed_data = np.random.choice(diff_data, subsetsize, replace=False)
jackknifed_data = np.random.choice(
diff_data, subsetsize, replace=False)
jackknifed_data = np.cumsum(jackknifed_data) jackknifed_data = np.cumsum(jackknifed_data)
jackknifed_data_centered = center_chirps( jackknifed_data_centered = center_chirps(
jackknifed_data, event_times, time_before, time_after) jackknifed_data, event_times, time_before, time_after
)
jackknifed_kde = acausal_kde1d( jackknifed_kde = acausal_kde1d(
jackknifed_data_centered, time=kde_time, width=kernel_width) jackknifed_data_centered, time=kde_time, width=kernel_width
)
jackknifed_kde = list(np.asarray( jackknifed_kde = list(np.asarray(jackknifed_kde) / len(event_times))
jackknifed_kde) / len(event_times))
jackknife_kdes.append(jackknifed_kde) jackknife_kdes.append(jackknifed_kde)
return jackknife_kdes return jackknife_kdes
def get_chirp_winner_loser(folder_name, Behavior, order_meta_df): def get_chirp_winner_loser(folder_name, Behavior, order_meta_df):
foldername = folder_name.split("/")[-2]
foldername = folder_name.split('/')[-2] winner_row = order_meta_df[order_meta_df["recording"] == foldername]
winner_row = order_meta_df[order_meta_df['recording'] == foldername] winner = winner_row["winner"].values[0].astype(int)
winner = winner_row['winner'].values[0].astype(int) winner_fish1 = winner_row["fish1"].values[0].astype(int)
winner_fish1 = winner_row['fish1'].values[0].astype(int) winner_fish2 = winner_row["fish2"].values[0].astype(int)
winner_fish2 = winner_row['fish2'].values[0].astype(int)
if winner > 0: if winner > 0:
if winner == winner_fish1: if winner == winner_fish1:
winner_fish_id = winner_row['rec_id1'].values[0] winner_fish_id = winner_row["rec_id1"].values[0]
loser_fish_id = winner_row['rec_id2'].values[0] loser_fish_id = winner_row["rec_id2"].values[0]
elif winner == winner_fish2: elif winner == winner_fish2:
winner_fish_id = winner_row['rec_id2'].values[0] winner_fish_id = winner_row["rec_id2"].values[0]
loser_fish_id = winner_row['rec_id1'].values[0] loser_fish_id = winner_row["rec_id1"].values[0]
chirp_winner = Behavior.chirps[Behavior.chirps_ids == winner_fish_id] chirp_winner = Behavior.chirps[Behavior.chirps_ids == winner_fish_id]
chirp_loser = Behavior.chirps[Behavior.chirps_ids == loser_fish_id] chirp_loser = Behavior.chirps[Behavior.chirps_ids == loser_fish_id]
@ -101,7 +116,6 @@ def get_chirp_winner_loser(folder_name, Behavior, order_meta_df):
def main(dataroot): def main(dataroot):
foldernames, _ = np.asarray(get_valid_datasets(dataroot)) foldernames, _ = np.asarray(get_valid_datasets(dataroot))
plot_all = True plot_all = True
time_before = 90 time_before = 90
@ -111,10 +125,9 @@ def main(dataroot):
kde_time = np.arange(-time_before, time_after, dt) kde_time = np.arange(-time_before, time_after, dt)
nbootstraps = 50 nbootstraps = 50
meta_path = ( meta_path = ("/").join(foldernames[0].split("/")[:-2]) + "/order_meta.csv"
'/').join(foldernames[0].split('/')[:-2]) + '/order_meta.csv'
meta = pd.read_csv(meta_path) meta = pd.read_csv(meta_path)
meta['recording'] = meta['recording'].str[1:-1] meta["recording"] = meta["recording"].str[1:-1]
winner_onsets = [] winner_onsets = []
winner_offsets = [] winner_offsets = []
@ -143,24 +156,24 @@ def main(dataroot):
# loser_onset_chirpcount = 0 # loser_onset_chirpcount = 0
# loser_offset_chirpcount = 0 # loser_offset_chirpcount = 0
# loser_physical_chirpcount = 0 # loser_physical_chirpcount = 0
fig, ax = plt.subplots(1, 2, figsize=( fig, ax = plt.subplots(
14 * ps.cm, 7*ps.cm), sharey=True, sharex=True) 1, 2, figsize=(14 * ps.cm, 7 * ps.cm), sharey=True, sharex=True
)
# Iterate over all recordings and save chirp- and event-timestamps # Iterate over all recordings and save chirp- and event-timestamps
good_recs = np.asarray([0, 15]) good_recs = np.asarray([0, 15])
for i, folder in tqdm(enumerate(foldernames[good_recs])): for i, folder in tqdm(enumerate(foldernames[good_recs])):
foldername = folder.split("/")[-2]
foldername = folder.split('/')[-2]
# logger.info('Loading data from folder: {}'.format(foldername)) # logger.info('Loading data from folder: {}'.format(foldername))
broken_folders = ['../data/mount_data/2020-05-12-10_00/'] broken_folders = ["../data/mount_data/2020-05-12-10_00/"]
if folder in broken_folders: if folder in broken_folders:
continue continue
bh = Behavior(folder) bh = Behavior(folder)
category, timestamps = correct_chasing_events(bh.behavior, bh.start_s) category, timestamps = correct_chasing_events(bh.behavior, bh.start_s)
category = category[timestamps < 3*60*60] # only night time category = category[timestamps < 3 * 60 * 60] # only night time
timestamps = timestamps[timestamps < 3*60*60] # only night time timestamps = timestamps[timestamps < 3 * 60 * 60] # only night time
winner, loser = get_chirp_winner_loser(folder, bh, meta) winner, loser = get_chirp_winner_loser(folder, bh, meta)
if winner is None: if winner is None:
@ -168,27 +181,33 @@ def main(dataroot):
# winner_count += len(winner) # winner_count += len(winner)
# loser_count += len(loser) # loser_count += len(loser)
onsets = (timestamps[category == 0]) onsets = timestamps[category == 0]
offsets = (timestamps[category == 1]) offsets = timestamps[category == 1]
physicals = (timestamps[category == 2]) physicals = timestamps[category == 2]
onset_count += len(onsets) onset_count += len(onsets)
offset_count += len(offsets) offset_count += len(offsets)
physical_count += len(physicals) physical_count += len(physicals)
winner_onsets.append(center_chirps( winner_onsets.append(
winner, onsets, time_before, time_after)) center_chirps(winner, onsets, time_before, time_after)
winner_offsets.append(center_chirps( )
winner, offsets, time_before, time_after)) winner_offsets.append(
winner_physicals.append(center_chirps( center_chirps(winner, offsets, time_before, time_after)
winner, physicals, time_before, time_after)) )
winner_physicals.append(
loser_onsets.append(center_chirps( center_chirps(winner, physicals, time_before, time_after)
loser, onsets, time_before, time_after)) )
loser_offsets.append(center_chirps(
loser, offsets, time_before, time_after)) loser_onsets.append(
loser_physicals.append(center_chirps( center_chirps(loser, onsets, time_before, time_after)
loser, physicals, time_before, time_after)) )
loser_offsets.append(
center_chirps(loser, offsets, time_before, time_after)
)
loser_physicals.append(
center_chirps(loser, physicals, time_before, time_after)
)
# winner_onset_chirpcount += len(winner_onsets[-1]) # winner_onset_chirpcount += len(winner_onsets[-1])
# winner_offset_chirpcount += len(winner_offsets[-1]) # winner_offset_chirpcount += len(winner_offsets[-1])
@ -232,14 +251,17 @@ def main(dataroot):
# event_times=onsets, # event_times=onsets,
# time_before=time_before, # time_before=time_before,
# time_after=time_after)) # time_after=time_after))
loser_offsets_boot.append(bootstrap( loser_offsets_boot.append(
bootstrap(
loser, loser,
nresamples=nbootstraps, nresamples=nbootstraps,
kde_time=kde_time, kde_time=kde_time,
kernel_width=kernel_width, kernel_width=kernel_width,
event_times=offsets, event_times=offsets,
time_before=time_before, time_before=time_before,
time_after=time_after)) time_after=time_after,
)
)
# loser_physicals_boot.append(bootstrap( # loser_physicals_boot.append(bootstrap(
# loser, # loser,
# nresamples=nbootstraps, # nresamples=nbootstraps,
@ -249,18 +271,17 @@ def main(dataroot):
# time_before=time_before, # time_before=time_before,
# time_after=time_after)) # time_after=time_after))
# loser_offsets_jackknife = jackknife( # loser_offsets_jackknife = jackknife(
# loser, # loser,
# nresamples=nbootstraps, # nresamples=nbootstraps,
# subsetsize=0.9, # subsetsize=0.9,
# kde_time=kde_time, # kde_time=kde_time,
# kernel_width=kernel_width, # kernel_width=kernel_width,
# event_times=offsets, # event_times=offsets,
# time_before=time_before, # time_before=time_before,
# time_after=time_after) # time_after=time_after)
if plot_all: if plot_all:
# winner_onsets_conv = acausal_kde1d( # winner_onsets_conv = acausal_kde1d(
# winner_onsets[-1], kde_time, kernel_width) # winner_onsets[-1], kde_time, kernel_width)
# winner_offsets_conv = acausal_kde1d( # winner_offsets_conv = acausal_kde1d(
@ -271,24 +292,35 @@ def main(dataroot):
# loser_onsets_conv = acausal_kde1d( # loser_onsets_conv = acausal_kde1d(
# loser_onsets[-1], kde_time, kernel_width) # loser_onsets[-1], kde_time, kernel_width)
loser_offsets_conv = acausal_kde1d( loser_offsets_conv = acausal_kde1d(
loser_offsets[-1], kde_time, kernel_width) loser_offsets[-1], kde_time, kernel_width
)
# loser_physicals_conv = acausal_kde1d( # loser_physicals_conv = acausal_kde1d(
# loser_physicals[-1], kde_time, kernel_width) # loser_physicals[-1], kde_time, kernel_width)
ax[i].plot(kde_time, loser_offsets_conv / ax[i].plot(
len(offsets), lw=2, zorder=100, c=ps.gblue1) kde_time,
loser_offsets_conv / len(offsets),
lw=2,
zorder=100,
c=ps.gblue1,
)
ax[i].fill_between( ax[i].fill_between(
kde_time, kde_time,
np.percentile(loser_offsets_boot[-1], 1, axis=0), np.percentile(loser_offsets_boot[-1], 1, axis=0),
np.percentile(loser_offsets_boot[-1], 99, axis=0), np.percentile(loser_offsets_boot[-1], 99, axis=0),
color='gray', color="gray",
alpha=0.8) alpha=0.8,
)
ax[i].plot(kde_time, np.median(loser_offsets_boot[-1], axis=0), ax[i].plot(
color=ps.black, linewidth=2) kde_time,
np.median(loser_offsets_boot[-1], axis=0),
color=ps.black,
linewidth=2,
)
ax[i].axvline(0, color=ps.gray, linestyle='--') ax[i].axvline(0, color=ps.gray, linestyle="--")
# ax[i].fill_between( # ax[i].fill_between(
# kde_time, # kde_time,
@ -300,8 +332,8 @@ def main(dataroot):
# color=ps.white, linewidth=2) # color=ps.white, linewidth=2)
ax[i].set_xlim(-60, 60) ax[i].set_xlim(-60, 60)
fig.supylabel('Chirp rate (a.u.)', fontsize=14) fig.supylabel("Chirp rate (a.u.)", fontsize=14)
fig.supxlabel('Time (s)', fontsize=14) fig.supxlabel("Time (s)", fontsize=14)
# fig, ax = plt.subplots(2, 3, figsize=( # fig, ax = plt.subplots(2, 3, figsize=(
# 21*ps.cm, 10*ps.cm), sharey=True, sharex=True) # 21*ps.cm, 10*ps.cm), sharey=True, sharex=True)
@ -521,9 +553,9 @@ def main(dataroot):
# color=ps.gray, # color=ps.gray,
# alpha=0.5) # alpha=0.5)
plt.subplots_adjust(bottom=0.21, top=0.93) plt.subplots_adjust(bottom=0.21, top=0.93)
plt.savefig('../poster/figs/kde.pdf') plt.savefig("../poster/figs/kde.pdf")
plt.show() plt.show()
if __name__ == '__main__': if __name__ == "__main__":
main('../data/mount_data/') main("../data/mount_data/")