from __future__ import annotations import importlib import sys import tomllib from pathlib import Path import click import tomli_w from retro_gamer.metadata import GameMetadata from retro_gamer.trainer import DQNTrainer, DEFAULTS, MODEL_KEYS @click.group() def cli(): """Train and run RL agents for retro games.""" # --------------------------------------------------------------------------- # retro-gamer create # --------------------------------------------------------------------------- @cli.command() @click.option('--game', required=True, help='Game to train: a .py file path (e.g. my_game.py) or a Python module ' '(e.g. retro.examples.snake)') @click.option('--output', required=True, help='Directory to create for this training run') @click.option('--learning-rate', default=DEFAULTS['learning_rate'], type=float, help=f"Adam optimizer learning rate (default {DEFAULTS['learning_rate']})") @click.option('--learning-rate-decay', default=DEFAULTS['learning_rate_decay'], type=float, help=f"Multiplicative LR decay per episode (default {DEFAULTS['learning_rate_decay']})") @click.option('--gamma', default=DEFAULTS['gamma'], type=float, help=f"Discount factor for future rewards (default {DEFAULTS['gamma']})") @click.option('--epsilon-decay', default=DEFAULTS['epsilon_decay'], type=float, help=f"Exploration rate decay per episode (default {DEFAULTS['epsilon_decay']})") @click.option('--epsilon-min', default=DEFAULTS['epsilon_min'], type=float, help=f"Minimum exploration rate (default {DEFAULTS['epsilon_min']})") @click.option('--batch-size', default=DEFAULTS['batch_size'], type=int, help=f"Experiences per training step (default {DEFAULTS['batch_size']})") @click.option('--memory-capacity', default=DEFAULTS['memory_capacity'], type=int, help=f"Replay buffer size (default {DEFAULTS['memory_capacity']})") @click.option('--target-update-freq', default=DEFAULTS['target_update_freq'], type=int, help=f"Steps between target network updates (default {DEFAULTS['target_update_freq']})") @click.option('--training-episodes', default=DEFAULTS['training_episodes'], type=int, help=f"Number of episodes to train (default {DEFAULTS['training_episodes']})") @click.option('--max-turns-per-episode', default=DEFAULTS['max_turns_per_episode'], type=int, help=f"Turn limit per episode (default {DEFAULTS['max_turns_per_episode']})") @click.option('--hidden-sizes', default=','.join(str(s) for s in DEFAULTS['hidden_sizes']), help=f"Comma-separated hidden layer sizes, e.g. 512,256 (default {DEFAULTS['hidden_sizes']})") @click.option('--exploration-turns', default=DEFAULTS['exploration_turns'], type=int, help=f"Random turns for character discovery (default {DEFAULTS['exploration_turns']})") @click.option('--train-every', default=DEFAULTS['train_every'], type=int, help=f"Run a training step every N game steps (default {DEFAULTS['train_every']})") @click.option('--prioritize-experiences/--no-prioritize-experiences', default=DEFAULTS['prioritize_experiences'], help='Use prioritized experience replay') def create(game, output, **hyperparams): """Create a new training run directory. Game metadata (actions, reward signal, etc.) is read from the [tool.retro-gamer] section of the game's pyproject.toml. Board size is read directly from the game. Hyperparameter options control how the trainer learns, not what it learns about. """ raw = hyperparams['hidden_sizes'] try: hyperparams['hidden_sizes'] = [int(x.strip()) for x in raw.split(',')] except ValueError: raise click.ClickException( f"Could not parse --hidden-sizes {raw!r}.\n" "It should be a comma-separated list of positive integers, one per hidden layer.\n" "Example: --hidden-sizes 512,256" ) game_config = _parse_game_arg(game) try: metadata = GameMetadata.from_pyproject(game_config['module']) except (FileNotFoundError, ValueError) as e: raise click.ClickException(str(e)) game_factory = _load_factory(game_config) g = game_factory() metadata.board_size = g.board_size metadata.validate() run_dir = Path(output) run_dir.mkdir(parents=True, exist_ok=True) preprocessing = {'spatial': metadata.spatial} config = { 'game': game_config, 'metadata': metadata.to_dict(), 'preprocessing': preprocessing, 'model': {k: v for k, v in hyperparams.items() if k in MODEL_KEYS}, 'training': {k: v for k, v in hyperparams.items() if k not in MODEL_KEYS}, } with open(run_dir / 'config.toml', 'wb') as f: tomli_w.dump(config, f) click.echo(f"Created training run at {output}/config.toml") click.echo(f" game : {game}") click.echo(f" board_size : {metadata.board_size[0]}×{metadata.board_size[1]}") click.echo(f" actions : {metadata.actions}") click.echo(f" reward : {metadata.reward}") if metadata.character_set: click.echo(f" characters : {metadata.character_set}") else: click.echo(f" characters : (will be auto-discovered during training)") click.echo(f" architecture: {'CNN (spatial)' if metadata.spatial else 'MLP (non-spatial)'}") # --------------------------------------------------------------------------- # retro-gamer train # --------------------------------------------------------------------------- @cli.command() @click.argument('run_dir') def train(run_dir): """Train a DQN agent, resuming automatically from the latest checkpoint. To start fresh or resume from an earlier point, delete the checkpoints you no longer want from RUN_DIR/checkpoints/. """ run_dir_path = Path(run_dir) config = _load_config(run_dir_path) game_factory = _load_factory(config['game']) metadata = GameMetadata.from_dict(config['metadata']) preprocessing = config.get('preprocessing', {}) metadata.spatial = preprocessing.get('spatial', False) hyperparams = {**config.get('model', {}), **config['training']} try: trainer = DQNTrainer(game_factory, metadata, run_dir, preprocessing=preprocessing, **hyperparams) except ValueError as e: raise click.ClickException(str(e)) latest = _latest_checkpoint(run_dir_path) if latest: click.echo(f"Resuming from {latest.name}") try: trainer.load_checkpoint(latest) except ValueError as e: raise click.ClickException(str(e)) if trainer.start_episode > trainer.hp['training_episodes']: click.echo( f"Training already complete ({trainer.hp['training_episodes']} episodes). " "To keep training, increase training_episodes in config.toml." ) return from retro_gamer.log_parser import parse_checkpoints from retro_gamer.dashboard import TrainingDashboard history = parse_checkpoints(run_dir_path / 'training.log') display = TrainingDashboard(trainer.hp['training_episodes'], trainer.start_episode, history) try: trainer.train(on_checkpoint=display.on_checkpoint, on_episode=display.on_episode) finally: display.close() click.echo(f"Done. Checkpoints saved in {run_dir}/checkpoints/") # --------------------------------------------------------------------------- # retro-gamer play # --------------------------------------------------------------------------- @cli.command() @click.argument('run_dir') @click.option('--checkpoint', default=None, help='Checkpoint to load, e.g. "ep_0100". Defaults to the latest available.') @click.option('--framerate', default=12, type=int, help='Target frames per second') def play(run_dir, checkpoint, framerate): """Watch a trained agent play the game.""" from time import sleep, perf_counter from blessed import Terminal from retro.views.terminal import TerminalView from retro_gamer.model_agent import TrainedPolicy, PolicyInput run_dir_path = Path(run_dir) config = _load_config(run_dir_path) game_factory = _load_factory(config['game']) try: ai = TrainedPolicy(run_dir_path, checkpoint=checkpoint) except FileNotFoundError as e: raise click.ClickException(str(e)) game = game_factory() inp = PolicyInput(ai, game) terminal = Terminal() term_view = TerminalView(terminal, color=game.color) click.echo("Playing… (press Escape or Enter to quit)") with terminal.fullscreen(), terminal.hidden_cursor(), terminal.cbreak(): game.input_source = inp game.view = term_view game.start() while game.playing: t0 = perf_counter() key = terminal.inkey(0) if key and key.name in ('KEY_ESCAPE', 'KEY_ENTER'): break game.step() elapsed = perf_counter() - t0 sleep(max(0, 1 / framerate - elapsed)) # --------------------------------------------------------------------------- # retro-gamer plot # --------------------------------------------------------------------------- @cli.command() @click.argument('run_dir') @click.option('--output', '-o', default=None, help='Save to file (e.g. plot.png) instead of displaying interactively.') def plot(run_dir, output): """Plot training metrics (reward, steps, loss, epsilon) from a run's log.""" from retro_gamer.plotter import plot_run run_dir_path = Path(run_dir) log_path = run_dir_path / 'training.log' if not log_path.exists(): raise click.ClickException(f"No training.log found in {run_dir}") output_path = Path(output) if output else None try: plot_run(log_path, output_path) except ValueError as e: raise click.ClickException(str(e)) # --------------------------------------------------------------------------- # retro-gamer info # --------------------------------------------------------------------------- @cli.command() @click.argument('run_dir') def info(run_dir): """Print a summary of a training run.""" run_dir_path = Path(run_dir) config = _load_config(run_dir_path) click.echo(f"Game module : {config['game']['module']}") click.echo(f"Metadata : {config['metadata']}") click.echo(f"Preprocessing : {config.get('preprocessing', {})}") click.echo(f"Model : {config.get('model', {})}") click.echo(f"Training : {config['training']}") log_path = run_dir_path / 'training.log' if log_path.exists(): lines = log_path.read_text().splitlines() ckpt_lines = [l for l in lines if l.startswith('[ep_')] if ckpt_lines: click.echo(f"\nLast 5 checkpoints:") for line in ckpt_lines[-5:]: click.echo(f" {line}") ckpt_dir = run_dir_path / 'checkpoints' if ckpt_dir.exists(): ckpts = sorted(ckpt_dir.glob('*.pt')) click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}") # --------------------------------------------------------------------------- # retro-gamer clean # --------------------------------------------------------------------------- @cli.command() @click.argument('run_dir') @click.option('--yes', '-y', is_flag=True, help='Skip confirmation prompt') def clean(run_dir, yes): """Remove all checkpoints and the training log from a run directory. Use this after changing the game description or network architecture, which require starting training from scratch. """ run_dir_path = Path(run_dir) if not run_dir_path.exists(): raise click.ClickException(f"Directory not found: {run_dir}") to_remove = [] ckpt_dir = run_dir_path / 'checkpoints' if ckpt_dir.exists(): to_remove.extend(sorted(ckpt_dir.glob('*.pt'))) log_path = run_dir_path / 'training.log' if log_path.exists(): to_remove.append(log_path) if not to_remove: click.echo("Nothing to clean.") return n_ckpts = sum(1 for p in to_remove if p.suffix == '.pt') click.echo(f"Will remove {n_ckpts} checkpoint(s) and training log from {run_dir}/:") for p in to_remove: click.echo(f" {p.relative_to(run_dir_path)}") if not yes: click.confirm("\nProceed?", abort=True) for p in to_remove: p.unlink() click.echo(f"Cleaned. Run 'retro-gamer train {run_dir}' to start fresh.") ckpt_dir = run_dir_path / 'checkpoints' if ckpt_dir.exists(): ckpts = sorted(ckpt_dir.glob('*.pt')) click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _parse_game_arg(arg: str) -> dict: """Accept a .py file path, a package directory, or a Python module name. Returns a dict with at least 'module', and optionally 'path' (the directory to add to sys.path so the module can be imported). """ p = Path(arg) if p.exists(): if p.is_file() and p.suffix == '.py': path = str(p.parent.resolve()) module = p.stem elif p.is_dir(): path = str(p.parent.resolve()) module = p.name else: raise click.ClickException( f"{arg!r} is not a .py file or package directory" ) if path not in sys.path: sys.path.insert(0, path) return {'module': module, 'path': path} return {'module': arg} def _latest_checkpoint(run_dir: Path) -> Path | None: """Return the most recent checkpoint in run_dir/checkpoints/, or None.""" ckpt_dir = run_dir / 'checkpoints' if ckpt_dir.exists(): candidates = sorted(ckpt_dir.glob('ep_*.pt')) if candidates: return candidates[-1] return None def _load_config(run_dir: Path) -> dict: config_path = run_dir / 'config.toml' if not config_path.exists(): raise click.ClickException(f"No config.toml found in {run_dir}") with open(config_path, 'rb') as f: return tomllib.load(f) def _load_factory(game_config: dict): path = game_config.get('path') if path and path not in sys.path: sys.path.insert(0, path) module_name = game_config['module'] try: module = importlib.import_module(module_name) except ImportError as e: raise click.ClickException(f"Cannot import game module '{module_name}': {e}") if not hasattr(module, 'create_game'): raise click.ClickException( f"Module '{module_name}' has no create_game() function" ) return module.create_game