"""Evaluate a digit classifier. Usage: digits -e digits models.handpicked.HandPickedClassifier digits models.pixels.PixelClassifier digits models.mlp.MLPClassifier digits models.mlp.MLPClassifier --hidden 64 64 digits models.mlp.MLPClassifier -a digits models.cnn.CNNClassifier --epochs 3 digits models.cnn.CNNClassifier -a 5 """ import argparse import importlib import cli.output as out from cli.data import load_mnist def load_classifier(class_path, **kwargs): module_path, class_name = class_path.rsplit(".", 1) module = importlib.import_module(module_path) cls = getattr(module, class_name) kwargs = {key: value for key, value in kwargs.items() if value is not None} return cls(**kwargs) def main(): parser = argparse.ArgumentParser( description="Train and evaluate a digit classifier." ) parser.add_argument( "classifier", nargs="?", help="Fully-qualified class, e.g. models.mlp.MLPClassifier", ) parser.add_argument( "-e", "--explore", action="store_true", help="Show sample digits and the label distribution", ) parser.add_argument( "-a", "--error-analysis", type=int, nargs="?", const=10, default=None, metavar="N", help="Show up to N misclassified digits (default: 10)", ) parser.add_argument( "--hidden", type=int, nargs="+", default=None, metavar="N", help="Hidden layer sizes, e.g. --hidden 128 64 (MLPClassifier only)", ) parser.add_argument( "--epochs", type=int, default=None, metavar="N", help="Number of training epochs (MLPClassifier and CNNClassifier only)", ) args = parser.parse_args() if not args.classifier and not args.explore: parser.print_help() return X_train, X_test, y_train, y_test = load_mnist() if args.explore: out.explore(X_train, y_train) if not args.classifier: return out.dataset_summary(len(X_train), len(X_test)) clf = load_classifier( args.classifier, hidden_sizes=tuple(args.hidden) if args.hidden else None, epochs=args.epochs, ) clf.fit(X_train, y_train) y_pred = clf.predict(X_test) out.evaluation(y_test, y_pred, type(clf).__name__) if args.error_analysis is not None: out.error_analysis(X_test, y_test, y_pred, args.error_analysis) if __name__ == "__main__": main()