Files
retro-gamer/retro_gamer/plotter.py
2026-06-22 16:41:31 -04:00

56 lines
1.8 KiB
Python

from __future__ import annotations
from pathlib import Path
from retro_gamer.log_parser import parse_checkpoints
def plot_run(log_path: Path, output: Path | None = None) -> None:
"""Generate training metric plots from a training.log file.
Displays an interactive window unless *output* is given, in which case
the figure is saved to that path (PNG, PDF, SVG, etc.).
"""
import matplotlib.pyplot as plt
import seaborn as sns
data = parse_checkpoints(log_path)
if not data:
raise ValueError(f"No checkpoint data found in {log_path}")
episodes = [d['episode'] for d in data]
rewards = [d['avg_reward'] for d in data]
steps = [d['avg_steps'] for d in data]
losses = [d['avg_loss'] for d in data]
epsilons = [d['epsilon'] for d in data]
sns.set_theme(style='darkgrid')
fig, axes = plt.subplots(2, 2, figsize=(12, 7))
(ax_reward, ax_steps), (ax_loss, ax_epsilon) = axes
ax_reward.plot(episodes, rewards)
ax_reward.axhline(0, color='gray', linestyle='--', linewidth=0.8, alpha=0.6)
ax_reward.set_title('Average Reward')
ax_reward.set_xlabel('Episode')
ax_steps.plot(episodes, steps, color='C1')
ax_steps.set_title('Average Steps')
ax_steps.set_xlabel('Episode')
ax_loss.plot(episodes, losses, color='C2')
ax_loss.set_yscale('log')
ax_loss.set_title('Average Loss')
ax_loss.set_xlabel('Episode')
ax_epsilon.plot(episodes, epsilons, color='C3')
ax_epsilon.set_title('Epsilon (exploration rate)')
ax_epsilon.set_xlabel('Episode')
ax_epsilon.set_ylim(0, 1)
fig.suptitle(f'Training: {log_path.parent.name}', fontsize=13)
plt.tight_layout()
if output:
plt.savefig(output, dpi=150, bbox_inches='tight')
print(f"Plot saved to {output}")
else:
plt.show()