diff --git a/confic.py b/confic.py index c08cab2..b3c93de 100644 --- a/confic.py +++ b/confic.py @@ -6,6 +6,9 @@ RESIZE_TO = 416 NUM_EPOCHS = 10 NUM_WORKERS = 4 +IMG_SIZE = (7, 7) # inches +IMG_DPI = 256 + DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') TRAIN_DIR = 'data/train' @@ -15,7 +18,9 @@ CLASSES = ['__backgroud__', '1'] NUM_CLASSES = len(CLASSES) OUTDIR = 'model_outputs' - - if not pathlib.Path(OUTDIR).exists(): - pathlib.Path(OUTDIR).mkdir(parents=True, exist_ok=True) \ No newline at end of file + pathlib.Path(OUTDIR).mkdir(parents=True, exist_ok=True) + +INFERENCE_OUTDIR = 'inference_outputs' +if not pathlib.Path(INFERENCE_OUTDIR).exists(): + pathlib.Path(INFERENCE_OUTDIR).mkdir(parents=True, exist_ok=True) \ No newline at end of file diff --git a/inference.py b/inference.py index 558aaea..3501da1 100644 --- a/inference.py +++ b/inference.py @@ -6,16 +6,17 @@ import os from PIL import Image from model import create_model -from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR +from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR, INFERENCE_OUTDIR, IMG_DPI, IMG_SIZE from datasets import create_train_or_test_dataset, create_valid_loader from IPython import embed +from pathlib import Path from tqdm.auto import tqdm import matplotlib.pyplot as plt from matplotlib.patches import Rectangle def plot_inference(img_tensor, img_name, output, target, detection_threshold): - fig, ax = plt.subplots(figsize=(7, 7), num=img_name) + fig, ax = plt.subplots(figsize=IMG_SIZE, num=img_name) ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto') for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()): if score < detection_threshold: @@ -36,7 +37,12 @@ def plot_inference(img_tensor, img_name, output, target, detection_threshold): fill=False, color="white", linewidth=2, zorder=9) ) - plt.show() + ax.set_axis_off() + embed() + quit() + plt.savefig(Path(INFERENCE_OUTDIR)/img_name/'_inferred.png', IMG_DPI) + plt.close() + # plt.show() def infere_model(test_loader, model, detection_th=0.8):