inference code copied here ... this accually was the validation step with plot output.
This commit is contained in:
parent
0243f560be
commit
cc6e97c2c8
82
train.py
82
train.py
@ -1,15 +1,18 @@
|
|||||||
from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, DATA_DIR)
|
from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, DATA_DIR, IMG_SIZE, IMG_DPI, INFERENCE_OUTDIR)
|
||||||
from model import create_model
|
from model import create_model
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
from datasets import create_train_loader, create_valid_loader, create_train_or_test_dataset
|
from datasets import create_train_loader, create_valid_loader, create_train_or_test_dataset
|
||||||
from custom_utils import Averager, SaveBestModel, save_model, save_loss_plot
|
from custom_utils import Averager, SaveBestModel, save_model, save_loss_plot
|
||||||
|
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.gridspec as gridspec
|
||||||
|
from matplotlib.patches import Rectangle
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from IPython import embed
|
from IPython import embed
|
||||||
|
|
||||||
def train(train_loader, model, optimizer, train_loss):
|
def train(train_loader, model, optimizer, train_loss):
|
||||||
@ -18,8 +21,7 @@ def train(train_loader, model, optimizer, train_loss):
|
|||||||
prog_bar = tqdm(train_loader, total=len(train_loader))
|
prog_bar = tqdm(train_loader, total=len(train_loader))
|
||||||
for samples, targets in prog_bar:
|
for samples, targets in prog_bar:
|
||||||
images = list(image.to(DEVICE) for image in samples)
|
images = list(image.to(DEVICE) for image in samples)
|
||||||
|
# img_names = [t['image_name'] for t in targets]
|
||||||
# targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
|
|
||||||
targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets]
|
targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets]
|
||||||
|
|
||||||
loss_dict = model(images, targets)
|
loss_dict = model(images, targets)
|
||||||
@ -43,11 +45,8 @@ def validate(test_loader, model, val_loss):
|
|||||||
prog_bar = tqdm(test_loader, total=len(test_loader))
|
prog_bar = tqdm(test_loader, total=len(test_loader))
|
||||||
for samples, targets in prog_bar:
|
for samples, targets in prog_bar:
|
||||||
images = list(image.to(DEVICE) for image in samples)
|
images = list(image.to(DEVICE) for image in samples)
|
||||||
|
|
||||||
# targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
|
|
||||||
targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets]
|
targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets]
|
||||||
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
loss_dict = model(images, targets)
|
loss_dict = model(images, targets)
|
||||||
|
|
||||||
@ -55,15 +54,67 @@ def validate(test_loader, model, val_loss):
|
|||||||
loss_value = losses.item()
|
loss_value = losses.item()
|
||||||
val_loss_hist.send(loss_value) # this is a global instance !!!
|
val_loss_hist.send(loss_value) # this is a global instance !!!
|
||||||
val_loss.append(loss_value)
|
val_loss.append(loss_value)
|
||||||
|
|
||||||
# optimizer.zero_grad()
|
|
||||||
# losses.backward()
|
|
||||||
# optimizer.step()
|
|
||||||
|
|
||||||
prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
|
prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
|
||||||
|
|
||||||
return val_loss
|
return val_loss
|
||||||
|
|
||||||
|
def best_model_validation_with_plots(test_loader):
|
||||||
|
model = create_model(num_classes=NUM_CLASSES)
|
||||||
|
checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
model.to(DEVICE).eval()
|
||||||
|
|
||||||
|
validate_with_plots(test_loader, model)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_with_plots(test_loader, model, detection_th=0.8):
|
||||||
|
print('Final validation with image putput')
|
||||||
|
|
||||||
|
prog_bar = tqdm(test_loader, total=len(test_loader))
|
||||||
|
for samples, targets in prog_bar:
|
||||||
|
images = list(image.to(DEVICE) for image in samples)
|
||||||
|
|
||||||
|
img_names = [t['image_name'] for t in targets]
|
||||||
|
targets = [{k: v for k, v in t.items() if k != 'image_name'} for t in targets]
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = model(images)
|
||||||
|
|
||||||
|
for image, img_name, output, target in zip(images, img_names, outputs, targets):
|
||||||
|
plot_validation(image, img_name, output, target, detection_th)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_validation(img_tensor, img_name, output, target, detection_threshold):
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=IMG_SIZE, num=img_name)
|
||||||
|
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
|
||||||
|
ax = fig.add_subplot(gs[0, 0])
|
||||||
|
|
||||||
|
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto')
|
||||||
|
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
|
||||||
|
if score < detection_threshold:
|
||||||
|
continue
|
||||||
|
# print(x0, y0, x1, y1, l)
|
||||||
|
ax.text(x0, y0, f'{score:.2f}', ha='left', va='bottom', fontsize=12, color='white')
|
||||||
|
ax.add_patch(
|
||||||
|
Rectangle((x0, y0),
|
||||||
|
(x1 - x0),
|
||||||
|
(y1 - y0),
|
||||||
|
fill=False, color="tab:green", linestyle='--', linewidth=2, zorder=10)
|
||||||
|
)
|
||||||
|
for (x0, y0, x1, y1), l in zip(target['boxes'], target['labels']):
|
||||||
|
ax.add_patch(
|
||||||
|
Rectangle((x0, y0),
|
||||||
|
(x1 - x0),
|
||||||
|
(y1 - y0),
|
||||||
|
fill=False, color="white", linewidth=2, zorder=9)
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_axis_off()
|
||||||
|
plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI)
|
||||||
|
plt.close()
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train_data = create_train_or_test_dataset(DATA_DIR)
|
train_data = create_train_or_test_dataset(DATA_DIR)
|
||||||
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
|
test_data = create_train_or_test_dataset(DATA_DIR, train=False)
|
||||||
@ -91,6 +142,7 @@ if __name__ == '__main__':
|
|||||||
val_loss_hist.reset()
|
val_loss_hist.reset()
|
||||||
|
|
||||||
train_loss = train(train_loader, model, optimizer, train_loss)
|
train_loss = train(train_loader, model, optimizer, train_loss)
|
||||||
|
|
||||||
val_loss = validate(test_loader, model, val_loss)
|
val_loss = validate(test_loader, model, val_loss)
|
||||||
|
|
||||||
|
|
||||||
@ -102,3 +154,5 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
save_loss_plot(OUTDIR, train_loss, val_loss)
|
save_loss_plot(OUTDIR, train_loss, val_loss)
|
||||||
|
|
||||||
|
# load best model and perform inference with plot output
|
||||||
|
best_model_validation_with_plots(test_loader)
|
Loading…
Reference in New Issue
Block a user