merge master

This commit is contained in:
2023-01-23 09:58:48 +01:00
19 changed files with 3859 additions and 526 deletions

View File

@@ -1,4 +1,5 @@
import os
import os
import numpy as np
import matplotlib.pyplot as plt
@@ -268,4 +269,5 @@ def main(datapath: str):
if __name__ == '__main__':
# Path to the data
datapath = '../data/mount_data/2020-05-13-10_00/'
datapath = '../data/mount_data/2020-05-13-10_00/'
main(datapath)

1099
code/chirpdetection.py Normal file → Executable file

File diff suppressed because it is too large Load Diff

View File

@@ -1,48 +1,46 @@
# directory setup
dataroot: "../data/"
outputdir: "../output/"
# Duration and overlap of the analysis window in seconds
window: 5
window: 10
overlap: 1
edge: 0.25
# Number of electrodes to go over
number_electrodes: 3
minimum_electrodes: 2
# Boundary for search frequency in Hz
search_boundary: 100
# Search window bandwidth and minimal baseline bandwidth
minimal_bandwidth: 20
# Cutoff frequency for envelope estimation by lowpass filter
envelope_cutoff: 25
# Instantaneous frequency smoothing usint a gaussian kernel of this width
baseline_frequency_smoothing: 5
# Cutoff frequency for envelope highpass filter
envelope_highpass_cutoff: 3
# Baseline processing parameters
baseline_envelope_cutoff: 25
baseline_envelope_bandpass_lowf: 4
baseline_envelope_bandpass_highf: 100
baseline_envelope_envelope_cutoff: 4
# Cutoff frequency for envelope of envelope
envelope_envelope_cutoff: 5
# search envelope processing parameters
search_envelope_cutoff: 5
# Instantaneous frequency bandpass filter cutoff frequencies
instantaneous_lowf: 15
instantaneous_highf: 8000
baseline_frequency_highpass_cutoff: 0.000005
baseline_frequency_envelope_cutoff: 0.000005
# Baseline envelope peak detection parameters
baseline_prominence_percentile: 90
# Search envelope peak detection parameters
search_prominence_percentile: 90
# Instantaneous frequency peak detection parameters
instantaneous_prominence_percentile: 90
# peak detecion parameters
prominence: 0.005
# search freq parameter
search_df_lower: 25
search_df_lower: 20
search_df_upper: 100
search_res: 1
search_freq_percentiles:
- 5
- 95
search_bandwidth: 10
default_search_freq: 50
# Classify events as chirps if they are less than this time apart
chirp_window_threshold: 0.05

View File

@@ -1,5 +1,78 @@
import numpy as np
from typing import List, Any
from scipy.ndimage import gaussian_filter1d
from scipy.stats import gamma, norm
def scale01(data):
"""
Normalize data to [0, 1]
Parameters
----------
data : np.ndarray
Data to normalize.
Returns
-------
np.ndarray
Normalized data.
"""
return (2*((data - np.min(data)) / (np.max(data) - np.min(data)))) - 1
def instantaneous_frequency(
signal: np.ndarray,
samplerate: int,
smoothing_window: int,
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the instantaneous frequency of a signal that is approximately
sinusoidal and symmetric around 0.
Parameters
----------
signal : np.ndarray
Signal to compute the instantaneous frequency from.
samplerate : int
Samplerate of the signal.
smoothing_window : int
Window size for the gaussian filter.
Returns
-------
tuple[np.ndarray, np.ndarray]
"""
# calculate instantaneous frequency with zero crossings
roll_signal = np.roll(signal, shift=1)
time_signal = np.arange(len(signal)) / samplerate
period_index = np.arange(len(signal))[(roll_signal < 0) & (signal >= 0)][
1:-1
]
upper_bound = np.abs(signal[period_index])
lower_bound = np.abs(signal[period_index - 1])
upper_time = np.abs(time_signal[period_index])
lower_time = np.abs(time_signal[period_index - 1])
# create ratio
lower_ratio = lower_bound / (lower_bound + upper_bound)
# appy to time delta
time_delta = upper_time - lower_time
true_zero = lower_time + lower_ratio * time_delta
# create new time array
instantaneous_frequency_time = true_zero[:-1] + 0.5 * np.diff(true_zero)
# compute frequency
instantaneous_frequency = gaussian_filter1d(
1 / np.diff(true_zero), smoothing_window
)
return instantaneous_frequency_time, instantaneous_frequency
def purge_duplicates(
@@ -64,7 +137,7 @@ def purge_duplicates(
def group_timestamps(
sublists: List[List[float]], n: int, threshold: float
sublists: List[List[float]], at_least_in: int, difference_threshold: float
) -> List[float]:
"""
Groups timestamps that are less than `threshold` milliseconds apart from
@@ -100,7 +173,7 @@ def group_timestamps(
# Group timestamps that are less than threshold milliseconds apart
for i in range(1, len(timestamps)):
if timestamps[i] - timestamps[i - 1] < threshold:
if timestamps[i] - timestamps[i - 1] < difference_threshold:
current_group.append(timestamps[i])
else:
groups.append(current_group)
@@ -111,7 +184,7 @@ def group_timestamps(
# Retain only groups that contain at least n timestamps
final_groups = []
for group in groups:
if len(group) >= n:
if len(group) >= at_least_in:
final_groups.append(group)
# Calculate the mean of each group
@@ -137,6 +210,117 @@ def flatten(list: List[List[Any]]) -> List:
return [item for sublist in list for item in sublist]
def causal_kde1d(spikes, time, width, shape=2):
"""
causalkde computes a kernel density estimate using a causal kernel (i.e. exponential or gamma distribution).
A shape of 1 turns the gamma distribution into an exponential.
Parameters
----------
spikes : array-like
spike times
time : array-like
sampling time
width : float
kernel width
shape : int, optional
shape of gamma distribution, by default 1
Returns
-------
rate : array-like
instantaneous firing rate
"""
# compute dt
dt = time[1] - time[0]
# time on which to compute kernel:
tmax = 10 * width
# kernel not wider than time
if 2 * tmax > time[-1] - time[0]:
tmax = 0.5 * (time[-1] - time[0])
# kernel time
ktime = np.arange(-tmax, tmax, dt)
# gamma kernel centered in ktime:
kernel = gamma.pdf(
x=ktime,
a=shape,
loc=0,
scale=width,
)
# indices of spikes in time array:
indices = np.asarray((spikes - time[0]) / dt, dtype=int)
# binary spike train:
brate = np.zeros(len(time))
brate[indices[(indices >= 0) & (indices < len(time))]] = 1.0
# convolution with kernel:
rate = np.convolve(brate, kernel, mode="same")
return rate
def acausal_kde1d(spikes, time, width):
"""
causalkde computes a kernel density estimate using a causal kernel (i.e. exponential or gamma distribution).
A shape of 1 turns the gamma distribution into an exponential.
Parameters
----------
spikes : array-like
spike times
time : array-like
sampling time
width : float
kernel width
shape : int, optional
shape of gamma distribution, by default 1
Returns
-------
rate : array-like
instantaneous firing rate
"""
# compute dt
dt = time[1] - time[0]
# time on which to compute kernel:
tmax = 10 * width
# kernel not wider than time
if 2 * tmax > time[-1] - time[0]:
tmax = 0.5 * (time[-1] - time[0])
# kernel time
ktime = np.arange(-tmax, tmax, dt)
# gamma kernel centered in ktime:
kernel = norm.pdf(
x=ktime,
loc=0,
scale=width,
)
# indices of spikes in time array:
indices = np.asarray((spikes - time[0]) / dt, dtype=int)
# binary spike train:
brate = np.zeros(len(time))
brate[indices[(indices >= 0) & (indices < len(time))]] = 1.0
# convolution with kernel:
rate = np.convolve(brate, kernel, mode="same")
return rate
if __name__ == "__main__":
timestamps = [

View File

@@ -3,8 +3,8 @@ import numpy as np
def bandpass_filter(
data: np.ndarray,
rate: float,
signal: np.ndarray,
samplerate: float,
lowf: float,
highf: float,
) -> np.ndarray:
@@ -12,7 +12,7 @@ def bandpass_filter(
Parameters
----------
data : np.ndarray
signal : np.ndarray
The data to be filtered
rate : float
The sampling rate
@@ -26,21 +26,22 @@ def bandpass_filter(
np.ndarray
The filtered data
"""
sos = butter(2, (lowf, highf), "bandpass", fs=rate, output="sos")
fdata = sosfiltfilt(sos, data)
return fdata
sos = butter(2, (lowf, highf), "bandpass", fs=samplerate, output="sos")
filtered_signal = sosfiltfilt(sos, signal)
return filtered_signal
def highpass_filter(
data: np.ndarray,
rate: float,
signal: np.ndarray,
samplerate: float,
cutoff: float,
) -> np.ndarray:
"""Highpass filter a signal.
Parameters
----------
data : np.ndarray
signal : np.ndarray
The data to be filtered
rate : float
The sampling rate
@@ -52,14 +53,15 @@ def highpass_filter(
np.ndarray
The filtered data
"""
sos = butter(2, cutoff, "highpass", fs=rate, output="sos")
fdata = sosfiltfilt(sos, data)
return fdata
sos = butter(2, cutoff, "highpass", fs=samplerate, output="sos")
filtered_signal = sosfiltfilt(sos, signal)
return filtered_signal
def lowpass_filter(
data: np.ndarray,
rate: float,
signal: np.ndarray,
samplerate: float,
cutoff: float
) -> np.ndarray:
"""Lowpass filter a signal.
@@ -78,21 +80,25 @@ def lowpass_filter(
np.ndarray
The filtered data
"""
sos = butter(2, cutoff, "lowpass", fs=rate, output="sos")
fdata = sosfiltfilt(sos, data)
return fdata
sos = butter(2, cutoff, "lowpass", fs=samplerate, output="sos")
filtered_signal = sosfiltfilt(sos, signal)
return filtered_signal
def envelope(data: np.ndarray, rate: float, freq: float) -> np.ndarray:
def envelope(signal: np.ndarray,
samplerate: float,
cutoff_frequency: float
) -> np.ndarray:
"""Calculate the envelope of a signal using a lowpass filter.
Parameters
----------
data : np.ndarray
signal : np.ndarray
The signal to calculate the envelope of
rate : float
samplingrate : float
The sampling rate of the signal
freq : float
cutoff_frequency : float
The cutoff frequency of the lowpass filter
Returns
@@ -100,6 +106,7 @@ def envelope(data: np.ndarray, rate: float, freq: float) -> np.ndarray:
np.ndarray
The envelope of the signal
"""
sos = butter(2, freq, "lowpass", fs=rate, output="sos")
envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(data))
sos = butter(2, cutoff_frequency, "lowpass", fs=samplerate, output="sos")
envelope = np.sqrt(2) * sosfiltfilt(sos, np.abs(signal))
return envelope

View File

@@ -30,10 +30,14 @@ def PlotStyle() -> None:
purple = "#cba6f7"
pink = "#f5c2e7"
lavender = "#b4befe"
gblue1 = "#8cb8ff"
gblue2 = "#7cdcdc"
gblue3 = "#82e896"
@classmethod
def lims(cls, track1, track2):
"""Helper function to get frequency y axis limits from two fundamental frequency tracks.
"""Helper function to get frequency y axis limits from two
fundamental frequency tracks.
Args:
track1 (array): First track
@@ -91,6 +95,16 @@ def PlotStyle() -> None:
ax.tick_params(left=False, labelleft=False)
ax.patch.set_visible(False)
@classmethod
def hide_xax(cls, ax):
ax.xaxis.set_visible(False)
ax.spines["bottom"].set_visible(False)
@classmethod
def hide_yax(cls, ax):
ax.yaxis.set_visible(False)
ax.spines["left"].set_visible(False)
@classmethod
def set_boxplot_color(cls, bp, color):
plt.setp(bp["boxes"], color=color)
@@ -216,8 +230,8 @@ def PlotStyle() -> None:
plt.rc("figure", titlesize=BIGGER_SIZE) # fontsize of the figure title
plt.rcParams["image.cmap"] = 'cmo.haline'
# plt.rcParams["axes.xmargin"] = 0.1
# plt.rcParams["axes.ymargin"] = 0.15
plt.rcParams["axes.xmargin"] = 0.05
plt.rcParams["axes.ymargin"] = 0.1
plt.rcParams["axes.titlelocation"] = "left"
plt.rcParams["axes.titlesize"] = BIGGER_SIZE
# plt.rcParams["axes.titlepad"] = -10
@@ -230,9 +244,9 @@ def PlotStyle() -> None:
plt.rcParams["legend.borderaxespad"] = 0.5
plt.rcParams["legend.fancybox"] = False
# specify the custom font to use
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = "Helvetica Now Text"
# # specify the custom font to use
# plt.rcParams["font.family"] = "sans-serif"
# plt.rcParams["font.sans-serif"] = "Helvetica Now Text"
# dark mode modifications
plt.rcParams["boxplot.flierprops.color"] = white
@@ -271,7 +285,7 @@ def PlotStyle() -> None:
plt.rcParams["ytick.color"] = gray # color of the ticks
plt.rcParams["grid.color"] = dark_gray # grid color
plt.rcParams["figure.facecolor"] = black # figure face color
plt.rcParams["figure.edgecolor"] = "#555169" # figure edge color
plt.rcParams["figure.edgecolor"] = black # figure edge color
plt.rcParams["savefig.facecolor"] = black # figure face color when saving
return style

View File

@@ -0,0 +1,121 @@
import numpy as np
import matplotlib.pyplot as plt
from thunderfish.powerspectrum import spectrogram, decibel
from modules.filehandling import LoadData
from modules.datahandling import instantaneous_frequency
from modules.filters import bandpass_filter
from modules.plotstyle import PlotStyle
ps = PlotStyle()
def main():
# Load data
datapath = "../data/2022-06-02-10_00/"
data = LoadData(datapath)
# good chirp times for data: 2022-06-02-10_00
window_start_seconds = 3 * 60 * 60 + 6 * 60 + 43.5 + 9 + 6.25
window_start_index = window_start_seconds * data.raw_rate
window_duration_seconds = 0.2
window_duration_index = window_duration_seconds * data.raw_rate
timescaler = 1000
raw = data.raw[window_start_index:window_start_index +
window_duration_index, 10]
fig, (ax1, ax2, ax3) = plt.subplots(
3, 1, figsize=(12 * ps.cm, 10*ps.cm), sharex=True, sharey=True)
# plot instantaneous frequency
filtered1 = bandpass_filter(
signal=raw, lowf=750, highf=1200, samplerate=data.raw_rate)
filtered2 = bandpass_filter(
signal=raw, lowf=550, highf=700, samplerate=data.raw_rate)
freqtime1, freq1 = instantaneous_frequency(
filtered1, data.raw_rate, smoothing_window=3)
freqtime2, freq2 = instantaneous_frequency(
filtered2, data.raw_rate, smoothing_window=3)
ax1.plot(freqtime1*timescaler, freq1, color=ps.gblue1,
lw=2, label=f"fish 1, {np.median(freq1):.0f} Hz")
ax1.plot(freqtime2*timescaler, freq2, color=ps.gblue3,
lw=2, label=f"fish 2, {np.median(freq2):.0f} Hz")
ax1.legend(bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower center",
mode="normal", borderaxespad=0, ncol=2)
ps.hide_xax(ax1)
# plot fine spectrogram
spec_power, spec_freqs, spec_times = spectrogram(
raw,
ratetime=data.raw_rate,
freq_resolution=150,
overlap_frac=0.2,
)
ylims = [300, 1200]
fmask = np.zeros(spec_freqs.shape, dtype=bool)
fmask[(spec_freqs > ylims[0]) & (spec_freqs < ylims[1])] = True
ax2.imshow(
decibel(spec_power[fmask, :]),
extent=[
spec_times[0]*timescaler,
spec_times[-1]*timescaler,
spec_freqs[fmask][0],
spec_freqs[fmask][-1],
],
aspect="auto",
origin="lower",
interpolation="gaussian",
alpha=1,
)
ps.hide_xax(ax2)
# plot coarse spectrogram
spec_power, spec_freqs, spec_times = spectrogram(
raw,
ratetime=data.raw_rate,
freq_resolution=10,
overlap_frac=0.3,
)
fmask = np.zeros(spec_freqs.shape, dtype=bool)
fmask[(spec_freqs > ylims[0]) & (spec_freqs < ylims[1])] = True
ax3.imshow(
decibel(spec_power[fmask, :]),
extent=[
spec_times[0]*timescaler,
spec_times[-1]*timescaler,
spec_freqs[fmask][0],
spec_freqs[fmask][-1],
],
aspect="auto",
origin="lower",
interpolation="gaussian",
alpha=1,
)
# ps.hide_xax(ax3)
ax3.set_xlabel("time [ms]")
ax2.set_ylabel("frequency [Hz]")
ax1.set_yticks(np.arange(400, 1201, 400))
ax1.spines.left.set_bounds((400, 1200))
ax2.set_yticks(np.arange(400, 1201, 400))
ax2.spines.left.set_bounds((400, 1200))
ax3.set_yticks(np.arange(400, 1201, 400))
ax3.spines.left.set_bounds((400, 1200))
plt.subplots_adjust(left=0.17, right=0.98, top=0.9,
bottom=0.14, hspace=0.35)
plt.savefig('../poster/figs/introplot.pdf')
plt.show()
if __name__ == '__main__':
main()