alot of work. dataset and loader not created
This commit is contained in:
		
							parent
							
								
									5ca527e39f
								
							
						
					
					
						commit
						3f66b4a39a
					
				
							
								
								
									
										6
									
								
								custom_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								custom_utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | ||||
| def collate_fn(batch): | ||||
|     """ | ||||
|     To handle the data loading as different images may have different number | ||||
|     of objects and to handle varying size tensors as well. | ||||
|     """ | ||||
|     return tuple(zip(*batch)) | ||||
| @ -9,6 +9,7 @@ import torchvision.transforms as T | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.gridspec as gridspec | ||||
| from pathlib import Path | ||||
| import pandas as pd | ||||
| 
 | ||||
| from tqdm.auto import tqdm | ||||
| 
 | ||||
| @ -55,7 +56,9 @@ def load_data(folder): | ||||
|     return fill_freqs, fill_times, fill_spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time | ||||
| 
 | ||||
| def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res): | ||||
|     fig_title = (f'{Path(folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0') | ||||
|     size = (7, 7) | ||||
|     dpi = 256 | ||||
|     fig_title = (f'{Path(folder).name}__{times[t_idx0]:.0f}s-{times[t_idx1]:.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(' ', '0') | ||||
|     fig = plt.figure(figsize=(7, 7), num=fig_title) | ||||
|     gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0)  # , bottom=0, left=0, right=1, top=1 | ||||
|     gs2 = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1)  # | ||||
| @ -64,9 +67,72 @@ def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, | ||||
|                    extent=(times[t_idx0] / 3600, times[t_idx1] / 3600 + t_res, freq[f_idx0], freq[f_idx1] + f_res)) | ||||
|     ax.axis(False) | ||||
| 
 | ||||
|     plt.savefig(os.path.join('train', fig_title + '.png'), dpi=256) | ||||
|     plt.savefig(os.path.join('train', fig_title), dpi=256) | ||||
|     plt.close() | ||||
| 
 | ||||
|     return fig_title, (size[0]*dpi, size[1]*dpi) | ||||
| 
 | ||||
| def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str, | ||||
|                      bbox_df, cols, width, height, t0, t1, f0, f1): | ||||
| 
 | ||||
|     times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1)) | ||||
|     for id_idx in range(len(fish_freq)): | ||||
|         rise_idx_oi = np.array(rise_idx[id_idx][ | ||||
|                                    (rise_idx[id_idx] >= times_v_idx0) & | ||||
|                                    (rise_idx[id_idx] <= times_v_idx1) & | ||||
|                                    (rise_size[id_idx] >= 10)], dtype=int) | ||||
|         rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) & | ||||
|                                          (rise_idx[id_idx] <= times_v_idx1) & | ||||
|                                          (rise_size[id_idx] >= 10)] | ||||
|         if len(rise_idx_oi) == 0: | ||||
|             continue | ||||
|         closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi])) | ||||
|         closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx] | ||||
| 
 | ||||
|         upper_freq_bound = closest_baseline_freq + rise_size_oi | ||||
|         lower_freq_bound = closest_baseline_freq | ||||
| 
 | ||||
|         left_time_bound = times_v[rise_idx_oi] | ||||
|         right_time_bound = np.zeros_like(left_time_bound) | ||||
| 
 | ||||
|         for enu, Ct_oi in enumerate(times_v[rise_idx_oi]): | ||||
|             Crise_size = rise_size_oi[enu] | ||||
|             Cblf = closest_baseline_freq[enu] | ||||
| 
 | ||||
|             rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] | ||||
|             if len(rise_end_t) == 0: | ||||
|                 right_time_bound[enu] = np.nan | ||||
|             else: | ||||
|                 right_time_bound[enu] = rise_end_t[0] | ||||
| 
 | ||||
|         dt_bbox = right_time_bound - left_time_bound | ||||
|         df_bbox = upper_freq_bound - lower_freq_bound | ||||
|         left_time_bound -= dt_bbox * 0.1 | ||||
|         right_time_bound += dt_bbox * 0.1 | ||||
|         lower_freq_bound -= df_bbox * 0.1 | ||||
|         upper_freq_bound += df_bbox * 0.1 | ||||
| 
 | ||||
|         x0 = np.array((left_time_bound - t0) / (t1 - t0) * width, dtype=int) | ||||
|         x1 = np.array((right_time_bound - t0) / (t1 - t0) * width, dtype=int) | ||||
|         y0 = np.array((1 - (upper_freq_bound - f0) / (f1 - f0)) * height, dtype=int) | ||||
|         y1 = np.array((1 - (lower_freq_bound - f0) / (f1 - f0)) * height, dtype=int) | ||||
| 
 | ||||
|         bbox = np.array([[pic_save_str for i in range(len(left_time_bound))], | ||||
|                          left_time_bound, | ||||
|                          right_time_bound, | ||||
|                          lower_freq_bound, | ||||
|                          upper_freq_bound, | ||||
|                          x0, x1, y0, y1]) | ||||
|         # test_s = ['a', 'a', 'a', 'a'] | ||||
|         tmp_df = pd.DataFrame( | ||||
|             # index= [pic_save_str for i in range(len(left_time_bound))], | ||||
|             # index= test_s, | ||||
|             data=bbox.T, | ||||
|             columns=cols | ||||
|         ) | ||||
|         bbox_df = pd.concat([bbox_df, tmp_df], ignore_index=True) | ||||
|         # bbox_df.append(tmp_df) | ||||
|     return bbox_df | ||||
| 
 | ||||
| def main(args): | ||||
|     min_freq = 200 | ||||
| @ -76,11 +142,26 @@ def main(args): | ||||
|     d_time = 60*15 | ||||
|     time_overlap = 60*5 | ||||
| 
 | ||||
|     freq, times, spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = ( | ||||
|         load_data(args.folder)) | ||||
|     f_res, t_res = freq[1] - freq[0], times[1] - times[0] | ||||
|     if not os.path.exists(os.path.join('train', 'bbox_dataset.csv')): | ||||
|         cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'x1', 'y0', 'y1'] | ||||
|         bbox_df = pd.DataFrame(columns=cols) | ||||
| 
 | ||||
|     unique_ids = np.unique(ident_v[~np.isnan(ident_v)]) | ||||
|     else: | ||||
|         bbox_df = pd.read_csv(os.path.join('train', 'bbox_dataset.csv'), sep=',', index_col=0) | ||||
|         cols = list(bbox_df.keys()) | ||||
|         eval_files = [] | ||||
|         # ToDo: make sure not same file twice | ||||
|         for f in pd.unique(bbox_df['image']): | ||||
|             eval_files.append(f.split('__')[0]) | ||||
| 
 | ||||
|     folders = [args.folders] | ||||
| 
 | ||||
|     for enu, folder in enumerate(folders): | ||||
|         print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}') | ||||
| 
 | ||||
|         freq, times, spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = ( | ||||
|             load_data(folder)) | ||||
|         f_res, t_res = freq[1] - freq[0], times[1] - times[0] | ||||
| 
 | ||||
|         pic_base = tqdm(itertools.product( | ||||
|             np.arange(0, times[-1], d_time), | ||||
| @ -110,8 +191,11 @@ def main(args): | ||||
|             s_trans = transformed(log_s.unsqueeze(0)) | ||||
| 
 | ||||
|             if not args.dev: | ||||
|             save_spec_pic(args.folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res) | ||||
|                 pic_save_str, (width, height) = save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res) | ||||
| 
 | ||||
|                 bbox_df = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, | ||||
|                                            fish_baseline_freq_time, fish_baseline_freq, | ||||
|                                            pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1) | ||||
|             else: | ||||
|                 fig_title = (f'{Path(args.folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0') | ||||
|                 fig = plt.figure(figsize=(10, 7), num=fig_title) | ||||
| @ -179,6 +263,8 @@ def main(args): | ||||
|                             ) | ||||
|                 plt.show() | ||||
| 
 | ||||
|         if not args.dev: | ||||
|             bbox_df.to_csv(os.path.join('train', 'bbox_dataset.csv'), columns=cols, sep=',') | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') | ||||
|  | ||||
							
								
								
									
										105
									
								
								datasets.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								datasets.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,105 @@ | ||||
| import os | ||||
| import glob | ||||
| 
 | ||||
| import torch | ||||
| import torchvision | ||||
| import torchvision.transforms.functional as F | ||||
| from torch.utils.data import Dataset, DataLoader | ||||
| 
 | ||||
| import numpy as np | ||||
| import matplotlib.pyplot as plt | ||||
| import pandas as pd | ||||
| from matplotlib.patches import Rectangle | ||||
| from pathlib import Path | ||||
| from tqdm.auto import tqdm | ||||
| from PIL import Image | ||||
| 
 | ||||
| from confic import (CLASSES, RESIZE_TO, TRAIN_DIR, BATCH_SIZE) | ||||
| from custom_utils import collate_fn | ||||
| 
 | ||||
| from IPython import embed | ||||
| 
 | ||||
| class CustomDataset(Dataset): | ||||
|     def __init__(self, dir_path, use_idxs = None): | ||||
|         self.dir_path = dir_path | ||||
|         self.image_paths = glob.glob(f'{self.dir_path}/*.png') | ||||
|         self.all_images = [img_path.split(os.path.sep)[-1] for img_path in self.image_paths] | ||||
|         self.all_images = np.array(sorted(self.all_images), dtype=str) | ||||
|         if hasattr(use_idxs, '__len__'): | ||||
|             self.all_images = self.all_images[use_idxs] | ||||
|         self.bbox_df = pd.read_csv(os.path.join(dir_path, 'bbox_dataset.csv'), sep=',', index_col=0) | ||||
| 
 | ||||
|     def __getitem__(self, idx): | ||||
|         image_name = self.all_images[idx] | ||||
|         image_path = os.path.join(self.dir_path, image_name) | ||||
| 
 | ||||
|         img = Image.open(image_path) | ||||
|         img_tensor = F.to_tensor(img.convert('RGB')) | ||||
| 
 | ||||
|         Cbbox = self.bbox_df[self.bbox_df['image'] == image_name] | ||||
| 
 | ||||
|         labels = np.ones(len(Cbbox), dtype=int) | ||||
|         boxes = Cbbox.loc[:, ['x0', 'x1', 'y0', 'y1']].values | ||||
| 
 | ||||
|         target = {} | ||||
|         target["boxes"] = boxes | ||||
|         target["labels"] = labels | ||||
| 
 | ||||
|         return img_tensor, target | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return len(self.all_images) | ||||
| 
 | ||||
| def create_train_test_dataset(path, test_size=0.2): | ||||
|     files = glob.glob(os.path.join(path, '*.png')) | ||||
|     train_test_idx = np.arange(len(files), dtype=int) | ||||
|     np.random.shuffle(train_test_idx) | ||||
| 
 | ||||
|     train_idx = train_test_idx[int(test_size*len(train_test_idx)):] | ||||
|     test_idx = train_test_idx[:int(test_size*len(train_test_idx))] | ||||
| 
 | ||||
|     train_data = CustomDataset(path, use_idxs=train_idx) | ||||
|     test_data = CustomDataset(path, use_idxs=test_idx) | ||||
| 
 | ||||
|     return train_data, test_data | ||||
| 
 | ||||
| def create_train_loader(train_dataset, num_workers=0): | ||||
|     train_loader = DataLoader( | ||||
|         train_dataset, | ||||
|         batch_size=BATCH_SIZE, | ||||
|         shuffle=True, | ||||
|         num_workers=num_workers, | ||||
|         collate_fn=collate_fn | ||||
|     ) | ||||
|     return train_loader | ||||
| def create_valid_loader(valid_dataset, num_workers=0): | ||||
|     valid_loader = DataLoader( | ||||
|         valid_dataset, | ||||
|         batch_size=BATCH_SIZE, | ||||
|         shuffle=False, | ||||
|         num_workers=num_workers, | ||||
|         collate_fn=collate_fn | ||||
|     ) | ||||
|     return valid_loader | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
| 
 | ||||
|     train_data, test_data = create_train_test_dataset(TRAIN_DIR) | ||||
| 
 | ||||
|     train_loader = create_train_loader(train_data) | ||||
|     test_loader = create_valid_loader(test_data) | ||||
| 
 | ||||
|     for samples, targets in train_loader: | ||||
|         for s, t in zip(samples, targets): | ||||
|             fig, ax = plt.subplots() | ||||
|             ax.imshow(s.permute(1, 2, 0), aspect='auto') | ||||
|             for (x0, x1, y0, y1), l in zip(t['boxes'], t['labels']): | ||||
|                 print(x0, x1, y0, y1, l) | ||||
|                 ax.add_patch( | ||||
|                     Rectangle((x0, y0), | ||||
|                               (x1 - x0), | ||||
|                               (y1 - y0), | ||||
|                               fill=False, color="white", linewidth=2, zorder=10) | ||||
|                 ) | ||||
|             plt.show() | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user