diff --git a/training/scripts/self_play.py b/training/scripts/self_play.py index d63e03b..f0c33d0 100644 --- a/training/scripts/self_play.py +++ b/training/scripts/self_play.py @@ -42,6 +42,9 @@ sys.path.insert(0, str(ROOT)) OUTPUT = ROOT / "data" / "processed" / "self_play.jsonl" +# Module-level API key, set from args in main() +_API_KEY = None + # --- RCON (persistent connection) --- from agent.tools.persistent_rcon import get_rcon @@ -72,8 +75,8 @@ def rcon_command(cmd, host, port, password): # --- LLM calls --- -def llm_call(model, system, user, ollama_url, temperature=0.7, max_tokens=500, fmt=None): - """Call Ollama and return content with think blocks stripped.""" +def llm_call(model, system, user, ollama_url, temperature=0.7, max_tokens=500, fmt=None, api_key=None): + """Call Ollama (or gateway) and return content with think blocks stripped.""" payload = { "model": model, "messages": [ @@ -85,7 +88,10 @@ def llm_call(model, system, user, ollama_url, temperature=0.7, max_tokens=500, f } if fmt: payload["format"] = fmt - r = requests.post(f"{ollama_url}/api/chat", json=payload, timeout=120) + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + r = requests.post(f"{ollama_url}/api/chat", json=payload, headers=headers, timeout=120) r.raise_for_status() content = r.json()["message"]["content"] # Strip think blocks @@ -345,7 +351,7 @@ def attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass, # First attempt try: - raw = llm_call(model, system, prompt, ollama_url, temperature=0.3, max_tokens=300, fmt="json") + raw = llm_call(model, system, prompt, ollama_url, temperature=0.3, max_tokens=300, fmt="json", api_key=_API_KEY) result = json.loads(raw) except (json.JSONDecodeError, Exception) as e: # Try extracting JSON @@ -401,7 +407,7 @@ def attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass, retry_prompt = f"Original request: {prompt}\n\nFailed commands:\n{error_context}\n\nPlease fix the commands." try: - raw = llm_call(model, RETRY_SYSTEM, retry_prompt, ollama_url, temperature=0.2, max_tokens=300, fmt="json") + raw = llm_call(model, RETRY_SYSTEM, retry_prompt, ollama_url, temperature=0.2, max_tokens=300, fmt="json", api_key=_API_KEY) result = json.loads(raw) except: match = re.search(r'\{[\s\S]*\}', raw if 'raw' in dir() else '') @@ -542,6 +548,7 @@ def main(): parser = argparse.ArgumentParser(description="Self-play training data generator") parser.add_argument("--model", default="qwen3-8b-mc-lora-v3") parser.add_argument("--ollama-url", default="http://192.168.0.141:11434") + parser.add_argument("--api-key", default=None, help="API key for authenticated gateways") parser.add_argument("--rcon-host", default="192.168.0.244") parser.add_argument("--rcon-port", type=int, default=25578) parser.add_argument("--rcon-pass", default="REDACTED_RCON") @@ -553,6 +560,9 @@ def main(): parser.add_argument("--max-retries", type=int, default=2) args = parser.parse_args() + global _API_KEY + _API_KEY = args.api_key + tiers = [1, 2, 3] if args.tier == "all" else [int(args.tier)] print(f"Self-play training data generator")