Files
retro-gamer/retro_gamer/network.py
2026-06-22 16:41:31 -04:00

107 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import torch
import torch.nn as nn
from retro_gamer.metadata import GameMetadata
def build_network(
metadata: GameMetadata,
hyperparams: dict,
) -> tuple[nn.Module, str]:
"""Build a Q-network from game metadata and hyperparameters.
Returns (model, rationale) where rationale is a multi-line string
describing the architecture and the reasoning behind each choice.
"""
hidden_sizes = hyperparams.get('hidden_sizes', [512, 256])
n_state = metadata.extras_size
n_actions = metadata.n_actions
lines = []
lines.append("[INIT] === Network Architecture ===")
if metadata.board:
C = len(metadata.character_set)
bw, bh = metadata.board_size
W, H = bw, bh
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
lines.append(f"[INIT] Observed state features: {n_state} | Actions (incl. no-op): {n_actions}")
if metadata.spatial:
model = _build_spatial(C, H, W, n_state, hidden_sizes, n_actions, lines)
else:
obs_size = C * W * H + n_state
model = _build_flat(obs_size, hidden_sizes, n_actions, lines)
else:
lines.append(f"[INIT] Board: disabled (board=false, state-only observation)")
lines.append(f"[INIT] Observed state features: {n_state} | Actions (incl. no-op): {n_actions}")
model = _build_flat(n_state, hidden_sizes, n_actions, lines)
lines.append(f"[INIT] Hidden layers: {len(hidden_sizes)} | Layer sizes: {hidden_sizes}")
lines.append(f"[INIT] Output: {n_actions} Q-values")
lines.append(f"[INIT] Actions: {metadata.actions} + (no-op)")
return model, '\n'.join(lines)
def _build_spatial(C, H, W, n_state, hidden_sizes, n_actions, lines):
lines.append("[INIT] spatial=True → using CNN architecture")
lines.append("[INIT] Rationale: the board is a 2-D spatial scene; a CNN captures")
lines.append("[INIT] local patterns (walls, items nearby) more efficiently than an MLP.")
lines.append(f"[INIT] CNN: Conv2d({C}→32, k=3, pad=1) → ReLU → Conv2d(32→64, k=3, pad=1) → ReLU")
conv_out = 64 * H * W # padding=1 preserves spatial dims
lines.append(f"[INIT] CNN output: 64 channels × {H}×{W} = {conv_out} features (flattened)")
mlp_in = conv_out + n_state
lines.append(f"[INIT] MLP head input: {conv_out} (conv) + {n_state} (state) = {mlp_in}")
lines.append(f"[INIT] MLP: {''.join([str(mlp_in)] + [str(s) for s in hidden_sizes] + [str(n_actions)])}")
return _SpatialNet(C, H, W, n_state, hidden_sizes, n_actions)
def _build_flat(obs_size, hidden_sizes, n_actions, lines):
lines.append("[INIT] spatial=False → using MLP architecture")
lines.append("[INIT] Rationale: the board encodes UI/status rather than a spatial scene;")
lines.append("[INIT] a flat MLP over the full observation is sufficient.")
lines.append(f"[INIT] MLP: {''.join([str(obs_size)] + [str(s) for s in hidden_sizes] + [str(n_actions)])}")
return _FlatNet(obs_size, hidden_sizes, n_actions)
class _SpatialNet(nn.Module):
def __init__(self, C, H, W, n_state, hidden_sizes, n_actions):
super().__init__()
self.C, self.H, self.W = C, H, W
self.n_board = C * H * W
self.conv = nn.Sequential(
nn.Conv2d(C, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
)
conv_out = 64 * H * W
mlp_in = conv_out + n_state
layers: list[nn.Module] = []
prev = mlp_in
for size in hidden_sizes:
layers += [nn.Linear(prev, size), nn.ReLU()]
prev = size
layers.append(nn.Linear(prev, n_actions))
self.mlp = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
board = x[:, :self.n_board].reshape(-1, self.C, self.H, self.W)
state = x[:, self.n_board:]
conv_out = self.conv(board).flatten(start_dim=1)
return self.mlp(torch.cat([conv_out, state], dim=1))
class _FlatNet(nn.Module):
def __init__(self, obs_size, hidden_sizes, n_actions):
super().__init__()
layers: list[nn.Module] = []
prev = obs_size
for size in hidden_sizes:
layers += [nn.Linear(prev, size), nn.ReLU()]
prev = size
layers.append(nn.Linear(prev, n_actions))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)