256 lines
9.1 KiB
Python
256 lines
9.1 KiB
Python
from __future__ import annotations
|
|
import random
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import tomli_w
|
|
|
|
from retro_gamer.metadata import GameMetadata
|
|
from retro_gamer.env import GameEnvironment
|
|
from retro_gamer.network import build_network
|
|
from retro_gamer.memory import ReplayMemory, PrioritizedReplayMemory
|
|
|
|
DEFAULTS: dict = {
|
|
'learning_rate': 1e-3,
|
|
'lr_decay': 0.995,
|
|
'gamma': 0.99,
|
|
'epsilon': 1.0,
|
|
'epsilon_decay': 0.995,
|
|
'epsilon_min': 0.05,
|
|
'batch_size': 64,
|
|
'memory_capacity': 10_000,
|
|
'target_update_freq': 100,
|
|
'training_episodes': 1_000,
|
|
'n_layers': 2,
|
|
'layer_size': 128,
|
|
'prioritize_experiences': False,
|
|
'exploration_turns': 200,
|
|
'unknown_character_strategy': 'ignore',
|
|
'max_turns_per_episode': 2_000,
|
|
}
|
|
|
|
|
|
class DQNTrainer:
|
|
"""Trains a deep Q-network agent to play a retro game.
|
|
|
|
On initialization the trainer:
|
|
1. Discovers the character set (if not already specified in metadata).
|
|
2. Builds the Q-network and logs the full architecture with rationale.
|
|
3. Saves config.toml and starts training.log in run_dir.
|
|
|
|
Call train() to run all episodes and save checkpoints.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
game_factory: Callable,
|
|
metadata: GameMetadata,
|
|
run_dir: str | Path,
|
|
**hyperparams,
|
|
):
|
|
self.game_factory = game_factory
|
|
self.metadata = metadata
|
|
self.run_dir = Path(run_dir)
|
|
self.hp: dict = {**DEFAULTS, **hyperparams}
|
|
self.run_dir.mkdir(parents=True, exist_ok=True)
|
|
(self.run_dir / 'checkpoints').mkdir(exist_ok=True)
|
|
|
|
self.env = GameEnvironment(game_factory, metadata)
|
|
|
|
if metadata.board_size is None:
|
|
g = game_factory()
|
|
metadata.board_size = g.board_size
|
|
|
|
if metadata.character_set is None:
|
|
self._discover_character_set()
|
|
|
|
self.model, rationale = build_network(metadata, self.hp)
|
|
self.target_model, _ = build_network(metadata, self.hp)
|
|
self.target_model.load_state_dict(self.model.state_dict())
|
|
self.target_model.eval()
|
|
|
|
self.optimizer = optim.Adam(
|
|
self.model.parameters(), lr=self.hp['learning_rate']
|
|
)
|
|
self.lr_scheduler = optim.lr_scheduler.ExponentialLR(
|
|
self.optimizer, gamma=self.hp['lr_decay']
|
|
)
|
|
|
|
if self.hp['prioritize_experiences']:
|
|
self.memory = PrioritizedReplayMemory(self.hp['memory_capacity'])
|
|
else:
|
|
self.memory = ReplayMemory(self.hp['memory_capacity'])
|
|
|
|
self.epsilon: float = self.hp['epsilon']
|
|
self.total_steps: int = 0
|
|
|
|
self._save_config()
|
|
self._open_log(rationale)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public API
|
|
# ------------------------------------------------------------------
|
|
|
|
def train(self):
|
|
"""Run all training episodes and save checkpoints."""
|
|
for episode in range(1, self.hp['training_episodes'] + 1):
|
|
total_reward, steps, avg_loss = self._run_episode()
|
|
self.epsilon = max(
|
|
self.hp['epsilon_min'], self.epsilon * self.hp['epsilon_decay']
|
|
)
|
|
self.lr_scheduler.step()
|
|
self._log_episode(episode, total_reward, steps, avg_loss)
|
|
if episode % 100 == 0:
|
|
self._save_checkpoint(f'ep_{episode:04d}.pt')
|
|
self._save_checkpoint('final.pt')
|
|
|
|
def load_checkpoint(self, path: str | Path):
|
|
ckpt = torch.load(path, weights_only=True)
|
|
self.model.load_state_dict(ckpt['model_state_dict'])
|
|
self.target_model.load_state_dict(ckpt['model_state_dict'])
|
|
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
|
|
self.epsilon = ckpt['epsilon']
|
|
self.total_steps = ckpt['total_steps']
|
|
|
|
# ------------------------------------------------------------------
|
|
# Training loop internals
|
|
# ------------------------------------------------------------------
|
|
|
|
def _run_episode(self) -> tuple[float, int, float]:
|
|
state = self.env.reset()
|
|
total_reward = 0.0
|
|
total_loss = 0.0
|
|
loss_count = 0
|
|
|
|
for step in range(self.hp['max_turns_per_episode']):
|
|
state_t = torch.as_tensor(state, dtype=torch.float32)
|
|
action_idx = self._select_action(state_t)
|
|
action_key = self._idx_to_key(action_idx)
|
|
|
|
next_state, reward, done = self.env.step(action_key)
|
|
self.memory.push(state, action_idx, reward, next_state, done)
|
|
|
|
loss = self._train_step()
|
|
if loss is not None:
|
|
total_loss += loss
|
|
loss_count += 1
|
|
|
|
self.total_steps += 1
|
|
if self.total_steps % self.hp['target_update_freq'] == 0:
|
|
self.target_model.load_state_dict(self.model.state_dict())
|
|
|
|
state = next_state
|
|
total_reward += reward
|
|
if done:
|
|
break
|
|
|
|
avg_loss = total_loss / loss_count if loss_count else 0.0
|
|
return total_reward, step + 1, avg_loss
|
|
|
|
def _select_action(self, state_t: torch.Tensor) -> int:
|
|
if random.random() < self.epsilon:
|
|
return random.randrange(self.metadata.n_actions)
|
|
with torch.no_grad():
|
|
return int(self.model(state_t.unsqueeze(0)).argmax().item())
|
|
|
|
def _idx_to_key(self, idx: int) -> str | None:
|
|
if idx >= len(self.metadata.actions):
|
|
return None
|
|
return self.metadata.actions[idx]
|
|
|
|
def _train_step(self) -> float | None:
|
|
if len(self.memory) < self.hp['batch_size']:
|
|
return None
|
|
|
|
if self.hp['prioritize_experiences']:
|
|
assert isinstance(self.memory, PrioritizedReplayMemory)
|
|
experiences, indices, weights = self.memory.sample(self.hp['batch_size'])
|
|
weight_t = torch.as_tensor(weights, dtype=torch.float32)
|
|
else:
|
|
experiences = self.memory.sample(self.hp['batch_size'])
|
|
indices = None
|
|
weight_t = None
|
|
|
|
states = torch.as_tensor(
|
|
np.array([e.state for e in experiences]), dtype=torch.float32
|
|
)
|
|
actions = torch.as_tensor([e.action for e in experiences], dtype=torch.long)
|
|
rewards = torch.as_tensor([e.reward for e in experiences], dtype=torch.float32)
|
|
next_states = torch.as_tensor(
|
|
np.array([e.next_state for e in experiences]), dtype=torch.float32
|
|
)
|
|
dones = torch.as_tensor([e.done for e in experiences], dtype=torch.float32)
|
|
|
|
q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
with torch.no_grad():
|
|
next_q = self.target_model(next_states).max(1).values
|
|
targets = rewards + self.hp['gamma'] * next_q * (1.0 - dones)
|
|
|
|
element_loss = nn.functional.mse_loss(q_values, targets, reduction='none')
|
|
|
|
if weight_t is not None:
|
|
loss = (weight_t * element_loss).mean()
|
|
td_errors = (q_values - targets).detach().abs().numpy()
|
|
self.memory.update_priorities(indices, td_errors)
|
|
else:
|
|
loss = element_loss.mean()
|
|
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
return float(loss.item())
|
|
|
|
# ------------------------------------------------------------------
|
|
# Initialisation helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _discover_character_set(self):
|
|
chars = self.env.discover_character_set(self.hp['exploration_turns'])
|
|
self.metadata.character_set = chars
|
|
self._log_raw(
|
|
f"[INIT] character_set not specified — discovered {len(chars)} chars "
|
|
f"after {self.hp['exploration_turns']} exploration turns: {chars}"
|
|
)
|
|
|
|
def _save_config(self):
|
|
config_path = self.run_dir / 'config.toml'
|
|
config: dict = {}
|
|
if config_path.exists():
|
|
import tomllib
|
|
with open(config_path, 'rb') as f:
|
|
config = tomllib.load(f)
|
|
config['metadata'] = self.metadata.to_dict()
|
|
config['hyperparameters'] = self.hp
|
|
with open(config_path, 'wb') as f:
|
|
tomli_w.dump(config, f)
|
|
|
|
def _open_log(self, rationale: str):
|
|
self.log_path = self.run_dir / 'training.log'
|
|
with open(self.log_path, 'w') as f:
|
|
f.write(rationale + '\n')
|
|
|
|
def _log_raw(self, line: str):
|
|
with open(self.log_path, 'a') as f:
|
|
f.write(line + '\n')
|
|
|
|
def _log_episode(self, episode: int, total_reward: float, steps: int, avg_loss: float):
|
|
line = (
|
|
f"[EP {episode:04d}] total_reward={total_reward:.1f} "
|
|
f"steps={steps} epsilon={self.epsilon:.4f} avg_loss={avg_loss:.6f}"
|
|
)
|
|
self._log_raw(line)
|
|
|
|
def _save_checkpoint(self, name: str):
|
|
torch.save(
|
|
{
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'epsilon': self.epsilon,
|
|
'total_steps': self.total_steps,
|
|
},
|
|
self.run_dir / 'checkpoints' / name,
|
|
)
|