Files
retro-gamer/retro_gamer/trainer.py
2026-06-22 16:41:31 -04:00

729 lines
30 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import random
from datetime import datetime
from pathlib import Path
from time import perf_counter
from typing import Callable
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tomli_w
from tqdm import tqdm
from retro_gamer.metadata import GameMetadata
from retro_gamer.env import GameEnvironment
from retro_gamer.network import build_network
from retro_gamer.memory import ReplayMemory, PrioritizedReplayMemory
MODEL_KEYS: frozenset = frozenset({'hidden_sizes'})
DEFAULTS: dict = {
# [model]
'hidden_sizes': [128, 64],
# [training]
'learning_rate': 1e-4,
'learning_rate_decay': 0.9999,
'gamma': 0.99,
'epsilon': 1.0,
'epsilon_decay': 0.9997,
'epsilon_min': 0.05,
'batch_size': 64,
'memory_capacity': 50_000,
'target_update_freq': 500,
'train_every': 4,
'training_episodes': 20_000,
'prioritize_experiences': True,
'exploration_turns': 200,
'unknown_character_strategy': 'ignore',
'max_turns_per_episode': 2_000,
}
def _get_device() -> torch.device:
if torch.backends.mps.is_available():
return torch.device('mps')
if torch.cuda.is_available():
return torch.device('cuda')
return torch.device('cpu')
# Fields that make an existing checkpoint incompatible with the current config.
# Changing any of these requires starting training from scratch.
_INCOMPATIBLE_METADATA = {
'actions': 'the list of actions the agent can take (changes output layer size)',
'reward': 'the reward signal — Q-values trained on the old signal are meaningless for the new one',
'character_set': 'the set of board characters (changes input layer size)',
'board_size': 'the board dimensions (changes input layer size)',
}
_INCOMPATIBLE_PREPROCESSING = {
'spatial': 'spatial vs non-spatial network type (changes network architecture)',
'board': 'whether the board is included in the observation (changes input size)',
'observe_state': 'the state keys included in the observation (changes input size)',
'observe_state_sizes': 'the size of each observed state key (changes input layer size)',
'egocentric': 'egocentric board transformation (changes input representation)',
'egocentric_player': 'the agent used as the egocentric center (changes input representation)',
'egocentric_radius': 'the egocentric crop radius (changes input layer size)',
}
_INCOMPATIBLE_ARCH = {
'hidden_sizes': 'the hidden layer sizes (changes network shape)',
}
def validate_hyperparams(hp: dict):
"""Check all hyperparameters and raise ValueError listing every problem found."""
problems = []
def _problem(heading: str, explanation: str, fix: str | None = None):
text = f" {heading}\n {explanation}"
if fix:
text += f"\n{fix}"
problems.append(text)
hs = hp.get('hidden_sizes')
if not isinstance(hs, list) or len(hs) == 0:
_problem(
f"hidden_sizes = {hs!r}",
"This sets the shape of the neural network's hidden layers. It must be a\n"
" non-empty list of positive integers, one number per layer.",
"Try: hidden_sizes = [512, 256]",
)
elif any(not isinstance(s, int) or s <= 0 for s in hs):
bad = [s for s in hs if not isinstance(s, int) or s <= 0]
_problem(
f"hidden_sizes = {hs!r}",
f"Every layer size must be a positive integer, but got {bad!r}.\n"
" Each number is the count of neurons in that layer — more neurons means\n"
" more capacity to learn complex patterns.",
"Try: hidden_sizes = [512, 256]",
)
lr = hp.get('learning_rate')
if not isinstance(lr, (int, float)) or lr <= 0:
_problem(
f"learning_rate = {lr!r}",
"The learning rate controls how much the network adjusts its weights after\n"
" each training step. Too high and training becomes unstable; too low and\n"
" it learns very slowly. It must be a positive number.",
"Typical values are between 0.0001 and 0.01. Try: learning_rate = 0.001",
)
for key, blurb in [
('learning_rate_decay',
"After each episode, the learning rate is multiplied by this value, gradually\n"
" slowing learning over time. A value of 1.0 means no decay; closer to 0\n"
" means very aggressive decay. It must be greater than 0 and at most 1."),
('gamma',
"Gamma is the discount factor: how much the agent values future rewards versus\n"
" immediate ones. 0.99 means future rewards are nearly as important as now;\n"
" 0.0 means the agent only cares about the very next step.\n"
" It must be greater than 0 and at most 1."),
('epsilon_decay',
"Each episode, the exploration rate (epsilon) is multiplied by this to gradually\n"
" reduce random actions over time. It must be between 0 (exclusive) and 1."),
]:
v = hp.get(key)
if not isinstance(v, (int, float)) or not (0 < v <= 1):
_problem(
f"{key} = {v!r}",
blurb,
f"Try a value close to but less than 1, like {key} = 0.995",
)
for key, blurb in [
('epsilon',
"Epsilon is the probability of taking a random action (exploration vs.\n"
" exploitation). It starts high — usually 1.0, meaning fully random —\n"
" and decays toward epsilon_min during training. It must be between 0 and 1."),
('epsilon_min',
"This is the lowest exploration rate allowed. Even after lots of training, the\n"
" agent keeps at least this much randomness so it keeps discovering new things.\n"
" It must be between 0 and 1."),
]:
v = hp.get(key)
if not isinstance(v, (int, float)) or not (0 <= v <= 1):
_problem(
f"{key} = {v!r}",
blurb,
f"Try: {key} = {'0.05' if 'min' in key else '1.0'}",
)
eps = hp.get('epsilon')
eps_min = hp.get('epsilon_min')
if (isinstance(eps, (int, float)) and isinstance(eps_min, (int, float))
and 0 <= eps <= 1 and 0 <= eps_min <= 1 and eps_min > eps):
_problem(
f"epsilon_min = {eps_min!r} is greater than epsilon = {eps!r}",
"epsilon is the starting exploration rate and epsilon_min is the floor it\n"
" decays toward, so epsilon_min must be less than or equal to epsilon.",
f"Try: epsilon = 1.0 and epsilon_min = 0.05",
)
for key, blurb in [
('batch_size',
"Training samples this many past experiences from the replay buffer at once\n"
" to compute a learning update. Must be a positive integer."),
('memory_capacity',
"The replay buffer stores this many past experiences. When it fills up, the\n"
" oldest are discarded. A larger buffer means more diverse training data.\n"
" Must be a positive integer."),
('target_update_freq',
"The target network (a stable copy of the Q-network used to compute targets)\n"
" is updated every this many steps. Must be a positive integer."),
('train_every',
"A training step runs once every this many game steps. This lets the agent\n"
" collect several new experiences before updating. Must be a positive integer."),
('training_episodes',
"The total number of episodes (games) to train for. Must be a positive integer."),
('max_turns_per_episode',
"If a game episode hasn't ended naturally after this many steps, it's cut\n"
" short. This prevents a buggy or stuck agent from running forever.\n"
" Must be a positive integer."),
]:
v = hp.get(key)
if not isinstance(v, int) or v <= 0:
_problem(
f"{key} = {v!r}",
blurb,
f"Try: {key} = {DEFAULTS[key]}",
)
v = hp.get('exploration_turns')
if not isinstance(v, int) or v < 0:
_problem(
f"exploration_turns = {v!r}",
"When no character_set is specified, the trainer runs this many random turns\n"
" to discover what characters appear on the board. Must be 0 or more\n"
" (0 skips discovery entirely, which only works if character_set is set).",
f"Try: exploration_turns = {DEFAULTS['exploration_turns']}",
)
bs = hp.get('batch_size')
mc = hp.get('memory_capacity')
if (isinstance(bs, int) and bs > 0 and isinstance(mc, int) and mc > 0 and bs > mc):
_problem(
f"batch_size = {bs} is larger than memory_capacity = {mc}",
"Training samples a batch of past experiences from the replay buffer each\n"
" step, so the buffer must be able to hold at least as many experiences\n"
f" as the batch size. With batch_size = {bs}, the buffer holds {mc}\n"
" that's not enough to sample from.",
f"Try: memory_capacity = {max(bs * 100, DEFAULTS['memory_capacity'])} "
f"(a much larger buffer also improves learning quality)",
)
if problems:
n = len(problems)
noun = "problem" if n == 1 else "problems"
header = f"Found {n} {noun} in your training configuration:\n\n"
footer = "\n\nFix these in config.toml, then run 'retro-gamer train' again."
raise ValueError(header + "\n\n".join(problems) + footer)
def _format_duration(seconds: float) -> str:
m, s = divmod(int(seconds), 60)
h, m = divmod(m, 60)
if h:
return f"{h}h{m:02d}m{s:02d}s"
return f"{m}m{s:02d}s"
class DQNTrainer:
"""Trains a deep Q-network agent to play a retro game.
Automatically selects the best available training device: Apple Silicon
GPU (MPS), NVIDIA GPU (CUDA), or CPU. The chosen device is recorded in
``training.log``.
On initialization, the trainer:
1. Discovers the character set if not already specified in *metadata*.
2. Builds the Q-network and logs its full architecture with rationale.
3. Writes ``config.toml`` and initializes ``training.log`` in *run_dir*.
Hyperparameters can be passed as keyword arguments; see the
:ref:`hyperparameters` reference for all options. Values not supplied
fall back to sensible defaults.
Call :meth:`train` to run all episodes. Checkpoints are saved every 100
episodes and training can be stopped (Ctrl+C) and resumed at any time.
Example::
from retro_gamer import GameMetadata, DQNTrainer
from retro.examples.snake import create_game
metadata = GameMetadata.from_pyproject("retro.examples.snake")
trainer = DQNTrainer(create_game, metadata, "runs/snake/")
trainer.train()
"""
def __init__(
self,
game_factory: Callable,
metadata: GameMetadata,
run_dir: str | Path,
preprocessing: dict | None = None,
**hyperparams,
):
self.game_factory = game_factory
self.metadata = metadata
self.run_dir = Path(run_dir)
self.hp: dict = {**DEFAULTS, **hyperparams}
validate_hyperparams(self.hp)
self.run_dir.mkdir(parents=True, exist_ok=True)
(self.run_dir / 'checkpoints').mkdir(exist_ok=True)
pre = preprocessing or {}
self.observe_state: list[str] = pre.get('observe_state', [])
self.egocentric: bool = pre.get('egocentric', False)
self.egocentric_player: str | None = pre.get('egocentric_player', None)
self.egocentric_radius: int | None = pre.get('egocentric_radius', None)
self.board: bool = pre.get('board', True)
self.observe_state_sizes: dict[str, int] = pre.get('observe_state_sizes', {})
if self.board is False and metadata.spatial:
raise ValueError(
"preprocessing.board = false is incompatible with spatial = true.\n"
"A CNN requires a 2-D board to operate on. Either set spatial = false\n"
"or keep board = true."
)
if self.board is False and not self.observe_state:
raise ValueError(
"preprocessing.board = false requires at least one entry in observe_state.\n"
"With board=false, the agent observes only the game state variables listed\n"
"in observe_state — if that list is empty, there is nothing to observe."
)
if self.egocentric and not self.egocentric_radius:
raise ValueError(
"preprocessing.egocentric = true requires egocentric_radius.\n"
"Choose a value based on how far the agent needs to see, e.g.:\n"
" egocentric_radius = 5 # 11×11 tight local view\n"
" egocentric_radius = 8 # 17×17 wider view"
)
metadata.board = self.board
if metadata.board_size is None:
g = game_factory()
metadata.board_size = g.board_size
if self.egocentric_radius:
side = 2 * self.egocentric_radius + 1
metadata.board_size = (side, side)
self.env = GameEnvironment(
game_factory, metadata,
observe_state=self.observe_state,
egocentric=self.egocentric,
egocentric_player=self.egocentric_player,
egocentric_radius=self.egocentric_radius,
board=self.board,
observe_state_sizes=self.observe_state_sizes,
)
if metadata.character_set is None and self.board:
self._discover_character_set()
if self.observe_state and not self.observe_state_sizes:
self._discover_observe_state_sizes()
self.env.observe_state_sizes = self.observe_state_sizes
metadata.extras_size = sum(self.observe_state_sizes.values()) if self.observe_state_sizes else 0
self.device = _get_device()
self.model, rationale = build_network(metadata, self.hp)
self.target_model, _ = build_network(metadata, self.hp)
self.model.to(self.device)
self.target_model.to(self.device)
self.target_model.load_state_dict(self.model.state_dict())
self.target_model.eval()
self.optimizer = optim.Adam(
self.model.parameters(), lr=self.hp['learning_rate']
)
self.lr_scheduler = optim.lr_scheduler.ExponentialLR(
self.optimizer, gamma=self.hp['learning_rate_decay']
)
if self.hp['prioritize_experiences']:
self.memory = PrioritizedReplayMemory(self.hp['memory_capacity'])
else:
self.memory = ReplayMemory(self.hp['memory_capacity'])
self.epsilon: float = self.hp['epsilon']
self.total_steps: int = 0
self.total_training_seconds: float = 0.0
self.start_episode: int = 1
self._resumed_from: str | None = None
self._save_config()
self._open_log(rationale)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def train(self, on_checkpoint=None, on_episode=None):
"""Run all training episodes and save checkpoints.
*on_checkpoint*, if provided, is called after each checkpoint with a
dict containing ``episode``, ``avg_reward``, ``avg_steps``,
``avg_loss``, and ``epsilon``. *on_episode*, if provided, is called
after every episode. When either callback is supplied, the built-in
tqdm progress bar is suppressed (the caller is expected to show its
own progress UI).
"""
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
if self._resumed_from:
self._log_raw(f'\n=== Resumed from {self._resumed_from} | {timestamp} ===')
else:
self._log_raw(f'\n=== Training started | {timestamp} ===')
use_tqdm = on_checkpoint is None and on_episode is None
if use_tqdm:
print("Press Control+C to stop training early. Progress will be saved at the latest checkpoint.")
session_start = perf_counter()
ckpt_start = perf_counter()
episode_rewards: list[float] = []
episode_losses: list[float] = []
episode_steps: list[int] = []
episodes = range(self.start_episode, self.hp['training_episodes'] + 1)
bar = tqdm(episodes, unit='ep') if use_tqdm else episodes
for episode in bar:
total_reward, steps, avg_loss, trained = self._run_episode()
episode_rewards.append(total_reward)
if avg_loss > 0:
episode_losses.append(avg_loss)
episode_steps.append(steps)
self.epsilon = max(
self.hp['epsilon_min'], self.epsilon * self.hp['epsilon_decay']
)
if trained:
self.lr_scheduler.step()
if on_episode:
on_episode()
if use_tqdm:
bar.set_postfix(
reward=f'{total_reward:.1f}',
eps=f'{self.epsilon:.3f}',
loss=f'{avg_loss:.4f}',
)
is_checkpoint = (episode % 100 == 0)
is_last = (episode == self.hp['training_episodes'])
if is_checkpoint or (is_last and episode_rewards):
now = perf_counter()
ckpt_elapsed = now - ckpt_start
self.total_training_seconds += ckpt_elapsed
ckpt_start = now
self._save_checkpoint(f'ep_{episode:04d}.pt', episode)
stats = self._log_checkpoint(episode, episode_rewards, episode_losses, episode_steps, ckpt_elapsed)
episode_rewards = []
episode_losses = []
episode_steps = []
if on_checkpoint:
on_checkpoint(stats)
def load_checkpoint(self, path: str | Path):
"""Load a checkpoint to resume training.
Checkpoints are PyTorch state dicts stored under
``run_dir/checkpoints/``. Each contains model weights, optimizer
state, current epsilon, and total step count.
Raises :exc:`ValueError` if the checkpoint was trained with a
different character set, board size, action space, or network
architecture. The error message names each changed field and explains
why it is incompatible.
The CLI invokes this automatically; call directly only when driving
training from Python.
"""
ckpt = torch.load(path, weights_only=True, map_location='cpu')
self._check_compatibility(ckpt, path)
self.model.load_state_dict(ckpt['model_state_dict'])
self.target_model.load_state_dict(ckpt['model_state_dict'])
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
for state in self.optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(self.device)
self.epsilon = ckpt['epsilon']
self.total_steps = ckpt['total_steps']
self.total_training_seconds = ckpt.get('total_training_seconds', 0.0)
self.start_episode = ckpt.get('episode', 0) + 1
self._resumed_from = Path(path).name
# ------------------------------------------------------------------
# Training loop internals
# ------------------------------------------------------------------
def _run_episode(self) -> tuple[float, int, float, bool]:
state = self.env.reset()
total_reward = 0.0
total_loss = 0.0
loss_count = 0
for step in range(self.hp['max_turns_per_episode']):
state_t = torch.as_tensor(state, dtype=torch.float32).to(self.device)
action_idx = self._select_action(state_t)
action_key = self._idx_to_key(action_idx)
next_state, reward, done = self.env.step(action_key)
self.memory.push(state, action_idx, reward, next_state, done)
if self.total_steps % self.hp['train_every'] == 0:
loss = self._train_step()
if loss is not None:
total_loss += loss
loss_count += 1
self.total_steps += 1
if self.total_steps % self.hp['target_update_freq'] == 0:
self.target_model.load_state_dict(self.model.state_dict())
state = next_state
total_reward += reward
if done:
break
avg_loss = total_loss / loss_count if loss_count else 0.0
return total_reward, step + 1, avg_loss, loss_count > 0
def _select_action(self, state_t: torch.Tensor) -> int:
if random.random() < self.epsilon:
return random.randrange(self.metadata.n_actions)
with torch.no_grad():
return int(self.model(state_t.unsqueeze(0)).argmax().item())
def _idx_to_key(self, idx: int) -> str | None:
if idx >= len(self.metadata.actions):
return None
return self.metadata.actions[idx]
def _train_step(self) -> float | None:
if len(self.memory) < self.hp['batch_size']:
return None
if self.hp['prioritize_experiences']:
assert isinstance(self.memory, PrioritizedReplayMemory)
experiences, indices, weights = self.memory.sample(self.hp['batch_size'])
weight_t = torch.as_tensor(weights, dtype=torch.float32).to(self.device)
else:
experiences = self.memory.sample(self.hp['batch_size'])
indices = None
weight_t = None
states = torch.as_tensor(
np.array([e.state for e in experiences]), dtype=torch.float32
).to(self.device)
actions = torch.as_tensor(
[e.action for e in experiences], dtype=torch.long
).to(self.device)
rewards = torch.as_tensor(
[e.reward for e in experiences], dtype=torch.float32
).to(self.device)
next_states = torch.as_tensor(
np.array([e.next_state for e in experiences]), dtype=torch.float32
).to(self.device)
dones = torch.as_tensor(
[e.done for e in experiences], dtype=torch.float32
).to(self.device)
q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
with torch.no_grad():
next_q = self.target_model(next_states).max(1).values
targets = rewards + self.hp['gamma'] * next_q * (1.0 - dones)
element_loss = nn.functional.huber_loss(q_values, targets, reduction='none', delta=1.0)
if weight_t is not None:
loss = (weight_t * element_loss).mean()
td_errors = (q_values - targets).detach().abs().cpu().numpy()
self.memory.update_priorities(indices, td_errors)
else:
loss = element_loss.mean()
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
self.optimizer.step()
return float(loss.item())
# ------------------------------------------------------------------
# Compatibility checking
# ------------------------------------------------------------------
def _config_snapshot(self) -> dict:
return {
'metadata': self.metadata.to_dict(),
'preprocessing': {
'spatial': self.metadata.spatial,
'board': self.board,
'observe_state': self.observe_state,
'observe_state_sizes': self.observe_state_sizes,
'egocentric': self.egocentric,
'egocentric_player': self.egocentric_player,
'egocentric_radius': self.egocentric_radius,
},
'hidden_sizes': self.hp['hidden_sizes'],
}
def _check_compatibility(self, ckpt: dict, path: str | Path):
snapshot = ckpt.get('config_snapshot')
if snapshot is None:
return
current = self._config_snapshot()
issues = []
old_meta = snapshot.get('metadata', {})
new_meta = current['metadata']
for field, desc in _INCOMPATIBLE_METADATA.items():
if old_meta.get(field) != new_meta.get(field):
issues.append((field, desc, old_meta.get(field), new_meta.get(field)))
old_pre = snapshot.get('preprocessing', {})
new_pre = current['preprocessing']
for field, desc in _INCOMPATIBLE_PREPROCESSING.items():
if old_pre.get(field) != new_pre.get(field):
issues.append((field, desc, old_pre.get(field), new_pre.get(field)))
for field, desc in _INCOMPATIBLE_ARCH.items():
if snapshot.get(field) != current.get(field):
issues.append((field, desc, snapshot.get(field), current.get(field)))
if not issues:
return
lines = [
f"Cannot resume from {Path(path).name}: incompatible changes detected in config.toml.",
"",
"The following changes require starting fresh. The existing model was trained",
"on a different problem and its weights cannot be reused:",
"",
]
for field, desc, old_val, new_val in issues:
lines += [
f" {field}",
f" was : {old_val!r}",
f" now : {new_val!r}",
f" why : {desc}",
"",
]
lines += [
"Run 'retro-gamer clean RUN_DIR' to remove existing checkpoints and the",
"training log, then run 'retro-gamer train RUN_DIR' to start fresh.",
]
raise ValueError("\n".join(lines))
# ------------------------------------------------------------------
# Initialisation helpers
# ------------------------------------------------------------------
def _discover_character_set(self):
chars = self.env.discover_character_set(self.hp['exploration_turns'])
self.metadata.character_set = chars
self._log_raw(
f"[INIT] character_set not specified — discovered {len(chars)} chars "
f"after {self.hp['exploration_turns']} exploration turns: {chars}"
)
def _discover_observe_state_sizes(self):
"""Sample game.state to determine the flat size of each observe_state key."""
self.env.reset()
state = dict(self.env.game.state)
sizes = {}
for key in self.observe_state:
val = state.get(key, 0)
sizes[key] = len(val) if isinstance(val, (list, tuple)) else 1
self.observe_state_sizes = sizes
def _save_config(self):
config_path = self.run_dir / 'config.toml'
config: dict = {}
if config_path.exists():
import tomllib
with open(config_path, 'rb') as f:
config = tomllib.load(f)
config['metadata'] = self.metadata.to_dict()
pre = config.setdefault('preprocessing', {})
pre['spatial'] = self.metadata.spatial
pre['board'] = self.board
pre['observe_state'] = self.observe_state
if self.observe_state_sizes:
pre['observe_state_sizes'] = self.observe_state_sizes
pre['egocentric'] = self.egocentric
if self.egocentric_player:
pre['egocentric_player'] = self.egocentric_player
if self.egocentric_radius:
pre['egocentric_radius'] = self.egocentric_radius
config['model'] = {k: v for k, v in self.hp.items() if k in MODEL_KEYS}
config['training'] = {k: v for k, v in self.hp.items() if k not in MODEL_KEYS}
with open(config_path, 'wb') as f:
tomli_w.dump(config, f)
def _open_log(self, rationale: str):
self.log_path = self.run_dir / 'training.log'
if not self.log_path.exists():
with open(self.log_path, 'w') as f:
f.write(rationale + '\n')
f.write(f'[INIT] Device: {self.device}\n')
def _log_raw(self, line: str):
with open(self.log_path, 'a') as f:
f.write(line + '\n')
def _log_checkpoint(
self,
episode: int,
rewards: list[float],
losses: list[float],
steps: list[int],
ckpt_elapsed: float,
) -> dict:
n = len(rewards)
start_ep = episode - n + 1
avg_reward = sum(rewards) / n if n else 0.0
avg_loss = sum(losses) / len(losses) if losses else 0.0
avg_steps = sum(steps) / n if n else 0.0
line = (
f"[ep_{episode:04d}]"
f" ep={start_ep:04d}-{episode:04d}"
f" avg_reward={avg_reward:+.1f}"
f" avg_steps={avg_steps:.0f}"
f" epsilon={self.epsilon:.3f}"
f" avg_loss={avg_loss:.1f}"
f" time={_format_duration(ckpt_elapsed)}"
f" total={_format_duration(self.total_training_seconds)}"
)
self._log_raw(line)
return {
'episode': episode,
'avg_reward': avg_reward,
'avg_steps': avg_steps,
'avg_loss': avg_loss,
'epsilon': self.epsilon,
}
def _save_checkpoint(self, name: str, episode: int):
torch.save(
{
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'epsilon': self.epsilon,
'total_steps': self.total_steps,
'episode': episode,
'total_training_seconds': self.total_training_seconds,
'config_snapshot': self._config_snapshot(),
},
self.run_dir / 'checkpoints' / name,
)