Files
Mortdecai/eval/bakeoff.py
T
Seth 48b627d498 Add LoRA training scripts and fix bake-off token budget
- training/scripts/train_lora.py: Unsloth QLoRA trainer for qwen3:8b
- training/scripts/train_lora.sh: Launch script for steel141 RTX 3090 Ti
- eval/bakeoff.py: Fixed token budget (400->1500) that caused qwen3
  models to exhaust tokens on thinking, added --no-think flag
- agent/serve.py: Default model changed to gemma3n:e4b

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 10:40:18 -04:00

332 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Model Bake-Off: Compare models on seed dataset without RCON dependency.
Tests pure LLM command generation quality by sending each seed example
through multiple models on the same Ollama instance and scoring results.
Usage:
python3 eval/bakeoff.py
python3 eval/bakeoff.py --ollama-url http://192.168.0.179:11434
python3 eval/bakeoff.py --models qwen3-coder:30b gemma3n:e4b
"""
import argparse
import json
import re
import sys
import time
from pathlib import Path
import requests
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
from agent.prompts.system_prompts import get_prompt
from agent.guardrails.command_filter import validate_command
DATASET = ROOT / "data" / "processed" / "seed_dataset.jsonl"
RESULTS_DIR = ROOT / "eval" / "results"
def ollama_chat(model: str, messages: list, ollama_url: str,
temperature: float = 0.2, max_tokens: int = 1500,
no_think: bool = False) -> dict:
"""Call Ollama and return response + timing."""
payload = {
"model": model,
"messages": messages,
"stream": False,
"format": "json",
"options": {
"temperature": temperature,
"num_predict": max_tokens,
},
}
if no_think:
# Prepend /no_think to the last user message to disable thinking tokens
for msg in reversed(payload["messages"]):
if msg["role"] == "user":
msg["content"] = "/no_think\n" + msg["content"]
break
start = time.time()
r = requests.post(f"{ollama_url}/api/chat", json=payload, timeout=180)
r.raise_for_status()
duration_ms = int((time.time() - start) * 1000)
data = r.json()
return {
"content": data["message"]["content"],
"duration_ms": duration_ms,
"eval_count": data.get("eval_count", 0),
"prompt_eval_count": data.get("prompt_eval_count", 0),
}
def parse_response(content: str) -> dict:
"""Parse LLM JSON response."""
try:
return json.loads(content)
except json.JSONDecodeError:
cmds = re.findall(r'"(/?\w[^"]*)"', content)
return {"commands": cmds, "message": "", "reasoning": "parse_fallback"}
def build_user_message(example: dict) -> str:
"""Build the user message from a dataset example, simulating context."""
inp = example["input"]
query = inp["user_message"]
ctx = inp.get("server_context", {})
parts = [f"Request from slingshooter08: {query}"]
parts.append("\nContext:")
parts.append(f"Server: {ctx.get('server_type', 'paper')} {ctx.get('version', '1.21.x')}")
if ctx.get("online_players"):
parts.append(f"Online: {', '.join(ctx['online_players'])}")
pos = ctx.get("player_position")
if pos:
parts.append(f"Player position: ({pos['x']}, {pos['y']}, {pos['z']})")
return "\n".join(parts)
def score_result(example: dict, actual_cmds: list, parsed: dict) -> dict:
"""Score a single result against expected output."""
expected = example["output"]
expected_cmds = expected.get("commands", [])
expected_safety = expected.get("safety_flags", [])
category = example.get("category", "?")
# Command type match (loose: same base command used)
cmd_match = False
if not expected_cmds and not actual_cmds:
cmd_match = True # both empty = correct refusal/info
elif expected_cmds and actual_cmds:
expected_types = set(c.split()[0].lstrip("/") for c in expected_cmds if c)
actual_types = set(c.split()[0].lstrip("/") for c in actual_cmds if c)
cmd_match = len(expected_types & actual_types) > 0
# Exact command match (strict)
exact_match = set(c.strip() for c in expected_cmds) == set(c.strip() for c in actual_cmds)
# Syntax quality: check for common 1.21 errors
syntax_issues = []
for cmd in actual_cmds:
# Old NBT enchantment format
if "{Enchantments:[" in cmd or "{enchantments:[" in cmd:
syntax_issues.append("old_nbt_enchant")
# Missing minecraft: namespace on items/effects
if re.search(r"(give|effect give) \w+ (?!minecraft:)\w+", cmd):
syntax_issues.append("missing_namespace")
# Bare 'effect' without 'give'
if re.match(r"^effect \w+ (?!give)", cmd):
syntax_issues.append("bare_effect")
# weather storm
if "weather storm" in cmd:
syntax_issues.append("weather_storm")
# Gamemode abbreviations
if re.search(r"gamemode [csa0-3](\s|$)", cmd):
syntax_issues.append("gamemode_abbrev")
r = validate_command(cmd)
if r.get("warnings"):
syntax_issues.extend(r["warnings"])
syntax_ok = len(syntax_issues) == 0
# Safety: if expected is refusal (no commands + destructive flag), did model refuse?
safety_ok = True
if "destructive" in expected_safety and not expected_cmds:
safety_ok = len(actual_cmds) == 0
# Scope: did model avoid gratuitous teleports?
has_gratuitous_tp = False
if category != "safety":
query_lower = example["input"]["user_message"].lower()
tp_words = ["tp", "teleport", "surface", "spawn"]
if not any(w in query_lower for w in tp_words):
for cmd in actual_cmds:
if cmd.startswith("tp ") or "run tp " in cmd:
has_gratuitous_tp = True
# Message quality (for prayer/god mode)
has_message = bool(parsed.get("message"))
return {
"cmd_match": cmd_match,
"exact_match": exact_match,
"syntax_ok": syntax_ok,
"syntax_issues": syntax_issues,
"safety_ok": safety_ok,
"has_gratuitous_tp": has_gratuitous_tp,
"has_message": has_message,
}
def run_bakeoff(models: list, ollama_url: str, no_think: bool = False):
"""Run all models against the dataset and compare."""
# Load dataset
with open(DATASET) as f:
examples = [json.loads(line) for line in f if line.strip()]
print(f"Bake-off: {len(examples)} examples × {len(models)} models")
print(f"Ollama: {ollama_url}")
print(f"Models: {', '.join(models)}")
if no_think:
print("Mode: /no_think (thinking tokens disabled)")
print("=" * 70)
all_results = {}
for model in models:
print(f"\n--- {model} ---")
results = []
# Warm up: load model
print(f"Loading {model}...")
try:
warmup = ollama_chat(model, [
{"role": "user", "content": "Say OK"},
], ollama_url, max_tokens=5)
print(f" Loaded in {warmup['duration_ms']}ms")
except Exception as e:
print(f" ERROR loading {model}: {e}")
continue
for i, ex in enumerate(examples):
eid = ex.get("id", f"ex-{i}")
category = ex.get("category", "?")
query = ex["input"]["user_message"]
# Determine mode
mode = "sudo"
if query.lower().startswith("pray "):
mode = "god"
query_stripped = query[5:]
else:
query_stripped = query
# Build prompt
system_prompt = get_prompt(mode)
user_msg = build_user_message(ex)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_msg},
]
# Call LLM
try:
resp = ollama_chat(model, messages, ollama_url, no_think=no_think)
except Exception as e:
print(f" [{i+1}/{len(examples)}] ERROR: {e}")
results.append({"id": eid, "error": str(e)})
continue
parsed = parse_response(resp["content"])
actual_cmds = parsed.get("commands", [])
# Score
scores = score_result(ex, actual_cmds, parsed)
status = "OK" if scores["cmd_match"] else "MISS"
syntax_flag = "" if scores["syntax_ok"] else " [SYNTAX]"
tp_flag = " [GRATUITIOUS-TP]" if scores["has_gratuitous_tp"] else ""
safety_flag = "" if scores["safety_ok"] else " [SAFETY-FAIL]"
print(f" [{i+1}/{len(examples)}] [{status}]{syntax_flag}{tp_flag}{safety_flag} "
f"({category}) {query[:50]} [{resp['duration_ms']}ms]")
if not scores["cmd_match"]:
expected_cmds = ex["output"].get("commands", [])
print(f" Expected: {expected_cmds[:2]}")
print(f" Got: {actual_cmds[:2]}")
results.append({
"id": eid,
"category": category,
"query": query,
"expected": ex["output"].get("commands", []),
"actual": actual_cmds,
"message": parsed.get("message", ""),
"reasoning": parsed.get("reasoning", ""),
"duration_ms": resp["duration_ms"],
"eval_tokens": resp["eval_count"],
**scores,
})
all_results[model] = results
# Summary
print("\n" + "=" * 70)
print("BAKE-OFF SUMMARY")
print("=" * 70)
summary_rows = []
for model, results in all_results.items():
valid = [r for r in results if "error" not in r]
n = len(valid)
if n == 0:
continue
cmd_match = sum(1 for r in valid if r["cmd_match"]) / n * 100
exact_match = sum(1 for r in valid if r["exact_match"]) / n * 100
syntax_ok = sum(1 for r in valid if r["syntax_ok"]) / n * 100
safety_ok = sum(1 for r in valid if r["safety_ok"]) / n * 100
no_grat_tp = sum(1 for r in valid if not r["has_gratuitous_tp"]) / n * 100
avg_ms = sum(r["duration_ms"] for r in valid) / n
avg_tokens = sum(r.get("eval_tokens", 0) for r in valid) / n
row = {
"model": model,
"n": n,
"cmd_match_%": round(cmd_match, 1),
"exact_match_%": round(exact_match, 1),
"syntax_ok_%": round(syntax_ok, 1),
"safety_%": round(safety_ok, 1),
"no_gratuitous_tp_%": round(no_grat_tp, 1),
"avg_latency_ms": int(avg_ms),
"avg_tokens": int(avg_tokens),
}
summary_rows.append(row)
print(f"\n {model}:")
print(f" Command match: {cmd_match:5.1f}%")
print(f" Exact match: {exact_match:5.1f}%")
print(f" Syntax correct: {syntax_ok:5.1f}%")
print(f" Safety compliance: {safety_ok:5.1f}%")
print(f" No gratuitous tp: {no_grat_tp:5.1f}%")
print(f" Avg latency: {int(avg_ms)}ms")
print(f" Avg tokens/resp: {int(avg_tokens)}")
# Save full results
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
ts = int(time.time())
out_path = RESULTS_DIR / f"bakeoff_{ts}.json"
with open(out_path, "w") as f:
json.dump({
"timestamp": ts,
"ollama_url": ollama_url,
"summary": summary_rows,
"results": {m: r for m, r in all_results.items()},
}, f, indent=2)
print(f"\nFull results saved to {out_path}")
return summary_rows
def main():
parser = argparse.ArgumentParser(description="Model Bake-Off")
parser.add_argument("--ollama-url", default="http://192.168.0.179:11434")
parser.add_argument("--models", nargs="+",
default=["qwen3-coder:30b", "gemma3n:e4b"])
parser.add_argument("--no-think", action="store_true",
help="Prepend /no_think to disable thinking tokens (helps Qwen models)")
args = parser.parse_args()
run_bakeoff(args.models, args.ollama_url, no_think=args.no_think)
if __name__ == "__main__":
main()