bf
This commit is contained in:
parent
f3e88078c4
commit
323d79b27a
@ -29,10 +29,11 @@ class InferenceDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image_path = self.all_images[idx]
|
||||
image_name = image_path.name
|
||||
img = Image.open(image_path)
|
||||
img_tensor = F.to_tensor(img.convert('RGB'))
|
||||
|
||||
return img_tensor
|
||||
return img_tensor, image_name
|
||||
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
|
@ -47,16 +47,15 @@ def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
|
||||
print(f'Inference on dataset: {dataset_name}')
|
||||
|
||||
prog_bar = tqdm(inference_loader, total=len(inference_loader))
|
||||
for samples, targets in prog_bar:
|
||||
for samples, img_names in prog_bar:
|
||||
images = list(image.to(DEVICE) for image in samples)
|
||||
|
||||
img_names = [t['image_name'] for t in targets]
|
||||
targets = [{k: v for k, v in t.items() if k != 'image_name'} for t in targets]
|
||||
# img_names = [t['image_name'] for t in targets]
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = model(images)
|
||||
|
||||
for image, img_name, output, target in zip(images, img_names, outputs, targets):
|
||||
for image, img_name, output in zip(images, img_names, outputs):
|
||||
plot_inference(image, img_name, output, detection_th, dataset_name)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user