from __future__ import annotations import random from datetime import datetime from pathlib import Path from time import perf_counter from typing import Callable import numpy as np import torch import torch.nn as nn import torch.optim as optim import tomli_w from tqdm import tqdm from retro_gamer.metadata import GameMetadata from retro_gamer.env import GameEnvironment from retro_gamer.network import build_network from retro_gamer.memory import ReplayMemory, PrioritizedReplayMemory MODEL_KEYS: frozenset = frozenset({'hidden_sizes'}) DEFAULTS: dict = { # [model] 'hidden_sizes': [128, 64], # [training] 'learning_rate': 1e-4, 'learning_rate_decay': 0.9999, 'gamma': 0.99, 'epsilon': 1.0, 'epsilon_decay': 0.9997, 'epsilon_min': 0.05, 'batch_size': 64, 'memory_capacity': 50_000, 'target_update_freq': 500, 'train_every': 4, 'training_episodes': 20_000, 'prioritize_experiences': True, 'exploration_turns': 200, 'unknown_character_strategy': 'ignore', 'max_turns_per_episode': 2_000, } def _get_device() -> torch.device: if torch.backends.mps.is_available(): return torch.device('mps') if torch.cuda.is_available(): return torch.device('cuda') return torch.device('cpu') # Fields that make an existing checkpoint incompatible with the current config. # Changing any of these requires starting training from scratch. _INCOMPATIBLE_METADATA = { 'actions': 'the list of actions the agent can take (changes output layer size)', 'reward': 'the reward signal — Q-values trained on the old signal are meaningless for the new one', 'character_set': 'the set of board characters (changes input layer size)', 'board_size': 'the board dimensions (changes input layer size)', } _INCOMPATIBLE_PREPROCESSING = { 'spatial': 'spatial vs non-spatial network type (changes network architecture)', 'board': 'whether the board is included in the observation (changes input size)', 'observe_state': 'the state keys included in the observation (changes input size)', 'observe_state_sizes': 'the size of each observed state key (changes input layer size)', 'egocentric': 'egocentric board transformation (changes input representation)', 'egocentric_player': 'the agent used as the egocentric center (changes input representation)', 'egocentric_radius': 'the egocentric crop radius (changes input layer size)', } _INCOMPATIBLE_ARCH = { 'hidden_sizes': 'the hidden layer sizes (changes network shape)', } def validate_hyperparams(hp: dict): """Check all hyperparameters and raise ValueError listing every problem found.""" problems = [] def _problem(heading: str, explanation: str, fix: str | None = None): text = f" {heading}\n {explanation}" if fix: text += f"\n → {fix}" problems.append(text) hs = hp.get('hidden_sizes') if not isinstance(hs, list) or len(hs) == 0: _problem( f"hidden_sizes = {hs!r}", "This sets the shape of the neural network's hidden layers. It must be a\n" " non-empty list of positive integers, one number per layer.", "Try: hidden_sizes = [512, 256]", ) elif any(not isinstance(s, int) or s <= 0 for s in hs): bad = [s for s in hs if not isinstance(s, int) or s <= 0] _problem( f"hidden_sizes = {hs!r}", f"Every layer size must be a positive integer, but got {bad!r}.\n" " Each number is the count of neurons in that layer — more neurons means\n" " more capacity to learn complex patterns.", "Try: hidden_sizes = [512, 256]", ) lr = hp.get('learning_rate') if not isinstance(lr, (int, float)) or lr <= 0: _problem( f"learning_rate = {lr!r}", "The learning rate controls how much the network adjusts its weights after\n" " each training step. Too high and training becomes unstable; too low and\n" " it learns very slowly. It must be a positive number.", "Typical values are between 0.0001 and 0.01. Try: learning_rate = 0.001", ) for key, blurb in [ ('learning_rate_decay', "After each episode, the learning rate is multiplied by this value, gradually\n" " slowing learning over time. A value of 1.0 means no decay; closer to 0\n" " means very aggressive decay. It must be greater than 0 and at most 1."), ('gamma', "Gamma is the discount factor: how much the agent values future rewards versus\n" " immediate ones. 0.99 means future rewards are nearly as important as now;\n" " 0.0 means the agent only cares about the very next step.\n" " It must be greater than 0 and at most 1."), ('epsilon_decay', "Each episode, the exploration rate (epsilon) is multiplied by this to gradually\n" " reduce random actions over time. It must be between 0 (exclusive) and 1."), ]: v = hp.get(key) if not isinstance(v, (int, float)) or not (0 < v <= 1): _problem( f"{key} = {v!r}", blurb, f"Try a value close to but less than 1, like {key} = 0.995", ) for key, blurb in [ ('epsilon', "Epsilon is the probability of taking a random action (exploration vs.\n" " exploitation). It starts high — usually 1.0, meaning fully random —\n" " and decays toward epsilon_min during training. It must be between 0 and 1."), ('epsilon_min', "This is the lowest exploration rate allowed. Even after lots of training, the\n" " agent keeps at least this much randomness so it keeps discovering new things.\n" " It must be between 0 and 1."), ]: v = hp.get(key) if not isinstance(v, (int, float)) or not (0 <= v <= 1): _problem( f"{key} = {v!r}", blurb, f"Try: {key} = {'0.05' if 'min' in key else '1.0'}", ) eps = hp.get('epsilon') eps_min = hp.get('epsilon_min') if (isinstance(eps, (int, float)) and isinstance(eps_min, (int, float)) and 0 <= eps <= 1 and 0 <= eps_min <= 1 and eps_min > eps): _problem( f"epsilon_min = {eps_min!r} is greater than epsilon = {eps!r}", "epsilon is the starting exploration rate and epsilon_min is the floor it\n" " decays toward, so epsilon_min must be less than or equal to epsilon.", f"Try: epsilon = 1.0 and epsilon_min = 0.05", ) for key, blurb in [ ('batch_size', "Training samples this many past experiences from the replay buffer at once\n" " to compute a learning update. Must be a positive integer."), ('memory_capacity', "The replay buffer stores this many past experiences. When it fills up, the\n" " oldest are discarded. A larger buffer means more diverse training data.\n" " Must be a positive integer."), ('target_update_freq', "The target network (a stable copy of the Q-network used to compute targets)\n" " is updated every this many steps. Must be a positive integer."), ('train_every', "A training step runs once every this many game steps. This lets the agent\n" " collect several new experiences before updating. Must be a positive integer."), ('training_episodes', "The total number of episodes (games) to train for. Must be a positive integer."), ('max_turns_per_episode', "If a game episode hasn't ended naturally after this many steps, it's cut\n" " short. This prevents a buggy or stuck agent from running forever.\n" " Must be a positive integer."), ]: v = hp.get(key) if not isinstance(v, int) or v <= 0: _problem( f"{key} = {v!r}", blurb, f"Try: {key} = {DEFAULTS[key]}", ) v = hp.get('exploration_turns') if not isinstance(v, int) or v < 0: _problem( f"exploration_turns = {v!r}", "When no character_set is specified, the trainer runs this many random turns\n" " to discover what characters appear on the board. Must be 0 or more\n" " (0 skips discovery entirely, which only works if character_set is set).", f"Try: exploration_turns = {DEFAULTS['exploration_turns']}", ) bs = hp.get('batch_size') mc = hp.get('memory_capacity') if (isinstance(bs, int) and bs > 0 and isinstance(mc, int) and mc > 0 and bs > mc): _problem( f"batch_size = {bs} is larger than memory_capacity = {mc}", "Training samples a batch of past experiences from the replay buffer each\n" " step, so the buffer must be able to hold at least as many experiences\n" f" as the batch size. With batch_size = {bs}, the buffer holds {mc} —\n" " that's not enough to sample from.", f"Try: memory_capacity = {max(bs * 100, DEFAULTS['memory_capacity'])} " f"(a much larger buffer also improves learning quality)", ) if problems: n = len(problems) noun = "problem" if n == 1 else "problems" header = f"Found {n} {noun} in your training configuration:\n\n" footer = "\n\nFix these in config.toml, then run 'retro-gamer train' again." raise ValueError(header + "\n\n".join(problems) + footer) def _format_duration(seconds: float) -> str: m, s = divmod(int(seconds), 60) h, m = divmod(m, 60) if h: return f"{h}h{m:02d}m{s:02d}s" return f"{m}m{s:02d}s" class DQNTrainer: """Trains a deep Q-network agent to play a retro game. Automatically selects the best available training device: Apple Silicon GPU (MPS), NVIDIA GPU (CUDA), or CPU. The chosen device is recorded in ``training.log``. On initialization, the trainer: 1. Discovers the character set if not already specified in *metadata*. 2. Builds the Q-network and logs its full architecture with rationale. 3. Writes ``config.toml`` and initializes ``training.log`` in *run_dir*. Hyperparameters can be passed as keyword arguments; see the :ref:`hyperparameters` reference for all options. Values not supplied fall back to sensible defaults. Call :meth:`train` to run all episodes. Checkpoints are saved every 100 episodes and training can be stopped (Ctrl+C) and resumed at any time. Example:: from retro_gamer import GameMetadata, DQNTrainer from retro.examples.snake import create_game metadata = GameMetadata.from_pyproject("retro.examples.snake") trainer = DQNTrainer(create_game, metadata, "runs/snake/") trainer.train() """ def __init__( self, game_factory: Callable, metadata: GameMetadata, run_dir: str | Path, preprocessing: dict | None = None, **hyperparams, ): self.game_factory = game_factory self.metadata = metadata self.run_dir = Path(run_dir) self.hp: dict = {**DEFAULTS, **hyperparams} validate_hyperparams(self.hp) self.run_dir.mkdir(parents=True, exist_ok=True) (self.run_dir / 'checkpoints').mkdir(exist_ok=True) pre = preprocessing or {} self.observe_state: list[str] = pre.get('observe_state', []) self.egocentric: bool = pre.get('egocentric', False) self.egocentric_player: str | None = pre.get('egocentric_player', None) self.egocentric_radius: int | None = pre.get('egocentric_radius', None) self.board: bool = pre.get('board', True) self.observe_state_sizes: dict[str, int] = pre.get('observe_state_sizes', {}) if self.board is False and metadata.spatial: raise ValueError( "preprocessing.board = false is incompatible with spatial = true.\n" "A CNN requires a 2-D board to operate on. Either set spatial = false\n" "or keep board = true." ) if self.board is False and not self.observe_state: raise ValueError( "preprocessing.board = false requires at least one entry in observe_state.\n" "With board=false, the agent observes only the game state variables listed\n" "in observe_state — if that list is empty, there is nothing to observe." ) if self.egocentric and not self.egocentric_radius: raise ValueError( "preprocessing.egocentric = true requires egocentric_radius.\n" "Choose a value based on how far the agent needs to see, e.g.:\n" " egocentric_radius = 5 # 11×11 tight local view\n" " egocentric_radius = 8 # 17×17 wider view" ) metadata.board = self.board if metadata.board_size is None: g = game_factory() metadata.board_size = g.board_size if self.egocentric_radius: side = 2 * self.egocentric_radius + 1 metadata.board_size = (side, side) self.env = GameEnvironment( game_factory, metadata, observe_state=self.observe_state, egocentric=self.egocentric, egocentric_player=self.egocentric_player, egocentric_radius=self.egocentric_radius, board=self.board, observe_state_sizes=self.observe_state_sizes, ) if metadata.character_set is None and self.board: self._discover_character_set() if self.observe_state and not self.observe_state_sizes: self._discover_observe_state_sizes() self.env.observe_state_sizes = self.observe_state_sizes metadata.extras_size = sum(self.observe_state_sizes.values()) if self.observe_state_sizes else 0 self.device = _get_device() self.model, rationale = build_network(metadata, self.hp) self.target_model, _ = build_network(metadata, self.hp) self.model.to(self.device) self.target_model.to(self.device) self.target_model.load_state_dict(self.model.state_dict()) self.target_model.eval() self.optimizer = optim.Adam( self.model.parameters(), lr=self.hp['learning_rate'] ) self.lr_scheduler = optim.lr_scheduler.ExponentialLR( self.optimizer, gamma=self.hp['learning_rate_decay'] ) if self.hp['prioritize_experiences']: self.memory = PrioritizedReplayMemory(self.hp['memory_capacity']) else: self.memory = ReplayMemory(self.hp['memory_capacity']) self.epsilon: float = self.hp['epsilon'] self.total_steps: int = 0 self.total_training_seconds: float = 0.0 self.start_episode: int = 1 self._resumed_from: str | None = None self._save_config() self._open_log(rationale) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def train(self, on_checkpoint=None, on_episode=None): """Run all training episodes and save checkpoints. *on_checkpoint*, if provided, is called after each checkpoint with a dict containing ``episode``, ``avg_reward``, ``avg_steps``, ``avg_loss``, and ``epsilon``. *on_episode*, if provided, is called after every episode. When either callback is supplied, the built-in tqdm progress bar is suppressed (the caller is expected to show its own progress UI). """ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') if self._resumed_from: self._log_raw(f'\n=== Resumed from {self._resumed_from} | {timestamp} ===') else: self._log_raw(f'\n=== Training started | {timestamp} ===') use_tqdm = on_checkpoint is None and on_episode is None if use_tqdm: print("Press Control+C to stop training early. Progress will be saved at the latest checkpoint.") session_start = perf_counter() ckpt_start = perf_counter() episode_rewards: list[float] = [] episode_losses: list[float] = [] episode_steps: list[int] = [] episodes = range(self.start_episode, self.hp['training_episodes'] + 1) bar = tqdm(episodes, unit='ep') if use_tqdm else episodes for episode in bar: total_reward, steps, avg_loss, trained = self._run_episode() episode_rewards.append(total_reward) if avg_loss > 0: episode_losses.append(avg_loss) episode_steps.append(steps) self.epsilon = max( self.hp['epsilon_min'], self.epsilon * self.hp['epsilon_decay'] ) if trained: self.lr_scheduler.step() if on_episode: on_episode() if use_tqdm: bar.set_postfix( reward=f'{total_reward:.1f}', eps=f'{self.epsilon:.3f}', loss=f'{avg_loss:.4f}', ) is_checkpoint = (episode % 100 == 0) is_last = (episode == self.hp['training_episodes']) if is_checkpoint or (is_last and episode_rewards): now = perf_counter() ckpt_elapsed = now - ckpt_start self.total_training_seconds += ckpt_elapsed ckpt_start = now self._save_checkpoint(f'ep_{episode:04d}.pt', episode) stats = self._log_checkpoint(episode, episode_rewards, episode_losses, episode_steps, ckpt_elapsed) episode_rewards = [] episode_losses = [] episode_steps = [] if on_checkpoint: on_checkpoint(stats) def load_checkpoint(self, path: str | Path): """Load a checkpoint to resume training. Checkpoints are PyTorch state dicts stored under ``run_dir/checkpoints/``. Each contains model weights, optimizer state, current epsilon, and total step count. Raises :exc:`ValueError` if the checkpoint was trained with a different character set, board size, action space, or network architecture. The error message names each changed field and explains why it is incompatible. The CLI invokes this automatically; call directly only when driving training from Python. """ ckpt = torch.load(path, weights_only=True, map_location='cpu') self._check_compatibility(ckpt, path) self.model.load_state_dict(ckpt['model_state_dict']) self.target_model.load_state_dict(ckpt['model_state_dict']) self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(self.device) self.epsilon = ckpt['epsilon'] self.total_steps = ckpt['total_steps'] self.total_training_seconds = ckpt.get('total_training_seconds', 0.0) self.start_episode = ckpt.get('episode', 0) + 1 self._resumed_from = Path(path).name # ------------------------------------------------------------------ # Training loop internals # ------------------------------------------------------------------ def _run_episode(self) -> tuple[float, int, float, bool]: state = self.env.reset() total_reward = 0.0 total_loss = 0.0 loss_count = 0 for step in range(self.hp['max_turns_per_episode']): state_t = torch.as_tensor(state, dtype=torch.float32).to(self.device) action_idx = self._select_action(state_t) action_key = self._idx_to_key(action_idx) next_state, reward, done = self.env.step(action_key) self.memory.push(state, action_idx, reward, next_state, done) if self.total_steps % self.hp['train_every'] == 0: loss = self._train_step() if loss is not None: total_loss += loss loss_count += 1 self.total_steps += 1 if self.total_steps % self.hp['target_update_freq'] == 0: self.target_model.load_state_dict(self.model.state_dict()) state = next_state total_reward += reward if done: break avg_loss = total_loss / loss_count if loss_count else 0.0 return total_reward, step + 1, avg_loss, loss_count > 0 def _select_action(self, state_t: torch.Tensor) -> int: if random.random() < self.epsilon: return random.randrange(self.metadata.n_actions) with torch.no_grad(): return int(self.model(state_t.unsqueeze(0)).argmax().item()) def _idx_to_key(self, idx: int) -> str | None: if idx >= len(self.metadata.actions): return None return self.metadata.actions[idx] def _train_step(self) -> float | None: if len(self.memory) < self.hp['batch_size']: return None if self.hp['prioritize_experiences']: assert isinstance(self.memory, PrioritizedReplayMemory) experiences, indices, weights = self.memory.sample(self.hp['batch_size']) weight_t = torch.as_tensor(weights, dtype=torch.float32).to(self.device) else: experiences = self.memory.sample(self.hp['batch_size']) indices = None weight_t = None states = torch.as_tensor( np.array([e.state for e in experiences]), dtype=torch.float32 ).to(self.device) actions = torch.as_tensor( [e.action for e in experiences], dtype=torch.long ).to(self.device) rewards = torch.as_tensor( [e.reward for e in experiences], dtype=torch.float32 ).to(self.device) next_states = torch.as_tensor( np.array([e.next_state for e in experiences]), dtype=torch.float32 ).to(self.device) dones = torch.as_tensor( [e.done for e in experiences], dtype=torch.float32 ).to(self.device) q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1) with torch.no_grad(): next_q = self.target_model(next_states).max(1).values targets = rewards + self.hp['gamma'] * next_q * (1.0 - dones) element_loss = nn.functional.huber_loss(q_values, targets, reduction='none', delta=1.0) if weight_t is not None: loss = (weight_t * element_loss).mean() td_errors = (q_values - targets).detach().abs().cpu().numpy() self.memory.update_priorities(indices, td_errors) else: loss = element_loss.mean() self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) self.optimizer.step() return float(loss.item()) # ------------------------------------------------------------------ # Compatibility checking # ------------------------------------------------------------------ def _config_snapshot(self) -> dict: return { 'metadata': self.metadata.to_dict(), 'preprocessing': { 'spatial': self.metadata.spatial, 'board': self.board, 'observe_state': self.observe_state, 'observe_state_sizes': self.observe_state_sizes, 'egocentric': self.egocentric, 'egocentric_player': self.egocentric_player, 'egocentric_radius': self.egocentric_radius, }, 'hidden_sizes': self.hp['hidden_sizes'], } def _check_compatibility(self, ckpt: dict, path: str | Path): snapshot = ckpt.get('config_snapshot') if snapshot is None: return current = self._config_snapshot() issues = [] old_meta = snapshot.get('metadata', {}) new_meta = current['metadata'] for field, desc in _INCOMPATIBLE_METADATA.items(): if old_meta.get(field) != new_meta.get(field): issues.append((field, desc, old_meta.get(field), new_meta.get(field))) old_pre = snapshot.get('preprocessing', {}) new_pre = current['preprocessing'] for field, desc in _INCOMPATIBLE_PREPROCESSING.items(): if old_pre.get(field) != new_pre.get(field): issues.append((field, desc, old_pre.get(field), new_pre.get(field))) for field, desc in _INCOMPATIBLE_ARCH.items(): if snapshot.get(field) != current.get(field): issues.append((field, desc, snapshot.get(field), current.get(field))) if not issues: return lines = [ f"Cannot resume from {Path(path).name}: incompatible changes detected in config.toml.", "", "The following changes require starting fresh. The existing model was trained", "on a different problem and its weights cannot be reused:", "", ] for field, desc, old_val, new_val in issues: lines += [ f" {field}", f" was : {old_val!r}", f" now : {new_val!r}", f" why : {desc}", "", ] lines += [ "Run 'retro-gamer clean RUN_DIR' to remove existing checkpoints and the", "training log, then run 'retro-gamer train RUN_DIR' to start fresh.", ] raise ValueError("\n".join(lines)) # ------------------------------------------------------------------ # Initialisation helpers # ------------------------------------------------------------------ def _discover_character_set(self): chars = self.env.discover_character_set(self.hp['exploration_turns']) self.metadata.character_set = chars self._log_raw( f"[INIT] character_set not specified — discovered {len(chars)} chars " f"after {self.hp['exploration_turns']} exploration turns: {chars}" ) def _discover_observe_state_sizes(self): """Sample game.state to determine the flat size of each observe_state key.""" self.env.reset() state = dict(self.env.game.state) sizes = {} for key in self.observe_state: val = state.get(key, 0) sizes[key] = len(val) if isinstance(val, (list, tuple)) else 1 self.observe_state_sizes = sizes def _save_config(self): config_path = self.run_dir / 'config.toml' config: dict = {} if config_path.exists(): import tomllib with open(config_path, 'rb') as f: config = tomllib.load(f) config['metadata'] = self.metadata.to_dict() pre = config.setdefault('preprocessing', {}) pre['spatial'] = self.metadata.spatial pre['board'] = self.board pre['observe_state'] = self.observe_state if self.observe_state_sizes: pre['observe_state_sizes'] = self.observe_state_sizes pre['egocentric'] = self.egocentric if self.egocentric_player: pre['egocentric_player'] = self.egocentric_player if self.egocentric_radius: pre['egocentric_radius'] = self.egocentric_radius config['model'] = {k: v for k, v in self.hp.items() if k in MODEL_KEYS} config['training'] = {k: v for k, v in self.hp.items() if k not in MODEL_KEYS} with open(config_path, 'wb') as f: tomli_w.dump(config, f) def _open_log(self, rationale: str): self.log_path = self.run_dir / 'training.log' if not self.log_path.exists(): with open(self.log_path, 'w') as f: f.write(rationale + '\n') f.write(f'[INIT] Device: {self.device}\n') def _log_raw(self, line: str): with open(self.log_path, 'a') as f: f.write(line + '\n') def _log_checkpoint( self, episode: int, rewards: list[float], losses: list[float], steps: list[int], ckpt_elapsed: float, ) -> dict: n = len(rewards) start_ep = episode - n + 1 avg_reward = sum(rewards) / n if n else 0.0 avg_loss = sum(losses) / len(losses) if losses else 0.0 avg_steps = sum(steps) / n if n else 0.0 line = ( f"[ep_{episode:04d}]" f" ep={start_ep:04d}-{episode:04d}" f" avg_reward={avg_reward:+.1f}" f" avg_steps={avg_steps:.0f}" f" epsilon={self.epsilon:.3f}" f" avg_loss={avg_loss:.1f}" f" time={_format_duration(ckpt_elapsed)}" f" total={_format_duration(self.total_training_seconds)}" ) self._log_raw(line) return { 'episode': episode, 'avg_reward': avg_reward, 'avg_steps': avg_steps, 'avg_loss': avg_loss, 'epsilon': self.epsilon, } def _save_checkpoint(self, name: str, episode: int): torch.save( { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'epsilon': self.epsilon, 'total_steps': self.total_steps, 'episode': episode, 'total_training_seconds': self.total_training_seconds, 'config_snapshot': self._config_snapshot(), }, self.run_dir / 'checkpoints' / name, )