Updates across the board
This commit is contained in:
@@ -9,15 +9,27 @@ from retro_gamer.observation import encode_observation
|
||||
|
||||
|
||||
class GameEnvironment:
|
||||
"""Gym-style wrapper around a retro Game for RL training.
|
||||
"""Gym-style wrapper around a retro game for RL training."""
|
||||
|
||||
Provides reset() / step(action) / observe(), managing one training episode
|
||||
at a time. The game is restarted by calling the factory function on each reset.
|
||||
"""
|
||||
|
||||
def __init__(self, game_factory: Callable, metadata: GameMetadata):
|
||||
def __init__(
|
||||
self,
|
||||
game_factory: Callable,
|
||||
metadata: GameMetadata,
|
||||
observe_state: list[str] | None = None,
|
||||
egocentric: bool = False,
|
||||
egocentric_player: str | None = None,
|
||||
egocentric_radius: int | None = None,
|
||||
board: bool = True,
|
||||
observe_state_sizes: dict[str, int] | None = None,
|
||||
):
|
||||
self.game_factory = game_factory
|
||||
self.metadata = metadata
|
||||
self.observe_state = observe_state or []
|
||||
self.egocentric = egocentric
|
||||
self.egocentric_player = egocentric_player
|
||||
self.egocentric_radius = egocentric_radius
|
||||
self.board = board
|
||||
self.observe_state_sizes = observe_state_sizes or {}
|
||||
self.game = None
|
||||
self.view: HeadlessView | None = None
|
||||
self.inp: ProgrammaticInput | None = None
|
||||
@@ -35,12 +47,7 @@ class GameEnvironment:
|
||||
return self._observe()
|
||||
|
||||
def step(self, action: str | None) -> tuple[np.ndarray, float, bool]:
|
||||
"""Advance one turn with the given action.
|
||||
|
||||
action: a keystroke string (e.g. 'KEY_RIGHT') or None for no-op.
|
||||
Returns (observation, reward, done).
|
||||
Reward is the change in the reward state key since the previous step.
|
||||
"""
|
||||
"""Advance one turn. Returns (observation, reward, done)."""
|
||||
self.inp.press(action)
|
||||
self.game.step()
|
||||
obs = self._observe()
|
||||
@@ -49,12 +56,47 @@ class GameEnvironment:
|
||||
return obs, reward, done
|
||||
|
||||
def _observe(self) -> np.ndarray:
|
||||
state = dict(self.game.state)
|
||||
if self.observe_state_sizes:
|
||||
self._check_state_sizes(state)
|
||||
player_pos = None
|
||||
if self.egocentric and self.egocentric_player:
|
||||
agent = self.game.get_agent_by_name(self.egocentric_player)
|
||||
if agent is not None:
|
||||
player_pos = agent.position
|
||||
return encode_observation(
|
||||
self.view.board_characters,
|
||||
dict(self.game.state),
|
||||
state,
|
||||
self.metadata,
|
||||
self.observe_state,
|
||||
player_pos=player_pos,
|
||||
egocentric_radius=self.egocentric_radius,
|
||||
board=self.board,
|
||||
)
|
||||
|
||||
def _check_state_sizes(self, state: dict):
|
||||
for key, expected in self.observe_state_sizes.items():
|
||||
val = state.get(key)
|
||||
if val is None:
|
||||
actual = 0
|
||||
elif isinstance(val, (list, tuple)):
|
||||
actual = len(val)
|
||||
else:
|
||||
actual = 1
|
||||
if actual != expected:
|
||||
raise ValueError(
|
||||
f"State key '{key}' changed size during training:\n"
|
||||
f" Expected : {expected} (discovered at training start)\n"
|
||||
f" Got : {actual}\n\n"
|
||||
f"This means game.state['{key}'] has a different length in some\n"
|
||||
f"episodes than it had when training started. The neural network\n"
|
||||
f"has a fixed input size and cannot adapt to changing state shapes.\n\n"
|
||||
f"Fix: make sure create_game() always initializes '{key}' with a\n"
|
||||
f"fixed-length value before the game starts each episode.\n"
|
||||
f"For example, if '{key}' is a list of 9 values, it must always be\n"
|
||||
f"a list of exactly 9 values — never more, never fewer, never missing."
|
||||
)
|
||||
|
||||
def _delta_reward(self) -> float:
|
||||
current = float(self.game.state.get(self.metadata.reward, 0))
|
||||
delta = current - self._prev_reward
|
||||
@@ -62,10 +104,8 @@ class GameEnvironment:
|
||||
return delta
|
||||
|
||||
def discover_character_set(self, exploration_turns: int) -> list[str]:
|
||||
"""Run random turns to discover the characters that appear on the board.
|
||||
Returns the sorted character list (excluding space).
|
||||
"""
|
||||
obs = self.reset()
|
||||
"""Run random turns to discover the characters that appear on the board."""
|
||||
self.reset()
|
||||
chars: set[str] = set()
|
||||
for _ in range(exploration_turns):
|
||||
for row in self.view.board_characters:
|
||||
|
||||
Reference in New Issue
Block a user