Initial commit
This commit is contained in:
49
retro_gamer/observation.py
Normal file
49
retro_gamer/observation.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import numpy as np
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
|
||||
|
||||
def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
|
||||
"""One-hot encode the board.
|
||||
|
||||
Returns an array of shape (H, W, C) where C = len(character_set).
|
||||
Unknown characters produce a zero vector.
|
||||
"""
|
||||
char_to_idx = {c: i for i, c in enumerate(character_set)}
|
||||
H = len(board_chars)
|
||||
W = len(board_chars[0]) if board_chars else 0
|
||||
C = len(character_set)
|
||||
arr = np.zeros((H, W, C), dtype=np.float32)
|
||||
for y, row in enumerate(board_chars):
|
||||
for x, char in enumerate(row):
|
||||
idx = char_to_idx.get(char)
|
||||
if idx is not None:
|
||||
arr[y, x, idx] = 1.0
|
||||
return arr
|
||||
|
||||
|
||||
def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
|
||||
"""Extract observed state keys into a 1D float array."""
|
||||
return np.array([float(state.get(k, 0)) for k in observe_state], dtype=np.float32)
|
||||
|
||||
|
||||
def encode_observation(
|
||||
board_chars: list[list[str]],
|
||||
state: dict,
|
||||
metadata: GameMetadata,
|
||||
) -> np.ndarray:
|
||||
"""Encode board + state into a flat 1D observation vector.
|
||||
|
||||
For spatial games the board is encoded channel-first (C, H, W) then flattened,
|
||||
so the network can reshape it back for CNN processing. For non-spatial games the
|
||||
board is encoded (H, W, C) then flattened.
|
||||
The state vector is appended at the end in both cases.
|
||||
"""
|
||||
if not metadata.character_set:
|
||||
raise ValueError("character_set must be set before encoding observations")
|
||||
board = encode_board(board_chars, metadata.character_set) # (H, W, C)
|
||||
if metadata.spatial:
|
||||
board_vec = board.transpose(2, 0, 1).flatten() # C*H*W, channel-first
|
||||
else:
|
||||
board_vec = board.flatten() # H*W*C
|
||||
state_vec = encode_state(state, metadata.observe_state)
|
||||
return np.concatenate([board_vec, state_vec])
|
||||
Reference in New Issue
Block a user