Updates
This commit is contained in:
@@ -3,7 +3,7 @@ from sklearn.datasets import fetch_openml
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
def load_mnist(n_train=10000, n_test=2000):
|
||||
def load_mnist(n_train=10000, n_test=2000, full=False):
|
||||
"""Load MNIST from sklearn (downloads on first run).
|
||||
|
||||
For speed, uses a subset of the data by default. Set n_train=60000
|
||||
@@ -18,6 +18,8 @@ def load_mnist(n_train=10000, n_test=2000):
|
||||
X = mnist.data.astype(np.float32) / 255.0
|
||||
y = mnist.target.astype(int)
|
||||
|
||||
if full:
|
||||
n_train, n_test = 60000, 10000
|
||||
return train_test_split(
|
||||
X, y, train_size=n_train, test_size=n_test, random_state=42, stratify=y
|
||||
)
|
||||
|
||||
22
cli/main.py
22
cli/main.py
@@ -10,9 +10,10 @@ Usage:
|
||||
digits models.mlp.MLPClassifier -a
|
||||
digits models.cnn.CNNClassifier --epochs 3
|
||||
digits models.cnn.CNNClassifier -a 5
|
||||
digits models.cnn.CNNClassifier --save weights/cnn
|
||||
digits weights/cnn
|
||||
digits weights/cnn --run
|
||||
digits models.cnn.CNNClassifier --save cnn
|
||||
digits cnn
|
||||
digits cnn --run
|
||||
digits models.cnn.CNNClassifier --full
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -77,8 +78,13 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
metavar="DIR",
|
||||
help="After training, save the model's configuration and weights to DIR",
|
||||
metavar="NAME",
|
||||
help="After training, save the model to weights/NAME (e.g. --save cnn)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full",
|
||||
action="store_true",
|
||||
help="Train on the full MNIST dataset (60,000 examples) instead of the default 10,000-example subset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run",
|
||||
@@ -91,7 +97,7 @@ def main():
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
X_train, X_test, y_train, y_test = load_mnist()
|
||||
X_train, X_test, y_train, y_test = load_mnist(full=args.full)
|
||||
|
||||
if args.explore is not None:
|
||||
out.explore(X_train, y_train, args.explore)
|
||||
@@ -102,7 +108,7 @@ def main():
|
||||
|
||||
if is_saved_model(args.classifier):
|
||||
clf = load_model(args.classifier)
|
||||
print(f"Loaded saved model from {args.classifier}\n")
|
||||
print(f"Loaded saved model: {args.classifier}\n")
|
||||
else:
|
||||
clf = load_classifier(
|
||||
args.classifier,
|
||||
@@ -112,7 +118,7 @@ def main():
|
||||
clf.fit(X_train, y_train)
|
||||
if args.save:
|
||||
save_model(clf, args.save)
|
||||
print(f"Saved model to {args.save}\n")
|
||||
print(f"Saved model: {args.save}\n")
|
||||
|
||||
y_pred = clf.predict(X_test)
|
||||
|
||||
|
||||
@@ -54,6 +54,13 @@ def evaluation(y_true, y_pred, clf_name):
|
||||
print(f" {digit}: {acc:.3f} {bar}")
|
||||
print()
|
||||
|
||||
print("Confusion matrix (row=actual, col=predicted):")
|
||||
header = " " + "".join(f"{d:5d}" for d in range(10))
|
||||
print(header)
|
||||
for actual, row in enumerate(cm):
|
||||
print(f" {actual:3d} " + "".join(f"{v:5d}" for v in row))
|
||||
print()
|
||||
|
||||
|
||||
def error_analysis(X, y_true, y_pred, n):
|
||||
errors = [
|
||||
|
||||
@@ -3,16 +3,26 @@ import os
|
||||
import joblib
|
||||
|
||||
MODEL_FILE = "model.joblib"
|
||||
WEIGHTS_DIR = "weights"
|
||||
|
||||
|
||||
def _resolve(name):
|
||||
if name.startswith(WEIGHTS_DIR + os.sep) or name.startswith(WEIGHTS_DIR + "/"):
|
||||
return name
|
||||
return os.path.join(WEIGHTS_DIR, name)
|
||||
|
||||
|
||||
def is_saved_model(path):
|
||||
return os.path.isdir(path) and os.path.exists(os.path.join(path, MODEL_FILE))
|
||||
directory = _resolve(path)
|
||||
return os.path.isdir(directory) and os.path.exists(os.path.join(directory, MODEL_FILE))
|
||||
|
||||
|
||||
def save_model(clf, directory):
|
||||
def save_model(clf, name):
|
||||
directory = _resolve(name)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
joblib.dump(clf, os.path.join(directory, MODEL_FILE))
|
||||
|
||||
|
||||
def load_model(directory):
|
||||
def load_model(path):
|
||||
directory = _resolve(path)
|
||||
return joblib.load(os.path.join(directory, MODEL_FILE))
|
||||
|
||||
@@ -32,6 +32,12 @@ def run(clf):
|
||||
print("Could not open the webcam.")
|
||||
return
|
||||
|
||||
capture.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||
|
||||
# Discard the first several frames while the camera warms up
|
||||
for _ in range(10):
|
||||
capture.read()
|
||||
|
||||
print("Hold a handwritten digit up to the camera, inside the box.")
|
||||
print("Press 'q' (with the video window focused) to quit.\n")
|
||||
|
||||
@@ -56,7 +62,7 @@ def run(clf):
|
||||
cv2.putText(frame, label, (left, top - 12), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 200, 0), 2)
|
||||
cv2.imshow(WINDOW_TITLE, frame)
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
if cv2.waitKey(30) & 0xFF == ord("q"):
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user