alot of work. dataset and loader not created
This commit is contained in:
		
							parent
							
								
									5ca527e39f
								
							
						
					
					
						commit
						3f66b4a39a
					
				
							
								
								
									
										6
									
								
								custom_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								custom_utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | |||||||
|  | def collate_fn(batch): | ||||||
|  |     """ | ||||||
|  |     To handle the data loading as different images may have different number | ||||||
|  |     of objects and to handle varying size tensors as well. | ||||||
|  |     """ | ||||||
|  |     return tuple(zip(*batch)) | ||||||
| @ -9,6 +9,7 @@ import torchvision.transforms as T | |||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
| import matplotlib.gridspec as gridspec | import matplotlib.gridspec as gridspec | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  | import pandas as pd | ||||||
| 
 | 
 | ||||||
| from tqdm.auto import tqdm | from tqdm.auto import tqdm | ||||||
| 
 | 
 | ||||||
| @ -55,7 +56,9 @@ def load_data(folder): | |||||||
|     return fill_freqs, fill_times, fill_spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time |     return fill_freqs, fill_times, fill_spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time | ||||||
| 
 | 
 | ||||||
| def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res): | def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res): | ||||||
|     fig_title = (f'{Path(folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0') |     size = (7, 7) | ||||||
|  |     dpi = 256 | ||||||
|  |     fig_title = (f'{Path(folder).name}__{times[t_idx0]:.0f}s-{times[t_idx1]:.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(' ', '0') | ||||||
|     fig = plt.figure(figsize=(7, 7), num=fig_title) |     fig = plt.figure(figsize=(7, 7), num=fig_title) | ||||||
|     gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0)  # , bottom=0, left=0, right=1, top=1 |     gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0)  # , bottom=0, left=0, right=1, top=1 | ||||||
|     gs2 = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1)  # |     gs2 = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1)  # | ||||||
| @ -64,9 +67,72 @@ def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, | |||||||
|                    extent=(times[t_idx0] / 3600, times[t_idx1] / 3600 + t_res, freq[f_idx0], freq[f_idx1] + f_res)) |                    extent=(times[t_idx0] / 3600, times[t_idx1] / 3600 + t_res, freq[f_idx0], freq[f_idx1] + f_res)) | ||||||
|     ax.axis(False) |     ax.axis(False) | ||||||
| 
 | 
 | ||||||
|     plt.savefig(os.path.join('train', fig_title + '.png'), dpi=256) |     plt.savefig(os.path.join('train', fig_title), dpi=256) | ||||||
|     plt.close() |     plt.close() | ||||||
| 
 | 
 | ||||||
|  |     return fig_title, (size[0]*dpi, size[1]*dpi) | ||||||
|  | 
 | ||||||
|  | def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str, | ||||||
|  |                      bbox_df, cols, width, height, t0, t1, f0, f1): | ||||||
|  | 
 | ||||||
|  |     times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1)) | ||||||
|  |     for id_idx in range(len(fish_freq)): | ||||||
|  |         rise_idx_oi = np.array(rise_idx[id_idx][ | ||||||
|  |                                    (rise_idx[id_idx] >= times_v_idx0) & | ||||||
|  |                                    (rise_idx[id_idx] <= times_v_idx1) & | ||||||
|  |                                    (rise_size[id_idx] >= 10)], dtype=int) | ||||||
|  |         rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) & | ||||||
|  |                                          (rise_idx[id_idx] <= times_v_idx1) & | ||||||
|  |                                          (rise_size[id_idx] >= 10)] | ||||||
|  |         if len(rise_idx_oi) == 0: | ||||||
|  |             continue | ||||||
|  |         closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi])) | ||||||
|  |         closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx] | ||||||
|  | 
 | ||||||
|  |         upper_freq_bound = closest_baseline_freq + rise_size_oi | ||||||
|  |         lower_freq_bound = closest_baseline_freq | ||||||
|  | 
 | ||||||
|  |         left_time_bound = times_v[rise_idx_oi] | ||||||
|  |         right_time_bound = np.zeros_like(left_time_bound) | ||||||
|  | 
 | ||||||
|  |         for enu, Ct_oi in enumerate(times_v[rise_idx_oi]): | ||||||
|  |             Crise_size = rise_size_oi[enu] | ||||||
|  |             Cblf = closest_baseline_freq[enu] | ||||||
|  | 
 | ||||||
|  |             rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] | ||||||
|  |             if len(rise_end_t) == 0: | ||||||
|  |                 right_time_bound[enu] = np.nan | ||||||
|  |             else: | ||||||
|  |                 right_time_bound[enu] = rise_end_t[0] | ||||||
|  | 
 | ||||||
|  |         dt_bbox = right_time_bound - left_time_bound | ||||||
|  |         df_bbox = upper_freq_bound - lower_freq_bound | ||||||
|  |         left_time_bound -= dt_bbox * 0.1 | ||||||
|  |         right_time_bound += dt_bbox * 0.1 | ||||||
|  |         lower_freq_bound -= df_bbox * 0.1 | ||||||
|  |         upper_freq_bound += df_bbox * 0.1 | ||||||
|  | 
 | ||||||
|  |         x0 = np.array((left_time_bound - t0) / (t1 - t0) * width, dtype=int) | ||||||
|  |         x1 = np.array((right_time_bound - t0) / (t1 - t0) * width, dtype=int) | ||||||
|  |         y0 = np.array((1 - (upper_freq_bound - f0) / (f1 - f0)) * height, dtype=int) | ||||||
|  |         y1 = np.array((1 - (lower_freq_bound - f0) / (f1 - f0)) * height, dtype=int) | ||||||
|  | 
 | ||||||
|  |         bbox = np.array([[pic_save_str for i in range(len(left_time_bound))], | ||||||
|  |                          left_time_bound, | ||||||
|  |                          right_time_bound, | ||||||
|  |                          lower_freq_bound, | ||||||
|  |                          upper_freq_bound, | ||||||
|  |                          x0, x1, y0, y1]) | ||||||
|  |         # test_s = ['a', 'a', 'a', 'a'] | ||||||
|  |         tmp_df = pd.DataFrame( | ||||||
|  |             # index= [pic_save_str for i in range(len(left_time_bound))], | ||||||
|  |             # index= test_s, | ||||||
|  |             data=bbox.T, | ||||||
|  |             columns=cols | ||||||
|  |         ) | ||||||
|  |         bbox_df = pd.concat([bbox_df, tmp_df], ignore_index=True) | ||||||
|  |         # bbox_df.append(tmp_df) | ||||||
|  |     return bbox_df | ||||||
| 
 | 
 | ||||||
| def main(args): | def main(args): | ||||||
|     min_freq = 200 |     min_freq = 200 | ||||||
| @ -76,109 +142,129 @@ def main(args): | |||||||
|     d_time = 60*15 |     d_time = 60*15 | ||||||
|     time_overlap = 60*5 |     time_overlap = 60*5 | ||||||
| 
 | 
 | ||||||
|     freq, times, spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = ( |     if not os.path.exists(os.path.join('train', 'bbox_dataset.csv')): | ||||||
|         load_data(args.folder)) |         cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'x1', 'y0', 'y1'] | ||||||
|     f_res, t_res = freq[1] - freq[0], times[1] - times[0] |         bbox_df = pd.DataFrame(columns=cols) | ||||||
| 
 | 
 | ||||||
|     unique_ids = np.unique(ident_v[~np.isnan(ident_v)]) |     else: | ||||||
|  |         bbox_df = pd.read_csv(os.path.join('train', 'bbox_dataset.csv'), sep=',', index_col=0) | ||||||
|  |         cols = list(bbox_df.keys()) | ||||||
|  |         eval_files = [] | ||||||
|  |         # ToDo: make sure not same file twice | ||||||
|  |         for f in pd.unique(bbox_df['image']): | ||||||
|  |             eval_files.append(f.split('__')[0]) | ||||||
| 
 | 
 | ||||||
|     pic_base = tqdm(itertools.product( |     folders = [args.folders] | ||||||
|         np.arange(0, times[-1], d_time), |  | ||||||
|         np.arange(min_freq, max_freq, d_freq) |  | ||||||
|     ), |  | ||||||
|         total=int(((max_freq-min_freq)//d_freq) * (times[-1] // d_time)) |  | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
|     for t0, f0 in pic_base: |     for enu, folder in enumerate(folders): | ||||||
|         t1 = t0 + d_time + time_overlap |         print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}') | ||||||
|         f1 = f0 + d_freq + freq_overlap |  | ||||||
| 
 | 
 | ||||||
|         present_freqs = EODf_v[(~np.isnan(ident_v)) & |         freq, times, spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = ( | ||||||
|                                (t0 <= times_v[idx_v]) & |             load_data(folder)) | ||||||
|                                (times_v[idx_v] <= t1) & |         f_res, t_res = freq[1] - freq[0], times[1] - times[0] | ||||||
|                                (EODf_v >= f0) & |  | ||||||
|                                (EODf_v <= f1)] |  | ||||||
|         if len(present_freqs) == 0: |  | ||||||
|             continue |  | ||||||
| 
 | 
 | ||||||
|         f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1)) |         pic_base = tqdm(itertools.product( | ||||||
|         t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1)) |             np.arange(0, times[-1], d_time), | ||||||
|  |             np.arange(min_freq, max_freq, d_freq) | ||||||
|  |         ), | ||||||
|  |             total=int(((max_freq-min_freq)//d_freq) * (times[-1] // d_time)) | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32) |         for t0, f0 in pic_base: | ||||||
|         log_s = torch.log10(s) |             t1 = t0 + d_time + time_overlap | ||||||
|         transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s)) |             f1 = f0 + d_freq + freq_overlap | ||||||
|         s_trans = transformed(log_s.unsqueeze(0)) | 
 | ||||||
|  |             present_freqs = EODf_v[(~np.isnan(ident_v)) & | ||||||
|  |                                    (t0 <= times_v[idx_v]) & | ||||||
|  |                                    (times_v[idx_v] <= t1) & | ||||||
|  |                                    (EODf_v >= f0) & | ||||||
|  |                                    (EODf_v <= f1)] | ||||||
|  |             if len(present_freqs) == 0: | ||||||
|  |                 continue | ||||||
|  | 
 | ||||||
|  |             f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1)) | ||||||
|  |             t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1)) | ||||||
|  | 
 | ||||||
|  |             s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32) | ||||||
|  |             log_s = torch.log10(s) | ||||||
|  |             transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s)) | ||||||
|  |             s_trans = transformed(log_s.unsqueeze(0)) | ||||||
|  | 
 | ||||||
|  |             if not args.dev: | ||||||
|  |                 pic_save_str, (width, height) = save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res) | ||||||
|  | 
 | ||||||
|  |                 bbox_df = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, | ||||||
|  |                                            fish_baseline_freq_time, fish_baseline_freq, | ||||||
|  |                                            pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1) | ||||||
|  |             else: | ||||||
|  |                 fig_title = (f'{Path(args.folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0') | ||||||
|  |                 fig = plt.figure(figsize=(10, 7), num=fig_title) | ||||||
|  |                 gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0, left=0.1, bottom=0.1, right=0.9, top=0.95)  # , bottom=0, left=0, right=1, top=1 | ||||||
|  |                 ax = fig.add_subplot(gs[0, 0]) | ||||||
|  |                 cax = fig.add_subplot(gs[0, 1]) | ||||||
|  |                 im = ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', | ||||||
|  |                                extent=(times[t_idx0], times[t_idx1] + t_res, freq[f_idx0], freq[f_idx1] + f_res)) | ||||||
|  |                 fig.colorbar(im, cax=cax, orientation='vertical') | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |                 times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1)) | ||||||
|  |                 for id_idx in range(len(fish_freq)): | ||||||
|  |                     ax.plot(times_v[times_v_idx0:times_v_idx1], fish_freq[id_idx][times_v_idx0:times_v_idx1], marker='.', color='k', markersize=4) | ||||||
|  |                     rise_idx_oi = np.array(rise_idx[id_idx][ | ||||||
|  |                                                (rise_idx[id_idx] >= times_v_idx0) & | ||||||
|  |                                                (rise_idx[id_idx] <= times_v_idx1) & | ||||||
|  |                                                (rise_size[id_idx] >= 10)], dtype=int) | ||||||
|  |                     rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) & | ||||||
|  |                                                     (rise_idx[id_idx] <= times_v_idx1) & | ||||||
|  |                                                     (rise_size[id_idx] >= 10)] | ||||||
|  | 
 | ||||||
|  |                     ax.plot(times_v[rise_idx_oi], fish_freq[id_idx][rise_idx_oi], 'o', color='tab:red') | ||||||
|  | 
 | ||||||
|  |                     if len(rise_idx_oi) > 0: | ||||||
|  |                         closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi])) | ||||||
|  |                         closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx] | ||||||
|  | 
 | ||||||
|  |                         upper_freq_bound = closest_baseline_freq + rise_size_oi | ||||||
|  |                         lower_freq_bound = closest_baseline_freq | ||||||
|  | 
 | ||||||
|  |                         left_time_bound = times_v[rise_idx_oi] | ||||||
|  |                         right_time_bound = np.zeros_like(left_time_bound) | ||||||
|  | 
 | ||||||
|  |                         for enu, Ct_oi in enumerate(times_v[rise_idx_oi]): | ||||||
|  |                             Crise_size = rise_size_oi[enu] | ||||||
|  |                             Cblf = closest_baseline_freq[enu] | ||||||
|  | 
 | ||||||
|  |                             rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] | ||||||
|  |                             if len(rise_end_t) == 0: | ||||||
|  |                                 right_time_bound[enu] = np.nan | ||||||
|  |                             else: | ||||||
|  |                                 right_time_bound[enu] = rise_end_t[0] | ||||||
|  | 
 | ||||||
|  |                         dt_bbox = right_time_bound - left_time_bound | ||||||
|  |                         df_bbox = upper_freq_bound - lower_freq_bound | ||||||
|  |                         left_time_bound -= dt_bbox*0.1 | ||||||
|  |                         right_time_bound += dt_bbox*0.1 | ||||||
|  |                         lower_freq_bound -= df_bbox*0.1 | ||||||
|  |                         upper_freq_bound += df_bbox*0.1 | ||||||
|  | 
 | ||||||
|  |                         print(f'f0: {lower_freq_bound}') | ||||||
|  |                         print(f'f1: {upper_freq_bound}') | ||||||
|  |                         print(f't0: {left_time_bound}') | ||||||
|  |                         print(f't1: {right_time_bound}') | ||||||
|  | 
 | ||||||
|  |                         for enu in range(len(left_time_bound)): | ||||||
|  |                             if np.isnan(right_time_bound[enu]): | ||||||
|  |                                 continue | ||||||
|  |                             ax.add_patch( | ||||||
|  |                                 Rectangle((left_time_bound[enu], lower_freq_bound[enu]), | ||||||
|  |                                                    (right_time_bound[enu] - left_time_bound[enu]), | ||||||
|  |                                                    (upper_freq_bound[enu] - lower_freq_bound[enu]), | ||||||
|  |                                                    fill=False, color="white", linewidth=2, zorder=10) | ||||||
|  |                             ) | ||||||
|  |                 plt.show() | ||||||
| 
 | 
 | ||||||
|         if not args.dev: |         if not args.dev: | ||||||
|             save_spec_pic(args.folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res) |             bbox_df.to_csv(os.path.join('train', 'bbox_dataset.csv'), columns=cols, sep=',') | ||||||
| 
 |  | ||||||
|         else: |  | ||||||
|             fig_title = (f'{Path(args.folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0') |  | ||||||
|             fig = plt.figure(figsize=(10, 7), num=fig_title) |  | ||||||
|             gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0, left=0.1, bottom=0.1, right=0.9, top=0.95)  # , bottom=0, left=0, right=1, top=1 |  | ||||||
|             ax = fig.add_subplot(gs[0, 0]) |  | ||||||
|             cax = fig.add_subplot(gs[0, 1]) |  | ||||||
|             im = ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', |  | ||||||
|                            extent=(times[t_idx0], times[t_idx1] + t_res, freq[f_idx0], freq[f_idx1] + f_res)) |  | ||||||
|             fig.colorbar(im, cax=cax, orientation='vertical') |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|             times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1)) |  | ||||||
|             for id_idx in range(len(fish_freq)): |  | ||||||
|                 ax.plot(times_v[times_v_idx0:times_v_idx1], fish_freq[id_idx][times_v_idx0:times_v_idx1], marker='.', color='k', markersize=4) |  | ||||||
|                 rise_idx_oi = np.array(rise_idx[id_idx][ |  | ||||||
|                                            (rise_idx[id_idx] >= times_v_idx0) & |  | ||||||
|                                            (rise_idx[id_idx] <= times_v_idx1) & |  | ||||||
|                                            (rise_size[id_idx] >= 10)], dtype=int) |  | ||||||
|                 rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) & |  | ||||||
|                                                 (rise_idx[id_idx] <= times_v_idx1) & |  | ||||||
|                                                 (rise_size[id_idx] >= 10)] |  | ||||||
| 
 |  | ||||||
|                 ax.plot(times_v[rise_idx_oi], fish_freq[id_idx][rise_idx_oi], 'o', color='tab:red') |  | ||||||
| 
 |  | ||||||
|                 if len(rise_idx_oi) > 0: |  | ||||||
|                     closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi])) |  | ||||||
|                     closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx] |  | ||||||
| 
 |  | ||||||
|                     upper_freq_bound = closest_baseline_freq + rise_size_oi |  | ||||||
|                     lower_freq_bound = closest_baseline_freq |  | ||||||
| 
 |  | ||||||
|                     left_time_bound = times_v[rise_idx_oi] |  | ||||||
|                     right_time_bound = np.zeros_like(left_time_bound) |  | ||||||
| 
 |  | ||||||
|                     for enu, Ct_oi in enumerate(times_v[rise_idx_oi]): |  | ||||||
|                         Crise_size = rise_size_oi[enu] |  | ||||||
|                         Cblf = closest_baseline_freq[enu] |  | ||||||
| 
 |  | ||||||
|                         rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] |  | ||||||
|                         if len(rise_end_t) == 0: |  | ||||||
|                             right_time_bound[enu] = np.nan |  | ||||||
|                         else: |  | ||||||
|                             right_time_bound[enu] = rise_end_t[0] |  | ||||||
| 
 |  | ||||||
|                     dt_bbox = right_time_bound - left_time_bound |  | ||||||
|                     df_bbox = upper_freq_bound - lower_freq_bound |  | ||||||
|                     left_time_bound -= dt_bbox*0.1 |  | ||||||
|                     right_time_bound += dt_bbox*0.1 |  | ||||||
|                     lower_freq_bound -= df_bbox*0.1 |  | ||||||
|                     upper_freq_bound += df_bbox*0.1 |  | ||||||
| 
 |  | ||||||
|                     print(f'f0: {lower_freq_bound}') |  | ||||||
|                     print(f'f1: {upper_freq_bound}') |  | ||||||
|                     print(f't0: {left_time_bound}') |  | ||||||
|                     print(f't1: {right_time_bound}') |  | ||||||
| 
 |  | ||||||
|                     for enu in range(len(left_time_bound)): |  | ||||||
|                         if np.isnan(right_time_bound[enu]): |  | ||||||
|                             continue |  | ||||||
|                         ax.add_patch( |  | ||||||
|                             Rectangle((left_time_bound[enu], lower_freq_bound[enu]), |  | ||||||
|                                                (right_time_bound[enu] - left_time_bound[enu]), |  | ||||||
|                                                (upper_freq_bound[enu] - lower_freq_bound[enu]), |  | ||||||
|                                                fill=False, color="white", linewidth=2, zorder=10) |  | ||||||
|                         ) |  | ||||||
|             plt.show() |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') |     parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') | ||||||
|  | |||||||
							
								
								
									
										105
									
								
								datasets.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								datasets.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,105 @@ | |||||||
|  | import os | ||||||
|  | import glob | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | import torchvision | ||||||
|  | import torchvision.transforms.functional as F | ||||||
|  | from torch.utils.data import Dataset, DataLoader | ||||||
|  | 
 | ||||||
|  | import numpy as np | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | import pandas as pd | ||||||
|  | from matplotlib.patches import Rectangle | ||||||
|  | from pathlib import Path | ||||||
|  | from tqdm.auto import tqdm | ||||||
|  | from PIL import Image | ||||||
|  | 
 | ||||||
|  | from confic import (CLASSES, RESIZE_TO, TRAIN_DIR, BATCH_SIZE) | ||||||
|  | from custom_utils import collate_fn | ||||||
|  | 
 | ||||||
|  | from IPython import embed | ||||||
|  | 
 | ||||||
|  | class CustomDataset(Dataset): | ||||||
|  |     def __init__(self, dir_path, use_idxs = None): | ||||||
|  |         self.dir_path = dir_path | ||||||
|  |         self.image_paths = glob.glob(f'{self.dir_path}/*.png') | ||||||
|  |         self.all_images = [img_path.split(os.path.sep)[-1] for img_path in self.image_paths] | ||||||
|  |         self.all_images = np.array(sorted(self.all_images), dtype=str) | ||||||
|  |         if hasattr(use_idxs, '__len__'): | ||||||
|  |             self.all_images = self.all_images[use_idxs] | ||||||
|  |         self.bbox_df = pd.read_csv(os.path.join(dir_path, 'bbox_dataset.csv'), sep=',', index_col=0) | ||||||
|  | 
 | ||||||
|  |     def __getitem__(self, idx): | ||||||
|  |         image_name = self.all_images[idx] | ||||||
|  |         image_path = os.path.join(self.dir_path, image_name) | ||||||
|  | 
 | ||||||
|  |         img = Image.open(image_path) | ||||||
|  |         img_tensor = F.to_tensor(img.convert('RGB')) | ||||||
|  | 
 | ||||||
|  |         Cbbox = self.bbox_df[self.bbox_df['image'] == image_name] | ||||||
|  | 
 | ||||||
|  |         labels = np.ones(len(Cbbox), dtype=int) | ||||||
|  |         boxes = Cbbox.loc[:, ['x0', 'x1', 'y0', 'y1']].values | ||||||
|  | 
 | ||||||
|  |         target = {} | ||||||
|  |         target["boxes"] = boxes | ||||||
|  |         target["labels"] = labels | ||||||
|  | 
 | ||||||
|  |         return img_tensor, target | ||||||
|  | 
 | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.all_images) | ||||||
|  | 
 | ||||||
|  | def create_train_test_dataset(path, test_size=0.2): | ||||||
|  |     files = glob.glob(os.path.join(path, '*.png')) | ||||||
|  |     train_test_idx = np.arange(len(files), dtype=int) | ||||||
|  |     np.random.shuffle(train_test_idx) | ||||||
|  | 
 | ||||||
|  |     train_idx = train_test_idx[int(test_size*len(train_test_idx)):] | ||||||
|  |     test_idx = train_test_idx[:int(test_size*len(train_test_idx))] | ||||||
|  | 
 | ||||||
|  |     train_data = CustomDataset(path, use_idxs=train_idx) | ||||||
|  |     test_data = CustomDataset(path, use_idxs=test_idx) | ||||||
|  | 
 | ||||||
|  |     return train_data, test_data | ||||||
|  | 
 | ||||||
|  | def create_train_loader(train_dataset, num_workers=0): | ||||||
|  |     train_loader = DataLoader( | ||||||
|  |         train_dataset, | ||||||
|  |         batch_size=BATCH_SIZE, | ||||||
|  |         shuffle=True, | ||||||
|  |         num_workers=num_workers, | ||||||
|  |         collate_fn=collate_fn | ||||||
|  |     ) | ||||||
|  |     return train_loader | ||||||
|  | def create_valid_loader(valid_dataset, num_workers=0): | ||||||
|  |     valid_loader = DataLoader( | ||||||
|  |         valid_dataset, | ||||||
|  |         batch_size=BATCH_SIZE, | ||||||
|  |         shuffle=False, | ||||||
|  |         num_workers=num_workers, | ||||||
|  |         collate_fn=collate_fn | ||||||
|  |     ) | ||||||
|  |     return valid_loader | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  | 
 | ||||||
|  |     train_data, test_data = create_train_test_dataset(TRAIN_DIR) | ||||||
|  | 
 | ||||||
|  |     train_loader = create_train_loader(train_data) | ||||||
|  |     test_loader = create_valid_loader(test_data) | ||||||
|  | 
 | ||||||
|  |     for samples, targets in train_loader: | ||||||
|  |         for s, t in zip(samples, targets): | ||||||
|  |             fig, ax = plt.subplots() | ||||||
|  |             ax.imshow(s.permute(1, 2, 0), aspect='auto') | ||||||
|  |             for (x0, x1, y0, y1), l in zip(t['boxes'], t['labels']): | ||||||
|  |                 print(x0, x1, y0, y1, l) | ||||||
|  |                 ax.add_patch( | ||||||
|  |                     Rectangle((x0, y0), | ||||||
|  |                               (x1 - x0), | ||||||
|  |                               (y1 - y0), | ||||||
|  |                               fill=False, color="white", linewidth=2, zorder=10) | ||||||
|  |                 ) | ||||||
|  |             plt.show() | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user