From 30a9e71e765b9ba8a1600fe94bd50be8d05c0da4 Mon Sep 17 00:00:00 2001 From: Till Raab Date: Mon, 23 Oct 2023 15:26:39 +0200 Subject: [PATCH] training loop runs ... no feed it --- data/generate_dataset.py | 11 +++++++++-- datasets.py | 12 ++++++++++-- model.py | 2 +- train.py | 42 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 61 insertions(+), 6 deletions(-) diff --git a/data/generate_dataset.py b/data/generate_dataset.py index 08a4670..8b59eff 100644 --- a/data/generate_dataset.py +++ b/data/generate_dataset.py @@ -99,12 +99,19 @@ def bboxes_from_file(times_v, fish_freq, rise_idx, rise_size, fish_baseline_freq Crise_size = rise_size_oi[enu] Cblf = closest_baseline_freq[enu] - rise_end_t = times_v[(times_v > Ct_oi) & (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] + rise_end_t = times_v[(times_v > Ct_oi) & + (fish_freq[id_idx] < Cblf + Crise_size * 0.37)] if len(rise_end_t) == 0: right_time_bound[enu] = np.nan else: right_time_bound[enu] = rise_end_t[0] + mask = (~np.isnan(right_time_bound) & ((right_time_bound - left_time_bound) > 1.)) + left_time_bound = left_time_bound[mask] + right_time_bound = right_time_bound[mask] + lower_freq_bound = lower_freq_bound[mask] + upper_freq_bound = upper_freq_bound[mask] + dt_bbox = right_time_bound - left_time_bound df_bbox = upper_freq_bound - lower_freq_bound left_time_bound -= dt_bbox * 0.1 @@ -154,7 +161,7 @@ def main(args): for f in pd.unique(bbox_df['image']): eval_files.append(f.split('__')[0]) - folders = [args.folders] + folders = [args.folder] for enu, folder in enumerate(folders): print(f'DataSet generation from {folder} | {enu+1}/{len(folders)}') diff --git a/datasets.py b/datasets.py index 43ef3df..a9b8617 100644 --- a/datasets.py +++ b/datasets.py @@ -39,11 +39,19 @@ class CustomDataset(Dataset): Cbbox = self.bbox_df[self.bbox_df['image'] == image_name] labels = np.ones(len(Cbbox), dtype=int) - boxes = Cbbox.loc[:, ['x0', 'x1', 'y0', 'y1']].values + boxes = torch.as_tensor(Cbbox.loc[:, ['x0', 'y0', 'x1', 'y1']].values, dtype=torch.float32) + + area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) + # no crowd instances + iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64) target = {} target["boxes"] = boxes - target["labels"] = labels + target["labels"] = torch.as_tensor(labels, dtype=torch.int64) + target["area"] = area + target["iscrowd"] = iscrowd + image_id = torch.tensor([idx]) + target["image_id"] = image_id return img_tensor, target diff --git a/model.py b/model.py index 9804b6e..4b67f7a 100644 --- a/model.py +++ b/model.py @@ -23,6 +23,6 @@ def create_model(num_classes: int) -> torch.nn.Module: in_features = model.roi_heads.box_predictor.cls_score.in_features - model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1) return model \ No newline at end of file diff --git a/train.py b/train.py index f9891fb..ff30473 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,43 @@ -from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS) +from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, TRAIN_DIR) from model import create_model +from tqdm.auto import tqdm +from datasets import create_train_test_dataset, create_train_loader, create_valid_loader + +import torch +import matplotlib.pyplot as plt +import time + +from IPython import embed + +if __name__ == '__main__': + train_data, test_data = create_train_test_dataset(TRAIN_DIR) + train_loader = create_train_loader(train_data) + test_loader = create_train_loader(test_data) + + model = create_model(num_classes=1) + model = model.to(DEVICE) + + params = [p for p in model.parameters() if p.requires_grad] + optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005) + + for epoch in range(NUM_EPOCHS): + prog_bar = tqdm(train_loader, total=len(train_loader)) + for samples, targets in prog_bar: + images = list(image.to(DEVICE) for image in samples) + + targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] + try: + loss_dict = model(images, targets) + except: + embed() + quit() + + losses = sum(loss for loss in loss_dict.values()) + loss_value = losses.item() + + optimizer.zero_grad() + losses.backward() + optimizer.step() + + prog_bar.set_description(desc=f"Loss: {loss_value:.4f}") \ No newline at end of file