diff --git a/train.py b/train.py index db2c584..491149a 100644 --- a/train.py +++ b/train.py @@ -111,7 +111,7 @@ def plot_validation(img_tensor, img_name, output, target, detection_threshold): ) ax.set_axis_off() - plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_inferred.png'), dpi=IMG_DPI) + plt.savefig(Path(INFERENCE_OUTDIR)/(os.path.splitext(img_name)[0] +'_predicted.png'), dpi=IMG_DPI) plt.close() # plt.show()