Updates across the board
This commit is contained in:
55
retro_gamer/plotter.py
Normal file
55
retro_gamer/plotter.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from retro_gamer.log_parser import parse_checkpoints
|
||||
|
||||
|
||||
def plot_run(log_path: Path, output: Path | None = None) -> None:
|
||||
"""Generate training metric plots from a training.log file.
|
||||
|
||||
Displays an interactive window unless *output* is given, in which case
|
||||
the figure is saved to that path (PNG, PDF, SVG, etc.).
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
data = parse_checkpoints(log_path)
|
||||
if not data:
|
||||
raise ValueError(f"No checkpoint data found in {log_path}")
|
||||
|
||||
episodes = [d['episode'] for d in data]
|
||||
rewards = [d['avg_reward'] for d in data]
|
||||
steps = [d['avg_steps'] for d in data]
|
||||
losses = [d['avg_loss'] for d in data]
|
||||
epsilons = [d['epsilon'] for d in data]
|
||||
|
||||
sns.set_theme(style='darkgrid')
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 7))
|
||||
(ax_reward, ax_steps), (ax_loss, ax_epsilon) = axes
|
||||
|
||||
ax_reward.plot(episodes, rewards)
|
||||
ax_reward.axhline(0, color='gray', linestyle='--', linewidth=0.8, alpha=0.6)
|
||||
ax_reward.set_title('Average Reward')
|
||||
ax_reward.set_xlabel('Episode')
|
||||
|
||||
ax_steps.plot(episodes, steps, color='C1')
|
||||
ax_steps.set_title('Average Steps')
|
||||
ax_steps.set_xlabel('Episode')
|
||||
|
||||
ax_loss.plot(episodes, losses, color='C2')
|
||||
ax_loss.set_yscale('log')
|
||||
ax_loss.set_title('Average Loss')
|
||||
ax_loss.set_xlabel('Episode')
|
||||
|
||||
ax_epsilon.plot(episodes, epsilons, color='C3')
|
||||
ax_epsilon.set_title('Epsilon (exploration rate)')
|
||||
ax_epsilon.set_xlabel('Episode')
|
||||
ax_epsilon.set_ylim(0, 1)
|
||||
|
||||
fig.suptitle(f'Training: {log_path.parent.name}', fontsize=13)
|
||||
plt.tight_layout()
|
||||
|
||||
if output:
|
||||
plt.savefig(output, dpi=150, bbox_inches='tight')
|
||||
print(f"Plot saved to {output}")
|
||||
else:
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user