Updates
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
@@ -8,16 +10,14 @@ class CNN(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(1, 32, kernel_size=3), # 28x28 -> 26x26
|
||||
nn.Conv2d(1, 32, kernel_size=3, stride=2), # 28x28 -> 13x13
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2), # 26x26 -> 13x13
|
||||
nn.Conv2d(32, 64, kernel_size=3), # 13x13 -> 11x11
|
||||
nn.Conv2d(32, 64, kernel_size=3, stride=2), # 13x13 -> 6x6
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2), # 11x11 -> 5x5
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(64 * 5 * 5, 128),
|
||||
nn.Linear(64 * 6 * 6, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 10),
|
||||
)
|
||||
@@ -51,6 +51,7 @@ class CNNClassifier:
|
||||
|
||||
print(f"\nTraining CNN (epochs={self.epochs})")
|
||||
for epoch in range(1, self.epochs + 1):
|
||||
t0 = time.time()
|
||||
model.train()
|
||||
total_loss = 0
|
||||
for xb, yb in loader:
|
||||
@@ -66,7 +67,8 @@ class CNNClassifier:
|
||||
val_pred = model(X_val.to(device)).argmax(dim=1).cpu()
|
||||
val_accuracy = (val_pred == y_val).float().mean().item()
|
||||
|
||||
print(f" epoch {epoch:2d}/{self.epochs} loss={total_loss / len(loader):.3f} val_accuracy={val_accuracy:.3f}")
|
||||
elapsed = time.time() - t0
|
||||
print(f" epoch {epoch:2d}/{self.epochs} loss={total_loss / len(loader):.3f} val_accuracy={val_accuracy:.3f} {elapsed:.1f}s")
|
||||
print()
|
||||
|
||||
self._model = model
|
||||
|
||||
Reference in New Issue
Block a user