From 76cee2c0504c390257cb069ff9393adde0233fdf Mon Sep 17 00:00:00 2001 From: Till Raab Date: Thu, 26 Oct 2023 10:13:48 +0200 Subject: [PATCH] inference images are now saved in a corresponding folder --- inference.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/inference.py b/inference.py index 3501da1..619b9b3 100644 --- a/inference.py +++ b/inference.py @@ -13,10 +13,16 @@ from IPython import embed from pathlib import Path from tqdm.auto import tqdm import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec from matplotlib.patches import Rectangle + def plot_inference(img_tensor, img_name, output, target, detection_threshold): - fig, ax = plt.subplots(figsize=IMG_SIZE, num=img_name) + + fig = plt.figure(figsize=IMG_SIZE, num=img_name) + gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) # + ax = fig.add_subplot(gs[0, 0]) + ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto') for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()): if score < detection_threshold: @@ -38,9 +44,7 @@ def plot_inference(img_tensor, img_name, output, target, detection_threshold): ) ax.set_axis_off() - embed() - quit() - plt.savefig(Path(INFERENCE_OUTDIR)/img_name/'_inferred.png', IMG_DPI) + plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), IMG_DPI) plt.close() # plt.show()