diff --git a/voicebot/bots/vibevoice.py b/voicebot/bots/vibevoice.py index c34d9c6..ca6f38f 100644 --- a/voicebot/bots/vibevoice.py +++ b/voicebot/bots/vibevoice.py @@ -1,879 +1,827 @@ -""" -VibeVoice WebRTC Bot - Text-to-Speech WebRTC Agent - -A WebRTC bot that converts incoming text messages to speech using VibeVoice -and streams the generated audio in real-time. -""" - -import secrets +import asyncio import numpy as np +import time +import os +from typing import Dict, Optional, Callable, Awaitable, Any, Union +import numpy.typing as npt + + +# Core dependencies +import librosa +from shared.logger import logger +from aiortc import MediaStreamTrack +from aiortc.mediastreams import MediaStreamError +from av import AudioFrame, VideoFrame import cv2 import fractions -import time -import threading -from queue import Queue, Empty -from typing import Awaitable, Callable, Dict, Optional, Tuple, Any, Union -import torch -import librosa -import soundfile as sf -from av.audio.frame import AudioFrame -from av import VideoFrame -from aiortc import MediaStreamTrack +from time import perf_counter +import wave -from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference -from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor -from vibevoice.modular.streamer import AudioStreamer -from transformers.utils import logging +# Import shared models for chat functionality +import sys -from shared.logger import logger +sys.path.append( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) from shared.models import ChatMessageModel +from voicebot.models import Peer -# Configure logging -logging.set_verbosity_info() -tts_logger = logging.get_logger(__name__) +from .vibevoicetts import text_to_speech -# Global registry to store active tracks by session -_active_tracks: Dict[str, Dict[str, Any]] = {} +# Global configuration and constants +AGENT_NAME = "Text To Speech Bot" +AGENT_DESCRIPTION = ( + "Real-time speech generation- converts text to speech on Intel Arc B580" +) +SAMPLE_RATE = 16000 # Whisper expects 16kHz +_audio_data = None + +class TTSModel: + """TTS via MSFT VibeVoice""" + + def __init__( + self, + peer_name: str, + send_chat_func: Callable[[ChatMessageModel], Awaitable[None]], + create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel], + ): + self.peer_name = peer_name + self.send_chat_func = send_chat_func + self.create_chat_message_func = create_chat_message_func + + +_send_chat_func: Optional[Callable[[ChatMessageModel], Awaitable[None]]] = None +_create_chat_message_func: Optional[ + Callable[[str, Optional[str]], ChatMessageModel] +] = None + class MediaClock: - """Shared clock for media synchronization.""" + """Simple monotonic clock for media tracks.""" - def __init__(self): - self.t0 = time.perf_counter() + def __init__(self) -> None: + self.t0 = perf_counter() def now(self) -> float: - return time.perf_counter() - self.t0 + return perf_counter() - self.t0 -class VibeVoiceWebRTCProcessor: - """VibeVoice processor adapted for WebRTC streaming.""" - - def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5): - """Initialize the VibeVoice processor for WebRTC.""" - self.model_path = model_path - self.device = device - self.inference_steps = inference_steps - self.is_loaded = False - self.audio_queue = Queue() - self.is_generating = False - self.load_model() - self.setup_voice_presets() - - def load_model(self): - """Load the VibeVoice model and processor.""" - try: - logger.info(f"Loading VibeVoice model from {self.model_path}") - - # Normalize device - if self.device.lower() == "mpx": - self.device = "mps" - if self.device == "mps" and not torch.backends.mps.is_available(): - logger.warning("MPS not available. Falling back to CPU.") - self.device = "cpu" - - # Load processor - self.processor = VibeVoiceProcessor.from_pretrained(self.model_path) - - # Determine dtype and attention implementation - if self.device == "mps": - load_dtype = torch.float32 - attn_impl = "sdpa" - elif self.device == "cuda": - load_dtype = torch.bfloat16 - attn_impl = "flash_attention_2" - else: - load_dtype = torch.float32 - attn_impl = "sdpa" - - logger.info(f"Using device: {self.device}, dtype: {load_dtype}, attention: {attn_impl}") - - # Load model with fallback - try: - if self.device == "mps": - self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( - self.model_path, - torch_dtype=load_dtype, - attn_implementation=attn_impl, - device_map=None, - ) - self.model.to("mps") - elif self.device == "cuda": - self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( - self.model_path, - torch_dtype=load_dtype, - device_map="cuda", - attn_implementation=attn_impl, - ) - else: - self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( - self.model_path, - torch_dtype=load_dtype, - device_map="cpu", - attn_implementation=attn_impl, - ) - except Exception as e: - if attn_impl == 'flash_attention_2': - logger.warning(f"Flash attention failed: {e}, falling back to SDPA") - fallback_attn = "sdpa" - self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( - self.model_path, - torch_dtype=load_dtype, - device_map=(self.device if self.device in ("cuda", "cpu") else None), - attn_implementation=fallback_attn, - ) - if self.device == "mps": - self.model.to("mps") - else: - raise e - - self.model.eval() - - # Configure noise scheduler - self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config( - self.model.model.noise_scheduler.config, - algorithm_type='sde-dpmsolver++', - beta_schedule='squaredcos_cap_v2' - ) - self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) - - self.is_loaded = True - logger.info("VibeVoice model loaded successfully") - - except Exception as e: - logger.error(f"Failed to load VibeVoice model: {e}") - self.is_loaded = False - - def setup_voice_presets(self): - """Setup voice presets by scanning the voices directory.""" - import os - voices_dir = os.path.join(os.path.dirname(__file__), "voices") - - self.voice_presets = {} - self.available_voices = {} - - if not os.path.exists(voices_dir): - logger.warning(f"Voices directory not found at {voices_dir}") - # Create default fallback voices - self.available_voices = { - 'default_female': None, - 'default_male': None - } - return - - # Scan for audio files - audio_extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac') - for filename in os.listdir(voices_dir): - if filename.lower().endswith(audio_extensions): - name = os.path.splitext(filename)[0] - full_path = os.path.join(voices_dir, filename) - if os.path.isfile(full_path): - self.voice_presets[name] = full_path - - self.available_voices = { - name: path for name, path in self.voice_presets.items() - if os.path.exists(path) - } - - if self.available_voices: - logger.info(f"Found {len(self.available_voices)} voice presets: {list(self.available_voices.keys())}") - else: - logger.warning("No voice presets found") - - def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray: - """Read and preprocess audio file.""" - try: - if audio_path is None: - # Return a simple sine wave as fallback - duration = 1.0 # 1 second - t = np.linspace(0, duration, int(target_sr * duration)) - return np.sin(2 * np.pi * 440 * t).astype(np.float32) - - wav, sr = sf.read(audio_path) - if len(wav.shape) > 1: - wav = np.mean(wav, axis=1) - if sr != target_sr: - wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr) - return wav.astype(np.float32) - except Exception as e: - logger.error(f"Error reading audio {audio_path}: {e}") - # Return silence as fallback - return np.zeros(target_sr, dtype=np.float32) - - async def generate_speech_async(self, text: str, voice_name: str = None, cfg_scale: float = 1.3): - """Generate speech asynchronously and put audio chunks in queue.""" - if not self.is_loaded: - logger.error("Model not loaded, cannot generate speech") - return - - if self.is_generating: - logger.warning("Already generating speech, skipping new request") - return - - self.is_generating = True - - try: - # Select voice - if voice_name and voice_name in self.available_voices: - voice_path = self.available_voices[voice_name] - else: - # Use first available voice or fallback - if self.available_voices: - voice_path = list(self.available_voices.values())[0] - voice_name = list(self.available_voices.keys())[0] - else: - voice_path = None - voice_name = "default" - - logger.info(f"Generating speech for text: '{text[:50]}...' using voice: {voice_name}") - - # Load voice sample - voice_sample = self.read_audio(voice_path) - voice_samples = [voice_sample] - - # Format script - formatted_script = f"Speaker 0: {text}" - - # Process input - inputs = self.processor( - text=[formatted_script], - voice_samples=[voice_samples], - padding=True, - return_tensors="pt", - return_attention_mask=True, - ) - - # Move to device - target_device = self.device if self.device in ("cuda", "mps") else "cpu" - for k, v in inputs.items(): - if torch.is_tensor(v): - inputs[k] = v.to(target_device) - - # Create audio streamer - audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None) - - # Run generation in thread - def generate_worker(): - try: - outputs = self.model.generate( - **inputs, - max_new_tokens=None, - cfg_scale=cfg_scale, - tokenizer=self.processor.tokenizer, - generation_config={'do_sample': False}, - audio_streamer=audio_streamer, - verbose=False, - refresh_negative=True, - ) - except Exception as e: - logger.error(f"Error in generation: {e}") - audio_streamer.end() - - # Start generation - generation_thread = threading.Thread(target=generate_worker) - generation_thread.start() - - # Process audio chunks - audio_stream = audio_streamer.get_stream(0) - - for audio_chunk in audio_stream: - if torch.is_tensor(audio_chunk): - if audio_chunk.dtype == torch.bfloat16: - audio_chunk = audio_chunk.float() - audio_np = audio_chunk.cpu().numpy().astype(np.float32) - else: - audio_np = np.array(audio_chunk, dtype=np.float32) - - # Ensure 1D - if len(audio_np.shape) > 1: - audio_np = audio_np.squeeze() - - # Put in queue for audio track - self.audio_queue.put(audio_np) - - # Wait for generation to complete - generation_thread.join(timeout=30.0) - - # Signal end of audio - self.audio_queue.put(None) # End marker - - logger.info("Speech generation completed") - - except Exception as e: - logger.error(f"Error generating speech: {e}") - finally: - self.is_generating = False +class WaveformVideoTrack(MediaStreamTrack): + """Video track that renders a live waveform of the incoming audio. - -class VibeVoiceAudioTrack(MediaStreamTrack): - """Audio track that streams VibeVoice generated speech.""" - - kind = "audio" - - def __init__(self, clock: MediaClock, config: Dict[str, Any], tts_processor: VibeVoiceWebRTCProcessor): - """Initialize the VibeVoice audio track.""" - super().__init__() - self.clock = clock - self.config = config - self.tts_processor = tts_processor - self.sample_rate = config.get('sample_rate', 24000) # VibeVoice uses 24kHz - self.samples_per_frame = config.get('samples_per_frame', 480) # 20ms at 24kHz - self._samples_generated: int = 0 - - # Audio buffer for TTS audio - self.audio_buffer = np.array([], dtype=np.float32) - self.buffer_lock = threading.Lock() - - # Fallback tone parameters - self.frequency = config.get('frequency', 440.0) - self.volume = config.get('volume', 0.1) # Lower default volume - self.mode = config.get('audio_mode', 'silence') # Default to silence when no TTS - - def update_config(self, config_updates: Dict[str, Any]) -> bool: - """Update the audio track configuration.""" - try: - self.config.update(config_updates) - - if 'sample_rate' in config_updates: - self.sample_rate = config_updates['sample_rate'] - if 'samples_per_frame' in config_updates: - self.samples_per_frame = config_updates['samples_per_frame'] - if 'frequency' in config_updates: - self.frequency = config_updates['frequency'] - if 'volume' in config_updates: - self.volume = config_updates['volume'] - if 'audio_mode' in config_updates: - self.mode = config_updates['audio_mode'] - - logger.info(f"VibeVoice audio track configuration updated: {config_updates}") - return True - except Exception as e: - logger.error(f"Error updating VibeVoice audio track configuration: {e}") - return False - - def add_tts_audio(self, audio_chunk: np.ndarray): - """Add TTS generated audio to the buffer.""" - with self.buffer_lock: - self.audio_buffer = np.append(self.audio_buffer, audio_chunk) - - def _consume_audio_buffer(self, num_samples: int) -> np.ndarray: - """Consume audio samples from the buffer.""" - with self.buffer_lock: - if len(self.audio_buffer) >= num_samples: - samples = self.audio_buffer[:num_samples].copy() - self.audio_buffer = self.audio_buffer[num_samples:] - return samples - elif len(self.audio_buffer) > 0: - # Return what we have and pad with silence - samples = self.audio_buffer.copy() - self.audio_buffer = np.array([], dtype=np.float32) - padding = np.zeros(num_samples - len(samples), dtype=np.float32) - return np.concatenate([samples, padding]) - else: - return np.zeros(num_samples, dtype=np.float32) - - async def next_timestamp(self) -> Tuple[int, float]: - pts = self._samples_generated - time_base = 1 / self.sample_rate - return pts, time_base - - async def recv(self) -> AudioFrame: - pts, time_base = await self.next_timestamp() - - # Check for TTS audio in queue and add to buffer - while True: - try: - audio_chunk = self.tts_processor.audio_queue.get_nowait() - if audio_chunk is None: - # End marker - break - self.add_tts_audio(audio_chunk) - except Empty: - break - - # Try to get audio from TTS buffer first - samples = self._consume_audio_buffer(self.samples_per_frame) - - # If no TTS audio available, generate fallback audio based on mode - if np.all(samples == 0) and self.mode != 'silence': - if self.mode == 'tone': - samples = self._generate_tone() - elif self.mode == 'noise': - samples = self._generate_noise() - - # Convert to WebRTC format (16-bit stereo) - # Resample from 24kHz to target sample rate if needed - if self.sample_rate != 24000: - samples = librosa.resample(samples, orig_sr=24000, target_sr=self.sample_rate) - - # Convert to 16-bit - left = (samples * self.volume * 32767).astype(np.int16) - right = left.copy() - - # Interleave channels for stereo - interleaved = np.empty(len(left) * 2, dtype=np.int16) - interleaved[0::2] = left - interleaved[1::2] = right - - stereo = interleaved.reshape(1, -1) - - frame = AudioFrame.from_ndarray(stereo, format="s16", layout="stereo") - frame.sample_rate = self.sample_rate - frame.pts = pts - frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000) - - self._samples_generated += self.samples_per_frame - return frame - - def _generate_tone(self) -> np.ndarray: - """Generate sine wave tone as fallback.""" - t = (np.arange(self.samples_per_frame) + self._samples_generated) / self.sample_rate - return np.sin(2 * np.pi * self.frequency * t).astype(np.float32) - - def _generate_noise(self) -> np.ndarray: - """Generate white noise as fallback.""" - return np.random.uniform(-1, 1, self.samples_per_frame).astype(np.float32) - - -class ConfigurableVideoTrack(MediaStreamTrack): - """Configurable video track with different visualization modes""" + The track reads the most-active `OptimizedAudioProcessor` in + `_audio_processors` and renders the last ~2s of its `current_phrase_audio`. + If no audio is available, the track will display a "No audio" message. + """ kind = "video" - def __init__(self, clock: MediaClock, config: Dict[str, Any]): - """Initialize the configurable video track.""" + # Shared buffer for audio data + buffer: Dict[str, npt.NDArray[np.float32]] = {} + speech_status: Dict[str, Dict[str, Any]] = {} + + def __init__( + self, session_name: str, width: int = 640, height: int = 480, fps: int = 15 + ) -> None: super().__init__() - self.clock = clock - self.config = config - self.width = config.get('width', 320) - self.height = config.get('height', 240) - self.fps = config.get('fps', 15) - self.mode = config.get('visualization', 'ball') - self.frame_count = 0 + self.session_name = session_name + self.width = int(width) + self.height = int(height) + self.fps = int(fps) + self.clock = MediaClock() + self._next_frame_index = 0 - # Initialize ball attributes - self.ball_x = self.width // 2 - self.ball_y = self.height // 2 - self.ball_dx = 2 - self.ball_dy = 2 - self.ball_radius = 20 - - def update_config(self, config_updates: Dict[str, Any]) -> bool: - """Update the video track configuration dynamically.""" - try: - old_mode = self.mode - old_width = self.width - old_height = self.height - - self.config.update(config_updates) - - if 'width' in config_updates: - self.width = config_updates['width'] - if 'height' in config_updates: - self.height = config_updates['height'] - if 'fps' in config_updates: - self.fps = config_updates['fps'] - if 'visualization' in config_updates: - self.mode = config_updates['visualization'] - - if self.mode != old_mode: - self._initialize_mode_state() - logger.info(f"Video mode changed from {old_mode} to {self.mode}") - - if self.width != old_width or self.height != old_height: - self._initialize_ball_state() - - logger.info(f"Video track configuration updated: {config_updates}") - return True - - except Exception as e: - logger.error(f"Error updating video track configuration: {e}") - return False - - def _initialize_mode_state(self): - """Initialize state specific to the current visualization mode.""" - self._initialize_ball_state() - - def _initialize_ball_state(self): - """Initialize bouncing ball state.""" - self.ball_x = self.width // 2 - self.ball_y = self.height // 2 - self.ball_dx = 2 - self.ball_dy = 2 - self.ball_radius = min(self.width, self.height) * 0.06 - - async def next_timestamp(self) -> Tuple[int, float]: - pts = int(self.frame_count * (90000 / self.fps)) + async def next_timestamp(self) -> tuple[int, float]: + pts = int(self._next_frame_index * (1 / self.fps) * 90000) time_base = 1 / 90000 return pts, time_base async def recv(self) -> VideoFrame: - pts, time_base = await self.next_timestamp() + pts, _ = await self.next_timestamp() - # Create frame based on mode - if self.mode == 'ball': - frame_array = self._generate_ball_frame() - elif self.mode == 'waveform': - frame_array = self._generate_waveform_frame() - elif self.mode == 'static': - frame_array = self._generate_static_frame() - else: - frame_array = self._generate_ball_frame() + # schedule frame according to clock + target_t = self._next_frame_index / self.fps + now = self.clock.now() + if target_t > now: + await asyncio.sleep(target_t - now) + + self._next_frame_index += 1 + + frame_array: npt.NDArray[np.uint8] = np.zeros( + (self.height, self.width, 3), dtype=np.uint8 + ) + + # Display model loading status prominently + status_text = "Initializing..." + progress = 0.1 + + # Draw status background (increased height for larger text) + cv2.rectangle(frame_array, (0, 0), (self.width, 80), (0, 0, 0), -1) + + # Draw progress bar if loading + if progress < 1.0 and "Ready" not in status_text: + bar_width = int(progress * (self.width - 40)) + cv2.rectangle(frame_array, (20, 55), (20 + bar_width, 70), (0, 255, 0), -1) + cv2.rectangle( + frame_array, (20, 55), (self.width - 20, 70), (255, 255, 255), 2 + ) + + # Draw status text (larger font) + cv2.putText( + frame_array, + f"Status: {status_text}", + (10, 35), + cv2.FONT_HERSHEY_SIMPLEX, + 1.2, + (255, 255, 255), + 3, + ) + + # Draw clock in lower right corner, right justified + current_time = time.strftime("%H:%M:%S") + (text_width, _), _ = cv2.getTextSize( + current_time, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2 + ) + clock_x = self.width - text_width - 10 # 10px margin from right edge + clock_y = self.height - 30 # Move to 450 for height=480 + cv2.putText( + frame_array, + current_time, + (clock_x, clock_y), + cv2.FONT_HERSHEY_SIMPLEX, + 1.0, + (255, 255, 255), + 2, + ) + + # Select the most active audio buffer and get its speech status + best_proc = None + best_rms = 0.0 + speech_info = None + + try: + for pname, arr in self.__class__.buffer.items(): + try: + if len(arr) == 0: + rms = 0.0 + else: + rms = float(np.sqrt(np.mean(arr**2))) + if rms > best_rms: + best_rms = rms + best_proc = (pname, arr.copy()) + speech_info = self.__class__.speech_status.get(pname, {}) + except Exception: + continue + except Exception: + best_proc = None + + if best_proc is not None: + pname, arr = best_proc + + # Use the last 2 second of audio data, padded with zeros if less + samples_needed = SAMPLE_RATE * 2 # 2 second(s) + if len(arr) <= 0: + arr_segment = np.zeros(samples_needed, dtype=np.float32) + elif len(arr) >= samples_needed: + arr_segment = arr[-samples_needed:].copy() + else: + # Pad with zeros at the beginning + arr_segment = np.concatenate( + [np.zeros(samples_needed - len(arr), dtype=np.float32), arr] + ) + + # Single normalization code path: normalize based on the historical + # peak observed for this stream (proc.max_observed_amplitude). This + # ensures the waveform display is consistent over time and avoids + # using the instantaneous buffer peak. + proc = None + norm = arr_segment.astype(np.float32) + + # Map audio samples to pixels across the width + if norm.size < self.width: + padded = np.zeros(self.width, dtype=np.float32) + if norm.size > 0: + padded[-norm.size :] = norm + norm = padded + else: + block = int(np.ceil(norm.size / self.width)) + norm = np.array( + [ + np.mean(norm[i * block : min((i + 1) * block, norm.size)]) + for i in range(self.width) + ], + dtype=np.float32, + ) + + # For display we use the same `norm` computed above (single code + # path). Use `display_norm` alias to avoid confusion later in the + # code but don't recompute normalization. + display_norm = norm + + # Draw waveform with color coding for speech detection + points: list[tuple[int, int]] = [] + colors: list[tuple[int, int, int]] = [] # Color for each point + + for x in range(self.width): + v = ( + float(display_norm[x]) + if x < display_norm.size and not np.isnan(display_norm[x]) + else 0.0 + ) + y = int((1.0 - ((v + 1.0) / 2.0)) * (self.height - 120)) + 100 + points.append((x, y)) + + # Color based on speech detection status + is_speech = ( + speech_info.get("is_speech", False) if speech_info else False + ) + energy_check = ( + speech_info.get("energy_check", False) if speech_info else False + ) + + if is_speech: + colors.append((0, 255, 0)) # Green for detected speech + elif energy_check: + colors.append((255, 255, 0)) # Yellow for energy but not speech + else: + colors.append((128, 128, 128)) # Gray for background noise + + # Draw colored waveform + if len(points) > 1: + for i in range(len(points) - 1): + cv2.line(frame_array, points[i], points[i + 1], colors[i], 1) + + # Draw historical peak indicator (horizontal lines at +/-(target_peak)) + try: + if proc is not None and getattr(proc, "normalization_enabled", False): + target_peak = float(getattr(proc, "normalization_target_peak", 0.0)) + # Ensure target_peak is within [0, 1] + target_peak = max(0.0, min(1.0, target_peak)) + + def _amp_to_y(a: float) -> int: + return ( + int((1.0 - ((a + 1.0) / 2.0)) * (self.height - 120)) + 100 + ) + + top_y = _amp_to_y(target_peak) + bot_y = _amp_to_y(-target_peak) + + # Draw thin magenta lines across the waveform area + cv2.line( + frame_array, + (0, top_y), + (self.width - 1, top_y), + (255, 0, 255), + 1, + ) + cv2.line( + frame_array, + (0, bot_y), + (self.width - 1, bot_y), + (255, 0, 255), + 1, + ) + + # Label the peak with small text near the right edge + label = f"Peak:{target_peak:.2f}" + (tw, _), _ = cv2.getTextSize( + label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1 + ) + lx = max(10, self.width - tw - 12) + ly = max(12, top_y - 6) + cv2.putText( + frame_array, + label, + (lx, ly), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 0, 255), + 1, + ) + except Exception: + # Non-critical: ignore any drawing errors + pass + + # Add speech detection status overlay + if speech_info: + self._draw_speech_status(frame_array, speech_info, pname) + + cv2.putText( + frame_array, + f"Waveform: {pname}", + (10, self.height - 15), + cv2.FONT_HERSHEY_SIMPLEX, + 1.0, + (255, 255, 255), + 2, + ) frame = VideoFrame.from_ndarray(frame_array, format="bgr24") frame.pts = pts - frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000) - - self.frame_count += 1 + frame.time_base = fractions.Fraction(1 / 90000).limit_denominator(1000000) return frame - def _generate_ball_frame(self) -> Any: - """Generate bouncing ball visualization""" - frame = np.zeros((self.height, self.width, 3), dtype=np.uint8) + def _draw_speech_status( + self, + frame_array: npt.NDArray[np.uint8], + speech_info: Dict[str, Any], + pname: str, + ): + """Draw speech detection status information.""" - # Update ball position - self.ball_x += self.ball_dx - self.ball_y += self.ball_dy + y_offset = 100 - # Bounce off walls - if self.ball_x <= self.ball_radius or self.ball_x >= self.width - self.ball_radius: - self.ball_dx = -self.ball_dx - if self.ball_y <= self.ball_radius or self.ball_y >= self.height - self.ball_radius: - self.ball_dy = -self.ball_dy + # Main status + is_speech = speech_info.get("is_speech", False) + status_text = "SPEECH" if is_speech else "NOISE" + status_color = (0, 255, 0) if is_speech else (128, 128, 128) - # Draw ball - cv2.circle(frame, (int(self.ball_x), int(self.ball_y)), int(self.ball_radius), (0, 255, 0), -1) + adaptive_thresh = speech_info.get("adaptive_threshold", 0) + cv2.putText( + frame_array, + f"{pname}: {status_text} (thresh: {adaptive_thresh:.4f})", + (10, y_offset), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + status_color, + 2, + ) - # Add timestamp - timestamp = f"Frame: {self.frame_count}" - cv2.putText(frame, timestamp, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) + # Detailed metrics (smaller text) + metrics = [ + f"Energy: {speech_info.get('energy', 0):.3f} ({'Y' if speech_info.get('energy_check', False) else 'N'})", + f"ZCR: {speech_info.get('zcr', 0):.3f} ({'Y' if speech_info.get('zcr_check', False) else 'N'})", + f"Spectral: {'Y' if 300 < speech_info.get('centroid', 0) < 3400 else 'N'}/{'Y' if speech_info.get('rolloff', 0) < 2000 else 'N'}/{'Y' if speech_info.get('flux', 0) > 0.01 else 'N'} ({'Y' if speech_info.get('spectral_check', False) else 'N'})", + f"Harmonic: {speech_info.get('hamonicity', 0):.3f} ({'Y' if speech_info.get('harmonic_check', False) else 'N'})", + f"Temporal: ({'Y' if speech_info.get('temporal_consistency', False) else 'N'})", + ] - return frame + for _, metric in enumerate(metrics): + cv2.putText( + frame_array, + metric, + (320, y_offset), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + ) + y_offset += 15 - def _generate_waveform_frame(self) -> Any: - """Generate waveform visualization""" - frame = np.zeros((self.height, self.width, 3), dtype=np.uint8) + logic_result = "E:" + ("Y" if speech_info.get("energy_check", False) else "N") + logic_result += " Z:" + ("Y" if speech_info.get("zcr_check", False) else "N") + logic_result += " S:" + ( + "Y" if speech_info.get("spectral_check", False) else "N" + ) + logic_result += " H:" + ( + "Y" if speech_info.get("harmonic_check", False) else "N" + ) + logic_result += " T:" + ( + "Y" if speech_info.get("temporal_consistency", False) else "N" + ) - # Generate sine wave - x = np.linspace(0, 4*np.pi, self.width) - y = np.sin(x + self.frame_count * 0.1) * self.height // 4 + self.height // 2 + cv2.putText( + frame_array, + logic_result, + (320, y_offset + 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (255, 255, 255), + 1, + ) - # Draw waveform - for i in range(1, len(y)): - cv2.line(frame, (i-1, int(y[i-1])), (i, int(y[i])), (255, 255, 255), 2) - - return frame - - def _generate_static_frame(self) -> Any: - """Generate static color frame""" - color = self.config.get('static_color', (128, 128, 128)) - frame = np.full((self.height, self.width, 3), color, dtype=np.uint8) - return frame + # Noise floor indicator + noise_floor = speech_info.get("noise_floor_energy", 0) + cv2.putText( + frame_array, + f"Noise Floor: {noise_floor:.4f}", + (10, y_offset + 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (200, 200, 200), + 1, + ) -def create_vibevoice_tracks(session_name: str, config: Optional[Dict[str, Any]] = None) -> Dict[str, MediaStreamTrack]: - """ - Create VibeVoice-enabled WebRTC tracks. - - Args: - session_name: Name for the session - config: Configuration dictionary - - Returns: - Dictionary containing 'video' and 'audio' tracks - """ - if config is None: - config = {} - - # Set defaults - default_config = { - 'visualization': 'ball', - 'audio_mode': 'silence', # Start with silence - 'width': 320, - 'height': 240, - 'fps': 15, - 'sample_rate': 24000, # VibeVoice native sample rate - 'samples_per_frame': 480, # 20ms at 24kHz - 'frequency': 440.0, - 'volume': 0.1, - 'static_color': (128, 128, 128), - 'model_path': 'vibevoice/VibeVoice-1.5B', # Default model path - 'device': 'cuda' if torch.cuda.is_available() else 'cpu', - 'voice_preset': None - } - default_config.update(config) - - # Parse static_color if it's a string - if isinstance(default_config.get('static_color'), str): - try: - color_str = default_config['static_color'] - r, g, b = map(int, color_str.split(',')) - default_config['static_color'] = (r, g, b) - except (ValueError, TypeError): - logger.warning(f"Invalid static_color format: {default_config.get('static_color')}, using default") - default_config['static_color'] = (128, 128, 128) - - media_clock = MediaClock() - - # Initialize VibeVoice processor - tts_processor = VibeVoiceWebRTCProcessor( - model_path=default_config['model_path'], - device=default_config['device'], - inference_steps=5 +async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None: + """Handle incoming audio tracks from WebRTC peers.""" + logger.info( + f"handle_track_received called for {peer.peer_name} with track kind: {track.kind}" ) + global _audio_processors, _send_chat_func - # Create tracks - video_track = ConfigurableVideoTrack(media_clock, default_config) - audio_track = VibeVoiceAudioTrack(media_clock, default_config, tts_processor) + if track.kind != "audio": + logger.info(f"Ignoring non-audio track from {peer.peer_name}: {track.kind}") + return - logger.info(f"Created VibeVoice tracks for {session_name} with config: {default_config}") + # Initialize raw audio buffer for immediate graphing + if peer.peer_name not in WaveformVideoTrack.buffer: + WaveformVideoTrack.buffer[peer.peer_name] = np.array([], dtype=np.float32) - # Store tracks and processor in global registry - _active_tracks[session_name] = { - "video": video_track, - "audio": audio_track, - "tts_processor": tts_processor - } + # Start processing frames immediately (before processor is ready) + logger.info(f"Starting frame processing loop for {peer.peer_name}") + frame_count = 0 + while True: + try: + frame = await track.recv() + frame_count += 1 - return {"video": video_track, "audio": audio_track} + if frame_count % 100 == 0: + logger.debug(f"Received {frame_count} frames from {peer.peer_name}") + + except MediaStreamError as e: + logger.info(f"Audio stream ended for {peer.peer_name}: {e}") + break + except Exception as e: + logger.error(f"Error receiving frame from {peer.peer_name}: {e}") + break + + if isinstance(frame, AudioFrame): + try: + # Convert frame to numpy array + audio_data = frame.to_ndarray() + + # Handle audio format conversion + audio_data = _process_audio_frame(audio_data, frame) + + # Resample if needed + if frame.sample_rate != SAMPLE_RATE: + audio_data = _resample_audio( + audio_data, frame.sample_rate, SAMPLE_RATE + ) + + # Convert to float32 + audio_data_float32 = audio_data.astype(np.float32) + + # Update visualization buffer immediately + WaveformVideoTrack.buffer[peer.peer_name] = np.concatenate( + [WaveformVideoTrack.buffer[peer.peer_name], audio_data_float32] + ) + # Limit buffer size to last 10 seconds + max_samples = SAMPLE_RATE * 10 + if len(WaveformVideoTrack.buffer[peer.peer_name]) > max_samples: + WaveformVideoTrack.buffer[peer.peer_name] = ( + WaveformVideoTrack.buffer[peer.peer_name][-max_samples:] + ) + + # # Process with optimized processor if available + # if peer.peer_name in _audio_processors: + # audio_processor = _audio_processors[peer.peer_name] + # audio_processor.add_audio_data(audio_data_float32) + + except Exception as e: + logger.error(f"Error processing audio frame for {peer.peer_name}: {e}") + continue + + # If processor already exists, just continue processing + # audio_processor = _audio_processors[peer.peer_name] + logger.info(f"Continuing OpenVINO audio processing for {peer.peer_name}") + + try: + frame_count = 0 + logger.info(f"Entering frame processing loop for {peer.peer_name}") + while True: + try: + logger.debug(f"Waiting for frame from {peer.peer_name}") + frame = await track.recv() + frame_count += 1 + + if frame_count == 1: + logger.info(f"Received first frame from {peer.peer_name}") + elif frame_count % 50 == 0: + logger.info(f"Received {frame_count} frames from {peer.peer_name}") + + except MediaStreamError as e: + logger.info(f"Audio stream ended for {peer.peer_name}: {e}") + break + except Exception as e: + logger.error(f"Error receiving frame from {peer.peer_name}: {e}") + break + + if isinstance(frame, AudioFrame): + try: + # Convert frame to numpy array + audio_data = frame.to_ndarray() + + # Handle audio format conversion + audio_data = _process_audio_frame(audio_data, frame) + + # Resample if needed + if frame.sample_rate != SAMPLE_RATE: + audio_data = _resample_audio( + audio_data, frame.sample_rate, SAMPLE_RATE + ) + + # Convert to float32 + audio_data_float32 = audio_data.astype(np.float32) + + logger.debug( + f"Processed audio frame {frame_count} from {peer.peer_name}: {len(audio_data_float32)} samples" + ) + + # Process with optimized processor if available + # audio_processor.add_audio_data(audio_data_float32) + + except Exception as e: + logger.error( + f"Error processing audio frame for {peer.peer_name}: {e}" + ) + continue + + except Exception as e: + logger.error( + f"Unexpected error in audio processing for {peer.peer_name}: {e}", + exc_info=True, + ) + finally: + cleanup_peer_processor(peer.peer_name) -# Agent descriptor -AGENT_NAME = "VibeVoice TTS Bot" -AGENT_DESCRIPTION = "WebRTC bot with VibeVoice text-to-speech capabilities" +def _process_audio_frame( + audio_data: npt.NDArray[Any], frame: AudioFrame +) -> npt.NDArray[np.float32]: + """Process audio frame format conversion.""" + # Handle stereo to mono conversion + if audio_data.ndim == 2: + if audio_data.shape[0] == 1: + audio_data = audio_data.squeeze(0) + else: + audio_data = np.mean( + audio_data, axis=0 if audio_data.shape[0] > audio_data.shape[1] else 1 + ) + # Normalize based on data type + if audio_data.dtype == np.int16: + audio_data = audio_data.astype(np.float32) / 32768.0 + elif audio_data.dtype == np.int32: + audio_data = audio_data.astype(np.float32) / 2147483648.0 + + return audio_data.astype(np.float32) + + +def _resample_audio( + audio_data: npt.NDArray[np.float32], orig_sr: int, target_sr: int +) -> npt.NDArray[np.float32]: + """Resample audio efficiently.""" + try: + # Handle stereo audio by converting to mono if necessary + if audio_data.ndim > 1: + audio_data = np.mean(audio_data, axis=1) + + # Use high-quality resampling + resampled = librosa.resample( # type: ignore + audio_data.astype(np.float64), + orig_sr=orig_sr, + target_sr=target_sr, + res_type="kaiser_fast", # Good balance of quality and speed + ) + return resampled.astype(np.float32) # type: ignore + except Exception as e: + logger.error(f"Resampling failed: {e}") + raise ValueError( + f"Failed to resample audio from {orig_sr} Hz to {target_sr} Hz: {e}" + ) + + +# Public API functions def agent_info() -> Dict[str, str]: - """Return agent metadata for discovery.""" return { "name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "true", "configurable": "true", - "has_tts": "true" } def get_config_schema() -> Dict[str, Any]: - """Get the configuration schema for the VibeVoice TTS Bot.""" + """Get the configuration schema for the Whisper bot""" return { "bot_name": AGENT_NAME, "version": "1.0", "parameters": [ - { - "name": "visualization", - "type": "select", - "label": "Video Visualization Mode", - "description": "Choose the type of video visualization to display", - "default_value": "ball", - "required": True, - "options": [ - {"value": "ball", "label": "Bouncing Ball"}, - {"value": "waveform", "label": "Sine Wave Animation"}, - {"value": "static", "label": "Static Color Frame"} - ] - }, - { - "name": "audio_mode", - "type": "select", - "label": "Fallback Audio Mode", - "description": "Audio mode when no TTS is active", - "default_value": "silence", - "required": True, - "options": [ - {"value": "silence", "label": "Silence"}, - {"value": "tone", "label": "Sine Wave Tone"}, - {"value": "noise", "label": "White Noise"} - ] - }, - { - "name": "voice_preset", - "type": "string", - "label": "Voice Preset", - "description": "Name of the voice preset to use for TTS", - "default_value": "", - "required": False, - "max_length": 50 - }, - { - "name": "device", - "type": "select", - "label": "Processing Device", - "description": "Device to use for TTS processing", - "default_value": "cuda", - "required": True, - "options": [ - {"value": "cuda", "label": "CUDA (GPU)"}, - {"value": "cpu", "label": "CPU"}, - {"value": "mps", "label": "Apple Metal (MPS)"} - ] - }, - { - "name": "width", - "type": "number", - "label": "Video Width", - "description": "Width of the video frame in pixels", - "default_value": 320, - "required": False, - "min_value": 160, - "max_value": 1920, - "step": 1 - }, - { - "name": "height", - "type": "number", - "label": "Video Height", - "description": "Height of the video frame in pixels", - "default_value": 240, - "required": False, - "min_value": 120, - "max_value": 1080, - "step": 1 - }, - { - "name": "volume", - "type": "range", - "label": "Audio Volume", - "description": "Volume level (0.0 to 1.0)", - "default_value": 0.1, - "required": False, - "min_value": 0.0, - "max_value": 1.0, - "step": 0.1 - } ], "categories": [ - { - "TTS Settings": ["voice_preset", "device"] - }, - { - "Video Settings": ["visualization", "width", "height"] - }, - { - "Audio Settings": ["audio_mode", "volume"] - } - ] + ], } +def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool: + """Handle configuration update for a specific lobby""" + global _model_id, _device, _ov_config + + try: + logger.info(f"Updating TTS config for lobby {lobby_id}: {config_values}") + + config_applied = False + + config_applied = True + + return config_applied + + except Exception as e: + logger.error(f"Failed to apply Whisper config update: {e}") + return False + + def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]: - """Factory wrapper used by the FastAPI service to instantiate tracks for an agent.""" - return create_vibevoice_tracks(session_name) + """Create agent tracks. Provides a synthetic video waveform track and a silent audio track for compatibility.""" + + class SilentAudioTrack(MediaStreamTrack): + kind = "audio" + + def __init__( + self, sample_rate: int = SAMPLE_RATE, channels: int = 1, fps: int = 50 + ): + super().__init__() + self.sample_rate = sample_rate + self.channels = channels + self.fps = fps + self.samples_per_frame = int(self.sample_rate / self.fps) + self._timestamp = 0 + # Per-track playback buffer (float32, mono, -1..1) + self._play_buffer = np.array([], dtype=np.float32) + # Source sample rate of the buffered audio (if any) + self._play_src_sr = None + # Phase for synthetic sine tone (radians) + self._sine_phase = 0.0 + + async def recv(self) -> AudioFrame: + global _audio_data + # If new global TTS audio was produced, grab and prepare it for playback + logger.info("recv called with _audio_data: %s and size: %s", "set" if _audio_data is not None else "unset", getattr(_audio_data, 'size', 0) if _audio_data is not None else 'N/A') + try: + if _audio_data is not None and getattr(_audio_data, 'size', 0) > 0: + # Get source sample rate from TTS engine if available + try: + from .vibevoicetts import get_tts_engine + + src_sr = get_tts_engine().get_sample_rate() + except Exception: + # Fallback: assume 24000 (VibeVoice default) + logger.info("Falling back to default source sample rate of 24000 Hz") + src_sr = 24000 + + # Ensure numpy float32 + src_audio = np.asarray(_audio_data, dtype=np.float32) + + # Clear global buffer (we consumed it into this track) + _audio_data = None + + # If source sr differs from track sr, resample + if src_sr != self.sample_rate: + try: + resampled = librosa.resample( + src_audio.astype(np.float64), + orig_sr=src_sr, + target_sr=self.sample_rate, + res_type="kaiser_fast", + ) + self._play_buffer = resampled.astype(np.float32) + except Exception: + # On failure, fallback to nearest simple scaling (not ideal) + logger.exception("Failed to resample TTS audio; using source as-is") + self._play_buffer = src_audio.astype(np.float32) + else: + self._play_buffer = src_audio.astype(np.float32) + + self._play_src_sr = self.sample_rate + + # Save a copy of the consumed audio to disk for inspection + try: + pcm_save = np.clip(self._play_buffer, -1.0, 1.0) + pcm_int16_save = (pcm_save * 32767.0).astype(np.int16) + sample_path = "./sample.wav" + with wave.open(sample_path, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(self.sample_rate) + wf.writeframes(pcm_int16_save.tobytes()) + logger.info("Wrote TTS sample to %s (%d samples, %d Hz)", sample_path, len(pcm_int16_save), self.sample_rate) + except Exception: + logger.exception("Failed to write sample.wav") + + # Prepare output samples for this frame + if self._play_buffer.size > 0: + take = min(self.samples_per_frame, int(self._play_buffer.size)) + out = self._play_buffer[:take] + # Advance buffer + if take >= self._play_buffer.size: + self._play_buffer = np.array([], dtype=np.float32) + else: + self._play_buffer = self._play_buffer[take:] + else: + # No TTS audio buffered; output a 110 Hz sine tone at 0.1 volume + freq = 110.0 + volume = 0.1 + n = self.samples_per_frame + # incremental phase per sample + phase_inc = 2.0 * np.pi * freq / float(self.sample_rate) + samples_idx = np.arange(n, dtype=np.float32) + out = volume * np.sin(self._sine_phase + phase_inc * samples_idx) + # advance and wrap phase + self._sine_phase = (self._sine_phase + phase_inc * n) % (2.0 * np.pi) + logger.debug("No TTS audio: emitting 110Hz test tone (%d samples at %d Hz)", n, self.sample_rate) + + # Convert float32 [-1.0,1.0] to int16 PCM + pcm = np.clip(out, -1.0, 1.0) + pcm_int16 = (pcm * 32767.0).astype(np.int16) + + # aiortc AudioFrame.from_ndarray expects shape (channels, samples) + if self.channels == 1: + data = np.expand_dims(pcm_int16, axis=0) + layout = "mono" + else: + # Duplicate mono into stereo + data = np.vstack([pcm_int16, pcm_int16]) + layout = "stereo" + + frame = AudioFrame.from_ndarray(data, layout=layout) + frame.sample_rate = self.sample_rate + frame.pts = self._timestamp + frame.time_base = fractions.Fraction(1, self.sample_rate) + self._timestamp += self.samples_per_frame + + # Pace the frame rate to avoid busy-looping + await asyncio.sleep(1 / self.fps) + + return frame + except Exception as e: + logger.exception(f"Error in SilentAudioTrack.recv: {e}") + # On error, fall back to silent frame + data = np.zeros((self.channels, self.samples_per_frame), dtype=np.int16) + frame = AudioFrame.from_ndarray(data, layout="mono" if self.channels == 1 else "stereo") + frame.sample_rate = self.sample_rate + frame.pts = self._timestamp + frame.time_base = fractions.Fraction(1, self.sample_rate) + self._timestamp += self.samples_per_frame + await asyncio.sleep(1 / self.fps) + return frame + + try: + video_track = WaveformVideoTrack( + session_name=session_name, width=640, height=480, fps=15 + ) + audio_track = SilentAudioTrack() + return {"video": video_track, "audio": audio_track} + except Exception as e: + logger.error(f"Failed to create agent tracks: {e}") + return {} async def handle_chat_message( - chat_message: ChatMessageModel, - send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]] + chat_message: ChatMessageModel, + send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]], ) -> Optional[str]: - """Handle chat messages and convert them to speech.""" + """Handle incoming chat messages.""" + global _audio_data + logger.info(f"Received chat message: {chat_message.message}") + # This ends up blocking; spin it off into a thread try: - message_text = chat_message.message.strip() - if not message_text: - return None - - logger.info(f"Processing TTS for message: '{message_text[:50]}...'") - - # Find the TTS processor for this session - # Note: This assumes the session_name is available somehow - # You might need to modify this based on how sessions are handled - session_name = getattr(chat_message, 'session_id', None) - if not session_name: - # Try to find any active session (fallback) - if _active_tracks: - session_name = list(_active_tracks.keys())[0] - - if session_name and session_name in _active_tracks: - tts_processor = _active_tracks[session_name].get("tts_processor") - if tts_processor: - # Get voice configuration - audio_track = _active_tracks[session_name].get("audio") - voice_preset = audio_track.config.get('voice_preset') if audio_track else None - - # Generate speech asynchronously - await tts_processor.generate_speech_async(message_text, voice_preset) - - # Send confirmation message - response_message = ChatMessageModel( - id=secrets.token_hex(16), - message=f"🔊 Converting to speech: '{message_text[:30]}{'...' if len(message_text) > 30 else ''}'", - sender_name="VibeVoice Bot", - sender_session_id=session_name, - lobby_id=chat_message.lobby_id, - timestamp=time.time() - ) - await send_message_func(response_message) - else: - logger.error(f"No TTS processor found for session {session_name}") - else: - logger.error("No active session found for TTS processing") - - return None - - except Exception as e: - logger.error(f"Error processing chat message for TTS: {e}") - return None + from .vibevoicetts import get_tts_engine + + engine = get_tts_engine() + _audio_data = engine.text_to_speech(text=f"Speaker 1: {chat_message.message}", speaker_names=None, verbose=True) + except Exception: + # Fallback to convenience function if direct engine call fails + _audio_data = text_to_speech(text=f"Speaker 1: {chat_message.message}", speaker_names=None) + return None -def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool: - """Handle runtime configuration updates for the VibeVoice bot.""" - try: - validated_config = {} +async def on_track_received(peer: Peer, track: MediaStreamTrack) -> None: + """Callback when a new track is received from a peer.""" + await handle_track_received(peer, track) - # Parse and validate static_color if provided - if 'static_color' in config_values: - color_value = config_values['static_color'] - if isinstance(color_value, str): - try: - r, g, b = map(int, color_value.split(',')) - validated_config['static_color'] = (r, g, b) - except (ValueError, TypeError): - logger.warning(f"Invalid static_color format: {color_value}, ignoring update") - return False - elif isinstance(color_value, (tuple, list)) and len(color_value) == 3: - validated_config['static_color'] = tuple(color_value) - else: - logger.warning(f"Invalid static_color type: {type(color_value)}, ignoring update") - return False - # Copy other valid configuration values - valid_keys = {'visualization', 'audio_mode', 'width', 'height', 'fps', - 'sample_rate', 'frequency', 'volume', 'voice_preset', 'device'} - for key in valid_keys: - if key in config_values: - validated_config[key] = config_values[key] +def get_track_handler() -> Callable[[Peer, MediaStreamTrack], Awaitable[None]]: + """Return the track handler function.""" + return on_track_received - if validated_config: - logger.info(f"Configuration updated for {lobby_id}: {validated_config}") - # Update running tracks if they exist in the registry - if lobby_id in _active_tracks: - tracks = _active_tracks[lobby_id] - video_track = tracks.get("video") - audio_track = tracks.get("audio") +def bind_send_chat_function( + send_chat_func: Callable[[ChatMessageModel], Awaitable[None]], + create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel], +) -> None: + """Bind the send chat function.""" + global _send_chat_func, _create_chat_message_func, _audio_processors - # Update video track configuration - if video_track and hasattr(video_track, 'update_config'): - video_updated = video_track.update_config(validated_config) - if video_updated: - logger.info(f"Video track configuration updated for {lobby_id}") - else: - logger.warning(f"Failed to update video track configuration for {lobby_id}") + logger.info("Binding send chat function to OpenVINO whisper agent") + _send_chat_func = send_chat_func + _create_chat_message_func = create_chat_message_func + + # Update existing processors + # for peer_name, processor in _audio_processors.items(): + # processor.send_chat_func = send_chat_func + # processor.create_chat_message_func = create_chat_message_func + # logger.debug(f"Updated processor for {peer_name} with new send chat function") + + +def cleanup_peer_processor(peer_name: str) -> None: + """Clean up processor for disconnected peer.""" + global _audio_processors + + if peer_name in WaveformVideoTrack.buffer: + del WaveformVideoTrack.buffer[peer_name] - # Update audio track configuration - if audio_track and hasattr(audio_track, 'update_config'): - audio_updated = audio_track.update_config(validated_config) - if audio_updated: - logger.info(f"Audio track configuration updated for {lobby_id}") - else: - logger.warning(f"Failed to update audio track configuration for {lobby_id}") - return True - else: - logger.warning(f"No active tracks found for session {lobby_id}") - return False - else: - logger.warning(f"No valid configuration values provided for {lobby_id}") - return False - except Exception as e: - logger.error(f"Error updating configuration for {lobby_id}: {e}") - return False \ No newline at end of file