This commit is contained in:
Chris Proctor
2026-06-22 16:08:23 -04:00
parent 95278c854d
commit 255c189d2f
9 changed files with 111 additions and 68 deletions

View File

@@ -3,7 +3,7 @@ from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
def load_mnist(n_train=10000, n_test=2000):
def load_mnist(n_train=10000, n_test=2000, full=False):
"""Load MNIST from sklearn (downloads on first run).
For speed, uses a subset of the data by default. Set n_train=60000
@@ -18,6 +18,8 @@ def load_mnist(n_train=10000, n_test=2000):
X = mnist.data.astype(np.float32) / 255.0
y = mnist.target.astype(int)
if full:
n_train, n_test = 60000, 10000
return train_test_split(
X, y, train_size=n_train, test_size=n_test, random_state=42, stratify=y
)

View File

@@ -10,9 +10,10 @@ Usage:
digits models.mlp.MLPClassifier -a
digits models.cnn.CNNClassifier --epochs 3
digits models.cnn.CNNClassifier -a 5
digits models.cnn.CNNClassifier --save weights/cnn
digits weights/cnn
digits weights/cnn --run
digits models.cnn.CNNClassifier --save cnn
digits cnn
digits cnn --run
digits models.cnn.CNNClassifier --full
"""
import argparse
@@ -77,8 +78,13 @@ def main():
)
parser.add_argument(
"--save",
metavar="DIR",
help="After training, save the model's configuration and weights to DIR",
metavar="NAME",
help="After training, save the model to weights/NAME (e.g. --save cnn)",
)
parser.add_argument(
"--full",
action="store_true",
help="Train on the full MNIST dataset (60,000 examples) instead of the default 10,000-example subset",
)
parser.add_argument(
"--run",
@@ -91,7 +97,7 @@ def main():
parser.print_help()
return
X_train, X_test, y_train, y_test = load_mnist()
X_train, X_test, y_train, y_test = load_mnist(full=args.full)
if args.explore is not None:
out.explore(X_train, y_train, args.explore)
@@ -102,7 +108,7 @@ def main():
if is_saved_model(args.classifier):
clf = load_model(args.classifier)
print(f"Loaded saved model from {args.classifier}\n")
print(f"Loaded saved model: {args.classifier}\n")
else:
clf = load_classifier(
args.classifier,
@@ -112,7 +118,7 @@ def main():
clf.fit(X_train, y_train)
if args.save:
save_model(clf, args.save)
print(f"Saved model to {args.save}\n")
print(f"Saved model: {args.save}\n")
y_pred = clf.predict(X_test)

View File

@@ -54,6 +54,13 @@ def evaluation(y_true, y_pred, clf_name):
print(f" {digit}: {acc:.3f} {bar}")
print()
print("Confusion matrix (row=actual, col=predicted):")
header = " " + "".join(f"{d:5d}" for d in range(10))
print(header)
for actual, row in enumerate(cm):
print(f" {actual:3d} " + "".join(f"{v:5d}" for v in row))
print()
def error_analysis(X, y_true, y_pred, n):
errors = [

View File

@@ -3,16 +3,26 @@ import os
import joblib
MODEL_FILE = "model.joblib"
WEIGHTS_DIR = "weights"
def _resolve(name):
if name.startswith(WEIGHTS_DIR + os.sep) or name.startswith(WEIGHTS_DIR + "/"):
return name
return os.path.join(WEIGHTS_DIR, name)
def is_saved_model(path):
return os.path.isdir(path) and os.path.exists(os.path.join(path, MODEL_FILE))
directory = _resolve(path)
return os.path.isdir(directory) and os.path.exists(os.path.join(directory, MODEL_FILE))
def save_model(clf, directory):
def save_model(clf, name):
directory = _resolve(name)
os.makedirs(directory, exist_ok=True)
joblib.dump(clf, os.path.join(directory, MODEL_FILE))
def load_model(directory):
def load_model(path):
directory = _resolve(path)
return joblib.load(os.path.join(directory, MODEL_FILE))

View File

@@ -32,6 +32,12 @@ def run(clf):
print("Could not open the webcam.")
return
capture.set(cv2.CAP_PROP_BUFFERSIZE, 1)
# Discard the first several frames while the camera warms up
for _ in range(10):
capture.read()
print("Hold a handwritten digit up to the camera, inside the box.")
print("Press 'q' (with the video window focused) to quit.\n")
@@ -56,7 +62,7 @@ def run(clf):
cv2.putText(frame, label, (left, top - 12), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 200, 0), 2)
cv2.imshow(WINDOW_TITLE, frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
if cv2.waitKey(30) & 0xFF == ord("q"):
break
except KeyboardInterrupt:
pass