31 lines
917 B
Python
31 lines
917 B
Python
from __future__ import annotations
|
|
import re
|
|
from pathlib import Path
|
|
|
|
_LINE_RE = re.compile(
|
|
r'\[ep_(\d+)\]'
|
|
r'.*avg_reward=([+-]?\d+\.?\d*)'
|
|
r'.*avg_steps=(\d+\.?\d*)'
|
|
r'.*epsilon=(\d+\.?\d*)'
|
|
r'.*avg_loss=(\d+\.?\d*)'
|
|
)
|
|
|
|
|
|
def parse_checkpoints(log_path: Path) -> list[dict]:
|
|
"""Parse checkpoint lines from a training log. Returns a list of dicts
|
|
with keys: episode, avg_reward, avg_steps, epsilon, avg_loss."""
|
|
results = []
|
|
if not log_path.exists():
|
|
return results
|
|
for line in log_path.read_text().splitlines():
|
|
m = _LINE_RE.search(line)
|
|
if m:
|
|
results.append({
|
|
'episode': int(m.group(1)),
|
|
'avg_reward': float(m.group(2)),
|
|
'avg_steps': float(m.group(3)),
|
|
'epsilon': float(m.group(4)),
|
|
'avg_loss': float(m.group(5)),
|
|
})
|
|
return results
|