export data to subsubsublists

This commit is contained in:
weygoldt 2023-01-16 18:53:45 +01:00
parent 812122f4f8
commit 12d1c3c1af
2 changed files with 94 additions and 61 deletions

View File

@ -139,8 +139,6 @@ def main(datapath: str) -> None:
# load raw file
file = os.path.join(datapath, "traces-grid1.raw")
# data = DataLoader(file, 60.0, 0, channel=-1)
data = LoadData(datapath)
# ititialize data collection
@ -150,17 +148,10 @@ def main(datapath: str) -> None:
fish_ids = []
electrodes = []
# load wavetracker files
# time = np.load(datapath + "times.npy", allow_pickle=True)
# freq = np.load(datapath + "fund_v.npy", allow_pickle=True)
# powers = np.load(datapath + "sign_v.npy", allow_pickle=True)
# idx = np.load(datapath + "idx_v.npy", allow_pickle=True)
# ident = np.load(datapath + "ident_v.npy", allow_pickle=True)
# load config file
config = ConfLoader("chirpdetector_conf.yml")
# set time window # <------------------------ Iterate through windows here
# set time window
window_duration = config.window * data.raw_rate
window_overlap = config.overlap * data.raw_rate
window_edge = config.edge * data.raw_rate
@ -177,18 +168,21 @@ def main(datapath: str) -> None:
else:
raise ValueError("Window overlap must be even.")
# make time array for raw data
raw_time = np.arange(data.raw.shape[0]) / data.raw_rate
# good chirp times for data: 2022-06-02-10_00
t0 = (3 * 60 * 60 + 6 * 60 + 43.5) * data.raw_rate
dt = 60 * data.raw_rate
# generate starting points of rolling window
window_starts = np.arange(
t0,
t0 + dt,
window_duration - (window_overlap + 2 * window_edge),
dtype=int
)
# ask how many windows should be calulated
nwindows = int(
input("How many windows should be calculated (integer number)? "))
@ -202,12 +196,7 @@ def main(datapath: str) -> None:
# set index window
stop_index = start_index + window_duration
# t0 = 3 * 60 * 60 + 6 * 60 + 43.5
# dt = 60
# start_index = t0 * data.raw_rate
# stop_index = (t0 + dt) * data.raw_rate
# calucate frequencies in wndow
# calucate median of fish frequencies in window
median_freq = []
track_ids = []
for i, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
@ -217,24 +206,36 @@ def main(datapath: str) -> None:
]
median_freq.append(np.median(data.freq[window_index]))
track_ids.append(track_id)
# convert to numpy array
median_freq = np.asarray(median_freq)
track_ids = np.asarray(track_ids)
# make empty lists for data collection
baseline_ts_sub = []
search_ts_sub = []
freq_ts_sub = []
electrodes_sub = []
fish_ids_sub = []
# iterate through all fish
for i, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
print(f"Track ID: {track_id}")
# get index of track data in this time window
window_index = np.arange(len(data.idx))[
(data.ident == track_id) & (data.time[data.idx] >= t0) & (
data.time[data.idx] <= (t0 + dt))
data.time[data.idx] <= (t0 + dt))
]
# get tracked frequencies and their times
freq_temp = data.freq[window_index]
powers_temp = data.powers[window_index, :]
# time_temp = time[idx[window_index]]
# approximate sampling rate to compute expected durations if there
# is data available for this time window for this fish id
track_samplerate = np.mean(1 / np.diff(data.time))
expected_duration = ((t0 + dt) - t0) * track_samplerate
@ -250,6 +251,7 @@ def main(datapath: str) -> None:
sharex=True,
sharey='row',
)
# get best electrode
best_electrodes = np.argsort(np.nanmean(
powers_temp, axis=0))[-config.electrodes:]
@ -321,31 +323,20 @@ def main(datapath: str) -> None:
print(f"Search frequency: {search_freq}")
# ititialize sublists to collect electrodes for this fish in this
# time window
baseline_ts_subsub = []
search_ts_subsub = []
freq_ts_subsub = []
electrodes_subsub = []
# iterate through electrodes
for i, electrode in enumerate(best_electrodes):
# load region of interest of raw data file
data_oi = data.raw[start_index:stop_index, :]
time_oi = raw_time[start_index:stop_index]
# plot wavetracker tracks to spectrogram
# for track_id in np.unique(ident): # <---------- Find freq gaps later
# here
# # get indices for time array in time window
# window_index = np.arange(len(idx))[
# (ident == track_id) &
# (time[idx] >= t0) &
# (time[idx] <= (t0 + dt))
# ]
# freq_temp = freq[window_index]
# time_temp = time[idx[window_index]]
# axs[0].plot(time_temp-t0, freq_temp, lw=2)
# axs[0].set_ylim(500, 1000)
# track_id = ids
# filter baseline and above
baseline, search = double_bandpass(
data_oi[:, electrode], data.raw_rate, freq_temp, search_freq
@ -377,13 +368,6 @@ def main(datapath: str) -> None:
config.envelope_highpass_cutoff
)
# baseline_envelope = np.abs(baseline_envelope)
# search_envelope = highpass_filter(
# search_envelope,
# data.raw_rate,
# config.envelope_highpass_cutoff
# )
# envelopes of filtered envelope of filtered baseline
baseline_envelope = envelope(
np.abs(baseline_envelope),
@ -391,9 +375,6 @@ def main(datapath: str) -> None:
config.envelope_envelope_cutoff
)
# search_envelope = bandpass_filter(
# search_envelope, data.raw_rate, lowf=lowf, highf=highf)
# bandpass filter the instantaneous
inst_freq_filtered = bandpass_filter(
baseline_freq,
@ -402,11 +383,7 @@ def main(datapath: str) -> None:
highf=config.instantaneous_highf
)
# test taking the log of the envelopes
# baseline_envelope = np.log(baseline_envelope)
# search_envelope = np.log(search_envelope)
# CUT OFF OVERLAP -------------------------------------------------
# CUT OFF OVERLAP ---------------------------------------------
# cut off first and last 0.5 * overlap at start and end
valid = np.arange(
@ -417,7 +394,7 @@ def main(datapath: str) -> None:
baseline_envelope = baseline_envelope[valid]
search_envelope = search_envelope[valid]
# get inst freq valid snippet
# get inst freq valid snippet
valid_t0 = int(window_edge) / data.raw_rate
valid_t1 = baseline_freq_time[-1] - \
(int(window_edge) / data.raw_rate)
@ -437,13 +414,13 @@ def main(datapath: str) -> None:
broad_baseline = broad_baseline[valid]
search = search[valid]
# NORMALIZE ----------------------------------------------------
# NORMALIZE ---------------------------------------------------
baseline_envelope = normalize([baseline_envelope])[0]
search_envelope = normalize([search_envelope])[0]
inst_freq_filtered = normalize([inst_freq_filtered])[0]
# PEAK DETECTION -----------------------------------------------
# PEAK DETECTION ----------------------------------------------
# detect peaks baseline_enelope
prominence = np.percentile(
@ -465,12 +442,11 @@ def main(datapath: str) -> None:
# SAVE DATA ----------------------------------------------------
baseline_ts.append(time_oi[baseline_peaks].tolist())
search_ts.append(time_oi[search_peaks].tolist())
freq_ts.append(baseline_freq_time[inst_freq_peaks].tolist())
fish_ids.append(track_id)
electrodes.append(electrode)
embed()
baseline_ts_subsub.append(time_oi[baseline_peaks].tolist())
search_ts_subsub.append(time_oi[search_peaks].tolist())
freq_ts_subsub.append(baseline_freq_time[inst_freq_peaks].tolist())
electrodes_subsub.append(electrode)
# PLOT ------------------------------------------------------------
# plot spectrogram
@ -543,6 +519,22 @@ def main(datapath: str) -> None:
plt.show()
baseline_ts_sub.append(baseline_ts_subsub)
search_ts_sub.append(search_ts_subsub)
freq_ts_sub.append(freq_ts_subsub)
electrodes_sub.append(electrodes_subsub)
fish_ids_sub.append(track_id)
baseline_ts.append(baseline_ts_sub)
search_ts.append(search_ts_sub)
freq_ts.append(freq_ts_sub)
fish_ids.append(fish_ids_sub)
electrodes.append(electrodes_sub)
embed()
if __name__ == "__main__":
datapath = "../data/2022-06-02-10_00/"

41
code/modules/logger.py Normal file
View File

@ -0,0 +1,41 @@
import logging
def makeLogger(name: str):
# create logger formats for file and terminal
file_formatter = logging.Formatter(
"[ %(levelname)s ] ~ %(asctime)s ~ %(module)s.%(funcName)s: %(message)s")
console_formatter = logging.Formatter(
"[ %(levelname)s ] in %(module)s.%(funcName)s: %(message)s")
# create logging file if loglevel is debug
file_handler = logging.FileHandler(f"gridtools_log.log", mode="w")
file_handler.setLevel(logging.WARN)
file_handler.setFormatter(file_formatter)
# create stream handler for terminal output
console_handler = logging.StreamHandler()
console_handler.setFormatter(console_formatter)
console_handler.setLevel(logging.INFO)
# create script specific logger
logger = logging.getLogger(name)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
logger.setLevel(logging.INFO)
return logger
if __name__ == "__main__":
# initiate logger
mylogger = makeLogger(__name__)
# test logger levels
mylogger.debug("This is for debugging!")
mylogger.info("This is an info.")
mylogger.warning("This is a warning.")
mylogger.error("This is an error.")
mylogger.critical("This is a critical error!")