inference images are now saved in a corresponding folder

This commit is contained in:
Till Raab 2023-10-26 10:13:48 +02:00
parent 391379551a
commit 76cee2c050

View File

@ -13,10 +13,16 @@ from IPython import embed
from pathlib import Path from pathlib import Path
from tqdm.auto import tqdm from tqdm.auto import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle from matplotlib.patches import Rectangle
def plot_inference(img_tensor, img_name, output, target, detection_threshold): def plot_inference(img_tensor, img_name, output, target, detection_threshold):
fig, ax = plt.subplots(figsize=IMG_SIZE, num=img_name)
fig = plt.figure(figsize=IMG_SIZE, num=img_name)
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
ax = fig.add_subplot(gs[0, 0])
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto') ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto')
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()): for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
if score < detection_threshold: if score < detection_threshold:
@ -38,9 +44,7 @@ def plot_inference(img_tensor, img_name, output, target, detection_threshold):
) )
ax.set_axis_off() ax.set_axis_off()
embed() plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), IMG_DPI)
quit()
plt.savefig(Path(INFERENCE_OUTDIR)/img_name/'_inferred.png', IMG_DPI)
plt.close() plt.close()
# plt.show() # plt.show()