From ad74322f94f5eb737b1d91e5d278a2e6ebb9ebc6 Mon Sep 17 00:00:00 2001 From: Till Raab Date: Wed, 25 Oct 2023 09:09:20 +0200 Subject: [PATCH] the model is working. however, in need to generate it and split data in test and train beforehand ... --- confic.py | 4 ++-- inference.py | 28 +++++++++++++++++++++++++++- train.py | 2 -- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/confic.py b/confic.py index 3016472..c08cab2 100644 --- a/confic.py +++ b/confic.py @@ -1,9 +1,9 @@ import torch import pathlib -BATCH_SIZE = 32 +BATCH_SIZE = 8 RESIZE_TO = 416 -NUM_EPOCHS = 20 +NUM_EPOCHS = 10 NUM_WORKERS = 4 DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') diff --git a/inference.py b/inference.py index 83a34c3..58048ac 100644 --- a/inference.py +++ b/inference.py @@ -10,6 +10,28 @@ from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR from IPython import embed from tqdm.auto import tqdm +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle + +def show_sample(img_tensor, outputs, detection_threshold): + + # embed() + # quit() + fig, ax = plt.subplots() + ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto') + for (x0, y0, x1, y1), l, score in zip(outputs[0]['boxes'].cpu(), outputs[0]['labels'].cpu(), outputs[0]['scores'].cpu()): + + if score < detection_threshold: + continue + # print(x0, y0, x1, y1, l) + ax.text(x0, y0, f'{score:.2f}', ha='left', va='bottom', fontsize=12, color='white') + ax.add_patch( + Rectangle((x0, y0), + (x1 - x0), + (y1 - y0), + fill=False, color="white", linewidth=2, zorder=10) + ) + plt.show() if __name__ == '__main__': model = create_model(num_classes=NUM_CLASSES) @@ -18,6 +40,8 @@ if __name__ == '__main__': model.to(DEVICE).eval() DIR_TEST = 'data/train' + + test_images = glob.glob(f"{DIR_TEST}/*.png") detection_threshold = 0.8 @@ -34,4 +58,6 @@ if __name__ == '__main__': with torch.inference_mode(): outputs = model(img_tensor.to(DEVICE)) - print(len(outputs[0]['boxes'])) \ No newline at end of file + print(len(outputs[0]['boxes'])) + + show_sample(img_tensor, outputs, detection_threshold) \ No newline at end of file diff --git a/train.py b/train.py index b46fc13..33ea54e 100644 --- a/train.py +++ b/train.py @@ -45,8 +45,6 @@ def validate(test_loader, model, val_loss): targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] - embed() - quit() with torch.inference_mode(): loss_dict = model(images, targets)