from __future__ import annotations import importlib import tomllib import tomli_w from dataclasses import dataclass, field from pathlib import Path @dataclass class GameMetadata: """Describes a retro game for training purposes. Required fields: actions, reward. Optional fields: character_set, spatial, observe_state. Discovered fields: board_size (read from game.board_size at startup). Metadata is read from the game's own pyproject.toml under the [tool.retro-gamer] section via GameMetadata.from_pyproject(). """ actions: list[str] reward: str character_set: list[str] | None = None spatial: bool = True observe_state: list[str] = field(default_factory=list) board_size: tuple[int, int] | None = None def validate(self): if not self.actions: raise ValueError("actions must be a non-empty list") if not self.reward: raise ValueError("reward must be a non-empty string") if self.reward in self.observe_state: raise ValueError(f"reward key '{self.reward}' must not appear in observe_state") if self.character_set is not None: for ch in self.character_set: if len(ch) != 1: raise ValueError(f"character_set entries must be single characters, got {ch!r}") @classmethod def from_pyproject(cls, module_name: str) -> GameMetadata: """Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml. Finds pyproject.toml by walking up from the module's source file. Raises FileNotFoundError if no pyproject.toml is found, or ValueError if the file exists but has no [tool.retro-gamer] section. """ pyproject_path = _find_pyproject(module_name) if pyproject_path is None: raise FileNotFoundError( f"Could not find pyproject.toml for module '{module_name}'. " f"Make sure the module is part of a Python project with a pyproject.toml." ) with open(pyproject_path, 'rb') as f: data = tomllib.load(f) section = data.get('tool', {}).get('retro-gamer') if section is None: raise ValueError( f"No [tool.retro-gamer] section found in {pyproject_path}.\n" f"Add game metadata to your pyproject.toml:\n\n" f"[tool.retro-gamer]\n" f"actions = [\"KEY_RIGHT\", ...]\n" f"reward = \"score\"\n" ) return cls.from_dict(section) @classmethod def from_dict(cls, d: dict) -> GameMetadata: board_size = tuple(d['board_size']) if 'board_size' in d else None return cls( actions=d['actions'], reward=d['reward'], character_set=d.get('character_set'), spatial=d.get('spatial', True), observe_state=d.get('observe_state', []), board_size=board_size, ) def to_dict(self) -> dict: d = { 'actions': self.actions, 'reward': self.reward, 'spatial': self.spatial, 'observe_state': self.observe_state, } if self.board_size is not None: d['board_size'] = list(self.board_size) if self.character_set is not None: d['character_set'] = self.character_set return d def to_toml(self, path: str | Path): with open(path, 'wb') as f: tomli_w.dump({'metadata': self.to_dict()}, f) @property def obs_size(self) -> int: """Total size of the flat observation vector.""" C = len(self.character_set) if self.character_set else 0 bw, bh = self.board_size return C * bw * bh + len(self.observe_state) @property def n_actions(self) -> int: """Number of actions including no-op.""" return len(self.actions) + 1 def _find_pyproject(module_name: str) -> Path | None: """Walk up from a module's source file to find its pyproject.toml.""" try: module = importlib.import_module(module_name) except ImportError: return None module_file = getattr(module, '__file__', None) if module_file is None: return None for parent in Path(module_file).resolve().parents: candidate = parent / 'pyproject.toml' if candidate.exists(): return candidate return None