Updates across the board
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import random
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from typing import Callable
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -8,40 +10,250 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import tomli_w
|
||||
|
||||
from tqdm import tqdm
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
from retro_gamer.env import GameEnvironment
|
||||
from retro_gamer.network import build_network
|
||||
from retro_gamer.memory import ReplayMemory, PrioritizedReplayMemory
|
||||
|
||||
MODEL_KEYS: frozenset = frozenset({'hidden_sizes'})
|
||||
|
||||
DEFAULTS: dict = {
|
||||
'learning_rate': 1e-3,
|
||||
'lr_decay': 0.995,
|
||||
# [model]
|
||||
'hidden_sizes': [128, 64],
|
||||
# [training]
|
||||
'learning_rate': 1e-4,
|
||||
'learning_rate_decay': 0.9999,
|
||||
'gamma': 0.99,
|
||||
'epsilon': 1.0,
|
||||
'epsilon_decay': 0.995,
|
||||
'epsilon_decay': 0.9997,
|
||||
'epsilon_min': 0.05,
|
||||
'batch_size': 64,
|
||||
'memory_capacity': 10_000,
|
||||
'target_update_freq': 100,
|
||||
'training_episodes': 1_000,
|
||||
'n_layers': 2,
|
||||
'layer_size': 128,
|
||||
'prioritize_experiences': False,
|
||||
'memory_capacity': 50_000,
|
||||
'target_update_freq': 500,
|
||||
'train_every': 4,
|
||||
'training_episodes': 20_000,
|
||||
'prioritize_experiences': True,
|
||||
'exploration_turns': 200,
|
||||
'unknown_character_strategy': 'ignore',
|
||||
'max_turns_per_episode': 2_000,
|
||||
}
|
||||
|
||||
|
||||
def _get_device() -> torch.device:
|
||||
if torch.backends.mps.is_available():
|
||||
return torch.device('mps')
|
||||
if torch.cuda.is_available():
|
||||
return torch.device('cuda')
|
||||
return torch.device('cpu')
|
||||
|
||||
# Fields that make an existing checkpoint incompatible with the current config.
|
||||
# Changing any of these requires starting training from scratch.
|
||||
_INCOMPATIBLE_METADATA = {
|
||||
'actions': 'the list of actions the agent can take (changes output layer size)',
|
||||
'reward': 'the reward signal — Q-values trained on the old signal are meaningless for the new one',
|
||||
'character_set': 'the set of board characters (changes input layer size)',
|
||||
'board_size': 'the board dimensions (changes input layer size)',
|
||||
}
|
||||
_INCOMPATIBLE_PREPROCESSING = {
|
||||
'spatial': 'spatial vs non-spatial network type (changes network architecture)',
|
||||
'board': 'whether the board is included in the observation (changes input size)',
|
||||
'observe_state': 'the state keys included in the observation (changes input size)',
|
||||
'observe_state_sizes': 'the size of each observed state key (changes input layer size)',
|
||||
'egocentric': 'egocentric board transformation (changes input representation)',
|
||||
'egocentric_player': 'the agent used as the egocentric center (changes input representation)',
|
||||
'egocentric_radius': 'the egocentric crop radius (changes input layer size)',
|
||||
}
|
||||
_INCOMPATIBLE_ARCH = {
|
||||
'hidden_sizes': 'the hidden layer sizes (changes network shape)',
|
||||
}
|
||||
|
||||
|
||||
def validate_hyperparams(hp: dict):
|
||||
"""Check all hyperparameters and raise ValueError listing every problem found."""
|
||||
problems = []
|
||||
|
||||
def _problem(heading: str, explanation: str, fix: str | None = None):
|
||||
text = f" {heading}\n {explanation}"
|
||||
if fix:
|
||||
text += f"\n → {fix}"
|
||||
problems.append(text)
|
||||
|
||||
hs = hp.get('hidden_sizes')
|
||||
if not isinstance(hs, list) or len(hs) == 0:
|
||||
_problem(
|
||||
f"hidden_sizes = {hs!r}",
|
||||
"This sets the shape of the neural network's hidden layers. It must be a\n"
|
||||
" non-empty list of positive integers, one number per layer.",
|
||||
"Try: hidden_sizes = [512, 256]",
|
||||
)
|
||||
elif any(not isinstance(s, int) or s <= 0 for s in hs):
|
||||
bad = [s for s in hs if not isinstance(s, int) or s <= 0]
|
||||
_problem(
|
||||
f"hidden_sizes = {hs!r}",
|
||||
f"Every layer size must be a positive integer, but got {bad!r}.\n"
|
||||
" Each number is the count of neurons in that layer — more neurons means\n"
|
||||
" more capacity to learn complex patterns.",
|
||||
"Try: hidden_sizes = [512, 256]",
|
||||
)
|
||||
|
||||
lr = hp.get('learning_rate')
|
||||
if not isinstance(lr, (int, float)) or lr <= 0:
|
||||
_problem(
|
||||
f"learning_rate = {lr!r}",
|
||||
"The learning rate controls how much the network adjusts its weights after\n"
|
||||
" each training step. Too high and training becomes unstable; too low and\n"
|
||||
" it learns very slowly. It must be a positive number.",
|
||||
"Typical values are between 0.0001 and 0.01. Try: learning_rate = 0.001",
|
||||
)
|
||||
|
||||
for key, blurb in [
|
||||
('learning_rate_decay',
|
||||
"After each episode, the learning rate is multiplied by this value, gradually\n"
|
||||
" slowing learning over time. A value of 1.0 means no decay; closer to 0\n"
|
||||
" means very aggressive decay. It must be greater than 0 and at most 1."),
|
||||
('gamma',
|
||||
"Gamma is the discount factor: how much the agent values future rewards versus\n"
|
||||
" immediate ones. 0.99 means future rewards are nearly as important as now;\n"
|
||||
" 0.0 means the agent only cares about the very next step.\n"
|
||||
" It must be greater than 0 and at most 1."),
|
||||
('epsilon_decay',
|
||||
"Each episode, the exploration rate (epsilon) is multiplied by this to gradually\n"
|
||||
" reduce random actions over time. It must be between 0 (exclusive) and 1."),
|
||||
]:
|
||||
v = hp.get(key)
|
||||
if not isinstance(v, (int, float)) or not (0 < v <= 1):
|
||||
_problem(
|
||||
f"{key} = {v!r}",
|
||||
blurb,
|
||||
f"Try a value close to but less than 1, like {key} = 0.995",
|
||||
)
|
||||
|
||||
for key, blurb in [
|
||||
('epsilon',
|
||||
"Epsilon is the probability of taking a random action (exploration vs.\n"
|
||||
" exploitation). It starts high — usually 1.0, meaning fully random —\n"
|
||||
" and decays toward epsilon_min during training. It must be between 0 and 1."),
|
||||
('epsilon_min',
|
||||
"This is the lowest exploration rate allowed. Even after lots of training, the\n"
|
||||
" agent keeps at least this much randomness so it keeps discovering new things.\n"
|
||||
" It must be between 0 and 1."),
|
||||
]:
|
||||
v = hp.get(key)
|
||||
if not isinstance(v, (int, float)) or not (0 <= v <= 1):
|
||||
_problem(
|
||||
f"{key} = {v!r}",
|
||||
blurb,
|
||||
f"Try: {key} = {'0.05' if 'min' in key else '1.0'}",
|
||||
)
|
||||
|
||||
eps = hp.get('epsilon')
|
||||
eps_min = hp.get('epsilon_min')
|
||||
if (isinstance(eps, (int, float)) and isinstance(eps_min, (int, float))
|
||||
and 0 <= eps <= 1 and 0 <= eps_min <= 1 and eps_min > eps):
|
||||
_problem(
|
||||
f"epsilon_min = {eps_min!r} is greater than epsilon = {eps!r}",
|
||||
"epsilon is the starting exploration rate and epsilon_min is the floor it\n"
|
||||
" decays toward, so epsilon_min must be less than or equal to epsilon.",
|
||||
f"Try: epsilon = 1.0 and epsilon_min = 0.05",
|
||||
)
|
||||
|
||||
for key, blurb in [
|
||||
('batch_size',
|
||||
"Training samples this many past experiences from the replay buffer at once\n"
|
||||
" to compute a learning update. Must be a positive integer."),
|
||||
('memory_capacity',
|
||||
"The replay buffer stores this many past experiences. When it fills up, the\n"
|
||||
" oldest are discarded. A larger buffer means more diverse training data.\n"
|
||||
" Must be a positive integer."),
|
||||
('target_update_freq',
|
||||
"The target network (a stable copy of the Q-network used to compute targets)\n"
|
||||
" is updated every this many steps. Must be a positive integer."),
|
||||
('train_every',
|
||||
"A training step runs once every this many game steps. This lets the agent\n"
|
||||
" collect several new experiences before updating. Must be a positive integer."),
|
||||
('training_episodes',
|
||||
"The total number of episodes (games) to train for. Must be a positive integer."),
|
||||
('max_turns_per_episode',
|
||||
"If a game episode hasn't ended naturally after this many steps, it's cut\n"
|
||||
" short. This prevents a buggy or stuck agent from running forever.\n"
|
||||
" Must be a positive integer."),
|
||||
]:
|
||||
v = hp.get(key)
|
||||
if not isinstance(v, int) or v <= 0:
|
||||
_problem(
|
||||
f"{key} = {v!r}",
|
||||
blurb,
|
||||
f"Try: {key} = {DEFAULTS[key]}",
|
||||
)
|
||||
|
||||
v = hp.get('exploration_turns')
|
||||
if not isinstance(v, int) or v < 0:
|
||||
_problem(
|
||||
f"exploration_turns = {v!r}",
|
||||
"When no character_set is specified, the trainer runs this many random turns\n"
|
||||
" to discover what characters appear on the board. Must be 0 or more\n"
|
||||
" (0 skips discovery entirely, which only works if character_set is set).",
|
||||
f"Try: exploration_turns = {DEFAULTS['exploration_turns']}",
|
||||
)
|
||||
|
||||
bs = hp.get('batch_size')
|
||||
mc = hp.get('memory_capacity')
|
||||
if (isinstance(bs, int) and bs > 0 and isinstance(mc, int) and mc > 0 and bs > mc):
|
||||
_problem(
|
||||
f"batch_size = {bs} is larger than memory_capacity = {mc}",
|
||||
"Training samples a batch of past experiences from the replay buffer each\n"
|
||||
" step, so the buffer must be able to hold at least as many experiences\n"
|
||||
f" as the batch size. With batch_size = {bs}, the buffer holds {mc} —\n"
|
||||
" that's not enough to sample from.",
|
||||
f"Try: memory_capacity = {max(bs * 100, DEFAULTS['memory_capacity'])} "
|
||||
f"(a much larger buffer also improves learning quality)",
|
||||
)
|
||||
|
||||
if problems:
|
||||
n = len(problems)
|
||||
noun = "problem" if n == 1 else "problems"
|
||||
header = f"Found {n} {noun} in your training configuration:\n\n"
|
||||
footer = "\n\nFix these in config.toml, then run 'retro-gamer train' again."
|
||||
raise ValueError(header + "\n\n".join(problems) + footer)
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
m, s = divmod(int(seconds), 60)
|
||||
h, m = divmod(m, 60)
|
||||
if h:
|
||||
return f"{h}h{m:02d}m{s:02d}s"
|
||||
return f"{m}m{s:02d}s"
|
||||
|
||||
|
||||
class DQNTrainer:
|
||||
"""Trains a deep Q-network agent to play a retro game.
|
||||
|
||||
On initialization the trainer:
|
||||
1. Discovers the character set (if not already specified in metadata).
|
||||
2. Builds the Q-network and logs the full architecture with rationale.
|
||||
3. Saves config.toml and starts training.log in run_dir.
|
||||
Automatically selects the best available training device: Apple Silicon
|
||||
GPU (MPS), NVIDIA GPU (CUDA), or CPU. The chosen device is recorded in
|
||||
``training.log``.
|
||||
|
||||
Call train() to run all episodes and save checkpoints.
|
||||
On initialization, the trainer:
|
||||
|
||||
1. Discovers the character set if not already specified in *metadata*.
|
||||
2. Builds the Q-network and logs its full architecture with rationale.
|
||||
3. Writes ``config.toml`` and initializes ``training.log`` in *run_dir*.
|
||||
|
||||
Hyperparameters can be passed as keyword arguments; see the
|
||||
:ref:`hyperparameters` reference for all options. Values not supplied
|
||||
fall back to sensible defaults.
|
||||
|
||||
Call :meth:`train` to run all episodes. Checkpoints are saved every 100
|
||||
episodes and training can be stopped (Ctrl+C) and resumed at any time.
|
||||
|
||||
Example::
|
||||
|
||||
from retro_gamer import GameMetadata, DQNTrainer
|
||||
from retro.examples.snake import create_game
|
||||
|
||||
metadata = GameMetadata.from_pyproject("retro.examples.snake")
|
||||
trainer = DQNTrainer(create_game, metadata, "runs/snake/")
|
||||
trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -49,26 +261,80 @@ class DQNTrainer:
|
||||
game_factory: Callable,
|
||||
metadata: GameMetadata,
|
||||
run_dir: str | Path,
|
||||
preprocessing: dict | None = None,
|
||||
**hyperparams,
|
||||
):
|
||||
self.game_factory = game_factory
|
||||
self.metadata = metadata
|
||||
self.run_dir = Path(run_dir)
|
||||
self.hp: dict = {**DEFAULTS, **hyperparams}
|
||||
validate_hyperparams(self.hp)
|
||||
self.run_dir.mkdir(parents=True, exist_ok=True)
|
||||
(self.run_dir / 'checkpoints').mkdir(exist_ok=True)
|
||||
|
||||
self.env = GameEnvironment(game_factory, metadata)
|
||||
pre = preprocessing or {}
|
||||
self.observe_state: list[str] = pre.get('observe_state', [])
|
||||
self.egocentric: bool = pre.get('egocentric', False)
|
||||
self.egocentric_player: str | None = pre.get('egocentric_player', None)
|
||||
self.egocentric_radius: int | None = pre.get('egocentric_radius', None)
|
||||
self.board: bool = pre.get('board', True)
|
||||
self.observe_state_sizes: dict[str, int] = pre.get('observe_state_sizes', {})
|
||||
|
||||
if self.board is False and metadata.spatial:
|
||||
raise ValueError(
|
||||
"preprocessing.board = false is incompatible with spatial = true.\n"
|
||||
"A CNN requires a 2-D board to operate on. Either set spatial = false\n"
|
||||
"or keep board = true."
|
||||
)
|
||||
if self.board is False and not self.observe_state:
|
||||
raise ValueError(
|
||||
"preprocessing.board = false requires at least one entry in observe_state.\n"
|
||||
"With board=false, the agent observes only the game state variables listed\n"
|
||||
"in observe_state — if that list is empty, there is nothing to observe."
|
||||
)
|
||||
if self.egocentric and not self.egocentric_radius:
|
||||
raise ValueError(
|
||||
"preprocessing.egocentric = true requires egocentric_radius.\n"
|
||||
"Choose a value based on how far the agent needs to see, e.g.:\n"
|
||||
" egocentric_radius = 5 # 11×11 tight local view\n"
|
||||
" egocentric_radius = 8 # 17×17 wider view"
|
||||
)
|
||||
|
||||
metadata.board = self.board
|
||||
|
||||
if metadata.board_size is None:
|
||||
g = game_factory()
|
||||
metadata.board_size = g.board_size
|
||||
|
||||
if metadata.character_set is None:
|
||||
if self.egocentric_radius:
|
||||
side = 2 * self.egocentric_radius + 1
|
||||
metadata.board_size = (side, side)
|
||||
|
||||
self.env = GameEnvironment(
|
||||
game_factory, metadata,
|
||||
observe_state=self.observe_state,
|
||||
egocentric=self.egocentric,
|
||||
egocentric_player=self.egocentric_player,
|
||||
egocentric_radius=self.egocentric_radius,
|
||||
board=self.board,
|
||||
observe_state_sizes=self.observe_state_sizes,
|
||||
)
|
||||
|
||||
if metadata.character_set is None and self.board:
|
||||
self._discover_character_set()
|
||||
|
||||
if self.observe_state and not self.observe_state_sizes:
|
||||
self._discover_observe_state_sizes()
|
||||
self.env.observe_state_sizes = self.observe_state_sizes
|
||||
|
||||
metadata.extras_size = sum(self.observe_state_sizes.values()) if self.observe_state_sizes else 0
|
||||
|
||||
self.device = _get_device()
|
||||
|
||||
self.model, rationale = build_network(metadata, self.hp)
|
||||
self.target_model, _ = build_network(metadata, self.hp)
|
||||
self.model.to(self.device)
|
||||
self.target_model.to(self.device)
|
||||
self.target_model.load_state_dict(self.model.state_dict())
|
||||
self.target_model.eval()
|
||||
|
||||
@@ -76,7 +342,7 @@ class DQNTrainer:
|
||||
self.model.parameters(), lr=self.hp['learning_rate']
|
||||
)
|
||||
self.lr_scheduler = optim.lr_scheduler.ExponentialLR(
|
||||
self.optimizer, gamma=self.hp['lr_decay']
|
||||
self.optimizer, gamma=self.hp['learning_rate_decay']
|
||||
)
|
||||
|
||||
if self.hp['prioritize_experiences']:
|
||||
@@ -86,6 +352,9 @@ class DQNTrainer:
|
||||
|
||||
self.epsilon: float = self.hp['epsilon']
|
||||
self.total_steps: int = 0
|
||||
self.total_training_seconds: float = 0.0
|
||||
self.start_episode: int = 1
|
||||
self._resumed_from: str | None = None
|
||||
|
||||
self._save_config()
|
||||
self._open_log(rationale)
|
||||
@@ -94,49 +363,127 @@ class DQNTrainer:
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def train(self):
|
||||
"""Run all training episodes and save checkpoints."""
|
||||
for episode in range(1, self.hp['training_episodes'] + 1):
|
||||
total_reward, steps, avg_loss = self._run_episode()
|
||||
def train(self, on_checkpoint=None, on_episode=None):
|
||||
"""Run all training episodes and save checkpoints.
|
||||
|
||||
*on_checkpoint*, if provided, is called after each checkpoint with a
|
||||
dict containing ``episode``, ``avg_reward``, ``avg_steps``,
|
||||
``avg_loss``, and ``epsilon``. *on_episode*, if provided, is called
|
||||
after every episode. When either callback is supplied, the built-in
|
||||
tqdm progress bar is suppressed (the caller is expected to show its
|
||||
own progress UI).
|
||||
"""
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
if self._resumed_from:
|
||||
self._log_raw(f'\n=== Resumed from {self._resumed_from} | {timestamp} ===')
|
||||
else:
|
||||
self._log_raw(f'\n=== Training started | {timestamp} ===')
|
||||
|
||||
use_tqdm = on_checkpoint is None and on_episode is None
|
||||
if use_tqdm:
|
||||
print("Press Control+C to stop training early. Progress will be saved at the latest checkpoint.")
|
||||
|
||||
session_start = perf_counter()
|
||||
ckpt_start = perf_counter()
|
||||
episode_rewards: list[float] = []
|
||||
episode_losses: list[float] = []
|
||||
episode_steps: list[int] = []
|
||||
|
||||
episodes = range(self.start_episode, self.hp['training_episodes'] + 1)
|
||||
bar = tqdm(episodes, unit='ep') if use_tqdm else episodes
|
||||
for episode in bar:
|
||||
total_reward, steps, avg_loss, trained = self._run_episode()
|
||||
episode_rewards.append(total_reward)
|
||||
if avg_loss > 0:
|
||||
episode_losses.append(avg_loss)
|
||||
episode_steps.append(steps)
|
||||
|
||||
self.epsilon = max(
|
||||
self.hp['epsilon_min'], self.epsilon * self.hp['epsilon_decay']
|
||||
)
|
||||
self.lr_scheduler.step()
|
||||
self._log_episode(episode, total_reward, steps, avg_loss)
|
||||
if episode % 100 == 0:
|
||||
self._save_checkpoint(f'ep_{episode:04d}.pt')
|
||||
self._save_checkpoint('final.pt')
|
||||
if trained:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if on_episode:
|
||||
on_episode()
|
||||
|
||||
if use_tqdm:
|
||||
bar.set_postfix(
|
||||
reward=f'{total_reward:.1f}',
|
||||
eps=f'{self.epsilon:.3f}',
|
||||
loss=f'{avg_loss:.4f}',
|
||||
)
|
||||
|
||||
is_checkpoint = (episode % 100 == 0)
|
||||
is_last = (episode == self.hp['training_episodes'])
|
||||
if is_checkpoint or (is_last and episode_rewards):
|
||||
now = perf_counter()
|
||||
ckpt_elapsed = now - ckpt_start
|
||||
self.total_training_seconds += ckpt_elapsed
|
||||
ckpt_start = now
|
||||
|
||||
self._save_checkpoint(f'ep_{episode:04d}.pt', episode)
|
||||
stats = self._log_checkpoint(episode, episode_rewards, episode_losses, episode_steps, ckpt_elapsed)
|
||||
episode_rewards = []
|
||||
episode_losses = []
|
||||
episode_steps = []
|
||||
|
||||
if on_checkpoint:
|
||||
on_checkpoint(stats)
|
||||
|
||||
def load_checkpoint(self, path: str | Path):
|
||||
ckpt = torch.load(path, weights_only=True)
|
||||
"""Load a checkpoint to resume training.
|
||||
|
||||
Checkpoints are PyTorch state dicts stored under
|
||||
``run_dir/checkpoints/``. Each contains model weights, optimizer
|
||||
state, current epsilon, and total step count.
|
||||
|
||||
Raises :exc:`ValueError` if the checkpoint was trained with a
|
||||
different character set, board size, action space, or network
|
||||
architecture. The error message names each changed field and explains
|
||||
why it is incompatible.
|
||||
|
||||
The CLI invokes this automatically; call directly only when driving
|
||||
training from Python.
|
||||
"""
|
||||
ckpt = torch.load(path, weights_only=True, map_location='cpu')
|
||||
self._check_compatibility(ckpt, path)
|
||||
self.model.load_state_dict(ckpt['model_state_dict'])
|
||||
self.target_model.load_state_dict(ckpt['model_state_dict'])
|
||||
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
|
||||
for state in self.optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(self.device)
|
||||
self.epsilon = ckpt['epsilon']
|
||||
self.total_steps = ckpt['total_steps']
|
||||
self.total_training_seconds = ckpt.get('total_training_seconds', 0.0)
|
||||
self.start_episode = ckpt.get('episode', 0) + 1
|
||||
self._resumed_from = Path(path).name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Training loop internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_episode(self) -> tuple[float, int, float]:
|
||||
def _run_episode(self) -> tuple[float, int, float, bool]:
|
||||
state = self.env.reset()
|
||||
total_reward = 0.0
|
||||
total_loss = 0.0
|
||||
loss_count = 0
|
||||
|
||||
for step in range(self.hp['max_turns_per_episode']):
|
||||
state_t = torch.as_tensor(state, dtype=torch.float32)
|
||||
state_t = torch.as_tensor(state, dtype=torch.float32).to(self.device)
|
||||
action_idx = self._select_action(state_t)
|
||||
action_key = self._idx_to_key(action_idx)
|
||||
|
||||
next_state, reward, done = self.env.step(action_key)
|
||||
self.memory.push(state, action_idx, reward, next_state, done)
|
||||
|
||||
loss = self._train_step()
|
||||
if loss is not None:
|
||||
total_loss += loss
|
||||
loss_count += 1
|
||||
if self.total_steps % self.hp['train_every'] == 0:
|
||||
loss = self._train_step()
|
||||
if loss is not None:
|
||||
total_loss += loss
|
||||
loss_count += 1
|
||||
|
||||
self.total_steps += 1
|
||||
if self.total_steps % self.hp['target_update_freq'] == 0:
|
||||
@@ -148,7 +495,7 @@ class DQNTrainer:
|
||||
break
|
||||
|
||||
avg_loss = total_loss / loss_count if loss_count else 0.0
|
||||
return total_reward, step + 1, avg_loss
|
||||
return total_reward, step + 1, avg_loss, loss_count > 0
|
||||
|
||||
def _select_action(self, state_t: torch.Tensor) -> int:
|
||||
if random.random() < self.epsilon:
|
||||
@@ -168,7 +515,7 @@ class DQNTrainer:
|
||||
if self.hp['prioritize_experiences']:
|
||||
assert isinstance(self.memory, PrioritizedReplayMemory)
|
||||
experiences, indices, weights = self.memory.sample(self.hp['batch_size'])
|
||||
weight_t = torch.as_tensor(weights, dtype=torch.float32)
|
||||
weight_t = torch.as_tensor(weights, dtype=torch.float32).to(self.device)
|
||||
else:
|
||||
experiences = self.memory.sample(self.hp['batch_size'])
|
||||
indices = None
|
||||
@@ -176,33 +523,107 @@ class DQNTrainer:
|
||||
|
||||
states = torch.as_tensor(
|
||||
np.array([e.state for e in experiences]), dtype=torch.float32
|
||||
)
|
||||
actions = torch.as_tensor([e.action for e in experiences], dtype=torch.long)
|
||||
rewards = torch.as_tensor([e.reward for e in experiences], dtype=torch.float32)
|
||||
).to(self.device)
|
||||
actions = torch.as_tensor(
|
||||
[e.action for e in experiences], dtype=torch.long
|
||||
).to(self.device)
|
||||
rewards = torch.as_tensor(
|
||||
[e.reward for e in experiences], dtype=torch.float32
|
||||
).to(self.device)
|
||||
next_states = torch.as_tensor(
|
||||
np.array([e.next_state for e in experiences]), dtype=torch.float32
|
||||
)
|
||||
dones = torch.as_tensor([e.done for e in experiences], dtype=torch.float32)
|
||||
).to(self.device)
|
||||
dones = torch.as_tensor(
|
||||
[e.done for e in experiences], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
with torch.no_grad():
|
||||
next_q = self.target_model(next_states).max(1).values
|
||||
targets = rewards + self.hp['gamma'] * next_q * (1.0 - dones)
|
||||
|
||||
element_loss = nn.functional.mse_loss(q_values, targets, reduction='none')
|
||||
element_loss = nn.functional.huber_loss(q_values, targets, reduction='none', delta=1.0)
|
||||
|
||||
if weight_t is not None:
|
||||
loss = (weight_t * element_loss).mean()
|
||||
td_errors = (q_values - targets).detach().abs().numpy()
|
||||
td_errors = (q_values - targets).detach().abs().cpu().numpy()
|
||||
self.memory.update_priorities(indices, td_errors)
|
||||
else:
|
||||
loss = element_loss.mean()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
||||
self.optimizer.step()
|
||||
return float(loss.item())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Compatibility checking
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _config_snapshot(self) -> dict:
|
||||
return {
|
||||
'metadata': self.metadata.to_dict(),
|
||||
'preprocessing': {
|
||||
'spatial': self.metadata.spatial,
|
||||
'board': self.board,
|
||||
'observe_state': self.observe_state,
|
||||
'observe_state_sizes': self.observe_state_sizes,
|
||||
'egocentric': self.egocentric,
|
||||
'egocentric_player': self.egocentric_player,
|
||||
'egocentric_radius': self.egocentric_radius,
|
||||
},
|
||||
'hidden_sizes': self.hp['hidden_sizes'],
|
||||
}
|
||||
|
||||
def _check_compatibility(self, ckpt: dict, path: str | Path):
|
||||
snapshot = ckpt.get('config_snapshot')
|
||||
if snapshot is None:
|
||||
return
|
||||
|
||||
current = self._config_snapshot()
|
||||
issues = []
|
||||
|
||||
old_meta = snapshot.get('metadata', {})
|
||||
new_meta = current['metadata']
|
||||
for field, desc in _INCOMPATIBLE_METADATA.items():
|
||||
if old_meta.get(field) != new_meta.get(field):
|
||||
issues.append((field, desc, old_meta.get(field), new_meta.get(field)))
|
||||
|
||||
old_pre = snapshot.get('preprocessing', {})
|
||||
new_pre = current['preprocessing']
|
||||
for field, desc in _INCOMPATIBLE_PREPROCESSING.items():
|
||||
if old_pre.get(field) != new_pre.get(field):
|
||||
issues.append((field, desc, old_pre.get(field), new_pre.get(field)))
|
||||
|
||||
for field, desc in _INCOMPATIBLE_ARCH.items():
|
||||
if snapshot.get(field) != current.get(field):
|
||||
issues.append((field, desc, snapshot.get(field), current.get(field)))
|
||||
|
||||
if not issues:
|
||||
return
|
||||
|
||||
lines = [
|
||||
f"Cannot resume from {Path(path).name}: incompatible changes detected in config.toml.",
|
||||
"",
|
||||
"The following changes require starting fresh. The existing model was trained",
|
||||
"on a different problem and its weights cannot be reused:",
|
||||
"",
|
||||
]
|
||||
for field, desc, old_val, new_val in issues:
|
||||
lines += [
|
||||
f" {field}",
|
||||
f" was : {old_val!r}",
|
||||
f" now : {new_val!r}",
|
||||
f" why : {desc}",
|
||||
"",
|
||||
]
|
||||
lines += [
|
||||
"Run 'retro-gamer clean RUN_DIR' to remove existing checkpoints and the",
|
||||
"training log, then run 'retro-gamer train RUN_DIR' to start fresh.",
|
||||
]
|
||||
raise ValueError("\n".join(lines))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Initialisation helpers
|
||||
# ------------------------------------------------------------------
|
||||
@@ -215,6 +636,16 @@ class DQNTrainer:
|
||||
f"after {self.hp['exploration_turns']} exploration turns: {chars}"
|
||||
)
|
||||
|
||||
def _discover_observe_state_sizes(self):
|
||||
"""Sample game.state to determine the flat size of each observe_state key."""
|
||||
self.env.reset()
|
||||
state = dict(self.env.game.state)
|
||||
sizes = {}
|
||||
for key in self.observe_state:
|
||||
val = state.get(key, 0)
|
||||
sizes[key] = len(val) if isinstance(val, (list, tuple)) else 1
|
||||
self.observe_state_sizes = sizes
|
||||
|
||||
def _save_config(self):
|
||||
config_path = self.run_dir / 'config.toml'
|
||||
config: dict = {}
|
||||
@@ -223,33 +654,75 @@ class DQNTrainer:
|
||||
with open(config_path, 'rb') as f:
|
||||
config = tomllib.load(f)
|
||||
config['metadata'] = self.metadata.to_dict()
|
||||
config['hyperparameters'] = self.hp
|
||||
pre = config.setdefault('preprocessing', {})
|
||||
pre['spatial'] = self.metadata.spatial
|
||||
pre['board'] = self.board
|
||||
pre['observe_state'] = self.observe_state
|
||||
if self.observe_state_sizes:
|
||||
pre['observe_state_sizes'] = self.observe_state_sizes
|
||||
pre['egocentric'] = self.egocentric
|
||||
if self.egocentric_player:
|
||||
pre['egocentric_player'] = self.egocentric_player
|
||||
if self.egocentric_radius:
|
||||
pre['egocentric_radius'] = self.egocentric_radius
|
||||
config['model'] = {k: v for k, v in self.hp.items() if k in MODEL_KEYS}
|
||||
config['training'] = {k: v for k, v in self.hp.items() if k not in MODEL_KEYS}
|
||||
with open(config_path, 'wb') as f:
|
||||
tomli_w.dump(config, f)
|
||||
|
||||
def _open_log(self, rationale: str):
|
||||
self.log_path = self.run_dir / 'training.log'
|
||||
with open(self.log_path, 'w') as f:
|
||||
f.write(rationale + '\n')
|
||||
if not self.log_path.exists():
|
||||
with open(self.log_path, 'w') as f:
|
||||
f.write(rationale + '\n')
|
||||
f.write(f'[INIT] Device: {self.device}\n')
|
||||
|
||||
def _log_raw(self, line: str):
|
||||
with open(self.log_path, 'a') as f:
|
||||
f.write(line + '\n')
|
||||
|
||||
def _log_episode(self, episode: int, total_reward: float, steps: int, avg_loss: float):
|
||||
def _log_checkpoint(
|
||||
self,
|
||||
episode: int,
|
||||
rewards: list[float],
|
||||
losses: list[float],
|
||||
steps: list[int],
|
||||
ckpt_elapsed: float,
|
||||
) -> dict:
|
||||
n = len(rewards)
|
||||
start_ep = episode - n + 1
|
||||
avg_reward = sum(rewards) / n if n else 0.0
|
||||
avg_loss = sum(losses) / len(losses) if losses else 0.0
|
||||
avg_steps = sum(steps) / n if n else 0.0
|
||||
line = (
|
||||
f"[EP {episode:04d}] total_reward={total_reward:.1f} "
|
||||
f"steps={steps} epsilon={self.epsilon:.4f} avg_loss={avg_loss:.6f}"
|
||||
f"[ep_{episode:04d}]"
|
||||
f" ep={start_ep:04d}-{episode:04d}"
|
||||
f" avg_reward={avg_reward:+.1f}"
|
||||
f" avg_steps={avg_steps:.0f}"
|
||||
f" epsilon={self.epsilon:.3f}"
|
||||
f" avg_loss={avg_loss:.1f}"
|
||||
f" time={_format_duration(ckpt_elapsed)}"
|
||||
f" total={_format_duration(self.total_training_seconds)}"
|
||||
)
|
||||
self._log_raw(line)
|
||||
return {
|
||||
'episode': episode,
|
||||
'avg_reward': avg_reward,
|
||||
'avg_steps': avg_steps,
|
||||
'avg_loss': avg_loss,
|
||||
'epsilon': self.epsilon,
|
||||
}
|
||||
|
||||
def _save_checkpoint(self, name: str):
|
||||
def _save_checkpoint(self, name: str, episode: int):
|
||||
torch.save(
|
||||
{
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'epsilon': self.epsilon,
|
||||
'total_steps': self.total_steps,
|
||||
'episode': episode,
|
||||
'total_training_seconds': self.total_training_seconds,
|
||||
'config_snapshot': self._config_snapshot(),
|
||||
},
|
||||
self.run_dir / 'checkpoints' / name,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user