Initial commit
This commit is contained in:
78
retro_gamer/env.py
Normal file
78
retro_gamer/env.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
import random
|
||||
import numpy as np
|
||||
from typing import Callable
|
||||
from retro.input import ProgrammaticInput
|
||||
from retro.views.headless import HeadlessView
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
from retro_gamer.observation import encode_observation
|
||||
|
||||
|
||||
class GameEnvironment:
|
||||
"""Gym-style wrapper around a retro Game for RL training.
|
||||
|
||||
Provides reset() / step(action) / observe(), managing one training episode
|
||||
at a time. The game is restarted by calling the factory function on each reset.
|
||||
"""
|
||||
|
||||
def __init__(self, game_factory: Callable, metadata: GameMetadata):
|
||||
self.game_factory = game_factory
|
||||
self.metadata = metadata
|
||||
self.game = None
|
||||
self.view: HeadlessView | None = None
|
||||
self.inp: ProgrammaticInput | None = None
|
||||
self._prev_reward: float = 0.0
|
||||
|
||||
def reset(self) -> np.ndarray:
|
||||
"""Create a fresh game episode and return the initial observation."""
|
||||
self.inp = ProgrammaticInput()
|
||||
self.view = HeadlessView()
|
||||
self.game = self.game_factory()
|
||||
self.game.input_source = self.inp
|
||||
self.game.view = self.view
|
||||
self.game.start()
|
||||
self._prev_reward = float(self.game.state.get(self.metadata.reward, 0))
|
||||
return self._observe()
|
||||
|
||||
def step(self, action: str | None) -> tuple[np.ndarray, float, bool]:
|
||||
"""Advance one turn with the given action.
|
||||
|
||||
action: a keystroke string (e.g. 'KEY_RIGHT') or None for no-op.
|
||||
Returns (observation, reward, done).
|
||||
Reward is the change in the reward state key since the previous step.
|
||||
"""
|
||||
self.inp.press(action)
|
||||
self.game.step()
|
||||
obs = self._observe()
|
||||
reward = self._delta_reward()
|
||||
done = not self.game.playing
|
||||
return obs, reward, done
|
||||
|
||||
def _observe(self) -> np.ndarray:
|
||||
return encode_observation(
|
||||
self.view.board_characters,
|
||||
dict(self.game.state),
|
||||
self.metadata,
|
||||
)
|
||||
|
||||
def _delta_reward(self) -> float:
|
||||
current = float(self.game.state.get(self.metadata.reward, 0))
|
||||
delta = current - self._prev_reward
|
||||
self._prev_reward = current
|
||||
return delta
|
||||
|
||||
def discover_character_set(self, exploration_turns: int) -> list[str]:
|
||||
"""Run random turns to discover the characters that appear on the board.
|
||||
Returns the sorted character list (excluding space).
|
||||
"""
|
||||
obs = self.reset()
|
||||
chars: set[str] = set()
|
||||
for _ in range(exploration_turns):
|
||||
for row in self.view.board_characters:
|
||||
chars.update(row)
|
||||
action = random.choice(self.metadata.actions + [None])
|
||||
_, _, done = self.step(action)
|
||||
if done:
|
||||
self.reset()
|
||||
chars.discard(' ')
|
||||
return sorted(chars)
|
||||
Reference in New Issue
Block a user