50 lines
1.9 KiB
Python
50 lines
1.9 KiB
Python
import numpy as np
|
|
from retro_gamer.metadata import GameMetadata
|
|
|
|
|
|
def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
|
|
"""One-hot encode the board.
|
|
|
|
Returns an array of shape (H, W, C) where C = len(character_set).
|
|
Unknown characters produce a zero vector.
|
|
"""
|
|
char_to_idx = {c: i for i, c in enumerate(character_set)}
|
|
H = len(board_chars)
|
|
W = len(board_chars[0]) if board_chars else 0
|
|
C = len(character_set)
|
|
arr = np.zeros((H, W, C), dtype=np.float32)
|
|
for y, row in enumerate(board_chars):
|
|
for x, char in enumerate(row):
|
|
idx = char_to_idx.get(char)
|
|
if idx is not None:
|
|
arr[y, x, idx] = 1.0
|
|
return arr
|
|
|
|
|
|
def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
|
|
"""Extract observed state keys into a 1D float array."""
|
|
return np.array([float(state.get(k, 0)) for k in observe_state], dtype=np.float32)
|
|
|
|
|
|
def encode_observation(
|
|
board_chars: list[list[str]],
|
|
state: dict,
|
|
metadata: GameMetadata,
|
|
) -> np.ndarray:
|
|
"""Encode board + state into a flat 1D observation vector.
|
|
|
|
For spatial games the board is encoded channel-first (C, H, W) then flattened,
|
|
so the network can reshape it back for CNN processing. For non-spatial games the
|
|
board is encoded (H, W, C) then flattened.
|
|
The state vector is appended at the end in both cases.
|
|
"""
|
|
if not metadata.character_set:
|
|
raise ValueError("character_set must be set before encoding observations")
|
|
board = encode_board(board_chars, metadata.character_set) # (H, W, C)
|
|
if metadata.spatial:
|
|
board_vec = board.transpose(2, 0, 1).flatten() # C*H*W, channel-first
|
|
else:
|
|
board_vec = board.flatten() # H*W*C
|
|
state_vec = encode_state(state, metadata.observe_state)
|
|
return np.concatenate([board_vec, state_vec])
|