diff --git a/inference.py b/inference.py index 58048ac..638b967 100644 --- a/inference.py +++ b/inference.py @@ -6,7 +6,8 @@ import os from PIL import Image from model import create_model -from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR +from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR +from datasets import create_train_or_test_dataset, create_valid_loader from IPython import embed from tqdm.auto import tqdm @@ -15,8 +16,6 @@ from matplotlib.patches import Rectangle def show_sample(img_tensor, outputs, detection_threshold): - # embed() - # quit() fig, ax = plt.subplots() ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto') for (x0, y0, x1, y1), l, score in zip(outputs[0]['boxes'].cpu(), outputs[0]['labels'].cpu(), outputs[0]['scores'].cpu()): @@ -33,21 +32,24 @@ def show_sample(img_tensor, outputs, detection_threshold): ) plt.show() +def infere_model(test_loader, model, detection_th=0.8) + if __name__ == '__main__': model = create_model(num_classes=NUM_CLASSES) checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE) model.load_state_dict(checkpoint["model_state_dict"]) model.to(DEVICE).eval() - DIR_TEST = 'data/train' + # test_data = create_train_or_test_dataset(TRAIN_DIR, train=False) + # test_loader = create_valid_loader(test_data) - test_images = glob.glob(f"{DIR_TEST}/*.png") + # infere_model(test_loader, model) detection_threshold = 0.8 - frame_count = 0 total_fps = 0 + test_images = glob.glob(f"{TRAIN_DIR}/*.png") for i in tqdm(np.arange(len(test_images))): image_name = test_images[i].split(os.path.sep)[-1].split('.')[0] @@ -60,4 +62,4 @@ if __name__ == '__main__': print(len(outputs[0]['boxes'])) - show_sample(img_tensor, outputs, detection_threshold) \ No newline at end of file + # show_sample(img_tensor, outputs, detection_threshold) \ No newline at end of file