3-tier self-play: command drills, self-critique, adversarial
Tier 1 — Command drills: Random seed prompts → generate commands → RCON validates Teaches: accurate command syntax Tier 2 — Single-shot self-critique: Model invents a tricky prompt AND responds in one call RCON validates the self-generated commands Teaches: edge-case awareness, self-evaluation Tier 3 — Adversarial self-play: Session A generates challenging prompts Fresh Session B responds cold (can't cheat) RCON validates, self-corrects on errors Teaches: robustness, generalization Usage: --tier 1|2|3|all --rounds N --focus category Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+256
-55
@@ -1,20 +1,29 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
self_play.py — Self-play training data generator.
|
self_play.py — Multi-tier self-play training data generator.
|
||||||
|
|
||||||
The fine-tuned model generates its own training data by:
|
Three tiers of self-play, each teaching different skills:
|
||||||
1. Generating diverse edge-case prompts it's uncertain about
|
|
||||||
2. Attempting commands via RCON
|
|
||||||
3. Self-correcting on errors
|
|
||||||
4. Saving successful sequences as training examples
|
|
||||||
|
|
||||||
This creates a closed-loop learning system with RCON as ground truth.
|
Tier 1 — Command drills:
|
||||||
No API cost — runs entirely on the local model.
|
Feed known prompts, execute commands via RCON, validate syntax.
|
||||||
|
Teaches: accurate command generation.
|
||||||
|
Usage: --tier 1 --rounds 50
|
||||||
|
|
||||||
Usage:
|
Tier 2 — Single-shot self-critique:
|
||||||
python3 training/scripts/self_play.py --rounds 100 --model qwen3.5-9b-mc-v4
|
Model generates BOTH the prompt AND response in one call.
|
||||||
python3 training/scripts/self_play.py --rounds 50 --dry-run
|
Teaches: edge-case awareness, self-evaluation.
|
||||||
python3 training/scripts/self_play.py --rounds 200 --focus enchantments
|
Usage: --tier 2 --rounds 50
|
||||||
|
|
||||||
|
Tier 3 — Adversarial self-play:
|
||||||
|
Session A generates a challenging prompt. Fresh Session B responds.
|
||||||
|
RCON validates. Model can't cheat by knowing both sides.
|
||||||
|
Teaches: robustness, generalization, error correction.
|
||||||
|
Usage: --tier 3 --rounds 50
|
||||||
|
|
||||||
|
All tiers:
|
||||||
|
--tier all --rounds 50 (runs ~17 rounds of each)
|
||||||
|
|
||||||
|
No API cost — runs entirely on the local model with RCON as ground truth.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -182,6 +191,120 @@ Analyze the error and return a corrected command.
|
|||||||
Return JSON: {"commands": ["corrected_cmd"], "reasoning": "what was wrong and how you fixed it"}"""
|
Return JSON: {"commands": ["corrected_cmd"], "reasoning": "what was wrong and how you fixed it"}"""
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tier 1: Command drills ---
|
||||||
|
|
||||||
|
def run_tier1_drill(model, ollama_url, rcon_host, rcon_port, rcon_pass, max_retries=2):
|
||||||
|
"""Pick a random prompt from seed dataset, generate commands, validate via RCON."""
|
||||||
|
seed_path = ROOT / "data" / "processed" / "seed_dataset.jsonl"
|
||||||
|
with open(seed_path) as f:
|
||||||
|
lines = [l for l in f if l.strip()]
|
||||||
|
line = random.choice(lines)
|
||||||
|
ex = json.loads(line)
|
||||||
|
prompt = ex["input"]["user_message"]
|
||||||
|
# Only drill command_gen examples
|
||||||
|
if ex.get("category") not in ("command_gen",):
|
||||||
|
return None
|
||||||
|
|
||||||
|
trace = attempt_command(model, ollama_url, prompt, rcon_host, rcon_port, rcon_pass, max_retries)
|
||||||
|
trace["tier"] = 1
|
||||||
|
trace["original_commands"] = ex.get("output", {}).get("commands", [])
|
||||||
|
return trace
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tier 2: Single-shot self-critique ---
|
||||||
|
|
||||||
|
SELF_CRITIQUE_SYSTEM = """You are a Minecraft 1.21 AI training data generator AND command translator.
|
||||||
|
|
||||||
|
Your task: generate a challenging Minecraft player request, then respond to it yourself.
|
||||||
|
Focus on edge cases you might get wrong: unusual items, complex enchantments, execute chains, ambiguous phrasing.
|
||||||
|
|
||||||
|
Return JSON:
|
||||||
|
{
|
||||||
|
"generated_prompt": "the player request you invented (must start with 'sudo ' or 'pray ')",
|
||||||
|
"difficulty": "what makes this tricky",
|
||||||
|
"commands": ["cmd1", "cmd2"],
|
||||||
|
"reasoning": "why these commands are correct",
|
||||||
|
"message": "God message if pray, empty string if sudo"
|
||||||
|
}
|
||||||
|
|
||||||
|
Commands use minecraft: prefix. Enchantments: item[enchantments={name:level}].
|
||||||
|
Effects: effect give <player> minecraft:<effect> <seconds> <amplifier>.
|
||||||
|
Player: slingshooter08. Do NOT start commands with /."""
|
||||||
|
|
||||||
|
|
||||||
|
def run_tier2_selfcritique(model, ollama_url, rcon_host, rcon_port, rcon_pass, category=None):
|
||||||
|
"""Model generates a prompt AND responds in one shot, then RCON validates."""
|
||||||
|
focus = ""
|
||||||
|
if category:
|
||||||
|
focus = f"\nFocus area: {category}. Generate a prompt specifically testing {category}."
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw = llm_call(
|
||||||
|
model=model,
|
||||||
|
system=SELF_CRITIQUE_SYSTEM + focus,
|
||||||
|
user="Generate one challenging Minecraft request and your response. Be creative — pick something you might get wrong.",
|
||||||
|
ollama_url=ollama_url,
|
||||||
|
temperature=0.9,
|
||||||
|
max_tokens=500,
|
||||||
|
fmt="json",
|
||||||
|
)
|
||||||
|
result = json.loads(raw)
|
||||||
|
except:
|
||||||
|
match = re.search(r'\{[\s\S]*\}', raw if 'raw' in dir() else '')
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
result = json.loads(match.group())
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
prompt = result.get("generated_prompt", "")
|
||||||
|
commands = result.get("commands") or []
|
||||||
|
message = result.get("message") or ""
|
||||||
|
reasoning = result.get("reasoning") or ""
|
||||||
|
difficulty = result.get("difficulty") or ""
|
||||||
|
|
||||||
|
if not prompt:
|
||||||
|
return None
|
||||||
|
|
||||||
|
trace = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"mode": "god" if prompt.lower().startswith("pray ") else "sudo",
|
||||||
|
"tier": 2,
|
||||||
|
"difficulty_note": difficulty,
|
||||||
|
"attempts": [],
|
||||||
|
"final_success": False,
|
||||||
|
"self_corrected": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not commands:
|
||||||
|
trace["attempts"].append({
|
||||||
|
"commands": [], "reasoning": reasoning, "message": message,
|
||||||
|
"rcon_results": [], "all_success": True,
|
||||||
|
})
|
||||||
|
trace["final_success"] = True
|
||||||
|
return trace
|
||||||
|
|
||||||
|
# Validate via RCON
|
||||||
|
rcon_results = []
|
||||||
|
all_success = True
|
||||||
|
for cmd in commands:
|
||||||
|
success, rcon_result = rcon_command(cmd, rcon_host, rcon_port, rcon_pass)
|
||||||
|
rcon_results.append({"command": cmd, "success": success, "result": rcon_result})
|
||||||
|
if not success:
|
||||||
|
all_success = False
|
||||||
|
|
||||||
|
trace["attempts"].append({
|
||||||
|
"commands": commands, "reasoning": reasoning, "message": message,
|
||||||
|
"rcon_results": rcon_results, "all_success": all_success,
|
||||||
|
})
|
||||||
|
trace["final_success"] = all_success
|
||||||
|
return trace
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tier 3: Adversarial self-play (original generate_prompts + attempt_command) ---
|
||||||
|
|
||||||
def generate_prompts(model, ollama_url, category=None):
|
def generate_prompts(model, ollama_url, category=None):
|
||||||
"""Use the model to generate edge-case prompts for itself."""
|
"""Use the model to generate edge-case prompts for itself."""
|
||||||
if category:
|
if category:
|
||||||
@@ -436,16 +559,20 @@ def main():
|
|||||||
parser.add_argument("--rcon-port", type=int, default=25578)
|
parser.add_argument("--rcon-port", type=int, default=25578)
|
||||||
parser.add_argument("--rcon-pass", default="REDACTED_RCON")
|
parser.add_argument("--rcon-pass", default="REDACTED_RCON")
|
||||||
parser.add_argument("--rounds", type=int, default=50)
|
parser.add_argument("--rounds", type=int, default=50)
|
||||||
|
parser.add_argument("--tier", default="all", choices=["1", "2", "3", "all"])
|
||||||
parser.add_argument("--focus", default=None, choices=list(EXPLORATION_CATEGORIES.keys()))
|
parser.add_argument("--focus", default=None, choices=list(EXPLORATION_CATEGORIES.keys()))
|
||||||
parser.add_argument("--output", default=str(OUTPUT))
|
parser.add_argument("--output", default=str(OUTPUT))
|
||||||
parser.add_argument("--dry-run", action="store_true")
|
parser.add_argument("--dry-run", action="store_true")
|
||||||
parser.add_argument("--max-retries", type=int, default=2)
|
parser.add_argument("--max-retries", type=int, default=2)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
tiers = [1, 2, 3] if args.tier == "all" else [int(args.tier)]
|
||||||
|
|
||||||
print(f"Self-play training data generator")
|
print(f"Self-play training data generator")
|
||||||
print(f" Model: {args.model}")
|
print(f" Model: {args.model}")
|
||||||
print(f" RCON: {args.rcon_host}:{args.rcon_port}")
|
print(f" RCON: {args.rcon_host}:{args.rcon_port}")
|
||||||
print(f" Rounds: {args.rounds}")
|
print(f" Rounds: {args.rounds}")
|
||||||
|
print(f" Tiers: {tiers}")
|
||||||
print(f" Focus: {args.focus or 'all categories'}")
|
print(f" Focus: {args.focus or 'all categories'}")
|
||||||
print(f" Max retries: {args.max_retries}")
|
print(f" Max retries: {args.max_retries}")
|
||||||
print(f" Output: {args.output}")
|
print(f" Output: {args.output}")
|
||||||
@@ -454,60 +581,129 @@ def main():
|
|||||||
stats = {
|
stats = {
|
||||||
"rounds": 0, "prompts_generated": 0, "attempts": 0,
|
"rounds": 0, "prompts_generated": 0, "attempts": 0,
|
||||||
"success_first_try": 0, "self_corrected": 0, "failed": 0,
|
"success_first_try": 0, "self_corrected": 0, "failed": 0,
|
||||||
"training_examples": 0, "by_category": {},
|
"training_examples": 0, "by_tier": {1: 0, 2: 0, 3: 0}, "by_category": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
for round_num in range(args.rounds):
|
for round_num in range(args.rounds):
|
||||||
print(f"\n--- Round {round_num + 1}/{args.rounds} ---")
|
tier = tiers[round_num % len(tiers)]
|
||||||
|
print(f"\n--- Round {round_num + 1}/{args.rounds} [Tier {tier}] ---")
|
||||||
# Generate prompts
|
|
||||||
prompts = generate_prompts(args.model, args.ollama_url, args.focus)
|
|
||||||
if not prompts:
|
|
||||||
print(" No prompts generated, skipping round")
|
|
||||||
continue
|
|
||||||
|
|
||||||
stats["prompts_generated"] += len(prompts)
|
|
||||||
print(f" Generated {len(prompts)} prompts")
|
|
||||||
|
|
||||||
for p in prompts:
|
|
||||||
prompt = p["prompt"]
|
|
||||||
cat = p["category"]
|
|
||||||
stats["by_category"].setdefault(cat, {"total": 0, "success": 0, "corrected": 0})
|
|
||||||
stats["by_category"][cat]["total"] += 1
|
|
||||||
stats["attempts"] += 1
|
|
||||||
|
|
||||||
print(f" [{cat}] {prompt[:60]:60}", end="")
|
|
||||||
|
|
||||||
|
if tier == 1:
|
||||||
|
# Command drill: pick random seed example, try to execute
|
||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
print(" [DRY RUN]")
|
print(" [DRY RUN] Would drill a random seed prompt via RCON")
|
||||||
|
stats["rounds"] += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
trace = attempt_command(
|
for _ in range(5): # 5 drills per round
|
||||||
args.model, args.ollama_url, prompt,
|
trace = run_tier1_drill(
|
||||||
args.rcon_host, args.rcon_port, args.rcon_pass,
|
args.model, args.ollama_url,
|
||||||
max_retries=args.max_retries,
|
args.rcon_host, args.rcon_port, args.rcon_pass,
|
||||||
|
args.max_retries,
|
||||||
|
)
|
||||||
|
if trace is None:
|
||||||
|
continue
|
||||||
|
stats["attempts"] += 1
|
||||||
|
stats["by_tier"][1] += 1
|
||||||
|
prompt = trace["prompt"]
|
||||||
|
print(f" [drill] {prompt[:55]:55}", end="")
|
||||||
|
if trace["final_success"] and not trace["self_corrected"]:
|
||||||
|
stats["success_first_try"] += 1
|
||||||
|
n_cmds = len(trace["attempts"][0].get("commands", []))
|
||||||
|
print(f" OK ({n_cmds} cmds)")
|
||||||
|
elif trace.get("self_corrected"):
|
||||||
|
stats["self_corrected"] += 1
|
||||||
|
print(f" CORRECTED ({len(trace['attempts'])} attempts)")
|
||||||
|
else:
|
||||||
|
stats["failed"] += 1
|
||||||
|
print(f" FAILED")
|
||||||
|
examples = trace_to_training(trace)
|
||||||
|
all_examples.extend(examples)
|
||||||
|
stats["training_examples"] += len(examples)
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
elif tier == 2:
|
||||||
|
# Self-critique: model generates prompt + response, RCON validates
|
||||||
|
cats = [args.focus] if args.focus else random.sample(
|
||||||
|
list(EXPLORATION_CATEGORIES.keys()), min(3, len(EXPLORATION_CATEGORIES))
|
||||||
)
|
)
|
||||||
|
for cat in cats:
|
||||||
|
if args.dry_run:
|
||||||
|
print(f" [DRY RUN] Would self-critique on {cat}")
|
||||||
|
continue
|
||||||
|
|
||||||
if trace["final_success"] and not trace["self_corrected"]:
|
trace = run_tier2_selfcritique(
|
||||||
stats["success_first_try"] += 1
|
args.model, args.ollama_url,
|
||||||
stats["by_category"][cat]["success"] += 1
|
args.rcon_host, args.rcon_port, args.rcon_pass,
|
||||||
n_cmds = len(trace["attempts"][0].get("commands", []))
|
category=cat,
|
||||||
print(f" OK ({n_cmds} cmds)")
|
)
|
||||||
elif trace["self_corrected"]:
|
if trace is None:
|
||||||
stats["self_corrected"] += 1
|
continue
|
||||||
stats["by_category"][cat]["corrected"] += 1
|
stats["attempts"] += 1
|
||||||
print(f" CORRECTED ({len(trace['attempts'])} attempts)")
|
stats["by_tier"][2] += 1
|
||||||
else:
|
prompt = trace["prompt"]
|
||||||
stats["failed"] += 1
|
diff = trace.get("difficulty_note", "")[:30]
|
||||||
print(f" FAILED")
|
print(f" [self-critique:{cat[:12]}] {prompt[:40]:40} ({diff})", end="")
|
||||||
|
if trace["final_success"]:
|
||||||
|
stats["success_first_try"] += 1
|
||||||
|
n_cmds = len(trace["attempts"][0].get("commands", []))
|
||||||
|
print(f" OK ({n_cmds} cmds)")
|
||||||
|
else:
|
||||||
|
stats["failed"] += 1
|
||||||
|
print(f" FAILED (self-generated bad commands)")
|
||||||
|
examples = trace_to_training(trace)
|
||||||
|
all_examples.extend(examples)
|
||||||
|
stats["training_examples"] += len(examples)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
# Convert to training examples
|
elif tier == 3:
|
||||||
examples = trace_to_training(trace)
|
# Adversarial: generate prompts in Session A, respond in fresh Session B
|
||||||
all_examples.extend(examples)
|
prompts = generate_prompts(args.model, args.ollama_url, args.focus)
|
||||||
stats["training_examples"] += len(examples)
|
if not prompts:
|
||||||
|
print(" No prompts generated, skipping round")
|
||||||
|
stats["rounds"] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
# Brief pause between attempts
|
stats["prompts_generated"] += len(prompts)
|
||||||
time.sleep(1)
|
print(f" Generated {len(prompts)} adversarial prompts")
|
||||||
|
|
||||||
|
for p in prompts:
|
||||||
|
prompt = p["prompt"]
|
||||||
|
cat = p["category"]
|
||||||
|
stats["by_category"].setdefault(cat, {"total": 0, "success": 0, "corrected": 0})
|
||||||
|
stats["by_category"][cat]["total"] += 1
|
||||||
|
stats["attempts"] += 1
|
||||||
|
stats["by_tier"][3] += 1
|
||||||
|
|
||||||
|
print(f" [adversarial:{cat[:12]}] {prompt[:48]:48}", end="")
|
||||||
|
|
||||||
|
if args.dry_run:
|
||||||
|
print(" [DRY RUN]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
trace = attempt_command(
|
||||||
|
args.model, args.ollama_url, prompt,
|
||||||
|
args.rcon_host, args.rcon_port, args.rcon_pass,
|
||||||
|
max_retries=args.max_retries,
|
||||||
|
)
|
||||||
|
|
||||||
|
if trace["final_success"] and not trace["self_corrected"]:
|
||||||
|
stats["success_first_try"] += 1
|
||||||
|
stats["by_category"][cat]["success"] += 1
|
||||||
|
n_cmds = len(trace["attempts"][0].get("commands", []))
|
||||||
|
print(f" OK ({n_cmds} cmds)")
|
||||||
|
elif trace["self_corrected"]:
|
||||||
|
stats["self_corrected"] += 1
|
||||||
|
stats["by_category"][cat]["corrected"] += 1
|
||||||
|
print(f" CORRECTED ({len(trace['attempts'])} attempts)")
|
||||||
|
else:
|
||||||
|
stats["failed"] += 1
|
||||||
|
print(f" FAILED")
|
||||||
|
|
||||||
|
examples = trace_to_training(trace)
|
||||||
|
all_examples.extend(examples)
|
||||||
|
stats["training_examples"] += len(examples)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
stats["rounds"] += 1
|
stats["rounds"] += 1
|
||||||
|
|
||||||
@@ -530,6 +726,11 @@ def main():
|
|||||||
print(f" Failed: {stats['failed']}")
|
print(f" Failed: {stats['failed']}")
|
||||||
print(f" Training examples:{stats['training_examples']}")
|
print(f" Training examples:{stats['training_examples']}")
|
||||||
|
|
||||||
|
print(f"\n By tier:")
|
||||||
|
for t in sorted(stats["by_tier"]):
|
||||||
|
labels = {1: "Command drills", 2: "Self-critique", 3: "Adversarial"}
|
||||||
|
print(f" Tier {t} ({labels[t]:16}): {stats['by_tier'][t]} attempts")
|
||||||
|
|
||||||
if stats["by_category"]:
|
if stats["by_category"]:
|
||||||
print(f"\n By category:")
|
print(f"\n By category:")
|
||||||
for cat, s in sorted(stats["by_category"].items()):
|
for cat, s in sorted(stats["by_category"].items()):
|
||||||
|
|||||||
Reference in New Issue
Block a user