This commit is contained in:
Chris Proctor
2026-06-22 16:08:23 -04:00
parent 95278c854d
commit 255c189d2f
9 changed files with 111 additions and 68 deletions

View File

@@ -1,3 +1,5 @@
import time
import torch
import torch.nn as nn
import torch.optim as optim
@@ -8,16 +10,14 @@ class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3), # 28x28 -> 26x26
nn.Conv2d(1, 32, kernel_size=3, stride=2), # 28x28 -> 13x13
nn.ReLU(),
nn.MaxPool2d(2), # 26x26 -> 13x13
nn.Conv2d(32, 64, kernel_size=3), # 13x13 -> 11x11
nn.Conv2d(32, 64, kernel_size=3, stride=2), # 13x13 -> 6x6
nn.ReLU(),
nn.MaxPool2d(2), # 11x11 -> 5x5
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 5 * 5, 128),
nn.Linear(64 * 6 * 6, 128),
nn.ReLU(),
nn.Linear(128, 10),
)
@@ -51,6 +51,7 @@ class CNNClassifier:
print(f"\nTraining CNN (epochs={self.epochs})")
for epoch in range(1, self.epochs + 1):
t0 = time.time()
model.train()
total_loss = 0
for xb, yb in loader:
@@ -66,7 +67,8 @@ class CNNClassifier:
val_pred = model(X_val.to(device)).argmax(dim=1).cpu()
val_accuracy = (val_pred == y_val).float().mean().item()
print(f" epoch {epoch:2d}/{self.epochs} loss={total_loss / len(loader):.3f} val_accuracy={val_accuracy:.3f}")
elapsed = time.time() - t0
print(f" epoch {epoch:2d}/{self.epochs} loss={total_loss / len(loader):.3f} val_accuracy={val_accuracy:.3f} {elapsed:.1f}s")
print()
self._model = model