inference code copied here ... this accually was the validation step with plot output.
This commit is contained in:
		
							parent
							
								
									0243f560be
								
							
						
					
					
						commit
						cc6e97c2c8
					
				
							
								
								
									
										82
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										82
									
								
								train.py
									
									
									
									
									
								
							| @ -1,15 +1,18 @@ | |||||||
| from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, DATA_DIR) | from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, DATA_DIR, IMG_SIZE, IMG_DPI, INFERENCE_OUTDIR) | ||||||
| from model import create_model | from model import create_model | ||||||
| 
 |  | ||||||
| from tqdm.auto import tqdm |  | ||||||
| 
 |  | ||||||
| from datasets import create_train_loader, create_valid_loader, create_train_or_test_dataset | from datasets import create_train_loader, create_valid_loader, create_train_or_test_dataset | ||||||
| from custom_utils import Averager, SaveBestModel, save_model, save_loss_plot | from custom_utils import Averager, SaveBestModel, save_model, save_loss_plot | ||||||
| 
 | 
 | ||||||
|  | from tqdm.auto import tqdm | ||||||
|  | 
 | ||||||
|  | import os | ||||||
| import torch | import torch | ||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
|  | import matplotlib.gridspec as gridspec | ||||||
|  | from matplotlib.patches import Rectangle | ||||||
| import time | import time | ||||||
| 
 | 
 | ||||||
|  | from pathlib import Path | ||||||
| from IPython import embed | from IPython import embed | ||||||
| 
 | 
 | ||||||
| def train(train_loader, model, optimizer, train_loss): | def train(train_loader, model, optimizer, train_loss): | ||||||
| @ -18,8 +21,7 @@ def train(train_loader, model, optimizer, train_loss): | |||||||
|     prog_bar = tqdm(train_loader, total=len(train_loader)) |     prog_bar = tqdm(train_loader, total=len(train_loader)) | ||||||
|     for samples, targets in prog_bar: |     for samples, targets in prog_bar: | ||||||
|         images = list(image.to(DEVICE) for image in samples) |         images = list(image.to(DEVICE) for image in samples) | ||||||
| 
 |         # img_names = [t['image_name'] for t in targets] | ||||||
|         # targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] |  | ||||||
|         targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets] |         targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets] | ||||||
| 
 | 
 | ||||||
|         loss_dict = model(images, targets) |         loss_dict = model(images, targets) | ||||||
| @ -43,11 +45,8 @@ def validate(test_loader, model, val_loss): | |||||||
|     prog_bar = tqdm(test_loader, total=len(test_loader)) |     prog_bar = tqdm(test_loader, total=len(test_loader)) | ||||||
|     for samples, targets in prog_bar: |     for samples, targets in prog_bar: | ||||||
|         images = list(image.to(DEVICE) for image in samples) |         images = list(image.to(DEVICE) for image in samples) | ||||||
| 
 |  | ||||||
|         # targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets] |  | ||||||
|         targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets] |         targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets] | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|         with torch.inference_mode(): |         with torch.inference_mode(): | ||||||
|             loss_dict = model(images, targets) |             loss_dict = model(images, targets) | ||||||
| 
 | 
 | ||||||
| @ -55,15 +54,67 @@ def validate(test_loader, model, val_loss): | |||||||
|         loss_value = losses.item() |         loss_value = losses.item() | ||||||
|         val_loss_hist.send(loss_value) # this is a global instance !!! |         val_loss_hist.send(loss_value) # this is a global instance !!! | ||||||
|         val_loss.append(loss_value) |         val_loss.append(loss_value) | ||||||
| 
 |  | ||||||
|         # optimizer.zero_grad() |  | ||||||
|         # losses.backward() |  | ||||||
|         # optimizer.step() |  | ||||||
| 
 |  | ||||||
|         prog_bar.set_description(desc=f"Loss: {loss_value:.4f}") |         prog_bar.set_description(desc=f"Loss: {loss_value:.4f}") | ||||||
| 
 | 
 | ||||||
|     return val_loss |     return val_loss | ||||||
| 
 | 
 | ||||||
|  | def best_model_validation_with_plots(test_loader): | ||||||
|  |     model = create_model(num_classes=NUM_CLASSES) | ||||||
|  |     checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE) | ||||||
|  |     model.load_state_dict(checkpoint["model_state_dict"]) | ||||||
|  |     model.to(DEVICE).eval() | ||||||
|  | 
 | ||||||
|  |     validate_with_plots(test_loader, model) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def validate_with_plots(test_loader, model, detection_th=0.8): | ||||||
|  |     print('Final validation with image putput') | ||||||
|  | 
 | ||||||
|  |     prog_bar = tqdm(test_loader, total=len(test_loader)) | ||||||
|  |     for samples, targets in prog_bar: | ||||||
|  |         images = list(image.to(DEVICE) for image in samples) | ||||||
|  | 
 | ||||||
|  |         img_names = [t['image_name'] for t in targets] | ||||||
|  |         targets = [{k: v for k, v in t.items() if k != 'image_name'} for t in targets] | ||||||
|  | 
 | ||||||
|  |         with torch.inference_mode(): | ||||||
|  |             outputs = model(images) | ||||||
|  | 
 | ||||||
|  |         for image, img_name, output, target in zip(images, img_names, outputs, targets): | ||||||
|  |             plot_validation(image, img_name, output, target, detection_th) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def plot_validation(img_tensor, img_name, output, target, detection_threshold): | ||||||
|  | 
 | ||||||
|  |     fig = plt.figure(figsize=IMG_SIZE, num=img_name) | ||||||
|  |     gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1)  # | ||||||
|  |     ax = fig.add_subplot(gs[0, 0]) | ||||||
|  | 
 | ||||||
|  |     ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0),  aspect='auto') | ||||||
|  |     for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()): | ||||||
|  |         if score < detection_threshold: | ||||||
|  |             continue | ||||||
|  |     #     print(x0, y0, x1, y1, l) | ||||||
|  |         ax.text(x0, y0, f'{score:.2f}', ha='left', va='bottom', fontsize=12, color='white') | ||||||
|  |         ax.add_patch( | ||||||
|  |             Rectangle((x0, y0), | ||||||
|  |                       (x1 - x0), | ||||||
|  |                       (y1 - y0), | ||||||
|  |                       fill=False, color="tab:green", linestyle='--', linewidth=2, zorder=10) | ||||||
|  |         ) | ||||||
|  |     for (x0, y0, x1, y1), l in zip(target['boxes'], target['labels']): | ||||||
|  |         ax.add_patch( | ||||||
|  |             Rectangle((x0, y0), | ||||||
|  |                       (x1 - x0), | ||||||
|  |                       (y1 - y0), | ||||||
|  |                       fill=False, color="white", linewidth=2, zorder=9) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     ax.set_axis_off() | ||||||
|  |     plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI) | ||||||
|  |     plt.close() | ||||||
|  |     # plt.show() | ||||||
|  | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     train_data = create_train_or_test_dataset(DATA_DIR) |     train_data = create_train_or_test_dataset(DATA_DIR) | ||||||
|     test_data = create_train_or_test_dataset(DATA_DIR, train=False) |     test_data = create_train_or_test_dataset(DATA_DIR, train=False) | ||||||
| @ -91,6 +142,7 @@ if __name__ == '__main__': | |||||||
|         val_loss_hist.reset() |         val_loss_hist.reset() | ||||||
| 
 | 
 | ||||||
|         train_loss = train(train_loader, model, optimizer, train_loss) |         train_loss = train(train_loader, model, optimizer, train_loss) | ||||||
|  | 
 | ||||||
|         val_loss = validate(test_loader, model, val_loss) |         val_loss = validate(test_loader, model, val_loss) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -102,3 +154,5 @@ if __name__ == '__main__': | |||||||
| 
 | 
 | ||||||
|         save_loss_plot(OUTDIR, train_loss, val_loss) |         save_loss_plot(OUTDIR, train_loss, val_loss) | ||||||
| 
 | 
 | ||||||
|  |     # load best model and perform inference with plot output | ||||||
|  |     best_model_validation_with_plots(test_loader) | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user