diff --git a/train.py b/train.py index 9f941ed..db2c584 100644 --- a/train.py +++ b/train.py @@ -1,15 +1,18 @@ -from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, DATA_DIR) +from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, DATA_DIR, IMG_SIZE, IMG_DPI, INFERENCE_OUTDIR) from model import create_model - -from tqdm.auto import tqdm - from datasets import create_train_loader, create_valid_loader, create_train_or_test_dataset from custom_utils import Averager, SaveBestModel, save_model, save_loss_plot +from tqdm.auto import tqdm + +import os import torch import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +from matplotlib.patches import Rectangle import time +from pathlib import Path from IPython import embed def train(train_loader, model, optimizer, train_loss): @@ -18,8 +21,7 @@ def train(train_loader, model, optimizer, train_loss): prog_bar = tqdm(train_loader, total=len(train_loader)) for samples, targets in prog_bar: images = list(image.to(DEVICE) for image in samples) - - # targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] + # img_names = [t['image_name'] for t in targets] targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets] loss_dict = model(images, targets) @@ -43,11 +45,8 @@ def validate(test_loader, model, val_loss): prog_bar = tqdm(test_loader, total=len(test_loader)) for samples, targets in prog_bar: images = list(image.to(DEVICE) for image in samples) - - # targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets] - with torch.inference_mode(): loss_dict = model(images, targets) @@ -55,15 +54,67 @@ def validate(test_loader, model, val_loss): loss_value = losses.item() val_loss_hist.send(loss_value) # this is a global instance !!! val_loss.append(loss_value) - - # optimizer.zero_grad() - # losses.backward() - # optimizer.step() - prog_bar.set_description(desc=f"Loss: {loss_value:.4f}") return val_loss +def best_model_validation_with_plots(test_loader): + model = create_model(num_classes=NUM_CLASSES) + checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE) + model.load_state_dict(checkpoint["model_state_dict"]) + model.to(DEVICE).eval() + + validate_with_plots(test_loader, model) + + +def validate_with_plots(test_loader, model, detection_th=0.8): + print('Final validation with image putput') + + prog_bar = tqdm(test_loader, total=len(test_loader)) + for samples, targets in prog_bar: + images = list(image.to(DEVICE) for image in samples) + + img_names = [t['image_name'] for t in targets] + targets = [{k: v for k, v in t.items() if k != 'image_name'} for t in targets] + + with torch.inference_mode(): + outputs = model(images) + + for image, img_name, output, target in zip(images, img_names, outputs, targets): + plot_validation(image, img_name, output, target, detection_th) + + +def plot_validation(img_tensor, img_name, output, target, detection_threshold): + + fig = plt.figure(figsize=IMG_SIZE, num=img_name) + gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) # + ax = fig.add_subplot(gs[0, 0]) + + ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto') + for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()): + if score < detection_threshold: + continue + # print(x0, y0, x1, y1, l) + ax.text(x0, y0, f'{score:.2f}', ha='left', va='bottom', fontsize=12, color='white') + ax.add_patch( + Rectangle((x0, y0), + (x1 - x0), + (y1 - y0), + fill=False, color="tab:green", linestyle='--', linewidth=2, zorder=10) + ) + for (x0, y0, x1, y1), l in zip(target['boxes'], target['labels']): + ax.add_patch( + Rectangle((x0, y0), + (x1 - x0), + (y1 - y0), + fill=False, color="white", linewidth=2, zorder=9) + ) + + ax.set_axis_off() + plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI) + plt.close() + # plt.show() + if __name__ == '__main__': train_data = create_train_or_test_dataset(DATA_DIR) test_data = create_train_or_test_dataset(DATA_DIR, train=False) @@ -91,6 +142,7 @@ if __name__ == '__main__': val_loss_hist.reset() train_loss = train(train_loader, model, optimizer, train_loss) + val_loss = validate(test_loader, model, val_loss) @@ -102,3 +154,5 @@ if __name__ == '__main__': save_loss_plot(OUTDIR, train_loss, val_loss) + # load best model and perform inference with plot output + best_model_validation_with_plots(test_loader) \ No newline at end of file