Self-play: --api-key for authenticated gateway connections
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -42,6 +42,9 @@ sys.path.insert(0, str(ROOT))
|
||||
|
||||
OUTPUT = ROOT / "data" / "processed" / "self_play.jsonl"
|
||||
|
||||
# Module-level API key, set from args in main()
|
||||
_API_KEY = None
|
||||
|
||||
# --- RCON (persistent connection) ---
|
||||
|
||||
from agent.tools.persistent_rcon import get_rcon
|
||||
@@ -72,8 +75,8 @@ def rcon_command(cmd, host, port, password):
|
||||
|
||||
# --- LLM calls ---
|
||||
|
||||
def llm_call(model, system, user, ollama_url, temperature=0.7, max_tokens=500, fmt=None):
|
||||
"""Call Ollama and return content with think blocks stripped."""
|
||||
def llm_call(model, system, user, ollama_url, temperature=0.7, max_tokens=500, fmt=None, api_key=None):
|
||||
"""Call Ollama (or gateway) and return content with think blocks stripped."""
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
@@ -85,7 +88,10 @@ def llm_call(model, system, user, ollama_url, temperature=0.7, max_tokens=500, f
|
||||
}
|
||||
if fmt:
|
||||
payload["format"] = fmt
|
||||
r = requests.post(f"{ollama_url}/api/chat", json=payload, timeout=120)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
r = requests.post(f"{ollama_url}/api/chat", json=payload, headers=headers, timeout=120)
|
||||
r.raise_for_status()
|
||||
content = r.json()["message"]["content"]
|
||||
# Strip think blocks
|
||||
@@ -345,7 +351,7 @@ def attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass,
|
||||
|
||||
# First attempt
|
||||
try:
|
||||
raw = llm_call(model, system, prompt, ollama_url, temperature=0.3, max_tokens=300, fmt="json")
|
||||
raw = llm_call(model, system, prompt, ollama_url, temperature=0.3, max_tokens=300, fmt="json", api_key=_API_KEY)
|
||||
result = json.loads(raw)
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
# Try extracting JSON
|
||||
@@ -401,7 +407,7 @@ def attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass,
|
||||
retry_prompt = f"Original request: {prompt}\n\nFailed commands:\n{error_context}\n\nPlease fix the commands."
|
||||
|
||||
try:
|
||||
raw = llm_call(model, RETRY_SYSTEM, retry_prompt, ollama_url, temperature=0.2, max_tokens=300, fmt="json")
|
||||
raw = llm_call(model, RETRY_SYSTEM, retry_prompt, ollama_url, temperature=0.2, max_tokens=300, fmt="json", api_key=_API_KEY)
|
||||
result = json.loads(raw)
|
||||
except:
|
||||
match = re.search(r'\{[\s\S]*\}', raw if 'raw' in dir() else '')
|
||||
@@ -542,6 +548,7 @@ def main():
|
||||
parser = argparse.ArgumentParser(description="Self-play training data generator")
|
||||
parser.add_argument("--model", default="qwen3-8b-mc-lora-v3")
|
||||
parser.add_argument("--ollama-url", default="http://192.168.0.141:11434")
|
||||
parser.add_argument("--api-key", default=None, help="API key for authenticated gateways")
|
||||
parser.add_argument("--rcon-host", default="192.168.0.244")
|
||||
parser.add_argument("--rcon-port", type=int, default=25578)
|
||||
parser.add_argument("--rcon-pass", default="REDACTED_RCON")
|
||||
@@ -553,6 +560,9 @@ def main():
|
||||
parser.add_argument("--max-retries", type=int, default=2)
|
||||
args = parser.parse_args()
|
||||
|
||||
global _API_KEY
|
||||
_API_KEY = args.api_key
|
||||
|
||||
tiers = [1, 2, 3] if args.tier == "all" else [int(args.tier)]
|
||||
|
||||
print(f"Self-play training data generator")
|
||||
|
||||
Reference in New Issue
Block a user