inference plots reworked
This commit is contained in:
		
							parent
							
								
									22b05aec76
								
							
						
					
					
						commit
						cc4408d75a
					
				
							
								
								
									
										21
									
								
								inference.py
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								inference.py
									
									
									
									
									
								
							| @ -19,22 +19,28 @@ from matplotlib.patches import Rectangle | ||||
| 
 | ||||
| 
 | ||||
| def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_name): | ||||
| 
 | ||||
|     # embed() | ||||
|     # quit() | ||||
|     fig = plt.figure(figsize=IMG_SIZE, num=img_name) | ||||
|     gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1)  # | ||||
|     ax = fig.add_subplot(gs[0, 0]) | ||||
| 
 | ||||
|     ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto', cmap='afmhot') | ||||
|     # ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto', cmap='afmhot') | ||||
|     ax.imshow(img_tensor.cpu().squeeze()[0], aspect='auto', cmap='afmhot', vmin=.2) | ||||
| 
 | ||||
|     for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()): | ||||
|         # embed() | ||||
|         # quit() | ||||
|         if score < detection_threshold: | ||||
|             continue | ||||
|     #     print(x0, y0, x1, y1, l) | ||||
|         ax.text(x0, y0, f'{score:.2f}', ha='left', va='bottom', fontsize=12, color='white') | ||||
|     #     print(score) | ||||
|         ax.text(x0 + (x1 - x0) / 2, y0, f'{score:.2f}', ha='center', va='bottom', fontsize=12, color='tab:gray', rotation=90) | ||||
|         ax.add_patch( | ||||
|             Rectangle((x0, y0), | ||||
|                       (x1 - x0), | ||||
|                       (y1 - y0), | ||||
|                       fill=False, color="tab:gray", linestyle='-', linewidth=2, zorder=10) | ||||
|                       fill=False, color="tab:gray", linestyle='-', linewidth=1, zorder=10, alpha=0.8) | ||||
|         ) | ||||
| 
 | ||||
|     ax.set_axis_off() | ||||
| @ -42,7 +48,7 @@ def plot_inference(img_tensor, img_name, output, detection_threshold, dataset_na | ||||
|     plt.close() | ||||
|     # plt.show() | ||||
| 
 | ||||
| def infere_model(inference_loader, model, dataset_name, detection_th=0.8): | ||||
| def infere_model(inference_loader, model, dataset_name, detection_th=0.8, figures_only=False): | ||||
| 
 | ||||
|     print(f'Inference on dataset: {dataset_name}') | ||||
| 
 | ||||
| @ -75,6 +81,7 @@ def infere_model(inference_loader, model, dataset_name, detection_th=0.8): | ||||
| 
 | ||||
|                 yolo_labels.append([1, rel_x_center, rel_y_center, rel_width, rel_height, score]) | ||||
| 
 | ||||
|             if not figures_only: | ||||
|                 label_path = Path('data') / dataset_name / 'labels' / Path(img_name).with_suffix('.txt') | ||||
|                 np.savetxt(label_path, yolo_labels) | ||||
| 
 | ||||
| @ -96,8 +103,9 @@ def main(args): | ||||
|     if not (Path(INFERENCE_OUTDIR)/dataset_name).exists(): | ||||
|         Path(Path(INFERENCE_OUTDIR)/dataset_name).mkdir(parents=True, exist_ok=True) | ||||
| 
 | ||||
|     infere_model(inference_loader, model, dataset_name) | ||||
|     infere_model(inference_loader, model, dataset_name, figures_only=args.figures_only) | ||||
| 
 | ||||
|     if not args.figures_only: | ||||
|         if (Path('data').absolute() / dataset_name / 'file_dict.csv').exists(): | ||||
|             (Path('data').absolute() / dataset_name / 'file_dict.csv').unlink() | ||||
| 
 | ||||
| @ -124,6 +132,7 @@ def main(args): | ||||
| if __name__ == '__main__': | ||||
|     parser = argparse.ArgumentParser(description='Evaluated electrode array recordings with multiple fish.') | ||||
|     parser.add_argument('folder', type=str, help='folder to infer picutes', default='') | ||||
|     parser.add_argument('-f', '--figures_only', action='store_true', help='only generate figures. keek possible existing labels') | ||||
|     args = parser.parse_args() | ||||
| 
 | ||||
|     main(args) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user