feat: native-bakeoff scaffold — Ollama JSON vs native-token tool-calling
Three-arm harness under scripts/native-bakeoff/: - arm A: /api/chat with JSON tools (current default) - arm B: /api/generate raw:true with canonical HF jinja template rendered directly - arm C: google-deepmind/gemma JAX ToolSampler (env-gated, JAX required) Interim finding from A+B sweep on matt-strix gemma4:26b Q4: Ollama's bidirectional JSON↔native tool-call translator is faithful. The "long" multi-tool task produces identical behavior (7 steps / 6 tools) on both arms. Earlier arm-B parser bug that looked like a divergence was a harness issue: preserving the model's <|channel>thought\n<channel|> prefix as assistant content tripped the jinja template's tool_response-following conditional, appending a spurious <turn|>\n that corrupted the next step's prompt. Fixed by dropping the channel prefix on the assistant message. Arm C left as scaffolded-but-not-run — the JAX/bf16 reference path would answer "does the GGUF runtime diverge from DeepMind's implementation" but requires a separate env with the `gemma` PyPI package. Parked pending SDXL eviction or vast-h100 session. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,261 @@
|
||||
"""Arm C: google-deepmind/gemma JAX ToolSampler (reference path).
|
||||
|
||||
This arm runs against the *canonical* JAX reference implementation. No
|
||||
Ollama, no llama.cpp, no GGUF quantization, no wire protocol — the
|
||||
chat template, token-level sampling, and tool-call parsing all happen
|
||||
inside the Python process using the code Google wrote for Gemma 4.
|
||||
|
||||
**Environment requirement** — this arm cannot run inside the Ollama-only
|
||||
environment used by arms A/B. Setup:
|
||||
|
||||
pip install jax[cuda12] gemma # or jax[cpu] for CPU fallback
|
||||
huggingface-cli login # weights download via HF
|
||||
|
||||
It will download `gm.ckpts.CheckpointPath.GEMMA4_E4B_IT` on first run
|
||||
(~8GB). Run this arm on a host with ≥16GB RAM (CPU) or ≥10GB VRAM (GPU).
|
||||
|
||||
**Known caveat** — the `gm.text.ToolSampler` docstring notes that
|
||||
"Gemma 1, 2 and 3 models were not specifically trained for tool use"
|
||||
and flags the sampler as a proof-of-concept. Gemma 4 *is* tool-trained
|
||||
so it should do better here, but if this arm underperforms A/B it may
|
||||
be the sampler wrapper, not the model. The trace logs the raw sampler
|
||||
turns so that can be diagnosed post-hoc.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
# Local imports are guarded so the harness can at least import this
|
||||
# module on a non-JAX host for syntax checking. The actual run() call
|
||||
# will blow up with a clean ImportError if the env isn't set up.
|
||||
try:
|
||||
from gemma import gm # type: ignore
|
||||
_GEMMA_AVAILABLE = True
|
||||
except ImportError:
|
||||
gm = None # type: ignore
|
||||
_GEMMA_AVAILABLE = False
|
||||
|
||||
from tasks import SYSTEM_PROMPT, FAKE_HISTORY, TASKS, execute_tool_stub # noqa: F401 (TASKS for parity with A/B)
|
||||
|
||||
|
||||
# -------- Tool wrappers: one gm.tools.Tool subclass per stub --------
|
||||
#
|
||||
# ToolSampler requires DESCRIPTION + EXAMPLE for each tool so the model
|
||||
# sees an in-context example of the calling pattern. The EXAMPLE bodies
|
||||
# are intentionally short — they're primers, not test cases.
|
||||
|
||||
def _build_tools():
|
||||
"""Build the 8 ToolSampler-compatible wrappers. Deferred so that
|
||||
`import gm` only happens when we actually intend to run the arm."""
|
||||
assert gm is not None
|
||||
|
||||
class WebSearch(gm.tools.Tool):
|
||||
DESCRIPTION = "Search the web for current information."
|
||||
EXAMPLE = gm.tools.Example(
|
||||
query="recent Home Assistant release notes",
|
||||
thought="web_search is the right tool for current events / docs.",
|
||||
tool_kwargs={"query": "home assistant latest release"},
|
||||
tool_kwargs_doc={"query": "<search query string>"},
|
||||
result="1. HA 2026.4 released...",
|
||||
answer="Home Assistant 2026.4 is the most recent release.",
|
||||
)
|
||||
def call(self, query: str) -> str:
|
||||
return execute_tool_stub("web_search", {"query": query})
|
||||
|
||||
class SethSearch(gm.tools.Tool):
|
||||
DESCRIPTION = "Search Seth's homelab (repos, wiki, media). Use source='sethflix' for movies/TV."
|
||||
EXAMPLE = gm.tools.Example(
|
||||
query="any cyberpunk movies on sethflix?",
|
||||
thought="Use source=sethflix to search the movie library.",
|
||||
tool_kwargs={"query": "cyberpunk", "source": "sethflix"},
|
||||
tool_kwargs_doc={
|
||||
"query": "<search query>",
|
||||
"source": "<'sethflix' | 'general'>",
|
||||
"limit": "<int, default 10>",
|
||||
},
|
||||
result="Blade Runner 2049, Ghost in the Shell, ...",
|
||||
answer="Yes — Blade Runner 2049, Ghost in the Shell, and a few others.",
|
||||
)
|
||||
def call(self, query: str, source: str = "general", limit: int = 10) -> str:
|
||||
return execute_tool_stub("sethsearch", {"query": query, "source": source, "limit": limit})
|
||||
|
||||
class CheckSethflix(gm.tools.Tool):
|
||||
DESCRIPTION = "Verify which comma-separated titles are in sethflix."
|
||||
EXAMPLE = gm.tools.Example(
|
||||
query="is The Matrix in the library?",
|
||||
thought="check_sethflix verifies library membership.",
|
||||
tool_kwargs={"titles": "The Matrix"},
|
||||
tool_kwargs_doc={"titles": "<comma-separated title list>"},
|
||||
result="- The Matrix: IN LIBRARY",
|
||||
answer="Yes, The Matrix is in the library.",
|
||||
)
|
||||
def call(self, titles: str) -> str:
|
||||
return execute_tool_stub("check_sethflix", {"titles": titles})
|
||||
|
||||
class MemoryRead(gm.tools.Tool):
|
||||
DESCRIPTION = "Look up stored facts about a topic or user."
|
||||
EXAMPLE = gm.tools.Example(
|
||||
query="what do I have about home automation?",
|
||||
thought="memory_read is the right tool.",
|
||||
tool_kwargs={"query": "home automation"},
|
||||
tool_kwargs_doc={"query": "<topic>", "user": "<optional user filter>"},
|
||||
result="- home_automation: Seth uses HA on VM 706...",
|
||||
answer="You have notes about HA on VM 706 with Zigbee2MQTT.",
|
||||
)
|
||||
def call(self, query: str, user: str = "") -> str:
|
||||
return execute_tool_stub("memory_read", {"query": query, "user": user})
|
||||
|
||||
class MemoryWrite(gm.tools.Tool):
|
||||
DESCRIPTION = "Store a durable fact."
|
||||
EXAMPLE = gm.tools.Example(
|
||||
query="remember that Seth prefers dark themes",
|
||||
thought="memory_write stores a key/content pair.",
|
||||
tool_kwargs={"key": "theme_preference", "content": "dark with orange accents"},
|
||||
tool_kwargs_doc={"key": "<short id>", "content": "<fact body>", "user": "<optional>"},
|
||||
result="stored: theme_preference = dark with orange accents",
|
||||
answer="Saved.",
|
||||
)
|
||||
def call(self, key: str, content: str, user: str = "") -> str:
|
||||
return execute_tool_stub("memory_write", {"key": key, "content": content, "user": user})
|
||||
|
||||
class WebFetch(gm.tools.Tool):
|
||||
DESCRIPTION = "Fetch the text contents of a URL."
|
||||
EXAMPLE = gm.tools.Example(
|
||||
query="fetch https://example.com/docs",
|
||||
thought="web_fetch pulls page text.",
|
||||
tool_kwargs={"url": "https://example.com/docs"},
|
||||
tool_kwargs_doc={"url": "<absolute URL>"},
|
||||
result="fetched content: ...",
|
||||
answer="The page discusses X, Y, Z.",
|
||||
)
|
||||
def call(self, url: str) -> str:
|
||||
return execute_tool_stub("web_fetch", {"url": url})
|
||||
|
||||
class ChatSearch(gm.tools.Tool):
|
||||
DESCRIPTION = "Search message history across Matrix rooms."
|
||||
EXAMPLE = gm.tools.Example(
|
||||
query="have we talked about grafana before?",
|
||||
thought="chat_search looks through prior messages.",
|
||||
tool_kwargs={"query": "grafana"},
|
||||
tool_kwargs_doc={"query": "<search query>"},
|
||||
result="[2026-03-14] @seth: grafana dashboard...",
|
||||
answer="Yes — you discussed a grafana dashboard on March 14.",
|
||||
)
|
||||
def call(self, query: str) -> str:
|
||||
return execute_tool_stub("chat_search", {"query": query})
|
||||
|
||||
class GenerateImage(gm.tools.Tool):
|
||||
DESCRIPTION = "Generate an image via SDXL."
|
||||
EXAMPLE = gm.tools.Example(
|
||||
query="make me a sunset image",
|
||||
thought="generate_image dispatches to SDXL.",
|
||||
tool_kwargs={"prompt": "dramatic ocean sunset"},
|
||||
tool_kwargs_doc={"prompt": "<image description>"},
|
||||
result="image generated: /mxc/abc/sunset.png",
|
||||
answer="Done — here's the sunset image.",
|
||||
)
|
||||
def call(self, prompt: str) -> str:
|
||||
return execute_tool_stub("generate_image", {"prompt": prompt})
|
||||
|
||||
return [
|
||||
WebSearch(), SethSearch(), CheckSethflix(),
|
||||
MemoryRead(), MemoryWrite(), WebFetch(),
|
||||
ChatSearch(), GenerateImage(),
|
||||
]
|
||||
|
||||
|
||||
async def run(
|
||||
*,
|
||||
ollama_url: str, # unused; kept for CLI parity with arms A/B
|
||||
model: str, # unused; arm C loads its own checkpoint
|
||||
task_prompt: str,
|
||||
num_ctx: int, # unused; ToolSampler uses its own seq_len
|
||||
num_predict: int,
|
||||
step_budget: int,
|
||||
) -> dict[str, Any]:
|
||||
if not _GEMMA_AVAILABLE:
|
||||
return {
|
||||
"arm": "jax-native",
|
||||
"error": "gemma package not importable — run in a JAX+gemma env. See module docstring.",
|
||||
"final": {"halt_reason": "env_missing", "steps_used": 0, "tool_calls_total": 0, "wall_clock_s": 0},
|
||||
}
|
||||
|
||||
# Let JAX use the whole GPU if present (per colab_tool_use.ipynb hint).
|
||||
os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.95")
|
||||
|
||||
t_load_start = time.time()
|
||||
model_net = gm.nn.Gemma4_E4B()
|
||||
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA4_E4B_IT)
|
||||
tools = _build_tools()
|
||||
sampler = gm.text.ToolSampler(
|
||||
model=model_net,
|
||||
params=params,
|
||||
tools=tools,
|
||||
print_stream=False,
|
||||
)
|
||||
load_elapsed_s = round(time.time() - t_load_start, 2)
|
||||
|
||||
# ToolSampler doesn't natively consume a system prompt + pre-populated
|
||||
# history. We simulate the same mid-session context by prepending a
|
||||
# compact version of FAKE_HISTORY into the user message itself. This is
|
||||
# a fidelity compromise documented in the writeup — the A/B arms feed
|
||||
# history through proper role-tagged turns. If a delta between arms is
|
||||
# traced to this, rebuild the sampler's turn list directly from
|
||||
# `sampler.turns` pre-population.
|
||||
history_compact = "\n".join(
|
||||
f"{m['role'].upper()}: {m['content']}" for m in FAKE_HISTORY[-6:]
|
||||
)
|
||||
user_msg = (
|
||||
f"[prior chat context]\n{history_compact}\n\n"
|
||||
f"[2026-04-18 14:20] @seth:sethpc.xyz: {task_prompt}"
|
||||
)
|
||||
|
||||
trace: dict[str, Any] = {
|
||||
"arm": "jax-native",
|
||||
"checkpoint": "GEMMA4_E4B_IT",
|
||||
"tools_registered": [t.__class__.__name__ for t in tools],
|
||||
"load_elapsed_s": load_elapsed_s,
|
||||
"step_budget_note": "ToolSampler manages its own step loop; step_budget ignored",
|
||||
"started_at": time.time(),
|
||||
"turns": [],
|
||||
"final": None,
|
||||
}
|
||||
|
||||
try:
|
||||
t0 = time.time()
|
||||
answer = sampler.chat(user_msg)
|
||||
elapsed = round(time.time() - t0, 2)
|
||||
except Exception as e:
|
||||
trace["final"] = {"halt_reason": f"sampler_error: {e}", "steps_used": 0,
|
||||
"tool_calls_total": 0, "wall_clock_s": round(time.time() - trace["started_at"], 2)}
|
||||
return trace
|
||||
|
||||
# Extract per-turn info from sampler.turns — the library exposes the
|
||||
# full trace (thoughts, tool calls, tool results, final answer).
|
||||
sampler_turns = list(getattr(sampler, "turns", []) or [])
|
||||
tool_call_total = 0
|
||||
for i, t in enumerate(sampler_turns):
|
||||
# Different releases of gemma have different turn schemas. We
|
||||
# log defensively — whatever attributes the turn object has end
|
||||
# up in the JSON so we can inspect post-hoc.
|
||||
info: dict[str, Any] = {"step": i + 1, "turn_type": t.__class__.__name__}
|
||||
for attr in ("query", "thought", "tool_name", "tool_kwargs", "tool_result", "answer"):
|
||||
v = getattr(t, attr, None)
|
||||
if v is not None:
|
||||
info[attr] = v if isinstance(v, (str, int, float, bool, list, dict)) else str(v)
|
||||
if info.get("tool_name"):
|
||||
tool_call_total += 1
|
||||
trace["turns"].append(info)
|
||||
|
||||
trace["final"] = {
|
||||
"halt_reason": "answer_returned" if answer else "no_answer",
|
||||
"steps_used": len(sampler_turns),
|
||||
"tool_calls_total": tool_call_total,
|
||||
"wall_clock_s": round(time.time() - trace["started_at"], 2),
|
||||
"model_answer": answer,
|
||||
"sampler_elapsed_s": elapsed,
|
||||
}
|
||||
return trace
|
||||
@@ -0,0 +1,121 @@
|
||||
"""Arm A: Ollama /api/chat with JSON tools.
|
||||
|
||||
This is the baseline — what mort-bot, OpenWebUI, and every other Ollama
|
||||
client does. Ollama's server translates the OpenAI-style JSON tools
|
||||
array into Gemma's native <|tool>declaration:...<tool|> tokens and
|
||||
parses the model's <|tool_call>call:...<tool_call|> output back into
|
||||
structured tool_calls. This arm measures what we already live with.
|
||||
|
||||
Think setting: fixed to `false` per round-3 bakeoff finding (26B silently
|
||||
stops on think:true in multi-turn tool loops). For E4B the finding was
|
||||
less load-bearing but we hold think:false constant across arms so
|
||||
only the inference path varies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from tasks import SYSTEM_PROMPT, TOOLS, FAKE_HISTORY, execute_tool_stub
|
||||
|
||||
|
||||
async def run(
|
||||
*,
|
||||
ollama_url: str,
|
||||
model: str,
|
||||
task_prompt: str,
|
||||
num_ctx: int,
|
||||
num_predict: int,
|
||||
step_budget: int,
|
||||
) -> dict[str, Any]:
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + list(FAKE_HISTORY)
|
||||
messages.append({"role": "user", "content": f"[2026-04-18 14:20] @seth:sethpc.xyz: {task_prompt}"})
|
||||
|
||||
trace: dict[str, Any] = {
|
||||
"arm": "ollama-json",
|
||||
"model": model,
|
||||
"num_ctx": num_ctx,
|
||||
"num_predict": num_predict,
|
||||
"started_at": time.time(),
|
||||
"turns": [],
|
||||
"final": None,
|
||||
}
|
||||
|
||||
tool_call_total = 0
|
||||
halt: str | None = None
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for step in range(1, step_budget + 1):
|
||||
t0 = time.time()
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": TOOLS,
|
||||
"stream": False,
|
||||
"think": False,
|
||||
"options": {"num_ctx": num_ctx, "num_predict": num_predict,
|
||||
"temperature": 0.7, "top_p": 0.95, "top_k": 64},
|
||||
"keep_alive": "2h",
|
||||
}
|
||||
try:
|
||||
async with session.post(
|
||||
f"{ollama_url}/api/chat", json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=300),
|
||||
) as resp:
|
||||
r = await resp.json()
|
||||
except Exception as e:
|
||||
halt = f"error: {e}"
|
||||
trace["turns"].append({"step": step, "error": str(e)})
|
||||
break
|
||||
|
||||
msg = r.get("message", {}) or {}
|
||||
content = msg.get("content", "") or ""
|
||||
tool_calls = msg.get("tool_calls") or []
|
||||
history_chars = sum(len(m.get("content", "") or "") for m in messages)
|
||||
|
||||
trace["turns"].append({
|
||||
"step": step,
|
||||
"elapsed_s": round(time.time() - t0, 2),
|
||||
"prompt_eval_count": r.get("prompt_eval_count"),
|
||||
"eval_count": r.get("eval_count"),
|
||||
"content_len": len(content),
|
||||
"tool_call_count": len(tool_calls),
|
||||
"history_chars_before_append": history_chars,
|
||||
})
|
||||
messages.append(msg)
|
||||
|
||||
if not tool_calls:
|
||||
halt = "no_tool_calls"
|
||||
break
|
||||
|
||||
tool_call_total += len(tool_calls)
|
||||
for tc in tool_calls:
|
||||
fn = tc.get("function", {})
|
||||
name = fn.get("name")
|
||||
args = fn.get("arguments") or {}
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except Exception:
|
||||
args = {}
|
||||
result = execute_tool_stub(name, args)
|
||||
messages.append({"role": "tool", "content": result})
|
||||
|
||||
if step == step_budget:
|
||||
halt = "step_budget"
|
||||
break
|
||||
|
||||
trace["final"] = {
|
||||
"halt_reason": halt,
|
||||
"steps_used": len(trace["turns"]),
|
||||
"tool_calls_total": tool_call_total,
|
||||
"wall_clock_s": round(time.time() - trace["started_at"], 2),
|
||||
"final_message_count": len(messages),
|
||||
"final_history_chars": sum(len(m.get("content", "") or "") for m in messages),
|
||||
}
|
||||
return trace
|
||||
@@ -0,0 +1,275 @@
|
||||
"""Arm B: Ollama /api/generate with raw:true and native Gemma 4 tokens.
|
||||
|
||||
Renders the canonical HF jinja chat template directly, sends the
|
||||
resulting string to Ollama's /api/generate with `raw: true` (which
|
||||
bypasses Ollama's own templating / BOS handling), and parses
|
||||
<|tool_call>call:NAME{args}<tool_call|> out of the completion with a
|
||||
regex.
|
||||
|
||||
The point of this arm: isolate Ollama's tool parser. Arm A lets
|
||||
Ollama's server translate OpenAI-shaped JSON tools into native tokens
|
||||
AND translate the model's native <|tool_call> output back into
|
||||
structured `tool_calls`. Arm B keeps everything native end-to-end and
|
||||
only uses Ollama as a thin completion engine. If A and B diverge, the
|
||||
delta lives in Ollama's bidirectional JSON↔native translator.
|
||||
|
||||
Template source: tooling/huggingface/model-cards/gemma-4-E4B-it-chat_template.jinja
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import jinja2
|
||||
|
||||
from tasks import SYSTEM_PROMPT, TOOLS, FAKE_HISTORY, execute_tool_stub
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||
_TEMPLATE_PATH = _REPO_ROOT / "tooling" / "huggingface" / "model-cards" / "gemma-4-E4B-it-chat_template.jinja"
|
||||
|
||||
|
||||
def _load_template() -> jinja2.Template:
|
||||
env = jinja2.Environment(
|
||||
keep_trailing_newline=True,
|
||||
# Canonical template uses `{%- ... -%}` whitespace control; keep
|
||||
# jinja defaults so it renders exactly as HF's template expects.
|
||||
)
|
||||
return env.from_string(_TEMPLATE_PATH.read_text())
|
||||
|
||||
|
||||
_TOOL_CALL_RE = re.compile(
|
||||
r"<\|tool_call>call:(?P<name>\w+)\{(?P<body>.*?)\}<tool_call\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def _parse_native_args(body: str) -> dict[str, Any]:
|
||||
"""Parse the body of a <|tool_call>call:NAME{...}<tool_call|>.
|
||||
|
||||
Gemma 4 native arg format (from the jinja template's format_argument
|
||||
macro with escape_keys=False):
|
||||
- key:<|"|>stringval<|"|>
|
||||
- key:123
|
||||
- key:true / key:false
|
||||
- key:{nested:...} (for mapping args — not used by our stubs)
|
||||
- key:[<|"|>item<|"|>,...] (for array args — not used by our stubs)
|
||||
|
||||
Our stub tool schemas are flat (string / integer / bool), so a
|
||||
simple top-level comma split is enough. If a future tool needs
|
||||
nested args this needs depth-aware splitting.
|
||||
"""
|
||||
out: dict[str, Any] = {}
|
||||
if not body:
|
||||
return out
|
||||
|
||||
# Top-level comma split, respecting only the `<|"|>...<|"|>` string
|
||||
# delimiter (since our tool args don't nest). This intentionally
|
||||
# doesn't handle {...} or [...] — flag it with a log entry in the
|
||||
# harness if a future tool needs those.
|
||||
parts: list[str] = []
|
||||
buf = ""
|
||||
i = 0
|
||||
str_delim = '<|"|>'
|
||||
in_str = False
|
||||
while i < len(body):
|
||||
if body[i : i + len(str_delim)] == str_delim:
|
||||
in_str = not in_str
|
||||
buf += str_delim
|
||||
i += len(str_delim)
|
||||
continue
|
||||
if body[i] == "," and not in_str:
|
||||
parts.append(buf)
|
||||
buf = ""
|
||||
i += 1
|
||||
continue
|
||||
buf += body[i]
|
||||
i += 1
|
||||
if buf:
|
||||
parts.append(buf)
|
||||
|
||||
for p in parts:
|
||||
if ":" not in p:
|
||||
continue
|
||||
k, _, v = p.partition(":")
|
||||
k = k.strip()
|
||||
v = v.strip()
|
||||
if v.startswith(str_delim) and v.endswith(str_delim):
|
||||
out[k] = v[len(str_delim) : -len(str_delim)]
|
||||
elif v == "true":
|
||||
out[k] = True
|
||||
elif v == "false":
|
||||
out[k] = False
|
||||
else:
|
||||
try:
|
||||
out[k] = int(v)
|
||||
except ValueError:
|
||||
try:
|
||||
out[k] = float(v)
|
||||
except ValueError:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
def _render(messages: list[dict[str, Any]]) -> str:
|
||||
tmpl = _load_template()
|
||||
return tmpl.render(
|
||||
messages=messages,
|
||||
tools=TOOLS,
|
||||
add_generation_prompt=True,
|
||||
bos_token="<bos>",
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
|
||||
async def run(
|
||||
*,
|
||||
ollama_url: str,
|
||||
model: str,
|
||||
task_prompt: str,
|
||||
num_ctx: int,
|
||||
num_predict: int,
|
||||
step_budget: int,
|
||||
) -> dict[str, Any]:
|
||||
messages: list[dict[str, Any]] = [{"role": "system", "content": SYSTEM_PROMPT}] + list(FAKE_HISTORY)
|
||||
messages.append({"role": "user", "content": f"[2026-04-18 14:20] @seth:sethpc.xyz: {task_prompt}"})
|
||||
|
||||
trace: dict[str, Any] = {
|
||||
"arm": "ollama-native",
|
||||
"model": model,
|
||||
"num_ctx": num_ctx,
|
||||
"num_predict": num_predict,
|
||||
"started_at": time.time(),
|
||||
"turns": [],
|
||||
"final": None,
|
||||
}
|
||||
|
||||
tool_call_total = 0
|
||||
halt: str | None = None
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for step in range(1, step_budget + 1):
|
||||
t0 = time.time()
|
||||
prompt = _render(messages)
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"raw": True,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_ctx": num_ctx,
|
||||
"num_predict": num_predict,
|
||||
"temperature": 0.7, "top_p": 0.95, "top_k": 64,
|
||||
# Stop at either end-of-turn (final answer) or end-of-tool-call.
|
||||
# "<tool_call|>" lets the regex match on the full call; we
|
||||
# re-append "<tool_call|>" before parsing to keep the regex
|
||||
# simple. "<turn|>" catches a clean final answer.
|
||||
"stop": ["<turn|>", "<tool_call|>"],
|
||||
},
|
||||
"keep_alive": "2h",
|
||||
}
|
||||
try:
|
||||
async with session.post(
|
||||
f"{ollama_url}/api/generate", json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=300),
|
||||
) as resp:
|
||||
r = await resp.json()
|
||||
except Exception as e:
|
||||
halt = f"error: {e}"
|
||||
trace["turns"].append({"step": step, "error": str(e)})
|
||||
break
|
||||
|
||||
completion = r.get("response", "") or ""
|
||||
stop_reason_native = r.get("done_reason") or r.get("stop_reason") or ""
|
||||
|
||||
# Rebuild the full assistant turn. Ollama's /api/generate
|
||||
# strips the matched stop token from the response, so we
|
||||
# always re-append based on which open token is present.
|
||||
# An unclosed `<|tool_call>` means the model was emitting a
|
||||
# tool call when the stop token fired; otherwise the model
|
||||
# was producing a final text turn.
|
||||
if completion.rstrip().endswith(("<tool_call|>", "<turn|>")):
|
||||
full = completion
|
||||
elif "<|tool_call>" in completion and "<tool_call|>" not in completion:
|
||||
full = completion + "<tool_call|>"
|
||||
else:
|
||||
full = completion + "<turn|>"
|
||||
|
||||
matches = list(_TOOL_CALL_RE.finditer(full))
|
||||
history_chars = sum(len(m.get("content", "") or "") for m in messages)
|
||||
|
||||
trace["turns"].append({
|
||||
"step": step,
|
||||
"elapsed_s": round(time.time() - t0, 2),
|
||||
"prompt_eval_count": r.get("prompt_eval_count"),
|
||||
"eval_count": r.get("eval_count"),
|
||||
"content_len": len(completion),
|
||||
"tool_call_count": len(matches),
|
||||
"stop_reason": stop_reason_native,
|
||||
"history_chars_before_append": history_chars,
|
||||
"raw_completion_head": completion[:240],
|
||||
"raw_completion_tail": completion[-240:] if len(completion) > 240 else "",
|
||||
"prompt_tail": prompt[-400:],
|
||||
"prompt_head": prompt[:200],
|
||||
})
|
||||
|
||||
if not matches:
|
||||
# Final answer — take the text minus any trailing <turn|>.
|
||||
content = full.replace("<turn|>", "").strip()
|
||||
messages.append({"role": "assistant", "content": content})
|
||||
halt = "no_tool_calls"
|
||||
break
|
||||
|
||||
# Build an assistant message with tool_calls (OpenAI shape) so the
|
||||
# jinja template re-renders them correctly on the next iteration.
|
||||
tool_calls_msg: list[dict[str, Any]] = []
|
||||
for m in matches:
|
||||
name = m.group("name")
|
||||
args = _parse_native_args(m.group("body"))
|
||||
tool_calls_msg.append({
|
||||
"id": f"call_{step}_{len(tool_calls_msg)}",
|
||||
"function": {"name": name, "arguments": args},
|
||||
})
|
||||
# Content MUST be empty when the message has tool_calls + will
|
||||
# have tool_responses inlined on next render. The jinja
|
||||
# template's post-turn conditional checks message.get('content')
|
||||
# before strip_thinking and any non-empty string (even a bare
|
||||
# <|channel>thought\n<channel|> prefix from the model) causes
|
||||
# a spurious <turn|>\n to be appended after <tool_response|>,
|
||||
# which breaks turn continuation on the following step.
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": tool_calls_msg,
|
||||
})
|
||||
|
||||
tool_call_total += len(tool_calls_msg)
|
||||
for tc in tool_calls_msg:
|
||||
fn = tc["function"]
|
||||
result = execute_tool_stub(fn["name"], fn["arguments"])
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc["id"],
|
||||
"name": fn["name"],
|
||||
"content": result,
|
||||
})
|
||||
|
||||
if step == step_budget:
|
||||
halt = "step_budget"
|
||||
break
|
||||
|
||||
trace["final"] = {
|
||||
"halt_reason": halt,
|
||||
"steps_used": len(trace["turns"]),
|
||||
"tool_calls_total": tool_call_total,
|
||||
"wall_clock_s": round(time.time() - trace["started_at"], 2),
|
||||
"final_message_count": len(messages),
|
||||
"final_history_chars": sum(len(m.get("content", "") or "") for m in messages),
|
||||
}
|
||||
return trace
|
||||
Reference in New Issue
Block a user