Files
retro-gamer/retro_gamer/env.py
2026-06-22 16:41:31 -04:00

119 lines
4.6 KiB
Python

from __future__ import annotations
import random
import numpy as np
from typing import Callable
from retro.input import ProgrammaticInput
from retro.views.headless import HeadlessView
from retro_gamer.metadata import GameMetadata
from retro_gamer.observation import encode_observation
class GameEnvironment:
"""Gym-style wrapper around a retro game for RL training."""
def __init__(
self,
game_factory: Callable,
metadata: GameMetadata,
observe_state: list[str] | None = None,
egocentric: bool = False,
egocentric_player: str | None = None,
egocentric_radius: int | None = None,
board: bool = True,
observe_state_sizes: dict[str, int] | None = None,
):
self.game_factory = game_factory
self.metadata = metadata
self.observe_state = observe_state or []
self.egocentric = egocentric
self.egocentric_player = egocentric_player
self.egocentric_radius = egocentric_radius
self.board = board
self.observe_state_sizes = observe_state_sizes or {}
self.game = None
self.view: HeadlessView | None = None
self.inp: ProgrammaticInput | None = None
self._prev_reward: float = 0.0
def reset(self) -> np.ndarray:
"""Create a fresh game episode and return the initial observation."""
self.inp = ProgrammaticInput()
self.view = HeadlessView()
self.game = self.game_factory()
self.game.input_source = self.inp
self.game.view = self.view
self.game.start()
self._prev_reward = float(self.game.state.get(self.metadata.reward, 0))
return self._observe()
def step(self, action: str | None) -> tuple[np.ndarray, float, bool]:
"""Advance one turn. Returns (observation, reward, done)."""
self.inp.press(action)
self.game.step()
obs = self._observe()
reward = self._delta_reward()
done = not self.game.playing
return obs, reward, done
def _observe(self) -> np.ndarray:
state = dict(self.game.state)
if self.observe_state_sizes:
self._check_state_sizes(state)
player_pos = None
if self.egocentric and self.egocentric_player:
agent = self.game.get_agent_by_name(self.egocentric_player)
if agent is not None:
player_pos = agent.position
return encode_observation(
self.view.board_characters,
state,
self.metadata,
self.observe_state,
player_pos=player_pos,
egocentric_radius=self.egocentric_radius,
board=self.board,
)
def _check_state_sizes(self, state: dict):
for key, expected in self.observe_state_sizes.items():
val = state.get(key)
if val is None:
actual = 0
elif isinstance(val, (list, tuple)):
actual = len(val)
else:
actual = 1
if actual != expected:
raise ValueError(
f"State key '{key}' changed size during training:\n"
f" Expected : {expected} (discovered at training start)\n"
f" Got : {actual}\n\n"
f"This means game.state['{key}'] has a different length in some\n"
f"episodes than it had when training started. The neural network\n"
f"has a fixed input size and cannot adapt to changing state shapes.\n\n"
f"Fix: make sure create_game() always initializes '{key}' with a\n"
f"fixed-length value before the game starts each episode.\n"
f"For example, if '{key}' is a list of 9 values, it must always be\n"
f"a list of exactly 9 values — never more, never fewer, never missing."
)
def _delta_reward(self) -> float:
current = float(self.game.state.get(self.metadata.reward, 0))
delta = current - self._prev_reward
self._prev_reward = current
return delta
def discover_character_set(self, exploration_turns: int) -> list[str]:
"""Run random turns to discover the characters that appear on the board."""
self.reset()
chars: set[str] = set()
for _ in range(exploration_turns):
for row in self.view.board_characters:
chars.update(row)
action = random.choice(self.metadata.actions + [None])
_, _, done = self.step(action)
if done:
self.reset()
chars.discard(' ')
return sorted(chars)