Files
retro-gamer/retro_gamer/cli.py
Chris Proctor 5ca97dc5d0 Initial commit
2026-05-08 14:07:17 -04:00

244 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import importlib
import tomllib
from pathlib import Path
import click
import tomli_w
from retro_gamer.metadata import GameMetadata
from retro_gamer.trainer import DQNTrainer, DEFAULTS
@click.group()
def cli():
"""Train and run RL agents for retro games."""
# ---------------------------------------------------------------------------
# retro-gamer create
# ---------------------------------------------------------------------------
@cli.command()
@click.option('--game', required=True,
help='Python module containing create_game() e.g. retro.examples.snake')
@click.option('--output', required=True,
help='Directory to create for this training run')
@click.option('--learning-rate', default=DEFAULTS['learning_rate'], type=float,
help=f"Adam optimizer learning rate (default {DEFAULTS['learning_rate']})")
@click.option('--lr-decay', default=DEFAULTS['lr_decay'], type=float,
help=f"Multiplicative LR decay per episode (default {DEFAULTS['lr_decay']})")
@click.option('--gamma', default=DEFAULTS['gamma'], type=float,
help=f"Discount factor for future rewards (default {DEFAULTS['gamma']})")
@click.option('--epsilon-decay', default=DEFAULTS['epsilon_decay'], type=float,
help=f"Exploration rate decay per episode (default {DEFAULTS['epsilon_decay']})")
@click.option('--epsilon-min', default=DEFAULTS['epsilon_min'], type=float,
help=f"Minimum exploration rate (default {DEFAULTS['epsilon_min']})")
@click.option('--batch-size', default=DEFAULTS['batch_size'], type=int,
help=f"Experiences per training step (default {DEFAULTS['batch_size']})")
@click.option('--memory-capacity', default=DEFAULTS['memory_capacity'], type=int,
help=f"Replay buffer size (default {DEFAULTS['memory_capacity']})")
@click.option('--target-update-freq', default=DEFAULTS['target_update_freq'], type=int,
help=f"Steps between target network updates (default {DEFAULTS['target_update_freq']})")
@click.option('--training-episodes', default=DEFAULTS['training_episodes'], type=int,
help=f"Number of episodes to train (default {DEFAULTS['training_episodes']})")
@click.option('--max-turns-per-episode', default=DEFAULTS['max_turns_per_episode'], type=int,
help=f"Turn limit per episode (default {DEFAULTS['max_turns_per_episode']})")
@click.option('--n-layers', default=DEFAULTS['n_layers'], type=int,
help=f"Hidden layers in MLP head (default {DEFAULTS['n_layers']})")
@click.option('--layer-size', default=DEFAULTS['layer_size'], type=int,
help=f"Width of each hidden layer (default {DEFAULTS['layer_size']})")
@click.option('--exploration-turns', default=DEFAULTS['exploration_turns'], type=int,
help=f"Random turns for character discovery (default {DEFAULTS['exploration_turns']})")
@click.option('--prioritize-experiences/--no-prioritize-experiences',
default=DEFAULTS['prioritize_experiences'],
help='Use prioritized experience replay')
def create(game, output, **hyperparams):
"""Create a new training run directory.
Game metadata (actions, reward signal, etc.) is read from the
[tool.retro-gamer] section of the game's pyproject.toml.
Board size is read directly from the game. Hyperparameter options
control how the trainer learns, not what it learns about.
"""
try:
metadata = GameMetadata.from_pyproject(game)
except (FileNotFoundError, ValueError) as e:
raise click.ClickException(str(e))
game_factory = _load_factory(game)
g = game_factory()
metadata.board_size = g.board_size
metadata.validate()
run_dir = Path(output)
run_dir.mkdir(parents=True, exist_ok=True)
config = {
'game': {'module': game},
'metadata': metadata.to_dict(),
'hyperparameters': hyperparams,
}
with open(run_dir / 'config.toml', 'wb') as f:
tomli_w.dump(config, f)
click.echo(f"Created training run at {output}/config.toml")
click.echo(f" game : {game}")
click.echo(f" board_size : {metadata.board_size[0]}×{metadata.board_size[1]}")
click.echo(f" actions : {metadata.actions}")
click.echo(f" reward : {metadata.reward}")
if metadata.character_set:
click.echo(f" characters : {metadata.character_set}")
else:
click.echo(f" characters : (will be auto-discovered during training)")
if metadata.observe_state:
click.echo(f" observe : {metadata.observe_state}")
click.echo(f" architecture: {'CNN (spatial)' if metadata.spatial else 'MLP (non-spatial)'}")
# ---------------------------------------------------------------------------
# retro-gamer train
# ---------------------------------------------------------------------------
@cli.command()
@click.argument('run_dir')
@click.option('--resume', default=None,
help='Path to checkpoint to resume from (e.g. checkpoints/ep_0500.pt)')
def train(run_dir, resume):
"""Train (or resume training) a DQN agent."""
run_dir_path = Path(run_dir)
config = _load_config(run_dir_path)
game_factory = _load_factory(config['game']['module'])
metadata = GameMetadata.from_dict(config['metadata'])
hyperparams = config.get('hyperparameters', {})
trainer = DQNTrainer(game_factory, metadata, run_dir, **hyperparams)
if resume:
click.echo(f"Resuming from {resume}")
trainer.load_checkpoint(resume)
click.echo(f"Training for {trainer.hp['training_episodes']} episodes…")
trainer.train()
click.echo(f"Done. Checkpoints in {run_dir}/checkpoints/")
# ---------------------------------------------------------------------------
# retro-gamer play
# ---------------------------------------------------------------------------
@cli.command()
@click.argument('run_dir')
@click.option('--checkpoint', default='final',
help='Checkpoint name e.g. "final" or "ep_0100"')
@click.option('--framerate', default=12, type=int,
help='Target frames per second')
def play(run_dir, checkpoint, framerate):
"""Watch a trained agent play the game."""
import torch
from time import sleep, perf_counter
from blessed import Terminal
from retro.input import ProgrammaticInput
from retro.views.headless import HeadlessView
from retro.views.terminal import TerminalView
from retro_gamer.observation import encode_observation
run_dir_path = Path(run_dir)
config = _load_config(run_dir_path)
game_factory = _load_factory(config['game']['module'])
metadata = GameMetadata.from_dict(config['metadata'])
hyperparams = {**DEFAULTS, **config.get('hyperparameters', {})}
from retro_gamer.network import build_network
model, _ = build_network(metadata, hyperparams)
ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt'
ckpt_path = run_dir_path / 'checkpoints' / ckpt_name
ckpt = torch.load(ckpt_path, weights_only=True)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
inp = ProgrammaticInput()
headless = HeadlessView()
game = game_factory()
game.input_source = inp
game.view = headless
game.start()
terminal = Terminal()
term_view = TerminalView(terminal, color=game.color)
click.echo("Playing… (press Escape or Enter to quit)")
with terminal.fullscreen(), terminal.hidden_cursor(), terminal.cbreak():
term_view.on_game_start(game)
while game.playing:
t0 = perf_counter()
obs = encode_observation(headless.board_characters, dict(game.state), metadata)
state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
action_idx = int(model(state_t).argmax().item())
action_key = None if action_idx >= len(metadata.actions) else metadata.actions[action_idx]
key = terminal.inkey(0)
if key and key.name in ('KEY_ESCAPE', 'KEY_ENTER'):
break
inp.press(action_key)
game.step()
term_view.render(game)
elapsed = perf_counter() - t0
sleep(max(0, 1 / framerate - elapsed))
# ---------------------------------------------------------------------------
# retro-gamer info
# ---------------------------------------------------------------------------
@cli.command()
@click.argument('run_dir')
def info(run_dir):
"""Print a summary of a training run."""
run_dir_path = Path(run_dir)
config = _load_config(run_dir_path)
click.echo(f"Game module : {config['game']['module']}")
click.echo(f"Metadata : {config['metadata']}")
click.echo(f"Hyperparams : {config.get('hyperparameters', {})}")
log_path = run_dir_path / 'training.log'
if log_path.exists():
lines = log_path.read_text().splitlines()
episode_lines = [l for l in lines if l.startswith('[EP')]
if episode_lines:
click.echo(f"\nLast 5 episodes:")
for line in episode_lines[-5:]:
click.echo(f" {line}")
ckpt_dir = run_dir_path / 'checkpoints'
if ckpt_dir.exists():
ckpts = sorted(ckpt_dir.glob('*.pt'))
click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _load_config(run_dir: Path) -> dict:
config_path = run_dir / 'config.toml'
if not config_path.exists():
raise click.ClickException(f"No config.toml found in {run_dir}")
with open(config_path, 'rb') as f:
return tomllib.load(f)
def _load_factory(module_name: str):
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise click.ClickException(f"Cannot import game module '{module_name}': {e}")
if not hasattr(module, 'create_game'):
raise click.ClickException(
f"Module '{module_name}' has no create_game() function"
)
return module.create_game