Merge branch 'master' into behaviour

This commit is contained in:
wendtalexander
2023-01-23 09:41:37 +01:00
17 changed files with 3174 additions and 197 deletions

View File

@@ -2,7 +2,9 @@ from itertools import compress
from dataclasses import dataclass
import numpy as np
from IPython import embed
import matplotlib.pyplot as plt
import matplotlib.gridspec as gr
from scipy.signal import find_peaks
from thunderfish.powerspectrum import spectrogram, decibel
from sklearn.preprocessing import normalize
@@ -40,9 +42,12 @@ class PlotBuffer:
time: np.ndarray
baseline: np.ndarray
baseline_envelope_unfiltered: np.ndarray
baseline_envelope: np.ndarray
baseline_peaks: np.ndarray
search_frequency: float
search: np.ndarray
search_envelope_unfiltered: np.ndarray
search_envelope: np.ndarray
search_peaks: np.ndarray
@@ -57,84 +62,202 @@ class PlotBuffer:
# make data for plotting
# # get index of track data in this time window
# window_idx = np.arange(len(self.data.idx))[
# (self.data.ident == self.track_id) & (self.data.time[self.data.idx] >= self.t0) & (
# self.data.time[self.data.idx] <= (self.t0 + self.dt))
# ]
# get index of track data in this time window
window_idx = np.arange(len(self.data.idx))[
(self.data.ident == self.track_id)
& (self.data.time[self.data.idx] >= self.t0)
& (self.data.time[self.data.idx] <= (self.t0 + self.dt))
]
# get tracked frequencies and their times
# freq_temp = self.data.freq[window_idx]
# time_temp = self.data.times[window_idx]
freq_temp = self.data.freq[window_idx]
# time_temp = self.data.time[
# self.data.idx[self.data.ident == self.track_id]][
# (self.data.time >= self.t0)
# & (self.data.time <= (self.t0 + self.dt))
# ]
# remake the band we filtered in
q25, q50, q75 = np.percentile(freq_temp, [25, 50, 75])
search_upper, search_lower = (
q50 + self.search_frequency + self.config.minimal_bandwidth / 2,
q50 + self.search_frequency - self.config.minimal_bandwidth / 2,
)
# get indices on raw data
start_idx = self.t0 * self.data.raw_rate
window_duration = self.dt * self.data.raw_rate
start_idx = (self.t0 - 5) * self.data.raw_rate
window_duration = (self.dt + 10) * self.data.raw_rate
stop_idx = start_idx + window_duration
# get raw data
data_oi = self.data.raw[start_idx:stop_idx, self.electrode]
fig, axs = plt.subplots(
7,
1,
figsize=(20 / 2.54, 12 / 2.54),
constrained_layout=True,
sharex=True,
sharey="row",
self.time = self.time - self.t0
self.frequency_time = self.frequency_time - self.t0
chirps = np.asarray(chirps) - self.t0
self.t0_old = self.t0
self.t0 = 0
fig = plt.figure(
figsize=(14 / 2.54, 20 / 2.54)
)
gs0 = gr.GridSpec(
3, 1, figure=fig, height_ratios=[1, 1, 1]
)
gs1 = gs0[0].subgridspec(1, 1)
gs2 = gs0[1].subgridspec(3, 1, hspace=0.4)
gs3 = gs0[2].subgridspec(3, 1, hspace=0.4)
# gs4 = gs0[5].subgridspec(1, 1)
ax6 = fig.add_subplot(gs3[2, 0])
ax0 = fig.add_subplot(gs1[0, 0], sharex=ax6)
ax1 = fig.add_subplot(gs2[0, 0], sharex=ax6)
ax2 = fig.add_subplot(gs2[1, 0], sharex=ax6)
ax3 = fig.add_subplot(gs2[2, 0], sharex=ax6)
ax4 = fig.add_subplot(gs3[0, 0], sharex=ax6)
ax5 = fig.add_subplot(gs3[1, 0], sharex=ax6)
# ax7 = fig.add_subplot(gs4[0, 0], sharex=ax0)
# ax_leg = fig.add_subplot(gs0[1, 0])
waveform_scaler = 1000
lw = 1.5
# plot spectrogram
plot_spectrogram(axs[0], data_oi, self.data.raw_rate, self.t0)
_ = plot_spectrogram(
ax0,
data_oi,
self.data.raw_rate,
self.t0 - 5,
[np.max(self.frequency) - 200, np.max(self.frequency) + 200]
)
for track_id in self.data.ids:
t0_track = self.t0_old - 5
dt_track = self.dt + 10
window_idx = np.arange(len(self.data.idx))[
(self.data.ident == track_id)
& (self.data.time[self.data.idx] >= t0_track)
& (self.data.time[self.data.idx] <= (t0_track + dt_track))
]
# get tracked frequencies and their times
f = self.data.freq[window_idx]
t = self.data.time[
self.data.idx[self.data.ident == self.track_id]]
tmask = (t >= t0_track) & (t <= (t0_track + dt_track))
if track_id == self.track_id:
ax0.plot(t[tmask]-self.t0_old, f, lw=lw,
zorder=10, color=ps.gblue1)
else:
ax0.plot(t[tmask]-self.t0_old, f, lw=lw,
zorder=10, color=ps.gray, alpha=0.5)
ax0.fill_between(
np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate),
q50 - self.config.minimal_bandwidth / 2,
q50 + self.config.minimal_bandwidth / 2,
color=ps.gblue1,
lw=1,
ls="dashed",
alpha=0.5,
)
ax0.fill_between(
np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate),
search_lower,
search_upper,
color=ps.gblue2,
lw=1,
ls="dashed",
alpha=0.5,
)
# ax0.axhline(q50, spec_times[0], spec_times[-1],
# color=ps.gblue1, lw=2, ls="dashed")
# ax0.axhline(q50 + self.search_frequency,
# spec_times[0], spec_times[-1],
# color=ps.gblue2, lw=2, ls="dashed")
for chirp in chirps:
axs[0].scatter(
chirp, np.median(self.frequency), c=ps.black, marker="x"
ax0.scatter(
chirp, np.median(self.frequency) + 150, c=ps.black, marker="v"
)
# plot waveform of filtered signal
axs[1].plot(self.time, self.baseline, c=ps.green)
ax1.plot(self.time, self.baseline * waveform_scaler,
c=ps.gray, lw=lw, alpha=0.5)
ax1.plot(self.time, self.baseline_envelope_unfiltered *
waveform_scaler, c=ps.gblue1, lw=lw, label="baseline envelope")
# plot waveform of filtered search signal
axs[2].plot(self.time, self.search)
ax2.plot(self.time, self.search * waveform_scaler,
c=ps.gray, lw=lw, alpha=0.5)
ax2.plot(self.time, self.search_envelope_unfiltered *
waveform_scaler, c=ps.gblue2, lw=lw, label="search envelope")
# plot baseline instantaneous frequency
axs[3].plot(self.frequency_time, self.frequency)
ax3.plot(self.frequency_time, self.frequency,
c=ps.gblue3, lw=lw, label="baseline inst. freq.")
# plot filtered and rectified envelope
axs[4].plot(self.time, self.baseline_envelope)
axs[4].scatter(
ax4.plot(self.time, self.baseline_envelope, c=ps.gblue1, lw=lw)
ax4.scatter(
(self.time)[self.baseline_peaks],
self.baseline_envelope[self.baseline_peaks],
c=ps.red,
edgecolors=ps.red,
zorder=10,
marker="o",
facecolors="none",
)
# plot envelope of search signal
axs[5].plot(self.time, self.search_envelope)
axs[5].scatter(
ax5.plot(self.time, self.search_envelope, c=ps.gblue2, lw=lw)
ax5.scatter(
(self.time)[self.search_peaks],
self.search_envelope[self.search_peaks],
c=ps.red,
edgecolors=ps.red,
zorder=10,
marker="o",
facecolors="none",
)
# plot filtered instantaneous frequency
axs[6].plot(self.frequency_time, self.frequency_filtered)
axs[6].scatter(
ax6.plot(self.frequency_time,
self.frequency_filtered, c=ps.gblue3, lw=lw)
ax6.scatter(
self.frequency_time[self.frequency_peaks],
self.frequency_filtered[self.frequency_peaks],
c=ps.red,
edgecolors=ps.red,
zorder=10,
marker="o",
facecolors="none",
)
axs[0].set_ylim(
np.max(self.frequency) - 200, top=np.max(self.frequency) + 200
)
axs[6].set_xlabel("Time [s]")
axs[0].set_title("Spectrogram")
axs[1].set_title("Fitered baseline")
axs[2].set_title("Fitered above")
axs[3].set_title("Fitered baseline instanenous frequency")
axs[4].set_title("Filtered envelope of baseline envelope")
axs[5].set_title("Search envelope")
axs[6].set_title("Filtered absolute instantaneous frequency")
ax0.set_ylabel("frequency [Hz]")
ax1.set_ylabel("a.u.")
ax2.set_ylabel("a.u.")
ax3.set_ylabel("Hz")
ax5.set_ylabel("a.u.")
ax6.set_xlabel("time [s]")
plt.setp(ax0.get_xticklabels(), visible=False)
plt.setp(ax1.get_xticklabels(), visible=False)
plt.setp(ax2.get_xticklabels(), visible=False)
plt.setp(ax3.get_xticklabels(), visible=False)
plt.setp(ax4.get_xticklabels(), visible=False)
plt.setp(ax5.get_xticklabels(), visible=False)
# ps.letter_subplots([ax0, ax1, ax4], xoffset=-0.21)
# ax7.set_xticks(np.arange(0, 5.5, 1))
# ax7.spines.bottom.set_bounds((0, 5))
ax0.set_xlim(0, self.config.window)
plt.subplots_adjust(left=0.165, right=0.975,
top=0.98, bottom=0.074, hspace=0.2)
fig.align_labels()
if plot == "show":
plt.show()
@@ -144,13 +267,18 @@ class PlotBuffer:
self.config.outputdir + self.data.datapath.split("/")[-2] + "/"
)
plt.savefig(f"{out}{self.track_id}_{self.t0}.pdf")
plt.savefig(f"{out}{self.track_id}_{self.t0_old}.pdf")
plt.savefig(f"{out}{self.track_id}_{self.t0_old}.svg")
plt.close()
def plot_spectrogram(
axis, signal: np.ndarray, samplerate: float, window_start_seconds: float
) -> None:
axis,
signal: np.ndarray,
samplerate: float,
window_start_seconds: float,
ylims: list[float]
) -> np.ndarray:
"""
Plot a spectrogram of a signal.
@@ -172,22 +300,28 @@ def plot_spectrogram(
spec_power, spec_freqs, spec_times = spectrogram(
signal,
ratetime=samplerate,
freq_resolution=20,
freq_resolution=10,
overlap_frac=0.5,
)
fmask = np.zeros(spec_freqs.shape, dtype=bool)
fmask[(spec_freqs > ylims[0]) & (spec_freqs < ylims[1])] = True
axis.imshow(
decibel(spec_power),
decibel(spec_power[fmask, :]),
extent=[
spec_times[0] + window_start_seconds,
spec_times[-1] + window_start_seconds,
spec_freqs[0],
spec_freqs[-1],
spec_freqs[fmask][0],
spec_freqs[fmask][-1],
],
aspect="auto",
origin="lower",
interpolation="gaussian",
alpha=1,
)
# axis.use_sticky_edges = False
return spec_times
def extract_frequency_bands(
@@ -244,9 +378,16 @@ def extract_frequency_bands(
def window_median_all_track_ids(
data: LoadData, window_start_seconds: float, window_duration_seconds: float
) -> tuple[float, list[int]]:
) -> tuple[list[tuple[float, float, float]], list[int]]:
"""
Calculate the median frequency of all fish in a given time window.
Calculate the median and quantiles of the frequency of all fish in a
given time window.
Iterate over all track ids and calculate the 25, 50 and 75 percentile
in a given time window to pass this data to 'find_searchband' function,
which then determines whether other fish in the current window fall
within the searchband of the current fish and then determine the
gaps that are outside of the percentile ranges.
Parameters
----------
@@ -259,14 +400,16 @@ def window_median_all_track_ids(
Returns
-------
tuple[float, list[int]]
tuple[list[tuple[float, float, float]], list[int]]
"""
median_freq = []
frequency_percentiles = []
track_ids = []
for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
# the window index combines the track id and the time window
window_idx = np.arange(len(data.idx))[
(data.ident == track_id)
& (data.time[data.idx] >= window_start_seconds)
@@ -277,20 +420,21 @@ def window_median_all_track_ids(
]
if len(data.freq[window_idx]) > 0:
median_freq.append(np.median(data.freq[window_idx]))
frequency_percentiles.append(
np.percentile(data.freq[window_idx], [25, 50, 75]))
track_ids.append(track_id)
# convert to numpy array
median_freq = np.asarray(median_freq)
frequency_percentiles = np.asarray(frequency_percentiles)
track_ids = np.asarray(track_ids)
return median_freq, track_ids
return frequency_percentiles, track_ids
def find_searchband(
freq_temp: np.ndarray,
median_ids: np.ndarray,
median_freq: np.ndarray,
current_frequency: np.ndarray,
percentiles_ids: np.ndarray,
frequency_percentiles: np.ndarray,
config: ConfLoader,
data: LoadData,
) -> float:
@@ -300,13 +444,13 @@ def find_searchband(
Parameters
----------
freq_temp : np.ndarray
current_frequency : np.ndarray
Current EOD frequency array / the current fish of interest.
median_ids : np.ndarray
percentiles_ids : np.ndarray
Array of track IDs of the medians of all other fish in the current
window.
median_freq : np.ndarray
Array of median frequencies of all other fish in the current window.
frequency_percentiles : np.ndarray
Array of percentiles frequencies of all other fish in the current window.
config : ConfLoader
Configuration file.
data : LoadData
@@ -317,19 +461,27 @@ def find_searchband(
float
"""
# frequency where second filter filters
# frequency window where second filter filters is potentially allowed
# to filter. This is the search window, in which we want to find
# a gap in the other fish's EODs.
search_window = np.arange(
np.median(freq_temp) + config.search_df_lower,
np.median(freq_temp) + config.search_df_upper,
np.median(current_frequency) + config.search_df_lower,
np.median(current_frequency) + config.search_df_upper,
config.search_res,
)
# search window in boolean
search_window_bool = np.ones(len(search_window), dtype=bool)
search_window_bool = np.ones_like(len(search_window), dtype=bool)
# make seperate arrays from the qartiles
q25 = np.asarray([i[0] for i in frequency_percentiles])
q75 = np.asarray([i[2] for i in frequency_percentiles])
# get tracks that fall into search window
check_track_ids = median_ids[
(median_freq > search_window[0]) & (median_freq < search_window[-1])
check_track_ids = percentiles_ids[
(q25 > search_window[0]) & (
q75 < search_window[-1])
]
# iterate through theses tracks
@@ -337,19 +489,22 @@ def find_searchband(
for j, check_track_id in enumerate(check_track_ids):
q1, q2 = np.percentile(
data.freq[data.ident == check_track_id],
config.search_freq_percentiles,
)
q25_temp = q25[percentiles_ids == check_track_id]
q75_temp = q75[percentiles_ids == check_track_id]
print(q25_temp, q75_temp)
search_window_bool[
(search_window > q1) & (search_window < q2)
(search_window > q25_temp) & (search_window < q75_temp)
] = False
# find gaps in search window
search_window_indices = np.arange(len(search_window))
# get search window gaps
# taking the diff of a boolean array gives non zero values where the
# array changes from true to false or vice versa
search_window_gaps = np.diff(search_window_bool, append=np.nan)
nonzeros = search_window_gaps[np.nonzero(search_window_gaps)[0]]
nonzeros = nonzeros[~np.isnan(nonzeros)]
@@ -385,14 +540,16 @@ def find_searchband(
search_windows_lens = [len(x) for x in search_windows]
longest_search_window = search_windows[np.argmax(search_windows_lens)]
# the center of the search frequency band is then the center of
# the longest gap
search_freq = (
longest_search_window[-1] - longest_search_window[0]
) / 2
else:
search_freq = config.default_search_freq
return search_freq
return search_freq
return config.default_search_freq
def main(datapath: str, plot: str) -> None:
@@ -432,16 +589,21 @@ def main(datapath: str, plot: str) -> None:
raw_time = np.arange(data.raw.shape[0]) / data.raw_rate
# good chirp times for data: 2022-06-02-10_00
window_start_seconds = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
window_duration_seconds = 60 * data.raw_rate
window_start_index = (3 * 60 * 60 + 6 * 60 + 43.5 + 5) * data.raw_rate
window_duration_index = 60 * data.raw_rate
# t0 = 0
# dt = data.raw.shape[0]
# window_start_seconds = (23495 + ((28336-23495)/3)) * data.raw_rate
# window_duration_seconds = (28336 - 23495) * data.raw_rate
# window_start_index = 0
# window_duration_index = data.raw.shape[0]
# generate starting points of rolling window
window_start_indices = np.arange(
window_start_seconds,
window_start_seconds + window_duration_seconds,
window_start_index,
window_start_index + window_duration_index,
window_duration - (window_overlap + 2 * window_edge),
dtype=int,
)
@@ -523,10 +685,10 @@ def main(datapath: str, plot: str) -> None:
search_frequency = find_searchband(
config=config,
freq_temp=current_frequencies,
median_ids=median_ids,
current_frequency=current_frequencies,
percentiles_ids=median_ids,
data=data,
median_freq=median_freq,
frequency_percentiles=median_freq,
)
# add all chirps that are detected on mulitple electrodes for one
@@ -598,11 +760,12 @@ def main(datapath: str, plot: str) -> None:
# band envelope correspond to troughs in the baseline envelope
# during chirps
search_envelope = envelope(
search_envelope_unfiltered = envelope(
signal=searchband,
samplerate=data.raw_rate,
cutoff_frequency=config.search_envelope_cutoff,
)
search_envelope = search_envelope_unfiltered
# compute instantaneous frequency of the baseline band to find
# anomalies during a chirp, i.e. a frequency jump upwards or
@@ -656,8 +819,10 @@ def main(datapath: str, plot: str) -> None:
)
current_raw_time = current_raw_time[no_edges]
baselineband = baselineband[no_edges]
baseline_envelope_unfiltered = baseline_envelope_unfiltered[no_edges]
searchband = searchband[no_edges]
baseline_envelope = baseline_envelope[no_edges]
search_envelope_unfiltered = search_envelope_unfiltered[no_edges]
search_envelope = search_envelope[no_edges]
# get instantaneous frequency withoup edges
@@ -709,7 +874,9 @@ def main(datapath: str, plot: str) -> None:
baseline_peak_timestamps = current_raw_time[
baseline_peak_indices
]
search_peak_timestamps = current_raw_time[search_peak_indices]
search_peak_timestamps = current_raw_time[
search_peak_indices]
frequency_peak_timestamps = baseline_frequency_time[
frequency_peak_indices
]
@@ -770,10 +937,13 @@ def main(datapath: str, plot: str) -> None:
track_id=track_id,
data=data,
time=current_raw_time,
baseline_envelope_unfiltered=baseline_envelope_unfiltered,
baseline=baselineband,
baseline_envelope=baseline_envelope,
baseline_peaks=baseline_peak_indices,
search_frequency=search_frequency,
search=searchband,
search_envelope_unfiltered=search_envelope_unfiltered,
search_envelope=search_envelope,
search_peaks=search_peak_indices,
frequency_time=baseline_frequency_time,
@@ -810,9 +980,9 @@ def main(datapath: str, plot: str) -> None:
multiwindow_chirps.append(multielectrode_chirps_validated)
multiwindow_ids.append(track_id)
logger.debug(
"Found %d chirps, starting plotting ... "
% len(multielectrode_chirps_validated)
logger.info(
f"Found {len(multielectrode_chirps_validated)}"
f" chirps for fish {track_id} in this window!"
)
# if chirps are detected and the plot flag is set, plot the
# chirps, otheswise try to delete the buffer if it exists
@@ -877,4 +1047,6 @@ def main(datapath: str, plot: str) -> None:
if __name__ == "__main__":
# datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-05-13-10_00/"
datapath = "../data/2022-06-02-10_00/"
main(datapath, plot="show")
# datapath = "/home/weygoldt/Data/uni/efishdata/2016-colombia/fishgrid/2016-04-09-22_25/"
# datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-03-13-10_00/"
main(datapath, plot="save")

View File

@@ -3,7 +3,7 @@ dataroot: "../data/"
outputdir: "../output/"
# Duration and overlap of the analysis window in seconds
window: 5
window: 10
overlap: 1
edge: 0.25
@@ -12,7 +12,7 @@ number_electrodes: 3
minimum_electrodes: 2
# Search window bandwidth and minimal baseline bandwidth
minimal_bandwidth: 10
minimal_bandwidth: 20
# Instantaneous frequency smoothing usint a gaussian kernel of this width
baseline_frequency_smoothing: 5

View File

@@ -3,6 +3,24 @@ from typing import List, Any
from scipy.ndimage import gaussian_filter1d
def norm(data):
"""
Normalize data to [0, 1]
Parameters
----------
data : np.ndarray
Data to normalize.
Returns
-------
np.ndarray
Normalized data.
"""
return (2*((data - np.min(data)) / (np.max(data) - np.min(data)))) - 1
def instantaneous_frequency(
signal: np.ndarray,
samplerate: int,

View File

@@ -30,10 +30,14 @@ def PlotStyle() -> None:
purple = "#cba6f7"
pink = "#f5c2e7"
lavender = "#b4befe"
gblue1 = "#8cb8ff"
gblue2 = "#7cdcdc"
gblue3 = "#82e896"
@classmethod
def lims(cls, track1, track2):
"""Helper function to get frequency y axis limits from two fundamental frequency tracks.
"""Helper function to get frequency y axis limits from two
fundamental frequency tracks.
Args:
track1 (array): First track
@@ -91,6 +95,16 @@ def PlotStyle() -> None:
ax.tick_params(left=False, labelleft=False)
ax.patch.set_visible(False)
@classmethod
def hide_xax(cls, ax):
ax.xaxis.set_visible(False)
ax.spines["bottom"].set_visible(False)
@classmethod
def hide_yax(cls, ax):
ax.yaxis.set_visible(False)
ax.spines["left"].set_visible(False)
@classmethod
def set_boxplot_color(cls, bp, color):
plt.setp(bp["boxes"], color=color)
@@ -216,8 +230,8 @@ def PlotStyle() -> None:
plt.rc("figure", titlesize=BIGGER_SIZE) # fontsize of the figure title
plt.rcParams["image.cmap"] = 'cmo.haline'
# plt.rcParams["axes.xmargin"] = 0.1
# plt.rcParams["axes.ymargin"] = 0.15
plt.rcParams["axes.xmargin"] = 0.05
plt.rcParams["axes.ymargin"] = 0.1
plt.rcParams["axes.titlelocation"] = "left"
plt.rcParams["axes.titlesize"] = BIGGER_SIZE
# plt.rcParams["axes.titlepad"] = -10
@@ -230,9 +244,9 @@ def PlotStyle() -> None:
plt.rcParams["legend.borderaxespad"] = 0.5
plt.rcParams["legend.fancybox"] = False
# specify the custom font to use
#plt.rcParams["font.family"] = "sans-serif"
#plt.rcParams["font.sans-serif"] = "Helvetica Now Text"
# # specify the custom font to use
# plt.rcParams["font.family"] = "sans-serif"
# plt.rcParams["font.sans-serif"] = "Helvetica Now Text"
# dark mode modifications
plt.rcParams["boxplot.flierprops.color"] = white
@@ -271,7 +285,7 @@ def PlotStyle() -> None:
plt.rcParams["ytick.color"] = gray # color of the ticks
plt.rcParams["grid.color"] = dark_gray # grid color
plt.rcParams["figure.facecolor"] = black # figure face color
plt.rcParams["figure.edgecolor"] = "#555169" # figure edge color
plt.rcParams["figure.edgecolor"] = black # figure edge color
plt.rcParams["savefig.facecolor"] = black # figure face color when saving
return style

View File

@@ -0,0 +1,121 @@
import numpy as np
import matplotlib.pyplot as plt
from thunderfish.powerspectrum import spectrogram, decibel
from modules.filehandling import LoadData
from modules.datahandling import instantaneous_frequency
from modules.filters import bandpass_filter
from modules.plotstyle import PlotStyle
ps = PlotStyle()
def main():
# Load data
datapath = "../data/2022-06-02-10_00/"
data = LoadData(datapath)
# good chirp times for data: 2022-06-02-10_00
window_start_seconds = 3 * 60 * 60 + 6 * 60 + 43.5 + 9 + 6.25
window_start_index = window_start_seconds * data.raw_rate
window_duration_seconds = 0.2
window_duration_index = window_duration_seconds * data.raw_rate
timescaler = 1000
raw = data.raw[window_start_index:window_start_index +
window_duration_index, 10]
fig, (ax1, ax2, ax3) = plt.subplots(
3, 1, figsize=(12 * ps.cm, 10*ps.cm), sharex=True, sharey=True)
# plot instantaneous frequency
filtered1 = bandpass_filter(
signal=raw, lowf=750, highf=1200, samplerate=data.raw_rate)
filtered2 = bandpass_filter(
signal=raw, lowf=550, highf=700, samplerate=data.raw_rate)
freqtime1, freq1 = instantaneous_frequency(
filtered1, data.raw_rate, smoothing_window=3)
freqtime2, freq2 = instantaneous_frequency(
filtered2, data.raw_rate, smoothing_window=3)
ax1.plot(freqtime1*timescaler, freq1, color=ps.gblue1,
lw=2, label=f"fish 1, {np.median(freq1):.0f} Hz")
ax1.plot(freqtime2*timescaler, freq2, color=ps.gblue3,
lw=2, label=f"fish 2, {np.median(freq2):.0f} Hz")
ax1.legend(bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower center",
mode="normal", borderaxespad=0, ncol=2)
ps.hide_xax(ax1)
# plot fine spectrogram
spec_power, spec_freqs, spec_times = spectrogram(
raw,
ratetime=data.raw_rate,
freq_resolution=150,
overlap_frac=0.2,
)
ylims = [300, 1200]
fmask = np.zeros(spec_freqs.shape, dtype=bool)
fmask[(spec_freqs > ylims[0]) & (spec_freqs < ylims[1])] = True
ax2.imshow(
decibel(spec_power[fmask, :]),
extent=[
spec_times[0]*timescaler,
spec_times[-1]*timescaler,
spec_freqs[fmask][0],
spec_freqs[fmask][-1],
],
aspect="auto",
origin="lower",
interpolation="gaussian",
alpha=1,
)
ps.hide_xax(ax2)
# plot coarse spectrogram
spec_power, spec_freqs, spec_times = spectrogram(
raw,
ratetime=data.raw_rate,
freq_resolution=10,
overlap_frac=0.3,
)
fmask = np.zeros(spec_freqs.shape, dtype=bool)
fmask[(spec_freqs > ylims[0]) & (spec_freqs < ylims[1])] = True
ax3.imshow(
decibel(spec_power[fmask, :]),
extent=[
spec_times[0]*timescaler,
spec_times[-1]*timescaler,
spec_freqs[fmask][0],
spec_freqs[fmask][-1],
],
aspect="auto",
origin="lower",
interpolation="gaussian",
alpha=1,
)
# ps.hide_xax(ax3)
ax3.set_xlabel("time [ms]")
ax2.set_ylabel("frequency [Hz]")
ax1.set_yticks(np.arange(400, 1201, 400))
ax1.spines.left.set_bounds((400, 1200))
ax2.set_yticks(np.arange(400, 1201, 400))
ax2.spines.left.set_bounds((400, 1200))
ax3.set_yticks(np.arange(400, 1201, 400))
ax3.spines.left.set_bounds((400, 1200))
plt.subplots_adjust(left=0.17, right=0.98, top=0.9,
bottom=0.14, hspace=0.35)
plt.savefig('../poster/figs/introplot.pdf')
plt.show()
if __name__ == '__main__':
main()