Files
retro-gamer/plan.md
Chris Proctor 5ca97dc5d0 Initial commit
2026-05-08 14:07:17 -04:00

21 KiB
Raw Blame History

retro-gamer Implementation Plan

Overview

retro-gamer is a Python package that trains deep Q-learning agents to play games built with the retro-games framework. Students specify game metadata (character set, actions, reward signal, etc.) and adjust training hyperparameters, letting the training scaffold make principled architecture decisions on their behalf. The package is designed to be pedagogically transparent: every design decision is logged with a plain-language rationale.


Package Structure

retro_gamer/
    __init__.py
    metadata.py       # GameMetadata dataclass + TOML load/save + validation
    env.py            # GameEnvironment — wraps a Game in a step/reset/observe interface
    observation.py    # Converts board + state_dict → numpy tensors (one-hot encoding, etc.)
    network.py        # Builds PyTorch CNN or MLP from metadata + hyperparameters
    memory.py         # ReplayMemory (standard FIFO and prioritized variants)
    trainer.py        # DQNTrainer — orchestrates training, logging, checkpointing
    cli.py            # Click-based CLI (create, train, play subcommands)

Phase 1 — retro-games Framework Changes

These changes must land in the retro package before retro-gamer can work. See the "Changes to retro" section at the bottom for details.


Phase 2 — GameMetadata (metadata.py)

GameMetadata is a validated dataclass that lives as a TOML file alongside each training run.

Required fields

Field Type Description
board_size (int, int) Width × height of the board
actions list[str] Keystroke names the agent may emit (e.g. "KEY_RIGHT", "q")
reward str Key in game.state used as the reward signal

Optional fields

Field Type Default Description
character_set list[str] None Characters that may appear on the board. Each character gets a one-hot encoding slot. If absent, the trainer explores first to discover the set.
spatial bool True When True, treat the board as a 2D spatial scene (use CNN). When False, flatten and use MLP.
observe_state list[str] [] State dict keys to append to the observation vector. Values must be int, float, or bool. Must not include reward.

TOML example

[game]
board_size = [32, 16]
actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]
reward = "score"
character_set = ["@", "*", ">", "<", "^", "v"]
spatial = true
observe_state = []

Validation

  • board_size must be two positive integers.
  • actions must be non-empty.
  • reward must not appear in observe_state.
  • character_set entries must be single characters.

Phase 3 — GameEnvironment (env.py)

GameEnvironment is retro-gamer's boundary with the retro-games framework. It owns the game lifecycle for one training episode and exposes a gym-style interface.

class GameEnvironment:
    def __init__(self, game_factory: Callable[[], Game], metadata: GameMetadata):
        ...

    def reset(self) -> Observation:
        """Calls game_factory() to get a fresh Game; returns initial observation."""

    def step(self, action: str | None) -> tuple[Observation, float, bool]:
        """
        Injects action into game, runs one turn via Game.step().
        Returns (observation, reward, done).
        Reward = current state[reward_key] - previous state[reward_key].
        """

    def observe(self) -> Observation:
        """Returns current observation without advancing the game."""

Observation is a numpy array constructed by observation.py.

Character-set discovery

When character_set is not provided, GameEnvironment.reset() runs exploration_turns random turns, collects every character seen, and sets metadata.character_set. This also triggers a rebuild of any already-initialized network.


Phase 4 — Observation Encoding (observation.py)

def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
    """
    Returns array of shape (H, W, C) where C = len(character_set).
    Each cell is a one-hot vector. Unknown characters become the zero vector.
    """

def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
    """Returns 1-D float array of observed state values."""

def encode_observation(board_chars, state, metadata) -> np.ndarray:
    """Concatenates encoded board (flattened if not spatial) with encoded state."""

The board tensor shape fed to the network:

  • Spatial: (C, H, W) (PyTorch channel-first for conv layers)
  • Non-spatial: (H * W * C + len(observe_state),) flat vector

Phase 5 — Network Architecture (network.py)

Two builder functions return a torch.nn.Module.

build_network(metadata, hyperparams) -> nn.Module

The function picks the architecture based on metadata.spatial and logs its rationale to a string returned alongside the model.

Spatial (CNN)

Conv2d(in=C, out=32, kernel=3) → ReLU
Conv2d(32, 64, kernel=3)       → ReLU
Flatten
Linear(auto-computed, layer_size)  → ReLU  (repeated n_layers times)
Linear(layer_size, n_actions)

Non-spatial (MLP)

Linear(obs_size, layer_size) → ReLU  (repeated n_layers times)
Linear(layer_size, n_actions)

Advanced override

Students who know PyTorch can pass a custom nn.Module subclass via --model-class in the CLI or directly to DQNTrainer. This bypasses the builder entirely. The log will note that a custom model was supplied.


Phase 6 — Replay Memory (memory.py)

Two classes with the same interface:

  • ReplayMemory(capacity) — standard ring buffer.
  • PrioritizedReplayMemory(capacity, alpha, beta) — priority-weighted sampling based on temporal-difference error magnitude. Used when prioritize_experiences=True.
memory.push(state, action, reward, next_state, done)
batch = memory.sample(batch_size)
memory.update_priorities(indices, td_errors)  # PER only

Phase 7 — DQN Trainer (trainer.py)

Hyperparameters

Name Default Description
learning_rate 1e-3 Adam optimizer LR
lr_decay 0.995 Multiplicative LR decay per episode
gamma 0.99 Discount factor
epsilon 1.0 Initial ε for ε-greedy exploration
epsilon_decay 0.995 ε decay per episode
epsilon_min 0.05 Minimum ε
batch_size 64 Experiences per training step
memory_capacity 10000 Replay buffer size
target_update_freq 100 Steps between target-network updates
training_episodes 1000 Number of episodes
n_layers 2 Hidden layers in MLP head
layer_size 128 Width of each hidden layer
prioritize_experiences False Use prioritized experience replay
exploration_turns 200 Random turns used to discover character set (when not specified)
unknown_character_strategy "ignore" "ignore" or "extend" (rebuild model when new character found)
max_turns_per_episode 2000 Safety cutoff for a single episode

Training directory structure

run_dir/              (e.g. snake_20240501_120000/)
    config.toml       # metadata + all hyperparameters, frozen at training start
    training.log      # structured log (see below)
    checkpoints/
        ep_0100.pt
        ep_0500.pt
        final.pt

Log file format

The first log entry (written at init) is a full, human-readable description of the model architecture and every design decision, e.g.:

[INIT] Board: 32×16, character set: 6 chars → input channels = 6
[INIT] spatial=True → using CNN architecture
[INIT] CNN layers: Conv(6→32, k=3), Conv(32→64, k=3)
[INIT] Flattened CNN output: 64×28×12 = 21504 → Linear(21504, 128) → Linear(128, 5)
[INIT] Actions: KEY_RIGHT KEY_UP KEY_LEFT KEY_DOWN (none) = 5 outputs
[INIT] Optimizer: Adam, lr=0.001, decay=0.995
[INIT] Exploration: ε-greedy starting at ε=1.0, min=0.05
...
[EP 0001] total_reward=3  steps=145  epsilon=0.995  avg_loss=0.043

DQNTrainer interface

class DQNTrainer:
    def __init__(self, game_factory, metadata, run_dir, **hyperparams): ...
    def train(self): ...             # runs all episodes, saves checkpoints
    def load_checkpoint(self, path): ...

Phase 8 — CLI (cli.py)

Implemented with Click, exposed as the retro-gamer entry point.

retro-gamer create

Creates a new training run directory with a config.toml stub.

retro-gamer create --game snake:create_game --output ./runs/snake/ \
    --actions KEY_RIGHT,KEY_UP,KEY_LEFT,KEY_DOWN \
    --reward score --board-size 32x16 \
    [--character-set @,*,>,<,^,v] [--spatial] [--no-spatial] \
    [--n-layers 2] [--layer-size 128] [--training-episodes 1000] ...

retro-gamer train

Runs (or resumes) training.

retro-gamer train ./runs/snake/
retro-gamer train ./runs/snake/ --resume checkpoints/ep_0500.pt

retro-gamer play

Loads a saved agent and plays the game visually.

retro-gamer play ./runs/snake/ --checkpoint final
retro-gamer play ./runs/snake/ --checkpoint checkpoints/ep_0100.pt

retro-gamer info

Prints a summary of the run: metadata, hyperparameters, latest training stats.


Implementation Order

  1. Land retro-games framework changes (input abstraction, step(), view refactor).
  2. metadata.py + tests — validate, load, save TOML.
  3. env.py — wraps game; unit-testable with a trivial fake game.
  4. observation.py — encoding functions; purely functional, easy to test.
  5. network.py — builders; test by checking output shapes.
  6. memory.py — both buffer types.
  7. trainer.py — integration; test with a tiny fast game.
  8. cli.py — wire everything together.


Required Changes to retro-games Framework

Design principle

These changes are driven by good software design — clean separation of input handling and rendering from game logic — not by retro-gamer's RL concepts. The refactoring should stand on its own as an improvement to the retro-games framework.


1. Factor out input handling (input.py)

Problem: collect_keystrokes(terminal) is baked into the Game.play() loop. There is no way for external code to supply keystrokes programmatically without faking blessed objects, and there is no clean seam to test game logic in isolation from terminal I/O.

Refactor: Introduce an InputSource protocol and two implementations. Game holds an input_source; step() calls self.input_source.collect() regardless of where input originates.

# retro/input.py

class InputSource:
    """Protocol for objects that supply keystrokes to the game each turn."""
    def collect(self) -> set:
        raise NotImplementedError

class TerminalInput(InputSource):
    """Reads keystrokes from a blessed Terminal (current behaviour)."""
    def __init__(self, terminal):
        self.terminal = terminal

    def collect(self) -> set:
        keys = set()
        while True:
            key = self.terminal.inkey(0.001)
            if key:
                keys.add(key)
            else:
                break
        return keys

class ProgrammaticInput(InputSource):
    """Accepts keystrokes injected by external code (e.g. retro-gamer)."""
    def __init__(self):
        self._pending: str | None = None

    def press(self, key: str | None):
        """Queue a single keystroke for the next turn.
        key should be a key name string (e.g. "KEY_RIGHT", "q") or None.
        """
        self._pending = key

    def collect(self) -> set:
        key = self._pending
        self._pending = None
        if key is None:
            return set()
        return {_make_keystroke(key)}

def _make_keystroke(key_str: str):
    """Creates a blessed-compatible keystroke from a string."""
    from blessed.keyboard import Keystroke
    if key_str.startswith("KEY_"):
        return Keystroke(ucs='', code=None, name=key_str)
    return Keystroke(ucs=key_str, code=None, name=None)

The _make_keystroke helper is an internal detail of the input module — Game and agents never know they're receiving synthetic keystrokes, and the RL framing never leaks into Game.

Game.__init__ gains an input_source: InputSource | None = None parameter. When None, play() creates a TerminalInput internally (preserving existing usage); when provided explicitly, it is used in step().

retro-gamer usage:

inp = ProgrammaticInput()
game = create_game(input_source=inp)   # or set after construction
inp.press("KEY_RIGHT")
game.step()

2. Game.step() and its relationship to Game.play()

Refactor: Extract the per-turn logic from play() into a step() method. play() then loops over step() calls, adding only the terminal context management, rendering, and frame-rate sleep on top. This ensures game logic is exercised via the same code path whether training or playing.

def step(self):
    """Run one turn: collect input, let each agent act, advance state."""
    self.turn_number += 1
    self.keys_pressed = self.input_source.collect()
    if self.debug and self.keys_pressed:
        self.log("Keys: " + ', '.join(k.name or str(k) for k in self.keys_pressed))
    self.prior_view_position = self.view_position
    self.prior_agent_positions = self.agent_positions
    for agent in self.agents:
        if hasattr(agent, 'handle_keystroke'):
            for key in self.keys_pressed:
                agent.handle_keystroke(key, self)
        if hasattr(agent, 'play_turn'):
            agent.play_turn(self)
            self._position_cache = None
        if getattr(agent, 'display', True):
            for pos in agent_occupied_positions(agent):
                if not self.on_board(pos):
                    raise IllegalMove(agent, pos)
    self.agent_positions = self.get_agents_by_position()
    self.state.changed = False

def play(self):
    """Run the game loop in a terminal with rendering and frame-rate control."""
    self.playing = True
    terminal = Terminal()
    terminal_input = TerminalInput(terminal)
    self.input_source = terminal_input
    with terminal.fullscreen(), terminal.hidden_cursor(), terminal.cbreak():
        view = TerminalView(terminal, color=self.color)
        self.agent_positions = {}
        self.state.changed = True
        while self.playing:
            turn_start = perf_counter()
            self.step()
            view.render(self)
            turn_elapsed = perf_counter() - turn_start
            sleep(max(0, 1 / self.framerate - turn_elapsed))
        if self.dump_state:
            ...
        if self.wait_for_enter:
            ...

Note: step() does not return anything. It simply advances game state. retro-gamer reads back whatever it needs from the game object (board characters via the HeadlessView — see below) after calling step(). This keeps step() free of RL-specific return values.


3. Refactor view rendering: View protocol, TerminalView, HeadlessView

Problem: All rendering code lives in one View class that is tightly coupled to blessed.Terminal. There is no way to run the game loop without terminal output, and no way for external code to read board state.

Refactor: Separate the concept of "rendering" from game logic by defining a View protocol and two implementations.

retro/view.py       → retro/views/terminal.py   (TerminalView — current View, renamed)
                    → retro/views/headless.py    (HeadlessView — new)
                    → retro/views/__init__.py    (exports View protocol + both classes)

View protocol

# retro/views/__init__.py
from typing import Protocol

class View(Protocol):
    def on_game_start(self, game) -> None: ...
    def render(self, game) -> None: ...

TerminalView

The current View class, moved to retro/views/terminal.py and renamed TerminalView. No behaviour changes — this is a pure rename/move.

HeadlessView

# retro/views/headless.py

class HeadlessView:
    """Maintains a readable board state without any terminal output.
    After each game.step(), board_characters reflects the current board.
    """
    def __init__(self):
        self.board_characters: list[list[str]] = []

    def on_game_start(self, game) -> None:
        bw, bh = game.board_size
        self.board_characters = [[' '] * bw for _ in range(bh)]

    def render(self, game) -> None:
        """Recompute board_characters from current agent positions."""
        bw, bh = game.board_size
        board = [[' '] * bw for _ in range(bh)]
        for (x, y), agents in game.get_agents_by_position().items():
            top = max(agents, key=lambda a: getattr(a, 'z', 0) or 0)
            board[y][x] = get_agent_character(top, (x, y))
        self.board_characters = board

get_agent_character is already defined at module level in the current view.py; it moves to retro/views/_util.py and is imported by both views.

How Game uses a View

Game.__init__ gains an optional view: View | None = None parameter. play() passes a freshly created TerminalView (as today), overriding whatever was set. When step() is called directly (as retro-gamer does), game.view.render(game) is called at the end of each step if a view is set.

# in Game.__init__
self.view = view  # None by default

# in Game.step()  (end of step)
if self.view is not None:
    self.view.render(self)

retro-gamer usage:

headless = HeadlessView()
game = create_game(view=headless)
game.step()
board = headless.board_characters   # 2D list of chars, ready to encode

Why this matters

This separation means:

  • Game logic (agent turns, state transitions) has no terminal dependency.
  • TerminalView can evolve independently (e.g., colour schemes, layout changes).
  • Future rendering targets (web, test harness) are first-class citizens, not hacks.
  • retro-gamer reads board state from HeadlessView.board_characters — a clean, stable API.

4. create_game() factory convention and entry points

Convention: Every game module must expose a no-argument function named create_game that returns a fully initialized Game instance. Standardizing the name is the right call — it keeps retro-gamer simple and makes the contract explicit for students.

# retro/examples/snake.py — add at bottom:
def create_game():
    head = SnakeHead()
    apple = Apple()
    game = Game([head, apple], {'score': 0}, board_size=(32, 16), framerate=12)
    apple.relocate(game)
    return game

retro-gamer loads a game as:

import importlib
module = importlib.import_module('snake')   # or 'retro.examples.snake'
game = module.create_game()

The CLI accepts a dotted module name: retro-gamer create --game retro.examples.snake.

What are entry points and should we use them?

Python package entry points are a mechanism for installed packages to advertise named callables that other packages can discover at runtime, without either side hard-coding import paths. They are declared in pyproject.toml:

# In the game package's pyproject.toml:
[project.entry-points."retro.games"]
snake = "retro.examples.snake:create_game"

retro-gamer can then discover all registered games without knowing their module names:

from importlib.metadata import entry_points
games = {ep.name: ep.load() for ep in entry_points(group="retro.games")}
# games == {'snake': <function create_game at 0x...>}

This is how pytest discovers plugins, Flask discovers extensions, etc.

Recommendation: Support both. For day-to-day use, students specify a module name and retro-gamer calls module.create_game(). For published or installed game packages (e.g., a course package with many games), the entry point mechanism lets retro-gamer discover games automatically (retro-gamer list could enumerate all installed games). These are complementary, not competing.

The standardized function name create_game is still required even when using entry points, because retro-gamer also needs to call the function without going through the entry point registry (e.g., when the module is on sys.path but not installed as a package).


Summary of retro changes

Change Where Complexity
InputSource protocol + TerminalInput + ProgrammaticInput new retro/input.py Medium
_make_keystroke helper retro/input.py (internal) Small
Game.__init__ accepts input_source, view game.py Small
Extract Game.step() from Game.play() game.py Medium
View protocol new retro/views/__init__.py Small
TerminalView (rename + move current View) retro/view.pyretro/views/terminal.py Small (rename)
HeadlessView new retro/views/headless.py Small
get_agent_character moved to shared util retro/views/_util.py Trivial
create_game() factory in snake.py + docs snake.py Trivial