small fix
This commit is contained in:
parent
c2256c7001
commit
4bbf3607e9
@ -67,6 +67,8 @@ gradient tracking) is computed and used to infer whether the model is better tha
|
|||||||
of the previous epochs. If the new model is the best model, the model.state_dict is saved in
|
of the previous epochs. If the new model is the best model, the model.state_dict is saved in
|
||||||
./model_outputs as best_model.pth.
|
./model_outputs as best_model.pth.
|
||||||
|
|
||||||
|
## ToDos:
|
||||||
|
*
|
||||||
|
|
||||||
### inference.py
|
### inference.py
|
||||||
|
|
||||||
|
2
train.py
2
train.py
@ -69,7 +69,7 @@ if __name__ == '__main__':
|
|||||||
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
test_data = create_train_or_test_dataset(TRAIN_DIR, train=False)
|
||||||
|
|
||||||
train_loader = create_train_loader(train_data)
|
train_loader = create_train_loader(train_data)
|
||||||
test_loader = create_train_loader(test_data)
|
test_loader = create_valid_loader(test_data)
|
||||||
|
|
||||||
model = create_model(num_classes=NUM_CLASSES)
|
model = create_model(num_classes=NUM_CLASSES)
|
||||||
model = model.to(DEVICE)
|
model = model.to(DEVICE)
|
||||||
|
Loading…
Reference in New Issue
Block a user