This commit is contained in:
Till Raab 2023-10-27 11:20:38 +02:00
parent cbbf839e72
commit e3c0286d21

View File

@ -18,7 +18,7 @@ import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle
def plot_inference(img_tensor, img_name, output, detection_threshold):
def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_name):
fig = plt.figure(figsize=IMG_SIZE, num=img_name)
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
@ -38,13 +38,13 @@ def plot_inference(img_tensor, img_name, output, detection_threshold):
)
ax.set_axis_off()
plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI)
plt.savefig(Path(INFERENCE_OUTDIR)/dataset_name/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI)
plt.close()
# plt.show()
def infere_model(inference_loader, model, detection_th=0.8):
def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
print('Inference')
print(f'Inference on dataset: {dataset_name}')
prog_bar = tqdm(inference_loader, total=len(inference_loader))
for samples, targets in prog_bar:
@ -57,7 +57,7 @@ def infere_model(inference_loader, model, detection_th=0.8):
outputs = model(images)
for image, img_name, output, target in zip(images, img_names, outputs, targets):
plot_inference(image, img_name, output, detection_th)
plot_inference(image, img_name, output, detection_th, dataset_name)
def main(args):
@ -69,9 +69,9 @@ def main(args):
inference_data = InferenceDataset(args.folder)
inference_loader = create_inference_loader(inference_data)
embed()
quit()
infere_model(inference_loader, model)
dataset_name = Path(args.folder).name
infere_model(inference_loader, model, dataset_name)
# detection_threshold = 0.8
# frame_count = 0