Initial commit

This commit is contained in:
Chris Proctor
2026-05-08 14:07:17 -04:00
commit 5ca97dc5d0
36 changed files with 4147 additions and 0 deletions

121
retro_gamer/metadata.py Normal file
View File

@@ -0,0 +1,121 @@
from __future__ import annotations
import importlib
import tomllib
import tomli_w
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class GameMetadata:
"""Describes a retro game for training purposes.
Required fields: actions, reward.
Optional fields: character_set, spatial, observe_state.
Discovered fields: board_size (read from game.board_size at startup).
Metadata is read from the game's own pyproject.toml under the
[tool.retro-gamer] section via GameMetadata.from_pyproject().
"""
actions: list[str]
reward: str
character_set: list[str] | None = None
spatial: bool = True
observe_state: list[str] = field(default_factory=list)
board_size: tuple[int, int] | None = None
def validate(self):
if not self.actions:
raise ValueError("actions must be a non-empty list")
if not self.reward:
raise ValueError("reward must be a non-empty string")
if self.reward in self.observe_state:
raise ValueError(f"reward key '{self.reward}' must not appear in observe_state")
if self.character_set is not None:
for ch in self.character_set:
if len(ch) != 1:
raise ValueError(f"character_set entries must be single characters, got {ch!r}")
@classmethod
def from_pyproject(cls, module_name: str) -> GameMetadata:
"""Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml.
Finds pyproject.toml by walking up from the module's source file.
Raises FileNotFoundError if no pyproject.toml is found, or ValueError
if the file exists but has no [tool.retro-gamer] section.
"""
pyproject_path = _find_pyproject(module_name)
if pyproject_path is None:
raise FileNotFoundError(
f"Could not find pyproject.toml for module '{module_name}'. "
f"Make sure the module is part of a Python project with a pyproject.toml."
)
with open(pyproject_path, 'rb') as f:
data = tomllib.load(f)
section = data.get('tool', {}).get('retro-gamer')
if section is None:
raise ValueError(
f"No [tool.retro-gamer] section found in {pyproject_path}.\n"
f"Add game metadata to your pyproject.toml:\n\n"
f"[tool.retro-gamer]\n"
f"actions = [\"KEY_RIGHT\", ...]\n"
f"reward = \"score\"\n"
)
return cls.from_dict(section)
@classmethod
def from_dict(cls, d: dict) -> GameMetadata:
board_size = tuple(d['board_size']) if 'board_size' in d else None
return cls(
actions=d['actions'],
reward=d['reward'],
character_set=d.get('character_set'),
spatial=d.get('spatial', True),
observe_state=d.get('observe_state', []),
board_size=board_size,
)
def to_dict(self) -> dict:
d = {
'actions': self.actions,
'reward': self.reward,
'spatial': self.spatial,
'observe_state': self.observe_state,
}
if self.board_size is not None:
d['board_size'] = list(self.board_size)
if self.character_set is not None:
d['character_set'] = self.character_set
return d
def to_toml(self, path: str | Path):
with open(path, 'wb') as f:
tomli_w.dump({'metadata': self.to_dict()}, f)
@property
def obs_size(self) -> int:
"""Total size of the flat observation vector."""
C = len(self.character_set) if self.character_set else 0
bw, bh = self.board_size
return C * bw * bh + len(self.observe_state)
@property
def n_actions(self) -> int:
"""Number of actions including no-op."""
return len(self.actions) + 1
def _find_pyproject(module_name: str) -> Path | None:
"""Walk up from a module's source file to find its pyproject.toml."""
try:
module = importlib.import_module(module_name)
except ImportError:
return None
module_file = getattr(module, '__file__', None)
if module_file is None:
return None
for parent in Path(module_file).resolve().parents:
candidate = parent / 'pyproject.toml'
if candidate.exists():
return candidate
return None