diff --git a/custom_utils.py b/custom_utils.py new file mode 100644 index 0000000..f5adad2 --- /dev/null +++ b/custom_utils.py @@ -0,0 +1,6 @@ +def collate_fn(batch): + """ + To handle the data loading as different images may have different number + of objects and to handle varying size tensors as well. + """ + return tuple(zip(*batch)) \ No newline at end of file diff --git a/data/generate_dataset.py b/data/generate_dataset.py index f51e4fb..08a4670 100644 --- a/data/generate_dataset.py +++ b/data/generate_dataset.py @@ -9,6 +9,7 @@ import torchvision.transforms as T import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from pathlib import Path +import pandas as pd from tqdm.auto import tqdm @@ -55,7 +56,9 @@ def load_data(folder): return fill_freqs, fill_times, fill_spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res): - fig_title = (f'{Path(folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0') + size = (7, 7) + dpi = 256 + fig_title = (f'{Path(folder).name}__{times[t_idx0]:.0f}s-{times[t_idx1]:.0f}s__{freq[f_idx0]:4.0f}-{freq[f_idx1]:4.0f}Hz.png').replace(' ', '0') fig = plt.figure(figsize=(7, 7), num=fig_title) gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0) # , bottom=0, left=0, right=1, top=1 gs2 = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) # @@ -64,9 +67,72 @@ def save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, extent=(times[t_idx0] / 3600, times[t_idx1] / 3600 + t_res, freq[f_idx0], freq[f_idx1] + f_res)) ax.axis(False) - plt.savefig(os.path.join('train', fig_title + '.png'), dpi=256) + plt.savefig(os.path.join('train', fig_title), dpi=256) plt.close() + return fig_title, (size[0]*dpi, size[1]*dpi) + +def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, pic_save_str, + bbox_df, cols, width, height, t0, t1, f0, f1): + + times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1)) + for id_idx in range(len(fish_freq)): + rise_idx_oi = np.array(rise_idx[id_idx][ + (rise_idx[id_idx] >= times_v_idx0) & + (rise_idx[id_idx] <= times_v_idx1) & + (rise_size[id_idx] >= 10)], dtype=int) + rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) & + (rise_idx[id_idx] <= times_v_idx1) & + (rise_size[id_idx] >= 10)] + if len(rise_idx_oi) == 0: + continue + closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi])) + closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx] + + upper_freq_bound = closest_baseline_freq + rise_size_oi + lower_freq_bound = closest_baseline_freq + + left_time_bound = times_v[rise_idx_oi] + right_time_bound = np.zeros_like(left_time_bound) + + for enu, Ct_oi in enumerate(times_v[rise_idx_oi]): + Crise_size = rise_size_oi[enu] + Cblf = closest_baseline_freq[enu] + + rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] + if len(rise_end_t) == 0: + right_time_bound[enu] = np.nan + else: + right_time_bound[enu] = rise_end_t[0] + + dt_bbox = right_time_bound - left_time_bound + df_bbox = upper_freq_bound - lower_freq_bound + left_time_bound -= dt_bbox * 0.1 + right_time_bound += dt_bbox * 0.1 + lower_freq_bound -= df_bbox * 0.1 + upper_freq_bound += df_bbox * 0.1 + + x0 = np.array((left_time_bound - t0) / (t1 - t0) * width, dtype=int) + x1 = np.array((right_time_bound - t0) / (t1 - t0) * width, dtype=int) + y0 = np.array((1 - (upper_freq_bound - f0) / (f1 - f0)) * height, dtype=int) + y1 = np.array((1 - (lower_freq_bound - f0) / (f1 - f0)) * height, dtype=int) + + bbox = np.array([[pic_save_str for i in range(len(left_time_bound))], + left_time_bound, + right_time_bound, + lower_freq_bound, + upper_freq_bound, + x0, x1, y0, y1]) + # test_s = ['a', 'a', 'a', 'a'] + tmp_df = pd.DataFrame( + # index= [pic_save_str for i in range(len(left_time_bound))], + # index= test_s, + data=bbox.T, + columns=cols + ) + bbox_df = pd.concat([bbox_df, tmp_df], ignore_index=True) + # bbox_df.append(tmp_df) + return bbox_df def main(args): min_freq = 200 @@ -76,109 +142,129 @@ def main(args): d_time = 60*15 time_overlap = 60*5 - freq, times, spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = ( - load_data(args.folder)) - f_res, t_res = freq[1] - freq[0], times[1] - times[0] - - unique_ids = np.unique(ident_v[~np.isnan(ident_v)]) - - pic_base = tqdm(itertools.product( - np.arange(0, times[-1], d_time), - np.arange(min_freq, max_freq, d_freq) - ), - total=int(((max_freq-min_freq)//d_freq) * (times[-1] // d_time)) - ) - - for t0, f0 in pic_base: - t1 = t0 + d_time + time_overlap - f1 = f0 + d_freq + freq_overlap - - present_freqs = EODf_v[(~np.isnan(ident_v)) & - (t0 <= times_v[idx_v]) & - (times_v[idx_v] <= t1) & - (EODf_v >= f0) & - (EODf_v <= f1)] - if len(present_freqs) == 0: - continue - - f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1)) - t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1)) - - s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32) - log_s = torch.log10(s) - transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s)) - s_trans = transformed(log_s.unsqueeze(0)) + if not os.path.exists(os.path.join('train', 'bbox_dataset.csv')): + cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'x1', 'y0', 'y1'] + bbox_df = pd.DataFrame(columns=cols) + + else: + bbox_df = pd.read_csv(os.path.join('train', 'bbox_dataset.csv'), sep=',', index_col=0) + cols = list(bbox_df.keys()) + eval_files = [] + # ToDo: make sure not same file twice + for f in pd.unique(bbox_df['image']): + eval_files.append(f.split('__')[0]) + + folders = [args.folders] + + for enu, folder in enumerate(folders): + print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}') + + freq, times, spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = ( + load_data(folder)) + f_res, t_res = freq[1] - freq[0], times[1] - times[0] + + pic_base = tqdm(itertools.product( + np.arange(0, times[-1], d_time), + np.arange(min_freq, max_freq, d_freq) + ), + total=int(((max_freq-min_freq)//d_freq) * (times[-1] // d_time)) + ) + + for t0, f0 in pic_base: + t1 = t0 + d_time + time_overlap + f1 = f0 + d_freq + freq_overlap + + present_freqs = EODf_v[(~np.isnan(ident_v)) & + (t0 <= times_v[idx_v]) & + (times_v[idx_v] <= t1) & + (EODf_v >= f0) & + (EODf_v <= f1)] + if len(present_freqs) == 0: + continue + + f_idx0, f_idx1 = np.argmin(np.abs(freq - f0)), np.argmin(np.abs(freq - f1)) + t_idx0, t_idx1 = np.argmin(np.abs(times - t0)), np.argmin(np.abs(times - t1)) + + s = torch.from_numpy(spec[f_idx0:f_idx1, t_idx0:t_idx1].copy()).type(torch.float32) + log_s = torch.log10(s) + transformed = T.Normalize(mean=torch.mean(log_s), std=torch.std(log_s)) + s_trans = transformed(log_s.unsqueeze(0)) + + if not args.dev: + pic_save_str, (width, height) = save_spec_pic(folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res) + + bbox_df = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, + fish_baseline_freq_time, fish_baseline_freq, + pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1) + else: + fig_title = (f'{Path(args.folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0') + fig = plt.figure(figsize=(10, 7), num=fig_title) + gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0, left=0.1, bottom=0.1, right=0.9, top=0.95) # , bottom=0, left=0, right=1, top=1 + ax = fig.add_subplot(gs[0, 0]) + cax = fig.add_subplot(gs[0, 1]) + im = ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', + extent=(times[t_idx0], times[t_idx1] + t_res, freq[f_idx0], freq[f_idx1] + f_res)) + fig.colorbar(im, cax=cax, orientation='vertical') + + + times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1)) + for id_idx in range(len(fish_freq)): + ax.plot(times_v[times_v_idx0:times_v_idx1], fish_freq[id_idx][times_v_idx0:times_v_idx1], marker='.', color='k', markersize=4) + rise_idx_oi = np.array(rise_idx[id_idx][ + (rise_idx[id_idx] >= times_v_idx0) & + (rise_idx[id_idx] <= times_v_idx1) & + (rise_size[id_idx] >= 10)], dtype=int) + rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) & + (rise_idx[id_idx] <= times_v_idx1) & + (rise_size[id_idx] >= 10)] + + ax.plot(times_v[rise_idx_oi], fish_freq[id_idx][rise_idx_oi], 'o', color='tab:red') + + if len(rise_idx_oi) > 0: + closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi])) + closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx] + + upper_freq_bound = closest_baseline_freq + rise_size_oi + lower_freq_bound = closest_baseline_freq + + left_time_bound = times_v[rise_idx_oi] + right_time_bound = np.zeros_like(left_time_bound) + + for enu, Ct_oi in enumerate(times_v[rise_idx_oi]): + Crise_size = rise_size_oi[enu] + Cblf = closest_baseline_freq[enu] + + rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] + if len(rise_end_t) == 0: + right_time_bound[enu] = np.nan + else: + right_time_bound[enu] = rise_end_t[0] + + dt_bbox = right_time_bound - left_time_bound + df_bbox = upper_freq_bound - lower_freq_bound + left_time_bound -= dt_bbox*0.1 + right_time_bound += dt_bbox*0.1 + lower_freq_bound -= df_bbox*0.1 + upper_freq_bound += df_bbox*0.1 + + print(f'f0: {lower_freq_bound}') + print(f'f1: {upper_freq_bound}') + print(f't0: {left_time_bound}') + print(f't1: {right_time_bound}') + + for enu in range(len(left_time_bound)): + if np.isnan(right_time_bound[enu]): + continue + ax.add_patch( + Rectangle((left_time_bound[enu], lower_freq_bound[enu]), + (right_time_bound[enu] - left_time_bound[enu]), + (upper_freq_bound[enu] - lower_freq_bound[enu]), + fill=False, color="white", linewidth=2, zorder=10) + ) + plt.show() if not args.dev: - save_spec_pic(args.folder, s_trans, times, freq, t_idx0, t_idx1, f_idx0, f_idx1, t_res, f_res) - - else: - fig_title = (f'{Path(args.folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0') - fig = plt.figure(figsize=(10, 7), num=fig_title) - gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0, left=0.1, bottom=0.1, right=0.9, top=0.95) # , bottom=0, left=0, right=1, top=1 - ax = fig.add_subplot(gs[0, 0]) - cax = fig.add_subplot(gs[0, 1]) - im = ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', - extent=(times[t_idx0], times[t_idx1] + t_res, freq[f_idx0], freq[f_idx1] + f_res)) - fig.colorbar(im, cax=cax, orientation='vertical') - - - times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1)) - for id_idx in range(len(fish_freq)): - ax.plot(times_v[times_v_idx0:times_v_idx1], fish_freq[id_idx][times_v_idx0:times_v_idx1], marker='.', color='k', markersize=4) - rise_idx_oi = np.array(rise_idx[id_idx][ - (rise_idx[id_idx] >= times_v_idx0) & - (rise_idx[id_idx] <= times_v_idx1) & - (rise_size[id_idx] >= 10)], dtype=int) - rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) & - (rise_idx[id_idx] <= times_v_idx1) & - (rise_size[id_idx] >= 10)] - - ax.plot(times_v[rise_idx_oi], fish_freq[id_idx][rise_idx_oi], 'o', color='tab:red') - - if len(rise_idx_oi) > 0: - closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi])) - closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx] - - upper_freq_bound = closest_baseline_freq + rise_size_oi - lower_freq_bound = closest_baseline_freq - - left_time_bound = times_v[rise_idx_oi] - right_time_bound = np.zeros_like(left_time_bound) - - for enu, Ct_oi in enumerate(times_v[rise_idx_oi]): - Crise_size = rise_size_oi[enu] - Cblf = closest_baseline_freq[enu] - - rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] - if len(rise_end_t) == 0: - right_time_bound[enu] = np.nan - else: - right_time_bound[enu] = rise_end_t[0] - - dt_bbox = right_time_bound - left_time_bound - df_bbox = upper_freq_bound - lower_freq_bound - left_time_bound -= dt_bbox*0.1 - right_time_bound += dt_bbox*0.1 - lower_freq_bound -= df_bbox*0.1 - upper_freq_bound += df_bbox*0.1 - - print(f'f0: {lower_freq_bound}') - print(f'f1: {upper_freq_bound}') - print(f't0: {left_time_bound}') - print(f't1: {right_time_bound}') - - for enu in range(len(left_time_bound)): - if np.isnan(right_time_bound[enu]): - continue - ax.add_patch( - Rectangle((left_time_bound[enu], lower_freq_bound[enu]), - (right_time_bound[enu] - left_time_bound[enu]), - (upper_freq_bound[enu] - lower_freq_bound[enu]), - fill=False, color="white", linewidth=2, zorder=10) - ) - plt.show() - + bbox_df.to_csv(os.path.join('train', 'bbox_dataset.csv'), columns=cols, sep=',') if __name__ == '__main__': parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') diff --git a/datasets.py b/datasets.py new file mode 100644 index 0000000..43ef3df --- /dev/null +++ b/datasets.py @@ -0,0 +1,105 @@ +import os +import glob + +import torch +import torchvision +import torchvision.transforms.functional as F +from torch.utils.data import Dataset, DataLoader + +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd +from matplotlib.patches import Rectangle +from pathlib import Path +from tqdm.auto import tqdm +from PIL import Image + +from confic import (CLASSES, RESIZE_TO, TRAIN_DIR, BATCH_SIZE) +from custom_utils import collate_fn + +from IPython import embed + +class CustomDataset(Dataset): + def __init__(self, dir_path, use_idxs = None): + self.dir_path = dir_path + self.image_paths = glob.glob(f'{self.dir_path}/*.png') + self.all_images = [img_path.split(os.path.sep)[-1] for img_path in self.image_paths] + self.all_images = np.array(sorted(self.all_images), dtype=str) + if hasattr(use_idxs, '__len__'): + self.all_images = self.all_images[use_idxs] + self.bbox_df = pd.read_csv(os.path.join(dir_path, 'bbox_dataset.csv'), sep=',', index_col=0) + + def __getitem__(self, idx): + image_name = self.all_images[idx] + image_path = os.path.join(self.dir_path, image_name) + + img = Image.open(image_path) + img_tensor = F.to_tensor(img.convert('RGB')) + + Cbbox = self.bbox_df[self.bbox_df['image'] == image_name] + + labels = np.ones(len(Cbbox), dtype=int) + boxes = Cbbox.loc[:, ['x0', 'x1', 'y0', 'y1']].values + + target = {} + target["boxes"] = boxes + target["labels"] = labels + + return img_tensor, target + + def __len__(self): + return len(self.all_images) + +def create_train_test_dataset(path, test_size=0.2): + files = glob.glob(os.path.join(path, '*.png')) + train_test_idx = np.arange(len(files), dtype=int) + np.random.shuffle(train_test_idx) + + train_idx = train_test_idx[int(test_size*len(train_test_idx)):] + test_idx = train_test_idx[:int(test_size*len(train_test_idx))] + + train_data = CustomDataset(path, use_idxs=train_idx) + test_data = CustomDataset(path, use_idxs=test_idx) + + return train_data, test_data + +def create_train_loader(train_dataset, num_workers=0): + train_loader = DataLoader( + train_dataset, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=num_workers, + collate_fn=collate_fn + ) + return train_loader +def create_valid_loader(valid_dataset, num_workers=0): + valid_loader = DataLoader( + valid_dataset, + batch_size=BATCH_SIZE, + shuffle=False, + num_workers=num_workers, + collate_fn=collate_fn + ) + return valid_loader + + +if __name__ == '__main__': + + train_data, test_data = create_train_test_dataset(TRAIN_DIR) + + train_loader = create_train_loader(train_data) + test_loader = create_valid_loader(test_data) + + for samples, targets in train_loader: + for s, t in zip(samples, targets): + fig, ax = plt.subplots() + ax.imshow(s.permute(1, 2, 0), aspect='auto') + for (x0, x1, y0, y1), l in zip(t['boxes'], t['labels']): + print(x0, x1, y0, y1, l) + ax.add_patch( + Rectangle((x0, y0), + (x1 - x0), + (y1 - y0), + fill=False, color="white", linewidth=2, zorder=10) + ) + plt.show() \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..f9891fb --- /dev/null +++ b/train.py @@ -0,0 +1,3 @@ +from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS) +from model import create_model +