feat: add FastAPI app with WebSocket streaming and escalation loop

This commit is contained in:
Mortdecai
2026-04-10 01:27:52 -04:00
parent 5d03c46dcc
commit ca52b94ffd
2 changed files with 286 additions and 0 deletions
+254
View File
@@ -0,0 +1,254 @@
"""FastAPI application — WebSocket streaming, REST endpoints, background workers."""
import asyncio
import base64
import logging
import random
import time
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from server.config import config
from server.escalation import EscalationEngine
from server.asset_pool import AssetPool
from server.streaming import StreamManager
from server.prompts import get_image_prompt, get_voice_text, get_direct_address_text
logger = logging.getLogger("ai-hell")
# Global instances (set during lifespan or create_app)
escalation: EscalationEngine | None = None
pool: AssetPool | None = None
stream: StreamManager | None = None
asset_gen = None # AssetGenerator (lazy, needs GPU)
voice_gen = None # VoiceGenerator (lazy, needs GPU)
_workers: list[asyncio.Task] = []
def create_app(skip_models: bool = False) -> FastAPI:
"""Create the FastAPI app. skip_models=True for testing without GPU."""
global escalation, pool, stream, asset_gen, voice_gen
escalation = EscalationEngine()
pool = AssetPool()
stream = StreamManager()
@asynccontextmanager
async def lifespan(app: FastAPI):
global asset_gen, voice_gen
escalation.start_session()
if not skip_models:
from server.asset_generator import AssetGenerator
from server.voice_generator import VoiceGenerator
logger.info("Loading SDXL Turbo...")
asset_gen = AssetGenerator()
logger.info("Loading XTTS v2...")
voice_gen = VoiceGenerator()
logger.info("Models loaded. Generating initial batch...")
# Generate initial asset batch in background
loop = asyncio.get_running_loop()
_workers.append(asyncio.create_task(_initial_batch(loop)))
_workers.append(asyncio.create_task(_background_generator(loop)))
_workers.append(asyncio.create_task(_escalation_loop()))
yield
# Shutdown workers
for task in _workers:
task.cancel()
the_app = FastAPI(title="AI Hell", lifespan=lifespan)
# Mount assets directory for static serving
assets_dir = Path(config.assets_dir)
assets_dir.mkdir(parents=True, exist_ok=True)
(assets_dir / "img").mkdir(exist_ok=True)
(assets_dir / "audio").mkdir(exist_ok=True)
the_app.mount("/assets", StaticFiles(directory=str(assets_dir)), name="assets")
# --- REST endpoints ---
@the_app.get("/")
async def index():
html_path = Path(__file__).parent.parent / "frontend" / "index.html"
if html_path.exists():
return FileResponse(html_path, media_type="text/html")
return HTMLResponse("<h1>AI Hell</h1><p>Frontend not found.</p>")
@the_app.get("/status")
async def status():
return {
"intensity": round(escalation.get_intensity(), 2),
"connected_clients": stream.client_count,
**pool.get_status(),
}
@the_app.post("/reset")
async def reset():
escalation.reset()
return {"status": "ok"}
# --- WebSocket ---
@the_app.websocket("/stream")
async def stream_ws(ws: WebSocket):
await ws.accept()
stream.add_client(ws)
# Send current state immediately
intensity = escalation.get_intensity()
params = escalation.get_phase_params(intensity)
await ws.send_text(
__import__("json").dumps({
"type": "phase",
"intensity": round(intensity, 2),
"params": params,
})
)
try:
while True:
await ws.receive_text() # Keep alive, ignore pings
except WebSocketDisconnect:
pass
finally:
stream.remove_client(ws)
# Serve frontend shader files
@the_app.get("/shaders/{filename}")
async def serve_shader(filename: str):
shader_path = Path(__file__).parent.parent / "frontend" / "shaders" / filename
if shader_path.exists():
return FileResponse(shader_path, media_type="text/plain")
return HTMLResponse("Not found", status_code=404)
return the_app
async def _initial_batch(loop: asyncio.AbstractEventLoop) -> None:
"""Generate the initial pool of images and audio clips."""
batch_size = config.escalation.initial_batch_size
img_count = int(batch_size * 0.75)
audio_count = batch_size - img_count
for i in range(img_count):
severity = (i / max(1, img_count - 1)) * 4.0 # Spread across severity range
prompt = get_image_prompt(severity)
try:
data = await asyncio.to_thread(asset_gen.generate, prompt)
pool.add_image(data, severity=severity)
logger.info(f"Initial image {i+1}/{img_count} (severity={severity:.1f})")
except Exception as e:
logger.error(f"Failed to generate initial image: {e}")
for i in range(audio_count):
severity = (i / max(1, audio_count - 1)) * 4.0
text = get_voice_text()
try:
data = await asyncio.to_thread(voice_gen.generate, text)
pool.add_audio(data, severity=severity)
logger.info(f"Initial audio {i+1}/{audio_count} (severity={severity:.1f})")
except Exception as e:
logger.error(f"Failed to generate initial audio: {e}")
logger.info("Initial batch complete.")
async def _background_generator(loop: asyncio.AbstractEventLoop) -> None:
"""Continuously generate new assets biased toward current viewer needs."""
while True:
await asyncio.sleep(random.uniform(10, 30))
if stream.client_count == 0:
continue
intensity = escalation.get_intensity()
severity = escalation.select_severity(intensity)
# Alternate between images and audio
if random.random() < 0.7: # 70% images, 30% audio
prompt = get_image_prompt(severity)
try:
data = await asyncio.to_thread(asset_gen.generate, prompt)
pool.add_image(data, severity=severity)
except Exception as e:
logger.error(f"Background image gen failed: {e}")
else:
text = get_voice_text()
try:
data = await asyncio.to_thread(voice_gen.generate, text)
pool.add_audio(data, severity=severity)
except Exception as e:
logger.error(f"Background audio gen failed: {e}")
async def _escalation_loop() -> None:
"""Main escalation loop — pushes phase updates and triggers events."""
while True:
if stream.client_count == 0:
await asyncio.sleep(1)
continue
intensity = escalation.get_intensity()
params = escalation.get_phase_params(intensity)
# Phase update
await stream.broadcast_phase(intensity=intensity, params=params)
# Asset swap
severity = escalation.select_severity(intensity)
url = pool.select_image(target_severity=severity)
if url:
transition = _pick_transition(intensity)
await stream.broadcast_asset(url=url, severity=severity, transition=transition)
# Whisper check
voice_interval = escalation.get_voice_interval(intensity)
if random.random() < (2.0 / max(1.0, voice_interval)):
audio_url = pool.select_audio(target_severity=severity)
if audio_url:
await stream.broadcast_whisper(
url=audio_url,
pan=random.uniform(-1.0, 1.0),
volume=random.uniform(0.1, 0.8),
reverb=random.uniform(0.3, 0.9),
)
# Direct address check (rarer)
if intensity > 1.5 and random.random() < params["voice_frequency"] * 0.1:
text = get_direct_address_text()
if voice_gen:
try:
data = await asyncio.to_thread(voice_gen.generate, text)
audio_b64 = base64.b64encode(data).decode("ascii")
await stream.broadcast_address(audio_b64=audio_b64, text=text)
except Exception as e:
logger.error(f"Direct address gen failed: {e}")
# Surprise scare check
if random.random() < params["surprise_chance"] * 0.05:
effect = random.choice(["face_flash", "white_out", "inversion", "glitch_burst"])
duration = random.randint(50, 300)
await stream.broadcast_scare(effect=effect, duration_ms=duration)
# Wait for next cycle
swap_interval = escalation.get_asset_swap_interval(intensity)
await asyncio.sleep(swap_interval)
def _pick_transition(intensity: float) -> str:
"""Pick transition mode based on intensity."""
if intensity < 1.0:
return "crossfade"
elif intensity < 2.5:
return random.choice(["crossfade", "dissolve", "melt_morph"])
else:
return random.choice(["glitch_cut", "melt_morph", "dissolve", "crossfade"])
# Default app instance for uvicorn
app = create_app(skip_models=False)
+32
View File
@@ -0,0 +1,32 @@
from fastapi.testclient import TestClient
from server.main import create_app
class TestRESTEndpoints:
def test_status_endpoint(self):
"""GET /status returns intensity and pool info."""
test_app = create_app(skip_models=True)
with TestClient(test_app) as client:
resp = client.get("/status")
assert resp.status_code == 200
data = resp.json()
assert "intensity" in data
assert "connected_clients" in data
assert "image_pool_size" in data
def test_reset_endpoint(self):
"""POST /reset restarts escalation."""
test_app = create_app(skip_models=True)
with TestClient(test_app) as client:
resp = client.post("/reset")
assert resp.status_code == 200
assert resp.json()["status"] == "ok"
def test_index_serves_html(self):
"""GET / serves the frontend HTML (or fallback)."""
test_app = create_app(skip_models=True)
with TestClient(test_app) as client:
resp = client.get("/")
assert resp.status_code == 200
assert "text/html" in resp.headers["content-type"]