Files
retro-gamer/retro_gamer/trainer.py
Chris Proctor 5ca97dc5d0 Initial commit
2026-05-08 14:07:17 -04:00

256 lines
9.1 KiB
Python

from __future__ import annotations
import random
from pathlib import Path
from typing import Callable
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tomli_w
from retro_gamer.metadata import GameMetadata
from retro_gamer.env import GameEnvironment
from retro_gamer.network import build_network
from retro_gamer.memory import ReplayMemory, PrioritizedReplayMemory
DEFAULTS: dict = {
'learning_rate': 1e-3,
'lr_decay': 0.995,
'gamma': 0.99,
'epsilon': 1.0,
'epsilon_decay': 0.995,
'epsilon_min': 0.05,
'batch_size': 64,
'memory_capacity': 10_000,
'target_update_freq': 100,
'training_episodes': 1_000,
'n_layers': 2,
'layer_size': 128,
'prioritize_experiences': False,
'exploration_turns': 200,
'unknown_character_strategy': 'ignore',
'max_turns_per_episode': 2_000,
}
class DQNTrainer:
"""Trains a deep Q-network agent to play a retro game.
On initialization the trainer:
1. Discovers the character set (if not already specified in metadata).
2. Builds the Q-network and logs the full architecture with rationale.
3. Saves config.toml and starts training.log in run_dir.
Call train() to run all episodes and save checkpoints.
"""
def __init__(
self,
game_factory: Callable,
metadata: GameMetadata,
run_dir: str | Path,
**hyperparams,
):
self.game_factory = game_factory
self.metadata = metadata
self.run_dir = Path(run_dir)
self.hp: dict = {**DEFAULTS, **hyperparams}
self.run_dir.mkdir(parents=True, exist_ok=True)
(self.run_dir / 'checkpoints').mkdir(exist_ok=True)
self.env = GameEnvironment(game_factory, metadata)
if metadata.board_size is None:
g = game_factory()
metadata.board_size = g.board_size
if metadata.character_set is None:
self._discover_character_set()
self.model, rationale = build_network(metadata, self.hp)
self.target_model, _ = build_network(metadata, self.hp)
self.target_model.load_state_dict(self.model.state_dict())
self.target_model.eval()
self.optimizer = optim.Adam(
self.model.parameters(), lr=self.hp['learning_rate']
)
self.lr_scheduler = optim.lr_scheduler.ExponentialLR(
self.optimizer, gamma=self.hp['lr_decay']
)
if self.hp['prioritize_experiences']:
self.memory = PrioritizedReplayMemory(self.hp['memory_capacity'])
else:
self.memory = ReplayMemory(self.hp['memory_capacity'])
self.epsilon: float = self.hp['epsilon']
self.total_steps: int = 0
self._save_config()
self._open_log(rationale)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def train(self):
"""Run all training episodes and save checkpoints."""
for episode in range(1, self.hp['training_episodes'] + 1):
total_reward, steps, avg_loss = self._run_episode()
self.epsilon = max(
self.hp['epsilon_min'], self.epsilon * self.hp['epsilon_decay']
)
self.lr_scheduler.step()
self._log_episode(episode, total_reward, steps, avg_loss)
if episode % 100 == 0:
self._save_checkpoint(f'ep_{episode:04d}.pt')
self._save_checkpoint('final.pt')
def load_checkpoint(self, path: str | Path):
ckpt = torch.load(path, weights_only=True)
self.model.load_state_dict(ckpt['model_state_dict'])
self.target_model.load_state_dict(ckpt['model_state_dict'])
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
self.epsilon = ckpt['epsilon']
self.total_steps = ckpt['total_steps']
# ------------------------------------------------------------------
# Training loop internals
# ------------------------------------------------------------------
def _run_episode(self) -> tuple[float, int, float]:
state = self.env.reset()
total_reward = 0.0
total_loss = 0.0
loss_count = 0
for step in range(self.hp['max_turns_per_episode']):
state_t = torch.as_tensor(state, dtype=torch.float32)
action_idx = self._select_action(state_t)
action_key = self._idx_to_key(action_idx)
next_state, reward, done = self.env.step(action_key)
self.memory.push(state, action_idx, reward, next_state, done)
loss = self._train_step()
if loss is not None:
total_loss += loss
loss_count += 1
self.total_steps += 1
if self.total_steps % self.hp['target_update_freq'] == 0:
self.target_model.load_state_dict(self.model.state_dict())
state = next_state
total_reward += reward
if done:
break
avg_loss = total_loss / loss_count if loss_count else 0.0
return total_reward, step + 1, avg_loss
def _select_action(self, state_t: torch.Tensor) -> int:
if random.random() < self.epsilon:
return random.randrange(self.metadata.n_actions)
with torch.no_grad():
return int(self.model(state_t.unsqueeze(0)).argmax().item())
def _idx_to_key(self, idx: int) -> str | None:
if idx >= len(self.metadata.actions):
return None
return self.metadata.actions[idx]
def _train_step(self) -> float | None:
if len(self.memory) < self.hp['batch_size']:
return None
if self.hp['prioritize_experiences']:
assert isinstance(self.memory, PrioritizedReplayMemory)
experiences, indices, weights = self.memory.sample(self.hp['batch_size'])
weight_t = torch.as_tensor(weights, dtype=torch.float32)
else:
experiences = self.memory.sample(self.hp['batch_size'])
indices = None
weight_t = None
states = torch.as_tensor(
np.array([e.state for e in experiences]), dtype=torch.float32
)
actions = torch.as_tensor([e.action for e in experiences], dtype=torch.long)
rewards = torch.as_tensor([e.reward for e in experiences], dtype=torch.float32)
next_states = torch.as_tensor(
np.array([e.next_state for e in experiences]), dtype=torch.float32
)
dones = torch.as_tensor([e.done for e in experiences], dtype=torch.float32)
q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
with torch.no_grad():
next_q = self.target_model(next_states).max(1).values
targets = rewards + self.hp['gamma'] * next_q * (1.0 - dones)
element_loss = nn.functional.mse_loss(q_values, targets, reduction='none')
if weight_t is not None:
loss = (weight_t * element_loss).mean()
td_errors = (q_values - targets).detach().abs().numpy()
self.memory.update_priorities(indices, td_errors)
else:
loss = element_loss.mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return float(loss.item())
# ------------------------------------------------------------------
# Initialisation helpers
# ------------------------------------------------------------------
def _discover_character_set(self):
chars = self.env.discover_character_set(self.hp['exploration_turns'])
self.metadata.character_set = chars
self._log_raw(
f"[INIT] character_set not specified — discovered {len(chars)} chars "
f"after {self.hp['exploration_turns']} exploration turns: {chars}"
)
def _save_config(self):
config_path = self.run_dir / 'config.toml'
config: dict = {}
if config_path.exists():
import tomllib
with open(config_path, 'rb') as f:
config = tomllib.load(f)
config['metadata'] = self.metadata.to_dict()
config['hyperparameters'] = self.hp
with open(config_path, 'wb') as f:
tomli_w.dump(config, f)
def _open_log(self, rationale: str):
self.log_path = self.run_dir / 'training.log'
with open(self.log_path, 'w') as f:
f.write(rationale + '\n')
def _log_raw(self, line: str):
with open(self.log_path, 'a') as f:
f.write(line + '\n')
def _log_episode(self, episode: int, total_reward: float, steps: int, avg_loss: float):
line = (
f"[EP {episode:04d}] total_reward={total_reward:.1f} "
f"steps={steps} epsilon={self.epsilon:.4f} avg_loss={avg_loss:.6f}"
)
self._log_raw(line)
def _save_checkpoint(self, name: str):
torch.save(
{
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'epsilon': self.epsilon,
'total_steps': self.total_steps,
},
self.run_dir / 'checkpoints' / name,
)