fixed the validation output. generation of folder was missing
This commit is contained in:
parent
4e1784a399
commit
61d5a7246b
@ -31,6 +31,7 @@ training and one for testing (both also stored in ./data/dataset).
|
|||||||
* on a long scale: only save raw file bounding boxes in frequency and time (t0, t1, f0, f1) and the hyperparameters of the corresponding spectrogram. USE THESE PARAMETERS IN DATASET_FN.
|
* on a long scale: only save raw file bounding boxes in frequency and time (t0, t1, f0, f1) and the hyperparameters of the corresponding spectrogram. USE THESE PARAMETERS IN DATASET_FN.
|
||||||
* rescale image input to (7, 7) * 256 --> width/height in inch * dpi
|
* rescale image input to (7, 7) * 256 --> width/height in inch * dpi
|
||||||
* when dataset input it spectrogram use resize transform.
|
* when dataset input it spectrogram use resize transform.
|
||||||
|
* replace datastructure with yolo structure ... per pic 1 .csv saved as .txt
|
||||||
|
|
||||||
## model.py
|
## model.py
|
||||||
|
|
||||||
|
5
train.py
5
train.py
@ -12,7 +12,9 @@ import matplotlib.gridspec as gridspec
|
|||||||
from matplotlib.patches import Rectangle
|
from matplotlib.patches import Rectangle
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import pathlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from IPython import embed
|
from IPython import embed
|
||||||
|
|
||||||
def train(train_loader, model, optimizer, train_loss):
|
def train(train_loader, model, optimizer, train_loss):
|
||||||
@ -64,6 +66,8 @@ def best_model_validation_with_plots(test_loader):
|
|||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
model.to(DEVICE).eval()
|
model.to(DEVICE).eval()
|
||||||
|
|
||||||
|
if not pathlib.Path(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name).exists():
|
||||||
|
pathlib.Path(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name).mkdir(parents=True, exist_ok=True)
|
||||||
validate_with_plots(test_loader, model)
|
validate_with_plots(test_loader, model)
|
||||||
|
|
||||||
|
|
||||||
@ -111,6 +115,7 @@ def plot_validation(img_tensor, img_name, output, target, detection_threshold):
|
|||||||
)
|
)
|
||||||
|
|
||||||
ax.set_axis_off()
|
ax.set_axis_off()
|
||||||
|
|
||||||
plt.savefig(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name/(os.path.splitext(img_name)[0] +'_predicted.png'), dpi=IMG_DPI)
|
plt.savefig(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name/(os.path.splitext(img_name)[0] +'_predicted.png'), dpi=IMG_DPI)
|
||||||
plt.close()
|
plt.close()
|
||||||
# plt.show()
|
# plt.show()
|
||||||
|
Loading…
Reference in New Issue
Block a user