diff --git a/datasets.py b/datasets.py
index 138841f..284efe8 100644
--- a/datasets.py
+++ b/datasets.py
@@ -29,10 +29,11 @@ class InferenceDataset(Dataset):
 
     def __getitem__(self, idx):
         image_path = self.all_images[idx]
+        image_name = image_path.name
         img = Image.open(image_path)
         img_tensor = F.to_tensor(img.convert('RGB'))
 
-        return img_tensor
+        return img_tensor, image_name
 
 
 class CustomDataset(Dataset):
diff --git a/inference.py b/inference.py
index 22ed687..702d4ae 100644
--- a/inference.py
+++ b/inference.py
@@ -47,16 +47,15 @@ def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
     print(f'Inference on dataset: {dataset_name}')
 
     prog_bar = tqdm(inference_loader, total=len(inference_loader))
-    for samples, targets in prog_bar:
+    for samples, img_names in prog_bar:
         images = list(image.to(DEVICE) for image in samples)
 
-        img_names = [t['image_name'] for t in targets]
-        targets = [{k: v for k, v in t.items() if k != 'image_name'} for t in targets]
+        # img_names = [t['image_name'] for t in targets]
 
         with torch.inference_mode():
             outputs = model(images)
 
-        for image, img_name, output, target in zip(images, img_names, outputs, targets):
+        for image, img_name, output in zip(images, img_names, outputs):
             plot_inference(image, img_name, output, detection_th, dataset_name)