Updates across the board
This commit is contained in:
@@ -11,39 +11,58 @@ class GameMetadata:
|
||||
"""Describes a retro game for training purposes.
|
||||
|
||||
Required fields: actions, reward.
|
||||
Optional fields: character_set, spatial, observe_state.
|
||||
Discovered fields: board_size (read from game.board_size at startup).
|
||||
|
||||
Metadata is read from the game's own pyproject.toml under the
|
||||
[tool.retro-gamer] section via GameMetadata.from_pyproject().
|
||||
Optional fields: character_set, spatial.
|
||||
Discovered fields: board_size (from game.board_size), extras_size (from
|
||||
the observe_state list in [preprocessing]).
|
||||
"""
|
||||
actions: list[str]
|
||||
reward: str
|
||||
character_set: list[str] | None = None
|
||||
spatial: bool = True
|
||||
observe_state: list[str] = field(default_factory=list)
|
||||
spatial: bool = False
|
||||
board: bool = True
|
||||
board_size: tuple[int, int] | None = None
|
||||
extras_size: int = 0
|
||||
|
||||
def validate(self):
|
||||
if not self.actions:
|
||||
raise ValueError("actions must be a non-empty list")
|
||||
raise ValueError(
|
||||
"The 'actions' list in [tool.retro-gamer] is empty or missing.\n"
|
||||
"It should list the keyboard keys your agent can press, for example:\n\n"
|
||||
' actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]\n\n'
|
||||
"The agent will learn which actions lead to higher rewards."
|
||||
)
|
||||
if not isinstance(self.actions, list) or not all(isinstance(a, str) for a in self.actions):
|
||||
raise ValueError(
|
||||
f"'actions' must be a list of strings, but got: {self.actions!r}\n"
|
||||
"Each entry should be a key name like \"KEY_RIGHT\" or \"KEY_SPACE\"."
|
||||
)
|
||||
if not self.reward:
|
||||
raise ValueError("reward must be a non-empty string")
|
||||
if self.reward in self.observe_state:
|
||||
raise ValueError(f"reward key '{self.reward}' must not appear in observe_state")
|
||||
raise ValueError(
|
||||
"The 'reward' field in [tool.retro-gamer] is empty or missing.\n"
|
||||
"It should name a game state variable whose value the agent is trying\n"
|
||||
"to maximize — for example:\n\n"
|
||||
" reward = \"score\"\n\n"
|
||||
"The trainer watches how this value changes each step and uses those\n"
|
||||
"changes as the reward signal."
|
||||
)
|
||||
if self.character_set is not None:
|
||||
if not isinstance(self.character_set, list):
|
||||
raise ValueError(
|
||||
f"'character_set' must be a list of single characters, but got: {self.character_set!r}\n"
|
||||
"Example: character_set = [\"@\", \"*\", \"#\"]"
|
||||
)
|
||||
for ch in self.character_set:
|
||||
if len(ch) != 1:
|
||||
raise ValueError(f"character_set entries must be single characters, got {ch!r}")
|
||||
if not isinstance(ch, str) or len(ch) != 1:
|
||||
raise ValueError(
|
||||
f"Every entry in character_set must be a single character, but got {ch!r}.\n"
|
||||
"Each character represents one type of cell on the game board.\n"
|
||||
"If you're not sure what characters your game uses, remove character_set\n"
|
||||
"entirely and the trainer will discover them automatically."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pyproject(cls, module_name: str) -> GameMetadata:
|
||||
"""Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml.
|
||||
|
||||
Finds pyproject.toml by walking up from the module's source file.
|
||||
Raises FileNotFoundError if no pyproject.toml is found, or ValueError
|
||||
if the file exists but has no [tool.retro-gamer] section.
|
||||
"""
|
||||
"""Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml."""
|
||||
pyproject_path = _find_pyproject(module_name)
|
||||
if pyproject_path is None:
|
||||
raise FileNotFoundError(
|
||||
@@ -65,13 +84,23 @@ class GameMetadata:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> GameMetadata:
|
||||
missing = [k for k in ('actions', 'reward') if k not in d]
|
||||
if missing:
|
||||
fields = ' and '.join(f"'{k}'" for k in missing)
|
||||
raise ValueError(
|
||||
f"The [tool.retro-gamer] section is missing required {fields}.\n"
|
||||
"A minimal configuration looks like this:\n\n"
|
||||
"[tool.retro-gamer]\n"
|
||||
'actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]\n'
|
||||
'reward = "score"\n\n'
|
||||
"See the documentation for all available options."
|
||||
)
|
||||
board_size = tuple(d['board_size']) if 'board_size' in d else None
|
||||
return cls(
|
||||
actions=d['actions'],
|
||||
reward=d['reward'],
|
||||
character_set=d.get('character_set'),
|
||||
spatial=d.get('spatial', True),
|
||||
observe_state=d.get('observe_state', []),
|
||||
spatial=d.get('spatial', False),
|
||||
board_size=board_size,
|
||||
)
|
||||
|
||||
@@ -79,8 +108,6 @@ class GameMetadata:
|
||||
d = {
|
||||
'actions': self.actions,
|
||||
'reward': self.reward,
|
||||
'spatial': self.spatial,
|
||||
'observe_state': self.observe_state,
|
||||
}
|
||||
if self.board_size is not None:
|
||||
d['board_size'] = list(self.board_size)
|
||||
@@ -95,9 +122,11 @@ class GameMetadata:
|
||||
@property
|
||||
def obs_size(self) -> int:
|
||||
"""Total size of the flat observation vector."""
|
||||
if not self.board:
|
||||
return self.extras_size
|
||||
C = len(self.character_set) if self.character_set else 0
|
||||
bw, bh = self.board_size
|
||||
return C * bw * bh + len(self.observe_state)
|
||||
return C * bw * bh + self.extras_size
|
||||
|
||||
@property
|
||||
def n_actions(self) -> int:
|
||||
@@ -118,4 +147,6 @@ def _find_pyproject(module_name: str) -> Path | None:
|
||||
candidate = parent / 'pyproject.toml'
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
if parent.name == 'site-packages':
|
||||
break
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user