This commit is contained in:
Till Raab 2023-10-26 08:40:44 +02:00
parent b89cf321f9
commit 7386bd83a5

View File

@ -16,7 +16,7 @@ from matplotlib.patches import Rectangle
def plot_inference(img_tensor, output, target, detection_threshold): def plot_inference(img_tensor, output, target, detection_threshold):
fig, ax = plt.subplots(figsize=(7, 7), num=target['image_name']) fig, ax = plt.subplots(figsize=(7, 7), num=target['image_id'])
ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto') ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto')
for (x0, y0, x1, y1), l, score in zip(output[0]['boxes'].cpu(), output[0]['labels'].cpu(), output[0]['scores'].cpu()): for (x0, y0, x1, y1), l, score in zip(output[0]['boxes'].cpu(), output[0]['labels'].cpu(), output[0]['scores'].cpu()):
if score < detection_threshold: if score < detection_threshold: