Initial commit
This commit is contained in:
255
retro_gamer/trainer.py
Normal file
255
retro_gamer/trainer.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from __future__ import annotations
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import tomli_w
|
||||
|
||||
from retro_gamer.metadata import GameMetadata
|
||||
from retro_gamer.env import GameEnvironment
|
||||
from retro_gamer.network import build_network
|
||||
from retro_gamer.memory import ReplayMemory, PrioritizedReplayMemory
|
||||
|
||||
DEFAULTS: dict = {
|
||||
'learning_rate': 1e-3,
|
||||
'lr_decay': 0.995,
|
||||
'gamma': 0.99,
|
||||
'epsilon': 1.0,
|
||||
'epsilon_decay': 0.995,
|
||||
'epsilon_min': 0.05,
|
||||
'batch_size': 64,
|
||||
'memory_capacity': 10_000,
|
||||
'target_update_freq': 100,
|
||||
'training_episodes': 1_000,
|
||||
'n_layers': 2,
|
||||
'layer_size': 128,
|
||||
'prioritize_experiences': False,
|
||||
'exploration_turns': 200,
|
||||
'unknown_character_strategy': 'ignore',
|
||||
'max_turns_per_episode': 2_000,
|
||||
}
|
||||
|
||||
|
||||
class DQNTrainer:
|
||||
"""Trains a deep Q-network agent to play a retro game.
|
||||
|
||||
On initialization the trainer:
|
||||
1. Discovers the character set (if not already specified in metadata).
|
||||
2. Builds the Q-network and logs the full architecture with rationale.
|
||||
3. Saves config.toml and starts training.log in run_dir.
|
||||
|
||||
Call train() to run all episodes and save checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
game_factory: Callable,
|
||||
metadata: GameMetadata,
|
||||
run_dir: str | Path,
|
||||
**hyperparams,
|
||||
):
|
||||
self.game_factory = game_factory
|
||||
self.metadata = metadata
|
||||
self.run_dir = Path(run_dir)
|
||||
self.hp: dict = {**DEFAULTS, **hyperparams}
|
||||
self.run_dir.mkdir(parents=True, exist_ok=True)
|
||||
(self.run_dir / 'checkpoints').mkdir(exist_ok=True)
|
||||
|
||||
self.env = GameEnvironment(game_factory, metadata)
|
||||
|
||||
if metadata.board_size is None:
|
||||
g = game_factory()
|
||||
metadata.board_size = g.board_size
|
||||
|
||||
if metadata.character_set is None:
|
||||
self._discover_character_set()
|
||||
|
||||
self.model, rationale = build_network(metadata, self.hp)
|
||||
self.target_model, _ = build_network(metadata, self.hp)
|
||||
self.target_model.load_state_dict(self.model.state_dict())
|
||||
self.target_model.eval()
|
||||
|
||||
self.optimizer = optim.Adam(
|
||||
self.model.parameters(), lr=self.hp['learning_rate']
|
||||
)
|
||||
self.lr_scheduler = optim.lr_scheduler.ExponentialLR(
|
||||
self.optimizer, gamma=self.hp['lr_decay']
|
||||
)
|
||||
|
||||
if self.hp['prioritize_experiences']:
|
||||
self.memory = PrioritizedReplayMemory(self.hp['memory_capacity'])
|
||||
else:
|
||||
self.memory = ReplayMemory(self.hp['memory_capacity'])
|
||||
|
||||
self.epsilon: float = self.hp['epsilon']
|
||||
self.total_steps: int = 0
|
||||
|
||||
self._save_config()
|
||||
self._open_log(rationale)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def train(self):
|
||||
"""Run all training episodes and save checkpoints."""
|
||||
for episode in range(1, self.hp['training_episodes'] + 1):
|
||||
total_reward, steps, avg_loss = self._run_episode()
|
||||
self.epsilon = max(
|
||||
self.hp['epsilon_min'], self.epsilon * self.hp['epsilon_decay']
|
||||
)
|
||||
self.lr_scheduler.step()
|
||||
self._log_episode(episode, total_reward, steps, avg_loss)
|
||||
if episode % 100 == 0:
|
||||
self._save_checkpoint(f'ep_{episode:04d}.pt')
|
||||
self._save_checkpoint('final.pt')
|
||||
|
||||
def load_checkpoint(self, path: str | Path):
|
||||
ckpt = torch.load(path, weights_only=True)
|
||||
self.model.load_state_dict(ckpt['model_state_dict'])
|
||||
self.target_model.load_state_dict(ckpt['model_state_dict'])
|
||||
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
|
||||
self.epsilon = ckpt['epsilon']
|
||||
self.total_steps = ckpt['total_steps']
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Training loop internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_episode(self) -> tuple[float, int, float]:
|
||||
state = self.env.reset()
|
||||
total_reward = 0.0
|
||||
total_loss = 0.0
|
||||
loss_count = 0
|
||||
|
||||
for step in range(self.hp['max_turns_per_episode']):
|
||||
state_t = torch.as_tensor(state, dtype=torch.float32)
|
||||
action_idx = self._select_action(state_t)
|
||||
action_key = self._idx_to_key(action_idx)
|
||||
|
||||
next_state, reward, done = self.env.step(action_key)
|
||||
self.memory.push(state, action_idx, reward, next_state, done)
|
||||
|
||||
loss = self._train_step()
|
||||
if loss is not None:
|
||||
total_loss += loss
|
||||
loss_count += 1
|
||||
|
||||
self.total_steps += 1
|
||||
if self.total_steps % self.hp['target_update_freq'] == 0:
|
||||
self.target_model.load_state_dict(self.model.state_dict())
|
||||
|
||||
state = next_state
|
||||
total_reward += reward
|
||||
if done:
|
||||
break
|
||||
|
||||
avg_loss = total_loss / loss_count if loss_count else 0.0
|
||||
return total_reward, step + 1, avg_loss
|
||||
|
||||
def _select_action(self, state_t: torch.Tensor) -> int:
|
||||
if random.random() < self.epsilon:
|
||||
return random.randrange(self.metadata.n_actions)
|
||||
with torch.no_grad():
|
||||
return int(self.model(state_t.unsqueeze(0)).argmax().item())
|
||||
|
||||
def _idx_to_key(self, idx: int) -> str | None:
|
||||
if idx >= len(self.metadata.actions):
|
||||
return None
|
||||
return self.metadata.actions[idx]
|
||||
|
||||
def _train_step(self) -> float | None:
|
||||
if len(self.memory) < self.hp['batch_size']:
|
||||
return None
|
||||
|
||||
if self.hp['prioritize_experiences']:
|
||||
assert isinstance(self.memory, PrioritizedReplayMemory)
|
||||
experiences, indices, weights = self.memory.sample(self.hp['batch_size'])
|
||||
weight_t = torch.as_tensor(weights, dtype=torch.float32)
|
||||
else:
|
||||
experiences = self.memory.sample(self.hp['batch_size'])
|
||||
indices = None
|
||||
weight_t = None
|
||||
|
||||
states = torch.as_tensor(
|
||||
np.array([e.state for e in experiences]), dtype=torch.float32
|
||||
)
|
||||
actions = torch.as_tensor([e.action for e in experiences], dtype=torch.long)
|
||||
rewards = torch.as_tensor([e.reward for e in experiences], dtype=torch.float32)
|
||||
next_states = torch.as_tensor(
|
||||
np.array([e.next_state for e in experiences]), dtype=torch.float32
|
||||
)
|
||||
dones = torch.as_tensor([e.done for e in experiences], dtype=torch.float32)
|
||||
|
||||
q_values = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
with torch.no_grad():
|
||||
next_q = self.target_model(next_states).max(1).values
|
||||
targets = rewards + self.hp['gamma'] * next_q * (1.0 - dones)
|
||||
|
||||
element_loss = nn.functional.mse_loss(q_values, targets, reduction='none')
|
||||
|
||||
if weight_t is not None:
|
||||
loss = (weight_t * element_loss).mean()
|
||||
td_errors = (q_values - targets).detach().abs().numpy()
|
||||
self.memory.update_priorities(indices, td_errors)
|
||||
else:
|
||||
loss = element_loss.mean()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
return float(loss.item())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Initialisation helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _discover_character_set(self):
|
||||
chars = self.env.discover_character_set(self.hp['exploration_turns'])
|
||||
self.metadata.character_set = chars
|
||||
self._log_raw(
|
||||
f"[INIT] character_set not specified — discovered {len(chars)} chars "
|
||||
f"after {self.hp['exploration_turns']} exploration turns: {chars}"
|
||||
)
|
||||
|
||||
def _save_config(self):
|
||||
config_path = self.run_dir / 'config.toml'
|
||||
config: dict = {}
|
||||
if config_path.exists():
|
||||
import tomllib
|
||||
with open(config_path, 'rb') as f:
|
||||
config = tomllib.load(f)
|
||||
config['metadata'] = self.metadata.to_dict()
|
||||
config['hyperparameters'] = self.hp
|
||||
with open(config_path, 'wb') as f:
|
||||
tomli_w.dump(config, f)
|
||||
|
||||
def _open_log(self, rationale: str):
|
||||
self.log_path = self.run_dir / 'training.log'
|
||||
with open(self.log_path, 'w') as f:
|
||||
f.write(rationale + '\n')
|
||||
|
||||
def _log_raw(self, line: str):
|
||||
with open(self.log_path, 'a') as f:
|
||||
f.write(line + '\n')
|
||||
|
||||
def _log_episode(self, episode: int, total_reward: float, steps: int, avg_loss: float):
|
||||
line = (
|
||||
f"[EP {episode:04d}] total_reward={total_reward:.1f} "
|
||||
f"steps={steps} epsilon={self.epsilon:.4f} avg_loss={avg_loss:.6f}"
|
||||
)
|
||||
self._log_raw(line)
|
||||
|
||||
def _save_checkpoint(self, name: str):
|
||||
torch.save(
|
||||
{
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'epsilon': self.epsilon,
|
||||
'total_steps': self.total_steps,
|
||||
},
|
||||
self.run_dir / 'checkpoints' / name,
|
||||
)
|
||||
Reference in New Issue
Block a user