Updates across the board

This commit is contained in:
Chris Proctor
2026-06-22 16:41:31 -04:00
parent 5ca97dc5d0
commit 73624d1a0c
33 changed files with 3104 additions and 643 deletions

View File

@@ -13,32 +13,36 @@ def build_network(
Returns (model, rationale) where rationale is a multi-line string
describing the architecture and the reasoning behind each choice.
"""
n_layers = hyperparams.get('n_layers', 2)
layer_size = hyperparams.get('layer_size', 128)
C = len(metadata.character_set)
bw, bh = metadata.board_size
W, H = bw, bh
n_state = len(metadata.observe_state)
hidden_sizes = hyperparams.get('hidden_sizes', [512, 256])
n_state = metadata.extras_size
n_actions = metadata.n_actions
lines = []
lines.append("[INIT] === Network Architecture ===")
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
lines.append(f"[INIT] Observed state keys: {n_state} | Actions (incl. no-op): {n_actions}")
if metadata.spatial:
model = _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines)
if metadata.board:
C = len(metadata.character_set)
bw, bh = metadata.board_size
W, H = bw, bh
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
lines.append(f"[INIT] Observed state features: {n_state} | Actions (incl. no-op): {n_actions}")
if metadata.spatial:
model = _build_spatial(C, H, W, n_state, hidden_sizes, n_actions, lines)
else:
obs_size = C * W * H + n_state
model = _build_flat(obs_size, hidden_sizes, n_actions, lines)
else:
obs_size = C * W * H + n_state
model = _build_flat(obs_size, n_layers, layer_size, n_actions, lines)
lines.append(f"[INIT] Board: disabled (board=false, state-only observation)")
lines.append(f"[INIT] Observed state features: {n_state} | Actions (incl. no-op): {n_actions}")
model = _build_flat(n_state, hidden_sizes, n_actions, lines)
lines.append(f"[INIT] Hidden layers: {n_layers} | Layer width: {layer_size}")
lines.append(f"[INIT] Hidden layers: {len(hidden_sizes)} | Layer sizes: {hidden_sizes}")
lines.append(f"[INIT] Output: {n_actions} Q-values")
lines.append(f"[INIT] Actions: {metadata.actions} + (no-op)")
return model, '\n'.join(lines)
def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
def _build_spatial(C, H, W, n_state, hidden_sizes, n_actions, lines):
lines.append("[INIT] spatial=True → using CNN architecture")
lines.append("[INIT] Rationale: the board is a 2-D spatial scene; a CNN captures")
lines.append("[INIT] local patterns (walls, items nearby) more efficiently than an MLP.")
@@ -47,20 +51,20 @@ def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
lines.append(f"[INIT] CNN output: 64 channels × {H}×{W} = {conv_out} features (flattened)")
mlp_in = conv_out + n_state
lines.append(f"[INIT] MLP head input: {conv_out} (conv) + {n_state} (state) = {mlp_in}")
lines.append(f"[INIT] MLP: {''.join([str(mlp_in)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
return _SpatialNet(C, H, W, n_state, n_layers, layer_size, n_actions)
lines.append(f"[INIT] MLP: {''.join([str(mlp_in)] + [str(s) for s in hidden_sizes] + [str(n_actions)])}")
return _SpatialNet(C, H, W, n_state, hidden_sizes, n_actions)
def _build_flat(obs_size, n_layers, layer_size, n_actions, lines):
def _build_flat(obs_size, hidden_sizes, n_actions, lines):
lines.append("[INIT] spatial=False → using MLP architecture")
lines.append("[INIT] Rationale: the board encodes UI/status rather than a spatial scene;")
lines.append("[INIT] a flat MLP over the full observation is sufficient.")
lines.append(f"[INIT] MLP: {''.join([str(obs_size)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
return _FlatNet(obs_size, n_layers, layer_size, n_actions)
lines.append(f"[INIT] MLP: {''.join([str(obs_size)] + [str(s) for s in hidden_sizes] + [str(n_actions)])}")
return _FlatNet(obs_size, hidden_sizes, n_actions)
class _SpatialNet(nn.Module):
def __init__(self, C, H, W, n_state, n_layers, layer_size, n_actions):
def __init__(self, C, H, W, n_state, hidden_sizes, n_actions):
super().__init__()
self.C, self.H, self.W = C, H, W
self.n_board = C * H * W
@@ -73,10 +77,11 @@ class _SpatialNet(nn.Module):
conv_out = 64 * H * W
mlp_in = conv_out + n_state
layers: list[nn.Module] = []
for i in range(n_layers):
in_size = mlp_in if i == 0 else layer_size
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
layers.append(nn.Linear(layer_size, n_actions))
prev = mlp_in
for size in hidden_sizes:
layers += [nn.Linear(prev, size), nn.ReLU()]
prev = size
layers.append(nn.Linear(prev, n_actions))
self.mlp = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -87,13 +92,14 @@ class _SpatialNet(nn.Module):
class _FlatNet(nn.Module):
def __init__(self, obs_size, n_layers, layer_size, n_actions):
def __init__(self, obs_size, hidden_sizes, n_actions):
super().__init__()
layers: list[nn.Module] = []
for i in range(n_layers):
in_size = obs_size if i == 0 else layer_size
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
layers.append(nn.Linear(layer_size, n_actions))
prev = obs_size
for size in hidden_sizes:
layers += [nn.Linear(prev, size), nn.ReLU()]
prev = size
layers.append(nn.Linear(prev, n_actions))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor: