fixed the validation output. generation of folder was missing

This commit is contained in:
Till Raab 2023-11-06 10:57:04 +01:00
parent 4e1784a399
commit 61d5a7246b
2 changed files with 6 additions and 0 deletions

View File

@ -31,6 +31,7 @@ training and one for testing (both also stored in ./data/dataset).
* on a long scale: only save raw file bounding boxes in frequency and time (t0, t1, f0, f1) and the hyperparameters of the corresponding spectrogram. USE THESE PARAMETERS IN DATASET_FN. * on a long scale: only save raw file bounding boxes in frequency and time (t0, t1, f0, f1) and the hyperparameters of the corresponding spectrogram. USE THESE PARAMETERS IN DATASET_FN.
* rescale image input to (7, 7) * 256 --> width/height in inch * dpi * rescale image input to (7, 7) * 256 --> width/height in inch * dpi
* when dataset input it spectrogram use resize transform. * when dataset input it spectrogram use resize transform.
* replace datastructure with yolo structure ... per pic 1 .csv saved as .txt
## model.py ## model.py

View File

@ -12,7 +12,9 @@ import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle from matplotlib.patches import Rectangle
import time import time
import pathlib
from pathlib import Path from pathlib import Path
from IPython import embed from IPython import embed
def train(train_loader, model, optimizer, train_loss): def train(train_loader, model, optimizer, train_loss):
@ -64,6 +66,8 @@ def best_model_validation_with_plots(test_loader):
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE).eval() model.to(DEVICE).eval()
if not pathlib.Path(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name).exists():
pathlib.Path(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name).mkdir(parents=True, exist_ok=True)
validate_with_plots(test_loader, model) validate_with_plots(test_loader, model)
@ -111,6 +115,7 @@ def plot_validation(img_tensor, img_name, output, target, detection_threshold):
) )
ax.set_axis_off() ax.set_axis_off()
plt.savefig(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name/(os.path.splitext(img_name)[0] +'_predicted.png'), dpi=IMG_DPI) plt.savefig(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name/(os.path.splitext(img_name)[0] +'_predicted.png'), dpi=IMG_DPI)
plt.close() plt.close()
# plt.show() # plt.show()