Updates across the board
This commit is contained in:
30
retro_gamer/log_parser.py
Normal file
30
retro_gamer/log_parser.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
_LINE_RE = re.compile(
|
||||
r'\[ep_(\d+)\]'
|
||||
r'.*avg_reward=([+-]?\d+\.?\d*)'
|
||||
r'.*avg_steps=(\d+\.?\d*)'
|
||||
r'.*epsilon=(\d+\.?\d*)'
|
||||
r'.*avg_loss=(\d+\.?\d*)'
|
||||
)
|
||||
|
||||
|
||||
def parse_checkpoints(log_path: Path) -> list[dict]:
|
||||
"""Parse checkpoint lines from a training log. Returns a list of dicts
|
||||
with keys: episode, avg_reward, avg_steps, epsilon, avg_loss."""
|
||||
results = []
|
||||
if not log_path.exists():
|
||||
return results
|
||||
for line in log_path.read_text().splitlines():
|
||||
m = _LINE_RE.search(line)
|
||||
if m:
|
||||
results.append({
|
||||
'episode': int(m.group(1)),
|
||||
'avg_reward': float(m.group(2)),
|
||||
'avg_steps': float(m.group(3)),
|
||||
'epsilon': float(m.group(4)),
|
||||
'avg_loss': float(m.group(5)),
|
||||
})
|
||||
return results
|
||||
Reference in New Issue
Block a user