Updates across the board

This commit is contained in:
Chris Proctor
2026-06-22 16:41:31 -04:00
parent 5ca97dc5d0
commit 73624d1a0c
33 changed files with 3104 additions and 643 deletions

86
retro_gamer/dashboard.py Normal file
View File

@@ -0,0 +1,86 @@
from __future__ import annotations
import os
import sys
from tqdm import tqdm
_CHART_HEIGHT = 22 # plotext chart area height in lines
def _terminal_width() -> int:
try:
return min(os.get_terminal_size().columns, 220)
except OSError:
return 120
def _build_charts(history: list[dict], width: int) -> str:
import plotext as plt
episodes = [d['episode'] for d in history]
series = [
("Epsilon", [d['epsilon'] for d in history]),
("Avg Steps", [d['avg_steps'] for d in history]),
("Avg Loss", [d['avg_loss'] for d in history]),
("Avg Reward", [d['avg_reward'] for d in history]),
]
panel_w = max(width // len(series), 20)
panels = []
for title, values in series:
plt.clf()
plt.canvas_color("default")
plt.axes_color("default")
plt.ticks_color("default")
if episodes:
plt.plot(episodes, values, color="default")
plt.title(title)
plt.xlabel("Episode")
plt.plotsize(panel_w, _CHART_HEIGHT)
panels.append(plt.build().splitlines())
height = max(len(p) for p in panels)
for p in panels:
while len(p) < height:
p.append(' ' * panel_w)
return '\n'.join(''.join(row) for row in zip(*panels))
class TrainingDashboard:
"""Inline training display: a plotext chart block that redraws in place
above a tqdm per-episode progress bar."""
def __init__(self, total_episodes: int, start_episode: int, history: list[dict]):
self._total = total_episodes
self._history = list(history)
self._rendered = False # have we drawn charts yet?
already_done = start_episode - 1
self._bar = tqdm(
initial=already_done,
total=total_episodes,
unit='ep',
dynamic_ncols=True,
)
self._draw()
def on_episode(self) -> None:
self._bar.update(1)
def on_checkpoint(self, stats: dict) -> None:
self._history.append(stats)
self._bar.clear()
self._draw()
self._bar.refresh()
def close(self) -> None:
self._bar.close()
def _draw(self) -> None:
chart = _build_charts(self._history, _terminal_width())
n_lines = len(chart.splitlines())
if self._rendered:
sys.stdout.write(f'\033[{n_lines}A')
sys.stdout.write(chart)
if not chart.endswith('\n'):
sys.stdout.write('\n')
sys.stdout.flush()
self._rendered = True