Initial commit

This commit is contained in:
Chris Proctor
2026-05-08 14:07:17 -04:00
commit 5ca97dc5d0
36 changed files with 4147 additions and 0 deletions

5
retro_gamer/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from retro_gamer.metadata import GameMetadata
from retro_gamer.env import GameEnvironment
from retro_gamer.trainer import DQNTrainer
__all__ = ["GameMetadata", "GameEnvironment", "DQNTrainer"]

243
retro_gamer/cli.py Normal file
View File

@@ -0,0 +1,243 @@
from __future__ import annotations
import importlib
import tomllib
from pathlib import Path
import click
import tomli_w
from retro_gamer.metadata import GameMetadata
from retro_gamer.trainer import DQNTrainer, DEFAULTS
@click.group()
def cli():
"""Train and run RL agents for retro games."""
# ---------------------------------------------------------------------------
# retro-gamer create
# ---------------------------------------------------------------------------
@cli.command()
@click.option('--game', required=True,
help='Python module containing create_game() e.g. retro.examples.snake')
@click.option('--output', required=True,
help='Directory to create for this training run')
@click.option('--learning-rate', default=DEFAULTS['learning_rate'], type=float,
help=f"Adam optimizer learning rate (default {DEFAULTS['learning_rate']})")
@click.option('--lr-decay', default=DEFAULTS['lr_decay'], type=float,
help=f"Multiplicative LR decay per episode (default {DEFAULTS['lr_decay']})")
@click.option('--gamma', default=DEFAULTS['gamma'], type=float,
help=f"Discount factor for future rewards (default {DEFAULTS['gamma']})")
@click.option('--epsilon-decay', default=DEFAULTS['epsilon_decay'], type=float,
help=f"Exploration rate decay per episode (default {DEFAULTS['epsilon_decay']})")
@click.option('--epsilon-min', default=DEFAULTS['epsilon_min'], type=float,
help=f"Minimum exploration rate (default {DEFAULTS['epsilon_min']})")
@click.option('--batch-size', default=DEFAULTS['batch_size'], type=int,
help=f"Experiences per training step (default {DEFAULTS['batch_size']})")
@click.option('--memory-capacity', default=DEFAULTS['memory_capacity'], type=int,
help=f"Replay buffer size (default {DEFAULTS['memory_capacity']})")
@click.option('--target-update-freq', default=DEFAULTS['target_update_freq'], type=int,
help=f"Steps between target network updates (default {DEFAULTS['target_update_freq']})")
@click.option('--training-episodes', default=DEFAULTS['training_episodes'], type=int,
help=f"Number of episodes to train (default {DEFAULTS['training_episodes']})")
@click.option('--max-turns-per-episode', default=DEFAULTS['max_turns_per_episode'], type=int,
help=f"Turn limit per episode (default {DEFAULTS['max_turns_per_episode']})")
@click.option('--n-layers', default=DEFAULTS['n_layers'], type=int,
help=f"Hidden layers in MLP head (default {DEFAULTS['n_layers']})")
@click.option('--layer-size', default=DEFAULTS['layer_size'], type=int,
help=f"Width of each hidden layer (default {DEFAULTS['layer_size']})")
@click.option('--exploration-turns', default=DEFAULTS['exploration_turns'], type=int,
help=f"Random turns for character discovery (default {DEFAULTS['exploration_turns']})")
@click.option('--prioritize-experiences/--no-prioritize-experiences',
default=DEFAULTS['prioritize_experiences'],
help='Use prioritized experience replay')
def create(game, output, **hyperparams):
"""Create a new training run directory.
Game metadata (actions, reward signal, etc.) is read from the
[tool.retro-gamer] section of the game's pyproject.toml.
Board size is read directly from the game. Hyperparameter options
control how the trainer learns, not what it learns about.
"""
try:
metadata = GameMetadata.from_pyproject(game)
except (FileNotFoundError, ValueError) as e:
raise click.ClickException(str(e))
game_factory = _load_factory(game)
g = game_factory()
metadata.board_size = g.board_size
metadata.validate()
run_dir = Path(output)
run_dir.mkdir(parents=True, exist_ok=True)
config = {
'game': {'module': game},
'metadata': metadata.to_dict(),
'hyperparameters': hyperparams,
}
with open(run_dir / 'config.toml', 'wb') as f:
tomli_w.dump(config, f)
click.echo(f"Created training run at {output}/config.toml")
click.echo(f" game : {game}")
click.echo(f" board_size : {metadata.board_size[0]}×{metadata.board_size[1]}")
click.echo(f" actions : {metadata.actions}")
click.echo(f" reward : {metadata.reward}")
if metadata.character_set:
click.echo(f" characters : {metadata.character_set}")
else:
click.echo(f" characters : (will be auto-discovered during training)")
if metadata.observe_state:
click.echo(f" observe : {metadata.observe_state}")
click.echo(f" architecture: {'CNN (spatial)' if metadata.spatial else 'MLP (non-spatial)'}")
# ---------------------------------------------------------------------------
# retro-gamer train
# ---------------------------------------------------------------------------
@cli.command()
@click.argument('run_dir')
@click.option('--resume', default=None,
help='Path to checkpoint to resume from (e.g. checkpoints/ep_0500.pt)')
def train(run_dir, resume):
"""Train (or resume training) a DQN agent."""
run_dir_path = Path(run_dir)
config = _load_config(run_dir_path)
game_factory = _load_factory(config['game']['module'])
metadata = GameMetadata.from_dict(config['metadata'])
hyperparams = config.get('hyperparameters', {})
trainer = DQNTrainer(game_factory, metadata, run_dir, **hyperparams)
if resume:
click.echo(f"Resuming from {resume}")
trainer.load_checkpoint(resume)
click.echo(f"Training for {trainer.hp['training_episodes']} episodes…")
trainer.train()
click.echo(f"Done. Checkpoints in {run_dir}/checkpoints/")
# ---------------------------------------------------------------------------
# retro-gamer play
# ---------------------------------------------------------------------------
@cli.command()
@click.argument('run_dir')
@click.option('--checkpoint', default='final',
help='Checkpoint name e.g. "final" or "ep_0100"')
@click.option('--framerate', default=12, type=int,
help='Target frames per second')
def play(run_dir, checkpoint, framerate):
"""Watch a trained agent play the game."""
import torch
from time import sleep, perf_counter
from blessed import Terminal
from retro.input import ProgrammaticInput
from retro.views.headless import HeadlessView
from retro.views.terminal import TerminalView
from retro_gamer.observation import encode_observation
run_dir_path = Path(run_dir)
config = _load_config(run_dir_path)
game_factory = _load_factory(config['game']['module'])
metadata = GameMetadata.from_dict(config['metadata'])
hyperparams = {**DEFAULTS, **config.get('hyperparameters', {})}
from retro_gamer.network import build_network
model, _ = build_network(metadata, hyperparams)
ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt'
ckpt_path = run_dir_path / 'checkpoints' / ckpt_name
ckpt = torch.load(ckpt_path, weights_only=True)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
inp = ProgrammaticInput()
headless = HeadlessView()
game = game_factory()
game.input_source = inp
game.view = headless
game.start()
terminal = Terminal()
term_view = TerminalView(terminal, color=game.color)
click.echo("Playing… (press Escape or Enter to quit)")
with terminal.fullscreen(), terminal.hidden_cursor(), terminal.cbreak():
term_view.on_game_start(game)
while game.playing:
t0 = perf_counter()
obs = encode_observation(headless.board_characters, dict(game.state), metadata)
state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
action_idx = int(model(state_t).argmax().item())
action_key = None if action_idx >= len(metadata.actions) else metadata.actions[action_idx]
key = terminal.inkey(0)
if key and key.name in ('KEY_ESCAPE', 'KEY_ENTER'):
break
inp.press(action_key)
game.step()
term_view.render(game)
elapsed = perf_counter() - t0
sleep(max(0, 1 / framerate - elapsed))
# ---------------------------------------------------------------------------
# retro-gamer info
# ---------------------------------------------------------------------------
@cli.command()
@click.argument('run_dir')
def info(run_dir):
"""Print a summary of a training run."""
run_dir_path = Path(run_dir)
config = _load_config(run_dir_path)
click.echo(f"Game module : {config['game']['module']}")
click.echo(f"Metadata : {config['metadata']}")
click.echo(f"Hyperparams : {config.get('hyperparameters', {})}")
log_path = run_dir_path / 'training.log'
if log_path.exists():
lines = log_path.read_text().splitlines()
episode_lines = [l for l in lines if l.startswith('[EP')]
if episode_lines:
click.echo(f"\nLast 5 episodes:")
for line in episode_lines[-5:]:
click.echo(f" {line}")
ckpt_dir = run_dir_path / 'checkpoints'
if ckpt_dir.exists():
ckpts = sorted(ckpt_dir.glob('*.pt'))
click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _load_config(run_dir: Path) -> dict:
config_path = run_dir / 'config.toml'
if not config_path.exists():
raise click.ClickException(f"No config.toml found in {run_dir}")
with open(config_path, 'rb') as f:
return tomllib.load(f)
def _load_factory(module_name: str):
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise click.ClickException(f"Cannot import game module '{module_name}': {e}")
if not hasattr(module, 'create_game'):
raise click.ClickException(
f"Module '{module_name}' has no create_game() function"
)
return module.create_game

78
retro_gamer/env.py Normal file
View File

@@ -0,0 +1,78 @@
from __future__ import annotations
import random
import numpy as np
from typing import Callable
from retro.input import ProgrammaticInput
from retro.views.headless import HeadlessView
from retro_gamer.metadata import GameMetadata
from retro_gamer.observation import encode_observation
class GameEnvironment:
"""Gym-style wrapper around a retro Game for RL training.
Provides reset() / step(action) / observe(), managing one training episode
at a time. The game is restarted by calling the factory function on each reset.
"""
def __init__(self, game_factory: Callable, metadata: GameMetadata):
self.game_factory = game_factory
self.metadata = metadata
self.game = None
self.view: HeadlessView | None = None
self.inp: ProgrammaticInput | None = None
self._prev_reward: float = 0.0
def reset(self) -> np.ndarray:
"""Create a fresh game episode and return the initial observation."""
self.inp = ProgrammaticInput()
self.view = HeadlessView()
self.game = self.game_factory()
self.game.input_source = self.inp
self.game.view = self.view
self.game.start()
self._prev_reward = float(self.game.state.get(self.metadata.reward, 0))
return self._observe()
def step(self, action: str | None) -> tuple[np.ndarray, float, bool]:
"""Advance one turn with the given action.
action: a keystroke string (e.g. 'KEY_RIGHT') or None for no-op.
Returns (observation, reward, done).
Reward is the change in the reward state key since the previous step.
"""
self.inp.press(action)
self.game.step()
obs = self._observe()
reward = self._delta_reward()
done = not self.game.playing
return obs, reward, done
def _observe(self) -> np.ndarray:
return encode_observation(
self.view.board_characters,
dict(self.game.state),
self.metadata,
)
def _delta_reward(self) -> float:
current = float(self.game.state.get(self.metadata.reward, 0))
delta = current - self._prev_reward
self._prev_reward = current
return delta
def discover_character_set(self, exploration_turns: int) -> list[str]:
"""Run random turns to discover the characters that appear on the board.
Returns the sorted character list (excluding space).
"""
obs = self.reset()
chars: set[str] = set()
for _ in range(exploration_turns):
for row in self.view.board_characters:
chars.update(row)
action = random.choice(self.metadata.actions + [None])
_, _, done = self.step(action)
if done:
self.reset()
chars.discard(' ')
return sorted(chars)

View File

View File

@@ -0,0 +1,17 @@
from retro.game import Game
from retro_gamer.examples.beast.board import Board
WIDTH = 40
HEIGHT = 20
NUM_BEASTS = 10
def create_game():
"""Return a fresh, initialized Beast game."""
board = Board(WIDTH, HEIGHT, num_beasts=NUM_BEASTS)
state = {'beasts_killed': 0}
game = Game(board.get_agents(), state, board_size=(WIDTH, HEIGHT))
game.num_beasts = NUM_BEASTS
return game
if __name__ == '__main__':
create_game().play()

View File

@@ -0,0 +1,67 @@
from retro_gamer.examples.beast.helpers import add, distance, get_occupant
from random import random, choice
class Beast:
"""A beast that hunts the player."""
character = "H"
color = "red"
probability_of_moving = 0.03
probability_of_random_move = 0.2
deadly = True
def __init__(self, position):
self.position = position
def handle_push(self, vector, game):
future_position = add(self.position, vector)
on_board = game.on_board(future_position)
obstacle = get_occupant(game, future_position)
if obstacle or not on_board:
self.die(game)
return True
else:
return False
def play_turn(self, game):
if self.should_move():
possible_moves = []
for position in self.get_adjacent_positions():
if game.is_empty(position) and game.on_board(position):
possible_moves.append(position)
if possible_moves:
if self.should_move_randomly():
self.position = choice(possible_moves)
else:
self.position = self.choose_best_move(possible_moves, game)
player = game.get_agent_by_name("player")
if player and player.position == self.position:
player.die(game)
def get_adjacent_positions(self):
"""Returns all eight adjacent positions, including diagonals."""
positions = []
for i in [-1, 0, 1]:
for j in [-1, 0, 1]:
if i or j:
positions.append(add(self.position, (i, j)))
return positions
def should_move(self):
return random() < self.probability_of_moving
def should_move_randomly(self):
return random() < self.probability_of_random_move
def choose_best_move(self, possible_moves, game):
player = game.get_agent_by_name("player")
move_distances = [[distance(player.position, move), move] for move in possible_moves]
shortest_distance, best_move = sorted(move_distances)[0]
return best_move
def die(self, game):
game.remove_agent(self)
game.num_beasts -= 1
game.state['beasts_killed'] += 1
if game.num_beasts == 0:
game.state["message"] = "You win!"
game.end()

View File

@@ -0,0 +1,25 @@
from retro_gamer.examples.beast.helpers import add, get_occupant
class Block:
"""A static block that can be pushed by the player."""
character = ""
color = "green4"
deadly = False
def __init__(self, position):
self.position = position
def handle_push(self, vector, game):
"""Responds to a push in the direction of vector.
Returns True when the push succeeds in creating empty space.
"""
future_position = add(self.position, vector)
on_board = game.on_board(future_position)
obstacle = get_occupant(game, future_position)
if obstacle:
success = obstacle.handle_push(vector, game)
else:
success = on_board
if success:
self.position = future_position
return success

View File

@@ -0,0 +1,39 @@
from retro_gamer.examples.beast.helpers import add, get_occupant
direction_vectors = {
"KEY_RIGHT": (1, 0),
"KEY_UP": (0, -1),
"KEY_LEFT": (-1, 0),
"KEY_DOWN": (0, 1),
}
class Player:
character = "*"
color = "white"
name = "player"
deadly = False
def __init__(self, position):
self.position = position
def handle_keystroke(self, keystroke, game):
if keystroke.name in direction_vectors:
vector = direction_vectors[keystroke.name]
self.try_to_move(vector, game)
def try_to_move(self, vector, game):
future_position = add(self.position, vector)
on_board = game.on_board(future_position)
obstacle = get_occupant(game, future_position)
if obstacle:
if obstacle.deadly:
self.die(game)
elif obstacle.handle_push(vector, game):
self.position = future_position
elif on_board:
self.position = future_position
def die(self, game):
self.color = "black_on_red"
game.state["message"] = "The beasties win!"
game.end()

View File

@@ -0,0 +1,44 @@
from random import shuffle
from retro_gamer.examples.beast.agents.player import Player
from retro_gamer.examples.beast.agents.beast import Beast
from retro_gamer.examples.beast.agents.block import Block
class Board:
"""Creates the agents needed at the beginning of the game and assigns their positions."""
def __init__(self, width, height, block_density=0.3, num_beasts=10):
self.width = width
self.height = height
self.block_density = block_density
self.num_blocks = round(width * height * block_density)
self.num_empty_spaces = width * height - self.num_blocks
self.num_beasts = num_beasts
self.validate()
def validate(self):
if self.block_density < 0 or self.block_density > 1:
raise ValueError("block density must be between 0 and 1.")
if self.num_empty_spaces < self.num_beasts + 1:
raise ValueError("Not enough space on the board.")
def get_agents(self):
"""Returns a list of agents initialized in their starting positions."""
positions = self.get_all_positions()
shuffle(positions)
player_position = positions[0]
beast_positions = positions[1:self.num_beasts + 1]
block_positions = positions[-self.num_blocks:]
player = [Player(player_position)]
beasts = [Beast(pos) for pos in beast_positions]
blocks = [Block(pos) for pos in block_positions]
return player + beasts + blocks
def get_all_positions(self):
"""Returns a list of all positions on the board."""
positions = []
for i in range(self.width):
for j in range(self.height):
positions.append((i, j))
return positions

View File

@@ -0,0 +1,18 @@
def add(vec0, vec1):
"""Adds two vectors."""
x0, y0 = vec0
x1, y1 = vec1
return (x0 + x1, y0 + y1)
def get_occupant(game, position):
"""Returns the agent at position, if there is one."""
positions_with_agents = game.get_agents_by_position()
if position in positions_with_agents:
agents_at_position = positions_with_agents[position]
return agents_at_position[0]
def distance(vec0, vec1):
"""Returns the Manhattan distance between two positions."""
x0, y0 = vec0
x1, y1 = vec1
return abs(x1 - x0) + abs(y1 - y0)

View File

@@ -0,0 +1,6 @@
[tool.retro-gamer]
actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]
reward = "beasts_killed"
character_set = ["*", "H", "█"]
spatial = true
observe_state = []

74
retro_gamer/memory.py Normal file
View File

@@ -0,0 +1,74 @@
from __future__ import annotations
import random
from collections import deque
from typing import NamedTuple
import numpy as np
class Experience(NamedTuple):
state: np.ndarray
action: int
reward: float
next_state: np.ndarray
done: bool
class ReplayMemory:
"""Fixed-capacity ring buffer of experiences sampled uniformly."""
def __init__(self, capacity: int):
self.memory: deque[Experience] = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.memory.append(Experience(state, action, reward, next_state, done))
def sample(self, batch_size: int) -> list[Experience]:
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
class PrioritizedReplayMemory:
"""Experience replay buffer with priority-weighted sampling.
Experiences with higher TD-error are sampled more often (alpha controls
the strength of prioritization). Importance-sampling weights (beta) correct
for the resulting bias.
"""
def __init__(self, capacity: int, alpha: float = 0.6, beta: float = 0.4):
self.capacity = capacity
self.alpha = alpha
self.beta = beta
self.memory: list[Experience] = []
self.priorities: list[float] = []
self._pos = 0
def push(self, state, action, reward, next_state, done):
max_priority = max(self.priorities, default=1.0)
exp = Experience(state, action, reward, next_state, done)
if len(self.memory) < self.capacity:
self.memory.append(exp)
self.priorities.append(max_priority)
else:
self.memory[self._pos] = exp
self.priorities[self._pos] = max_priority
self._pos = (self._pos + 1) % self.capacity
def sample(self, batch_size: int) -> tuple[list[Experience], np.ndarray, np.ndarray]:
"""Returns (experiences, indices, importance_sampling_weights)."""
probs = np.array(self.priorities, dtype=np.float64) ** self.alpha
probs /= probs.sum()
indices = np.random.choice(len(self.memory), batch_size, p=probs)
weights = (len(self.memory) * probs[indices]) ** -self.beta
weights = (weights / weights.max()).astype(np.float32)
experiences = [self.memory[i] for i in indices]
return experiences, indices, weights
def update_priorities(self, indices: np.ndarray, td_errors: np.ndarray):
for idx, err in zip(indices, td_errors):
self.priorities[idx] = float(abs(err)) + 1e-6
def __len__(self):
return len(self.memory)

121
retro_gamer/metadata.py Normal file
View File

@@ -0,0 +1,121 @@
from __future__ import annotations
import importlib
import tomllib
import tomli_w
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class GameMetadata:
"""Describes a retro game for training purposes.
Required fields: actions, reward.
Optional fields: character_set, spatial, observe_state.
Discovered fields: board_size (read from game.board_size at startup).
Metadata is read from the game's own pyproject.toml under the
[tool.retro-gamer] section via GameMetadata.from_pyproject().
"""
actions: list[str]
reward: str
character_set: list[str] | None = None
spatial: bool = True
observe_state: list[str] = field(default_factory=list)
board_size: tuple[int, int] | None = None
def validate(self):
if not self.actions:
raise ValueError("actions must be a non-empty list")
if not self.reward:
raise ValueError("reward must be a non-empty string")
if self.reward in self.observe_state:
raise ValueError(f"reward key '{self.reward}' must not appear in observe_state")
if self.character_set is not None:
for ch in self.character_set:
if len(ch) != 1:
raise ValueError(f"character_set entries must be single characters, got {ch!r}")
@classmethod
def from_pyproject(cls, module_name: str) -> GameMetadata:
"""Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml.
Finds pyproject.toml by walking up from the module's source file.
Raises FileNotFoundError if no pyproject.toml is found, or ValueError
if the file exists but has no [tool.retro-gamer] section.
"""
pyproject_path = _find_pyproject(module_name)
if pyproject_path is None:
raise FileNotFoundError(
f"Could not find pyproject.toml for module '{module_name}'. "
f"Make sure the module is part of a Python project with a pyproject.toml."
)
with open(pyproject_path, 'rb') as f:
data = tomllib.load(f)
section = data.get('tool', {}).get('retro-gamer')
if section is None:
raise ValueError(
f"No [tool.retro-gamer] section found in {pyproject_path}.\n"
f"Add game metadata to your pyproject.toml:\n\n"
f"[tool.retro-gamer]\n"
f"actions = [\"KEY_RIGHT\", ...]\n"
f"reward = \"score\"\n"
)
return cls.from_dict(section)
@classmethod
def from_dict(cls, d: dict) -> GameMetadata:
board_size = tuple(d['board_size']) if 'board_size' in d else None
return cls(
actions=d['actions'],
reward=d['reward'],
character_set=d.get('character_set'),
spatial=d.get('spatial', True),
observe_state=d.get('observe_state', []),
board_size=board_size,
)
def to_dict(self) -> dict:
d = {
'actions': self.actions,
'reward': self.reward,
'spatial': self.spatial,
'observe_state': self.observe_state,
}
if self.board_size is not None:
d['board_size'] = list(self.board_size)
if self.character_set is not None:
d['character_set'] = self.character_set
return d
def to_toml(self, path: str | Path):
with open(path, 'wb') as f:
tomli_w.dump({'metadata': self.to_dict()}, f)
@property
def obs_size(self) -> int:
"""Total size of the flat observation vector."""
C = len(self.character_set) if self.character_set else 0
bw, bh = self.board_size
return C * bw * bh + len(self.observe_state)
@property
def n_actions(self) -> int:
"""Number of actions including no-op."""
return len(self.actions) + 1
def _find_pyproject(module_name: str) -> Path | None:
"""Walk up from a module's source file to find its pyproject.toml."""
try:
module = importlib.import_module(module_name)
except ImportError:
return None
module_file = getattr(module, '__file__', None)
if module_file is None:
return None
for parent in Path(module_file).resolve().parents:
candidate = parent / 'pyproject.toml'
if candidate.exists():
return candidate
return None

100
retro_gamer/network.py Normal file
View File

@@ -0,0 +1,100 @@
from __future__ import annotations
import torch
import torch.nn as nn
from retro_gamer.metadata import GameMetadata
def build_network(
metadata: GameMetadata,
hyperparams: dict,
) -> tuple[nn.Module, str]:
"""Build a Q-network from game metadata and hyperparameters.
Returns (model, rationale) where rationale is a multi-line string
describing the architecture and the reasoning behind each choice.
"""
n_layers = hyperparams.get('n_layers', 2)
layer_size = hyperparams.get('layer_size', 128)
C = len(metadata.character_set)
bw, bh = metadata.board_size
W, H = bw, bh
n_state = len(metadata.observe_state)
n_actions = metadata.n_actions
lines = []
lines.append("[INIT] === Network Architecture ===")
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
lines.append(f"[INIT] Observed state keys: {n_state} | Actions (incl. no-op): {n_actions}")
if metadata.spatial:
model = _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines)
else:
obs_size = C * W * H + n_state
model = _build_flat(obs_size, n_layers, layer_size, n_actions, lines)
lines.append(f"[INIT] Hidden layers: {n_layers} | Layer width: {layer_size}")
lines.append(f"[INIT] Output: {n_actions} Q-values")
lines.append(f"[INIT] Actions: {metadata.actions} + (no-op)")
return model, '\n'.join(lines)
def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
lines.append("[INIT] spatial=True → using CNN architecture")
lines.append("[INIT] Rationale: the board is a 2-D spatial scene; a CNN captures")
lines.append("[INIT] local patterns (walls, items nearby) more efficiently than an MLP.")
lines.append(f"[INIT] CNN: Conv2d({C}→32, k=3, pad=1) → ReLU → Conv2d(32→64, k=3, pad=1) → ReLU")
conv_out = 64 * H * W # padding=1 preserves spatial dims
lines.append(f"[INIT] CNN output: 64 channels × {H}×{W} = {conv_out} features (flattened)")
mlp_in = conv_out + n_state
lines.append(f"[INIT] MLP head input: {conv_out} (conv) + {n_state} (state) = {mlp_in}")
lines.append(f"[INIT] MLP: {''.join([str(mlp_in)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
return _SpatialNet(C, H, W, n_state, n_layers, layer_size, n_actions)
def _build_flat(obs_size, n_layers, layer_size, n_actions, lines):
lines.append("[INIT] spatial=False → using MLP architecture")
lines.append("[INIT] Rationale: the board encodes UI/status rather than a spatial scene;")
lines.append("[INIT] a flat MLP over the full observation is sufficient.")
lines.append(f"[INIT] MLP: {''.join([str(obs_size)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
return _FlatNet(obs_size, n_layers, layer_size, n_actions)
class _SpatialNet(nn.Module):
def __init__(self, C, H, W, n_state, n_layers, layer_size, n_actions):
super().__init__()
self.C, self.H, self.W = C, H, W
self.n_board = C * H * W
self.conv = nn.Sequential(
nn.Conv2d(C, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
)
conv_out = 64 * H * W
mlp_in = conv_out + n_state
layers: list[nn.Module] = []
for i in range(n_layers):
in_size = mlp_in if i == 0 else layer_size
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
layers.append(nn.Linear(layer_size, n_actions))
self.mlp = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
board = x[:, :self.n_board].reshape(-1, self.C, self.H, self.W)
state = x[:, self.n_board:]
conv_out = self.conv(board).flatten(start_dim=1)
return self.mlp(torch.cat([conv_out, state], dim=1))
class _FlatNet(nn.Module):
def __init__(self, obs_size, n_layers, layer_size, n_actions):
super().__init__()
layers: list[nn.Module] = []
for i in range(n_layers):
in_size = obs_size if i == 0 else layer_size
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
layers.append(nn.Linear(layer_size, n_actions))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)

View File

@@ -0,0 +1,49 @@
import numpy as np
from retro_gamer.metadata import GameMetadata
def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
"""One-hot encode the board.
Returns an array of shape (H, W, C) where C = len(character_set).
Unknown characters produce a zero vector.
"""
char_to_idx = {c: i for i, c in enumerate(character_set)}
H = len(board_chars)
W = len(board_chars[0]) if board_chars else 0
C = len(character_set)
arr = np.zeros((H, W, C), dtype=np.float32)
for y, row in enumerate(board_chars):
for x, char in enumerate(row):
idx = char_to_idx.get(char)
if idx is not None:
arr[y, x, idx] = 1.0
return arr
def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
"""Extract observed state keys into a 1D float array."""
return np.array([float(state.get(k, 0)) for k in observe_state], dtype=np.float32)
def encode_observation(
board_chars: list[list[str]],
state: dict,
metadata: GameMetadata,
) -> np.ndarray:
"""Encode board + state into a flat 1D observation vector.
For spatial games the board is encoded channel-first (C, H, W) then flattened,
so the network can reshape it back for CNN processing. For non-spatial games the
board is encoded (H, W, C) then flattened.
The state vector is appended at the end in both cases.
"""
if not metadata.character_set:
raise ValueError("character_set must be set before encoding observations")
board = encode_board(board_chars, metadata.character_set) # (H, W, C)
if metadata.spatial:
board_vec = board.transpose(2, 0, 1).flatten() # C*H*W, channel-first
else:
board_vec = board.flatten() # H*W*C
state_vec = encode_state(state, metadata.observe_state)
return np.concatenate([board_vec, state_vec])

255
retro_gamer/trainer.py Normal file
View File

@@ -0,0 +1,255 @@
from __future__ import annotations
import random
from pathlib import Path
from typing import Callable
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tomli_w
from retro_gamer.metadata import GameMetadata
from retro_gamer.env import GameEnvironment
from retro_gamer.network import build_network
from retro_gamer.memory import ReplayMemory, PrioritizedReplayMemory
DEFAULTS: dict = {
'learning_rate': 1e-3,
'lr_decay': 0.995,
'gamma': 0.99,
'epsilon': 1.0,
'epsilon_decay': 0.995,
'epsilon_min': 0.05,
'batch_size': 64,
'memory_capacity': 10_000,
'target_update_freq': 100,
'training_episodes': 1_000,
'n_layers': 2,
'layer_size': 128,
'prioritize_experiences': False,
'exploration_turns': 200,
'unknown_character_strategy': 'ignore',
'max_turns_per_episode': 2_000,
}
class DQNTrainer:
"""Trains a deep Q-network agent to play a retro game.
On initialization the trainer:
1. Discovers the character set (if not already specified in metadata).
2. Builds the Q-network and logs the full architecture with rationale.
3. Saves config.toml and starts training.log in run_dir.
Call train() to run all episodes and save checkpoints.
"""
def __init__(
self,
game_factory: Callable,
metadata: GameMetadata,
run_dir: str | Path,
**hyperparams,
):
self.game_factory = game_factory
self.metadata = metadata
self.run_dir = Path(run_dir)
self.hp: dict = {**DEFAULTS, **hyperparams}
self.run_dir.mkdir(parents=True, exist_ok=True)
(self.run_dir / 'checkpoints').mkdir(exist_ok=True)
self.env = GameEnvironment(game_factory, metadata)
if metadata.board_size is None:
g = game_factory()
metadata.board_size = g.board_size
if metadata.character_set is None:
self._discover_character_set()
self.model, rationale = build_network(metadata, self.hp)
self.target_model, _ = build_network(metadata, self.hp)
self.target_model.load_state_dict(self.model.state_dict())
self.target_model.eval()
self.optimizer = optim.Adam(
self.model.parameters(), lr=self.hp['learning_rate']
)
self.lr_scheduler = optim.lr_scheduler.ExponentialLR(
self.optimizer, gamma=self.hp['lr_decay']
)
if self.hp['prioritize_experiences']:
self.memory = PrioritizedReplayMemory(self.hp['memory_capacity'])
else:
self.memory = ReplayMemory(self.hp['memory_capacity'])
self.epsilon: float = self.hp['epsilon']
self.total_steps: int = 0
self._save_config()
self._open_log(rationale)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def train(self):
"""Run all training episodes and save checkpoints."""
for episode in range(1, self.hp['training_episodes'] + 1):
total_reward, steps, avg_loss = self._run_episode()
self.epsilon = max(
self.hp['epsilon_min'], self.epsilon * self.hp['epsilon_decay']
)
self.lr_scheduler.step()
self._log_episode(episode, total_reward, steps, avg_loss)
if episode % 100 == 0:
self._save_checkpoint(f'ep_{episode:04d}.pt')
self._save_checkpoint('final.pt')
def load_checkpoint(self, path: str | Path):
ckpt = torch.load(path, weights_only=True)
self.model.load_state_dict(ckpt['model_state_dict'])
self.target_model.load_state_dict(ckpt['model_state_dict'])
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
self.epsilon = ckpt['epsilon']
self.total_steps = ckpt['total_steps']
# ------------------------------------------------------------------
# Training loop internals
# ------------------------------------------------------------------
def _run_episode(self) -> tuple[float, int, float]:
state = self.env.reset()
total_reward = 0.0
total_loss = 0.0
loss_count = 0
for step in range(self.hp['max_turns_per_episode']):
state_t = torch.as_tensor(state, dtype=torch.float32)
action_idx = self._select_action(state_t)
action_key = self._idx_to_key(action_idx)
next_state, reward, done = self.env.step(action_key)
self.memory.push(state, action_idx, reward, next_state, done)
loss = self._train_step()
if loss is not None:
total_loss += loss
loss_count += 1
self.total_steps += 1
if self.total_steps % self.hp['target_update_freq'] == 0:
self.target_model.load_state_dict(self.model.state_dict())
state = next_state
total_reward += reward
if done:
break
avg_loss = total_loss / loss_count if loss_count else 0.0
return total_reward, step + 1, avg_loss
def _select_action(self, state_t: torch.Tensor) -> int:
if random.random() < self.epsilon:
return random.randrange(self.metadata.n_actions)
with torch.no_grad():
return int(self.model(state_t.unsqueeze(0)).argmax().item())
def _idx_to_key(self, idx: int) -> str | None:
if idx >= len(self.metadata.actions):
return None
return self.metadata.actions[idx]
def _train_step(self) -> float | None:
if len(self.memory) < self.hp['batch_size']:
return None
if self.hp['prioritize_experiences']:
assert isinstance(self.memory, PrioritizedReplayMemory)
experiences, indices, weights = self.memory.sample(self.hp['batch_size'])
weight_t = torch.as_tensor(weights, dtype=torch.float32)
else:
experiences = self.memory.sample(self.hp['batch_size'])
indices = None
weight_t = None
states = torch.as_tensor(
np.array([e.state for e in experiences]), dtype=torch.float32
)
actions = torch.as_tensor([e.action for e in experiences], dtype=torch.long)
rewards = torch.as_tensor([e.reward for e in experiences], dtype=torch.float32)
next_states = torch.as_tensor(
np.array([e.next_state for e in experiences]), dtype=torch.float32
)
dones = torch.as_tensor([e.done for e in experiences], dtype=torch.float32)
q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
with torch.no_grad():
next_q = self.target_model(next_states).max(1).values
targets = rewards + self.hp['gamma'] * next_q * (1.0 - dones)
element_loss = nn.functional.mse_loss(q_values, targets, reduction='none')
if weight_t is not None:
loss = (weight_t * element_loss).mean()
td_errors = (q_values - targets).detach().abs().numpy()
self.memory.update_priorities(indices, td_errors)
else:
loss = element_loss.mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return float(loss.item())
# ------------------------------------------------------------------
# Initialisation helpers
# ------------------------------------------------------------------
def _discover_character_set(self):
chars = self.env.discover_character_set(self.hp['exploration_turns'])
self.metadata.character_set = chars
self._log_raw(
f"[INIT] character_set not specified — discovered {len(chars)} chars "
f"after {self.hp['exploration_turns']} exploration turns: {chars}"
)
def _save_config(self):
config_path = self.run_dir / 'config.toml'
config: dict = {}
if config_path.exists():
import tomllib
with open(config_path, 'rb') as f:
config = tomllib.load(f)
config['metadata'] = self.metadata.to_dict()
config['hyperparameters'] = self.hp
with open(config_path, 'wb') as f:
tomli_w.dump(config, f)
def _open_log(self, rationale: str):
self.log_path = self.run_dir / 'training.log'
with open(self.log_path, 'w') as f:
f.write(rationale + '\n')
def _log_raw(self, line: str):
with open(self.log_path, 'a') as f:
f.write(line + '\n')
def _log_episode(self, episode: int, total_reward: float, steps: int, avg_loss: float):
line = (
f"[EP {episode:04d}] total_reward={total_reward:.1f} "
f"steps={steps} epsilon={self.epsilon:.4f} avg_loss={avg_loss:.6f}"
)
self._log_raw(line)
def _save_checkpoint(self, name: str):
torch.save(
{
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'epsilon': self.epsilon,
'total_steps': self.total_steps,
},
self.run_dir / 'checkpoints' / name,
)