from __future__ import annotations import importlib import tomllib from pathlib import Path import click import tomli_w from retro_gamer.metadata import GameMetadata from retro_gamer.trainer import DQNTrainer, DEFAULTS @click.group() def cli(): """Train and run RL agents for retro games.""" # --------------------------------------------------------------------------- # retro-gamer create # --------------------------------------------------------------------------- @cli.command() @click.option('--game', required=True, help='Python module containing create_game() e.g. retro.examples.snake') @click.option('--output', required=True, help='Directory to create for this training run') @click.option('--learning-rate', default=DEFAULTS['learning_rate'], type=float, help=f"Adam optimizer learning rate (default {DEFAULTS['learning_rate']})") @click.option('--lr-decay', default=DEFAULTS['lr_decay'], type=float, help=f"Multiplicative LR decay per episode (default {DEFAULTS['lr_decay']})") @click.option('--gamma', default=DEFAULTS['gamma'], type=float, help=f"Discount factor for future rewards (default {DEFAULTS['gamma']})") @click.option('--epsilon-decay', default=DEFAULTS['epsilon_decay'], type=float, help=f"Exploration rate decay per episode (default {DEFAULTS['epsilon_decay']})") @click.option('--epsilon-min', default=DEFAULTS['epsilon_min'], type=float, help=f"Minimum exploration rate (default {DEFAULTS['epsilon_min']})") @click.option('--batch-size', default=DEFAULTS['batch_size'], type=int, help=f"Experiences per training step (default {DEFAULTS['batch_size']})") @click.option('--memory-capacity', default=DEFAULTS['memory_capacity'], type=int, help=f"Replay buffer size (default {DEFAULTS['memory_capacity']})") @click.option('--target-update-freq', default=DEFAULTS['target_update_freq'], type=int, help=f"Steps between target network updates (default {DEFAULTS['target_update_freq']})") @click.option('--training-episodes', default=DEFAULTS['training_episodes'], type=int, help=f"Number of episodes to train (default {DEFAULTS['training_episodes']})") @click.option('--max-turns-per-episode', default=DEFAULTS['max_turns_per_episode'], type=int, help=f"Turn limit per episode (default {DEFAULTS['max_turns_per_episode']})") @click.option('--n-layers', default=DEFAULTS['n_layers'], type=int, help=f"Hidden layers in MLP head (default {DEFAULTS['n_layers']})") @click.option('--layer-size', default=DEFAULTS['layer_size'], type=int, help=f"Width of each hidden layer (default {DEFAULTS['layer_size']})") @click.option('--exploration-turns', default=DEFAULTS['exploration_turns'], type=int, help=f"Random turns for character discovery (default {DEFAULTS['exploration_turns']})") @click.option('--prioritize-experiences/--no-prioritize-experiences', default=DEFAULTS['prioritize_experiences'], help='Use prioritized experience replay') def create(game, output, **hyperparams): """Create a new training run directory. Game metadata (actions, reward signal, etc.) is read from the [tool.retro-gamer] section of the game's pyproject.toml. Board size is read directly from the game. Hyperparameter options control how the trainer learns, not what it learns about. """ try: metadata = GameMetadata.from_pyproject(game) except (FileNotFoundError, ValueError) as e: raise click.ClickException(str(e)) game_factory = _load_factory(game) g = game_factory() metadata.board_size = g.board_size metadata.validate() run_dir = Path(output) run_dir.mkdir(parents=True, exist_ok=True) config = { 'game': {'module': game}, 'metadata': metadata.to_dict(), 'hyperparameters': hyperparams, } with open(run_dir / 'config.toml', 'wb') as f: tomli_w.dump(config, f) click.echo(f"Created training run at {output}/config.toml") click.echo(f" game : {game}") click.echo(f" board_size : {metadata.board_size[0]}×{metadata.board_size[1]}") click.echo(f" actions : {metadata.actions}") click.echo(f" reward : {metadata.reward}") if metadata.character_set: click.echo(f" characters : {metadata.character_set}") else: click.echo(f" characters : (will be auto-discovered during training)") if metadata.observe_state: click.echo(f" observe : {metadata.observe_state}") click.echo(f" architecture: {'CNN (spatial)' if metadata.spatial else 'MLP (non-spatial)'}") # --------------------------------------------------------------------------- # retro-gamer train # --------------------------------------------------------------------------- @cli.command() @click.argument('run_dir') @click.option('--resume', default=None, help='Path to checkpoint to resume from (e.g. checkpoints/ep_0500.pt)') def train(run_dir, resume): """Train (or resume training) a DQN agent.""" run_dir_path = Path(run_dir) config = _load_config(run_dir_path) game_factory = _load_factory(config['game']['module']) metadata = GameMetadata.from_dict(config['metadata']) hyperparams = config.get('hyperparameters', {}) trainer = DQNTrainer(game_factory, metadata, run_dir, **hyperparams) if resume: click.echo(f"Resuming from {resume}") trainer.load_checkpoint(resume) click.echo(f"Training for {trainer.hp['training_episodes']} episodes…") trainer.train() click.echo(f"Done. Checkpoints in {run_dir}/checkpoints/") # --------------------------------------------------------------------------- # retro-gamer play # --------------------------------------------------------------------------- @cli.command() @click.argument('run_dir') @click.option('--checkpoint', default='final', help='Checkpoint name e.g. "final" or "ep_0100"') @click.option('--framerate', default=12, type=int, help='Target frames per second') def play(run_dir, checkpoint, framerate): """Watch a trained agent play the game.""" import torch from time import sleep, perf_counter from blessed import Terminal from retro.input import ProgrammaticInput from retro.views.headless import HeadlessView from retro.views.terminal import TerminalView from retro_gamer.observation import encode_observation run_dir_path = Path(run_dir) config = _load_config(run_dir_path) game_factory = _load_factory(config['game']['module']) metadata = GameMetadata.from_dict(config['metadata']) hyperparams = {**DEFAULTS, **config.get('hyperparameters', {})} from retro_gamer.network import build_network model, _ = build_network(metadata, hyperparams) ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt' ckpt_path = run_dir_path / 'checkpoints' / ckpt_name ckpt = torch.load(ckpt_path, weights_only=True) model.load_state_dict(ckpt['model_state_dict']) model.eval() inp = ProgrammaticInput() headless = HeadlessView() game = game_factory() game.input_source = inp game.view = headless game.start() terminal = Terminal() term_view = TerminalView(terminal, color=game.color) click.echo("Playing… (press Escape or Enter to quit)") with terminal.fullscreen(), terminal.hidden_cursor(), terminal.cbreak(): term_view.on_game_start(game) while game.playing: t0 = perf_counter() obs = encode_observation(headless.board_characters, dict(game.state), metadata) state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0) with torch.no_grad(): action_idx = int(model(state_t).argmax().item()) action_key = None if action_idx >= len(metadata.actions) else metadata.actions[action_idx] key = terminal.inkey(0) if key and key.name in ('KEY_ESCAPE', 'KEY_ENTER'): break inp.press(action_key) game.step() term_view.render(game) elapsed = perf_counter() - t0 sleep(max(0, 1 / framerate - elapsed)) # --------------------------------------------------------------------------- # retro-gamer info # --------------------------------------------------------------------------- @cli.command() @click.argument('run_dir') def info(run_dir): """Print a summary of a training run.""" run_dir_path = Path(run_dir) config = _load_config(run_dir_path) click.echo(f"Game module : {config['game']['module']}") click.echo(f"Metadata : {config['metadata']}") click.echo(f"Hyperparams : {config.get('hyperparameters', {})}") log_path = run_dir_path / 'training.log' if log_path.exists(): lines = log_path.read_text().splitlines() episode_lines = [l for l in lines if l.startswith('[EP')] if episode_lines: click.echo(f"\nLast 5 episodes:") for line in episode_lines[-5:]: click.echo(f" {line}") ckpt_dir = run_dir_path / 'checkpoints' if ckpt_dir.exists(): ckpts = sorted(ckpt_dir.glob('*.pt')) click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _load_config(run_dir: Path) -> dict: config_path = run_dir / 'config.toml' if not config_path.exists(): raise click.ClickException(f"No config.toml found in {run_dir}") with open(config_path, 'rb') as f: return tomllib.load(f) def _load_factory(module_name: str): try: module = importlib.import_module(module_name) except ImportError as e: raise click.ClickException(f"Cannot import game module '{module_name}': {e}") if not hasattr(module, 'create_game'): raise click.ClickException( f"Module '{module_name}' has no create_game() function" ) return module.create_game