df5542f7d6
Three-arm harness under scripts/native-bakeoff/: - arm A: /api/chat with JSON tools (current default) - arm B: /api/generate raw:true with canonical HF jinja template rendered directly - arm C: google-deepmind/gemma JAX ToolSampler (env-gated, JAX required) Interim finding from A+B sweep on matt-strix gemma4:26b Q4: Ollama's bidirectional JSON↔native tool-call translator is faithful. The "long" multi-tool task produces identical behavior (7 steps / 6 tools) on both arms. Earlier arm-B parser bug that looked like a divergence was a harness issue: preserving the model's <|channel>thought\n<channel|> prefix as assistant content tripped the jinja template's tool_response-following conditional, appending a spurious <turn|>\n that corrupted the next step's prompt. Fixed by dropping the channel prefix on the assistant message. Arm C left as scaffolded-but-not-run — the JAX/bf16 reference path would answer "does the GGUF runtime diverge from DeepMind's implementation" but requires a separate env with the `gemma` PyPI package. Parked pending SDXL eviction or vast-h100 session. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
276 lines
9.8 KiB
Python
276 lines
9.8 KiB
Python
"""Arm B: Ollama /api/generate with raw:true and native Gemma 4 tokens.
|
|
|
|
Renders the canonical HF jinja chat template directly, sends the
|
|
resulting string to Ollama's /api/generate with `raw: true` (which
|
|
bypasses Ollama's own templating / BOS handling), and parses
|
|
<|tool_call>call:NAME{args}<tool_call|> out of the completion with a
|
|
regex.
|
|
|
|
The point of this arm: isolate Ollama's tool parser. Arm A lets
|
|
Ollama's server translate OpenAI-shaped JSON tools into native tokens
|
|
AND translate the model's native <|tool_call> output back into
|
|
structured `tool_calls`. Arm B keeps everything native end-to-end and
|
|
only uses Ollama as a thin completion engine. If A and B diverge, the
|
|
delta lives in Ollama's bidirectional JSON↔native translator.
|
|
|
|
Template source: tooling/huggingface/model-cards/gemma-4-E4B-it-chat_template.jinja
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import re
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import aiohttp
|
|
import jinja2
|
|
|
|
from tasks import SYSTEM_PROMPT, TOOLS, FAKE_HISTORY, execute_tool_stub
|
|
|
|
|
|
_REPO_ROOT = Path(__file__).resolve().parents[3]
|
|
_TEMPLATE_PATH = _REPO_ROOT / "tooling" / "huggingface" / "model-cards" / "gemma-4-E4B-it-chat_template.jinja"
|
|
|
|
|
|
def _load_template() -> jinja2.Template:
|
|
env = jinja2.Environment(
|
|
keep_trailing_newline=True,
|
|
# Canonical template uses `{%- ... -%}` whitespace control; keep
|
|
# jinja defaults so it renders exactly as HF's template expects.
|
|
)
|
|
return env.from_string(_TEMPLATE_PATH.read_text())
|
|
|
|
|
|
_TOOL_CALL_RE = re.compile(
|
|
r"<\|tool_call>call:(?P<name>\w+)\{(?P<body>.*?)\}<tool_call\|>",
|
|
re.DOTALL,
|
|
)
|
|
|
|
|
|
def _parse_native_args(body: str) -> dict[str, Any]:
|
|
"""Parse the body of a <|tool_call>call:NAME{...}<tool_call|>.
|
|
|
|
Gemma 4 native arg format (from the jinja template's format_argument
|
|
macro with escape_keys=False):
|
|
- key:<|"|>stringval<|"|>
|
|
- key:123
|
|
- key:true / key:false
|
|
- key:{nested:...} (for mapping args — not used by our stubs)
|
|
- key:[<|"|>item<|"|>,...] (for array args — not used by our stubs)
|
|
|
|
Our stub tool schemas are flat (string / integer / bool), so a
|
|
simple top-level comma split is enough. If a future tool needs
|
|
nested args this needs depth-aware splitting.
|
|
"""
|
|
out: dict[str, Any] = {}
|
|
if not body:
|
|
return out
|
|
|
|
# Top-level comma split, respecting only the `<|"|>...<|"|>` string
|
|
# delimiter (since our tool args don't nest). This intentionally
|
|
# doesn't handle {...} or [...] — flag it with a log entry in the
|
|
# harness if a future tool needs those.
|
|
parts: list[str] = []
|
|
buf = ""
|
|
i = 0
|
|
str_delim = '<|"|>'
|
|
in_str = False
|
|
while i < len(body):
|
|
if body[i : i + len(str_delim)] == str_delim:
|
|
in_str = not in_str
|
|
buf += str_delim
|
|
i += len(str_delim)
|
|
continue
|
|
if body[i] == "," and not in_str:
|
|
parts.append(buf)
|
|
buf = ""
|
|
i += 1
|
|
continue
|
|
buf += body[i]
|
|
i += 1
|
|
if buf:
|
|
parts.append(buf)
|
|
|
|
for p in parts:
|
|
if ":" not in p:
|
|
continue
|
|
k, _, v = p.partition(":")
|
|
k = k.strip()
|
|
v = v.strip()
|
|
if v.startswith(str_delim) and v.endswith(str_delim):
|
|
out[k] = v[len(str_delim) : -len(str_delim)]
|
|
elif v == "true":
|
|
out[k] = True
|
|
elif v == "false":
|
|
out[k] = False
|
|
else:
|
|
try:
|
|
out[k] = int(v)
|
|
except ValueError:
|
|
try:
|
|
out[k] = float(v)
|
|
except ValueError:
|
|
out[k] = v
|
|
return out
|
|
|
|
|
|
def _render(messages: list[dict[str, Any]]) -> str:
|
|
tmpl = _load_template()
|
|
return tmpl.render(
|
|
messages=messages,
|
|
tools=TOOLS,
|
|
add_generation_prompt=True,
|
|
bos_token="<bos>",
|
|
enable_thinking=False,
|
|
)
|
|
|
|
|
|
async def run(
|
|
*,
|
|
ollama_url: str,
|
|
model: str,
|
|
task_prompt: str,
|
|
num_ctx: int,
|
|
num_predict: int,
|
|
step_budget: int,
|
|
) -> dict[str, Any]:
|
|
messages: list[dict[str, Any]] = [{"role": "system", "content": SYSTEM_PROMPT}] + list(FAKE_HISTORY)
|
|
messages.append({"role": "user", "content": f"[2026-04-18 14:20] @seth:sethpc.xyz: {task_prompt}"})
|
|
|
|
trace: dict[str, Any] = {
|
|
"arm": "ollama-native",
|
|
"model": model,
|
|
"num_ctx": num_ctx,
|
|
"num_predict": num_predict,
|
|
"started_at": time.time(),
|
|
"turns": [],
|
|
"final": None,
|
|
}
|
|
|
|
tool_call_total = 0
|
|
halt: str | None = None
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
for step in range(1, step_budget + 1):
|
|
t0 = time.time()
|
|
prompt = _render(messages)
|
|
payload = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"raw": True,
|
|
"stream": False,
|
|
"options": {
|
|
"num_ctx": num_ctx,
|
|
"num_predict": num_predict,
|
|
"temperature": 0.7, "top_p": 0.95, "top_k": 64,
|
|
# Stop at either end-of-turn (final answer) or end-of-tool-call.
|
|
# "<tool_call|>" lets the regex match on the full call; we
|
|
# re-append "<tool_call|>" before parsing to keep the regex
|
|
# simple. "<turn|>" catches a clean final answer.
|
|
"stop": ["<turn|>", "<tool_call|>"],
|
|
},
|
|
"keep_alive": "2h",
|
|
}
|
|
try:
|
|
async with session.post(
|
|
f"{ollama_url}/api/generate", json=payload,
|
|
timeout=aiohttp.ClientTimeout(total=300),
|
|
) as resp:
|
|
r = await resp.json()
|
|
except Exception as e:
|
|
halt = f"error: {e}"
|
|
trace["turns"].append({"step": step, "error": str(e)})
|
|
break
|
|
|
|
completion = r.get("response", "") or ""
|
|
stop_reason_native = r.get("done_reason") or r.get("stop_reason") or ""
|
|
|
|
# Rebuild the full assistant turn. Ollama's /api/generate
|
|
# strips the matched stop token from the response, so we
|
|
# always re-append based on which open token is present.
|
|
# An unclosed `<|tool_call>` means the model was emitting a
|
|
# tool call when the stop token fired; otherwise the model
|
|
# was producing a final text turn.
|
|
if completion.rstrip().endswith(("<tool_call|>", "<turn|>")):
|
|
full = completion
|
|
elif "<|tool_call>" in completion and "<tool_call|>" not in completion:
|
|
full = completion + "<tool_call|>"
|
|
else:
|
|
full = completion + "<turn|>"
|
|
|
|
matches = list(_TOOL_CALL_RE.finditer(full))
|
|
history_chars = sum(len(m.get("content", "") or "") for m in messages)
|
|
|
|
trace["turns"].append({
|
|
"step": step,
|
|
"elapsed_s": round(time.time() - t0, 2),
|
|
"prompt_eval_count": r.get("prompt_eval_count"),
|
|
"eval_count": r.get("eval_count"),
|
|
"content_len": len(completion),
|
|
"tool_call_count": len(matches),
|
|
"stop_reason": stop_reason_native,
|
|
"history_chars_before_append": history_chars,
|
|
"raw_completion_head": completion[:240],
|
|
"raw_completion_tail": completion[-240:] if len(completion) > 240 else "",
|
|
"prompt_tail": prompt[-400:],
|
|
"prompt_head": prompt[:200],
|
|
})
|
|
|
|
if not matches:
|
|
# Final answer — take the text minus any trailing <turn|>.
|
|
content = full.replace("<turn|>", "").strip()
|
|
messages.append({"role": "assistant", "content": content})
|
|
halt = "no_tool_calls"
|
|
break
|
|
|
|
# Build an assistant message with tool_calls (OpenAI shape) so the
|
|
# jinja template re-renders them correctly on the next iteration.
|
|
tool_calls_msg: list[dict[str, Any]] = []
|
|
for m in matches:
|
|
name = m.group("name")
|
|
args = _parse_native_args(m.group("body"))
|
|
tool_calls_msg.append({
|
|
"id": f"call_{step}_{len(tool_calls_msg)}",
|
|
"function": {"name": name, "arguments": args},
|
|
})
|
|
# Content MUST be empty when the message has tool_calls + will
|
|
# have tool_responses inlined on next render. The jinja
|
|
# template's post-turn conditional checks message.get('content')
|
|
# before strip_thinking and any non-empty string (even a bare
|
|
# <|channel>thought\n<channel|> prefix from the model) causes
|
|
# a spurious <turn|>\n to be appended after <tool_response|>,
|
|
# which breaks turn continuation on the following step.
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": tool_calls_msg,
|
|
})
|
|
|
|
tool_call_total += len(tool_calls_msg)
|
|
for tc in tool_calls_msg:
|
|
fn = tc["function"]
|
|
result = execute_tool_stub(fn["name"], fn["arguments"])
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tc["id"],
|
|
"name": fn["name"],
|
|
"content": result,
|
|
})
|
|
|
|
if step == step_budget:
|
|
halt = "step_budget"
|
|
break
|
|
|
|
trace["final"] = {
|
|
"halt_reason": halt,
|
|
"steps_used": len(trace["turns"]),
|
|
"tool_calls_total": tool_call_total,
|
|
"wall_clock_s": round(time.time() - trace["started_at"], 2),
|
|
"final_message_count": len(messages),
|
|
"final_history_chars": sum(len(m.get("content", "") or "") for m in messages),
|
|
}
|
|
return trace
|