From fe6e51d6250f12ddcb5d42ee1509dd8e7fbc117d Mon Sep 17 00:00:00 2001 From: Till Raab Date: Thu, 26 Oct 2023 07:35:41 +0200 Subject: [PATCH] inference function with dataloader to optimizer inference. now generate plots ... --- inference.py | 48 ++++++++++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/inference.py b/inference.py index 01e298f..36953cd 100644 --- a/inference.py +++ b/inference.py @@ -33,7 +33,19 @@ def show_sample(img_tensor, outputs, detection_threshold): plt.show() def infere_model(test_loader, model, detection_th=0.8): - pass + + print('Validation') + + prog_bar = tqdm(test_loader, total=len(test_loader)) + for samples, targets in prog_bar: + images = list(image.to(DEVICE) for image in samples) + targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] + + with torch.inference_mode(): + outputs = model(images) + + embed() + if __name__ == '__main__': model = create_model(num_classes=NUM_CLASSES) @@ -41,26 +53,26 @@ if __name__ == '__main__': model.load_state_dict(checkpoint["model_state_dict"]) model.to(DEVICE).eval() - # test_data = create_train_or_test_dataset(TRAIN_DIR, train=False) - # test_loader = create_valid_loader(test_data) - - - # infere_model(test_loader, model) + test_data = create_train_or_test_dataset(TRAIN_DIR, train=False) + test_loader = create_valid_loader(test_data) - detection_threshold = 0.8 - frame_count = 0 - total_fps = 0 - test_images = glob.glob(f"{TRAIN_DIR}/*.png") - for i in tqdm(np.arange(len(test_images))): - image_name = test_images[i].split(os.path.sep)[-1].split('.')[0] + infere_model(test_loader, model) - img = Image.open(test_images[i]) - img_tensor = F.to_tensor(img.convert('RGB')).unsqueeze(dim=0) - - with torch.inference_mode(): - outputs = model(img_tensor.to(DEVICE)) + # detection_threshold = 0.8 + # frame_count = 0 + # total_fps = 0 + # test_images = glob.glob(f"{TRAIN_DIR}/*.png") - print(len(outputs[0]['boxes'])) + # for i in tqdm(np.arange(len(test_images))): + # image_name = test_images[i].split(os.path.sep)[-1].split('.')[0] + # + # img = Image.open(test_images[i]) + # img_tensor = F.to_tensor(img.convert('RGB')).unsqueeze(dim=0) + # + # with torch.inference_mode(): + # outputs = model(img_tensor.to(DEVICE)) + # + # print(len(outputs[0]['boxes'])) # show_sample(img_tensor, outputs, detection_threshold) \ No newline at end of file