diff --git a/data/generate_dataset.py b/data/generate_dataset.py index 39e3afc..8a8570e 100644 --- a/data/generate_dataset.py +++ b/data/generate_dataset.py @@ -22,8 +22,7 @@ from IPython import embed from matplotlib.patches import Rectangle from matplotlib.collections import PatchCollection - -def load_data(folder): +def load_spec_data(folder): fill_freqs, fill_times, fill_spec = [], [], [] if os.path.exists(os.path.join(folder, 'fill_spec.npy')): @@ -40,12 +39,19 @@ def load_data(folder): fill_spec = np.memmap(os.path.join(folder, 'fine_spec.npy'), dtype='float', mode='r', shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F') + return fill_freqs, fill_times, fill_spec + +def load_tracking_data(folder): base_path = Path(folder) EODf_v = np.load(base_path / 'fund_v.npy') ident_v = np.load(base_path / 'ident_v.npy') idx_v = np.load(base_path / 'idx_v.npy') times_v = np.load(base_path / 'times.npy') + return EODf_v, ident_v, idx_v, times_v + +def load_trial_data(folder): + base_path = Path(folder) fish_freq = np.load(base_path / 'analysis' / 'fish_freq.npy') rise_idx = np.load(base_path / 'analysis' / 'rise_idx.npy') rise_size = np.load(base_path / 'analysis' / 'rise_size.npy') @@ -53,21 +59,22 @@ def load_data(folder): fish_baseline_freq = np.load(base_path / 'analysis' / 'baseline_freqs.npy') fish_baseline_freq_time = np.load(base_path / 'analysis' / 'baseline_freq_times.npy') - return fill_freqs, fill_times, fill_spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time + return fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time -def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res): +def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, dataset_folder): size = (7, 7) dpi = 256 + f_res, t_res = freq[1] - freq[0], times[1] - times[0] + fig_title = (f'{Path(folder).name}__{times[t_idx0]:.0f}s-{times[t_idx1]:.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(' ', '0') fig = plt.figure(figsize=(7, 7), num=fig_title) - gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0) # , bottom=0, left=0, right=1, top=1 - gs2 = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) # - ax = fig.add_subplot(gs2[0, 0]) - im = ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', - extent=(times[t_idx0] / 3600, times[t_idx1] / 3600 + t_res, freq[f_idx0], freq[f_idx1] + f_res)) + gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) # + ax = fig.add_subplot(gs[0, 0]) + ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', + extent=(times[t_idx0] / 3600, times[t_idx1] / 3600 + t_res, freq[f_idx0], freq[f_idx1] + f_res)) ax.axis(False) - plt.savefig(os.path.join('dataset', fig_title), dpi=256) + plt.savefig(os.path.join(dataset_folder, fig_title), dpi=256) plt.close() return fig_title, (size[0]*dpi, size[1]*dpi) @@ -148,45 +155,8 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq bbox_df = pd.concat([bbox_df, tmp_df], ignore_index=True) return bbox_df -def main(args): - def development_fn(): - fig_title = (f'{Path(args.folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0') - fig = plt.figure(figsize=(7, 7), num=fig_title) - gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0, left=0.1, bottom=0.1, right=0.9, - top=0.95) # , bottom=0, left=0, right=1, top=1 - ax = fig.add_subplot(gs[0, 0]) - cax = fig.add_subplot(gs[0, 1]) - im = ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', - extent=(times[t_idx0], times[t_idx1] + t_res, freq[f_idx0], freq[f_idx1] + f_res)) - fig.colorbar(im, cax=cax, orientation='vertical') - - - cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1'] - dev_df = pd.DataFrame(columns=cols) - - dev_df = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, - fig_title, dev_df, cols, (7*256), (7*256), t0, t1, f0, f1) - - # embed() - # quit() - time_freq_bbox = torch.as_tensor(dev_df.loc[:, ['t0', 'f0', 't1', 'f1']].values.astype(np.float32)) - - for bbox in time_freq_bbox: - Ct0, Cf0, Ct1, Cf1 = bbox - ax.add_patch( - Rectangle((Ct0, Cf0), Ct1-Ct0, Cf1-Cf0, fill=False, color="white", linewidth=2, zorder=10) - ) - # for enu in range(len(left_time_bound)): - # if np.isnan(right_time_bound[enu]): - # continue - # ax.add_patch( - # Rectangle((left_time_bound[enu], lower_freq_bound[enu]), - # (right_time_bound[enu] - left_time_bound[enu]), - # (upper_freq_bound[enu] - lower_freq_bound[enu]), - # fill=False, color="white", linewidth=2, zorder=10) - # ) - plt.show() +def main(args): # Hyperparameter min_freq = 200 max_freq = 1500 @@ -195,38 +165,30 @@ def main(args): d_time = 60*10 time_overlap = 60*1 - # init dataframe if not existent so far - eval_files = [] - if not os.path.exists(os.path.join('dataset', 'bbox_dataset.csv')): + folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy')) + + if not args.inference: + print('generate training dataset only for files with detected rises') + folders = [folder for folder in folders if (folder / 'analysis' / 'rise_idx.npy').exists()] cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1'] bbox_df = pd.DataFrame(columns=cols) - - # else load datafile ... and check for already regarded files (eval_files) else: - bbox_df = pd.read_csv(os.path.join('dataset', 'bbox_dataset.csv'), sep=',', index_col=0) - cols = list(bbox_df.keys()) - # ToDo: make sure not same file twice - for f in pd.unique(bbox_df['image']): - eval_files.append(f.split('__')[0]) - - # find folders that have fine_specs... - folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy')) - + print('generate inference dataset ... only image output') + bbox_df = {} for enu, folder in enumerate(folders): print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}') - # check for those folders where rises are detected - if not (folder/'analysis'/'rise_idx.npy').exists(): - continue - - # embed() - # quit() - # ToDo: check if folder in eval_files ... is so: continue - freq, times, spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = ( - load_data(folder)) - f_res, t_res = freq[1] - freq[0], times[1] - times[0] + # load different categories of data + freq, times, spec = ( + load_spec_data(folder)) + EODf_v, ident_v, idx_v, times_v = ( + load_tracking_data(folder)) + if not args.inference: + fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = ( + load_trial_data(folder)) + # generate iterator for analysis window loop pic_base = tqdm(itertools.product( np.arange(0, times[-1], d_time), np.arange(min_freq, max_freq, d_freq) @@ -247,30 +209,35 @@ def main(args): if len(present_freqs) == 0: continue + # get spec_idx for current spec snippet f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1)) t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1)) + # get spec snippet and create torch.tensfor from it s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32) log_s = torch.log10(s) transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s)) s_trans = transformed(log_s.unsqueeze(0)) - if not args.dev: - pic_save_str, (width, height) = save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res) + pic_save_str, (width, height) = save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, args.dataset_folder) + if not args.inference: bbox_df = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1) - else: - development_fn() - if not args.dev: - print('save') - bbox_df.to_csv(os.path.join('dataset', 'bbox_dataset.csv'), columns=cols, sep=',') + if not args.inference: + print('save bboxes') + bbox_df.to_csv(os.path.join(args.dataset_folder, 'bbox_dataset.csv'), columns=cols, sep=',') if __name__ == '__main__': parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') parser.add_argument('folder', type=str, help='single recording analysis', default='') - parser.add_argument('-d', "--dev", action="store_true", help="developer mode; no data saved") + parser.add_argument('-d', "--dataset_folder", type=str, help='designated datasef folder', default='dataset') + parser.add_argument('-i', "--inference", action="store_true", help="generate inference dataset. Img only") args = parser.parse_args() + + if not Path(args.dataset_folder).exists(): + Path(args.dataset_folder).mkdir(parents=True, exist_ok=True) + main(args) \ No newline at end of file