from __future__ import annotations import random import numpy as np from typing import Callable from retro.input import ProgrammaticInput from retro.views.headless import HeadlessView from retro_gamer.metadata import GameMetadata from retro_gamer.observation import encode_observation class GameEnvironment: """Gym-style wrapper around a retro Game for RL training. Provides reset() / step(action) / observe(), managing one training episode at a time. The game is restarted by calling the factory function on each reset. """ def __init__(self, game_factory: Callable, metadata: GameMetadata): self.game_factory = game_factory self.metadata = metadata self.game = None self.view: HeadlessView | None = None self.inp: ProgrammaticInput | None = None self._prev_reward: float = 0.0 def reset(self) -> np.ndarray: """Create a fresh game episode and return the initial observation.""" self.inp = ProgrammaticInput() self.view = HeadlessView() self.game = self.game_factory() self.game.input_source = self.inp self.game.view = self.view self.game.start() self._prev_reward = float(self.game.state.get(self.metadata.reward, 0)) return self._observe() def step(self, action: str | None) -> tuple[np.ndarray, float, bool]: """Advance one turn with the given action. action: a keystroke string (e.g. 'KEY_RIGHT') or None for no-op. Returns (observation, reward, done). Reward is the change in the reward state key since the previous step. """ self.inp.press(action) self.game.step() obs = self._observe() reward = self._delta_reward() done = not self.game.playing return obs, reward, done def _observe(self) -> np.ndarray: return encode_observation( self.view.board_characters, dict(self.game.state), self.metadata, ) def _delta_reward(self) -> float: current = float(self.game.state.get(self.metadata.reward, 0)) delta = current - self._prev_reward self._prev_reward = current return delta def discover_character_set(self, exploration_turns: int) -> list[str]: """Run random turns to discover the characters that appear on the board. Returns the sorted character list (excluding space). """ obs = self.reset() chars: set[str] = set() for _ in range(exploration_turns): for row in self.view.board_characters: chars.update(row) action = random.choice(self.metadata.actions + [None]) _, _, done = self.step(action) if done: self.reset() chars.discard(' ') return sorted(chars)