Updates across the board
This commit is contained in:
86
retro_gamer/dashboard.py
Normal file
86
retro_gamer/dashboard.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import sys
|
||||
from tqdm import tqdm
|
||||
|
||||
_CHART_HEIGHT = 22 # plotext chart area height in lines
|
||||
|
||||
|
||||
def _terminal_width() -> int:
|
||||
try:
|
||||
return min(os.get_terminal_size().columns, 220)
|
||||
except OSError:
|
||||
return 120
|
||||
|
||||
|
||||
def _build_charts(history: list[dict], width: int) -> str:
|
||||
import plotext as plt
|
||||
episodes = [d['episode'] for d in history]
|
||||
series = [
|
||||
("Epsilon", [d['epsilon'] for d in history]),
|
||||
("Avg Steps", [d['avg_steps'] for d in history]),
|
||||
("Avg Loss", [d['avg_loss'] for d in history]),
|
||||
("Avg Reward", [d['avg_reward'] for d in history]),
|
||||
]
|
||||
panel_w = max(width // len(series), 20)
|
||||
panels = []
|
||||
for title, values in series:
|
||||
plt.clf()
|
||||
plt.canvas_color("default")
|
||||
plt.axes_color("default")
|
||||
plt.ticks_color("default")
|
||||
if episodes:
|
||||
plt.plot(episodes, values, color="default")
|
||||
plt.title(title)
|
||||
plt.xlabel("Episode")
|
||||
plt.plotsize(panel_w, _CHART_HEIGHT)
|
||||
panels.append(plt.build().splitlines())
|
||||
|
||||
height = max(len(p) for p in panels)
|
||||
for p in panels:
|
||||
while len(p) < height:
|
||||
p.append(' ' * panel_w)
|
||||
return '\n'.join(''.join(row) for row in zip(*panels))
|
||||
|
||||
|
||||
class TrainingDashboard:
|
||||
"""Inline training display: a plotext chart block that redraws in place
|
||||
above a tqdm per-episode progress bar."""
|
||||
|
||||
def __init__(self, total_episodes: int, start_episode: int, history: list[dict]):
|
||||
self._total = total_episodes
|
||||
self._history = list(history)
|
||||
self._rendered = False # have we drawn charts yet?
|
||||
|
||||
already_done = start_episode - 1
|
||||
self._bar = tqdm(
|
||||
initial=already_done,
|
||||
total=total_episodes,
|
||||
unit='ep',
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
self._draw()
|
||||
|
||||
def on_episode(self) -> None:
|
||||
self._bar.update(1)
|
||||
|
||||
def on_checkpoint(self, stats: dict) -> None:
|
||||
self._history.append(stats)
|
||||
self._bar.clear()
|
||||
self._draw()
|
||||
self._bar.refresh()
|
||||
|
||||
def close(self) -> None:
|
||||
self._bar.close()
|
||||
|
||||
def _draw(self) -> None:
|
||||
chart = _build_charts(self._history, _terminal_width())
|
||||
n_lines = len(chart.splitlines())
|
||||
if self._rendered:
|
||||
sys.stdout.write(f'\033[{n_lines}A')
|
||||
sys.stdout.write(chart)
|
||||
if not chart.endswith('\n'):
|
||||
sys.stdout.write('\n')
|
||||
sys.stdout.flush()
|
||||
self._rendered = True
|
||||
Reference in New Issue
Block a user