Updates across the board
This commit is contained in:
@@ -13,32 +13,36 @@ def build_network(
|
||||
Returns (model, rationale) where rationale is a multi-line string
|
||||
describing the architecture and the reasoning behind each choice.
|
||||
"""
|
||||
n_layers = hyperparams.get('n_layers', 2)
|
||||
layer_size = hyperparams.get('layer_size', 128)
|
||||
C = len(metadata.character_set)
|
||||
bw, bh = metadata.board_size
|
||||
W, H = bw, bh
|
||||
n_state = len(metadata.observe_state)
|
||||
hidden_sizes = hyperparams.get('hidden_sizes', [512, 256])
|
||||
n_state = metadata.extras_size
|
||||
n_actions = metadata.n_actions
|
||||
|
||||
lines = []
|
||||
lines.append("[INIT] === Network Architecture ===")
|
||||
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
|
||||
lines.append(f"[INIT] Observed state keys: {n_state} | Actions (incl. no-op): {n_actions}")
|
||||
|
||||
if metadata.spatial:
|
||||
model = _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines)
|
||||
if metadata.board:
|
||||
C = len(metadata.character_set)
|
||||
bw, bh = metadata.board_size
|
||||
W, H = bw, bh
|
||||
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
|
||||
lines.append(f"[INIT] Observed state features: {n_state} | Actions (incl. no-op): {n_actions}")
|
||||
if metadata.spatial:
|
||||
model = _build_spatial(C, H, W, n_state, hidden_sizes, n_actions, lines)
|
||||
else:
|
||||
obs_size = C * W * H + n_state
|
||||
model = _build_flat(obs_size, hidden_sizes, n_actions, lines)
|
||||
else:
|
||||
obs_size = C * W * H + n_state
|
||||
model = _build_flat(obs_size, n_layers, layer_size, n_actions, lines)
|
||||
lines.append(f"[INIT] Board: disabled (board=false, state-only observation)")
|
||||
lines.append(f"[INIT] Observed state features: {n_state} | Actions (incl. no-op): {n_actions}")
|
||||
model = _build_flat(n_state, hidden_sizes, n_actions, lines)
|
||||
|
||||
lines.append(f"[INIT] Hidden layers: {n_layers} | Layer width: {layer_size}")
|
||||
lines.append(f"[INIT] Hidden layers: {len(hidden_sizes)} | Layer sizes: {hidden_sizes}")
|
||||
lines.append(f"[INIT] Output: {n_actions} Q-values")
|
||||
lines.append(f"[INIT] Actions: {metadata.actions} + (no-op)")
|
||||
return model, '\n'.join(lines)
|
||||
|
||||
|
||||
def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
|
||||
def _build_spatial(C, H, W, n_state, hidden_sizes, n_actions, lines):
|
||||
lines.append("[INIT] spatial=True → using CNN architecture")
|
||||
lines.append("[INIT] Rationale: the board is a 2-D spatial scene; a CNN captures")
|
||||
lines.append("[INIT] local patterns (walls, items nearby) more efficiently than an MLP.")
|
||||
@@ -47,20 +51,20 @@ def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
|
||||
lines.append(f"[INIT] CNN output: 64 channels × {H}×{W} = {conv_out} features (flattened)")
|
||||
mlp_in = conv_out + n_state
|
||||
lines.append(f"[INIT] MLP head input: {conv_out} (conv) + {n_state} (state) = {mlp_in}")
|
||||
lines.append(f"[INIT] MLP: {' → '.join([str(mlp_in)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
|
||||
return _SpatialNet(C, H, W, n_state, n_layers, layer_size, n_actions)
|
||||
lines.append(f"[INIT] MLP: {' → '.join([str(mlp_in)] + [str(s) for s in hidden_sizes] + [str(n_actions)])}")
|
||||
return _SpatialNet(C, H, W, n_state, hidden_sizes, n_actions)
|
||||
|
||||
|
||||
def _build_flat(obs_size, n_layers, layer_size, n_actions, lines):
|
||||
def _build_flat(obs_size, hidden_sizes, n_actions, lines):
|
||||
lines.append("[INIT] spatial=False → using MLP architecture")
|
||||
lines.append("[INIT] Rationale: the board encodes UI/status rather than a spatial scene;")
|
||||
lines.append("[INIT] a flat MLP over the full observation is sufficient.")
|
||||
lines.append(f"[INIT] MLP: {' → '.join([str(obs_size)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
|
||||
return _FlatNet(obs_size, n_layers, layer_size, n_actions)
|
||||
lines.append(f"[INIT] MLP: {' → '.join([str(obs_size)] + [str(s) for s in hidden_sizes] + [str(n_actions)])}")
|
||||
return _FlatNet(obs_size, hidden_sizes, n_actions)
|
||||
|
||||
|
||||
class _SpatialNet(nn.Module):
|
||||
def __init__(self, C, H, W, n_state, n_layers, layer_size, n_actions):
|
||||
def __init__(self, C, H, W, n_state, hidden_sizes, n_actions):
|
||||
super().__init__()
|
||||
self.C, self.H, self.W = C, H, W
|
||||
self.n_board = C * H * W
|
||||
@@ -73,10 +77,11 @@ class _SpatialNet(nn.Module):
|
||||
conv_out = 64 * H * W
|
||||
mlp_in = conv_out + n_state
|
||||
layers: list[nn.Module] = []
|
||||
for i in range(n_layers):
|
||||
in_size = mlp_in if i == 0 else layer_size
|
||||
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
|
||||
layers.append(nn.Linear(layer_size, n_actions))
|
||||
prev = mlp_in
|
||||
for size in hidden_sizes:
|
||||
layers += [nn.Linear(prev, size), nn.ReLU()]
|
||||
prev = size
|
||||
layers.append(nn.Linear(prev, n_actions))
|
||||
self.mlp = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -87,13 +92,14 @@ class _SpatialNet(nn.Module):
|
||||
|
||||
|
||||
class _FlatNet(nn.Module):
|
||||
def __init__(self, obs_size, n_layers, layer_size, n_actions):
|
||||
def __init__(self, obs_size, hidden_sizes, n_actions):
|
||||
super().__init__()
|
||||
layers: list[nn.Module] = []
|
||||
for i in range(n_layers):
|
||||
in_size = obs_size if i == 0 else layer_size
|
||||
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
|
||||
layers.append(nn.Linear(layer_size, n_actions))
|
||||
prev = obs_size
|
||||
for size in hidden_sizes:
|
||||
layers += [nn.Linear(prev, size), nn.ReLU()]
|
||||
prev = size
|
||||
layers.append(nn.Linear(prev, n_actions))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user