Revisions

This commit is contained in:
Chris Proctor
2026-06-08 15:15:52 -04:00
parent 1c99ca8bd3
commit 49c4e43f45
12 changed files with 209 additions and 34 deletions

View File

@@ -2,20 +2,26 @@
Usage:
digits -e
digits models.handpicked.HandPickedClassifier
digits -e 10
digits models.features.FeatureClassifier
digits models.pixels.PixelClassifier
digits models.mlp.MLPClassifier
digits models.mlp.MLPClassifier --hidden 64 64
digits models.mlp.MLPClassifier -a
digits models.cnn.CNNClassifier --epochs 3
digits models.cnn.CNNClassifier -a 5
digits models.cnn.CNNClassifier --save weights/cnn
digits weights/cnn
digits weights/cnn --run
"""
import argparse
import importlib
import cli.output as out
import cli.webcam as webcam
from cli.data import load_mnist
from cli.persistence import is_saved_model, load_model, save_model
def load_classifier(class_path, **kwargs):
@@ -33,12 +39,17 @@ def main():
parser.add_argument(
"classifier",
nargs="?",
help="Fully-qualified class, e.g. models.mlp.MLPClassifier",
help="Fully-qualified class (e.g. models.mlp.MLPClassifier), "
"or the path to a model saved with --save",
)
parser.add_argument(
"-e", "--explore",
action="store_true",
help="Show sample digits and the label distribution",
type=int,
nargs="?",
const=3,
default=None,
metavar="N",
help="Show N sample digits and the label distribution (default: 3)",
)
parser.add_argument(
"-a", "--error-analysis",
@@ -64,27 +75,45 @@ def main():
metavar="N",
help="Number of training epochs (MLPClassifier and CNNClassifier only)",
)
parser.add_argument(
"--save",
metavar="DIR",
help="After training, save the model's configuration and weights to DIR",
)
parser.add_argument(
"--run",
action="store_true",
help="Open the webcam and classify handwritten digits live",
)
args = parser.parse_args()
if not args.classifier and not args.explore:
if not args.classifier and args.explore is None:
parser.print_help()
return
X_train, X_test, y_train, y_test = load_mnist()
if args.explore:
out.explore(X_train, y_train)
if args.explore is not None:
out.explore(X_train, y_train, args.explore)
if not args.classifier:
return
out.dataset_summary(len(X_train), len(X_test))
clf = load_classifier(
args.classifier,
hidden_sizes=tuple(args.hidden) if args.hidden else None,
epochs=args.epochs,
)
clf.fit(X_train, y_train)
if is_saved_model(args.classifier):
clf = load_model(args.classifier)
print(f"Loaded saved model from {args.classifier}\n")
else:
clf = load_classifier(
args.classifier,
hidden_sizes=tuple(args.hidden) if args.hidden else None,
epochs=args.epochs,
)
clf.fit(X_train, y_train)
if args.save:
save_model(clf, args.save)
print(f"Saved model to {args.save}\n")
y_pred = clf.predict(X_test)
out.evaluation(y_test, y_pred, type(clf).__name__)
@@ -92,6 +121,9 @@ def main():
if args.error_analysis is not None:
out.error_analysis(X_test, y_test, y_pred, args.error_analysis)
if args.run:
webcam.run(clf)
if __name__ == "__main__":
main()