diff --git a/inference.py b/inference.py index fafb828..558aaea 100644 --- a/inference.py +++ b/inference.py @@ -14,8 +14,8 @@ from tqdm.auto import tqdm import matplotlib.pyplot as plt from matplotlib.patches import Rectangle -def plot_inference(img_tensor, output, target, detection_threshold): - fig, ax = plt.subplots(figsize=(7, 7), num=target['image_id']) +def plot_inference(img_tensor, img_name, output, target, detection_threshold): + fig, ax = plt.subplots(figsize=(7, 7), num=img_name) ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto') for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()): if score < detection_threshold: @@ -46,15 +46,14 @@ def infere_model(test_loader, model, detection_th=0.8): for samples, targets in prog_bar: images = list(image.to(DEVICE) for image in samples) - embed() - quit() - targets = [{k: v for k, v in t.items()} for t in targets] + img_names = [t['image_name'] for t in targets] + targets = [{k: v for k, v in t.items() if k != 'image_name'} for t in targets] with torch.inference_mode(): outputs = model(images) - for image, output, target in zip(images, outputs, targets): - plot_inference(image, output, target, detection_th) + for image, img_name, output, target in zip(images, img_names, outputs, targets): + plot_inference(image, img_name, output, target, detection_th) if __name__ == '__main__': diff --git a/train.py b/train.py index b7bc197..9061287 100644 --- a/train.py +++ b/train.py @@ -19,7 +19,8 @@ def train(train_loader, model, optimizer, train_loss): for samples, targets in prog_bar: images = list(image.to(DEVICE) for image in samples) - targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] + # targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] + targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets] loss_dict = model(images, targets) @@ -43,7 +44,9 @@ def validate(test_loader, model, val_loss): for samples, targets in prog_bar: images = list(image.to(DEVICE) for image in samples) - targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] + # targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] + targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets] + with torch.inference_mode(): loss_dict = model(images, targets)