Files
retro-gamer/retro_gamer/memory.py
Chris Proctor 5ca97dc5d0 Initial commit
2026-05-08 14:07:17 -04:00

75 lines
2.6 KiB
Python

from __future__ import annotations
import random
from collections import deque
from typing import NamedTuple
import numpy as np
class Experience(NamedTuple):
state: np.ndarray
action: int
reward: float
next_state: np.ndarray
done: bool
class ReplayMemory:
"""Fixed-capacity ring buffer of experiences sampled uniformly."""
def __init__(self, capacity: int):
self.memory: deque[Experience] = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.memory.append(Experience(state, action, reward, next_state, done))
def sample(self, batch_size: int) -> list[Experience]:
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
class PrioritizedReplayMemory:
"""Experience replay buffer with priority-weighted sampling.
Experiences with higher TD-error are sampled more often (alpha controls
the strength of prioritization). Importance-sampling weights (beta) correct
for the resulting bias.
"""
def __init__(self, capacity: int, alpha: float = 0.6, beta: float = 0.4):
self.capacity = capacity
self.alpha = alpha
self.beta = beta
self.memory: list[Experience] = []
self.priorities: list[float] = []
self._pos = 0
def push(self, state, action, reward, next_state, done):
max_priority = max(self.priorities, default=1.0)
exp = Experience(state, action, reward, next_state, done)
if len(self.memory) < self.capacity:
self.memory.append(exp)
self.priorities.append(max_priority)
else:
self.memory[self._pos] = exp
self.priorities[self._pos] = max_priority
self._pos = (self._pos + 1) % self.capacity
def sample(self, batch_size: int) -> tuple[list[Experience], np.ndarray, np.ndarray]:
"""Returns (experiences, indices, importance_sampling_weights)."""
probs = np.array(self.priorities, dtype=np.float64) ** self.alpha
probs /= probs.sum()
indices = np.random.choice(len(self.memory), batch_size, p=probs)
weights = (len(self.memory) * probs[indices]) ** -self.beta
weights = (weights / weights.max()).astype(np.float32)
experiences = [self.memory[i] for i in indices]
return experiences, indices, weights
def update_priorities(self, indices: np.ndarray, td_errors: np.ndarray):
for idx, err in zip(indices, td_errors):
self.priorities[idx] = float(abs(err)) + 1e-6
def __len__(self):
return len(self.memory)