Files
2026-06-22 16:41:31 -04:00

375 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import importlib
import sys
import tomllib
from pathlib import Path
import click
import tomli_w
from retro_gamer.metadata import GameMetadata
from retro_gamer.trainer import DQNTrainer, DEFAULTS, MODEL_KEYS
@click.group()
def cli():
"""Train and run RL agents for retro games."""
# ---------------------------------------------------------------------------
# retro-gamer create
# ---------------------------------------------------------------------------
@cli.command()
@click.option('--game', required=True,
help='Game to train: a .py file path (e.g. my_game.py) or a Python module '
'(e.g. retro.examples.snake)')
@click.option('--output', required=True,
help='Directory to create for this training run')
@click.option('--learning-rate', default=DEFAULTS['learning_rate'], type=float,
help=f"Adam optimizer learning rate (default {DEFAULTS['learning_rate']})")
@click.option('--learning-rate-decay', default=DEFAULTS['learning_rate_decay'], type=float,
help=f"Multiplicative LR decay per episode (default {DEFAULTS['learning_rate_decay']})")
@click.option('--gamma', default=DEFAULTS['gamma'], type=float,
help=f"Discount factor for future rewards (default {DEFAULTS['gamma']})")
@click.option('--epsilon-decay', default=DEFAULTS['epsilon_decay'], type=float,
help=f"Exploration rate decay per episode (default {DEFAULTS['epsilon_decay']})")
@click.option('--epsilon-min', default=DEFAULTS['epsilon_min'], type=float,
help=f"Minimum exploration rate (default {DEFAULTS['epsilon_min']})")
@click.option('--batch-size', default=DEFAULTS['batch_size'], type=int,
help=f"Experiences per training step (default {DEFAULTS['batch_size']})")
@click.option('--memory-capacity', default=DEFAULTS['memory_capacity'], type=int,
help=f"Replay buffer size (default {DEFAULTS['memory_capacity']})")
@click.option('--target-update-freq', default=DEFAULTS['target_update_freq'], type=int,
help=f"Steps between target network updates (default {DEFAULTS['target_update_freq']})")
@click.option('--training-episodes', default=DEFAULTS['training_episodes'], type=int,
help=f"Number of episodes to train (default {DEFAULTS['training_episodes']})")
@click.option('--max-turns-per-episode', default=DEFAULTS['max_turns_per_episode'], type=int,
help=f"Turn limit per episode (default {DEFAULTS['max_turns_per_episode']})")
@click.option('--hidden-sizes', default=','.join(str(s) for s in DEFAULTS['hidden_sizes']),
help=f"Comma-separated hidden layer sizes, e.g. 512,256 (default {DEFAULTS['hidden_sizes']})")
@click.option('--exploration-turns', default=DEFAULTS['exploration_turns'], type=int,
help=f"Random turns for character discovery (default {DEFAULTS['exploration_turns']})")
@click.option('--train-every', default=DEFAULTS['train_every'], type=int,
help=f"Run a training step every N game steps (default {DEFAULTS['train_every']})")
@click.option('--prioritize-experiences/--no-prioritize-experiences',
default=DEFAULTS['prioritize_experiences'],
help='Use prioritized experience replay')
def create(game, output, **hyperparams):
"""Create a new training run directory.
Game metadata (actions, reward signal, etc.) is read from the
[tool.retro-gamer] section of the game's pyproject.toml.
Board size is read directly from the game. Hyperparameter options
control how the trainer learns, not what it learns about.
"""
raw = hyperparams['hidden_sizes']
try:
hyperparams['hidden_sizes'] = [int(x.strip()) for x in raw.split(',')]
except ValueError:
raise click.ClickException(
f"Could not parse --hidden-sizes {raw!r}.\n"
"It should be a comma-separated list of positive integers, one per hidden layer.\n"
"Example: --hidden-sizes 512,256"
)
game_config = _parse_game_arg(game)
try:
metadata = GameMetadata.from_pyproject(game_config['module'])
except (FileNotFoundError, ValueError) as e:
raise click.ClickException(str(e))
game_factory = _load_factory(game_config)
g = game_factory()
metadata.board_size = g.board_size
metadata.validate()
run_dir = Path(output)
run_dir.mkdir(parents=True, exist_ok=True)
preprocessing = {'spatial': metadata.spatial}
config = {
'game': game_config,
'metadata': metadata.to_dict(),
'preprocessing': preprocessing,
'model': {k: v for k, v in hyperparams.items() if k in MODEL_KEYS},
'training': {k: v for k, v in hyperparams.items() if k not in MODEL_KEYS},
}
with open(run_dir / 'config.toml', 'wb') as f:
tomli_w.dump(config, f)
click.echo(f"Created training run at {output}/config.toml")
click.echo(f" game : {game}")
click.echo(f" board_size : {metadata.board_size[0]}×{metadata.board_size[1]}")
click.echo(f" actions : {metadata.actions}")
click.echo(f" reward : {metadata.reward}")
if metadata.character_set:
click.echo(f" characters : {metadata.character_set}")
else:
click.echo(f" characters : (will be auto-discovered during training)")
click.echo(f" architecture: {'CNN (spatial)' if metadata.spatial else 'MLP (non-spatial)'}")
# ---------------------------------------------------------------------------
# retro-gamer train
# ---------------------------------------------------------------------------
@cli.command()
@click.argument('run_dir')
def train(run_dir):
"""Train a DQN agent, resuming automatically from the latest checkpoint.
To start fresh or resume from an earlier point, delete the checkpoints
you no longer want from RUN_DIR/checkpoints/.
"""
run_dir_path = Path(run_dir)
config = _load_config(run_dir_path)
game_factory = _load_factory(config['game'])
metadata = GameMetadata.from_dict(config['metadata'])
preprocessing = config.get('preprocessing', {})
metadata.spatial = preprocessing.get('spatial', False)
hyperparams = {**config.get('model', {}), **config['training']}
try:
trainer = DQNTrainer(game_factory, metadata, run_dir, preprocessing=preprocessing, **hyperparams)
except ValueError as e:
raise click.ClickException(str(e))
latest = _latest_checkpoint(run_dir_path)
if latest:
click.echo(f"Resuming from {latest.name}")
try:
trainer.load_checkpoint(latest)
except ValueError as e:
raise click.ClickException(str(e))
if trainer.start_episode > trainer.hp['training_episodes']:
click.echo(
f"Training already complete ({trainer.hp['training_episodes']} episodes). "
"To keep training, increase training_episodes in config.toml."
)
return
from retro_gamer.log_parser import parse_checkpoints
from retro_gamer.dashboard import TrainingDashboard
history = parse_checkpoints(run_dir_path / 'training.log')
display = TrainingDashboard(trainer.hp['training_episodes'], trainer.start_episode, history)
try:
trainer.train(on_checkpoint=display.on_checkpoint, on_episode=display.on_episode)
finally:
display.close()
click.echo(f"Done. Checkpoints saved in {run_dir}/checkpoints/")
# ---------------------------------------------------------------------------
# retro-gamer play
# ---------------------------------------------------------------------------
@cli.command()
@click.argument('run_dir')
@click.option('--checkpoint', default=None,
help='Checkpoint to load, e.g. "ep_0100". Defaults to the latest available.')
@click.option('--framerate', default=12, type=int,
help='Target frames per second')
def play(run_dir, checkpoint, framerate):
"""Watch a trained agent play the game."""
from time import sleep, perf_counter
from blessed import Terminal
from retro.views.terminal import TerminalView
from retro_gamer.model_agent import TrainedPolicy, PolicyInput
run_dir_path = Path(run_dir)
config = _load_config(run_dir_path)
game_factory = _load_factory(config['game'])
try:
ai = TrainedPolicy(run_dir_path, checkpoint=checkpoint)
except FileNotFoundError as e:
raise click.ClickException(str(e))
game = game_factory()
inp = PolicyInput(ai, game)
terminal = Terminal()
term_view = TerminalView(terminal, color=game.color)
click.echo("Playing… (press Escape or Enter to quit)")
with terminal.fullscreen(), terminal.hidden_cursor(), terminal.cbreak():
game.input_source = inp
game.view = term_view
game.start()
while game.playing:
t0 = perf_counter()
key = terminal.inkey(0)
if key and key.name in ('KEY_ESCAPE', 'KEY_ENTER'):
break
game.step()
elapsed = perf_counter() - t0
sleep(max(0, 1 / framerate - elapsed))
# ---------------------------------------------------------------------------
# retro-gamer plot
# ---------------------------------------------------------------------------
@cli.command()
@click.argument('run_dir')
@click.option('--output', '-o', default=None,
help='Save to file (e.g. plot.png) instead of displaying interactively.')
def plot(run_dir, output):
"""Plot training metrics (reward, steps, loss, epsilon) from a run's log."""
from retro_gamer.plotter import plot_run
run_dir_path = Path(run_dir)
log_path = run_dir_path / 'training.log'
if not log_path.exists():
raise click.ClickException(f"No training.log found in {run_dir}")
output_path = Path(output) if output else None
try:
plot_run(log_path, output_path)
except ValueError as e:
raise click.ClickException(str(e))
# ---------------------------------------------------------------------------
# retro-gamer info
# ---------------------------------------------------------------------------
@cli.command()
@click.argument('run_dir')
def info(run_dir):
"""Print a summary of a training run."""
run_dir_path = Path(run_dir)
config = _load_config(run_dir_path)
click.echo(f"Game module : {config['game']['module']}")
click.echo(f"Metadata : {config['metadata']}")
click.echo(f"Preprocessing : {config.get('preprocessing', {})}")
click.echo(f"Model : {config.get('model', {})}")
click.echo(f"Training : {config['training']}")
log_path = run_dir_path / 'training.log'
if log_path.exists():
lines = log_path.read_text().splitlines()
ckpt_lines = [l for l in lines if l.startswith('[ep_')]
if ckpt_lines:
click.echo(f"\nLast 5 checkpoints:")
for line in ckpt_lines[-5:]:
click.echo(f" {line}")
ckpt_dir = run_dir_path / 'checkpoints'
if ckpt_dir.exists():
ckpts = sorted(ckpt_dir.glob('*.pt'))
click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}")
# ---------------------------------------------------------------------------
# retro-gamer clean
# ---------------------------------------------------------------------------
@cli.command()
@click.argument('run_dir')
@click.option('--yes', '-y', is_flag=True, help='Skip confirmation prompt')
def clean(run_dir, yes):
"""Remove all checkpoints and the training log from a run directory.
Use this after changing the game description or network architecture,
which require starting training from scratch.
"""
run_dir_path = Path(run_dir)
if not run_dir_path.exists():
raise click.ClickException(f"Directory not found: {run_dir}")
to_remove = []
ckpt_dir = run_dir_path / 'checkpoints'
if ckpt_dir.exists():
to_remove.extend(sorted(ckpt_dir.glob('*.pt')))
log_path = run_dir_path / 'training.log'
if log_path.exists():
to_remove.append(log_path)
if not to_remove:
click.echo("Nothing to clean.")
return
n_ckpts = sum(1 for p in to_remove if p.suffix == '.pt')
click.echo(f"Will remove {n_ckpts} checkpoint(s) and training log from {run_dir}/:")
for p in to_remove:
click.echo(f" {p.relative_to(run_dir_path)}")
if not yes:
click.confirm("\nProceed?", abort=True)
for p in to_remove:
p.unlink()
click.echo(f"Cleaned. Run 'retro-gamer train {run_dir}' to start fresh.")
ckpt_dir = run_dir_path / 'checkpoints'
if ckpt_dir.exists():
ckpts = sorted(ckpt_dir.glob('*.pt'))
click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _parse_game_arg(arg: str) -> dict:
"""Accept a .py file path, a package directory, or a Python module name.
Returns a dict with at least 'module', and optionally 'path' (the
directory to add to sys.path so the module can be imported).
"""
p = Path(arg)
if p.exists():
if p.is_file() and p.suffix == '.py':
path = str(p.parent.resolve())
module = p.stem
elif p.is_dir():
path = str(p.parent.resolve())
module = p.name
else:
raise click.ClickException(
f"{arg!r} is not a .py file or package directory"
)
if path not in sys.path:
sys.path.insert(0, path)
return {'module': module, 'path': path}
return {'module': arg}
def _latest_checkpoint(run_dir: Path) -> Path | None:
"""Return the most recent checkpoint in run_dir/checkpoints/, or None."""
ckpt_dir = run_dir / 'checkpoints'
if ckpt_dir.exists():
candidates = sorted(ckpt_dir.glob('ep_*.pt'))
if candidates:
return candidates[-1]
return None
def _load_config(run_dir: Path) -> dict:
config_path = run_dir / 'config.toml'
if not config_path.exists():
raise click.ClickException(f"No config.toml found in {run_dir}")
with open(config_path, 'rb') as f:
return tomllib.load(f)
def _load_factory(game_config: dict):
path = game_config.get('path')
if path and path not in sys.path:
sys.path.insert(0, path)
module_name = game_config['module']
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise click.ClickException(f"Cannot import game module '{module_name}': {e}")
if not hasattr(module, 'create_game'):
raise click.ClickException(
f"Module '{module_name}' has no create_game() function"
)
return module.create_game