3-tier self-play: command drills, self-critique, adversarial
Tier 1 — Command drills: Random seed prompts → generate commands → RCON validates Teaches: accurate command syntax Tier 2 — Single-shot self-critique: Model invents a tricky prompt AND responds in one call RCON validates the self-generated commands Teaches: edge-case awareness, self-evaluation Tier 3 — Adversarial self-play: Session A generates challenging prompts Fresh Session B responds cold (can't cheat) RCON validates, self-corrects on errors Teaches: robustness, generalization Usage: --tier 1|2|3|all --rounds N --focus category Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+256
-55
@@ -1,20 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
self_play.py — Self-play training data generator.
|
||||
self_play.py — Multi-tier self-play training data generator.
|
||||
|
||||
The fine-tuned model generates its own training data by:
|
||||
1. Generating diverse edge-case prompts it's uncertain about
|
||||
2. Attempting commands via RCON
|
||||
3. Self-correcting on errors
|
||||
4. Saving successful sequences as training examples
|
||||
Three tiers of self-play, each teaching different skills:
|
||||
|
||||
This creates a closed-loop learning system with RCON as ground truth.
|
||||
No API cost — runs entirely on the local model.
|
||||
Tier 1 — Command drills:
|
||||
Feed known prompts, execute commands via RCON, validate syntax.
|
||||
Teaches: accurate command generation.
|
||||
Usage: --tier 1 --rounds 50
|
||||
|
||||
Usage:
|
||||
python3 training/scripts/self_play.py --rounds 100 --model qwen3.5-9b-mc-v4
|
||||
python3 training/scripts/self_play.py --rounds 50 --dry-run
|
||||
python3 training/scripts/self_play.py --rounds 200 --focus enchantments
|
||||
Tier 2 — Single-shot self-critique:
|
||||
Model generates BOTH the prompt AND response in one call.
|
||||
Teaches: edge-case awareness, self-evaluation.
|
||||
Usage: --tier 2 --rounds 50
|
||||
|
||||
Tier 3 — Adversarial self-play:
|
||||
Session A generates a challenging prompt. Fresh Session B responds.
|
||||
RCON validates. Model can't cheat by knowing both sides.
|
||||
Teaches: robustness, generalization, error correction.
|
||||
Usage: --tier 3 --rounds 50
|
||||
|
||||
All tiers:
|
||||
--tier all --rounds 50 (runs ~17 rounds of each)
|
||||
|
||||
No API cost — runs entirely on the local model with RCON as ground truth.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -182,6 +191,120 @@ Analyze the error and return a corrected command.
|
||||
Return JSON: {"commands": ["corrected_cmd"], "reasoning": "what was wrong and how you fixed it"}"""
|
||||
|
||||
|
||||
# --- Tier 1: Command drills ---
|
||||
|
||||
def run_tier1_drill(model, ollama_url, rcon_host, rcon_port, rcon_pass, max_retries=2):
|
||||
"""Pick a random prompt from seed dataset, generate commands, validate via RCON."""
|
||||
seed_path = ROOT / "data" / "processed" / "seed_dataset.jsonl"
|
||||
with open(seed_path) as f:
|
||||
lines = [l for l in f if l.strip()]
|
||||
line = random.choice(lines)
|
||||
ex = json.loads(line)
|
||||
prompt = ex["input"]["user_message"]
|
||||
# Only drill command_gen examples
|
||||
if ex.get("category") not in ("command_gen",):
|
||||
return None
|
||||
|
||||
trace = attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass, max_retries)
|
||||
trace["tier"] = 1
|
||||
trace["original_commands"] = ex.get("output", {}).get("commands", [])
|
||||
return trace
|
||||
|
||||
|
||||
# --- Tier 2: Single-shot self-critique ---
|
||||
|
||||
SELF_CRITIQUE_SYSTEM = """You are a Minecraft 1.21 AI training data generator AND command translator.
|
||||
|
||||
Your task: generate a challenging Minecraft player request, then respond to it yourself.
|
||||
Focus on edge cases you might get wrong: unusual items, complex enchantments, execute chains, ambiguous phrasing.
|
||||
|
||||
Return JSON:
|
||||
{
|
||||
"generated_prompt": "the player request you invented (must start with 'sudo ' or 'pray ')",
|
||||
"difficulty": "what makes this tricky",
|
||||
"commands": ["cmd1", "cmd2"],
|
||||
"reasoning": "why these commands are correct",
|
||||
"message": "God message if pray, empty string if sudo"
|
||||
}
|
||||
|
||||
Commands use minecraft: prefix. Enchantments: item[enchantments={name:level}].
|
||||
Effects: effect give <player> minecraft:<effect> <seconds> <amplifier>.
|
||||
Player: slingshooter08. Do NOT start commands with /."""
|
||||
|
||||
|
||||
def run_tier2_selfcritique(model, ollama_url, rcon_host, rcon_port, rcon_pass, category=None):
|
||||
"""Model generates a prompt AND responds in one shot, then RCON validates."""
|
||||
focus = ""
|
||||
if category:
|
||||
focus = f"\nFocus area: {category}. Generate a prompt specifically testing {category}."
|
||||
|
||||
try:
|
||||
raw = llm_call(
|
||||
model=model,
|
||||
system=SELF_CRITIQUE_SYSTEM + focus,
|
||||
user="Generate one challenging Minecraft request and your response. Be creative — pick something you might get wrong.",
|
||||
ollama_url=ollama_url,
|
||||
temperature=0.9,
|
||||
max_tokens=500,
|
||||
fmt="json",
|
||||
)
|
||||
result = json.loads(raw)
|
||||
except:
|
||||
match = re.search(r'\{[\s\S]*\}', raw if 'raw' in dir() else '')
|
||||
if match:
|
||||
try:
|
||||
result = json.loads(match.group())
|
||||
except:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
prompt = result.get("generated_prompt", "")
|
||||
commands = result.get("commands") or []
|
||||
message = result.get("message") or ""
|
||||
reasoning = result.get("reasoning") or ""
|
||||
difficulty = result.get("difficulty") or ""
|
||||
|
||||
if not prompt:
|
||||
return None
|
||||
|
||||
trace = {
|
||||
"prompt": prompt,
|
||||
"mode": "god" if prompt.lower().startswith("pray ") else "sudo",
|
||||
"tier": 2,
|
||||
"difficulty_note": difficulty,
|
||||
"attempts": [],
|
||||
"final_success": False,
|
||||
"self_corrected": False,
|
||||
}
|
||||
|
||||
if not commands:
|
||||
trace["attempts"].append({
|
||||
"commands": [], "reasoning": reasoning, "message": message,
|
||||
"rcon_results": [], "all_success": True,
|
||||
})
|
||||
trace["final_success"] = True
|
||||
return trace
|
||||
|
||||
# Validate via RCON
|
||||
rcon_results = []
|
||||
all_success = True
|
||||
for cmd in commands:
|
||||
success, rcon_result = rcon_command(cmd, rcon_host, rcon_port, rcon_pass)
|
||||
rcon_results.append({"command": cmd, "success": success, "result": rcon_result})
|
||||
if not success:
|
||||
all_success = False
|
||||
|
||||
trace["attempts"].append({
|
||||
"commands": commands, "reasoning": reasoning, "message": message,
|
||||
"rcon_results": rcon_results, "all_success": all_success,
|
||||
})
|
||||
trace["final_success"] = all_success
|
||||
return trace
|
||||
|
||||
|
||||
# --- Tier 3: Adversarial self-play (original generate_prompts + attempt_command) ---
|
||||
|
||||
def generate_prompts(model, ollama_url, category=None):
|
||||
"""Use the model to generate edge-case prompts for itself."""
|
||||
if category:
|
||||
@@ -436,16 +559,20 @@ def main():
|
||||
parser.add_argument("--rcon-port", type=int, default=25578)
|
||||
parser.add_argument("--rcon-pass", default="REDACTED_RCON")
|
||||
parser.add_argument("--rounds", type=int, default=50)
|
||||
parser.add_argument("--tier", default="all", choices=["1", "2", "3", "all"])
|
||||
parser.add_argument("--focus", default=None, choices=list(EXPLORATION_CATEGORIES.keys()))
|
||||
parser.add_argument("--output", default=str(OUTPUT))
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
parser.add_argument("--max-retries", type=int, default=2)
|
||||
args = parser.parse_args()
|
||||
|
||||
tiers = [1, 2, 3] if args.tier == "all" else [int(args.tier)]
|
||||
|
||||
print(f"Self-play training data generator")
|
||||
print(f" Model: {args.model}")
|
||||
print(f" RCON: {args.rcon_host}:{args.rcon_port}")
|
||||
print(f" Rounds: {args.rounds}")
|
||||
print(f" Tiers: {tiers}")
|
||||
print(f" Focus: {args.focus or 'all categories'}")
|
||||
print(f" Max retries: {args.max_retries}")
|
||||
print(f" Output: {args.output}")
|
||||
@@ -454,60 +581,129 @@ def main():
|
||||
stats = {
|
||||
"rounds": 0, "prompts_generated": 0, "attempts": 0,
|
||||
"success_first_try": 0, "self_corrected": 0, "failed": 0,
|
||||
"training_examples": 0, "by_category": {},
|
||||
"training_examples": 0, "by_tier": {1: 0, 2: 0, 3: 0}, "by_category": {},
|
||||
}
|
||||
|
||||
for round_num in range(args.rounds):
|
||||
print(f"\n--- Round {round_num + 1}/{args.rounds} ---")
|
||||
|
||||
# Generate prompts
|
||||
prompts = generate_prompts(args.model, args.ollama_url, args.focus)
|
||||
if not prompts:
|
||||
print(" No prompts generated, skipping round")
|
||||
continue
|
||||
|
||||
stats["prompts_generated"] += len(prompts)
|
||||
print(f" Generated {len(prompts)} prompts")
|
||||
|
||||
for p in prompts:
|
||||
prompt = p["prompt"]
|
||||
cat = p["category"]
|
||||
stats["by_category"].setdefault(cat, {"total": 0, "success": 0, "corrected": 0})
|
||||
stats["by_category"][cat]["total"] += 1
|
||||
stats["attempts"] += 1
|
||||
|
||||
print(f" [{cat}] {prompt[:60]:60}", end="")
|
||||
tier = tiers[round_num % len(tiers)]
|
||||
print(f"\n--- Round {round_num + 1}/{args.rounds} [Tier {tier}] ---")
|
||||
|
||||
if tier == 1:
|
||||
# Command drill: pick random seed example, try to execute
|
||||
if args.dry_run:
|
||||
print(" [DRY RUN]")
|
||||
print(" [DRY RUN] Would drill a random seed prompt via RCON")
|
||||
stats["rounds"] += 1
|
||||
continue
|
||||
|
||||
trace = attempt_command(
|
||||
args.model, args.ollama_url, prompt,
|
||||
args.rcon_host, args.rcon_port, args.rcon_pass,
|
||||
max_retries=args.max_retries,
|
||||
for _ in range(5): # 5 drills per round
|
||||
trace = run_tier1_drill(
|
||||
args.model, args.ollama_url,
|
||||
args.rcon_host, args.rcon_port, args.rcon_pass,
|
||||
args.max_retries,
|
||||
)
|
||||
if trace is None:
|
||||
continue
|
||||
stats["attempts"] += 1
|
||||
stats["by_tier"][1] += 1
|
||||
prompt = trace["prompt"]
|
||||
print(f" [drill] {prompt[:55]:55}", end="")
|
||||
if trace["final_success"] and not trace["self_corrected"]:
|
||||
stats["success_first_try"] += 1
|
||||
n_cmds = len(trace["attempts"][0].get("commands", []))
|
||||
print(f" OK ({n_cmds} cmds)")
|
||||
elif trace.get("self_corrected"):
|
||||
stats["self_corrected"] += 1
|
||||
print(f" CORRECTED ({len(trace['attempts'])} attempts)")
|
||||
else:
|
||||
stats["failed"] += 1
|
||||
print(f" FAILED")
|
||||
examples = trace_to_training(trace)
|
||||
all_examples.extend(examples)
|
||||
stats["training_examples"] += len(examples)
|
||||
time.sleep(0.5)
|
||||
|
||||
elif tier == 2:
|
||||
# Self-critique: model generates prompt + response, RCON validates
|
||||
cats = [args.focus] if args.focus else random.sample(
|
||||
list(EXPLORATION_CATEGORIES.keys()), min(3, len(EXPLORATION_CATEGORIES))
|
||||
)
|
||||
for cat in cats:
|
||||
if args.dry_run:
|
||||
print(f" [DRY RUN] Would self-critique on {cat}")
|
||||
continue
|
||||
|
||||
if trace["final_success"] and not trace["self_corrected"]:
|
||||
stats["success_first_try"] += 1
|
||||
stats["by_category"][cat]["success"] += 1
|
||||
n_cmds = len(trace["attempts"][0].get("commands", []))
|
||||
print(f" OK ({n_cmds} cmds)")
|
||||
elif trace["self_corrected"]:
|
||||
stats["self_corrected"] += 1
|
||||
stats["by_category"][cat]["corrected"] += 1
|
||||
print(f" CORRECTED ({len(trace['attempts'])} attempts)")
|
||||
else:
|
||||
stats["failed"] += 1
|
||||
print(f" FAILED")
|
||||
trace = run_tier2_selfcritique(
|
||||
args.model, args.ollama_url,
|
||||
args.rcon_host, args.rcon_port, args.rcon_pass,
|
||||
category=cat,
|
||||
)
|
||||
if trace is None:
|
||||
continue
|
||||
stats["attempts"] += 1
|
||||
stats["by_tier"][2] += 1
|
||||
prompt = trace["prompt"]
|
||||
diff = trace.get("difficulty_note", "")[:30]
|
||||
print(f" [self-critique:{cat[:12]}] {prompt[:40]:40} ({diff})", end="")
|
||||
if trace["final_success"]:
|
||||
stats["success_first_try"] += 1
|
||||
n_cmds = len(trace["attempts"][0].get("commands", []))
|
||||
print(f" OK ({n_cmds} cmds)")
|
||||
else:
|
||||
stats["failed"] += 1
|
||||
print(f" FAILED (self-generated bad commands)")
|
||||
examples = trace_to_training(trace)
|
||||
all_examples.extend(examples)
|
||||
stats["training_examples"] += len(examples)
|
||||
time.sleep(1)
|
||||
|
||||
# Convert to training examples
|
||||
examples = trace_to_training(trace)
|
||||
all_examples.extend(examples)
|
||||
stats["training_examples"] += len(examples)
|
||||
elif tier == 3:
|
||||
# Adversarial: generate prompts in Session A, respond in fresh Session B
|
||||
prompts = generate_prompts(args.model, args.ollama_url, args.focus)
|
||||
if not prompts:
|
||||
print(" No prompts generated, skipping round")
|
||||
stats["rounds"] += 1
|
||||
continue
|
||||
|
||||
# Brief pause between attempts
|
||||
time.sleep(1)
|
||||
stats["prompts_generated"] += len(prompts)
|
||||
print(f" Generated {len(prompts)} adversarial prompts")
|
||||
|
||||
for p in prompts:
|
||||
prompt = p["prompt"]
|
||||
cat = p["category"]
|
||||
stats["by_category"].setdefault(cat, {"total": 0, "success": 0, "corrected": 0})
|
||||
stats["by_category"][cat]["total"] += 1
|
||||
stats["attempts"] += 1
|
||||
stats["by_tier"][3] += 1
|
||||
|
||||
print(f" [adversarial:{cat[:12]}] {prompt[:48]:48}", end="")
|
||||
|
||||
if args.dry_run:
|
||||
print(" [DRY RUN]")
|
||||
continue
|
||||
|
||||
trace = attempt_command(
|
||||
args.model, args.ollama_url, prompt,
|
||||
args.rcon_host, args.rcon_port, args.rcon_pass,
|
||||
max_retries=args.max_retries,
|
||||
)
|
||||
|
||||
if trace["final_success"] and not trace["self_corrected"]:
|
||||
stats["success_first_try"] += 1
|
||||
stats["by_category"][cat]["success"] += 1
|
||||
n_cmds = len(trace["attempts"][0].get("commands", []))
|
||||
print(f" OK ({n_cmds} cmds)")
|
||||
elif trace["self_corrected"]:
|
||||
stats["self_corrected"] += 1
|
||||
stats["by_category"][cat]["corrected"] += 1
|
||||
print(f" CORRECTED ({len(trace['attempts'])} attempts)")
|
||||
else:
|
||||
stats["failed"] += 1
|
||||
print(f" FAILED")
|
||||
|
||||
examples = trace_to_training(trace)
|
||||
all_examples.extend(examples)
|
||||
stats["training_examples"] += len(examples)
|
||||
time.sleep(1)
|
||||
|
||||
stats["rounds"] += 1
|
||||
|
||||
@@ -530,6 +726,11 @@ def main():
|
||||
print(f" Failed: {stats['failed']}")
|
||||
print(f" Training examples:{stats['training_examples']}")
|
||||
|
||||
print(f"\n By tier:")
|
||||
for t in sorted(stats["by_tier"]):
|
||||
labels = {1: "Command drills", 2: "Self-critique", 3: "Adversarial"}
|
||||
print(f" Tier {t} ({labels[t]:16}): {stats['by_tier'][t]} attempts")
|
||||
|
||||
if stats["by_category"]:
|
||||
print(f"\n By category:")
|
||||
for cat, s in sorted(stats["by_category"].items()):
|
||||
|
||||
Reference in New Issue
Block a user