diff --git a/training/scripts/train_lora.py b/training/scripts/train_lora.py index c2d5a58..152e408 100644 --- a/training/scripts/train_lora.py +++ b/training/scripts/train_lora.py @@ -18,8 +18,39 @@ import os from pathlib import Path +def determine_mode(example: dict) -> str: + """Determine prompt mode from the example.""" + query = example["input"]["user_message"] + eid = example.get("id", "") + if query.lower().startswith("pray "): + return "god" + elif eid.startswith("negative-") and "god" in query.lower(): + return "god_system" + elif example.get("source") == "prayer_log": + return "god" + return "sudo" + + +def get_system_prompt(mode: str) -> str: + """Get the system prompt for training. Import from project if available, fallback to inline.""" + try: + import sys + script_dir = Path(__file__).resolve().parent + project_root = script_dir.parent.parent + sys.path.insert(0, str(project_root)) + from agent.prompts.system_prompts import get_prompt + return get_prompt(mode) + except ImportError: + # Minimal fallback prompts + if mode == "god": + return "You are God in a Minecraft server. Return JSON: {\"message\": \"...\", \"commands\": [...], \"reasoning\": \"...\"}" + elif mode == "god_system": + return "You are God performing an unprompted intervention. Return JSON: {\"message\": \"...\", \"commands\": [...]}" + return "You are a Minecraft 1.21 command translator. Return JSON: {\"commands\": [...], \"reasoning\": \"...\"}" + + def load_dataset(path: str) -> list: - """Load seed dataset and format for SFT training.""" + """Load seed dataset and format for SFT training with system prompts and mode awareness.""" examples = [] with open(path) as f: for line in f: @@ -27,6 +58,10 @@ def load_dataset(path: str) -> list: continue ex = json.loads(line) + # Determine mode and get system prompt + mode = determine_mode(ex) + system_prompt = get_system_prompt(mode) + # Build the training conversation inp = ex["input"] out = ex["output"] @@ -48,11 +83,14 @@ def load_dataset(path: str) -> list: response = { "reasoning": out.get("reasoning", ""), "commands": out.get("commands", []), - "message": out.get("message"), } + # Include message field for god modes + if mode in ("god", "god_system"): + response["message"] = out.get("message") or "" examples.append({ "conversations": [ + {"role": "system", "content": system_prompt}, {"role": "user", "content": user_msg}, {"role": "assistant", "content": json.dumps(response)}, ] @@ -158,7 +196,7 @@ def main(): lr_scheduler_type="cosine", warmup_ratio=0.1, weight_decay=0.01, - fp16=True, + bf16=True, logging_steps=1, save_strategy="epoch", seed=42,