Updates across the board

This commit is contained in:
Chris Proctor
2026-06-22 16:41:31 -04:00
parent 5ca97dc5d0
commit 73624d1a0c
33 changed files with 3104 additions and 643 deletions

View File

@@ -1,5 +1,6 @@
from retro_gamer.metadata import GameMetadata
from retro_gamer.env import GameEnvironment
from retro_gamer.trainer import DQNTrainer
from retro_gamer.model_agent import TrainedPolicy, PolicyInput
__all__ = ["GameMetadata", "GameEnvironment", "DQNTrainer"]
__all__ = ["GameMetadata", "GameEnvironment", "DQNTrainer", "TrainedPolicy", "PolicyInput"]

View File

@@ -1,12 +1,13 @@
from __future__ import annotations
import importlib
import sys
import tomllib
from pathlib import Path
import click
import tomli_w
from retro_gamer.metadata import GameMetadata
from retro_gamer.trainer import DQNTrainer, DEFAULTS
from retro_gamer.trainer import DQNTrainer, DEFAULTS, MODEL_KEYS
@click.group()
@@ -20,13 +21,14 @@ def cli():
@cli.command()
@click.option('--game', required=True,
help='Python module containing create_game() e.g. retro.examples.snake')
help='Game to train: a .py file path (e.g. my_game.py) or a Python module '
'(e.g. retro.examples.snake)')
@click.option('--output', required=True,
help='Directory to create for this training run')
@click.option('--learning-rate', default=DEFAULTS['learning_rate'], type=float,
help=f"Adam optimizer learning rate (default {DEFAULTS['learning_rate']})")
@click.option('--lr-decay', default=DEFAULTS['lr_decay'], type=float,
help=f"Multiplicative LR decay per episode (default {DEFAULTS['lr_decay']})")
@click.option('--learning-rate-decay', default=DEFAULTS['learning_rate_decay'], type=float,
help=f"Multiplicative LR decay per episode (default {DEFAULTS['learning_rate_decay']})")
@click.option('--gamma', default=DEFAULTS['gamma'], type=float,
help=f"Discount factor for future rewards (default {DEFAULTS['gamma']})")
@click.option('--epsilon-decay', default=DEFAULTS['epsilon_decay'], type=float,
@@ -43,12 +45,12 @@ def cli():
help=f"Number of episodes to train (default {DEFAULTS['training_episodes']})")
@click.option('--max-turns-per-episode', default=DEFAULTS['max_turns_per_episode'], type=int,
help=f"Turn limit per episode (default {DEFAULTS['max_turns_per_episode']})")
@click.option('--n-layers', default=DEFAULTS['n_layers'], type=int,
help=f"Hidden layers in MLP head (default {DEFAULTS['n_layers']})")
@click.option('--layer-size', default=DEFAULTS['layer_size'], type=int,
help=f"Width of each hidden layer (default {DEFAULTS['layer_size']})")
@click.option('--hidden-sizes', default=','.join(str(s) for s in DEFAULTS['hidden_sizes']),
help=f"Comma-separated hidden layer sizes, e.g. 512,256 (default {DEFAULTS['hidden_sizes']})")
@click.option('--exploration-turns', default=DEFAULTS['exploration_turns'], type=int,
help=f"Random turns for character discovery (default {DEFAULTS['exploration_turns']})")
@click.option('--train-every', default=DEFAULTS['train_every'], type=int,
help=f"Run a training step every N game steps (default {DEFAULTS['train_every']})")
@click.option('--prioritize-experiences/--no-prioritize-experiences',
default=DEFAULTS['prioritize_experiences'],
help='Use prioritized experience replay')
@@ -60,12 +62,23 @@ def create(game, output, **hyperparams):
Board size is read directly from the game. Hyperparameter options
control how the trainer learns, not what it learns about.
"""
raw = hyperparams['hidden_sizes']
try:
metadata = GameMetadata.from_pyproject(game)
hyperparams['hidden_sizes'] = [int(x.strip()) for x in raw.split(',')]
except ValueError:
raise click.ClickException(
f"Could not parse --hidden-sizes {raw!r}.\n"
"It should be a comma-separated list of positive integers, one per hidden layer.\n"
"Example: --hidden-sizes 512,256"
)
game_config = _parse_game_arg(game)
try:
metadata = GameMetadata.from_pyproject(game_config['module'])
except (FileNotFoundError, ValueError) as e:
raise click.ClickException(str(e))
game_factory = _load_factory(game)
game_factory = _load_factory(game_config)
g = game_factory()
metadata.board_size = g.board_size
@@ -74,10 +87,13 @@ def create(game, output, **hyperparams):
run_dir = Path(output)
run_dir.mkdir(parents=True, exist_ok=True)
preprocessing = {'spatial': metadata.spatial}
config = {
'game': {'module': game},
'game': game_config,
'metadata': metadata.to_dict(),
'hyperparameters': hyperparams,
'preprocessing': preprocessing,
'model': {k: v for k, v in hyperparams.items() if k in MODEL_KEYS},
'training': {k: v for k, v in hyperparams.items() if k not in MODEL_KEYS},
}
with open(run_dir / 'config.toml', 'wb') as f:
tomli_w.dump(config, f)
@@ -91,8 +107,6 @@ def create(game, output, **hyperparams):
click.echo(f" characters : {metadata.character_set}")
else:
click.echo(f" characters : (will be auto-discovered during training)")
if metadata.observe_state:
click.echo(f" observe : {metadata.observe_state}")
click.echo(f" architecture: {'CNN (spatial)' if metadata.spatial else 'MLP (non-spatial)'}")
@@ -102,23 +116,49 @@ def create(game, output, **hyperparams):
@cli.command()
@click.argument('run_dir')
@click.option('--resume', default=None,
help='Path to checkpoint to resume from (e.g. checkpoints/ep_0500.pt)')
def train(run_dir, resume):
"""Train (or resume training) a DQN agent."""
def train(run_dir):
"""Train a DQN agent, resuming automatically from the latest checkpoint.
To start fresh or resume from an earlier point, delete the checkpoints
you no longer want from RUN_DIR/checkpoints/.
"""
run_dir_path = Path(run_dir)
config = _load_config(run_dir_path)
game_factory = _load_factory(config['game']['module'])
game_factory = _load_factory(config['game'])
metadata = GameMetadata.from_dict(config['metadata'])
hyperparams = config.get('hyperparameters', {})
preprocessing = config.get('preprocessing', {})
metadata.spatial = preprocessing.get('spatial', False)
hyperparams = {**config.get('model', {}), **config['training']}
trainer = DQNTrainer(game_factory, metadata, run_dir, **hyperparams)
if resume:
click.echo(f"Resuming from {resume}")
trainer.load_checkpoint(resume)
click.echo(f"Training for {trainer.hp['training_episodes']} episodes…")
trainer.train()
click.echo(f"Done. Checkpoints in {run_dir}/checkpoints/")
try:
trainer = DQNTrainer(game_factory, metadata, run_dir, preprocessing=preprocessing, **hyperparams)
except ValueError as e:
raise click.ClickException(str(e))
latest = _latest_checkpoint(run_dir_path)
if latest:
click.echo(f"Resuming from {latest.name}")
try:
trainer.load_checkpoint(latest)
except ValueError as e:
raise click.ClickException(str(e))
if trainer.start_episode > trainer.hp['training_episodes']:
click.echo(
f"Training already complete ({trainer.hp['training_episodes']} episodes). "
"To keep training, increase training_episodes in config.toml."
)
return
from retro_gamer.log_parser import parse_checkpoints
from retro_gamer.dashboard import TrainingDashboard
history = parse_checkpoints(run_dir_path / 'training.log')
display = TrainingDashboard(trainer.hp['training_episodes'], trainer.start_episode, history)
try:
trainer.train(on_checkpoint=display.on_checkpoint, on_episode=display.on_episode)
finally:
display.close()
click.echo(f"Done. Checkpoints saved in {run_dir}/checkpoints/")
# ---------------------------------------------------------------------------
@@ -127,69 +167,72 @@ def train(run_dir, resume):
@cli.command()
@click.argument('run_dir')
@click.option('--checkpoint', default='final',
help='Checkpoint name e.g. "final" or "ep_0100"')
@click.option('--checkpoint', default=None,
help='Checkpoint to load, e.g. "ep_0100". Defaults to the latest available.')
@click.option('--framerate', default=12, type=int,
help='Target frames per second')
def play(run_dir, checkpoint, framerate):
"""Watch a trained agent play the game."""
import torch
from time import sleep, perf_counter
from blessed import Terminal
from retro.input import ProgrammaticInput
from retro.views.headless import HeadlessView
from retro.views.terminal import TerminalView
from retro_gamer.observation import encode_observation
from retro_gamer.model_agent import TrainedPolicy, PolicyInput
run_dir_path = Path(run_dir)
config = _load_config(run_dir_path)
game_factory = _load_factory(config['game']['module'])
metadata = GameMetadata.from_dict(config['metadata'])
hyperparams = {**DEFAULTS, **config.get('hyperparameters', {})}
game_factory = _load_factory(config['game'])
from retro_gamer.network import build_network
model, _ = build_network(metadata, hyperparams)
try:
ai = TrainedPolicy(run_dir_path, checkpoint=checkpoint)
except FileNotFoundError as e:
raise click.ClickException(str(e))
ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt'
ckpt_path = run_dir_path / 'checkpoints' / ckpt_name
ckpt = torch.load(ckpt_path, weights_only=True)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
inp = ProgrammaticInput()
headless = HeadlessView()
game = game_factory()
game.input_source = inp
game.view = headless
game.start()
inp = PolicyInput(ai, game)
terminal = Terminal()
term_view = TerminalView(terminal, color=game.color)
click.echo("Playing… (press Escape or Enter to quit)")
with terminal.fullscreen(), terminal.hidden_cursor(), terminal.cbreak():
term_view.on_game_start(game)
game.input_source = inp
game.view = term_view
game.start()
while game.playing:
t0 = perf_counter()
obs = encode_observation(headless.board_characters, dict(game.state), metadata)
state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
action_idx = int(model(state_t).argmax().item())
action_key = None if action_idx >= len(metadata.actions) else metadata.actions[action_idx]
key = terminal.inkey(0)
if key and key.name in ('KEY_ESCAPE', 'KEY_ENTER'):
break
inp.press(action_key)
game.step()
term_view.render(game)
elapsed = perf_counter() - t0
sleep(max(0, 1 / framerate - elapsed))
# ---------------------------------------------------------------------------
# retro-gamer plot
# ---------------------------------------------------------------------------
@cli.command()
@click.argument('run_dir')
@click.option('--output', '-o', default=None,
help='Save to file (e.g. plot.png) instead of displaying interactively.')
def plot(run_dir, output):
"""Plot training metrics (reward, steps, loss, epsilon) from a run's log."""
from retro_gamer.plotter import plot_run
run_dir_path = Path(run_dir)
log_path = run_dir_path / 'training.log'
if not log_path.exists():
raise click.ClickException(f"No training.log found in {run_dir}")
output_path = Path(output) if output else None
try:
plot_run(log_path, output_path)
except ValueError as e:
raise click.ClickException(str(e))
# ---------------------------------------------------------------------------
# retro-gamer info
# ---------------------------------------------------------------------------
@@ -200,17 +243,19 @@ def info(run_dir):
"""Print a summary of a training run."""
run_dir_path = Path(run_dir)
config = _load_config(run_dir_path)
click.echo(f"Game module : {config['game']['module']}")
click.echo(f"Metadata : {config['metadata']}")
click.echo(f"Hyperparams : {config.get('hyperparameters', {})}")
click.echo(f"Game module : {config['game']['module']}")
click.echo(f"Metadata : {config['metadata']}")
click.echo(f"Preprocessing : {config.get('preprocessing', {})}")
click.echo(f"Model : {config.get('model', {})}")
click.echo(f"Training : {config['training']}")
log_path = run_dir_path / 'training.log'
if log_path.exists():
lines = log_path.read_text().splitlines()
episode_lines = [l for l in lines if l.startswith('[EP')]
if episode_lines:
click.echo(f"\nLast 5 episodes:")
for line in episode_lines[-5:]:
ckpt_lines = [l for l in lines if l.startswith('[ep_')]
if ckpt_lines:
click.echo(f"\nLast 5 checkpoints:")
for line in ckpt_lines[-5:]:
click.echo(f" {line}")
ckpt_dir = run_dir_path / 'checkpoints'
@@ -219,10 +264,91 @@ def info(run_dir):
click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}")
# ---------------------------------------------------------------------------
# retro-gamer clean
# ---------------------------------------------------------------------------
@cli.command()
@click.argument('run_dir')
@click.option('--yes', '-y', is_flag=True, help='Skip confirmation prompt')
def clean(run_dir, yes):
"""Remove all checkpoints and the training log from a run directory.
Use this after changing the game description or network architecture,
which require starting training from scratch.
"""
run_dir_path = Path(run_dir)
if not run_dir_path.exists():
raise click.ClickException(f"Directory not found: {run_dir}")
to_remove = []
ckpt_dir = run_dir_path / 'checkpoints'
if ckpt_dir.exists():
to_remove.extend(sorted(ckpt_dir.glob('*.pt')))
log_path = run_dir_path / 'training.log'
if log_path.exists():
to_remove.append(log_path)
if not to_remove:
click.echo("Nothing to clean.")
return
n_ckpts = sum(1 for p in to_remove if p.suffix == '.pt')
click.echo(f"Will remove {n_ckpts} checkpoint(s) and training log from {run_dir}/:")
for p in to_remove:
click.echo(f" {p.relative_to(run_dir_path)}")
if not yes:
click.confirm("\nProceed?", abort=True)
for p in to_remove:
p.unlink()
click.echo(f"Cleaned. Run 'retro-gamer train {run_dir}' to start fresh.")
ckpt_dir = run_dir_path / 'checkpoints'
if ckpt_dir.exists():
ckpts = sorted(ckpt_dir.glob('*.pt'))
click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _parse_game_arg(arg: str) -> dict:
"""Accept a .py file path, a package directory, or a Python module name.
Returns a dict with at least 'module', and optionally 'path' (the
directory to add to sys.path so the module can be imported).
"""
p = Path(arg)
if p.exists():
if p.is_file() and p.suffix == '.py':
path = str(p.parent.resolve())
module = p.stem
elif p.is_dir():
path = str(p.parent.resolve())
module = p.name
else:
raise click.ClickException(
f"{arg!r} is not a .py file or package directory"
)
if path not in sys.path:
sys.path.insert(0, path)
return {'module': module, 'path': path}
return {'module': arg}
def _latest_checkpoint(run_dir: Path) -> Path | None:
"""Return the most recent checkpoint in run_dir/checkpoints/, or None."""
ckpt_dir = run_dir / 'checkpoints'
if ckpt_dir.exists():
candidates = sorted(ckpt_dir.glob('ep_*.pt'))
if candidates:
return candidates[-1]
return None
def _load_config(run_dir: Path) -> dict:
config_path = run_dir / 'config.toml'
if not config_path.exists():
@@ -231,7 +357,11 @@ def _load_config(run_dir: Path) -> dict:
return tomllib.load(f)
def _load_factory(module_name: str):
def _load_factory(game_config: dict):
path = game_config.get('path')
if path and path not in sys.path:
sys.path.insert(0, path)
module_name = game_config['module']
try:
module = importlib.import_module(module_name)
except ImportError as e:
@@ -241,3 +371,4 @@ def _load_factory(module_name: str):
f"Module '{module_name}' has no create_game() function"
)
return module.create_game

86
retro_gamer/dashboard.py Normal file
View File

@@ -0,0 +1,86 @@
from __future__ import annotations
import os
import sys
from tqdm import tqdm
_CHART_HEIGHT = 22 # plotext chart area height in lines
def _terminal_width() -> int:
try:
return min(os.get_terminal_size().columns, 220)
except OSError:
return 120
def _build_charts(history: list[dict], width: int) -> str:
import plotext as plt
episodes = [d['episode'] for d in history]
series = [
("Epsilon", [d['epsilon'] for d in history]),
("Avg Steps", [d['avg_steps'] for d in history]),
("Avg Loss", [d['avg_loss'] for d in history]),
("Avg Reward", [d['avg_reward'] for d in history]),
]
panel_w = max(width // len(series), 20)
panels = []
for title, values in series:
plt.clf()
plt.canvas_color("default")
plt.axes_color("default")
plt.ticks_color("default")
if episodes:
plt.plot(episodes, values, color="default")
plt.title(title)
plt.xlabel("Episode")
plt.plotsize(panel_w, _CHART_HEIGHT)
panels.append(plt.build().splitlines())
height = max(len(p) for p in panels)
for p in panels:
while len(p) < height:
p.append(' ' * panel_w)
return '\n'.join(''.join(row) for row in zip(*panels))
class TrainingDashboard:
"""Inline training display: a plotext chart block that redraws in place
above a tqdm per-episode progress bar."""
def __init__(self, total_episodes: int, start_episode: int, history: list[dict]):
self._total = total_episodes
self._history = list(history)
self._rendered = False # have we drawn charts yet?
already_done = start_episode - 1
self._bar = tqdm(
initial=already_done,
total=total_episodes,
unit='ep',
dynamic_ncols=True,
)
self._draw()
def on_episode(self) -> None:
self._bar.update(1)
def on_checkpoint(self, stats: dict) -> None:
self._history.append(stats)
self._bar.clear()
self._draw()
self._bar.refresh()
def close(self) -> None:
self._bar.close()
def _draw(self) -> None:
chart = _build_charts(self._history, _terminal_width())
n_lines = len(chart.splitlines())
if self._rendered:
sys.stdout.write(f'\033[{n_lines}A')
sys.stdout.write(chart)
if not chart.endswith('\n'):
sys.stdout.write('\n')
sys.stdout.flush()
self._rendered = True

View File

@@ -9,15 +9,27 @@ from retro_gamer.observation import encode_observation
class GameEnvironment:
"""Gym-style wrapper around a retro Game for RL training.
"""Gym-style wrapper around a retro game for RL training."""
Provides reset() / step(action) / observe(), managing one training episode
at a time. The game is restarted by calling the factory function on each reset.
"""
def __init__(self, game_factory: Callable, metadata: GameMetadata):
def __init__(
self,
game_factory: Callable,
metadata: GameMetadata,
observe_state: list[str] | None = None,
egocentric: bool = False,
egocentric_player: str | None = None,
egocentric_radius: int | None = None,
board: bool = True,
observe_state_sizes: dict[str, int] | None = None,
):
self.game_factory = game_factory
self.metadata = metadata
self.observe_state = observe_state or []
self.egocentric = egocentric
self.egocentric_player = egocentric_player
self.egocentric_radius = egocentric_radius
self.board = board
self.observe_state_sizes = observe_state_sizes or {}
self.game = None
self.view: HeadlessView | None = None
self.inp: ProgrammaticInput | None = None
@@ -35,12 +47,7 @@ class GameEnvironment:
return self._observe()
def step(self, action: str | None) -> tuple[np.ndarray, float, bool]:
"""Advance one turn with the given action.
action: a keystroke string (e.g. 'KEY_RIGHT') or None for no-op.
Returns (observation, reward, done).
Reward is the change in the reward state key since the previous step.
"""
"""Advance one turn. Returns (observation, reward, done)."""
self.inp.press(action)
self.game.step()
obs = self._observe()
@@ -49,12 +56,47 @@ class GameEnvironment:
return obs, reward, done
def _observe(self) -> np.ndarray:
state = dict(self.game.state)
if self.observe_state_sizes:
self._check_state_sizes(state)
player_pos = None
if self.egocentric and self.egocentric_player:
agent = self.game.get_agent_by_name(self.egocentric_player)
if agent is not None:
player_pos = agent.position
return encode_observation(
self.view.board_characters,
dict(self.game.state),
state,
self.metadata,
self.observe_state,
player_pos=player_pos,
egocentric_radius=self.egocentric_radius,
board=self.board,
)
def _check_state_sizes(self, state: dict):
for key, expected in self.observe_state_sizes.items():
val = state.get(key)
if val is None:
actual = 0
elif isinstance(val, (list, tuple)):
actual = len(val)
else:
actual = 1
if actual != expected:
raise ValueError(
f"State key '{key}' changed size during training:\n"
f" Expected : {expected} (discovered at training start)\n"
f" Got : {actual}\n\n"
f"This means game.state['{key}'] has a different length in some\n"
f"episodes than it had when training started. The neural network\n"
f"has a fixed input size and cannot adapt to changing state shapes.\n\n"
f"Fix: make sure create_game() always initializes '{key}' with a\n"
f"fixed-length value before the game starts each episode.\n"
f"For example, if '{key}' is a list of 9 values, it must always be\n"
f"a list of exactly 9 values — never more, never fewer, never missing."
)
def _delta_reward(self) -> float:
current = float(self.game.state.get(self.metadata.reward, 0))
delta = current - self._prev_reward
@@ -62,10 +104,8 @@ class GameEnvironment:
return delta
def discover_character_set(self, exploration_turns: int) -> list[str]:
"""Run random turns to discover the characters that appear on the board.
Returns the sorted character list (excluding space).
"""
obs = self.reset()
"""Run random turns to discover the characters that appear on the board."""
self.reset()
chars: set[str] = set()
for _ in range(exploration_turns):
for row in self.view.board_characters:

View File

@@ -1,17 +0,0 @@
from retro.game import Game
from retro_gamer.examples.beast.board import Board
WIDTH = 40
HEIGHT = 20
NUM_BEASTS = 10
def create_game():
"""Return a fresh, initialized Beast game."""
board = Board(WIDTH, HEIGHT, num_beasts=NUM_BEASTS)
state = {'beasts_killed': 0}
game = Game(board.get_agents(), state, board_size=(WIDTH, HEIGHT))
game.num_beasts = NUM_BEASTS
return game
if __name__ == '__main__':
create_game().play()

View File

@@ -1,67 +0,0 @@
from retro_gamer.examples.beast.helpers import add, distance, get_occupant
from random import random, choice
class Beast:
"""A beast that hunts the player."""
character = "H"
color = "red"
probability_of_moving = 0.03
probability_of_random_move = 0.2
deadly = True
def __init__(self, position):
self.position = position
def handle_push(self, vector, game):
future_position = add(self.position, vector)
on_board = game.on_board(future_position)
obstacle = get_occupant(game, future_position)
if obstacle or not on_board:
self.die(game)
return True
else:
return False
def play_turn(self, game):
if self.should_move():
possible_moves = []
for position in self.get_adjacent_positions():
if game.is_empty(position) and game.on_board(position):
possible_moves.append(position)
if possible_moves:
if self.should_move_randomly():
self.position = choice(possible_moves)
else:
self.position = self.choose_best_move(possible_moves, game)
player = game.get_agent_by_name("player")
if player and player.position == self.position:
player.die(game)
def get_adjacent_positions(self):
"""Returns all eight adjacent positions, including diagonals."""
positions = []
for i in [-1, 0, 1]:
for j in [-1, 0, 1]:
if i or j:
positions.append(add(self.position, (i, j)))
return positions
def should_move(self):
return random() < self.probability_of_moving
def should_move_randomly(self):
return random() < self.probability_of_random_move
def choose_best_move(self, possible_moves, game):
player = game.get_agent_by_name("player")
move_distances = [[distance(player.position, move), move] for move in possible_moves]
shortest_distance, best_move = sorted(move_distances)[0]
return best_move
def die(self, game):
game.remove_agent(self)
game.num_beasts -= 1
game.state['beasts_killed'] += 1
if game.num_beasts == 0:
game.state["message"] = "You win!"
game.end()

View File

@@ -1,25 +0,0 @@
from retro_gamer.examples.beast.helpers import add, get_occupant
class Block:
"""A static block that can be pushed by the player."""
character = ""
color = "green4"
deadly = False
def __init__(self, position):
self.position = position
def handle_push(self, vector, game):
"""Responds to a push in the direction of vector.
Returns True when the push succeeds in creating empty space.
"""
future_position = add(self.position, vector)
on_board = game.on_board(future_position)
obstacle = get_occupant(game, future_position)
if obstacle:
success = obstacle.handle_push(vector, game)
else:
success = on_board
if success:
self.position = future_position
return success

View File

@@ -1,39 +0,0 @@
from retro_gamer.examples.beast.helpers import add, get_occupant
direction_vectors = {
"KEY_RIGHT": (1, 0),
"KEY_UP": (0, -1),
"KEY_LEFT": (-1, 0),
"KEY_DOWN": (0, 1),
}
class Player:
character = "*"
color = "white"
name = "player"
deadly = False
def __init__(self, position):
self.position = position
def handle_keystroke(self, keystroke, game):
if keystroke.name in direction_vectors:
vector = direction_vectors[keystroke.name]
self.try_to_move(vector, game)
def try_to_move(self, vector, game):
future_position = add(self.position, vector)
on_board = game.on_board(future_position)
obstacle = get_occupant(game, future_position)
if obstacle:
if obstacle.deadly:
self.die(game)
elif obstacle.handle_push(vector, game):
self.position = future_position
elif on_board:
self.position = future_position
def die(self, game):
self.color = "black_on_red"
game.state["message"] = "The beasties win!"
game.end()

View File

@@ -1,44 +0,0 @@
from random import shuffle
from retro_gamer.examples.beast.agents.player import Player
from retro_gamer.examples.beast.agents.beast import Beast
from retro_gamer.examples.beast.agents.block import Block
class Board:
"""Creates the agents needed at the beginning of the game and assigns their positions."""
def __init__(self, width, height, block_density=0.3, num_beasts=10):
self.width = width
self.height = height
self.block_density = block_density
self.num_blocks = round(width * height * block_density)
self.num_empty_spaces = width * height - self.num_blocks
self.num_beasts = num_beasts
self.validate()
def validate(self):
if self.block_density < 0 or self.block_density > 1:
raise ValueError("block density must be between 0 and 1.")
if self.num_empty_spaces < self.num_beasts + 1:
raise ValueError("Not enough space on the board.")
def get_agents(self):
"""Returns a list of agents initialized in their starting positions."""
positions = self.get_all_positions()
shuffle(positions)
player_position = positions[0]
beast_positions = positions[1:self.num_beasts + 1]
block_positions = positions[-self.num_blocks:]
player = [Player(player_position)]
beasts = [Beast(pos) for pos in beast_positions]
blocks = [Block(pos) for pos in block_positions]
return player + beasts + blocks
def get_all_positions(self):
"""Returns a list of all positions on the board."""
positions = []
for i in range(self.width):
for j in range(self.height):
positions.append((i, j))
return positions

View File

@@ -1,18 +0,0 @@
def add(vec0, vec1):
"""Adds two vectors."""
x0, y0 = vec0
x1, y1 = vec1
return (x0 + x1, y0 + y1)
def get_occupant(game, position):
"""Returns the agent at position, if there is one."""
positions_with_agents = game.get_agents_by_position()
if position in positions_with_agents:
agents_at_position = positions_with_agents[position]
return agents_at_position[0]
def distance(vec0, vec1):
"""Returns the Manhattan distance between two positions."""
x0, y0 = vec0
x1, y1 = vec1
return abs(x1 - x0) + abs(y1 - y0)

View File

@@ -1,6 +0,0 @@
[tool.retro-gamer]
actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]
reward = "beasts_killed"
character_set = ["*", "H", "█"]
spatial = true
observe_state = []

30
retro_gamer/log_parser.py Normal file
View File

@@ -0,0 +1,30 @@
from __future__ import annotations
import re
from pathlib import Path
_LINE_RE = re.compile(
r'\[ep_(\d+)\]'
r'.*avg_reward=([+-]?\d+\.?\d*)'
r'.*avg_steps=(\d+\.?\d*)'
r'.*epsilon=(\d+\.?\d*)'
r'.*avg_loss=(\d+\.?\d*)'
)
def parse_checkpoints(log_path: Path) -> list[dict]:
"""Parse checkpoint lines from a training log. Returns a list of dicts
with keys: episode, avg_reward, avg_steps, epsilon, avg_loss."""
results = []
if not log_path.exists():
return results
for line in log_path.read_text().splitlines():
m = _LINE_RE.search(line)
if m:
results.append({
'episode': int(m.group(1)),
'avg_reward': float(m.group(2)),
'avg_steps': float(m.group(3)),
'epsilon': float(m.group(4)),
'avg_loss': float(m.group(5)),
})
return results

View File

@@ -11,39 +11,58 @@ class GameMetadata:
"""Describes a retro game for training purposes.
Required fields: actions, reward.
Optional fields: character_set, spatial, observe_state.
Discovered fields: board_size (read from game.board_size at startup).
Metadata is read from the game's own pyproject.toml under the
[tool.retro-gamer] section via GameMetadata.from_pyproject().
Optional fields: character_set, spatial.
Discovered fields: board_size (from game.board_size), extras_size (from
the observe_state list in [preprocessing]).
"""
actions: list[str]
reward: str
character_set: list[str] | None = None
spatial: bool = True
observe_state: list[str] = field(default_factory=list)
spatial: bool = False
board: bool = True
board_size: tuple[int, int] | None = None
extras_size: int = 0
def validate(self):
if not self.actions:
raise ValueError("actions must be a non-empty list")
raise ValueError(
"The 'actions' list in [tool.retro-gamer] is empty or missing.\n"
"It should list the keyboard keys your agent can press, for example:\n\n"
' actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]\n\n'
"The agent will learn which actions lead to higher rewards."
)
if not isinstance(self.actions, list) or not all(isinstance(a, str) for a in self.actions):
raise ValueError(
f"'actions' must be a list of strings, but got: {self.actions!r}\n"
"Each entry should be a key name like \"KEY_RIGHT\" or \"KEY_SPACE\"."
)
if not self.reward:
raise ValueError("reward must be a non-empty string")
if self.reward in self.observe_state:
raise ValueError(f"reward key '{self.reward}' must not appear in observe_state")
raise ValueError(
"The 'reward' field in [tool.retro-gamer] is empty or missing.\n"
"It should name a game state variable whose value the agent is trying\n"
"to maximize — for example:\n\n"
" reward = \"score\"\n\n"
"The trainer watches how this value changes each step and uses those\n"
"changes as the reward signal."
)
if self.character_set is not None:
if not isinstance(self.character_set, list):
raise ValueError(
f"'character_set' must be a list of single characters, but got: {self.character_set!r}\n"
"Example: character_set = [\"@\", \"*\", \"#\"]"
)
for ch in self.character_set:
if len(ch) != 1:
raise ValueError(f"character_set entries must be single characters, got {ch!r}")
if not isinstance(ch, str) or len(ch) != 1:
raise ValueError(
f"Every entry in character_set must be a single character, but got {ch!r}.\n"
"Each character represents one type of cell on the game board.\n"
"If you're not sure what characters your game uses, remove character_set\n"
"entirely and the trainer will discover them automatically."
)
@classmethod
def from_pyproject(cls, module_name: str) -> GameMetadata:
"""Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml.
Finds pyproject.toml by walking up from the module's source file.
Raises FileNotFoundError if no pyproject.toml is found, or ValueError
if the file exists but has no [tool.retro-gamer] section.
"""
"""Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml."""
pyproject_path = _find_pyproject(module_name)
if pyproject_path is None:
raise FileNotFoundError(
@@ -65,13 +84,23 @@ class GameMetadata:
@classmethod
def from_dict(cls, d: dict) -> GameMetadata:
missing = [k for k in ('actions', 'reward') if k not in d]
if missing:
fields = ' and '.join(f"'{k}'" for k in missing)
raise ValueError(
f"The [tool.retro-gamer] section is missing required {fields}.\n"
"A minimal configuration looks like this:\n\n"
"[tool.retro-gamer]\n"
'actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]\n'
'reward = "score"\n\n'
"See the documentation for all available options."
)
board_size = tuple(d['board_size']) if 'board_size' in d else None
return cls(
actions=d['actions'],
reward=d['reward'],
character_set=d.get('character_set'),
spatial=d.get('spatial', True),
observe_state=d.get('observe_state', []),
spatial=d.get('spatial', False),
board_size=board_size,
)
@@ -79,8 +108,6 @@ class GameMetadata:
d = {
'actions': self.actions,
'reward': self.reward,
'spatial': self.spatial,
'observe_state': self.observe_state,
}
if self.board_size is not None:
d['board_size'] = list(self.board_size)
@@ -95,9 +122,11 @@ class GameMetadata:
@property
def obs_size(self) -> int:
"""Total size of the flat observation vector."""
if not self.board:
return self.extras_size
C = len(self.character_set) if self.character_set else 0
bw, bh = self.board_size
return C * bw * bh + len(self.observe_state)
return C * bw * bh + self.extras_size
@property
def n_actions(self) -> int:
@@ -118,4 +147,6 @@ def _find_pyproject(module_name: str) -> Path | None:
candidate = parent / 'pyproject.toml'
if candidate.exists():
return candidate
if parent.name == 'site-packages':
break
return None

139
retro_gamer/model_agent.py Normal file
View File

@@ -0,0 +1,139 @@
from __future__ import annotations
import tomllib
import torch
from pathlib import Path
from blessed.keyboard import Keystroke
from retro.input import ProgrammaticInput
from retro.views.headless import HeadlessView
from retro_gamer.metadata import GameMetadata
from retro_gamer.observation import encode_observation
class TrainedPolicy:
"""A trained retro-gamer model that can observe a game and choose actions.
Load from a training run directory, then call ``get_action(game)`` from
inside any agent's ``play_turn`` to get the model's recommended key.
Example::
from retro_gamer import TrainedPolicy
_ai = TrainedPolicy("runs/enemy/")
class EnemyAgent:
def play_turn(self, game):
key = _ai.get_action(game)
if key == 'KEY_RIGHT': self.direction = (1, 0)
...
"""
def __init__(self, run_dir: str | Path, checkpoint: str | None = None):
from retro_gamer.network import build_network
from retro_gamer.trainer import DEFAULTS
run_dir = Path(run_dir)
config_path = run_dir / 'config.toml'
if not config_path.exists():
raise FileNotFoundError(f"No config.toml found in {run_dir}")
with open(config_path, 'rb') as f:
config = tomllib.load(f)
self._metadata = GameMetadata.from_dict(config['metadata'])
pre = config.get('preprocessing', {})
self._metadata.spatial = pre.get('spatial', False)
self._metadata.board = pre.get('board', True)
observe_state_sizes = pre.get('observe_state_sizes', {})
self._observe_state: list[str] = pre.get('observe_state', [])
self._egocentric: bool = pre.get('egocentric', False)
self._egocentric_player: str | None = pre.get('egocentric_player')
self._egocentric_radius: int | None = pre.get('egocentric_radius')
self._board: bool = pre.get('board', True)
if observe_state_sizes:
self._metadata.extras_size = sum(observe_state_sizes.values())
else:
self._metadata.extras_size = len(self._observe_state)
hyperparams = {**DEFAULTS, **config.get('model', {}), **config.get('training', {})}
self._model, _ = build_network(self._metadata, hyperparams)
if checkpoint is not None:
ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt'
ckpt_path = run_dir / 'checkpoints' / ckpt_name
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
else:
ckpt_dir = run_dir / 'checkpoints'
candidates = sorted(ckpt_dir.glob('ep_*.pt')) if ckpt_dir.exists() else []
if not candidates:
raise FileNotFoundError(f"No checkpoints found in {ckpt_dir}")
ckpt_path = candidates[-1]
ckpt = torch.load(ckpt_path, weights_only=True)
self._model.load_state_dict(ckpt['model_state_dict'])
self._model.eval()
def get_action(self, game) -> str | None:
"""Return the key the model recommends this turn, or None for no-op."""
view = HeadlessView()
view.on_game_start(game)
view.render(game)
board_chars = view.board_characters
player_pos = None
if self._egocentric and self._egocentric_player:
agent = game.get_agent_by_name(self._egocentric_player)
if agent is not None:
player_pos = agent.position
obs = encode_observation(
board_chars,
dict(game.state),
self._metadata,
self._observe_state,
player_pos=player_pos,
egocentric_radius=self._egocentric_radius,
board=self._board,
)
device = next(self._model.parameters()).device
state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
with torch.no_grad():
action_idx = int(self._model(state_t).argmax().item())
if action_idx >= len(self._metadata.actions):
return None
return self._metadata.actions[action_idx]
def _keystroke(name: str) -> Keystroke:
if name.startswith("KEY_"):
return Keystroke(ucs='', code=None, name=name)
return Keystroke(ucs=name, code=None, name=None)
class PolicyInput:
"""An InputSource that drives the game with a TrainedPolicy instead of the keyboard.
Pass it as ``input_source`` to ``game.play()`` and everything else works
exactly as usual.
Example::
from retro_gamer import TrainedPolicy, PolicyInput
ai = TrainedPolicy("runs/snake/")
game = create_game()
game.play(input_source=PolicyInput(ai, game))
"""
def __init__(self, model: TrainedPolicy, game):
self._model = model
self._game = game
self._inp = ProgrammaticInput()
def collect(self) -> set:
key = self._model.get_action(self._game)
self._inp.press(key)
return self._inp.collect()

View File

@@ -13,32 +13,36 @@ def build_network(
Returns (model, rationale) where rationale is a multi-line string
describing the architecture and the reasoning behind each choice.
"""
n_layers = hyperparams.get('n_layers', 2)
layer_size = hyperparams.get('layer_size', 128)
C = len(metadata.character_set)
bw, bh = metadata.board_size
W, H = bw, bh
n_state = len(metadata.observe_state)
hidden_sizes = hyperparams.get('hidden_sizes', [512, 256])
n_state = metadata.extras_size
n_actions = metadata.n_actions
lines = []
lines.append("[INIT] === Network Architecture ===")
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
lines.append(f"[INIT] Observed state keys: {n_state} | Actions (incl. no-op): {n_actions}")
if metadata.spatial:
model = _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines)
if metadata.board:
C = len(metadata.character_set)
bw, bh = metadata.board_size
W, H = bw, bh
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
lines.append(f"[INIT] Observed state features: {n_state} | Actions (incl. no-op): {n_actions}")
if metadata.spatial:
model = _build_spatial(C, H, W, n_state, hidden_sizes, n_actions, lines)
else:
obs_size = C * W * H + n_state
model = _build_flat(obs_size, hidden_sizes, n_actions, lines)
else:
obs_size = C * W * H + n_state
model = _build_flat(obs_size, n_layers, layer_size, n_actions, lines)
lines.append(f"[INIT] Board: disabled (board=false, state-only observation)")
lines.append(f"[INIT] Observed state features: {n_state} | Actions (incl. no-op): {n_actions}")
model = _build_flat(n_state, hidden_sizes, n_actions, lines)
lines.append(f"[INIT] Hidden layers: {n_layers} | Layer width: {layer_size}")
lines.append(f"[INIT] Hidden layers: {len(hidden_sizes)} | Layer sizes: {hidden_sizes}")
lines.append(f"[INIT] Output: {n_actions} Q-values")
lines.append(f"[INIT] Actions: {metadata.actions} + (no-op)")
return model, '\n'.join(lines)
def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
def _build_spatial(C, H, W, n_state, hidden_sizes, n_actions, lines):
lines.append("[INIT] spatial=True → using CNN architecture")
lines.append("[INIT] Rationale: the board is a 2-D spatial scene; a CNN captures")
lines.append("[INIT] local patterns (walls, items nearby) more efficiently than an MLP.")
@@ -47,20 +51,20 @@ def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
lines.append(f"[INIT] CNN output: 64 channels × {H}×{W} = {conv_out} features (flattened)")
mlp_in = conv_out + n_state
lines.append(f"[INIT] MLP head input: {conv_out} (conv) + {n_state} (state) = {mlp_in}")
lines.append(f"[INIT] MLP: {''.join([str(mlp_in)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
return _SpatialNet(C, H, W, n_state, n_layers, layer_size, n_actions)
lines.append(f"[INIT] MLP: {''.join([str(mlp_in)] + [str(s) for s in hidden_sizes] + [str(n_actions)])}")
return _SpatialNet(C, H, W, n_state, hidden_sizes, n_actions)
def _build_flat(obs_size, n_layers, layer_size, n_actions, lines):
def _build_flat(obs_size, hidden_sizes, n_actions, lines):
lines.append("[INIT] spatial=False → using MLP architecture")
lines.append("[INIT] Rationale: the board encodes UI/status rather than a spatial scene;")
lines.append("[INIT] a flat MLP over the full observation is sufficient.")
lines.append(f"[INIT] MLP: {''.join([str(obs_size)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
return _FlatNet(obs_size, n_layers, layer_size, n_actions)
lines.append(f"[INIT] MLP: {''.join([str(obs_size)] + [str(s) for s in hidden_sizes] + [str(n_actions)])}")
return _FlatNet(obs_size, hidden_sizes, n_actions)
class _SpatialNet(nn.Module):
def __init__(self, C, H, W, n_state, n_layers, layer_size, n_actions):
def __init__(self, C, H, W, n_state, hidden_sizes, n_actions):
super().__init__()
self.C, self.H, self.W = C, H, W
self.n_board = C * H * W
@@ -73,10 +77,11 @@ class _SpatialNet(nn.Module):
conv_out = 64 * H * W
mlp_in = conv_out + n_state
layers: list[nn.Module] = []
for i in range(n_layers):
in_size = mlp_in if i == 0 else layer_size
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
layers.append(nn.Linear(layer_size, n_actions))
prev = mlp_in
for size in hidden_sizes:
layers += [nn.Linear(prev, size), nn.ReLU()]
prev = size
layers.append(nn.Linear(prev, n_actions))
self.mlp = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -87,13 +92,14 @@ class _SpatialNet(nn.Module):
class _FlatNet(nn.Module):
def __init__(self, obs_size, n_layers, layer_size, n_actions):
def __init__(self, obs_size, hidden_sizes, n_actions):
super().__init__()
layers: list[nn.Module] = []
for i in range(n_layers):
in_size = obs_size if i == 0 else layer_size
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
layers.append(nn.Linear(layer_size, n_actions))
prev = obs_size
for size in hidden_sizes:
layers += [nn.Linear(prev, size), nn.ReLU()]
prev = size
layers.append(nn.Linear(prev, n_actions))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:

View File

@@ -3,11 +3,7 @@ from retro_gamer.metadata import GameMetadata
def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
"""One-hot encode the board.
Returns an array of shape (H, W, C) where C = len(character_set).
Unknown characters produce a zero vector.
"""
"""One-hot encode the board. Returns (H, W, C). Unknown characters → zero vector."""
char_to_idx = {c: i for i, c in enumerate(character_set)}
H = len(board_chars)
W = len(board_chars[0]) if board_chars else 0
@@ -22,28 +18,78 @@ def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.n
def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
"""Extract observed state keys into a 1D float array."""
return np.array([float(state.get(k, 0)) for k in observe_state], dtype=np.float32)
"""Extract selected keys from game.state into a 1D float array.
Scalar values contribute one element; list/tuple values are flattened.
"""
values: list[float] = []
for k in observe_state:
val = state.get(k, 0)
if isinstance(val, (list, tuple)):
values.extend(float(x) for x in val)
else:
values.append(float(val))
return np.array(values, dtype=np.float32)
def egocentric_board(
board_chars: list[list[str]],
player_pos: tuple[int, int],
radius: int,
) -> list[list[str]]:
"""Crop the board to a (2r+1)×(2r+1) window centred on player_pos.
Out-of-bounds cells are filled with a space (treated as empty by the
encoder). The resulting grid is always square with side 2*radius+1.
"""
H = len(board_chars)
W = len(board_chars[0]) if board_chars else 0
px, py = player_pos
result = []
for dy in range(-radius, radius + 1):
row = []
for dx in range(-radius, radius + 1):
src_x = px + dx
src_y = py + dy
if 0 <= src_x < W and 0 <= src_y < H:
row.append(board_chars[src_y][src_x])
else:
row.append(' ')
result.append(row)
return result
def encode_observation(
board_chars: list[list[str]],
state: dict,
metadata: GameMetadata,
observe_state: list[str],
player_pos: tuple[int, int] | None = None,
egocentric_radius: int | None = None,
board: bool = True,
) -> np.ndarray:
"""Encode board + state into a flat 1D observation vector.
"""Encode board and/or selected state values into a flat 1D observation vector.
For spatial games the board is encoded channel-first (C, H, W) then flattened,
so the network can reshape it back for CNN processing. For non-spatial games the
board is encoded (H, W, C) then flattened.
The state vector is appended at the end in both cases.
When *board* is True the board is encoded and prepended to the vector. If
player_pos and egocentric_radius are given the board is first cropped to a
(2r+1)×(2r+1) window centred on the player. For spatial games the board is
encoded channel-first (C, H, W) then flattened; for non-spatial games it is
encoded (H, W, C) then flattened. The state vector is appended at the end.
When *board* is False only the observe_state features are returned.
"""
if not metadata.character_set:
raise ValueError("character_set must be set before encoding observations")
board = encode_board(board_chars, metadata.character_set) # (H, W, C)
if metadata.spatial:
board_vec = board.transpose(2, 0, 1).flatten() # C*H*W, channel-first
if board:
if not metadata.character_set:
raise ValueError("character_set must be set before encoding observations")
if player_pos is not None and egocentric_radius is not None:
board_chars = egocentric_board(board_chars, player_pos, egocentric_radius)
board_enc = encode_board(board_chars, metadata.character_set) # (H, W, C)
if metadata.spatial:
board_vec = board_enc.transpose(2, 0, 1).flatten()
else:
board_vec = board_enc.flatten()
if observe_state:
return np.concatenate([board_vec, encode_state(state, observe_state)])
return board_vec
else:
board_vec = board.flatten() # H*W*C
state_vec = encode_state(state, metadata.observe_state)
return np.concatenate([board_vec, state_vec])
return encode_state(state, observe_state)

55
retro_gamer/plotter.py Normal file
View File

@@ -0,0 +1,55 @@
from __future__ import annotations
from pathlib import Path
from retro_gamer.log_parser import parse_checkpoints
def plot_run(log_path: Path, output: Path | None = None) -> None:
"""Generate training metric plots from a training.log file.
Displays an interactive window unless *output* is given, in which case
the figure is saved to that path (PNG, PDF, SVG, etc.).
"""
import matplotlib.pyplot as plt
import seaborn as sns
data = parse_checkpoints(log_path)
if not data:
raise ValueError(f"No checkpoint data found in {log_path}")
episodes = [d['episode'] for d in data]
rewards = [d['avg_reward'] for d in data]
steps = [d['avg_steps'] for d in data]
losses = [d['avg_loss'] for d in data]
epsilons = [d['epsilon'] for d in data]
sns.set_theme(style='darkgrid')
fig, axes = plt.subplots(2, 2, figsize=(12, 7))
(ax_reward, ax_steps), (ax_loss, ax_epsilon) = axes
ax_reward.plot(episodes, rewards)
ax_reward.axhline(0, color='gray', linestyle='--', linewidth=0.8, alpha=0.6)
ax_reward.set_title('Average Reward')
ax_reward.set_xlabel('Episode')
ax_steps.plot(episodes, steps, color='C1')
ax_steps.set_title('Average Steps')
ax_steps.set_xlabel('Episode')
ax_loss.plot(episodes, losses, color='C2')
ax_loss.set_yscale('log')
ax_loss.set_title('Average Loss')
ax_loss.set_xlabel('Episode')
ax_epsilon.plot(episodes, epsilons, color='C3')
ax_epsilon.set_title('Epsilon (exploration rate)')
ax_epsilon.set_xlabel('Episode')
ax_epsilon.set_ylim(0, 1)
fig.suptitle(f'Training: {log_path.parent.name}', fontsize=13)
plt.tight_layout()
if output:
plt.savefig(output, dpi=150, bbox_inches='tight')
print(f"Plot saved to {output}")
else:
plt.show()

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import random
from datetime import datetime
from pathlib import Path
from time import perf_counter
from typing import Callable
import numpy as np
import torch
@@ -8,40 +10,250 @@ import torch.nn as nn
import torch.optim as optim
import tomli_w
from tqdm import tqdm
from retro_gamer.metadata import GameMetadata
from retro_gamer.env import GameEnvironment
from retro_gamer.network import build_network
from retro_gamer.memory import ReplayMemory, PrioritizedReplayMemory
MODEL_KEYS: frozenset = frozenset({'hidden_sizes'})
DEFAULTS: dict = {
'learning_rate': 1e-3,
'lr_decay': 0.995,
# [model]
'hidden_sizes': [128, 64],
# [training]
'learning_rate': 1e-4,
'learning_rate_decay': 0.9999,
'gamma': 0.99,
'epsilon': 1.0,
'epsilon_decay': 0.995,
'epsilon_decay': 0.9997,
'epsilon_min': 0.05,
'batch_size': 64,
'memory_capacity': 10_000,
'target_update_freq': 100,
'training_episodes': 1_000,
'n_layers': 2,
'layer_size': 128,
'prioritize_experiences': False,
'memory_capacity': 50_000,
'target_update_freq': 500,
'train_every': 4,
'training_episodes': 20_000,
'prioritize_experiences': True,
'exploration_turns': 200,
'unknown_character_strategy': 'ignore',
'max_turns_per_episode': 2_000,
}
def _get_device() -> torch.device:
if torch.backends.mps.is_available():
return torch.device('mps')
if torch.cuda.is_available():
return torch.device('cuda')
return torch.device('cpu')
# Fields that make an existing checkpoint incompatible with the current config.
# Changing any of these requires starting training from scratch.
_INCOMPATIBLE_METADATA = {
'actions': 'the list of actions the agent can take (changes output layer size)',
'reward': 'the reward signal — Q-values trained on the old signal are meaningless for the new one',
'character_set': 'the set of board characters (changes input layer size)',
'board_size': 'the board dimensions (changes input layer size)',
}
_INCOMPATIBLE_PREPROCESSING = {
'spatial': 'spatial vs non-spatial network type (changes network architecture)',
'board': 'whether the board is included in the observation (changes input size)',
'observe_state': 'the state keys included in the observation (changes input size)',
'observe_state_sizes': 'the size of each observed state key (changes input layer size)',
'egocentric': 'egocentric board transformation (changes input representation)',
'egocentric_player': 'the agent used as the egocentric center (changes input representation)',
'egocentric_radius': 'the egocentric crop radius (changes input layer size)',
}
_INCOMPATIBLE_ARCH = {
'hidden_sizes': 'the hidden layer sizes (changes network shape)',
}
def validate_hyperparams(hp: dict):
"""Check all hyperparameters and raise ValueError listing every problem found."""
problems = []
def _problem(heading: str, explanation: str, fix: str | None = None):
text = f" {heading}\n {explanation}"
if fix:
text += f"\n{fix}"
problems.append(text)
hs = hp.get('hidden_sizes')
if not isinstance(hs, list) or len(hs) == 0:
_problem(
f"hidden_sizes = {hs!r}",
"This sets the shape of the neural network's hidden layers. It must be a\n"
" non-empty list of positive integers, one number per layer.",
"Try: hidden_sizes = [512, 256]",
)
elif any(not isinstance(s, int) or s <= 0 for s in hs):
bad = [s for s in hs if not isinstance(s, int) or s <= 0]
_problem(
f"hidden_sizes = {hs!r}",
f"Every layer size must be a positive integer, but got {bad!r}.\n"
" Each number is the count of neurons in that layer — more neurons means\n"
" more capacity to learn complex patterns.",
"Try: hidden_sizes = [512, 256]",
)
lr = hp.get('learning_rate')
if not isinstance(lr, (int, float)) or lr <= 0:
_problem(
f"learning_rate = {lr!r}",
"The learning rate controls how much the network adjusts its weights after\n"
" each training step. Too high and training becomes unstable; too low and\n"
" it learns very slowly. It must be a positive number.",
"Typical values are between 0.0001 and 0.01. Try: learning_rate = 0.001",
)
for key, blurb in [
('learning_rate_decay',
"After each episode, the learning rate is multiplied by this value, gradually\n"
" slowing learning over time. A value of 1.0 means no decay; closer to 0\n"
" means very aggressive decay. It must be greater than 0 and at most 1."),
('gamma',
"Gamma is the discount factor: how much the agent values future rewards versus\n"
" immediate ones. 0.99 means future rewards are nearly as important as now;\n"
" 0.0 means the agent only cares about the very next step.\n"
" It must be greater than 0 and at most 1."),
('epsilon_decay',
"Each episode, the exploration rate (epsilon) is multiplied by this to gradually\n"
" reduce random actions over time. It must be between 0 (exclusive) and 1."),
]:
v = hp.get(key)
if not isinstance(v, (int, float)) or not (0 < v <= 1):
_problem(
f"{key} = {v!r}",
blurb,
f"Try a value close to but less than 1, like {key} = 0.995",
)
for key, blurb in [
('epsilon',
"Epsilon is the probability of taking a random action (exploration vs.\n"
" exploitation). It starts high — usually 1.0, meaning fully random —\n"
" and decays toward epsilon_min during training. It must be between 0 and 1."),
('epsilon_min',
"This is the lowest exploration rate allowed. Even after lots of training, the\n"
" agent keeps at least this much randomness so it keeps discovering new things.\n"
" It must be between 0 and 1."),
]:
v = hp.get(key)
if not isinstance(v, (int, float)) or not (0 <= v <= 1):
_problem(
f"{key} = {v!r}",
blurb,
f"Try: {key} = {'0.05' if 'min' in key else '1.0'}",
)
eps = hp.get('epsilon')
eps_min = hp.get('epsilon_min')
if (isinstance(eps, (int, float)) and isinstance(eps_min, (int, float))
and 0 <= eps <= 1 and 0 <= eps_min <= 1 and eps_min > eps):
_problem(
f"epsilon_min = {eps_min!r} is greater than epsilon = {eps!r}",
"epsilon is the starting exploration rate and epsilon_min is the floor it\n"
" decays toward, so epsilon_min must be less than or equal to epsilon.",
f"Try: epsilon = 1.0 and epsilon_min = 0.05",
)
for key, blurb in [
('batch_size',
"Training samples this many past experiences from the replay buffer at once\n"
" to compute a learning update. Must be a positive integer."),
('memory_capacity',
"The replay buffer stores this many past experiences. When it fills up, the\n"
" oldest are discarded. A larger buffer means more diverse training data.\n"
" Must be a positive integer."),
('target_update_freq',
"The target network (a stable copy of the Q-network used to compute targets)\n"
" is updated every this many steps. Must be a positive integer."),
('train_every',
"A training step runs once every this many game steps. This lets the agent\n"
" collect several new experiences before updating. Must be a positive integer."),
('training_episodes',
"The total number of episodes (games) to train for. Must be a positive integer."),
('max_turns_per_episode',
"If a game episode hasn't ended naturally after this many steps, it's cut\n"
" short. This prevents a buggy or stuck agent from running forever.\n"
" Must be a positive integer."),
]:
v = hp.get(key)
if not isinstance(v, int) or v <= 0:
_problem(
f"{key} = {v!r}",
blurb,
f"Try: {key} = {DEFAULTS[key]}",
)
v = hp.get('exploration_turns')
if not isinstance(v, int) or v < 0:
_problem(
f"exploration_turns = {v!r}",
"When no character_set is specified, the trainer runs this many random turns\n"
" to discover what characters appear on the board. Must be 0 or more\n"
" (0 skips discovery entirely, which only works if character_set is set).",
f"Try: exploration_turns = {DEFAULTS['exploration_turns']}",
)
bs = hp.get('batch_size')
mc = hp.get('memory_capacity')
if (isinstance(bs, int) and bs > 0 and isinstance(mc, int) and mc > 0 and bs > mc):
_problem(
f"batch_size = {bs} is larger than memory_capacity = {mc}",
"Training samples a batch of past experiences from the replay buffer each\n"
" step, so the buffer must be able to hold at least as many experiences\n"
f" as the batch size. With batch_size = {bs}, the buffer holds {mc}\n"
" that's not enough to sample from.",
f"Try: memory_capacity = {max(bs * 100, DEFAULTS['memory_capacity'])} "
f"(a much larger buffer also improves learning quality)",
)
if problems:
n = len(problems)
noun = "problem" if n == 1 else "problems"
header = f"Found {n} {noun} in your training configuration:\n\n"
footer = "\n\nFix these in config.toml, then run 'retro-gamer train' again."
raise ValueError(header + "\n\n".join(problems) + footer)
def _format_duration(seconds: float) -> str:
m, s = divmod(int(seconds), 60)
h, m = divmod(m, 60)
if h:
return f"{h}h{m:02d}m{s:02d}s"
return f"{m}m{s:02d}s"
class DQNTrainer:
"""Trains a deep Q-network agent to play a retro game.
On initialization the trainer:
1. Discovers the character set (if not already specified in metadata).
2. Builds the Q-network and logs the full architecture with rationale.
3. Saves config.toml and starts training.log in run_dir.
Automatically selects the best available training device: Apple Silicon
GPU (MPS), NVIDIA GPU (CUDA), or CPU. The chosen device is recorded in
``training.log``.
Call train() to run all episodes and save checkpoints.
On initialization, the trainer:
1. Discovers the character set if not already specified in *metadata*.
2. Builds the Q-network and logs its full architecture with rationale.
3. Writes ``config.toml`` and initializes ``training.log`` in *run_dir*.
Hyperparameters can be passed as keyword arguments; see the
:ref:`hyperparameters` reference for all options. Values not supplied
fall back to sensible defaults.
Call :meth:`train` to run all episodes. Checkpoints are saved every 100
episodes and training can be stopped (Ctrl+C) and resumed at any time.
Example::
from retro_gamer import GameMetadata, DQNTrainer
from retro.examples.snake import create_game
metadata = GameMetadata.from_pyproject("retro.examples.snake")
trainer = DQNTrainer(create_game, metadata, "runs/snake/")
trainer.train()
"""
def __init__(
@@ -49,26 +261,80 @@ class DQNTrainer:
game_factory: Callable,
metadata: GameMetadata,
run_dir: str | Path,
preprocessing: dict | None = None,
**hyperparams,
):
self.game_factory = game_factory
self.metadata = metadata
self.run_dir = Path(run_dir)
self.hp: dict = {**DEFAULTS, **hyperparams}
validate_hyperparams(self.hp)
self.run_dir.mkdir(parents=True, exist_ok=True)
(self.run_dir / 'checkpoints').mkdir(exist_ok=True)
self.env = GameEnvironment(game_factory, metadata)
pre = preprocessing or {}
self.observe_state: list[str] = pre.get('observe_state', [])
self.egocentric: bool = pre.get('egocentric', False)
self.egocentric_player: str | None = pre.get('egocentric_player', None)
self.egocentric_radius: int | None = pre.get('egocentric_radius', None)
self.board: bool = pre.get('board', True)
self.observe_state_sizes: dict[str, int] = pre.get('observe_state_sizes', {})
if self.board is False and metadata.spatial:
raise ValueError(
"preprocessing.board = false is incompatible with spatial = true.\n"
"A CNN requires a 2-D board to operate on. Either set spatial = false\n"
"or keep board = true."
)
if self.board is False and not self.observe_state:
raise ValueError(
"preprocessing.board = false requires at least one entry in observe_state.\n"
"With board=false, the agent observes only the game state variables listed\n"
"in observe_state — if that list is empty, there is nothing to observe."
)
if self.egocentric and not self.egocentric_radius:
raise ValueError(
"preprocessing.egocentric = true requires egocentric_radius.\n"
"Choose a value based on how far the agent needs to see, e.g.:\n"
" egocentric_radius = 5 # 11×11 tight local view\n"
" egocentric_radius = 8 # 17×17 wider view"
)
metadata.board = self.board
if metadata.board_size is None:
g = game_factory()
metadata.board_size = g.board_size
if metadata.character_set is None:
if self.egocentric_radius:
side = 2 * self.egocentric_radius + 1
metadata.board_size = (side, side)
self.env = GameEnvironment(
game_factory, metadata,
observe_state=self.observe_state,
egocentric=self.egocentric,
egocentric_player=self.egocentric_player,
egocentric_radius=self.egocentric_radius,
board=self.board,
observe_state_sizes=self.observe_state_sizes,
)
if metadata.character_set is None and self.board:
self._discover_character_set()
if self.observe_state and not self.observe_state_sizes:
self._discover_observe_state_sizes()
self.env.observe_state_sizes = self.observe_state_sizes
metadata.extras_size = sum(self.observe_state_sizes.values()) if self.observe_state_sizes else 0
self.device = _get_device()
self.model, rationale = build_network(metadata, self.hp)
self.target_model, _ = build_network(metadata, self.hp)
self.model.to(self.device)
self.target_model.to(self.device)
self.target_model.load_state_dict(self.model.state_dict())
self.target_model.eval()
@@ -76,7 +342,7 @@ class DQNTrainer:
self.model.parameters(), lr=self.hp['learning_rate']
)
self.lr_scheduler = optim.lr_scheduler.ExponentialLR(
self.optimizer, gamma=self.hp['lr_decay']
self.optimizer, gamma=self.hp['learning_rate_decay']
)
if self.hp['prioritize_experiences']:
@@ -86,6 +352,9 @@ class DQNTrainer:
self.epsilon: float = self.hp['epsilon']
self.total_steps: int = 0
self.total_training_seconds: float = 0.0
self.start_episode: int = 1
self._resumed_from: str | None = None
self._save_config()
self._open_log(rationale)
@@ -94,49 +363,127 @@ class DQNTrainer:
# Public API
# ------------------------------------------------------------------
def train(self):
"""Run all training episodes and save checkpoints."""
for episode in range(1, self.hp['training_episodes'] + 1):
total_reward, steps, avg_loss = self._run_episode()
def train(self, on_checkpoint=None, on_episode=None):
"""Run all training episodes and save checkpoints.
*on_checkpoint*, if provided, is called after each checkpoint with a
dict containing ``episode``, ``avg_reward``, ``avg_steps``,
``avg_loss``, and ``epsilon``. *on_episode*, if provided, is called
after every episode. When either callback is supplied, the built-in
tqdm progress bar is suppressed (the caller is expected to show its
own progress UI).
"""
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
if self._resumed_from:
self._log_raw(f'\n=== Resumed from {self._resumed_from} | {timestamp} ===')
else:
self._log_raw(f'\n=== Training started | {timestamp} ===')
use_tqdm = on_checkpoint is None and on_episode is None
if use_tqdm:
print("Press Control+C to stop training early. Progress will be saved at the latest checkpoint.")
session_start = perf_counter()
ckpt_start = perf_counter()
episode_rewards: list[float] = []
episode_losses: list[float] = []
episode_steps: list[int] = []
episodes = range(self.start_episode, self.hp['training_episodes'] + 1)
bar = tqdm(episodes, unit='ep') if use_tqdm else episodes
for episode in bar:
total_reward, steps, avg_loss, trained = self._run_episode()
episode_rewards.append(total_reward)
if avg_loss > 0:
episode_losses.append(avg_loss)
episode_steps.append(steps)
self.epsilon = max(
self.hp['epsilon_min'], self.epsilon * self.hp['epsilon_decay']
)
self.lr_scheduler.step()
self._log_episode(episode, total_reward, steps, avg_loss)
if episode % 100 == 0:
self._save_checkpoint(f'ep_{episode:04d}.pt')
self._save_checkpoint('final.pt')
if trained:
self.lr_scheduler.step()
if on_episode:
on_episode()
if use_tqdm:
bar.set_postfix(
reward=f'{total_reward:.1f}',
eps=f'{self.epsilon:.3f}',
loss=f'{avg_loss:.4f}',
)
is_checkpoint = (episode % 100 == 0)
is_last = (episode == self.hp['training_episodes'])
if is_checkpoint or (is_last and episode_rewards):
now = perf_counter()
ckpt_elapsed = now - ckpt_start
self.total_training_seconds += ckpt_elapsed
ckpt_start = now
self._save_checkpoint(f'ep_{episode:04d}.pt', episode)
stats = self._log_checkpoint(episode, episode_rewards, episode_losses, episode_steps, ckpt_elapsed)
episode_rewards = []
episode_losses = []
episode_steps = []
if on_checkpoint:
on_checkpoint(stats)
def load_checkpoint(self, path: str | Path):
ckpt = torch.load(path, weights_only=True)
"""Load a checkpoint to resume training.
Checkpoints are PyTorch state dicts stored under
``run_dir/checkpoints/``. Each contains model weights, optimizer
state, current epsilon, and total step count.
Raises :exc:`ValueError` if the checkpoint was trained with a
different character set, board size, action space, or network
architecture. The error message names each changed field and explains
why it is incompatible.
The CLI invokes this automatically; call directly only when driving
training from Python.
"""
ckpt = torch.load(path, weights_only=True, map_location='cpu')
self._check_compatibility(ckpt, path)
self.model.load_state_dict(ckpt['model_state_dict'])
self.target_model.load_state_dict(ckpt['model_state_dict'])
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
for state in self.optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(self.device)
self.epsilon = ckpt['epsilon']
self.total_steps = ckpt['total_steps']
self.total_training_seconds = ckpt.get('total_training_seconds', 0.0)
self.start_episode = ckpt.get('episode', 0) + 1
self._resumed_from = Path(path).name
# ------------------------------------------------------------------
# Training loop internals
# ------------------------------------------------------------------
def _run_episode(self) -> tuple[float, int, float]:
def _run_episode(self) -> tuple[float, int, float, bool]:
state = self.env.reset()
total_reward = 0.0
total_loss = 0.0
loss_count = 0
for step in range(self.hp['max_turns_per_episode']):
state_t = torch.as_tensor(state, dtype=torch.float32)
state_t = torch.as_tensor(state, dtype=torch.float32).to(self.device)
action_idx = self._select_action(state_t)
action_key = self._idx_to_key(action_idx)
next_state, reward, done = self.env.step(action_key)
self.memory.push(state, action_idx, reward, next_state, done)
loss = self._train_step()
if loss is not None:
total_loss += loss
loss_count += 1
if self.total_steps % self.hp['train_every'] == 0:
loss = self._train_step()
if loss is not None:
total_loss += loss
loss_count += 1
self.total_steps += 1
if self.total_steps % self.hp['target_update_freq'] == 0:
@@ -148,7 +495,7 @@ class DQNTrainer:
break
avg_loss = total_loss / loss_count if loss_count else 0.0
return total_reward, step + 1, avg_loss
return total_reward, step + 1, avg_loss, loss_count > 0
def _select_action(self, state_t: torch.Tensor) -> int:
if random.random() < self.epsilon:
@@ -168,7 +515,7 @@ class DQNTrainer:
if self.hp['prioritize_experiences']:
assert isinstance(self.memory, PrioritizedReplayMemory)
experiences, indices, weights = self.memory.sample(self.hp['batch_size'])
weight_t = torch.as_tensor(weights, dtype=torch.float32)
weight_t = torch.as_tensor(weights, dtype=torch.float32).to(self.device)
else:
experiences = self.memory.sample(self.hp['batch_size'])
indices = None
@@ -176,33 +523,107 @@ class DQNTrainer:
states = torch.as_tensor(
np.array([e.state for e in experiences]), dtype=torch.float32
)
actions = torch.as_tensor([e.action for e in experiences], dtype=torch.long)
rewards = torch.as_tensor([e.reward for e in experiences], dtype=torch.float32)
).to(self.device)
actions = torch.as_tensor(
[e.action for e in experiences], dtype=torch.long
).to(self.device)
rewards = torch.as_tensor(
[e.reward for e in experiences], dtype=torch.float32
).to(self.device)
next_states = torch.as_tensor(
np.array([e.next_state for e in experiences]), dtype=torch.float32
)
dones = torch.as_tensor([e.done for e in experiences], dtype=torch.float32)
).to(self.device)
dones = torch.as_tensor(
[e.done for e in experiences], dtype=torch.float32
).to(self.device)
q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
with torch.no_grad():
next_q = self.target_model(next_states).max(1).values
targets = rewards + self.hp['gamma'] * next_q * (1.0 - dones)
element_loss = nn.functional.mse_loss(q_values, targets, reduction='none')
element_loss = nn.functional.huber_loss(q_values, targets, reduction='none', delta=1.0)
if weight_t is not None:
loss = (weight_t * element_loss).mean()
td_errors = (q_values - targets).detach().abs().numpy()
td_errors = (q_values - targets).detach().abs().cpu().numpy()
self.memory.update_priorities(indices, td_errors)
else:
loss = element_loss.mean()
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
self.optimizer.step()
return float(loss.item())
# ------------------------------------------------------------------
# Compatibility checking
# ------------------------------------------------------------------
def _config_snapshot(self) -> dict:
return {
'metadata': self.metadata.to_dict(),
'preprocessing': {
'spatial': self.metadata.spatial,
'board': self.board,
'observe_state': self.observe_state,
'observe_state_sizes': self.observe_state_sizes,
'egocentric': self.egocentric,
'egocentric_player': self.egocentric_player,
'egocentric_radius': self.egocentric_radius,
},
'hidden_sizes': self.hp['hidden_sizes'],
}
def _check_compatibility(self, ckpt: dict, path: str | Path):
snapshot = ckpt.get('config_snapshot')
if snapshot is None:
return
current = self._config_snapshot()
issues = []
old_meta = snapshot.get('metadata', {})
new_meta = current['metadata']
for field, desc in _INCOMPATIBLE_METADATA.items():
if old_meta.get(field) != new_meta.get(field):
issues.append((field, desc, old_meta.get(field), new_meta.get(field)))
old_pre = snapshot.get('preprocessing', {})
new_pre = current['preprocessing']
for field, desc in _INCOMPATIBLE_PREPROCESSING.items():
if old_pre.get(field) != new_pre.get(field):
issues.append((field, desc, old_pre.get(field), new_pre.get(field)))
for field, desc in _INCOMPATIBLE_ARCH.items():
if snapshot.get(field) != current.get(field):
issues.append((field, desc, snapshot.get(field), current.get(field)))
if not issues:
return
lines = [
f"Cannot resume from {Path(path).name}: incompatible changes detected in config.toml.",
"",
"The following changes require starting fresh. The existing model was trained",
"on a different problem and its weights cannot be reused:",
"",
]
for field, desc, old_val, new_val in issues:
lines += [
f" {field}",
f" was : {old_val!r}",
f" now : {new_val!r}",
f" why : {desc}",
"",
]
lines += [
"Run 'retro-gamer clean RUN_DIR' to remove existing checkpoints and the",
"training log, then run 'retro-gamer train RUN_DIR' to start fresh.",
]
raise ValueError("\n".join(lines))
# ------------------------------------------------------------------
# Initialisation helpers
# ------------------------------------------------------------------
@@ -215,6 +636,16 @@ class DQNTrainer:
f"after {self.hp['exploration_turns']} exploration turns: {chars}"
)
def _discover_observe_state_sizes(self):
"""Sample game.state to determine the flat size of each observe_state key."""
self.env.reset()
state = dict(self.env.game.state)
sizes = {}
for key in self.observe_state:
val = state.get(key, 0)
sizes[key] = len(val) if isinstance(val, (list, tuple)) else 1
self.observe_state_sizes = sizes
def _save_config(self):
config_path = self.run_dir / 'config.toml'
config: dict = {}
@@ -223,33 +654,75 @@ class DQNTrainer:
with open(config_path, 'rb') as f:
config = tomllib.load(f)
config['metadata'] = self.metadata.to_dict()
config['hyperparameters'] = self.hp
pre = config.setdefault('preprocessing', {})
pre['spatial'] = self.metadata.spatial
pre['board'] = self.board
pre['observe_state'] = self.observe_state
if self.observe_state_sizes:
pre['observe_state_sizes'] = self.observe_state_sizes
pre['egocentric'] = self.egocentric
if self.egocentric_player:
pre['egocentric_player'] = self.egocentric_player
if self.egocentric_radius:
pre['egocentric_radius'] = self.egocentric_radius
config['model'] = {k: v for k, v in self.hp.items() if k in MODEL_KEYS}
config['training'] = {k: v for k, v in self.hp.items() if k not in MODEL_KEYS}
with open(config_path, 'wb') as f:
tomli_w.dump(config, f)
def _open_log(self, rationale: str):
self.log_path = self.run_dir / 'training.log'
with open(self.log_path, 'w') as f:
f.write(rationale + '\n')
if not self.log_path.exists():
with open(self.log_path, 'w') as f:
f.write(rationale + '\n')
f.write(f'[INIT] Device: {self.device}\n')
def _log_raw(self, line: str):
with open(self.log_path, 'a') as f:
f.write(line + '\n')
def _log_episode(self, episode: int, total_reward: float, steps: int, avg_loss: float):
def _log_checkpoint(
self,
episode: int,
rewards: list[float],
losses: list[float],
steps: list[int],
ckpt_elapsed: float,
) -> dict:
n = len(rewards)
start_ep = episode - n + 1
avg_reward = sum(rewards) / n if n else 0.0
avg_loss = sum(losses) / len(losses) if losses else 0.0
avg_steps = sum(steps) / n if n else 0.0
line = (
f"[EP {episode:04d}] total_reward={total_reward:.1f} "
f"steps={steps} epsilon={self.epsilon:.4f} avg_loss={avg_loss:.6f}"
f"[ep_{episode:04d}]"
f" ep={start_ep:04d}-{episode:04d}"
f" avg_reward={avg_reward:+.1f}"
f" avg_steps={avg_steps:.0f}"
f" epsilon={self.epsilon:.3f}"
f" avg_loss={avg_loss:.1f}"
f" time={_format_duration(ckpt_elapsed)}"
f" total={_format_duration(self.total_training_seconds)}"
)
self._log_raw(line)
return {
'episode': episode,
'avg_reward': avg_reward,
'avg_steps': avg_steps,
'avg_loss': avg_loss,
'epsilon': self.epsilon,
}
def _save_checkpoint(self, name: str):
def _save_checkpoint(self, name: str, episode: int):
torch.save(
{
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'epsilon': self.epsilon,
'total_steps': self.total_steps,
'episode': episode,
'total_training_seconds': self.total_training_seconds,
'config_snapshot': self._config_snapshot(),
},
self.run_dir / 'checkpoints' / name,
)