import time import numpy as np import argparse import torch from torch import nn import torch.nn.functional as F import torchvision.transforms as T import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from pathlib import Path import pandas as pd from tqdm.auto import tqdm import itertools import sys import os from IPython import embed from matplotlib.patches import Rectangle from matplotlib.collections import PatchCollection def load_spec_data(folder): fill_freqs, fill_times, fill_spec = [], [], [] if os.path.exists(os.path.join(folder, 'fill_spec.npy')): fill_freqs = np.load(os.path.join(folder, 'fill_freqs.npy')) fill_times = np.load(os.path.join(folder, 'fill_times.npy')) fill_spec_shape = np.load(os.path.join(folder, 'fill_spec_shape.npy')) fill_spec = np.memmap(os.path.join(folder, 'fill_spec.npy'), dtype='float', mode='r', shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F') elif os.path.exists(os.path.join(folder, 'fine_spec.npy')): fill_freqs = np.load(os.path.join(folder, 'fine_freqs.npy')) fill_times = np.load(os.path.join(folder, 'fine_times.npy')) fill_spec_shape = np.load(os.path.join(folder, 'fine_spec_shape.npy')) fill_spec = np.memmap(os.path.join(folder, 'fine_spec.npy'), dtype='float', mode='r', shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F') return fill_freqs, fill_times, fill_spec def load_tracking_data(folder): base_path = Path(folder) EODf_v = np.load(base_path / 'fund_v.npy') ident_v = np.load(base_path / 'ident_v.npy') idx_v = np.load(base_path / 'idx_v.npy') times_v = np.load(base_path / 'times.npy') return EODf_v, ident_v, idx_v, times_v def load_trial_data(folder): base_path = Path(folder) fish_freq = np.load(base_path / 'analysis' / 'fish_freq.npy') rise_idx = np.load(base_path / 'analysis' / 'rise_idx.npy') rise_size = np.load(base_path / 'analysis' / 'rise_size.npy') fish_baseline_freq = np.load(base_path / 'analysis' / 'baseline_freqs.npy') fish_baseline_freq_time = np.load(base_path / 'analysis' / 'baseline_freq_times.npy') return fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, dataset_folder): size = (7, 7) dpi = 256 f_res, t_res = freq[1] - freq[0], times[1] - times[0] fig_title = (f'{Path(folder).name}__{times[t_idx0]:.0f}s-{times[t_idx1]:.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(' ', '0') fig = plt.figure(figsize=(7, 7), num=fig_title) gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) # ax = fig.add_subplot(gs[0, 0]) ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', extent=(times[t_idx0] / 3600, times[t_idx1] / 3600 + t_res, freq[f_idx0], freq[f_idx1] + f_res)) ax.axis(False) plt.savefig(os.path.join(dataset_folder, fig_title), dpi=256) plt.close() return fig_title, (size[0]*dpi, size[1]*dpi) def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1): times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1)) for id_idx in range(len(fish_freq)): rise_idx_oi = np.array(rise_idx[id_idx][ (rise_idx[id_idx] >= times_v_idx0) & (rise_idx[id_idx] <= times_v_idx1) & (rise_size[id_idx] >= 10)], dtype=int) rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) & (rise_idx[id_idx] <= times_v_idx1) & (rise_size[id_idx] >= 10)] if len(rise_idx_oi) == 0: continue closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi])) closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx] upper_freq_bound = closest_baseline_freq + rise_size_oi lower_freq_bound = closest_baseline_freq left_time_bound = times_v[rise_idx_oi] right_time_bound = np.zeros_like(left_time_bound) for enu, Ct_oi in enumerate(times_v[rise_idx_oi]): Crise_size = rise_size_oi[enu] Cblf = closest_baseline_freq[enu] rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] if len(rise_end_t) == 0: right_time_bound[enu] = np.nan else: right_time_bound[enu] = rise_end_t[0] mask = (~np.isnan(right_time_bound) & ((right_time_bound - left_time_bound) > 1.)) left_time_bound = left_time_bound[mask] right_time_bound = right_time_bound[mask] lower_freq_bound = lower_freq_bound[mask] upper_freq_bound = upper_freq_bound[mask] # dt_bbox = right_time_bound - left_time_bound # df_bbox = upper_freq_bound - lower_freq_bound left_time_bound -= 0.01 * (t1 - t0) right_time_bound += 0.05 * (t1 - t0) lower_freq_bound -= 0.01 * (f1 - f0) upper_freq_bound += 0.05 * (f1 - f0) mask2 = ((left_time_bound >= t0) & (right_time_bound <= t1) & (lower_freq_bound >= f0) & (upper_freq_bound <= f1) ) left_time_bound = left_time_bound[mask2] right_time_bound = right_time_bound[mask2] lower_freq_bound = lower_freq_bound[mask2] upper_freq_bound = upper_freq_bound[mask2] x0 = np.array((left_time_bound - t0) / (t1 - t0) * width, dtype=int) x1 = np.array((right_time_bound - t0) / (t1 - t0) * width, dtype=int) y0 = np.array((1 - (upper_freq_bound - f0) / (f1 - f0)) * height, dtype=int) y1 = np.array((1 - (lower_freq_bound - f0) / (f1 - f0)) * height, dtype=int) bbox = np.array([[pic_save_str for i in range(len(left_time_bound))], left_time_bound, right_time_bound, lower_freq_bound, upper_freq_bound, x0, y0, x1, y1]) tmp_df = pd.DataFrame( data=bbox.T, columns=cols ) bbox_df = pd.concat([bbox_df, tmp_df], ignore_index=True) return bbox_df def main(args): # Hyperparameter min_freq = 200 max_freq = 1500 d_freq = 200 freq_overlap = 25 d_time = 60*10 time_overlap = 60*1 folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy')) if not args.inference: print('generate training dataset only for files with detected rises') folders = [folder for folder in folders if (folder / 'analysis' / 'rise_idx.npy').exists()] cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1'] bbox_df = pd.DataFrame(columns=cols) else: print('generate inference dataset ... only image output') bbox_df = {} for enu, folder in enumerate(folders): print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}') # load different categories of data freq, times, spec = ( load_spec_data(folder)) EODf_v, ident_v, idx_v, times_v = ( load_tracking_data(folder)) if not args.inference: fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = ( load_trial_data(folder)) # generate iterator for analysis window loop pic_base = tqdm(itertools.product( np.arange(0, times[-1], d_time), np.arange(min_freq, max_freq, d_freq) ), total=int((((max_freq-min_freq)//d_freq)+1) * ((times[-1] // d_time)+1)) ) for t0, f0 in pic_base: t1 = t0 + d_time + time_overlap f1 = f0 + d_freq + freq_overlap present_freqs = EODf_v[(~np.isnan(ident_v)) & (t0 <= times_v[idx_v]) & (times_v[idx_v] <= t1) & (EODf_v >= f0) & (EODf_v <= f1)] if len(present_freqs) == 0: continue # get spec_idx for current spec snippet f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1)) t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1)) # get spec snippet and create torch.tensfor from it s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32) log_s = torch.log10(s) transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s)) s_trans = transformed(log_s.unsqueeze(0)) pic_save_str, (width, height) = save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, args.dataset_folder) if not args.inference: bbox_df = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1) if not args.inference: print('save bboxes') bbox_df.to_csv(os.path.join(args.dataset_folder, 'bbox_dataset.csv'), columns=cols, sep=',') if __name__ == '__main__': parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') parser.add_argument('folder', type=str, help='single recording analysis', default='') parser.add_argument('-d', "--dataset_folder", type=str, help='designated datasef folder', default='dataset') parser.add_argument('-i', "--inference", action="store_true", help="generate inference dataset. Img only") args = parser.parse_args() if not Path(args.dataset_folder).exists(): Path(args.dataset_folder).mkdir(parents=True, exist_ok=True) main(args)