From 61d5a7246b21a4810f9031c1b4a4c0fc839d7589 Mon Sep 17 00:00:00 2001 From: Till Raab Date: Mon, 6 Nov 2023 10:57:04 +0100 Subject: [PATCH] fixed the validation output. generation of folder was missing --- README.md | 1 + train.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/README.md b/README.md index ec92bd2..6dcc87e 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ training and one for testing (both also stored in ./data/dataset). * on a long scale: only save raw file bounding boxes in frequency and time (t0, t1, f0, f1) and the hyperparameters of the corresponding spectrogram. USE THESE PARAMETERS IN DATASET_FN. * rescale image input to (7, 7) * 256 --> width/height in inch * dpi * when dataset input it spectrogram use resize transform. +* replace datastructure with yolo structure ... per pic 1 .csv saved as .txt ## model.py diff --git a/train.py b/train.py index 4a99a3e..56b7389 100644 --- a/train.py +++ b/train.py @@ -12,7 +12,9 @@ import matplotlib.gridspec as gridspec from matplotlib.patches import Rectangle import time +import pathlib from pathlib import Path + from IPython import embed def train(train_loader, model, optimizer, train_loss): @@ -64,6 +66,8 @@ def best_model_validation_with_plots(test_loader): model.load_state_dict(checkpoint["model_state_dict"]) model.to(DEVICE).eval() + if not pathlib.Path(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name).exists(): + pathlib.Path(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name).mkdir(parents=True, exist_ok=True) validate_with_plots(test_loader, model) @@ -111,6 +115,7 @@ def plot_validation(img_tensor, img_name, output, target, detection_threshold): ) ax.set_axis_off() + plt.savefig(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name/(os.path.splitext(img_name)[0] +'_predicted.png'), dpi=IMG_DPI) plt.close() # plt.show()