Initial commit

This commit is contained in:
Chris Proctor
2026-05-08 14:07:17 -04:00
commit 5ca97dc5d0
36 changed files with 4147 additions and 0 deletions

100
retro_gamer/network.py Normal file
View File

@@ -0,0 +1,100 @@
from __future__ import annotations
import torch
import torch.nn as nn
from retro_gamer.metadata import GameMetadata
def build_network(
metadata: GameMetadata,
hyperparams: dict,
) -> tuple[nn.Module, str]:
"""Build a Q-network from game metadata and hyperparameters.
Returns (model, rationale) where rationale is a multi-line string
describing the architecture and the reasoning behind each choice.
"""
n_layers = hyperparams.get('n_layers', 2)
layer_size = hyperparams.get('layer_size', 128)
C = len(metadata.character_set)
bw, bh = metadata.board_size
W, H = bw, bh
n_state = len(metadata.observe_state)
n_actions = metadata.n_actions
lines = []
lines.append("[INIT] === Network Architecture ===")
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
lines.append(f"[INIT] Observed state keys: {n_state} | Actions (incl. no-op): {n_actions}")
if metadata.spatial:
model = _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines)
else:
obs_size = C * W * H + n_state
model = _build_flat(obs_size, n_layers, layer_size, n_actions, lines)
lines.append(f"[INIT] Hidden layers: {n_layers} | Layer width: {layer_size}")
lines.append(f"[INIT] Output: {n_actions} Q-values")
lines.append(f"[INIT] Actions: {metadata.actions} + (no-op)")
return model, '\n'.join(lines)
def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
lines.append("[INIT] spatial=True → using CNN architecture")
lines.append("[INIT] Rationale: the board is a 2-D spatial scene; a CNN captures")
lines.append("[INIT] local patterns (walls, items nearby) more efficiently than an MLP.")
lines.append(f"[INIT] CNN: Conv2d({C}→32, k=3, pad=1) → ReLU → Conv2d(32→64, k=3, pad=1) → ReLU")
conv_out = 64 * H * W # padding=1 preserves spatial dims
lines.append(f"[INIT] CNN output: 64 channels × {H}×{W} = {conv_out} features (flattened)")
mlp_in = conv_out + n_state
lines.append(f"[INIT] MLP head input: {conv_out} (conv) + {n_state} (state) = {mlp_in}")
lines.append(f"[INIT] MLP: {''.join([str(mlp_in)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
return _SpatialNet(C, H, W, n_state, n_layers, layer_size, n_actions)
def _build_flat(obs_size, n_layers, layer_size, n_actions, lines):
lines.append("[INIT] spatial=False → using MLP architecture")
lines.append("[INIT] Rationale: the board encodes UI/status rather than a spatial scene;")
lines.append("[INIT] a flat MLP over the full observation is sufficient.")
lines.append(f"[INIT] MLP: {''.join([str(obs_size)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
return _FlatNet(obs_size, n_layers, layer_size, n_actions)
class _SpatialNet(nn.Module):
def __init__(self, C, H, W, n_state, n_layers, layer_size, n_actions):
super().__init__()
self.C, self.H, self.W = C, H, W
self.n_board = C * H * W
self.conv = nn.Sequential(
nn.Conv2d(C, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
)
conv_out = 64 * H * W
mlp_in = conv_out + n_state
layers: list[nn.Module] = []
for i in range(n_layers):
in_size = mlp_in if i == 0 else layer_size
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
layers.append(nn.Linear(layer_size, n_actions))
self.mlp = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
board = x[:, :self.n_board].reshape(-1, self.C, self.H, self.W)
state = x[:, self.n_board:]
conv_out = self.conv(board).flatten(start_dim=1)
return self.mlp(torch.cat([conv_out, state], dim=1))
class _FlatNet(nn.Module):
def __init__(self, obs_size, n_layers, layer_size, n_actions):
super().__init__()
layers: list[nn.Module] = []
for i in range(n_layers):
in_size = obs_size if i == 0 else layer_size
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
layers.append(nn.Linear(layer_size, n_actions))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)