Updates
This commit is contained in:
@@ -3,7 +3,7 @@ from sklearn.datasets import fetch_openml
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
def load_mnist(n_train=10000, n_test=2000):
|
||||
def load_mnist(n_train=10000, n_test=2000, full=False):
|
||||
"""Load MNIST from sklearn (downloads on first run).
|
||||
|
||||
For speed, uses a subset of the data by default. Set n_train=60000
|
||||
@@ -18,6 +18,8 @@ def load_mnist(n_train=10000, n_test=2000):
|
||||
X = mnist.data.astype(np.float32) / 255.0
|
||||
y = mnist.target.astype(int)
|
||||
|
||||
if full:
|
||||
n_train, n_test = 60000, 10000
|
||||
return train_test_split(
|
||||
X, y, train_size=n_train, test_size=n_test, random_state=42, stratify=y
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user