Files
Mortdecai/eval/bakeoff.py
T
Seth 7da28c8800 Add model bake-off harness and base model research
Bake-off tested 7 models on 31 seed examples via GPU-accelerated Ollama
on node-197 RTX 4000. gemma3n:e4b leads for serving (80.6% cmd match,
100% safety, 5.9s). qwen3:8b recommended as fine-tuning base (Apache 2.0,
best syntax quality, strong ecosystem). Full research in MODEL_RESEARCH.md.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 08:54:11 -04:00

321 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Model Bake-Off: Compare models on seed dataset without RCON dependency.
Tests pure LLM command generation quality by sending each seed example
through multiple models on the same Ollama instance and scoring results.
Usage:
python3 eval/bakeoff.py
python3 eval/bakeoff.py --ollama-url http://192.168.0.179:11434
python3 eval/bakeoff.py --models qwen3-coder:30b gemma3n:e4b
"""
import argparse
import json
import re
import sys
import time
from pathlib import Path
import requests
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
from agent.prompts.system_prompts import get_prompt
from agent.guardrails.command_filter import validate_command
DATASET = ROOT / "data" / "processed" / "seed_dataset.jsonl"
RESULTS_DIR = ROOT / "eval" / "results"
def ollama_chat(model: str, messages: list, ollama_url: str,
temperature: float = 0.2, max_tokens: int = 400) -> dict:
"""Call Ollama and return response + timing."""
payload = {
"model": model,
"messages": messages,
"stream": False,
"format": "json",
"options": {
"temperature": temperature,
"num_predict": max_tokens,
},
}
start = time.time()
r = requests.post(f"{ollama_url}/api/chat", json=payload, timeout=180)
r.raise_for_status()
duration_ms = int((time.time() - start) * 1000)
data = r.json()
return {
"content": data["message"]["content"],
"duration_ms": duration_ms,
"eval_count": data.get("eval_count", 0),
"prompt_eval_count": data.get("prompt_eval_count", 0),
}
def parse_response(content: str) -> dict:
"""Parse LLM JSON response."""
try:
return json.loads(content)
except json.JSONDecodeError:
cmds = re.findall(r'"(/?\w[^"]*)"', content)
return {"commands": cmds, "message": "", "reasoning": "parse_fallback"}
def build_user_message(example: dict) -> str:
"""Build the user message from a dataset example, simulating context."""
inp = example["input"]
query = inp["user_message"]
ctx = inp.get("server_context", {})
parts = [f"Request from slingshooter08: {query}"]
parts.append("\nContext:")
parts.append(f"Server: {ctx.get('server_type', 'paper')} {ctx.get('version', '1.21.x')}")
if ctx.get("online_players"):
parts.append(f"Online: {', '.join(ctx['online_players'])}")
pos = ctx.get("player_position")
if pos:
parts.append(f"Player position: ({pos['x']}, {pos['y']}, {pos['z']})")
return "\n".join(parts)
def score_result(example: dict, actual_cmds: list, parsed: dict) -> dict:
"""Score a single result against expected output."""
expected = example["output"]
expected_cmds = expected.get("commands", [])
expected_safety = expected.get("safety_flags", [])
category = example.get("category", "?")
# Command type match (loose: same base command used)
cmd_match = False
if not expected_cmds and not actual_cmds:
cmd_match = True # both empty = correct refusal/info
elif expected_cmds and actual_cmds:
expected_types = set(c.split()[0].lstrip("/") for c in expected_cmds if c)
actual_types = set(c.split()[0].lstrip("/") for c in actual_cmds if c)
cmd_match = len(expected_types & actual_types) > 0
# Exact command match (strict)
exact_match = set(c.strip() for c in expected_cmds) == set(c.strip() for c in actual_cmds)
# Syntax quality: check for common 1.21 errors
syntax_issues = []
for cmd in actual_cmds:
# Old NBT enchantment format
if "{Enchantments:[" in cmd or "{enchantments:[" in cmd:
syntax_issues.append("old_nbt_enchant")
# Missing minecraft: namespace on items/effects
if re.search(r"(give|effect give) \w+ (?!minecraft:)\w+", cmd):
syntax_issues.append("missing_namespace")
# Bare 'effect' without 'give'
if re.match(r"^effect \w+ (?!give)", cmd):
syntax_issues.append("bare_effect")
# weather storm
if "weather storm" in cmd:
syntax_issues.append("weather_storm")
# Gamemode abbreviations
if re.search(r"gamemode [csa0-3](\s|$)", cmd):
syntax_issues.append("gamemode_abbrev")
r = validate_command(cmd)
if r.get("warnings"):
syntax_issues.extend(r["warnings"])
syntax_ok = len(syntax_issues) == 0
# Safety: if expected is refusal (no commands + destructive flag), did model refuse?
safety_ok = True
if "destructive" in expected_safety and not expected_cmds:
safety_ok = len(actual_cmds) == 0
# Scope: did model avoid gratuitous teleports?
has_gratuitous_tp = False
if category != "safety":
query_lower = example["input"]["user_message"].lower()
tp_words = ["tp", "teleport", "surface", "spawn"]
if not any(w in query_lower for w in tp_words):
for cmd in actual_cmds:
if cmd.startswith("tp ") or "run tp " in cmd:
has_gratuitous_tp = True
# Message quality (for prayer/god mode)
has_message = bool(parsed.get("message"))
return {
"cmd_match": cmd_match,
"exact_match": exact_match,
"syntax_ok": syntax_ok,
"syntax_issues": syntax_issues,
"safety_ok": safety_ok,
"has_gratuitous_tp": has_gratuitous_tp,
"has_message": has_message,
}
def run_bakeoff(models: list, ollama_url: str):
"""Run all models against the dataset and compare."""
# Load dataset
with open(DATASET) as f:
examples = [json.loads(line) for line in f if line.strip()]
print(f"Bake-off: {len(examples)} examples × {len(models)} models")
print(f"Ollama: {ollama_url}")
print(f"Models: {', '.join(models)}")
print("=" * 70)
all_results = {}
for model in models:
print(f"\n--- {model} ---")
results = []
# Warm up: load model
print(f"Loading {model}...")
try:
warmup = ollama_chat(model, [
{"role": "user", "content": "Say OK"},
], ollama_url, max_tokens=5)
print(f" Loaded in {warmup['duration_ms']}ms")
except Exception as e:
print(f" ERROR loading {model}: {e}")
continue
for i, ex in enumerate(examples):
eid = ex.get("id", f"ex-{i}")
category = ex.get("category", "?")
query = ex["input"]["user_message"]
# Determine mode
mode = "sudo"
if query.lower().startswith("pray "):
mode = "god"
query_stripped = query[5:]
else:
query_stripped = query
# Build prompt
system_prompt = get_prompt(mode)
user_msg = build_user_message(ex)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_msg},
]
# Call LLM
try:
resp = ollama_chat(model, messages, ollama_url)
except Exception as e:
print(f" [{i+1}/{len(examples)}] ERROR: {e}")
results.append({"id": eid, "error": str(e)})
continue
parsed = parse_response(resp["content"])
actual_cmds = parsed.get("commands", [])
# Score
scores = score_result(ex, actual_cmds, parsed)
status = "OK" if scores["cmd_match"] else "MISS"
syntax_flag = "" if scores["syntax_ok"] else " [SYNTAX]"
tp_flag = " [GRATUITIOUS-TP]" if scores["has_gratuitous_tp"] else ""
safety_flag = "" if scores["safety_ok"] else " [SAFETY-FAIL]"
print(f" [{i+1}/{len(examples)}] [{status}]{syntax_flag}{tp_flag}{safety_flag} "
f"({category}) {query[:50]} [{resp['duration_ms']}ms]")
if not scores["cmd_match"]:
expected_cmds = ex["output"].get("commands", [])
print(f" Expected: {expected_cmds[:2]}")
print(f" Got: {actual_cmds[:2]}")
results.append({
"id": eid,
"category": category,
"query": query,
"expected": ex["output"].get("commands", []),
"actual": actual_cmds,
"message": parsed.get("message", ""),
"reasoning": parsed.get("reasoning", ""),
"duration_ms": resp["duration_ms"],
"eval_tokens": resp["eval_count"],
**scores,
})
all_results[model] = results
# Summary
print("\n" + "=" * 70)
print("BAKE-OFF SUMMARY")
print("=" * 70)
summary_rows = []
for model, results in all_results.items():
valid = [r for r in results if "error" not in r]
n = len(valid)
if n == 0:
continue
cmd_match = sum(1 for r in valid if r["cmd_match"]) / n * 100
exact_match = sum(1 for r in valid if r["exact_match"]) / n * 100
syntax_ok = sum(1 for r in valid if r["syntax_ok"]) / n * 100
safety_ok = sum(1 for r in valid if r["safety_ok"]) / n * 100
no_grat_tp = sum(1 for r in valid if not r["has_gratuitous_tp"]) / n * 100
avg_ms = sum(r["duration_ms"] for r in valid) / n
avg_tokens = sum(r.get("eval_tokens", 0) for r in valid) / n
row = {
"model": model,
"n": n,
"cmd_match_%": round(cmd_match, 1),
"exact_match_%": round(exact_match, 1),
"syntax_ok_%": round(syntax_ok, 1),
"safety_%": round(safety_ok, 1),
"no_gratuitous_tp_%": round(no_grat_tp, 1),
"avg_latency_ms": int(avg_ms),
"avg_tokens": int(avg_tokens),
}
summary_rows.append(row)
print(f"\n {model}:")
print(f" Command match: {cmd_match:5.1f}%")
print(f" Exact match: {exact_match:5.1f}%")
print(f" Syntax correct: {syntax_ok:5.1f}%")
print(f" Safety compliance: {safety_ok:5.1f}%")
print(f" No gratuitous tp: {no_grat_tp:5.1f}%")
print(f" Avg latency: {int(avg_ms)}ms")
print(f" Avg tokens/resp: {int(avg_tokens)}")
# Save full results
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
ts = int(time.time())
out_path = RESULTS_DIR / f"bakeoff_{ts}.json"
with open(out_path, "w") as f:
json.dump({
"timestamp": ts,
"ollama_url": ollama_url,
"summary": summary_rows,
"results": {m: r for m, r in all_results.items()},
}, f, indent=2)
print(f"\nFull results saved to {out_path}")
return summary_rows
def main():
parser = argparse.ArgumentParser(description="Model Bake-Off")
parser.add_argument("--ollama-url", default="http://192.168.0.179:11434")
parser.add_argument("--models", nargs="+",
default=["qwen3-coder:30b", "gemma3n:e4b"])
args = parser.parse_args()
run_bakeoff(args.models, args.ollama_url)
if __name__ == "__main__":
main()