bf
This commit is contained in:
parent
cbbf839e72
commit
e3c0286d21
16
inference.py
16
inference.py
@ -18,7 +18,7 @@ import matplotlib.gridspec as gridspec
|
||||
from matplotlib.patches import Rectangle
|
||||
|
||||
|
||||
def plot_inference(img_tensor, img_name, output, detection_threshold):
|
||||
def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_name):
|
||||
|
||||
fig = plt.figure(figsize=IMG_SIZE, num=img_name)
|
||||
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
|
||||
@ -38,13 +38,13 @@ def plot_inference(img_tensor, img_name, output, detection_threshold):
|
||||
)
|
||||
|
||||
ax.set_axis_off()
|
||||
plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI)
|
||||
plt.savefig(Path(INFERENCE_OUTDIR)/dataset_name/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI)
|
||||
plt.close()
|
||||
# plt.show()
|
||||
|
||||
def infere_model(inference_loader, model, detection_th=0.8):
|
||||
def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
|
||||
|
||||
print('Inference')
|
||||
print(f'Inference on dataset: {dataset_name}')
|
||||
|
||||
prog_bar = tqdm(inference_loader, total=len(inference_loader))
|
||||
for samples, targets in prog_bar:
|
||||
@ -57,7 +57,7 @@ def infere_model(inference_loader, model, detection_th=0.8):
|
||||
outputs = model(images)
|
||||
|
||||
for image, img_name, output, target in zip(images, img_names, outputs, targets):
|
||||
plot_inference(image, img_name, output, detection_th)
|
||||
plot_inference(image, img_name, output, detection_th, dataset_name)
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -69,9 +69,9 @@ def main(args):
|
||||
inference_data = InferenceDataset(args.folder)
|
||||
inference_loader = create_inference_loader(inference_data)
|
||||
|
||||
embed()
|
||||
quit()
|
||||
infere_model(inference_loader, model)
|
||||
dataset_name = Path(args.folder).name
|
||||
|
||||
infere_model(inference_loader, model, dataset_name)
|
||||
|
||||
# detection_threshold = 0.8
|
||||
# frame_count = 0
|
||||
|
Loading…
Reference in New Issue
Block a user