diff --git a/custom_utils.py b/custom_utils.py
index c470f6d..fff2284 100644
--- a/custom_utils.py
+++ b/custom_utils.py
@@ -41,7 +41,7 @@ class SaveBestModel:
         if current_valid_loss < self.best_valid_loss:
             self.best_valid_loss = current_valid_loss
             print(f"\nBest validation loss: {self.best_valid_loss}")
-            print(f"\nSaving best model for epoch: {epoch + 1}\n")
+            print(f"Saving best model for epoch: {epoch + 1}")
             torch.save({
                 'epoch': epoch + 1,
                 'model_state_dict': model.state_dict(),
@@ -78,7 +78,6 @@ def save_loss_plot(OUT_DIR, train_loss, val_loss):
     valid_ax.set_ylabel('validation loss')
     figure_1.savefig(f"{OUT_DIR}/train_loss.png")
     figure_2.savefig(f"{OUT_DIR}/valid_loss.png")
-    print('SAVING PLOTS COMPLETE...')
 
     plt.close('all')
 
diff --git a/train.py b/train.py
index 22fe661..33ea54e 100644
--- a/train.py
+++ b/train.py
@@ -12,9 +12,8 @@ import time
 
 from IPython import embed
 
-def train(train_loader, model, optimizer):
+def train(train_loader, model, optimizer, train_loss):
     print('Training')
-    global train_loss_list
 
     prog_bar = tqdm(train_loader, total=len(train_loader))
     for samples, targets in prog_bar:
@@ -27,7 +26,7 @@ def train(train_loader, model, optimizer):
         losses = sum(loss for loss in loss_dict.values())
         loss_value = losses.item()
         train_loss_hist.send(loss_value) # this is a global instance !!!
-        train_loss_list.append(loss_value) # check what exactly this does !!!
+        train_loss.append(loss_value)
 
         optimizer.zero_grad()
         losses.backward()
@@ -35,8 +34,32 @@ def train(train_loader, model, optimizer):
 
         prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
 
-    return train_loss_list
+    return train_loss
 
+def validate(test_loader, model, val_loss):
+    print('Validation')
+
+    prog_bar = tqdm(test_loader, total=len(test_loader))
+    for samples, targets in prog_bar:
+        images = list(image.to(DEVICE) for image in samples)
+
+        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
+
+        with torch.inference_mode():
+            loss_dict = model(images, targets)
+
+        losses = sum(loss for loss in loss_dict.values())
+        loss_value = losses.item()
+        val_loss_hist.send(loss_value) # this is a global instance !!!
+        val_loss.append(loss_value)
+
+        # optimizer.zero_grad()
+        # losses.backward()
+        # optimizer.step()
+
+        prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
+
+    return val_loss
 
 if __name__ == '__main__':
     train_data, test_data = create_train_test_dataset(TRAIN_DIR)
@@ -53,18 +76,20 @@ if __name__ == '__main__':
     val_loss_hist = Averager()
     # train_itr = 1
     # val_itr = 1
-    train_loss_list = []
-    val_loss_list = []
+    train_loss = []
+    val_loss = []
 
     save_best_model = SaveBestModel()
 
     for epoch in range(NUM_EPOCHS):
+        print(f'\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---')
 
         train_loss_hist.reset()
         val_loss_hist.reset()
 
-        train_loss = train(train_loader, model, optimizer)
-        # val_loss = validate(train_loader, model, optimizer)
+        train_loss = train(train_loader, model, optimizer, train_loss)
+        val_loss = validate(test_loader, model, val_loss)
+
 
         save_best_model(
             val_loss_hist.value, epoch, model, optimizer
@@ -74,19 +99,3 @@ if __name__ == '__main__':
 
         save_loss_plot(OUTDIR, train_loss, val_loss)
 
-        # prog_bar = tqdm(train_loader, total=len(train_loader))
-        # for samples, targets in prog_bar:
-        #     images = list(image.to(DEVICE) for image in samples)
-        #
-        #     targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
-        #
-        #     loss_dict = model(images, targets)
-        #
-        #     losses = sum(loss for loss in loss_dict.values())
-        #     loss_value = losses.item()
-        #
-        #     optimizer.zero_grad()
-        #     losses.backward()
-        #     optimizer.step()
-        #
-        #     prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
\ No newline at end of file