from __future__ import annotations import random import numpy as np from typing import Callable from retro.input import ProgrammaticInput from retro.views.headless import HeadlessView from retro_gamer.metadata import GameMetadata from retro_gamer.observation import encode_observation class GameEnvironment: """Gym-style wrapper around a retro game for RL training.""" def __init__( self, game_factory: Callable, metadata: GameMetadata, observe_state: list[str] | None = None, egocentric: bool = False, egocentric_player: str | None = None, egocentric_radius: int | None = None, board: bool = True, observe_state_sizes: dict[str, int] | None = None, ): self.game_factory = game_factory self.metadata = metadata self.observe_state = observe_state or [] self.egocentric = egocentric self.egocentric_player = egocentric_player self.egocentric_radius = egocentric_radius self.board = board self.observe_state_sizes = observe_state_sizes or {} self.game = None self.view: HeadlessView | None = None self.inp: ProgrammaticInput | None = None self._prev_reward: float = 0.0 def reset(self) -> np.ndarray: """Create a fresh game episode and return the initial observation.""" self.inp = ProgrammaticInput() self.view = HeadlessView() self.game = self.game_factory() self.game.input_source = self.inp self.game.view = self.view self.game.start() self._prev_reward = float(self.game.state.get(self.metadata.reward, 0)) return self._observe() def step(self, action: str | None) -> tuple[np.ndarray, float, bool]: """Advance one turn. Returns (observation, reward, done).""" self.inp.press(action) self.game.step() obs = self._observe() reward = self._delta_reward() done = not self.game.playing return obs, reward, done def _observe(self) -> np.ndarray: state = dict(self.game.state) if self.observe_state_sizes: self._check_state_sizes(state) player_pos = None if self.egocentric and self.egocentric_player: agent = self.game.get_agent_by_name(self.egocentric_player) if agent is not None: player_pos = agent.position return encode_observation( self.view.board_characters, state, self.metadata, self.observe_state, player_pos=player_pos, egocentric_radius=self.egocentric_radius, board=self.board, ) def _check_state_sizes(self, state: dict): for key, expected in self.observe_state_sizes.items(): val = state.get(key) if val is None: actual = 0 elif isinstance(val, (list, tuple)): actual = len(val) else: actual = 1 if actual != expected: raise ValueError( f"State key '{key}' changed size during training:\n" f" Expected : {expected} (discovered at training start)\n" f" Got : {actual}\n\n" f"This means game.state['{key}'] has a different length in some\n" f"episodes than it had when training started. The neural network\n" f"has a fixed input size and cannot adapt to changing state shapes.\n\n" f"Fix: make sure create_game() always initializes '{key}' with a\n" f"fixed-length value before the game starts each episode.\n" f"For example, if '{key}' is a list of 9 values, it must always be\n" f"a list of exactly 9 values — never more, never fewer, never missing." ) def _delta_reward(self) -> float: current = float(self.game.state.get(self.metadata.reward, 0)) delta = current - self._prev_reward self._prev_reward = current return delta def discover_character_set(self, exploration_turns: int) -> list[str]: """Run random turns to discover the characters that appear on the board.""" self.reset() chars: set[str] = set() for _ in range(exploration_turns): for row in self.view.board_characters: chars.update(row) action = random.choice(self.metadata.actions + [None]) _, _, done = self.step(action) if done: self.reset() chars.discard(' ') return sorted(chars)