136 lines
3.7 KiB
Python
136 lines
3.7 KiB
Python
"""Evaluate a digit classifier.
|
|
|
|
Usage:
|
|
digits -e
|
|
digits -e 10
|
|
digits models.features.FeatureClassifier
|
|
digits models.pixels.PixelClassifier
|
|
digits models.mlp.MLPClassifier
|
|
digits models.mlp.MLPClassifier --hidden 64 64
|
|
digits models.mlp.MLPClassifier -a
|
|
digits models.cnn.CNNClassifier --epochs 3
|
|
digits models.cnn.CNNClassifier -a 5
|
|
digits models.cnn.CNNClassifier --save cnn
|
|
digits cnn
|
|
digits cnn --run
|
|
digits models.cnn.CNNClassifier --full
|
|
"""
|
|
|
|
import argparse
|
|
import importlib
|
|
|
|
import cli.output as out
|
|
import cli.webcam as webcam
|
|
from cli.data import load_mnist
|
|
from cli.persistence import is_saved_model, load_model, save_model
|
|
|
|
|
|
def load_classifier(class_path, **kwargs):
|
|
module_path, class_name = class_path.rsplit(".", 1)
|
|
module = importlib.import_module(module_path)
|
|
cls = getattr(module, class_name)
|
|
kwargs = {key: value for key, value in kwargs.items() if value is not None}
|
|
return cls(**kwargs)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Train and evaluate a digit classifier."
|
|
)
|
|
parser.add_argument(
|
|
"classifier",
|
|
nargs="?",
|
|
help="Fully-qualified class (e.g. models.mlp.MLPClassifier), "
|
|
"or the path to a model saved with --save",
|
|
)
|
|
parser.add_argument(
|
|
"-e", "--explore",
|
|
type=int,
|
|
nargs="?",
|
|
const=3,
|
|
default=None,
|
|
metavar="N",
|
|
help="Show N sample digits and the label distribution (default: 3)",
|
|
)
|
|
parser.add_argument(
|
|
"-a", "--error-analysis",
|
|
type=int,
|
|
nargs="?",
|
|
const=10,
|
|
default=None,
|
|
metavar="N",
|
|
help="Show up to N misclassified digits (default: 10)",
|
|
)
|
|
parser.add_argument(
|
|
"--hidden",
|
|
type=int,
|
|
nargs="+",
|
|
default=None,
|
|
metavar="N",
|
|
help="Hidden layer sizes, e.g. --hidden 128 64 (MLPClassifier only)",
|
|
)
|
|
parser.add_argument(
|
|
"--epochs",
|
|
type=int,
|
|
default=None,
|
|
metavar="N",
|
|
help="Number of training epochs (MLPClassifier and CNNClassifier only)",
|
|
)
|
|
parser.add_argument(
|
|
"--save",
|
|
metavar="NAME",
|
|
help="After training, save the model to weights/NAME (e.g. --save cnn)",
|
|
)
|
|
parser.add_argument(
|
|
"--full",
|
|
action="store_true",
|
|
help="Train on the full MNIST dataset (60,000 examples) instead of the default 10,000-example subset",
|
|
)
|
|
parser.add_argument(
|
|
"--run",
|
|
action="store_true",
|
|
help="Open the webcam and classify handwritten digits live",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if not args.classifier and args.explore is None:
|
|
parser.print_help()
|
|
return
|
|
|
|
X_train, X_test, y_train, y_test = load_mnist(full=args.full)
|
|
|
|
if args.explore is not None:
|
|
out.explore(X_train, y_train, args.explore)
|
|
if not args.classifier:
|
|
return
|
|
|
|
out.dataset_summary(len(X_train), len(X_test))
|
|
|
|
if is_saved_model(args.classifier):
|
|
clf = load_model(args.classifier)
|
|
print(f"Loaded saved model: {args.classifier}\n")
|
|
else:
|
|
clf = load_classifier(
|
|
args.classifier,
|
|
hidden_sizes=tuple(args.hidden) if args.hidden else None,
|
|
epochs=args.epochs,
|
|
)
|
|
clf.fit(X_train, y_train)
|
|
if args.save:
|
|
save_model(clf, args.save)
|
|
print(f"Saved model: {args.save}\n")
|
|
|
|
y_pred = clf.predict(X_test)
|
|
|
|
out.evaluation(y_test, y_pred, type(clf).__name__)
|
|
|
|
if args.error_analysis is not None:
|
|
out.error_analysis(X_test, y_test, y_pred, args.error_analysis)
|
|
|
|
if args.run:
|
|
webcam.run(clf)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|