Updates across the board
This commit is contained in:
@@ -3,11 +3,7 @@ from retro_gamer.metadata import GameMetadata
|
||||
|
||||
|
||||
def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
|
||||
"""One-hot encode the board.
|
||||
|
||||
Returns an array of shape (H, W, C) where C = len(character_set).
|
||||
Unknown characters produce a zero vector.
|
||||
"""
|
||||
"""One-hot encode the board. Returns (H, W, C). Unknown characters → zero vector."""
|
||||
char_to_idx = {c: i for i, c in enumerate(character_set)}
|
||||
H = len(board_chars)
|
||||
W = len(board_chars[0]) if board_chars else 0
|
||||
@@ -22,28 +18,78 @@ def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.n
|
||||
|
||||
|
||||
def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
|
||||
"""Extract observed state keys into a 1D float array."""
|
||||
return np.array([float(state.get(k, 0)) for k in observe_state], dtype=np.float32)
|
||||
"""Extract selected keys from game.state into a 1D float array.
|
||||
|
||||
Scalar values contribute one element; list/tuple values are flattened.
|
||||
"""
|
||||
values: list[float] = []
|
||||
for k in observe_state:
|
||||
val = state.get(k, 0)
|
||||
if isinstance(val, (list, tuple)):
|
||||
values.extend(float(x) for x in val)
|
||||
else:
|
||||
values.append(float(val))
|
||||
return np.array(values, dtype=np.float32)
|
||||
|
||||
|
||||
def egocentric_board(
|
||||
board_chars: list[list[str]],
|
||||
player_pos: tuple[int, int],
|
||||
radius: int,
|
||||
) -> list[list[str]]:
|
||||
"""Crop the board to a (2r+1)×(2r+1) window centred on player_pos.
|
||||
|
||||
Out-of-bounds cells are filled with a space (treated as empty by the
|
||||
encoder). The resulting grid is always square with side 2*radius+1.
|
||||
"""
|
||||
H = len(board_chars)
|
||||
W = len(board_chars[0]) if board_chars else 0
|
||||
px, py = player_pos
|
||||
result = []
|
||||
for dy in range(-radius, radius + 1):
|
||||
row = []
|
||||
for dx in range(-radius, radius + 1):
|
||||
src_x = px + dx
|
||||
src_y = py + dy
|
||||
if 0 <= src_x < W and 0 <= src_y < H:
|
||||
row.append(board_chars[src_y][src_x])
|
||||
else:
|
||||
row.append(' ')
|
||||
result.append(row)
|
||||
return result
|
||||
|
||||
|
||||
def encode_observation(
|
||||
board_chars: list[list[str]],
|
||||
state: dict,
|
||||
metadata: GameMetadata,
|
||||
observe_state: list[str],
|
||||
player_pos: tuple[int, int] | None = None,
|
||||
egocentric_radius: int | None = None,
|
||||
board: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""Encode board + state into a flat 1D observation vector.
|
||||
"""Encode board and/or selected state values into a flat 1D observation vector.
|
||||
|
||||
For spatial games the board is encoded channel-first (C, H, W) then flattened,
|
||||
so the network can reshape it back for CNN processing. For non-spatial games the
|
||||
board is encoded (H, W, C) then flattened.
|
||||
The state vector is appended at the end in both cases.
|
||||
When *board* is True the board is encoded and prepended to the vector. If
|
||||
player_pos and egocentric_radius are given the board is first cropped to a
|
||||
(2r+1)×(2r+1) window centred on the player. For spatial games the board is
|
||||
encoded channel-first (C, H, W) then flattened; for non-spatial games it is
|
||||
encoded (H, W, C) then flattened. The state vector is appended at the end.
|
||||
|
||||
When *board* is False only the observe_state features are returned.
|
||||
"""
|
||||
if not metadata.character_set:
|
||||
raise ValueError("character_set must be set before encoding observations")
|
||||
board = encode_board(board_chars, metadata.character_set) # (H, W, C)
|
||||
if metadata.spatial:
|
||||
board_vec = board.transpose(2, 0, 1).flatten() # C*H*W, channel-first
|
||||
if board:
|
||||
if not metadata.character_set:
|
||||
raise ValueError("character_set must be set before encoding observations")
|
||||
if player_pos is not None and egocentric_radius is not None:
|
||||
board_chars = egocentric_board(board_chars, player_pos, egocentric_radius)
|
||||
board_enc = encode_board(board_chars, metadata.character_set) # (H, W, C)
|
||||
if metadata.spatial:
|
||||
board_vec = board_enc.transpose(2, 0, 1).flatten()
|
||||
else:
|
||||
board_vec = board_enc.flatten()
|
||||
if observe_state:
|
||||
return np.concatenate([board_vec, encode_state(state, observe_state)])
|
||||
return board_vec
|
||||
else:
|
||||
board_vec = board.flatten() # H*W*C
|
||||
state_vec = encode_state(state, metadata.observe_state)
|
||||
return np.concatenate([board_vec, state_vec])
|
||||
return encode_state(state, observe_state)
|
||||
|
||||
Reference in New Issue
Block a user