From 3690907f93709cf0930bd43498ec38023f6e7917 Mon Sep 17 00:00:00 2001 From: Till Raab Date: Fri, 27 Oct 2023 11:15:53 +0200 Subject: [PATCH] inference rework --- datasets.py | 33 +++++++++++++++++++++++++++++++++ inference.py | 30 ++++++++++++++++++++---------- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/datasets.py b/datasets.py index 8fc3020..d171391 100644 --- a/datasets.py +++ b/datasets.py @@ -19,6 +19,25 @@ from custom_utils import collate_fn from IPython import embed + +class InferenceDataset(Dataset): + def __init__(self, dir_path): + self.dir_path = dir_path + self.all_images = sorted(list(Path(self.dir_path).rglob(f'*.png'))) + + def __len__(self): + return len(self.all_images) + + def __getitem__(self, idx): + image_name = self.all_images[idx] + image_path = os.path.join(self.dir_path, image_name) + + img = Image.open(image_path) + img_tensor = F.to_tensor(img.convert('RGB')) + + return img_tensor + + class CustomDataset(Dataset): def __init__(self, dir_path, bbox_df): self.dir_path = dir_path @@ -57,6 +76,7 @@ class CustomDataset(Dataset): def __len__(self): return len(self.all_images) + def create_train_or_test_dataset(path, train=True): if train == True: pfx='train' @@ -73,6 +93,7 @@ def create_train_or_test_dataset(path, train=True): bboxes = pd.read_csv(csv_candidates[0], sep=',', index_col=0) return CustomDataset(path, bboxes) + def create_train_loader(train_dataset, num_workers=0): train_loader = DataLoader( train_dataset, @@ -82,6 +103,8 @@ def create_train_loader(train_dataset, num_workers=0): collate_fn=collate_fn ) return train_loader + + def create_valid_loader(valid_dataset, num_workers=0): valid_loader = DataLoader( valid_dataset, @@ -92,6 +115,16 @@ def create_valid_loader(valid_dataset, num_workers=0): ) return valid_loader +def create_inference_loader(inference_dataset, num_workers=0) + inference_loader = DataLoader( + inference_dataset, + batch_size=BATCH_SIZE, + shuffle=False, + num_workers=num_workers, + collate_fn=collate_fn + ) + return inference_loader + if __name__ == '__main__': diff --git a/inference.py b/inference.py index fa3d22c..4c1f83f 100644 --- a/inference.py +++ b/inference.py @@ -4,10 +4,11 @@ import torchvision.transforms.functional as F import glob import os from PIL import Image +import argparse from model import create_model from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, DATA_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE -from datasets import create_train_or_test_dataset, create_valid_loader +from datasets import InferenceDataset, create_inference_loader from IPython import embed from pathlib import Path @@ -41,11 +42,11 @@ def plot_inference(img_tensor, img_name, output, detection_threshold): plt.close() # plt.show() -def infere_model(test_loader, model, detection_th=0.8): +def infere_model(inference_loader, model, detection_th=0.8): - print('Validation') + print('Inference') - prog_bar = tqdm(test_loader, total=len(test_loader)) + prog_bar = tqdm(inference_loader, total=len(inference_loader)) for samples, targets in prog_bar: images = list(image.to(DEVICE) for image in samples) @@ -56,19 +57,21 @@ def infere_model(test_loader, model, detection_th=0.8): outputs = model(images) for image, img_name, output, target in zip(images, img_names, outputs, targets): - plot_inference(image, img_name, output, target, detection_th) + plot_inference(image, img_name, output, detection_th) -if __name__ == '__main__': +def main(args): model = create_model(num_classes=NUM_CLASSES) checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE) model.load_state_dict(checkpoint["model_state_dict"]) model.to(DEVICE).eval() - test_data = create_train_or_test_dataset(DATA_DIR, train=False) - test_loader = create_valid_loader(test_data) + inference_data = InferenceDataset(args.folder) + inference_loader = create_inference_loader(inference_data) - infere_model(test_loader, model) + embed() + quit() + infere_model(inference_loader, model) # detection_threshold = 0.8 # frame_count = 0 @@ -86,4 +89,11 @@ if __name__ == '__main__': # # print(len(outputs[0]['boxes'])) - # show_sample(img_tensor, outputs, detection_threshold) \ No newline at end of file + # show_sample(img_tensor, outputs, detection_threshold) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') + parser.add_argument('folder', type=str, help='folder to infer picutes', default='') + args = parser.parse_args() + + main(args)