This commit is contained in:
Chris Proctor
2026-06-22 16:08:23 -04:00
parent 95278c854d
commit 255c189d2f
9 changed files with 111 additions and 68 deletions

View File

@@ -10,9 +10,10 @@ Usage:
digits models.mlp.MLPClassifier -a
digits models.cnn.CNNClassifier --epochs 3
digits models.cnn.CNNClassifier -a 5
digits models.cnn.CNNClassifier --save weights/cnn
digits weights/cnn
digits weights/cnn --run
digits models.cnn.CNNClassifier --save cnn
digits cnn
digits cnn --run
digits models.cnn.CNNClassifier --full
"""
import argparse
@@ -77,8 +78,13 @@ def main():
)
parser.add_argument(
"--save",
metavar="DIR",
help="After training, save the model's configuration and weights to DIR",
metavar="NAME",
help="After training, save the model to weights/NAME (e.g. --save cnn)",
)
parser.add_argument(
"--full",
action="store_true",
help="Train on the full MNIST dataset (60,000 examples) instead of the default 10,000-example subset",
)
parser.add_argument(
"--run",
@@ -91,7 +97,7 @@ def main():
parser.print_help()
return
X_train, X_test, y_train, y_test = load_mnist()
X_train, X_test, y_train, y_test = load_mnist(full=args.full)
if args.explore is not None:
out.explore(X_train, y_train, args.explore)
@@ -102,7 +108,7 @@ def main():
if is_saved_model(args.classifier):
clf = load_model(args.classifier)
print(f"Loaded saved model from {args.classifier}\n")
print(f"Loaded saved model: {args.classifier}\n")
else:
clf = load_classifier(
args.classifier,
@@ -112,7 +118,7 @@ def main():
clf.fit(X_train, y_train)
if args.save:
save_model(clf, args.save)
print(f"Saved model to {args.save}\n")
print(f"Saved model: {args.save}\n")
y_pred = clf.predict(X_test)