Updates across the board

This commit is contained in:
Chris Proctor
2026-06-22 16:41:31 -04:00
parent 5ca97dc5d0
commit 73624d1a0c
33 changed files with 3104 additions and 643 deletions

View File

@@ -11,39 +11,58 @@ class GameMetadata:
"""Describes a retro game for training purposes.
Required fields: actions, reward.
Optional fields: character_set, spatial, observe_state.
Discovered fields: board_size (read from game.board_size at startup).
Metadata is read from the game's own pyproject.toml under the
[tool.retro-gamer] section via GameMetadata.from_pyproject().
Optional fields: character_set, spatial.
Discovered fields: board_size (from game.board_size), extras_size (from
the observe_state list in [preprocessing]).
"""
actions: list[str]
reward: str
character_set: list[str] | None = None
spatial: bool = True
observe_state: list[str] = field(default_factory=list)
spatial: bool = False
board: bool = True
board_size: tuple[int, int] | None = None
extras_size: int = 0
def validate(self):
if not self.actions:
raise ValueError("actions must be a non-empty list")
raise ValueError(
"The 'actions' list in [tool.retro-gamer] is empty or missing.\n"
"It should list the keyboard keys your agent can press, for example:\n\n"
' actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]\n\n'
"The agent will learn which actions lead to higher rewards."
)
if not isinstance(self.actions, list) or not all(isinstance(a, str) for a in self.actions):
raise ValueError(
f"'actions' must be a list of strings, but got: {self.actions!r}\n"
"Each entry should be a key name like \"KEY_RIGHT\" or \"KEY_SPACE\"."
)
if not self.reward:
raise ValueError("reward must be a non-empty string")
if self.reward in self.observe_state:
raise ValueError(f"reward key '{self.reward}' must not appear in observe_state")
raise ValueError(
"The 'reward' field in [tool.retro-gamer] is empty or missing.\n"
"It should name a game state variable whose value the agent is trying\n"
"to maximize — for example:\n\n"
" reward = \"score\"\n\n"
"The trainer watches how this value changes each step and uses those\n"
"changes as the reward signal."
)
if self.character_set is not None:
if not isinstance(self.character_set, list):
raise ValueError(
f"'character_set' must be a list of single characters, but got: {self.character_set!r}\n"
"Example: character_set = [\"@\", \"*\", \"#\"]"
)
for ch in self.character_set:
if len(ch) != 1:
raise ValueError(f"character_set entries must be single characters, got {ch!r}")
if not isinstance(ch, str) or len(ch) != 1:
raise ValueError(
f"Every entry in character_set must be a single character, but got {ch!r}.\n"
"Each character represents one type of cell on the game board.\n"
"If you're not sure what characters your game uses, remove character_set\n"
"entirely and the trainer will discover them automatically."
)
@classmethod
def from_pyproject(cls, module_name: str) -> GameMetadata:
"""Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml.
Finds pyproject.toml by walking up from the module's source file.
Raises FileNotFoundError if no pyproject.toml is found, or ValueError
if the file exists but has no [tool.retro-gamer] section.
"""
"""Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml."""
pyproject_path = _find_pyproject(module_name)
if pyproject_path is None:
raise FileNotFoundError(
@@ -65,13 +84,23 @@ class GameMetadata:
@classmethod
def from_dict(cls, d: dict) -> GameMetadata:
missing = [k for k in ('actions', 'reward') if k not in d]
if missing:
fields = ' and '.join(f"'{k}'" for k in missing)
raise ValueError(
f"The [tool.retro-gamer] section is missing required {fields}.\n"
"A minimal configuration looks like this:\n\n"
"[tool.retro-gamer]\n"
'actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]\n'
'reward = "score"\n\n'
"See the documentation for all available options."
)
board_size = tuple(d['board_size']) if 'board_size' in d else None
return cls(
actions=d['actions'],
reward=d['reward'],
character_set=d.get('character_set'),
spatial=d.get('spatial', True),
observe_state=d.get('observe_state', []),
spatial=d.get('spatial', False),
board_size=board_size,
)
@@ -79,8 +108,6 @@ class GameMetadata:
d = {
'actions': self.actions,
'reward': self.reward,
'spatial': self.spatial,
'observe_state': self.observe_state,
}
if self.board_size is not None:
d['board_size'] = list(self.board_size)
@@ -95,9 +122,11 @@ class GameMetadata:
@property
def obs_size(self) -> int:
"""Total size of the flat observation vector."""
if not self.board:
return self.extras_size
C = len(self.character_set) if self.character_set else 0
bw, bh = self.board_size
return C * bw * bh + len(self.observe_state)
return C * bw * bh + self.extras_size
@property
def n_actions(self) -> int:
@@ -118,4 +147,6 @@ def _find_pyproject(module_name: str) -> Path | None:
candidate = parent / 'pyproject.toml'
if candidate.exists():
return candidate
if parent.name == 'site-packages':
break
return None