66 lines
2.1 KiB
Python
66 lines
2.1 KiB
Python
import numpy as np
|
|
import torch
|
|
import torchvision.transforms.functional as F
|
|
import glob
|
|
import os
|
|
from PIL import Image
|
|
|
|
from model import create_model
|
|
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR
|
|
from datasets import create_train_or_test_dataset, create_valid_loader
|
|
|
|
from IPython import embed
|
|
from tqdm.auto import tqdm
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.patches import Rectangle
|
|
|
|
def show_sample(img_tensor, outputs, detection_threshold):
|
|
|
|
fig, ax = plt.subplots()
|
|
ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto')
|
|
for (x0, y0, x1, y1), l, score in zip(outputs[0]['boxes'].cpu(), outputs[0]['labels'].cpu(), outputs[0]['scores'].cpu()):
|
|
|
|
if score < detection_threshold:
|
|
continue
|
|
# print(x0, y0, x1, y1, l)
|
|
ax.text(x0, y0, f'{score:.2f}', ha='left', va='bottom', fontsize=12, color='white')
|
|
ax.add_patch(
|
|
Rectangle((x0, y0),
|
|
(x1 - x0),
|
|
(y1 - y0),
|
|
fill=False, color="white", linewidth=2, zorder=10)
|
|
)
|
|
plt.show()
|
|
|
|
def infere_model(test_loader, model, detection_th=0.8):
|
|
pass
|
|
|
|
if __name__ == '__main__':
|
|
model = create_model(num_classes=NUM_CLASSES)
|
|
checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE)
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
model.to(DEVICE).eval()
|
|
|
|
# test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
|
# test_loader = create_valid_loader(test_data)
|
|
|
|
|
|
# infere_model(test_loader, model)
|
|
|
|
detection_threshold = 0.8
|
|
frame_count = 0
|
|
total_fps = 0
|
|
test_images = glob.glob(f"{TRAIN_DIR}/*.png")
|
|
|
|
for i in tqdm(np.arange(len(test_images))):
|
|
image_name = test_images[i].split(os.path.sep)[-1].split('.')[0]
|
|
|
|
img = Image.open(test_images[i])
|
|
img_tensor = F.to_tensor(img.convert('RGB')).unsqueeze(dim=0)
|
|
|
|
with torch.inference_mode():
|
|
outputs = model(img_tensor.to(DEVICE))
|
|
|
|
print(len(outputs[0]['boxes']))
|
|
|
|
# show_sample(img_tensor, outputs, detection_threshold) |