check if model still finds stuff
This commit is contained in:
parent
e1a97ac493
commit
a98f488813
16
inference.py
16
inference.py
@ -6,7 +6,8 @@ import os
|
||||
from PIL import Image
|
||||
|
||||
from model import create_model
|
||||
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR
|
||||
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR
|
||||
from datasets import create_train_or_test_dataset, create_valid_loader
|
||||
|
||||
from IPython import embed
|
||||
from tqdm.auto import tqdm
|
||||
@ -15,8 +16,6 @@ from matplotlib.patches import Rectangle
|
||||
|
||||
def show_sample(img_tensor, outputs, detection_threshold):
|
||||
|
||||
# embed()
|
||||
# quit()
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto')
|
||||
for (x0, y0, x1, y1), l, score in zip(outputs[0]['boxes'].cpu(), outputs[0]['labels'].cpu(), outputs[0]['scores'].cpu()):
|
||||
@ -33,21 +32,24 @@ def show_sample(img_tensor, outputs, detection_threshold):
|
||||
)
|
||||
plt.show()
|
||||
|
||||
def infere_model(test_loader, model, detection_th=0.8)
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = create_model(num_classes=NUM_CLASSES)
|
||||
checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.to(DEVICE).eval()
|
||||
|
||||
DIR_TEST = 'data/train'
|
||||
# test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
||||
# test_loader = create_valid_loader(test_data)
|
||||
|
||||
|
||||
test_images = glob.glob(f"{DIR_TEST}/*.png")
|
||||
# infere_model(test_loader, model)
|
||||
|
||||
detection_threshold = 0.8
|
||||
|
||||
frame_count = 0
|
||||
total_fps = 0
|
||||
test_images = glob.glob(f"{TRAIN_DIR}/*.png")
|
||||
|
||||
for i in tqdm(np.arange(len(test_images))):
|
||||
image_name = test_images[i].split(os.path.sep)[-1].split('.')[0]
|
||||
@ -60,4 +62,4 @@ if __name__ == '__main__':
|
||||
|
||||
print(len(outputs[0]['boxes']))
|
||||
|
||||
show_sample(img_tensor, outputs, detection_threshold)
|
||||
# show_sample(img_tensor, outputs, detection_threshold)
|
Loading…
Reference in New Issue
Block a user