This commit is contained in:
Till Raab 2023-10-27 11:20:38 +02:00
parent cbbf839e72
commit e3c0286d21

View File

@ -18,7 +18,7 @@ import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle from matplotlib.patches import Rectangle
def plot_inference(img_tensor, img_name, output, detection_threshold): def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_name):
fig = plt.figure(figsize=IMG_SIZE, num=img_name) fig = plt.figure(figsize=IMG_SIZE, num=img_name)
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) # gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
@ -38,13 +38,13 @@ def plot_inference(img_tensor, img_name, output, detection_threshold):
) )
ax.set_axis_off() ax.set_axis_off()
plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI) plt.savefig(Path(INFERENCE_OUTDIR)/dataset_name/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI)
plt.close() plt.close()
# plt.show() # plt.show()
def infere_model(inference_loader, model, detection_th=0.8): def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
print('Inference') print(f'Inference on dataset: {dataset_name}')
prog_bar = tqdm(inference_loader, total=len(inference_loader)) prog_bar = tqdm(inference_loader, total=len(inference_loader))
for samples, targets in prog_bar: for samples, targets in prog_bar:
@ -57,7 +57,7 @@ def infere_model(inference_loader, model, detection_th=0.8):
outputs = model(images) outputs = model(images)
for image, img_name, output, target in zip(images, img_names, outputs, targets): for image, img_name, output, target in zip(images, img_names, outputs, targets):
plot_inference(image, img_name, output, detection_th) plot_inference(image, img_name, output, detection_th, dataset_name)
def main(args): def main(args):
@ -69,9 +69,9 @@ def main(args):
inference_data = InferenceDataset(args.folder) inference_data = InferenceDataset(args.folder)
inference_loader = create_inference_loader(inference_data) inference_loader = create_inference_loader(inference_data)
embed() dataset_name = Path(args.folder).name
quit()
infere_model(inference_loader, model) infere_model(inference_loader, model, dataset_name)
# detection_threshold = 0.8 # detection_threshold = 0.8
# frame_count = 0 # frame_count = 0