diff --git a/inference.py b/inference.py index 9615779..b840c0f 100644 --- a/inference.py +++ b/inference.py @@ -19,22 +19,28 @@ from matplotlib.patches import Rectangle def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_name): - + # embed() + # quit() fig = plt.figure(figsize=IMG_SIZE, num=img_name) gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) # ax = fig.add_subplot(gs[0, 0]) - ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto', cmap='afmhot') + # ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto', cmap='afmhot') + ax.imshow(img_tensor.cpu().squeeze()[0], aspect='auto', cmap='afmhot', vmin=.2) + for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()): + # embed() + # quit() if score < detection_threshold: continue # print(x0, y0, x1, y1, l) - ax.text(x0, y0, f'{score:.2f}', ha='left', va='bottom', fontsize=12, color='white') + # print(score) + ax.text(x0 + (x1 - x0) / 2, y0, f'{score:.2f}', ha='center', va='bottom', fontsize=12, color='tab:gray', rotation=90) ax.add_patch( Rectangle((x0, y0), (x1 - x0), (y1 - y0), - fill=False, color="tab:gray", linestyle='-', linewidth=2, zorder=10) + fill=False, color="tab:gray", linestyle='-', linewidth=1, zorder=10, alpha=0.8) ) ax.set_axis_off() @@ -42,7 +48,7 @@ def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_na plt.close() # plt.show() -def infere_model(inference_loader, model, dataset_name, detection_th=0.8): +def infere_model(inference_loader, model, dataset_name, detection_th=0.8, figures_only=False): print(f'Inference on dataset: {dataset_name}') @@ -75,8 +81,9 @@ def infere_model(inference_loader, model, dataset_name, detection_th=0.8): yolo_labels.append([1, rel_x_center, rel_y_center, rel_width, rel_height, score]) - label_path = Path('data') / dataset_name / 'labels' / Path(img_name).with_suffix('.txt') - np.savetxt(label_path, yolo_labels) + if not figures_only: + label_path = Path('data') / dataset_name / 'labels' / Path(img_name).with_suffix('.txt') + np.savetxt(label_path, yolo_labels) plot_inference(image, img_name, output, detection_th, dataset_name) @@ -96,10 +103,11 @@ def main(args): if not (Path(INFERENCE_OUTDIR)/dataset_name).exists(): Path(Path(INFERENCE_OUTDIR)/dataset_name).mkdir(parents=True, exist_ok=True) - infere_model(inference_loader, model, dataset_name) + infere_model(inference_loader, model, dataset_name, figures_only=args.figures_only) - if (Path('data').absolute() / dataset_name / 'file_dict.csv').exists(): - (Path('data').absolute() / dataset_name / 'file_dict.csv').unlink() + if not args.figures_only: + if (Path('data').absolute() / dataset_name / 'file_dict.csv').exists(): + (Path('data').absolute() / dataset_name / 'file_dict.csv').unlink() @@ -124,6 +132,7 @@ def main(args): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') parser.add_argument('folder', type=str, help='folder to infer picutes', default='') + parser.add_argument('-f', '--figures_only', action='store_true', help='only generate figures. keek possible existing labels') args = parser.parse_args() main(args)