import numpy as np from sklearn.datasets import fetch_openml from sklearn.model_selection import train_test_split def load_mnist(n_train=10000, n_test=2000, full=False): """Load MNIST from sklearn (downloads on first run). For speed, uses a subset of the data by default. Set n_train=60000 and n_test=10000 for the full dataset. Returns: X_train, X_test: float arrays of shape (n, 784), values in [0, 1] y_train, y_test: int arrays of digit labels 0-9 """ print("Loading MNIST (this may take a minute on first run)...") mnist = fetch_openml("mnist_784", version=1, as_frame=False) X = mnist.data.astype(np.float32) / 255.0 y = mnist.target.astype(int) if full: n_train, n_test = 60000, 10000 return train_test_split( X, y, train_size=n_train, test_size=n_test, random_state=42, stratify=y )