check if model still finds stuff
This commit is contained in:
		
							parent
							
								
									e1a97ac493
								
							
						
					
					
						commit
						a98f488813
					
				
							
								
								
									
										16
									
								
								inference.py
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								inference.py
									
									
									
									
									
								
							| @ -6,7 +6,8 @@ import os | ||||
| from PIL import Image | ||||
| 
 | ||||
| from model import create_model | ||||
| from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR | ||||
| from confic import NUM_CLASSES, DEVICE, CLASSES, OUTDIR, TRAIN_DIR | ||||
| from datasets import create_train_or_test_dataset, create_valid_loader | ||||
| 
 | ||||
| from IPython import embed | ||||
| from tqdm.auto import tqdm | ||||
| @ -15,8 +16,6 @@ from matplotlib.patches import Rectangle | ||||
| 
 | ||||
| def show_sample(img_tensor, outputs, detection_threshold): | ||||
| 
 | ||||
|     # embed() | ||||
|     # quit() | ||||
|     fig, ax = plt.subplots() | ||||
|     ax.imshow(img_tensor.squeeze().permute(1, 2, 0), aspect='auto') | ||||
|     for (x0, y0, x1, y1), l, score in zip(outputs[0]['boxes'].cpu(), outputs[0]['labels'].cpu(), outputs[0]['scores'].cpu()): | ||||
| @ -33,21 +32,24 @@ def show_sample(img_tensor, outputs, detection_threshold): | ||||
|         ) | ||||
|     plt.show() | ||||
| 
 | ||||
| def infere_model(test_loader, model, detection_th=0.8) | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     model = create_model(num_classes=NUM_CLASSES) | ||||
|     checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE) | ||||
|     model.load_state_dict(checkpoint["model_state_dict"]) | ||||
|     model.to(DEVICE).eval() | ||||
| 
 | ||||
|     DIR_TEST = 'data/train' | ||||
|     # test_data = create_train_or_test_dataset(TRAIN_DIR, train=False) | ||||
|     # test_loader = create_valid_loader(test_data) | ||||
| 
 | ||||
| 
 | ||||
|     test_images = glob.glob(f"{DIR_TEST}/*.png") | ||||
|     # infere_model(test_loader, model) | ||||
| 
 | ||||
|     detection_threshold = 0.8 | ||||
| 
 | ||||
|     frame_count = 0 | ||||
|     total_fps = 0 | ||||
|     test_images = glob.glob(f"{TRAIN_DIR}/*.png") | ||||
| 
 | ||||
|     for i in tqdm(np.arange(len(test_images))): | ||||
|         image_name = test_images[i].split(os.path.sep)[-1].split('.')[0] | ||||
| @ -60,4 +62,4 @@ if __name__ == '__main__': | ||||
| 
 | ||||
|         print(len(outputs[0]['boxes'])) | ||||
| 
 | ||||
|         show_sample(img_tensor, outputs, detection_threshold) | ||||
|         # show_sample(img_tensor, outputs, detection_threshold) | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user