import torch import matplotlib.pyplot as plt from confic import OUTDIR class Averager: def __init__(self): self.current_total = 0.0 self.iterations = 0.0 def send(self, value): self.current_total += value self.iterations += 1 @property def value(self): if self.iterations == 0: return 0 else: return 1.0 * self.current_total / self.iterations def reset(self): self.current_total = 0.0 self.iterations = 0.0 class SaveBestModel: """ Class to save the best model while training. If the current epoch's validation loss is less than the previous least less, then save the model state. """ def __init__( self, best_valid_loss=float('inf') ): self.best_valid_loss = best_valid_loss def __call__( self, current_valid_loss, epoch, model, optimizer ): if current_valid_loss < self.best_valid_loss: self.best_valid_loss = current_valid_loss print(f"\nBest validation loss: {self.best_valid_loss}") print(f"Saving best model for epoch: {epoch + 1}") torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, f'./{OUTDIR}/best_model.pth') def collate_fn(batch): """ To handle the data loading as different images may have different number of objects and to handle varying size tensors as well. """ return tuple(zip(*batch)) def save_model(epoch, model, optimizer): """ Function to save the trained model till current epoch, or whenver called """ torch.save({ 'epoch': epoch+1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, f'./{OUTDIR}/last_model.pth') def save_loss_plot(OUT_DIR, train_loss, val_loss): figure_1, train_ax = plt.subplots() figure_2, valid_ax = plt.subplots() train_ax.plot(train_loss, color='tab:blue') train_ax.set_xlabel('iterations') train_ax.set_ylabel('train loss') valid_ax.plot(val_loss, color='tab:red') valid_ax.set_xlabel('iterations') valid_ax.set_ylabel('validation loss') figure_1.savefig(f"{OUT_DIR}/train_loss.png") figure_2.savefig(f"{OUT_DIR}/valid_loss.png") plt.close('all')