Files
retro-gamer/retro_gamer/metadata.py
2026-06-22 16:41:31 -04:00

153 lines
6.0 KiB
Python

from __future__ import annotations
import importlib
import tomllib
import tomli_w
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class GameMetadata:
"""Describes a retro game for training purposes.
Required fields: actions, reward.
Optional fields: character_set, spatial.
Discovered fields: board_size (from game.board_size), extras_size (from
the observe_state list in [preprocessing]).
"""
actions: list[str]
reward: str
character_set: list[str] | None = None
spatial: bool = False
board: bool = True
board_size: tuple[int, int] | None = None
extras_size: int = 0
def validate(self):
if not self.actions:
raise ValueError(
"The 'actions' list in [tool.retro-gamer] is empty or missing.\n"
"It should list the keyboard keys your agent can press, for example:\n\n"
' actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]\n\n'
"The agent will learn which actions lead to higher rewards."
)
if not isinstance(self.actions, list) or not all(isinstance(a, str) for a in self.actions):
raise ValueError(
f"'actions' must be a list of strings, but got: {self.actions!r}\n"
"Each entry should be a key name like \"KEY_RIGHT\" or \"KEY_SPACE\"."
)
if not self.reward:
raise ValueError(
"The 'reward' field in [tool.retro-gamer] is empty or missing.\n"
"It should name a game state variable whose value the agent is trying\n"
"to maximize — for example:\n\n"
" reward = \"score\"\n\n"
"The trainer watches how this value changes each step and uses those\n"
"changes as the reward signal."
)
if self.character_set is not None:
if not isinstance(self.character_set, list):
raise ValueError(
f"'character_set' must be a list of single characters, but got: {self.character_set!r}\n"
"Example: character_set = [\"@\", \"*\", \"#\"]"
)
for ch in self.character_set:
if not isinstance(ch, str) or len(ch) != 1:
raise ValueError(
f"Every entry in character_set must be a single character, but got {ch!r}.\n"
"Each character represents one type of cell on the game board.\n"
"If you're not sure what characters your game uses, remove character_set\n"
"entirely and the trainer will discover them automatically."
)
@classmethod
def from_pyproject(cls, module_name: str) -> GameMetadata:
"""Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml."""
pyproject_path = _find_pyproject(module_name)
if pyproject_path is None:
raise FileNotFoundError(
f"Could not find pyproject.toml for module '{module_name}'. "
f"Make sure the module is part of a Python project with a pyproject.toml."
)
with open(pyproject_path, 'rb') as f:
data = tomllib.load(f)
section = data.get('tool', {}).get('retro-gamer')
if section is None:
raise ValueError(
f"No [tool.retro-gamer] section found in {pyproject_path}.\n"
f"Add game metadata to your pyproject.toml:\n\n"
f"[tool.retro-gamer]\n"
f"actions = [\"KEY_RIGHT\", ...]\n"
f"reward = \"score\"\n"
)
return cls.from_dict(section)
@classmethod
def from_dict(cls, d: dict) -> GameMetadata:
missing = [k for k in ('actions', 'reward') if k not in d]
if missing:
fields = ' and '.join(f"'{k}'" for k in missing)
raise ValueError(
f"The [tool.retro-gamer] section is missing required {fields}.\n"
"A minimal configuration looks like this:\n\n"
"[tool.retro-gamer]\n"
'actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]\n'
'reward = "score"\n\n'
"See the documentation for all available options."
)
board_size = tuple(d['board_size']) if 'board_size' in d else None
return cls(
actions=d['actions'],
reward=d['reward'],
character_set=d.get('character_set'),
spatial=d.get('spatial', False),
board_size=board_size,
)
def to_dict(self) -> dict:
d = {
'actions': self.actions,
'reward': self.reward,
}
if self.board_size is not None:
d['board_size'] = list(self.board_size)
if self.character_set is not None:
d['character_set'] = self.character_set
return d
def to_toml(self, path: str | Path):
with open(path, 'wb') as f:
tomli_w.dump({'metadata': self.to_dict()}, f)
@property
def obs_size(self) -> int:
"""Total size of the flat observation vector."""
if not self.board:
return self.extras_size
C = len(self.character_set) if self.character_set else 0
bw, bh = self.board_size
return C * bw * bh + self.extras_size
@property
def n_actions(self) -> int:
"""Number of actions including no-op."""
return len(self.actions) + 1
def _find_pyproject(module_name: str) -> Path | None:
"""Walk up from a module's source file to find its pyproject.toml."""
try:
module = importlib.import_module(module_name)
except ImportError:
return None
module_file = getattr(module, '__file__', None)
if module_file is None:
return None
for parent in Path(module_file).resolve().parents:
candidate = parent / 'pyproject.toml'
if candidate.exists():
return candidate
if parent.name == 'site-packages':
break
return None