Files
retro-gamer/retro_gamer/log_parser.py
2026-06-22 16:41:31 -04:00

31 lines
917 B
Python

from __future__ import annotations
import re
from pathlib import Path
_LINE_RE = re.compile(
r'\[ep_(\d+)\]'
r'.*avg_reward=([+-]?\d+\.?\d*)'
r'.*avg_steps=(\d+\.?\d*)'
r'.*epsilon=(\d+\.?\d*)'
r'.*avg_loss=(\d+\.?\d*)'
)
def parse_checkpoints(log_path: Path) -> list[dict]:
"""Parse checkpoint lines from a training log. Returns a list of dicts
with keys: episode, avg_reward, avg_steps, epsilon, avg_loss."""
results = []
if not log_path.exists():
return results
for line in log_path.read_text().splitlines():
m = _LINE_RE.search(line)
if m:
results.append({
'episode': int(m.group(1)),
'avg_reward': float(m.group(2)),
'avg_steps': float(m.group(3)),
'epsilon': float(m.group(4)),
'avg_loss': float(m.group(5)),
})
return results