diff --git a/README.md b/README.md index 912d513..3caaf16 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,8 @@ gradient tracking) is computed and used to infer whether the model is better tha of the previous epochs. If the new model is the best model, the model.state_dict is saved in ./model_outputs as best_model.pth. +## ToDos: +* ### inference.py diff --git a/train.py b/train.py index ca57852..16380a1 100644 --- a/train.py +++ b/train.py @@ -69,7 +69,7 @@ if __name__ == '__main__': test_data = create_train_or_test_dataset(TRAIN_DIR, train=False) train_loader = create_train_loader(train_data) - test_loader = create_train_loader(test_data) + test_loader = create_valid_loader(test_data) model = create_model(num_classes=NUM_CLASSES) model = model.to(DEVICE)