from __future__ import annotations import tomllib import torch from pathlib import Path from blessed.keyboard import Keystroke from retro.input import ProgrammaticInput from retro.views.headless import HeadlessView from retro_gamer.metadata import GameMetadata from retro_gamer.observation import encode_observation class TrainedPolicy: """A trained retro-gamer model that can observe a game and choose actions. Load from a training run directory, then call ``get_action(game)`` from inside any agent's ``play_turn`` to get the model's recommended key. Example:: from retro_gamer import TrainedPolicy _ai = TrainedPolicy("runs/enemy/") class EnemyAgent: def play_turn(self, game): key = _ai.get_action(game) if key == 'KEY_RIGHT': self.direction = (1, 0) ... """ def __init__(self, run_dir: str | Path, checkpoint: str | None = None): from retro_gamer.network import build_network from retro_gamer.trainer import DEFAULTS run_dir = Path(run_dir) config_path = run_dir / 'config.toml' if not config_path.exists(): raise FileNotFoundError(f"No config.toml found in {run_dir}") with open(config_path, 'rb') as f: config = tomllib.load(f) self._metadata = GameMetadata.from_dict(config['metadata']) pre = config.get('preprocessing', {}) self._metadata.spatial = pre.get('spatial', False) self._metadata.board = pre.get('board', True) observe_state_sizes = pre.get('observe_state_sizes', {}) self._observe_state: list[str] = pre.get('observe_state', []) self._egocentric: bool = pre.get('egocentric', False) self._egocentric_player: str | None = pre.get('egocentric_player') self._egocentric_radius: int | None = pre.get('egocentric_radius') self._board: bool = pre.get('board', True) if observe_state_sizes: self._metadata.extras_size = sum(observe_state_sizes.values()) else: self._metadata.extras_size = len(self._observe_state) hyperparams = {**DEFAULTS, **config.get('model', {}), **config.get('training', {})} self._model, _ = build_network(self._metadata, hyperparams) if checkpoint is not None: ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt' ckpt_path = run_dir / 'checkpoints' / ckpt_name if not ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") else: ckpt_dir = run_dir / 'checkpoints' candidates = sorted(ckpt_dir.glob('ep_*.pt')) if ckpt_dir.exists() else [] if not candidates: raise FileNotFoundError(f"No checkpoints found in {ckpt_dir}") ckpt_path = candidates[-1] ckpt = torch.load(ckpt_path, weights_only=True) self._model.load_state_dict(ckpt['model_state_dict']) self._model.eval() def get_action(self, game) -> str | None: """Return the key the model recommends this turn, or None for no-op.""" view = HeadlessView() view.on_game_start(game) view.render(game) board_chars = view.board_characters player_pos = None if self._egocentric and self._egocentric_player: agent = game.get_agent_by_name(self._egocentric_player) if agent is not None: player_pos = agent.position obs = encode_observation( board_chars, dict(game.state), self._metadata, self._observe_state, player_pos=player_pos, egocentric_radius=self._egocentric_radius, board=self._board, ) device = next(self._model.parameters()).device state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(device) with torch.no_grad(): action_idx = int(self._model(state_t).argmax().item()) if action_idx >= len(self._metadata.actions): return None return self._metadata.actions[action_idx] def _keystroke(name: str) -> Keystroke: if name.startswith("KEY_"): return Keystroke(ucs='', code=None, name=name) return Keystroke(ucs=name, code=None, name=None) class PolicyInput: """An InputSource that drives the game with a TrainedPolicy instead of the keyboard. Pass it as ``input_source`` to ``game.play()`` and everything else works exactly as usual. Example:: from retro_gamer import TrainedPolicy, PolicyInput ai = TrainedPolicy("runs/snake/") game = create_game() game.play(input_source=PolicyInput(ai, game)) """ def __init__(self, model: TrainedPolicy, game): self._model = model self._game = game self._inp = ProgrammaticInput() def collect(self) -> set: key = self._model.get_action(self._game) self._inp.press(key) return self._inp.collect()