Initial commit
This commit is contained in:
100
retro_gamer/network.py
Normal file
100
retro_gamer/network.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
|
||||
|
||||
def build_network(
|
||||
metadata: GameMetadata,
|
||||
hyperparams: dict,
|
||||
) -> tuple[nn.Module, str]:
|
||||
"""Build a Q-network from game metadata and hyperparameters.
|
||||
|
||||
Returns (model, rationale) where rationale is a multi-line string
|
||||
describing the architecture and the reasoning behind each choice.
|
||||
"""
|
||||
n_layers = hyperparams.get('n_layers', 2)
|
||||
layer_size = hyperparams.get('layer_size', 128)
|
||||
C = len(metadata.character_set)
|
||||
bw, bh = metadata.board_size
|
||||
W, H = bw, bh
|
||||
n_state = len(metadata.observe_state)
|
||||
n_actions = metadata.n_actions
|
||||
|
||||
lines = []
|
||||
lines.append("[INIT] === Network Architecture ===")
|
||||
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
|
||||
lines.append(f"[INIT] Observed state keys: {n_state} | Actions (incl. no-op): {n_actions}")
|
||||
|
||||
if metadata.spatial:
|
||||
model = _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines)
|
||||
else:
|
||||
obs_size = C * W * H + n_state
|
||||
model = _build_flat(obs_size, n_layers, layer_size, n_actions, lines)
|
||||
|
||||
lines.append(f"[INIT] Hidden layers: {n_layers} | Layer width: {layer_size}")
|
||||
lines.append(f"[INIT] Output: {n_actions} Q-values")
|
||||
lines.append(f"[INIT] Actions: {metadata.actions} + (no-op)")
|
||||
return model, '\n'.join(lines)
|
||||
|
||||
|
||||
def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
|
||||
lines.append("[INIT] spatial=True → using CNN architecture")
|
||||
lines.append("[INIT] Rationale: the board is a 2-D spatial scene; a CNN captures")
|
||||
lines.append("[INIT] local patterns (walls, items nearby) more efficiently than an MLP.")
|
||||
lines.append(f"[INIT] CNN: Conv2d({C}→32, k=3, pad=1) → ReLU → Conv2d(32→64, k=3, pad=1) → ReLU")
|
||||
conv_out = 64 * H * W # padding=1 preserves spatial dims
|
||||
lines.append(f"[INIT] CNN output: 64 channels × {H}×{W} = {conv_out} features (flattened)")
|
||||
mlp_in = conv_out + n_state
|
||||
lines.append(f"[INIT] MLP head input: {conv_out} (conv) + {n_state} (state) = {mlp_in}")
|
||||
lines.append(f"[INIT] MLP: {' → '.join([str(mlp_in)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
|
||||
return _SpatialNet(C, H, W, n_state, n_layers, layer_size, n_actions)
|
||||
|
||||
|
||||
def _build_flat(obs_size, n_layers, layer_size, n_actions, lines):
|
||||
lines.append("[INIT] spatial=False → using MLP architecture")
|
||||
lines.append("[INIT] Rationale: the board encodes UI/status rather than a spatial scene;")
|
||||
lines.append("[INIT] a flat MLP over the full observation is sufficient.")
|
||||
lines.append(f"[INIT] MLP: {' → '.join([str(obs_size)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
|
||||
return _FlatNet(obs_size, n_layers, layer_size, n_actions)
|
||||
|
||||
|
||||
class _SpatialNet(nn.Module):
|
||||
def __init__(self, C, H, W, n_state, n_layers, layer_size, n_actions):
|
||||
super().__init__()
|
||||
self.C, self.H, self.W = C, H, W
|
||||
self.n_board = C * H * W
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(C, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
)
|
||||
conv_out = 64 * H * W
|
||||
mlp_in = conv_out + n_state
|
||||
layers: list[nn.Module] = []
|
||||
for i in range(n_layers):
|
||||
in_size = mlp_in if i == 0 else layer_size
|
||||
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
|
||||
layers.append(nn.Linear(layer_size, n_actions))
|
||||
self.mlp = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
board = x[:, :self.n_board].reshape(-1, self.C, self.H, self.W)
|
||||
state = x[:, self.n_board:]
|
||||
conv_out = self.conv(board).flatten(start_dim=1)
|
||||
return self.mlp(torch.cat([conv_out, state], dim=1))
|
||||
|
||||
|
||||
class _FlatNet(nn.Module):
|
||||
def __init__(self, obs_size, n_layers, layer_size, n_actions):
|
||||
super().__init__()
|
||||
layers: list[nn.Module] = []
|
||||
for i in range(n_layers):
|
||||
in_size = obs_size if i == 0 else layer_size
|
||||
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
|
||||
layers.append(nn.Linear(layer_size, n_actions))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
Reference in New Issue
Block a user