diff --git a/generate_dataset.py b/generate_dataset.py new file mode 100644 index 0000000..bcd7e4f --- /dev/null +++ b/generate_dataset.py @@ -0,0 +1,324 @@ +import itertools +import sys +import os +import argparse + +import torch +import torchvision.transforms as T + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +from matplotlib.patches import Rectangle +import pandas as pd +from pathlib import Path +from tqdm.auto import tqdm +from IPython import embed + +from confic import (MIN_FREQ, MAX_FREQ, DELTA_FREQ, FREQ_OVERLAP, DELTA_TIME, TIME_OVERLAP, IMG_SIZE, IMG_DPI, DATA_DIR, LABEL_DIR) + + +def load_spec_data(folder: str): + """ + Load spectrogram of a given electrode-grid recording generated with the wavetracker package. The spectrograms may + be to large to load in total, thats why memmory mapping is used (numpy.memmap). + + Parameters + ---------- + folder: str + Folder where fine spec numpy files generated for grid recordings with the wavetracker package can be found. + + Returns + ------- + fill_freqs: ndarray + Freuqencies corresponding to 1st dimension of the spectrogram. + fill_times: ndarray + Times corresponding to the 2nd dimenstion if the spectrigram. + fill_spec: ndarray + Spectrigram of the recording refered to in the input folder. + """ + fill_freqs, fill_times, fill_spec = [], [], [] + + if os.path.exists(os.path.join(folder, 'fill_spec.npy')): + fill_freqs = np.load(os.path.join(folder, 'fill_freqs.npy')) + fill_times = np.load(os.path.join(folder, 'fill_times.npy')) + fill_spec_shape = np.load(os.path.join(folder, 'fill_spec_shape.npy')) + fill_spec = np.memmap(os.path.join(folder, 'fill_spec.npy'), dtype='float', mode='r', + shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F') + + elif os.path.exists(os.path.join(folder, 'fine_spec.npy')): + fill_freqs = np.load(os.path.join(folder, 'fine_freqs.npy')) + fill_times = np.load(os.path.join(folder, 'fine_times.npy')) + fill_spec_shape = np.load(os.path.join(folder, 'fine_spec_shape.npy')) + fill_spec = np.memmap(os.path.join(folder, 'fine_spec.npy'), dtype='float', mode='r', + shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F') + + return fill_freqs, fill_times, fill_spec + + +def load_tracking_data(folder): + base_path = Path(folder) + EODf_v = np.load(base_path / 'fund_v.npy') + ident_v = np.load(base_path / 'ident_v.npy') + idx_v = np.load(base_path / 'idx_v.npy') + times_v = np.load(base_path / 'times.npy') + + return EODf_v, ident_v, idx_v, times_v + + +def load_trial_data(folder): + base_path = Path(folder) + fish_freq = np.load(base_path / 'analysis' / 'fish_freq.npy') + rise_idx = np.load(base_path / 'analysis' / 'rise_idx.npy') + rise_size = np.load(base_path / 'analysis' / 'rise_size.npy') + + fish_baseline_freq = np.load(base_path / 'analysis' / 'baseline_freqs.npy') + fish_baseline_freq_time = np.load(base_path / 'analysis' / 'baseline_freq_times.npy') + + return fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time + + +def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, dataset_folder): + f_res, t_res = freq[1] - freq[0], times[1] - times[0] + + fig_title = (f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(' ', '0') + fig = plt.figure(figsize=IMG_SIZE, num=fig_title) + gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) # + ax = fig.add_subplot(gs[0, 0]) + ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', + extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res)) + ax.axis(False) + + # plt.savefig(os.path.join(dataset_folder, fig_title), dpi=IMG_DPI) + plt.savefig(Path(DATA_DIR)/fig_title, dpi=IMG_DPI) + plt.close() + + return fig_title, (IMG_SIZE[0]*IMG_DPI, IMG_SIZE[1]*IMG_DPI) + + +def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str, + bbox_df, cols, width, height, t0, t1, f0, f1): + + times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1)) + + all_x_center = [] + all_y_center = [] + all_width = [] + all_height= [] + for id_idx in range(len(fish_freq)): + rise_idx_oi = np.array(rise_idx[id_idx][ + (rise_idx[id_idx] >= times_v_idx0) & + (rise_idx[id_idx] <= times_v_idx1) & + (rise_size[id_idx] >= 10)], dtype=int) + rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) & + (rise_idx[id_idx] <= times_v_idx1) & + (rise_size[id_idx] >= 10)] + if len(rise_idx_oi) == 0: + # np.savetxt(LABEL_DIR / Path(pic_save_str).with_suffix('.txt'), np.array([])) + continue + + closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi])) + closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx] + + upper_freq_bound = closest_baseline_freq + rise_size_oi + lower_freq_bound = closest_baseline_freq + + left_time_bound = times_v[rise_idx_oi] + right_time_bound = np.zeros_like(left_time_bound) + + for enu, Ct_oi in enumerate(times_v[rise_idx_oi]): + Crise_size = rise_size_oi[enu] + Cblf = closest_baseline_freq[enu] + + rise_end_t = times_v[(times_v > Ct_oi) & + (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] + if len(rise_end_t) == 0: + right_time_bound[enu] = np.nan + else: + right_time_bound[enu] = rise_end_t[0] + + mask = (~np.isnan(right_time_bound) & ((right_time_bound - left_time_bound) > 1.)) + left_time_bound = left_time_bound[mask] + right_time_bound = right_time_bound[mask] + lower_freq_bound = lower_freq_bound[mask] + upper_freq_bound = upper_freq_bound[mask] + + left_time_bound -= 0.01 * (t1 - t0) + right_time_bound += 0.05 * (t1 - t0) + lower_freq_bound -= 0.01 * (f1 - f0) + upper_freq_bound += 0.05 * (f1 - f0) + + mask2 = ((left_time_bound >= t0) & + (right_time_bound <= t1) & + (lower_freq_bound >= f0) & + (upper_freq_bound <= f1) + ) + left_time_bound = left_time_bound[mask2] + right_time_bound = right_time_bound[mask2] + lower_freq_bound = lower_freq_bound[mask2] + upper_freq_bound = upper_freq_bound[mask2] + + if len(left_time_bound) == 0: + continue + + # x0 = np.array((left_time_bound - t0) / (t1 - t0) * width, dtype=int) + # x1 = np.array((right_time_bound - t0) / (t1 - t0) * width, dtype=int) + # + # y0 = np.array((1 - (upper_freq_bound - f0) / (f1 - f0)) * height, dtype=int) + # y1 = np.array((1 - (lower_freq_bound - f0) / (f1 - f0)) * height, dtype=int) + + + rel_x0 = np.array((left_time_bound - t0) / (t1 - t0), dtype=float) + rel_x1 = np.array((right_time_bound - t0) / (t1 - t0), dtype=float) + + rel_y0 = np.array(1 - (upper_freq_bound - f0) / (f1 - f0), dtype=float) + rel_y1 = np.array(1 - (lower_freq_bound - f0) / (f1 - f0), dtype=float) + + rel_x_center = rel_x1 - (rel_x1 - rel_x0) / 2 + rel_y_center = rel_y1 - (rel_y1 - rel_y0) / 2 + rel_width = rel_x1 - rel_x0 + rel_height = rel_y1 - rel_y0 + + all_x_center.extend(rel_x_center) + all_y_center.extend(rel_y_center) + all_width.extend(rel_width) + all_height.extend(rel_height) + + # bbox = np.array([[pic_save_str for i in range(len(left_time_bound))], + # left_time_bound, + # right_time_bound, + # lower_freq_bound, + # upper_freq_bound, + # x0, y0, x1, y1]) + + bbox_yolo_style = np.array([ + np.ones(len(all_x_center)), + all_x_center, + all_y_center, + all_width, + all_height + ]).T + + np.savetxt(LABEL_DIR/ Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style) + return bbox_yolo_style + + +def main(args): + folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy')) + + if not args.inference: + print('generate training dataset only for files with detected rises') + folders = [folder for folder in folders if (folder / 'analysis' / 'rise_idx.npy').exists()] + cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1'] + bbox_df = pd.DataFrame(columns=cols) + + else: + print('generate inference dataset ... only image output') + bbox_df = {} + + for enu, folder in enumerate(folders): + print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}') + + # load different categories of data + freq, times, spec = ( + load_spec_data(folder)) + EODf_v, ident_v, idx_v, times_v = ( + load_tracking_data(folder)) + if not args.inference: + fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = ( + load_trial_data(folder)) + + # generate iterator for analysis window loop + pic_base = tqdm(itertools.product( + np.arange(0, times[-1], DELTA_TIME), + np.arange(MIN_FREQ, MAX_FREQ, DELTA_FREQ) + ), + total=int((((MAX_FREQ-MIN_FREQ)//DELTA_FREQ)+1) * ((times[-1] // DELTA_TIME)+1)) + ) + + for t0, f0 in pic_base: + + t1 = t0 + DELTA_TIME + TIME_OVERLAP + f1 = f0 + DELTA_FREQ + FREQ_OVERLAP + + present_freqs = EODf_v[(~np.isnan(ident_v)) & + (t0 <= times_v[idx_v]) & + (times_v[idx_v] <= t1) & + (EODf_v >= f0) & + (EODf_v <= f1)] + if len(present_freqs) == 0: + continue + + # get spec_idx for current spec snippet + f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1)) + t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1)) + + # get spec snippet and create torch.tensfor from it + s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32) + log_s = torch.log10(s) + transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s)) + s_trans = transformed(log_s.unsqueeze(0)) + + pic_save_str, (width, height) = save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, args.dataset_folder) + + if not args.inference: + bbox_yolo_style = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, + fish_baseline_freq_time, fish_baseline_freq, + pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1) + + ####################################################################### + if False: + if bbox_yolo_style.shape[0] >= 1: + f_res, t_res = freq[1] - freq[0], times[1] - times[0] + + fig_title = ( + f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace( + ' ', '0') + fig = plt.figure(figsize=IMG_SIZE, num=fig_title) + gs = gridspec.GridSpec(1, 1, bottom=0.1, left=0.1, right=0.95, top=0.95) # + ax = fig.add_subplot(gs[0, 0]) + ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', + extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res)) + # ax.invert_yaxis() + # ax.axis(False) + + for i in range(len(bbox_df)): + # Cbbox = np.array(bbox_df.loc[i, ['x0', 'y0', 'x1', 'y1']].values, dtype=np.float32) + Cbbox = bbox_df.loc[i, ['t0', 'f0', 't1', 'f1']] + ax.add_patch( + Rectangle((float(Cbbox['t0']) / 3600, float(Cbbox['f0'])), + float(Cbbox['t1']) / 3600 - float(Cbbox['t0']) / 3600, + float(Cbbox['f1']) - float(Cbbox['f0']), + fill=False, color="white", linestyle='-', linewidth=2, zorder=10) + ) + + # print(bbox_yolo_style.T) + + for bbox in bbox_yolo_style: + x0 = bbox[1] - bbox[3]/2 # x_center - width/2 + y0 = 1 - (bbox[2] + bbox[4]/2) # x_center - width/2 + w = bbox[3] + h = bbox[4] + ax.add_patch( + Rectangle((x0, y0), w, h, + fill=False, color="k", linestyle='--', linewidth=2, zorder=10, + transform=ax.transAxes) + ) + plt.show() + ####################################################################### + + # if not args.inference: + # print('save bboxes') + # bbox_df.to_csv(os.path.join(args.dataset_folder, 'bbox_dataset.csv'), columns=cols, sep=',') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') + parser.add_argument('folder', type=str, help='single recording analysis', default='') + parser.add_argument('-d', "--dataset_folder", type=str, help='designated datasef folder', default=DATA_DIR) + parser.add_argument('-i', "--inference", action="store_true", help="generate inference dataset. Img only") + args = parser.parse_args() + + if not Path(args.dataset_folder).exists(): + Path(args.dataset_folder).mkdir(parents=True, exist_ok=True) + + main(args) \ No newline at end of file