From e3c0286d21d75eaa9e6186b48e79a9ee810bb0e8 Mon Sep 17 00:00:00 2001
From: Till Raab <till.raab@uni-tuebingen.de>
Date: Fri, 27 Oct 2023 11:20:38 +0200
Subject: [PATCH] bf

---
 inference.py | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/inference.py b/inference.py
index 4c1f83f..22ed687 100644
--- a/inference.py
+++ b/inference.py
@@ -18,7 +18,7 @@ import matplotlib.gridspec as gridspec
 from matplotlib.patches import Rectangle
 
 
-def plot_inference(img_tensor, img_name, output, detection_threshold):
+def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_name):
 
     fig = plt.figure(figsize=IMG_SIZE, num=img_name)
     gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1)  #
@@ -38,13 +38,13 @@ def plot_inference(img_tensor, img_name, output, detection_threshold):
         )
 
     ax.set_axis_off()
-    plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI)
+    plt.savefig(Path(INFERENCE_OUTDIR)/dataset_name/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI)
     plt.close()
     # plt.show()
 
-def infere_model(inference_loader, model, detection_th=0.8):
+def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
 
-    print('Inference')
+    print(f'Inference on dataset: {dataset_name}')
 
     prog_bar = tqdm(inference_loader, total=len(inference_loader))
     for samples, targets in prog_bar:
@@ -57,7 +57,7 @@ def infere_model(inference_loader, model, detection_th=0.8):
             outputs = model(images)
 
         for image, img_name, output, target in zip(images, img_names, outputs, targets):
-            plot_inference(image, img_name, output, detection_th)
+            plot_inference(image, img_name, output, detection_th, dataset_name)
 
 
 def main(args):
@@ -69,9 +69,9 @@ def main(args):
     inference_data = InferenceDataset(args.folder)
     inference_loader = create_inference_loader(inference_data)
 
-    embed()
-    quit()
-    infere_model(inference_loader, model)
+    dataset_name = Path(args.folder).name
+
+    infere_model(inference_loader, model, dataset_name)
 
     # detection_threshold = 0.8
     # frame_count = 0