inference code copied here ... this accually was the validation step with plot output.

This commit is contained in:
Till Raab 2023-10-27 09:33:07 +02:00
parent 0243f560be
commit cc6e97c2c8

View File

@ -1,15 +1,18 @@
from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, DATA_DIR)
from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, DATA_DIR, IMG_SIZE, IMG_DPI, INFERENCE_OUTDIR)
from model import create_model
from tqdm.auto import tqdm
from datasets import create_train_loader, create_valid_loader, create_train_or_test_dataset
from custom_utils import Averager, SaveBestModel, save_model, save_loss_plot
from tqdm.auto import tqdm
import os
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle
import time
from pathlib import Path
from IPython import embed
def train(train_loader, model, optimizer, train_loss):
@ -18,8 +21,7 @@ def train(train_loader, model, optimizer, train_loss):
prog_bar = tqdm(train_loader, total=len(train_loader))
for samples, targets in prog_bar:
images = list(image.to(DEVICE) for image in samples)
# targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
# img_names = [t['image_name'] for t in targets]
targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets]
loss_dict = model(images, targets)
@ -43,11 +45,8 @@ def validate(test_loader, model, val_loss):
prog_bar = tqdm(test_loader, total=len(test_loader))
for samples, targets in prog_bar:
images = list(image.to(DEVICE) for image in samples)
# targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets]
with torch.inference_mode():
loss_dict = model(images, targets)
@ -55,15 +54,67 @@ def validate(test_loader, model, val_loss):
loss_value = losses.item()
val_loss_hist.send(loss_value) # this is a global instance !!!
val_loss.append(loss_value)
# optimizer.zero_grad()
# losses.backward()
# optimizer.step()
prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
return val_loss
def best_model_validation_with_plots(test_loader):
model = create_model(num_classes=NUM_CLASSES)
checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE).eval()
validate_with_plots(test_loader, model)
def validate_with_plots(test_loader, model, detection_th=0.8):
print('Final validation with image putput')
prog_bar = tqdm(test_loader, total=len(test_loader))
for samples, targets in prog_bar:
images = list(image.to(DEVICE) for image in samples)
img_names = [t['image_name'] for t in targets]
targets = [{k: v for k, v in t.items() if k != 'image_name'} for t in targets]
with torch.inference_mode():
outputs = model(images)
for image, img_name, output, target in zip(images, img_names, outputs, targets):
plot_validation(image, img_name, output, target, detection_th)
def plot_validation(img_tensor, img_name, output, target, detection_threshold):
fig = plt.figure(figsize=IMG_SIZE, num=img_name)
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
ax = fig.add_subplot(gs[0, 0])
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto')
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
if score < detection_threshold:
continue
# print(x0, y0, x1, y1, l)
ax.text(x0, y0, f'{score:.2f}', ha='left', va='bottom', fontsize=12, color='white')
ax.add_patch(
Rectangle((x0, y0),
(x1 - x0),
(y1 - y0),
fill=False, color="tab:green", linestyle='--', linewidth=2, zorder=10)
)
for (x0, y0, x1, y1), l in zip(target['boxes'], target['labels']):
ax.add_patch(
Rectangle((x0, y0),
(x1 - x0),
(y1 - y0),
fill=False, color="white", linewidth=2, zorder=9)
)
ax.set_axis_off()
plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI)
plt.close()
# plt.show()
if __name__ == '__main__':
train_data = create_train_or_test_dataset(DATA_DIR)
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
@ -91,6 +142,7 @@ if __name__ == '__main__':
val_loss_hist.reset()
train_loss = train(train_loader, model, optimizer, train_loss)
val_loss = validate(test_loader, model, val_loss)
@ -102,3 +154,5 @@ if __name__ == '__main__':
save_loss_plot(OUTDIR, train_loss, val_loss)
# load best model and perform inference with plot output
best_model_validation_with_plots(test_loader)