Files
retro-gamer/retro_gamer/network.py
Chris Proctor 5ca97dc5d0 Initial commit
2026-05-08 14:07:17 -04:00

101 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import torch
import torch.nn as nn
from retro_gamer.metadata import GameMetadata
def build_network(
metadata: GameMetadata,
hyperparams: dict,
) -> tuple[nn.Module, str]:
"""Build a Q-network from game metadata and hyperparameters.
Returns (model, rationale) where rationale is a multi-line string
describing the architecture and the reasoning behind each choice.
"""
n_layers = hyperparams.get('n_layers', 2)
layer_size = hyperparams.get('layer_size', 128)
C = len(metadata.character_set)
bw, bh = metadata.board_size
W, H = bw, bh
n_state = len(metadata.observe_state)
n_actions = metadata.n_actions
lines = []
lines.append("[INIT] === Network Architecture ===")
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
lines.append(f"[INIT] Observed state keys: {n_state} | Actions (incl. no-op): {n_actions}")
if metadata.spatial:
model = _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines)
else:
obs_size = C * W * H + n_state
model = _build_flat(obs_size, n_layers, layer_size, n_actions, lines)
lines.append(f"[INIT] Hidden layers: {n_layers} | Layer width: {layer_size}")
lines.append(f"[INIT] Output: {n_actions} Q-values")
lines.append(f"[INIT] Actions: {metadata.actions} + (no-op)")
return model, '\n'.join(lines)
def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
lines.append("[INIT] spatial=True → using CNN architecture")
lines.append("[INIT] Rationale: the board is a 2-D spatial scene; a CNN captures")
lines.append("[INIT] local patterns (walls, items nearby) more efficiently than an MLP.")
lines.append(f"[INIT] CNN: Conv2d({C}→32, k=3, pad=1) → ReLU → Conv2d(32→64, k=3, pad=1) → ReLU")
conv_out = 64 * H * W # padding=1 preserves spatial dims
lines.append(f"[INIT] CNN output: 64 channels × {H}×{W} = {conv_out} features (flattened)")
mlp_in = conv_out + n_state
lines.append(f"[INIT] MLP head input: {conv_out} (conv) + {n_state} (state) = {mlp_in}")
lines.append(f"[INIT] MLP: {''.join([str(mlp_in)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
return _SpatialNet(C, H, W, n_state, n_layers, layer_size, n_actions)
def _build_flat(obs_size, n_layers, layer_size, n_actions, lines):
lines.append("[INIT] spatial=False → using MLP architecture")
lines.append("[INIT] Rationale: the board encodes UI/status rather than a spatial scene;")
lines.append("[INIT] a flat MLP over the full observation is sufficient.")
lines.append(f"[INIT] MLP: {''.join([str(obs_size)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
return _FlatNet(obs_size, n_layers, layer_size, n_actions)
class _SpatialNet(nn.Module):
def __init__(self, C, H, W, n_state, n_layers, layer_size, n_actions):
super().__init__()
self.C, self.H, self.W = C, H, W
self.n_board = C * H * W
self.conv = nn.Sequential(
nn.Conv2d(C, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
)
conv_out = 64 * H * W
mlp_in = conv_out + n_state
layers: list[nn.Module] = []
for i in range(n_layers):
in_size = mlp_in if i == 0 else layer_size
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
layers.append(nn.Linear(layer_size, n_actions))
self.mlp = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
board = x[:, :self.n_board].reshape(-1, self.C, self.H, self.W)
state = x[:, self.n_board:]
conv_out = self.conv(board).flatten(start_dim=1)
return self.mlp(torch.cat([conv_out, state], dim=1))
class _FlatNet(nn.Module):
def __init__(self, obs_size, n_layers, layer_size, n_actions):
super().__init__()
layers: list[nn.Module] = []
for i in range(n_layers):
in_size = obs_size if i == 0 else layer_size
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
layers.append(nn.Linear(layer_size, n_actions))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)