from __future__ import annotations from pathlib import Path from retro_gamer.log_parser import parse_checkpoints def plot_run(log_path: Path, output: Path | None = None) -> None: """Generate training metric plots from a training.log file. Displays an interactive window unless *output* is given, in which case the figure is saved to that path (PNG, PDF, SVG, etc.). """ import matplotlib.pyplot as plt import seaborn as sns data = parse_checkpoints(log_path) if not data: raise ValueError(f"No checkpoint data found in {log_path}") episodes = [d['episode'] for d in data] rewards = [d['avg_reward'] for d in data] steps = [d['avg_steps'] for d in data] losses = [d['avg_loss'] for d in data] epsilons = [d['epsilon'] for d in data] sns.set_theme(style='darkgrid') fig, axes = plt.subplots(2, 2, figsize=(12, 7)) (ax_reward, ax_steps), (ax_loss, ax_epsilon) = axes ax_reward.plot(episodes, rewards) ax_reward.axhline(0, color='gray', linestyle='--', linewidth=0.8, alpha=0.6) ax_reward.set_title('Average Reward') ax_reward.set_xlabel('Episode') ax_steps.plot(episodes, steps, color='C1') ax_steps.set_title('Average Steps') ax_steps.set_xlabel('Episode') ax_loss.plot(episodes, losses, color='C2') ax_loss.set_yscale('log') ax_loss.set_title('Average Loss') ax_loss.set_xlabel('Episode') ax_epsilon.plot(episodes, epsilons, color='C3') ax_epsilon.set_title('Epsilon (exploration rate)') ax_epsilon.set_xlabel('Episode') ax_epsilon.set_ylim(0, 1) fig.suptitle(f'Training: {log_path.parent.name}', fontsize=13) plt.tight_layout() if output: plt.savefig(output, dpi=150, bbox_inches='tight') print(f"Plot saved to {output}") else: plt.show()