Initial commit
This commit is contained in:
5
retro_gamer/__init__.py
Normal file
5
retro_gamer/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
from retro_gamer.env import GameEnvironment
|
||||
from retro_gamer.trainer import DQNTrainer
|
||||
|
||||
__all__ = ["GameMetadata", "GameEnvironment", "DQNTrainer"]
|
||||
243
retro_gamer/cli.py
Normal file
243
retro_gamer/cli.py
Normal file
@@ -0,0 +1,243 @@
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
import click
|
||||
import tomli_w
|
||||
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
from retro_gamer.trainer import DQNTrainer, DEFAULTS
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""Train and run RL agents for retro games."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# retro-gamer create
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@cli.command()
|
||||
@click.option('--game', required=True,
|
||||
help='Python module containing create_game() e.g. retro.examples.snake')
|
||||
@click.option('--output', required=True,
|
||||
help='Directory to create for this training run')
|
||||
@click.option('--learning-rate', default=DEFAULTS['learning_rate'], type=float,
|
||||
help=f"Adam optimizer learning rate (default {DEFAULTS['learning_rate']})")
|
||||
@click.option('--lr-decay', default=DEFAULTS['lr_decay'], type=float,
|
||||
help=f"Multiplicative LR decay per episode (default {DEFAULTS['lr_decay']})")
|
||||
@click.option('--gamma', default=DEFAULTS['gamma'], type=float,
|
||||
help=f"Discount factor for future rewards (default {DEFAULTS['gamma']})")
|
||||
@click.option('--epsilon-decay', default=DEFAULTS['epsilon_decay'], type=float,
|
||||
help=f"Exploration rate decay per episode (default {DEFAULTS['epsilon_decay']})")
|
||||
@click.option('--epsilon-min', default=DEFAULTS['epsilon_min'], type=float,
|
||||
help=f"Minimum exploration rate (default {DEFAULTS['epsilon_min']})")
|
||||
@click.option('--batch-size', default=DEFAULTS['batch_size'], type=int,
|
||||
help=f"Experiences per training step (default {DEFAULTS['batch_size']})")
|
||||
@click.option('--memory-capacity', default=DEFAULTS['memory_capacity'], type=int,
|
||||
help=f"Replay buffer size (default {DEFAULTS['memory_capacity']})")
|
||||
@click.option('--target-update-freq', default=DEFAULTS['target_update_freq'], type=int,
|
||||
help=f"Steps between target network updates (default {DEFAULTS['target_update_freq']})")
|
||||
@click.option('--training-episodes', default=DEFAULTS['training_episodes'], type=int,
|
||||
help=f"Number of episodes to train (default {DEFAULTS['training_episodes']})")
|
||||
@click.option('--max-turns-per-episode', default=DEFAULTS['max_turns_per_episode'], type=int,
|
||||
help=f"Turn limit per episode (default {DEFAULTS['max_turns_per_episode']})")
|
||||
@click.option('--n-layers', default=DEFAULTS['n_layers'], type=int,
|
||||
help=f"Hidden layers in MLP head (default {DEFAULTS['n_layers']})")
|
||||
@click.option('--layer-size', default=DEFAULTS['layer_size'], type=int,
|
||||
help=f"Width of each hidden layer (default {DEFAULTS['layer_size']})")
|
||||
@click.option('--exploration-turns', default=DEFAULTS['exploration_turns'], type=int,
|
||||
help=f"Random turns for character discovery (default {DEFAULTS['exploration_turns']})")
|
||||
@click.option('--prioritize-experiences/--no-prioritize-experiences',
|
||||
default=DEFAULTS['prioritize_experiences'],
|
||||
help='Use prioritized experience replay')
|
||||
def create(game, output, **hyperparams):
|
||||
"""Create a new training run directory.
|
||||
|
||||
Game metadata (actions, reward signal, etc.) is read from the
|
||||
[tool.retro-gamer] section of the game's pyproject.toml.
|
||||
Board size is read directly from the game. Hyperparameter options
|
||||
control how the trainer learns, not what it learns about.
|
||||
"""
|
||||
try:
|
||||
metadata = GameMetadata.from_pyproject(game)
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
raise click.ClickException(str(e))
|
||||
|
||||
game_factory = _load_factory(game)
|
||||
g = game_factory()
|
||||
metadata.board_size = g.board_size
|
||||
|
||||
metadata.validate()
|
||||
|
||||
run_dir = Path(output)
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config = {
|
||||
'game': {'module': game},
|
||||
'metadata': metadata.to_dict(),
|
||||
'hyperparameters': hyperparams,
|
||||
}
|
||||
with open(run_dir / 'config.toml', 'wb') as f:
|
||||
tomli_w.dump(config, f)
|
||||
|
||||
click.echo(f"Created training run at {output}/config.toml")
|
||||
click.echo(f" game : {game}")
|
||||
click.echo(f" board_size : {metadata.board_size[0]}×{metadata.board_size[1]}")
|
||||
click.echo(f" actions : {metadata.actions}")
|
||||
click.echo(f" reward : {metadata.reward}")
|
||||
if metadata.character_set:
|
||||
click.echo(f" characters : {metadata.character_set}")
|
||||
else:
|
||||
click.echo(f" characters : (will be auto-discovered during training)")
|
||||
if metadata.observe_state:
|
||||
click.echo(f" observe : {metadata.observe_state}")
|
||||
click.echo(f" architecture: {'CNN (spatial)' if metadata.spatial else 'MLP (non-spatial)'}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# retro-gamer train
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@cli.command()
|
||||
@click.argument('run_dir')
|
||||
@click.option('--resume', default=None,
|
||||
help='Path to checkpoint to resume from (e.g. checkpoints/ep_0500.pt)')
|
||||
def train(run_dir, resume):
|
||||
"""Train (or resume training) a DQN agent."""
|
||||
run_dir_path = Path(run_dir)
|
||||
config = _load_config(run_dir_path)
|
||||
game_factory = _load_factory(config['game']['module'])
|
||||
metadata = GameMetadata.from_dict(config['metadata'])
|
||||
hyperparams = config.get('hyperparameters', {})
|
||||
|
||||
trainer = DQNTrainer(game_factory, metadata, run_dir, **hyperparams)
|
||||
if resume:
|
||||
click.echo(f"Resuming from {resume}")
|
||||
trainer.load_checkpoint(resume)
|
||||
click.echo(f"Training for {trainer.hp['training_episodes']} episodes…")
|
||||
trainer.train()
|
||||
click.echo(f"Done. Checkpoints in {run_dir}/checkpoints/")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# retro-gamer play
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@cli.command()
|
||||
@click.argument('run_dir')
|
||||
@click.option('--checkpoint', default='final',
|
||||
help='Checkpoint name e.g. "final" or "ep_0100"')
|
||||
@click.option('--framerate', default=12, type=int,
|
||||
help='Target frames per second')
|
||||
def play(run_dir, checkpoint, framerate):
|
||||
"""Watch a trained agent play the game."""
|
||||
import torch
|
||||
from time import sleep, perf_counter
|
||||
from blessed import Terminal
|
||||
from retro.input import ProgrammaticInput
|
||||
from retro.views.headless import HeadlessView
|
||||
from retro.views.terminal import TerminalView
|
||||
from retro_gamer.observation import encode_observation
|
||||
|
||||
run_dir_path = Path(run_dir)
|
||||
config = _load_config(run_dir_path)
|
||||
game_factory = _load_factory(config['game']['module'])
|
||||
metadata = GameMetadata.from_dict(config['metadata'])
|
||||
hyperparams = {**DEFAULTS, **config.get('hyperparameters', {})}
|
||||
|
||||
from retro_gamer.network import build_network
|
||||
model, _ = build_network(metadata, hyperparams)
|
||||
|
||||
ckpt_name = checkpoint if checkpoint.endswith('.pt') else f'{checkpoint}.pt'
|
||||
ckpt_path = run_dir_path / 'checkpoints' / ckpt_name
|
||||
ckpt = torch.load(ckpt_path, weights_only=True)
|
||||
model.load_state_dict(ckpt['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
inp = ProgrammaticInput()
|
||||
headless = HeadlessView()
|
||||
game = game_factory()
|
||||
game.input_source = inp
|
||||
game.view = headless
|
||||
game.start()
|
||||
|
||||
terminal = Terminal()
|
||||
term_view = TerminalView(terminal, color=game.color)
|
||||
|
||||
click.echo("Playing… (press Escape or Enter to quit)")
|
||||
with terminal.fullscreen(), terminal.hidden_cursor(), terminal.cbreak():
|
||||
term_view.on_game_start(game)
|
||||
while game.playing:
|
||||
t0 = perf_counter()
|
||||
|
||||
obs = encode_observation(headless.board_characters, dict(game.state), metadata)
|
||||
state_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
action_idx = int(model(state_t).argmax().item())
|
||||
action_key = None if action_idx >= len(metadata.actions) else metadata.actions[action_idx]
|
||||
|
||||
key = terminal.inkey(0)
|
||||
if key and key.name in ('KEY_ESCAPE', 'KEY_ENTER'):
|
||||
break
|
||||
|
||||
inp.press(action_key)
|
||||
game.step()
|
||||
term_view.render(game)
|
||||
|
||||
elapsed = perf_counter() - t0
|
||||
sleep(max(0, 1 / framerate - elapsed))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# retro-gamer info
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@cli.command()
|
||||
@click.argument('run_dir')
|
||||
def info(run_dir):
|
||||
"""Print a summary of a training run."""
|
||||
run_dir_path = Path(run_dir)
|
||||
config = _load_config(run_dir_path)
|
||||
click.echo(f"Game module : {config['game']['module']}")
|
||||
click.echo(f"Metadata : {config['metadata']}")
|
||||
click.echo(f"Hyperparams : {config.get('hyperparameters', {})}")
|
||||
|
||||
log_path = run_dir_path / 'training.log'
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text().splitlines()
|
||||
episode_lines = [l for l in lines if l.startswith('[EP')]
|
||||
if episode_lines:
|
||||
click.echo(f"\nLast 5 episodes:")
|
||||
for line in episode_lines[-5:]:
|
||||
click.echo(f" {line}")
|
||||
|
||||
ckpt_dir = run_dir_path / 'checkpoints'
|
||||
if ckpt_dir.exists():
|
||||
ckpts = sorted(ckpt_dir.glob('*.pt'))
|
||||
click.echo(f"\nCheckpoints ({len(ckpts)}): {[c.name for c in ckpts]}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_config(run_dir: Path) -> dict:
|
||||
config_path = run_dir / 'config.toml'
|
||||
if not config_path.exists():
|
||||
raise click.ClickException(f"No config.toml found in {run_dir}")
|
||||
with open(config_path, 'rb') as f:
|
||||
return tomllib.load(f)
|
||||
|
||||
|
||||
def _load_factory(module_name: str):
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError as e:
|
||||
raise click.ClickException(f"Cannot import game module '{module_name}': {e}")
|
||||
if not hasattr(module, 'create_game'):
|
||||
raise click.ClickException(
|
||||
f"Module '{module_name}' has no create_game() function"
|
||||
)
|
||||
return module.create_game
|
||||
78
retro_gamer/env.py
Normal file
78
retro_gamer/env.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
import random
|
||||
import numpy as np
|
||||
from typing import Callable
|
||||
from retro.input import ProgrammaticInput
|
||||
from retro.views.headless import HeadlessView
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
from retro_gamer.observation import encode_observation
|
||||
|
||||
|
||||
class GameEnvironment:
|
||||
"""Gym-style wrapper around a retro Game for RL training.
|
||||
|
||||
Provides reset() / step(action) / observe(), managing one training episode
|
||||
at a time. The game is restarted by calling the factory function on each reset.
|
||||
"""
|
||||
|
||||
def __init__(self, game_factory: Callable, metadata: GameMetadata):
|
||||
self.game_factory = game_factory
|
||||
self.metadata = metadata
|
||||
self.game = None
|
||||
self.view: HeadlessView | None = None
|
||||
self.inp: ProgrammaticInput | None = None
|
||||
self._prev_reward: float = 0.0
|
||||
|
||||
def reset(self) -> np.ndarray:
|
||||
"""Create a fresh game episode and return the initial observation."""
|
||||
self.inp = ProgrammaticInput()
|
||||
self.view = HeadlessView()
|
||||
self.game = self.game_factory()
|
||||
self.game.input_source = self.inp
|
||||
self.game.view = self.view
|
||||
self.game.start()
|
||||
self._prev_reward = float(self.game.state.get(self.metadata.reward, 0))
|
||||
return self._observe()
|
||||
|
||||
def step(self, action: str | None) -> tuple[np.ndarray, float, bool]:
|
||||
"""Advance one turn with the given action.
|
||||
|
||||
action: a keystroke string (e.g. 'KEY_RIGHT') or None for no-op.
|
||||
Returns (observation, reward, done).
|
||||
Reward is the change in the reward state key since the previous step.
|
||||
"""
|
||||
self.inp.press(action)
|
||||
self.game.step()
|
||||
obs = self._observe()
|
||||
reward = self._delta_reward()
|
||||
done = not self.game.playing
|
||||
return obs, reward, done
|
||||
|
||||
def _observe(self) -> np.ndarray:
|
||||
return encode_observation(
|
||||
self.view.board_characters,
|
||||
dict(self.game.state),
|
||||
self.metadata,
|
||||
)
|
||||
|
||||
def _delta_reward(self) -> float:
|
||||
current = float(self.game.state.get(self.metadata.reward, 0))
|
||||
delta = current - self._prev_reward
|
||||
self._prev_reward = current
|
||||
return delta
|
||||
|
||||
def discover_character_set(self, exploration_turns: int) -> list[str]:
|
||||
"""Run random turns to discover the characters that appear on the board.
|
||||
Returns the sorted character list (excluding space).
|
||||
"""
|
||||
obs = self.reset()
|
||||
chars: set[str] = set()
|
||||
for _ in range(exploration_turns):
|
||||
for row in self.view.board_characters:
|
||||
chars.update(row)
|
||||
action = random.choice(self.metadata.actions + [None])
|
||||
_, _, done = self.step(action)
|
||||
if done:
|
||||
self.reset()
|
||||
chars.discard(' ')
|
||||
return sorted(chars)
|
||||
0
retro_gamer/examples/__init__.py
Normal file
0
retro_gamer/examples/__init__.py
Normal file
17
retro_gamer/examples/beast/__init__.py
Normal file
17
retro_gamer/examples/beast/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from retro.game import Game
|
||||
from retro_gamer.examples.beast.board import Board
|
||||
|
||||
WIDTH = 40
|
||||
HEIGHT = 20
|
||||
NUM_BEASTS = 10
|
||||
|
||||
def create_game():
|
||||
"""Return a fresh, initialized Beast game."""
|
||||
board = Board(WIDTH, HEIGHT, num_beasts=NUM_BEASTS)
|
||||
state = {'beasts_killed': 0}
|
||||
game = Game(board.get_agents(), state, board_size=(WIDTH, HEIGHT))
|
||||
game.num_beasts = NUM_BEASTS
|
||||
return game
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_game().play()
|
||||
0
retro_gamer/examples/beast/agents/__init__.py
Normal file
0
retro_gamer/examples/beast/agents/__init__.py
Normal file
67
retro_gamer/examples/beast/agents/beast.py
Normal file
67
retro_gamer/examples/beast/agents/beast.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from retro_gamer.examples.beast.helpers import add, distance, get_occupant
|
||||
from random import random, choice
|
||||
|
||||
class Beast:
|
||||
"""A beast that hunts the player."""
|
||||
character = "H"
|
||||
color = "red"
|
||||
probability_of_moving = 0.03
|
||||
probability_of_random_move = 0.2
|
||||
deadly = True
|
||||
|
||||
def __init__(self, position):
|
||||
self.position = position
|
||||
|
||||
def handle_push(self, vector, game):
|
||||
future_position = add(self.position, vector)
|
||||
on_board = game.on_board(future_position)
|
||||
obstacle = get_occupant(game, future_position)
|
||||
if obstacle or not on_board:
|
||||
self.die(game)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def play_turn(self, game):
|
||||
if self.should_move():
|
||||
possible_moves = []
|
||||
for position in self.get_adjacent_positions():
|
||||
if game.is_empty(position) and game.on_board(position):
|
||||
possible_moves.append(position)
|
||||
if possible_moves:
|
||||
if self.should_move_randomly():
|
||||
self.position = choice(possible_moves)
|
||||
else:
|
||||
self.position = self.choose_best_move(possible_moves, game)
|
||||
player = game.get_agent_by_name("player")
|
||||
if player and player.position == self.position:
|
||||
player.die(game)
|
||||
|
||||
def get_adjacent_positions(self):
|
||||
"""Returns all eight adjacent positions, including diagonals."""
|
||||
positions = []
|
||||
for i in [-1, 0, 1]:
|
||||
for j in [-1, 0, 1]:
|
||||
if i or j:
|
||||
positions.append(add(self.position, (i, j)))
|
||||
return positions
|
||||
|
||||
def should_move(self):
|
||||
return random() < self.probability_of_moving
|
||||
|
||||
def should_move_randomly(self):
|
||||
return random() < self.probability_of_random_move
|
||||
|
||||
def choose_best_move(self, possible_moves, game):
|
||||
player = game.get_agent_by_name("player")
|
||||
move_distances = [[distance(player.position, move), move] for move in possible_moves]
|
||||
shortest_distance, best_move = sorted(move_distances)[0]
|
||||
return best_move
|
||||
|
||||
def die(self, game):
|
||||
game.remove_agent(self)
|
||||
game.num_beasts -= 1
|
||||
game.state['beasts_killed'] += 1
|
||||
if game.num_beasts == 0:
|
||||
game.state["message"] = "You win!"
|
||||
game.end()
|
||||
25
retro_gamer/examples/beast/agents/block.py
Normal file
25
retro_gamer/examples/beast/agents/block.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from retro_gamer.examples.beast.helpers import add, get_occupant
|
||||
|
||||
class Block:
|
||||
"""A static block that can be pushed by the player."""
|
||||
character = "█"
|
||||
color = "green4"
|
||||
deadly = False
|
||||
|
||||
def __init__(self, position):
|
||||
self.position = position
|
||||
|
||||
def handle_push(self, vector, game):
|
||||
"""Responds to a push in the direction of vector.
|
||||
Returns True when the push succeeds in creating empty space.
|
||||
"""
|
||||
future_position = add(self.position, vector)
|
||||
on_board = game.on_board(future_position)
|
||||
obstacle = get_occupant(game, future_position)
|
||||
if obstacle:
|
||||
success = obstacle.handle_push(vector, game)
|
||||
else:
|
||||
success = on_board
|
||||
if success:
|
||||
self.position = future_position
|
||||
return success
|
||||
39
retro_gamer/examples/beast/agents/player.py
Normal file
39
retro_gamer/examples/beast/agents/player.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from retro_gamer.examples.beast.helpers import add, get_occupant
|
||||
|
||||
direction_vectors = {
|
||||
"KEY_RIGHT": (1, 0),
|
||||
"KEY_UP": (0, -1),
|
||||
"KEY_LEFT": (-1, 0),
|
||||
"KEY_DOWN": (0, 1),
|
||||
}
|
||||
|
||||
class Player:
|
||||
character = "*"
|
||||
color = "white"
|
||||
name = "player"
|
||||
deadly = False
|
||||
|
||||
def __init__(self, position):
|
||||
self.position = position
|
||||
|
||||
def handle_keystroke(self, keystroke, game):
|
||||
if keystroke.name in direction_vectors:
|
||||
vector = direction_vectors[keystroke.name]
|
||||
self.try_to_move(vector, game)
|
||||
|
||||
def try_to_move(self, vector, game):
|
||||
future_position = add(self.position, vector)
|
||||
on_board = game.on_board(future_position)
|
||||
obstacle = get_occupant(game, future_position)
|
||||
if obstacle:
|
||||
if obstacle.deadly:
|
||||
self.die(game)
|
||||
elif obstacle.handle_push(vector, game):
|
||||
self.position = future_position
|
||||
elif on_board:
|
||||
self.position = future_position
|
||||
|
||||
def die(self, game):
|
||||
self.color = "black_on_red"
|
||||
game.state["message"] = "The beasties win!"
|
||||
game.end()
|
||||
44
retro_gamer/examples/beast/board.py
Normal file
44
retro_gamer/examples/beast/board.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from random import shuffle
|
||||
from retro_gamer.examples.beast.agents.player import Player
|
||||
from retro_gamer.examples.beast.agents.beast import Beast
|
||||
from retro_gamer.examples.beast.agents.block import Block
|
||||
|
||||
class Board:
|
||||
"""Creates the agents needed at the beginning of the game and assigns their positions."""
|
||||
|
||||
def __init__(self, width, height, block_density=0.3, num_beasts=10):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.block_density = block_density
|
||||
self.num_blocks = round(width * height * block_density)
|
||||
self.num_empty_spaces = width * height - self.num_blocks
|
||||
self.num_beasts = num_beasts
|
||||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
if self.block_density < 0 or self.block_density > 1:
|
||||
raise ValueError("block density must be between 0 and 1.")
|
||||
if self.num_empty_spaces < self.num_beasts + 1:
|
||||
raise ValueError("Not enough space on the board.")
|
||||
|
||||
def get_agents(self):
|
||||
"""Returns a list of agents initialized in their starting positions."""
|
||||
positions = self.get_all_positions()
|
||||
shuffle(positions)
|
||||
|
||||
player_position = positions[0]
|
||||
beast_positions = positions[1:self.num_beasts + 1]
|
||||
block_positions = positions[-self.num_blocks:]
|
||||
|
||||
player = [Player(player_position)]
|
||||
beasts = [Beast(pos) for pos in beast_positions]
|
||||
blocks = [Block(pos) for pos in block_positions]
|
||||
return player + beasts + blocks
|
||||
|
||||
def get_all_positions(self):
|
||||
"""Returns a list of all positions on the board."""
|
||||
positions = []
|
||||
for i in range(self.width):
|
||||
for j in range(self.height):
|
||||
positions.append((i, j))
|
||||
return positions
|
||||
18
retro_gamer/examples/beast/helpers.py
Normal file
18
retro_gamer/examples/beast/helpers.py
Normal file
@@ -0,0 +1,18 @@
|
||||
def add(vec0, vec1):
|
||||
"""Adds two vectors."""
|
||||
x0, y0 = vec0
|
||||
x1, y1 = vec1
|
||||
return (x0 + x1, y0 + y1)
|
||||
|
||||
def get_occupant(game, position):
|
||||
"""Returns the agent at position, if there is one."""
|
||||
positions_with_agents = game.get_agents_by_position()
|
||||
if position in positions_with_agents:
|
||||
agents_at_position = positions_with_agents[position]
|
||||
return agents_at_position[0]
|
||||
|
||||
def distance(vec0, vec1):
|
||||
"""Returns the Manhattan distance between two positions."""
|
||||
x0, y0 = vec0
|
||||
x1, y1 = vec1
|
||||
return abs(x1 - x0) + abs(y1 - y0)
|
||||
6
retro_gamer/examples/beast/pyproject.toml
Normal file
6
retro_gamer/examples/beast/pyproject.toml
Normal file
@@ -0,0 +1,6 @@
|
||||
[tool.retro-gamer]
|
||||
actions = ["KEY_RIGHT", "KEY_UP", "KEY_LEFT", "KEY_DOWN"]
|
||||
reward = "beasts_killed"
|
||||
character_set = ["*", "H", "█"]
|
||||
spatial = true
|
||||
observe_state = []
|
||||
74
retro_gamer/memory.py
Normal file
74
retro_gamer/memory.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
import random
|
||||
from collections import deque
|
||||
from typing import NamedTuple
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Experience(NamedTuple):
|
||||
state: np.ndarray
|
||||
action: int
|
||||
reward: float
|
||||
next_state: np.ndarray
|
||||
done: bool
|
||||
|
||||
|
||||
class ReplayMemory:
|
||||
"""Fixed-capacity ring buffer of experiences sampled uniformly."""
|
||||
|
||||
def __init__(self, capacity: int):
|
||||
self.memory: deque[Experience] = deque(maxlen=capacity)
|
||||
|
||||
def push(self, state, action, reward, next_state, done):
|
||||
self.memory.append(Experience(state, action, reward, next_state, done))
|
||||
|
||||
def sample(self, batch_size: int) -> list[Experience]:
|
||||
return random.sample(self.memory, batch_size)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
|
||||
|
||||
class PrioritizedReplayMemory:
|
||||
"""Experience replay buffer with priority-weighted sampling.
|
||||
|
||||
Experiences with higher TD-error are sampled more often (alpha controls
|
||||
the strength of prioritization). Importance-sampling weights (beta) correct
|
||||
for the resulting bias.
|
||||
"""
|
||||
|
||||
def __init__(self, capacity: int, alpha: float = 0.6, beta: float = 0.4):
|
||||
self.capacity = capacity
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.memory: list[Experience] = []
|
||||
self.priorities: list[float] = []
|
||||
self._pos = 0
|
||||
|
||||
def push(self, state, action, reward, next_state, done):
|
||||
max_priority = max(self.priorities, default=1.0)
|
||||
exp = Experience(state, action, reward, next_state, done)
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append(exp)
|
||||
self.priorities.append(max_priority)
|
||||
else:
|
||||
self.memory[self._pos] = exp
|
||||
self.priorities[self._pos] = max_priority
|
||||
self._pos = (self._pos + 1) % self.capacity
|
||||
|
||||
def sample(self, batch_size: int) -> tuple[list[Experience], np.ndarray, np.ndarray]:
|
||||
"""Returns (experiences, indices, importance_sampling_weights)."""
|
||||
probs = np.array(self.priorities, dtype=np.float64) ** self.alpha
|
||||
probs /= probs.sum()
|
||||
indices = np.random.choice(len(self.memory), batch_size, p=probs)
|
||||
weights = (len(self.memory) * probs[indices]) ** -self.beta
|
||||
weights = (weights / weights.max()).astype(np.float32)
|
||||
experiences = [self.memory[i] for i in indices]
|
||||
return experiences, indices, weights
|
||||
|
||||
def update_priorities(self, indices: np.ndarray, td_errors: np.ndarray):
|
||||
for idx, err in zip(indices, td_errors):
|
||||
self.priorities[idx] = float(abs(err)) + 1e-6
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
121
retro_gamer/metadata.py
Normal file
121
retro_gamer/metadata.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import tomllib
|
||||
import tomli_w
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class GameMetadata:
|
||||
"""Describes a retro game for training purposes.
|
||||
|
||||
Required fields: actions, reward.
|
||||
Optional fields: character_set, spatial, observe_state.
|
||||
Discovered fields: board_size (read from game.board_size at startup).
|
||||
|
||||
Metadata is read from the game's own pyproject.toml under the
|
||||
[tool.retro-gamer] section via GameMetadata.from_pyproject().
|
||||
"""
|
||||
actions: list[str]
|
||||
reward: str
|
||||
character_set: list[str] | None = None
|
||||
spatial: bool = True
|
||||
observe_state: list[str] = field(default_factory=list)
|
||||
board_size: tuple[int, int] | None = None
|
||||
|
||||
def validate(self):
|
||||
if not self.actions:
|
||||
raise ValueError("actions must be a non-empty list")
|
||||
if not self.reward:
|
||||
raise ValueError("reward must be a non-empty string")
|
||||
if self.reward in self.observe_state:
|
||||
raise ValueError(f"reward key '{self.reward}' must not appear in observe_state")
|
||||
if self.character_set is not None:
|
||||
for ch in self.character_set:
|
||||
if len(ch) != 1:
|
||||
raise ValueError(f"character_set entries must be single characters, got {ch!r}")
|
||||
|
||||
@classmethod
|
||||
def from_pyproject(cls, module_name: str) -> GameMetadata:
|
||||
"""Load metadata from the [tool.retro-gamer] section of the game's pyproject.toml.
|
||||
|
||||
Finds pyproject.toml by walking up from the module's source file.
|
||||
Raises FileNotFoundError if no pyproject.toml is found, or ValueError
|
||||
if the file exists but has no [tool.retro-gamer] section.
|
||||
"""
|
||||
pyproject_path = _find_pyproject(module_name)
|
||||
if pyproject_path is None:
|
||||
raise FileNotFoundError(
|
||||
f"Could not find pyproject.toml for module '{module_name}'. "
|
||||
f"Make sure the module is part of a Python project with a pyproject.toml."
|
||||
)
|
||||
with open(pyproject_path, 'rb') as f:
|
||||
data = tomllib.load(f)
|
||||
section = data.get('tool', {}).get('retro-gamer')
|
||||
if section is None:
|
||||
raise ValueError(
|
||||
f"No [tool.retro-gamer] section found in {pyproject_path}.\n"
|
||||
f"Add game metadata to your pyproject.toml:\n\n"
|
||||
f"[tool.retro-gamer]\n"
|
||||
f"actions = [\"KEY_RIGHT\", ...]\n"
|
||||
f"reward = \"score\"\n"
|
||||
)
|
||||
return cls.from_dict(section)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> GameMetadata:
|
||||
board_size = tuple(d['board_size']) if 'board_size' in d else None
|
||||
return cls(
|
||||
actions=d['actions'],
|
||||
reward=d['reward'],
|
||||
character_set=d.get('character_set'),
|
||||
spatial=d.get('spatial', True),
|
||||
observe_state=d.get('observe_state', []),
|
||||
board_size=board_size,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = {
|
||||
'actions': self.actions,
|
||||
'reward': self.reward,
|
||||
'spatial': self.spatial,
|
||||
'observe_state': self.observe_state,
|
||||
}
|
||||
if self.board_size is not None:
|
||||
d['board_size'] = list(self.board_size)
|
||||
if self.character_set is not None:
|
||||
d['character_set'] = self.character_set
|
||||
return d
|
||||
|
||||
def to_toml(self, path: str | Path):
|
||||
with open(path, 'wb') as f:
|
||||
tomli_w.dump({'metadata': self.to_dict()}, f)
|
||||
|
||||
@property
|
||||
def obs_size(self) -> int:
|
||||
"""Total size of the flat observation vector."""
|
||||
C = len(self.character_set) if self.character_set else 0
|
||||
bw, bh = self.board_size
|
||||
return C * bw * bh + len(self.observe_state)
|
||||
|
||||
@property
|
||||
def n_actions(self) -> int:
|
||||
"""Number of actions including no-op."""
|
||||
return len(self.actions) + 1
|
||||
|
||||
|
||||
def _find_pyproject(module_name: str) -> Path | None:
|
||||
"""Walk up from a module's source file to find its pyproject.toml."""
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError:
|
||||
return None
|
||||
module_file = getattr(module, '__file__', None)
|
||||
if module_file is None:
|
||||
return None
|
||||
for parent in Path(module_file).resolve().parents:
|
||||
candidate = parent / 'pyproject.toml'
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
100
retro_gamer/network.py
Normal file
100
retro_gamer/network.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
|
||||
|
||||
def build_network(
|
||||
metadata: GameMetadata,
|
||||
hyperparams: dict,
|
||||
) -> tuple[nn.Module, str]:
|
||||
"""Build a Q-network from game metadata and hyperparameters.
|
||||
|
||||
Returns (model, rationale) where rationale is a multi-line string
|
||||
describing the architecture and the reasoning behind each choice.
|
||||
"""
|
||||
n_layers = hyperparams.get('n_layers', 2)
|
||||
layer_size = hyperparams.get('layer_size', 128)
|
||||
C = len(metadata.character_set)
|
||||
bw, bh = metadata.board_size
|
||||
W, H = bw, bh
|
||||
n_state = len(metadata.observe_state)
|
||||
n_actions = metadata.n_actions
|
||||
|
||||
lines = []
|
||||
lines.append("[INIT] === Network Architecture ===")
|
||||
lines.append(f"[INIT] Board: {W}×{H}, character set: {C} chars (one-hot per cell)")
|
||||
lines.append(f"[INIT] Observed state keys: {n_state} | Actions (incl. no-op): {n_actions}")
|
||||
|
||||
if metadata.spatial:
|
||||
model = _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines)
|
||||
else:
|
||||
obs_size = C * W * H + n_state
|
||||
model = _build_flat(obs_size, n_layers, layer_size, n_actions, lines)
|
||||
|
||||
lines.append(f"[INIT] Hidden layers: {n_layers} | Layer width: {layer_size}")
|
||||
lines.append(f"[INIT] Output: {n_actions} Q-values")
|
||||
lines.append(f"[INIT] Actions: {metadata.actions} + (no-op)")
|
||||
return model, '\n'.join(lines)
|
||||
|
||||
|
||||
def _build_spatial(C, H, W, n_state, n_layers, layer_size, n_actions, lines):
|
||||
lines.append("[INIT] spatial=True → using CNN architecture")
|
||||
lines.append("[INIT] Rationale: the board is a 2-D spatial scene; a CNN captures")
|
||||
lines.append("[INIT] local patterns (walls, items nearby) more efficiently than an MLP.")
|
||||
lines.append(f"[INIT] CNN: Conv2d({C}→32, k=3, pad=1) → ReLU → Conv2d(32→64, k=3, pad=1) → ReLU")
|
||||
conv_out = 64 * H * W # padding=1 preserves spatial dims
|
||||
lines.append(f"[INIT] CNN output: 64 channels × {H}×{W} = {conv_out} features (flattened)")
|
||||
mlp_in = conv_out + n_state
|
||||
lines.append(f"[INIT] MLP head input: {conv_out} (conv) + {n_state} (state) = {mlp_in}")
|
||||
lines.append(f"[INIT] MLP: {' → '.join([str(mlp_in)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
|
||||
return _SpatialNet(C, H, W, n_state, n_layers, layer_size, n_actions)
|
||||
|
||||
|
||||
def _build_flat(obs_size, n_layers, layer_size, n_actions, lines):
|
||||
lines.append("[INIT] spatial=False → using MLP architecture")
|
||||
lines.append("[INIT] Rationale: the board encodes UI/status rather than a spatial scene;")
|
||||
lines.append("[INIT] a flat MLP over the full observation is sufficient.")
|
||||
lines.append(f"[INIT] MLP: {' → '.join([str(obs_size)] + [str(layer_size)] * n_layers + [str(n_actions)])}")
|
||||
return _FlatNet(obs_size, n_layers, layer_size, n_actions)
|
||||
|
||||
|
||||
class _SpatialNet(nn.Module):
|
||||
def __init__(self, C, H, W, n_state, n_layers, layer_size, n_actions):
|
||||
super().__init__()
|
||||
self.C, self.H, self.W = C, H, W
|
||||
self.n_board = C * H * W
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(C, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
)
|
||||
conv_out = 64 * H * W
|
||||
mlp_in = conv_out + n_state
|
||||
layers: list[nn.Module] = []
|
||||
for i in range(n_layers):
|
||||
in_size = mlp_in if i == 0 else layer_size
|
||||
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
|
||||
layers.append(nn.Linear(layer_size, n_actions))
|
||||
self.mlp = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
board = x[:, :self.n_board].reshape(-1, self.C, self.H, self.W)
|
||||
state = x[:, self.n_board:]
|
||||
conv_out = self.conv(board).flatten(start_dim=1)
|
||||
return self.mlp(torch.cat([conv_out, state], dim=1))
|
||||
|
||||
|
||||
class _FlatNet(nn.Module):
|
||||
def __init__(self, obs_size, n_layers, layer_size, n_actions):
|
||||
super().__init__()
|
||||
layers: list[nn.Module] = []
|
||||
for i in range(n_layers):
|
||||
in_size = obs_size if i == 0 else layer_size
|
||||
layers += [nn.Linear(in_size, layer_size), nn.ReLU()]
|
||||
layers.append(nn.Linear(layer_size, n_actions))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
49
retro_gamer/observation.py
Normal file
49
retro_gamer/observation.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import numpy as np
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
|
||||
|
||||
def encode_board(board_chars: list[list[str]], character_set: list[str]) -> np.ndarray:
|
||||
"""One-hot encode the board.
|
||||
|
||||
Returns an array of shape (H, W, C) where C = len(character_set).
|
||||
Unknown characters produce a zero vector.
|
||||
"""
|
||||
char_to_idx = {c: i for i, c in enumerate(character_set)}
|
||||
H = len(board_chars)
|
||||
W = len(board_chars[0]) if board_chars else 0
|
||||
C = len(character_set)
|
||||
arr = np.zeros((H, W, C), dtype=np.float32)
|
||||
for y, row in enumerate(board_chars):
|
||||
for x, char in enumerate(row):
|
||||
idx = char_to_idx.get(char)
|
||||
if idx is not None:
|
||||
arr[y, x, idx] = 1.0
|
||||
return arr
|
||||
|
||||
|
||||
def encode_state(state: dict, observe_state: list[str]) -> np.ndarray:
|
||||
"""Extract observed state keys into a 1D float array."""
|
||||
return np.array([float(state.get(k, 0)) for k in observe_state], dtype=np.float32)
|
||||
|
||||
|
||||
def encode_observation(
|
||||
board_chars: list[list[str]],
|
||||
state: dict,
|
||||
metadata: GameMetadata,
|
||||
) -> np.ndarray:
|
||||
"""Encode board + state into a flat 1D observation vector.
|
||||
|
||||
For spatial games the board is encoded channel-first (C, H, W) then flattened,
|
||||
so the network can reshape it back for CNN processing. For non-spatial games the
|
||||
board is encoded (H, W, C) then flattened.
|
||||
The state vector is appended at the end in both cases.
|
||||
"""
|
||||
if not metadata.character_set:
|
||||
raise ValueError("character_set must be set before encoding observations")
|
||||
board = encode_board(board_chars, metadata.character_set) # (H, W, C)
|
||||
if metadata.spatial:
|
||||
board_vec = board.transpose(2, 0, 1).flatten() # C*H*W, channel-first
|
||||
else:
|
||||
board_vec = board.flatten() # H*W*C
|
||||
state_vec = encode_state(state, metadata.observe_state)
|
||||
return np.concatenate([board_vec, state_vec])
|
||||
255
retro_gamer/trainer.py
Normal file
255
retro_gamer/trainer.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from __future__ import annotations
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import tomli_w
|
||||
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
from retro_gamer.env import GameEnvironment
|
||||
from retro_gamer.network import build_network
|
||||
from retro_gamer.memory import ReplayMemory, PrioritizedReplayMemory
|
||||
|
||||
DEFAULTS: dict = {
|
||||
'learning_rate': 1e-3,
|
||||
'lr_decay': 0.995,
|
||||
'gamma': 0.99,
|
||||
'epsilon': 1.0,
|
||||
'epsilon_decay': 0.995,
|
||||
'epsilon_min': 0.05,
|
||||
'batch_size': 64,
|
||||
'memory_capacity': 10_000,
|
||||
'target_update_freq': 100,
|
||||
'training_episodes': 1_000,
|
||||
'n_layers': 2,
|
||||
'layer_size': 128,
|
||||
'prioritize_experiences': False,
|
||||
'exploration_turns': 200,
|
||||
'unknown_character_strategy': 'ignore',
|
||||
'max_turns_per_episode': 2_000,
|
||||
}
|
||||
|
||||
|
||||
class DQNTrainer:
|
||||
"""Trains a deep Q-network agent to play a retro game.
|
||||
|
||||
On initialization the trainer:
|
||||
1. Discovers the character set (if not already specified in metadata).
|
||||
2. Builds the Q-network and logs the full architecture with rationale.
|
||||
3. Saves config.toml and starts training.log in run_dir.
|
||||
|
||||
Call train() to run all episodes and save checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
game_factory: Callable,
|
||||
metadata: GameMetadata,
|
||||
run_dir: str | Path,
|
||||
**hyperparams,
|
||||
):
|
||||
self.game_factory = game_factory
|
||||
self.metadata = metadata
|
||||
self.run_dir = Path(run_dir)
|
||||
self.hp: dict = {**DEFAULTS, **hyperparams}
|
||||
self.run_dir.mkdir(parents=True, exist_ok=True)
|
||||
(self.run_dir / 'checkpoints').mkdir(exist_ok=True)
|
||||
|
||||
self.env = GameEnvironment(game_factory, metadata)
|
||||
|
||||
if metadata.board_size is None:
|
||||
g = game_factory()
|
||||
metadata.board_size = g.board_size
|
||||
|
||||
if metadata.character_set is None:
|
||||
self._discover_character_set()
|
||||
|
||||
self.model, rationale = build_network(metadata, self.hp)
|
||||
self.target_model, _ = build_network(metadata, self.hp)
|
||||
self.target_model.load_state_dict(self.model.state_dict())
|
||||
self.target_model.eval()
|
||||
|
||||
self.optimizer = optim.Adam(
|
||||
self.model.parameters(), lr=self.hp['learning_rate']
|
||||
)
|
||||
self.lr_scheduler = optim.lr_scheduler.ExponentialLR(
|
||||
self.optimizer, gamma=self.hp['lr_decay']
|
||||
)
|
||||
|
||||
if self.hp['prioritize_experiences']:
|
||||
self.memory = PrioritizedReplayMemory(self.hp['memory_capacity'])
|
||||
else:
|
||||
self.memory = ReplayMemory(self.hp['memory_capacity'])
|
||||
|
||||
self.epsilon: float = self.hp['epsilon']
|
||||
self.total_steps: int = 0
|
||||
|
||||
self._save_config()
|
||||
self._open_log(rationale)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def train(self):
|
||||
"""Run all training episodes and save checkpoints."""
|
||||
for episode in range(1, self.hp['training_episodes'] + 1):
|
||||
total_reward, steps, avg_loss = self._run_episode()
|
||||
self.epsilon = max(
|
||||
self.hp['epsilon_min'], self.epsilon * self.hp['epsilon_decay']
|
||||
)
|
||||
self.lr_scheduler.step()
|
||||
self._log_episode(episode, total_reward, steps, avg_loss)
|
||||
if episode % 100 == 0:
|
||||
self._save_checkpoint(f'ep_{episode:04d}.pt')
|
||||
self._save_checkpoint('final.pt')
|
||||
|
||||
def load_checkpoint(self, path: str | Path):
|
||||
ckpt = torch.load(path, weights_only=True)
|
||||
self.model.load_state_dict(ckpt['model_state_dict'])
|
||||
self.target_model.load_state_dict(ckpt['model_state_dict'])
|
||||
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
|
||||
self.epsilon = ckpt['epsilon']
|
||||
self.total_steps = ckpt['total_steps']
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Training loop internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_episode(self) -> tuple[float, int, float]:
|
||||
state = self.env.reset()
|
||||
total_reward = 0.0
|
||||
total_loss = 0.0
|
||||
loss_count = 0
|
||||
|
||||
for step in range(self.hp['max_turns_per_episode']):
|
||||
state_t = torch.as_tensor(state, dtype=torch.float32)
|
||||
action_idx = self._select_action(state_t)
|
||||
action_key = self._idx_to_key(action_idx)
|
||||
|
||||
next_state, reward, done = self.env.step(action_key)
|
||||
self.memory.push(state, action_idx, reward, next_state, done)
|
||||
|
||||
loss = self._train_step()
|
||||
if loss is not None:
|
||||
total_loss += loss
|
||||
loss_count += 1
|
||||
|
||||
self.total_steps += 1
|
||||
if self.total_steps % self.hp['target_update_freq'] == 0:
|
||||
self.target_model.load_state_dict(self.model.state_dict())
|
||||
|
||||
state = next_state
|
||||
total_reward += reward
|
||||
if done:
|
||||
break
|
||||
|
||||
avg_loss = total_loss / loss_count if loss_count else 0.0
|
||||
return total_reward, step + 1, avg_loss
|
||||
|
||||
def _select_action(self, state_t: torch.Tensor) -> int:
|
||||
if random.random() < self.epsilon:
|
||||
return random.randrange(self.metadata.n_actions)
|
||||
with torch.no_grad():
|
||||
return int(self.model(state_t.unsqueeze(0)).argmax().item())
|
||||
|
||||
def _idx_to_key(self, idx: int) -> str | None:
|
||||
if idx >= len(self.metadata.actions):
|
||||
return None
|
||||
return self.metadata.actions[idx]
|
||||
|
||||
def _train_step(self) -> float | None:
|
||||
if len(self.memory) < self.hp['batch_size']:
|
||||
return None
|
||||
|
||||
if self.hp['prioritize_experiences']:
|
||||
assert isinstance(self.memory, PrioritizedReplayMemory)
|
||||
experiences, indices, weights = self.memory.sample(self.hp['batch_size'])
|
||||
weight_t = torch.as_tensor(weights, dtype=torch.float32)
|
||||
else:
|
||||
experiences = self.memory.sample(self.hp['batch_size'])
|
||||
indices = None
|
||||
weight_t = None
|
||||
|
||||
states = torch.as_tensor(
|
||||
np.array([e.state for e in experiences]), dtype=torch.float32
|
||||
)
|
||||
actions = torch.as_tensor([e.action for e in experiences], dtype=torch.long)
|
||||
rewards = torch.as_tensor([e.reward for e in experiences], dtype=torch.float32)
|
||||
next_states = torch.as_tensor(
|
||||
np.array([e.next_state for e in experiences]), dtype=torch.float32
|
||||
)
|
||||
dones = torch.as_tensor([e.done for e in experiences], dtype=torch.float32)
|
||||
|
||||
q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
with torch.no_grad():
|
||||
next_q = self.target_model(next_states).max(1).values
|
||||
targets = rewards + self.hp['gamma'] * next_q * (1.0 - dones)
|
||||
|
||||
element_loss = nn.functional.mse_loss(q_values, targets, reduction='none')
|
||||
|
||||
if weight_t is not None:
|
||||
loss = (weight_t * element_loss).mean()
|
||||
td_errors = (q_values - targets).detach().abs().numpy()
|
||||
self.memory.update_priorities(indices, td_errors)
|
||||
else:
|
||||
loss = element_loss.mean()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
return float(loss.item())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Initialisation helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _discover_character_set(self):
|
||||
chars = self.env.discover_character_set(self.hp['exploration_turns'])
|
||||
self.metadata.character_set = chars
|
||||
self._log_raw(
|
||||
f"[INIT] character_set not specified — discovered {len(chars)} chars "
|
||||
f"after {self.hp['exploration_turns']} exploration turns: {chars}"
|
||||
)
|
||||
|
||||
def _save_config(self):
|
||||
config_path = self.run_dir / 'config.toml'
|
||||
config: dict = {}
|
||||
if config_path.exists():
|
||||
import tomllib
|
||||
with open(config_path, 'rb') as f:
|
||||
config = tomllib.load(f)
|
||||
config['metadata'] = self.metadata.to_dict()
|
||||
config['hyperparameters'] = self.hp
|
||||
with open(config_path, 'wb') as f:
|
||||
tomli_w.dump(config, f)
|
||||
|
||||
def _open_log(self, rationale: str):
|
||||
self.log_path = self.run_dir / 'training.log'
|
||||
with open(self.log_path, 'w') as f:
|
||||
f.write(rationale + '\n')
|
||||
|
||||
def _log_raw(self, line: str):
|
||||
with open(self.log_path, 'a') as f:
|
||||
f.write(line + '\n')
|
||||
|
||||
def _log_episode(self, episode: int, total_reward: float, steps: int, avg_loss: float):
|
||||
line = (
|
||||
f"[EP {episode:04d}] total_reward={total_reward:.1f} "
|
||||
f"steps={steps} epsilon={self.epsilon:.4f} avg_loss={avg_loss:.6f}"
|
||||
)
|
||||
self._log_raw(line)
|
||||
|
||||
def _save_checkpoint(self, name: str):
|
||||
torch.save(
|
||||
{
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'epsilon': self.epsilon,
|
||||
'total_steps': self.total_steps,
|
||||
},
|
||||
self.run_dir / 'checkpoints' / name,
|
||||
)
|
||||
Reference in New Issue
Block a user