inference updated
This commit is contained in:
parent
fe6e51d625
commit
b1e5b9e665
@ -59,6 +59,7 @@ class CustomDataset(Dataset):
|
|||||||
target["iscrowd"] = iscrowd
|
target["iscrowd"] = iscrowd
|
||||||
image_id = torch.tensor([idx])
|
image_id = torch.tensor([idx])
|
||||||
target["image_id"] = image_id
|
target["image_id"] = image_id
|
||||||
|
target["image_name"] = image_name
|
||||||
|
|
||||||
return img_tensor, target
|
return img_tensor, target
|
||||||
|
|
||||||
|
21
inference.py
21
inference.py
@ -14,12 +14,11 @@ from tqdm.auto import tqdm
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.patches import Rectangle
|
from matplotlib.patches import Rectangle
|
||||||
|
|
||||||
def show_sample(img_tensor, outputs, detection_threshold):
|
def plot_inference(img_tensor, output, target, detection_threshold):
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots(figsize=(7, 7), num=target['image_name'])
|
||||||
ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto')
|
ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto')
|
||||||
for (x0, y0, x1, y1), l, score in zip(outputs[0]['boxes'].cpu(), outputs[0]['labels'].cpu(), outputs[0]['scores'].cpu()):
|
for (x0, y0, x1, y1), l, score in zip(output[0]['boxes'].cpu(), output[0]['labels'].cpu(), output[0]['scores'].cpu()):
|
||||||
|
|
||||||
if score < detection_threshold:
|
if score < detection_threshold:
|
||||||
continue
|
continue
|
||||||
# print(x0, y0, x1, y1, l)
|
# print(x0, y0, x1, y1, l)
|
||||||
@ -28,9 +27,18 @@ def show_sample(img_tensor, outputs, detection_threshold):
|
|||||||
Rectangle((x0, y0),
|
Rectangle((x0, y0),
|
||||||
(x1 - x0),
|
(x1 - x0),
|
||||||
(y1 - y0),
|
(y1 - y0),
|
||||||
fill=False, color="white", linewidth=2, zorder=10)
|
fill=False, color="tab:green", linewidth=2, zorder=10)
|
||||||
)
|
)
|
||||||
|
for (x0, y0, x1, y1), l in zip(target['boxes'], target['labels']):
|
||||||
|
ax.add_patch(
|
||||||
|
Rectangle((x0, y0),
|
||||||
|
(x1 - x0),
|
||||||
|
(y1 - y0),
|
||||||
|
fill=False, color="white", linewidth=2, zorder=9)
|
||||||
|
)
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
embed()
|
||||||
|
|
||||||
def infere_model(test_loader, model, detection_th=0.8):
|
def infere_model(test_loader, model, detection_th=0.8):
|
||||||
|
|
||||||
@ -44,7 +52,8 @@ def infere_model(test_loader, model, detection_th=0.8):
|
|||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
|
|
||||||
embed()
|
for image, output, target in zip(images, outputs, targets):
|
||||||
|
plot_inference(image, output, target, detection_th)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user