Updates across the board
This commit is contained in:
139
retro_gamer/model_agent.py
Normal file
139
retro_gamer/model_agent.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
import tomllib
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from blessed.keyboard import Keystroke
|
||||
from retro.input import ProgrammaticInput
|
||||
from retro.views.headless import HeadlessView
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
from retro_gamer.observation import encode_observation
|
||||
|
||||
|
||||
class TrainedPolicy:
|
||||
"""A trained retro-gamer model that can observe a game and choose actions.
|
||||
|
||||
Load from a training run directory, then call ``get_action(game)`` from
|
||||
inside any agent's ``play_turn`` to get the model's recommended key.
|
||||
|
||||
Example::
|
||||
|
||||
from retro_gamer import TrainedPolicy
|
||||
|
||||
_ai = TrainedPolicy("runs/enemy/")
|
||||
|
||||
class EnemyAgent:
|
||||
def play_turn(self, game):
|
||||
key = _ai.get_action(game)
|
||||
if key == 'KEY_RIGHT': self.direction = (1, 0)
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, run_dir: str | Path, checkpoint: str | None = None):
|
||||
from retro_gamer.network import build_network
|
||||
from retro_gamer.trainer import DEFAULTS
|
||||
|
||||
run_dir = Path(run_dir)
|
||||
config_path = run_dir / 'config.toml'
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"No config.toml found in {run_dir}")
|
||||
|
||||
with open(config_path, 'rb') as f:
|
||||
config = tomllib.load(f)
|
||||
|
||||
self._metadata = GameMetadata.from_dict(config['metadata'])
|
||||
|
||||
pre = config.get('preprocessing', {})
|
||||
self._metadata.spatial = pre.get('spatial', False)
|
||||
self._metadata.board = pre.get('board', True)
|
||||
observe_state_sizes = pre.get('observe_state_sizes', {})
|
||||
self._observe_state: list[str] = pre.get('observe_state', [])
|
||||
self._egocentric: bool = pre.get('egocentric', False)
|
||||
self._egocentric_player: str | None = pre.get('egocentric_player')
|
||||
self._egocentric_radius: int | None = pre.get('egocentric_radius')
|
||||
self._board: bool = pre.get('board', True)
|
||||
|
||||
if observe_state_sizes:
|
||||
self._metadata.extras_size = sum(observe_state_sizes.values())
|
||||
else:
|
||||
self._metadata.extras_size = len(self._observe_state)
|
||||
|
||||
hyperparams = {**DEFAULTS, **config.get('model', {}), **config.get('training', {})}
|
||||
self._model, _ = build_network(self._metadata, hyperparams)
|
||||
|
||||
if checkpoint is not None:
|
||||
ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt'
|
||||
ckpt_path = run_dir / 'checkpoints' / ckpt_name
|
||||
if not ckpt_path.exists():
|
||||
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
||||
else:
|
||||
ckpt_dir = run_dir / 'checkpoints'
|
||||
candidates = sorted(ckpt_dir.glob('ep_*.pt')) if ckpt_dir.exists() else []
|
||||
if not candidates:
|
||||
raise FileNotFoundError(f"No checkpoints found in {ckpt_dir}")
|
||||
ckpt_path = candidates[-1]
|
||||
|
||||
ckpt = torch.load(ckpt_path, weights_only=True)
|
||||
self._model.load_state_dict(ckpt['model_state_dict'])
|
||||
self._model.eval()
|
||||
|
||||
def get_action(self, game) -> str | None:
|
||||
"""Return the key the model recommends this turn, or None for no-op."""
|
||||
view = HeadlessView()
|
||||
view.on_game_start(game)
|
||||
view.render(game)
|
||||
board_chars = view.board_characters
|
||||
|
||||
player_pos = None
|
||||
if self._egocentric and self._egocentric_player:
|
||||
agent = game.get_agent_by_name(self._egocentric_player)
|
||||
if agent is not None:
|
||||
player_pos = agent.position
|
||||
|
||||
obs = encode_observation(
|
||||
board_chars,
|
||||
dict(game.state),
|
||||
self._metadata,
|
||||
self._observe_state,
|
||||
player_pos=player_pos,
|
||||
egocentric_radius=self._egocentric_radius,
|
||||
board=self._board,
|
||||
)
|
||||
|
||||
device = next(self._model.parameters()).device
|
||||
state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
action_idx = int(self._model(state_t).argmax().item())
|
||||
if action_idx >= len(self._metadata.actions):
|
||||
return None
|
||||
return self._metadata.actions[action_idx]
|
||||
|
||||
|
||||
def _keystroke(name: str) -> Keystroke:
|
||||
if name.startswith("KEY_"):
|
||||
return Keystroke(ucs='', code=None, name=name)
|
||||
return Keystroke(ucs=name, code=None, name=None)
|
||||
|
||||
|
||||
class PolicyInput:
|
||||
"""An InputSource that drives the game with a TrainedPolicy instead of the keyboard.
|
||||
|
||||
Pass it as ``input_source`` to ``game.play()`` and everything else works
|
||||
exactly as usual.
|
||||
|
||||
Example::
|
||||
|
||||
from retro_gamer import TrainedPolicy, PolicyInput
|
||||
ai = TrainedPolicy("runs/snake/")
|
||||
game = create_game()
|
||||
game.play(input_source=PolicyInput(ai, game))
|
||||
"""
|
||||
|
||||
def __init__(self, model: TrainedPolicy, game):
|
||||
self._model = model
|
||||
self._game = game
|
||||
self._inp = ProgrammaticInput()
|
||||
|
||||
def collect(self) -> set:
|
||||
key = self._model.get_action(self._game)
|
||||
self._inp.press(key)
|
||||
return self._inp.collect()
|
||||
Reference in New Issue
Block a user