plot works

This commit is contained in:
weygoldt 2023-01-19 13:47:18 +01:00
parent a79fc86ab9
commit c2de6c7060
5 changed files with 109 additions and 232 deletions

View File

@ -1,8 +1,7 @@
from itertools import combinations, compress
from itertools import compress
from dataclasses import dataclass
import numpy as np
from tqdm import tqdm
from IPython import embed
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
@ -12,17 +11,19 @@ from thunderfish.powerspectrum import spectrogram, decibel
from sklearn.preprocessing import normalize
from modules.filters import bandpass_filter, envelope, highpass_filter
from modules.filehandling import ConfLoader, LoadData
from modules.filehandling import ConfLoader, LoadData, make_outputdir
from modules.datahandling import flatten, purge_duplicates, group_timestamps
from modules.plotstyle import PlotStyle
from modules.logger import makeLogger
logger = makeLogger(__name__)
ps = PlotStyle()
@dataclass
class PlotBuffer:
config: ConfLoader
t0: float
dt: float
track_id: float
@ -42,20 +43,20 @@ class PlotBuffer:
frequency_filtered: np.ndarray
frequency_peaks: np.ndarray
def plot_buffer(self, chirps) -> None:
def plot_buffer(self, chirps: np.ndarray, plot: str) -> None:
logger.debug("Starting plotting")
# make data for plotting
# get index of track data in this time window
window_idx = np.arange(len(self.data.idx))[
(self.data.ident == self.track_id) & (self.data.time[self.data.idx] >= self.t0) & (
self.data.time[self.data.idx] <= (self.t0 + self.dt))
]
# # get index of track data in this time window
# window_idx = np.arange(len(self.data.idx))[
# (self.data.ident == self.track_id) & (self.data.time[self.data.idx] >= self.t0) & (
# self.data.time[self.data.idx] <= (self.t0 + self.dt))
# ]
# get tracked frequencies and their times
freq_temp = self.data.freq[window_idx]
# freq_temp = self.data.freq[window_idx]
# time_temp = self.data.times[window_idx]
# get indices on raw data
@ -113,7 +114,8 @@ class PlotBuffer:
self.frequency_filtered[self.frequency_peaks],
c=ps.red,
)
axs[0].set_ylim(np.max(self.frequency)-200, top=np.max(self.frequency)+200)
axs[0].set_ylim(np.max(self.frequency)-200,
top=np.max(self.frequency)+200)
axs[6].set_xlabel("Time [s]")
axs[0].set_title("Spectrogram")
axs[1].set_title("Fitered baseline")
@ -123,7 +125,16 @@ class PlotBuffer:
axs[5].set_title("Search envelope")
axs[6].set_title(
"Filtered absolute instantaneous frequency")
plt.show()
if plot == 'show':
plt.show()
elif plot == 'save':
make_outputdir(self.config.outputdir)
out = make_outputdir(self.config.outputdir +
self.data.datapath.split('/')[-2] + '/')
plt.savefig(f"{out}{self.track_id}_{self.t0}.pdf")
plt.close()
def instantaneos_frequency(
@ -248,6 +259,45 @@ def double_bandpass(
return (filtered_baseline, filtered_search_freq)
def freqmedian_allfish(data: LoadData, t0: float, dt: float) -> tuple[float, list[int]]:
"""
Calculate the median frequency of all fish in a given time window.
Parameters
----------
data : LoadData
Data to calculate the median frequency from.
t0 : float
Start time of the window.
dt : float
Duration of the window.
Returns
-------
tuple[float, list[int]]
"""
median_freq = []
track_ids = []
for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
window_idx = np.arange(len(data.idx))[
(data.ident == track_id) & (data.time[data.idx] >= t0) & (
data.time[data.idx] <= (t0 + dt))
]
if len(data.freq[window_idx]) > 0:
median_freq.append(np.median(data.freq[window_idx]))
track_ids.append(track_id)
# convert to numpy array
median_freq = np.asarray(median_freq)
track_ids = np.asarray(track_ids)
return median_freq, track_ids
def main(datapath: str, plot: str) -> None:
assert plot in ["save", "show", "false"]
@ -279,7 +329,7 @@ def main(datapath: str, plot: str) -> None:
raw_time = np.arange(data.raw.shape[0]) / data.raw_rate
# # good chirp times for data: 2022-06-02-10_00
#t0 = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
# t0 = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
# dt = 60 * data.raw_rate
t0 = 0
@ -293,18 +343,13 @@ def main(datapath: str, plot: str) -> None:
dtype=int
)
# # ask how many windows should be calulated
# nwindows = int(
# input("How many windows should be calculated (integer number)? "))
# ititialize lists to store data
chirps = []
fish_ids = []
for st, start_index in tqdm(enumerate(window_starts)):
#print(f"Processing window {st/data.raw_rate} of {len(window_starts/data.raw_rate)}")
for st, start_index in enumerate(window_starts):
logger.debug(f"Processing window {st} of {len(window_starts)}")
logger.info(f"Processing window {st} of {len(window_starts)}")
# make t0 and dt
t0 = start_index / data.raw_rate
@ -314,25 +359,12 @@ def main(datapath: str, plot: str) -> None:
stop_index = start_index + window_duration
# calucate median of fish frequencies in window
median_freq = []
track_ids = []
for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
window_idx = np.arange(len(data.idx))[
(data.ident == track_id) & (data.time[data.idx] >= t0) & (
data.time[data.idx] <= (t0 + dt))
]
median_freq.append(np.median(data.freq[window_idx]))
track_ids.append(track_id)
# convert to numpy array
median_freq = np.asarray(median_freq)
track_ids = np.asarray(track_ids)
median_freq, median_ids = freqmedian_allfish(data, t0, dt)
# iterate through all fish
for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
logger.debug(f"Processing track {tr} of {len(track_ids)}")
logger.debug(f"Processing track {tr} of {len(data.ids)}")
# get index of track data in this time window
window_idx = np.arange(len(data.idx))[
@ -350,10 +382,18 @@ def main(datapath: str, plot: str) -> None:
expected_duration = ((t0 + dt) - t0) * track_samplerate
# check if tracked data available in this window
if len(freq_temp) < expected_duration * 0.9:
if len(freq_temp) < expected_duration * 0.5:
logger.warning(
f"Track {track_id} has no data in window {st}, skipping.")
continue
# check if there are powers available in this window
nanchecker = np.unique(np.isnan(powers_temp))
if (len(nanchecker) == 1) and nanchecker[0] == True:
logger.warning(
f"No powers available for track {track_id} window {st}, skipping.")
continue
# get best electrode
best_electrodes = np.argsort(np.nanmean(
powers_temp, axis=0))[-config.number_electrodes:]
@ -366,7 +406,7 @@ def main(datapath: str, plot: str) -> None:
search_window_bool = np.ones(len(search_window), dtype=bool)
# get tracks that fall into search window
check_track_ids = track_ids[(median_freq > search_window[0]) & (
check_track_ids = median_ids[(median_freq > search_window[0]) & (
median_freq < search_window[-1])]
# iterate through theses tracks
@ -429,10 +469,8 @@ def main(datapath: str, plot: str) -> None:
else:
search_freq = config.default_search_freq
#print(f"Search frequency: {search_freq}")
# ----------- chrips on the two best electrodes-----------
chirps_electrodes = []
electrodes_of_chirps = []
# iterate through electrodes
for el, electrode in enumerate(best_electrodes):
@ -560,77 +598,6 @@ def main(datapath: str, plot: str) -> None:
prominence=prominence
)
# # PLOT --------------------------------------------------------
# # plot spectrogram
# plot_spectrogram(
# axs[0, el], data_oi[:, electrode], data.raw_rate, t0)
# # plot baseline instantaneos frequency
# axs[1, el].plot(baseline_freq_time, baseline_freq -
# np.median(baseline_freq))
# # plot waveform of filtered signal
# axs[2, el].plot(time_oi, baseline, c=ps.green)
# # plot broad filtered baseline
# axs[2, el].plot(
# time_oi,
# broad_baseline,
# )
# # plot narrow filtered baseline envelope
# axs[2, el].plot(
# time_oi,
# baseline_envelope_unfiltered,
# c=ps.red
# )
# # plot waveform of filtered search signal
# axs[3, el].plot(time_oi, search)
# # plot envelope of search signal
# axs[3, el].plot(
# time_oi,
# search_envelope,
# c=ps.red
# )
# # plot filtered and rectified envelope
# axs[4, el].plot(time_oi, baseline_envelope)
# axs[4, el].scatter(
# (time_oi)[baseline_peaks],
# baseline_envelope[baseline_peaks],
# c=ps.red,
# )
# # plot envelope of search signal
# axs[5, el].plot(time_oi, search_envelope)
# axs[5, el].scatter(
# (time_oi)[search_peaks],
# search_envelope[search_peaks],
# c=ps.red,
# )
# # plot filtered instantaneous frequency
# axs[6, el].plot(baseline_freq_time, np.abs(inst_freq_filtered))
# axs[6, el].scatter(
# baseline_freq_time[inst_freq_peaks],
# np.abs(inst_freq_filtered)[inst_freq_peaks],
# c=ps.red,
# )
# axs[6, el].set_xlabel("Time [s]")
# axs[0, el].set_title("Spectrogram")
# axs[1, el].set_title("Fitered baseline instanenous frequency")
# axs[2, el].set_title("Fitered baseline")
# axs[3, el].set_title("Fitered above")
# axs[4, el].set_title("Filtered envelope of baseline envelope")
# axs[5, el].set_title("Search envelope")
# axs[6, el].set_title(
# "Filtered absolute instantaneous frequency")
# DETECT CHIRPS IN SEARCH WINDOW -------------------------------
baseline_ts = time_oi[baseline_peaks]
@ -641,69 +608,14 @@ def main(datapath: str, plot: str) -> None:
if len(baseline_ts) == 0 or len(search_ts) == 0 or len(freq_ts) == 0:
continue
# current_chirps = group_timestamps_v2(
# [list(baseline_ts), list(search_ts), list(freq_ts)], 3)
# get index for each feature
baseline_idx = np.zeros_like(baseline_ts)
search_idx = np.ones_like(search_ts)
freq_idx = np.ones_like(freq_ts) * 2
timestamps_features = np.hstack(
[baseline_idx, search_idx, freq_idx])
timestamps = np.hstack([baseline_ts, search_ts, freq_ts])
# sort timestamps
timestamps_idx = np.arange(len(timestamps))
timestamps_features = timestamps_features[np.argsort(
timestamps)]
timestamps = timestamps[np.argsort(timestamps)]
# # get chirps
# diff = np.empty(timestamps.shape)
# diff[0] = np.inf # always retain the 1st element
# diff[1:] = np.diff(timestamps)
# mask = diff < config.chirp_window_threshold
# shared_peak_indices = timestamp_idx[mask]
current_chirps = []
bool_timestamps = np.ones_like(timestamps, dtype=bool)
for bo, tt in enumerate(timestamps):
if bool_timestamps[bo] is False:
continue
cm = timestamps_idx[(timestamps >= tt) & (
timestamps <= tt + config.chirp_window_threshold)]
if set([0, 1, 2]).issubset(timestamps_features[cm]):
current_chirps.append(np.mean(timestamps[cm]))
electrodes_of_chirps.append(el)
bool_timestamps[cm] = False
current_chirps = group_timestamps(
[list(baseline_ts), list(search_ts), list(freq_ts)], 3, config.chirp_window_threshold)
# for checking if there are chirps on multiple electrodes
if len(current_chirps) == 0:
continue
chirps_electrodes.append(current_chirps)
# for ct in current_chirps:
# axs[0, el].axvline(ct, color='r', lw=1)
# axs[0, el].scatter(
# baseline_freq_time[inst_freq_peaks],
# np.ones_like(baseline_freq_time[inst_freq_peaks]) * 600,
# c=ps.red,
# )
# axs[0, el].scatter(
# (time_oi)[search_peaks],
# np.ones_like((time_oi)[search_peaks]) * 600,
# c=ps.red,
# )
# axs[0, el].scatter(
# (time_oi)[baseline_peaks],
# np.ones_like((time_oi)[baseline_peaks]) * 600,
# c=ps.red,
# )
if (el == config.number_electrodes - 1) & \
(len(current_chirps) > 0) & \
(plot in ["show", "save"]):
@ -712,6 +624,7 @@ def main(datapath: str, plot: str) -> None:
# save data to Buffer
buffer = PlotBuffer(
config=config,
t0=t0,
dt=dt,
electrode=electrode,
@ -735,70 +648,19 @@ def main(datapath: str, plot: str) -> None:
logger.debug(
f"Processed all electrodes for fish {track_id} for this window, sorting chirps ...")
# continue if no chirps for current fish
# make one array
# chirps_electrodes = np.concatenate(chirps_electrodes)
# make shure they are numpy arrays
# electrodes_of_chirps = np.asarray(electrodes_of_chirps)
# # sort them
# sort_chirps_electrodes = chirps_electrodes[np.argsort(
# chirps_electrodes)]
# sort_electrodes = electrodes_of_chirps[np.argsort(
# chirps_electrodes)]
# bool_vector = np.ones(len(sort_chirps_electrodes), dtype=bool)
# # make index vector
# index_vector = np.arange(len(sort_chirps_electrodes))
# # make it more than only two electrodes for the search after chirps
# combinations_best_elctrodes = list(
# combinations(range(3), 2))
if len(chirps_electrodes) == 0:
continue
the_real_chirps = group_timestamps(chirps_electrodes, 2, 0.05)
# for chirp_index, seoc in enumerate(sort_chirps_electrodes):
# if bool_vector[chirp_index] is False:
# continue
# cm = index_vector[(sort_chirps_electrodes >= seoc) & (
# sort_chirps_electrodes <= seoc + config.chirp_window_threshold)]
# chirps_unique = []
# for combination in combinations_best_elctrodes:
# if set(combination).issubset(sort_electrodes[cm]):
# chirps_unique.append(
# np.mean(sort_chirps_electrodes[cm]))
# the_real_chirps.append(np.mean(chirps_unique))
# """
# if set([0,1]).issubset(sort_electrodes[cm]):
# the_real_chirps.append(np.mean(sort_chirps_electrodes[cm]))
# elif set([1,0]).issubset(sort_electrodes[cm]):
# the_real_chirps.append(np.mean(sort_chirps_electrodes[cm]))
# elif set([0,2]).issubset(sort_electrodes[cm]):
# the_real_chirps.append(np.mean(sort_chirps_electrodes[cm]))
# elif set([1,2]).issubset(sort_electrodes[cm]):
# the_real_chirps.append(np.mean(sort_chirps_electrodes[cm]))
# """
# bool_vector[cm] = False
chirps.append(the_real_chirps)
fish_ids.append(track_id)
# for ct in the_real_chirps:
# axs[0, el].axvline(ct, color='b', lw=1)
logger.debug('Found %d chirps, starting plotting ... ' %
len(the_real_chirps))
if len(the_real_chirps) > 0:
try:
buffer.plot_buffer(the_real_chirps)
buffer.plot_buffer(the_real_chirps, plot)
except NameError:
pass
else:
@ -807,14 +669,6 @@ def main(datapath: str, plot: str) -> None:
except NameError:
pass
# fig, ax = plt.subplots()
# t0 = (3 * 60 * 60 + 6 * 60 + 43.5)
# data_oi = data.raw[window_starts[0]:window_starts[-1] + int(dt*data.raw_rate), 10]
# plot_spectrogram(ax, data_oi, data.raw_rate, t0)
# chirps_concat = np.concatenate(chirps)
# for ch in chirps_concat:
# ax. axvline(ch, color='b', lw=1)
chirps_new = []
chirps_ids = []
for tr in np.unique(fish_ids):
@ -837,4 +691,4 @@ def main(datapath: str, plot: str) -> None:
if __name__ == "__main__":
datapath = "../data/2022-06-02-10_00/"
main(datapath, plot="show")
main(datapath, plot="save")

4
code/chirpdetector_conf.yml Normal file → Executable file
View File

@ -1,3 +1,6 @@
dataroot: "../data/"
outputdir: "../output/"
# Duration and overlap of the analysis window in seconds
window: 5
overlap: 1
@ -40,7 +43,6 @@ search_freq_percentiles:
- 95
default_search_freq: 50
chirp_window_threshold: 0.05

View File

@ -1,5 +1,5 @@
import numpy as np
from typing import List, Union, Any
from typing import List, Any
def purge_duplicates(

View File

@ -36,6 +36,7 @@ class LoadData:
def __init__(self, datapath: str) -> None:
# load raw data
self.datapath = datapath
self.file = os.path.join(datapath, "traces-grid1.raw")
self.raw = DataLoader(self.file, 60.0, 0, channel=-1)
self.raw_rate = self.raw.samplerate
@ -53,3 +54,23 @@ class LoadData:
def __str__(self) -> str:
return f"LoadData({self.file})"
def make_outputdir(path: str) -> str:
"""
Creates a new directory where the path leads if it does not already exist.
Parameters
----------
path : string
path to the new output directory
Returns
-------
string
path of the newly created output directory
"""
if os.path.isdir(path) == False:
os.mkdir(path)
return path

View File

@ -23,7 +23,7 @@ def makeLogger(name: str):
logger = logging.getLogger(name)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
logger.setLevel(logging.DEBUG)
logger.setLevel(logging.INFO)
return logger