87 lines
2.5 KiB
Python
87 lines
2.5 KiB
Python
from __future__ import annotations
|
|
import os
|
|
import sys
|
|
from tqdm import tqdm
|
|
|
|
_CHART_HEIGHT = 22 # plotext chart area height in lines
|
|
|
|
|
|
def _terminal_width() -> int:
|
|
try:
|
|
return min(os.get_terminal_size().columns, 220)
|
|
except OSError:
|
|
return 120
|
|
|
|
|
|
def _build_charts(history: list[dict], width: int) -> str:
|
|
import plotext as plt
|
|
episodes = [d['episode'] for d in history]
|
|
series = [
|
|
("Epsilon", [d['epsilon'] for d in history]),
|
|
("Avg Steps", [d['avg_steps'] for d in history]),
|
|
("Avg Loss", [d['avg_loss'] for d in history]),
|
|
("Avg Reward", [d['avg_reward'] for d in history]),
|
|
]
|
|
panel_w = max(width // len(series), 20)
|
|
panels = []
|
|
for title, values in series:
|
|
plt.clf()
|
|
plt.canvas_color("default")
|
|
plt.axes_color("default")
|
|
plt.ticks_color("default")
|
|
if episodes:
|
|
plt.plot(episodes, values, color="default")
|
|
plt.title(title)
|
|
plt.xlabel("Episode")
|
|
plt.plotsize(panel_w, _CHART_HEIGHT)
|
|
panels.append(plt.build().splitlines())
|
|
|
|
height = max(len(p) for p in panels)
|
|
for p in panels:
|
|
while len(p) < height:
|
|
p.append(' ' * panel_w)
|
|
return '\n'.join(''.join(row) for row in zip(*panels))
|
|
|
|
|
|
class TrainingDashboard:
|
|
"""Inline training display: a plotext chart block that redraws in place
|
|
above a tqdm per-episode progress bar."""
|
|
|
|
def __init__(self, total_episodes: int, start_episode: int, history: list[dict]):
|
|
self._total = total_episodes
|
|
self._history = list(history)
|
|
self._rendered = False # have we drawn charts yet?
|
|
|
|
already_done = start_episode - 1
|
|
self._bar = tqdm(
|
|
initial=already_done,
|
|
total=total_episodes,
|
|
unit='ep',
|
|
dynamic_ncols=True,
|
|
)
|
|
|
|
self._draw()
|
|
|
|
def on_episode(self) -> None:
|
|
self._bar.update(1)
|
|
|
|
def on_checkpoint(self, stats: dict) -> None:
|
|
self._history.append(stats)
|
|
self._bar.clear()
|
|
self._draw()
|
|
self._bar.refresh()
|
|
|
|
def close(self) -> None:
|
|
self._bar.close()
|
|
|
|
def _draw(self) -> None:
|
|
chart = _build_charts(self._history, _terminal_width())
|
|
n_lines = len(chart.splitlines())
|
|
if self._rendered:
|
|
sys.stdout.write(f'\033[{n_lines}A')
|
|
sys.stdout.write(chart)
|
|
if not chart.endswith('\n'):
|
|
sys.stdout.write('\n')
|
|
sys.stdout.flush()
|
|
self._rendered = True
|