plot works

This commit is contained in:
weygoldt 2023-01-19 13:47:18 +01:00
parent a79fc86ab9
commit c2de6c7060
5 changed files with 109 additions and 232 deletions

View File

@ -1,8 +1,7 @@
from itertools import combinations, compress from itertools import compress
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
from tqdm import tqdm
from IPython import embed from IPython import embed
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from scipy.signal import find_peaks from scipy.signal import find_peaks
@ -12,17 +11,19 @@ from thunderfish.powerspectrum import spectrogram, decibel
from sklearn.preprocessing import normalize from sklearn.preprocessing import normalize
from modules.filters import bandpass_filter, envelope, highpass_filter from modules.filters import bandpass_filter, envelope, highpass_filter
from modules.filehandling import ConfLoader, LoadData from modules.filehandling import ConfLoader, LoadData, make_outputdir
from modules.datahandling import flatten, purge_duplicates, group_timestamps from modules.datahandling import flatten, purge_duplicates, group_timestamps
from modules.plotstyle import PlotStyle from modules.plotstyle import PlotStyle
from modules.logger import makeLogger from modules.logger import makeLogger
logger = makeLogger(__name__) logger = makeLogger(__name__)
ps = PlotStyle() ps = PlotStyle()
@dataclass @dataclass
class PlotBuffer: class PlotBuffer:
config: ConfLoader
t0: float t0: float
dt: float dt: float
track_id: float track_id: float
@ -42,20 +43,20 @@ class PlotBuffer:
frequency_filtered: np.ndarray frequency_filtered: np.ndarray
frequency_peaks: np.ndarray frequency_peaks: np.ndarray
def plot_buffer(self, chirps) -> None: def plot_buffer(self, chirps: np.ndarray, plot: str) -> None:
logger.debug("Starting plotting") logger.debug("Starting plotting")
# make data for plotting # make data for plotting
# get index of track data in this time window # # get index of track data in this time window
window_idx = np.arange(len(self.data.idx))[ # window_idx = np.arange(len(self.data.idx))[
(self.data.ident == self.track_id) & (self.data.time[self.data.idx] >= self.t0) & ( # (self.data.ident == self.track_id) & (self.data.time[self.data.idx] >= self.t0) & (
self.data.time[self.data.idx] <= (self.t0 + self.dt)) # self.data.time[self.data.idx] <= (self.t0 + self.dt))
] # ]
# get tracked frequencies and their times # get tracked frequencies and their times
freq_temp = self.data.freq[window_idx] # freq_temp = self.data.freq[window_idx]
# time_temp = self.data.times[window_idx] # time_temp = self.data.times[window_idx]
# get indices on raw data # get indices on raw data
@ -113,7 +114,8 @@ class PlotBuffer:
self.frequency_filtered[self.frequency_peaks], self.frequency_filtered[self.frequency_peaks],
c=ps.red, c=ps.red,
) )
axs[0].set_ylim(np.max(self.frequency)-200, top=np.max(self.frequency)+200) axs[0].set_ylim(np.max(self.frequency)-200,
top=np.max(self.frequency)+200)
axs[6].set_xlabel("Time [s]") axs[6].set_xlabel("Time [s]")
axs[0].set_title("Spectrogram") axs[0].set_title("Spectrogram")
axs[1].set_title("Fitered baseline") axs[1].set_title("Fitered baseline")
@ -123,7 +125,16 @@ class PlotBuffer:
axs[5].set_title("Search envelope") axs[5].set_title("Search envelope")
axs[6].set_title( axs[6].set_title(
"Filtered absolute instantaneous frequency") "Filtered absolute instantaneous frequency")
plt.show()
if plot == 'show':
plt.show()
elif plot == 'save':
make_outputdir(self.config.outputdir)
out = make_outputdir(self.config.outputdir +
self.data.datapath.split('/')[-2] + '/')
plt.savefig(f"{out}{self.track_id}_{self.t0}.pdf")
plt.close()
def instantaneos_frequency( def instantaneos_frequency(
@ -248,6 +259,45 @@ def double_bandpass(
return (filtered_baseline, filtered_search_freq) return (filtered_baseline, filtered_search_freq)
def freqmedian_allfish(data: LoadData, t0: float, dt: float) -> tuple[float, list[int]]:
"""
Calculate the median frequency of all fish in a given time window.
Parameters
----------
data : LoadData
Data to calculate the median frequency from.
t0 : float
Start time of the window.
dt : float
Duration of the window.
Returns
-------
tuple[float, list[int]]
"""
median_freq = []
track_ids = []
for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
window_idx = np.arange(len(data.idx))[
(data.ident == track_id) & (data.time[data.idx] >= t0) & (
data.time[data.idx] <= (t0 + dt))
]
if len(data.freq[window_idx]) > 0:
median_freq.append(np.median(data.freq[window_idx]))
track_ids.append(track_id)
# convert to numpy array
median_freq = np.asarray(median_freq)
track_ids = np.asarray(track_ids)
return median_freq, track_ids
def main(datapath: str, plot: str) -> None: def main(datapath: str, plot: str) -> None:
assert plot in ["save", "show", "false"] assert plot in ["save", "show", "false"]
@ -279,7 +329,7 @@ def main(datapath: str, plot: str) -> None:
raw_time = np.arange(data.raw.shape[0]) / data.raw_rate raw_time = np.arange(data.raw.shape[0]) / data.raw_rate
# # good chirp times for data: 2022-06-02-10_00 # # good chirp times for data: 2022-06-02-10_00
#t0 = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate # t0 = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
# dt = 60 * data.raw_rate # dt = 60 * data.raw_rate
t0 = 0 t0 = 0
@ -293,18 +343,13 @@ def main(datapath: str, plot: str) -> None:
dtype=int dtype=int
) )
# # ask how many windows should be calulated
# nwindows = int(
# input("How many windows should be calculated (integer number)? "))
# ititialize lists to store data # ititialize lists to store data
chirps = [] chirps = []
fish_ids = [] fish_ids = []
for st, start_index in tqdm(enumerate(window_starts)): for st, start_index in enumerate(window_starts):
#print(f"Processing window {st/data.raw_rate} of {len(window_starts/data.raw_rate)}")
logger.debug(f"Processing window {st} of {len(window_starts)}") logger.info(f"Processing window {st} of {len(window_starts)}")
# make t0 and dt # make t0 and dt
t0 = start_index / data.raw_rate t0 = start_index / data.raw_rate
@ -314,25 +359,12 @@ def main(datapath: str, plot: str) -> None:
stop_index = start_index + window_duration stop_index = start_index + window_duration
# calucate median of fish frequencies in window # calucate median of fish frequencies in window
median_freq = [] median_freq, median_ids = freqmedian_allfish(data, t0, dt)
track_ids = []
for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
window_idx = np.arange(len(data.idx))[
(data.ident == track_id) & (data.time[data.idx] >= t0) & (
data.time[data.idx] <= (t0 + dt))
]
median_freq.append(np.median(data.freq[window_idx]))
track_ids.append(track_id)
# convert to numpy array
median_freq = np.asarray(median_freq)
track_ids = np.asarray(track_ids)
# iterate through all fish # iterate through all fish
for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])): for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
logger.debug(f"Processing track {tr} of {len(track_ids)}") logger.debug(f"Processing track {tr} of {len(data.ids)}")
# get index of track data in this time window # get index of track data in this time window
window_idx = np.arange(len(data.idx))[ window_idx = np.arange(len(data.idx))[
@ -350,10 +382,18 @@ def main(datapath: str, plot: str) -> None:
expected_duration = ((t0 + dt) - t0) * track_samplerate expected_duration = ((t0 + dt) - t0) * track_samplerate
# check if tracked data available in this window # check if tracked data available in this window
if len(freq_temp) < expected_duration * 0.9: if len(freq_temp) < expected_duration * 0.5:
logger.warning(
f"Track {track_id} has no data in window {st}, skipping.")
continue
# check if there are powers available in this window
nanchecker = np.unique(np.isnan(powers_temp))
if (len(nanchecker) == 1) and nanchecker[0] == True:
logger.warning(
f"No powers available for track {track_id} window {st}, skipping.")
continue continue
# get best electrode
best_electrodes = np.argsort(np.nanmean( best_electrodes = np.argsort(np.nanmean(
powers_temp, axis=0))[-config.number_electrodes:] powers_temp, axis=0))[-config.number_electrodes:]
@ -366,7 +406,7 @@ def main(datapath: str, plot: str) -> None:
search_window_bool = np.ones(len(search_window), dtype=bool) search_window_bool = np.ones(len(search_window), dtype=bool)
# get tracks that fall into search window # get tracks that fall into search window
check_track_ids = track_ids[(median_freq > search_window[0]) & ( check_track_ids = median_ids[(median_freq > search_window[0]) & (
median_freq < search_window[-1])] median_freq < search_window[-1])]
# iterate through theses tracks # iterate through theses tracks
@ -429,10 +469,8 @@ def main(datapath: str, plot: str) -> None:
else: else:
search_freq = config.default_search_freq search_freq = config.default_search_freq
#print(f"Search frequency: {search_freq}")
# ----------- chrips on the two best electrodes----------- # ----------- chrips on the two best electrodes-----------
chirps_electrodes = [] chirps_electrodes = []
electrodes_of_chirps = []
# iterate through electrodes # iterate through electrodes
for el, electrode in enumerate(best_electrodes): for el, electrode in enumerate(best_electrodes):
@ -560,77 +598,6 @@ def main(datapath: str, plot: str) -> None:
prominence=prominence prominence=prominence
) )
# # PLOT --------------------------------------------------------
# # plot spectrogram
# plot_spectrogram(
# axs[0, el], data_oi[:, electrode], data.raw_rate, t0)
# # plot baseline instantaneos frequency
# axs[1, el].plot(baseline_freq_time, baseline_freq -
# np.median(baseline_freq))
# # plot waveform of filtered signal
# axs[2, el].plot(time_oi, baseline, c=ps.green)
# # plot broad filtered baseline
# axs[2, el].plot(
# time_oi,
# broad_baseline,
# )
# # plot narrow filtered baseline envelope
# axs[2, el].plot(
# time_oi,
# baseline_envelope_unfiltered,
# c=ps.red
# )
# # plot waveform of filtered search signal
# axs[3, el].plot(time_oi, search)
# # plot envelope of search signal
# axs[3, el].plot(
# time_oi,
# search_envelope,
# c=ps.red
# )
# # plot filtered and rectified envelope
# axs[4, el].plot(time_oi, baseline_envelope)
# axs[4, el].scatter(
# (time_oi)[baseline_peaks],
# baseline_envelope[baseline_peaks],
# c=ps.red,
# )
# # plot envelope of search signal
# axs[5, el].plot(time_oi, search_envelope)
# axs[5, el].scatter(
# (time_oi)[search_peaks],
# search_envelope[search_peaks],
# c=ps.red,
# )
# # plot filtered instantaneous frequency
# axs[6, el].plot(baseline_freq_time, np.abs(inst_freq_filtered))
# axs[6, el].scatter(
# baseline_freq_time[inst_freq_peaks],
# np.abs(inst_freq_filtered)[inst_freq_peaks],
# c=ps.red,
# )
# axs[6, el].set_xlabel("Time [s]")
# axs[0, el].set_title("Spectrogram")
# axs[1, el].set_title("Fitered baseline instanenous frequency")
# axs[2, el].set_title("Fitered baseline")
# axs[3, el].set_title("Fitered above")
# axs[4, el].set_title("Filtered envelope of baseline envelope")
# axs[5, el].set_title("Search envelope")
# axs[6, el].set_title(
# "Filtered absolute instantaneous frequency")
# DETECT CHIRPS IN SEARCH WINDOW ------------------------------- # DETECT CHIRPS IN SEARCH WINDOW -------------------------------
baseline_ts = time_oi[baseline_peaks] baseline_ts = time_oi[baseline_peaks]
@ -641,69 +608,14 @@ def main(datapath: str, plot: str) -> None:
if len(baseline_ts) == 0 or len(search_ts) == 0 or len(freq_ts) == 0: if len(baseline_ts) == 0 or len(search_ts) == 0 or len(freq_ts) == 0:
continue continue
# current_chirps = group_timestamps_v2( current_chirps = group_timestamps(
# [list(baseline_ts), list(search_ts), list(freq_ts)], 3) [list(baseline_ts), list(search_ts), list(freq_ts)], 3, config.chirp_window_threshold)
# get index for each feature
baseline_idx = np.zeros_like(baseline_ts)
search_idx = np.ones_like(search_ts)
freq_idx = np.ones_like(freq_ts) * 2
timestamps_features = np.hstack(
[baseline_idx, search_idx, freq_idx])
timestamps = np.hstack([baseline_ts, search_ts, freq_ts])
# sort timestamps
timestamps_idx = np.arange(len(timestamps))
timestamps_features = timestamps_features[np.argsort(
timestamps)]
timestamps = timestamps[np.argsort(timestamps)]
# # get chirps
# diff = np.empty(timestamps.shape)
# diff[0] = np.inf # always retain the 1st element
# diff[1:] = np.diff(timestamps)
# mask = diff < config.chirp_window_threshold
# shared_peak_indices = timestamp_idx[mask]
current_chirps = []
bool_timestamps = np.ones_like(timestamps, dtype=bool)
for bo, tt in enumerate(timestamps):
if bool_timestamps[bo] is False:
continue
cm = timestamps_idx[(timestamps >= tt) & (
timestamps <= tt + config.chirp_window_threshold)]
if set([0, 1, 2]).issubset(timestamps_features[cm]):
current_chirps.append(np.mean(timestamps[cm]))
electrodes_of_chirps.append(el)
bool_timestamps[cm] = False
# for checking if there are chirps on multiple electrodes # for checking if there are chirps on multiple electrodes
if len(current_chirps) == 0: if len(current_chirps) == 0:
continue continue
chirps_electrodes.append(current_chirps) chirps_electrodes.append(current_chirps)
# for ct in current_chirps:
# axs[0, el].axvline(ct, color='r', lw=1)
# axs[0, el].scatter(
# baseline_freq_time[inst_freq_peaks],
# np.ones_like(baseline_freq_time[inst_freq_peaks]) * 600,
# c=ps.red,
# )
# axs[0, el].scatter(
# (time_oi)[search_peaks],
# np.ones_like((time_oi)[search_peaks]) * 600,
# c=ps.red,
# )
# axs[0, el].scatter(
# (time_oi)[baseline_peaks],
# np.ones_like((time_oi)[baseline_peaks]) * 600,
# c=ps.red,
# )
if (el == config.number_electrodes - 1) & \ if (el == config.number_electrodes - 1) & \
(len(current_chirps) > 0) & \ (len(current_chirps) > 0) & \
(plot in ["show", "save"]): (plot in ["show", "save"]):
@ -712,6 +624,7 @@ def main(datapath: str, plot: str) -> None:
# save data to Buffer # save data to Buffer
buffer = PlotBuffer( buffer = PlotBuffer(
config=config,
t0=t0, t0=t0,
dt=dt, dt=dt,
electrode=electrode, electrode=electrode,
@ -735,70 +648,19 @@ def main(datapath: str, plot: str) -> None:
logger.debug( logger.debug(
f"Processed all electrodes for fish {track_id} for this window, sorting chirps ...") f"Processed all electrodes for fish {track_id} for this window, sorting chirps ...")
# continue if no chirps for current fish
# make one array
# chirps_electrodes = np.concatenate(chirps_electrodes)
# make shure they are numpy arrays
# electrodes_of_chirps = np.asarray(electrodes_of_chirps)
# # sort them
# sort_chirps_electrodes = chirps_electrodes[np.argsort(
# chirps_electrodes)]
# sort_electrodes = electrodes_of_chirps[np.argsort(
# chirps_electrodes)]
# bool_vector = np.ones(len(sort_chirps_electrodes), dtype=bool)
# # make index vector
# index_vector = np.arange(len(sort_chirps_electrodes))
# # make it more than only two electrodes for the search after chirps
# combinations_best_elctrodes = list(
# combinations(range(3), 2))
if len(chirps_electrodes) == 0: if len(chirps_electrodes) == 0:
continue continue
the_real_chirps = group_timestamps(chirps_electrodes, 2, 0.05) the_real_chirps = group_timestamps(chirps_electrodes, 2, 0.05)
# for chirp_index, seoc in enumerate(sort_chirps_electrodes):
# if bool_vector[chirp_index] is False:
# continue
# cm = index_vector[(sort_chirps_electrodes >= seoc) & (
# sort_chirps_electrodes <= seoc + config.chirp_window_threshold)]
# chirps_unique = []
# for combination in combinations_best_elctrodes:
# if set(combination).issubset(sort_electrodes[cm]):
# chirps_unique.append(
# np.mean(sort_chirps_electrodes[cm]))
# the_real_chirps.append(np.mean(chirps_unique))
# """
# if set([0,1]).issubset(sort_electrodes[cm]):
# the_real_chirps.append(np.mean(sort_chirps_electrodes[cm]))
# elif set([1,0]).issubset(sort_electrodes[cm]):
# the_real_chirps.append(np.mean(sort_chirps_electrodes[cm]))
# elif set([0,2]).issubset(sort_electrodes[cm]):
# the_real_chirps.append(np.mean(sort_chirps_electrodes[cm]))
# elif set([1,2]).issubset(sort_electrodes[cm]):
# the_real_chirps.append(np.mean(sort_chirps_electrodes[cm]))
# """
# bool_vector[cm] = False
chirps.append(the_real_chirps) chirps.append(the_real_chirps)
fish_ids.append(track_id) fish_ids.append(track_id)
# for ct in the_real_chirps:
# axs[0, el].axvline(ct, color='b', lw=1)
logger.debug('Found %d chirps, starting plotting ... ' % logger.debug('Found %d chirps, starting plotting ... ' %
len(the_real_chirps)) len(the_real_chirps))
if len(the_real_chirps) > 0: if len(the_real_chirps) > 0:
try: try:
buffer.plot_buffer(the_real_chirps) buffer.plot_buffer(the_real_chirps, plot)
except NameError: except NameError:
pass pass
else: else:
@ -807,14 +669,6 @@ def main(datapath: str, plot: str) -> None:
except NameError: except NameError:
pass pass
# fig, ax = plt.subplots()
# t0 = (3 * 60 * 60 + 6 * 60 + 43.5)
# data_oi = data.raw[window_starts[0]:window_starts[-1] + int(dt*data.raw_rate), 10]
# plot_spectrogram(ax, data_oi, data.raw_rate, t0)
# chirps_concat = np.concatenate(chirps)
# for ch in chirps_concat:
# ax. axvline(ch, color='b', lw=1)
chirps_new = [] chirps_new = []
chirps_ids = [] chirps_ids = []
for tr in np.unique(fish_ids): for tr in np.unique(fish_ids):
@ -837,4 +691,4 @@ def main(datapath: str, plot: str) -> None:
if __name__ == "__main__": if __name__ == "__main__":
datapath = "../data/2022-06-02-10_00/" datapath = "../data/2022-06-02-10_00/"
main(datapath, plot="show") main(datapath, plot="save")

4
code/chirpdetector_conf.yml Normal file → Executable file
View File

@ -1,3 +1,6 @@
dataroot: "../data/"
outputdir: "../output/"
# Duration and overlap of the analysis window in seconds # Duration and overlap of the analysis window in seconds
window: 5 window: 5
overlap: 1 overlap: 1
@ -40,7 +43,6 @@ search_freq_percentiles:
- 95 - 95
default_search_freq: 50 default_search_freq: 50
chirp_window_threshold: 0.05 chirp_window_threshold: 0.05

View File

@ -1,5 +1,5 @@
import numpy as np import numpy as np
from typing import List, Union, Any from typing import List, Any
def purge_duplicates( def purge_duplicates(

View File

@ -36,6 +36,7 @@ class LoadData:
def __init__(self, datapath: str) -> None: def __init__(self, datapath: str) -> None:
# load raw data # load raw data
self.datapath = datapath
self.file = os.path.join(datapath, "traces-grid1.raw") self.file = os.path.join(datapath, "traces-grid1.raw")
self.raw = DataLoader(self.file, 60.0, 0, channel=-1) self.raw = DataLoader(self.file, 60.0, 0, channel=-1)
self.raw_rate = self.raw.samplerate self.raw_rate = self.raw.samplerate
@ -53,3 +54,23 @@ class LoadData:
def __str__(self) -> str: def __str__(self) -> str:
return f"LoadData({self.file})" return f"LoadData({self.file})"
def make_outputdir(path: str) -> str:
"""
Creates a new directory where the path leads if it does not already exist.
Parameters
----------
path : string
path to the new output directory
Returns
-------
string
path of the newly created output directory
"""
if os.path.isdir(path) == False:
os.mkdir(path)
return path

View File

@ -23,7 +23,7 @@ def makeLogger(name: str):
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.addHandler(file_handler) logger.addHandler(file_handler)
logger.addHandler(console_handler) logger.addHandler(console_handler)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.INFO)
return logger return logger