Updates across the board
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import sys
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
import click
|
||||
import tomli_w
|
||||
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
from retro_gamer.trainer import DQNTrainer, DEFAULTS
|
||||
from retro_gamer.trainer import DQNTrainer, DEFAULTS, MODEL_KEYS
|
||||
|
||||
|
||||
@click.group()
|
||||
@@ -20,13 +21,14 @@ def cli():
|
||||
|
||||
@cli.command()
|
||||
@click.option('--game', required=True,
|
||||
help='Python module containing create_game() e.g. retro.examples.snake')
|
||||
help='Game to train: a .py file path (e.g. my_game.py) or a Python module '
|
||||
'(e.g. retro.examples.snake)')
|
||||
@click.option('--output', required=True,
|
||||
help='Directory to create for this training run')
|
||||
@click.option('--learning-rate', default=DEFAULTS['learning_rate'], type=float,
|
||||
help=f"Adam optimizer learning rate (default {DEFAULTS['learning_rate']})")
|
||||
@click.option('--lr-decay', default=DEFAULTS['lr_decay'], type=float,
|
||||
help=f"Multiplicative LR decay per episode (default {DEFAULTS['lr_decay']})")
|
||||
@click.option('--learning-rate-decay', default=DEFAULTS['learning_rate_decay'], type=float,
|
||||
help=f"Multiplicative LR decay per episode (default {DEFAULTS['learning_rate_decay']})")
|
||||
@click.option('--gamma', default=DEFAULTS['gamma'], type=float,
|
||||
help=f"Discount factor for future rewards (default {DEFAULTS['gamma']})")
|
||||
@click.option('--epsilon-decay', default=DEFAULTS['epsilon_decay'], type=float,
|
||||
@@ -43,12 +45,12 @@ def cli():
|
||||
help=f"Number of episodes to train (default {DEFAULTS['training_episodes']})")
|
||||
@click.option('--max-turns-per-episode', default=DEFAULTS['max_turns_per_episode'], type=int,
|
||||
help=f"Turn limit per episode (default {DEFAULTS['max_turns_per_episode']})")
|
||||
@click.option('--n-layers', default=DEFAULTS['n_layers'], type=int,
|
||||
help=f"Hidden layers in MLP head (default {DEFAULTS['n_layers']})")
|
||||
@click.option('--layer-size', default=DEFAULTS['layer_size'], type=int,
|
||||
help=f"Width of each hidden layer (default {DEFAULTS['layer_size']})")
|
||||
@click.option('--hidden-sizes', default=','.join(str(s) for s in DEFAULTS['hidden_sizes']),
|
||||
help=f"Comma-separated hidden layer sizes, e.g. 512,256 (default {DEFAULTS['hidden_sizes']})")
|
||||
@click.option('--exploration-turns', default=DEFAULTS['exploration_turns'], type=int,
|
||||
help=f"Random turns for character discovery (default {DEFAULTS['exploration_turns']})")
|
||||
@click.option('--train-every', default=DEFAULTS['train_every'], type=int,
|
||||
help=f"Run a training step every N game steps (default {DEFAULTS['train_every']})")
|
||||
@click.option('--prioritize-experiences/--no-prioritize-experiences',
|
||||
default=DEFAULTS['prioritize_experiences'],
|
||||
help='Use prioritized experience replay')
|
||||
@@ -60,12 +62,23 @@ def create(game, output, **hyperparams):
|
||||
Board size is read directly from the game. Hyperparameter options
|
||||
control how the trainer learns, not what it learns about.
|
||||
"""
|
||||
raw = hyperparams['hidden_sizes']
|
||||
try:
|
||||
metadata = GameMetadata.from_pyproject(game)
|
||||
hyperparams['hidden_sizes'] = [int(x.strip()) for x in raw.split(',')]
|
||||
except ValueError:
|
||||
raise click.ClickException(
|
||||
f"Could not parse --hidden-sizes {raw!r}.\n"
|
||||
"It should be a comma-separated list of positive integers, one per hidden layer.\n"
|
||||
"Example: --hidden-sizes 512,256"
|
||||
)
|
||||
game_config = _parse_game_arg(game)
|
||||
|
||||
try:
|
||||
metadata = GameMetadata.from_pyproject(game_config['module'])
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
game_factory = _load_factory(game)
|
||||
game_factory = _load_factory(game_config)
|
||||
g = game_factory()
|
||||
metadata.board_size = g.board_size
|
||||
|
||||
@@ -74,10 +87,13 @@ def create(game, output, **hyperparams):
|
||||
run_dir = Path(output)
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
preprocessing = {'spatial': metadata.spatial}
|
||||
config = {
|
||||
'game': {'module': game},
|
||||
'game': game_config,
|
||||
'metadata': metadata.to_dict(),
|
||||
'hyperparameters': hyperparams,
|
||||
'preprocessing': preprocessing,
|
||||
'model': {k: v for k, v in hyperparams.items() if k in MODEL_KEYS},
|
||||
'training': {k: v for k, v in hyperparams.items() if k not in MODEL_KEYS},
|
||||
}
|
||||
with open(run_dir / 'config.toml', 'wb') as f:
|
||||
tomli_w.dump(config, f)
|
||||
@@ -91,8 +107,6 @@ def create(game, output, **hyperparams):
|
||||
click.echo(f" characters : {metadata.character_set}")
|
||||
else:
|
||||
click.echo(f" characters : (will be auto-discovered during training)")
|
||||
if metadata.observe_state:
|
||||
click.echo(f" observe : {metadata.observe_state}")
|
||||
click.echo(f" architecture: {'CNN (spatial)' if metadata.spatial else 'MLP (non-spatial)'}")
|
||||
|
||||
|
||||
@@ -102,23 +116,49 @@ def create(game, output, **hyperparams):
|
||||
|
||||
@cli.command()
|
||||
@click.argument('run_dir')
|
||||
@click.option('--resume', default=None,
|
||||
help='Path to checkpoint to resume from (e.g. checkpoints/ep_0500.pt)')
|
||||
def train(run_dir, resume):
|
||||
"""Train (or resume training) a DQN agent."""
|
||||
def train(run_dir):
|
||||
"""Train a DQN agent, resuming automatically from the latest checkpoint.
|
||||
|
||||
To start fresh or resume from an earlier point, delete the checkpoints
|
||||
you no longer want from RUN_DIR/checkpoints/.
|
||||
"""
|
||||
run_dir_path = Path(run_dir)
|
||||
config = _load_config(run_dir_path)
|
||||
game_factory = _load_factory(config['game']['module'])
|
||||
game_factory = _load_factory(config['game'])
|
||||
metadata = GameMetadata.from_dict(config['metadata'])
|
||||
hyperparams = config.get('hyperparameters', {})
|
||||
preprocessing = config.get('preprocessing', {})
|
||||
metadata.spatial = preprocessing.get('spatial', False)
|
||||
hyperparams = {**config.get('model', {}), **config['training']}
|
||||
|
||||
trainer = DQNTrainer(game_factory, metadata, run_dir, **hyperparams)
|
||||
if resume:
|
||||
click.echo(f"Resuming from {resume}")
|
||||
trainer.load_checkpoint(resume)
|
||||
click.echo(f"Training for {trainer.hp['training_episodes']} episodes…")
|
||||
trainer.train()
|
||||
click.echo(f"Done. Checkpoints in {run_dir}/checkpoints/")
|
||||
try:
|
||||
trainer = DQNTrainer(game_factory, metadata, run_dir, preprocessing=preprocessing, **hyperparams)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
latest = _latest_checkpoint(run_dir_path)
|
||||
if latest:
|
||||
click.echo(f"Resuming from {latest.name}")
|
||||
try:
|
||||
trainer.load_checkpoint(latest)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
if trainer.start_episode > trainer.hp['training_episodes']:
|
||||
click.echo(
|
||||
f"Training already complete ({trainer.hp['training_episodes']} episodes). "
|
||||
"To keep training, increase training_episodes in config.toml."
|
||||
)
|
||||
return
|
||||
|
||||
from retro_gamer.log_parser import parse_checkpoints
|
||||
from retro_gamer.dashboard import TrainingDashboard
|
||||
history = parse_checkpoints(run_dir_path / 'training.log')
|
||||
display = TrainingDashboard(trainer.hp['training_episodes'], trainer.start_episode, history)
|
||||
try:
|
||||
trainer.train(on_checkpoint=display.on_checkpoint, on_episode=display.on_episode)
|
||||
finally:
|
||||
display.close()
|
||||
click.echo(f"Done. Checkpoints saved in {run_dir}/checkpoints/")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -127,69 +167,72 @@ def train(run_dir, resume):
|
||||
|
||||
@cli.command()
|
||||
@click.argument('run_dir')
|
||||
@click.option('--checkpoint', default='final',
|
||||
help='Checkpoint name e.g. "final" or "ep_0100"')
|
||||
@click.option('--checkpoint', default=None,
|
||||
help='Checkpoint to load, e.g. "ep_0100". Defaults to the latest available.')
|
||||
@click.option('--framerate', default=12, type=int,
|
||||
help='Target frames per second')
|
||||
def play(run_dir, checkpoint, framerate):
|
||||
"""Watch a trained agent play the game."""
|
||||
import torch
|
||||
from time import sleep, perf_counter
|
||||
from blessed import Terminal
|
||||
from retro.input import ProgrammaticInput
|
||||
from retro.views.headless import HeadlessView
|
||||
from retro.views.terminal import TerminalView
|
||||
from retro_gamer.observation import encode_observation
|
||||
from retro_gamer.model_agent import TrainedPolicy, PolicyInput
|
||||
|
||||
run_dir_path = Path(run_dir)
|
||||
config = _load_config(run_dir_path)
|
||||
game_factory = _load_factory(config['game']['module'])
|
||||
metadata = GameMetadata.from_dict(config['metadata'])
|
||||
hyperparams = {**DEFAULTS, **config.get('hyperparameters', {})}
|
||||
game_factory = _load_factory(config['game'])
|
||||
|
||||
from retro_gamer.network import build_network
|
||||
model, _ = build_network(metadata, hyperparams)
|
||||
try:
|
||||
ai = TrainedPolicy(run_dir_path, checkpoint=checkpoint)
|
||||
except FileNotFoundError as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt'
|
||||
ckpt_path = run_dir_path / 'checkpoints' / ckpt_name
|
||||
ckpt = torch.load(ckpt_path, weights_only=True)
|
||||
model.load_state_dict(ckpt['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
inp = ProgrammaticInput()
|
||||
headless = HeadlessView()
|
||||
game = game_factory()
|
||||
game.input_source = inp
|
||||
game.view = headless
|
||||
game.start()
|
||||
inp = PolicyInput(ai, game)
|
||||
|
||||
terminal = Terminal()
|
||||
term_view = TerminalView(terminal, color=game.color)
|
||||
|
||||
click.echo("Playing… (press Escape or Enter to quit)")
|
||||
with terminal.fullscreen(), terminal.hidden_cursor(), terminal.cbreak():
|
||||
term_view.on_game_start(game)
|
||||
game.input_source = inp
|
||||
game.view = term_view
|
||||
game.start()
|
||||
while game.playing:
|
||||
t0 = perf_counter()
|
||||
|
||||
obs = encode_observation(headless.board_characters, dict(game.state), metadata)
|
||||
state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
action_idx = int(model(state_t).argmax().item())
|
||||
action_key = None if action_idx >= len(metadata.actions) else metadata.actions[action_idx]
|
||||
|
||||
key = terminal.inkey(0)
|
||||
if key and key.name in ('KEY_ESCAPE', 'KEY_ENTER'):
|
||||
break
|
||||
|
||||
inp.press(action_key)
|
||||
game.step()
|
||||
term_view.render(game)
|
||||
|
||||
elapsed = perf_counter() - t0
|
||||
sleep(max(0, 1 / framerate - elapsed))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# retro-gamer plot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@cli.command()
|
||||
@click.argument('run_dir')
|
||||
@click.option('--output', '-o', default=None,
|
||||
help='Save to file (e.g. plot.png) instead of displaying interactively.')
|
||||
def plot(run_dir, output):
|
||||
"""Plot training metrics (reward, steps, loss, epsilon) from a run's log."""
|
||||
from retro_gamer.plotter import plot_run
|
||||
run_dir_path = Path(run_dir)
|
||||
log_path = run_dir_path / 'training.log'
|
||||
if not log_path.exists():
|
||||
raise click.ClickException(f"No training.log found in {run_dir}")
|
||||
output_path = Path(output) if output else None
|
||||
try:
|
||||
plot_run(log_path, output_path)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# retro-gamer info
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -200,17 +243,19 @@ def info(run_dir):
|
||||
"""Print a summary of a training run."""
|
||||
run_dir_path = Path(run_dir)
|
||||
config = _load_config(run_dir_path)
|
||||
click.echo(f"Game module : {config['game']['module']}")
|
||||
click.echo(f"Metadata : {config['metadata']}")
|
||||
click.echo(f"Hyperparams : {config.get('hyperparameters', {})}")
|
||||
click.echo(f"Game module : {config['game']['module']}")
|
||||
click.echo(f"Metadata : {config['metadata']}")
|
||||
click.echo(f"Preprocessing : {config.get('preprocessing', {})}")
|
||||
click.echo(f"Model : {config.get('model', {})}")
|
||||
click.echo(f"Training : {config['training']}")
|
||||
|
||||
log_path = run_dir_path / 'training.log'
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text().splitlines()
|
||||
episode_lines = [l for l in lines if l.startswith('[EP')]
|
||||
if episode_lines:
|
||||
click.echo(f"\nLast 5 episodes:")
|
||||
for line in episode_lines[-5:]:
|
||||
ckpt_lines = [l for l in lines if l.startswith('[ep_')]
|
||||
if ckpt_lines:
|
||||
click.echo(f"\nLast 5 checkpoints:")
|
||||
for line in ckpt_lines[-5:]:
|
||||
click.echo(f" {line}")
|
||||
|
||||
ckpt_dir = run_dir_path / 'checkpoints'
|
||||
@@ -219,10 +264,91 @@ def info(run_dir):
|
||||
click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# retro-gamer clean
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@cli.command()
|
||||
@click.argument('run_dir')
|
||||
@click.option('--yes', '-y', is_flag=True, help='Skip confirmation prompt')
|
||||
def clean(run_dir, yes):
|
||||
"""Remove all checkpoints and the training log from a run directory.
|
||||
|
||||
Use this after changing the game description or network architecture,
|
||||
which require starting training from scratch.
|
||||
"""
|
||||
run_dir_path = Path(run_dir)
|
||||
if not run_dir_path.exists():
|
||||
raise click.ClickException(f"Directory not found: {run_dir}")
|
||||
|
||||
to_remove = []
|
||||
ckpt_dir = run_dir_path / 'checkpoints'
|
||||
if ckpt_dir.exists():
|
||||
to_remove.extend(sorted(ckpt_dir.glob('*.pt')))
|
||||
log_path = run_dir_path / 'training.log'
|
||||
if log_path.exists():
|
||||
to_remove.append(log_path)
|
||||
|
||||
if not to_remove:
|
||||
click.echo("Nothing to clean.")
|
||||
return
|
||||
|
||||
n_ckpts = sum(1 for p in to_remove if p.suffix == '.pt')
|
||||
click.echo(f"Will remove {n_ckpts} checkpoint(s) and training log from {run_dir}/:")
|
||||
for p in to_remove:
|
||||
click.echo(f" {p.relative_to(run_dir_path)}")
|
||||
|
||||
if not yes:
|
||||
click.confirm("\nProceed?", abort=True)
|
||||
|
||||
for p in to_remove:
|
||||
p.unlink()
|
||||
click.echo(f"Cleaned. Run 'retro-gamer train {run_dir}' to start fresh.")
|
||||
|
||||
ckpt_dir = run_dir_path / 'checkpoints'
|
||||
if ckpt_dir.exists():
|
||||
ckpts = sorted(ckpt_dir.glob('*.pt'))
|
||||
click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_game_arg(arg: str) -> dict:
|
||||
"""Accept a .py file path, a package directory, or a Python module name.
|
||||
|
||||
Returns a dict with at least 'module', and optionally 'path' (the
|
||||
directory to add to sys.path so the module can be imported).
|
||||
"""
|
||||
p = Path(arg)
|
||||
if p.exists():
|
||||
if p.is_file() and p.suffix == '.py':
|
||||
path = str(p.parent.resolve())
|
||||
module = p.stem
|
||||
elif p.is_dir():
|
||||
path = str(p.parent.resolve())
|
||||
module = p.name
|
||||
else:
|
||||
raise click.ClickException(
|
||||
f"{arg!r} is not a .py file or package directory"
|
||||
)
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
return {'module': module, 'path': path}
|
||||
return {'module': arg}
|
||||
|
||||
|
||||
def _latest_checkpoint(run_dir: Path) -> Path | None:
|
||||
"""Return the most recent checkpoint in run_dir/checkpoints/, or None."""
|
||||
ckpt_dir = run_dir / 'checkpoints'
|
||||
if ckpt_dir.exists():
|
||||
candidates = sorted(ckpt_dir.glob('ep_*.pt'))
|
||||
if candidates:
|
||||
return candidates[-1]
|
||||
return None
|
||||
|
||||
|
||||
def _load_config(run_dir: Path) -> dict:
|
||||
config_path = run_dir / 'config.toml'
|
||||
if not config_path.exists():
|
||||
@@ -231,7 +357,11 @@ def _load_config(run_dir: Path) -> dict:
|
||||
return tomllib.load(f)
|
||||
|
||||
|
||||
def _load_factory(module_name: str):
|
||||
def _load_factory(game_config: dict):
|
||||
path = game_config.get('path')
|
||||
if path and path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
module_name = game_config['module']
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError as e:
|
||||
@@ -241,3 +371,4 @@ def _load_factory(module_name: str):
|
||||
f"Module '{module_name}' has no create_game() function"
|
||||
)
|
||||
return module.create_game
|
||||
|
||||
|
||||
Reference in New Issue
Block a user