101 lines
4.3 KiB
Python
101 lines
4.3 KiB
Python
from __future__ import annotations
|
||
import torch
|
||
import torch.nn as nn
|
||
from retro_gamer.metadata import GameMetadata
|
||
|
||
|
||
def build_network(
|
||
metadata: GameMetadata,
|
||
hyperparams: dict,
|
||
) -> tuple[nn.Module, str]:
|
||
"""Build a Q-network from game metadata and hyperparameters.
|
||
|
||
Returns (model, rationale) where rationale is a multi-line string
|
||
describing the architecture and the reasoning behind each choice.
|
||
"""
|
||
n_layers = hyperparams.get('n_layers', 2)
|
||
layer_size = hyperparams.get('layer_size', 128)
|
||
C = len(metadata.character_set)
|
||
bw, bh = metadata.board_size
|
||
W, H = bw, bh
|
||
n_state = len(metadata.observe_state)
|
||
n_actions = metadata.n_actions
|
||
|
||
lines = []
|
||
lines.append("[INIT] === Network Architecture ===")
|
||
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
|
||
lines.append(f"[INIT] Observed state keys: {n_state} | Actions (incl. no-op): {n_actions}")
|
||
|
||
if metadata.spatial:
|
||
model = _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines)
|
||
else:
|
||
obs_size = C * W * H + n_state
|
||
model = _build_flat(obs_size, n_layers, layer_size, n_actions, lines)
|
||
|
||
lines.append(f"[INIT] Hidden layers: {n_layers} | Layer width: {layer_size}")
|
||
lines.append(f"[INIT] Output: {n_actions} Q-values")
|
||
lines.append(f"[INIT] Actions: {metadata.actions} + (no-op)")
|
||
return model, '\n'.join(lines)
|
||
|
||
|
||
def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
|
||
lines.append("[INIT] spatial=True → using CNN architecture")
|
||
lines.append("[INIT] Rationale: the board is a 2-D spatial scene; a CNN captures")
|
||
lines.append("[INIT] local patterns (walls, items nearby) more efficiently than an MLP.")
|
||
lines.append(f"[INIT] CNN: Conv2d({C}→32, k=3, pad=1) → ReLU → Conv2d(32→64, k=3, pad=1) → ReLU")
|
||
conv_out = 64 * H * W # padding=1 preserves spatial dims
|
||
lines.append(f"[INIT] CNN output: 64 channels × {H}×{W} = {conv_out} features (flattened)")
|
||
mlp_in = conv_out + n_state
|
||
lines.append(f"[INIT] MLP head input: {conv_out} (conv) + {n_state} (state) = {mlp_in}")
|
||
lines.append(f"[INIT] MLP: {' → '.join([str(mlp_in)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
|
||
return _SpatialNet(C, H, W, n_state, n_layers, layer_size, n_actions)
|
||
|
||
|
||
def _build_flat(obs_size, n_layers, layer_size, n_actions, lines):
|
||
lines.append("[INIT] spatial=False → using MLP architecture")
|
||
lines.append("[INIT] Rationale: the board encodes UI/status rather than a spatial scene;")
|
||
lines.append("[INIT] a flat MLP over the full observation is sufficient.")
|
||
lines.append(f"[INIT] MLP: {' → '.join([str(obs_size)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
|
||
return _FlatNet(obs_size, n_layers, layer_size, n_actions)
|
||
|
||
|
||
class _SpatialNet(nn.Module):
|
||
def __init__(self, C, H, W, n_state, n_layers, layer_size, n_actions):
|
||
super().__init__()
|
||
self.C, self.H, self.W = C, H, W
|
||
self.n_board = C * H * W
|
||
self.conv = nn.Sequential(
|
||
nn.Conv2d(C, 32, kernel_size=3, padding=1),
|
||
nn.ReLU(),
|
||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||
nn.ReLU(),
|
||
)
|
||
conv_out = 64 * H * W
|
||
mlp_in = conv_out + n_state
|
||
layers: list[nn.Module] = []
|
||
for i in range(n_layers):
|
||
in_size = mlp_in if i == 0 else layer_size
|
||
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
|
||
layers.append(nn.Linear(layer_size, n_actions))
|
||
self.mlp = nn.Sequential(*layers)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
board = x[:, :self.n_board].reshape(-1, self.C, self.H, self.W)
|
||
state = x[:, self.n_board:]
|
||
conv_out = self.conv(board).flatten(start_dim=1)
|
||
return self.mlp(torch.cat([conv_out, state], dim=1))
|
||
|
||
|
||
class _FlatNet(nn.Module):
|
||
def __init__(self, obs_size, n_layers, layer_size, n_actions):
|
||
super().__init__()
|
||
layers: list[nn.Module] = []
|
||
for i in range(n_layers):
|
||
in_size = obs_size if i == 0 else layer_size
|
||
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
|
||
layers.append(nn.Linear(layer_size, n_actions))
|
||
self.net = nn.Sequential(*layers)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
return self.net(x)
|