Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3bf1a9a373 | |||
| ad288d82f2 | |||
| 2274bf8066 | |||
| ca52b94ffd | |||
| 5d03c46dcc | |||
| 68a1d143e8 | |||
| ee08a1ffd3 | |||
| 33ea2ac43c | |||
| 5e12527e23 | |||
| 6725231775 | |||
| ed39bcf3f5 | |||
| e27789f2ba | |||
| 42ff0193fc | |||
| 7d5514b834 |
+2
-1
@@ -3,7 +3,8 @@ __pycache__/
|
||||
.pytest_cache/
|
||||
*.egg-info/
|
||||
venv/
|
||||
.venv/
|
||||
.env
|
||||
assets/
|
||||
samples/
|
||||
samples/*.wav
|
||||
.backup/
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
[Unit]
|
||||
Description=AI Hell - Passive Horror Webapp
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=root
|
||||
WorkingDirectory=/opt/ai-hell
|
||||
ExecStart=/opt/ai-hell/venv/bin/python run.py
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
Environment=PYTHONUNBUFFERED=1
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
Executable
+33
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
# AI Hell container setup script
|
||||
# Run inside the LXC container on pve197 after GPU passthrough is configured.
|
||||
set -euo pipefail
|
||||
|
||||
echo "=== AI Hell Container Setup ==="
|
||||
|
||||
# System packages
|
||||
apt-get update
|
||||
apt-get install -y python3 python3-pip python3-venv git ffmpeg
|
||||
|
||||
# Create app directory
|
||||
mkdir -p /opt/ai-hell
|
||||
cd /opt/ai-hell
|
||||
|
||||
# Clone or copy project
|
||||
# (Adjust this based on whether you're cloning from Gitea or copying files)
|
||||
# git clone http://192.168.0.125:3000/Seth/ai-hell.git .
|
||||
|
||||
# Virtual environment
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# Install dependencies
|
||||
pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Create runtime directories
|
||||
mkdir -p assets/img assets/audio samples
|
||||
|
||||
echo "=== Setup complete ==="
|
||||
echo "Drop WAV files into /opt/ai-hell/samples/ for XTTS voice cloning sources."
|
||||
echo "Start with: systemctl start ai-hell"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,466 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1">
|
||||
<title>AI Hell</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
html, body { width: 100%; height: 100%; overflow: hidden; background: #000; cursor: none; }
|
||||
canvas { display: block; width: 100%; height: 100%; }
|
||||
#start-overlay {
|
||||
position: fixed; top: 0; left: 0; width: 100%; height: 100%;
|
||||
background: #000; display: flex; align-items: center; justify-content: center;
|
||||
cursor: pointer; z-index: 10;
|
||||
}
|
||||
#start-overlay span {
|
||||
color: #333; font-family: monospace; font-size: 14px;
|
||||
transition: opacity 2s;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="start-overlay"><span>click to enter</span></div>
|
||||
<canvas id="c"></canvas>
|
||||
|
||||
<script>
|
||||
// ============================================================
|
||||
// AI Hell — Frontend Compositor
|
||||
// ============================================================
|
||||
|
||||
const RECONNECT_DELAY = 2000;
|
||||
const PHASE_LERP_SPEED = 0.02;
|
||||
|
||||
// --- State ---
|
||||
let ws = null;
|
||||
let audioCtx = null;
|
||||
let gl = null;
|
||||
let program = null;
|
||||
let started = false;
|
||||
|
||||
// Current and target params (lerped)
|
||||
let currentParams = {
|
||||
morph_speed: 0, shader_severity: 0, noise_level: 0,
|
||||
voice_frequency: 0, surprise_chance: 0,
|
||||
};
|
||||
let targetParams = { ...currentParams };
|
||||
let currentIntensity = 0;
|
||||
|
||||
// Image state
|
||||
let currentTexture = null;
|
||||
let nextTexture = null;
|
||||
let blendProgress = 0;
|
||||
let blendTarget = 0;
|
||||
let transitionMode = 0; // 0=crossfade, 1=dissolve, 2=glitch, 3=melt
|
||||
let blendSpeed = 0.01;
|
||||
|
||||
// Flash state
|
||||
let flashIntensity = 0;
|
||||
let flashType = 0;
|
||||
let flashDecay = 0;
|
||||
|
||||
// Palette colors
|
||||
const PALETTES = {
|
||||
void_black: [0.1, 0.1, 0.1],
|
||||
crimson_void: [1.0, 0.15, 0.1],
|
||||
deep_rot: [0.5, 0.1, 0.05],
|
||||
sickly_green: [0.2, 0.6, 0.1],
|
||||
bruise_purple: [0.4, 0.1, 0.5],
|
||||
ash_grey: [0.4, 0.4, 0.4],
|
||||
bile_yellow: [0.6, 0.5, 0.1],
|
||||
blood_orange: [0.8, 0.3, 0.05],
|
||||
};
|
||||
let colorTint = [1.0, 1.0, 1.0];
|
||||
|
||||
// Cursor hide timer
|
||||
let cursorTimer = null;
|
||||
document.addEventListener('mousemove', () => {
|
||||
document.body.style.cursor = 'default';
|
||||
clearTimeout(cursorTimer);
|
||||
cursorTimer = setTimeout(() => { document.body.style.cursor = 'none'; }, 3000);
|
||||
});
|
||||
|
||||
// ============================================================
|
||||
// WebGL Setup
|
||||
// ============================================================
|
||||
|
||||
const VERT_SRC = `
|
||||
attribute vec2 a_position;
|
||||
varying vec2 v_uv;
|
||||
void main() {
|
||||
v_uv = a_position * 0.5 + 0.5;
|
||||
gl_Position = vec4(a_position, 0.0, 1.0);
|
||||
}`;
|
||||
|
||||
function initWebGL() {
|
||||
const canvas = document.getElementById('c');
|
||||
gl = canvas.getContext('webgl', { antialias: false, alpha: false });
|
||||
if (!gl) { console.error('WebGL not supported'); return; }
|
||||
|
||||
// Load fragment shader via fetch
|
||||
return fetch('/shaders/compositor.frag')
|
||||
.then(r => r.text())
|
||||
.then(fragSrc => {
|
||||
const vs = compileShader(gl.VERTEX_SHADER, VERT_SRC);
|
||||
const fs = compileShader(gl.FRAGMENT_SHADER, fragSrc);
|
||||
program = gl.createProgram();
|
||||
gl.attachShader(program, vs);
|
||||
gl.attachShader(program, fs);
|
||||
gl.linkProgram(program);
|
||||
if (!gl.getProgramParameter(program, gl.LINK_STATUS)) {
|
||||
console.error('Shader link error:', gl.getProgramInfoLog(program));
|
||||
return;
|
||||
}
|
||||
gl.useProgram(program);
|
||||
|
||||
// Fullscreen quad
|
||||
const buf = gl.createBuffer();
|
||||
gl.bindBuffer(gl.ARRAY_BUFFER, buf);
|
||||
gl.bufferData(gl.ARRAY_BUFFER, new Float32Array([-1,-1, 1,-1, -1,1, 1,1]), gl.STATIC_DRAW);
|
||||
const loc = gl.getAttribLocation(program, 'a_position');
|
||||
gl.enableVertexAttribArray(loc);
|
||||
gl.vertexAttribPointer(loc, 2, gl.FLOAT, false, 0, 0);
|
||||
|
||||
// Create placeholder textures (black 1x1)
|
||||
currentTexture = createTexture(null);
|
||||
nextTexture = createTexture(null);
|
||||
|
||||
resize();
|
||||
window.addEventListener('resize', resize);
|
||||
});
|
||||
}
|
||||
|
||||
function compileShader(type, src) {
|
||||
const s = gl.createShader(type);
|
||||
gl.shaderSource(s, src);
|
||||
gl.compileShader(s);
|
||||
if (!gl.getShaderParameter(s, gl.COMPILE_STATUS)) {
|
||||
console.error('Shader error:', gl.getShaderInfoLog(s));
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
function createTexture(source) {
|
||||
const tex = gl.createTexture();
|
||||
gl.bindTexture(gl.TEXTURE_2D, tex);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
|
||||
if (source) {
|
||||
gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, source);
|
||||
} else {
|
||||
gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, 1, 1, 0, gl.RGBA, gl.UNSIGNED_BYTE, new Uint8Array([0,0,0,255]));
|
||||
}
|
||||
return tex;
|
||||
}
|
||||
|
||||
function loadImageTexture(url, callback) {
|
||||
const img = new Image();
|
||||
img.crossOrigin = 'anonymous';
|
||||
img.onload = () => callback(img);
|
||||
img.src = url;
|
||||
}
|
||||
|
||||
function updateTexture(tex, source) {
|
||||
gl.bindTexture(gl.TEXTURE_2D, tex);
|
||||
gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, source);
|
||||
}
|
||||
|
||||
function resize() {
|
||||
const canvas = gl.canvas;
|
||||
canvas.width = window.innerWidth;
|
||||
canvas.height = window.innerHeight;
|
||||
gl.viewport(0, 0, canvas.width, canvas.height);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Audio System
|
||||
// ============================================================
|
||||
|
||||
// Layer 1: Ambient drones (client-side loops)
|
||||
let droneGain = null;
|
||||
let droneFilter = null;
|
||||
let droneSource = null;
|
||||
|
||||
// Layer 2: Whisper pool (server-pushed clips)
|
||||
// Layer 3: Direct address (server-pushed, dry)
|
||||
|
||||
function initAudio() {
|
||||
audioCtx = new (window.AudioContext || window.webkitAudioContext)();
|
||||
|
||||
// Master gain
|
||||
const master = audioCtx.createGain();
|
||||
master.gain.value = 0.8;
|
||||
master.connect(audioCtx.destination);
|
||||
|
||||
// Drone chain: source -> filter -> gain -> master
|
||||
droneFilter = audioCtx.createBiquadFilter();
|
||||
droneFilter.type = 'lowpass';
|
||||
droneFilter.frequency.value = 400;
|
||||
droneFilter.connect(master);
|
||||
|
||||
droneGain = audioCtx.createGain();
|
||||
droneGain.gain.value = 0.3;
|
||||
droneGain.connect(droneFilter);
|
||||
|
||||
// Generate a dark drone using oscillators (no external files needed)
|
||||
startDrone();
|
||||
|
||||
// Store master for whisper/address routing
|
||||
window._audioMaster = master;
|
||||
}
|
||||
|
||||
function startDrone() {
|
||||
// Layered oscillators for a dark ambient drone
|
||||
const freqs = [55, 55.5, 82.5, 110.2]; // Slightly detuned for beating
|
||||
freqs.forEach(freq => {
|
||||
const osc = audioCtx.createOscillator();
|
||||
osc.type = 'sine';
|
||||
osc.frequency.value = freq;
|
||||
|
||||
const oscGain = audioCtx.createGain();
|
||||
oscGain.gain.value = 0.08;
|
||||
osc.connect(oscGain);
|
||||
oscGain.connect(droneGain);
|
||||
osc.start();
|
||||
});
|
||||
|
||||
// Sub-bass rumble
|
||||
const sub = audioCtx.createOscillator();
|
||||
sub.type = 'triangle';
|
||||
sub.frequency.value = 30;
|
||||
const subGain = audioCtx.createGain();
|
||||
subGain.gain.value = 0.15;
|
||||
sub.connect(subGain);
|
||||
subGain.connect(droneGain);
|
||||
sub.start();
|
||||
|
||||
// LFO on drone filter frequency for slow movement
|
||||
const lfo = audioCtx.createOscillator();
|
||||
lfo.type = 'sine';
|
||||
lfo.frequency.value = 0.05; // Very slow
|
||||
const lfoGain = audioCtx.createGain();
|
||||
lfoGain.gain.value = 200;
|
||||
lfo.connect(lfoGain);
|
||||
lfoGain.connect(droneFilter.frequency);
|
||||
lfo.start();
|
||||
}
|
||||
|
||||
function updateDroneFromIntensity(intensity) {
|
||||
if (!droneFilter || !droneGain) return;
|
||||
// Open up filter and increase gain with intensity
|
||||
const targetFreq = 400 + Math.min(intensity, 5) * 300;
|
||||
const targetGain = 0.3 + Math.min(intensity, 5) * 0.1;
|
||||
droneFilter.frequency.linearRampToValueAtTime(targetFreq, audioCtx.currentTime + 2);
|
||||
droneGain.gain.linearRampToValueAtTime(targetGain, audioCtx.currentTime + 2);
|
||||
}
|
||||
|
||||
function playWhisper(url, pan, volume, reverb) {
|
||||
fetch(url)
|
||||
.then(r => r.arrayBuffer())
|
||||
.then(buf => audioCtx.decodeAudioData(buf))
|
||||
.then(audioBuffer => {
|
||||
const source = audioCtx.createBufferSource();
|
||||
source.buffer = audioBuffer;
|
||||
|
||||
const gainNode = audioCtx.createGain();
|
||||
gainNode.gain.value = volume;
|
||||
|
||||
const panner = audioCtx.createStereoPanner();
|
||||
panner.pan.value = pan;
|
||||
|
||||
// Simple convolver substitute: delay for reverb-like effect
|
||||
const delay = audioCtx.createDelay();
|
||||
delay.delayTime.value = reverb * 0.1;
|
||||
const delayGain = audioCtx.createGain();
|
||||
delayGain.gain.value = reverb * 0.4;
|
||||
|
||||
source.connect(gainNode);
|
||||
gainNode.connect(panner);
|
||||
panner.connect(window._audioMaster);
|
||||
|
||||
// Reverb feedback path
|
||||
source.connect(delay);
|
||||
delay.connect(delayGain);
|
||||
delayGain.connect(panner);
|
||||
|
||||
source.start();
|
||||
})
|
||||
.catch(() => {});
|
||||
}
|
||||
|
||||
function playDirectAddress(base64Audio) {
|
||||
const binary = atob(base64Audio);
|
||||
const bytes = new Uint8Array(binary.length);
|
||||
for (let i = 0; i < binary.length; i++) bytes[i] = binary.charCodeAt(i);
|
||||
|
||||
audioCtx.decodeAudioData(bytes.buffer)
|
||||
.then(audioBuffer => {
|
||||
const source = audioCtx.createBufferSource();
|
||||
source.buffer = audioBuffer;
|
||||
|
||||
// Direct address: dry, centered, slightly louder
|
||||
const gainNode = audioCtx.createGain();
|
||||
gainNode.gain.value = 0.9;
|
||||
|
||||
source.connect(gainNode);
|
||||
gainNode.connect(window._audioMaster);
|
||||
source.start();
|
||||
})
|
||||
.catch(() => {});
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// WebSocket
|
||||
// ============================================================
|
||||
|
||||
function connect() {
|
||||
const proto = location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
ws = new WebSocket(proto + '//' + location.host + '/stream');
|
||||
|
||||
ws.onopen = () => {
|
||||
// Keepalive
|
||||
setInterval(() => {
|
||||
if (ws.readyState === WebSocket.OPEN) ws.send('{"type":"ping"}');
|
||||
}, 30000);
|
||||
};
|
||||
|
||||
ws.onmessage = (ev) => {
|
||||
const msg = JSON.parse(ev.data);
|
||||
|
||||
switch (msg.type) {
|
||||
case 'phase':
|
||||
currentIntensity = msg.intensity;
|
||||
Object.assign(targetParams, msg.params);
|
||||
if (msg.params.palette && PALETTES[msg.params.palette]) {
|
||||
colorTint = PALETTES[msg.params.palette];
|
||||
}
|
||||
updateDroneFromIntensity(msg.intensity);
|
||||
break;
|
||||
|
||||
case 'asset':
|
||||
// Load new image and start transition
|
||||
loadImageTexture(msg.url, (img) => {
|
||||
// Swap: current becomes old, next becomes new
|
||||
const tmp = currentTexture;
|
||||
currentTexture = nextTexture;
|
||||
nextTexture = tmp;
|
||||
updateTexture(nextTexture, img);
|
||||
blendProgress = 0;
|
||||
blendTarget = 1;
|
||||
transitionMode = ['crossfade','dissolve','glitch_cut','melt_morph'].indexOf(msg.transition);
|
||||
if (transitionMode < 0) transitionMode = 0;
|
||||
// Blend speed scales with morph_speed
|
||||
blendSpeed = 0.005 + currentParams.morph_speed * 0.03;
|
||||
});
|
||||
break;
|
||||
|
||||
case 'whisper':
|
||||
playWhisper(msg.url, msg.pan, msg.volume, msg.reverb);
|
||||
break;
|
||||
|
||||
case 'address':
|
||||
playDirectAddress(msg.audio);
|
||||
break;
|
||||
|
||||
case 'scare':
|
||||
triggerScare(msg.effect, msg.duration_ms);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = () => setTimeout(connect, RECONNECT_DELAY);
|
||||
ws.onerror = () => ws.close();
|
||||
}
|
||||
|
||||
function triggerScare(effect, durationMs) {
|
||||
flashIntensity = 1.0;
|
||||
flashDecay = 1.0 / (durationMs / 16.67); // frames to decay
|
||||
switch (effect) {
|
||||
case 'face_flash': flashType = 2; break;
|
||||
case 'white_out': flashType = 0; break;
|
||||
case 'inversion': flashType = 1; break;
|
||||
case 'glitch_burst': flashType = 0; flashDecay *= 0.3; break; // Slower decay
|
||||
default: flashType = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Render Loop
|
||||
// ============================================================
|
||||
|
||||
let frameTime = 0;
|
||||
|
||||
function render(timestamp) {
|
||||
requestAnimationFrame(render);
|
||||
if (!gl || !program) return;
|
||||
|
||||
frameTime = timestamp * 0.001;
|
||||
|
||||
// Lerp params
|
||||
for (const key in currentParams) {
|
||||
if (targetParams[key] !== undefined) {
|
||||
currentParams[key] += (targetParams[key] - currentParams[key]) * PHASE_LERP_SPEED;
|
||||
}
|
||||
}
|
||||
|
||||
// Blend transition
|
||||
if (blendProgress < blendTarget) {
|
||||
blendProgress = Math.min(blendTarget, blendProgress + blendSpeed);
|
||||
}
|
||||
|
||||
// Flash decay
|
||||
if (flashIntensity > 0) {
|
||||
flashIntensity = Math.max(0, flashIntensity - flashDecay);
|
||||
}
|
||||
|
||||
// Bind textures
|
||||
gl.activeTexture(gl.TEXTURE0);
|
||||
gl.bindTexture(gl.TEXTURE_2D, currentTexture);
|
||||
gl.uniform1i(gl.getUniformLocation(program, 'u_currentImage'), 0);
|
||||
|
||||
gl.activeTexture(gl.TEXTURE1);
|
||||
gl.bindTexture(gl.TEXTURE_2D, nextTexture);
|
||||
gl.uniform1i(gl.getUniformLocation(program, 'u_nextImage'), 1);
|
||||
|
||||
// Set uniforms
|
||||
gl.uniform1f(gl.getUniformLocation(program, 'u_blend'), blendProgress);
|
||||
gl.uniform1i(gl.getUniformLocation(program, 'u_transitionMode'), transitionMode);
|
||||
gl.uniform1f(gl.getUniformLocation(program, 'u_morphSpeed'), currentParams.morph_speed);
|
||||
gl.uniform1f(gl.getUniformLocation(program, 'u_shaderSeverity'), currentParams.shader_severity);
|
||||
gl.uniform1f(gl.getUniformLocation(program, 'u_noiseLevel'), currentParams.noise_level);
|
||||
gl.uniform1f(gl.getUniformLocation(program, 'u_time'), frameTime);
|
||||
gl.uniform1f(gl.getUniformLocation(program, 'u_flashIntensity'), flashIntensity);
|
||||
gl.uniform1i(gl.getUniformLocation(program, 'u_flashType'), flashType);
|
||||
gl.uniform1f(gl.getUniformLocation(program, 'u_vignetteStrength'), 0.5 + currentParams.shader_severity * 0.5);
|
||||
gl.uniform3fv(gl.getUniformLocation(program, 'u_colorTint'), colorTint);
|
||||
|
||||
// Draw
|
||||
gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Start
|
||||
// ============================================================
|
||||
|
||||
document.getElementById('start-overlay').addEventListener('click', () => {
|
||||
if (started) return;
|
||||
started = true;
|
||||
|
||||
document.getElementById('start-overlay').style.display = 'none';
|
||||
|
||||
// Request fullscreen
|
||||
const el = document.documentElement;
|
||||
if (el.requestFullscreen) el.requestFullscreen();
|
||||
else if (el.webkitRequestFullscreen) el.webkitRequestFullscreen();
|
||||
|
||||
initAudio();
|
||||
initWebGL().then(() => {
|
||||
connect();
|
||||
requestAnimationFrame(render);
|
||||
});
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,195 @@
|
||||
precision highp float;
|
||||
|
||||
// Textures
|
||||
uniform sampler2D u_currentImage;
|
||||
uniform sampler2D u_nextImage;
|
||||
|
||||
// Transition
|
||||
uniform float u_blend; // 0=current, 1=next
|
||||
uniform int u_transitionMode; // 0=crossfade, 1=dissolve, 2=glitch_cut, 3=melt_morph
|
||||
|
||||
// Shader params (from escalation engine, all 0-1)
|
||||
uniform float u_morphSpeed;
|
||||
uniform float u_shaderSeverity;
|
||||
uniform float u_noiseLevel;
|
||||
uniform float u_time;
|
||||
|
||||
// Scare flash
|
||||
uniform float u_flashIntensity; // 0=none, 1=full
|
||||
uniform int u_flashType; // 0=white, 1=inversion, 2=face_flash
|
||||
|
||||
// Vignette and overlay
|
||||
uniform float u_vignetteStrength;
|
||||
uniform vec3 u_colorTint; // palette color shift
|
||||
|
||||
varying vec2 v_uv;
|
||||
|
||||
// --- Noise functions ---
|
||||
|
||||
float hash(vec2 p) {
|
||||
return fract(sin(dot(p, vec2(127.1, 311.7))) * 43758.5453);
|
||||
}
|
||||
|
||||
float noise(vec2 p) {
|
||||
vec2 i = floor(p);
|
||||
vec2 f = fract(p);
|
||||
float a = hash(i);
|
||||
float b = hash(i + vec2(1.0, 0.0));
|
||||
float c = hash(i + vec2(0.0, 1.0));
|
||||
float d = hash(i + vec2(1.0, 1.0));
|
||||
vec2 u = f * f * (3.0 - 2.0 * f);
|
||||
return mix(a, b, u.x) + (c - a) * u.y * (1.0 - u.x) + (d - b) * u.x * u.y;
|
||||
}
|
||||
|
||||
// --- Effect functions ---
|
||||
|
||||
vec2 meshWarp(vec2 uv, float severity, float time) {
|
||||
float freq = 3.0 + severity * 8.0;
|
||||
float amp = 0.002 + severity * 0.03;
|
||||
uv.x += sin(uv.y * freq + time * 0.5) * amp;
|
||||
uv.y += cos(uv.x * freq + time * 0.7) * amp;
|
||||
return uv;
|
||||
}
|
||||
|
||||
vec3 chromaticAberration(sampler2D tex, vec2 uv, float severity) {
|
||||
float offset = 0.001 + severity * 0.015;
|
||||
float r = texture2D(tex, uv + vec2(offset, 0.0)).r;
|
||||
float g = texture2D(tex, uv).g;
|
||||
float b = texture2D(tex, uv - vec2(offset, 0.0)).b;
|
||||
return vec3(r, g, b);
|
||||
}
|
||||
|
||||
vec2 meltEffect(vec2 uv, float severity, float time) {
|
||||
float melt = severity * 0.05 * noise(vec2(uv.x * 10.0, time * 0.3));
|
||||
uv.y += melt;
|
||||
return uv;
|
||||
}
|
||||
|
||||
vec2 glitchEffect(vec2 uv, float severity, float time) {
|
||||
float glitchLine = step(0.98 - severity * 0.15, hash(vec2(floor(uv.y * 50.0), floor(time * 8.0))));
|
||||
uv.x += glitchLine * (hash(vec2(time, uv.y)) - 0.5) * severity * 0.1;
|
||||
return uv;
|
||||
}
|
||||
|
||||
float scanlines(vec2 uv, float time) {
|
||||
return 0.95 + 0.05 * sin(uv.y * 800.0 + time * 2.0);
|
||||
}
|
||||
|
||||
float filmGrain(vec2 uv, float time, float amount) {
|
||||
return 1.0 - amount * 0.5 * (hash(uv * 1000.0 + time) - 0.5);
|
||||
}
|
||||
|
||||
float vignette(vec2 uv, float strength) {
|
||||
vec2 center = uv - 0.5;
|
||||
float dist = length(center);
|
||||
return 1.0 - smoothstep(0.3, 0.9, dist) * strength;
|
||||
}
|
||||
|
||||
float pulse(float time, float severity) {
|
||||
return 1.0 + sin(time * (1.0 + severity * 3.0)) * severity * 0.1;
|
||||
}
|
||||
|
||||
// --- Transition functions ---
|
||||
|
||||
vec3 transitionCrossfade(vec2 uv, float blend) {
|
||||
vec3 a = texture2D(u_currentImage, uv).rgb;
|
||||
vec3 b = texture2D(u_nextImage, uv).rgb;
|
||||
return mix(a, b, blend);
|
||||
}
|
||||
|
||||
vec3 transitionDissolve(vec2 uv, float blend) {
|
||||
vec3 a = texture2D(u_currentImage, uv).rgb;
|
||||
vec3 b = texture2D(u_nextImage, uv).rgb;
|
||||
// Dissolve through black
|
||||
float mid = 0.5;
|
||||
if (blend < mid) {
|
||||
return a * (1.0 - blend / mid);
|
||||
} else {
|
||||
return b * ((blend - mid) / (1.0 - mid));
|
||||
}
|
||||
}
|
||||
|
||||
vec3 transitionGlitchCut(vec2 uv, float blend, float time) {
|
||||
// Hard cut with glitch artifacts
|
||||
float threshold = 0.5 + 0.1 * sin(time * 20.0);
|
||||
vec3 img = blend < threshold
|
||||
? texture2D(u_currentImage, uv).rgb
|
||||
: texture2D(u_nextImage, uv).rgb;
|
||||
// Add RGB split at transition point
|
||||
if (abs(blend - threshold) < 0.1) {
|
||||
img = chromaticAberration(blend < threshold ? u_currentImage : u_nextImage, uv, 0.8);
|
||||
}
|
||||
return img;
|
||||
}
|
||||
|
||||
vec3 transitionMeltMorph(vec2 uv, float blend, float time) {
|
||||
vec2 meltUV = uv;
|
||||
meltUV.y += blend * 0.1 * noise(vec2(uv.x * 5.0, time));
|
||||
vec3 a = texture2D(u_currentImage, meltUV).rgb;
|
||||
vec3 b = texture2D(u_nextImage, uv).rgb;
|
||||
return mix(a, b, smoothstep(0.3, 0.7, blend));
|
||||
}
|
||||
|
||||
// --- Main ---
|
||||
|
||||
void main() {
|
||||
vec2 uv = v_uv;
|
||||
float severity = u_shaderSeverity;
|
||||
|
||||
// Apply distortion effects
|
||||
uv = meshWarp(uv, severity * u_morphSpeed, u_time);
|
||||
uv = meltEffect(uv, severity * 0.5, u_time);
|
||||
uv = glitchEffect(uv, severity, u_time);
|
||||
|
||||
// Clamp UV to prevent sampling outside texture
|
||||
uv = clamp(uv, 0.0, 1.0);
|
||||
|
||||
// Image transition
|
||||
vec3 color;
|
||||
if (u_transitionMode == 0) {
|
||||
color = transitionCrossfade(uv, u_blend);
|
||||
} else if (u_transitionMode == 1) {
|
||||
color = transitionDissolve(uv, u_blend);
|
||||
} else if (u_transitionMode == 2) {
|
||||
color = transitionGlitchCut(uv, u_blend, u_time);
|
||||
} else {
|
||||
color = transitionMeltMorph(uv, u_blend, u_time);
|
||||
}
|
||||
|
||||
// Chromatic aberration on composited image
|
||||
if (severity > 0.1) {
|
||||
vec3 aberrated = chromaticAberration(u_currentImage, uv, severity);
|
||||
color = mix(color, aberrated, severity * 0.3);
|
||||
}
|
||||
|
||||
// Color tint / palette shift
|
||||
color = mix(color, color * u_colorTint, severity * 0.3);
|
||||
|
||||
// Pulse brightness
|
||||
color *= pulse(u_time, severity);
|
||||
|
||||
// Film grain / noise
|
||||
color *= filmGrain(v_uv, u_time, u_noiseLevel);
|
||||
|
||||
// Scanlines
|
||||
color *= scanlines(v_uv, u_time);
|
||||
|
||||
// Vignette
|
||||
color *= vignette(v_uv, 0.5 + severity * 0.5);
|
||||
|
||||
// Flash effects
|
||||
if (u_flashIntensity > 0.0) {
|
||||
if (u_flashType == 0) {
|
||||
// White-out
|
||||
color = mix(color, vec3(1.0), u_flashIntensity);
|
||||
} else if (u_flashType == 1) {
|
||||
// Inversion
|
||||
color = mix(color, 1.0 - color, u_flashIntensity);
|
||||
} else {
|
||||
// Face flash (red tint)
|
||||
color = mix(color, vec3(1.0, 0.0, 0.0), u_flashIntensity * 0.7);
|
||||
}
|
||||
}
|
||||
|
||||
gl_FragColor = vec4(color, 1.0);
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.34.0
|
||||
torch>=2.2.0
|
||||
diffusers>=0.30.0
|
||||
transformers>=4.40.0
|
||||
accelerate>=0.30.0
|
||||
TTS>=0.22.0
|
||||
Pillow>=10.0.0
|
||||
numpy>=1.26.0
|
||||
pydantic>=2.0.0
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.23.0
|
||||
httpx>=0.27.0
|
||||
websockets>=12.0
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Entry point for AI Hell server."""
|
||||
|
||||
import logging
|
||||
import uvicorn
|
||||
|
||||
from server.config import config
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"server.main:app",
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
log_level="info",
|
||||
)
|
||||
@@ -0,0 +1,57 @@
|
||||
"""SDXL Turbo wrapper for horror image generation."""
|
||||
|
||||
import io
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
except ImportError: # pragma: no cover - exercised only in test environments
|
||||
AutoPipelineForText2Image = None # Tests patch this attribute directly.
|
||||
|
||||
from server.config import config
|
||||
from server.prompts import NEGATIVE_PROMPT
|
||||
|
||||
|
||||
class AssetGenerator:
|
||||
"""Generates horror images via SDXL Turbo."""
|
||||
|
||||
def __init__(self, device: str | None = None, model_id: str | None = None):
|
||||
self.device = device or config.device
|
||||
self.model_id = model_id or config.models.sdxl_model_id
|
||||
|
||||
if AutoPipelineForText2Image is None:
|
||||
raise RuntimeError(
|
||||
"diffusers is not installed; install diffusers to use AssetGenerator"
|
||||
)
|
||||
|
||||
use_fp16 = self.device == "cuda"
|
||||
self._pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
self.model_id,
|
||||
torch_dtype=torch.float16 if use_fp16 else torch.float32,
|
||||
variant="fp16" if use_fp16 else None,
|
||||
)
|
||||
if self.device == "cuda":
|
||||
self._pipe = self._pipe.to("cuda")
|
||||
|
||||
def generate(self, prompt: str, seed: int | None = None) -> bytes:
|
||||
"""Generate a 512x512 PNG image from a horror prompt. Returns PNG bytes."""
|
||||
if seed is None:
|
||||
seed = torch.randint(0, 2**32, (1,)).item()
|
||||
|
||||
generator = torch.Generator(device=self.device).manual_seed(seed)
|
||||
|
||||
result = self._pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=NEGATIVE_PROMPT,
|
||||
num_inference_steps=config.models.sdxl_steps,
|
||||
guidance_scale=config.models.sdxl_guidance_scale,
|
||||
width=config.models.sdxl_width,
|
||||
height=config.models.sdxl_height,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
image = result.images[0]
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
return buf.getvalue()
|
||||
@@ -0,0 +1,130 @@
|
||||
"""Asset pool — manages generated images and audio on disk with severity tagging."""
|
||||
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
from server.config import config
|
||||
|
||||
|
||||
@dataclass
|
||||
class Asset:
|
||||
"""A generated asset with metadata."""
|
||||
filename: str
|
||||
severity: float
|
||||
created_at: float
|
||||
asset_type: str # "image" or "audio"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
subdir = "img" if self.asset_type == "image" else "audio"
|
||||
return f"/assets/{subdir}/{self.filename}"
|
||||
|
||||
|
||||
class AssetPool:
|
||||
"""Thread-safe pool of generated assets with severity-based selection and rotation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_dir: str | None = None,
|
||||
max_images: int | None = None,
|
||||
max_audio: int | None = None,
|
||||
):
|
||||
self.base_dir = Path(base_dir or config.assets_dir)
|
||||
self.max_images = max_images if max_images is not None else config.escalation.max_images
|
||||
self.max_audio = max_audio if max_audio is not None else config.escalation.max_audio_clips
|
||||
self._images: list[Asset] = []
|
||||
self._audio: list[Asset] = []
|
||||
self._lock = Lock()
|
||||
|
||||
(self.base_dir / "img").mkdir(parents=True, exist_ok=True)
|
||||
(self.base_dir / "audio").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def image_count(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._images)
|
||||
|
||||
@property
|
||||
def audio_count(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._audio)
|
||||
|
||||
def add_image(self, data: bytes, severity: float) -> str:
|
||||
"""Save image data to disk and add to pool. Returns URL path."""
|
||||
filename = f"{uuid.uuid4().hex[:12]}.png"
|
||||
path = self.base_dir / "img" / filename
|
||||
path.write_bytes(data)
|
||||
|
||||
asset = Asset(
|
||||
filename=filename,
|
||||
severity=severity,
|
||||
created_at=time.monotonic(),
|
||||
asset_type="image",
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._images.append(asset)
|
||||
self._rotate(self._images, self.max_images, "img")
|
||||
|
||||
return asset.url
|
||||
|
||||
def add_audio(self, data: bytes, severity: float) -> str:
|
||||
"""Save audio data to disk and add to pool. Returns URL path."""
|
||||
filename = f"{uuid.uuid4().hex[:12]}.wav"
|
||||
path = self.base_dir / "audio" / filename
|
||||
path.write_bytes(data)
|
||||
|
||||
asset = Asset(
|
||||
filename=filename,
|
||||
severity=severity,
|
||||
created_at=time.monotonic(),
|
||||
asset_type="audio",
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._audio.append(asset)
|
||||
self._rotate(self._audio, self.max_audio, "audio")
|
||||
|
||||
return asset.url
|
||||
|
||||
def select_image(self, target_severity: float) -> str | None:
|
||||
"""Select an image near the target severity. Weighted random, biased toward close matches."""
|
||||
with self._lock:
|
||||
return self._select(self._images, target_severity)
|
||||
|
||||
def select_audio(self, target_severity: float) -> str | None:
|
||||
"""Select an audio clip near the target severity."""
|
||||
with self._lock:
|
||||
return self._select(self._audio, target_severity)
|
||||
|
||||
def _select(self, assets: list[Asset], target: float) -> str | None:
|
||||
"""Weighted selection: closer severity = higher weight."""
|
||||
if not assets:
|
||||
return None
|
||||
weights = []
|
||||
for a in assets:
|
||||
distance = abs(a.severity - target)
|
||||
weights.append(1.0 / (1.0 + distance))
|
||||
chosen = random.choices(assets, weights=weights, k=1)[0]
|
||||
return chosen.url
|
||||
|
||||
def _rotate(self, assets: list[Asset], max_count: int, subdir: str) -> None:
|
||||
"""Remove oldest assets when pool exceeds capacity. Must hold lock."""
|
||||
while len(assets) > max_count:
|
||||
old = assets.pop(0)
|
||||
path = self.base_dir / subdir / old.filename
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def get_status(self) -> dict:
|
||||
with self._lock:
|
||||
return {
|
||||
"image_pool_size": len(self._images),
|
||||
"audio_pool_size": len(self._audio),
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Application configuration with Pydantic models."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EscalationConfig(BaseModel):
|
||||
"""Escalation engine parameters."""
|
||||
rate: float = 0.05
|
||||
initial_batch_size: int = 40
|
||||
max_images: int = 200
|
||||
max_audio_clips: int = 50
|
||||
asset_swap_min: float = 0.5 # seconds (at high intensity)
|
||||
asset_swap_max: float = 15.0 # seconds (at low intensity)
|
||||
voice_mean_interval: float = 60.0 # Poisson mean at intensity 0
|
||||
silence_gap_min: float = 2.0
|
||||
silence_gap_max: float = 30.0
|
||||
fake_calm_chance: float = 0.08 # probability per phase update
|
||||
fake_calm_duration_min: float = 10.0
|
||||
fake_calm_duration_max: float = 30.0
|
||||
cluster_burst_chance: float = 0.1
|
||||
cluster_burst_count_min: int = 2
|
||||
cluster_burst_count_max: int = 5
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""AI model identifiers and generation parameters."""
|
||||
sdxl_model_id: str = "stabilityai/sdxl-turbo"
|
||||
sdxl_steps: int = 4
|
||||
sdxl_guidance_scale: float = 0.0
|
||||
sdxl_width: int = 512
|
||||
sdxl_height: int = 512
|
||||
xtts_model: str = "tts_models/multilingual/multi-dataset/xtts_v2"
|
||||
xtts_language: str = "en"
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""Top-level application config."""
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8400
|
||||
device: str = "cuda"
|
||||
assets_dir: str = "assets"
|
||||
samples_dir: str = "samples"
|
||||
escalation: EscalationConfig = Field(default_factory=EscalationConfig)
|
||||
models: ModelConfig = Field(default_factory=ModelConfig)
|
||||
|
||||
|
||||
config = AppConfig()
|
||||
@@ -0,0 +1,121 @@
|
||||
"""Escalation engine — intensity curve, phase params, stochastic timing."""
|
||||
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
from server.config import config
|
||||
|
||||
|
||||
class EscalationEngine:
|
||||
"""Computes intensity and rendering parameters from elapsed time.
|
||||
|
||||
Intensity follows a logarithmic curve: intensity = log(1 + elapsed * rate).
|
||||
Rendering parameters are derived from intensity and clamped to [0, 1].
|
||||
Timing intervals are randomized within ranges that shrink with intensity.
|
||||
"""
|
||||
|
||||
PALETTES = [
|
||||
"void_black", "crimson_void", "deep_rot", "sickly_green",
|
||||
"bruise_purple", "ash_grey", "bile_yellow", "blood_orange",
|
||||
]
|
||||
|
||||
def __init__(self, rate: float | None = None):
|
||||
self.rate = rate if rate is not None else config.escalation.rate
|
||||
self.session_start: float | None = None
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Begin a new escalation session."""
|
||||
self.session_start = time.monotonic()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Restart the session from zero."""
|
||||
self.session_start = time.monotonic()
|
||||
|
||||
def get_elapsed(self) -> float:
|
||||
"""Seconds since session start."""
|
||||
if self.session_start is None:
|
||||
return 0.0
|
||||
return time.monotonic() - self.session_start
|
||||
|
||||
def get_intensity(self, elapsed: float | None = None) -> float:
|
||||
"""Compute intensity from elapsed seconds. Logarithmic, never peaks."""
|
||||
if elapsed is None:
|
||||
elapsed = self.get_elapsed()
|
||||
return math.log(1 + elapsed * self.rate)
|
||||
|
||||
def get_phase_params(self, intensity: float | None = None) -> dict:
|
||||
"""Derive rendering parameters from current intensity.
|
||||
|
||||
All values clamped to [0, 1]. Higher intensity = more severe effects.
|
||||
"""
|
||||
if intensity is None:
|
||||
intensity = self.get_intensity()
|
||||
|
||||
def _sigmoid(x: float, midpoint: float = 2.0, steepness: float = 1.5) -> float:
|
||||
"""Smooth 0→1 mapping centered at midpoint."""
|
||||
return 1.0 / (1.0 + math.exp(-steepness * (x - midpoint)))
|
||||
|
||||
morph_speed = _sigmoid(intensity, midpoint=1.5, steepness=1.2)
|
||||
shader_severity = _sigmoid(intensity, midpoint=2.0, steepness=1.0)
|
||||
voice_frequency = _sigmoid(intensity, midpoint=2.5, steepness=0.8)
|
||||
noise_level = _sigmoid(intensity, midpoint=3.0, steepness=1.0)
|
||||
surprise_chance = min(1.0, _sigmoid(intensity, midpoint=3.5, steepness=0.6))
|
||||
|
||||
# Palette shifts to more aggressive colors at higher intensity
|
||||
palette_index = min(int(intensity), len(self.PALETTES) - 1)
|
||||
|
||||
return {
|
||||
"morph_speed": round(morph_speed, 3),
|
||||
"shader_severity": round(shader_severity, 3),
|
||||
"voice_frequency": round(voice_frequency, 3),
|
||||
"noise_level": round(noise_level, 3),
|
||||
"surprise_chance": round(surprise_chance, 3),
|
||||
"palette": self.PALETTES[palette_index],
|
||||
}
|
||||
|
||||
def get_asset_swap_interval(self, intensity: float | None = None) -> float:
|
||||
"""Random interval until next asset swap. Shrinks with intensity."""
|
||||
if intensity is None:
|
||||
intensity = self.get_intensity()
|
||||
cfg = config.escalation
|
||||
# Lerp from max to min as intensity increases, with noise
|
||||
t = min(1.0, intensity / 5.0)
|
||||
base = cfg.asset_swap_max - t * (cfg.asset_swap_max - cfg.asset_swap_min)
|
||||
jitter = random.uniform(-base * 0.3, base * 0.3)
|
||||
return max(cfg.asset_swap_min, base + jitter)
|
||||
|
||||
def get_voice_interval(self, intensity: float | None = None) -> float:
|
||||
"""Poisson-distributed interval until next voice clip. Mean decreases with intensity."""
|
||||
if intensity is None:
|
||||
intensity = self.get_intensity()
|
||||
cfg = config.escalation
|
||||
# Mean interval decreases from voice_mean_interval toward 3s
|
||||
mean = max(3.0, cfg.voice_mean_interval / (1 + intensity))
|
||||
return random.expovariate(1.0 / mean)
|
||||
|
||||
def should_fake_calm(self) -> bool:
|
||||
"""Roll for a fake calm period."""
|
||||
return random.random() < config.escalation.fake_calm_chance
|
||||
|
||||
def get_fake_calm_duration(self) -> float:
|
||||
"""Duration of a fake calm period."""
|
||||
cfg = config.escalation
|
||||
return random.uniform(cfg.fake_calm_duration_min, cfg.fake_calm_duration_max)
|
||||
|
||||
def should_cluster_burst(self) -> bool:
|
||||
"""Roll for a cluster burst (rapid-fire events)."""
|
||||
return random.random() < config.escalation.cluster_burst_chance
|
||||
|
||||
def get_cluster_burst_count(self) -> int:
|
||||
"""Number of events in a cluster burst."""
|
||||
cfg = config.escalation
|
||||
return random.randint(cfg.cluster_burst_count_min, cfg.cluster_burst_count_max)
|
||||
|
||||
def select_severity(self, intensity: float | None = None) -> float:
|
||||
"""Pick a target severity for the next asset. Biased toward current intensity."""
|
||||
if intensity is None:
|
||||
intensity = self.get_intensity()
|
||||
# Normal distribution centered on intensity, clipped to [0, intensity+1]
|
||||
severity = random.gauss(intensity, 0.5)
|
||||
return max(0.0, min(severity, intensity + 1.0))
|
||||
+254
@@ -0,0 +1,254 @@
|
||||
"""FastAPI application — WebSocket streaming, REST endpoints, background workers."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from server.config import config
|
||||
from server.escalation import EscalationEngine
|
||||
from server.asset_pool import AssetPool
|
||||
from server.streaming import StreamManager
|
||||
from server.prompts import get_image_prompt, get_voice_text, get_direct_address_text
|
||||
|
||||
logger = logging.getLogger("ai-hell")
|
||||
|
||||
# Global instances (set during lifespan or create_app)
|
||||
escalation: EscalationEngine | None = None
|
||||
pool: AssetPool | None = None
|
||||
stream: StreamManager | None = None
|
||||
asset_gen = None # AssetGenerator (lazy, needs GPU)
|
||||
voice_gen = None # VoiceGenerator (lazy, needs GPU)
|
||||
_workers: list[asyncio.Task] = []
|
||||
|
||||
|
||||
def create_app(skip_models: bool = False) -> FastAPI:
|
||||
"""Create the FastAPI app. skip_models=True for testing without GPU."""
|
||||
global escalation, pool, stream, asset_gen, voice_gen
|
||||
|
||||
escalation = EscalationEngine()
|
||||
pool = AssetPool()
|
||||
stream = StreamManager()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global asset_gen, voice_gen
|
||||
escalation.start_session()
|
||||
|
||||
if not skip_models:
|
||||
from server.asset_generator import AssetGenerator
|
||||
from server.voice_generator import VoiceGenerator
|
||||
|
||||
logger.info("Loading SDXL Turbo...")
|
||||
asset_gen = AssetGenerator()
|
||||
logger.info("Loading XTTS v2...")
|
||||
voice_gen = VoiceGenerator()
|
||||
logger.info("Models loaded. Generating initial batch...")
|
||||
|
||||
# Generate initial asset batch in background
|
||||
loop = asyncio.get_running_loop()
|
||||
_workers.append(asyncio.create_task(_initial_batch(loop)))
|
||||
_workers.append(asyncio.create_task(_background_generator(loop)))
|
||||
_workers.append(asyncio.create_task(_escalation_loop()))
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown workers
|
||||
for task in _workers:
|
||||
task.cancel()
|
||||
|
||||
the_app = FastAPI(title="AI Hell", lifespan=lifespan)
|
||||
|
||||
# Mount assets directory for static serving
|
||||
assets_dir = Path(config.assets_dir)
|
||||
assets_dir.mkdir(parents=True, exist_ok=True)
|
||||
(assets_dir / "img").mkdir(exist_ok=True)
|
||||
(assets_dir / "audio").mkdir(exist_ok=True)
|
||||
the_app.mount("/assets", StaticFiles(directory=str(assets_dir)), name="assets")
|
||||
|
||||
# --- REST endpoints ---
|
||||
|
||||
@the_app.get("/")
|
||||
async def index():
|
||||
html_path = Path(__file__).parent.parent / "frontend" / "index.html"
|
||||
if html_path.exists():
|
||||
return FileResponse(html_path, media_type="text/html")
|
||||
return HTMLResponse("<h1>AI Hell</h1><p>Frontend not found.</p>")
|
||||
|
||||
@the_app.get("/status")
|
||||
async def status():
|
||||
return {
|
||||
"intensity": round(escalation.get_intensity(), 2),
|
||||
"connected_clients": stream.client_count,
|
||||
**pool.get_status(),
|
||||
}
|
||||
|
||||
@the_app.post("/reset")
|
||||
async def reset():
|
||||
escalation.reset()
|
||||
return {"status": "ok"}
|
||||
|
||||
# --- WebSocket ---
|
||||
|
||||
@the_app.websocket("/stream")
|
||||
async def stream_ws(ws: WebSocket):
|
||||
await ws.accept()
|
||||
stream.add_client(ws)
|
||||
# Send current state immediately
|
||||
intensity = escalation.get_intensity()
|
||||
params = escalation.get_phase_params(intensity)
|
||||
await ws.send_text(
|
||||
__import__("json").dumps({
|
||||
"type": "phase",
|
||||
"intensity": round(intensity, 2),
|
||||
"params": params,
|
||||
})
|
||||
)
|
||||
try:
|
||||
while True:
|
||||
await ws.receive_text() # Keep alive, ignore pings
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
stream.remove_client(ws)
|
||||
|
||||
# Serve frontend shader files
|
||||
@the_app.get("/shaders/{filename}")
|
||||
async def serve_shader(filename: str):
|
||||
shader_path = Path(__file__).parent.parent / "frontend" / "shaders" / filename
|
||||
if shader_path.exists():
|
||||
return FileResponse(shader_path, media_type="text/plain")
|
||||
return HTMLResponse("Not found", status_code=404)
|
||||
|
||||
return the_app
|
||||
|
||||
|
||||
async def _initial_batch(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Generate the initial pool of images and audio clips."""
|
||||
batch_size = config.escalation.initial_batch_size
|
||||
img_count = int(batch_size * 0.75)
|
||||
audio_count = batch_size - img_count
|
||||
|
||||
for i in range(img_count):
|
||||
severity = (i / max(1, img_count - 1)) * 4.0 # Spread across severity range
|
||||
prompt = get_image_prompt(severity)
|
||||
try:
|
||||
data = await asyncio.to_thread(asset_gen.generate, prompt)
|
||||
pool.add_image(data, severity=severity)
|
||||
logger.info(f"Initial image {i+1}/{img_count} (severity={severity:.1f})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate initial image: {e}")
|
||||
|
||||
for i in range(audio_count):
|
||||
severity = (i / max(1, audio_count - 1)) * 4.0
|
||||
text = get_voice_text()
|
||||
try:
|
||||
data = await asyncio.to_thread(voice_gen.generate, text)
|
||||
pool.add_audio(data, severity=severity)
|
||||
logger.info(f"Initial audio {i+1}/{audio_count} (severity={severity:.1f})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate initial audio: {e}")
|
||||
|
||||
logger.info("Initial batch complete.")
|
||||
|
||||
|
||||
async def _background_generator(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Continuously generate new assets biased toward current viewer needs."""
|
||||
while True:
|
||||
await asyncio.sleep(random.uniform(10, 30))
|
||||
if stream.client_count == 0:
|
||||
continue
|
||||
|
||||
intensity = escalation.get_intensity()
|
||||
severity = escalation.select_severity(intensity)
|
||||
|
||||
# Alternate between images and audio
|
||||
if random.random() < 0.7: # 70% images, 30% audio
|
||||
prompt = get_image_prompt(severity)
|
||||
try:
|
||||
data = await asyncio.to_thread(asset_gen.generate, prompt)
|
||||
pool.add_image(data, severity=severity)
|
||||
except Exception as e:
|
||||
logger.error(f"Background image gen failed: {e}")
|
||||
else:
|
||||
text = get_voice_text()
|
||||
try:
|
||||
data = await asyncio.to_thread(voice_gen.generate, text)
|
||||
pool.add_audio(data, severity=severity)
|
||||
except Exception as e:
|
||||
logger.error(f"Background audio gen failed: {e}")
|
||||
|
||||
|
||||
async def _escalation_loop() -> None:
|
||||
"""Main escalation loop — pushes phase updates and triggers events."""
|
||||
while True:
|
||||
if stream.client_count == 0:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
intensity = escalation.get_intensity()
|
||||
params = escalation.get_phase_params(intensity)
|
||||
|
||||
# Phase update
|
||||
await stream.broadcast_phase(intensity=intensity, params=params)
|
||||
|
||||
# Asset swap
|
||||
severity = escalation.select_severity(intensity)
|
||||
url = pool.select_image(target_severity=severity)
|
||||
if url:
|
||||
transition = _pick_transition(intensity)
|
||||
await stream.broadcast_asset(url=url, severity=severity, transition=transition)
|
||||
|
||||
# Whisper check
|
||||
voice_interval = escalation.get_voice_interval(intensity)
|
||||
if random.random() < (2.0 / max(1.0, voice_interval)):
|
||||
audio_url = pool.select_audio(target_severity=severity)
|
||||
if audio_url:
|
||||
await stream.broadcast_whisper(
|
||||
url=audio_url,
|
||||
pan=random.uniform(-1.0, 1.0),
|
||||
volume=random.uniform(0.1, 0.8),
|
||||
reverb=random.uniform(0.3, 0.9),
|
||||
)
|
||||
|
||||
# Direct address check (rarer)
|
||||
if intensity > 1.5 and random.random() < params["voice_frequency"] * 0.1:
|
||||
text = get_direct_address_text()
|
||||
if voice_gen:
|
||||
try:
|
||||
data = await asyncio.to_thread(voice_gen.generate, text)
|
||||
audio_b64 = base64.b64encode(data).decode("ascii")
|
||||
await stream.broadcast_address(audio_b64=audio_b64, text=text)
|
||||
except Exception as e:
|
||||
logger.error(f"Direct address gen failed: {e}")
|
||||
|
||||
# Surprise scare check
|
||||
if random.random() < params["surprise_chance"] * 0.05:
|
||||
effect = random.choice(["face_flash", "white_out", "inversion", "glitch_burst"])
|
||||
duration = random.randint(50, 300)
|
||||
await stream.broadcast_scare(effect=effect, duration_ms=duration)
|
||||
|
||||
# Wait for next cycle
|
||||
swap_interval = escalation.get_asset_swap_interval(intensity)
|
||||
await asyncio.sleep(swap_interval)
|
||||
|
||||
|
||||
def _pick_transition(intensity: float) -> str:
|
||||
"""Pick transition mode based on intensity."""
|
||||
if intensity < 1.0:
|
||||
return "crossfade"
|
||||
elif intensity < 2.5:
|
||||
return random.choice(["crossfade", "dissolve", "melt_morph"])
|
||||
else:
|
||||
return random.choice(["glitch_cut", "melt_morph", "dissolve", "crossfade"])
|
||||
|
||||
|
||||
# Default app instance for uvicorn
|
||||
app = create_app(skip_models=False)
|
||||
@@ -0,0 +1,139 @@
|
||||
"""Horror prompt library — severity-tiered SDXL prompts + XTTS voice texts."""
|
||||
|
||||
import random
|
||||
|
||||
# Negative prompt applied to all SDXL generations
|
||||
NEGATIVE_PROMPT = (
|
||||
"cheerful, bright, colorful, cartoon, anime, text, watermark, "
|
||||
"logo, signature, pleasant, happy, cute, well-lit, clean"
|
||||
)
|
||||
|
||||
# (min_severity, [prompt_templates])
|
||||
# Each tier's prompts are available when severity >= min_severity.
|
||||
# Prompts use {detail} placeholder for procedural variation.
|
||||
SEVERITY_TIERS: list[tuple[float, list[str]]] = [
|
||||
(0.0, [
|
||||
"dark abstract void, deep shadows, {detail}, horror atmosphere, cinematic",
|
||||
"black fog rolling over dark water, {detail}, ominous, photorealistic",
|
||||
"concrete corridor stretching into darkness, {detail}, liminal space, unsettling",
|
||||
"dark gradient with subtle organic texture, {detail}, dread, macro photography",
|
||||
"abandoned room in total darkness, single light source, {detail}, eerie silence",
|
||||
"static noise pattern forming almost-shapes, {detail}, analog horror aesthetic",
|
||||
"deep underground cavern, no visible exit, {detail}, claustrophobic, dark",
|
||||
]),
|
||||
(1.0, [
|
||||
"distorted human face emerging from darkness, {detail}, uncanny valley, horror",
|
||||
"long dark hallway with a figure at the end, {detail}, found footage aesthetic",
|
||||
"room where the walls are slightly wrong, {detail}, liminal horror, photorealistic",
|
||||
"mirror reflection that doesn't match, {detail}, psychological horror",
|
||||
"staircase descending into impossible depth, {detail}, surreal horror",
|
||||
"doorway opening into a void, {detail}, threshold horror, dark atmosphere",
|
||||
"familiar room but every proportion is wrong, {detail}, dreamlike horror",
|
||||
]),
|
||||
(2.0, [
|
||||
"face melting into dark liquid, {detail}, body horror, visceral, photorealistic",
|
||||
"impossible architecture folding in on itself, {detail}, Escher nightmare, dark",
|
||||
"multiple overlapping faces fused together, {detail}, uncanny, disturbing",
|
||||
"corridor with too many doors, all slightly open, {detail}, psychological horror",
|
||||
"human figure with wrong number of limbs, {detail}, body horror, dark",
|
||||
"room full of eyes watching from every surface, {detail}, paranoid horror",
|
||||
"teeth growing from walls, {detail}, organic horror, visceral, photorealistic",
|
||||
]),
|
||||
(3.0, [
|
||||
"screaming void, flesh merging with architecture, {detail}, extreme body horror",
|
||||
"reality fracturing into bleeding shards, {detail}, cosmic horror, overwhelming",
|
||||
"mass of tangled human forms, {detail}, hellscape, Beksinski inspired, photorealistic",
|
||||
"sky replaced by a massive watching face, {detail}, cosmic dread, surreal",
|
||||
"ground made of writhing organic matter, {detail}, Giger inspired, dark",
|
||||
"impossible geometry that hurts to perceive, {detail}, Lovecraftian, extreme",
|
||||
"world turned inside out, organs as landscape, {detail}, visceral cosmic horror",
|
||||
]),
|
||||
]
|
||||
|
||||
# Procedural detail fragments inserted into {detail} placeholders
|
||||
_DETAILS = [
|
||||
"wet surfaces", "rust and decay", "dim red light", "flickering fluorescent",
|
||||
"peeling paint", "fog and mist", "dripping liquid", "cracked surfaces",
|
||||
"organic growths", "shadow patterns", "reflected light on water",
|
||||
"dust particles", "cobwebs", "stained walls", "scratched metal",
|
||||
"condensation", "mold patterns", "burned edges", "frozen in time",
|
||||
"overlapping shadows", "single bare bulb", "moonlight through cracks",
|
||||
"bioluminescent", "blood-red sky", "green pallor", "bruise-purple tint",
|
||||
]
|
||||
|
||||
# Whisper fragments — sentence fragments, numbers, names, nonsense
|
||||
_WHISPERS = [
|
||||
"seven", "behind you", "don't turn around", "it remembers",
|
||||
"the door", "counting", "almost time", "in the walls",
|
||||
"not alone", "watching", "three two one", "forgetting",
|
||||
"underneath", "the wrong room", "teeth", "it knows your name",
|
||||
"listen", "the sound", "nobody left", "opening",
|
||||
"he's here", "she won't stop", "the children", "below",
|
||||
"always here", "never gone", "the dark", "it follows",
|
||||
"coming closer", "just outside", "the mirror", "look",
|
||||
"run", "too late", "the floor", "above you",
|
||||
"inside", "the old house", "breathing", "footsteps",
|
||||
"scratching", "dripping", "humming", "whispers",
|
||||
"ha ha ha ha", "one two three four five", "again again again",
|
||||
"please", "help me", "let me in", "let me out",
|
||||
]
|
||||
|
||||
# Direct address phrases — "it sees you" moments
|
||||
_DIRECT_ADDRESS = [
|
||||
"you're still here",
|
||||
"don't leave",
|
||||
"I can see you",
|
||||
"why",
|
||||
"stay",
|
||||
"you came back",
|
||||
"I've been waiting",
|
||||
"don't close your eyes",
|
||||
"you can't leave",
|
||||
"I know you're there",
|
||||
"look at me",
|
||||
"do you hear it",
|
||||
"it's behind you",
|
||||
"we see you",
|
||||
"you're one of us now",
|
||||
"don't you remember",
|
||||
"you were here before",
|
||||
"this is yours",
|
||||
"you did this",
|
||||
"welcome home",
|
||||
]
|
||||
|
||||
|
||||
def get_image_prompt(severity: float) -> str:
|
||||
"""Select and fill a horror prompt appropriate for the given severity.
|
||||
|
||||
Higher severity unlocks more extreme prompt tiers.
|
||||
A random prompt is chosen from all available tiers, weighted toward
|
||||
the highest unlocked tier.
|
||||
"""
|
||||
available: list[tuple[float, str]] = []
|
||||
for threshold, templates in SEVERITY_TIERS:
|
||||
if severity >= threshold:
|
||||
for tmpl in templates:
|
||||
available.append((threshold, tmpl))
|
||||
|
||||
if not available:
|
||||
available = [(0.0, t) for t in SEVERITY_TIERS[0][1]]
|
||||
|
||||
# Weight toward higher tiers: tier_weight = 1 + tier_index
|
||||
tier_thresholds = sorted(set(t[0] for t in available))
|
||||
tier_map = {t: i + 1 for i, t in enumerate(tier_thresholds)}
|
||||
weights = [tier_map[threshold] for threshold, _ in available]
|
||||
|
||||
_, template = random.choices(available, weights=weights, k=1)[0]
|
||||
detail = random.choice(_DETAILS)
|
||||
return template.format(detail=detail)
|
||||
|
||||
|
||||
def get_voice_text() -> str:
|
||||
"""Random whisper fragment for XTTS voice generation."""
|
||||
return random.choice(_WHISPERS)
|
||||
|
||||
|
||||
def get_direct_address_text() -> str:
|
||||
"""Random direct address phrase for 'it sees you' moments."""
|
||||
return random.choice(_DIRECT_ADDRESS)
|
||||
@@ -0,0 +1,90 @@
|
||||
"""WebSocket broadcast manager for streaming horror to connected clients."""
|
||||
|
||||
import json
|
||||
import random
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
|
||||
class StreamManager:
|
||||
"""Manages WebSocket clients and broadcasts horror events."""
|
||||
|
||||
TRANSITIONS = ["crossfade", "dissolve", "glitch_cut", "melt_morph"]
|
||||
|
||||
def __init__(self):
|
||||
self._clients: set[WebSocket] = set()
|
||||
|
||||
@property
|
||||
def client_count(self) -> int:
|
||||
return len(self._clients)
|
||||
|
||||
def add_client(self, ws: WebSocket) -> None:
|
||||
self._clients.add(ws)
|
||||
|
||||
def remove_client(self, ws: WebSocket) -> None:
|
||||
self._clients.discard(ws)
|
||||
|
||||
async def _broadcast(self, message: str) -> None:
|
||||
"""Send to all clients, remove dead ones."""
|
||||
dead: list[WebSocket] = []
|
||||
for ws in self._clients:
|
||||
try:
|
||||
await ws.send_text(message)
|
||||
except Exception:
|
||||
dead.append(ws)
|
||||
for ws in dead:
|
||||
self._clients.discard(ws)
|
||||
|
||||
async def broadcast_phase(self, intensity: float, params: dict) -> None:
|
||||
"""Push a phase update with current intensity and rendering params."""
|
||||
msg = json.dumps({
|
||||
"type": "phase",
|
||||
"intensity": round(intensity, 2),
|
||||
"params": params,
|
||||
})
|
||||
await self._broadcast(msg)
|
||||
|
||||
async def broadcast_asset(
|
||||
self, url: str, severity: float, transition: str | None = None,
|
||||
) -> None:
|
||||
"""Push a new image asset reference."""
|
||||
if transition is None:
|
||||
transition = random.choice(self.TRANSITIONS)
|
||||
msg = json.dumps({
|
||||
"type": "asset",
|
||||
"url": url,
|
||||
"severity": round(severity, 2),
|
||||
"transition": transition,
|
||||
})
|
||||
await self._broadcast(msg)
|
||||
|
||||
async def broadcast_whisper(
|
||||
self, url: str, pan: float, volume: float, reverb: float,
|
||||
) -> None:
|
||||
"""Push a whisper audio clip reference."""
|
||||
msg = json.dumps({
|
||||
"type": "whisper",
|
||||
"url": url,
|
||||
"pan": round(pan, 2),
|
||||
"volume": round(volume, 2),
|
||||
"reverb": round(reverb, 2),
|
||||
})
|
||||
await self._broadcast(msg)
|
||||
|
||||
async def broadcast_address(self, audio_b64: str, text: str) -> None:
|
||||
"""Push a direct address audio clip (base64-encoded WAV)."""
|
||||
msg = json.dumps({
|
||||
"type": "address",
|
||||
"audio": audio_b64,
|
||||
"text": text,
|
||||
})
|
||||
await self._broadcast(msg)
|
||||
|
||||
async def broadcast_scare(self, effect: str, duration_ms: int) -> None:
|
||||
"""Push a scare event (face flash, white-out, inversion, etc.)."""
|
||||
msg = json.dumps({
|
||||
"type": "scare",
|
||||
"effect": effect,
|
||||
"duration_ms": duration_ms,
|
||||
})
|
||||
await self._broadcast(msg)
|
||||
@@ -0,0 +1,73 @@
|
||||
"""XTTS v2 wrapper for voice cloning from non-voice audio samples."""
|
||||
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from TTS.api import TTS
|
||||
except ImportError:
|
||||
TTS = None # Tests patch this; real runtime requires the TTS package
|
||||
|
||||
from server.config import config
|
||||
|
||||
|
||||
class VoiceGenerator:
|
||||
"""Generates speech cloned from arbitrary audio samples via XTTS v2."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: str | None = None,
|
||||
model_name: str | None = None,
|
||||
samples_dir: str | None = None,
|
||||
):
|
||||
self.device = device or config.device
|
||||
self.model_name = model_name or config.models.xtts_model
|
||||
self.samples_dir = Path(samples_dir or config.samples_dir)
|
||||
if TTS is None:
|
||||
raise RuntimeError(
|
||||
"TTS package is not installed; cannot instantiate VoiceGenerator"
|
||||
)
|
||||
self._tts = TTS(model_name=self.model_name)
|
||||
self._tts.to(self.device)
|
||||
|
||||
def generate(self, text: str, speaker_wav: str | None = None) -> bytes:
|
||||
"""Generate speech as WAV bytes. Uses a random clone source if none specified."""
|
||||
if speaker_wav is None:
|
||||
speaker_wav = self.random_clone_source()
|
||||
if speaker_wav is None:
|
||||
raise ValueError("No speaker WAV provided and no samples available")
|
||||
|
||||
# XTTS writes to file, so use a temp file
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
tmp.close()
|
||||
try:
|
||||
self._tts.tts_to_file(
|
||||
text=text,
|
||||
speaker_wav=speaker_wav,
|
||||
language=config.models.xtts_language,
|
||||
file_path=tmp.name,
|
||||
)
|
||||
with open(tmp.name, "rb") as f:
|
||||
return f.read()
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp.name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def list_clone_sources(self) -> list[str]:
|
||||
"""List all WAV files in the samples directory."""
|
||||
if not self.samples_dir.is_dir():
|
||||
return []
|
||||
return [
|
||||
str(p) for p in sorted(self.samples_dir.glob("*.wav"))
|
||||
]
|
||||
|
||||
def random_clone_source(self) -> str | None:
|
||||
"""Pick a random clone source WAV file."""
|
||||
sources = self.list_clone_sources()
|
||||
if not sources:
|
||||
return None
|
||||
return random.choice(sources)
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Shared test fixtures for AI Hell."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from server.config import config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_assets(tmp_path):
|
||||
"""Temporary assets directory for pool tests."""
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_config():
|
||||
"""Reset config to defaults between tests."""
|
||||
# Config is a module-level singleton; tests shouldn't mutate it
|
||||
# but if they do, this ensures isolation
|
||||
yield
|
||||
@@ -0,0 +1,52 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from PIL import Image
|
||||
|
||||
from server.asset_generator import AssetGenerator
|
||||
|
||||
|
||||
class TestAssetGenerator:
|
||||
@patch("server.asset_generator.AutoPipelineForText2Image")
|
||||
def test_init_loads_model(self, mock_pipeline_cls):
|
||||
"""Generator loads the SDXL Turbo pipeline on init."""
|
||||
mock_pipe = MagicMock()
|
||||
mock_pipeline_cls.from_pretrained.return_value = mock_pipe
|
||||
mock_pipe.to.return_value = mock_pipe
|
||||
|
||||
gen = AssetGenerator(device="cpu")
|
||||
mock_pipeline_cls.from_pretrained.assert_called_once()
|
||||
|
||||
@patch("server.asset_generator.AutoPipelineForText2Image")
|
||||
def test_generate_returns_bytes(self, mock_pipeline_cls):
|
||||
"""Generate returns PNG bytes."""
|
||||
mock_pipe = MagicMock()
|
||||
mock_pipeline_cls.from_pretrained.return_value = mock_pipe
|
||||
mock_pipe.to.return_value = mock_pipe
|
||||
|
||||
# Mock pipeline output
|
||||
fake_image = Image.new("RGB", (512, 512), color="black")
|
||||
mock_result = MagicMock()
|
||||
mock_result.images = [fake_image]
|
||||
mock_pipe.return_value = mock_result
|
||||
|
||||
gen = AssetGenerator(device="cpu")
|
||||
data = gen.generate("dark void, horror")
|
||||
assert isinstance(data, bytes)
|
||||
assert len(data) > 0
|
||||
|
||||
@patch("server.asset_generator.AutoPipelineForText2Image")
|
||||
def test_generate_uses_negative_prompt(self, mock_pipeline_cls):
|
||||
"""Generate passes the negative prompt to the pipeline."""
|
||||
mock_pipe = MagicMock()
|
||||
mock_pipeline_cls.from_pretrained.return_value = mock_pipe
|
||||
mock_pipe.to.return_value = mock_pipe
|
||||
|
||||
fake_image = Image.new("RGB", (512, 512), color="black")
|
||||
mock_result = MagicMock()
|
||||
mock_result.images = [fake_image]
|
||||
mock_pipe.return_value = mock_result
|
||||
|
||||
gen = AssetGenerator(device="cpu")
|
||||
gen.generate("test prompt")
|
||||
|
||||
call_kwargs = mock_pipe.call_args
|
||||
assert "negative_prompt" in call_kwargs.kwargs
|
||||
@@ -0,0 +1,112 @@
|
||||
from pathlib import Path
|
||||
|
||||
from server.asset_pool import AssetPool
|
||||
|
||||
|
||||
def _make_pool(tmp_path: Path, max_images: int = 10, max_audio: int = 5) -> AssetPool:
|
||||
return AssetPool(
|
||||
base_dir=str(tmp_path),
|
||||
max_images=max_images,
|
||||
max_audio=max_audio,
|
||||
)
|
||||
|
||||
|
||||
def _fake_image(tmp_path: Path, pool: AssetPool, severity: float) -> str:
|
||||
content = b"fake png data"
|
||||
return pool.add_image(content, severity=severity)
|
||||
|
||||
|
||||
def _fake_audio(tmp_path: Path, pool: AssetPool, severity: float) -> str:
|
||||
content = b"fake wav data"
|
||||
return pool.add_audio(content, severity=severity)
|
||||
|
||||
|
||||
class TestAssetPoolInit:
|
||||
def test_creates_directories(self, tmp_path):
|
||||
"""Pool creates img/ and audio/ subdirectories."""
|
||||
pool = _make_pool(tmp_path)
|
||||
assert (tmp_path / "img").is_dir()
|
||||
assert (tmp_path / "audio").is_dir()
|
||||
|
||||
def test_empty_pool(self, tmp_path):
|
||||
"""New pool has no assets."""
|
||||
pool = _make_pool(tmp_path)
|
||||
assert pool.image_count == 0
|
||||
assert pool.audio_count == 0
|
||||
|
||||
|
||||
class TestAddAssets:
|
||||
def test_add_image(self, tmp_path):
|
||||
"""Adding an image increments count and returns a URL path."""
|
||||
pool = _make_pool(tmp_path)
|
||||
url = _fake_image(tmp_path, pool, severity=1.0)
|
||||
assert pool.image_count == 1
|
||||
assert url.startswith("/assets/img/")
|
||||
assert url.endswith(".png")
|
||||
|
||||
def test_add_audio(self, tmp_path):
|
||||
"""Adding audio increments count and returns a URL path."""
|
||||
pool = _make_pool(tmp_path)
|
||||
url = _fake_audio(tmp_path, pool, severity=1.0)
|
||||
assert pool.audio_count == 1
|
||||
assert url.startswith("/assets/audio/")
|
||||
assert url.endswith(".wav")
|
||||
|
||||
def test_file_exists_on_disk(self, tmp_path):
|
||||
"""Added assets exist as real files."""
|
||||
pool = _make_pool(tmp_path)
|
||||
url = _fake_image(tmp_path, pool, severity=1.0)
|
||||
filename = url.split("/")[-1]
|
||||
assert (tmp_path / "img" / filename).exists()
|
||||
|
||||
|
||||
class TestSelectAssets:
|
||||
def test_select_image_by_severity(self, tmp_path):
|
||||
"""Selects an image closest to target severity."""
|
||||
pool = _make_pool(tmp_path)
|
||||
_fake_image(tmp_path, pool, severity=0.5)
|
||||
_fake_image(tmp_path, pool, severity=2.0)
|
||||
_fake_image(tmp_path, pool, severity=4.0)
|
||||
url = pool.select_image(target_severity=1.8)
|
||||
assert url is not None
|
||||
|
||||
def test_select_audio_by_severity(self, tmp_path):
|
||||
"""Selects an audio clip closest to target severity."""
|
||||
pool = _make_pool(tmp_path)
|
||||
_fake_audio(tmp_path, pool, severity=0.5)
|
||||
_fake_audio(tmp_path, pool, severity=3.0)
|
||||
url = pool.select_audio(target_severity=2.5)
|
||||
assert url is not None
|
||||
|
||||
def test_select_from_empty_returns_none(self, tmp_path):
|
||||
"""Selecting from empty pool returns None."""
|
||||
pool = _make_pool(tmp_path)
|
||||
assert pool.select_image(target_severity=1.0) is None
|
||||
assert pool.select_audio(target_severity=1.0) is None
|
||||
|
||||
|
||||
class TestRotation:
|
||||
def test_image_rotation(self, tmp_path):
|
||||
"""Oldest images are removed when pool exceeds max."""
|
||||
pool = _make_pool(tmp_path, max_images=3)
|
||||
for i in range(5):
|
||||
_fake_image(tmp_path, pool, severity=float(i))
|
||||
assert pool.image_count == 3
|
||||
|
||||
def test_audio_rotation(self, tmp_path):
|
||||
"""Oldest audio clips are removed when pool exceeds max."""
|
||||
pool = _make_pool(tmp_path, max_audio=2)
|
||||
for i in range(4):
|
||||
_fake_audio(tmp_path, pool, severity=float(i))
|
||||
assert pool.audio_count == 2
|
||||
|
||||
|
||||
class TestStatus:
|
||||
def test_status_dict(self, tmp_path):
|
||||
"""Status returns pool sizes."""
|
||||
pool = _make_pool(tmp_path)
|
||||
_fake_image(tmp_path, pool, severity=1.0)
|
||||
_fake_audio(tmp_path, pool, severity=1.0)
|
||||
status = pool.get_status()
|
||||
assert status["image_pool_size"] == 1
|
||||
assert status["audio_pool_size"] == 1
|
||||
@@ -0,0 +1,28 @@
|
||||
from server.config import config, EscalationConfig, ModelConfig, AppConfig
|
||||
|
||||
|
||||
def test_config_defaults():
|
||||
"""Config loads with sane defaults."""
|
||||
assert config.port == 8400
|
||||
assert config.host == "0.0.0.0"
|
||||
assert config.device == "cuda"
|
||||
|
||||
|
||||
def test_escalation_defaults():
|
||||
"""Escalation config has correct default rate and timing."""
|
||||
assert config.escalation.rate == 0.05
|
||||
assert config.escalation.initial_batch_size == 40
|
||||
assert config.escalation.max_images == 200
|
||||
assert config.escalation.max_audio_clips == 50
|
||||
|
||||
|
||||
def test_model_defaults():
|
||||
"""Model config points to correct model IDs."""
|
||||
assert "sdxl-turbo" in config.models.sdxl_model_id
|
||||
assert "xtts" in config.models.xtts_model
|
||||
|
||||
|
||||
def test_timing_defaults():
|
||||
"""Timing ranges are ordered correctly."""
|
||||
assert config.escalation.asset_swap_min < config.escalation.asset_swap_max
|
||||
assert config.escalation.voice_mean_interval > 0
|
||||
@@ -0,0 +1,116 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
from server.escalation import EscalationEngine
|
||||
|
||||
|
||||
class TestIntensityCurve:
|
||||
def test_intensity_at_zero(self):
|
||||
"""Intensity is 0 at start."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
assert engine.get_intensity(elapsed=0.0) == 0.0
|
||||
|
||||
def test_intensity_at_40s(self):
|
||||
"""Intensity ~1.0 at 40s with default rate."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
intensity = engine.get_intensity(elapsed=40.0)
|
||||
assert abs(intensity - math.log(1 + 40 * 0.05)) < 0.01
|
||||
|
||||
def test_intensity_monotonic(self):
|
||||
"""Intensity always increases."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
prev = 0.0
|
||||
for t in [10, 30, 60, 120, 300, 600, 1800]:
|
||||
val = engine.get_intensity(elapsed=float(t))
|
||||
assert val > prev
|
||||
prev = val
|
||||
|
||||
def test_intensity_never_peaks(self):
|
||||
"""Even at 1 hour, intensity is still rising."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
at_30m = engine.get_intensity(elapsed=1800.0)
|
||||
at_60m = engine.get_intensity(elapsed=3600.0)
|
||||
assert at_60m > at_30m
|
||||
|
||||
|
||||
class TestPhaseParams:
|
||||
def test_low_intensity_params(self):
|
||||
"""Low intensity produces slow, subtle params."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
params = engine.get_phase_params(intensity=0.5)
|
||||
assert params["morph_speed"] < 0.3
|
||||
assert params["shader_severity"] < 0.3
|
||||
|
||||
def test_high_intensity_params(self):
|
||||
"""High intensity produces fast, severe params."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
params = engine.get_phase_params(intensity=4.0)
|
||||
assert params["morph_speed"] > 0.6
|
||||
assert params["shader_severity"] > 0.7
|
||||
|
||||
def test_params_are_clamped(self):
|
||||
"""Params stay in 0-1 range even at extreme intensity."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
params = engine.get_phase_params(intensity=100.0)
|
||||
for key in ["morph_speed", "shader_severity", "voice_frequency", "noise_level"]:
|
||||
assert 0.0 <= params[key] <= 1.0
|
||||
|
||||
|
||||
class TestAssetSwapInterval:
|
||||
def test_low_intensity_slow_swaps(self):
|
||||
"""Low intensity means long intervals between swaps."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
interval = engine.get_asset_swap_interval(intensity=0.5)
|
||||
assert interval >= 5.0
|
||||
|
||||
def test_high_intensity_fast_swaps(self):
|
||||
"""High intensity means short intervals."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
interval = engine.get_asset_swap_interval(intensity=5.0)
|
||||
assert interval <= 4.0
|
||||
|
||||
def test_interval_has_randomness(self):
|
||||
"""Consecutive calls produce different intervals."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
intervals = [engine.get_asset_swap_interval(intensity=2.0) for _ in range(20)]
|
||||
assert len(set(round(i, 2) for i in intervals)) > 1
|
||||
|
||||
|
||||
class TestVoiceInterval:
|
||||
def test_low_intensity_rare_voices(self):
|
||||
"""Low intensity = voices are rare."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
# Run many trials since exponential distribution is stochastic
|
||||
intervals = [engine.get_voice_interval(intensity=0.5) for _ in range(50)]
|
||||
mean = sum(intervals) / len(intervals)
|
||||
assert mean > 20.0 # Mean should be around 40
|
||||
|
||||
def test_high_intensity_frequent_voices(self):
|
||||
"""High intensity = voices are frequent."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
intervals = [engine.get_voice_interval(intensity=5.0) for _ in range(50)]
|
||||
mean = sum(intervals) / len(intervals)
|
||||
assert mean < 20.0
|
||||
|
||||
|
||||
class TestSessionTiming:
|
||||
def test_start_sets_time(self):
|
||||
"""Starting a session records the start time."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
engine.start_session()
|
||||
assert engine.session_start is not None
|
||||
|
||||
def test_elapsed_time(self):
|
||||
"""Elapsed time increases after start."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
engine.start_session()
|
||||
elapsed = engine.get_elapsed()
|
||||
assert elapsed >= 0.0
|
||||
|
||||
def test_reset_clears_session(self):
|
||||
"""Reset restarts the session."""
|
||||
engine = EscalationEngine(rate=0.05)
|
||||
engine.start_session()
|
||||
time.sleep(0.01)
|
||||
engine.reset()
|
||||
assert engine.get_elapsed() < 0.1
|
||||
@@ -0,0 +1,32 @@
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from server.main import create_app
|
||||
|
||||
|
||||
class TestRESTEndpoints:
|
||||
def test_status_endpoint(self):
|
||||
"""GET /status returns intensity and pool info."""
|
||||
test_app = create_app(skip_models=True)
|
||||
with TestClient(test_app) as client:
|
||||
resp = client.get("/status")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "intensity" in data
|
||||
assert "connected_clients" in data
|
||||
assert "image_pool_size" in data
|
||||
|
||||
def test_reset_endpoint(self):
|
||||
"""POST /reset restarts escalation."""
|
||||
test_app = create_app(skip_models=True)
|
||||
with TestClient(test_app) as client:
|
||||
resp = client.post("/reset")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ok"
|
||||
|
||||
def test_index_serves_html(self):
|
||||
"""GET / serves the frontend HTML (or fallback)."""
|
||||
test_app = create_app(skip_models=True)
|
||||
with TestClient(test_app) as client:
|
||||
resp = client.get("/")
|
||||
assert resp.status_code == 200
|
||||
assert "text/html" in resp.headers["content-type"]
|
||||
@@ -0,0 +1,61 @@
|
||||
from server.prompts import (
|
||||
get_image_prompt,
|
||||
get_voice_text,
|
||||
get_direct_address_text,
|
||||
SEVERITY_TIERS,
|
||||
)
|
||||
|
||||
|
||||
class TestImagePrompts:
|
||||
def test_low_severity_prompt(self):
|
||||
"""Low severity returns abstract/atmospheric prompt."""
|
||||
prompt = get_image_prompt(severity=0.5)
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 10
|
||||
|
||||
def test_high_severity_prompt(self):
|
||||
"""High severity returns more extreme prompt."""
|
||||
prompt = get_image_prompt(severity=4.0)
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 10
|
||||
|
||||
def test_prompts_vary(self):
|
||||
"""Consecutive calls produce different prompts."""
|
||||
prompts = [get_image_prompt(severity=2.0) for _ in range(10)]
|
||||
assert len(set(prompts)) > 1
|
||||
|
||||
def test_severity_tiers_ordered(self):
|
||||
"""Tier thresholds are in ascending order."""
|
||||
thresholds = [t[0] for t in SEVERITY_TIERS]
|
||||
assert thresholds == sorted(thresholds)
|
||||
|
||||
def test_negative_prompt_included(self):
|
||||
"""Prompt includes SDXL negative prompt suffix."""
|
||||
prompt = get_image_prompt(severity=1.0)
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
|
||||
class TestVoiceTexts:
|
||||
def test_whisper_text(self):
|
||||
"""Get a whisper text fragment."""
|
||||
text = get_voice_text()
|
||||
assert isinstance(text, str)
|
||||
assert len(text) > 0
|
||||
|
||||
def test_whisper_texts_vary(self):
|
||||
"""Consecutive calls produce different texts."""
|
||||
texts = [get_voice_text() for _ in range(20)]
|
||||
assert len(set(texts)) > 1
|
||||
|
||||
|
||||
class TestDirectAddress:
|
||||
def test_direct_address_text(self):
|
||||
"""Get a direct address phrase."""
|
||||
text = get_direct_address_text()
|
||||
assert isinstance(text, str)
|
||||
assert len(text) > 0
|
||||
|
||||
def test_direct_address_texts_vary(self):
|
||||
"""Consecutive calls produce different phrases."""
|
||||
texts = [get_direct_address_text() for _ in range(20)]
|
||||
assert len(set(texts)) > 1
|
||||
@@ -0,0 +1,99 @@
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from server.streaming import StreamManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager():
|
||||
return StreamManager()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ws():
|
||||
ws = AsyncMock()
|
||||
ws.send_text = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
class TestClientManagement:
|
||||
def test_add_client(self, manager, mock_ws):
|
||||
"""Adding a client increases count."""
|
||||
manager.add_client(mock_ws)
|
||||
assert manager.client_count == 1
|
||||
|
||||
def test_remove_client(self, manager, mock_ws):
|
||||
"""Removing a client decreases count."""
|
||||
manager.add_client(mock_ws)
|
||||
manager.remove_client(mock_ws)
|
||||
assert manager.client_count == 0
|
||||
|
||||
def test_remove_missing_client(self, manager, mock_ws):
|
||||
"""Removing a non-existent client doesn't error."""
|
||||
manager.remove_client(mock_ws)
|
||||
assert manager.client_count == 0
|
||||
|
||||
|
||||
class TestBroadcast:
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_phase(self, manager, mock_ws):
|
||||
"""Phase update is broadcast to all clients."""
|
||||
manager.add_client(mock_ws)
|
||||
await manager.broadcast_phase(
|
||||
intensity=2.4,
|
||||
params={"morph_speed": 0.35, "shader_severity": 0.6, "palette": "crimson_void"},
|
||||
)
|
||||
mock_ws.send_text.assert_called_once()
|
||||
msg = json.loads(mock_ws.send_text.call_args[0][0])
|
||||
assert msg["type"] == "phase"
|
||||
assert msg["intensity"] == 2.4
|
||||
assert msg["params"]["morph_speed"] == 0.35
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_asset(self, manager, mock_ws):
|
||||
"""Asset notification is broadcast."""
|
||||
manager.add_client(mock_ws)
|
||||
await manager.broadcast_asset(
|
||||
url="/assets/img/abc.png",
|
||||
severity=1.8,
|
||||
transition="melt",
|
||||
)
|
||||
msg = json.loads(mock_ws.send_text.call_args[0][0])
|
||||
assert msg["type"] == "asset"
|
||||
assert msg["url"] == "/assets/img/abc.png"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_whisper(self, manager, mock_ws):
|
||||
"""Whisper notification is broadcast."""
|
||||
manager.add_client(mock_ws)
|
||||
await manager.broadcast_whisper(
|
||||
url="/assets/audio/w01.wav",
|
||||
pan=-0.3,
|
||||
volume=0.4,
|
||||
reverb=0.7,
|
||||
)
|
||||
msg = json.loads(mock_ws.send_text.call_args[0][0])
|
||||
assert msg["type"] == "whisper"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_scare(self, manager, mock_ws):
|
||||
"""Scare event is broadcast."""
|
||||
manager.add_client(mock_ws)
|
||||
await manager.broadcast_scare(effect="face_flash", duration_ms=150)
|
||||
msg = json.loads(mock_ws.send_text.call_args[0][0])
|
||||
assert msg["type"] == "scare"
|
||||
assert msg["effect"] == "face_flash"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dead_client_cleanup(self, manager):
|
||||
"""Dead clients are removed on broadcast."""
|
||||
dead_ws = AsyncMock()
|
||||
dead_ws.send_text = AsyncMock(side_effect=Exception("connection closed"))
|
||||
manager.add_client(dead_ws)
|
||||
assert manager.client_count == 1
|
||||
|
||||
await manager.broadcast_phase(intensity=1.0, params={})
|
||||
assert manager.client_count == 0
|
||||
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
import tempfile
|
||||
import wave
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from server.voice_generator import VoiceGenerator
|
||||
|
||||
|
||||
class TestVoiceGenerator:
|
||||
@patch("server.voice_generator.TTS")
|
||||
def test_init_loads_model(self, mock_tts_cls):
|
||||
"""Generator loads the XTTS v2 model on init."""
|
||||
mock_tts = MagicMock()
|
||||
mock_tts_cls.return_value = mock_tts
|
||||
|
||||
gen = VoiceGenerator(device="cpu")
|
||||
mock_tts_cls.assert_called_once()
|
||||
|
||||
@patch("server.voice_generator.TTS")
|
||||
def test_generate_returns_wav_bytes(self, mock_tts_cls):
|
||||
"""Generate returns WAV bytes."""
|
||||
mock_tts = MagicMock()
|
||||
mock_tts_cls.return_value = mock_tts
|
||||
|
||||
# Create a real WAV file for the mock to "produce"
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
tmp_wav = f.name
|
||||
with wave.open(f, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(22050)
|
||||
wf.writeframes(b"\x00\x00" * 22050) # 1 second of silence
|
||||
|
||||
try:
|
||||
# Mock tts_to_file to copy our test WAV
|
||||
def fake_tts_to_file(text, speaker_wav, language, file_path):
|
||||
import shutil
|
||||
shutil.copy2(tmp_wav, file_path)
|
||||
|
||||
mock_tts.tts_to_file = fake_tts_to_file
|
||||
|
||||
gen = VoiceGenerator(device="cpu")
|
||||
data = gen.generate("hello", speaker_wav=tmp_wav)
|
||||
assert isinstance(data, bytes)
|
||||
assert len(data) > 0
|
||||
finally:
|
||||
os.unlink(tmp_wav)
|
||||
|
||||
@patch("server.voice_generator.TTS")
|
||||
def test_list_clone_sources(self, mock_tts_cls):
|
||||
"""Lists available clone source files."""
|
||||
mock_tts = MagicMock()
|
||||
mock_tts_cls.return_value = mock_tts
|
||||
|
||||
with tempfile.TemporaryDirectory() as samples_dir:
|
||||
# Create some fake sample files
|
||||
for name in ["dog.wav", "machine.wav", "wind.wav"]:
|
||||
with open(os.path.join(samples_dir, name), "wb") as f:
|
||||
f.write(b"fake")
|
||||
|
||||
gen = VoiceGenerator(device="cpu", samples_dir=samples_dir)
|
||||
sources = gen.list_clone_sources()
|
||||
assert len(sources) == 3
|
||||
assert all(s.endswith(".wav") for s in sources)
|
||||
|
||||
@patch("server.voice_generator.TTS")
|
||||
def test_random_clone_source(self, mock_tts_cls):
|
||||
"""Picks a random clone source from samples directory."""
|
||||
mock_tts = MagicMock()
|
||||
mock_tts_cls.return_value = mock_tts
|
||||
|
||||
with tempfile.TemporaryDirectory() as samples_dir:
|
||||
for name in ["a.wav", "b.wav", "c.wav"]:
|
||||
with open(os.path.join(samples_dir, name), "wb") as f:
|
||||
f.write(b"fake")
|
||||
|
||||
gen = VoiceGenerator(device="cpu", samples_dir=samples_dir)
|
||||
source = gen.random_clone_source()
|
||||
assert source is not None
|
||||
assert source.endswith(".wav")
|
||||
|
||||
@patch("server.voice_generator.TTS")
|
||||
def test_empty_samples_dir(self, mock_tts_cls):
|
||||
"""Empty samples dir returns None for random source."""
|
||||
mock_tts = MagicMock()
|
||||
mock_tts_cls.return_value = mock_tts
|
||||
|
||||
with tempfile.TemporaryDirectory() as samples_dir:
|
||||
gen = VoiceGenerator(device="cpu", samples_dir=samples_dir)
|
||||
assert gen.random_clone_source() is None
|
||||
Reference in New Issue
Block a user