Files
retro-gamer/retro_gamer/observation.py
Chris Proctor 5ca97dc5d0 Initial commit
2026-05-08 14:07:17 -04:00

50 lines
1.9 KiB
Python

import numpy as np
from retro_gamer.metadata import GameMetadata
def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
"""One-hot encode the board.
Returns an array of shape (H, W, C) where C = len(character_set).
Unknown characters produce a zero vector.
"""
char_to_idx = {c: i for i, c in enumerate(character_set)}
H = len(board_chars)
W = len(board_chars[0]) if board_chars else 0
C = len(character_set)
arr = np.zeros((H, W, C), dtype=np.float32)
for y, row in enumerate(board_chars):
for x, char in enumerate(row):
idx = char_to_idx.get(char)
if idx is not None:
arr[y, x, idx] = 1.0
return arr
def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
"""Extract observed state keys into a 1D float array."""
return np.array([float(state.get(k, 0)) for k in observe_state], dtype=np.float32)
def encode_observation(
board_chars: list[list[str]],
state: dict,
metadata: GameMetadata,
) -> np.ndarray:
"""Encode board + state into a flat 1D observation vector.
For spatial games the board is encoded channel-first (C, H, W) then flattened,
so the network can reshape it back for CNN processing. For non-spatial games the
board is encoded (H, W, C) then flattened.
The state vector is appended at the end in both cases.
"""
if not metadata.character_set:
raise ValueError("character_set must be set before encoding observations")
board = encode_board(board_chars, metadata.character_set) # (H, W, C)
if metadata.spatial:
board_vec = board.transpose(2, 0, 1).flatten() # C*H*W, channel-first
else:
board_vec = board.flatten() # H*W*C
state_vec = encode_state(state, metadata.observe_state)
return np.concatenate([board_vec, state_vec])