542 lines
21 KiB
Markdown
542 lines
21 KiB
Markdown
# retro-gamer Implementation Plan
|
||
|
||
## Overview
|
||
|
||
`retro-gamer` is a Python package that trains deep Q-learning agents to play games built with the `retro-games` framework. Students specify game metadata (character set, actions, reward signal, etc.) and adjust training hyperparameters, letting the training scaffold make principled architecture decisions on their behalf. The package is designed to be pedagogically transparent: every design decision is logged with a plain-language rationale.
|
||
|
||
---
|
||
|
||
## Package Structure
|
||
|
||
```
|
||
retro_gamer/
|
||
__init__.py
|
||
metadata.py # GameMetadata dataclass + TOML load/save + validation
|
||
env.py # GameEnvironment — wraps a Game in a step/reset/observe interface
|
||
observation.py # Converts board + state_dict → numpy tensors (one-hot encoding, etc.)
|
||
network.py # Builds PyTorch CNN or MLP from metadata + hyperparameters
|
||
memory.py # ReplayMemory (standard FIFO and prioritized variants)
|
||
trainer.py # DQNTrainer — orchestrates training, logging, checkpointing
|
||
cli.py # Click-based CLI (create, train, play subcommands)
|
||
```
|
||
|
||
---
|
||
|
||
## Phase 1 — retro-games Framework Changes
|
||
|
||
These changes must land in the `retro` package before retro-gamer can work. See the "Changes to retro" section at the bottom for details.
|
||
|
||
---
|
||
|
||
## Phase 2 — GameMetadata (`metadata.py`)
|
||
|
||
`GameMetadata` is a validated dataclass that lives as a TOML file alongside each training run.
|
||
|
||
### Required fields
|
||
| Field | Type | Description |
|
||
|---|---|---|
|
||
| `board_size` | `(int, int)` | Width × height of the board |
|
||
| `actions` | `list[str]` | Keystroke names the agent may emit (e.g. `"KEY_RIGHT"`, `"q"`) |
|
||
| `reward` | `str` | Key in `game.state` used as the reward signal |
|
||
|
||
### Optional fields
|
||
| Field | Type | Default | Description |
|
||
|---|---|---|---|
|
||
| `character_set` | `list[str]` | `None` | Characters that may appear on the board. Each character gets a one-hot encoding slot. If absent, the trainer explores first to discover the set. |
|
||
| `spatial` | `bool` | `True` | When `True`, treat the board as a 2D spatial scene (use CNN). When `False`, flatten and use MLP. |
|
||
| `observe_state` | `list[str]` | `[]` | State dict keys to append to the observation vector. Values must be `int`, `float`, or `bool`. Must not include `reward`. |
|
||
|
||
### TOML example
|
||
```toml
|
||
[game]
|
||
board_size = [32, 16]
|
||
actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]
|
||
reward = "score"
|
||
character_set = ["@", "*", ">", "<", "^", "v"]
|
||
spatial = true
|
||
observe_state = []
|
||
```
|
||
|
||
### Validation
|
||
- `board_size` must be two positive integers.
|
||
- `actions` must be non-empty.
|
||
- `reward` must not appear in `observe_state`.
|
||
- `character_set` entries must be single characters.
|
||
|
||
---
|
||
|
||
## Phase 3 — GameEnvironment (`env.py`)
|
||
|
||
`GameEnvironment` is retro-gamer's boundary with the retro-games framework. It owns the game lifecycle for one training episode and exposes a gym-style interface.
|
||
|
||
```python
|
||
class GameEnvironment:
|
||
def __init__(self, game_factory: Callable[[], Game], metadata: GameMetadata):
|
||
...
|
||
|
||
def reset(self) -> Observation:
|
||
"""Calls game_factory() to get a fresh Game; returns initial observation."""
|
||
|
||
def step(self, action: str | None) -> tuple[Observation, float, bool]:
|
||
"""
|
||
Injects action into game, runs one turn via Game.step().
|
||
Returns (observation, reward, done).
|
||
Reward = current state[reward_key] - previous state[reward_key].
|
||
"""
|
||
|
||
def observe(self) -> Observation:
|
||
"""Returns current observation without advancing the game."""
|
||
```
|
||
|
||
`Observation` is a numpy array constructed by `observation.py`.
|
||
|
||
### Character-set discovery
|
||
When `character_set` is not provided, `GameEnvironment.reset()` runs `exploration_turns` random turns, collects every character seen, and sets `metadata.character_set`. This also triggers a rebuild of any already-initialized network.
|
||
|
||
---
|
||
|
||
## Phase 4 — Observation Encoding (`observation.py`)
|
||
|
||
```python
|
||
def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
|
||
"""
|
||
Returns array of shape (H, W, C) where C = len(character_set).
|
||
Each cell is a one-hot vector. Unknown characters become the zero vector.
|
||
"""
|
||
|
||
def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
|
||
"""Returns 1-D float array of observed state values."""
|
||
|
||
def encode_observation(board_chars, state, metadata) -> np.ndarray:
|
||
"""Concatenates encoded board (flattened if not spatial) with encoded state."""
|
||
```
|
||
|
||
The board tensor shape fed to the network:
|
||
- Spatial: `(C, H, W)` (PyTorch channel-first for conv layers)
|
||
- Non-spatial: `(H * W * C + len(observe_state),)` flat vector
|
||
|
||
---
|
||
|
||
## Phase 5 — Network Architecture (`network.py`)
|
||
|
||
Two builder functions return a `torch.nn.Module`.
|
||
|
||
### `build_network(metadata, hyperparams) -> nn.Module`
|
||
The function picks the architecture based on `metadata.spatial` and logs its rationale to a string returned alongside the model.
|
||
|
||
**Spatial (CNN)**
|
||
```
|
||
Conv2d(in=C, out=32, kernel=3) → ReLU
|
||
Conv2d(32, 64, kernel=3) → ReLU
|
||
Flatten
|
||
Linear(auto-computed, layer_size) → ReLU (repeated n_layers times)
|
||
Linear(layer_size, n_actions)
|
||
```
|
||
|
||
**Non-spatial (MLP)**
|
||
```
|
||
Linear(obs_size, layer_size) → ReLU (repeated n_layers times)
|
||
Linear(layer_size, n_actions)
|
||
```
|
||
|
||
### Advanced override
|
||
Students who know PyTorch can pass a custom `nn.Module` subclass via `--model-class` in the CLI or directly to `DQNTrainer`. This bypasses the builder entirely. The log will note that a custom model was supplied.
|
||
|
||
---
|
||
|
||
## Phase 6 — Replay Memory (`memory.py`)
|
||
|
||
Two classes with the same interface:
|
||
- `ReplayMemory(capacity)` — standard ring buffer.
|
||
- `PrioritizedReplayMemory(capacity, alpha, beta)` — priority-weighted sampling based on temporal-difference error magnitude. Used when `prioritize_experiences=True`.
|
||
|
||
```python
|
||
memory.push(state, action, reward, next_state, done)
|
||
batch = memory.sample(batch_size)
|
||
memory.update_priorities(indices, td_errors) # PER only
|
||
```
|
||
|
||
---
|
||
|
||
## Phase 7 — DQN Trainer (`trainer.py`)
|
||
|
||
### Hyperparameters
|
||
| Name | Default | Description |
|
||
|---|---|---|
|
||
| `learning_rate` | `1e-3` | Adam optimizer LR |
|
||
| `lr_decay` | `0.995` | Multiplicative LR decay per episode |
|
||
| `gamma` | `0.99` | Discount factor |
|
||
| `epsilon` | `1.0` | Initial ε for ε-greedy exploration |
|
||
| `epsilon_decay` | `0.995` | ε decay per episode |
|
||
| `epsilon_min` | `0.05` | Minimum ε |
|
||
| `batch_size` | `64` | Experiences per training step |
|
||
| `memory_capacity` | `10000` | Replay buffer size |
|
||
| `target_update_freq` | `100` | Steps between target-network updates |
|
||
| `training_episodes` | `1000` | Number of episodes |
|
||
| `n_layers` | `2` | Hidden layers in MLP head |
|
||
| `layer_size` | `128` | Width of each hidden layer |
|
||
| `prioritize_experiences` | `False` | Use prioritized experience replay |
|
||
| `exploration_turns` | `200` | Random turns used to discover character set (when not specified) |
|
||
| `unknown_character_strategy` | `"ignore"` | `"ignore"` or `"extend"` (rebuild model when new character found) |
|
||
| `max_turns_per_episode` | `2000` | Safety cutoff for a single episode |
|
||
|
||
### Training directory structure
|
||
```
|
||
run_dir/ (e.g. snake_20240501_120000/)
|
||
config.toml # metadata + all hyperparameters, frozen at training start
|
||
training.log # structured log (see below)
|
||
checkpoints/
|
||
ep_0100.pt
|
||
ep_0500.pt
|
||
final.pt
|
||
```
|
||
|
||
### Log file format
|
||
The first log entry (written at init) is a full, human-readable description of the model architecture and every design decision, e.g.:
|
||
|
||
```
|
||
[INIT] Board: 32×16, character set: 6 chars → input channels = 6
|
||
[INIT] spatial=True → using CNN architecture
|
||
[INIT] CNN layers: Conv(6→32, k=3), Conv(32→64, k=3)
|
||
[INIT] Flattened CNN output: 64×28×12 = 21504 → Linear(21504, 128) → Linear(128, 5)
|
||
[INIT] Actions: KEY_RIGHT KEY_UP KEY_LEFT KEY_DOWN (none) = 5 outputs
|
||
[INIT] Optimizer: Adam, lr=0.001, decay=0.995
|
||
[INIT] Exploration: ε-greedy starting at ε=1.0, min=0.05
|
||
...
|
||
[EP 0001] total_reward=3 steps=145 epsilon=0.995 avg_loss=0.043
|
||
```
|
||
|
||
### DQNTrainer interface
|
||
```python
|
||
class DQNTrainer:
|
||
def __init__(self, game_factory, metadata, run_dir, **hyperparams): ...
|
||
def train(self): ... # runs all episodes, saves checkpoints
|
||
def load_checkpoint(self, path): ...
|
||
```
|
||
|
||
---
|
||
|
||
## Phase 8 — CLI (`cli.py`)
|
||
|
||
Implemented with Click, exposed as the `retro-gamer` entry point.
|
||
|
||
### `retro-gamer create`
|
||
Creates a new training run directory with a `config.toml` stub.
|
||
```
|
||
retro-gamer create --game snake:create_game --output ./runs/snake/ \
|
||
--actions KEY_RIGHT,KEY_UP,KEY_LEFT,KEY_DOWN \
|
||
--reward score --board-size 32x16 \
|
||
[--character-set @,*,>,<,^,v] [--spatial] [--no-spatial] \
|
||
[--n-layers 2] [--layer-size 128] [--training-episodes 1000] ...
|
||
```
|
||
|
||
### `retro-gamer train`
|
||
Runs (or resumes) training.
|
||
```
|
||
retro-gamer train ./runs/snake/
|
||
retro-gamer train ./runs/snake/ --resume checkpoints/ep_0500.pt
|
||
```
|
||
|
||
### `retro-gamer play`
|
||
Loads a saved agent and plays the game visually.
|
||
```
|
||
retro-gamer play ./runs/snake/ --checkpoint final
|
||
retro-gamer play ./runs/snake/ --checkpoint checkpoints/ep_0100.pt
|
||
```
|
||
|
||
### `retro-gamer info`
|
||
Prints a summary of the run: metadata, hyperparameters, latest training stats.
|
||
|
||
---
|
||
|
||
## Implementation Order
|
||
|
||
1. Land retro-games framework changes (input abstraction, `step()`, view refactor).
|
||
2. `metadata.py` + tests — validate, load, save TOML.
|
||
3. `env.py` — wraps game; unit-testable with a trivial fake game.
|
||
4. `observation.py` — encoding functions; purely functional, easy to test.
|
||
5. `network.py` — builders; test by checking output shapes.
|
||
6. `memory.py` — both buffer types.
|
||
7. `trainer.py` — integration; test with a tiny fast game.
|
||
8. `cli.py` — wire everything together.
|
||
|
||
---
|
||
|
||
---
|
||
|
||
# Required Changes to retro-games Framework
|
||
|
||
## Design principle
|
||
|
||
These changes are driven by good software design — clean separation of input handling and rendering from game logic — not by retro-gamer's RL concepts. The refactoring should stand on its own as an improvement to the retro-games framework.
|
||
|
||
---
|
||
|
||
## 1. Factor out input handling (`input.py`)
|
||
|
||
**Problem:** `collect_keystrokes(terminal)` is baked into the `Game.play()` loop. There is no way for external code to supply keystrokes programmatically without faking blessed objects, and there is no clean seam to test game logic in isolation from terminal I/O.
|
||
|
||
**Refactor:** Introduce an `InputSource` protocol and two implementations. Game holds an `input_source`; `step()` calls `self.input_source.collect()` regardless of where input originates.
|
||
|
||
```python
|
||
# retro/input.py
|
||
|
||
class InputSource:
|
||
"""Protocol for objects that supply keystrokes to the game each turn."""
|
||
def collect(self) -> set:
|
||
raise NotImplementedError
|
||
|
||
class TerminalInput(InputSource):
|
||
"""Reads keystrokes from a blessed Terminal (current behaviour)."""
|
||
def __init__(self, terminal):
|
||
self.terminal = terminal
|
||
|
||
def collect(self) -> set:
|
||
keys = set()
|
||
while True:
|
||
key = self.terminal.inkey(0.001)
|
||
if key:
|
||
keys.add(key)
|
||
else:
|
||
break
|
||
return keys
|
||
|
||
class ProgrammaticInput(InputSource):
|
||
"""Accepts keystrokes injected by external code (e.g. retro-gamer)."""
|
||
def __init__(self):
|
||
self._pending: str | None = None
|
||
|
||
def press(self, key: str | None):
|
||
"""Queue a single keystroke for the next turn.
|
||
key should be a key name string (e.g. "KEY_RIGHT", "q") or None.
|
||
"""
|
||
self._pending = key
|
||
|
||
def collect(self) -> set:
|
||
key = self._pending
|
||
self._pending = None
|
||
if key is None:
|
||
return set()
|
||
return {_make_keystroke(key)}
|
||
|
||
def _make_keystroke(key_str: str):
|
||
"""Creates a blessed-compatible keystroke from a string."""
|
||
from blessed.keyboard import Keystroke
|
||
if key_str.startswith("KEY_"):
|
||
return Keystroke(ucs='', code=None, name=key_str)
|
||
return Keystroke(ucs=key_str, code=None, name=None)
|
||
```
|
||
|
||
The `_make_keystroke` helper is an internal detail of the input module — Game and agents never know they're receiving synthetic keystrokes, and the RL framing never leaks into Game.
|
||
|
||
`Game.__init__` gains an `input_source: InputSource | None = None` parameter. When `None`, `play()` creates a `TerminalInput` internally (preserving existing usage); when provided explicitly, it is used in `step()`.
|
||
|
||
retro-gamer usage:
|
||
```python
|
||
inp = ProgrammaticInput()
|
||
game = create_game(input_source=inp) # or set after construction
|
||
inp.press("KEY_RIGHT")
|
||
game.step()
|
||
```
|
||
|
||
---
|
||
|
||
## 2. `Game.step()` and its relationship to `Game.play()`
|
||
|
||
**Refactor:** Extract the per-turn logic from `play()` into a `step()` method. `play()` then loops over `step()` calls, adding only the terminal context management, rendering, and frame-rate sleep on top. This ensures game logic is exercised via the same code path whether training or playing.
|
||
|
||
```python
|
||
def step(self):
|
||
"""Run one turn: collect input, let each agent act, advance state."""
|
||
self.turn_number += 1
|
||
self.keys_pressed = self.input_source.collect()
|
||
if self.debug and self.keys_pressed:
|
||
self.log("Keys: " + ', '.join(k.name or str(k) for k in self.keys_pressed))
|
||
self.prior_view_position = self.view_position
|
||
self.prior_agent_positions = self.agent_positions
|
||
for agent in self.agents:
|
||
if hasattr(agent, 'handle_keystroke'):
|
||
for key in self.keys_pressed:
|
||
agent.handle_keystroke(key, self)
|
||
if hasattr(agent, 'play_turn'):
|
||
agent.play_turn(self)
|
||
self._position_cache = None
|
||
if getattr(agent, 'display', True):
|
||
for pos in agent_occupied_positions(agent):
|
||
if not self.on_board(pos):
|
||
raise IllegalMove(agent, pos)
|
||
self.agent_positions = self.get_agents_by_position()
|
||
self.state.changed = False
|
||
|
||
def play(self):
|
||
"""Run the game loop in a terminal with rendering and frame-rate control."""
|
||
self.playing = True
|
||
terminal = Terminal()
|
||
terminal_input = TerminalInput(terminal)
|
||
self.input_source = terminal_input
|
||
with terminal.fullscreen(), terminal.hidden_cursor(), terminal.cbreak():
|
||
view = TerminalView(terminal, color=self.color)
|
||
self.agent_positions = {}
|
||
self.state.changed = True
|
||
while self.playing:
|
||
turn_start = perf_counter()
|
||
self.step()
|
||
view.render(self)
|
||
turn_elapsed = perf_counter() - turn_start
|
||
sleep(max(0, 1 / self.framerate - turn_elapsed))
|
||
if self.dump_state:
|
||
...
|
||
if self.wait_for_enter:
|
||
...
|
||
```
|
||
|
||
Note: `step()` does not return anything. It simply advances game state. retro-gamer reads back whatever it needs from the game object (board characters via the HeadlessView — see below) after calling `step()`. This keeps `step()` free of RL-specific return values.
|
||
|
||
---
|
||
|
||
## 3. Refactor view rendering: `View` protocol, `TerminalView`, `HeadlessView`
|
||
|
||
**Problem:** All rendering code lives in one `View` class that is tightly coupled to `blessed.Terminal`. There is no way to run the game loop without terminal output, and no way for external code to read board state.
|
||
|
||
**Refactor:** Separate the concept of "rendering" from game logic by defining a `View` protocol and two implementations.
|
||
|
||
```
|
||
retro/view.py → retro/views/terminal.py (TerminalView — current View, renamed)
|
||
→ retro/views/headless.py (HeadlessView — new)
|
||
→ retro/views/__init__.py (exports View protocol + both classes)
|
||
```
|
||
|
||
### `View` protocol
|
||
```python
|
||
# retro/views/__init__.py
|
||
from typing import Protocol
|
||
|
||
class View(Protocol):
|
||
def on_game_start(self, game) -> None: ...
|
||
def render(self, game) -> None: ...
|
||
```
|
||
|
||
### `TerminalView`
|
||
The current `View` class, moved to `retro/views/terminal.py` and renamed `TerminalView`. No behaviour changes — this is a pure rename/move.
|
||
|
||
### `HeadlessView`
|
||
```python
|
||
# retro/views/headless.py
|
||
|
||
class HeadlessView:
|
||
"""Maintains a readable board state without any terminal output.
|
||
After each game.step(), board_characters reflects the current board.
|
||
"""
|
||
def __init__(self):
|
||
self.board_characters: list[list[str]] = []
|
||
|
||
def on_game_start(self, game) -> None:
|
||
bw, bh = game.board_size
|
||
self.board_characters = [[' '] * bw for _ in range(bh)]
|
||
|
||
def render(self, game) -> None:
|
||
"""Recompute board_characters from current agent positions."""
|
||
bw, bh = game.board_size
|
||
board = [[' '] * bw for _ in range(bh)]
|
||
for (x, y), agents in game.get_agents_by_position().items():
|
||
top = max(agents, key=lambda a: getattr(a, 'z', 0) or 0)
|
||
board[y][x] = get_agent_character(top, (x, y))
|
||
self.board_characters = board
|
||
```
|
||
|
||
`get_agent_character` is already defined at module level in the current `view.py`; it moves to `retro/views/_util.py` and is imported by both views.
|
||
|
||
### How Game uses a View
|
||
|
||
`Game.__init__` gains an optional `view: View | None = None` parameter. `play()` passes a freshly created `TerminalView` (as today), overriding whatever was set. When `step()` is called directly (as retro-gamer does), `game.view.render(game)` is called at the end of each step if a view is set.
|
||
|
||
```python
|
||
# in Game.__init__
|
||
self.view = view # None by default
|
||
|
||
# in Game.step() (end of step)
|
||
if self.view is not None:
|
||
self.view.render(self)
|
||
```
|
||
|
||
retro-gamer usage:
|
||
```python
|
||
headless = HeadlessView()
|
||
game = create_game(view=headless)
|
||
game.step()
|
||
board = headless.board_characters # 2D list of chars, ready to encode
|
||
```
|
||
|
||
### Why this matters
|
||
|
||
This separation means:
|
||
- Game logic (agent turns, state transitions) has no terminal dependency.
|
||
- `TerminalView` can evolve independently (e.g., colour schemes, layout changes).
|
||
- Future rendering targets (web, test harness) are first-class citizens, not hacks.
|
||
- retro-gamer reads board state from `HeadlessView.board_characters` — a clean, stable API.
|
||
|
||
---
|
||
|
||
## 4. `create_game()` factory convention and entry points
|
||
|
||
**Convention:** Every game module must expose a no-argument function named `create_game` that returns a fully initialized `Game` instance. Standardizing the name is the right call — it keeps retro-gamer simple and makes the contract explicit for students.
|
||
|
||
```python
|
||
# retro/examples/snake.py — add at bottom:
|
||
def create_game():
|
||
head = SnakeHead()
|
||
apple = Apple()
|
||
game = Game([head, apple], {'score': 0}, board_size=(32, 16), framerate=12)
|
||
apple.relocate(game)
|
||
return game
|
||
```
|
||
|
||
retro-gamer loads a game as:
|
||
```python
|
||
import importlib
|
||
module = importlib.import_module('snake') # or 'retro.examples.snake'
|
||
game = module.create_game()
|
||
```
|
||
|
||
The CLI accepts a dotted module name: `retro-gamer create --game retro.examples.snake`.
|
||
|
||
### What are entry points and should we use them?
|
||
|
||
Python package **entry points** are a mechanism for installed packages to advertise named callables that other packages can discover at runtime, without either side hard-coding import paths. They are declared in `pyproject.toml`:
|
||
|
||
```toml
|
||
# In the game package's pyproject.toml:
|
||
[project.entry-points."retro.games"]
|
||
snake = "retro.examples.snake:create_game"
|
||
```
|
||
|
||
retro-gamer can then discover all registered games without knowing their module names:
|
||
|
||
```python
|
||
from importlib.metadata import entry_points
|
||
games = {ep.name: ep.load() for ep in entry_points(group="retro.games")}
|
||
# games == {'snake': <function create_game at 0x...>}
|
||
```
|
||
|
||
This is how pytest discovers plugins, Flask discovers extensions, etc.
|
||
|
||
**Recommendation:** Support both. For day-to-day use, students specify a module name and retro-gamer calls `module.create_game()`. For published or installed game packages (e.g., a course package with many games), the entry point mechanism lets retro-gamer discover games automatically (`retro-gamer list` could enumerate all installed games). These are complementary, not competing.
|
||
|
||
The standardized function name `create_game` is still required even when using entry points, because retro-gamer also needs to call the function without going through the entry point registry (e.g., when the module is on `sys.path` but not installed as a package).
|
||
|
||
---
|
||
|
||
## Summary of retro changes
|
||
|
||
| Change | Where | Complexity |
|
||
|---|---|---|
|
||
| `InputSource` protocol + `TerminalInput` + `ProgrammaticInput` | new `retro/input.py` | Medium |
|
||
| `_make_keystroke` helper | `retro/input.py` (internal) | Small |
|
||
| `Game.__init__` accepts `input_source`, `view` | `game.py` | Small |
|
||
| Extract `Game.step()` from `Game.play()` | `game.py` | Medium |
|
||
| `View` protocol | new `retro/views/__init__.py` | Small |
|
||
| `TerminalView` (rename + move current View) | `retro/view.py` → `retro/views/terminal.py` | Small (rename) |
|
||
| `HeadlessView` | new `retro/views/headless.py` | Small |
|
||
| `get_agent_character` moved to shared util | `retro/views/_util.py` | Trivial |
|
||
| `create_game()` factory in `snake.py` + docs | `snake.py` | Trivial |
|