diff --git a/train.py b/train.py index 491149a..4a99a3e 100644 --- a/train.py +++ b/train.py @@ -111,7 +111,7 @@ def plot_validation(img_tensor, img_name, output, target, detection_threshold): ) ax.set_axis_off() - plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_predicted.png'), dpi=IMG_DPI) + plt.savefig(Path(INFERENCE_OUTDIR)/Path(DATA_DIR).name/(os.path.splitext(img_name)[0] +'_predicted.png'), dpi=IMG_DPI) plt.close() # plt.show()