Updates across the board

This commit is contained in:
Chris Proctor
2026-06-22 16:41:31 -04:00
parent 5ca97dc5d0
commit 73624d1a0c
33 changed files with 3104 additions and 643 deletions

View File

@@ -3,11 +3,7 @@ from retro_gamer.metadata import GameMetadata
def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
"""One-hot encode the board.
Returns an array of shape (H, W, C) where C = len(character_set).
Unknown characters produce a zero vector.
"""
"""One-hot encode the board. Returns (H, W, C). Unknown characters → zero vector."""
char_to_idx = {c: i for i, c in enumerate(character_set)}
H = len(board_chars)
W = len(board_chars[0]) if board_chars else 0
@@ -22,28 +18,78 @@ def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.n
def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
"""Extract observed state keys into a 1D float array."""
return np.array([float(state.get(k, 0)) for k in observe_state], dtype=np.float32)
"""Extract selected keys from game.state into a 1D float array.
Scalar values contribute one element; list/tuple values are flattened.
"""
values: list[float] = []
for k in observe_state:
val = state.get(k, 0)
if isinstance(val, (list, tuple)):
values.extend(float(x) for x in val)
else:
values.append(float(val))
return np.array(values, dtype=np.float32)
def egocentric_board(
board_chars: list[list[str]],
player_pos: tuple[int, int],
radius: int,
) -> list[list[str]]:
"""Crop the board to a (2r+1)×(2r+1) window centred on player_pos.
Out-of-bounds cells are filled with a space (treated as empty by the
encoder). The resulting grid is always square with side 2*radius+1.
"""
H = len(board_chars)
W = len(board_chars[0]) if board_chars else 0
px, py = player_pos
result = []
for dy in range(-radius, radius + 1):
row = []
for dx in range(-radius, radius + 1):
src_x = px + dx
src_y = py + dy
if 0 <= src_x < W and 0 <= src_y < H:
row.append(board_chars[src_y][src_x])
else:
row.append(' ')
result.append(row)
return result
def encode_observation(
board_chars: list[list[str]],
state: dict,
metadata: GameMetadata,
observe_state: list[str],
player_pos: tuple[int, int] | None = None,
egocentric_radius: int | None = None,
board: bool = True,
) -> np.ndarray:
"""Encode board + state into a flat 1D observation vector.
"""Encode board and/or selected state values into a flat 1D observation vector.
For spatial games the board is encoded channel-first (C, H, W) then flattened,
so the network can reshape it back for CNN processing. For non-spatial games the
board is encoded (H, W, C) then flattened.
The state vector is appended at the end in both cases.
When *board* is True the board is encoded and prepended to the vector. If
player_pos and egocentric_radius are given the board is first cropped to a
(2r+1)×(2r+1) window centred on the player. For spatial games the board is
encoded channel-first (C, H, W) then flattened; for non-spatial games it is
encoded (H, W, C) then flattened. The state vector is appended at the end.
When *board* is False only the observe_state features are returned.
"""
if not metadata.character_set:
raise ValueError("character_set must be set before encoding observations")
board = encode_board(board_chars, metadata.character_set) # (H, W, C)
if metadata.spatial:
board_vec = board.transpose(2, 0, 1).flatten() # C*H*W, channel-first
if board:
if not metadata.character_set:
raise ValueError("character_set must be set before encoding observations")
if player_pos is not None and egocentric_radius is not None:
board_chars = egocentric_board(board_chars, player_pos, egocentric_radius)
board_enc = encode_board(board_chars, metadata.character_set) # (H, W, C)
if metadata.spatial:
board_vec = board_enc.transpose(2, 0, 1).flatten()
else:
board_vec = board_enc.flatten()
if observe_state:
return np.concatenate([board_vec, encode_state(state, observe_state)])
return board_vec
else:
board_vec = board.flatten() # H*W*C
state_vec = encode_state(state, metadata.observe_state)
return np.concatenate([board_vec, state_vec])
return encode_state(state, observe_state)