from __future__ import annotations import torch import torch.nn as nn from retro_gamer.metadata import GameMetadata def build_network( metadata: GameMetadata, hyperparams: dict, ) -> tuple[nn.Module, str]: """Build a Q-network from game metadata and hyperparameters. Returns (model, rationale) where rationale is a multi-line string describing the architecture and the reasoning behind each choice. """ hidden_sizes = hyperparams.get('hidden_sizes', [512, 256]) n_state = metadata.extras_size n_actions = metadata.n_actions lines = [] lines.append("[INIT] === Network Architecture ===") if metadata.board: C = len(metadata.character_set) bw, bh = metadata.board_size W, H = bw, bh lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)") lines.append(f"[INIT] Observed state features: {n_state} | Actions (incl. no-op): {n_actions}") if metadata.spatial: model = _build_spatial(C, H, W, n_state, hidden_sizes, n_actions, lines) else: obs_size = C * W * H + n_state model = _build_flat(obs_size, hidden_sizes, n_actions, lines) else: lines.append(f"[INIT] Board: disabled (board=false, state-only observation)") lines.append(f"[INIT] Observed state features: {n_state} | Actions (incl. no-op): {n_actions}") model = _build_flat(n_state, hidden_sizes, n_actions, lines) lines.append(f"[INIT] Hidden layers: {len(hidden_sizes)} | Layer sizes: {hidden_sizes}") lines.append(f"[INIT] Output: {n_actions} Q-values") lines.append(f"[INIT] Actions: {metadata.actions} + (no-op)") return model, '\n'.join(lines) def _build_spatial(C, H, W, n_state, hidden_sizes, n_actions, lines): lines.append("[INIT] spatial=True → using CNN architecture") lines.append("[INIT] Rationale: the board is a 2-D spatial scene; a CNN captures") lines.append("[INIT] local patterns (walls, items nearby) more efficiently than an MLP.") lines.append(f"[INIT] CNN: Conv2d({C}→32, k=3, pad=1) → ReLU → Conv2d(32→64, k=3, pad=1) → ReLU") conv_out = 64 * H * W # padding=1 preserves spatial dims lines.append(f"[INIT] CNN output: 64 channels × {H}×{W} = {conv_out} features (flattened)") mlp_in = conv_out + n_state lines.append(f"[INIT] MLP head input: {conv_out} (conv) + {n_state} (state) = {mlp_in}") lines.append(f"[INIT] MLP: {' → '.join([str(mlp_in)] + [str(s) for s in hidden_sizes] + [str(n_actions)])}") return _SpatialNet(C, H, W, n_state, hidden_sizes, n_actions) def _build_flat(obs_size, hidden_sizes, n_actions, lines): lines.append("[INIT] spatial=False → using MLP architecture") lines.append("[INIT] Rationale: the board encodes UI/status rather than a spatial scene;") lines.append("[INIT] a flat MLP over the full observation is sufficient.") lines.append(f"[INIT] MLP: {' → '.join([str(obs_size)] + [str(s) for s in hidden_sizes] + [str(n_actions)])}") return _FlatNet(obs_size, hidden_sizes, n_actions) class _SpatialNet(nn.Module): def __init__(self, C, H, W, n_state, hidden_sizes, n_actions): super().__init__() self.C, self.H, self.W = C, H, W self.n_board = C * H * W self.conv = nn.Sequential( nn.Conv2d(C, 32, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), ) conv_out = 64 * H * W mlp_in = conv_out + n_state layers: list[nn.Module] = [] prev = mlp_in for size in hidden_sizes: layers += [nn.Linear(prev, size), nn.ReLU()] prev = size layers.append(nn.Linear(prev, n_actions)) self.mlp = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: board = x[:, :self.n_board].reshape(-1, self.C, self.H, self.W) state = x[:, self.n_board:] conv_out = self.conv(board).flatten(start_dim=1) return self.mlp(torch.cat([conv_out, state], dim=1)) class _FlatNet(nn.Module): def __init__(self, obs_size, hidden_sizes, n_actions): super().__init__() layers: list[nn.Module] = [] prev = obs_size for size in hidden_sizes: layers += [nn.Linear(prev, size), nn.ReLU()] prev = size layers.append(nn.Linear(prev, n_actions)) self.net = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x)