Updates across the board

This commit is contained in:
Chris Proctor
2026-06-22 16:41:31 -04:00
parent 5ca97dc5d0
commit 73624d1a0c
33 changed files with 3104 additions and 643 deletions

139
retro_gamer/model_agent.py Normal file
View File

@@ -0,0 +1,139 @@
from __future__ import annotations
import tomllib
import torch
from pathlib import Path
from blessed.keyboard import Keystroke
from retro.input import ProgrammaticInput
from retro.views.headless import HeadlessView
from retro_gamer.metadata import GameMetadata
from retro_gamer.observation import encode_observation
class TrainedPolicy:
"""A trained retro-gamer model that can observe a game and choose actions.
Load from a training run directory, then call ``get_action(game)`` from
inside any agent's ``play_turn`` to get the model's recommended key.
Example::
from retro_gamer import TrainedPolicy
_ai = TrainedPolicy("runs/enemy/")
class EnemyAgent:
def play_turn(self, game):
key = _ai.get_action(game)
if key == 'KEY_RIGHT': self.direction = (1, 0)
...
"""
def __init__(self, run_dir: str | Path, checkpoint: str | None = None):
from retro_gamer.network import build_network
from retro_gamer.trainer import DEFAULTS
run_dir = Path(run_dir)
config_path = run_dir / 'config.toml'
if not config_path.exists():
raise FileNotFoundError(f"No config.toml found in {run_dir}")
with open(config_path, 'rb') as f:
config = tomllib.load(f)
self._metadata = GameMetadata.from_dict(config['metadata'])
pre = config.get('preprocessing', {})
self._metadata.spatial = pre.get('spatial', False)
self._metadata.board = pre.get('board', True)
observe_state_sizes = pre.get('observe_state_sizes', {})
self._observe_state: list[str] = pre.get('observe_state', [])
self._egocentric: bool = pre.get('egocentric', False)
self._egocentric_player: str | None = pre.get('egocentric_player')
self._egocentric_radius: int | None = pre.get('egocentric_radius')
self._board: bool = pre.get('board', True)
if observe_state_sizes:
self._metadata.extras_size = sum(observe_state_sizes.values())
else:
self._metadata.extras_size = len(self._observe_state)
hyperparams = {**DEFAULTS, **config.get('model', {}), **config.get('training', {})}
self._model, _ = build_network(self._metadata, hyperparams)
if checkpoint is not None:
ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt'
ckpt_path = run_dir / 'checkpoints' / ckpt_name
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
else:
ckpt_dir = run_dir / 'checkpoints'
candidates = sorted(ckpt_dir.glob('ep_*.pt')) if ckpt_dir.exists() else []
if not candidates:
raise FileNotFoundError(f"No checkpoints found in {ckpt_dir}")
ckpt_path = candidates[-1]
ckpt = torch.load(ckpt_path, weights_only=True)
self._model.load_state_dict(ckpt['model_state_dict'])
self._model.eval()
def get_action(self, game) -> str | None:
"""Return the key the model recommends this turn, or None for no-op."""
view = HeadlessView()
view.on_game_start(game)
view.render(game)
board_chars = view.board_characters
player_pos = None
if self._egocentric and self._egocentric_player:
agent = game.get_agent_by_name(self._egocentric_player)
if agent is not None:
player_pos = agent.position
obs = encode_observation(
board_chars,
dict(game.state),
self._metadata,
self._observe_state,
player_pos=player_pos,
egocentric_radius=self._egocentric_radius,
board=self._board,
)
device = next(self._model.parameters()).device
state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
with torch.no_grad():
action_idx = int(self._model(state_t).argmax().item())
if action_idx >= len(self._metadata.actions):
return None
return self._metadata.actions[action_idx]
def _keystroke(name: str) -> Keystroke:
if name.startswith("KEY_"):
return Keystroke(ucs='', code=None, name=name)
return Keystroke(ucs=name, code=None, name=None)
class PolicyInput:
"""An InputSource that drives the game with a TrainedPolicy instead of the keyboard.
Pass it as ``input_source`` to ``game.play()`` and everything else works
exactly as usual.
Example::
from retro_gamer import TrainedPolicy, PolicyInput
ai = TrainedPolicy("runs/snake/")
game = create_game()
game.play(input_source=PolicyInput(ai, game))
"""
def __init__(self, model: TrainedPolicy, game):
self._model = model
self._game = game
self._inp = ProgrammaticInput()
def collect(self) -> set:
key = self._model.get_action(self._game)
self._inp.press(key)
return self._inp.collect()