24 lines
812 B
Python
24 lines
812 B
Python
import numpy as np
|
|
from sklearn.datasets import fetch_openml
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
def load_mnist(n_train=10000, n_test=2000):
|
|
"""Load MNIST from sklearn (downloads on first run).
|
|
|
|
For speed, uses a subset of the data by default. Set n_train=60000
|
|
and n_test=10000 for the full dataset.
|
|
|
|
Returns:
|
|
X_train, X_test: float arrays of shape (n, 784), values in [0, 1]
|
|
y_train, y_test: int arrays of digit labels 0-9
|
|
"""
|
|
print("Loading MNIST (this may take a minute on first run)...")
|
|
mnist = fetch_openml("mnist_784", version=1, as_frame=False)
|
|
X = mnist.data.astype(np.float32) / 255.0
|
|
y = mnist.target.astype(int)
|
|
|
|
return train_test_split(
|
|
X, y, train_size=n_train, test_size=n_test, random_state=42, stratify=y
|
|
)
|