From 48d84d37938cc1de06ce4222f353f601c105c5f5 Mon Sep 17 00:00:00 2001 From: Till Raab Date: Tue, 24 Oct 2023 09:17:31 +0200 Subject: [PATCH] looks better ... write inference code and feed the model --- custom_utils.py | 3 +-- train.py | 57 ++++++++++++++++++++++++++++--------------------- 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/custom_utils.py b/custom_utils.py index c470f6d..fff2284 100644 --- a/custom_utils.py +++ b/custom_utils.py @@ -41,7 +41,7 @@ class SaveBestModel: if current_valid_loss < self.best_valid_loss: self.best_valid_loss = current_valid_loss print(f"\nBest validation loss: {self.best_valid_loss}") - print(f"\nSaving best model for epoch: {epoch + 1}\n") + print(f"Saving best model for epoch: {epoch + 1}") torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), @@ -78,7 +78,6 @@ def save_loss_plot(OUT_DIR, train_loss, val_loss): valid_ax.set_ylabel('validation loss') figure_1.savefig(f"{OUT_DIR}/train_loss.png") figure_2.savefig(f"{OUT_DIR}/valid_loss.png") - print('SAVING PLOTS COMPLETE...') plt.close('all') diff --git a/train.py b/train.py index 22fe661..33ea54e 100644 --- a/train.py +++ b/train.py @@ -12,9 +12,8 @@ import time from IPython import embed -def train(train_loader, model, optimizer): +def train(train_loader, model, optimizer, train_loss): print('Training') - global train_loss_list prog_bar = tqdm(train_loader, total=len(train_loader)) for samples, targets in prog_bar: @@ -27,7 +26,7 @@ def train(train_loader, model, optimizer): losses = sum(loss for loss in loss_dict.values()) loss_value = losses.item() train_loss_hist.send(loss_value) # this is a global instance !!! - train_loss_list.append(loss_value) # check what exactly this does !!! + train_loss.append(loss_value) optimizer.zero_grad() losses.backward() @@ -35,8 +34,32 @@ def train(train_loader, model, optimizer): prog_bar.set_description(desc=f"Loss: {loss_value:.4f}") - return train_loss_list + return train_loss +def validate(test_loader, model, val_loss): + print('Validation') + + prog_bar = tqdm(test_loader, total=len(test_loader)) + for samples, targets in prog_bar: + images = list(image.to(DEVICE) for image in samples) + + targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] + + with torch.inference_mode(): + loss_dict = model(images, targets) + + losses = sum(loss for loss in loss_dict.values()) + loss_value = losses.item() + val_loss_hist.send(loss_value) # this is a global instance !!! + val_loss.append(loss_value) + + # optimizer.zero_grad() + # losses.backward() + # optimizer.step() + + prog_bar.set_description(desc=f"Loss: {loss_value:.4f}") + + return val_loss if __name__ == '__main__': train_data, test_data = create_train_test_dataset(TRAIN_DIR) @@ -53,18 +76,20 @@ if __name__ == '__main__': val_loss_hist = Averager() # train_itr = 1 # val_itr = 1 - train_loss_list = [] - val_loss_list = [] + train_loss = [] + val_loss = [] save_best_model = SaveBestModel() for epoch in range(NUM_EPOCHS): + print(f'\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---') train_loss_hist.reset() val_loss_hist.reset() - train_loss = train(train_loader, model, optimizer) - # val_loss = validate(train_loader, model, optimizer) + train_loss = train(train_loader, model, optimizer, train_loss) + val_loss = validate(test_loader, model, val_loss) + save_best_model( val_loss_hist.value, epoch, model, optimizer @@ -74,19 +99,3 @@ if __name__ == '__main__': save_loss_plot(OUTDIR, train_loss, val_loss) - # prog_bar = tqdm(train_loader, total=len(train_loader)) - # for samples, targets in prog_bar: - # images = list(image.to(DEVICE) for image in samples) - # - # targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] - # - # loss_dict = model(images, targets) - # - # losses = sum(loss for loss in loss_dict.values()) - # loss_value = losses.item() - # - # optimizer.zero_grad() - # losses.backward() - # optimizer.step() - # - # prog_bar.set_description(desc=f"Loss: {loss_value:.4f}") \ No newline at end of file