309 lines
13 KiB
Python
309 lines
13 KiB
Python
import itertools
|
|
import sys
|
|
import os
|
|
import argparse
|
|
|
|
import torch
|
|
import torchvision.transforms as T
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.gridspec as gridspec
|
|
from matplotlib.patches import Rectangle
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
from tqdm.auto import tqdm
|
|
from IPython import embed
|
|
|
|
from confic import (MIN_FREQ, MAX_FREQ, DELTA_FREQ, FREQ_OVERLAP, DELTA_TIME, TIME_OVERLAP, IMG_SIZE, IMG_DPI, DATA_DIR, LABEL_DIR)
|
|
|
|
|
|
def load_spec_data(folder: str):
|
|
"""
|
|
Load spectrogram of a given electrode-grid recording generated with the wavetracker package. The spectrograms may
|
|
be to large to load in total, thats why memmory mapping is used (numpy.memmap).
|
|
|
|
Parameters
|
|
----------
|
|
folder: str
|
|
Folder where fine spec numpy files generated for grid recordings with the wavetracker package can be found.
|
|
|
|
Returns
|
|
-------
|
|
fill_freqs: ndarray
|
|
Freuqencies corresponding to 1st dimension of the spectrogram.
|
|
fill_times: ndarray
|
|
Times corresponding to the 2nd dimenstion if the spectrigram.
|
|
fill_spec: ndarray
|
|
Spectrigram of the recording refered to in the input folder.
|
|
"""
|
|
fill_freqs, fill_times, fill_spec = [], [], []
|
|
|
|
if os.path.exists(os.path.join(folder, 'fill_spec.npy')):
|
|
fill_freqs = np.load(os.path.join(folder, 'fill_freqs.npy'))
|
|
fill_times = np.load(os.path.join(folder, 'fill_times.npy'))
|
|
fill_spec_shape = np.load(os.path.join(folder, 'fill_spec_shape.npy'))
|
|
fill_spec = np.memmap(os.path.join(folder, 'fill_spec.npy'), dtype='float', mode='r',
|
|
shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F')
|
|
|
|
elif os.path.exists(os.path.join(folder, 'fine_spec.npy')):
|
|
fill_freqs = np.load(os.path.join(folder, 'fine_freqs.npy'))
|
|
fill_times = np.load(os.path.join(folder, 'fine_times.npy'))
|
|
fill_spec_shape = np.load(os.path.join(folder, 'fine_spec_shape.npy'))
|
|
fill_spec = np.memmap(os.path.join(folder, 'fine_spec.npy'), dtype='float', mode='r',
|
|
shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F')
|
|
|
|
return fill_freqs, fill_times, fill_spec
|
|
|
|
|
|
def load_tracking_data(folder):
|
|
base_path = Path(folder)
|
|
EODf_v = np.load(base_path / 'fund_v.npy')
|
|
ident_v = np.load(base_path / 'ident_v.npy')
|
|
idx_v = np.load(base_path / 'idx_v.npy')
|
|
times_v = np.load(base_path / 'times.npy')
|
|
|
|
return EODf_v, ident_v, idx_v, times_v
|
|
|
|
|
|
def load_trial_data(folder):
|
|
base_path = Path(folder)
|
|
fish_freq = np.load(base_path / 'analysis' / 'fish_freq.npy')
|
|
rise_idx = np.load(base_path / 'analysis' / 'rise_idx.npy')
|
|
rise_size = np.load(base_path / 'analysis' / 'rise_size.npy')
|
|
|
|
fish_baseline_freq = np.load(base_path / 'analysis' / 'baseline_freqs.npy')
|
|
fish_baseline_freq_time = np.load(base_path / 'analysis' / 'baseline_freq_times.npy')
|
|
|
|
return fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time
|
|
|
|
|
|
def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, pic_save_folder):
|
|
f_res, t_res = freq[1] - freq[0], times[1] - times[0]
|
|
|
|
fig_title = (f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(' ', '0')
|
|
fig = plt.figure(figsize=IMG_SIZE, num=fig_title)
|
|
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
|
|
ax = fig.add_subplot(gs[0, 0])
|
|
ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower',
|
|
extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res))
|
|
ax.axis(False)
|
|
|
|
# plt.savefig(os.path.join(dataset_folder, fig_title), dpi=IMG_DPI)
|
|
|
|
plt.savefig(Path(pic_save_folder)/fig_title, dpi=IMG_DPI)
|
|
plt.close()
|
|
|
|
return fig_title
|
|
|
|
|
|
def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str,t0, t1, f0, f1):
|
|
|
|
times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1))
|
|
|
|
all_x_center = []
|
|
all_y_center = []
|
|
all_width = []
|
|
all_height= []
|
|
for id_idx in range(len(fish_freq)):
|
|
rise_idx_oi = np.array(rise_idx[id_idx][
|
|
(rise_idx[id_idx] >= times_v_idx0) &
|
|
(rise_idx[id_idx] <= times_v_idx1) &
|
|
(rise_size[id_idx] >= 10)], dtype=int)
|
|
rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) &
|
|
(rise_idx[id_idx] <= times_v_idx1) &
|
|
(rise_size[id_idx] >= 10)]
|
|
if len(rise_idx_oi) == 0:
|
|
# np.savetxt(LABEL_DIR / Path(pic_save_str).with_suffix('.txt'), np.array([]))
|
|
continue
|
|
|
|
closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi]))
|
|
closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx]
|
|
|
|
upper_freq_bound = closest_baseline_freq + rise_size_oi
|
|
lower_freq_bound = closest_baseline_freq
|
|
|
|
left_time_bound = times_v[rise_idx_oi]
|
|
right_time_bound = np.zeros_like(left_time_bound)
|
|
|
|
for enu, Ct_oi in enumerate(times_v[rise_idx_oi]):
|
|
Crise_size = rise_size_oi[enu]
|
|
Cblf = closest_baseline_freq[enu]
|
|
|
|
rise_end_t = times_v[(times_v > Ct_oi) &
|
|
(fish_freq[id_idx] < Cblf + Crise_size * 0.37)]
|
|
if len(rise_end_t) == 0:
|
|
right_time_bound[enu] = np.nan
|
|
else:
|
|
right_time_bound[enu] = rise_end_t[0]
|
|
|
|
mask = (~np.isnan(right_time_bound) & ((right_time_bound - left_time_bound) > 1.))
|
|
left_time_bound = left_time_bound[mask]
|
|
right_time_bound = right_time_bound[mask]
|
|
lower_freq_bound = lower_freq_bound[mask]
|
|
upper_freq_bound = upper_freq_bound[mask]
|
|
|
|
left_time_bound -= 0.01 * (t1 - t0)
|
|
right_time_bound += 0.05 * (t1 - t0)
|
|
lower_freq_bound -= 0.01 * (f1 - f0)
|
|
upper_freq_bound += 0.05 * (f1 - f0)
|
|
|
|
mask2 = ((left_time_bound >= t0) &
|
|
(right_time_bound <= t1) &
|
|
(lower_freq_bound >= f0) &
|
|
(upper_freq_bound <= f1)
|
|
)
|
|
left_time_bound = left_time_bound[mask2]
|
|
right_time_bound = right_time_bound[mask2]
|
|
lower_freq_bound = lower_freq_bound[mask2]
|
|
upper_freq_bound = upper_freq_bound[mask2]
|
|
|
|
if len(left_time_bound) == 0:
|
|
continue
|
|
|
|
rel_x0 = np.array((left_time_bound - t0) / (t1 - t0), dtype=float)
|
|
rel_x1 = np.array((right_time_bound - t0) / (t1 - t0), dtype=float)
|
|
|
|
rel_y0 = np.array(1 - (upper_freq_bound - f0) / (f1 - f0), dtype=float)
|
|
rel_y1 = np.array(1 - (lower_freq_bound - f0) / (f1 - f0), dtype=float)
|
|
|
|
rel_x_center = rel_x1 - (rel_x1 - rel_x0) / 2
|
|
rel_y_center = rel_y1 - (rel_y1 - rel_y0) / 2
|
|
rel_width = rel_x1 - rel_x0
|
|
rel_height = rel_y1 - rel_y0
|
|
|
|
all_x_center.extend(rel_x_center)
|
|
all_y_center.extend(rel_y_center)
|
|
all_width.extend(rel_width)
|
|
all_height.extend(rel_height)
|
|
|
|
bbox_yolo_style = np.array([
|
|
np.ones(len(all_x_center), dtype=int),
|
|
all_x_center,
|
|
all_y_center,
|
|
all_width,
|
|
all_height
|
|
]).T
|
|
|
|
np.savetxt(LABEL_DIR/ Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style)
|
|
return bbox_yolo_style
|
|
|
|
|
|
def main(args):
|
|
folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy'))
|
|
pic_save_folder = DATA_DIR if not args.inference else (Path('data') / Path(args.folder).name)
|
|
|
|
if len(folders) == 0:
|
|
print('no datasets containing fill_times.npy found')
|
|
|
|
if not args.inference:
|
|
print('generate training dataset only for files with detected rises')
|
|
folders = [folder for folder in folders if (folder / 'analysis' / 'rise_idx.npy').exists()]
|
|
|
|
else:
|
|
print('generate inference dataset ... only image output')
|
|
if not (Path('data') / Path(args.folder).name).exists():
|
|
(Path('data') / Path(args.folder).name).mkdir(parents=True, exist_ok=True)
|
|
|
|
for enu, folder in enumerate(folders):
|
|
print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}')
|
|
|
|
# load different categories of data
|
|
freq, times, spec = (
|
|
load_spec_data(folder))
|
|
EODf_v, ident_v, idx_v, times_v = (
|
|
load_tracking_data(folder))
|
|
if not args.inference:
|
|
fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = (
|
|
load_trial_data(folder))
|
|
|
|
# generate iterator for analysis window loop
|
|
pic_base = tqdm(itertools.product(
|
|
np.arange(0, times[-1], DELTA_TIME),
|
|
np.arange(MIN_FREQ, MAX_FREQ, DELTA_FREQ)
|
|
),
|
|
total=int((((MAX_FREQ-MIN_FREQ)//DELTA_FREQ)+1) * ((times[-1] // DELTA_TIME)+1))
|
|
)
|
|
|
|
for t0, f0 in pic_base:
|
|
|
|
t1 = t0 + DELTA_TIME + TIME_OVERLAP
|
|
f1 = f0 + DELTA_FREQ + FREQ_OVERLAP
|
|
|
|
present_freqs = EODf_v[(~np.isnan(ident_v)) &
|
|
(t0 <= times_v[idx_v]) &
|
|
(times_v[idx_v] <= t1) &
|
|
(EODf_v >= f0) &
|
|
(EODf_v <= f1)]
|
|
if len(present_freqs) == 0:
|
|
continue
|
|
|
|
# get spec_idx for current spec snippet
|
|
f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1))
|
|
t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1))
|
|
|
|
# get spec snippet and create torch.tensfor from it
|
|
s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32)
|
|
log_s = torch.log10(s)
|
|
transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s))
|
|
s_trans = transformed(log_s.unsqueeze(0))
|
|
|
|
pic_save_str = save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, pic_save_folder)
|
|
|
|
if not args.inference:
|
|
bbox_yolo_style = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size,
|
|
fish_baseline_freq_time, fish_baseline_freq,
|
|
pic_save_str,t0, t1, f0, f1)
|
|
|
|
#######################################################################
|
|
# if False:
|
|
# if bbox_yolo_style.shape[0] >= 1:
|
|
# f_res, t_res = freq[1] - freq[0], times[1] - times[0]
|
|
#
|
|
# fig_title = (
|
|
# f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(
|
|
# ' ', '0')
|
|
# fig = plt.figure(figsize=IMG_SIZE, num=fig_title)
|
|
# gs = gridspec.GridSpec(1, 1, bottom=0.1, left=0.1, right=0.95, top=0.95) #
|
|
# ax = fig.add_subplot(gs[0, 0])
|
|
# ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower',
|
|
# extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res))
|
|
# # ax.invert_yaxis()
|
|
# # ax.axis(False)
|
|
#
|
|
# for i in range(len(bbox_df)):
|
|
# # Cbbox = np.array(bbox_df.loc[i, ['x0', 'y0', 'x1', 'y1']].values, dtype=np.float32)
|
|
# Cbbox = bbox_df.loc[i, ['t0', 'f0', 't1', 'f1']]
|
|
# ax.add_patch(
|
|
# Rectangle((float(Cbbox['t0']) / 3600, float(Cbbox['f0'])),
|
|
# float(Cbbox['t1']) / 3600 - float(Cbbox['t0']) / 3600,
|
|
# float(Cbbox['f1']) - float(Cbbox['f0']),
|
|
# fill=False, color="white", linestyle='-', linewidth=2, zorder=10)
|
|
# )
|
|
#
|
|
# # print(bbox_yolo_style.T)
|
|
#
|
|
# for bbox in bbox_yolo_style:
|
|
# x0 = bbox[1] - bbox[3]/2 # x_center - width/2
|
|
# y0 = 1 - (bbox[2] + bbox[4]/2) # x_center - width/2
|
|
# w = bbox[3]
|
|
# h = bbox[4]
|
|
# ax.add_patch(
|
|
# Rectangle((x0, y0), w, h,
|
|
# fill=False, color="k", linestyle='--', linewidth=2, zorder=10,
|
|
# transform=ax.transAxes)
|
|
# )
|
|
# plt.show()
|
|
#######################################################################
|
|
|
|
# if not args.inference:
|
|
# print('save bboxes')
|
|
# bbox_df.to_csv(os.path.join(args.dataset_folder, 'bbox_dataset.csv'), columns=cols, sep=',')
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
|
|
parser.add_argument('folder', type=str, help='single recording analysis', default='')
|
|
parser.add_argument('-i', "--inference", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
main(args) |