added miussing dataset generator
This commit is contained in:
		
							parent
							
								
									9dcf54c139
								
							
						
					
					
						commit
						92ab342a65
					
				
							
								
								
									
										324
									
								
								generate_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										324
									
								
								generate_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,324 @@ | |||||||
|  | import itertools | ||||||
|  | import sys | ||||||
|  | import os | ||||||
|  | import argparse | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | import torchvision.transforms as T | ||||||
|  | 
 | ||||||
|  | import numpy as np | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | import matplotlib.gridspec as gridspec | ||||||
|  | from matplotlib.patches import Rectangle | ||||||
|  | import pandas as pd | ||||||
|  | from pathlib import Path | ||||||
|  | from tqdm.auto import tqdm | ||||||
|  | from IPython import embed | ||||||
|  | 
 | ||||||
|  | from confic import (MIN_FREQ, MAX_FREQ, DELTA_FREQ, FREQ_OVERLAP, DELTA_TIME, TIME_OVERLAP, IMG_SIZE, IMG_DPI, DATA_DIR, LABEL_DIR) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def load_spec_data(folder: str): | ||||||
|  |     """ | ||||||
|  |     Load spectrogram of a given electrode-grid recording generated with the wavetracker package. The spectrograms may | ||||||
|  |     be to large to load in total, thats why memmory mapping is used (numpy.memmap). | ||||||
|  | 
 | ||||||
|  |     Parameters | ||||||
|  |     ---------- | ||||||
|  |     folder: str | ||||||
|  |         Folder where fine spec numpy files generated for grid recordings with the wavetracker package can be found. | ||||||
|  | 
 | ||||||
|  |     Returns | ||||||
|  |     ------- | ||||||
|  |     fill_freqs: ndarray | ||||||
|  |         Freuqencies corresponding to 1st dimension of the spectrogram. | ||||||
|  |     fill_times: ndarray | ||||||
|  |         Times corresponding to the 2nd dimenstion if the spectrigram. | ||||||
|  |     fill_spec: ndarray | ||||||
|  |         Spectrigram of the recording refered to in the input folder. | ||||||
|  |     """ | ||||||
|  |     fill_freqs, fill_times, fill_spec = [], [], [] | ||||||
|  | 
 | ||||||
|  |     if os.path.exists(os.path.join(folder, 'fill_spec.npy')): | ||||||
|  |         fill_freqs = np.load(os.path.join(folder, 'fill_freqs.npy')) | ||||||
|  |         fill_times = np.load(os.path.join(folder, 'fill_times.npy')) | ||||||
|  |         fill_spec_shape = np.load(os.path.join(folder, 'fill_spec_shape.npy')) | ||||||
|  |         fill_spec = np.memmap(os.path.join(folder, 'fill_spec.npy'), dtype='float', mode='r', | ||||||
|  |                                    shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F') | ||||||
|  | 
 | ||||||
|  |     elif os.path.exists(os.path.join(folder, 'fine_spec.npy')): | ||||||
|  |         fill_freqs = np.load(os.path.join(folder, 'fine_freqs.npy')) | ||||||
|  |         fill_times = np.load(os.path.join(folder, 'fine_times.npy')) | ||||||
|  |         fill_spec_shape = np.load(os.path.join(folder, 'fine_spec_shape.npy')) | ||||||
|  |         fill_spec = np.memmap(os.path.join(folder, 'fine_spec.npy'), dtype='float', mode='r', | ||||||
|  |                                    shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F') | ||||||
|  | 
 | ||||||
|  |     return fill_freqs, fill_times, fill_spec | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def load_tracking_data(folder): | ||||||
|  |     base_path = Path(folder) | ||||||
|  |     EODf_v = np.load(base_path / 'fund_v.npy') | ||||||
|  |     ident_v = np.load(base_path / 'ident_v.npy') | ||||||
|  |     idx_v = np.load(base_path / 'idx_v.npy') | ||||||
|  |     times_v = np.load(base_path / 'times.npy') | ||||||
|  | 
 | ||||||
|  |     return EODf_v, ident_v, idx_v, times_v | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def load_trial_data(folder): | ||||||
|  |     base_path = Path(folder) | ||||||
|  |     fish_freq = np.load(base_path / 'analysis' / 'fish_freq.npy') | ||||||
|  |     rise_idx = np.load(base_path / 'analysis' / 'rise_idx.npy') | ||||||
|  |     rise_size = np.load(base_path / 'analysis' / 'rise_size.npy') | ||||||
|  | 
 | ||||||
|  |     fish_baseline_freq = np.load(base_path / 'analysis' / 'baseline_freqs.npy') | ||||||
|  |     fish_baseline_freq_time = np.load(base_path / 'analysis' / 'baseline_freq_times.npy') | ||||||
|  | 
 | ||||||
|  |     return fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, dataset_folder): | ||||||
|  |     f_res, t_res = freq[1] - freq[0], times[1] - times[0] | ||||||
|  | 
 | ||||||
|  |     fig_title = (f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(' ', '0') | ||||||
|  |     fig = plt.figure(figsize=IMG_SIZE, num=fig_title) | ||||||
|  |     gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1)  # | ||||||
|  |     ax = fig.add_subplot(gs[0, 0]) | ||||||
|  |     ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', | ||||||
|  |               extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res)) | ||||||
|  |     ax.axis(False) | ||||||
|  | 
 | ||||||
|  |     # plt.savefig(os.path.join(dataset_folder, fig_title), dpi=IMG_DPI) | ||||||
|  |     plt.savefig(Path(DATA_DIR)/fig_title, dpi=IMG_DPI) | ||||||
|  |     plt.close() | ||||||
|  | 
 | ||||||
|  |     return fig_title, (IMG_SIZE[0]*IMG_DPI, IMG_SIZE[1]*IMG_DPI) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str, | ||||||
|  |                      bbox_df, cols, width, height, t0, t1, f0, f1): | ||||||
|  | 
 | ||||||
|  |     times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1)) | ||||||
|  | 
 | ||||||
|  |     all_x_center = [] | ||||||
|  |     all_y_center = [] | ||||||
|  |     all_width = [] | ||||||
|  |     all_height= [] | ||||||
|  |     for id_idx in range(len(fish_freq)): | ||||||
|  |         rise_idx_oi = np.array(rise_idx[id_idx][ | ||||||
|  |                                    (rise_idx[id_idx] >= times_v_idx0) & | ||||||
|  |                                    (rise_idx[id_idx] <= times_v_idx1) & | ||||||
|  |                                    (rise_size[id_idx] >= 10)], dtype=int) | ||||||
|  |         rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) & | ||||||
|  |                                          (rise_idx[id_idx] <= times_v_idx1) & | ||||||
|  |                                          (rise_size[id_idx] >= 10)] | ||||||
|  |         if len(rise_idx_oi) == 0: | ||||||
|  |             # np.savetxt(LABEL_DIR / Path(pic_save_str).with_suffix('.txt'), np.array([])) | ||||||
|  |             continue | ||||||
|  | 
 | ||||||
|  |         closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi])) | ||||||
|  |         closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx] | ||||||
|  | 
 | ||||||
|  |         upper_freq_bound = closest_baseline_freq + rise_size_oi | ||||||
|  |         lower_freq_bound = closest_baseline_freq | ||||||
|  | 
 | ||||||
|  |         left_time_bound = times_v[rise_idx_oi] | ||||||
|  |         right_time_bound = np.zeros_like(left_time_bound) | ||||||
|  | 
 | ||||||
|  |         for enu, Ct_oi in enumerate(times_v[rise_idx_oi]): | ||||||
|  |             Crise_size = rise_size_oi[enu] | ||||||
|  |             Cblf = closest_baseline_freq[enu] | ||||||
|  | 
 | ||||||
|  |             rise_end_t = times_v[(times_v > Ct_oi) & | ||||||
|  |                                  (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] | ||||||
|  |             if len(rise_end_t) == 0: | ||||||
|  |                 right_time_bound[enu] = np.nan | ||||||
|  |             else: | ||||||
|  |                 right_time_bound[enu] = rise_end_t[0] | ||||||
|  | 
 | ||||||
|  |         mask = (~np.isnan(right_time_bound) & ((right_time_bound - left_time_bound) > 1.)) | ||||||
|  |         left_time_bound = left_time_bound[mask] | ||||||
|  |         right_time_bound = right_time_bound[mask] | ||||||
|  |         lower_freq_bound = lower_freq_bound[mask] | ||||||
|  |         upper_freq_bound = upper_freq_bound[mask] | ||||||
|  | 
 | ||||||
|  |         left_time_bound -= 0.01 * (t1 - t0) | ||||||
|  |         right_time_bound += 0.05 * (t1 - t0) | ||||||
|  |         lower_freq_bound -= 0.01 * (f1 - f0) | ||||||
|  |         upper_freq_bound += 0.05 * (f1 - f0) | ||||||
|  | 
 | ||||||
|  |         mask2 = ((left_time_bound >= t0) & | ||||||
|  |                 (right_time_bound <= t1) & | ||||||
|  |                 (lower_freq_bound >= f0) & | ||||||
|  |                 (upper_freq_bound <= f1) | ||||||
|  |         ) | ||||||
|  |         left_time_bound = left_time_bound[mask2] | ||||||
|  |         right_time_bound = right_time_bound[mask2] | ||||||
|  |         lower_freq_bound = lower_freq_bound[mask2] | ||||||
|  |         upper_freq_bound = upper_freq_bound[mask2] | ||||||
|  | 
 | ||||||
|  |         if len(left_time_bound) == 0: | ||||||
|  |             continue | ||||||
|  | 
 | ||||||
|  |         # x0 = np.array((left_time_bound - t0) / (t1 - t0) * width, dtype=int) | ||||||
|  |         # x1 = np.array((right_time_bound - t0) / (t1 - t0) * width, dtype=int) | ||||||
|  |         # | ||||||
|  |         # y0 = np.array((1 - (upper_freq_bound - f0) / (f1 - f0)) * height, dtype=int) | ||||||
|  |         # y1 = np.array((1 - (lower_freq_bound - f0) / (f1 - f0)) * height, dtype=int) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         rel_x0 = np.array((left_time_bound - t0) / (t1 - t0), dtype=float) | ||||||
|  |         rel_x1 = np.array((right_time_bound - t0) / (t1 - t0), dtype=float) | ||||||
|  | 
 | ||||||
|  |         rel_y0 = np.array(1 - (upper_freq_bound - f0) / (f1 - f0), dtype=float) | ||||||
|  |         rel_y1 = np.array(1 - (lower_freq_bound - f0) / (f1 - f0), dtype=float) | ||||||
|  | 
 | ||||||
|  |         rel_x_center = rel_x1 - (rel_x1 - rel_x0) / 2 | ||||||
|  |         rel_y_center = rel_y1 - (rel_y1 - rel_y0) / 2 | ||||||
|  |         rel_width = rel_x1 - rel_x0 | ||||||
|  |         rel_height = rel_y1 - rel_y0 | ||||||
|  | 
 | ||||||
|  |         all_x_center.extend(rel_x_center) | ||||||
|  |         all_y_center.extend(rel_y_center) | ||||||
|  |         all_width.extend(rel_width) | ||||||
|  |         all_height.extend(rel_height) | ||||||
|  | 
 | ||||||
|  |         # bbox = np.array([[pic_save_str for i in range(len(left_time_bound))], | ||||||
|  |         #                  left_time_bound, | ||||||
|  |         #                  right_time_bound, | ||||||
|  |         #                  lower_freq_bound, | ||||||
|  |         #                  upper_freq_bound, | ||||||
|  |         #                  x0, y0, x1, y1]) | ||||||
|  | 
 | ||||||
|  |     bbox_yolo_style = np.array([ | ||||||
|  |         np.ones(len(all_x_center)), | ||||||
|  |         all_x_center, | ||||||
|  |         all_y_center, | ||||||
|  |         all_width, | ||||||
|  |         all_height | ||||||
|  |     ]).T | ||||||
|  | 
 | ||||||
|  |     np.savetxt(LABEL_DIR/ Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style) | ||||||
|  |     return bbox_yolo_style | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def main(args): | ||||||
|  |     folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy')) | ||||||
|  | 
 | ||||||
|  |     if not args.inference: | ||||||
|  |         print('generate training dataset only for files with detected rises') | ||||||
|  |         folders = [folder for folder in folders if (folder / 'analysis' / 'rise_idx.npy').exists()] | ||||||
|  |         cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1'] | ||||||
|  |         bbox_df = pd.DataFrame(columns=cols) | ||||||
|  | 
 | ||||||
|  |     else: | ||||||
|  |         print('generate inference dataset ... only image output') | ||||||
|  |         bbox_df = {} | ||||||
|  | 
 | ||||||
|  |     for enu, folder in enumerate(folders): | ||||||
|  |         print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}') | ||||||
|  | 
 | ||||||
|  |         # load different categories of data | ||||||
|  |         freq, times, spec = ( | ||||||
|  |             load_spec_data(folder)) | ||||||
|  |         EODf_v, ident_v, idx_v, times_v = ( | ||||||
|  |             load_tracking_data(folder)) | ||||||
|  |         if not args.inference: | ||||||
|  |             fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = ( | ||||||
|  |                 load_trial_data(folder)) | ||||||
|  | 
 | ||||||
|  |         # generate iterator for analysis window loop | ||||||
|  |         pic_base = tqdm(itertools.product( | ||||||
|  |             np.arange(0, times[-1], DELTA_TIME), | ||||||
|  |             np.arange(MIN_FREQ, MAX_FREQ, DELTA_FREQ) | ||||||
|  |         ), | ||||||
|  |             total=int((((MAX_FREQ-MIN_FREQ)//DELTA_FREQ)+1) * ((times[-1] // DELTA_TIME)+1)) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         for t0, f0 in pic_base: | ||||||
|  | 
 | ||||||
|  |             t1 = t0 + DELTA_TIME + TIME_OVERLAP | ||||||
|  |             f1 = f0 + DELTA_FREQ + FREQ_OVERLAP | ||||||
|  | 
 | ||||||
|  |             present_freqs = EODf_v[(~np.isnan(ident_v))     & | ||||||
|  |                                    (t0 <= times_v[idx_v]) & | ||||||
|  |                                    (times_v[idx_v] <= t1) & | ||||||
|  |                                    (EODf_v >= f0) & | ||||||
|  |                                    (EODf_v <= f1)] | ||||||
|  |             if len(present_freqs) == 0: | ||||||
|  |                 continue | ||||||
|  | 
 | ||||||
|  |             # get spec_idx for current spec snippet | ||||||
|  |             f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1)) | ||||||
|  |             t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1)) | ||||||
|  | 
 | ||||||
|  |             # get spec snippet and create torch.tensfor from it | ||||||
|  |             s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32) | ||||||
|  |             log_s = torch.log10(s) | ||||||
|  |             transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s)) | ||||||
|  |             s_trans = transformed(log_s.unsqueeze(0)) | ||||||
|  | 
 | ||||||
|  |             pic_save_str, (width, height) = save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, args.dataset_folder) | ||||||
|  | 
 | ||||||
|  |             if not args.inference: | ||||||
|  |                 bbox_yolo_style = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, | ||||||
|  |                                                    fish_baseline_freq_time, fish_baseline_freq, | ||||||
|  |                                                    pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1) | ||||||
|  | 
 | ||||||
|  |             ####################################################################### | ||||||
|  |             if False: | ||||||
|  |                 if bbox_yolo_style.shape[0] >= 1: | ||||||
|  |                     f_res, t_res = freq[1] - freq[0], times[1] - times[0] | ||||||
|  | 
 | ||||||
|  |                     fig_title = ( | ||||||
|  |                         f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace( | ||||||
|  |                         ' ', '0') | ||||||
|  |                     fig = plt.figure(figsize=IMG_SIZE, num=fig_title) | ||||||
|  |                     gs = gridspec.GridSpec(1, 1, bottom=0.1, left=0.1, right=0.95, top=0.95)  # | ||||||
|  |                     ax = fig.add_subplot(gs[0, 0]) | ||||||
|  |                     ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', | ||||||
|  |                               extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res)) | ||||||
|  |                     # ax.invert_yaxis() | ||||||
|  |                     # ax.axis(False) | ||||||
|  | 
 | ||||||
|  |                     for i in range(len(bbox_df)): | ||||||
|  |                         # Cbbox = np.array(bbox_df.loc[i, ['x0', 'y0', 'x1', 'y1']].values, dtype=np.float32) | ||||||
|  |                         Cbbox = bbox_df.loc[i, ['t0', 'f0', 't1', 'f1']] | ||||||
|  |                         ax.add_patch( | ||||||
|  |                             Rectangle((float(Cbbox['t0']) / 3600, float(Cbbox['f0'])), | ||||||
|  |                                       float(Cbbox['t1']) / 3600 - float(Cbbox['t0']) / 3600, | ||||||
|  |                                       float(Cbbox['f1']) - float(Cbbox['f0']), | ||||||
|  |                                       fill=False, color="white", linestyle='-', linewidth=2, zorder=10) | ||||||
|  |                         ) | ||||||
|  | 
 | ||||||
|  |                     # print(bbox_yolo_style.T) | ||||||
|  | 
 | ||||||
|  |                     for bbox in bbox_yolo_style: | ||||||
|  |                         x0 = bbox[1] - bbox[3]/2 # x_center - width/2 | ||||||
|  |                         y0 = 1 - (bbox[2] + bbox[4]/2) # x_center - width/2 | ||||||
|  |                         w = bbox[3] | ||||||
|  |                         h = bbox[4] | ||||||
|  |                         ax.add_patch( | ||||||
|  |                             Rectangle((x0, y0), w, h, | ||||||
|  |                                       fill=False, color="k", linestyle='--', linewidth=2, zorder=10, | ||||||
|  |                                       transform=ax.transAxes) | ||||||
|  |                         ) | ||||||
|  |                     plt.show() | ||||||
|  |             ####################################################################### | ||||||
|  | 
 | ||||||
|  |         # if not args.inference: | ||||||
|  |         #     print('save bboxes') | ||||||
|  |             # bbox_df.to_csv(os.path.join(args.dataset_folder, 'bbox_dataset.csv'), columns=cols, sep=',') | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') | ||||||
|  |     parser.add_argument('folder', type=str, help='single recording analysis', default='') | ||||||
|  |     parser.add_argument('-d', "--dataset_folder", type=str, help='designated datasef folder', default=DATA_DIR) | ||||||
|  |     parser.add_argument('-i', "--inference", action="store_true", help="generate inference dataset. Img only") | ||||||
|  |     args = parser.parse_args() | ||||||
|  | 
 | ||||||
|  |     if not Path(args.dataset_folder).exists(): | ||||||
|  |         Path(args.dataset_folder).mkdir(parents=True, exist_ok=True) | ||||||
|  | 
 | ||||||
|  |     main(args) | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user