96 lines
3.5 KiB
Python
96 lines
3.5 KiB
Python
import numpy as np
|
||
from retro_gamer.metadata import GameMetadata
|
||
|
||
|
||
def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
|
||
"""One-hot encode the board. Returns (H, W, C). Unknown characters → zero vector."""
|
||
char_to_idx = {c: i for i, c in enumerate(character_set)}
|
||
H = len(board_chars)
|
||
W = len(board_chars[0]) if board_chars else 0
|
||
C = len(character_set)
|
||
arr = np.zeros((H, W, C), dtype=np.float32)
|
||
for y, row in enumerate(board_chars):
|
||
for x, char in enumerate(row):
|
||
idx = char_to_idx.get(char)
|
||
if idx is not None:
|
||
arr[y, x, idx] = 1.0
|
||
return arr
|
||
|
||
|
||
def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
|
||
"""Extract selected keys from game.state into a 1D float array.
|
||
|
||
Scalar values contribute one element; list/tuple values are flattened.
|
||
"""
|
||
values: list[float] = []
|
||
for k in observe_state:
|
||
val = state.get(k, 0)
|
||
if isinstance(val, (list, tuple)):
|
||
values.extend(float(x) for x in val)
|
||
else:
|
||
values.append(float(val))
|
||
return np.array(values, dtype=np.float32)
|
||
|
||
|
||
def egocentric_board(
|
||
board_chars: list[list[str]],
|
||
player_pos: tuple[int, int],
|
||
radius: int,
|
||
) -> list[list[str]]:
|
||
"""Crop the board to a (2r+1)×(2r+1) window centred on player_pos.
|
||
|
||
Out-of-bounds cells are filled with a space (treated as empty by the
|
||
encoder). The resulting grid is always square with side 2*radius+1.
|
||
"""
|
||
H = len(board_chars)
|
||
W = len(board_chars[0]) if board_chars else 0
|
||
px, py = player_pos
|
||
result = []
|
||
for dy in range(-radius, radius + 1):
|
||
row = []
|
||
for dx in range(-radius, radius + 1):
|
||
src_x = px + dx
|
||
src_y = py + dy
|
||
if 0 <= src_x < W and 0 <= src_y < H:
|
||
row.append(board_chars[src_y][src_x])
|
||
else:
|
||
row.append(' ')
|
||
result.append(row)
|
||
return result
|
||
|
||
|
||
def encode_observation(
|
||
board_chars: list[list[str]],
|
||
state: dict,
|
||
metadata: GameMetadata,
|
||
observe_state: list[str],
|
||
player_pos: tuple[int, int] | None = None,
|
||
egocentric_radius: int | None = None,
|
||
board: bool = True,
|
||
) -> np.ndarray:
|
||
"""Encode board and/or selected state values into a flat 1D observation vector.
|
||
|
||
When *board* is True the board is encoded and prepended to the vector. If
|
||
player_pos and egocentric_radius are given the board is first cropped to a
|
||
(2r+1)×(2r+1) window centred on the player. For spatial games the board is
|
||
encoded channel-first (C, H, W) then flattened; for non-spatial games it is
|
||
encoded (H, W, C) then flattened. The state vector is appended at the end.
|
||
|
||
When *board* is False only the observe_state features are returned.
|
||
"""
|
||
if board:
|
||
if not metadata.character_set:
|
||
raise ValueError("character_set must be set before encoding observations")
|
||
if player_pos is not None and egocentric_radius is not None:
|
||
board_chars = egocentric_board(board_chars, player_pos, egocentric_radius)
|
||
board_enc = encode_board(board_chars, metadata.character_set) # (H, W, C)
|
||
if metadata.spatial:
|
||
board_vec = board_enc.transpose(2, 0, 1).flatten()
|
||
else:
|
||
board_vec = board_enc.flatten()
|
||
if observe_state:
|
||
return np.concatenate([board_vec, encode_state(state, observe_state)])
|
||
return board_vec
|
||
else:
|
||
return encode_state(state, observe_state)
|