diff --git a/inference.py b/inference.py index 687c53f..4797a1f 100644 --- a/inference.py +++ b/inference.py @@ -20,8 +20,8 @@ def plot_inference(img_tensor, output, target, detection_threshold): quit() fig, ax = plt.subplots(figsize=(7, 7), num=target['image_id']) - ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto') - for (x0, y0, x1, y1), l, score in zip(output[0]['boxes'].cpu(), output[0]['labels'].cpu(), output[0]['scores'].cpu()): + ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto') + for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()): if score < detection_threshold: continue # print(x0, y0, x1, y1, l) @@ -50,7 +50,7 @@ def infere_model(test_loader, model, detection_th=0.8): prog_bar = tqdm(test_loader, total=len(test_loader)) for samples, targets in prog_bar: images = list(image.to(DEVICE) for image in samples) - targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] + targets = [{k: v for k, v in t.items()} for t in targets] with torch.inference_mode(): outputs = model(images)