Files
retro-gamer/retro_gamer/observation.py
2026-06-22 16:41:31 -04:00

96 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
from retro_gamer.metadata import GameMetadata
def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
"""One-hot encode the board. Returns (H, W, C). Unknown characters → zero vector."""
char_to_idx = {c: i for i, c in enumerate(character_set)}
H = len(board_chars)
W = len(board_chars[0]) if board_chars else 0
C = len(character_set)
arr = np.zeros((H, W, C), dtype=np.float32)
for y, row in enumerate(board_chars):
for x, char in enumerate(row):
idx = char_to_idx.get(char)
if idx is not None:
arr[y, x, idx] = 1.0
return arr
def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
"""Extract selected keys from game.state into a 1D float array.
Scalar values contribute one element; list/tuple values are flattened.
"""
values: list[float] = []
for k in observe_state:
val = state.get(k, 0)
if isinstance(val, (list, tuple)):
values.extend(float(x) for x in val)
else:
values.append(float(val))
return np.array(values, dtype=np.float32)
def egocentric_board(
board_chars: list[list[str]],
player_pos: tuple[int, int],
radius: int,
) -> list[list[str]]:
"""Crop the board to a (2r+1)×(2r+1) window centred on player_pos.
Out-of-bounds cells are filled with a space (treated as empty by the
encoder). The resulting grid is always square with side 2*radius+1.
"""
H = len(board_chars)
W = len(board_chars[0]) if board_chars else 0
px, py = player_pos
result = []
for dy in range(-radius, radius + 1):
row = []
for dx in range(-radius, radius + 1):
src_x = px + dx
src_y = py + dy
if 0 <= src_x < W and 0 <= src_y < H:
row.append(board_chars[src_y][src_x])
else:
row.append(' ')
result.append(row)
return result
def encode_observation(
board_chars: list[list[str]],
state: dict,
metadata: GameMetadata,
observe_state: list[str],
player_pos: tuple[int, int] | None = None,
egocentric_radius: int | None = None,
board: bool = True,
) -> np.ndarray:
"""Encode board and/or selected state values into a flat 1D observation vector.
When *board* is True the board is encoded and prepended to the vector. If
player_pos and egocentric_radius are given the board is first cropped to a
(2r+1)×(2r+1) window centred on the player. For spatial games the board is
encoded channel-first (C, H, W) then flattened; for non-spatial games it is
encoded (H, W, C) then flattened. The state vector is appended at the end.
When *board* is False only the observe_state features are returned.
"""
if board:
if not metadata.character_set:
raise ValueError("character_set must be set before encoding observations")
if player_pos is not None and egocentric_radius is not None:
board_chars = egocentric_board(board_chars, player_pos, egocentric_radius)
board_enc = encode_board(board_chars, metadata.character_set) # (H, W, C)
if metadata.spatial:
board_vec = board_enc.transpose(2, 0, 1).flatten()
else:
board_vec = board_enc.flatten()
if observe_state:
return np.concatenate([board_vec, encode_state(state, observe_state)])
return board_vec
else:
return encode_state(state, observe_state)