Fix training script: bf16 for Ampere GPU, add system prompts to training data

- Switch fp16 to bf16 (RTX 3090 Ti is Ampere, supports BF16 natively)
- Include system prompt in training conversations (mode-aware: sudo/god/god_system)
- Include message field only for god modes
- Add determine_mode() and get_system_prompt() helpers

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-18 16:26:47 -04:00
parent 78031d16c0
commit 142e4fd3c4
+41 -3
View File
@@ -18,8 +18,39 @@ import os
from pathlib import Path from pathlib import Path
def determine_mode(example: dict) -> str:
"""Determine prompt mode from the example."""
query = example["input"]["user_message"]
eid = example.get("id", "")
if query.lower().startswith("pray "):
return "god"
elif eid.startswith("negative-") and "god" in query.lower():
return "god_system"
elif example.get("source") == "prayer_log":
return "god"
return "sudo"
def get_system_prompt(mode: str) -> str:
"""Get the system prompt for training. Import from project if available, fallback to inline."""
try:
import sys
script_dir = Path(__file__).resolve().parent
project_root = script_dir.parent.parent
sys.path.insert(0, str(project_root))
from agent.prompts.system_prompts import get_prompt
return get_prompt(mode)
except ImportError:
# Minimal fallback prompts
if mode == "god":
return "You are God in a Minecraft server. Return JSON: {\"message\": \"...\", \"commands\": [...], \"reasoning\": \"...\"}"
elif mode == "god_system":
return "You are God performing an unprompted intervention. Return JSON: {\"message\": \"...\", \"commands\": [...]}"
return "You are a Minecraft 1.21 command translator. Return JSON: {\"commands\": [...], \"reasoning\": \"...\"}"
def load_dataset(path: str) -> list: def load_dataset(path: str) -> list:
"""Load seed dataset and format for SFT training.""" """Load seed dataset and format for SFT training with system prompts and mode awareness."""
examples = [] examples = []
with open(path) as f: with open(path) as f:
for line in f: for line in f:
@@ -27,6 +58,10 @@ def load_dataset(path: str) -> list:
continue continue
ex = json.loads(line) ex = json.loads(line)
# Determine mode and get system prompt
mode = determine_mode(ex)
system_prompt = get_system_prompt(mode)
# Build the training conversation # Build the training conversation
inp = ex["input"] inp = ex["input"]
out = ex["output"] out = ex["output"]
@@ -48,11 +83,14 @@ def load_dataset(path: str) -> list:
response = { response = {
"reasoning": out.get("reasoning", ""), "reasoning": out.get("reasoning", ""),
"commands": out.get("commands", []), "commands": out.get("commands", []),
"message": out.get("message"),
} }
# Include message field for god modes
if mode in ("god", "god_system"):
response["message"] = out.get("message") or ""
examples.append({ examples.append({
"conversations": [ "conversations": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_msg}, {"role": "user", "content": user_msg},
{"role": "assistant", "content": json.dumps(response)}, {"role": "assistant", "content": json.dumps(response)},
] ]
@@ -158,7 +196,7 @@ def main():
lr_scheduler_type="cosine", lr_scheduler_type="cosine",
warmup_ratio=0.1, warmup_ratio=0.1,
weight_decay=0.01, weight_decay=0.01,
fp16=True, bf16=True,
logging_steps=1, logging_steps=1,
save_strategy="epoch", save_strategy="epoch",
seed=42, seed=42,