Fix training script: bf16 for Ampere GPU, add system prompts to training data
- Switch fp16 to bf16 (RTX 3090 Ti is Ampere, supports BF16 natively) - Include system prompt in training conversations (mode-aware: sudo/god/god_system) - Include message field only for god modes - Add determine_mode() and get_system_prompt() helpers Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -18,8 +18,39 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def determine_mode(example: dict) -> str:
|
||||||
|
"""Determine prompt mode from the example."""
|
||||||
|
query = example["input"]["user_message"]
|
||||||
|
eid = example.get("id", "")
|
||||||
|
if query.lower().startswith("pray "):
|
||||||
|
return "god"
|
||||||
|
elif eid.startswith("negative-") and "god" in query.lower():
|
||||||
|
return "god_system"
|
||||||
|
elif example.get("source") == "prayer_log":
|
||||||
|
return "god"
|
||||||
|
return "sudo"
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_prompt(mode: str) -> str:
|
||||||
|
"""Get the system prompt for training. Import from project if available, fallback to inline."""
|
||||||
|
try:
|
||||||
|
import sys
|
||||||
|
script_dir = Path(__file__).resolve().parent
|
||||||
|
project_root = script_dir.parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
from agent.prompts.system_prompts import get_prompt
|
||||||
|
return get_prompt(mode)
|
||||||
|
except ImportError:
|
||||||
|
# Minimal fallback prompts
|
||||||
|
if mode == "god":
|
||||||
|
return "You are God in a Minecraft server. Return JSON: {\"message\": \"...\", \"commands\": [...], \"reasoning\": \"...\"}"
|
||||||
|
elif mode == "god_system":
|
||||||
|
return "You are God performing an unprompted intervention. Return JSON: {\"message\": \"...\", \"commands\": [...]}"
|
||||||
|
return "You are a Minecraft 1.21 command translator. Return JSON: {\"commands\": [...], \"reasoning\": \"...\"}"
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(path: str) -> list:
|
def load_dataset(path: str) -> list:
|
||||||
"""Load seed dataset and format for SFT training."""
|
"""Load seed dataset and format for SFT training with system prompts and mode awareness."""
|
||||||
examples = []
|
examples = []
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
@@ -27,6 +58,10 @@ def load_dataset(path: str) -> list:
|
|||||||
continue
|
continue
|
||||||
ex = json.loads(line)
|
ex = json.loads(line)
|
||||||
|
|
||||||
|
# Determine mode and get system prompt
|
||||||
|
mode = determine_mode(ex)
|
||||||
|
system_prompt = get_system_prompt(mode)
|
||||||
|
|
||||||
# Build the training conversation
|
# Build the training conversation
|
||||||
inp = ex["input"]
|
inp = ex["input"]
|
||||||
out = ex["output"]
|
out = ex["output"]
|
||||||
@@ -48,11 +83,14 @@ def load_dataset(path: str) -> list:
|
|||||||
response = {
|
response = {
|
||||||
"reasoning": out.get("reasoning", ""),
|
"reasoning": out.get("reasoning", ""),
|
||||||
"commands": out.get("commands", []),
|
"commands": out.get("commands", []),
|
||||||
"message": out.get("message"),
|
|
||||||
}
|
}
|
||||||
|
# Include message field for god modes
|
||||||
|
if mode in ("god", "god_system"):
|
||||||
|
response["message"] = out.get("message") or ""
|
||||||
|
|
||||||
examples.append({
|
examples.append({
|
||||||
"conversations": [
|
"conversations": [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": user_msg},
|
{"role": "user", "content": user_msg},
|
||||||
{"role": "assistant", "content": json.dumps(response)},
|
{"role": "assistant", "content": json.dumps(response)},
|
||||||
]
|
]
|
||||||
@@ -158,7 +196,7 @@ def main():
|
|||||||
lr_scheduler_type="cosine",
|
lr_scheduler_type="cosine",
|
||||||
warmup_ratio=0.1,
|
warmup_ratio=0.1,
|
||||||
weight_decay=0.01,
|
weight_decay=0.01,
|
||||||
fp16=True,
|
bf16=True,
|
||||||
logging_steps=1,
|
logging_steps=1,
|
||||||
save_strategy="epoch",
|
save_strategy="epoch",
|
||||||
seed=42,
|
seed=42,
|
||||||
|
|||||||
Reference in New Issue
Block a user