ai-voicebot/voicebot/bots/vibevoice.py
2025-09-17 15:05:16 -07:00

879 lines
33 KiB
Python

"""
VibeVoice WebRTC Bot - Text-to-Speech WebRTC Agent
A WebRTC bot that converts incoming text messages to speech using VibeVoice
and streams the generated audio in real-time.
"""
import secrets
import numpy as np
import cv2
import fractions
import time
import threading
from queue import Queue, Empty
from typing import Awaitable, Callable, Dict, Optional, Tuple, Any, Union
import torch
import librosa
import soundfile as sf
from av.audio.frame import AudioFrame
from av import VideoFrame
from aiortc import MediaStreamTrack
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
from vibevoice.modular.streamer import AudioStreamer
from transformers.utils import logging
from shared.logger import logger
from shared.models import ChatMessageModel
# Configure logging
logging.set_verbosity_info()
tts_logger = logging.get_logger(__name__)
# Global registry to store active tracks by session
_active_tracks: Dict[str, Dict[str, Any]] = {}
class MediaClock:
"""Shared clock for media synchronization."""
def __init__(self):
self.t0 = time.perf_counter()
def now(self) -> float:
return time.perf_counter() - self.t0
class VibeVoiceWebRTCProcessor:
"""VibeVoice processor adapted for WebRTC streaming."""
def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
"""Initialize the VibeVoice processor for WebRTC."""
self.model_path = model_path
self.device = device
self.inference_steps = inference_steps
self.is_loaded = False
self.audio_queue = Queue()
self.is_generating = False
self.load_model()
self.setup_voice_presets()
def load_model(self):
"""Load the VibeVoice model and processor."""
try:
logger.info(f"Loading VibeVoice model from {self.model_path}")
# Normalize device
if self.device.lower() == "mpx":
self.device = "mps"
if self.device == "mps" and not torch.backends.mps.is_available():
logger.warning("MPS not available. Falling back to CPU.")
self.device = "cpu"
# Load processor
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
# Determine dtype and attention implementation
if self.device == "mps":
load_dtype = torch.float32
attn_impl = "sdpa"
elif self.device == "cuda":
load_dtype = torch.bfloat16
attn_impl = "flash_attention_2"
else:
load_dtype = torch.float32
attn_impl = "sdpa"
logger.info(f"Using device: {self.device}, dtype: {load_dtype}, attention: {attn_impl}")
# Load model with fallback
try:
if self.device == "mps":
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
attn_implementation=attn_impl,
device_map=None,
)
self.model.to("mps")
elif self.device == "cuda":
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
device_map="cuda",
attn_implementation=attn_impl,
)
else:
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
device_map="cpu",
attn_implementation=attn_impl,
)
except Exception as e:
if attn_impl == 'flash_attention_2':
logger.warning(f"Flash attention failed: {e}, falling back to SDPA")
fallback_attn = "sdpa"
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
device_map=(self.device if self.device in ("cuda", "cpu") else None),
attn_implementation=fallback_attn,
)
if self.device == "mps":
self.model.to("mps")
else:
raise e
self.model.eval()
# Configure noise scheduler
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
self.model.model.noise_scheduler.config,
algorithm_type='sde-dpmsolver++',
beta_schedule='squaredcos_cap_v2'
)
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
self.is_loaded = True
logger.info("VibeVoice model loaded successfully")
except Exception as e:
logger.error(f"Failed to load VibeVoice model: {e}")
self.is_loaded = False
def setup_voice_presets(self):
"""Setup voice presets by scanning the voices directory."""
import os
voices_dir = os.path.join(os.path.dirname(__file__), "voices")
self.voice_presets = {}
self.available_voices = {}
if not os.path.exists(voices_dir):
logger.warning(f"Voices directory not found at {voices_dir}")
# Create default fallback voices
self.available_voices = {
'default_female': None,
'default_male': None
}
return
# Scan for audio files
audio_extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')
for filename in os.listdir(voices_dir):
if filename.lower().endswith(audio_extensions):
name = os.path.splitext(filename)[0]
full_path = os.path.join(voices_dir, filename)
if os.path.isfile(full_path):
self.voice_presets[name] = full_path
self.available_voices = {
name: path for name, path in self.voice_presets.items()
if os.path.exists(path)
}
if self.available_voices:
logger.info(f"Found {len(self.available_voices)} voice presets: {list(self.available_voices.keys())}")
else:
logger.warning("No voice presets found")
def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
"""Read and preprocess audio file."""
try:
if audio_path is None:
# Return a simple sine wave as fallback
duration = 1.0 # 1 second
t = np.linspace(0, duration, int(target_sr * duration))
return np.sin(2 * np.pi * 440 * t).astype(np.float32)
wav, sr = sf.read(audio_path)
if len(wav.shape) > 1:
wav = np.mean(wav, axis=1)
if sr != target_sr:
wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
return wav.astype(np.float32)
except Exception as e:
logger.error(f"Error reading audio {audio_path}: {e}")
# Return silence as fallback
return np.zeros(target_sr, dtype=np.float32)
async def generate_speech_async(self, text: str, voice_name: str = None, cfg_scale: float = 1.3):
"""Generate speech asynchronously and put audio chunks in queue."""
if not self.is_loaded:
logger.error("Model not loaded, cannot generate speech")
return
if self.is_generating:
logger.warning("Already generating speech, skipping new request")
return
self.is_generating = True
try:
# Select voice
if voice_name and voice_name in self.available_voices:
voice_path = self.available_voices[voice_name]
else:
# Use first available voice or fallback
if self.available_voices:
voice_path = list(self.available_voices.values())[0]
voice_name = list(self.available_voices.keys())[0]
else:
voice_path = None
voice_name = "default"
logger.info(f"Generating speech for text: '{text[:50]}...' using voice: {voice_name}")
# Load voice sample
voice_sample = self.read_audio(voice_path)
voice_samples = [voice_sample]
# Format script
formatted_script = f"Speaker 0: {text}"
# Process input
inputs = self.processor(
text=[formatted_script],
voice_samples=[voice_samples],
padding=True,
return_tensors="pt",
return_attention_mask=True,
)
# Move to device
target_device = self.device if self.device in ("cuda", "mps") else "cpu"
for k, v in inputs.items():
if torch.is_tensor(v):
inputs[k] = v.to(target_device)
# Create audio streamer
audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None)
# Run generation in thread
def generate_worker():
try:
outputs = self.model.generate(
**inputs,
max_new_tokens=None,
cfg_scale=cfg_scale,
tokenizer=self.processor.tokenizer,
generation_config={'do_sample': False},
audio_streamer=audio_streamer,
verbose=False,
refresh_negative=True,
)
except Exception as e:
logger.error(f"Error in generation: {e}")
audio_streamer.end()
# Start generation
generation_thread = threading.Thread(target=generate_worker)
generation_thread.start()
# Process audio chunks
audio_stream = audio_streamer.get_stream(0)
for audio_chunk in audio_stream:
if torch.is_tensor(audio_chunk):
if audio_chunk.dtype == torch.bfloat16:
audio_chunk = audio_chunk.float()
audio_np = audio_chunk.cpu().numpy().astype(np.float32)
else:
audio_np = np.array(audio_chunk, dtype=np.float32)
# Ensure 1D
if len(audio_np.shape) > 1:
audio_np = audio_np.squeeze()
# Put in queue for audio track
self.audio_queue.put(audio_np)
# Wait for generation to complete
generation_thread.join(timeout=30.0)
# Signal end of audio
self.audio_queue.put(None) # End marker
logger.info("Speech generation completed")
except Exception as e:
logger.error(f"Error generating speech: {e}")
finally:
self.is_generating = False
class VibeVoiceAudioTrack(MediaStreamTrack):
"""Audio track that streams VibeVoice generated speech."""
kind = "audio"
def __init__(self, clock: MediaClock, config: Dict[str, Any], tts_processor: VibeVoiceWebRTCProcessor):
"""Initialize the VibeVoice audio track."""
super().__init__()
self.clock = clock
self.config = config
self.tts_processor = tts_processor
self.sample_rate = config.get('sample_rate', 24000) # VibeVoice uses 24kHz
self.samples_per_frame = config.get('samples_per_frame', 480) # 20ms at 24kHz
self._samples_generated: int = 0
# Audio buffer for TTS audio
self.audio_buffer = np.array([], dtype=np.float32)
self.buffer_lock = threading.Lock()
# Fallback tone parameters
self.frequency = config.get('frequency', 440.0)
self.volume = config.get('volume', 0.1) # Lower default volume
self.mode = config.get('audio_mode', 'silence') # Default to silence when no TTS
def update_config(self, config_updates: Dict[str, Any]) -> bool:
"""Update the audio track configuration."""
try:
self.config.update(config_updates)
if 'sample_rate' in config_updates:
self.sample_rate = config_updates['sample_rate']
if 'samples_per_frame' in config_updates:
self.samples_per_frame = config_updates['samples_per_frame']
if 'frequency' in config_updates:
self.frequency = config_updates['frequency']
if 'volume' in config_updates:
self.volume = config_updates['volume']
if 'audio_mode' in config_updates:
self.mode = config_updates['audio_mode']
logger.info(f"VibeVoice audio track configuration updated: {config_updates}")
return True
except Exception as e:
logger.error(f"Error updating VibeVoice audio track configuration: {e}")
return False
def add_tts_audio(self, audio_chunk: np.ndarray):
"""Add TTS generated audio to the buffer."""
with self.buffer_lock:
self.audio_buffer = np.append(self.audio_buffer, audio_chunk)
def _consume_audio_buffer(self, num_samples: int) -> np.ndarray:
"""Consume audio samples from the buffer."""
with self.buffer_lock:
if len(self.audio_buffer) >= num_samples:
samples = self.audio_buffer[:num_samples].copy()
self.audio_buffer = self.audio_buffer[num_samples:]
return samples
elif len(self.audio_buffer) > 0:
# Return what we have and pad with silence
samples = self.audio_buffer.copy()
self.audio_buffer = np.array([], dtype=np.float32)
padding = np.zeros(num_samples - len(samples), dtype=np.float32)
return np.concatenate([samples, padding])
else:
return np.zeros(num_samples, dtype=np.float32)
async def next_timestamp(self) -> Tuple[int, float]:
pts = self._samples_generated
time_base = 1 / self.sample_rate
return pts, time_base
async def recv(self) -> AudioFrame:
pts, time_base = await self.next_timestamp()
# Check for TTS audio in queue and add to buffer
while True:
try:
audio_chunk = self.tts_processor.audio_queue.get_nowait()
if audio_chunk is None:
# End marker
break
self.add_tts_audio(audio_chunk)
except Empty:
break
# Try to get audio from TTS buffer first
samples = self._consume_audio_buffer(self.samples_per_frame)
# If no TTS audio available, generate fallback audio based on mode
if np.all(samples == 0) and self.mode != 'silence':
if self.mode == 'tone':
samples = self._generate_tone()
elif self.mode == 'noise':
samples = self._generate_noise()
# Convert to WebRTC format (16-bit stereo)
# Resample from 24kHz to target sample rate if needed
if self.sample_rate != 24000:
samples = librosa.resample(samples, orig_sr=24000, target_sr=self.sample_rate)
# Convert to 16-bit
left = (samples * self.volume * 32767).astype(np.int16)
right = left.copy()
# Interleave channels for stereo
interleaved = np.empty(len(left) * 2, dtype=np.int16)
interleaved[0::2] = left
interleaved[1::2] = right
stereo = interleaved.reshape(1, -1)
frame = AudioFrame.from_ndarray(stereo, format="s16", layout="stereo")
frame.sample_rate = self.sample_rate
frame.pts = pts
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
self._samples_generated += self.samples_per_frame
return frame
def _generate_tone(self) -> np.ndarray:
"""Generate sine wave tone as fallback."""
t = (np.arange(self.samples_per_frame) + self._samples_generated) / self.sample_rate
return np.sin(2 * np.pi * self.frequency * t).astype(np.float32)
def _generate_noise(self) -> np.ndarray:
"""Generate white noise as fallback."""
return np.random.uniform(-1, 1, self.samples_per_frame).astype(np.float32)
class ConfigurableVideoTrack(MediaStreamTrack):
"""Configurable video track with different visualization modes"""
kind = "video"
def __init__(self, clock: MediaClock, config: Dict[str, Any]):
"""Initialize the configurable video track."""
super().__init__()
self.clock = clock
self.config = config
self.width = config.get('width', 320)
self.height = config.get('height', 240)
self.fps = config.get('fps', 15)
self.mode = config.get('visualization', 'ball')
self.frame_count = 0
# Initialize ball attributes
self.ball_x = self.width // 2
self.ball_y = self.height // 2
self.ball_dx = 2
self.ball_dy = 2
self.ball_radius = 20
def update_config(self, config_updates: Dict[str, Any]) -> bool:
"""Update the video track configuration dynamically."""
try:
old_mode = self.mode
old_width = self.width
old_height = self.height
self.config.update(config_updates)
if 'width' in config_updates:
self.width = config_updates['width']
if 'height' in config_updates:
self.height = config_updates['height']
if 'fps' in config_updates:
self.fps = config_updates['fps']
if 'visualization' in config_updates:
self.mode = config_updates['visualization']
if self.mode != old_mode:
self._initialize_mode_state()
logger.info(f"Video mode changed from {old_mode} to {self.mode}")
if self.width != old_width or self.height != old_height:
self._initialize_ball_state()
logger.info(f"Video track configuration updated: {config_updates}")
return True
except Exception as e:
logger.error(f"Error updating video track configuration: {e}")
return False
def _initialize_mode_state(self):
"""Initialize state specific to the current visualization mode."""
self._initialize_ball_state()
def _initialize_ball_state(self):
"""Initialize bouncing ball state."""
self.ball_x = self.width // 2
self.ball_y = self.height // 2
self.ball_dx = 2
self.ball_dy = 2
self.ball_radius = min(self.width, self.height) * 0.06
async def next_timestamp(self) -> Tuple[int, float]:
pts = int(self.frame_count * (90000 / self.fps))
time_base = 1 / 90000
return pts, time_base
async def recv(self) -> VideoFrame:
pts, time_base = await self.next_timestamp()
# Create frame based on mode
if self.mode == 'ball':
frame_array = self._generate_ball_frame()
elif self.mode == 'waveform':
frame_array = self._generate_waveform_frame()
elif self.mode == 'static':
frame_array = self._generate_static_frame()
else:
frame_array = self._generate_ball_frame()
frame = VideoFrame.from_ndarray(frame_array, format="bgr24")
frame.pts = pts
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
self.frame_count += 1
return frame
def _generate_ball_frame(self) -> Any:
"""Generate bouncing ball visualization"""
frame = np.zeros((self.height, self.width, 3), dtype=np.uint8)
# Update ball position
self.ball_x += self.ball_dx
self.ball_y += self.ball_dy
# Bounce off walls
if self.ball_x <= self.ball_radius or self.ball_x >= self.width - self.ball_radius:
self.ball_dx = -self.ball_dx
if self.ball_y <= self.ball_radius or self.ball_y >= self.height - self.ball_radius:
self.ball_dy = -self.ball_dy
# Draw ball
cv2.circle(frame, (int(self.ball_x), int(self.ball_y)), int(self.ball_radius), (0, 255, 0), -1)
# Add timestamp
timestamp = f"Frame: {self.frame_count}"
cv2.putText(frame, timestamp, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return frame
def _generate_waveform_frame(self) -> Any:
"""Generate waveform visualization"""
frame = np.zeros((self.height, self.width, 3), dtype=np.uint8)
# Generate sine wave
x = np.linspace(0, 4*np.pi, self.width)
y = np.sin(x + self.frame_count * 0.1) * self.height // 4 + self.height // 2
# Draw waveform
for i in range(1, len(y)):
cv2.line(frame, (i-1, int(y[i-1])), (i, int(y[i])), (255, 255, 255), 2)
return frame
def _generate_static_frame(self) -> Any:
"""Generate static color frame"""
color = self.config.get('static_color', (128, 128, 128))
frame = np.full((self.height, self.width, 3), color, dtype=np.uint8)
return frame
def create_vibevoice_tracks(session_name: str, config: Optional[Dict[str, Any]] = None) -> Dict[str, MediaStreamTrack]:
"""
Create VibeVoice-enabled WebRTC tracks.
Args:
session_name: Name for the session
config: Configuration dictionary
Returns:
Dictionary containing 'video' and 'audio' tracks
"""
if config is None:
config = {}
# Set defaults
default_config = {
'visualization': 'ball',
'audio_mode': 'silence', # Start with silence
'width': 320,
'height': 240,
'fps': 15,
'sample_rate': 24000, # VibeVoice native sample rate
'samples_per_frame': 480, # 20ms at 24kHz
'frequency': 440.0,
'volume': 0.1,
'static_color': (128, 128, 128),
'model_path': 'vibevoice/VibeVoice-1.5B', # Default model path
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
'voice_preset': None
}
default_config.update(config)
# Parse static_color if it's a string
if isinstance(default_config.get('static_color'), str):
try:
color_str = default_config['static_color']
r, g, b = map(int, color_str.split(','))
default_config['static_color'] = (r, g, b)
except (ValueError, TypeError):
logger.warning(f"Invalid static_color format: {default_config.get('static_color')}, using default")
default_config['static_color'] = (128, 128, 128)
media_clock = MediaClock()
# Initialize VibeVoice processor
tts_processor = VibeVoiceWebRTCProcessor(
model_path=default_config['model_path'],
device=default_config['device'],
inference_steps=5
)
# Create tracks
video_track = ConfigurableVideoTrack(media_clock, default_config)
audio_track = VibeVoiceAudioTrack(media_clock, default_config, tts_processor)
logger.info(f"Created VibeVoice tracks for {session_name} with config: {default_config}")
# Store tracks and processor in global registry
_active_tracks[session_name] = {
"video": video_track,
"audio": audio_track,
"tts_processor": tts_processor
}
return {"video": video_track, "audio": audio_track}
# Agent descriptor
AGENT_NAME = "VibeVoice TTS Bot"
AGENT_DESCRIPTION = "WebRTC bot with VibeVoice text-to-speech capabilities"
def agent_info() -> Dict[str, str]:
"""Return agent metadata for discovery."""
return {
"name": AGENT_NAME,
"description": AGENT_DESCRIPTION,
"has_media": "true",
"configurable": "true",
"has_tts": "true"
}
def get_config_schema() -> Dict[str, Any]:
"""Get the configuration schema for the VibeVoice TTS Bot."""
return {
"bot_name": AGENT_NAME,
"version": "1.0",
"parameters": [
{
"name": "visualization",
"type": "select",
"label": "Video Visualization Mode",
"description": "Choose the type of video visualization to display",
"default_value": "ball",
"required": True,
"options": [
{"value": "ball", "label": "Bouncing Ball"},
{"value": "waveform", "label": "Sine Wave Animation"},
{"value": "static", "label": "Static Color Frame"}
]
},
{
"name": "audio_mode",
"type": "select",
"label": "Fallback Audio Mode",
"description": "Audio mode when no TTS is active",
"default_value": "silence",
"required": True,
"options": [
{"value": "silence", "label": "Silence"},
{"value": "tone", "label": "Sine Wave Tone"},
{"value": "noise", "label": "White Noise"}
]
},
{
"name": "voice_preset",
"type": "string",
"label": "Voice Preset",
"description": "Name of the voice preset to use for TTS",
"default_value": "",
"required": False,
"max_length": 50
},
{
"name": "device",
"type": "select",
"label": "Processing Device",
"description": "Device to use for TTS processing",
"default_value": "cuda",
"required": True,
"options": [
{"value": "cuda", "label": "CUDA (GPU)"},
{"value": "cpu", "label": "CPU"},
{"value": "mps", "label": "Apple Metal (MPS)"}
]
},
{
"name": "width",
"type": "number",
"label": "Video Width",
"description": "Width of the video frame in pixels",
"default_value": 320,
"required": False,
"min_value": 160,
"max_value": 1920,
"step": 1
},
{
"name": "height",
"type": "number",
"label": "Video Height",
"description": "Height of the video frame in pixels",
"default_value": 240,
"required": False,
"min_value": 120,
"max_value": 1080,
"step": 1
},
{
"name": "volume",
"type": "range",
"label": "Audio Volume",
"description": "Volume level (0.0 to 1.0)",
"default_value": 0.1,
"required": False,
"min_value": 0.0,
"max_value": 1.0,
"step": 0.1
}
],
"categories": [
{
"TTS Settings": ["voice_preset", "device"]
},
{
"Video Settings": ["visualization", "width", "height"]
},
{
"Audio Settings": ["audio_mode", "volume"]
}
]
}
def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
"""Factory wrapper used by the FastAPI service to instantiate tracks for an agent."""
return create_vibevoice_tracks(session_name)
async def handle_chat_message(
chat_message: ChatMessageModel,
send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]
) -> Optional[str]:
"""Handle chat messages and convert them to speech."""
try:
message_text = chat_message.message.strip()
if not message_text:
return None
logger.info(f"Processing TTS for message: '{message_text[:50]}...'")
# Find the TTS processor for this session
# Note: This assumes the session_name is available somehow
# You might need to modify this based on how sessions are handled
session_name = getattr(chat_message, 'session_id', None)
if not session_name:
# Try to find any active session (fallback)
if _active_tracks:
session_name = list(_active_tracks.keys())[0]
if session_name and session_name in _active_tracks:
tts_processor = _active_tracks[session_name].get("tts_processor")
if tts_processor:
# Get voice configuration
audio_track = _active_tracks[session_name].get("audio")
voice_preset = audio_track.config.get('voice_preset') if audio_track else None
# Generate speech asynchronously
await tts_processor.generate_speech_async(message_text, voice_preset)
# Send confirmation message
response_message = ChatMessageModel(
id=secrets.token_hex(16),
message=f"🔊 Converting to speech: '{message_text[:30]}{'...' if len(message_text) > 30 else ''}'",
sender_name="VibeVoice Bot",
sender_session_id=session_name,
lobby_id=chat_message.lobby_id,
timestamp=time.time()
)
await send_message_func(response_message)
else:
logger.error(f"No TTS processor found for session {session_name}")
else:
logger.error("No active session found for TTS processing")
return None
except Exception as e:
logger.error(f"Error processing chat message for TTS: {e}")
return None
def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
"""Handle runtime configuration updates for the VibeVoice bot."""
try:
validated_config = {}
# Parse and validate static_color if provided
if 'static_color' in config_values:
color_value = config_values['static_color']
if isinstance(color_value, str):
try:
r, g, b = map(int, color_value.split(','))
validated_config['static_color'] = (r, g, b)
except (ValueError, TypeError):
logger.warning(f"Invalid static_color format: {color_value}, ignoring update")
return False
elif isinstance(color_value, (tuple, list)) and len(color_value) == 3:
validated_config['static_color'] = tuple(color_value)
else:
logger.warning(f"Invalid static_color type: {type(color_value)}, ignoring update")
return False
# Copy other valid configuration values
valid_keys = {'visualization', 'audio_mode', 'width', 'height', 'fps',
'sample_rate', 'frequency', 'volume', 'voice_preset', 'device'}
for key in valid_keys:
if key in config_values:
validated_config[key] = config_values[key]
if validated_config:
logger.info(f"Configuration updated for {lobby_id}: {validated_config}")
# Update running tracks if they exist in the registry
if lobby_id in _active_tracks:
tracks = _active_tracks[lobby_id]
video_track = tracks.get("video")
audio_track = tracks.get("audio")
# Update video track configuration
if video_track and hasattr(video_track, 'update_config'):
video_updated = video_track.update_config(validated_config)
if video_updated:
logger.info(f"Video track configuration updated for {lobby_id}")
else:
logger.warning(f"Failed to update video track configuration for {lobby_id}")
# Update audio track configuration
if audio_track and hasattr(audio_track, 'update_config'):
audio_updated = audio_track.update_config(validated_config)
if audio_updated:
logger.info(f"Audio track configuration updated for {lobby_id}")
else:
logger.warning(f"Failed to update audio track configuration for {lobby_id}")
return True
else:
logger.warning(f"No active tracks found for session {lobby_id}")
return False
else:
logger.warning(f"No valid configuration values provided for {lobby_id}")
return False
except Exception as e:
logger.error(f"Error updating configuration for {lobby_id}: {e}")
return False