From 323d79b27a88068f40413cc514debbb68802aa64 Mon Sep 17 00:00:00 2001 From: Till Raab Date: Fri, 27 Oct 2023 11:27:31 +0200 Subject: [PATCH] bf --- datasets.py | 3 ++- inference.py | 7 +++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/datasets.py b/datasets.py index 138841f..284efe8 100644 --- a/datasets.py +++ b/datasets.py @@ -29,10 +29,11 @@ class InferenceDataset(Dataset): def __getitem__(self, idx): image_path = self.all_images[idx] + image_name = image_path.name img = Image.open(image_path) img_tensor = F.to_tensor(img.convert('RGB')) - return img_tensor + return img_tensor, image_name class CustomDataset(Dataset): diff --git a/inference.py b/inference.py index 22ed687..702d4ae 100644 --- a/inference.py +++ b/inference.py @@ -47,16 +47,15 @@ def infere_model(inference_loader, model, dataset_name, detection_th=0.8): print(f'Inference on dataset: {dataset_name}') prog_bar = tqdm(inference_loader, total=len(inference_loader)) - for samples, targets in prog_bar: + for samples, img_names in prog_bar: images = list(image.to(DEVICE) for image in samples) - img_names = [t['image_name'] for t in targets] - targets = [{k: v for k, v in t.items() if k != 'image_name'} for t in targets] + # img_names = [t['image_name'] for t in targets] with torch.inference_mode(): outputs = model(images) - for image, img_name, output, target in zip(images, img_names, outputs, targets): + for image, img_name, output in zip(images, img_names, outputs): plot_inference(image, img_name, output, detection_th, dataset_name)