diff --git a/confic.py b/confic.py
index 634bd5e..3016472 100644
--- a/confic.py
+++ b/confic.py
@@ -1,7 +1,7 @@
 import torch
 import pathlib
 
-BATCH_SIZE = 4
+BATCH_SIZE = 32
 RESIZE_TO = 416
 NUM_EPOCHS = 20
 NUM_WORKERS = 4
diff --git a/data/generate_dataset.py b/data/generate_dataset.py
index 8b59eff..e3e0e91 100644
--- a/data/generate_dataset.py
+++ b/data/generate_dataset.py
@@ -114,10 +114,30 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq
 
         dt_bbox = right_time_bound - left_time_bound
         df_bbox = upper_freq_bound - lower_freq_bound
-        left_time_bound -= dt_bbox * 0.1
-        right_time_bound += dt_bbox * 0.1
-        lower_freq_bound -= df_bbox * 0.1
-        upper_freq_bound += df_bbox * 0.1
+
+        # embed()
+        # quit()
+        # left_time_bound -= dt_bbox + 0.01 * (t1 - t0)
+        # right_time_bound += dt_bbox + 0.01 * (t1 - t0)
+        # lower_freq_bound -= df_bbox + 0.01 * (f1 - f0)
+        # upper_freq_bound += df_bbox + 0.01 * (f1 - f0)
+
+        left_time_bound -= 0.01 * (t1 - t0)
+        right_time_bound += 0.05 * (t1 - t0)
+        lower_freq_bound -= 0.01 * (f1 - f0)
+        upper_freq_bound += 0.05 * (f1 - f0)
+
+        # embed()
+        # quit()
+        mask2 = ((left_time_bound >= t0) &
+                (right_time_bound <= t1) &
+                (lower_freq_bound >= f0) &
+                (upper_freq_bound <= f1)
+        )
+        left_time_bound = left_time_bound[mask2]
+        right_time_bound = right_time_bound[mask2]
+        lower_freq_bound = lower_freq_bound[mask2]
+        upper_freq_bound = upper_freq_bound[mask2]
 
         x0 = np.array((left_time_bound - t0) / (t1 - t0) * width, dtype=int)
         x1 = np.array((right_time_bound - t0) / (t1 - t0) * width, dtype=int)
@@ -129,7 +149,7 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq
                          right_time_bound,
                          lower_freq_bound,
                          upper_freq_bound,
-                         x0, x1, y0, y1])
+                         x0, y0, x1, y1])
         # test_s = ['a', 'a', 'a', 'a']
         tmp_df = pd.DataFrame(
             # index= [pic_save_str for i in range(len(left_time_bound))],
@@ -142,15 +162,53 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq
     return bbox_df
 
 def main(args):
+    def development_fn():
+        fig_title = (f'{Path(args.folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0')
+        fig = plt.figure(figsize=(7, 7), num=fig_title)
+        gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0, left=0.1, bottom=0.1, right=0.9,
+                               top=0.95)  # , bottom=0, left=0, right=1, top=1
+        ax = fig.add_subplot(gs[0, 0])
+        cax = fig.add_subplot(gs[0, 1])
+        im = ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower',
+                       extent=(times[t_idx0], times[t_idx1] + t_res, freq[f_idx0], freq[f_idx1] + f_res))
+        fig.colorbar(im, cax=cax, orientation='vertical')
+
+
+        cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1']
+        dev_df = pd.DataFrame(columns=cols)
+
+        dev_df = bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq_time, fish_baseline_freq,
+                                  fig_title, dev_df, cols, (7*256), (7*256), t0, t1, f0, f1)
+
+        # embed()
+        # quit()
+        time_freq_bbox = torch.as_tensor(dev_df.loc[:, ['t0', 'f0', 't1', 'f1']].values.astype(np.float32))
+
+        for bbox in time_freq_bbox:
+            Ct0, Cf0, Ct1, Cf1 = bbox
+            ax.add_patch(
+                Rectangle((Ct0, Cf0), Ct1-Ct0, Cf1-Cf0, fill=False, color="white", linewidth=2, zorder=10)
+            )
+        # for enu in range(len(left_time_bound)):
+        #     if np.isnan(right_time_bound[enu]):
+        #         continue
+        #     ax.add_patch(
+        #         Rectangle((left_time_bound[enu], lower_freq_bound[enu]),
+        #                   (right_time_bound[enu] - left_time_bound[enu]),
+        #                   (upper_freq_bound[enu] - lower_freq_bound[enu]),
+        #                   fill=False, color="white", linewidth=2, zorder=10)
+        #     )
+        plt.show()
+
     min_freq = 200
     max_freq = 1500
     d_freq = 200
-    freq_overlap = 50
-    d_time = 60*15
-    time_overlap = 60*5
+    freq_overlap = 25
+    d_time = 60*10
+    time_overlap = 60*1
 
     if not os.path.exists(os.path.join('train', 'bbox_dataset.csv')):
-        cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'x1', 'y0', 'y1']
+        cols = ['image', 't0', 't1', 'f0', 'f1', 'x0', 'y0', 'x1', 'y1']
         bbox_df = pd.DataFrame(columns=cols)
 
     else:
@@ -161,10 +219,15 @@ def main(args):
         for f in pd.unique(bbox_df['image']):
             eval_files.append(f.split('__')[0])
 
-    folders = [args.folder]
+    folders = list(f.parent for f in Path(args.folder).rglob('fill_times.npy'))
+
+    # embed()
+    # quit()
 
     for enu, folder in enumerate(folders):
         print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}')
+        if not (folder/'analysis'/'rise_idx.npy').exists():
+            continue
 
         freq, times, spec, EODf_v, ident_v, idx_v, times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq, fish_baseline_freq_time = (
             load_data(folder))
@@ -174,14 +237,15 @@ def main(args):
             np.arange(0, times[-1], d_time),
             np.arange(min_freq, max_freq, d_freq)
         ),
-            total=int(((max_freq-min_freq)//d_freq) * (times[-1] // d_time))
+            total=int((((max_freq-min_freq)//d_freq)+1) * ((times[-1] // d_time)+1))
         )
 
         for t0, f0 in pic_base:
+
             t1 = t0 + d_time + time_overlap
             f1 = f0 + d_freq + freq_overlap
 
-            present_freqs = EODf_v[(~np.isnan(ident_v)) &
+            present_freqs = EODf_v[(~np.isnan(ident_v))     &
                                    (t0 <= times_v[idx_v]) &
                                    (times_v[idx_v] <= t1) &
                                    (EODf_v >= f0) &
@@ -204,73 +268,10 @@ def main(args):
                                            fish_baseline_freq_time, fish_baseline_freq,
                                            pic_save_str, bbox_df, cols, width, height, t0, t1, f0, f1)
             else:
-                fig_title = (f'{Path(args.folder).name}__{t0:.0f}s-{t1:.0f}s__{f0:4.0f}-{f1:4.0f}Hz').replace(' ', '0')
-                fig = plt.figure(figsize=(10, 7), num=fig_title)
-                gs = gridspec.GridSpec(1, 2, width_ratios=(8, 1), wspace=0, left=0.1, bottom=0.1, right=0.9, top=0.95)  # , bottom=0, left=0, right=1, top=1
-                ax = fig.add_subplot(gs[0, 0])
-                cax = fig.add_subplot(gs[0, 1])
-                im = ax.imshow(s_trans.squeeze(), cmap='gray', aspect='auto', origin='lower',
-                               extent=(times[t_idx0], times[t_idx1] + t_res, freq[f_idx0], freq[f_idx1] + f_res))
-                fig.colorbar(im, cax=cax, orientation='vertical')
-
-
-                times_v_idx0, times_v_idx1 = np.argmin(np.abs(times_v - t0)), np.argmin(np.abs(times_v - t1))
-                for id_idx in range(len(fish_freq)):
-                    ax.plot(times_v[times_v_idx0:times_v_idx1], fish_freq[id_idx][times_v_idx0:times_v_idx1], marker='.', color='k', markersize=4)
-                    rise_idx_oi = np.array(rise_idx[id_idx][
-                                               (rise_idx[id_idx] >= times_v_idx0) &
-                                               (rise_idx[id_idx] <= times_v_idx1) &
-                                               (rise_size[id_idx] >= 10)], dtype=int)
-                    rise_size_oi = rise_size[id_idx][(rise_idx[id_idx] >= times_v_idx0) &
-                                                    (rise_idx[id_idx] <= times_v_idx1) &
-                                                    (rise_size[id_idx] >= 10)]
-
-                    ax.plot(times_v[rise_idx_oi], fish_freq[id_idx][rise_idx_oi], 'o', color='tab:red')
-
-                    if len(rise_idx_oi) > 0:
-                        closest_baseline_idx = list(map(lambda x: np.argmin(np.abs(fish_baseline_freq_time - x)), times_v[rise_idx_oi]))
-                        closest_baseline_freq = fish_baseline_freq[id_idx][closest_baseline_idx]
-
-                        upper_freq_bound = closest_baseline_freq + rise_size_oi
-                        lower_freq_bound = closest_baseline_freq
-
-                        left_time_bound = times_v[rise_idx_oi]
-                        right_time_bound = np.zeros_like(left_time_bound)
-
-                        for enu, Ct_oi in enumerate(times_v[rise_idx_oi]):
-                            Crise_size = rise_size_oi[enu]
-                            Cblf = closest_baseline_freq[enu]
-
-                            rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)]
-                            if len(rise_end_t) == 0:
-                                right_time_bound[enu] = np.nan
-                            else:
-                                right_time_bound[enu] = rise_end_t[0]
-
-                        dt_bbox = right_time_bound - left_time_bound
-                        df_bbox = upper_freq_bound - lower_freq_bound
-                        left_time_bound -= dt_bbox*0.1
-                        right_time_bound += dt_bbox*0.1
-                        lower_freq_bound -= df_bbox*0.1
-                        upper_freq_bound += df_bbox*0.1
-
-                        print(f'f0: {lower_freq_bound}')
-                        print(f'f1: {upper_freq_bound}')
-                        print(f't0: {left_time_bound}')
-                        print(f't1: {right_time_bound}')
-
-                        for enu in range(len(left_time_bound)):
-                            if np.isnan(right_time_bound[enu]):
-                                continue
-                            ax.add_patch(
-                                Rectangle((left_time_bound[enu], lower_freq_bound[enu]),
-                                                   (right_time_bound[enu] - left_time_bound[enu]),
-                                                   (upper_freq_bound[enu] - lower_freq_bound[enu]),
-                                                   fill=False, color="white", linewidth=2, zorder=10)
-                            )
-                plt.show()
+                development_fn()
 
         if not args.dev:
+            print('save')
             bbox_df.to_csv(os.path.join('train', 'bbox_dataset.csv'), columns=cols, sep=',')
 
 if __name__ == '__main__':
diff --git a/datasets.py b/datasets.py
index a9b8617..54e992c 100644
--- a/datasets.py
+++ b/datasets.py
@@ -102,8 +102,8 @@ if __name__ == '__main__':
         for s, t in zip(samples, targets):
             fig, ax = plt.subplots()
             ax.imshow(s.permute(1, 2, 0), aspect='auto')
-            for (x0, x1, y0, y1), l in zip(t['boxes'], t['labels']):
-                print(x0, x1, y0, y1, l)
+            for (x0, y0, x1, y1), l in zip(t['boxes'], t['labels']):
+                print(x0, y0, x1, y1, l)
                 ax.add_patch(
                     Rectangle((x0, y0),
                               (x1 - x0),
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000..83a34c3
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,37 @@
+import numpy as np
+import torch
+import torchvision.transforms.functional as F
+import glob
+import os
+from PIL import Image
+
+from model import create_model
+from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR
+
+from IPython import embed
+from tqdm.auto import tqdm
+
+if __name__ == '__main__':
+    model = create_model(num_classes=NUM_CLASSES)
+    checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE)
+    model.load_state_dict(checkpoint["model_state_dict"])
+    model.to(DEVICE).eval()
+
+    DIR_TEST = 'data/train'
+    test_images = glob.glob(f"{DIR_TEST}/*.png")
+
+    detection_threshold = 0.8
+
+    frame_count = 0
+    total_fps = 0
+
+    for i in tqdm(np.arange(len(test_images))):
+        image_name = test_images[i].split(os.path.sep)[-1].split('.')[0]
+
+        img = Image.open(test_images[i])
+        img_tensor = F.to_tensor(img.convert('RGB')).unsqueeze(dim=0)
+
+        with torch.inference_mode():
+            outputs = model(img_tensor.to(DEVICE))
+
+        print(len(outputs[0]['boxes']))
\ No newline at end of file
diff --git a/train.py b/train.py
index 33ea54e..b46fc13 100644
--- a/train.py
+++ b/train.py
@@ -45,6 +45,8 @@ def validate(test_loader, model, val_loss):
 
         targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
 
+        embed()
+        quit()
         with torch.inference_mode():
             loss_dict = model(images, targets)