Merge branch 'master' into plot_event_timeline

This commit is contained in:
wendtalexander
2023-01-23 09:45:48 +01:00
18 changed files with 3360 additions and 215 deletions

View File

@@ -1,16 +1,19 @@
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
from IPython import embed
from pandas import read_csv
from modules.logger import makeLogger
logger = makeLogger(__name__)
class Behavior:
"""Load behavior data from csv file as class attributes
Attributes
----------
behavior: 0: chasing onset, 1: chasing offset, 2: physical contact
behavior_type:
behavioral_category:
comment_start:
@@ -20,30 +23,195 @@ class Behavior:
media_file:
observation_date:
observation_id:
start_s:
stop_s:
start_s: start time of the event in seconds
stop_s: stop time of the event in seconds
total_length:
"""
def __init__(self, datapath: str) -> None:
csv_file = str(sorted(Path(datapath).glob('**/*.csv'))[0])
self.dataframe = read_csv(csv_file, delimiter=',')
for key in self.dataframe:
def __init__(self, folder_path: str) -> None:
LED_on_time_BORIS = np.load(os.path.join(folder_path, 'LED_on_time.npy'), allow_pickle=True)
self.time = np.load(os.path.join(folder_path, "times.npy"), allow_pickle=True)
csv_filename = [f for f in os.listdir(folder_path) if f.endswith('.csv')][0] # check if there are more than one csv file
self.dataframe = read_csv(os.path.join(folder_path, csv_filename))
self.chirps = np.load(os.path.join(folder_path, 'chirps.npy'), allow_pickle=True)
self.chirps_ids = np.load(os.path.join(folder_path, 'chirps_ids.npy'), allow_pickle=True)
for k, key in enumerate(self.dataframe.keys()):
key = key.lower()
if ' ' in key:
new_key = key.replace(' ', '_')
if '(' in new_key:
new_key = new_key.replace('(', '')
new_key = new_key.replace(')', '')
new_key = new_key.lower()
setattr(self, new_key, np.array(self.dataframe[key]))
key = key.replace(' ', '_')
if '(' in key:
key = key.replace('(', '')
key = key.replace(')', '')
setattr(self, key, np.array(self.dataframe[self.dataframe.keys()[k]]))
last_LED_t_BORIS = LED_on_time_BORIS[-1]
real_time_range = self.time[-1] - self.time[0]
factor = 1.034141
shift = last_LED_t_BORIS - real_time_range * factor
self.start_s = (self.start_s - shift) / factor
self.stop_s = (self.stop_s - shift) / factor
"""
1 - chasing onset
2 - chasing offset
3 - physical contact event
temporal encpding needs to be corrected ... not exactly 25FPS.
### correspinding python code ###
factor = 1.034141
LED_on_time_BORIS = np.load(os.path.join(folder_path, 'LED_on_time.npy'), allow_pickle=True)
last_LED_t_BORIS = LED_on_time_BORIS[-1]
real_time_range = times[-1] - times[0]
shift = last_LED_t_BORIS - real_time_range * factor
data = pd.read_csv(os.path.join(folder_path, file[1:-7] + '.csv'))
boris_times = data['Start (s)']
data_times = []
for Cevent_t in boris_times:
Cevent_boris_times = (Cevent_t - shift) / factor
data_times.append(Cevent_boris_times)
data_times = np.array(data_times)
behavior = data['Behavior']
"""
def correct_chasing_events(
category: np.ndarray,
timestamps: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
onset_ids = np.arange(
len(category))[category == 0]
offset_ids = np.arange(
len(category))[category == 1]
# Check whether on- or offset is longer and calculate length difference
if len(onset_ids) > len(offset_ids):
len_diff = len(onset_ids) - len(offset_ids)
longer_array = onset_ids
shorter_array = offset_ids
logger.info(f'Onsets are greater than offsets by {len_diff}')
elif len(onset_ids) < len(offset_ids):
len_diff = len(offset_ids) - len(onset_ids)
longer_array = offset_ids
shorter_array = onset_ids
logger.info(f'Offsets are greater than offsets by {len_diff}')
elif len(onset_ids) == len(offset_ids):
logger.info('Chasing events are equal')
return category, timestamps
# Correct the wrong chasing events; delete double events
wrong_ids = []
for i in range(len(longer_array)-(len_diff+1)):
if (shorter_array[i] > longer_array[i]) & (shorter_array[i] < longer_array[i+1]):
pass
else:
wrong_ids.append(longer_array[i])
longer_array = np.delete(longer_array, i)
category = np.delete(
category, wrong_ids)
timestamps = np.delete(
timestamps, wrong_ids)
return category, timestamps
def main(datapath: str):
# behabvior is pandas dataframe with all the data
behavior = Behavior(datapath)
# behavior is pandas dataframe with all the data
bh = Behavior(datapath)
# chirps are not sorted in time (presumably due to prior groupings)
# get and sort chirps and corresponding fish_ids of the chirps
chirps = bh.chirps[np.argsort(bh.chirps)]
chirps_fish_ids = bh.chirps_ids[np.argsort(bh.chirps)]
category = bh.behavior
timestamps = bh.start_s
# Correct for doubles in chasing on- and offsets to get the right on-/offset pairs
# Get rid of tracking faults (two onsets or two offsets after another)
category, timestamps = correct_chasing_events(category, timestamps)
# split categories
chasing_onset = timestamps[category == 0]
chasing_offset = timestamps[category == 1]
physical_contact = timestamps[category == 2]
##### TODO Physical contact-triggered chirps (PTC) mit Rasterplot #####
# Wahrscheinlichkeit von Phys auf Ch und vice versa
# Chasing-triggered chirps (CTC) mit Rasterplot
# Wahrscheinlichkeit von Chase auf Ch und vice versa
# First overview plot
fig1, ax1 = plt.subplots()
ax1.scatter(chirps, np.ones_like(chirps), marker='*', color='royalblue', label='Chirps')
ax1.scatter(chasing_onset, np.ones_like(chasing_onset)*2, marker='.', color='forestgreen', label='Chasing onset')
ax1.scatter(chasing_offset, np.ones_like(chasing_offset)*2.5, marker='.', color='firebrick', label='Chasing offset')
ax1.scatter(physical_contact, np.ones_like(physical_contact)*3, marker='x', color='black', label='Physical contact')
plt.legend()
# plt.show()
plt.close()
# Get fish ids
all_fish_ids = np.unique(chirps_fish_ids)
# Associate chirps to inidividual fish
fish1 = chirps[chirps_fish_ids == all_fish_ids[0]]
fish2 = chirps[chirps_fish_ids == all_fish_ids[1]]
fish = [len(fish1), len(fish2)]
#### Chirp counts per fish general #####
fig2, ax2 = plt.subplots()
x = ['Fish1', 'Fish2']
width = 0.35
ax2.bar(x, fish, width=width)
ax2.set_ylabel('Chirp count')
# plt.show()
plt.close()
##### Count chirps emitted during chasing events and chirps emitted out of chasing events #####
chirps_in_chasings = []
for onset, offset in zip(chasing_onset, chasing_offset):
chirps_in_chasing = [c for c in chirps if (c > onset) & (c < offset)]
chirps_in_chasings.append(chirps_in_chasing)
# chirps out of chasing events
counts_chirps_chasings = 0
chasings_without_chirps = 0
for i in chirps_in_chasings:
if i:
chasings_without_chirps += 1
else:
counts_chirps_chasings += 1
# chirps in chasing events
fig3 , ax3 = plt.subplots()
ax3.bar(['Chirps in chasing events', 'Chasing events without Chirps'], [counts_chirps_chasings, chasings_without_chirps], width=width)
plt.ylabel('Count')
plt.show()
plt.close()
# comparison between chasing events with and without chirps
embed()
if __name__ == '__main__':
# Path to the data
datapath = '../data/mount_data/2020-03-13-10_00/'
datapath = '../data/mount_data/2020-05-13-10_00/'
main(datapath)

View File

@@ -2,7 +2,9 @@ from itertools import compress
from dataclasses import dataclass
import numpy as np
from IPython import embed
import matplotlib.pyplot as plt
import matplotlib.gridspec as gr
from scipy.signal import find_peaks
from thunderfish.powerspectrum import spectrogram, decibel
from sklearn.preprocessing import normalize
@@ -40,9 +42,12 @@ class PlotBuffer:
time: np.ndarray
baseline: np.ndarray
baseline_envelope_unfiltered: np.ndarray
baseline_envelope: np.ndarray
baseline_peaks: np.ndarray
search_frequency: float
search: np.ndarray
search_envelope_unfiltered: np.ndarray
search_envelope: np.ndarray
search_peaks: np.ndarray
@@ -57,84 +62,202 @@ class PlotBuffer:
# make data for plotting
# # get index of track data in this time window
# window_idx = np.arange(len(self.data.idx))[
# (self.data.ident == self.track_id) & (self.data.time[self.data.idx] >= self.t0) & (
# self.data.time[self.data.idx] <= (self.t0 + self.dt))
# ]
# get index of track data in this time window
window_idx = np.arange(len(self.data.idx))[
(self.data.ident == self.track_id)
& (self.data.time[self.data.idx] >= self.t0)
& (self.data.time[self.data.idx] <= (self.t0 + self.dt))
]
# get tracked frequencies and their times
# freq_temp = self.data.freq[window_idx]
# time_temp = self.data.times[window_idx]
freq_temp = self.data.freq[window_idx]
# time_temp = self.data.time[
# self.data.idx[self.data.ident == self.track_id]][
# (self.data.time >= self.t0)
# & (self.data.time <= (self.t0 + self.dt))
# ]
# remake the band we filtered in
q25, q50, q75 = np.percentile(freq_temp, [25, 50, 75])
search_upper, search_lower = (
q50 + self.search_frequency + self.config.minimal_bandwidth / 2,
q50 + self.search_frequency - self.config.minimal_bandwidth / 2,
)
# get indices on raw data
start_idx = self.t0 * self.data.raw_rate
window_duration = self.dt * self.data.raw_rate
start_idx = (self.t0 - 5) * self.data.raw_rate
window_duration = (self.dt + 10) * self.data.raw_rate
stop_idx = start_idx + window_duration
# get raw data
data_oi = self.data.raw[start_idx:stop_idx, self.electrode]
fig, axs = plt.subplots(
7,
1,
figsize=(20 / 2.54, 12 / 2.54),
constrained_layout=True,
sharex=True,
sharey="row",
self.time = self.time - self.t0
self.frequency_time = self.frequency_time - self.t0
chirps = np.asarray(chirps) - self.t0
self.t0_old = self.t0
self.t0 = 0
fig = plt.figure(
figsize=(14 / 2.54, 20 / 2.54)
)
gs0 = gr.GridSpec(
3, 1, figure=fig, height_ratios=[1, 1, 1]
)
gs1 = gs0[0].subgridspec(1, 1)
gs2 = gs0[1].subgridspec(3, 1, hspace=0.4)
gs3 = gs0[2].subgridspec(3, 1, hspace=0.4)
# gs4 = gs0[5].subgridspec(1, 1)
ax6 = fig.add_subplot(gs3[2, 0])
ax0 = fig.add_subplot(gs1[0, 0], sharex=ax6)
ax1 = fig.add_subplot(gs2[0, 0], sharex=ax6)
ax2 = fig.add_subplot(gs2[1, 0], sharex=ax6)
ax3 = fig.add_subplot(gs2[2, 0], sharex=ax6)
ax4 = fig.add_subplot(gs3[0, 0], sharex=ax6)
ax5 = fig.add_subplot(gs3[1, 0], sharex=ax6)
# ax7 = fig.add_subplot(gs4[0, 0], sharex=ax0)
# ax_leg = fig.add_subplot(gs0[1, 0])
waveform_scaler = 1000
lw = 1.5
# plot spectrogram
plot_spectrogram(axs[0], data_oi, self.data.raw_rate, self.t0)
_ = plot_spectrogram(
ax0,
data_oi,
self.data.raw_rate,
self.t0 - 5,
[np.max(self.frequency) - 200, np.max(self.frequency) + 200]
)
for track_id in self.data.ids:
t0_track = self.t0_old - 5
dt_track = self.dt + 10
window_idx = np.arange(len(self.data.idx))[
(self.data.ident == track_id)
& (self.data.time[self.data.idx] >= t0_track)
& (self.data.time[self.data.idx] <= (t0_track + dt_track))
]
# get tracked frequencies and their times
f = self.data.freq[window_idx]
t = self.data.time[
self.data.idx[self.data.ident == self.track_id]]
tmask = (t >= t0_track) & (t <= (t0_track + dt_track))
if track_id == self.track_id:
ax0.plot(t[tmask]-self.t0_old, f, lw=lw,
zorder=10, color=ps.gblue1)
else:
ax0.plot(t[tmask]-self.t0_old, f, lw=lw,
zorder=10, color=ps.gray, alpha=0.5)
ax0.fill_between(
np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate),
q50 - self.config.minimal_bandwidth / 2,
q50 + self.config.minimal_bandwidth / 2,
color=ps.gblue1,
lw=1,
ls="dashed",
alpha=0.5,
)
ax0.fill_between(
np.arange(self.t0, self.t0 + self.dt, 1 / self.data.raw_rate),
search_lower,
search_upper,
color=ps.gblue2,
lw=1,
ls="dashed",
alpha=0.5,
)
# ax0.axhline(q50, spec_times[0], spec_times[-1],
# color=ps.gblue1, lw=2, ls="dashed")
# ax0.axhline(q50 + self.search_frequency,
# spec_times[0], spec_times[-1],
# color=ps.gblue2, lw=2, ls="dashed")
for chirp in chirps:
axs[0].scatter(
chirp, np.median(self.frequency), c=ps.black, marker="x"
ax0.scatter(
chirp, np.median(self.frequency) + 150, c=ps.black, marker="v"
)
# plot waveform of filtered signal
axs[1].plot(self.time, self.baseline, c=ps.green)
ax1.plot(self.time, self.baseline * waveform_scaler,
c=ps.gray, lw=lw, alpha=0.5)
ax1.plot(self.time, self.baseline_envelope_unfiltered *
waveform_scaler, c=ps.gblue1, lw=lw, label="baseline envelope")
# plot waveform of filtered search signal
axs[2].plot(self.time, self.search)
ax2.plot(self.time, self.search * waveform_scaler,
c=ps.gray, lw=lw, alpha=0.5)
ax2.plot(self.time, self.search_envelope_unfiltered *
waveform_scaler, c=ps.gblue2, lw=lw, label="search envelope")
# plot baseline instantaneous frequency
axs[3].plot(self.frequency_time, self.frequency)
ax3.plot(self.frequency_time, self.frequency,
c=ps.gblue3, lw=lw, label="baseline inst. freq.")
# plot filtered and rectified envelope
axs[4].plot(self.time, self.baseline_envelope)
axs[4].scatter(
ax4.plot(self.time, self.baseline_envelope, c=ps.gblue1, lw=lw)
ax4.scatter(
(self.time)[self.baseline_peaks],
self.baseline_envelope[self.baseline_peaks],
c=ps.red,
edgecolors=ps.red,
zorder=10,
marker="o",
facecolors="none",
)
# plot envelope of search signal
axs[5].plot(self.time, self.search_envelope)
axs[5].scatter(
ax5.plot(self.time, self.search_envelope, c=ps.gblue2, lw=lw)
ax5.scatter(
(self.time)[self.search_peaks],
self.search_envelope[self.search_peaks],
c=ps.red,
edgecolors=ps.red,
zorder=10,
marker="o",
facecolors="none",
)
# plot filtered instantaneous frequency
axs[6].plot(self.frequency_time, self.frequency_filtered)
axs[6].scatter(
ax6.plot(self.frequency_time,
self.frequency_filtered, c=ps.gblue3, lw=lw)
ax6.scatter(
self.frequency_time[self.frequency_peaks],
self.frequency_filtered[self.frequency_peaks],
c=ps.red,
edgecolors=ps.red,
zorder=10,
marker="o",
facecolors="none",
)
axs[0].set_ylim(
np.max(self.frequency) - 200, top=np.max(self.frequency) + 200
)
axs[6].set_xlabel("Time [s]")
axs[0].set_title("Spectrogram")
axs[1].set_title("Fitered baseline")
axs[2].set_title("Fitered above")
axs[3].set_title("Fitered baseline instanenous frequency")
axs[4].set_title("Filtered envelope of baseline envelope")
axs[5].set_title("Search envelope")
axs[6].set_title("Filtered absolute instantaneous frequency")
ax0.set_ylabel("frequency [Hz]")
ax1.set_ylabel("a.u.")
ax2.set_ylabel("a.u.")
ax3.set_ylabel("Hz")
ax5.set_ylabel("a.u.")
ax6.set_xlabel("time [s]")
plt.setp(ax0.get_xticklabels(), visible=False)
plt.setp(ax1.get_xticklabels(), visible=False)
plt.setp(ax2.get_xticklabels(), visible=False)
plt.setp(ax3.get_xticklabels(), visible=False)
plt.setp(ax4.get_xticklabels(), visible=False)
plt.setp(ax5.get_xticklabels(), visible=False)
# ps.letter_subplots([ax0, ax1, ax4], xoffset=-0.21)
# ax7.set_xticks(np.arange(0, 5.5, 1))
# ax7.spines.bottom.set_bounds((0, 5))
ax0.set_xlim(0, self.config.window)
plt.subplots_adjust(left=0.165, right=0.975,
top=0.98, bottom=0.074, hspace=0.2)
fig.align_labels()
if plot == "show":
plt.show()
@@ -144,13 +267,18 @@ class PlotBuffer:
self.config.outputdir + self.data.datapath.split("/")[-2] + "/"
)
plt.savefig(f"{out}{self.track_id}_{self.t0}.pdf")
plt.savefig(f"{out}{self.track_id}_{self.t0_old}.pdf")
plt.savefig(f"{out}{self.track_id}_{self.t0_old}.svg")
plt.close()
def plot_spectrogram(
axis, signal: np.ndarray, samplerate: float, window_start_seconds: float
) -> None:
axis,
signal: np.ndarray,
samplerate: float,
window_start_seconds: float,
ylims: list[float]
) -> np.ndarray:
"""
Plot a spectrogram of a signal.
@@ -172,22 +300,28 @@ def plot_spectrogram(
spec_power, spec_freqs, spec_times = spectrogram(
signal,
ratetime=samplerate,
freq_resolution=20,
freq_resolution=10,
overlap_frac=0.5,
)
fmask = np.zeros(spec_freqs.shape, dtype=bool)
fmask[(spec_freqs > ylims[0]) & (spec_freqs < ylims[1])] = True
axis.imshow(
decibel(spec_power),
decibel(spec_power[fmask, :]),
extent=[
spec_times[0] + window_start_seconds,
spec_times[-1] + window_start_seconds,
spec_freqs[0],
spec_freqs[-1],
spec_freqs[fmask][0],
spec_freqs[fmask][-1],
],
aspect="auto",
origin="lower",
interpolation="gaussian",
alpha=1,
)
# axis.use_sticky_edges = False
return spec_times
def extract_frequency_bands(
@@ -244,9 +378,16 @@ def extract_frequency_bands(
def window_median_all_track_ids(
data: LoadData, window_start_seconds: float, window_duration_seconds: float
) -> tuple[float, list[int]]:
) -> tuple[list[tuple[float, float, float]], list[int]]:
"""
Calculate the median frequency of all fish in a given time window.
Calculate the median and quantiles of the frequency of all fish in a
given time window.
Iterate over all track ids and calculate the 25, 50 and 75 percentile
in a given time window to pass this data to 'find_searchband' function,
which then determines whether other fish in the current window fall
within the searchband of the current fish and then determine the
gaps that are outside of the percentile ranges.
Parameters
----------
@@ -259,14 +400,16 @@ def window_median_all_track_ids(
Returns
-------
tuple[float, list[int]]
tuple[list[tuple[float, float, float]], list[int]]
"""
median_freq = []
frequency_percentiles = []
track_ids = []
for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
# the window index combines the track id and the time window
window_idx = np.arange(len(data.idx))[
(data.ident == track_id)
& (data.time[data.idx] >= window_start_seconds)
@@ -277,20 +420,21 @@ def window_median_all_track_ids(
]
if len(data.freq[window_idx]) > 0:
median_freq.append(np.median(data.freq[window_idx]))
frequency_percentiles.append(
np.percentile(data.freq[window_idx], [25, 50, 75]))
track_ids.append(track_id)
# convert to numpy array
median_freq = np.asarray(median_freq)
frequency_percentiles = np.asarray(frequency_percentiles)
track_ids = np.asarray(track_ids)
return median_freq, track_ids
return frequency_percentiles, track_ids
def find_searchband(
freq_temp: np.ndarray,
median_ids: np.ndarray,
median_freq: np.ndarray,
current_frequency: np.ndarray,
percentiles_ids: np.ndarray,
frequency_percentiles: np.ndarray,
config: ConfLoader,
data: LoadData,
) -> float:
@@ -300,13 +444,13 @@ def find_searchband(
Parameters
----------
freq_temp : np.ndarray
current_frequency : np.ndarray
Current EOD frequency array / the current fish of interest.
median_ids : np.ndarray
percentiles_ids : np.ndarray
Array of track IDs of the medians of all other fish in the current
window.
median_freq : np.ndarray
Array of median frequencies of all other fish in the current window.
frequency_percentiles : np.ndarray
Array of percentiles frequencies of all other fish in the current window.
config : ConfLoader
Configuration file.
data : LoadData
@@ -317,19 +461,27 @@ def find_searchband(
float
"""
# frequency where second filter filters
# frequency window where second filter filters is potentially allowed
# to filter. This is the search window, in which we want to find
# a gap in the other fish's EODs.
search_window = np.arange(
np.median(freq_temp) + config.search_df_lower,
np.median(freq_temp) + config.search_df_upper,
np.median(current_frequency) + config.search_df_lower,
np.median(current_frequency) + config.search_df_upper,
config.search_res,
)
# search window in boolean
search_window_bool = np.ones(len(search_window), dtype=bool)
search_window_bool = np.ones_like(len(search_window), dtype=bool)
# make seperate arrays from the qartiles
q25 = np.asarray([i[0] for i in frequency_percentiles])
q75 = np.asarray([i[2] for i in frequency_percentiles])
# get tracks that fall into search window
check_track_ids = median_ids[
(median_freq > search_window[0]) & (median_freq < search_window[-1])
check_track_ids = percentiles_ids[
(q25 > search_window[0]) & (
q75 < search_window[-1])
]
# iterate through theses tracks
@@ -337,19 +489,22 @@ def find_searchband(
for j, check_track_id in enumerate(check_track_ids):
q1, q2 = np.percentile(
data.freq[data.ident == check_track_id],
config.search_freq_percentiles,
)
q25_temp = q25[percentiles_ids == check_track_id]
q75_temp = q75[percentiles_ids == check_track_id]
print(q25_temp, q75_temp)
search_window_bool[
(search_window > q1) & (search_window < q2)
(search_window > q25_temp) & (search_window < q75_temp)
] = False
# find gaps in search window
search_window_indices = np.arange(len(search_window))
# get search window gaps
# taking the diff of a boolean array gives non zero values where the
# array changes from true to false or vice versa
search_window_gaps = np.diff(search_window_bool, append=np.nan)
nonzeros = search_window_gaps[np.nonzero(search_window_gaps)[0]]
nonzeros = nonzeros[~np.isnan(nonzeros)]
@@ -385,14 +540,16 @@ def find_searchband(
search_windows_lens = [len(x) for x in search_windows]
longest_search_window = search_windows[np.argmax(search_windows_lens)]
# the center of the search frequency band is then the center of
# the longest gap
search_freq = (
longest_search_window[-1] - longest_search_window[0]
) / 2
else:
search_freq = config.default_search_freq
return search_freq
return search_freq
return config.default_search_freq
def main(datapath: str, plot: str) -> None:
@@ -432,16 +589,21 @@ def main(datapath: str, plot: str) -> None:
raw_time = np.arange(data.raw.shape[0]) / data.raw_rate
# good chirp times for data: 2022-06-02-10_00
window_start_seconds = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
window_duration_seconds = 60 * data.raw_rate
window_start_index = (3 * 60 * 60 + 6 * 60 + 43.5 + 5) * data.raw_rate
window_duration_index = 60 * data.raw_rate
# t0 = 0
# dt = data.raw.shape[0]
# window_start_seconds = (23495 + ((28336-23495)/3)) * data.raw_rate
# window_duration_seconds = (28336 - 23495) * data.raw_rate
# window_start_index = 0
# window_duration_index = data.raw.shape[0]
# generate starting points of rolling window
window_start_indices = np.arange(
window_start_seconds,
window_start_seconds + window_duration_seconds,
window_start_index,
window_start_index + window_duration_index,
window_duration - (window_overlap + 2 * window_edge),
dtype=int,
)
@@ -523,10 +685,10 @@ def main(datapath: str, plot: str) -> None:
search_frequency = find_searchband(
config=config,
freq_temp=current_frequencies,
median_ids=median_ids,
current_frequency=current_frequencies,
percentiles_ids=median_ids,
data=data,
median_freq=median_freq,
frequency_percentiles=median_freq,
)
# add all chirps that are detected on mulitple electrodes for one
@@ -598,11 +760,12 @@ def main(datapath: str, plot: str) -> None:
# band envelope correspond to troughs in the baseline envelope
# during chirps
search_envelope = envelope(
search_envelope_unfiltered = envelope(
signal=searchband,
samplerate=data.raw_rate,
cutoff_frequency=config.search_envelope_cutoff,
)
search_envelope = search_envelope_unfiltered
# compute instantaneous frequency of the baseline band to find
# anomalies during a chirp, i.e. a frequency jump upwards or
@@ -656,8 +819,10 @@ def main(datapath: str, plot: str) -> None:
)
current_raw_time = current_raw_time[no_edges]
baselineband = baselineband[no_edges]
baseline_envelope_unfiltered = baseline_envelope_unfiltered[no_edges]
searchband = searchband[no_edges]
baseline_envelope = baseline_envelope[no_edges]
search_envelope_unfiltered = search_envelope_unfiltered[no_edges]
search_envelope = search_envelope[no_edges]
# get instantaneous frequency withoup edges
@@ -709,7 +874,9 @@ def main(datapath: str, plot: str) -> None:
baseline_peak_timestamps = current_raw_time[
baseline_peak_indices
]
search_peak_timestamps = current_raw_time[search_peak_indices]
search_peak_timestamps = current_raw_time[
search_peak_indices]
frequency_peak_timestamps = baseline_frequency_time[
frequency_peak_indices
]
@@ -770,10 +937,13 @@ def main(datapath: str, plot: str) -> None:
track_id=track_id,
data=data,
time=current_raw_time,
baseline_envelope_unfiltered=baseline_envelope_unfiltered,
baseline=baselineband,
baseline_envelope=baseline_envelope,
baseline_peaks=baseline_peak_indices,
search_frequency=search_frequency,
search=searchband,
search_envelope_unfiltered=search_envelope_unfiltered,
search_envelope=search_envelope,
search_peaks=search_peak_indices,
frequency_time=baseline_frequency_time,
@@ -810,9 +980,9 @@ def main(datapath: str, plot: str) -> None:
multiwindow_chirps.append(multielectrode_chirps_validated)
multiwindow_ids.append(track_id)
logger.debug(
"Found %d chirps, starting plotting ... "
% len(multielectrode_chirps_validated)
logger.info(
f"Found {len(multielectrode_chirps_validated)}"
f" chirps for fish {track_id} in this window!"
)
# if chirps are detected and the plot flag is set, plot the
# chirps, otheswise try to delete the buffer if it exists
@@ -877,4 +1047,6 @@ def main(datapath: str, plot: str) -> None:
if __name__ == "__main__":
# datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-05-13-10_00/"
datapath = "../data/2022-06-02-10_00/"
main(datapath, plot="show")
# datapath = "/home/weygoldt/Data/uni/efishdata/2016-colombia/fishgrid/2016-04-09-22_25/"
# datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-03-13-10_00/"
main(datapath, plot="save")

View File

@@ -3,7 +3,7 @@ dataroot: "../data/"
outputdir: "../output/"
# Duration and overlap of the analysis window in seconds
window: 5
window: 10
overlap: 1
edge: 0.25
@@ -12,7 +12,7 @@ number_electrodes: 3
minimum_electrodes: 2
# Search window bandwidth and minimal baseline bandwidth
minimal_bandwidth: 10
minimal_bandwidth: 20
# Instantaneous frequency smoothing usint a gaussian kernel of this width
baseline_frequency_smoothing: 5

View File

@@ -3,6 +3,24 @@ from typing import List, Any
from scipy.ndimage import gaussian_filter1d
def norm(data):
"""
Normalize data to [0, 1]
Parameters
----------
data : np.ndarray
Data to normalize.
Returns
-------
np.ndarray
Normalized data.
"""
return (2*((data - np.min(data)) / (np.max(data) - np.min(data)))) - 1
def instantaneous_frequency(
signal: np.ndarray,
samplerate: int,

View File

@@ -30,10 +30,14 @@ def PlotStyle() -> None:
purple = "#cba6f7"
pink = "#f5c2e7"
lavender = "#b4befe"
gblue1 = "#8cb8ff"
gblue2 = "#7cdcdc"
gblue3 = "#82e896"
@classmethod
def lims(cls, track1, track2):
"""Helper function to get frequency y axis limits from two fundamental frequency tracks.
"""Helper function to get frequency y axis limits from two
fundamental frequency tracks.
Args:
track1 (array): First track
@@ -91,6 +95,16 @@ def PlotStyle() -> None:
ax.tick_params(left=False, labelleft=False)
ax.patch.set_visible(False)
@classmethod
def hide_xax(cls, ax):
ax.xaxis.set_visible(False)
ax.spines["bottom"].set_visible(False)
@classmethod
def hide_yax(cls, ax):
ax.yaxis.set_visible(False)
ax.spines["left"].set_visible(False)
@classmethod
def set_boxplot_color(cls, bp, color):
plt.setp(bp["boxes"], color=color)
@@ -216,8 +230,8 @@ def PlotStyle() -> None:
plt.rc("figure", titlesize=BIGGER_SIZE) # fontsize of the figure title
plt.rcParams["image.cmap"] = 'cmo.haline'
# plt.rcParams["axes.xmargin"] = 0.1
# plt.rcParams["axes.ymargin"] = 0.15
plt.rcParams["axes.xmargin"] = 0.05
plt.rcParams["axes.ymargin"] = 0.1
plt.rcParams["axes.titlelocation"] = "left"
plt.rcParams["axes.titlesize"] = BIGGER_SIZE
# plt.rcParams["axes.titlepad"] = -10
@@ -230,9 +244,9 @@ def PlotStyle() -> None:
plt.rcParams["legend.borderaxespad"] = 0.5
plt.rcParams["legend.fancybox"] = False
# specify the custom font to use
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = "Helvetica Now Text"
# # specify the custom font to use
# plt.rcParams["font.family"] = "sans-serif"
# plt.rcParams["font.sans-serif"] = "Helvetica Now Text"
# dark mode modifications
plt.rcParams["boxplot.flierprops.color"] = white
@@ -271,7 +285,7 @@ def PlotStyle() -> None:
plt.rcParams["ytick.color"] = gray # color of the ticks
plt.rcParams["grid.color"] = dark_gray # grid color
plt.rcParams["figure.facecolor"] = black # figure face color
plt.rcParams["figure.edgecolor"] = "#555169" # figure edge color
plt.rcParams["figure.edgecolor"] = black # figure edge color
plt.rcParams["savefig.facecolor"] = black # figure face color when saving
return style

View File

@@ -0,0 +1,121 @@
import numpy as np
import matplotlib.pyplot as plt
from thunderfish.powerspectrum import spectrogram, decibel
from modules.filehandling import LoadData
from modules.datahandling import instantaneous_frequency
from modules.filters import bandpass_filter
from modules.plotstyle import PlotStyle
ps = PlotStyle()
def main():
# Load data
datapath = "../data/2022-06-02-10_00/"
data = LoadData(datapath)
# good chirp times for data: 2022-06-02-10_00
window_start_seconds = 3 * 60 * 60 + 6 * 60 + 43.5 + 9 + 6.25
window_start_index = window_start_seconds * data.raw_rate
window_duration_seconds = 0.2
window_duration_index = window_duration_seconds * data.raw_rate
timescaler = 1000
raw = data.raw[window_start_index:window_start_index +
window_duration_index, 10]
fig, (ax1, ax2, ax3) = plt.subplots(
3, 1, figsize=(12 * ps.cm, 10*ps.cm), sharex=True, sharey=True)
# plot instantaneous frequency
filtered1 = bandpass_filter(
signal=raw, lowf=750, highf=1200, samplerate=data.raw_rate)
filtered2 = bandpass_filter(
signal=raw, lowf=550, highf=700, samplerate=data.raw_rate)
freqtime1, freq1 = instantaneous_frequency(
filtered1, data.raw_rate, smoothing_window=3)
freqtime2, freq2 = instantaneous_frequency(
filtered2, data.raw_rate, smoothing_window=3)
ax1.plot(freqtime1*timescaler, freq1, color=ps.gblue1,
lw=2, label=f"fish 1, {np.median(freq1):.0f} Hz")
ax1.plot(freqtime2*timescaler, freq2, color=ps.gblue3,
lw=2, label=f"fish 2, {np.median(freq2):.0f} Hz")
ax1.legend(bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower center",
mode="normal", borderaxespad=0, ncol=2)
ps.hide_xax(ax1)
# plot fine spectrogram
spec_power, spec_freqs, spec_times = spectrogram(
raw,
ratetime=data.raw_rate,
freq_resolution=150,
overlap_frac=0.2,
)
ylims = [300, 1200]
fmask = np.zeros(spec_freqs.shape, dtype=bool)
fmask[(spec_freqs > ylims[0]) & (spec_freqs < ylims[1])] = True
ax2.imshow(
decibel(spec_power[fmask, :]),
extent=[
spec_times[0]*timescaler,
spec_times[-1]*timescaler,
spec_freqs[fmask][0],
spec_freqs[fmask][-1],
],
aspect="auto",
origin="lower",
interpolation="gaussian",
alpha=1,
)
ps.hide_xax(ax2)
# plot coarse spectrogram
spec_power, spec_freqs, spec_times = spectrogram(
raw,
ratetime=data.raw_rate,
freq_resolution=10,
overlap_frac=0.3,
)
fmask = np.zeros(spec_freqs.shape, dtype=bool)
fmask[(spec_freqs > ylims[0]) & (spec_freqs < ylims[1])] = True
ax3.imshow(
decibel(spec_power[fmask, :]),
extent=[
spec_times[0]*timescaler,
spec_times[-1]*timescaler,
spec_freqs[fmask][0],
spec_freqs[fmask][-1],
],
aspect="auto",
origin="lower",
interpolation="gaussian",
alpha=1,
)
# ps.hide_xax(ax3)
ax3.set_xlabel("time [ms]")
ax2.set_ylabel("frequency [Hz]")
ax1.set_yticks(np.arange(400, 1201, 400))
ax1.spines.left.set_bounds((400, 1200))
ax2.set_yticks(np.arange(400, 1201, 400))
ax2.spines.left.set_bounds((400, 1200))
ax3.set_yticks(np.arange(400, 1201, 400))
ax3.spines.left.set_bounds((400, 1200))
plt.subplots_adjust(left=0.17, right=0.98, top=0.9,
bottom=0.14, hspace=0.35)
plt.savefig('../poster/figs/introplot.pdf')
plt.show()
if __name__ == '__main__':
main()