Files
Mortdecai/training/scripts/generate_training_chart.py
T
Mortdecai d9acb653fe Fix chart labels, add version history table to README
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 15:48:35 -04:00

124 lines
4.8 KiB
Python

#!/usr/bin/env python3
"""
Generate SVG training history chart for the Gitea README.
X-axis: Model version
Y-axis: Training examples (bar) and inverse loss (line)
"""
import json
from pathlib import Path
OUTPUT = Path(__file__).resolve().parent.parent.parent / "branding" / "training_progress.svg"
# Historical data from training runs
VERSIONS = [
{"version": "0.1.0", "examples": 500, "loss": 2.10, "label": "seed only"},
{"version": "0.2.0", "examples": 1200, "loss": 1.45, "label": "+entities"},
{"version": "0.3.0", "examples": 2100, "loss": 0.82, "label": "+errors"},
{"version": "0.4.0", "examples": 3175, "loss": 0.35, "label": "+tools"},
{"version": "0.5.0", "examples": 4358, "loss": 0.16, "label": "+plugins"},
]
# Chart dimensions
W = 700
H = 400
PAD_L = 70
PAD_R = 30
PAD_T = 40
PAD_B = 80
PLOT_W = W - PAD_L - PAD_R
PLOT_H = H - PAD_T - PAD_B
# Colors
BG = "#111111"
GRID = "#252525"
TEXT = "#999999"
BAR_COLOR = "#D35400"
LINE_COLOR = "#4caf50"
LABEL_COLOR = "#e0e0e0"
def generate_svg():
max_examples = max(v["examples"] for v in VERSIONS) * 1.15
max_inv_loss = max(1.0 / v["loss"] for v in VERSIONS) * 1.15
n = len(VERSIONS)
bar_w = PLOT_W / n * 0.6
gap = PLOT_W / n
svg = f"""<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {W} {H}" width="{W}" height="{H}">
<rect width="{W}" height="{H}" fill="{BG}" rx="8"/>
<!-- Title -->
<text x="{W/2}" y="25" fill="{LABEL_COLOR}" font-family="monospace" font-size="16" text-anchor="middle" font-weight="bold">Mortdecai Training Progress</text>
<!-- Grid lines -->
"""
# Y-axis grid (examples)
for i in range(5):
y = PAD_T + PLOT_H - (i / 4 * PLOT_H)
val = int(max_examples * i / 4)
svg += f'<line x1="{PAD_L}" y1="{y}" x2="{W-PAD_R}" y2="{y}" stroke="{GRID}" stroke-width="0.5"/>\n'
svg += f'<text x="{PAD_L-5}" y="{y+4}" fill="{TEXT}" font-family="monospace" font-size="10" text-anchor="end">{val:,}</text>\n'
# Bars (training examples)
for i, v in enumerate(VERSIONS):
cx = PAD_L + gap * i + gap / 2
bh = (v["examples"] / max_examples) * PLOT_H
by = PAD_T + PLOT_H - bh
svg += f'<rect x="{cx - bar_w/2}" y="{by}" width="{bar_w}" height="{bh}" fill="{BAR_COLOR}" rx="3" opacity="0.85"/>\n'
svg += f'<text x="{cx}" y="{by - 8}" fill="{BAR_COLOR}" font-family="monospace" font-size="11" text-anchor="middle" font-weight="bold">{v["examples"]:,}</text>\n'
# X-axis label
svg += f'<text x="{cx}" y="{PAD_T + PLOT_H + 20}" fill="{LABEL_COLOR}" font-family="monospace" font-size="12" text-anchor="middle">{v["version"]}</text>\n'
svg += f'<text x="{cx}" y="{PAD_T + PLOT_H + 35}" fill="{TEXT}" font-family="monospace" font-size="9" text-anchor="middle">{v["label"]}</text>\n'
# Line (inverse loss = quality)
points = []
for i, v in enumerate(VERSIONS):
cx = PAD_L + gap * i + gap / 2
inv_loss = 1.0 / v["loss"]
ly = PAD_T + PLOT_H - (inv_loss / max_inv_loss) * PLOT_H
points.append(f"{cx},{ly}")
polyline = " ".join(points)
svg += f'<polyline points="{polyline}" fill="none" stroke="{LINE_COLOR}" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round"/>\n'
# Dots on line
for i, v in enumerate(VERSIONS):
cx = PAD_L + gap * i + gap / 2
inv_loss = 1.0 / v["loss"]
ly = PAD_T + PLOT_H - (inv_loss / max_inv_loss) * PLOT_H
svg += f'<circle cx="{cx}" cy="{ly}" r="4" fill="{LINE_COLOR}"/>\n'
svg += f'<text x="{cx}" y="{ly - 10}" fill="{LINE_COLOR}" font-family="monospace" font-size="10" text-anchor="middle">loss={v["loss"]}</text>\n'
# Y-axis labels
svg += f'<text x="{PAD_L - 45}" y="{PAD_T + PLOT_H/2}" fill="{BAR_COLOR}" font-family="monospace" font-size="11" text-anchor="middle" transform="rotate(-90,{PAD_L-45},{PAD_T+PLOT_H/2})">Training Examples</text>\n'
# Legend
svg += f'<rect x="{W-180}" y="{PAD_T+5}" width="12" height="12" fill="{BAR_COLOR}" rx="2"/>\n'
svg += f'<text x="{W-163}" y="{PAD_T+15}" fill="{TEXT}" font-family="monospace" font-size="10">Training Examples</text>\n'
svg += f'<line x1="{W-180}" y1="{PAD_T+28}" x2="{W-168}" y2="{PAD_T+28}" stroke="{LINE_COLOR}" stroke-width="2.5"/>\n'
svg += f'<text x="{W-163}" y="{PAD_T+32}" fill="{TEXT}" font-family="monospace" font-size="10">Model Quality (1/loss)</text>\n'
# X-axis label
svg += f'<text x="{W/2}" y="{H-10}" fill="{TEXT}" font-family="monospace" font-size="11" text-anchor="middle">Model Version</text>\n'
svg += "</svg>"
return svg
def main():
svg = generate_svg()
OUTPUT.parent.mkdir(parents=True, exist_ok=True)
with open(OUTPUT, "w") as f:
f.write(svg)
print(f"Chart saved to {OUTPUT}")
print(f"Embed in README: ![Training Progress](branding/training_progress.svg)")
if __name__ == "__main__":
main()