21 KiB
retro-gamer Implementation Plan
Overview
retro-gamer is a Python package that trains deep Q-learning agents to play games built with the retro-games framework. Students specify game metadata (character set, actions, reward signal, etc.) and adjust training hyperparameters, letting the training scaffold make principled architecture decisions on their behalf. The package is designed to be pedagogically transparent: every design decision is logged with a plain-language rationale.
Package Structure
retro_gamer/
__init__.py
metadata.py # GameMetadata dataclass + TOML load/save + validation
env.py # GameEnvironment — wraps a Game in a step/reset/observe interface
observation.py # Converts board + state_dict → numpy tensors (one-hot encoding, etc.)
network.py # Builds PyTorch CNN or MLP from metadata + hyperparameters
memory.py # ReplayMemory (standard FIFO and prioritized variants)
trainer.py # DQNTrainer — orchestrates training, logging, checkpointing
cli.py # Click-based CLI (create, train, play subcommands)
Phase 1 — retro-games Framework Changes
These changes must land in the retro package before retro-gamer can work. See the "Changes to retro" section at the bottom for details.
Phase 2 — GameMetadata (metadata.py)
GameMetadata is a validated dataclass that lives as a TOML file alongside each training run.
Required fields
| Field | Type | Description |
|---|---|---|
board_size |
(int, int) |
Width × height of the board |
actions |
list[str] |
Keystroke names the agent may emit (e.g. "KEY_RIGHT", "q") |
reward |
str |
Key in game.state used as the reward signal |
Optional fields
| Field | Type | Default | Description |
|---|---|---|---|
character_set |
list[str] |
None |
Characters that may appear on the board. Each character gets a one-hot encoding slot. If absent, the trainer explores first to discover the set. |
spatial |
bool |
True |
When True, treat the board as a 2D spatial scene (use CNN). When False, flatten and use MLP. |
observe_state |
list[str] |
[] |
State dict keys to append to the observation vector. Values must be int, float, or bool. Must not include reward. |
TOML example
[game]
board_size = [32, 16]
actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]
reward = "score"
character_set = ["@", "*", ">", "<", "^", "v"]
spatial = true
observe_state = []
Validation
board_sizemust be two positive integers.actionsmust be non-empty.rewardmust not appear inobserve_state.character_setentries must be single characters.
Phase 3 — GameEnvironment (env.py)
GameEnvironment is retro-gamer's boundary with the retro-games framework. It owns the game lifecycle for one training episode and exposes a gym-style interface.
class GameEnvironment:
def __init__(self, game_factory: Callable[[], Game], metadata: GameMetadata):
...
def reset(self) -> Observation:
"""Calls game_factory() to get a fresh Game; returns initial observation."""
def step(self, action: str | None) -> tuple[Observation, float, bool]:
"""
Injects action into game, runs one turn via Game.step().
Returns (observation, reward, done).
Reward = current state[reward_key] - previous state[reward_key].
"""
def observe(self) -> Observation:
"""Returns current observation without advancing the game."""
Observation is a numpy array constructed by observation.py.
Character-set discovery
When character_set is not provided, GameEnvironment.reset() runs exploration_turns random turns, collects every character seen, and sets metadata.character_set. This also triggers a rebuild of any already-initialized network.
Phase 4 — Observation Encoding (observation.py)
def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
"""
Returns array of shape (H, W, C) where C = len(character_set).
Each cell is a one-hot vector. Unknown characters become the zero vector.
"""
def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
"""Returns 1-D float array of observed state values."""
def encode_observation(board_chars, state, metadata) -> np.ndarray:
"""Concatenates encoded board (flattened if not spatial) with encoded state."""
The board tensor shape fed to the network:
- Spatial:
(C, H, W)(PyTorch channel-first for conv layers) - Non-spatial:
(H * W * C + len(observe_state),)flat vector
Phase 5 — Network Architecture (network.py)
Two builder functions return a torch.nn.Module.
build_network(metadata, hyperparams) -> nn.Module
The function picks the architecture based on metadata.spatial and logs its rationale to a string returned alongside the model.
Spatial (CNN)
Conv2d(in=C, out=32, kernel=3) → ReLU
Conv2d(32, 64, kernel=3) → ReLU
Flatten
Linear(auto-computed, layer_size) → ReLU (repeated n_layers times)
Linear(layer_size, n_actions)
Non-spatial (MLP)
Linear(obs_size, layer_size) → ReLU (repeated n_layers times)
Linear(layer_size, n_actions)
Advanced override
Students who know PyTorch can pass a custom nn.Module subclass via --model-class in the CLI or directly to DQNTrainer. This bypasses the builder entirely. The log will note that a custom model was supplied.
Phase 6 — Replay Memory (memory.py)
Two classes with the same interface:
ReplayMemory(capacity)— standard ring buffer.PrioritizedReplayMemory(capacity, alpha, beta)— priority-weighted sampling based on temporal-difference error magnitude. Used whenprioritize_experiences=True.
memory.push(state, action, reward, next_state, done)
batch = memory.sample(batch_size)
memory.update_priorities(indices, td_errors) # PER only
Phase 7 — DQN Trainer (trainer.py)
Hyperparameters
| Name | Default | Description |
|---|---|---|
learning_rate |
1e-3 |
Adam optimizer LR |
lr_decay |
0.995 |
Multiplicative LR decay per episode |
gamma |
0.99 |
Discount factor |
epsilon |
1.0 |
Initial ε for ε-greedy exploration |
epsilon_decay |
0.995 |
ε decay per episode |
epsilon_min |
0.05 |
Minimum ε |
batch_size |
64 |
Experiences per training step |
memory_capacity |
10000 |
Replay buffer size |
target_update_freq |
100 |
Steps between target-network updates |
training_episodes |
1000 |
Number of episodes |
n_layers |
2 |
Hidden layers in MLP head |
layer_size |
128 |
Width of each hidden layer |
prioritize_experiences |
False |
Use prioritized experience replay |
exploration_turns |
200 |
Random turns used to discover character set (when not specified) |
unknown_character_strategy |
"ignore" |
"ignore" or "extend" (rebuild model when new character found) |
max_turns_per_episode |
2000 |
Safety cutoff for a single episode |
Training directory structure
run_dir/ (e.g. snake_20240501_120000/)
config.toml # metadata + all hyperparameters, frozen at training start
training.log # structured log (see below)
checkpoints/
ep_0100.pt
ep_0500.pt
final.pt
Log file format
The first log entry (written at init) is a full, human-readable description of the model architecture and every design decision, e.g.:
[INIT] Board: 32×16, character set: 6 chars → input channels = 6
[INIT] spatial=True → using CNN architecture
[INIT] CNN layers: Conv(6→32, k=3), Conv(32→64, k=3)
[INIT] Flattened CNN output: 64×28×12 = 21504 → Linear(21504, 128) → Linear(128, 5)
[INIT] Actions: KEY_RIGHT KEY_UP KEY_LEFT KEY_DOWN (none) = 5 outputs
[INIT] Optimizer: Adam, lr=0.001, decay=0.995
[INIT] Exploration: ε-greedy starting at ε=1.0, min=0.05
...
[EP 0001] total_reward=3 steps=145 epsilon=0.995 avg_loss=0.043
DQNTrainer interface
class DQNTrainer:
def __init__(self, game_factory, metadata, run_dir, **hyperparams): ...
def train(self): ... # runs all episodes, saves checkpoints
def load_checkpoint(self, path): ...
Phase 8 — CLI (cli.py)
Implemented with Click, exposed as the retro-gamer entry point.
retro-gamer create
Creates a new training run directory with a config.toml stub.
retro-gamer create --game snake:create_game --output ./runs/snake/ \
--actions KEY_RIGHT,KEY_UP,KEY_LEFT,KEY_DOWN \
--reward score --board-size 32x16 \
[--character-set @,*,>,<,^,v] [--spatial] [--no-spatial] \
[--n-layers 2] [--layer-size 128] [--training-episodes 1000] ...
retro-gamer train
Runs (or resumes) training.
retro-gamer train ./runs/snake/
retro-gamer train ./runs/snake/ --resume checkpoints/ep_0500.pt
retro-gamer play
Loads a saved agent and plays the game visually.
retro-gamer play ./runs/snake/ --checkpoint final
retro-gamer play ./runs/snake/ --checkpoint checkpoints/ep_0100.pt
retro-gamer info
Prints a summary of the run: metadata, hyperparameters, latest training stats.
Implementation Order
- Land retro-games framework changes (input abstraction,
step(), view refactor). metadata.py+ tests — validate, load, save TOML.env.py— wraps game; unit-testable with a trivial fake game.observation.py— encoding functions; purely functional, easy to test.network.py— builders; test by checking output shapes.memory.py— both buffer types.trainer.py— integration; test with a tiny fast game.cli.py— wire everything together.
Required Changes to retro-games Framework
Design principle
These changes are driven by good software design — clean separation of input handling and rendering from game logic — not by retro-gamer's RL concepts. The refactoring should stand on its own as an improvement to the retro-games framework.
1. Factor out input handling (input.py)
Problem: collect_keystrokes(terminal) is baked into the Game.play() loop. There is no way for external code to supply keystrokes programmatically without faking blessed objects, and there is no clean seam to test game logic in isolation from terminal I/O.
Refactor: Introduce an InputSource protocol and two implementations. Game holds an input_source; step() calls self.input_source.collect() regardless of where input originates.
# retro/input.py
class InputSource:
"""Protocol for objects that supply keystrokes to the game each turn."""
def collect(self) -> set:
raise NotImplementedError
class TerminalInput(InputSource):
"""Reads keystrokes from a blessed Terminal (current behaviour)."""
def __init__(self, terminal):
self.terminal = terminal
def collect(self) -> set:
keys = set()
while True:
key = self.terminal.inkey(0.001)
if key:
keys.add(key)
else:
break
return keys
class ProgrammaticInput(InputSource):
"""Accepts keystrokes injected by external code (e.g. retro-gamer)."""
def __init__(self):
self._pending: str | None = None
def press(self, key: str | None):
"""Queue a single keystroke for the next turn.
key should be a key name string (e.g. "KEY_RIGHT", "q") or None.
"""
self._pending = key
def collect(self) -> set:
key = self._pending
self._pending = None
if key is None:
return set()
return {_make_keystroke(key)}
def _make_keystroke(key_str: str):
"""Creates a blessed-compatible keystroke from a string."""
from blessed.keyboard import Keystroke
if key_str.startswith("KEY_"):
return Keystroke(ucs='', code=None, name=key_str)
return Keystroke(ucs=key_str, code=None, name=None)
The _make_keystroke helper is an internal detail of the input module — Game and agents never know they're receiving synthetic keystrokes, and the RL framing never leaks into Game.
Game.__init__ gains an input_source: InputSource | None = None parameter. When None, play() creates a TerminalInput internally (preserving existing usage); when provided explicitly, it is used in step().
retro-gamer usage:
inp = ProgrammaticInput()
game = create_game(input_source=inp) # or set after construction
inp.press("KEY_RIGHT")
game.step()
2. Game.step() and its relationship to Game.play()
Refactor: Extract the per-turn logic from play() into a step() method. play() then loops over step() calls, adding only the terminal context management, rendering, and frame-rate sleep on top. This ensures game logic is exercised via the same code path whether training or playing.
def step(self):
"""Run one turn: collect input, let each agent act, advance state."""
self.turn_number += 1
self.keys_pressed = self.input_source.collect()
if self.debug and self.keys_pressed:
self.log("Keys: " + ', '.join(k.name or str(k) for k in self.keys_pressed))
self.prior_view_position = self.view_position
self.prior_agent_positions = self.agent_positions
for agent in self.agents:
if hasattr(agent, 'handle_keystroke'):
for key in self.keys_pressed:
agent.handle_keystroke(key, self)
if hasattr(agent, 'play_turn'):
agent.play_turn(self)
self._position_cache = None
if getattr(agent, 'display', True):
for pos in agent_occupied_positions(agent):
if not self.on_board(pos):
raise IllegalMove(agent, pos)
self.agent_positions = self.get_agents_by_position()
self.state.changed = False
def play(self):
"""Run the game loop in a terminal with rendering and frame-rate control."""
self.playing = True
terminal = Terminal()
terminal_input = TerminalInput(terminal)
self.input_source = terminal_input
with terminal.fullscreen(), terminal.hidden_cursor(), terminal.cbreak():
view = TerminalView(terminal, color=self.color)
self.agent_positions = {}
self.state.changed = True
while self.playing:
turn_start = perf_counter()
self.step()
view.render(self)
turn_elapsed = perf_counter() - turn_start
sleep(max(0, 1 / self.framerate - turn_elapsed))
if self.dump_state:
...
if self.wait_for_enter:
...
Note: step() does not return anything. It simply advances game state. retro-gamer reads back whatever it needs from the game object (board characters via the HeadlessView — see below) after calling step(). This keeps step() free of RL-specific return values.
3. Refactor view rendering: View protocol, TerminalView, HeadlessView
Problem: All rendering code lives in one View class that is tightly coupled to blessed.Terminal. There is no way to run the game loop without terminal output, and no way for external code to read board state.
Refactor: Separate the concept of "rendering" from game logic by defining a View protocol and two implementations.
retro/view.py → retro/views/terminal.py (TerminalView — current View, renamed)
→ retro/views/headless.py (HeadlessView — new)
→ retro/views/__init__.py (exports View protocol + both classes)
View protocol
# retro/views/__init__.py
from typing import Protocol
class View(Protocol):
def on_game_start(self, game) -> None: ...
def render(self, game) -> None: ...
TerminalView
The current View class, moved to retro/views/terminal.py and renamed TerminalView. No behaviour changes — this is a pure rename/move.
HeadlessView
# retro/views/headless.py
class HeadlessView:
"""Maintains a readable board state without any terminal output.
After each game.step(), board_characters reflects the current board.
"""
def __init__(self):
self.board_characters: list[list[str]] = []
def on_game_start(self, game) -> None:
bw, bh = game.board_size
self.board_characters = [[' '] * bw for _ in range(bh)]
def render(self, game) -> None:
"""Recompute board_characters from current agent positions."""
bw, bh = game.board_size
board = [[' '] * bw for _ in range(bh)]
for (x, y), agents in game.get_agents_by_position().items():
top = max(agents, key=lambda a: getattr(a, 'z', 0) or 0)
board[y][x] = get_agent_character(top, (x, y))
self.board_characters = board
get_agent_character is already defined at module level in the current view.py; it moves to retro/views/_util.py and is imported by both views.
How Game uses a View
Game.__init__ gains an optional view: View | None = None parameter. play() passes a freshly created TerminalView (as today), overriding whatever was set. When step() is called directly (as retro-gamer does), game.view.render(game) is called at the end of each step if a view is set.
# in Game.__init__
self.view = view # None by default
# in Game.step() (end of step)
if self.view is not None:
self.view.render(self)
retro-gamer usage:
headless = HeadlessView()
game = create_game(view=headless)
game.step()
board = headless.board_characters # 2D list of chars, ready to encode
Why this matters
This separation means:
- Game logic (agent turns, state transitions) has no terminal dependency.
TerminalViewcan evolve independently (e.g., colour schemes, layout changes).- Future rendering targets (web, test harness) are first-class citizens, not hacks.
- retro-gamer reads board state from
HeadlessView.board_characters— a clean, stable API.
4. create_game() factory convention and entry points
Convention: Every game module must expose a no-argument function named create_game that returns a fully initialized Game instance. Standardizing the name is the right call — it keeps retro-gamer simple and makes the contract explicit for students.
# retro/examples/snake.py — add at bottom:
def create_game():
head = SnakeHead()
apple = Apple()
game = Game([head, apple], {'score': 0}, board_size=(32, 16), framerate=12)
apple.relocate(game)
return game
retro-gamer loads a game as:
import importlib
module = importlib.import_module('snake') # or 'retro.examples.snake'
game = module.create_game()
The CLI accepts a dotted module name: retro-gamer create --game retro.examples.snake.
What are entry points and should we use them?
Python package entry points are a mechanism for installed packages to advertise named callables that other packages can discover at runtime, without either side hard-coding import paths. They are declared in pyproject.toml:
# In the game package's pyproject.toml:
[project.entry-points."retro.games"]
snake = "retro.examples.snake:create_game"
retro-gamer can then discover all registered games without knowing their module names:
from importlib.metadata import entry_points
games = {ep.name: ep.load() for ep in entry_points(group="retro.games")}
# games == {'snake': <function create_game at 0x...>}
This is how pytest discovers plugins, Flask discovers extensions, etc.
Recommendation: Support both. For day-to-day use, students specify a module name and retro-gamer calls module.create_game(). For published or installed game packages (e.g., a course package with many games), the entry point mechanism lets retro-gamer discover games automatically (retro-gamer list could enumerate all installed games). These are complementary, not competing.
The standardized function name create_game is still required even when using entry points, because retro-gamer also needs to call the function without going through the entry point registry (e.g., when the module is on sys.path but not installed as a package).
Summary of retro changes
| Change | Where | Complexity |
|---|---|---|
InputSource protocol + TerminalInput + ProgrammaticInput |
new retro/input.py |
Medium |
_make_keystroke helper |
retro/input.py (internal) |
Small |
Game.__init__ accepts input_source, view |
game.py |
Small |
Extract Game.step() from Game.play() |
game.py |
Medium |
View protocol |
new retro/views/__init__.py |
Small |
TerminalView (rename + move current View) |
retro/view.py → retro/views/terminal.py |
Small (rename) |
HeadlessView |
new retro/views/headless.py |
Small |
get_agent_character moved to shared util |
retro/views/_util.py |
Trivial |
create_game() factory in snake.py + docs |
snake.py |
Trivial |