from __future__ import annotations import random from pathlib import Path from typing import Callable import numpy as np import torch import torch.nn as nn import torch.optim as optim import tomli_w from retro_gamer.metadata import GameMetadata from retro_gamer.env import GameEnvironment from retro_gamer.network import build_network from retro_gamer.memory import ReplayMemory, PrioritizedReplayMemory DEFAULTS: dict = { 'learning_rate': 1e-3, 'lr_decay': 0.995, 'gamma': 0.99, 'epsilon': 1.0, 'epsilon_decay': 0.995, 'epsilon_min': 0.05, 'batch_size': 64, 'memory_capacity': 10_000, 'target_update_freq': 100, 'training_episodes': 1_000, 'n_layers': 2, 'layer_size': 128, 'prioritize_experiences': False, 'exploration_turns': 200, 'unknown_character_strategy': 'ignore', 'max_turns_per_episode': 2_000, } class DQNTrainer: """Trains a deep Q-network agent to play a retro game. On initialization the trainer: 1. Discovers the character set (if not already specified in metadata). 2. Builds the Q-network and logs the full architecture with rationale. 3. Saves config.toml and starts training.log in run_dir. Call train() to run all episodes and save checkpoints. """ def __init__( self, game_factory: Callable, metadata: GameMetadata, run_dir: str | Path, **hyperparams, ): self.game_factory = game_factory self.metadata = metadata self.run_dir = Path(run_dir) self.hp: dict = {**DEFAULTS, **hyperparams} self.run_dir.mkdir(parents=True, exist_ok=True) (self.run_dir / 'checkpoints').mkdir(exist_ok=True) self.env = GameEnvironment(game_factory, metadata) if metadata.board_size is None: g = game_factory() metadata.board_size = g.board_size if metadata.character_set is None: self._discover_character_set() self.model, rationale = build_network(metadata, self.hp) self.target_model, _ = build_network(metadata, self.hp) self.target_model.load_state_dict(self.model.state_dict()) self.target_model.eval() self.optimizer = optim.Adam( self.model.parameters(), lr=self.hp['learning_rate'] ) self.lr_scheduler = optim.lr_scheduler.ExponentialLR( self.optimizer, gamma=self.hp['lr_decay'] ) if self.hp['prioritize_experiences']: self.memory = PrioritizedReplayMemory(self.hp['memory_capacity']) else: self.memory = ReplayMemory(self.hp['memory_capacity']) self.epsilon: float = self.hp['epsilon'] self.total_steps: int = 0 self._save_config() self._open_log(rationale) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def train(self): """Run all training episodes and save checkpoints.""" for episode in range(1, self.hp['training_episodes'] + 1): total_reward, steps, avg_loss = self._run_episode() self.epsilon = max( self.hp['epsilon_min'], self.epsilon * self.hp['epsilon_decay'] ) self.lr_scheduler.step() self._log_episode(episode, total_reward, steps, avg_loss) if episode % 100 == 0: self._save_checkpoint(f'ep_{episode:04d}.pt') self._save_checkpoint('final.pt') def load_checkpoint(self, path: str | Path): ckpt = torch.load(path, weights_only=True) self.model.load_state_dict(ckpt['model_state_dict']) self.target_model.load_state_dict(ckpt['model_state_dict']) self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) self.epsilon = ckpt['epsilon'] self.total_steps = ckpt['total_steps'] # ------------------------------------------------------------------ # Training loop internals # ------------------------------------------------------------------ def _run_episode(self) -> tuple[float, int, float]: state = self.env.reset() total_reward = 0.0 total_loss = 0.0 loss_count = 0 for step in range(self.hp['max_turns_per_episode']): state_t = torch.as_tensor(state, dtype=torch.float32) action_idx = self._select_action(state_t) action_key = self._idx_to_key(action_idx) next_state, reward, done = self.env.step(action_key) self.memory.push(state, action_idx, reward, next_state, done) loss = self._train_step() if loss is not None: total_loss += loss loss_count += 1 self.total_steps += 1 if self.total_steps % self.hp['target_update_freq'] == 0: self.target_model.load_state_dict(self.model.state_dict()) state = next_state total_reward += reward if done: break avg_loss = total_loss / loss_count if loss_count else 0.0 return total_reward, step + 1, avg_loss def _select_action(self, state_t: torch.Tensor) -> int: if random.random() < self.epsilon: return random.randrange(self.metadata.n_actions) with torch.no_grad(): return int(self.model(state_t.unsqueeze(0)).argmax().item()) def _idx_to_key(self, idx: int) -> str | None: if idx >= len(self.metadata.actions): return None return self.metadata.actions[idx] def _train_step(self) -> float | None: if len(self.memory) < self.hp['batch_size']: return None if self.hp['prioritize_experiences']: assert isinstance(self.memory, PrioritizedReplayMemory) experiences, indices, weights = self.memory.sample(self.hp['batch_size']) weight_t = torch.as_tensor(weights, dtype=torch.float32) else: experiences = self.memory.sample(self.hp['batch_size']) indices = None weight_t = None states = torch.as_tensor( np.array([e.state for e in experiences]), dtype=torch.float32 ) actions = torch.as_tensor([e.action for e in experiences], dtype=torch.long) rewards = torch.as_tensor([e.reward for e in experiences], dtype=torch.float32) next_states = torch.as_tensor( np.array([e.next_state for e in experiences]), dtype=torch.float32 ) dones = torch.as_tensor([e.done for e in experiences], dtype=torch.float32) q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1) with torch.no_grad(): next_q = self.target_model(next_states).max(1).values targets = rewards + self.hp['gamma'] * next_q * (1.0 - dones) element_loss = nn.functional.mse_loss(q_values, targets, reduction='none') if weight_t is not None: loss = (weight_t * element_loss).mean() td_errors = (q_values - targets).detach().abs().numpy() self.memory.update_priorities(indices, td_errors) else: loss = element_loss.mean() self.optimizer.zero_grad() loss.backward() self.optimizer.step() return float(loss.item()) # ------------------------------------------------------------------ # Initialisation helpers # ------------------------------------------------------------------ def _discover_character_set(self): chars = self.env.discover_character_set(self.hp['exploration_turns']) self.metadata.character_set = chars self._log_raw( f"[INIT] character_set not specified — discovered {len(chars)} chars " f"after {self.hp['exploration_turns']} exploration turns: {chars}" ) def _save_config(self): config_path = self.run_dir / 'config.toml' config: dict = {} if config_path.exists(): import tomllib with open(config_path, 'rb') as f: config = tomllib.load(f) config['metadata'] = self.metadata.to_dict() config['hyperparameters'] = self.hp with open(config_path, 'wb') as f: tomli_w.dump(config, f) def _open_log(self, rationale: str): self.log_path = self.run_dir / 'training.log' with open(self.log_path, 'w') as f: f.write(rationale + '\n') def _log_raw(self, line: str): with open(self.log_path, 'a') as f: f.write(line + '\n') def _log_episode(self, episode: int, total_reward: float, steps: int, avg_loss: float): line = ( f"[EP {episode:04d}] total_reward={total_reward:.1f} " f"steps={steps} epsilon={self.epsilon:.4f} avg_loss={avg_loss:.6f}" ) self._log_raw(line) def _save_checkpoint(self, name: str): torch.save( { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'epsilon': self.epsilon, 'total_steps': self.total_steps, }, self.run_dir / 'checkpoints' / name, )