from __future__ import annotations import random from collections import deque from typing import NamedTuple import numpy as np class Experience(NamedTuple): state: np.ndarray action: int reward: float next_state: np.ndarray done: bool class ReplayMemory: """Fixed-capacity ring buffer of experiences sampled uniformly.""" def __init__(self, capacity: int): self.memory: deque[Experience] = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.memory.append(Experience(state, action, reward, next_state, done)) def sample(self, batch_size: int) -> list[Experience]: return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory) class PrioritizedReplayMemory: """Experience replay buffer with priority-weighted sampling. Experiences with higher TD-error are sampled more often (alpha controls the strength of prioritization). Importance-sampling weights (beta) correct for the resulting bias. """ def __init__(self, capacity: int, alpha: float = 0.6, beta: float = 0.4): self.capacity = capacity self.alpha = alpha self.beta = beta self.memory: list[Experience] = [] self.priorities: list[float] = [] self._pos = 0 def push(self, state, action, reward, next_state, done): max_priority = max(self.priorities, default=1.0) exp = Experience(state, action, reward, next_state, done) if len(self.memory) < self.capacity: self.memory.append(exp) self.priorities.append(max_priority) else: self.memory[self._pos] = exp self.priorities[self._pos] = max_priority self._pos = (self._pos + 1) % self.capacity def sample(self, batch_size: int) -> tuple[list[Experience], np.ndarray, np.ndarray]: """Returns (experiences, indices, importance_sampling_weights).""" probs = np.array(self.priorities, dtype=np.float64) ** self.alpha probs /= probs.sum() indices = np.random.choice(len(self.memory), batch_size, p=probs) weights = (len(self.memory) * probs[indices]) ** -self.beta weights = (weights / weights.max()).astype(np.float32) experiences = [self.memory[i] for i in indices] return experiences, indices, weights def update_priorities(self, indices: np.ndarray, td_errors: np.ndarray): for idx, err in zip(indices, td_errors): self.priorities[idx] = float(abs(err)) + 1e-6 def __len__(self): return len(self.memory)