Updates across the board
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
from retro_gamer.env import GameEnvironment
|
||||
from retro_gamer.trainer import DQNTrainer
|
||||
from retro_gamer.model_agent import TrainedPolicy, PolicyInput
|
||||
|
||||
__all__ = ["GameMetadata", "GameEnvironment", "DQNTrainer"]
|
||||
__all__ = ["GameMetadata", "GameEnvironment", "DQNTrainer", "TrainedPolicy", "PolicyInput"]
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import sys
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
import click
|
||||
import tomli_w
|
||||
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
from retro_gamer.trainer import DQNTrainer, DEFAULTS
|
||||
from retro_gamer.trainer import DQNTrainer, DEFAULTS, MODEL_KEYS
|
||||
|
||||
|
||||
@click.group()
|
||||
@@ -20,13 +21,14 @@ def cli():
|
||||
|
||||
@cli.command()
|
||||
@click.option('--game', required=True,
|
||||
help='Python module containing create_game() e.g. retro.examples.snake')
|
||||
help='Game to train: a .py file path (e.g. my_game.py) or a Python module '
|
||||
'(e.g. retro.examples.snake)')
|
||||
@click.option('--output', required=True,
|
||||
help='Directory to create for this training run')
|
||||
@click.option('--learning-rate', default=DEFAULTS['learning_rate'], type=float,
|
||||
help=f"Adam optimizer learning rate (default {DEFAULTS['learning_rate']})")
|
||||
@click.option('--lr-decay', default=DEFAULTS['lr_decay'], type=float,
|
||||
help=f"Multiplicative LR decay per episode (default {DEFAULTS['lr_decay']})")
|
||||
@click.option('--learning-rate-decay', default=DEFAULTS['learning_rate_decay'], type=float,
|
||||
help=f"Multiplicative LR decay per episode (default {DEFAULTS['learning_rate_decay']})")
|
||||
@click.option('--gamma', default=DEFAULTS['gamma'], type=float,
|
||||
help=f"Discount factor for future rewards (default {DEFAULTS['gamma']})")
|
||||
@click.option('--epsilon-decay', default=DEFAULTS['epsilon_decay'], type=float,
|
||||
@@ -43,12 +45,12 @@ def cli():
|
||||
help=f"Number of episodes to train (default {DEFAULTS['training_episodes']})")
|
||||
@click.option('--max-turns-per-episode', default=DEFAULTS['max_turns_per_episode'], type=int,
|
||||
help=f"Turn limit per episode (default {DEFAULTS['max_turns_per_episode']})")
|
||||
@click.option('--n-layers', default=DEFAULTS['n_layers'], type=int,
|
||||
help=f"Hidden layers in MLP head (default {DEFAULTS['n_layers']})")
|
||||
@click.option('--layer-size', default=DEFAULTS['layer_size'], type=int,
|
||||
help=f"Width of each hidden layer (default {DEFAULTS['layer_size']})")
|
||||
@click.option('--hidden-sizes', default=','.join(str(s) for s in DEFAULTS['hidden_sizes']),
|
||||
help=f"Comma-separated hidden layer sizes, e.g. 512,256 (default {DEFAULTS['hidden_sizes']})")
|
||||
@click.option('--exploration-turns', default=DEFAULTS['exploration_turns'], type=int,
|
||||
help=f"Random turns for character discovery (default {DEFAULTS['exploration_turns']})")
|
||||
@click.option('--train-every', default=DEFAULTS['train_every'], type=int,
|
||||
help=f"Run a training step every N game steps (default {DEFAULTS['train_every']})")
|
||||
@click.option('--prioritize-experiences/--no-prioritize-experiences',
|
||||
default=DEFAULTS['prioritize_experiences'],
|
||||
help='Use prioritized experience replay')
|
||||
@@ -60,12 +62,23 @@ def create(game, output, **hyperparams):
|
||||
Board size is read directly from the game. Hyperparameter options
|
||||
control how the trainer learns, not what it learns about.
|
||||
"""
|
||||
raw = hyperparams['hidden_sizes']
|
||||
try:
|
||||
metadata = GameMetadata.from_pyproject(game)
|
||||
hyperparams['hidden_sizes'] = [int(x.strip()) for x in raw.split(',')]
|
||||
except ValueError:
|
||||
raise click.ClickException(
|
||||
f"Could not parse --hidden-sizes {raw!r}.\n"
|
||||
"It should be a comma-separated list of positive integers, one per hidden layer.\n"
|
||||
"Example: --hidden-sizes 512,256"
|
||||
)
|
||||
game_config = _parse_game_arg(game)
|
||||
|
||||
try:
|
||||
metadata = GameMetadata.from_pyproject(game_config['module'])
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
game_factory = _load_factory(game)
|
||||
game_factory = _load_factory(game_config)
|
||||
g = game_factory()
|
||||
metadata.board_size = g.board_size
|
||||
|
||||
@@ -74,10 +87,13 @@ def create(game, output, **hyperparams):
|
||||
run_dir = Path(output)
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
preprocessing = {'spatial': metadata.spatial}
|
||||
config = {
|
||||
'game': {'module': game},
|
||||
'game': game_config,
|
||||
'metadata': metadata.to_dict(),
|
||||
'hyperparameters': hyperparams,
|
||||
'preprocessing': preprocessing,
|
||||
'model': {k: v for k, v in hyperparams.items() if k in MODEL_KEYS},
|
||||
'training': {k: v for k, v in hyperparams.items() if k not in MODEL_KEYS},
|
||||
}
|
||||
with open(run_dir / 'config.toml', 'wb') as f:
|
||||
tomli_w.dump(config, f)
|
||||
@@ -91,8 +107,6 @@ def create(game, output, **hyperparams):
|
||||
click.echo(f" characters : {metadata.character_set}")
|
||||
else:
|
||||
click.echo(f" characters : (will be auto-discovered during training)")
|
||||
if metadata.observe_state:
|
||||
click.echo(f" observe : {metadata.observe_state}")
|
||||
click.echo(f" architecture: {'CNN (spatial)' if metadata.spatial else 'MLP (non-spatial)'}")
|
||||
|
||||
|
||||
@@ -102,23 +116,49 @@ def create(game, output, **hyperparams):
|
||||
|
||||
@cli.command()
|
||||
@click.argument('run_dir')
|
||||
@click.option('--resume', default=None,
|
||||
help='Path to checkpoint to resume from (e.g. checkpoints/ep_0500.pt)')
|
||||
def train(run_dir, resume):
|
||||
"""Train (or resume training) a DQN agent."""
|
||||
def train(run_dir):
|
||||
"""Train a DQN agent, resuming automatically from the latest checkpoint.
|
||||
|
||||
To start fresh or resume from an earlier point, delete the checkpoints
|
||||
you no longer want from RUN_DIR/checkpoints/.
|
||||
"""
|
||||
run_dir_path = Path(run_dir)
|
||||
config = _load_config(run_dir_path)
|
||||
game_factory = _load_factory(config['game']['module'])
|
||||
game_factory = _load_factory(config['game'])
|
||||
metadata = GameMetadata.from_dict(config['metadata'])
|
||||
hyperparams = config.get('hyperparameters', {})
|
||||
preprocessing = config.get('preprocessing', {})
|
||||
metadata.spatial = preprocessing.get('spatial', False)
|
||||
hyperparams = {**config.get('model', {}), **config['training']}
|
||||
|
||||
trainer = DQNTrainer(game_factory, metadata, run_dir, **hyperparams)
|
||||
if resume:
|
||||
click.echo(f"Resuming from {resume}")
|
||||
trainer.load_checkpoint(resume)
|
||||
click.echo(f"Training for {trainer.hp['training_episodes']} episodes…")
|
||||
trainer.train()
|
||||
click.echo(f"Done. Checkpoints in {run_dir}/checkpoints/")
|
||||
try:
|
||||
trainer = DQNTrainer(game_factory, metadata, run_dir, preprocessing=preprocessing, **hyperparams)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
latest = _latest_checkpoint(run_dir_path)
|
||||
if latest:
|
||||
click.echo(f"Resuming from {latest.name}")
|
||||
try:
|
||||
trainer.load_checkpoint(latest)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
if trainer.start_episode > trainer.hp['training_episodes']:
|
||||
click.echo(
|
||||
f"Training already complete ({trainer.hp['training_episodes']} episodes). "
|
||||
"To keep training, increase training_episodes in config.toml."
|
||||
)
|
||||
return
|
||||
|
||||
from retro_gamer.log_parser import parse_checkpoints
|
||||
from retro_gamer.dashboard import TrainingDashboard
|
||||
history = parse_checkpoints(run_dir_path / 'training.log')
|
||||
display = TrainingDashboard(trainer.hp['training_episodes'], trainer.start_episode, history)
|
||||
try:
|
||||
trainer.train(on_checkpoint=display.on_checkpoint, on_episode=display.on_episode)
|
||||
finally:
|
||||
display.close()
|
||||
click.echo(f"Done. Checkpoints saved in {run_dir}/checkpoints/")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -127,69 +167,72 @@ def train(run_dir, resume):
|
||||
|
||||
@cli.command()
|
||||
@click.argument('run_dir')
|
||||
@click.option('--checkpoint', default='final',
|
||||
help='Checkpoint name e.g. "final" or "ep_0100"')
|
||||
@click.option('--checkpoint', default=None,
|
||||
help='Checkpoint to load, e.g. "ep_0100". Defaults to the latest available.')
|
||||
@click.option('--framerate', default=12, type=int,
|
||||
help='Target frames per second')
|
||||
def play(run_dir, checkpoint, framerate):
|
||||
"""Watch a trained agent play the game."""
|
||||
import torch
|
||||
from time import sleep, perf_counter
|
||||
from blessed import Terminal
|
||||
from retro.input import ProgrammaticInput
|
||||
from retro.views.headless import HeadlessView
|
||||
from retro.views.terminal import TerminalView
|
||||
from retro_gamer.observation import encode_observation
|
||||
from retro_gamer.model_agent import TrainedPolicy, PolicyInput
|
||||
|
||||
run_dir_path = Path(run_dir)
|
||||
config = _load_config(run_dir_path)
|
||||
game_factory = _load_factory(config['game']['module'])
|
||||
metadata = GameMetadata.from_dict(config['metadata'])
|
||||
hyperparams = {**DEFAULTS, **config.get('hyperparameters', {})}
|
||||
game_factory = _load_factory(config['game'])
|
||||
|
||||
from retro_gamer.network import build_network
|
||||
model, _ = build_network(metadata, hyperparams)
|
||||
try:
|
||||
ai = TrainedPolicy(run_dir_path, checkpoint=checkpoint)
|
||||
except FileNotFoundError as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt'
|
||||
ckpt_path = run_dir_path / 'checkpoints' / ckpt_name
|
||||
ckpt = torch.load(ckpt_path, weights_only=True)
|
||||
model.load_state_dict(ckpt['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
inp = ProgrammaticInput()
|
||||
headless = HeadlessView()
|
||||
game = game_factory()
|
||||
game.input_source = inp
|
||||
game.view = headless
|
||||
game.start()
|
||||
inp = PolicyInput(ai, game)
|
||||
|
||||
terminal = Terminal()
|
||||
term_view = TerminalView(terminal, color=game.color)
|
||||
|
||||
click.echo("Playing… (press Escape or Enter to quit)")
|
||||
with terminal.fullscreen(), terminal.hidden_cursor(), terminal.cbreak():
|
||||
term_view.on_game_start(game)
|
||||
game.input_source = inp
|
||||
game.view = term_view
|
||||
game.start()
|
||||
while game.playing:
|
||||
t0 = perf_counter()
|
||||
|
||||
obs = encode_observation(headless.board_characters, dict(game.state), metadata)
|
||||
state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
action_idx = int(model(state_t).argmax().item())
|
||||
action_key = None if action_idx >= len(metadata.actions) else metadata.actions[action_idx]
|
||||
|
||||
key = terminal.inkey(0)
|
||||
if key and key.name in ('KEY_ESCAPE', 'KEY_ENTER'):
|
||||
break
|
||||
|
||||
inp.press(action_key)
|
||||
game.step()
|
||||
term_view.render(game)
|
||||
|
||||
elapsed = perf_counter() - t0
|
||||
sleep(max(0, 1 / framerate - elapsed))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# retro-gamer plot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@cli.command()
|
||||
@click.argument('run_dir')
|
||||
@click.option('--output', '-o', default=None,
|
||||
help='Save to file (e.g. plot.png) instead of displaying interactively.')
|
||||
def plot(run_dir, output):
|
||||
"""Plot training metrics (reward, steps, loss, epsilon) from a run's log."""
|
||||
from retro_gamer.plotter import plot_run
|
||||
run_dir_path = Path(run_dir)
|
||||
log_path = run_dir_path / 'training.log'
|
||||
if not log_path.exists():
|
||||
raise click.ClickException(f"No training.log found in {run_dir}")
|
||||
output_path = Path(output) if output else None
|
||||
try:
|
||||
plot_run(log_path, output_path)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# retro-gamer info
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -200,17 +243,19 @@ def info(run_dir):
|
||||
"""Print a summary of a training run."""
|
||||
run_dir_path = Path(run_dir)
|
||||
config = _load_config(run_dir_path)
|
||||
click.echo(f"Game module : {config['game']['module']}")
|
||||
click.echo(f"Metadata : {config['metadata']}")
|
||||
click.echo(f"Hyperparams : {config.get('hyperparameters', {})}")
|
||||
click.echo(f"Game module : {config['game']['module']}")
|
||||
click.echo(f"Metadata : {config['metadata']}")
|
||||
click.echo(f"Preprocessing : {config.get('preprocessing', {})}")
|
||||
click.echo(f"Model : {config.get('model', {})}")
|
||||
click.echo(f"Training : {config['training']}")
|
||||
|
||||
log_path = run_dir_path / 'training.log'
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text().splitlines()
|
||||
episode_lines = [l for l in lines if l.startswith('[EP')]
|
||||
if episode_lines:
|
||||
click.echo(f"\nLast 5 episodes:")
|
||||
for line in episode_lines[-5:]:
|
||||
ckpt_lines = [l for l in lines if l.startswith('[ep_')]
|
||||
if ckpt_lines:
|
||||
click.echo(f"\nLast 5 checkpoints:")
|
||||
for line in ckpt_lines[-5:]:
|
||||
click.echo(f" {line}")
|
||||
|
||||
ckpt_dir = run_dir_path / 'checkpoints'
|
||||
@@ -219,10 +264,91 @@ def info(run_dir):
|
||||
click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# retro-gamer clean
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@cli.command()
|
||||
@click.argument('run_dir')
|
||||
@click.option('--yes', '-y', is_flag=True, help='Skip confirmation prompt')
|
||||
def clean(run_dir, yes):
|
||||
"""Remove all checkpoints and the training log from a run directory.
|
||||
|
||||
Use this after changing the game description or network architecture,
|
||||
which require starting training from scratch.
|
||||
"""
|
||||
run_dir_path = Path(run_dir)
|
||||
if not run_dir_path.exists():
|
||||
raise click.ClickException(f"Directory not found: {run_dir}")
|
||||
|
||||
to_remove = []
|
||||
ckpt_dir = run_dir_path / 'checkpoints'
|
||||
if ckpt_dir.exists():
|
||||
to_remove.extend(sorted(ckpt_dir.glob('*.pt')))
|
||||
log_path = run_dir_path / 'training.log'
|
||||
if log_path.exists():
|
||||
to_remove.append(log_path)
|
||||
|
||||
if not to_remove:
|
||||
click.echo("Nothing to clean.")
|
||||
return
|
||||
|
||||
n_ckpts = sum(1 for p in to_remove if p.suffix == '.pt')
|
||||
click.echo(f"Will remove {n_ckpts} checkpoint(s) and training log from {run_dir}/:")
|
||||
for p in to_remove:
|
||||
click.echo(f" {p.relative_to(run_dir_path)}")
|
||||
|
||||
if not yes:
|
||||
click.confirm("\nProceed?", abort=True)
|
||||
|
||||
for p in to_remove:
|
||||
p.unlink()
|
||||
click.echo(f"Cleaned. Run 'retro-gamer train {run_dir}' to start fresh.")
|
||||
|
||||
ckpt_dir = run_dir_path / 'checkpoints'
|
||||
if ckpt_dir.exists():
|
||||
ckpts = sorted(ckpt_dir.glob('*.pt'))
|
||||
click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_game_arg(arg: str) -> dict:
|
||||
"""Accept a .py file path, a package directory, or a Python module name.
|
||||
|
||||
Returns a dict with at least 'module', and optionally 'path' (the
|
||||
directory to add to sys.path so the module can be imported).
|
||||
"""
|
||||
p = Path(arg)
|
||||
if p.exists():
|
||||
if p.is_file() and p.suffix == '.py':
|
||||
path = str(p.parent.resolve())
|
||||
module = p.stem
|
||||
elif p.is_dir():
|
||||
path = str(p.parent.resolve())
|
||||
module = p.name
|
||||
else:
|
||||
raise click.ClickException(
|
||||
f"{arg!r} is not a .py file or package directory"
|
||||
)
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
return {'module': module, 'path': path}
|
||||
return {'module': arg}
|
||||
|
||||
|
||||
def _latest_checkpoint(run_dir: Path) -> Path | None:
|
||||
"""Return the most recent checkpoint in run_dir/checkpoints/, or None."""
|
||||
ckpt_dir = run_dir / 'checkpoints'
|
||||
if ckpt_dir.exists():
|
||||
candidates = sorted(ckpt_dir.glob('ep_*.pt'))
|
||||
if candidates:
|
||||
return candidates[-1]
|
||||
return None
|
||||
|
||||
|
||||
def _load_config(run_dir: Path) -> dict:
|
||||
config_path = run_dir / 'config.toml'
|
||||
if not config_path.exists():
|
||||
@@ -231,7 +357,11 @@ def _load_config(run_dir: Path) -> dict:
|
||||
return tomllib.load(f)
|
||||
|
||||
|
||||
def _load_factory(module_name: str):
|
||||
def _load_factory(game_config: dict):
|
||||
path = game_config.get('path')
|
||||
if path and path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
module_name = game_config['module']
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError as e:
|
||||
@@ -241,3 +371,4 @@ def _load_factory(module_name: str):
|
||||
f"Module '{module_name}' has no create_game() function"
|
||||
)
|
||||
return module.create_game
|
||||
|
||||
|
||||
86
retro_gamer/dashboard.py
Normal file
86
retro_gamer/dashboard.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import sys
|
||||
from tqdm import tqdm
|
||||
|
||||
_CHART_HEIGHT = 22 # plotext chart area height in lines
|
||||
|
||||
|
||||
def _terminal_width() -> int:
|
||||
try:
|
||||
return min(os.get_terminal_size().columns, 220)
|
||||
except OSError:
|
||||
return 120
|
||||
|
||||
|
||||
def _build_charts(history: list[dict], width: int) -> str:
|
||||
import plotext as plt
|
||||
episodes = [d['episode'] for d in history]
|
||||
series = [
|
||||
("Epsilon", [d['epsilon'] for d in history]),
|
||||
("Avg Steps", [d['avg_steps'] for d in history]),
|
||||
("Avg Loss", [d['avg_loss'] for d in history]),
|
||||
("Avg Reward", [d['avg_reward'] for d in history]),
|
||||
]
|
||||
panel_w = max(width // len(series), 20)
|
||||
panels = []
|
||||
for title, values in series:
|
||||
plt.clf()
|
||||
plt.canvas_color("default")
|
||||
plt.axes_color("default")
|
||||
plt.ticks_color("default")
|
||||
if episodes:
|
||||
plt.plot(episodes, values, color="default")
|
||||
plt.title(title)
|
||||
plt.xlabel("Episode")
|
||||
plt.plotsize(panel_w, _CHART_HEIGHT)
|
||||
panels.append(plt.build().splitlines())
|
||||
|
||||
height = max(len(p) for p in panels)
|
||||
for p in panels:
|
||||
while len(p) < height:
|
||||
p.append(' ' * panel_w)
|
||||
return '\n'.join(''.join(row) for row in zip(*panels))
|
||||
|
||||
|
||||
class TrainingDashboard:
|
||||
"""Inline training display: a plotext chart block that redraws in place
|
||||
above a tqdm per-episode progress bar."""
|
||||
|
||||
def __init__(self, total_episodes: int, start_episode: int, history: list[dict]):
|
||||
self._total = total_episodes
|
||||
self._history = list(history)
|
||||
self._rendered = False # have we drawn charts yet?
|
||||
|
||||
already_done = start_episode - 1
|
||||
self._bar = tqdm(
|
||||
initial=already_done,
|
||||
total=total_episodes,
|
||||
unit='ep',
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
self._draw()
|
||||
|
||||
def on_episode(self) -> None:
|
||||
self._bar.update(1)
|
||||
|
||||
def on_checkpoint(self, stats: dict) -> None:
|
||||
self._history.append(stats)
|
||||
self._bar.clear()
|
||||
self._draw()
|
||||
self._bar.refresh()
|
||||
|
||||
def close(self) -> None:
|
||||
self._bar.close()
|
||||
|
||||
def _draw(self) -> None:
|
||||
chart = _build_charts(self._history, _terminal_width())
|
||||
n_lines = len(chart.splitlines())
|
||||
if self._rendered:
|
||||
sys.stdout.write(f'\033[{n_lines}A')
|
||||
sys.stdout.write(chart)
|
||||
if not chart.endswith('\n'):
|
||||
sys.stdout.write('\n')
|
||||
sys.stdout.flush()
|
||||
self._rendered = True
|
||||
@@ -9,15 +9,27 @@ from retro_gamer.observation import encode_observation
|
||||
|
||||
|
||||
class GameEnvironment:
|
||||
"""Gym-style wrapper around a retro Game for RL training.
|
||||
"""Gym-style wrapper around a retro game for RL training."""
|
||||
|
||||
Provides reset() / step(action) / observe(), managing one training episode
|
||||
at a time. The game is restarted by calling the factory function on each reset.
|
||||
"""
|
||||
|
||||
def __init__(self, game_factory: Callable, metadata: GameMetadata):
|
||||
def __init__(
|
||||
self,
|
||||
game_factory: Callable,
|
||||
metadata: GameMetadata,
|
||||
observe_state: list[str] | None = None,
|
||||
egocentric: bool = False,
|
||||
egocentric_player: str | None = None,
|
||||
egocentric_radius: int | None = None,
|
||||
board: bool = True,
|
||||
observe_state_sizes: dict[str, int] | None = None,
|
||||
):
|
||||
self.game_factory = game_factory
|
||||
self.metadata = metadata
|
||||
self.observe_state = observe_state or []
|
||||
self.egocentric = egocentric
|
||||
self.egocentric_player = egocentric_player
|
||||
self.egocentric_radius = egocentric_radius
|
||||
self.board = board
|
||||
self.observe_state_sizes = observe_state_sizes or {}
|
||||
self.game = None
|
||||
self.view: HeadlessView | None = None
|
||||
self.inp: ProgrammaticInput | None = None
|
||||
@@ -35,12 +47,7 @@ class GameEnvironment:
|
||||
return self._observe()
|
||||
|
||||
def step(self, action: str | None) -> tuple[np.ndarray, float, bool]:
|
||||
"""Advance one turn with the given action.
|
||||
|
||||
action: a keystroke string (e.g. 'KEY_RIGHT') or None for no-op.
|
||||
Returns (observation, reward, done).
|
||||
Reward is the change in the reward state key since the previous step.
|
||||
"""
|
||||
"""Advance one turn. Returns (observation, reward, done)."""
|
||||
self.inp.press(action)
|
||||
self.game.step()
|
||||
obs = self._observe()
|
||||
@@ -49,12 +56,47 @@ class GameEnvironment:
|
||||
return obs, reward, done
|
||||
|
||||
def _observe(self) -> np.ndarray:
|
||||
state = dict(self.game.state)
|
||||
if self.observe_state_sizes:
|
||||
self._check_state_sizes(state)
|
||||
player_pos = None
|
||||
if self.egocentric and self.egocentric_player:
|
||||
agent = self.game.get_agent_by_name(self.egocentric_player)
|
||||
if agent is not None:
|
||||
player_pos = agent.position
|
||||
return encode_observation(
|
||||
self.view.board_characters,
|
||||
dict(self.game.state),
|
||||
state,
|
||||
self.metadata,
|
||||
self.observe_state,
|
||||
player_pos=player_pos,
|
||||
egocentric_radius=self.egocentric_radius,
|
||||
board=self.board,
|
||||
)
|
||||
|
||||
def _check_state_sizes(self, state: dict):
|
||||
for key, expected in self.observe_state_sizes.items():
|
||||
val = state.get(key)
|
||||
if val is None:
|
||||
actual = 0
|
||||
elif isinstance(val, (list, tuple)):
|
||||
actual = len(val)
|
||||
else:
|
||||
actual = 1
|
||||
if actual != expected:
|
||||
raise ValueError(
|
||||
f"State key '{key}' changed size during training:\n"
|
||||
f" Expected : {expected} (discovered at training start)\n"
|
||||
f" Got : {actual}\n\n"
|
||||
f"This means game.state['{key}'] has a different length in some\n"
|
||||
f"episodes than it had when training started. The neural network\n"
|
||||
f"has a fixed input size and cannot adapt to changing state shapes.\n\n"
|
||||
f"Fix: make sure create_game() always initializes '{key}' with a\n"
|
||||
f"fixed-length value before the game starts each episode.\n"
|
||||
f"For example, if '{key}' is a list of 9 values, it must always be\n"
|
||||
f"a list of exactly 9 values — never more, never fewer, never missing."
|
||||
)
|
||||
|
||||
def _delta_reward(self) -> float:
|
||||
current = float(self.game.state.get(self.metadata.reward, 0))
|
||||
delta = current - self._prev_reward
|
||||
@@ -62,10 +104,8 @@ class GameEnvironment:
|
||||
return delta
|
||||
|
||||
def discover_character_set(self, exploration_turns: int) -> list[str]:
|
||||
"""Run random turns to discover the characters that appear on the board.
|
||||
Returns the sorted character list (excluding space).
|
||||
"""
|
||||
obs = self.reset()
|
||||
"""Run random turns to discover the characters that appear on the board."""
|
||||
self.reset()
|
||||
chars: set[str] = set()
|
||||
for _ in range(exploration_turns):
|
||||
for row in self.view.board_characters:
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
from retro.game import Game
|
||||
from retro_gamer.examples.beast.board import Board
|
||||
|
||||
WIDTH = 40
|
||||
HEIGHT = 20
|
||||
NUM_BEASTS = 10
|
||||
|
||||
def create_game():
|
||||
"""Return a fresh, initialized Beast game."""
|
||||
board = Board(WIDTH, HEIGHT, num_beasts=NUM_BEASTS)
|
||||
state = {'beasts_killed': 0}
|
||||
game = Game(board.get_agents(), state, board_size=(WIDTH, HEIGHT))
|
||||
game.num_beasts = NUM_BEASTS
|
||||
return game
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_game().play()
|
||||
@@ -1,67 +0,0 @@
|
||||
from retro_gamer.examples.beast.helpers import add, distance, get_occupant
|
||||
from random import random, choice
|
||||
|
||||
class Beast:
|
||||
"""A beast that hunts the player."""
|
||||
character = "H"
|
||||
color = "red"
|
||||
probability_of_moving = 0.03
|
||||
probability_of_random_move = 0.2
|
||||
deadly = True
|
||||
|
||||
def __init__(self, position):
|
||||
self.position = position
|
||||
|
||||
def handle_push(self, vector, game):
|
||||
future_position = add(self.position, vector)
|
||||
on_board = game.on_board(future_position)
|
||||
obstacle = get_occupant(game, future_position)
|
||||
if obstacle or not on_board:
|
||||
self.die(game)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def play_turn(self, game):
|
||||
if self.should_move():
|
||||
possible_moves = []
|
||||
for position in self.get_adjacent_positions():
|
||||
if game.is_empty(position) and game.on_board(position):
|
||||
possible_moves.append(position)
|
||||
if possible_moves:
|
||||
if self.should_move_randomly():
|
||||
self.position = choice(possible_moves)
|
||||
else:
|
||||
self.position = self.choose_best_move(possible_moves, game)
|
||||
player = game.get_agent_by_name("player")
|
||||
if player and player.position == self.position:
|
||||
player.die(game)
|
||||
|
||||
def get_adjacent_positions(self):
|
||||
"""Returns all eight adjacent positions, including diagonals."""
|
||||
positions = []
|
||||
for i in [-1, 0, 1]:
|
||||
for j in [-1, 0, 1]:
|
||||
if i or j:
|
||||
positions.append(add(self.position, (i, j)))
|
||||
return positions
|
||||
|
||||
def should_move(self):
|
||||
return random() < self.probability_of_moving
|
||||
|
||||
def should_move_randomly(self):
|
||||
return random() < self.probability_of_random_move
|
||||
|
||||
def choose_best_move(self, possible_moves, game):
|
||||
player = game.get_agent_by_name("player")
|
||||
move_distances = [[distance(player.position, move), move] for move in possible_moves]
|
||||
shortest_distance, best_move = sorted(move_distances)[0]
|
||||
return best_move
|
||||
|
||||
def die(self, game):
|
||||
game.remove_agent(self)
|
||||
game.num_beasts -= 1
|
||||
game.state['beasts_killed'] += 1
|
||||
if game.num_beasts == 0:
|
||||
game.state["message"] = "You win!"
|
||||
game.end()
|
||||
@@ -1,25 +0,0 @@
|
||||
from retro_gamer.examples.beast.helpers import add, get_occupant
|
||||
|
||||
class Block:
|
||||
"""A static block that can be pushed by the player."""
|
||||
character = "█"
|
||||
color = "green4"
|
||||
deadly = False
|
||||
|
||||
def __init__(self, position):
|
||||
self.position = position
|
||||
|
||||
def handle_push(self, vector, game):
|
||||
"""Responds to a push in the direction of vector.
|
||||
Returns True when the push succeeds in creating empty space.
|
||||
"""
|
||||
future_position = add(self.position, vector)
|
||||
on_board = game.on_board(future_position)
|
||||
obstacle = get_occupant(game, future_position)
|
||||
if obstacle:
|
||||
success = obstacle.handle_push(vector, game)
|
||||
else:
|
||||
success = on_board
|
||||
if success:
|
||||
self.position = future_position
|
||||
return success
|
||||
@@ -1,39 +0,0 @@
|
||||
from retro_gamer.examples.beast.helpers import add, get_occupant
|
||||
|
||||
direction_vectors = {
|
||||
"KEY_RIGHT": (1, 0),
|
||||
"KEY_UP": (0, -1),
|
||||
"KEY_LEFT": (-1, 0),
|
||||
"KEY_DOWN": (0, 1),
|
||||
}
|
||||
|
||||
class Player:
|
||||
character = "*"
|
||||
color = "white"
|
||||
name = "player"
|
||||
deadly = False
|
||||
|
||||
def __init__(self, position):
|
||||
self.position = position
|
||||
|
||||
def handle_keystroke(self, keystroke, game):
|
||||
if keystroke.name in direction_vectors:
|
||||
vector = direction_vectors[keystroke.name]
|
||||
self.try_to_move(vector, game)
|
||||
|
||||
def try_to_move(self, vector, game):
|
||||
future_position = add(self.position, vector)
|
||||
on_board = game.on_board(future_position)
|
||||
obstacle = get_occupant(game, future_position)
|
||||
if obstacle:
|
||||
if obstacle.deadly:
|
||||
self.die(game)
|
||||
elif obstacle.handle_push(vector, game):
|
||||
self.position = future_position
|
||||
elif on_board:
|
||||
self.position = future_position
|
||||
|
||||
def die(self, game):
|
||||
self.color = "black_on_red"
|
||||
game.state["message"] = "The beasties win!"
|
||||
game.end()
|
||||
@@ -1,44 +0,0 @@
|
||||
from random import shuffle
|
||||
from retro_gamer.examples.beast.agents.player import Player
|
||||
from retro_gamer.examples.beast.agents.beast import Beast
|
||||
from retro_gamer.examples.beast.agents.block import Block
|
||||
|
||||
class Board:
|
||||
"""Creates the agents needed at the beginning of the game and assigns their positions."""
|
||||
|
||||
def __init__(self, width, height, block_density=0.3, num_beasts=10):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.block_density = block_density
|
||||
self.num_blocks = round(width * height * block_density)
|
||||
self.num_empty_spaces = width * height - self.num_blocks
|
||||
self.num_beasts = num_beasts
|
||||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
if self.block_density < 0 or self.block_density > 1:
|
||||
raise ValueError("block density must be between 0 and 1.")
|
||||
if self.num_empty_spaces < self.num_beasts + 1:
|
||||
raise ValueError("Not enough space on the board.")
|
||||
|
||||
def get_agents(self):
|
||||
"""Returns a list of agents initialized in their starting positions."""
|
||||
positions = self.get_all_positions()
|
||||
shuffle(positions)
|
||||
|
||||
player_position = positions[0]
|
||||
beast_positions = positions[1:self.num_beasts + 1]
|
||||
block_positions = positions[-self.num_blocks:]
|
||||
|
||||
player = [Player(player_position)]
|
||||
beasts = [Beast(pos) for pos in beast_positions]
|
||||
blocks = [Block(pos) for pos in block_positions]
|
||||
return player + beasts + blocks
|
||||
|
||||
def get_all_positions(self):
|
||||
"""Returns a list of all positions on the board."""
|
||||
positions = []
|
||||
for i in range(self.width):
|
||||
for j in range(self.height):
|
||||
positions.append((i, j))
|
||||
return positions
|
||||
@@ -1,18 +0,0 @@
|
||||
def add(vec0, vec1):
|
||||
"""Adds two vectors."""
|
||||
x0, y0 = vec0
|
||||
x1, y1 = vec1
|
||||
return (x0 + x1, y0 + y1)
|
||||
|
||||
def get_occupant(game, position):
|
||||
"""Returns the agent at position, if there is one."""
|
||||
positions_with_agents = game.get_agents_by_position()
|
||||
if position in positions_with_agents:
|
||||
agents_at_position = positions_with_agents[position]
|
||||
return agents_at_position[0]
|
||||
|
||||
def distance(vec0, vec1):
|
||||
"""Returns the Manhattan distance between two positions."""
|
||||
x0, y0 = vec0
|
||||
x1, y1 = vec1
|
||||
return abs(x1 - x0) + abs(y1 - y0)
|
||||
@@ -1,6 +0,0 @@
|
||||
[tool.retro-gamer]
|
||||
actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]
|
||||
reward = "beasts_killed"
|
||||
character_set = ["*", "H", "█"]
|
||||
spatial = true
|
||||
observe_state = []
|
||||
30
retro_gamer/log_parser.py
Normal file
30
retro_gamer/log_parser.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
_LINE_RE = re.compile(
|
||||
r'\[ep_(\d+)\]'
|
||||
r'.*avg_reward=([+-]?\d+\.?\d*)'
|
||||
r'.*avg_steps=(\d+\.?\d*)'
|
||||
r'.*epsilon=(\d+\.?\d*)'
|
||||
r'.*avg_loss=(\d+\.?\d*)'
|
||||
)
|
||||
|
||||
|
||||
def parse_checkpoints(log_path: Path) -> list[dict]:
|
||||
"""Parse checkpoint lines from a training log. Returns a list of dicts
|
||||
with keys: episode, avg_reward, avg_steps, epsilon, avg_loss."""
|
||||
results = []
|
||||
if not log_path.exists():
|
||||
return results
|
||||
for line in log_path.read_text().splitlines():
|
||||
m = _LINE_RE.search(line)
|
||||
if m:
|
||||
results.append({
|
||||
'episode': int(m.group(1)),
|
||||
'avg_reward': float(m.group(2)),
|
||||
'avg_steps': float(m.group(3)),
|
||||
'epsilon': float(m.group(4)),
|
||||
'avg_loss': float(m.group(5)),
|
||||
})
|
||||
return results
|
||||
@@ -11,39 +11,58 @@ class GameMetadata:
|
||||
"""Describes a retro game for training purposes.
|
||||
|
||||
Required fields: actions, reward.
|
||||
Optional fields: character_set, spatial, observe_state.
|
||||
Discovered fields: board_size (read from game.board_size at startup).
|
||||
|
||||
Metadata is read from the game's own pyproject.toml under the
|
||||
[tool.retro-gamer] section via GameMetadata.from_pyproject().
|
||||
Optional fields: character_set, spatial.
|
||||
Discovered fields: board_size (from game.board_size), extras_size (from
|
||||
the observe_state list in [preprocessing]).
|
||||
"""
|
||||
actions: list[str]
|
||||
reward: str
|
||||
character_set: list[str] | None = None
|
||||
spatial: bool = True
|
||||
observe_state: list[str] = field(default_factory=list)
|
||||
spatial: bool = False
|
||||
board: bool = True
|
||||
board_size: tuple[int, int] | None = None
|
||||
extras_size: int = 0
|
||||
|
||||
def validate(self):
|
||||
if not self.actions:
|
||||
raise ValueError("actions must be a non-empty list")
|
||||
raise ValueError(
|
||||
"The 'actions' list in [tool.retro-gamer] is empty or missing.\n"
|
||||
"It should list the keyboard keys your agent can press, for example:\n\n"
|
||||
' actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]\n\n'
|
||||
"The agent will learn which actions lead to higher rewards."
|
||||
)
|
||||
if not isinstance(self.actions, list) or not all(isinstance(a, str) for a in self.actions):
|
||||
raise ValueError(
|
||||
f"'actions' must be a list of strings, but got: {self.actions!r}\n"
|
||||
"Each entry should be a key name like \"KEY_RIGHT\" or \"KEY_SPACE\"."
|
||||
)
|
||||
if not self.reward:
|
||||
raise ValueError("reward must be a non-empty string")
|
||||
if self.reward in self.observe_state:
|
||||
raise ValueError(f"reward key '{self.reward}' must not appear in observe_state")
|
||||
raise ValueError(
|
||||
"The 'reward' field in [tool.retro-gamer] is empty or missing.\n"
|
||||
"It should name a game state variable whose value the agent is trying\n"
|
||||
"to maximize — for example:\n\n"
|
||||
" reward = \"score\"\n\n"
|
||||
"The trainer watches how this value changes each step and uses those\n"
|
||||
"changes as the reward signal."
|
||||
)
|
||||
if self.character_set is not None:
|
||||
if not isinstance(self.character_set, list):
|
||||
raise ValueError(
|
||||
f"'character_set' must be a list of single characters, but got: {self.character_set!r}\n"
|
||||
"Example: character_set = [\"@\", \"*\", \"#\"]"
|
||||
)
|
||||
for ch in self.character_set:
|
||||
if len(ch) != 1:
|
||||
raise ValueError(f"character_set entries must be single characters, got {ch!r}")
|
||||
if not isinstance(ch, str) or len(ch) != 1:
|
||||
raise ValueError(
|
||||
f"Every entry in character_set must be a single character, but got {ch!r}.\n"
|
||||
"Each character represents one type of cell on the game board.\n"
|
||||
"If you're not sure what characters your game uses, remove character_set\n"
|
||||
"entirely and the trainer will discover them automatically."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pyproject(cls, module_name: str) -> GameMetadata:
|
||||
"""Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml.
|
||||
|
||||
Finds pyproject.toml by walking up from the module's source file.
|
||||
Raises FileNotFoundError if no pyproject.toml is found, or ValueError
|
||||
if the file exists but has no [tool.retro-gamer] section.
|
||||
"""
|
||||
"""Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml."""
|
||||
pyproject_path = _find_pyproject(module_name)
|
||||
if pyproject_path is None:
|
||||
raise FileNotFoundError(
|
||||
@@ -65,13 +84,23 @@ class GameMetadata:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> GameMetadata:
|
||||
missing = [k for k in ('actions', 'reward') if k not in d]
|
||||
if missing:
|
||||
fields = ' and '.join(f"'{k}'" for k in missing)
|
||||
raise ValueError(
|
||||
f"The [tool.retro-gamer] section is missing required {fields}.\n"
|
||||
"A minimal configuration looks like this:\n\n"
|
||||
"[tool.retro-gamer]\n"
|
||||
'actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]\n'
|
||||
'reward = "score"\n\n'
|
||||
"See the documentation for all available options."
|
||||
)
|
||||
board_size = tuple(d['board_size']) if 'board_size' in d else None
|
||||
return cls(
|
||||
actions=d['actions'],
|
||||
reward=d['reward'],
|
||||
character_set=d.get('character_set'),
|
||||
spatial=d.get('spatial', True),
|
||||
observe_state=d.get('observe_state', []),
|
||||
spatial=d.get('spatial', False),
|
||||
board_size=board_size,
|
||||
)
|
||||
|
||||
@@ -79,8 +108,6 @@ class GameMetadata:
|
||||
d = {
|
||||
'actions': self.actions,
|
||||
'reward': self.reward,
|
||||
'spatial': self.spatial,
|
||||
'observe_state': self.observe_state,
|
||||
}
|
||||
if self.board_size is not None:
|
||||
d['board_size'] = list(self.board_size)
|
||||
@@ -95,9 +122,11 @@ class GameMetadata:
|
||||
@property
|
||||
def obs_size(self) -> int:
|
||||
"""Total size of the flat observation vector."""
|
||||
if not self.board:
|
||||
return self.extras_size
|
||||
C = len(self.character_set) if self.character_set else 0
|
||||
bw, bh = self.board_size
|
||||
return C * bw * bh + len(self.observe_state)
|
||||
return C * bw * bh + self.extras_size
|
||||
|
||||
@property
|
||||
def n_actions(self) -> int:
|
||||
@@ -118,4 +147,6 @@ def _find_pyproject(module_name: str) -> Path | None:
|
||||
candidate = parent / 'pyproject.toml'
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
if parent.name == 'site-packages':
|
||||
break
|
||||
return None
|
||||
|
||||
139
retro_gamer/model_agent.py
Normal file
139
retro_gamer/model_agent.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
import tomllib
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from blessed.keyboard import Keystroke
|
||||
from retro.input import ProgrammaticInput
|
||||
from retro.views.headless import HeadlessView
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
from retro_gamer.observation import encode_observation
|
||||
|
||||
|
||||
class TrainedPolicy:
|
||||
"""A trained retro-gamer model that can observe a game and choose actions.
|
||||
|
||||
Load from a training run directory, then call ``get_action(game)`` from
|
||||
inside any agent's ``play_turn`` to get the model's recommended key.
|
||||
|
||||
Example::
|
||||
|
||||
from retro_gamer import TrainedPolicy
|
||||
|
||||
_ai = TrainedPolicy("runs/enemy/")
|
||||
|
||||
class EnemyAgent:
|
||||
def play_turn(self, game):
|
||||
key = _ai.get_action(game)
|
||||
if key == 'KEY_RIGHT': self.direction = (1, 0)
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, run_dir: str | Path, checkpoint: str | None = None):
|
||||
from retro_gamer.network import build_network
|
||||
from retro_gamer.trainer import DEFAULTS
|
||||
|
||||
run_dir = Path(run_dir)
|
||||
config_path = run_dir / 'config.toml'
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"No config.toml found in {run_dir}")
|
||||
|
||||
with open(config_path, 'rb') as f:
|
||||
config = tomllib.load(f)
|
||||
|
||||
self._metadata = GameMetadata.from_dict(config['metadata'])
|
||||
|
||||
pre = config.get('preprocessing', {})
|
||||
self._metadata.spatial = pre.get('spatial', False)
|
||||
self._metadata.board = pre.get('board', True)
|
||||
observe_state_sizes = pre.get('observe_state_sizes', {})
|
||||
self._observe_state: list[str] = pre.get('observe_state', [])
|
||||
self._egocentric: bool = pre.get('egocentric', False)
|
||||
self._egocentric_player: str | None = pre.get('egocentric_player')
|
||||
self._egocentric_radius: int | None = pre.get('egocentric_radius')
|
||||
self._board: bool = pre.get('board', True)
|
||||
|
||||
if observe_state_sizes:
|
||||
self._metadata.extras_size = sum(observe_state_sizes.values())
|
||||
else:
|
||||
self._metadata.extras_size = len(self._observe_state)
|
||||
|
||||
hyperparams = {**DEFAULTS, **config.get('model', {}), **config.get('training', {})}
|
||||
self._model, _ = build_network(self._metadata, hyperparams)
|
||||
|
||||
if checkpoint is not None:
|
||||
ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt'
|
||||
ckpt_path = run_dir / 'checkpoints' / ckpt_name
|
||||
if not ckpt_path.exists():
|
||||
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
||||
else:
|
||||
ckpt_dir = run_dir / 'checkpoints'
|
||||
candidates = sorted(ckpt_dir.glob('ep_*.pt')) if ckpt_dir.exists() else []
|
||||
if not candidates:
|
||||
raise FileNotFoundError(f"No checkpoints found in {ckpt_dir}")
|
||||
ckpt_path = candidates[-1]
|
||||
|
||||
ckpt = torch.load(ckpt_path, weights_only=True)
|
||||
self._model.load_state_dict(ckpt['model_state_dict'])
|
||||
self._model.eval()
|
||||
|
||||
def get_action(self, game) -> str | None:
|
||||
"""Return the key the model recommends this turn, or None for no-op."""
|
||||
view = HeadlessView()
|
||||
view.on_game_start(game)
|
||||
view.render(game)
|
||||
board_chars = view.board_characters
|
||||
|
||||
player_pos = None
|
||||
if self._egocentric and self._egocentric_player:
|
||||
agent = game.get_agent_by_name(self._egocentric_player)
|
||||
if agent is not None:
|
||||
player_pos = agent.position
|
||||
|
||||
obs = encode_observation(
|
||||
board_chars,
|
||||
dict(game.state),
|
||||
self._metadata,
|
||||
self._observe_state,
|
||||
player_pos=player_pos,
|
||||
egocentric_radius=self._egocentric_radius,
|
||||
board=self._board,
|
||||
)
|
||||
|
||||
device = next(self._model.parameters()).device
|
||||
state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
action_idx = int(self._model(state_t).argmax().item())
|
||||
if action_idx >= len(self._metadata.actions):
|
||||
return None
|
||||
return self._metadata.actions[action_idx]
|
||||
|
||||
|
||||
def _keystroke(name: str) -> Keystroke:
|
||||
if name.startswith("KEY_"):
|
||||
return Keystroke(ucs='', code=None, name=name)
|
||||
return Keystroke(ucs=name, code=None, name=None)
|
||||
|
||||
|
||||
class PolicyInput:
|
||||
"""An InputSource that drives the game with a TrainedPolicy instead of the keyboard.
|
||||
|
||||
Pass it as ``input_source`` to ``game.play()`` and everything else works
|
||||
exactly as usual.
|
||||
|
||||
Example::
|
||||
|
||||
from retro_gamer import TrainedPolicy, PolicyInput
|
||||
ai = TrainedPolicy("runs/snake/")
|
||||
game = create_game()
|
||||
game.play(input_source=PolicyInput(ai, game))
|
||||
"""
|
||||
|
||||
def __init__(self, model: TrainedPolicy, game):
|
||||
self._model = model
|
||||
self._game = game
|
||||
self._inp = ProgrammaticInput()
|
||||
|
||||
def collect(self) -> set:
|
||||
key = self._model.get_action(self._game)
|
||||
self._inp.press(key)
|
||||
return self._inp.collect()
|
||||
@@ -13,32 +13,36 @@ def build_network(
|
||||
Returns (model, rationale) where rationale is a multi-line string
|
||||
describing the architecture and the reasoning behind each choice.
|
||||
"""
|
||||
n_layers = hyperparams.get('n_layers', 2)
|
||||
layer_size = hyperparams.get('layer_size', 128)
|
||||
C = len(metadata.character_set)
|
||||
bw, bh = metadata.board_size
|
||||
W, H = bw, bh
|
||||
n_state = len(metadata.observe_state)
|
||||
hidden_sizes = hyperparams.get('hidden_sizes', [512, 256])
|
||||
n_state = metadata.extras_size
|
||||
n_actions = metadata.n_actions
|
||||
|
||||
lines = []
|
||||
lines.append("[INIT] === Network Architecture ===")
|
||||
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
|
||||
lines.append(f"[INIT] Observed state keys: {n_state} | Actions (incl. no-op): {n_actions}")
|
||||
|
||||
if metadata.spatial:
|
||||
model = _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines)
|
||||
if metadata.board:
|
||||
C = len(metadata.character_set)
|
||||
bw, bh = metadata.board_size
|
||||
W, H = bw, bh
|
||||
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
|
||||
lines.append(f"[INIT] Observed state features: {n_state} | Actions (incl. no-op): {n_actions}")
|
||||
if metadata.spatial:
|
||||
model = _build_spatial(C, H, W, n_state, hidden_sizes, n_actions, lines)
|
||||
else:
|
||||
obs_size = C * W * H + n_state
|
||||
model = _build_flat(obs_size, hidden_sizes, n_actions, lines)
|
||||
else:
|
||||
obs_size = C * W * H + n_state
|
||||
model = _build_flat(obs_size, n_layers, layer_size, n_actions, lines)
|
||||
lines.append(f"[INIT] Board: disabled (board=false, state-only observation)")
|
||||
lines.append(f"[INIT] Observed state features: {n_state} | Actions (incl. no-op): {n_actions}")
|
||||
model = _build_flat(n_state, hidden_sizes, n_actions, lines)
|
||||
|
||||
lines.append(f"[INIT] Hidden layers: {n_layers} | Layer width: {layer_size}")
|
||||
lines.append(f"[INIT] Hidden layers: {len(hidden_sizes)} | Layer sizes: {hidden_sizes}")
|
||||
lines.append(f"[INIT] Output: {n_actions} Q-values")
|
||||
lines.append(f"[INIT] Actions: {metadata.actions} + (no-op)")
|
||||
return model, '\n'.join(lines)
|
||||
|
||||
|
||||
def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
|
||||
def _build_spatial(C, H, W, n_state, hidden_sizes, n_actions, lines):
|
||||
lines.append("[INIT] spatial=True → using CNN architecture")
|
||||
lines.append("[INIT] Rationale: the board is a 2-D spatial scene; a CNN captures")
|
||||
lines.append("[INIT] local patterns (walls, items nearby) more efficiently than an MLP.")
|
||||
@@ -47,20 +51,20 @@ def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
|
||||
lines.append(f"[INIT] CNN output: 64 channels × {H}×{W} = {conv_out} features (flattened)")
|
||||
mlp_in = conv_out + n_state
|
||||
lines.append(f"[INIT] MLP head input: {conv_out} (conv) + {n_state} (state) = {mlp_in}")
|
||||
lines.append(f"[INIT] MLP: {' → '.join([str(mlp_in)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
|
||||
return _SpatialNet(C, H, W, n_state, n_layers, layer_size, n_actions)
|
||||
lines.append(f"[INIT] MLP: {' → '.join([str(mlp_in)] + [str(s) for s in hidden_sizes] + [str(n_actions)])}")
|
||||
return _SpatialNet(C, H, W, n_state, hidden_sizes, n_actions)
|
||||
|
||||
|
||||
def _build_flat(obs_size, n_layers, layer_size, n_actions, lines):
|
||||
def _build_flat(obs_size, hidden_sizes, n_actions, lines):
|
||||
lines.append("[INIT] spatial=False → using MLP architecture")
|
||||
lines.append("[INIT] Rationale: the board encodes UI/status rather than a spatial scene;")
|
||||
lines.append("[INIT] a flat MLP over the full observation is sufficient.")
|
||||
lines.append(f"[INIT] MLP: {' → '.join([str(obs_size)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
|
||||
return _FlatNet(obs_size, n_layers, layer_size, n_actions)
|
||||
lines.append(f"[INIT] MLP: {' → '.join([str(obs_size)] + [str(s) for s in hidden_sizes] + [str(n_actions)])}")
|
||||
return _FlatNet(obs_size, hidden_sizes, n_actions)
|
||||
|
||||
|
||||
class _SpatialNet(nn.Module):
|
||||
def __init__(self, C, H, W, n_state, n_layers, layer_size, n_actions):
|
||||
def __init__(self, C, H, W, n_state, hidden_sizes, n_actions):
|
||||
super().__init__()
|
||||
self.C, self.H, self.W = C, H, W
|
||||
self.n_board = C * H * W
|
||||
@@ -73,10 +77,11 @@ class _SpatialNet(nn.Module):
|
||||
conv_out = 64 * H * W
|
||||
mlp_in = conv_out + n_state
|
||||
layers: list[nn.Module] = []
|
||||
for i in range(n_layers):
|
||||
in_size = mlp_in if i == 0 else layer_size
|
||||
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
|
||||
layers.append(nn.Linear(layer_size, n_actions))
|
||||
prev = mlp_in
|
||||
for size in hidden_sizes:
|
||||
layers += [nn.Linear(prev, size), nn.ReLU()]
|
||||
prev = size
|
||||
layers.append(nn.Linear(prev, n_actions))
|
||||
self.mlp = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -87,13 +92,14 @@ class _SpatialNet(nn.Module):
|
||||
|
||||
|
||||
class _FlatNet(nn.Module):
|
||||
def __init__(self, obs_size, n_layers, layer_size, n_actions):
|
||||
def __init__(self, obs_size, hidden_sizes, n_actions):
|
||||
super().__init__()
|
||||
layers: list[nn.Module] = []
|
||||
for i in range(n_layers):
|
||||
in_size = obs_size if i == 0 else layer_size
|
||||
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
|
||||
layers.append(nn.Linear(layer_size, n_actions))
|
||||
prev = obs_size
|
||||
for size in hidden_sizes:
|
||||
layers += [nn.Linear(prev, size), nn.ReLU()]
|
||||
prev = size
|
||||
layers.append(nn.Linear(prev, n_actions))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -3,11 +3,7 @@ from retro_gamer.metadata import GameMetadata
|
||||
|
||||
|
||||
def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
|
||||
"""One-hot encode the board.
|
||||
|
||||
Returns an array of shape (H, W, C) where C = len(character_set).
|
||||
Unknown characters produce a zero vector.
|
||||
"""
|
||||
"""One-hot encode the board. Returns (H, W, C). Unknown characters → zero vector."""
|
||||
char_to_idx = {c: i for i, c in enumerate(character_set)}
|
||||
H = len(board_chars)
|
||||
W = len(board_chars[0]) if board_chars else 0
|
||||
@@ -22,28 +18,78 @@ def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.n
|
||||
|
||||
|
||||
def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
|
||||
"""Extract observed state keys into a 1D float array."""
|
||||
return np.array([float(state.get(k, 0)) for k in observe_state], dtype=np.float32)
|
||||
"""Extract selected keys from game.state into a 1D float array.
|
||||
|
||||
Scalar values contribute one element; list/tuple values are flattened.
|
||||
"""
|
||||
values: list[float] = []
|
||||
for k in observe_state:
|
||||
val = state.get(k, 0)
|
||||
if isinstance(val, (list, tuple)):
|
||||
values.extend(float(x) for x in val)
|
||||
else:
|
||||
values.append(float(val))
|
||||
return np.array(values, dtype=np.float32)
|
||||
|
||||
|
||||
def egocentric_board(
|
||||
board_chars: list[list[str]],
|
||||
player_pos: tuple[int, int],
|
||||
radius: int,
|
||||
) -> list[list[str]]:
|
||||
"""Crop the board to a (2r+1)×(2r+1) window centred on player_pos.
|
||||
|
||||
Out-of-bounds cells are filled with a space (treated as empty by the
|
||||
encoder). The resulting grid is always square with side 2*radius+1.
|
||||
"""
|
||||
H = len(board_chars)
|
||||
W = len(board_chars[0]) if board_chars else 0
|
||||
px, py = player_pos
|
||||
result = []
|
||||
for dy in range(-radius, radius + 1):
|
||||
row = []
|
||||
for dx in range(-radius, radius + 1):
|
||||
src_x = px + dx
|
||||
src_y = py + dy
|
||||
if 0 <= src_x < W and 0 <= src_y < H:
|
||||
row.append(board_chars[src_y][src_x])
|
||||
else:
|
||||
row.append(' ')
|
||||
result.append(row)
|
||||
return result
|
||||
|
||||
|
||||
def encode_observation(
|
||||
board_chars: list[list[str]],
|
||||
state: dict,
|
||||
metadata: GameMetadata,
|
||||
observe_state: list[str],
|
||||
player_pos: tuple[int, int] | None = None,
|
||||
egocentric_radius: int | None = None,
|
||||
board: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""Encode board + state into a flat 1D observation vector.
|
||||
"""Encode board and/or selected state values into a flat 1D observation vector.
|
||||
|
||||
For spatial games the board is encoded channel-first (C, H, W) then flattened,
|
||||
so the network can reshape it back for CNN processing. For non-spatial games the
|
||||
board is encoded (H, W, C) then flattened.
|
||||
The state vector is appended at the end in both cases.
|
||||
When *board* is True the board is encoded and prepended to the vector. If
|
||||
player_pos and egocentric_radius are given the board is first cropped to a
|
||||
(2r+1)×(2r+1) window centred on the player. For spatial games the board is
|
||||
encoded channel-first (C, H, W) then flattened; for non-spatial games it is
|
||||
encoded (H, W, C) then flattened. The state vector is appended at the end.
|
||||
|
||||
When *board* is False only the observe_state features are returned.
|
||||
"""
|
||||
if not metadata.character_set:
|
||||
raise ValueError("character_set must be set before encoding observations")
|
||||
board = encode_board(board_chars, metadata.character_set) # (H, W, C)
|
||||
if metadata.spatial:
|
||||
board_vec = board.transpose(2, 0, 1).flatten() # C*H*W, channel-first
|
||||
if board:
|
||||
if not metadata.character_set:
|
||||
raise ValueError("character_set must be set before encoding observations")
|
||||
if player_pos is not None and egocentric_radius is not None:
|
||||
board_chars = egocentric_board(board_chars, player_pos, egocentric_radius)
|
||||
board_enc = encode_board(board_chars, metadata.character_set) # (H, W, C)
|
||||
if metadata.spatial:
|
||||
board_vec = board_enc.transpose(2, 0, 1).flatten()
|
||||
else:
|
||||
board_vec = board_enc.flatten()
|
||||
if observe_state:
|
||||
return np.concatenate([board_vec, encode_state(state, observe_state)])
|
||||
return board_vec
|
||||
else:
|
||||
board_vec = board.flatten() # H*W*C
|
||||
state_vec = encode_state(state, metadata.observe_state)
|
||||
return np.concatenate([board_vec, state_vec])
|
||||
return encode_state(state, observe_state)
|
||||
|
||||
55
retro_gamer/plotter.py
Normal file
55
retro_gamer/plotter.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from retro_gamer.log_parser import parse_checkpoints
|
||||
|
||||
|
||||
def plot_run(log_path: Path, output: Path | None = None) -> None:
|
||||
"""Generate training metric plots from a training.log file.
|
||||
|
||||
Displays an interactive window unless *output* is given, in which case
|
||||
the figure is saved to that path (PNG, PDF, SVG, etc.).
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
data = parse_checkpoints(log_path)
|
||||
if not data:
|
||||
raise ValueError(f"No checkpoint data found in {log_path}")
|
||||
|
||||
episodes = [d['episode'] for d in data]
|
||||
rewards = [d['avg_reward'] for d in data]
|
||||
steps = [d['avg_steps'] for d in data]
|
||||
losses = [d['avg_loss'] for d in data]
|
||||
epsilons = [d['epsilon'] for d in data]
|
||||
|
||||
sns.set_theme(style='darkgrid')
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 7))
|
||||
(ax_reward, ax_steps), (ax_loss, ax_epsilon) = axes
|
||||
|
||||
ax_reward.plot(episodes, rewards)
|
||||
ax_reward.axhline(0, color='gray', linestyle='--', linewidth=0.8, alpha=0.6)
|
||||
ax_reward.set_title('Average Reward')
|
||||
ax_reward.set_xlabel('Episode')
|
||||
|
||||
ax_steps.plot(episodes, steps, color='C1')
|
||||
ax_steps.set_title('Average Steps')
|
||||
ax_steps.set_xlabel('Episode')
|
||||
|
||||
ax_loss.plot(episodes, losses, color='C2')
|
||||
ax_loss.set_yscale('log')
|
||||
ax_loss.set_title('Average Loss')
|
||||
ax_loss.set_xlabel('Episode')
|
||||
|
||||
ax_epsilon.plot(episodes, epsilons, color='C3')
|
||||
ax_epsilon.set_title('Epsilon (exploration rate)')
|
||||
ax_epsilon.set_xlabel('Episode')
|
||||
ax_epsilon.set_ylim(0, 1)
|
||||
|
||||
fig.suptitle(f'Training: {log_path.parent.name}', fontsize=13)
|
||||
plt.tight_layout()
|
||||
|
||||
if output:
|
||||
plt.savefig(output, dpi=150, bbox_inches='tight')
|
||||
print(f"Plot saved to {output}")
|
||||
else:
|
||||
plt.show()
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import random
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from typing import Callable
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -8,40 +10,250 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import tomli_w
|
||||
|
||||
from tqdm import tqdm
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
from retro_gamer.env import GameEnvironment
|
||||
from retro_gamer.network import build_network
|
||||
from retro_gamer.memory import ReplayMemory, PrioritizedReplayMemory
|
||||
|
||||
MODEL_KEYS: frozenset = frozenset({'hidden_sizes'})
|
||||
|
||||
DEFAULTS: dict = {
|
||||
'learning_rate': 1e-3,
|
||||
'lr_decay': 0.995,
|
||||
# [model]
|
||||
'hidden_sizes': [128, 64],
|
||||
# [training]
|
||||
'learning_rate': 1e-4,
|
||||
'learning_rate_decay': 0.9999,
|
||||
'gamma': 0.99,
|
||||
'epsilon': 1.0,
|
||||
'epsilon_decay': 0.995,
|
||||
'epsilon_decay': 0.9997,
|
||||
'epsilon_min': 0.05,
|
||||
'batch_size': 64,
|
||||
'memory_capacity': 10_000,
|
||||
'target_update_freq': 100,
|
||||
'training_episodes': 1_000,
|
||||
'n_layers': 2,
|
||||
'layer_size': 128,
|
||||
'prioritize_experiences': False,
|
||||
'memory_capacity': 50_000,
|
||||
'target_update_freq': 500,
|
||||
'train_every': 4,
|
||||
'training_episodes': 20_000,
|
||||
'prioritize_experiences': True,
|
||||
'exploration_turns': 200,
|
||||
'unknown_character_strategy': 'ignore',
|
||||
'max_turns_per_episode': 2_000,
|
||||
}
|
||||
|
||||
|
||||
def _get_device() -> torch.device:
|
||||
if torch.backends.mps.is_available():
|
||||
return torch.device('mps')
|
||||
if torch.cuda.is_available():
|
||||
return torch.device('cuda')
|
||||
return torch.device('cpu')
|
||||
|
||||
# Fields that make an existing checkpoint incompatible with the current config.
|
||||
# Changing any of these requires starting training from scratch.
|
||||
_INCOMPATIBLE_METADATA = {
|
||||
'actions': 'the list of actions the agent can take (changes output layer size)',
|
||||
'reward': 'the reward signal — Q-values trained on the old signal are meaningless for the new one',
|
||||
'character_set': 'the set of board characters (changes input layer size)',
|
||||
'board_size': 'the board dimensions (changes input layer size)',
|
||||
}
|
||||
_INCOMPATIBLE_PREPROCESSING = {
|
||||
'spatial': 'spatial vs non-spatial network type (changes network architecture)',
|
||||
'board': 'whether the board is included in the observation (changes input size)',
|
||||
'observe_state': 'the state keys included in the observation (changes input size)',
|
||||
'observe_state_sizes': 'the size of each observed state key (changes input layer size)',
|
||||
'egocentric': 'egocentric board transformation (changes input representation)',
|
||||
'egocentric_player': 'the agent used as the egocentric center (changes input representation)',
|
||||
'egocentric_radius': 'the egocentric crop radius (changes input layer size)',
|
||||
}
|
||||
_INCOMPATIBLE_ARCH = {
|
||||
'hidden_sizes': 'the hidden layer sizes (changes network shape)',
|
||||
}
|
||||
|
||||
|
||||
def validate_hyperparams(hp: dict):
|
||||
"""Check all hyperparameters and raise ValueError listing every problem found."""
|
||||
problems = []
|
||||
|
||||
def _problem(heading: str, explanation: str, fix: str | None = None):
|
||||
text = f" {heading}\n {explanation}"
|
||||
if fix:
|
||||
text += f"\n → {fix}"
|
||||
problems.append(text)
|
||||
|
||||
hs = hp.get('hidden_sizes')
|
||||
if not isinstance(hs, list) or len(hs) == 0:
|
||||
_problem(
|
||||
f"hidden_sizes = {hs!r}",
|
||||
"This sets the shape of the neural network's hidden layers. It must be a\n"
|
||||
" non-empty list of positive integers, one number per layer.",
|
||||
"Try: hidden_sizes = [512, 256]",
|
||||
)
|
||||
elif any(not isinstance(s, int) or s <= 0 for s in hs):
|
||||
bad = [s for s in hs if not isinstance(s, int) or s <= 0]
|
||||
_problem(
|
||||
f"hidden_sizes = {hs!r}",
|
||||
f"Every layer size must be a positive integer, but got {bad!r}.\n"
|
||||
" Each number is the count of neurons in that layer — more neurons means\n"
|
||||
" more capacity to learn complex patterns.",
|
||||
"Try: hidden_sizes = [512, 256]",
|
||||
)
|
||||
|
||||
lr = hp.get('learning_rate')
|
||||
if not isinstance(lr, (int, float)) or lr <= 0:
|
||||
_problem(
|
||||
f"learning_rate = {lr!r}",
|
||||
"The learning rate controls how much the network adjusts its weights after\n"
|
||||
" each training step. Too high and training becomes unstable; too low and\n"
|
||||
" it learns very slowly. It must be a positive number.",
|
||||
"Typical values are between 0.0001 and 0.01. Try: learning_rate = 0.001",
|
||||
)
|
||||
|
||||
for key, blurb in [
|
||||
('learning_rate_decay',
|
||||
"After each episode, the learning rate is multiplied by this value, gradually\n"
|
||||
" slowing learning over time. A value of 1.0 means no decay; closer to 0\n"
|
||||
" means very aggressive decay. It must be greater than 0 and at most 1."),
|
||||
('gamma',
|
||||
"Gamma is the discount factor: how much the agent values future rewards versus\n"
|
||||
" immediate ones. 0.99 means future rewards are nearly as important as now;\n"
|
||||
" 0.0 means the agent only cares about the very next step.\n"
|
||||
" It must be greater than 0 and at most 1."),
|
||||
('epsilon_decay',
|
||||
"Each episode, the exploration rate (epsilon) is multiplied by this to gradually\n"
|
||||
" reduce random actions over time. It must be between 0 (exclusive) and 1."),
|
||||
]:
|
||||
v = hp.get(key)
|
||||
if not isinstance(v, (int, float)) or not (0 < v <= 1):
|
||||
_problem(
|
||||
f"{key} = {v!r}",
|
||||
blurb,
|
||||
f"Try a value close to but less than 1, like {key} = 0.995",
|
||||
)
|
||||
|
||||
for key, blurb in [
|
||||
('epsilon',
|
||||
"Epsilon is the probability of taking a random action (exploration vs.\n"
|
||||
" exploitation). It starts high — usually 1.0, meaning fully random —\n"
|
||||
" and decays toward epsilon_min during training. It must be between 0 and 1."),
|
||||
('epsilon_min',
|
||||
"This is the lowest exploration rate allowed. Even after lots of training, the\n"
|
||||
" agent keeps at least this much randomness so it keeps discovering new things.\n"
|
||||
" It must be between 0 and 1."),
|
||||
]:
|
||||
v = hp.get(key)
|
||||
if not isinstance(v, (int, float)) or not (0 <= v <= 1):
|
||||
_problem(
|
||||
f"{key} = {v!r}",
|
||||
blurb,
|
||||
f"Try: {key} = {'0.05' if 'min' in key else '1.0'}",
|
||||
)
|
||||
|
||||
eps = hp.get('epsilon')
|
||||
eps_min = hp.get('epsilon_min')
|
||||
if (isinstance(eps, (int, float)) and isinstance(eps_min, (int, float))
|
||||
and 0 <= eps <= 1 and 0 <= eps_min <= 1 and eps_min > eps):
|
||||
_problem(
|
||||
f"epsilon_min = {eps_min!r} is greater than epsilon = {eps!r}",
|
||||
"epsilon is the starting exploration rate and epsilon_min is the floor it\n"
|
||||
" decays toward, so epsilon_min must be less than or equal to epsilon.",
|
||||
f"Try: epsilon = 1.0 and epsilon_min = 0.05",
|
||||
)
|
||||
|
||||
for key, blurb in [
|
||||
('batch_size',
|
||||
"Training samples this many past experiences from the replay buffer at once\n"
|
||||
" to compute a learning update. Must be a positive integer."),
|
||||
('memory_capacity',
|
||||
"The replay buffer stores this many past experiences. When it fills up, the\n"
|
||||
" oldest are discarded. A larger buffer means more diverse training data.\n"
|
||||
" Must be a positive integer."),
|
||||
('target_update_freq',
|
||||
"The target network (a stable copy of the Q-network used to compute targets)\n"
|
||||
" is updated every this many steps. Must be a positive integer."),
|
||||
('train_every',
|
||||
"A training step runs once every this many game steps. This lets the agent\n"
|
||||
" collect several new experiences before updating. Must be a positive integer."),
|
||||
('training_episodes',
|
||||
"The total number of episodes (games) to train for. Must be a positive integer."),
|
||||
('max_turns_per_episode',
|
||||
"If a game episode hasn't ended naturally after this many steps, it's cut\n"
|
||||
" short. This prevents a buggy or stuck agent from running forever.\n"
|
||||
" Must be a positive integer."),
|
||||
]:
|
||||
v = hp.get(key)
|
||||
if not isinstance(v, int) or v <= 0:
|
||||
_problem(
|
||||
f"{key} = {v!r}",
|
||||
blurb,
|
||||
f"Try: {key} = {DEFAULTS[key]}",
|
||||
)
|
||||
|
||||
v = hp.get('exploration_turns')
|
||||
if not isinstance(v, int) or v < 0:
|
||||
_problem(
|
||||
f"exploration_turns = {v!r}",
|
||||
"When no character_set is specified, the trainer runs this many random turns\n"
|
||||
" to discover what characters appear on the board. Must be 0 or more\n"
|
||||
" (0 skips discovery entirely, which only works if character_set is set).",
|
||||
f"Try: exploration_turns = {DEFAULTS['exploration_turns']}",
|
||||
)
|
||||
|
||||
bs = hp.get('batch_size')
|
||||
mc = hp.get('memory_capacity')
|
||||
if (isinstance(bs, int) and bs > 0 and isinstance(mc, int) and mc > 0 and bs > mc):
|
||||
_problem(
|
||||
f"batch_size = {bs} is larger than memory_capacity = {mc}",
|
||||
"Training samples a batch of past experiences from the replay buffer each\n"
|
||||
" step, so the buffer must be able to hold at least as many experiences\n"
|
||||
f" as the batch size. With batch_size = {bs}, the buffer holds {mc} —\n"
|
||||
" that's not enough to sample from.",
|
||||
f"Try: memory_capacity = {max(bs * 100, DEFAULTS['memory_capacity'])} "
|
||||
f"(a much larger buffer also improves learning quality)",
|
||||
)
|
||||
|
||||
if problems:
|
||||
n = len(problems)
|
||||
noun = "problem" if n == 1 else "problems"
|
||||
header = f"Found {n} {noun} in your training configuration:\n\n"
|
||||
footer = "\n\nFix these in config.toml, then run 'retro-gamer train' again."
|
||||
raise ValueError(header + "\n\n".join(problems) + footer)
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
m, s = divmod(int(seconds), 60)
|
||||
h, m = divmod(m, 60)
|
||||
if h:
|
||||
return f"{h}h{m:02d}m{s:02d}s"
|
||||
return f"{m}m{s:02d}s"
|
||||
|
||||
|
||||
class DQNTrainer:
|
||||
"""Trains a deep Q-network agent to play a retro game.
|
||||
|
||||
On initialization the trainer:
|
||||
1. Discovers the character set (if not already specified in metadata).
|
||||
2. Builds the Q-network and logs the full architecture with rationale.
|
||||
3. Saves config.toml and starts training.log in run_dir.
|
||||
Automatically selects the best available training device: Apple Silicon
|
||||
GPU (MPS), NVIDIA GPU (CUDA), or CPU. The chosen device is recorded in
|
||||
``training.log``.
|
||||
|
||||
Call train() to run all episodes and save checkpoints.
|
||||
On initialization, the trainer:
|
||||
|
||||
1. Discovers the character set if not already specified in *metadata*.
|
||||
2. Builds the Q-network and logs its full architecture with rationale.
|
||||
3. Writes ``config.toml`` and initializes ``training.log`` in *run_dir*.
|
||||
|
||||
Hyperparameters can be passed as keyword arguments; see the
|
||||
:ref:`hyperparameters` reference for all options. Values not supplied
|
||||
fall back to sensible defaults.
|
||||
|
||||
Call :meth:`train` to run all episodes. Checkpoints are saved every 100
|
||||
episodes and training can be stopped (Ctrl+C) and resumed at any time.
|
||||
|
||||
Example::
|
||||
|
||||
from retro_gamer import GameMetadata, DQNTrainer
|
||||
from retro.examples.snake import create_game
|
||||
|
||||
metadata = GameMetadata.from_pyproject("retro.examples.snake")
|
||||
trainer = DQNTrainer(create_game, metadata, "runs/snake/")
|
||||
trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -49,26 +261,80 @@ class DQNTrainer:
|
||||
game_factory: Callable,
|
||||
metadata: GameMetadata,
|
||||
run_dir: str | Path,
|
||||
preprocessing: dict | None = None,
|
||||
**hyperparams,
|
||||
):
|
||||
self.game_factory = game_factory
|
||||
self.metadata = metadata
|
||||
self.run_dir = Path(run_dir)
|
||||
self.hp: dict = {**DEFAULTS, **hyperparams}
|
||||
validate_hyperparams(self.hp)
|
||||
self.run_dir.mkdir(parents=True, exist_ok=True)
|
||||
(self.run_dir / 'checkpoints').mkdir(exist_ok=True)
|
||||
|
||||
self.env = GameEnvironment(game_factory, metadata)
|
||||
pre = preprocessing or {}
|
||||
self.observe_state: list[str] = pre.get('observe_state', [])
|
||||
self.egocentric: bool = pre.get('egocentric', False)
|
||||
self.egocentric_player: str | None = pre.get('egocentric_player', None)
|
||||
self.egocentric_radius: int | None = pre.get('egocentric_radius', None)
|
||||
self.board: bool = pre.get('board', True)
|
||||
self.observe_state_sizes: dict[str, int] = pre.get('observe_state_sizes', {})
|
||||
|
||||
if self.board is False and metadata.spatial:
|
||||
raise ValueError(
|
||||
"preprocessing.board = false is incompatible with spatial = true.\n"
|
||||
"A CNN requires a 2-D board to operate on. Either set spatial = false\n"
|
||||
"or keep board = true."
|
||||
)
|
||||
if self.board is False and not self.observe_state:
|
||||
raise ValueError(
|
||||
"preprocessing.board = false requires at least one entry in observe_state.\n"
|
||||
"With board=false, the agent observes only the game state variables listed\n"
|
||||
"in observe_state — if that list is empty, there is nothing to observe."
|
||||
)
|
||||
if self.egocentric and not self.egocentric_radius:
|
||||
raise ValueError(
|
||||
"preprocessing.egocentric = true requires egocentric_radius.\n"
|
||||
"Choose a value based on how far the agent needs to see, e.g.:\n"
|
||||
" egocentric_radius = 5 # 11×11 tight local view\n"
|
||||
" egocentric_radius = 8 # 17×17 wider view"
|
||||
)
|
||||
|
||||
metadata.board = self.board
|
||||
|
||||
if metadata.board_size is None:
|
||||
g = game_factory()
|
||||
metadata.board_size = g.board_size
|
||||
|
||||
if metadata.character_set is None:
|
||||
if self.egocentric_radius:
|
||||
side = 2 * self.egocentric_radius + 1
|
||||
metadata.board_size = (side, side)
|
||||
|
||||
self.env = GameEnvironment(
|
||||
game_factory, metadata,
|
||||
observe_state=self.observe_state,
|
||||
egocentric=self.egocentric,
|
||||
egocentric_player=self.egocentric_player,
|
||||
egocentric_radius=self.egocentric_radius,
|
||||
board=self.board,
|
||||
observe_state_sizes=self.observe_state_sizes,
|
||||
)
|
||||
|
||||
if metadata.character_set is None and self.board:
|
||||
self._discover_character_set()
|
||||
|
||||
if self.observe_state and not self.observe_state_sizes:
|
||||
self._discover_observe_state_sizes()
|
||||
self.env.observe_state_sizes = self.observe_state_sizes
|
||||
|
||||
metadata.extras_size = sum(self.observe_state_sizes.values()) if self.observe_state_sizes else 0
|
||||
|
||||
self.device = _get_device()
|
||||
|
||||
self.model, rationale = build_network(metadata, self.hp)
|
||||
self.target_model, _ = build_network(metadata, self.hp)
|
||||
self.model.to(self.device)
|
||||
self.target_model.to(self.device)
|
||||
self.target_model.load_state_dict(self.model.state_dict())
|
||||
self.target_model.eval()
|
||||
|
||||
@@ -76,7 +342,7 @@ class DQNTrainer:
|
||||
self.model.parameters(), lr=self.hp['learning_rate']
|
||||
)
|
||||
self.lr_scheduler = optim.lr_scheduler.ExponentialLR(
|
||||
self.optimizer, gamma=self.hp['lr_decay']
|
||||
self.optimizer, gamma=self.hp['learning_rate_decay']
|
||||
)
|
||||
|
||||
if self.hp['prioritize_experiences']:
|
||||
@@ -86,6 +352,9 @@ class DQNTrainer:
|
||||
|
||||
self.epsilon: float = self.hp['epsilon']
|
||||
self.total_steps: int = 0
|
||||
self.total_training_seconds: float = 0.0
|
||||
self.start_episode: int = 1
|
||||
self._resumed_from: str | None = None
|
||||
|
||||
self._save_config()
|
||||
self._open_log(rationale)
|
||||
@@ -94,49 +363,127 @@ class DQNTrainer:
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def train(self):
|
||||
"""Run all training episodes and save checkpoints."""
|
||||
for episode in range(1, self.hp['training_episodes'] + 1):
|
||||
total_reward, steps, avg_loss = self._run_episode()
|
||||
def train(self, on_checkpoint=None, on_episode=None):
|
||||
"""Run all training episodes and save checkpoints.
|
||||
|
||||
*on_checkpoint*, if provided, is called after each checkpoint with a
|
||||
dict containing ``episode``, ``avg_reward``, ``avg_steps``,
|
||||
``avg_loss``, and ``epsilon``. *on_episode*, if provided, is called
|
||||
after every episode. When either callback is supplied, the built-in
|
||||
tqdm progress bar is suppressed (the caller is expected to show its
|
||||
own progress UI).
|
||||
"""
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
if self._resumed_from:
|
||||
self._log_raw(f'\n=== Resumed from {self._resumed_from} | {timestamp} ===')
|
||||
else:
|
||||
self._log_raw(f'\n=== Training started | {timestamp} ===')
|
||||
|
||||
use_tqdm = on_checkpoint is None and on_episode is None
|
||||
if use_tqdm:
|
||||
print("Press Control+C to stop training early. Progress will be saved at the latest checkpoint.")
|
||||
|
||||
session_start = perf_counter()
|
||||
ckpt_start = perf_counter()
|
||||
episode_rewards: list[float] = []
|
||||
episode_losses: list[float] = []
|
||||
episode_steps: list[int] = []
|
||||
|
||||
episodes = range(self.start_episode, self.hp['training_episodes'] + 1)
|
||||
bar = tqdm(episodes, unit='ep') if use_tqdm else episodes
|
||||
for episode in bar:
|
||||
total_reward, steps, avg_loss, trained = self._run_episode()
|
||||
episode_rewards.append(total_reward)
|
||||
if avg_loss > 0:
|
||||
episode_losses.append(avg_loss)
|
||||
episode_steps.append(steps)
|
||||
|
||||
self.epsilon = max(
|
||||
self.hp['epsilon_min'], self.epsilon * self.hp['epsilon_decay']
|
||||
)
|
||||
self.lr_scheduler.step()
|
||||
self._log_episode(episode, total_reward, steps, avg_loss)
|
||||
if episode % 100 == 0:
|
||||
self._save_checkpoint(f'ep_{episode:04d}.pt')
|
||||
self._save_checkpoint('final.pt')
|
||||
if trained:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if on_episode:
|
||||
on_episode()
|
||||
|
||||
if use_tqdm:
|
||||
bar.set_postfix(
|
||||
reward=f'{total_reward:.1f}',
|
||||
eps=f'{self.epsilon:.3f}',
|
||||
loss=f'{avg_loss:.4f}',
|
||||
)
|
||||
|
||||
is_checkpoint = (episode % 100 == 0)
|
||||
is_last = (episode == self.hp['training_episodes'])
|
||||
if is_checkpoint or (is_last and episode_rewards):
|
||||
now = perf_counter()
|
||||
ckpt_elapsed = now - ckpt_start
|
||||
self.total_training_seconds += ckpt_elapsed
|
||||
ckpt_start = now
|
||||
|
||||
self._save_checkpoint(f'ep_{episode:04d}.pt', episode)
|
||||
stats = self._log_checkpoint(episode, episode_rewards, episode_losses, episode_steps, ckpt_elapsed)
|
||||
episode_rewards = []
|
||||
episode_losses = []
|
||||
episode_steps = []
|
||||
|
||||
if on_checkpoint:
|
||||
on_checkpoint(stats)
|
||||
|
||||
def load_checkpoint(self, path: str | Path):
|
||||
ckpt = torch.load(path, weights_only=True)
|
||||
"""Load a checkpoint to resume training.
|
||||
|
||||
Checkpoints are PyTorch state dicts stored under
|
||||
``run_dir/checkpoints/``. Each contains model weights, optimizer
|
||||
state, current epsilon, and total step count.
|
||||
|
||||
Raises :exc:`ValueError` if the checkpoint was trained with a
|
||||
different character set, board size, action space, or network
|
||||
architecture. The error message names each changed field and explains
|
||||
why it is incompatible.
|
||||
|
||||
The CLI invokes this automatically; call directly only when driving
|
||||
training from Python.
|
||||
"""
|
||||
ckpt = torch.load(path, weights_only=True, map_location='cpu')
|
||||
self._check_compatibility(ckpt, path)
|
||||
self.model.load_state_dict(ckpt['model_state_dict'])
|
||||
self.target_model.load_state_dict(ckpt['model_state_dict'])
|
||||
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
|
||||
for state in self.optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(self.device)
|
||||
self.epsilon = ckpt['epsilon']
|
||||
self.total_steps = ckpt['total_steps']
|
||||
self.total_training_seconds = ckpt.get('total_training_seconds', 0.0)
|
||||
self.start_episode = ckpt.get('episode', 0) + 1
|
||||
self._resumed_from = Path(path).name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Training loop internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_episode(self) -> tuple[float, int, float]:
|
||||
def _run_episode(self) -> tuple[float, int, float, bool]:
|
||||
state = self.env.reset()
|
||||
total_reward = 0.0
|
||||
total_loss = 0.0
|
||||
loss_count = 0
|
||||
|
||||
for step in range(self.hp['max_turns_per_episode']):
|
||||
state_t = torch.as_tensor(state, dtype=torch.float32)
|
||||
state_t = torch.as_tensor(state, dtype=torch.float32).to(self.device)
|
||||
action_idx = self._select_action(state_t)
|
||||
action_key = self._idx_to_key(action_idx)
|
||||
|
||||
next_state, reward, done = self.env.step(action_key)
|
||||
self.memory.push(state, action_idx, reward, next_state, done)
|
||||
|
||||
loss = self._train_step()
|
||||
if loss is not None:
|
||||
total_loss += loss
|
||||
loss_count += 1
|
||||
if self.total_steps % self.hp['train_every'] == 0:
|
||||
loss = self._train_step()
|
||||
if loss is not None:
|
||||
total_loss += loss
|
||||
loss_count += 1
|
||||
|
||||
self.total_steps += 1
|
||||
if self.total_steps % self.hp['target_update_freq'] == 0:
|
||||
@@ -148,7 +495,7 @@ class DQNTrainer:
|
||||
break
|
||||
|
||||
avg_loss = total_loss / loss_count if loss_count else 0.0
|
||||
return total_reward, step + 1, avg_loss
|
||||
return total_reward, step + 1, avg_loss, loss_count > 0
|
||||
|
||||
def _select_action(self, state_t: torch.Tensor) -> int:
|
||||
if random.random() < self.epsilon:
|
||||
@@ -168,7 +515,7 @@ class DQNTrainer:
|
||||
if self.hp['prioritize_experiences']:
|
||||
assert isinstance(self.memory, PrioritizedReplayMemory)
|
||||
experiences, indices, weights = self.memory.sample(self.hp['batch_size'])
|
||||
weight_t = torch.as_tensor(weights, dtype=torch.float32)
|
||||
weight_t = torch.as_tensor(weights, dtype=torch.float32).to(self.device)
|
||||
else:
|
||||
experiences = self.memory.sample(self.hp['batch_size'])
|
||||
indices = None
|
||||
@@ -176,33 +523,107 @@ class DQNTrainer:
|
||||
|
||||
states = torch.as_tensor(
|
||||
np.array([e.state for e in experiences]), dtype=torch.float32
|
||||
)
|
||||
actions = torch.as_tensor([e.action for e in experiences], dtype=torch.long)
|
||||
rewards = torch.as_tensor([e.reward for e in experiences], dtype=torch.float32)
|
||||
).to(self.device)
|
||||
actions = torch.as_tensor(
|
||||
[e.action for e in experiences], dtype=torch.long
|
||||
).to(self.device)
|
||||
rewards = torch.as_tensor(
|
||||
[e.reward for e in experiences], dtype=torch.float32
|
||||
).to(self.device)
|
||||
next_states = torch.as_tensor(
|
||||
np.array([e.next_state for e in experiences]), dtype=torch.float32
|
||||
)
|
||||
dones = torch.as_tensor([e.done for e in experiences], dtype=torch.float32)
|
||||
).to(self.device)
|
||||
dones = torch.as_tensor(
|
||||
[e.done for e in experiences], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
with torch.no_grad():
|
||||
next_q = self.target_model(next_states).max(1).values
|
||||
targets = rewards + self.hp['gamma'] * next_q * (1.0 - dones)
|
||||
|
||||
element_loss = nn.functional.mse_loss(q_values, targets, reduction='none')
|
||||
element_loss = nn.functional.huber_loss(q_values, targets, reduction='none', delta=1.0)
|
||||
|
||||
if weight_t is not None:
|
||||
loss = (weight_t * element_loss).mean()
|
||||
td_errors = (q_values - targets).detach().abs().numpy()
|
||||
td_errors = (q_values - targets).detach().abs().cpu().numpy()
|
||||
self.memory.update_priorities(indices, td_errors)
|
||||
else:
|
||||
loss = element_loss.mean()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
||||
self.optimizer.step()
|
||||
return float(loss.item())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Compatibility checking
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _config_snapshot(self) -> dict:
|
||||
return {
|
||||
'metadata': self.metadata.to_dict(),
|
||||
'preprocessing': {
|
||||
'spatial': self.metadata.spatial,
|
||||
'board': self.board,
|
||||
'observe_state': self.observe_state,
|
||||
'observe_state_sizes': self.observe_state_sizes,
|
||||
'egocentric': self.egocentric,
|
||||
'egocentric_player': self.egocentric_player,
|
||||
'egocentric_radius': self.egocentric_radius,
|
||||
},
|
||||
'hidden_sizes': self.hp['hidden_sizes'],
|
||||
}
|
||||
|
||||
def _check_compatibility(self, ckpt: dict, path: str | Path):
|
||||
snapshot = ckpt.get('config_snapshot')
|
||||
if snapshot is None:
|
||||
return
|
||||
|
||||
current = self._config_snapshot()
|
||||
issues = []
|
||||
|
||||
old_meta = snapshot.get('metadata', {})
|
||||
new_meta = current['metadata']
|
||||
for field, desc in _INCOMPATIBLE_METADATA.items():
|
||||
if old_meta.get(field) != new_meta.get(field):
|
||||
issues.append((field, desc, old_meta.get(field), new_meta.get(field)))
|
||||
|
||||
old_pre = snapshot.get('preprocessing', {})
|
||||
new_pre = current['preprocessing']
|
||||
for field, desc in _INCOMPATIBLE_PREPROCESSING.items():
|
||||
if old_pre.get(field) != new_pre.get(field):
|
||||
issues.append((field, desc, old_pre.get(field), new_pre.get(field)))
|
||||
|
||||
for field, desc in _INCOMPATIBLE_ARCH.items():
|
||||
if snapshot.get(field) != current.get(field):
|
||||
issues.append((field, desc, snapshot.get(field), current.get(field)))
|
||||
|
||||
if not issues:
|
||||
return
|
||||
|
||||
lines = [
|
||||
f"Cannot resume from {Path(path).name}: incompatible changes detected in config.toml.",
|
||||
"",
|
||||
"The following changes require starting fresh. The existing model was trained",
|
||||
"on a different problem and its weights cannot be reused:",
|
||||
"",
|
||||
]
|
||||
for field, desc, old_val, new_val in issues:
|
||||
lines += [
|
||||
f" {field}",
|
||||
f" was : {old_val!r}",
|
||||
f" now : {new_val!r}",
|
||||
f" why : {desc}",
|
||||
"",
|
||||
]
|
||||
lines += [
|
||||
"Run 'retro-gamer clean RUN_DIR' to remove existing checkpoints and the",
|
||||
"training log, then run 'retro-gamer train RUN_DIR' to start fresh.",
|
||||
]
|
||||
raise ValueError("\n".join(lines))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Initialisation helpers
|
||||
# ------------------------------------------------------------------
|
||||
@@ -215,6 +636,16 @@ class DQNTrainer:
|
||||
f"after {self.hp['exploration_turns']} exploration turns: {chars}"
|
||||
)
|
||||
|
||||
def _discover_observe_state_sizes(self):
|
||||
"""Sample game.state to determine the flat size of each observe_state key."""
|
||||
self.env.reset()
|
||||
state = dict(self.env.game.state)
|
||||
sizes = {}
|
||||
for key in self.observe_state:
|
||||
val = state.get(key, 0)
|
||||
sizes[key] = len(val) if isinstance(val, (list, tuple)) else 1
|
||||
self.observe_state_sizes = sizes
|
||||
|
||||
def _save_config(self):
|
||||
config_path = self.run_dir / 'config.toml'
|
||||
config: dict = {}
|
||||
@@ -223,33 +654,75 @@ class DQNTrainer:
|
||||
with open(config_path, 'rb') as f:
|
||||
config = tomllib.load(f)
|
||||
config['metadata'] = self.metadata.to_dict()
|
||||
config['hyperparameters'] = self.hp
|
||||
pre = config.setdefault('preprocessing', {})
|
||||
pre['spatial'] = self.metadata.spatial
|
||||
pre['board'] = self.board
|
||||
pre['observe_state'] = self.observe_state
|
||||
if self.observe_state_sizes:
|
||||
pre['observe_state_sizes'] = self.observe_state_sizes
|
||||
pre['egocentric'] = self.egocentric
|
||||
if self.egocentric_player:
|
||||
pre['egocentric_player'] = self.egocentric_player
|
||||
if self.egocentric_radius:
|
||||
pre['egocentric_radius'] = self.egocentric_radius
|
||||
config['model'] = {k: v for k, v in self.hp.items() if k in MODEL_KEYS}
|
||||
config['training'] = {k: v for k, v in self.hp.items() if k not in MODEL_KEYS}
|
||||
with open(config_path, 'wb') as f:
|
||||
tomli_w.dump(config, f)
|
||||
|
||||
def _open_log(self, rationale: str):
|
||||
self.log_path = self.run_dir / 'training.log'
|
||||
with open(self.log_path, 'w') as f:
|
||||
f.write(rationale + '\n')
|
||||
if not self.log_path.exists():
|
||||
with open(self.log_path, 'w') as f:
|
||||
f.write(rationale + '\n')
|
||||
f.write(f'[INIT] Device: {self.device}\n')
|
||||
|
||||
def _log_raw(self, line: str):
|
||||
with open(self.log_path, 'a') as f:
|
||||
f.write(line + '\n')
|
||||
|
||||
def _log_episode(self, episode: int, total_reward: float, steps: int, avg_loss: float):
|
||||
def _log_checkpoint(
|
||||
self,
|
||||
episode: int,
|
||||
rewards: list[float],
|
||||
losses: list[float],
|
||||
steps: list[int],
|
||||
ckpt_elapsed: float,
|
||||
) -> dict:
|
||||
n = len(rewards)
|
||||
start_ep = episode - n + 1
|
||||
avg_reward = sum(rewards) / n if n else 0.0
|
||||
avg_loss = sum(losses) / len(losses) if losses else 0.0
|
||||
avg_steps = sum(steps) / n if n else 0.0
|
||||
line = (
|
||||
f"[EP {episode:04d}] total_reward={total_reward:.1f} "
|
||||
f"steps={steps} epsilon={self.epsilon:.4f} avg_loss={avg_loss:.6f}"
|
||||
f"[ep_{episode:04d}]"
|
||||
f" ep={start_ep:04d}-{episode:04d}"
|
||||
f" avg_reward={avg_reward:+.1f}"
|
||||
f" avg_steps={avg_steps:.0f}"
|
||||
f" epsilon={self.epsilon:.3f}"
|
||||
f" avg_loss={avg_loss:.1f}"
|
||||
f" time={_format_duration(ckpt_elapsed)}"
|
||||
f" total={_format_duration(self.total_training_seconds)}"
|
||||
)
|
||||
self._log_raw(line)
|
||||
return {
|
||||
'episode': episode,
|
||||
'avg_reward': avg_reward,
|
||||
'avg_steps': avg_steps,
|
||||
'avg_loss': avg_loss,
|
||||
'epsilon': self.epsilon,
|
||||
}
|
||||
|
||||
def _save_checkpoint(self, name: str):
|
||||
def _save_checkpoint(self, name: str, episode: int):
|
||||
torch.save(
|
||||
{
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'epsilon': self.epsilon,
|
||||
'total_steps': self.total_steps,
|
||||
'episode': episode,
|
||||
'total_training_seconds': self.total_training_seconds,
|
||||
'config_snapshot': self._config_snapshot(),
|
||||
},
|
||||
self.run_dir / 'checkpoints' / name,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user