Initial commit
This commit is contained in:
0
cli/__init__.py
Normal file
0
cli/__init__.py
Normal file
BIN
cli/__pycache__/__init__.cpython-314.pyc
Normal file
BIN
cli/__pycache__/__init__.cpython-314.pyc
Normal file
Binary file not shown.
BIN
cli/__pycache__/data.cpython-314.pyc
Normal file
BIN
cli/__pycache__/data.cpython-314.pyc
Normal file
Binary file not shown.
BIN
cli/__pycache__/main.cpython-314.pyc
Normal file
BIN
cli/__pycache__/main.cpython-314.pyc
Normal file
Binary file not shown.
BIN
cli/__pycache__/output.cpython-314.pyc
Normal file
BIN
cli/__pycache__/output.cpython-314.pyc
Normal file
Binary file not shown.
23
cli/data.py
Normal file
23
cli/data.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import numpy as np
|
||||
from sklearn.datasets import fetch_openml
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
def load_mnist(n_train=10000, n_test=2000):
|
||||
"""Load MNIST from sklearn (downloads on first run).
|
||||
|
||||
For speed, uses a subset of the data by default. Set n_train=60000
|
||||
and n_test=10000 for the full dataset.
|
||||
|
||||
Returns:
|
||||
X_train, X_test: float arrays of shape (n, 784), values in [0, 1]
|
||||
y_train, y_test: int arrays of digit labels 0-9
|
||||
"""
|
||||
print("Loading MNIST (this may take a minute on first run)...")
|
||||
mnist = fetch_openml("mnist_784", version=1, as_frame=False)
|
||||
X = mnist.data.astype(np.float32) / 255.0
|
||||
y = mnist.target.astype(int)
|
||||
|
||||
return train_test_split(
|
||||
X, y, train_size=n_train, test_size=n_test, random_state=42, stratify=y
|
||||
)
|
||||
97
cli/main.py
Normal file
97
cli/main.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Evaluate a digit classifier.
|
||||
|
||||
Usage:
|
||||
digits -e
|
||||
digits models.handpicked.HandPickedClassifier
|
||||
digits models.pixels.PixelClassifier
|
||||
digits models.mlp.MLPClassifier
|
||||
digits models.mlp.MLPClassifier --hidden 64 64
|
||||
digits models.mlp.MLPClassifier -a
|
||||
digits models.cnn.CNNClassifier --epochs 3
|
||||
digits models.cnn.CNNClassifier -a 5
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
|
||||
import cli.output as out
|
||||
from cli.data import load_mnist
|
||||
|
||||
|
||||
def load_classifier(class_path, **kwargs):
|
||||
module_path, class_name = class_path.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
cls = getattr(module, class_name)
|
||||
kwargs = {key: value for key, value in kwargs.items() if value is not None}
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train and evaluate a digit classifier."
|
||||
)
|
||||
parser.add_argument(
|
||||
"classifier",
|
||||
nargs="?",
|
||||
help="Fully-qualified class, e.g. models.mlp.MLPClassifier",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e", "--explore",
|
||||
action="store_true",
|
||||
help="Show sample digits and the label distribution",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--error-analysis",
|
||||
type=int,
|
||||
nargs="?",
|
||||
const=10,
|
||||
default=None,
|
||||
metavar="N",
|
||||
help="Show up to N misclassified digits (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
metavar="N",
|
||||
help="Hidden layer sizes, e.g. --hidden 128 64 (MLPClassifier only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
type=int,
|
||||
default=None,
|
||||
metavar="N",
|
||||
help="Number of training epochs (MLPClassifier and CNNClassifier only)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.classifier and not args.explore:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
X_train, X_test, y_train, y_test = load_mnist()
|
||||
|
||||
if args.explore:
|
||||
out.explore(X_train, y_train)
|
||||
if not args.classifier:
|
||||
return
|
||||
|
||||
out.dataset_summary(len(X_train), len(X_test))
|
||||
|
||||
clf = load_classifier(
|
||||
args.classifier,
|
||||
hidden_sizes=tuple(args.hidden) if args.hidden else None,
|
||||
epochs=args.epochs,
|
||||
)
|
||||
clf.fit(X_train, y_train)
|
||||
y_pred = clf.predict(X_test)
|
||||
|
||||
out.evaluation(y_test, y_pred, type(clf).__name__)
|
||||
|
||||
if args.error_analysis is not None:
|
||||
out.error_analysis(X_test, y_test, y_pred, args.error_analysis)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
71
cli/output.py
Normal file
71
cli/output.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import numpy as np
|
||||
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
|
||||
|
||||
|
||||
def show_digit(pixels):
|
||||
img = pixels.reshape(28, 28)
|
||||
for row in img:
|
||||
print(" ".join("#" if p > 0.5 else "." for p in row))
|
||||
|
||||
|
||||
def explore(X_train, y_train):
|
||||
print("=" * 60)
|
||||
print("SAMPLE DIGITS")
|
||||
print("=" * 60)
|
||||
for i in range(3):
|
||||
print(f"\nLabel: {y_train[i]}")
|
||||
show_digit(X_train[i])
|
||||
|
||||
unique, counts = np.unique(y_train, return_counts=True)
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("LABEL DISTRIBUTION")
|
||||
print("=" * 60)
|
||||
for digit, count in zip(unique, counts):
|
||||
print(f" {digit}: {count:5d} ({100 * count / len(y_train):.1f}%)")
|
||||
print()
|
||||
|
||||
|
||||
def dataset_summary(n_train, n_test):
|
||||
print("=" * 60)
|
||||
print("DATASET")
|
||||
print("=" * 60)
|
||||
total = n_train + n_test
|
||||
print(f"\n train: {n_train} ({100 * n_train / total:.0f}%) test: {n_test} ({100 * n_test / total:.0f}%)")
|
||||
print()
|
||||
|
||||
|
||||
def evaluation(y_true, y_pred, clf_name):
|
||||
print("=" * 60)
|
||||
print(f"RESULTS: {clf_name}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
accuracy = accuracy_score(y_true, y_pred)
|
||||
avg_f1 = f1_score(y_true, y_pred, average="macro")
|
||||
print(f" accuracy {accuracy:>10.3f}")
|
||||
print(f" average f1 {avg_f1:>10.3f}")
|
||||
print()
|
||||
|
||||
cm = confusion_matrix(y_true, y_pred, labels=list(range(10)))
|
||||
per_digit = cm.diagonal() / cm.sum(axis=1)
|
||||
print("Per-digit accuracy:")
|
||||
for digit, acc in enumerate(per_digit):
|
||||
bar = "+" * int(acc * 30)
|
||||
print(f" {digit}: {acc:.3f} {bar}")
|
||||
print()
|
||||
|
||||
|
||||
def error_analysis(X, y_true, y_pred, n):
|
||||
errors = [
|
||||
(pixels, t, p)
|
||||
for pixels, t, p in zip(X, y_true, y_pred)
|
||||
if t != p
|
||||
]
|
||||
shown = errors[:n]
|
||||
print("=" * 60)
|
||||
print(f"ERROR ANALYSIS ({len(shown)} of {len(errors)} misclassified)")
|
||||
print("=" * 60)
|
||||
for pixels, true_label, pred_label in shown:
|
||||
print(f"\n true={true_label} pred={pred_label}")
|
||||
show_digit(pixels)
|
||||
print()
|
||||
Reference in New Issue
Block a user