140 lines
5.0 KiB
Python
140 lines
5.0 KiB
Python
from __future__ import annotations
|
|
import tomllib
|
|
import torch
|
|
from pathlib import Path
|
|
from blessed.keyboard import Keystroke
|
|
from retro.input import ProgrammaticInput
|
|
from retro.views.headless import HeadlessView
|
|
from retro_gamer.metadata import GameMetadata
|
|
from retro_gamer.observation import encode_observation
|
|
|
|
|
|
class TrainedPolicy:
|
|
"""A trained retro-gamer model that can observe a game and choose actions.
|
|
|
|
Load from a training run directory, then call ``get_action(game)`` from
|
|
inside any agent's ``play_turn`` to get the model's recommended key.
|
|
|
|
Example::
|
|
|
|
from retro_gamer import TrainedPolicy
|
|
|
|
_ai = TrainedPolicy("runs/enemy/")
|
|
|
|
class EnemyAgent:
|
|
def play_turn(self, game):
|
|
key = _ai.get_action(game)
|
|
if key == 'KEY_RIGHT': self.direction = (1, 0)
|
|
...
|
|
"""
|
|
|
|
def __init__(self, run_dir: str | Path, checkpoint: str | None = None):
|
|
from retro_gamer.network import build_network
|
|
from retro_gamer.trainer import DEFAULTS
|
|
|
|
run_dir = Path(run_dir)
|
|
config_path = run_dir / 'config.toml'
|
|
if not config_path.exists():
|
|
raise FileNotFoundError(f"No config.toml found in {run_dir}")
|
|
|
|
with open(config_path, 'rb') as f:
|
|
config = tomllib.load(f)
|
|
|
|
self._metadata = GameMetadata.from_dict(config['metadata'])
|
|
|
|
pre = config.get('preprocessing', {})
|
|
self._metadata.spatial = pre.get('spatial', False)
|
|
self._metadata.board = pre.get('board', True)
|
|
observe_state_sizes = pre.get('observe_state_sizes', {})
|
|
self._observe_state: list[str] = pre.get('observe_state', [])
|
|
self._egocentric: bool = pre.get('egocentric', False)
|
|
self._egocentric_player: str | None = pre.get('egocentric_player')
|
|
self._egocentric_radius: int | None = pre.get('egocentric_radius')
|
|
self._board: bool = pre.get('board', True)
|
|
|
|
if observe_state_sizes:
|
|
self._metadata.extras_size = sum(observe_state_sizes.values())
|
|
else:
|
|
self._metadata.extras_size = len(self._observe_state)
|
|
|
|
hyperparams = {**DEFAULTS, **config.get('model', {}), **config.get('training', {})}
|
|
self._model, _ = build_network(self._metadata, hyperparams)
|
|
|
|
if checkpoint is not None:
|
|
ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt'
|
|
ckpt_path = run_dir / 'checkpoints' / ckpt_name
|
|
if not ckpt_path.exists():
|
|
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
|
else:
|
|
ckpt_dir = run_dir / 'checkpoints'
|
|
candidates = sorted(ckpt_dir.glob('ep_*.pt')) if ckpt_dir.exists() else []
|
|
if not candidates:
|
|
raise FileNotFoundError(f"No checkpoints found in {ckpt_dir}")
|
|
ckpt_path = candidates[-1]
|
|
|
|
ckpt = torch.load(ckpt_path, weights_only=True)
|
|
self._model.load_state_dict(ckpt['model_state_dict'])
|
|
self._model.eval()
|
|
|
|
def get_action(self, game) -> str | None:
|
|
"""Return the key the model recommends this turn, or None for no-op."""
|
|
view = HeadlessView()
|
|
view.on_game_start(game)
|
|
view.render(game)
|
|
board_chars = view.board_characters
|
|
|
|
player_pos = None
|
|
if self._egocentric and self._egocentric_player:
|
|
agent = game.get_agent_by_name(self._egocentric_player)
|
|
if agent is not None:
|
|
player_pos = agent.position
|
|
|
|
obs = encode_observation(
|
|
board_chars,
|
|
dict(game.state),
|
|
self._metadata,
|
|
self._observe_state,
|
|
player_pos=player_pos,
|
|
egocentric_radius=self._egocentric_radius,
|
|
board=self._board,
|
|
)
|
|
|
|
device = next(self._model.parameters()).device
|
|
state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
|
|
with torch.no_grad():
|
|
action_idx = int(self._model(state_t).argmax().item())
|
|
if action_idx >= len(self._metadata.actions):
|
|
return None
|
|
return self._metadata.actions[action_idx]
|
|
|
|
|
|
def _keystroke(name: str) -> Keystroke:
|
|
if name.startswith("KEY_"):
|
|
return Keystroke(ucs='', code=None, name=name)
|
|
return Keystroke(ucs=name, code=None, name=None)
|
|
|
|
|
|
class PolicyInput:
|
|
"""An InputSource that drives the game with a TrainedPolicy instead of the keyboard.
|
|
|
|
Pass it as ``input_source`` to ``game.play()`` and everything else works
|
|
exactly as usual.
|
|
|
|
Example::
|
|
|
|
from retro_gamer import TrainedPolicy, PolicyInput
|
|
ai = TrainedPolicy("runs/snake/")
|
|
game = create_game()
|
|
game.play(input_source=PolicyInput(ai, game))
|
|
"""
|
|
|
|
def __init__(self, model: TrainedPolicy, game):
|
|
self._model = model
|
|
self._game = game
|
|
self._inp = ProgrammaticInput()
|
|
|
|
def collect(self) -> set:
|
|
key = self._model.get_action(self._game)
|
|
self._inp.press(key)
|
|
return self._inp.collect()
|