diff --git a/inference.py b/inference.py index 4c1f83f..22ed687 100644 --- a/inference.py +++ b/inference.py @@ -18,7 +18,7 @@ import matplotlib.gridspec as gridspec from matplotlib.patches import Rectangle -def plot_inference(img_tensor, img_name, output, detection_threshold): +def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_name): fig = plt.figure(figsize=IMG_SIZE, num=img_name) gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) # @@ -38,13 +38,13 @@ def plot_inference(img_tensor, img_name, output, detection_threshold): ) ax.set_axis_off() - plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI) + plt.savefig(Path(INFERENCE_OUTDIR)/dataset_name/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI) plt.close() # plt.show() -def infere_model(inference_loader, model, detection_th=0.8): +def infere_model(inference_loader, model, dataset_name, detection_th=0.8): - print('Inference') + print(f'Inference on dataset: {dataset_name}') prog_bar = tqdm(inference_loader, total=len(inference_loader)) for samples, targets in prog_bar: @@ -57,7 +57,7 @@ def infere_model(inference_loader, model, detection_th=0.8): outputs = model(images) for image, img_name, output, target in zip(images, img_names, outputs, targets): - plot_inference(image, img_name, output, detection_th) + plot_inference(image, img_name, output, detection_th, dataset_name) def main(args): @@ -69,9 +69,9 @@ def main(args): inference_data = InferenceDataset(args.folder) inference_loader = create_inference_loader(inference_data) - embed() - quit() - infere_model(inference_loader, model) + dataset_name = Path(args.folder).name + + infere_model(inference_loader, model, dataset_name) # detection_threshold = 0.8 # frame_count = 0