diff --git a/datasets.py b/datasets.py index df0ee01..57e810a 100644 --- a/datasets.py +++ b/datasets.py @@ -59,6 +59,7 @@ class CustomDataset(Dataset): target["iscrowd"] = iscrowd image_id = torch.tensor([idx]) target["image_id"] = image_id + target["image_name"] = image_name return img_tensor, target diff --git a/inference.py b/inference.py index 36953cd..2cbed5d 100644 --- a/inference.py +++ b/inference.py @@ -14,12 +14,11 @@ from tqdm.auto import tqdm import matplotlib.pyplot as plt from matplotlib.patches import Rectangle -def show_sample(img_tensor, outputs, detection_threshold): +def plot_inference(img_tensor, output, target, detection_threshold): - fig, ax = plt.subplots() + fig, ax = plt.subplots(figsize=(7, 7), num=target['image_name']) ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto') - for (x0, y0, x1, y1), l, score in zip(outputs[0]['boxes'].cpu(), outputs[0]['labels'].cpu(), outputs[0]['scores'].cpu()): - + for (x0, y0, x1, y1), l, score in zip(output[0]['boxes'].cpu(), output[0]['labels'].cpu(), output[0]['scores'].cpu()): if score < detection_threshold: continue # print(x0, y0, x1, y1, l) @@ -28,9 +27,18 @@ def show_sample(img_tensor, outputs, detection_threshold): Rectangle((x0, y0), (x1 - x0), (y1 - y0), - fill=False, color="white", linewidth=2, zorder=10) + fill=False, color="tab:green", linewidth=2, zorder=10) ) + for (x0, y0, x1, y1), l in zip(target['boxes'], target['labels']): + ax.add_patch( + Rectangle((x0, y0), + (x1 - x0), + (y1 - y0), + fill=False, color="white", linewidth=2, zorder=9) + ) + plt.show() + embed() def infere_model(test_loader, model, detection_th=0.8): @@ -44,7 +52,8 @@ def infere_model(test_loader, model, detection_th=0.8): with torch.inference_mode(): outputs = model(images) - embed() + for image, output, target in zip(images, outputs, targets): + plot_inference(image, output, target, detection_th) if __name__ == '__main__':