inference plots reworked

This commit is contained in:
Till Raab 2023-12-04 11:25:56 +01:00
parent 22b05aec76
commit cc4408d75a

View File

@ -19,22 +19,28 @@ from matplotlib.patches import Rectangle
def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_name): def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_name):
# embed()
# quit()
fig = plt.figure(figsize=IMG_SIZE, num=img_name) fig = plt.figure(figsize=IMG_SIZE, num=img_name)
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) # gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
ax = fig.add_subplot(gs[0, 0]) ax = fig.add_subplot(gs[0, 0])
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto', cmap='afmhot') # ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto', cmap='afmhot')
ax.imshow(img_tensor.cpu().squeeze()[0], aspect='auto', cmap='afmhot', vmin=.2)
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()): for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
# embed()
# quit()
if score < detection_threshold: if score < detection_threshold:
continue continue
# print(x0, y0, x1, y1, l) # print(x0, y0, x1, y1, l)
ax.text(x0, y0, f'{score:.2f}', ha='left', va='bottom', fontsize=12, color='white') # print(score)
ax.text(x0 + (x1 - x0) / 2, y0, f'{score:.2f}', ha='center', va='bottom', fontsize=12, color='tab:gray', rotation=90)
ax.add_patch( ax.add_patch(
Rectangle((x0, y0), Rectangle((x0, y0),
(x1 - x0), (x1 - x0),
(y1 - y0), (y1 - y0),
fill=False, color="tab:gray", linestyle='-', linewidth=2, zorder=10) fill=False, color="tab:gray", linestyle='-', linewidth=1, zorder=10, alpha=0.8)
) )
ax.set_axis_off() ax.set_axis_off()
@ -42,7 +48,7 @@ def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_na
plt.close() plt.close()
# plt.show() # plt.show()
def infere_model(inference_loader, model, dataset_name, detection_th=0.8): def infere_model(inference_loader, model, dataset_name, detection_th=0.8, figures_only=False):
print(f'Inference on dataset: {dataset_name}') print(f'Inference on dataset: {dataset_name}')
@ -75,6 +81,7 @@ def infere_model(inference_loader, model, dataset_name, detection_th=0.8):
yolo_labels.append([1, rel_x_center, rel_y_center, rel_width, rel_height, score]) yolo_labels.append([1, rel_x_center, rel_y_center, rel_width, rel_height, score])
if not figures_only:
label_path = Path('data') / dataset_name / 'labels' / Path(img_name).with_suffix('.txt') label_path = Path('data') / dataset_name / 'labels' / Path(img_name).with_suffix('.txt')
np.savetxt(label_path, yolo_labels) np.savetxt(label_path, yolo_labels)
@ -96,8 +103,9 @@ def main(args):
if not (Path(INFERENCE_OUTDIR)/dataset_name).exists(): if not (Path(INFERENCE_OUTDIR)/dataset_name).exists():
Path(Path(INFERENCE_OUTDIR)/dataset_name).mkdir(parents=True, exist_ok=True) Path(Path(INFERENCE_OUTDIR)/dataset_name).mkdir(parents=True, exist_ok=True)
infere_model(inference_loader, model, dataset_name) infere_model(inference_loader, model, dataset_name, figures_only=args.figures_only)
if not args.figures_only:
if (Path('data').absolute() / dataset_name / 'file_dict.csv').exists(): if (Path('data').absolute() / dataset_name / 'file_dict.csv').exists():
(Path('data').absolute() / dataset_name / 'file_dict.csv').unlink() (Path('data').absolute() / dataset_name / 'file_dict.csv').unlink()
@ -124,6 +132,7 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.')
parser.add_argument('folder', type=str, help='folder to infer picutes', default='') parser.add_argument('folder', type=str, help='folder to infer picutes', default='')
parser.add_argument('-f', '--figures_only', action='store_true', help='only generate figures. keek possible existing labels')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)