Revisions
This commit is contained in:
58
cli/main.py
58
cli/main.py
@@ -2,20 +2,26 @@
|
||||
|
||||
Usage:
|
||||
digits -e
|
||||
digits models.handpicked.HandPickedClassifier
|
||||
digits -e 10
|
||||
digits models.features.FeatureClassifier
|
||||
digits models.pixels.PixelClassifier
|
||||
digits models.mlp.MLPClassifier
|
||||
digits models.mlp.MLPClassifier --hidden 64 64
|
||||
digits models.mlp.MLPClassifier -a
|
||||
digits models.cnn.CNNClassifier --epochs 3
|
||||
digits models.cnn.CNNClassifier -a 5
|
||||
digits models.cnn.CNNClassifier --save weights/cnn
|
||||
digits weights/cnn
|
||||
digits weights/cnn --run
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
|
||||
import cli.output as out
|
||||
import cli.webcam as webcam
|
||||
from cli.data import load_mnist
|
||||
from cli.persistence import is_saved_model, load_model, save_model
|
||||
|
||||
|
||||
def load_classifier(class_path, **kwargs):
|
||||
@@ -33,12 +39,17 @@ def main():
|
||||
parser.add_argument(
|
||||
"classifier",
|
||||
nargs="?",
|
||||
help="Fully-qualified class, e.g. models.mlp.MLPClassifier",
|
||||
help="Fully-qualified class (e.g. models.mlp.MLPClassifier), "
|
||||
"or the path to a model saved with --save",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e", "--explore",
|
||||
action="store_true",
|
||||
help="Show sample digits and the label distribution",
|
||||
type=int,
|
||||
nargs="?",
|
||||
const=3,
|
||||
default=None,
|
||||
metavar="N",
|
||||
help="Show N sample digits and the label distribution (default: 3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--error-analysis",
|
||||
@@ -64,27 +75,45 @@ def main():
|
||||
metavar="N",
|
||||
help="Number of training epochs (MLPClassifier and CNNClassifier only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
metavar="DIR",
|
||||
help="After training, save the model's configuration and weights to DIR",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run",
|
||||
action="store_true",
|
||||
help="Open the webcam and classify handwritten digits live",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.classifier and not args.explore:
|
||||
if not args.classifier and args.explore is None:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
X_train, X_test, y_train, y_test = load_mnist()
|
||||
|
||||
if args.explore:
|
||||
out.explore(X_train, y_train)
|
||||
if args.explore is not None:
|
||||
out.explore(X_train, y_train, args.explore)
|
||||
if not args.classifier:
|
||||
return
|
||||
|
||||
out.dataset_summary(len(X_train), len(X_test))
|
||||
|
||||
clf = load_classifier(
|
||||
args.classifier,
|
||||
hidden_sizes=tuple(args.hidden) if args.hidden else None,
|
||||
epochs=args.epochs,
|
||||
)
|
||||
clf.fit(X_train, y_train)
|
||||
if is_saved_model(args.classifier):
|
||||
clf = load_model(args.classifier)
|
||||
print(f"Loaded saved model from {args.classifier}\n")
|
||||
else:
|
||||
clf = load_classifier(
|
||||
args.classifier,
|
||||
hidden_sizes=tuple(args.hidden) if args.hidden else None,
|
||||
epochs=args.epochs,
|
||||
)
|
||||
clf.fit(X_train, y_train)
|
||||
if args.save:
|
||||
save_model(clf, args.save)
|
||||
print(f"Saved model to {args.save}\n")
|
||||
|
||||
y_pred = clf.predict(X_test)
|
||||
|
||||
out.evaluation(y_test, y_pred, type(clf).__name__)
|
||||
@@ -92,6 +121,9 @@ def main():
|
||||
if args.error_analysis is not None:
|
||||
out.error_analysis(X_test, y_test, y_pred, args.error_analysis)
|
||||
|
||||
if args.run:
|
||||
webcam.run(clf)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -8,11 +8,11 @@ def show_digit(pixels):
|
||||
print(" ".join("#" if p > 0.5 else "." for p in row))
|
||||
|
||||
|
||||
def explore(X_train, y_train):
|
||||
def explore(X_train, y_train, n=3):
|
||||
print("=" * 60)
|
||||
print("SAMPLE DIGITS")
|
||||
print("=" * 60)
|
||||
for i in range(3):
|
||||
for i in range(n):
|
||||
print(f"\nLabel: {y_train[i]}")
|
||||
show_digit(X_train[i])
|
||||
|
||||
|
||||
18
cli/persistence.py
Normal file
18
cli/persistence.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import os
|
||||
|
||||
import joblib
|
||||
|
||||
MODEL_FILE = "model.joblib"
|
||||
|
||||
|
||||
def is_saved_model(path):
|
||||
return os.path.isdir(path) and os.path.exists(os.path.join(path, MODEL_FILE))
|
||||
|
||||
|
||||
def save_model(clf, directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
joblib.dump(clf, os.path.join(directory, MODEL_FILE))
|
||||
|
||||
|
||||
def load_model(directory):
|
||||
return joblib.load(os.path.join(directory, MODEL_FILE))
|
||||
66
cli/webcam.py
Normal file
66
cli/webcam.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
WINDOW_TITLE = "Hold up a digit -- press q to quit"
|
||||
|
||||
|
||||
def preprocess(region):
|
||||
"""Turn a captured square region into a 784-value array like MNIST's.
|
||||
|
||||
MNIST digits are white strokes on a black background, so after
|
||||
converting to grayscale and shrinking to 28x28, we invert the
|
||||
brightness (ink-on-paper is normally dark-on-light) and scale
|
||||
pixel values down to the [0, 1] range `load_mnist` uses.
|
||||
"""
|
||||
gray = cv2.cvtColor(region, cv2.COLOR_BGR2GRAY)
|
||||
small = cv2.resize(gray, (28, 28))
|
||||
inverted = 255 - small
|
||||
return (inverted.astype(np.float32) / 255.0).flatten()
|
||||
|
||||
|
||||
def central_square(frame):
|
||||
height, width = frame.shape[:2]
|
||||
size = min(height, width)
|
||||
top = (height - size) // 2
|
||||
left = (width - size) // 2
|
||||
return top, left, size
|
||||
|
||||
|
||||
def run(clf):
|
||||
capture = cv2.VideoCapture(0)
|
||||
if not capture.isOpened():
|
||||
print("Could not open the webcam.")
|
||||
return
|
||||
|
||||
print("Hold a handwritten digit up to the camera, inside the box.")
|
||||
print("Press 'q' (with the video window focused) to quit.\n")
|
||||
|
||||
try:
|
||||
while True:
|
||||
found, frame = capture.read()
|
||||
if not found:
|
||||
break
|
||||
frame = cv2.flip(frame, 1)
|
||||
|
||||
top, left, size = central_square(frame)
|
||||
region = frame[top:top + size, left:left + size]
|
||||
pixels = preprocess(region)
|
||||
|
||||
probabilities = clf.predict_proba([pixels])[0]
|
||||
digit = probabilities.argmax()
|
||||
confidence = probabilities[digit]
|
||||
print(f"\rpredicted digit: {digit} confidence: {confidence:.2f} ", end="", flush=True)
|
||||
|
||||
label = f"{digit} ({confidence:.0%})"
|
||||
cv2.rectangle(frame, (left, top), (left + size, top + size), (0, 200, 0), 2)
|
||||
cv2.putText(frame, label, (left, top - 12), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 200, 0), 2)
|
||||
cv2.imshow(WINDOW_TITLE, frame)
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
capture.release()
|
||||
cv2.destroyAllWindows()
|
||||
print()
|
||||
Reference in New Issue
Block a user