729 lines
30 KiB
Python
729 lines
30 KiB
Python
from __future__ import annotations
|
||
import random
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from time import perf_counter
|
||
from typing import Callable
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
import tomli_w
|
||
|
||
from tqdm import tqdm
|
||
from retro_gamer.metadata import GameMetadata
|
||
from retro_gamer.env import GameEnvironment
|
||
from retro_gamer.network import build_network
|
||
from retro_gamer.memory import ReplayMemory, PrioritizedReplayMemory
|
||
|
||
MODEL_KEYS: frozenset = frozenset({'hidden_sizes'})
|
||
|
||
DEFAULTS: dict = {
|
||
# [model]
|
||
'hidden_sizes': [128, 64],
|
||
# [training]
|
||
'learning_rate': 1e-4,
|
||
'learning_rate_decay': 0.9999,
|
||
'gamma': 0.99,
|
||
'epsilon': 1.0,
|
||
'epsilon_decay': 0.9997,
|
||
'epsilon_min': 0.05,
|
||
'batch_size': 64,
|
||
'memory_capacity': 50_000,
|
||
'target_update_freq': 500,
|
||
'train_every': 4,
|
||
'training_episodes': 20_000,
|
||
'prioritize_experiences': True,
|
||
'exploration_turns': 200,
|
||
'unknown_character_strategy': 'ignore',
|
||
'max_turns_per_episode': 2_000,
|
||
}
|
||
|
||
|
||
def _get_device() -> torch.device:
|
||
if torch.backends.mps.is_available():
|
||
return torch.device('mps')
|
||
if torch.cuda.is_available():
|
||
return torch.device('cuda')
|
||
return torch.device('cpu')
|
||
|
||
# Fields that make an existing checkpoint incompatible with the current config.
|
||
# Changing any of these requires starting training from scratch.
|
||
_INCOMPATIBLE_METADATA = {
|
||
'actions': 'the list of actions the agent can take (changes output layer size)',
|
||
'reward': 'the reward signal — Q-values trained on the old signal are meaningless for the new one',
|
||
'character_set': 'the set of board characters (changes input layer size)',
|
||
'board_size': 'the board dimensions (changes input layer size)',
|
||
}
|
||
_INCOMPATIBLE_PREPROCESSING = {
|
||
'spatial': 'spatial vs non-spatial network type (changes network architecture)',
|
||
'board': 'whether the board is included in the observation (changes input size)',
|
||
'observe_state': 'the state keys included in the observation (changes input size)',
|
||
'observe_state_sizes': 'the size of each observed state key (changes input layer size)',
|
||
'egocentric': 'egocentric board transformation (changes input representation)',
|
||
'egocentric_player': 'the agent used as the egocentric center (changes input representation)',
|
||
'egocentric_radius': 'the egocentric crop radius (changes input layer size)',
|
||
}
|
||
_INCOMPATIBLE_ARCH = {
|
||
'hidden_sizes': 'the hidden layer sizes (changes network shape)',
|
||
}
|
||
|
||
|
||
def validate_hyperparams(hp: dict):
|
||
"""Check all hyperparameters and raise ValueError listing every problem found."""
|
||
problems = []
|
||
|
||
def _problem(heading: str, explanation: str, fix: str | None = None):
|
||
text = f" {heading}\n {explanation}"
|
||
if fix:
|
||
text += f"\n → {fix}"
|
||
problems.append(text)
|
||
|
||
hs = hp.get('hidden_sizes')
|
||
if not isinstance(hs, list) or len(hs) == 0:
|
||
_problem(
|
||
f"hidden_sizes = {hs!r}",
|
||
"This sets the shape of the neural network's hidden layers. It must be a\n"
|
||
" non-empty list of positive integers, one number per layer.",
|
||
"Try: hidden_sizes = [512, 256]",
|
||
)
|
||
elif any(not isinstance(s, int) or s <= 0 for s in hs):
|
||
bad = [s for s in hs if not isinstance(s, int) or s <= 0]
|
||
_problem(
|
||
f"hidden_sizes = {hs!r}",
|
||
f"Every layer size must be a positive integer, but got {bad!r}.\n"
|
||
" Each number is the count of neurons in that layer — more neurons means\n"
|
||
" more capacity to learn complex patterns.",
|
||
"Try: hidden_sizes = [512, 256]",
|
||
)
|
||
|
||
lr = hp.get('learning_rate')
|
||
if not isinstance(lr, (int, float)) or lr <= 0:
|
||
_problem(
|
||
f"learning_rate = {lr!r}",
|
||
"The learning rate controls how much the network adjusts its weights after\n"
|
||
" each training step. Too high and training becomes unstable; too low and\n"
|
||
" it learns very slowly. It must be a positive number.",
|
||
"Typical values are between 0.0001 and 0.01. Try: learning_rate = 0.001",
|
||
)
|
||
|
||
for key, blurb in [
|
||
('learning_rate_decay',
|
||
"After each episode, the learning rate is multiplied by this value, gradually\n"
|
||
" slowing learning over time. A value of 1.0 means no decay; closer to 0\n"
|
||
" means very aggressive decay. It must be greater than 0 and at most 1."),
|
||
('gamma',
|
||
"Gamma is the discount factor: how much the agent values future rewards versus\n"
|
||
" immediate ones. 0.99 means future rewards are nearly as important as now;\n"
|
||
" 0.0 means the agent only cares about the very next step.\n"
|
||
" It must be greater than 0 and at most 1."),
|
||
('epsilon_decay',
|
||
"Each episode, the exploration rate (epsilon) is multiplied by this to gradually\n"
|
||
" reduce random actions over time. It must be between 0 (exclusive) and 1."),
|
||
]:
|
||
v = hp.get(key)
|
||
if not isinstance(v, (int, float)) or not (0 < v <= 1):
|
||
_problem(
|
||
f"{key} = {v!r}",
|
||
blurb,
|
||
f"Try a value close to but less than 1, like {key} = 0.995",
|
||
)
|
||
|
||
for key, blurb in [
|
||
('epsilon',
|
||
"Epsilon is the probability of taking a random action (exploration vs.\n"
|
||
" exploitation). It starts high — usually 1.0, meaning fully random —\n"
|
||
" and decays toward epsilon_min during training. It must be between 0 and 1."),
|
||
('epsilon_min',
|
||
"This is the lowest exploration rate allowed. Even after lots of training, the\n"
|
||
" agent keeps at least this much randomness so it keeps discovering new things.\n"
|
||
" It must be between 0 and 1."),
|
||
]:
|
||
v = hp.get(key)
|
||
if not isinstance(v, (int, float)) or not (0 <= v <= 1):
|
||
_problem(
|
||
f"{key} = {v!r}",
|
||
blurb,
|
||
f"Try: {key} = {'0.05' if 'min' in key else '1.0'}",
|
||
)
|
||
|
||
eps = hp.get('epsilon')
|
||
eps_min = hp.get('epsilon_min')
|
||
if (isinstance(eps, (int, float)) and isinstance(eps_min, (int, float))
|
||
and 0 <= eps <= 1 and 0 <= eps_min <= 1 and eps_min > eps):
|
||
_problem(
|
||
f"epsilon_min = {eps_min!r} is greater than epsilon = {eps!r}",
|
||
"epsilon is the starting exploration rate and epsilon_min is the floor it\n"
|
||
" decays toward, so epsilon_min must be less than or equal to epsilon.",
|
||
f"Try: epsilon = 1.0 and epsilon_min = 0.05",
|
||
)
|
||
|
||
for key, blurb in [
|
||
('batch_size',
|
||
"Training samples this many past experiences from the replay buffer at once\n"
|
||
" to compute a learning update. Must be a positive integer."),
|
||
('memory_capacity',
|
||
"The replay buffer stores this many past experiences. When it fills up, the\n"
|
||
" oldest are discarded. A larger buffer means more diverse training data.\n"
|
||
" Must be a positive integer."),
|
||
('target_update_freq',
|
||
"The target network (a stable copy of the Q-network used to compute targets)\n"
|
||
" is updated every this many steps. Must be a positive integer."),
|
||
('train_every',
|
||
"A training step runs once every this many game steps. This lets the agent\n"
|
||
" collect several new experiences before updating. Must be a positive integer."),
|
||
('training_episodes',
|
||
"The total number of episodes (games) to train for. Must be a positive integer."),
|
||
('max_turns_per_episode',
|
||
"If a game episode hasn't ended naturally after this many steps, it's cut\n"
|
||
" short. This prevents a buggy or stuck agent from running forever.\n"
|
||
" Must be a positive integer."),
|
||
]:
|
||
v = hp.get(key)
|
||
if not isinstance(v, int) or v <= 0:
|
||
_problem(
|
||
f"{key} = {v!r}",
|
||
blurb,
|
||
f"Try: {key} = {DEFAULTS[key]}",
|
||
)
|
||
|
||
v = hp.get('exploration_turns')
|
||
if not isinstance(v, int) or v < 0:
|
||
_problem(
|
||
f"exploration_turns = {v!r}",
|
||
"When no character_set is specified, the trainer runs this many random turns\n"
|
||
" to discover what characters appear on the board. Must be 0 or more\n"
|
||
" (0 skips discovery entirely, which only works if character_set is set).",
|
||
f"Try: exploration_turns = {DEFAULTS['exploration_turns']}",
|
||
)
|
||
|
||
bs = hp.get('batch_size')
|
||
mc = hp.get('memory_capacity')
|
||
if (isinstance(bs, int) and bs > 0 and isinstance(mc, int) and mc > 0 and bs > mc):
|
||
_problem(
|
||
f"batch_size = {bs} is larger than memory_capacity = {mc}",
|
||
"Training samples a batch of past experiences from the replay buffer each\n"
|
||
" step, so the buffer must be able to hold at least as many experiences\n"
|
||
f" as the batch size. With batch_size = {bs}, the buffer holds {mc} —\n"
|
||
" that's not enough to sample from.",
|
||
f"Try: memory_capacity = {max(bs * 100, DEFAULTS['memory_capacity'])} "
|
||
f"(a much larger buffer also improves learning quality)",
|
||
)
|
||
|
||
if problems:
|
||
n = len(problems)
|
||
noun = "problem" if n == 1 else "problems"
|
||
header = f"Found {n} {noun} in your training configuration:\n\n"
|
||
footer = "\n\nFix these in config.toml, then run 'retro-gamer train' again."
|
||
raise ValueError(header + "\n\n".join(problems) + footer)
|
||
|
||
|
||
def _format_duration(seconds: float) -> str:
|
||
m, s = divmod(int(seconds), 60)
|
||
h, m = divmod(m, 60)
|
||
if h:
|
||
return f"{h}h{m:02d}m{s:02d}s"
|
||
return f"{m}m{s:02d}s"
|
||
|
||
|
||
class DQNTrainer:
|
||
"""Trains a deep Q-network agent to play a retro game.
|
||
|
||
Automatically selects the best available training device: Apple Silicon
|
||
GPU (MPS), NVIDIA GPU (CUDA), or CPU. The chosen device is recorded in
|
||
``training.log``.
|
||
|
||
On initialization, the trainer:
|
||
|
||
1. Discovers the character set if not already specified in *metadata*.
|
||
2. Builds the Q-network and logs its full architecture with rationale.
|
||
3. Writes ``config.toml`` and initializes ``training.log`` in *run_dir*.
|
||
|
||
Hyperparameters can be passed as keyword arguments; see the
|
||
:ref:`hyperparameters` reference for all options. Values not supplied
|
||
fall back to sensible defaults.
|
||
|
||
Call :meth:`train` to run all episodes. Checkpoints are saved every 100
|
||
episodes and training can be stopped (Ctrl+C) and resumed at any time.
|
||
|
||
Example::
|
||
|
||
from retro_gamer import GameMetadata, DQNTrainer
|
||
from retro.examples.snake import create_game
|
||
|
||
metadata = GameMetadata.from_pyproject("retro.examples.snake")
|
||
trainer = DQNTrainer(create_game, metadata, "runs/snake/")
|
||
trainer.train()
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
game_factory: Callable,
|
||
metadata: GameMetadata,
|
||
run_dir: str | Path,
|
||
preprocessing: dict | None = None,
|
||
**hyperparams,
|
||
):
|
||
self.game_factory = game_factory
|
||
self.metadata = metadata
|
||
self.run_dir = Path(run_dir)
|
||
self.hp: dict = {**DEFAULTS, **hyperparams}
|
||
validate_hyperparams(self.hp)
|
||
self.run_dir.mkdir(parents=True, exist_ok=True)
|
||
(self.run_dir / 'checkpoints').mkdir(exist_ok=True)
|
||
|
||
pre = preprocessing or {}
|
||
self.observe_state: list[str] = pre.get('observe_state', [])
|
||
self.egocentric: bool = pre.get('egocentric', False)
|
||
self.egocentric_player: str | None = pre.get('egocentric_player', None)
|
||
self.egocentric_radius: int | None = pre.get('egocentric_radius', None)
|
||
self.board: bool = pre.get('board', True)
|
||
self.observe_state_sizes: dict[str, int] = pre.get('observe_state_sizes', {})
|
||
|
||
if self.board is False and metadata.spatial:
|
||
raise ValueError(
|
||
"preprocessing.board = false is incompatible with spatial = true.\n"
|
||
"A CNN requires a 2-D board to operate on. Either set spatial = false\n"
|
||
"or keep board = true."
|
||
)
|
||
if self.board is False and not self.observe_state:
|
||
raise ValueError(
|
||
"preprocessing.board = false requires at least one entry in observe_state.\n"
|
||
"With board=false, the agent observes only the game state variables listed\n"
|
||
"in observe_state — if that list is empty, there is nothing to observe."
|
||
)
|
||
if self.egocentric and not self.egocentric_radius:
|
||
raise ValueError(
|
||
"preprocessing.egocentric = true requires egocentric_radius.\n"
|
||
"Choose a value based on how far the agent needs to see, e.g.:\n"
|
||
" egocentric_radius = 5 # 11×11 tight local view\n"
|
||
" egocentric_radius = 8 # 17×17 wider view"
|
||
)
|
||
|
||
metadata.board = self.board
|
||
|
||
if metadata.board_size is None:
|
||
g = game_factory()
|
||
metadata.board_size = g.board_size
|
||
|
||
if self.egocentric_radius:
|
||
side = 2 * self.egocentric_radius + 1
|
||
metadata.board_size = (side, side)
|
||
|
||
self.env = GameEnvironment(
|
||
game_factory, metadata,
|
||
observe_state=self.observe_state,
|
||
egocentric=self.egocentric,
|
||
egocentric_player=self.egocentric_player,
|
||
egocentric_radius=self.egocentric_radius,
|
||
board=self.board,
|
||
observe_state_sizes=self.observe_state_sizes,
|
||
)
|
||
|
||
if metadata.character_set is None and self.board:
|
||
self._discover_character_set()
|
||
|
||
if self.observe_state and not self.observe_state_sizes:
|
||
self._discover_observe_state_sizes()
|
||
self.env.observe_state_sizes = self.observe_state_sizes
|
||
|
||
metadata.extras_size = sum(self.observe_state_sizes.values()) if self.observe_state_sizes else 0
|
||
|
||
self.device = _get_device()
|
||
|
||
self.model, rationale = build_network(metadata, self.hp)
|
||
self.target_model, _ = build_network(metadata, self.hp)
|
||
self.model.to(self.device)
|
||
self.target_model.to(self.device)
|
||
self.target_model.load_state_dict(self.model.state_dict())
|
||
self.target_model.eval()
|
||
|
||
self.optimizer = optim.Adam(
|
||
self.model.parameters(), lr=self.hp['learning_rate']
|
||
)
|
||
self.lr_scheduler = optim.lr_scheduler.ExponentialLR(
|
||
self.optimizer, gamma=self.hp['learning_rate_decay']
|
||
)
|
||
|
||
if self.hp['prioritize_experiences']:
|
||
self.memory = PrioritizedReplayMemory(self.hp['memory_capacity'])
|
||
else:
|
||
self.memory = ReplayMemory(self.hp['memory_capacity'])
|
||
|
||
self.epsilon: float = self.hp['epsilon']
|
||
self.total_steps: int = 0
|
||
self.total_training_seconds: float = 0.0
|
||
self.start_episode: int = 1
|
||
self._resumed_from: str | None = None
|
||
|
||
self._save_config()
|
||
self._open_log(rationale)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Public API
|
||
# ------------------------------------------------------------------
|
||
|
||
def train(self, on_checkpoint=None, on_episode=None):
|
||
"""Run all training episodes and save checkpoints.
|
||
|
||
*on_checkpoint*, if provided, is called after each checkpoint with a
|
||
dict containing ``episode``, ``avg_reward``, ``avg_steps``,
|
||
``avg_loss``, and ``epsilon``. *on_episode*, if provided, is called
|
||
after every episode. When either callback is supplied, the built-in
|
||
tqdm progress bar is suppressed (the caller is expected to show its
|
||
own progress UI).
|
||
"""
|
||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
if self._resumed_from:
|
||
self._log_raw(f'\n=== Resumed from {self._resumed_from} | {timestamp} ===')
|
||
else:
|
||
self._log_raw(f'\n=== Training started | {timestamp} ===')
|
||
|
||
use_tqdm = on_checkpoint is None and on_episode is None
|
||
if use_tqdm:
|
||
print("Press Control+C to stop training early. Progress will be saved at the latest checkpoint.")
|
||
|
||
session_start = perf_counter()
|
||
ckpt_start = perf_counter()
|
||
episode_rewards: list[float] = []
|
||
episode_losses: list[float] = []
|
||
episode_steps: list[int] = []
|
||
|
||
episodes = range(self.start_episode, self.hp['training_episodes'] + 1)
|
||
bar = tqdm(episodes, unit='ep') if use_tqdm else episodes
|
||
for episode in bar:
|
||
total_reward, steps, avg_loss, trained = self._run_episode()
|
||
episode_rewards.append(total_reward)
|
||
if avg_loss > 0:
|
||
episode_losses.append(avg_loss)
|
||
episode_steps.append(steps)
|
||
|
||
self.epsilon = max(
|
||
self.hp['epsilon_min'], self.epsilon * self.hp['epsilon_decay']
|
||
)
|
||
if trained:
|
||
self.lr_scheduler.step()
|
||
|
||
if on_episode:
|
||
on_episode()
|
||
|
||
if use_tqdm:
|
||
bar.set_postfix(
|
||
reward=f'{total_reward:.1f}',
|
||
eps=f'{self.epsilon:.3f}',
|
||
loss=f'{avg_loss:.4f}',
|
||
)
|
||
|
||
is_checkpoint = (episode % 100 == 0)
|
||
is_last = (episode == self.hp['training_episodes'])
|
||
if is_checkpoint or (is_last and episode_rewards):
|
||
now = perf_counter()
|
||
ckpt_elapsed = now - ckpt_start
|
||
self.total_training_seconds += ckpt_elapsed
|
||
ckpt_start = now
|
||
|
||
self._save_checkpoint(f'ep_{episode:04d}.pt', episode)
|
||
stats = self._log_checkpoint(episode, episode_rewards, episode_losses, episode_steps, ckpt_elapsed)
|
||
episode_rewards = []
|
||
episode_losses = []
|
||
episode_steps = []
|
||
|
||
if on_checkpoint:
|
||
on_checkpoint(stats)
|
||
|
||
def load_checkpoint(self, path: str | Path):
|
||
"""Load a checkpoint to resume training.
|
||
|
||
Checkpoints are PyTorch state dicts stored under
|
||
``run_dir/checkpoints/``. Each contains model weights, optimizer
|
||
state, current epsilon, and total step count.
|
||
|
||
Raises :exc:`ValueError` if the checkpoint was trained with a
|
||
different character set, board size, action space, or network
|
||
architecture. The error message names each changed field and explains
|
||
why it is incompatible.
|
||
|
||
The CLI invokes this automatically; call directly only when driving
|
||
training from Python.
|
||
"""
|
||
ckpt = torch.load(path, weights_only=True, map_location='cpu')
|
||
self._check_compatibility(ckpt, path)
|
||
self.model.load_state_dict(ckpt['model_state_dict'])
|
||
self.target_model.load_state_dict(ckpt['model_state_dict'])
|
||
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
|
||
for state in self.optimizer.state.values():
|
||
for k, v in state.items():
|
||
if isinstance(v, torch.Tensor):
|
||
state[k] = v.to(self.device)
|
||
self.epsilon = ckpt['epsilon']
|
||
self.total_steps = ckpt['total_steps']
|
||
self.total_training_seconds = ckpt.get('total_training_seconds', 0.0)
|
||
self.start_episode = ckpt.get('episode', 0) + 1
|
||
self._resumed_from = Path(path).name
|
||
|
||
# ------------------------------------------------------------------
|
||
# Training loop internals
|
||
# ------------------------------------------------------------------
|
||
|
||
def _run_episode(self) -> tuple[float, int, float, bool]:
|
||
state = self.env.reset()
|
||
total_reward = 0.0
|
||
total_loss = 0.0
|
||
loss_count = 0
|
||
|
||
for step in range(self.hp['max_turns_per_episode']):
|
||
state_t = torch.as_tensor(state, dtype=torch.float32).to(self.device)
|
||
action_idx = self._select_action(state_t)
|
||
action_key = self._idx_to_key(action_idx)
|
||
|
||
next_state, reward, done = self.env.step(action_key)
|
||
self.memory.push(state, action_idx, reward, next_state, done)
|
||
|
||
if self.total_steps % self.hp['train_every'] == 0:
|
||
loss = self._train_step()
|
||
if loss is not None:
|
||
total_loss += loss
|
||
loss_count += 1
|
||
|
||
self.total_steps += 1
|
||
if self.total_steps % self.hp['target_update_freq'] == 0:
|
||
self.target_model.load_state_dict(self.model.state_dict())
|
||
|
||
state = next_state
|
||
total_reward += reward
|
||
if done:
|
||
break
|
||
|
||
avg_loss = total_loss / loss_count if loss_count else 0.0
|
||
return total_reward, step + 1, avg_loss, loss_count > 0
|
||
|
||
def _select_action(self, state_t: torch.Tensor) -> int:
|
||
if random.random() < self.epsilon:
|
||
return random.randrange(self.metadata.n_actions)
|
||
with torch.no_grad():
|
||
return int(self.model(state_t.unsqueeze(0)).argmax().item())
|
||
|
||
def _idx_to_key(self, idx: int) -> str | None:
|
||
if idx >= len(self.metadata.actions):
|
||
return None
|
||
return self.metadata.actions[idx]
|
||
|
||
def _train_step(self) -> float | None:
|
||
if len(self.memory) < self.hp['batch_size']:
|
||
return None
|
||
|
||
if self.hp['prioritize_experiences']:
|
||
assert isinstance(self.memory, PrioritizedReplayMemory)
|
||
experiences, indices, weights = self.memory.sample(self.hp['batch_size'])
|
||
weight_t = torch.as_tensor(weights, dtype=torch.float32).to(self.device)
|
||
else:
|
||
experiences = self.memory.sample(self.hp['batch_size'])
|
||
indices = None
|
||
weight_t = None
|
||
|
||
states = torch.as_tensor(
|
||
np.array([e.state for e in experiences]), dtype=torch.float32
|
||
).to(self.device)
|
||
actions = torch.as_tensor(
|
||
[e.action for e in experiences], dtype=torch.long
|
||
).to(self.device)
|
||
rewards = torch.as_tensor(
|
||
[e.reward for e in experiences], dtype=torch.float32
|
||
).to(self.device)
|
||
next_states = torch.as_tensor(
|
||
np.array([e.next_state for e in experiences]), dtype=torch.float32
|
||
).to(self.device)
|
||
dones = torch.as_tensor(
|
||
[e.done for e in experiences], dtype=torch.float32
|
||
).to(self.device)
|
||
|
||
q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
|
||
with torch.no_grad():
|
||
next_q = self.target_model(next_states).max(1).values
|
||
targets = rewards + self.hp['gamma'] * next_q * (1.0 - dones)
|
||
|
||
element_loss = nn.functional.huber_loss(q_values, targets, reduction='none', delta=1.0)
|
||
|
||
if weight_t is not None:
|
||
loss = (weight_t * element_loss).mean()
|
||
td_errors = (q_values - targets).detach().abs().cpu().numpy()
|
||
self.memory.update_priorities(indices, td_errors)
|
||
else:
|
||
loss = element_loss.mean()
|
||
|
||
self.optimizer.zero_grad()
|
||
loss.backward()
|
||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
||
self.optimizer.step()
|
||
return float(loss.item())
|
||
|
||
# ------------------------------------------------------------------
|
||
# Compatibility checking
|
||
# ------------------------------------------------------------------
|
||
|
||
def _config_snapshot(self) -> dict:
|
||
return {
|
||
'metadata': self.metadata.to_dict(),
|
||
'preprocessing': {
|
||
'spatial': self.metadata.spatial,
|
||
'board': self.board,
|
||
'observe_state': self.observe_state,
|
||
'observe_state_sizes': self.observe_state_sizes,
|
||
'egocentric': self.egocentric,
|
||
'egocentric_player': self.egocentric_player,
|
||
'egocentric_radius': self.egocentric_radius,
|
||
},
|
||
'hidden_sizes': self.hp['hidden_sizes'],
|
||
}
|
||
|
||
def _check_compatibility(self, ckpt: dict, path: str | Path):
|
||
snapshot = ckpt.get('config_snapshot')
|
||
if snapshot is None:
|
||
return
|
||
|
||
current = self._config_snapshot()
|
||
issues = []
|
||
|
||
old_meta = snapshot.get('metadata', {})
|
||
new_meta = current['metadata']
|
||
for field, desc in _INCOMPATIBLE_METADATA.items():
|
||
if old_meta.get(field) != new_meta.get(field):
|
||
issues.append((field, desc, old_meta.get(field), new_meta.get(field)))
|
||
|
||
old_pre = snapshot.get('preprocessing', {})
|
||
new_pre = current['preprocessing']
|
||
for field, desc in _INCOMPATIBLE_PREPROCESSING.items():
|
||
if old_pre.get(field) != new_pre.get(field):
|
||
issues.append((field, desc, old_pre.get(field), new_pre.get(field)))
|
||
|
||
for field, desc in _INCOMPATIBLE_ARCH.items():
|
||
if snapshot.get(field) != current.get(field):
|
||
issues.append((field, desc, snapshot.get(field), current.get(field)))
|
||
|
||
if not issues:
|
||
return
|
||
|
||
lines = [
|
||
f"Cannot resume from {Path(path).name}: incompatible changes detected in config.toml.",
|
||
"",
|
||
"The following changes require starting fresh. The existing model was trained",
|
||
"on a different problem and its weights cannot be reused:",
|
||
"",
|
||
]
|
||
for field, desc, old_val, new_val in issues:
|
||
lines += [
|
||
f" {field}",
|
||
f" was : {old_val!r}",
|
||
f" now : {new_val!r}",
|
||
f" why : {desc}",
|
||
"",
|
||
]
|
||
lines += [
|
||
"Run 'retro-gamer clean RUN_DIR' to remove existing checkpoints and the",
|
||
"training log, then run 'retro-gamer train RUN_DIR' to start fresh.",
|
||
]
|
||
raise ValueError("\n".join(lines))
|
||
|
||
# ------------------------------------------------------------------
|
||
# Initialisation helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
def _discover_character_set(self):
|
||
chars = self.env.discover_character_set(self.hp['exploration_turns'])
|
||
self.metadata.character_set = chars
|
||
self._log_raw(
|
||
f"[INIT] character_set not specified — discovered {len(chars)} chars "
|
||
f"after {self.hp['exploration_turns']} exploration turns: {chars}"
|
||
)
|
||
|
||
def _discover_observe_state_sizes(self):
|
||
"""Sample game.state to determine the flat size of each observe_state key."""
|
||
self.env.reset()
|
||
state = dict(self.env.game.state)
|
||
sizes = {}
|
||
for key in self.observe_state:
|
||
val = state.get(key, 0)
|
||
sizes[key] = len(val) if isinstance(val, (list, tuple)) else 1
|
||
self.observe_state_sizes = sizes
|
||
|
||
def _save_config(self):
|
||
config_path = self.run_dir / 'config.toml'
|
||
config: dict = {}
|
||
if config_path.exists():
|
||
import tomllib
|
||
with open(config_path, 'rb') as f:
|
||
config = tomllib.load(f)
|
||
config['metadata'] = self.metadata.to_dict()
|
||
pre = config.setdefault('preprocessing', {})
|
||
pre['spatial'] = self.metadata.spatial
|
||
pre['board'] = self.board
|
||
pre['observe_state'] = self.observe_state
|
||
if self.observe_state_sizes:
|
||
pre['observe_state_sizes'] = self.observe_state_sizes
|
||
pre['egocentric'] = self.egocentric
|
||
if self.egocentric_player:
|
||
pre['egocentric_player'] = self.egocentric_player
|
||
if self.egocentric_radius:
|
||
pre['egocentric_radius'] = self.egocentric_radius
|
||
config['model'] = {k: v for k, v in self.hp.items() if k in MODEL_KEYS}
|
||
config['training'] = {k: v for k, v in self.hp.items() if k not in MODEL_KEYS}
|
||
with open(config_path, 'wb') as f:
|
||
tomli_w.dump(config, f)
|
||
|
||
def _open_log(self, rationale: str):
|
||
self.log_path = self.run_dir / 'training.log'
|
||
if not self.log_path.exists():
|
||
with open(self.log_path, 'w') as f:
|
||
f.write(rationale + '\n')
|
||
f.write(f'[INIT] Device: {self.device}\n')
|
||
|
||
def _log_raw(self, line: str):
|
||
with open(self.log_path, 'a') as f:
|
||
f.write(line + '\n')
|
||
|
||
def _log_checkpoint(
|
||
self,
|
||
episode: int,
|
||
rewards: list[float],
|
||
losses: list[float],
|
||
steps: list[int],
|
||
ckpt_elapsed: float,
|
||
) -> dict:
|
||
n = len(rewards)
|
||
start_ep = episode - n + 1
|
||
avg_reward = sum(rewards) / n if n else 0.0
|
||
avg_loss = sum(losses) / len(losses) if losses else 0.0
|
||
avg_steps = sum(steps) / n if n else 0.0
|
||
line = (
|
||
f"[ep_{episode:04d}]"
|
||
f" ep={start_ep:04d}-{episode:04d}"
|
||
f" avg_reward={avg_reward:+.1f}"
|
||
f" avg_steps={avg_steps:.0f}"
|
||
f" epsilon={self.epsilon:.3f}"
|
||
f" avg_loss={avg_loss:.1f}"
|
||
f" time={_format_duration(ckpt_elapsed)}"
|
||
f" total={_format_duration(self.total_training_seconds)}"
|
||
)
|
||
self._log_raw(line)
|
||
return {
|
||
'episode': episode,
|
||
'avg_reward': avg_reward,
|
||
'avg_steps': avg_steps,
|
||
'avg_loss': avg_loss,
|
||
'epsilon': self.epsilon,
|
||
}
|
||
|
||
def _save_checkpoint(self, name: str, episode: int):
|
||
torch.save(
|
||
{
|
||
'model_state_dict': self.model.state_dict(),
|
||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||
'epsilon': self.epsilon,
|
||
'total_steps': self.total_steps,
|
||
'episode': episode,
|
||
'total_training_seconds': self.total_training_seconds,
|
||
'config_snapshot': self._config_snapshot(),
|
||
},
|
||
self.run_dir / 'checkpoints' / name,
|
||
)
|