inference updated

This commit is contained in:
Till Raab 2023-10-26 08:38:14 +02:00
parent fe6e51d625
commit b1e5b9e665
2 changed files with 16 additions and 6 deletions

View File

@ -59,6 +59,7 @@ class CustomDataset(Dataset):
target["iscrowd"] = iscrowd target["iscrowd"] = iscrowd
image_id = torch.tensor([idx]) image_id = torch.tensor([idx])
target["image_id"] = image_id target["image_id"] = image_id
target["image_name"] = image_name
return img_tensor, target return img_tensor, target

View File

@ -14,12 +14,11 @@ from tqdm.auto import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle from matplotlib.patches import Rectangle
def show_sample(img_tensor, outputs, detection_threshold): def plot_inference(img_tensor, output, target, detection_threshold):
fig, ax = plt.subplots() fig, ax = plt.subplots(figsize=(7, 7), num=target['image_name'])
ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto') ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto')
for (x0, y0, x1, y1), l, score in zip(outputs[0]['boxes'].cpu(), outputs[0]['labels'].cpu(), outputs[0]['scores'].cpu()): for (x0, y0, x1, y1), l, score in zip(output[0]['boxes'].cpu(), output[0]['labels'].cpu(), output[0]['scores'].cpu()):
if score < detection_threshold: if score < detection_threshold:
continue continue
# print(x0, y0, x1, y1, l) # print(x0, y0, x1, y1, l)
@ -28,9 +27,18 @@ def show_sample(img_tensor, outputs, detection_threshold):
Rectangle((x0, y0), Rectangle((x0, y0),
(x1 - x0), (x1 - x0),
(y1 - y0), (y1 - y0),
fill=False, color="white", linewidth=2, zorder=10) fill=False, color="tab:green", linewidth=2, zorder=10)
) )
for (x0, y0, x1, y1), l in zip(target['boxes'], target['labels']):
ax.add_patch(
Rectangle((x0, y0),
(x1 - x0),
(y1 - y0),
fill=False, color="white", linewidth=2, zorder=9)
)
plt.show() plt.show()
embed()
def infere_model(test_loader, model, detection_th=0.8): def infere_model(test_loader, model, detection_th=0.8):
@ -44,7 +52,8 @@ def infere_model(test_loader, model, detection_th=0.8):
with torch.inference_mode(): with torch.inference_mode():
outputs = model(images) outputs = model(images)
embed() for image, output, target in zip(images, outputs, targets):
plot_inference(image, output, target, detection_th)
if __name__ == '__main__': if __name__ == '__main__':