bf
This commit is contained in:
parent
cbbf839e72
commit
e3c0286d21
16
inference.py
16
inference.py
@ -18,7 +18,7 @@ import matplotlib.gridspec as gridspec
|
|||||||
from matplotlib.patches import Rectangle
|
from matplotlib.patches import Rectangle
|
||||||
|
|
||||||
|
|
||||||
def plot_inference(img_tensor, img_name, output, detection_threshold):
|
def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_name):
|
||||||
|
|
||||||
fig = plt.figure(figsize=IMG_SIZE, num=img_name)
|
fig = plt.figure(figsize=IMG_SIZE, num=img_name)
|
||||||
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
|
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
|
||||||
@ -38,13 +38,13 @@ def plot_inference(img_tensor, img_name, output, detection_threshold):
|
|||||||
)
|
)
|
||||||
|
|
||||||
ax.set_axis_off()
|
ax.set_axis_off()
|
||||||
plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI)
|
plt.savefig(Path(INFERENCE_OUTDIR)/dataset_name/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI)
|
||||||
plt.close()
|
plt.close()
|
||||||
# plt.show()
|
# plt.show()
|
||||||
|
|
||||||
def infere_model(inference_loader, model, detection_th=0.8):
|
def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
|
||||||
|
|
||||||
print('Inference')
|
print(f'Inference on dataset: {dataset_name}')
|
||||||
|
|
||||||
prog_bar = tqdm(inference_loader, total=len(inference_loader))
|
prog_bar = tqdm(inference_loader, total=len(inference_loader))
|
||||||
for samples, targets in prog_bar:
|
for samples, targets in prog_bar:
|
||||||
@ -57,7 +57,7 @@ def infere_model(inference_loader, model, detection_th=0.8):
|
|||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
|
|
||||||
for image, img_name, output, target in zip(images, img_names, outputs, targets):
|
for image, img_name, output, target in zip(images, img_names, outputs, targets):
|
||||||
plot_inference(image, img_name, output, detection_th)
|
plot_inference(image, img_name, output, detection_th, dataset_name)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
@ -69,9 +69,9 @@ def main(args):
|
|||||||
inference_data = InferenceDataset(args.folder)
|
inference_data = InferenceDataset(args.folder)
|
||||||
inference_loader = create_inference_loader(inference_data)
|
inference_loader = create_inference_loader(inference_data)
|
||||||
|
|
||||||
embed()
|
dataset_name = Path(args.folder).name
|
||||||
quit()
|
|
||||||
infere_model(inference_loader, model)
|
infere_model(inference_loader, model, dataset_name)
|
||||||
|
|
||||||
# detection_threshold = 0.8
|
# detection_threshold = 0.8
|
||||||
# frame_count = 0
|
# frame_count = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user