diff --git a/data/generate_dataset.py b/data/generate_dataset.py
index 08a4670..8b59eff 100644
--- a/data/generate_dataset.py
+++ b/data/generate_dataset.py
@@ -99,12 +99,19 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq
             Crise_size = rise_size_oi[enu]
             Cblf = closest_baseline_freq[enu]
 
-            rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)]
+            rise_end_t = times_v[(times_v > Ct_oi) &
+                                 (fish_freq[id_idx] < Cblf + Crise_size * 0.37)]
             if len(rise_end_t) == 0:
                 right_time_bound[enu] = np.nan
             else:
                 right_time_bound[enu] = rise_end_t[0]
 
+        mask = (~np.isnan(right_time_bound) & ((right_time_bound - left_time_bound) > 1.))
+        left_time_bound = left_time_bound[mask]
+        right_time_bound = right_time_bound[mask]
+        lower_freq_bound = lower_freq_bound[mask]
+        upper_freq_bound = upper_freq_bound[mask]
+
         dt_bbox = right_time_bound - left_time_bound
         df_bbox = upper_freq_bound - lower_freq_bound
         left_time_bound -= dt_bbox * 0.1
@@ -154,7 +161,7 @@ def main(args):
         for f in pd.unique(bbox_df['image']):
             eval_files.append(f.split('__')[0])
 
-    folders = [args.folders]
+    folders = [args.folder]
 
     for enu, folder in enumerate(folders):
         print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}')
diff --git a/datasets.py b/datasets.py
index 43ef3df..a9b8617 100644
--- a/datasets.py
+++ b/datasets.py
@@ -39,11 +39,19 @@ class CustomDataset(Dataset):
         Cbbox = self.bbox_df[self.bbox_df['image'] == image_name]
 
         labels = np.ones(len(Cbbox), dtype=int)
-        boxes = Cbbox.loc[:, ['x0', 'x1', 'y0', 'y1']].values
+        boxes = torch.as_tensor(Cbbox.loc[:, ['x0', 'y0', 'x1', 'y1']].values, dtype=torch.float32)
+
+        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
+        # no crowd instances
+        iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
 
         target = {}
         target["boxes"] = boxes
-        target["labels"] = labels
+        target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
+        target["area"] = area
+        target["iscrowd"] = iscrowd
+        image_id = torch.tensor([idx])
+        target["image_id"] = image_id
 
         return img_tensor, target
 
diff --git a/model.py b/model.py
index 9804b6e..4b67f7a 100644
--- a/model.py
+++ b/model.py
@@ -23,6 +23,6 @@ def create_model(num_classes: int) -> torch.nn.Module:
 
     in_features = model.roi_heads.box_predictor.cls_score.in_features
 
-    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
+    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1)
 
     return model
\ No newline at end of file
diff --git a/train.py b/train.py
index f9891fb..ff30473 100644
--- a/train.py
+++ b/train.py
@@ -1,3 +1,43 @@
-from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS)
+from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, TRAIN_DIR)
 from model import create_model
 
+from tqdm.auto import tqdm
+from datasets import create_train_test_dataset, create_train_loader, create_valid_loader
+
+import torch
+import matplotlib.pyplot as plt
+import time
+
+from IPython import embed
+
+if __name__ == '__main__':
+    train_data, test_data = create_train_test_dataset(TRAIN_DIR)
+    train_loader = create_train_loader(train_data)
+    test_loader = create_train_loader(test_data)
+
+    model = create_model(num_classes=1)
+    model = model.to(DEVICE)
+
+    params = [p for p in model.parameters() if p.requires_grad]
+    optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
+
+    for epoch in range(NUM_EPOCHS):
+        prog_bar = tqdm(train_loader, total=len(train_loader))
+        for samples, targets in prog_bar:
+            images = list(image.to(DEVICE) for image in samples)
+
+            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
+            try:
+                loss_dict = model(images, targets)
+            except:
+                embed()
+                quit()
+
+            losses = sum(loss for loss in loss_dict.values())
+            loss_value = losses.item()
+
+            optimizer.zero_grad()
+            losses.backward()
+            optimizer.step()
+
+            prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
\ No newline at end of file