75 lines
2.6 KiB
Python
75 lines
2.6 KiB
Python
from __future__ import annotations
|
|
import random
|
|
from collections import deque
|
|
from typing import NamedTuple
|
|
import numpy as np
|
|
|
|
|
|
class Experience(NamedTuple):
|
|
state: np.ndarray
|
|
action: int
|
|
reward: float
|
|
next_state: np.ndarray
|
|
done: bool
|
|
|
|
|
|
class ReplayMemory:
|
|
"""Fixed-capacity ring buffer of experiences sampled uniformly."""
|
|
|
|
def __init__(self, capacity: int):
|
|
self.memory: deque[Experience] = deque(maxlen=capacity)
|
|
|
|
def push(self, state, action, reward, next_state, done):
|
|
self.memory.append(Experience(state, action, reward, next_state, done))
|
|
|
|
def sample(self, batch_size: int) -> list[Experience]:
|
|
return random.sample(self.memory, batch_size)
|
|
|
|
def __len__(self):
|
|
return len(self.memory)
|
|
|
|
|
|
class PrioritizedReplayMemory:
|
|
"""Experience replay buffer with priority-weighted sampling.
|
|
|
|
Experiences with higher TD-error are sampled more often (alpha controls
|
|
the strength of prioritization). Importance-sampling weights (beta) correct
|
|
for the resulting bias.
|
|
"""
|
|
|
|
def __init__(self, capacity: int, alpha: float = 0.6, beta: float = 0.4):
|
|
self.capacity = capacity
|
|
self.alpha = alpha
|
|
self.beta = beta
|
|
self.memory: list[Experience] = []
|
|
self.priorities: list[float] = []
|
|
self._pos = 0
|
|
|
|
def push(self, state, action, reward, next_state, done):
|
|
max_priority = max(self.priorities, default=1.0)
|
|
exp = Experience(state, action, reward, next_state, done)
|
|
if len(self.memory) < self.capacity:
|
|
self.memory.append(exp)
|
|
self.priorities.append(max_priority)
|
|
else:
|
|
self.memory[self._pos] = exp
|
|
self.priorities[self._pos] = max_priority
|
|
self._pos = (self._pos + 1) % self.capacity
|
|
|
|
def sample(self, batch_size: int) -> tuple[list[Experience], np.ndarray, np.ndarray]:
|
|
"""Returns (experiences, indices, importance_sampling_weights)."""
|
|
probs = np.array(self.priorities, dtype=np.float64) ** self.alpha
|
|
probs /= probs.sum()
|
|
indices = np.random.choice(len(self.memory), batch_size, p=probs)
|
|
weights = (len(self.memory) * probs[indices]) ** -self.beta
|
|
weights = (weights / weights.max()).astype(np.float32)
|
|
experiences = [self.memory[i] for i in indices]
|
|
return experiences, indices, weights
|
|
|
|
def update_priorities(self, indices: np.ndarray, td_errors: np.ndarray):
|
|
for idx, err in zip(indices, td_errors):
|
|
self.priorities[idx] = float(abs(err)) + 1e-6
|
|
|
|
def __len__(self):
|
|
return len(self.memory)
|