bf
This commit is contained in:
parent
68b3967a21
commit
47cd574354
@ -20,8 +20,8 @@ def plot_inference(img_tensor, output, target, detection_threshold):
|
|||||||
quit()
|
quit()
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(7, 7), num=target['image_id'])
|
fig, ax = plt.subplots(figsize=(7, 7), num=target['image_id'])
|
||||||
ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto')
|
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto')
|
||||||
for (x0, y0, x1, y1), l, score in zip(output[0]['boxes'].cpu(), output[0]['labels'].cpu(), output[0]['scores'].cpu()):
|
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
|
||||||
if score < detection_threshold:
|
if score < detection_threshold:
|
||||||
continue
|
continue
|
||||||
# print(x0, y0, x1, y1, l)
|
# print(x0, y0, x1, y1, l)
|
||||||
@ -50,7 +50,7 @@ def infere_model(test_loader, model, detection_th=0.8):
|
|||||||
prog_bar = tqdm(test_loader, total=len(test_loader))
|
prog_bar = tqdm(test_loader, total=len(test_loader))
|
||||||
for samples, targets in prog_bar:
|
for samples, targets in prog_bar:
|
||||||
images = list(image.to(DEVICE) for image in samples)
|
images = list(image.to(DEVICE) for image in samples)
|
||||||
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
|
targets = [{k: v for k, v in t.items()} for t in targets]
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
|
Loading…
Reference in New Issue
Block a user