122 lines
4.3 KiB
Python
122 lines
4.3 KiB
Python
from __future__ import annotations
|
|
import importlib
|
|
import tomllib
|
|
import tomli_w
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
|
|
|
|
@dataclass
|
|
class GameMetadata:
|
|
"""Describes a retro game for training purposes.
|
|
|
|
Required fields: actions, reward.
|
|
Optional fields: character_set, spatial, observe_state.
|
|
Discovered fields: board_size (read from game.board_size at startup).
|
|
|
|
Metadata is read from the game's own pyproject.toml under the
|
|
[tool.retro-gamer] section via GameMetadata.from_pyproject().
|
|
"""
|
|
actions: list[str]
|
|
reward: str
|
|
character_set: list[str] | None = None
|
|
spatial: bool = True
|
|
observe_state: list[str] = field(default_factory=list)
|
|
board_size: tuple[int, int] | None = None
|
|
|
|
def validate(self):
|
|
if not self.actions:
|
|
raise ValueError("actions must be a non-empty list")
|
|
if not self.reward:
|
|
raise ValueError("reward must be a non-empty string")
|
|
if self.reward in self.observe_state:
|
|
raise ValueError(f"reward key '{self.reward}' must not appear in observe_state")
|
|
if self.character_set is not None:
|
|
for ch in self.character_set:
|
|
if len(ch) != 1:
|
|
raise ValueError(f"character_set entries must be single characters, got {ch!r}")
|
|
|
|
@classmethod
|
|
def from_pyproject(cls, module_name: str) -> GameMetadata:
|
|
"""Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml.
|
|
|
|
Finds pyproject.toml by walking up from the module's source file.
|
|
Raises FileNotFoundError if no pyproject.toml is found, or ValueError
|
|
if the file exists but has no [tool.retro-gamer] section.
|
|
"""
|
|
pyproject_path = _find_pyproject(module_name)
|
|
if pyproject_path is None:
|
|
raise FileNotFoundError(
|
|
f"Could not find pyproject.toml for module '{module_name}'. "
|
|
f"Make sure the module is part of a Python project with a pyproject.toml."
|
|
)
|
|
with open(pyproject_path, 'rb') as f:
|
|
data = tomllib.load(f)
|
|
section = data.get('tool', {}).get('retro-gamer')
|
|
if section is None:
|
|
raise ValueError(
|
|
f"No [tool.retro-gamer] section found in {pyproject_path}.\n"
|
|
f"Add game metadata to your pyproject.toml:\n\n"
|
|
f"[tool.retro-gamer]\n"
|
|
f"actions = [\"KEY_RIGHT\", ...]\n"
|
|
f"reward = \"score\"\n"
|
|
)
|
|
return cls.from_dict(section)
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict) -> GameMetadata:
|
|
board_size = tuple(d['board_size']) if 'board_size' in d else None
|
|
return cls(
|
|
actions=d['actions'],
|
|
reward=d['reward'],
|
|
character_set=d.get('character_set'),
|
|
spatial=d.get('spatial', True),
|
|
observe_state=d.get('observe_state', []),
|
|
board_size=board_size,
|
|
)
|
|
|
|
def to_dict(self) -> dict:
|
|
d = {
|
|
'actions': self.actions,
|
|
'reward': self.reward,
|
|
'spatial': self.spatial,
|
|
'observe_state': self.observe_state,
|
|
}
|
|
if self.board_size is not None:
|
|
d['board_size'] = list(self.board_size)
|
|
if self.character_set is not None:
|
|
d['character_set'] = self.character_set
|
|
return d
|
|
|
|
def to_toml(self, path: str | Path):
|
|
with open(path, 'wb') as f:
|
|
tomli_w.dump({'metadata': self.to_dict()}, f)
|
|
|
|
@property
|
|
def obs_size(self) -> int:
|
|
"""Total size of the flat observation vector."""
|
|
C = len(self.character_set) if self.character_set else 0
|
|
bw, bh = self.board_size
|
|
return C * bw * bh + len(self.observe_state)
|
|
|
|
@property
|
|
def n_actions(self) -> int:
|
|
"""Number of actions including no-op."""
|
|
return len(self.actions) + 1
|
|
|
|
|
|
def _find_pyproject(module_name: str) -> Path | None:
|
|
"""Walk up from a module's source file to find its pyproject.toml."""
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
except ImportError:
|
|
return None
|
|
module_file = getattr(module, '__file__', None)
|
|
if module_file is None:
|
|
return None
|
|
for parent in Path(module_file).resolve().parents:
|
|
candidate = parent / 'pyproject.toml'
|
|
if candidate.exists():
|
|
return candidate
|
|
return None
|