879 lines
33 KiB
Python
879 lines
33 KiB
Python
"""
|
|
VibeVoice WebRTC Bot - Text-to-Speech WebRTC Agent
|
|
|
|
A WebRTC bot that converts incoming text messages to speech using VibeVoice
|
|
and streams the generated audio in real-time.
|
|
"""
|
|
|
|
import secrets
|
|
import numpy as np
|
|
import cv2
|
|
import fractions
|
|
import time
|
|
import threading
|
|
from queue import Queue, Empty
|
|
from typing import Awaitable, Callable, Dict, Optional, Tuple, Any, Union
|
|
import torch
|
|
import librosa
|
|
import soundfile as sf
|
|
from av.audio.frame import AudioFrame
|
|
from av import VideoFrame
|
|
from aiortc import MediaStreamTrack
|
|
|
|
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
|
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
|
from vibevoice.modular.streamer import AudioStreamer
|
|
from transformers.utils import logging
|
|
|
|
from shared.logger import logger
|
|
from shared.models import ChatMessageModel
|
|
|
|
# Configure logging
|
|
logging.set_verbosity_info()
|
|
tts_logger = logging.get_logger(__name__)
|
|
|
|
# Global registry to store active tracks by session
|
|
_active_tracks: Dict[str, Dict[str, Any]] = {}
|
|
|
|
|
|
class MediaClock:
|
|
"""Shared clock for media synchronization."""
|
|
|
|
def __init__(self):
|
|
self.t0 = time.perf_counter()
|
|
|
|
def now(self) -> float:
|
|
return time.perf_counter() - self.t0
|
|
|
|
|
|
class VibeVoiceWebRTCProcessor:
|
|
"""VibeVoice processor adapted for WebRTC streaming."""
|
|
|
|
def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
|
|
"""Initialize the VibeVoice processor for WebRTC."""
|
|
self.model_path = model_path
|
|
self.device = device
|
|
self.inference_steps = inference_steps
|
|
self.is_loaded = False
|
|
self.audio_queue = Queue()
|
|
self.is_generating = False
|
|
self.load_model()
|
|
self.setup_voice_presets()
|
|
|
|
def load_model(self):
|
|
"""Load the VibeVoice model and processor."""
|
|
try:
|
|
logger.info(f"Loading VibeVoice model from {self.model_path}")
|
|
|
|
# Normalize device
|
|
if self.device.lower() == "mpx":
|
|
self.device = "mps"
|
|
if self.device == "mps" and not torch.backends.mps.is_available():
|
|
logger.warning("MPS not available. Falling back to CPU.")
|
|
self.device = "cpu"
|
|
|
|
# Load processor
|
|
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
|
|
|
|
# Determine dtype and attention implementation
|
|
if self.device == "mps":
|
|
load_dtype = torch.float32
|
|
attn_impl = "sdpa"
|
|
elif self.device == "cuda":
|
|
load_dtype = torch.bfloat16
|
|
attn_impl = "flash_attention_2"
|
|
else:
|
|
load_dtype = torch.float32
|
|
attn_impl = "sdpa"
|
|
|
|
logger.info(f"Using device: {self.device}, dtype: {load_dtype}, attention: {attn_impl}")
|
|
|
|
# Load model with fallback
|
|
try:
|
|
if self.device == "mps":
|
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
|
self.model_path,
|
|
torch_dtype=load_dtype,
|
|
attn_implementation=attn_impl,
|
|
device_map=None,
|
|
)
|
|
self.model.to("mps")
|
|
elif self.device == "cuda":
|
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
|
self.model_path,
|
|
torch_dtype=load_dtype,
|
|
device_map="cuda",
|
|
attn_implementation=attn_impl,
|
|
)
|
|
else:
|
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
|
self.model_path,
|
|
torch_dtype=load_dtype,
|
|
device_map="cpu",
|
|
attn_implementation=attn_impl,
|
|
)
|
|
except Exception as e:
|
|
if attn_impl == 'flash_attention_2':
|
|
logger.warning(f"Flash attention failed: {e}, falling back to SDPA")
|
|
fallback_attn = "sdpa"
|
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
|
self.model_path,
|
|
torch_dtype=load_dtype,
|
|
device_map=(self.device if self.device in ("cuda", "cpu") else None),
|
|
attn_implementation=fallback_attn,
|
|
)
|
|
if self.device == "mps":
|
|
self.model.to("mps")
|
|
else:
|
|
raise e
|
|
|
|
self.model.eval()
|
|
|
|
# Configure noise scheduler
|
|
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
|
|
self.model.model.noise_scheduler.config,
|
|
algorithm_type='sde-dpmsolver++',
|
|
beta_schedule='squaredcos_cap_v2'
|
|
)
|
|
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
|
|
|
|
self.is_loaded = True
|
|
logger.info("VibeVoice model loaded successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load VibeVoice model: {e}")
|
|
self.is_loaded = False
|
|
|
|
def setup_voice_presets(self):
|
|
"""Setup voice presets by scanning the voices directory."""
|
|
import os
|
|
voices_dir = os.path.join(os.path.dirname(__file__), "voices")
|
|
|
|
self.voice_presets = {}
|
|
self.available_voices = {}
|
|
|
|
if not os.path.exists(voices_dir):
|
|
logger.warning(f"Voices directory not found at {voices_dir}")
|
|
# Create default fallback voices
|
|
self.available_voices = {
|
|
'default_female': None,
|
|
'default_male': None
|
|
}
|
|
return
|
|
|
|
# Scan for audio files
|
|
audio_extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')
|
|
for filename in os.listdir(voices_dir):
|
|
if filename.lower().endswith(audio_extensions):
|
|
name = os.path.splitext(filename)[0]
|
|
full_path = os.path.join(voices_dir, filename)
|
|
if os.path.isfile(full_path):
|
|
self.voice_presets[name] = full_path
|
|
|
|
self.available_voices = {
|
|
name: path for name, path in self.voice_presets.items()
|
|
if os.path.exists(path)
|
|
}
|
|
|
|
if self.available_voices:
|
|
logger.info(f"Found {len(self.available_voices)} voice presets: {list(self.available_voices.keys())}")
|
|
else:
|
|
logger.warning("No voice presets found")
|
|
|
|
def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
|
|
"""Read and preprocess audio file."""
|
|
try:
|
|
if audio_path is None:
|
|
# Return a simple sine wave as fallback
|
|
duration = 1.0 # 1 second
|
|
t = np.linspace(0, duration, int(target_sr * duration))
|
|
return np.sin(2 * np.pi * 440 * t).astype(np.float32)
|
|
|
|
wav, sr = sf.read(audio_path)
|
|
if len(wav.shape) > 1:
|
|
wav = np.mean(wav, axis=1)
|
|
if sr != target_sr:
|
|
wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
|
|
return wav.astype(np.float32)
|
|
except Exception as e:
|
|
logger.error(f"Error reading audio {audio_path}: {e}")
|
|
# Return silence as fallback
|
|
return np.zeros(target_sr, dtype=np.float32)
|
|
|
|
async def generate_speech_async(self, text: str, voice_name: str = None, cfg_scale: float = 1.3):
|
|
"""Generate speech asynchronously and put audio chunks in queue."""
|
|
if not self.is_loaded:
|
|
logger.error("Model not loaded, cannot generate speech")
|
|
return
|
|
|
|
if self.is_generating:
|
|
logger.warning("Already generating speech, skipping new request")
|
|
return
|
|
|
|
self.is_generating = True
|
|
|
|
try:
|
|
# Select voice
|
|
if voice_name and voice_name in self.available_voices:
|
|
voice_path = self.available_voices[voice_name]
|
|
else:
|
|
# Use first available voice or fallback
|
|
if self.available_voices:
|
|
voice_path = list(self.available_voices.values())[0]
|
|
voice_name = list(self.available_voices.keys())[0]
|
|
else:
|
|
voice_path = None
|
|
voice_name = "default"
|
|
|
|
logger.info(f"Generating speech for text: '{text[:50]}...' using voice: {voice_name}")
|
|
|
|
# Load voice sample
|
|
voice_sample = self.read_audio(voice_path)
|
|
voice_samples = [voice_sample]
|
|
|
|
# Format script
|
|
formatted_script = f"Speaker 0: {text}"
|
|
|
|
# Process input
|
|
inputs = self.processor(
|
|
text=[formatted_script],
|
|
voice_samples=[voice_samples],
|
|
padding=True,
|
|
return_tensors="pt",
|
|
return_attention_mask=True,
|
|
)
|
|
|
|
# Move to device
|
|
target_device = self.device if self.device in ("cuda", "mps") else "cpu"
|
|
for k, v in inputs.items():
|
|
if torch.is_tensor(v):
|
|
inputs[k] = v.to(target_device)
|
|
|
|
# Create audio streamer
|
|
audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None)
|
|
|
|
# Run generation in thread
|
|
def generate_worker():
|
|
try:
|
|
outputs = self.model.generate(
|
|
**inputs,
|
|
max_new_tokens=None,
|
|
cfg_scale=cfg_scale,
|
|
tokenizer=self.processor.tokenizer,
|
|
generation_config={'do_sample': False},
|
|
audio_streamer=audio_streamer,
|
|
verbose=False,
|
|
refresh_negative=True,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error in generation: {e}")
|
|
audio_streamer.end()
|
|
|
|
# Start generation
|
|
generation_thread = threading.Thread(target=generate_worker)
|
|
generation_thread.start()
|
|
|
|
# Process audio chunks
|
|
audio_stream = audio_streamer.get_stream(0)
|
|
|
|
for audio_chunk in audio_stream:
|
|
if torch.is_tensor(audio_chunk):
|
|
if audio_chunk.dtype == torch.bfloat16:
|
|
audio_chunk = audio_chunk.float()
|
|
audio_np = audio_chunk.cpu().numpy().astype(np.float32)
|
|
else:
|
|
audio_np = np.array(audio_chunk, dtype=np.float32)
|
|
|
|
# Ensure 1D
|
|
if len(audio_np.shape) > 1:
|
|
audio_np = audio_np.squeeze()
|
|
|
|
# Put in queue for audio track
|
|
self.audio_queue.put(audio_np)
|
|
|
|
# Wait for generation to complete
|
|
generation_thread.join(timeout=30.0)
|
|
|
|
# Signal end of audio
|
|
self.audio_queue.put(None) # End marker
|
|
|
|
logger.info("Speech generation completed")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating speech: {e}")
|
|
finally:
|
|
self.is_generating = False
|
|
|
|
|
|
class VibeVoiceAudioTrack(MediaStreamTrack):
|
|
"""Audio track that streams VibeVoice generated speech."""
|
|
|
|
kind = "audio"
|
|
|
|
def __init__(self, clock: MediaClock, config: Dict[str, Any], tts_processor: VibeVoiceWebRTCProcessor):
|
|
"""Initialize the VibeVoice audio track."""
|
|
super().__init__()
|
|
self.clock = clock
|
|
self.config = config
|
|
self.tts_processor = tts_processor
|
|
self.sample_rate = config.get('sample_rate', 24000) # VibeVoice uses 24kHz
|
|
self.samples_per_frame = config.get('samples_per_frame', 480) # 20ms at 24kHz
|
|
self._samples_generated: int = 0
|
|
|
|
# Audio buffer for TTS audio
|
|
self.audio_buffer = np.array([], dtype=np.float32)
|
|
self.buffer_lock = threading.Lock()
|
|
|
|
# Fallback tone parameters
|
|
self.frequency = config.get('frequency', 440.0)
|
|
self.volume = config.get('volume', 0.1) # Lower default volume
|
|
self.mode = config.get('audio_mode', 'silence') # Default to silence when no TTS
|
|
|
|
def update_config(self, config_updates: Dict[str, Any]) -> bool:
|
|
"""Update the audio track configuration."""
|
|
try:
|
|
self.config.update(config_updates)
|
|
|
|
if 'sample_rate' in config_updates:
|
|
self.sample_rate = config_updates['sample_rate']
|
|
if 'samples_per_frame' in config_updates:
|
|
self.samples_per_frame = config_updates['samples_per_frame']
|
|
if 'frequency' in config_updates:
|
|
self.frequency = config_updates['frequency']
|
|
if 'volume' in config_updates:
|
|
self.volume = config_updates['volume']
|
|
if 'audio_mode' in config_updates:
|
|
self.mode = config_updates['audio_mode']
|
|
|
|
logger.info(f"VibeVoice audio track configuration updated: {config_updates}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error updating VibeVoice audio track configuration: {e}")
|
|
return False
|
|
|
|
def add_tts_audio(self, audio_chunk: np.ndarray):
|
|
"""Add TTS generated audio to the buffer."""
|
|
with self.buffer_lock:
|
|
self.audio_buffer = np.append(self.audio_buffer, audio_chunk)
|
|
|
|
def _consume_audio_buffer(self, num_samples: int) -> np.ndarray:
|
|
"""Consume audio samples from the buffer."""
|
|
with self.buffer_lock:
|
|
if len(self.audio_buffer) >= num_samples:
|
|
samples = self.audio_buffer[:num_samples].copy()
|
|
self.audio_buffer = self.audio_buffer[num_samples:]
|
|
return samples
|
|
elif len(self.audio_buffer) > 0:
|
|
# Return what we have and pad with silence
|
|
samples = self.audio_buffer.copy()
|
|
self.audio_buffer = np.array([], dtype=np.float32)
|
|
padding = np.zeros(num_samples - len(samples), dtype=np.float32)
|
|
return np.concatenate([samples, padding])
|
|
else:
|
|
return np.zeros(num_samples, dtype=np.float32)
|
|
|
|
async def next_timestamp(self) -> Tuple[int, float]:
|
|
pts = self._samples_generated
|
|
time_base = 1 / self.sample_rate
|
|
return pts, time_base
|
|
|
|
async def recv(self) -> AudioFrame:
|
|
pts, time_base = await self.next_timestamp()
|
|
|
|
# Check for TTS audio in queue and add to buffer
|
|
while True:
|
|
try:
|
|
audio_chunk = self.tts_processor.audio_queue.get_nowait()
|
|
if audio_chunk is None:
|
|
# End marker
|
|
break
|
|
self.add_tts_audio(audio_chunk)
|
|
except Empty:
|
|
break
|
|
|
|
# Try to get audio from TTS buffer first
|
|
samples = self._consume_audio_buffer(self.samples_per_frame)
|
|
|
|
# If no TTS audio available, generate fallback audio based on mode
|
|
if np.all(samples == 0) and self.mode != 'silence':
|
|
if self.mode == 'tone':
|
|
samples = self._generate_tone()
|
|
elif self.mode == 'noise':
|
|
samples = self._generate_noise()
|
|
|
|
# Convert to WebRTC format (16-bit stereo)
|
|
# Resample from 24kHz to target sample rate if needed
|
|
if self.sample_rate != 24000:
|
|
samples = librosa.resample(samples, orig_sr=24000, target_sr=self.sample_rate)
|
|
|
|
# Convert to 16-bit
|
|
left = (samples * self.volume * 32767).astype(np.int16)
|
|
right = left.copy()
|
|
|
|
# Interleave channels for stereo
|
|
interleaved = np.empty(len(left) * 2, dtype=np.int16)
|
|
interleaved[0::2] = left
|
|
interleaved[1::2] = right
|
|
|
|
stereo = interleaved.reshape(1, -1)
|
|
|
|
frame = AudioFrame.from_ndarray(stereo, format="s16", layout="stereo")
|
|
frame.sample_rate = self.sample_rate
|
|
frame.pts = pts
|
|
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
|
|
|
|
self._samples_generated += self.samples_per_frame
|
|
return frame
|
|
|
|
def _generate_tone(self) -> np.ndarray:
|
|
"""Generate sine wave tone as fallback."""
|
|
t = (np.arange(self.samples_per_frame) + self._samples_generated) / self.sample_rate
|
|
return np.sin(2 * np.pi * self.frequency * t).astype(np.float32)
|
|
|
|
def _generate_noise(self) -> np.ndarray:
|
|
"""Generate white noise as fallback."""
|
|
return np.random.uniform(-1, 1, self.samples_per_frame).astype(np.float32)
|
|
|
|
|
|
class ConfigurableVideoTrack(MediaStreamTrack):
|
|
"""Configurable video track with different visualization modes"""
|
|
|
|
kind = "video"
|
|
|
|
def __init__(self, clock: MediaClock, config: Dict[str, Any]):
|
|
"""Initialize the configurable video track."""
|
|
super().__init__()
|
|
self.clock = clock
|
|
self.config = config
|
|
self.width = config.get('width', 320)
|
|
self.height = config.get('height', 240)
|
|
self.fps = config.get('fps', 15)
|
|
self.mode = config.get('visualization', 'ball')
|
|
self.frame_count = 0
|
|
|
|
# Initialize ball attributes
|
|
self.ball_x = self.width // 2
|
|
self.ball_y = self.height // 2
|
|
self.ball_dx = 2
|
|
self.ball_dy = 2
|
|
self.ball_radius = 20
|
|
|
|
def update_config(self, config_updates: Dict[str, Any]) -> bool:
|
|
"""Update the video track configuration dynamically."""
|
|
try:
|
|
old_mode = self.mode
|
|
old_width = self.width
|
|
old_height = self.height
|
|
|
|
self.config.update(config_updates)
|
|
|
|
if 'width' in config_updates:
|
|
self.width = config_updates['width']
|
|
if 'height' in config_updates:
|
|
self.height = config_updates['height']
|
|
if 'fps' in config_updates:
|
|
self.fps = config_updates['fps']
|
|
if 'visualization' in config_updates:
|
|
self.mode = config_updates['visualization']
|
|
|
|
if self.mode != old_mode:
|
|
self._initialize_mode_state()
|
|
logger.info(f"Video mode changed from {old_mode} to {self.mode}")
|
|
|
|
if self.width != old_width or self.height != old_height:
|
|
self._initialize_ball_state()
|
|
|
|
logger.info(f"Video track configuration updated: {config_updates}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating video track configuration: {e}")
|
|
return False
|
|
|
|
def _initialize_mode_state(self):
|
|
"""Initialize state specific to the current visualization mode."""
|
|
self._initialize_ball_state()
|
|
|
|
def _initialize_ball_state(self):
|
|
"""Initialize bouncing ball state."""
|
|
self.ball_x = self.width // 2
|
|
self.ball_y = self.height // 2
|
|
self.ball_dx = 2
|
|
self.ball_dy = 2
|
|
self.ball_radius = min(self.width, self.height) * 0.06
|
|
|
|
async def next_timestamp(self) -> Tuple[int, float]:
|
|
pts = int(self.frame_count * (90000 / self.fps))
|
|
time_base = 1 / 90000
|
|
return pts, time_base
|
|
|
|
async def recv(self) -> VideoFrame:
|
|
pts, time_base = await self.next_timestamp()
|
|
|
|
# Create frame based on mode
|
|
if self.mode == 'ball':
|
|
frame_array = self._generate_ball_frame()
|
|
elif self.mode == 'waveform':
|
|
frame_array = self._generate_waveform_frame()
|
|
elif self.mode == 'static':
|
|
frame_array = self._generate_static_frame()
|
|
else:
|
|
frame_array = self._generate_ball_frame()
|
|
|
|
frame = VideoFrame.from_ndarray(frame_array, format="bgr24")
|
|
frame.pts = pts
|
|
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
|
|
|
|
self.frame_count += 1
|
|
return frame
|
|
|
|
def _generate_ball_frame(self) -> Any:
|
|
"""Generate bouncing ball visualization"""
|
|
frame = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
|
|
|
# Update ball position
|
|
self.ball_x += self.ball_dx
|
|
self.ball_y += self.ball_dy
|
|
|
|
# Bounce off walls
|
|
if self.ball_x <= self.ball_radius or self.ball_x >= self.width - self.ball_radius:
|
|
self.ball_dx = -self.ball_dx
|
|
if self.ball_y <= self.ball_radius or self.ball_y >= self.height - self.ball_radius:
|
|
self.ball_dy = -self.ball_dy
|
|
|
|
# Draw ball
|
|
cv2.circle(frame, (int(self.ball_x), int(self.ball_y)), int(self.ball_radius), (0, 255, 0), -1)
|
|
|
|
# Add timestamp
|
|
timestamp = f"Frame: {self.frame_count}"
|
|
cv2.putText(frame, timestamp, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
|
|
|
return frame
|
|
|
|
def _generate_waveform_frame(self) -> Any:
|
|
"""Generate waveform visualization"""
|
|
frame = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
|
|
|
# Generate sine wave
|
|
x = np.linspace(0, 4*np.pi, self.width)
|
|
y = np.sin(x + self.frame_count * 0.1) * self.height // 4 + self.height // 2
|
|
|
|
# Draw waveform
|
|
for i in range(1, len(y)):
|
|
cv2.line(frame, (i-1, int(y[i-1])), (i, int(y[i])), (255, 255, 255), 2)
|
|
|
|
return frame
|
|
|
|
def _generate_static_frame(self) -> Any:
|
|
"""Generate static color frame"""
|
|
color = self.config.get('static_color', (128, 128, 128))
|
|
frame = np.full((self.height, self.width, 3), color, dtype=np.uint8)
|
|
return frame
|
|
|
|
|
|
def create_vibevoice_tracks(session_name: str, config: Optional[Dict[str, Any]] = None) -> Dict[str, MediaStreamTrack]:
|
|
"""
|
|
Create VibeVoice-enabled WebRTC tracks.
|
|
|
|
Args:
|
|
session_name: Name for the session
|
|
config: Configuration dictionary
|
|
|
|
Returns:
|
|
Dictionary containing 'video' and 'audio' tracks
|
|
"""
|
|
if config is None:
|
|
config = {}
|
|
|
|
# Set defaults
|
|
default_config = {
|
|
'visualization': 'ball',
|
|
'audio_mode': 'silence', # Start with silence
|
|
'width': 320,
|
|
'height': 240,
|
|
'fps': 15,
|
|
'sample_rate': 24000, # VibeVoice native sample rate
|
|
'samples_per_frame': 480, # 20ms at 24kHz
|
|
'frequency': 440.0,
|
|
'volume': 0.1,
|
|
'static_color': (128, 128, 128),
|
|
'model_path': 'vibevoice/VibeVoice-1.5B', # Default model path
|
|
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
|
|
'voice_preset': None
|
|
}
|
|
default_config.update(config)
|
|
|
|
# Parse static_color if it's a string
|
|
if isinstance(default_config.get('static_color'), str):
|
|
try:
|
|
color_str = default_config['static_color']
|
|
r, g, b = map(int, color_str.split(','))
|
|
default_config['static_color'] = (r, g, b)
|
|
except (ValueError, TypeError):
|
|
logger.warning(f"Invalid static_color format: {default_config.get('static_color')}, using default")
|
|
default_config['static_color'] = (128, 128, 128)
|
|
|
|
media_clock = MediaClock()
|
|
|
|
# Initialize VibeVoice processor
|
|
tts_processor = VibeVoiceWebRTCProcessor(
|
|
model_path=default_config['model_path'],
|
|
device=default_config['device'],
|
|
inference_steps=5
|
|
)
|
|
|
|
# Create tracks
|
|
video_track = ConfigurableVideoTrack(media_clock, default_config)
|
|
audio_track = VibeVoiceAudioTrack(media_clock, default_config, tts_processor)
|
|
|
|
logger.info(f"Created VibeVoice tracks for {session_name} with config: {default_config}")
|
|
|
|
# Store tracks and processor in global registry
|
|
_active_tracks[session_name] = {
|
|
"video": video_track,
|
|
"audio": audio_track,
|
|
"tts_processor": tts_processor
|
|
}
|
|
|
|
return {"video": video_track, "audio": audio_track}
|
|
|
|
|
|
# Agent descriptor
|
|
AGENT_NAME = "VibeVoice TTS Bot"
|
|
AGENT_DESCRIPTION = "WebRTC bot with VibeVoice text-to-speech capabilities"
|
|
|
|
def agent_info() -> Dict[str, str]:
|
|
"""Return agent metadata for discovery."""
|
|
return {
|
|
"name": AGENT_NAME,
|
|
"description": AGENT_DESCRIPTION,
|
|
"has_media": "true",
|
|
"configurable": "true",
|
|
"has_tts": "true"
|
|
}
|
|
|
|
|
|
def get_config_schema() -> Dict[str, Any]:
|
|
"""Get the configuration schema for the VibeVoice TTS Bot."""
|
|
return {
|
|
"bot_name": AGENT_NAME,
|
|
"version": "1.0",
|
|
"parameters": [
|
|
{
|
|
"name": "visualization",
|
|
"type": "select",
|
|
"label": "Video Visualization Mode",
|
|
"description": "Choose the type of video visualization to display",
|
|
"default_value": "ball",
|
|
"required": True,
|
|
"options": [
|
|
{"value": "ball", "label": "Bouncing Ball"},
|
|
{"value": "waveform", "label": "Sine Wave Animation"},
|
|
{"value": "static", "label": "Static Color Frame"}
|
|
]
|
|
},
|
|
{
|
|
"name": "audio_mode",
|
|
"type": "select",
|
|
"label": "Fallback Audio Mode",
|
|
"description": "Audio mode when no TTS is active",
|
|
"default_value": "silence",
|
|
"required": True,
|
|
"options": [
|
|
{"value": "silence", "label": "Silence"},
|
|
{"value": "tone", "label": "Sine Wave Tone"},
|
|
{"value": "noise", "label": "White Noise"}
|
|
]
|
|
},
|
|
{
|
|
"name": "voice_preset",
|
|
"type": "string",
|
|
"label": "Voice Preset",
|
|
"description": "Name of the voice preset to use for TTS",
|
|
"default_value": "",
|
|
"required": False,
|
|
"max_length": 50
|
|
},
|
|
{
|
|
"name": "device",
|
|
"type": "select",
|
|
"label": "Processing Device",
|
|
"description": "Device to use for TTS processing",
|
|
"default_value": "cuda",
|
|
"required": True,
|
|
"options": [
|
|
{"value": "cuda", "label": "CUDA (GPU)"},
|
|
{"value": "cpu", "label": "CPU"},
|
|
{"value": "mps", "label": "Apple Metal (MPS)"}
|
|
]
|
|
},
|
|
{
|
|
"name": "width",
|
|
"type": "number",
|
|
"label": "Video Width",
|
|
"description": "Width of the video frame in pixels",
|
|
"default_value": 320,
|
|
"required": False,
|
|
"min_value": 160,
|
|
"max_value": 1920,
|
|
"step": 1
|
|
},
|
|
{
|
|
"name": "height",
|
|
"type": "number",
|
|
"label": "Video Height",
|
|
"description": "Height of the video frame in pixels",
|
|
"default_value": 240,
|
|
"required": False,
|
|
"min_value": 120,
|
|
"max_value": 1080,
|
|
"step": 1
|
|
},
|
|
{
|
|
"name": "volume",
|
|
"type": "range",
|
|
"label": "Audio Volume",
|
|
"description": "Volume level (0.0 to 1.0)",
|
|
"default_value": 0.1,
|
|
"required": False,
|
|
"min_value": 0.0,
|
|
"max_value": 1.0,
|
|
"step": 0.1
|
|
}
|
|
],
|
|
"categories": [
|
|
{
|
|
"TTS Settings": ["voice_preset", "device"]
|
|
},
|
|
{
|
|
"Video Settings": ["visualization", "width", "height"]
|
|
},
|
|
{
|
|
"Audio Settings": ["audio_mode", "volume"]
|
|
}
|
|
]
|
|
}
|
|
|
|
|
|
def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
|
|
"""Factory wrapper used by the FastAPI service to instantiate tracks for an agent."""
|
|
return create_vibevoice_tracks(session_name)
|
|
|
|
|
|
async def handle_chat_message(
|
|
chat_message: ChatMessageModel,
|
|
send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]
|
|
) -> Optional[str]:
|
|
"""Handle chat messages and convert them to speech."""
|
|
try:
|
|
message_text = chat_message.message.strip()
|
|
if not message_text:
|
|
return None
|
|
|
|
logger.info(f"Processing TTS for message: '{message_text[:50]}...'")
|
|
|
|
# Find the TTS processor for this session
|
|
# Note: This assumes the session_name is available somehow
|
|
# You might need to modify this based on how sessions are handled
|
|
session_name = getattr(chat_message, 'session_id', None)
|
|
if not session_name:
|
|
# Try to find any active session (fallback)
|
|
if _active_tracks:
|
|
session_name = list(_active_tracks.keys())[0]
|
|
|
|
if session_name and session_name in _active_tracks:
|
|
tts_processor = _active_tracks[session_name].get("tts_processor")
|
|
if tts_processor:
|
|
# Get voice configuration
|
|
audio_track = _active_tracks[session_name].get("audio")
|
|
voice_preset = audio_track.config.get('voice_preset') if audio_track else None
|
|
|
|
# Generate speech asynchronously
|
|
await tts_processor.generate_speech_async(message_text, voice_preset)
|
|
|
|
# Send confirmation message
|
|
response_message = ChatMessageModel(
|
|
id=secrets.token_hex(16),
|
|
message=f"🔊 Converting to speech: '{message_text[:30]}{'...' if len(message_text) > 30 else ''}'",
|
|
sender_name="VibeVoice Bot",
|
|
sender_session_id=session_name,
|
|
lobby_id=chat_message.lobby_id,
|
|
timestamp=time.time()
|
|
)
|
|
await send_message_func(response_message)
|
|
else:
|
|
logger.error(f"No TTS processor found for session {session_name}")
|
|
else:
|
|
logger.error("No active session found for TTS processing")
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing chat message for TTS: {e}")
|
|
return None
|
|
|
|
|
|
def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
|
"""Handle runtime configuration updates for the VibeVoice bot."""
|
|
try:
|
|
validated_config = {}
|
|
|
|
# Parse and validate static_color if provided
|
|
if 'static_color' in config_values:
|
|
color_value = config_values['static_color']
|
|
if isinstance(color_value, str):
|
|
try:
|
|
r, g, b = map(int, color_value.split(','))
|
|
validated_config['static_color'] = (r, g, b)
|
|
except (ValueError, TypeError):
|
|
logger.warning(f"Invalid static_color format: {color_value}, ignoring update")
|
|
return False
|
|
elif isinstance(color_value, (tuple, list)) and len(color_value) == 3:
|
|
validated_config['static_color'] = tuple(color_value)
|
|
else:
|
|
logger.warning(f"Invalid static_color type: {type(color_value)}, ignoring update")
|
|
return False
|
|
|
|
# Copy other valid configuration values
|
|
valid_keys = {'visualization', 'audio_mode', 'width', 'height', 'fps',
|
|
'sample_rate', 'frequency', 'volume', 'voice_preset', 'device'}
|
|
for key in valid_keys:
|
|
if key in config_values:
|
|
validated_config[key] = config_values[key]
|
|
|
|
if validated_config:
|
|
logger.info(f"Configuration updated for {lobby_id}: {validated_config}")
|
|
|
|
# Update running tracks if they exist in the registry
|
|
if lobby_id in _active_tracks:
|
|
tracks = _active_tracks[lobby_id]
|
|
video_track = tracks.get("video")
|
|
audio_track = tracks.get("audio")
|
|
|
|
# Update video track configuration
|
|
if video_track and hasattr(video_track, 'update_config'):
|
|
video_updated = video_track.update_config(validated_config)
|
|
if video_updated:
|
|
logger.info(f"Video track configuration updated for {lobby_id}")
|
|
else:
|
|
logger.warning(f"Failed to update video track configuration for {lobby_id}")
|
|
|
|
# Update audio track configuration
|
|
if audio_track and hasattr(audio_track, 'update_config'):
|
|
audio_updated = audio_track.update_config(validated_config)
|
|
if audio_updated:
|
|
logger.info(f"Audio track configuration updated for {lobby_id}")
|
|
else:
|
|
logger.warning(f"Failed to update audio track configuration for {lobby_id}")
|
|
|
|
return True
|
|
else:
|
|
logger.warning(f"No active tracks found for session {lobby_id}")
|
|
return False
|
|
else:
|
|
logger.warning(f"No valid configuration values provided for {lobby_id}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating configuration for {lobby_id}: {e}")
|
|
return False |