Merge branch 'master' into eventtriggeredchirps

zsh:1: command not found: q
This commit is contained in:
sprause 2023-01-24 11:08:21 +01:00
commit dc2074222c
6 changed files with 178 additions and 61 deletions

View File

@ -18,6 +18,7 @@ from modules.datahandling import (
purge_duplicates, purge_duplicates,
group_timestamps, group_timestamps,
instantaneous_frequency, instantaneous_frequency,
minmaxnorm
) )
logger = makeLogger(__name__) logger = makeLogger(__name__)
@ -26,7 +27,7 @@ ps = PlotStyle()
@dataclass @dataclass
class PlotBuffer: class ChirpPlotBuffer:
""" """
Buffer to save data that is created in the main detection loop Buffer to save data that is created in the main detection loop
@ -83,6 +84,7 @@ class PlotBuffer:
q50 + self.search_frequency + self.config.minimal_bandwidth / 2, q50 + self.search_frequency + self.config.minimal_bandwidth / 2,
q50 + self.search_frequency - self.config.minimal_bandwidth / 2, q50 + self.search_frequency - self.config.minimal_bandwidth / 2,
) )
print(search_upper, search_lower)
# get indices on raw data # get indices on raw data
start_idx = (self.t0 - 5) * self.data.raw_rate start_idx = (self.t0 - 5) * self.data.raw_rate
@ -94,7 +96,8 @@ class PlotBuffer:
self.time = self.time - self.t0 self.time = self.time - self.t0
self.frequency_time = self.frequency_time - self.t0 self.frequency_time = self.frequency_time - self.t0
chirps = np.asarray(chirps) - self.t0 if len(chirps) > 0:
chirps = np.asarray(chirps) - self.t0
self.t0_old = self.t0 self.t0_old = self.t0
self.t0 = 0 self.t0 = 0
@ -130,7 +133,7 @@ class PlotBuffer:
data_oi, data_oi,
self.data.raw_rate, self.data.raw_rate,
self.t0 - 5, self.t0 - 5,
[np.min(self.frequency) - 200, np.max(self.frequency) + 200] [np.min(self.frequency) - 100, np.max(self.frequency) + 200]
) )
for track_id in self.data.ids: for track_id in self.data.ids:
@ -181,10 +184,11 @@ class PlotBuffer:
# spec_times[0], spec_times[-1], # spec_times[0], spec_times[-1],
# color=ps.gblue2, lw=2, ls="dashed") # color=ps.gblue2, lw=2, ls="dashed")
for chirp in chirps: if len(chirps) > 0:
ax0.scatter( for chirp in chirps:
chirp, np.median(self.frequency) + 150, c=ps.black, marker="v" ax0.scatter(
) chirp, np.median(self.frequency) + 150, c=ps.black, marker="v"
)
# plot waveform of filtered signal # plot waveform of filtered signal
ax1.plot(self.time, self.baseline * waveform_scaler, ax1.plot(self.time, self.baseline * waveform_scaler,
@ -319,7 +323,7 @@ def plot_spectrogram(
aspect="auto", aspect="auto",
origin="lower", origin="lower",
interpolation="gaussian", interpolation="gaussian",
alpha=1, alpha=0.6,
) )
# axis.use_sticky_edges = False # axis.use_sticky_edges = False
return spec_times return spec_times
@ -432,6 +436,28 @@ def window_median_all_track_ids(
return frequency_percentiles, track_ids return frequency_percentiles, track_ids
def array_center(array: np.ndarray) -> float:
"""
Return the center value of an array.
If the array length is even, returns
the mean of the two center values.
Parameters
----------
array : np.ndarray
Array to calculate the center from.
Returns
-------
float
"""
if len(array) % 2 == 0:
return np.mean(array[int(len(array) / 2) - 1:int(len(array) / 2) + 1])
else:
return array[int(len(array) / 2)]
def find_searchband( def find_searchband(
current_frequency: np.ndarray, current_frequency: np.ndarray,
percentiles_ids: np.ndarray, percentiles_ids: np.ndarray,
@ -465,10 +491,10 @@ def find_searchband(
# frequency window where second filter filters is potentially allowed # frequency window where second filter filters is potentially allowed
# to filter. This is the search window, in which we want to find # to filter. This is the search window, in which we want to find
# a gap in the other fish's EODs. # a gap in the other fish's EODs.
current_median = np.median(current_frequency)
search_window = np.arange( search_window = np.arange(
np.median(current_frequency) + config.search_df_lower, current_median + config.search_df_lower,
np.median(current_frequency) + config.search_df_upper, current_median + config.search_df_upper,
config.search_res, config.search_res,
) )
@ -483,7 +509,7 @@ def find_searchband(
# get tracks that fall into search window # get tracks that fall into search window
check_track_ids = percentiles_ids[ check_track_ids = percentiles_ids[
(q25 > search_window[0]) & ( (q25 > current_median) & (
q75 < search_window[-1]) q75 < search_window[-1])
] ]
@ -511,6 +537,9 @@ def find_searchband(
nonzeros = search_window_gaps[np.nonzero(search_window_gaps)[0]] nonzeros = search_window_gaps[np.nonzero(search_window_gaps)[0]]
nonzeros = nonzeros[~np.isnan(nonzeros)] nonzeros = nonzeros[~np.isnan(nonzeros)]
if len(nonzeros) == 0:
return config.default_search_freq
# if the first value is -1, the array starst with true, so a gap # if the first value is -1, the array starst with true, so a gap
if nonzeros[0] == -1: if nonzeros[0] == -1:
stops = search_window_indices[search_window_gaps == -1] stops = search_window_indices[search_window_gaps == -1]
@ -545,16 +574,14 @@ def find_searchband(
# the center of the search frequency band is then the center of # the center of the search frequency band is then the center of
# the longest gap # the longest gap
search_freq = ( search_freq = array_center(longest_search_window) - current_median
longest_search_window[-1] - longest_search_window[0]
) / 2
return search_freq return search_freq
return config.default_search_freq return config.default_search_freq
def chirpdetection(datapath: str, plot: str) -> None: def chirpdetection(datapath: str, plot: str, debug: str = 'false') -> None:
assert plot in [ assert plot in [
"save", "save",
@ -562,6 +589,15 @@ def chirpdetection(datapath: str, plot: str) -> None:
"false", "false",
], "plot must be 'save', 'show' or 'false'" ], "plot must be 'save', 'show' or 'false'"
assert debug in [
"false",
"electrode",
"fish",
], "debug must be 'false', 'electrode' or 'fish'"
if debug != "false":
assert plot == "show", "debug mode only runs when plot is 'show'"
# load raw file # load raw file
print('datapath', datapath) print('datapath', datapath)
data = LoadData(datapath) data = LoadData(datapath)
@ -592,8 +628,8 @@ def chirpdetection(datapath: str, plot: str) -> None:
raw_time = np.arange(data.raw.shape[0]) / data.raw_rate raw_time = np.arange(data.raw.shape[0]) / data.raw_rate
# good chirp times for data: 2022-06-02-10_00 # good chirp times for data: 2022-06-02-10_00
window_start_index = (3 * 60 * 60 + 6 * 60 + 43.5 + 5) * data.raw_rate # window_start_index = (3 * 60 * 60 + 6 * 60 + 43.5 + 5) * data.raw_rate
window_duration_index = 60 * data.raw_rate # window_duration_index = 60 * data.raw_rate
# t0 = 0 # t0 = 0
# dt = data.raw.shape[0] # dt = data.raw.shape[0]
@ -753,11 +789,11 @@ def chirpdetection(datapath: str, plot: str) -> None:
baseline_envelope = -baseline_envelope baseline_envelope = -baseline_envelope
baseline_envelope = envelope( # baseline_envelope = envelope(
signal=baseline_envelope, # signal=baseline_envelope,
samplerate=data.raw_rate, # samplerate=data.raw_rate,
cutoff_frequency=config.baseline_envelope_envelope_cutoff, # cutoff_frequency=config.baseline_envelope_envelope_cutoff,
) # )
# compute the envelope of the search band. Peaks in the search # compute the envelope of the search band. Peaks in the search
# band envelope correspond to troughs in the baseline envelope # band envelope correspond to troughs in the baseline envelope
@ -791,25 +827,25 @@ def chirpdetection(datapath: str, plot: str) -> None:
# compute the envelope of the signal to remove the oscillations # compute the envelope of the signal to remove the oscillations
# around the peaks # around the peaks
baseline_frequency_samplerate = np.mean( # baseline_frequency_samplerate = np.mean(
np.diff(baseline_frequency_time) # np.diff(baseline_frequency_time)
) # )
baseline_frequency_filtered = np.abs( baseline_frequency_filtered = np.abs(
baseline_frequency - np.median(baseline_frequency) baseline_frequency - np.median(baseline_frequency)
) )
baseline_frequency_filtered = highpass_filter( # baseline_frequency_filtered = highpass_filter(
signal=baseline_frequency_filtered, # signal=baseline_frequency_filtered,
samplerate=baseline_frequency_samplerate, # samplerate=baseline_frequency_samplerate,
cutoff=config.baseline_frequency_highpass_cutoff, # cutoff=config.baseline_frequency_highpass_cutoff,
) # )
baseline_frequency_filtered = envelope( # baseline_frequency_filtered = envelope(
signal=-baseline_frequency_filtered, # signal=-baseline_frequency_filtered,
samplerate=baseline_frequency_samplerate, # samplerate=baseline_frequency_samplerate,
cutoff_frequency=config.baseline_frequency_envelope_cutoff, # cutoff_frequency=config.baseline_frequency_envelope_cutoff,
) # )
# CUT OFF OVERLAP --------------------------------------------- # CUT OFF OVERLAP ---------------------------------------------
@ -850,9 +886,9 @@ def chirpdetection(datapath: str, plot: str) -> None:
# normalize all three feature arrays to the same range to make # normalize all three feature arrays to the same range to make
# peak detection simpler # peak detection simpler
baseline_envelope = normalize([baseline_envelope])[0] baseline_envelope = minmaxnorm([baseline_envelope])[0]
search_envelope = normalize([search_envelope])[0] search_envelope = minmaxnorm([search_envelope])[0]
baseline_frequency_filtered = normalize( baseline_frequency_filtered = minmaxnorm(
[baseline_frequency_filtered] [baseline_frequency_filtered]
)[0] )[0]
@ -893,7 +929,7 @@ def chirpdetection(datapath: str, plot: str) -> None:
or len(frequency_peak_timestamps) == 0 or len(frequency_peak_timestamps) == 0
) )
if one_feature_empty: if one_feature_empty and (debug == 'false'):
continue continue
# group peak across feature arrays but only if they # group peak across feature arrays but only if they
@ -914,7 +950,7 @@ def chirpdetection(datapath: str, plot: str) -> None:
# check it there are chirps detected after grouping, continue # check it there are chirps detected after grouping, continue
# with the loop if not # with the loop if not
if len(singleelectrode_chirps) == 0: if (len(singleelectrode_chirps) == 0) and (debug == 'false'):
continue continue
# append chirps from this electrode to the multilectrode list # append chirps from this electrode to the multilectrode list
@ -925,12 +961,12 @@ def chirpdetection(datapath: str, plot: str) -> None:
& (plot in ["show", "save"]) & (plot in ["show", "save"])
) )
if chirp_detected: if chirp_detected or (debug != 'elecrode'):
logger.debug("Detected chirp, ititialize buffer ...") logger.debug("Detected chirp, ititialize buffer ...")
# save data to Buffer # save data to Buffer
buffer = PlotBuffer( buffer = ChirpPlotBuffer(
config=config, config=config,
t0=window_start_seconds, t0=window_start_seconds,
dt=window_duration_seconds, dt=window_duration_seconds,
@ -955,6 +991,11 @@ def chirpdetection(datapath: str, plot: str) -> None:
logger.debug("Buffer initialized!") logger.debug("Buffer initialized!")
if debug == "electrode":
logger.info(f'Plotting electrode {el} ...')
buffer.plot_buffer(
chirps=singleelectrode_chirps, plot=plot)
logger.debug( logger.debug(
f"Processed all electrodes for fish {track_id} for this" f"Processed all electrodes for fish {track_id} for this"
"window, sorting chirps ..." "window, sorting chirps ..."
@ -963,7 +1004,7 @@ def chirpdetection(datapath: str, plot: str) -> None:
# check if there are chirps detected in multiple electrodes and # check if there are chirps detected in multiple electrodes and
# continue the loop if not # continue the loop if not
if len(multielectrode_chirps) == 0: if (len(multielectrode_chirps) == 0) and (debug == 'false'):
continue continue
# validate multielectrode chirps, i.e. check if they are # validate multielectrode chirps, i.e. check if they are
@ -988,12 +1029,17 @@ def chirpdetection(datapath: str, plot: str) -> None:
# if chirps are detected and the plot flag is set, plot the # if chirps are detected and the plot flag is set, plot the
# chirps, otheswise try to delete the buffer if it exists # chirps, otheswise try to delete the buffer if it exists
if ((len(multielectrode_chirps_validated) > 0) & (plot in ["show", "save"])): if debug == "fish":
logger.info(f'Plotting fish {track_id} ...')
buffer.plot_buffer(multielectrode_chirps_validated, plot)
if ((len(multielectrode_chirps_validated) > 0) &
(plot in ["show", "save"]) & (debug == 'false')):
try: try:
buffer.plot_buffer(multielectrode_chirps_validated, plot) buffer.plot_buffer(multielectrode_chirps_validated, plot)
del buffer del buffer
except NameError: except NameError:
embed() pass
else: else:
try: try:
del buffer del buffer
@ -1051,4 +1097,4 @@ if __name__ == "__main__":
datapath = "../data/2022-06-02-10_00/" datapath = "../data/2022-06-02-10_00/"
# datapath = "/home/weygoldt/Data/uni/efishdata/2016-colombia/fishgrid/2016-04-09-22_25/" # datapath = "/home/weygoldt/Data/uni/efishdata/2016-colombia/fishgrid/2016-04-09-22_25/"
# datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-03-13-10_00/" # datapath = "/home/weygoldt/Data/uni/chirpdetection/GP2023_chirp_detection/data/mount_data/2020-03-13-10_00/"
chirpdetection(datapath, plot="show") chirpdetection(datapath, plot="show", debug="fish")

View File

@ -19,29 +19,29 @@ baseline_frequency_smoothing: 5
# Baseline processing parameters # Baseline processing parameters
baseline_envelope_cutoff: 25 baseline_envelope_cutoff: 25
baseline_envelope_bandpass_lowf: 4 baseline_envelope_bandpass_lowf: 2
baseline_envelope_bandpass_highf: 100 baseline_envelope_bandpass_highf: 100
baseline_envelope_envelope_cutoff: 4 # baseline_envelope_envelope_cutoff: 4
# search envelope processing parameters # search envelope processing parameters
search_envelope_cutoff: 5 search_envelope_cutoff: 10
# Instantaneous frequency bandpass filter cutoff frequencies # Instantaneous frequency bandpass filter cutoff frequencies
baseline_frequency_highpass_cutoff: 0.000005 # baseline_frequency_highpass_cutoff: 0.000005
baseline_frequency_envelope_cutoff: 0.000005 # baseline_frequency_envelope_cutoff: 0.000005
# peak detecion parameters # peak detecion parameters
prominence: 0.005 prominence: 0.7
# search freq parameter # search freq parameter
search_df_lower: 20 search_df_lower: 20
search_df_upper: 100 search_df_upper: 100
search_res: 1 search_res: 1
search_bandwidth: 10 search_bandwidth: 20
default_search_freq: 50 default_search_freq: 60
# Classify events as chirps if they are less than this time apart # Classify events as chirps if they are less than this time apart
chirp_window_threshold: 0.05 chirp_window_threshold: 0.015

View File

@ -1,4 +1,5 @@
import os import os
import pandas as pd
import numpy as np import numpy as np
from chirpdetection import chirpdetection from chirpdetection import chirpdetection
from IPython import embed from IPython import embed
@ -7,7 +8,7 @@ from IPython import embed
def main(datapaths): def main(datapaths):
for path in datapaths: for path in datapaths:
chirpdetection(path, plot='show') chirpdetection(path, plot='show', debug='electrode')
if __name__ == '__main__': if __name__ == '__main__':
@ -39,6 +40,9 @@ if __name__ == '__main__':
datapaths = [os.path.join(dataroot, dataset) + datapaths = [os.path.join(dataroot, dataset) +
'/' for dataset in valid_datasets] '/' for dataset in valid_datasets]
embed()
main(datapaths[3]) recs = pd.DataFrame(columns=['recording'], data=valid_datasets)
recs.to_csv('../recs.csv', index=False)
# main(datapaths)
# window 1524 + 244 in dataset index 4 is nice example

35
code/get_behaviour.py Normal file
View File

@ -0,0 +1,35 @@
import os
from paramiko import SSHClient
from scp import SCPClient
from IPython import embed
from pandas import read_csv
ssh = SSHClient()
ssh.load_system_host_keys()
ssh.connect(hostname='kraken',
username='efish',
password='fwNix4U',
)
# SCPCLient takes a paramiko transport as its only argument
scp = SCPClient(ssh.get_transport())
data = read_csv('../recs.csv')
foldernames = data['recording'].values
directory = f'/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/mount_data/'
for foldername in foldernames:
if not os.path.exists(directory+foldername):
os.makedirs(directory+foldername)
files = [('-').join(foldername.split('-')[:3])+'.csv','chirp_ids.npy', 'chirps.npy', 'fund_v.npy', 'ident_v.npy', 'idx_v.npy', 'times.npy', 'spec.npy', 'LED_on_time.npy', 'sign_v.npy']
for f in files:
scp.get(f'/home/efish/behavior/2019_tube_competition/{foldername}/{f}',
directory+foldername)
scp.close()

View File

@ -4,7 +4,7 @@ from scipy.ndimage import gaussian_filter1d
from scipy.stats import gamma, norm from scipy.stats import gamma, norm
def scale01(data): def minmaxnorm(data):
""" """
Normalize data to [0, 1] Normalize data to [0, 1]
@ -19,7 +19,7 @@ def scale01(data):
Normalized data. Normalized data.
""" """
return (2*((data - np.min(data)) / (np.max(data) - np.min(data)))) - 1 return (data - np.min(data)) / (np.max(data) - np.min(data))
def instantaneous_frequency( def instantaneous_frequency(
@ -168,6 +168,9 @@ def group_timestamps(
] ]
timestamps.sort() timestamps.sort()
if len(timestamps) == 0:
return []
groups = [] groups = []
current_group = [timestamps[0]] current_group = [timestamps[0]]

29
recs.csv Normal file
View File

@ -0,0 +1,29 @@
recording
2020-03-13-10_00
2020-03-16-10_00
2020-03-19-10_00
2020-03-20-10_00
2020-03-23-09_58
2020-03-24-10_00
2020-03-25-10_00
2020-03-31-09_59
2020-05-11-10_00
2020-05-12-10_00
2020-05-13-10_00
2020-05-14-10_00
2020-05-15-10_00
2020-05-18-10_00
2020-05-19-10_00
2020-05-21-10_00
2020-05-25-10_00
2020-05-27-10_00
2020-05-28-10_00
2020-05-29-10_00
2020-06-02-10_00
2020-06-03-10_10
2020-06-04-10_00
2020-06-05-10_00
2020-06-08-10_00
2020-06-09-10_00
2020-06-10-10_00
2020-06-11-10_00
1 recording
2 2020-03-13-10_00
3 2020-03-16-10_00
4 2020-03-19-10_00
5 2020-03-20-10_00
6 2020-03-23-09_58
7 2020-03-24-10_00
8 2020-03-25-10_00
9 2020-03-31-09_59
10 2020-05-11-10_00
11 2020-05-12-10_00
12 2020-05-13-10_00
13 2020-05-14-10_00
14 2020-05-15-10_00
15 2020-05-18-10_00
16 2020-05-19-10_00
17 2020-05-21-10_00
18 2020-05-25-10_00
19 2020-05-27-10_00
20 2020-05-28-10_00
21 2020-05-29-10_00
22 2020-06-02-10_00
23 2020-06-03-10_10
24 2020-06-04-10_00
25 2020-06-05-10_00
26 2020-06-08-10_00
27 2020-06-09-10_00
28 2020-06-10-10_00
29 2020-06-11-10_00