84 lines
3.0 KiB
Python
84 lines
3.0 KiB
Python
import time
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, hidden_sizes=(128, 64)):
|
|
super().__init__()
|
|
layers = []
|
|
input_size = 784
|
|
for hidden_size in hidden_sizes:
|
|
layers.append(nn.Linear(input_size, hidden_size))
|
|
layers.append(nn.ReLU())
|
|
input_size = hidden_size
|
|
layers.append(nn.Linear(input_size, 10))
|
|
self.net = nn.Sequential(*layers)
|
|
|
|
def forward(self, pixels):
|
|
return self.net(pixels)
|
|
|
|
|
|
class MLPClassifier:
|
|
def __init__(self, hidden_sizes=(128, 64), epochs=10):
|
|
self.hidden_sizes = tuple(hidden_sizes)
|
|
self.epochs = epochs
|
|
|
|
def fit(self, X, y):
|
|
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self._model = MLP(hidden_sizes=self.hidden_sizes).to(self._device)
|
|
|
|
images = torch.tensor(X, dtype=torch.float32)
|
|
labels = torch.tensor(y, dtype=torch.long)
|
|
train_images, train_labels, val_images, val_labels = self._split(images, labels)
|
|
|
|
batches = DataLoader(TensorDataset(train_images, train_labels), batch_size=64, shuffle=True)
|
|
optimizer = optim.Adam(self._model.parameters(), lr=1e-3)
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
|
|
print(f"\nTraining MLP (hidden_sizes={self.hidden_sizes}, epochs={self.epochs})")
|
|
for epoch in range(1, self.epochs + 1):
|
|
t0 = time.time()
|
|
avg_loss = self._train_one_epoch(batches, optimizer, loss_fn)
|
|
val_accuracy = self._accuracy(val_images, val_labels)
|
|
elapsed = time.time() - t0
|
|
print(f" epoch {epoch:2d}/{self.epochs} loss={avg_loss:.3f} val_accuracy={val_accuracy:.3f} {elapsed:.1f}s")
|
|
print()
|
|
return self
|
|
|
|
def _split(self, images, labels):
|
|
n_val = len(images) // 10
|
|
return images[n_val:], labels[n_val:], images[:n_val], labels[:n_val]
|
|
|
|
def _train_one_epoch(self, batches, optimizer, loss_fn):
|
|
self._model.train()
|
|
total_loss = 0
|
|
for image_batch, label_batch in batches:
|
|
image_batch = image_batch.to(self._device)
|
|
label_batch = label_batch.to(self._device)
|
|
optimizer.zero_grad()
|
|
loss = loss_fn(self._model(image_batch), label_batch)
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss.item()
|
|
return total_loss / len(batches)
|
|
|
|
def _accuracy(self, images, labels):
|
|
self._model.eval()
|
|
with torch.no_grad():
|
|
predictions = self._model(images.to(self._device)).argmax(dim=1).cpu()
|
|
return (predictions == labels).float().mean().item()
|
|
|
|
def predict_proba(self, X):
|
|
images = torch.tensor(X, dtype=torch.float32)
|
|
self._model.eval()
|
|
with torch.no_grad():
|
|
logits = self._model(images.to(self._device))
|
|
return torch.softmax(logits, dim=1).cpu().numpy()
|
|
|
|
def predict(self, X):
|
|
return self.predict_proba(X).argmax(axis=1)
|