efishSignalDetector/train.py

163 lines
5.5 KiB
Python

from confic import (DEVICE, NUM_CLASSES, NUM_EPOCHS, OUTDIR, NUM_WORKERS, DATA_DIR, IMG_SIZE, IMG_DPI, INFERENCE_OUTDIR)
from model import create_model
from datasets import create_train_loader, create_valid_loader, custom_train_test_split
from custom_utils import Averager, SaveBestModel, save_model, save_loss_plot
from tqdm.auto import tqdm
import os
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle
import time
import pathlib
from pathlib import Path
from IPython import embed
def train(train_loader, model, optimizer, train_loss):
print('Training')
prog_bar = tqdm(train_loader, total=len(train_loader))
for samples, targets in prog_bar:
images = list(image.to(DEVICE) for image in samples)
# img_names = [t['image_name'] for t in targets]
targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item()
train_loss_hist.send(loss_value) # this is a global instance !!!
train_loss.append(loss_value)
optimizer.zero_grad()
losses.backward()
optimizer.step()
prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
return train_loss
def validate(test_loader, model, val_loss):
print('Validation')
prog_bar = tqdm(test_loader, total=len(test_loader))
for samples, targets in prog_bar:
images = list(image.to(DEVICE) for image in samples)
targets = [{k: v.to(DEVICE) for k, v in t.items() if k != 'image_name'} for t in targets]
with torch.inference_mode():
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item()
val_loss_hist.send(loss_value) # this is a global instance !!!
val_loss.append(loss_value)
prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
return val_loss
def best_model_validation_with_plots(test_loader):
model = create_model(num_classes=NUM_CLASSES)
checkpoint = torch.load(f'{OUTDIR}/best_model.pth', map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE).eval()
if not pathlib.Path(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name).exists():
pathlib.Path(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name).mkdir(parents=True, exist_ok=True)
validate_with_plots(test_loader, model)
def validate_with_plots(test_loader, model, detection_th=0.8):
print('Final validation with image putput')
prog_bar = tqdm(test_loader, total=len(test_loader))
for samples, targets in prog_bar:
images = list(image.to(DEVICE) for image in samples)
img_names = [t['image_name'] for t in targets]
targets = [{k: v for k, v in t.items() if k != 'image_name'} for t in targets]
with torch.inference_mode():
outputs = model(images)
for image, img_name, output, target in zip(images, img_names, outputs, targets):
plot_validation(image, img_name, output, target, detection_th)
def plot_validation(img_tensor, img_name, output, target, detection_threshold):
fig = plt.figure(figsize=IMG_SIZE, num=img_name)
gs = gridspec.GridSpec(1, 1, bottom=0, left=0, right=1, top=1) #
ax = fig.add_subplot(gs[0, 0])
ax.imshow(img_tensor.cpu().squeeze().permute(1, 2, 0), aspect='auto')
for (x0, y0, x1, y1), l, score in zip(output['boxes'].cpu(), output['labels'].cpu(), output['scores'].cpu()):
if score < detection_threshold:
continue
# print(x0, y0, x1, y1, l)
ax.text(x0, y0, f'{score:.2f}', ha='left', va='bottom', fontsize=12, color='white')
ax.add_patch(
Rectangle((x0, y0),
(x1 - x0),
(y1 - y0),
fill=False, color="tab:green", linestyle='--', linewidth=2, zorder=10)
)
for (x0, y0, x1, y1), l in zip(target['boxes'], target['labels']):
ax.add_patch(
Rectangle((x0, y0),
(x1 - x0),
(y1 - y0),
fill=False, color="white", linewidth=2, zorder=9)
)
ax.set_axis_off()
plt.savefig(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name/(os.path.splitext(img_name)[0] +'_predicted.png'), dpi=IMG_DPI)
plt.close()
# plt.show()
if __name__ == '__main__':
train_data, test_data = custom_train_test_split()
train_loader = create_train_loader(train_data)
test_loader = create_valid_loader(test_data)
model = create_model(num_classes=NUM_CLASSES)
model = model.to(DEVICE)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
train_loss_hist = Averager()
val_loss_hist = Averager()
train_loss = []
val_loss = []
save_best_model = SaveBestModel()
for epoch in range(NUM_EPOCHS):
print(f'\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---')
train_loss_hist.reset()
val_loss_hist.reset()
train_loss = train(train_loader, model, optimizer, train_loss)
val_loss = validate(test_loader, model, val_loss)
save_best_model(
val_loss_hist.value, epoch, model, optimizer
)
save_model(epoch, model, optimizer)
save_loss_plot(OUTDIR, train_loss, val_loss)
# load best model and perform inference with plot output
best_model_validation_with_plots(test_loader)