From 8a6e7df57f4e27d71a1597bcf1a5efa06f5cb417 Mon Sep 17 00:00:00 2001 From: Till Raab Date: Tue, 24 Oct 2023 15:15:19 +0200 Subject: [PATCH] generated huge dataset --- confic.py | 2 +- data/generate_dataset.py | 155 ++++++++++++++++++++------------------- datasets.py | 4 +- inference.py | 37 ++++++++++ train.py | 2 + 5 files changed, 120 insertions(+), 80 deletions(-) create mode 100644 inference.py diff --git a/confic.py b/confic.py index 634bd5e..3016472 100644 --- a/confic.py +++ b/confic.py @@ -1,7 +1,7 @@ import torch import pathlib -BATCH_SIZE = 4 +BATCH_SIZE = 32 RESIZE_TO = 416 NUM_EPOCHS = 20 NUM_WORKERS = 4 diff --git a/data/generate_dataset.py b/data/generate_dataset.py index 8b59eff..e3e0e91 100644 --- a/data/generate_dataset.py +++ b/data/generate_dataset.py @@ -114,10 +114,30 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq dt_bbox = right_time_bound - left_time_bound df_bbox = upper_freq_bound - lower_freq_bound - left_time_bound -= dt_bbox * 0.1 - right_time_bound += dt_bbox * 0.1 - lower_freq_bound -= df_bbox * 0.1 - upper_freq_bound += df_bbox * 0.1 + + # embed() + # quit() + # left_time_bound -= dt_bbox + 0.01 * (t1 - t0) + # right_time_bound += dt_bbox + 0.01 * (t1 - t0) + # lower_freq_bound -= df_bbox + 0.01 * (f1 - f0) + # upper_freq_bound += df_bbox + 0.01 * (f1 - f0) + + left_time_bound -= 0.01 * (t1 - t0) + right_time_bound += 0.05 * (t1 - t0) + lower_freq_bound -= 0.01 * (f1 - f0) + upper_freq_bound += 0.05 * (f1 - f0) + + # embed() + # quit() + mask2 = ((left_time_bound >= t0) & + (right_time_bound <= t1) & + (lower_freq_bound >= f0) & + (upper_freq_bound <= f1) + ) + left_time_bound = left_time_bound[mask2] + right_time_bound = right_time_bound[mask2] + lower_freq_bound = lower_freq_bound[mask2] + upper_freq_bound = upper_freq_bound[mask2] x0 = np.array((left_time_bound - t0) / (t1 - t0) * width, dtype=int) x1 = np.array((right_time_bound - t0) / (t1 - t0) * width, dtype=int) @@ -129,7 +149,7 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq right_time_bound, lower_freq_bound, upper_freq_bound, - x0, x1, y0, y1]) + x0, y0, x1, y1]) # test_s = ['a', 'a', 'a', 'a'] tmp_df = pd.DataFrame( # index= [pic_save_str for i in range(len(left_time_bound))], @@ -142,15 +162,53 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq return bbox_df def main(args): + def development_fn(): + fig_title = (f'{Path(args.folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0') + fig = plt.figure(figsize=(7, 7), num=fig_title) + gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0, left=0.1, bottom=0.1, right=0.9, + top=0.95) # , bottom=0, left=0, right=1, top=1 + ax = fig.add_subplot(gs[0, 0]) + cax = fig.add_subplot(gs[0, 1]) + im = ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', + extent=(times[t_idx0], times[t_idx1] + t_res, freq[f_idx0], freq[f_idx1] + f_res)) + fig.colorbar(im, cax=cax, orientation='vertical') + + + cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1'] + dev_df = pd.DataFrame(columns=cols) + + dev_df = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq, + fig_title, dev_df, cols, (7*256), (7*256), t0, t1, f0, f1) + + # embed() + # quit() + time_freq_bbox = torch.as_tensor(dev_df.loc[:, ['t0', 'f0', 't1', 'f1']].values.astype(np.float32)) + + for bbox in time_freq_bbox: + Ct0, Cf0, Ct1, Cf1 = bbox + ax.add_patch( + Rectangle((Ct0, Cf0), Ct1-Ct0, Cf1-Cf0, fill=False, color="white", linewidth=2, zorder=10) + ) + # for enu in range(len(left_time_bound)): + # if np.isnan(right_time_bound[enu]): + # continue + # ax.add_patch( + # Rectangle((left_time_bound[enu], lower_freq_bound[enu]), + # (right_time_bound[enu] - left_time_bound[enu]), + # (upper_freq_bound[enu] - lower_freq_bound[enu]), + # fill=False, color="white", linewidth=2, zorder=10) + # ) + plt.show() + min_freq = 200 max_freq = 1500 d_freq = 200 - freq_overlap = 50 - d_time = 60*15 - time_overlap = 60*5 + freq_overlap = 25 + d_time = 60*10 + time_overlap = 60*1 if not os.path.exists(os.path.join('train', 'bbox_dataset.csv')): - cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'x1', 'y0', 'y1'] + cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1'] bbox_df = pd.DataFrame(columns=cols) else: @@ -161,10 +219,15 @@ def main(args): for f in pd.unique(bbox_df['image']): eval_files.append(f.split('__')[0]) - folders = [args.folder] + folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy')) + + # embed() + # quit() for enu, folder in enumerate(folders): print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}') + if not (folder/'analysis'/'rise_idx.npy').exists(): + continue freq, times, spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = ( load_data(folder)) @@ -174,14 +237,15 @@ def main(args): np.arange(0, times[-1], d_time), np.arange(min_freq, max_freq, d_freq) ), - total=int(((max_freq-min_freq)//d_freq) * (times[-1] // d_time)) + total=int((((max_freq-min_freq)//d_freq)+1) * ((times[-1] // d_time)+1)) ) for t0, f0 in pic_base: + t1 = t0 + d_time + time_overlap f1 = f0 + d_freq + freq_overlap - present_freqs = EODf_v[(~np.isnan(ident_v)) & + present_freqs = EODf_v[(~np.isnan(ident_v)) & (t0 <= times_v[idx_v]) & (times_v[idx_v] <= t1) & (EODf_v >= f0) & @@ -204,73 +268,10 @@ def main(args): fish_baseline_freq_time, fish_baseline_freq, pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1) else: - fig_title = (f'{Path(args.folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0') - fig = plt.figure(figsize=(10, 7), num=fig_title) - gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0, left=0.1, bottom=0.1, right=0.9, top=0.95) # , bottom=0, left=0, right=1, top=1 - ax = fig.add_subplot(gs[0, 0]) - cax = fig.add_subplot(gs[0, 1]) - im = ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower', - extent=(times[t_idx0], times[t_idx1] + t_res, freq[f_idx0], freq[f_idx1] + f_res)) - fig.colorbar(im, cax=cax, orientation='vertical') - - - times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1)) - for id_idx in range(len(fish_freq)): - ax.plot(times_v[times_v_idx0:times_v_idx1], fish_freq[id_idx][times_v_idx0:times_v_idx1], marker='.', color='k', markersize=4) - rise_idx_oi = np.array(rise_idx[id_idx][ - (rise_idx[id_idx] >= times_v_idx0) & - (rise_idx[id_idx] <= times_v_idx1) & - (rise_size[id_idx] >= 10)], dtype=int) - rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) & - (rise_idx[id_idx] <= times_v_idx1) & - (rise_size[id_idx] >= 10)] - - ax.plot(times_v[rise_idx_oi], fish_freq[id_idx][rise_idx_oi], 'o', color='tab:red') - - if len(rise_idx_oi) > 0: - closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi])) - closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx] - - upper_freq_bound = closest_baseline_freq + rise_size_oi - lower_freq_bound = closest_baseline_freq - - left_time_bound = times_v[rise_idx_oi] - right_time_bound = np.zeros_like(left_time_bound) - - for enu, Ct_oi in enumerate(times_v[rise_idx_oi]): - Crise_size = rise_size_oi[enu] - Cblf = closest_baseline_freq[enu] - - rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] - if len(rise_end_t) == 0: - right_time_bound[enu] = np.nan - else: - right_time_bound[enu] = rise_end_t[0] - - dt_bbox = right_time_bound - left_time_bound - df_bbox = upper_freq_bound - lower_freq_bound - left_time_bound -= dt_bbox*0.1 - right_time_bound += dt_bbox*0.1 - lower_freq_bound -= df_bbox*0.1 - upper_freq_bound += df_bbox*0.1 - - print(f'f0: {lower_freq_bound}') - print(f'f1: {upper_freq_bound}') - print(f't0: {left_time_bound}') - print(f't1: {right_time_bound}') - - for enu in range(len(left_time_bound)): - if np.isnan(right_time_bound[enu]): - continue - ax.add_patch( - Rectangle((left_time_bound[enu], lower_freq_bound[enu]), - (right_time_bound[enu] - left_time_bound[enu]), - (upper_freq_bound[enu] - lower_freq_bound[enu]), - fill=False, color="white", linewidth=2, zorder=10) - ) - plt.show() + development_fn() if not args.dev: + print('save') bbox_df.to_csv(os.path.join('train', 'bbox_dataset.csv'), columns=cols, sep=',') if __name__ == '__main__': diff --git a/datasets.py b/datasets.py index a9b8617..54e992c 100644 --- a/datasets.py +++ b/datasets.py @@ -102,8 +102,8 @@ if __name__ == '__main__': for s, t in zip(samples, targets): fig, ax = plt.subplots() ax.imshow(s.permute(1, 2, 0), aspect='auto') - for (x0, x1, y0, y1), l in zip(t['boxes'], t['labels']): - print(x0, x1, y0, y1, l) + for (x0, y0, x1, y1), l in zip(t['boxes'], t['labels']): + print(x0, y0, x1, y1, l) ax.add_patch( Rectangle((x0, y0), (x1 - x0), diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..83a34c3 --- /dev/null +++ b/inference.py @@ -0,0 +1,37 @@ +import numpy as np +import torch +import torchvision.transforms.functional as F +import glob +import os +from PIL import Image + +from model import create_model +from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR + +from IPython import embed +from tqdm.auto import tqdm + +if __name__ == '__main__': + model = create_model(num_classes=NUM_CLASSES) + checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE) + model.load_state_dict(checkpoint["model_state_dict"]) + model.to(DEVICE).eval() + + DIR_TEST = 'data/train' + test_images = glob.glob(f"{DIR_TEST}/*.png") + + detection_threshold = 0.8 + + frame_count = 0 + total_fps = 0 + + for i in tqdm(np.arange(len(test_images))): + image_name = test_images[i].split(os.path.sep)[-1].split('.')[0] + + img = Image.open(test_images[i]) + img_tensor = F.to_tensor(img.convert('RGB')).unsqueeze(dim=0) + + with torch.inference_mode(): + outputs = model(img_tensor.to(DEVICE)) + + print(len(outputs[0]['boxes'])) \ No newline at end of file diff --git a/train.py b/train.py index 33ea54e..b46fc13 100644 --- a/train.py +++ b/train.py @@ -45,6 +45,8 @@ def validate(test_loader, model, val_loss): targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] + embed() + quit() with torch.inference_mode(): loss_dict = model(images, targets)