Updates
This commit is contained in:
22
cli/main.py
22
cli/main.py
@@ -10,9 +10,10 @@ Usage:
|
||||
digits models.mlp.MLPClassifier -a
|
||||
digits models.cnn.CNNClassifier --epochs 3
|
||||
digits models.cnn.CNNClassifier -a 5
|
||||
digits models.cnn.CNNClassifier --save weights/cnn
|
||||
digits weights/cnn
|
||||
digits weights/cnn --run
|
||||
digits models.cnn.CNNClassifier --save cnn
|
||||
digits cnn
|
||||
digits cnn --run
|
||||
digits models.cnn.CNNClassifier --full
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -77,8 +78,13 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
metavar="DIR",
|
||||
help="After training, save the model's configuration and weights to DIR",
|
||||
metavar="NAME",
|
||||
help="After training, save the model to weights/NAME (e.g. --save cnn)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full",
|
||||
action="store_true",
|
||||
help="Train on the full MNIST dataset (60,000 examples) instead of the default 10,000-example subset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run",
|
||||
@@ -91,7 +97,7 @@ def main():
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
X_train, X_test, y_train, y_test = load_mnist()
|
||||
X_train, X_test, y_train, y_test = load_mnist(full=args.full)
|
||||
|
||||
if args.explore is not None:
|
||||
out.explore(X_train, y_train, args.explore)
|
||||
@@ -102,7 +108,7 @@ def main():
|
||||
|
||||
if is_saved_model(args.classifier):
|
||||
clf = load_model(args.classifier)
|
||||
print(f"Loaded saved model from {args.classifier}\n")
|
||||
print(f"Loaded saved model: {args.classifier}\n")
|
||||
else:
|
||||
clf = load_classifier(
|
||||
args.classifier,
|
||||
@@ -112,7 +118,7 @@ def main():
|
||||
clf.fit(X_train, y_train)
|
||||
if args.save:
|
||||
save_model(clf, args.save)
|
||||
print(f"Saved model to {args.save}\n")
|
||||
print(f"Saved model: {args.save}\n")
|
||||
|
||||
y_pred = clf.predict(X_test)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user