Initial commit
This commit is contained in:
74
retro_gamer/memory.py
Normal file
74
retro_gamer/memory.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
import random
|
||||
from collections import deque
|
||||
from typing import NamedTuple
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Experience(NamedTuple):
|
||||
state: np.ndarray
|
||||
action: int
|
||||
reward: float
|
||||
next_state: np.ndarray
|
||||
done: bool
|
||||
|
||||
|
||||
class ReplayMemory:
|
||||
"""Fixed-capacity ring buffer of experiences sampled uniformly."""
|
||||
|
||||
def __init__(self, capacity: int):
|
||||
self.memory: deque[Experience] = deque(maxlen=capacity)
|
||||
|
||||
def push(self, state, action, reward, next_state, done):
|
||||
self.memory.append(Experience(state, action, reward, next_state, done))
|
||||
|
||||
def sample(self, batch_size: int) -> list[Experience]:
|
||||
return random.sample(self.memory, batch_size)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
|
||||
|
||||
class PrioritizedReplayMemory:
|
||||
"""Experience replay buffer with priority-weighted sampling.
|
||||
|
||||
Experiences with higher TD-error are sampled more often (alpha controls
|
||||
the strength of prioritization). Importance-sampling weights (beta) correct
|
||||
for the resulting bias.
|
||||
"""
|
||||
|
||||
def __init__(self, capacity: int, alpha: float = 0.6, beta: float = 0.4):
|
||||
self.capacity = capacity
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.memory: list[Experience] = []
|
||||
self.priorities: list[float] = []
|
||||
self._pos = 0
|
||||
|
||||
def push(self, state, action, reward, next_state, done):
|
||||
max_priority = max(self.priorities, default=1.0)
|
||||
exp = Experience(state, action, reward, next_state, done)
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append(exp)
|
||||
self.priorities.append(max_priority)
|
||||
else:
|
||||
self.memory[self._pos] = exp
|
||||
self.priorities[self._pos] = max_priority
|
||||
self._pos = (self._pos + 1) % self.capacity
|
||||
|
||||
def sample(self, batch_size: int) -> tuple[list[Experience], np.ndarray, np.ndarray]:
|
||||
"""Returns (experiences, indices, importance_sampling_weights)."""
|
||||
probs = np.array(self.priorities, dtype=np.float64) ** self.alpha
|
||||
probs /= probs.sum()
|
||||
indices = np.random.choice(len(self.memory), batch_size, p=probs)
|
||||
weights = (len(self.memory) * probs[indices]) ** -self.beta
|
||||
weights = (weights / weights.max()).astype(np.float32)
|
||||
experiences = [self.memory[i] for i in indices]
|
||||
return experiences, indices, weights
|
||||
|
||||
def update_priorities(self, indices: np.ndarray, td_errors: np.ndarray):
|
||||
for idx, err in zip(indices, td_errors):
|
||||
self.priorities[idx] = float(abs(err)) + 1e-6
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
Reference in New Issue
Block a user