inference rework

This commit is contained in:
Till Raab 2023-10-27 11:15:53 +02:00
parent 43f0e0a27d
commit 3690907f93
2 changed files with 53 additions and 10 deletions

View File

@ -19,6 +19,25 @@ from custom_utils import collate_fn
from IPython import embed
class InferenceDataset(Dataset):
def __init__(self, dir_path):
self.dir_path = dir_path
self.all_images = sorted(list(Path(self.dir_path).rglob(f'*.png')))
def __len__(self):
return len(self.all_images)
def __getitem__(self, idx):
image_name = self.all_images[idx]
image_path = os.path.join(self.dir_path, image_name)
img = Image.open(image_path)
img_tensor = F.to_tensor(img.convert('RGB'))
return img_tensor
class CustomDataset(Dataset):
def __init__(self, dir_path, bbox_df):
self.dir_path = dir_path
@ -57,6 +76,7 @@ class CustomDataset(Dataset):
def __len__(self):
return len(self.all_images)
def create_train_or_test_dataset(path, train=True):
if train == True:
pfx='train'
@ -73,6 +93,7 @@ def create_train_or_test_dataset(path, train=True):
bboxes = pd.read_csv(csv_candidates[0], sep=',', index_col=0)
return CustomDataset(path, bboxes)
def create_train_loader(train_dataset, num_workers=0):
train_loader = DataLoader(
train_dataset,
@ -82,6 +103,8 @@ def create_train_loader(train_dataset, num_workers=0):
collate_fn=collate_fn
)
return train_loader
def create_valid_loader(valid_dataset, num_workers=0):
valid_loader = DataLoader(
valid_dataset,
@ -92,6 +115,16 @@ def create_valid_loader(valid_dataset, num_workers=0):
)
return valid_loader
def create_inference_loader(inference_dataset, num_workers=0)
inference_loader = DataLoader(
inference_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn
)
return inference_loader
if __name__ == '__main__':

View File

@ -4,10 +4,11 @@ import torchvision.transforms.functional as F
import glob
import os
from PIL import Image
import argparse
from model import create_model
from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, DATA_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE
from datasets import create_train_or_test_dataset, create_valid_loader
from datasets import InferenceDataset, create_inference_loader
from IPython import embed
from pathlib import Path
@ -41,11 +42,11 @@ def plot_inference(img_tensor, img_name, output, detection_threshold):
plt.close()
# plt.show()
def infere_model(test_loader, model, detection_th=0.8):
def infere_model(inference_loader, model, detection_th=0.8):
print('Validation')
print('Inference')
prog_bar = tqdm(test_loader, total=len(test_loader))
prog_bar = tqdm(inference_loader, total=len(inference_loader))
for samples, targets in prog_bar:
images = list(image.to(DEVICE) for image in samples)
@ -56,19 +57,21 @@ def infere_model(test_loader, model, detection_th=0.8):
outputs = model(images)
for image, img_name, output, target in zip(images, img_names, outputs, targets):
plot_inference(image, img_name, output, target, detection_th)
plot_inference(image, img_name, output, detection_th)
if __name__ == '__main__':
def main(args):
model = create_model(num_classes=NUM_CLASSES)
checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE).eval()
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
test_loader = create_valid_loader(test_data)
inference_data = InferenceDataset(args.folder)
inference_loader = create_inference_loader(inference_data)
infere_model(test_loader, model)
embed()
quit()
infere_model(inference_loader, model)
# detection_threshold = 0.8
# frame_count = 0
@ -86,4 +89,11 @@ if __name__ == '__main__':
#
# print(len(outputs[0]['boxes']))
# show_sample(img_tensor, outputs, detection_threshold)
# show_sample(img_tensor, outputs, detection_threshold)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
parser.add_argument('folder', type=str, help='folder to infer picutes', default='')
args = parser.parse_args()
main(args)