Initial commit

This commit is contained in:
Chris Proctor
2026-06-08 12:27:01 -04:00
commit 395180d6b2
19 changed files with 1461 additions and 0 deletions

97
cli/main.py Normal file
View File

@@ -0,0 +1,97 @@
"""Evaluate a digit classifier.
Usage:
digits -e
digits models.handpicked.HandPickedClassifier
digits models.pixels.PixelClassifier
digits models.mlp.MLPClassifier
digits models.mlp.MLPClassifier --hidden 64 64
digits models.mlp.MLPClassifier -a
digits models.cnn.CNNClassifier --epochs 3
digits models.cnn.CNNClassifier -a 5
"""
import argparse
import importlib
import cli.output as out
from cli.data import load_mnist
def load_classifier(class_path, **kwargs):
module_path, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
kwargs = {key: value for key, value in kwargs.items() if value is not None}
return cls(**kwargs)
def main():
parser = argparse.ArgumentParser(
description="Train and evaluate a digit classifier."
)
parser.add_argument(
"classifier",
nargs="?",
help="Fully-qualified class, e.g. models.mlp.MLPClassifier",
)
parser.add_argument(
"-e", "--explore",
action="store_true",
help="Show sample digits and the label distribution",
)
parser.add_argument(
"-a", "--error-analysis",
type=int,
nargs="?",
const=10,
default=None,
metavar="N",
help="Show up to N misclassified digits (default: 10)",
)
parser.add_argument(
"--hidden",
type=int,
nargs="+",
default=None,
metavar="N",
help="Hidden layer sizes, e.g. --hidden 128 64 (MLPClassifier only)",
)
parser.add_argument(
"--epochs",
type=int,
default=None,
metavar="N",
help="Number of training epochs (MLPClassifier and CNNClassifier only)",
)
args = parser.parse_args()
if not args.classifier and not args.explore:
parser.print_help()
return
X_train, X_test, y_train, y_test = load_mnist()
if args.explore:
out.explore(X_train, y_train)
if not args.classifier:
return
out.dataset_summary(len(X_train), len(X_test))
clf = load_classifier(
args.classifier,
hidden_sizes=tuple(args.hidden) if args.hidden else None,
epochs=args.epochs,
)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
out.evaluation(y_test, y_pred, type(clf).__name__)
if args.error_analysis is not None:
out.error_analysis(X_test, y_test, y_pred, args.error_analysis)
if __name__ == "__main__":
main()