diff --git a/confic.py b/confic.py index acb0397..634bd5e 100644 --- a/confic.py +++ b/confic.py @@ -1,8 +1,9 @@ import torch +import pathlib BATCH_SIZE = 4 RESIZE_TO = 416 -NUM_EPOCHS = 10 +NUM_EPOCHS = 20 NUM_WORKERS = 4 DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') @@ -13,4 +14,8 @@ CLASSES = ['__backgroud__', '1'] NUM_CLASSES = len(CLASSES) -OUTDIR = 'model_outputs' \ No newline at end of file +OUTDIR = 'model_outputs' + + +if not pathlib.Path(OUTDIR).exists(): + pathlib.Path(OUTDIR).mkdir(parents=True, exist_ok=True) \ No newline at end of file diff --git a/custom_utils.py b/custom_utils.py index f5adad2..c470f6d 100644 --- a/custom_utils.py +++ b/custom_utils.py @@ -1,6 +1,84 @@ +import torch +import matplotlib.pyplot as plt +from confic import OUTDIR +class Averager: + def __init__(self): + self.current_total = 0.0 + self.iterations = 0.0 + + def send(self, value): + self.current_total += value + self.iterations += 1 + + @property + def value(self): + if self.iterations == 0: + return 0 + else: + return 1.0 * self.current_total / self.iterations + + def reset(self): + self.current_total = 0.0 + self.iterations = 0.0 + + +class SaveBestModel: + """ + Class to save the best model while training. If the current epoch's + validation loss is less than the previous least less, then save the + model state. + """ + + def __init__( + self, best_valid_loss=float('inf') + ): + self.best_valid_loss = best_valid_loss + + def __call__( + self, current_valid_loss, + epoch, model, optimizer + ): + if current_valid_loss < self.best_valid_loss: + self.best_valid_loss = current_valid_loss + print(f"\nBest validation loss: {self.best_valid_loss}") + print(f"\nSaving best model for epoch: {epoch + 1}\n") + torch.save({ + 'epoch': epoch + 1, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + }, f'./{OUTDIR}/best_model.pth') + + def collate_fn(batch): """ To handle the data loading as different images may have different number of objects and to handle varying size tensors as well. """ - return tuple(zip(*batch)) \ No newline at end of file + return tuple(zip(*batch)) + + +def save_model(epoch, model, optimizer): + """ + Function to save the trained model till current epoch, or whenver called + """ + torch.save({ + 'epoch': epoch+1, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + }, f'./{OUTDIR}/last_model.pth') + +def save_loss_plot(OUT_DIR, train_loss, val_loss): + figure_1, train_ax = plt.subplots() + figure_2, valid_ax = plt.subplots() + train_ax.plot(train_loss, color='tab:blue') + train_ax.set_xlabel('iterations') + train_ax.set_ylabel('train loss') + valid_ax.plot(val_loss, color='tab:red') + valid_ax.set_xlabel('iterations') + valid_ax.set_ylabel('validation loss') + figure_1.savefig(f"{OUT_DIR}/train_loss.png") + figure_2.savefig(f"{OUT_DIR}/valid_loss.png") + print('SAVING PLOTS COMPLETE...') + + plt.close('all') + diff --git a/model.py b/model.py index 4b67f7a..9804b6e 100644 --- a/model.py +++ b/model.py @@ -23,6 +23,6 @@ def create_model(num_classes: int) -> torch.nn.Module: in_features = model.roi_heads.box_predictor.cls_score.in_features - model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1) + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model \ No newline at end of file diff --git a/train.py b/train.py index ff30473..22fe661 100644 --- a/train.py +++ b/train.py @@ -2,7 +2,9 @@ from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, TRAIN_ from model import create_model from tqdm.auto import tqdm + from datasets import create_train_test_dataset, create_train_loader, create_valid_loader +from custom_utils import Averager, SaveBestModel, save_model, save_loss_plot import torch import matplotlib.pyplot as plt @@ -10,34 +12,81 @@ import time from IPython import embed +def train(train_loader, model, optimizer): + print('Training') + global train_loss_list + + prog_bar = tqdm(train_loader, total=len(train_loader)) + for samples, targets in prog_bar: + images = list(image.to(DEVICE) for image in samples) + + targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] + + loss_dict = model(images, targets) + + losses = sum(loss for loss in loss_dict.values()) + loss_value = losses.item() + train_loss_hist.send(loss_value) # this is a global instance !!! + train_loss_list.append(loss_value) # check what exactly this does !!! + + optimizer.zero_grad() + losses.backward() + optimizer.step() + + prog_bar.set_description(desc=f"Loss: {loss_value:.4f}") + + return train_loss_list + + if __name__ == '__main__': train_data, test_data = create_train_test_dataset(TRAIN_DIR) train_loader = create_train_loader(train_data) test_loader = create_train_loader(test_data) - model = create_model(num_classes=1) + model = create_model(num_classes=NUM_CLASSES) model = model.to(DEVICE) params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005) + train_loss_hist = Averager() + val_loss_hist = Averager() + # train_itr = 1 + # val_itr = 1 + train_loss_list = [] + val_loss_list = [] + + save_best_model = SaveBestModel() + for epoch in range(NUM_EPOCHS): - prog_bar = tqdm(train_loader, total=len(train_loader)) - for samples, targets in prog_bar: - images = list(image.to(DEVICE) for image in samples) - targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] - try: - loss_dict = model(images, targets) - except: - embed() - quit() + train_loss_hist.reset() + val_loss_hist.reset() + + train_loss = train(train_loader, model, optimizer) + # val_loss = validate(train_loader, model, optimizer) + + save_best_model( + val_loss_hist.value, epoch, model, optimizer + ) - losses = sum(loss for loss in loss_dict.values()) - loss_value = losses.item() + save_model(epoch, model, optimizer) - optimizer.zero_grad() - losses.backward() - optimizer.step() + save_loss_plot(OUTDIR, train_loss, val_loss) - prog_bar.set_description(desc=f"Loss: {loss_value:.4f}") \ No newline at end of file + # prog_bar = tqdm(train_loader, total=len(train_loader)) + # for samples, targets in prog_bar: + # images = list(image.to(DEVICE) for image in samples) + # + # targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] + # + # loss_dict = model(images, targets) + # + # losses = sum(loss for loss in loss_dict.values()) + # loss_value = losses.item() + # + # optimizer.zero_grad() + # losses.backward() + # optimizer.step() + # + # prog_bar.set_description(desc=f"Loss: {loss_value:.4f}") \ No newline at end of file