"""Evaluate a digit classifier. Usage: digits -e digits -e 10 digits models.features.FeatureClassifier digits models.pixels.PixelClassifier digits models.mlp.MLPClassifier digits models.mlp.MLPClassifier --hidden 64 64 digits models.mlp.MLPClassifier -a digits models.cnn.CNNClassifier --epochs 3 digits models.cnn.CNNClassifier -a 5 digits models.cnn.CNNClassifier --save weights/cnn digits weights/cnn digits weights/cnn --run """ import argparse import importlib import cli.output as out import cli.webcam as webcam from cli.data import load_mnist from cli.persistence import is_saved_model, load_model, save_model def load_classifier(class_path, **kwargs): module_path, class_name = class_path.rsplit(".", 1) module = importlib.import_module(module_path) cls = getattr(module, class_name) kwargs = {key: value for key, value in kwargs.items() if value is not None} return cls(**kwargs) def main(): parser = argparse.ArgumentParser( description="Train and evaluate a digit classifier." ) parser.add_argument( "classifier", nargs="?", help="Fully-qualified class (e.g. models.mlp.MLPClassifier), " "or the path to a model saved with --save", ) parser.add_argument( "-e", "--explore", type=int, nargs="?", const=3, default=None, metavar="N", help="Show N sample digits and the label distribution (default: 3)", ) parser.add_argument( "-a", "--error-analysis", type=int, nargs="?", const=10, default=None, metavar="N", help="Show up to N misclassified digits (default: 10)", ) parser.add_argument( "--hidden", type=int, nargs="+", default=None, metavar="N", help="Hidden layer sizes, e.g. --hidden 128 64 (MLPClassifier only)", ) parser.add_argument( "--epochs", type=int, default=None, metavar="N", help="Number of training epochs (MLPClassifier and CNNClassifier only)", ) parser.add_argument( "--save", metavar="DIR", help="After training, save the model's configuration and weights to DIR", ) parser.add_argument( "--run", action="store_true", help="Open the webcam and classify handwritten digits live", ) args = parser.parse_args() if not args.classifier and args.explore is None: parser.print_help() return X_train, X_test, y_train, y_test = load_mnist() if args.explore is not None: out.explore(X_train, y_train, args.explore) if not args.classifier: return out.dataset_summary(len(X_train), len(X_test)) if is_saved_model(args.classifier): clf = load_model(args.classifier) print(f"Loaded saved model from {args.classifier}\n") else: clf = load_classifier( args.classifier, hidden_sizes=tuple(args.hidden) if args.hidden else None, epochs=args.epochs, ) clf.fit(X_train, y_train) if args.save: save_model(clf, args.save) print(f"Saved model to {args.save}\n") y_pred = clf.predict(X_test) out.evaluation(y_test, y_pred, type(clf).__name__) if args.error_analysis is not None: out.error_analysis(X_test, y_test, y_pred, args.error_analysis) if args.run: webcam.run(clf) if __name__ == "__main__": main()