244 lines
10 KiB
Python
244 lines
10 KiB
Python
from __future__ import annotations
|
||
import importlib
|
||
import tomllib
|
||
from pathlib import Path
|
||
import click
|
||
import tomli_w
|
||
|
||
from retro_gamer.metadata import GameMetadata
|
||
from retro_gamer.trainer import DQNTrainer, DEFAULTS
|
||
|
||
|
||
@click.group()
|
||
def cli():
|
||
"""Train and run RL agents for retro games."""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# retro-gamer create
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@cli.command()
|
||
@click.option('--game', required=True,
|
||
help='Python module containing create_game() e.g. retro.examples.snake')
|
||
@click.option('--output', required=True,
|
||
help='Directory to create for this training run')
|
||
@click.option('--learning-rate', default=DEFAULTS['learning_rate'], type=float,
|
||
help=f"Adam optimizer learning rate (default {DEFAULTS['learning_rate']})")
|
||
@click.option('--lr-decay', default=DEFAULTS['lr_decay'], type=float,
|
||
help=f"Multiplicative LR decay per episode (default {DEFAULTS['lr_decay']})")
|
||
@click.option('--gamma', default=DEFAULTS['gamma'], type=float,
|
||
help=f"Discount factor for future rewards (default {DEFAULTS['gamma']})")
|
||
@click.option('--epsilon-decay', default=DEFAULTS['epsilon_decay'], type=float,
|
||
help=f"Exploration rate decay per episode (default {DEFAULTS['epsilon_decay']})")
|
||
@click.option('--epsilon-min', default=DEFAULTS['epsilon_min'], type=float,
|
||
help=f"Minimum exploration rate (default {DEFAULTS['epsilon_min']})")
|
||
@click.option('--batch-size', default=DEFAULTS['batch_size'], type=int,
|
||
help=f"Experiences per training step (default {DEFAULTS['batch_size']})")
|
||
@click.option('--memory-capacity', default=DEFAULTS['memory_capacity'], type=int,
|
||
help=f"Replay buffer size (default {DEFAULTS['memory_capacity']})")
|
||
@click.option('--target-update-freq', default=DEFAULTS['target_update_freq'], type=int,
|
||
help=f"Steps between target network updates (default {DEFAULTS['target_update_freq']})")
|
||
@click.option('--training-episodes', default=DEFAULTS['training_episodes'], type=int,
|
||
help=f"Number of episodes to train (default {DEFAULTS['training_episodes']})")
|
||
@click.option('--max-turns-per-episode', default=DEFAULTS['max_turns_per_episode'], type=int,
|
||
help=f"Turn limit per episode (default {DEFAULTS['max_turns_per_episode']})")
|
||
@click.option('--n-layers', default=DEFAULTS['n_layers'], type=int,
|
||
help=f"Hidden layers in MLP head (default {DEFAULTS['n_layers']})")
|
||
@click.option('--layer-size', default=DEFAULTS['layer_size'], type=int,
|
||
help=f"Width of each hidden layer (default {DEFAULTS['layer_size']})")
|
||
@click.option('--exploration-turns', default=DEFAULTS['exploration_turns'], type=int,
|
||
help=f"Random turns for character discovery (default {DEFAULTS['exploration_turns']})")
|
||
@click.option('--prioritize-experiences/--no-prioritize-experiences',
|
||
default=DEFAULTS['prioritize_experiences'],
|
||
help='Use prioritized experience replay')
|
||
def create(game, output, **hyperparams):
|
||
"""Create a new training run directory.
|
||
|
||
Game metadata (actions, reward signal, etc.) is read from the
|
||
[tool.retro-gamer] section of the game's pyproject.toml.
|
||
Board size is read directly from the game. Hyperparameter options
|
||
control how the trainer learns, not what it learns about.
|
||
"""
|
||
try:
|
||
metadata = GameMetadata.from_pyproject(game)
|
||
except (FileNotFoundError, ValueError) as e:
|
||
raise click.ClickException(str(e))
|
||
|
||
game_factory = _load_factory(game)
|
||
g = game_factory()
|
||
metadata.board_size = g.board_size
|
||
|
||
metadata.validate()
|
||
|
||
run_dir = Path(output)
|
||
run_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
config = {
|
||
'game': {'module': game},
|
||
'metadata': metadata.to_dict(),
|
||
'hyperparameters': hyperparams,
|
||
}
|
||
with open(run_dir / 'config.toml', 'wb') as f:
|
||
tomli_w.dump(config, f)
|
||
|
||
click.echo(f"Created training run at {output}/config.toml")
|
||
click.echo(f" game : {game}")
|
||
click.echo(f" board_size : {metadata.board_size[0]}×{metadata.board_size[1]}")
|
||
click.echo(f" actions : {metadata.actions}")
|
||
click.echo(f" reward : {metadata.reward}")
|
||
if metadata.character_set:
|
||
click.echo(f" characters : {metadata.character_set}")
|
||
else:
|
||
click.echo(f" characters : (will be auto-discovered during training)")
|
||
if metadata.observe_state:
|
||
click.echo(f" observe : {metadata.observe_state}")
|
||
click.echo(f" architecture: {'CNN (spatial)' if metadata.spatial else 'MLP (non-spatial)'}")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# retro-gamer train
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@cli.command()
|
||
@click.argument('run_dir')
|
||
@click.option('--resume', default=None,
|
||
help='Path to checkpoint to resume from (e.g. checkpoints/ep_0500.pt)')
|
||
def train(run_dir, resume):
|
||
"""Train (or resume training) a DQN agent."""
|
||
run_dir_path = Path(run_dir)
|
||
config = _load_config(run_dir_path)
|
||
game_factory = _load_factory(config['game']['module'])
|
||
metadata = GameMetadata.from_dict(config['metadata'])
|
||
hyperparams = config.get('hyperparameters', {})
|
||
|
||
trainer = DQNTrainer(game_factory, metadata, run_dir, **hyperparams)
|
||
if resume:
|
||
click.echo(f"Resuming from {resume}")
|
||
trainer.load_checkpoint(resume)
|
||
click.echo(f"Training for {trainer.hp['training_episodes']} episodes…")
|
||
trainer.train()
|
||
click.echo(f"Done. Checkpoints in {run_dir}/checkpoints/")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# retro-gamer play
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@cli.command()
|
||
@click.argument('run_dir')
|
||
@click.option('--checkpoint', default='final',
|
||
help='Checkpoint name e.g. "final" or "ep_0100"')
|
||
@click.option('--framerate', default=12, type=int,
|
||
help='Target frames per second')
|
||
def play(run_dir, checkpoint, framerate):
|
||
"""Watch a trained agent play the game."""
|
||
import torch
|
||
from time import sleep, perf_counter
|
||
from blessed import Terminal
|
||
from retro.input import ProgrammaticInput
|
||
from retro.views.headless import HeadlessView
|
||
from retro.views.terminal import TerminalView
|
||
from retro_gamer.observation import encode_observation
|
||
|
||
run_dir_path = Path(run_dir)
|
||
config = _load_config(run_dir_path)
|
||
game_factory = _load_factory(config['game']['module'])
|
||
metadata = GameMetadata.from_dict(config['metadata'])
|
||
hyperparams = {**DEFAULTS, **config.get('hyperparameters', {})}
|
||
|
||
from retro_gamer.network import build_network
|
||
model, _ = build_network(metadata, hyperparams)
|
||
|
||
ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt'
|
||
ckpt_path = run_dir_path / 'checkpoints' / ckpt_name
|
||
ckpt = torch.load(ckpt_path, weights_only=True)
|
||
model.load_state_dict(ckpt['model_state_dict'])
|
||
model.eval()
|
||
|
||
inp = ProgrammaticInput()
|
||
headless = HeadlessView()
|
||
game = game_factory()
|
||
game.input_source = inp
|
||
game.view = headless
|
||
game.start()
|
||
|
||
terminal = Terminal()
|
||
term_view = TerminalView(terminal, color=game.color)
|
||
|
||
click.echo("Playing… (press Escape or Enter to quit)")
|
||
with terminal.fullscreen(), terminal.hidden_cursor(), terminal.cbreak():
|
||
term_view.on_game_start(game)
|
||
while game.playing:
|
||
t0 = perf_counter()
|
||
|
||
obs = encode_observation(headless.board_characters, dict(game.state), metadata)
|
||
state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0)
|
||
with torch.no_grad():
|
||
action_idx = int(model(state_t).argmax().item())
|
||
action_key = None if action_idx >= len(metadata.actions) else metadata.actions[action_idx]
|
||
|
||
key = terminal.inkey(0)
|
||
if key and key.name in ('KEY_ESCAPE', 'KEY_ENTER'):
|
||
break
|
||
|
||
inp.press(action_key)
|
||
game.step()
|
||
term_view.render(game)
|
||
|
||
elapsed = perf_counter() - t0
|
||
sleep(max(0, 1 / framerate - elapsed))
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# retro-gamer info
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@cli.command()
|
||
@click.argument('run_dir')
|
||
def info(run_dir):
|
||
"""Print a summary of a training run."""
|
||
run_dir_path = Path(run_dir)
|
||
config = _load_config(run_dir_path)
|
||
click.echo(f"Game module : {config['game']['module']}")
|
||
click.echo(f"Metadata : {config['metadata']}")
|
||
click.echo(f"Hyperparams : {config.get('hyperparameters', {})}")
|
||
|
||
log_path = run_dir_path / 'training.log'
|
||
if log_path.exists():
|
||
lines = log_path.read_text().splitlines()
|
||
episode_lines = [l for l in lines if l.startswith('[EP')]
|
||
if episode_lines:
|
||
click.echo(f"\nLast 5 episodes:")
|
||
for line in episode_lines[-5:]:
|
||
click.echo(f" {line}")
|
||
|
||
ckpt_dir = run_dir_path / 'checkpoints'
|
||
if ckpt_dir.exists():
|
||
ckpts = sorted(ckpt_dir.glob('*.pt'))
|
||
click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _load_config(run_dir: Path) -> dict:
|
||
config_path = run_dir / 'config.toml'
|
||
if not config_path.exists():
|
||
raise click.ClickException(f"No config.toml found in {run_dir}")
|
||
with open(config_path, 'rb') as f:
|
||
return tomllib.load(f)
|
||
|
||
|
||
def _load_factory(module_name: str):
|
||
try:
|
||
module = importlib.import_module(module_name)
|
||
except ImportError as e:
|
||
raise click.ClickException(f"Cannot import game module '{module_name}': {e}")
|
||
if not hasattr(module, 'create_game'):
|
||
raise click.ClickException(
|
||
f"Module '{module_name}' has no create_game() function"
|
||
)
|
||
return module.create_game
|