From 85e675fb48ffc33f1c2245ce061cd7b427dd0d3a Mon Sep 17 00:00:00 2001
From: Till Raab <till.raab@uni-tuebingen.de>
Date: Tue, 24 Oct 2023 08:56:35 +0200
Subject: [PATCH] something works !!!

---
 confic.py       |  9 ++++--
 custom_utils.py | 80 +++++++++++++++++++++++++++++++++++++++++++++++-
 model.py        |  2 +-
 train.py        | 81 +++++++++++++++++++++++++++++++++++++++----------
 4 files changed, 152 insertions(+), 20 deletions(-)

diff --git a/confic.py b/confic.py
index acb0397..634bd5e 100644
--- a/confic.py
+++ b/confic.py
@@ -1,8 +1,9 @@
 import torch
+import pathlib
 
 BATCH_SIZE = 4
 RESIZE_TO = 416
-NUM_EPOCHS = 10
+NUM_EPOCHS = 20
 NUM_WORKERS = 4
 
 DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
@@ -13,4 +14,8 @@ CLASSES = ['__backgroud__', '1']
 
 NUM_CLASSES = len(CLASSES)
 
-OUTDIR = 'model_outputs'
\ No newline at end of file
+OUTDIR = 'model_outputs'
+
+
+if not pathlib.Path(OUTDIR).exists():
+    pathlib.Path(OUTDIR).mkdir(parents=True, exist_ok=True)
\ No newline at end of file
diff --git a/custom_utils.py b/custom_utils.py
index f5adad2..c470f6d 100644
--- a/custom_utils.py
+++ b/custom_utils.py
@@ -1,6 +1,84 @@
+import torch
+import matplotlib.pyplot as plt
+from confic import OUTDIR
+class Averager:
+    def __init__(self):
+        self.current_total = 0.0
+        self.iterations = 0.0
+
+    def send(self, value):
+        self.current_total += value
+        self.iterations += 1
+
+    @property
+    def value(self):
+        if self.iterations == 0:
+            return 0
+        else:
+            return 1.0 * self.current_total / self.iterations
+
+    def reset(self):
+        self.current_total = 0.0
+        self.iterations = 0.0
+
+
+class SaveBestModel:
+    """
+    Class to save the best model while training. If the current epoch's
+    validation loss is less than the previous least less, then save the
+    model state.
+    """
+
+    def __init__(
+            self, best_valid_loss=float('inf')
+    ):
+        self.best_valid_loss = best_valid_loss
+
+    def __call__(
+            self, current_valid_loss,
+            epoch, model, optimizer
+    ):
+        if current_valid_loss < self.best_valid_loss:
+            self.best_valid_loss = current_valid_loss
+            print(f"\nBest validation loss: {self.best_valid_loss}")
+            print(f"\nSaving best model for epoch: {epoch + 1}\n")
+            torch.save({
+                'epoch': epoch + 1,
+                'model_state_dict': model.state_dict(),
+                'optimizer_state_dict': optimizer.state_dict(),
+            }, f'./{OUTDIR}/best_model.pth')
+
+
 def collate_fn(batch):
     """
     To handle the data loading as different images may have different number
     of objects and to handle varying size tensors as well.
     """
-    return tuple(zip(*batch))
\ No newline at end of file
+    return tuple(zip(*batch))
+
+
+def save_model(epoch, model, optimizer):
+    """
+    Function to save the trained model till current epoch, or whenver called
+    """
+    torch.save({
+                'epoch': epoch+1,
+                'model_state_dict': model.state_dict(),
+                'optimizer_state_dict': optimizer.state_dict(),
+                }, f'./{OUTDIR}/last_model.pth')
+
+def save_loss_plot(OUT_DIR, train_loss, val_loss):
+    figure_1, train_ax = plt.subplots()
+    figure_2, valid_ax = plt.subplots()
+    train_ax.plot(train_loss, color='tab:blue')
+    train_ax.set_xlabel('iterations')
+    train_ax.set_ylabel('train loss')
+    valid_ax.plot(val_loss, color='tab:red')
+    valid_ax.set_xlabel('iterations')
+    valid_ax.set_ylabel('validation loss')
+    figure_1.savefig(f"{OUT_DIR}/train_loss.png")
+    figure_2.savefig(f"{OUT_DIR}/valid_loss.png")
+    print('SAVING PLOTS COMPLETE...')
+
+    plt.close('all')
+
diff --git a/model.py b/model.py
index 4b67f7a..9804b6e 100644
--- a/model.py
+++ b/model.py
@@ -23,6 +23,6 @@ def create_model(num_classes: int) -> torch.nn.Module:
 
     in_features = model.roi_heads.box_predictor.cls_score.in_features
 
-    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1)
+    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
 
     return model
\ No newline at end of file
diff --git a/train.py b/train.py
index ff30473..22fe661 100644
--- a/train.py
+++ b/train.py
@@ -2,7 +2,9 @@ from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, TRAIN_
 from model import create_model
 
 from tqdm.auto import tqdm
+
 from datasets import create_train_test_dataset, create_train_loader, create_valid_loader
+from custom_utils import Averager, SaveBestModel, save_model, save_loss_plot
 
 import torch
 import matplotlib.pyplot as plt
@@ -10,34 +12,81 @@ import time
 
 from IPython import embed
 
+def train(train_loader, model, optimizer):
+    print('Training')
+    global train_loss_list
+
+    prog_bar = tqdm(train_loader, total=len(train_loader))
+    for samples, targets in prog_bar:
+        images = list(image.to(DEVICE) for image in samples)
+
+        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
+
+        loss_dict = model(images, targets)
+
+        losses = sum(loss for loss in loss_dict.values())
+        loss_value = losses.item()
+        train_loss_hist.send(loss_value) # this is a global instance !!!
+        train_loss_list.append(loss_value) # check what exactly this does !!!
+
+        optimizer.zero_grad()
+        losses.backward()
+        optimizer.step()
+
+        prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
+
+    return train_loss_list
+
+
 if __name__ == '__main__':
     train_data, test_data = create_train_test_dataset(TRAIN_DIR)
     train_loader = create_train_loader(train_data)
     test_loader = create_train_loader(test_data)
 
-    model = create_model(num_classes=1)
+    model = create_model(num_classes=NUM_CLASSES)
     model = model.to(DEVICE)
 
     params = [p for p in model.parameters() if p.requires_grad]
     optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
 
+    train_loss_hist = Averager()
+    val_loss_hist = Averager()
+    # train_itr = 1
+    # val_itr = 1
+    train_loss_list = []
+    val_loss_list = []
+
+    save_best_model = SaveBestModel()
+
     for epoch in range(NUM_EPOCHS):
-        prog_bar = tqdm(train_loader, total=len(train_loader))
-        for samples, targets in prog_bar:
-            images = list(image.to(DEVICE) for image in samples)
 
-            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
-            try:
-                loss_dict = model(images, targets)
-            except:
-                embed()
-                quit()
+        train_loss_hist.reset()
+        val_loss_hist.reset()
+
+        train_loss = train(train_loader, model, optimizer)
+        # val_loss = validate(train_loader, model, optimizer)
+
+        save_best_model(
+            val_loss_hist.value, epoch, model, optimizer
+        )
 
-            losses = sum(loss for loss in loss_dict.values())
-            loss_value = losses.item()
+        save_model(epoch, model, optimizer)
 
-            optimizer.zero_grad()
-            losses.backward()
-            optimizer.step()
+        save_loss_plot(OUTDIR, train_loss, val_loss)
 
-            prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
\ No newline at end of file
+        # prog_bar = tqdm(train_loader, total=len(train_loader))
+        # for samples, targets in prog_bar:
+        #     images = list(image.to(DEVICE) for image in samples)
+        #
+        #     targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
+        #
+        #     loss_dict = model(images, targets)
+        #
+        #     losses = sum(loss for loss in loss_dict.values())
+        #     loss_value = losses.item()
+        #
+        #     optimizer.zero_grad()
+        #     losses.backward()
+        #     optimizer.step()
+        #
+        #     prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
\ No newline at end of file