added miussing dataset generator
This commit is contained in:
		
							parent
							
								
									9dcf54c139
								
							
						
					
					
						commit
						92ab342a65
					
				
							
								
								
									
										324
									
								
								generate_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										324
									
								
								generate_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,324 @@ | ||||
| import itertools | ||||
| import sys | ||||
| import os | ||||
| import argparse | ||||
| 
 | ||||
| import torch | ||||
| import torchvision.transforms as T | ||||
| 
 | ||||
| import numpy as np | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.gridspec as gridspec | ||||
| from matplotlib.patches import Rectangle | ||||
| import pandas as pd | ||||
| from pathlib import Path | ||||
| from tqdm.auto import tqdm | ||||
| from IPython import embed | ||||
| 
 | ||||
| from confic import (MIN_FREQ, MAX_FREQ, DELTA_FREQ, FREQ_OVERLAP, DELTA_TIME, TIME_OVERLAP, IMG_SIZE, IMG_DPI, DATA_DIR, LABEL_DIR) | ||||
| 
 | ||||
| 
 | ||||
| def load_spec_data(folder: str): | ||||
|     """ | ||||
|     Load spectrogram of a given electrode-grid recording generated with the wavetracker package. The spectrograms may | ||||
|     be to large to load in total, thats why memmory mapping is used (numpy.memmap). | ||||
| 
 | ||||
|     Parameters | ||||
|     ---------- | ||||
|     folder: str | ||||
|         Folder where fine spec numpy files generated for grid recordings with the wavetracker package can be found. | ||||
| 
 | ||||
|     Returns | ||||
|     ------- | ||||
|     fill_freqs: ndarray | ||||
|         Freuqencies corresponding to 1st dimension of the spectrogram. | ||||
|     fill_times: ndarray | ||||
|         Times corresponding to the 2nd dimenstion if the spectrigram. | ||||
|     fill_spec: ndarray | ||||
|         Spectrigram of the recording refered to in the input folder. | ||||
|     """ | ||||
|     fill_freqs, fill_times, fill_spec = [], [], [] | ||||
| 
 | ||||
|     if os.path.exists(os.path.join(folder, 'fill_spec.npy')): | ||||
|         fill_freqs = np.load(os.path.join(folder, 'fill_freqs.npy')) | ||||
|         fill_times = np.load(os.path.join(folder, 'fill_times.npy')) | ||||
|         fill_spec_shape = np.load(os.path.join(folder, 'fill_spec_shape.npy')) | ||||
|         fill_spec = np.memmap(os.path.join(folder, 'fill_spec.npy'), dtype='float', mode='r', | ||||
|                                    shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F') | ||||
| 
 | ||||
|     elif os.path.exists(os.path.join(folder, 'fine_spec.npy')): | ||||
|         fill_freqs = np.load(os.path.join(folder, 'fine_freqs.npy')) | ||||
|         fill_times = np.load(os.path.join(folder, 'fine_times.npy')) | ||||
|         fill_spec_shape = np.load(os.path.join(folder, 'fine_spec_shape.npy')) | ||||
|         fill_spec = np.memmap(os.path.join(folder, 'fine_spec.npy'), dtype='float', mode='r', | ||||
|                                    shape=(fill_spec_shape[0], fill_spec_shape[1]), order='F') | ||||
| 
 | ||||
|     return fill_freqs, fill_times, fill_spec | ||||
| 
 | ||||
| 
 | ||||
| def load_tracking_data(folder): | ||||
|     base_path = Path(folder) | ||||
|     EODf_v = np.load(base_path / 'fund_v.npy') | ||||
|     ident_v = np.load(base_path / 'ident_v.npy') | ||||
|     idx_v = np.load(base_path / 'idx_v.npy') | ||||
|     times_v = np.load(base_path / 'times.npy') | ||||
| 
 | ||||
|     return EODf_v, ident_v, idx_v, times_v | ||||
| 
 | ||||
| 
 | ||||
| def load_trial_data(folder): | ||||
|     base_path = Path(folder) | ||||
|     fish_freq = np.load(base_path / 'analysis' / 'fish_freq.npy') | ||||
|     rise_idx = np.load(base_path / 'analysis' / 'rise_idx.npy') | ||||
|     rise_size = np.load(base_path / 'analysis' / 'rise_size.npy') | ||||
| 
 | ||||
|     fish_baseline_freq = np.load(base_path / 'analysis' / 'baseline_freqs.npy') | ||||
|     fish_baseline_freq_time = np.load(base_path / 'analysis' / 'baseline_freq_times.npy') | ||||
| 
 | ||||
|     return fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time | ||||
| 
 | ||||
| 
 | ||||
| def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, dataset_folder): | ||||
|     f_res, t_res = freq[1] - freq[0], times[1] - times[0] | ||||
| 
 | ||||
|     fig_title = (f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(' ', '0') | ||||
|     fig = plt.figure(figsize=IMG_SIZE, num=fig_title) | ||||
|     gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1)  # | ||||
|     ax = fig.add_subplot(gs[0, 0]) | ||||
|     ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', | ||||
|               extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res)) | ||||
|     ax.axis(False) | ||||
| 
 | ||||
|     # plt.savefig(os.path.join(dataset_folder, fig_title), dpi=IMG_DPI) | ||||
|     plt.savefig(Path(DATA_DIR)/fig_title, dpi=IMG_DPI) | ||||
|     plt.close() | ||||
| 
 | ||||
|     return fig_title, (IMG_SIZE[0]*IMG_DPI, IMG_SIZE[1]*IMG_DPI) | ||||
| 
 | ||||
| 
 | ||||
| def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str, | ||||
|                      bbox_df, cols, width, height, t0, t1, f0, f1): | ||||
| 
 | ||||
|     times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1)) | ||||
| 
 | ||||
|     all_x_center = [] | ||||
|     all_y_center = [] | ||||
|     all_width = [] | ||||
|     all_height= [] | ||||
|     for id_idx in range(len(fish_freq)): | ||||
|         rise_idx_oi = np.array(rise_idx[id_idx][ | ||||
|                                    (rise_idx[id_idx] >= times_v_idx0) & | ||||
|                                    (rise_idx[id_idx] <= times_v_idx1) & | ||||
|                                    (rise_size[id_idx] >= 10)], dtype=int) | ||||
|         rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) & | ||||
|                                          (rise_idx[id_idx] <= times_v_idx1) & | ||||
|                                          (rise_size[id_idx] >= 10)] | ||||
|         if len(rise_idx_oi) == 0: | ||||
|             # np.savetxt(LABEL_DIR / Path(pic_save_str).with_suffix('.txt'), np.array([])) | ||||
|             continue | ||||
| 
 | ||||
|         closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi])) | ||||
|         closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx] | ||||
| 
 | ||||
|         upper_freq_bound = closest_baseline_freq + rise_size_oi | ||||
|         lower_freq_bound = closest_baseline_freq | ||||
| 
 | ||||
|         left_time_bound = times_v[rise_idx_oi] | ||||
|         right_time_bound = np.zeros_like(left_time_bound) | ||||
| 
 | ||||
|         for enu, Ct_oi in enumerate(times_v[rise_idx_oi]): | ||||
|             Crise_size = rise_size_oi[enu] | ||||
|             Cblf = closest_baseline_freq[enu] | ||||
| 
 | ||||
|             rise_end_t = times_v[(times_v > Ct_oi) & | ||||
|                                  (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] | ||||
|             if len(rise_end_t) == 0: | ||||
|                 right_time_bound[enu] = np.nan | ||||
|             else: | ||||
|                 right_time_bound[enu] = rise_end_t[0] | ||||
| 
 | ||||
|         mask = (~np.isnan(right_time_bound) & ((right_time_bound - left_time_bound) > 1.)) | ||||
|         left_time_bound = left_time_bound[mask] | ||||
|         right_time_bound = right_time_bound[mask] | ||||
|         lower_freq_bound = lower_freq_bound[mask] | ||||
|         upper_freq_bound = upper_freq_bound[mask] | ||||
| 
 | ||||
|         left_time_bound -= 0.01 * (t1 - t0) | ||||
|         right_time_bound += 0.05 * (t1 - t0) | ||||
|         lower_freq_bound -= 0.01 * (f1 - f0) | ||||
|         upper_freq_bound += 0.05 * (f1 - f0) | ||||
| 
 | ||||
|         mask2 = ((left_time_bound >= t0) & | ||||
|                 (right_time_bound <= t1) & | ||||
|                 (lower_freq_bound >= f0) & | ||||
|                 (upper_freq_bound <= f1) | ||||
|         ) | ||||
|         left_time_bound = left_time_bound[mask2] | ||||
|         right_time_bound = right_time_bound[mask2] | ||||
|         lower_freq_bound = lower_freq_bound[mask2] | ||||
|         upper_freq_bound = upper_freq_bound[mask2] | ||||
| 
 | ||||
|         if len(left_time_bound) == 0: | ||||
|             continue | ||||
| 
 | ||||
|         # x0 = np.array((left_time_bound - t0) / (t1 - t0) * width, dtype=int) | ||||
|         # x1 = np.array((right_time_bound - t0) / (t1 - t0) * width, dtype=int) | ||||
|         # | ||||
|         # y0 = np.array((1 - (upper_freq_bound - f0) / (f1 - f0)) * height, dtype=int) | ||||
|         # y1 = np.array((1 - (lower_freq_bound - f0) / (f1 - f0)) * height, dtype=int) | ||||
| 
 | ||||
| 
 | ||||
|         rel_x0 = np.array((left_time_bound - t0) / (t1 - t0), dtype=float) | ||||
|         rel_x1 = np.array((right_time_bound - t0) / (t1 - t0), dtype=float) | ||||
| 
 | ||||
|         rel_y0 = np.array(1 - (upper_freq_bound - f0) / (f1 - f0), dtype=float) | ||||
|         rel_y1 = np.array(1 - (lower_freq_bound - f0) / (f1 - f0), dtype=float) | ||||
| 
 | ||||
|         rel_x_center = rel_x1 - (rel_x1 - rel_x0) / 2 | ||||
|         rel_y_center = rel_y1 - (rel_y1 - rel_y0) / 2 | ||||
|         rel_width = rel_x1 - rel_x0 | ||||
|         rel_height = rel_y1 - rel_y0 | ||||
| 
 | ||||
|         all_x_center.extend(rel_x_center) | ||||
|         all_y_center.extend(rel_y_center) | ||||
|         all_width.extend(rel_width) | ||||
|         all_height.extend(rel_height) | ||||
| 
 | ||||
|         # bbox = np.array([[pic_save_str for i in range(len(left_time_bound))], | ||||
|         #                  left_time_bound, | ||||
|         #                  right_time_bound, | ||||
|         #                  lower_freq_bound, | ||||
|         #                  upper_freq_bound, | ||||
|         #                  x0, y0, x1, y1]) | ||||
| 
 | ||||
|     bbox_yolo_style = np.array([ | ||||
|         np.ones(len(all_x_center)), | ||||
|         all_x_center, | ||||
|         all_y_center, | ||||
|         all_width, | ||||
|         all_height | ||||
|     ]).T | ||||
| 
 | ||||
|     np.savetxt(LABEL_DIR/ Path(pic_save_str).with_suffix('.txt'), bbox_yolo_style) | ||||
|     return bbox_yolo_style | ||||
| 
 | ||||
| 
 | ||||
| def main(args): | ||||
|     folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy')) | ||||
| 
 | ||||
|     if not args.inference: | ||||
|         print('generate training dataset only for files with detected rises') | ||||
|         folders = [folder for folder in folders if (folder / 'analysis' / 'rise_idx.npy').exists()] | ||||
|         cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1'] | ||||
|         bbox_df = pd.DataFrame(columns=cols) | ||||
| 
 | ||||
|     else: | ||||
|         print('generate inference dataset ... only image output') | ||||
|         bbox_df = {} | ||||
| 
 | ||||
|     for enu, folder in enumerate(folders): | ||||
|         print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}') | ||||
| 
 | ||||
|         # load different categories of data | ||||
|         freq, times, spec = ( | ||||
|             load_spec_data(folder)) | ||||
|         EODf_v, ident_v, idx_v, times_v = ( | ||||
|             load_tracking_data(folder)) | ||||
|         if not args.inference: | ||||
|             fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = ( | ||||
|                 load_trial_data(folder)) | ||||
| 
 | ||||
|         # generate iterator for analysis window loop | ||||
|         pic_base = tqdm(itertools.product( | ||||
|             np.arange(0, times[-1], DELTA_TIME), | ||||
|             np.arange(MIN_FREQ, MAX_FREQ, DELTA_FREQ) | ||||
|         ), | ||||
|             total=int((((MAX_FREQ-MIN_FREQ)//DELTA_FREQ)+1) * ((times[-1] // DELTA_TIME)+1)) | ||||
|         ) | ||||
| 
 | ||||
|         for t0, f0 in pic_base: | ||||
| 
 | ||||
|             t1 = t0 + DELTA_TIME + TIME_OVERLAP | ||||
|             f1 = f0 + DELTA_FREQ + FREQ_OVERLAP | ||||
| 
 | ||||
|             present_freqs = EODf_v[(~np.isnan(ident_v))     & | ||||
|                                    (t0 <= times_v[idx_v]) & | ||||
|                                    (times_v[idx_v] <= t1) & | ||||
|                                    (EODf_v >= f0) & | ||||
|                                    (EODf_v <= f1)] | ||||
|             if len(present_freqs) == 0: | ||||
|                 continue | ||||
| 
 | ||||
|             # get spec_idx for current spec snippet | ||||
|             f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1)) | ||||
|             t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1)) | ||||
| 
 | ||||
|             # get spec snippet and create torch.tensfor from it | ||||
|             s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32) | ||||
|             log_s = torch.log10(s) | ||||
|             transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s)) | ||||
|             s_trans = transformed(log_s.unsqueeze(0)) | ||||
| 
 | ||||
|             pic_save_str, (width, height) = save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, args.dataset_folder) | ||||
| 
 | ||||
|             if not args.inference: | ||||
|                 bbox_yolo_style = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, | ||||
|                                                    fish_baseline_freq_time, fish_baseline_freq, | ||||
|                                                    pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1) | ||||
| 
 | ||||
|             ####################################################################### | ||||
|             if False: | ||||
|                 if bbox_yolo_style.shape[0] >= 1: | ||||
|                     f_res, t_res = freq[1] - freq[0], times[1] - times[0] | ||||
| 
 | ||||
|                     fig_title = ( | ||||
|                         f'{Path(folder).name}__{times[t_idx0]:5.0f}s-{times[t_idx1]:5.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace( | ||||
|                         ' ', '0') | ||||
|                     fig = plt.figure(figsize=IMG_SIZE, num=fig_title) | ||||
|                     gs = gridspec.GridSpec(1, 1, bottom=0.1, left=0.1, right=0.95, top=0.95)  # | ||||
|                     ax = fig.add_subplot(gs[0, 0]) | ||||
|                     ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', | ||||
|                               extent=(times[t_idx0] / 3600, (times[t_idx1] + t_res) / 3600, freq[f_idx0], freq[f_idx1] + f_res)) | ||||
|                     # ax.invert_yaxis() | ||||
|                     # ax.axis(False) | ||||
| 
 | ||||
|                     for i in range(len(bbox_df)): | ||||
|                         # Cbbox = np.array(bbox_df.loc[i, ['x0', 'y0', 'x1', 'y1']].values, dtype=np.float32) | ||||
|                         Cbbox = bbox_df.loc[i, ['t0', 'f0', 't1', 'f1']] | ||||
|                         ax.add_patch( | ||||
|                             Rectangle((float(Cbbox['t0']) / 3600, float(Cbbox['f0'])), | ||||
|                                       float(Cbbox['t1']) / 3600 - float(Cbbox['t0']) / 3600, | ||||
|                                       float(Cbbox['f1']) - float(Cbbox['f0']), | ||||
|                                       fill=False, color="white", linestyle='-', linewidth=2, zorder=10) | ||||
|                         ) | ||||
| 
 | ||||
|                     # print(bbox_yolo_style.T) | ||||
| 
 | ||||
|                     for bbox in bbox_yolo_style: | ||||
|                         x0 = bbox[1] - bbox[3]/2 # x_center - width/2 | ||||
|                         y0 = 1 - (bbox[2] + bbox[4]/2) # x_center - width/2 | ||||
|                         w = bbox[3] | ||||
|                         h = bbox[4] | ||||
|                         ax.add_patch( | ||||
|                             Rectangle((x0, y0), w, h, | ||||
|                                       fill=False, color="k", linestyle='--', linewidth=2, zorder=10, | ||||
|                                       transform=ax.transAxes) | ||||
|                         ) | ||||
|                     plt.show() | ||||
|             ####################################################################### | ||||
| 
 | ||||
|         # if not args.inference: | ||||
|         #     print('save bboxes') | ||||
|             # bbox_df.to_csv(os.path.join(args.dataset_folder, 'bbox_dataset.csv'), columns=cols, sep=',') | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') | ||||
|     parser.add_argument('folder', type=str, help='single recording analysis', default='') | ||||
|     parser.add_argument('-d', "--dataset_folder", type=str, help='designated datasef folder', default=DATA_DIR) | ||||
|     parser.add_argument('-i', "--inference", action="store_true", help="generate inference dataset. Img only") | ||||
|     args = parser.parse_args() | ||||
| 
 | ||||
|     if not Path(args.dataset_folder).exists(): | ||||
|         Path(args.dataset_folder).mkdir(parents=True, exist_ok=True) | ||||
| 
 | ||||
|     main(args) | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user