From 30886d9fa82a11acce0dd16f6be717b1252983a8 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Fri, 19 Sep 2025 13:25:18 -0700 Subject: [PATCH] Move vibevoice to GPU --- voicebot/bots/vibevoice.py | 1143 --------------------------------- voicebot/bots/vibevoicetts.py | 51 +- voicebot/requirements.txt | 25 +- 3 files changed, 53 insertions(+), 1166 deletions(-) delete mode 100644 voicebot/bots/vibevoice.py diff --git a/voicebot/bots/vibevoice.py b/voicebot/bots/vibevoice.py deleted file mode 100644 index 6c397f5..0000000 --- a/voicebot/bots/vibevoice.py +++ /dev/null @@ -1,1143 +0,0 @@ -import asyncio -import threading -import numpy as np -import time -import os -from typing import Dict, Optional, Callable, Awaitable, Any, Union -import numpy.typing as npt - - -# Core dependencies -import librosa -from shared.logger import logger -from aiortc import MediaStreamTrack -from aiortc.mediastreams import MediaStreamError -from av import AudioFrame, VideoFrame -import cv2 -import fractions -from time import perf_counter - -# Import shared models for chat functionality -import sys - -sys.path.append( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -) -from shared.models import ChatMessageModel -from voicebot.models import Peer - -# Defer importing the heavy TTS module until needed (lazy import). -# Tests and other code can monkeypatch these names on this module. -get_tts_engine = None -text_to_speech = None - -# Global configuration and constants -AGENT_NAME = "Text To Speech Bot" -AGENT_DESCRIPTION = ( - "Real-time speech generation- converts text to speech on Intel Arc B580" -) -SAMPLE_RATE = 16000 # Whisper expects 16kHz - - -_audio_data = None - -# Track background processing threads per lobby to ensure only one worker runs -# per lobby session at a time. -_lobby_threads: Dict[str, Any] = {} -_lobby_threads_lock = threading.Lock() -# Per-lobby job metadata for status display -_lobby_jobs: Dict[str, Dict[str, Any]] = {} - -class TTSModel: - """TTS via MSFT VibeVoice""" - - def __init__( - self, - peer_name: str, - send_chat_func: Callable[[ChatMessageModel], Awaitable[None]], - create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel], - ): - self.peer_name = peer_name - self.send_chat_func = send_chat_func - self.create_chat_message_func = create_chat_message_func - - -_send_chat_func: Optional[Callable[[ChatMessageModel], Awaitable[None]]] = None -_create_chat_message_func: Optional[ - Callable[[str, Optional[str]], ChatMessageModel] -] = None - -class MediaClock: - """Simple monotonic clock for media tracks.""" - - def __init__(self) -> None: - self.t0 = perf_counter() - - def now(self) -> float: - return perf_counter() - self.t0 - - -class WaveformVideoTrack(MediaStreamTrack): - """Video track that renders a live waveform of the incoming audio. - - The track reads the most-active `OptimizedAudioProcessor` in - `_audio_processors` and renders the last ~2s of its `current_phrase_audio`. - If no audio is available, the track will display a "No audio" message. - """ - - kind = "video" - - # Shared buffer for audio data - buffer: Dict[str, npt.NDArray[np.float32]] = {} - speech_status: Dict[str, Dict[str, Any]] = {} - - def __init__( - self, session_name: str, width: int = 640, height: int = 480, fps: int = 15 - ) -> None: - super().__init__() - self.session_name = session_name - self.width = int(width) - self.height = int(height) - self.fps = int(fps) - self.clock = MediaClock() - self._next_frame_index = 0 - - async def next_timestamp(self) -> tuple[int, float]: - pts = int(self._next_frame_index * (1 / self.fps) * 90000) - time_base = 1 / 90000 - return pts, time_base - - async def recv(self) -> VideoFrame: - pts, _ = await self.next_timestamp() - - # schedule frame according to clock - target_t = self._next_frame_index / self.fps - now = self.clock.now() - if target_t > now: - await asyncio.sleep(target_t - now) - - self._next_frame_index += 1 - - frame_array: npt.NDArray[np.uint8] = np.zeros( - (self.height, self.width, 3), dtype=np.uint8 - ) - - # Display model / TTS job status prominently - status_text = "Idle" - progress = 0.0 - - try: - # Show number of active lobby TTS jobs - with _lobby_threads_lock: - active_jobs = {k: v for k, v in _lobby_jobs.items() if v.get("status") == "running"} - - if active_jobs: - # Pick the most recent running job (sort by start_time desc) - recent_lobby, recent_job = sorted( - active_jobs.items(), key=lambda kv: kv[1].get("start_time", 0), reverse=True - )[0] - # elapsed time available if needed for overlays - _ = time.time() - recent_job.get("start_time", time.time()) - snippet = (recent_job.get("message") or "")[:80] - status_text = f"TTS running ({len(active_jobs)}): {snippet}" - # Progress as a heuristic (running jobs -> 0.5) - progress = 0.5 - else: - # If there is produced audio waiting in _audio_data, show ready-to-play - if _audio_data is not None and getattr(_audio_data, "size", 0) > 0: - status_text = "TTS: Ready (audio buffered)" - progress = 1.0 - else: - status_text = "Idle" - progress = 0.0 - except Exception: - status_text = "Status: error" - progress = 0.0 - - # Draw status background (increased height for larger text) - cv2.rectangle(frame_array, (0, 0), (self.width, 80), (0, 0, 0), -1) - - # Draw progress bar if loading - if progress < 1.0 and "Ready" not in status_text: - bar_width = int(progress * (self.width - 40)) - cv2.rectangle(frame_array, (20, 55), (20 + bar_width, 70), (0, 255, 0), -1) - cv2.rectangle( - frame_array, (20, 55), (self.width - 20, 70), (255, 255, 255), 2 - ) - - # Draw status text (larger font) - cv2.putText( - frame_array, - f"Status: {status_text}", - (10, 35), - cv2.FONT_HERSHEY_SIMPLEX, - 1.2, - (255, 255, 255), - 3, - ) - - # Draw clock in lower right corner, right justified - current_time = time.strftime("%H:%M:%S") - (text_width, _), _ = cv2.getTextSize( - current_time, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2 - ) - clock_x = self.width - text_width - 10 # 10px margin from right edge - clock_y = self.height - 30 # Move to 450 for height=480 - cv2.putText( - frame_array, - current_time, - (clock_x, clock_y), - cv2.FONT_HERSHEY_SIMPLEX, - 1.0, - (255, 255, 255), - 2, - ) - - # Draw TTS job status (small text under header) - try: - with _lobby_threads_lock: - active_jobs_count = sum(1 for v in _lobby_jobs.values() if v.get("status") == "running") - # Get most recent job info if present - recent_job = None - if active_jobs_count > 0: - running = [v for v in _lobby_jobs.values() if v.get("status") == "running"] - recent_job = max(running, key=lambda j: j.get("start_time", 0)) - - audio_samples = getattr(_audio_data, "size", 0) if _audio_data is not None else 0 - - job_snippet = "" - job_elapsed = None - if recent_job: - job_snippet = (recent_job.get("message") or "")[:80] - job_elapsed = time.time() - recent_job.get("start_time", time.time()) - - info_x = 10 - info_y = 95 - - if active_jobs_count > 0: - cv2.putText( - frame_array, - f"TTS jobs: {active_jobs_count} | {job_snippet}", - (info_x, info_y), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (0, 255, 255), - 1, - ) - if job_elapsed is not None: - cv2.putText( - frame_array, - f"Elapsed: {job_elapsed:.1f}s", - (info_x + 420, info_y), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (200, 200, 255), - 1, - ) - else: - # show audio buffer status when idle - cv2.putText( - frame_array, - f"Audio buffered: {audio_samples} samples", - (info_x, info_y), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (200, 200, 200), - 1, - ) - except Exception: - # Non-critical overlay; ignore failures - pass - - # Prefer the generated TTS audio buffer if present. vibevoice ignores - # incoming WebRTC audio; the global `_audio_data` is populated by the - # TTS background worker and represents the audio we want to visualize. - best_proc = None - best_rms = 0.0 - speech_info = None - - try: - if _audio_data is not None and getattr(_audio_data, "size", 0) > 0: - # Use a synthetic pname to indicate TTS-generated audio and - # copy the buffer for safe local use. - try: - tts_arr = np.asarray(_audio_data, dtype=np.float32) - best_proc = ("__tts__", tts_arr.copy()) - # Mark as speech for coloring purposes - speech_info = {"is_speech": True, "energy_check": True} - except Exception: - best_proc = None - else: - for pname, arr in self.__class__.buffer.items(): - try: - if len(arr) == 0: - rms = 0.0 - else: - rms = float(np.sqrt(np.mean(arr**2))) - if rms > best_rms: - best_rms = rms - best_proc = (pname, arr.copy()) - speech_info = self.__class__.speech_status.get(pname, {}) - except Exception: - continue - except Exception: - best_proc = None - - if best_proc is not None: - pname, arr = best_proc - - # Use the last 2 second of audio data, padded with zeros if less - samples_needed = SAMPLE_RATE * 2 # 2 second(s) - if len(arr) <= 0: - arr_segment = np.zeros(samples_needed, dtype=np.float32) - elif len(arr) >= samples_needed: - arr_segment = arr[-samples_needed:].copy() - else: - # Pad with zeros at the beginning - arr_segment = np.concatenate( - [np.zeros(samples_needed - len(arr), dtype=np.float32), arr] - ) - - # Single normalization code path: normalize based on the historical - # peak observed for this stream (proc.max_observed_amplitude). This - # ensures the waveform display is consistent over time and avoids - # using the instantaneous buffer peak. - proc = None - norm = arr_segment.astype(np.float32) - - # Map audio samples to pixels across the width - if norm.size < self.width: - padded = np.zeros(self.width, dtype=np.float32) - if norm.size > 0: - padded[-norm.size :] = norm - norm = padded - else: - block = int(np.ceil(norm.size / self.width)) - norm = np.array( - [ - np.mean(norm[i * block : min((i + 1) * block, norm.size)]) - for i in range(self.width) - ], - dtype=np.float32, - ) - - # For display we use the same `norm` computed above (single code - # path). Use `display_norm` alias to avoid confusion later in the - # code but don't recompute normalization. - display_norm = norm - - # Draw waveform with color coding for speech detection - points: list[tuple[int, int]] = [] - colors: list[tuple[int, int, int]] = [] # Color for each point - - for x in range(self.width): - v = ( - float(display_norm[x]) - if x < display_norm.size and not np.isnan(display_norm[x]) - else 0.0 - ) - y = int((1.0 - ((v + 1.0) / 2.0)) * (self.height - 120)) + 100 - points.append((x, y)) - - # Color based on speech detection status - is_speech = ( - speech_info.get("is_speech", False) if speech_info else False - ) - energy_check = ( - speech_info.get("energy_check", False) if speech_info else False - ) - - if is_speech: - colors.append((0, 255, 0)) # Green for detected speech - elif energy_check: - colors.append((255, 255, 0)) # Yellow for energy but not speech - else: - colors.append((128, 128, 128)) # Gray for background noise - - # Draw colored waveform - if len(points) > 1: - for i in range(len(points) - 1): - cv2.line(frame_array, points[i], points[i + 1], colors[i], 1) - - # Draw historical peak indicator (horizontal lines at +/-(target_peak)) - try: - if proc is not None and getattr(proc, "normalization_enabled", False): - target_peak = float(getattr(proc, "normalization_target_peak", 0.0)) - # Ensure target_peak is within [0, 1] - target_peak = max(0.0, min(1.0, target_peak)) - - def _amp_to_y(a: float) -> int: - return ( - int((1.0 - ((a + 1.0) / 2.0)) * (self.height - 120)) + 100 - ) - - top_y = _amp_to_y(target_peak) - bot_y = _amp_to_y(-target_peak) - - # Draw thin magenta lines across the waveform area - cv2.line( - frame_array, - (0, top_y), - (self.width - 1, top_y), - (255, 0, 255), - 1, - ) - cv2.line( - frame_array, - (0, bot_y), - (self.width - 1, bot_y), - (255, 0, 255), - 1, - ) - - # Label the peak with small text near the right edge - label = f"Peak:{target_peak:.2f}" - (tw, _), _ = cv2.getTextSize( - label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1 - ) - lx = max(10, self.width - tw - 12) - ly = max(12, top_y - 6) - cv2.putText( - frame_array, - label, - (lx, ly), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 0, 255), - 1, - ) - except Exception: - # Non-critical: ignore any drawing errors - pass - - # Add speech detection status overlay - if speech_info: - self._draw_speech_status(frame_array, speech_info, pname) - - cv2.putText( - frame_array, - f"Waveform: {pname}", - (10, self.height - 15), - cv2.FONT_HERSHEY_SIMPLEX, - 1.0, - (255, 255, 255), - 2, - ) - - frame = VideoFrame.from_ndarray(frame_array, format="bgr24") - frame.pts = pts - frame.time_base = fractions.Fraction(1 / 90000).limit_denominator(1000000) - return frame - - def _draw_speech_status( - self, - frame_array: npt.NDArray[np.uint8], - speech_info: Dict[str, Any], - pname: str, - ): - """Draw speech detection status information.""" - - y_offset = 100 - - # Main status - is_speech = speech_info.get("is_speech", False) - status_text = "SPEECH" if is_speech else "NOISE" - status_color = (0, 255, 0) if is_speech else (128, 128, 128) - - adaptive_thresh = speech_info.get("adaptive_threshold", 0) - cv2.putText( - frame_array, - f"{pname}: {status_text} (thresh: {adaptive_thresh:.4f})", - (10, y_offset), - cv2.FONT_HERSHEY_SIMPLEX, - 0.7, - status_color, - 2, - ) - - # Detailed metrics (smaller text) - metrics = [ - f"Energy: {speech_info.get('energy', 0):.3f} ({'Y' if speech_info.get('energy_check', False) else 'N'})", - f"ZCR: {speech_info.get('zcr', 0):.3f} ({'Y' if speech_info.get('zcr_check', False) else 'N'})", - f"Spectral: {'Y' if 300 < speech_info.get('centroid', 0) < 3400 else 'N'}/{'Y' if speech_info.get('rolloff', 0) < 2000 else 'N'}/{'Y' if speech_info.get('flux', 0) > 0.01 else 'N'} ({'Y' if speech_info.get('spectral_check', False) else 'N'})", - f"Harmonic: {speech_info.get('hamonicity', 0):.3f} ({'Y' if speech_info.get('harmonic_check', False) else 'N'})", - f"Temporal: ({'Y' if speech_info.get('temporal_consistency', False) else 'N'})", - ] - - for _, metric in enumerate(metrics): - cv2.putText( - frame_array, - metric, - (320, y_offset), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - ) - y_offset += 15 - - logic_result = "E:" + ("Y" if speech_info.get("energy_check", False) else "N") - logic_result += " Z:" + ("Y" if speech_info.get("zcr_check", False) else "N") - logic_result += " S:" + ( - "Y" if speech_info.get("spectral_check", False) else "N" - ) - logic_result += " H:" + ( - "Y" if speech_info.get("harmonic_check", False) else "N" - ) - logic_result += " T:" + ( - "Y" if speech_info.get("temporal_consistency", False) else "N" - ) - - cv2.putText( - frame_array, - logic_result, - (320, y_offset + 5), - cv2.FONT_HERSHEY_SIMPLEX, - 0.6, - (255, 255, 255), - 1, - ) - - # Noise floor indicator - noise_floor = speech_info.get("noise_floor_energy", 0) - cv2.putText( - frame_array, - f"Noise Floor: {noise_floor:.4f}", - (10, y_offset + 30), - cv2.FONT_HERSHEY_SIMPLEX, - 0.6, - (200, 200, 200), - 1, - ) - - -async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None: - """Handle incoming audio tracks from WebRTC peers.""" - logger.info( - f"handle_track_received called for {peer.peer_name} with track kind: {track.kind}" - ) - global _audio_processors, _send_chat_func - - if track.kind != "audio": - logger.info(f"Ignoring non-audio track from {peer.peer_name}: {track.kind}") - return - - # This bot (vibevoice) does not use incoming WebRTC audio for TTS or - # waveform rendering. Ignore audio tracks entirely to avoid populating - # the shared waveform buffer with remote audio. The generated TTS audio - # is stored in the module-global `_audio_data` and will be used for the - # waveform and silent audio track playback instead. - logger.info(f"vibevoice: ignoring incoming audio track from {peer.peer_name}") - return - - # Initialize raw audio buffer for immediate graphing - if peer.peer_name not in WaveformVideoTrack.buffer: - WaveformVideoTrack.buffer[peer.peer_name] = np.array([], dtype=np.float32) - - # Start processing frames immediately (before processor is ready) - logger.info(f"Starting frame processing loop for {peer.peer_name}") - frame_count = 0 - while True: - try: - frame = await track.recv() - frame_count += 1 - - if frame_count % 100 == 0: - logger.debug(f"Received {frame_count} frames from {peer.peer_name}") - - except MediaStreamError as e: - logger.info(f"Audio stream ended for {peer.peer_name}: {e}") - break - except Exception as e: - logger.error(f"Error receiving frame from {peer.peer_name}: {e}") - break - - if isinstance(frame, AudioFrame): - try: - # Convert frame to numpy array - audio_data = frame.to_ndarray() - - # Handle audio format conversion - audio_data = _process_audio_frame(audio_data, frame) - - # Resample if needed - if frame.sample_rate != SAMPLE_RATE: - audio_data = _resample_audio( - audio_data, frame.sample_rate, SAMPLE_RATE - ) - - # Convert to float32 - audio_data_float32 = audio_data.astype(np.float32) - - # Update visualization buffer immediately - WaveformVideoTrack.buffer[peer.peer_name] = np.concatenate( - [WaveformVideoTrack.buffer[peer.peer_name], audio_data_float32] - ) - # Limit buffer size to last 10 seconds - max_samples = SAMPLE_RATE * 10 - if len(WaveformVideoTrack.buffer[peer.peer_name]) > max_samples: - WaveformVideoTrack.buffer[peer.peer_name] = ( - WaveformVideoTrack.buffer[peer.peer_name][-max_samples:] - ) - - # # Process with optimized processor if available - # if peer.peer_name in _audio_processors: - # audio_processor = _audio_processors[peer.peer_name] - # audio_processor.add_audio_data(audio_data_float32) - - except Exception as e: - logger.error(f"Error processing audio frame for {peer.peer_name}: {e}") - continue - - # If processor already exists, just continue processing - # audio_processor = _audio_processors[peer.peer_name] - logger.info(f"Continuing OpenVINO audio processing for {peer.peer_name}") - - try: - frame_count = 0 - logger.info(f"Entering frame processing loop for {peer.peer_name}") - while True: - try: - logger.debug(f"Waiting for frame from {peer.peer_name}") - frame = await track.recv() - frame_count += 1 - - if frame_count == 1: - logger.info(f"Received first frame from {peer.peer_name}") - elif frame_count % 50 == 0: - logger.info(f"Received {frame_count} frames from {peer.peer_name}") - - except MediaStreamError as e: - logger.info(f"Audio stream ended for {peer.peer_name}: {e}") - break - except Exception as e: - logger.error(f"Error receiving frame from {peer.peer_name}: {e}") - break - - if isinstance(frame, AudioFrame): - try: - # Convert frame to numpy array - audio_data = frame.to_ndarray() - - # Handle audio format conversion - audio_data = _process_audio_frame(audio_data, frame) - - # Resample if needed - if frame.sample_rate != SAMPLE_RATE: - audio_data = _resample_audio( - audio_data, frame.sample_rate, SAMPLE_RATE - ) - - # Convert to float32 - audio_data_float32 = audio_data.astype(np.float32) - - logger.debug( - f"Processed audio frame {frame_count} from {peer.peer_name}: {len(audio_data_float32)} samples" - ) - - # Process with optimized processor if available - # audio_processor.add_audio_data(audio_data_float32) - - except Exception as e: - logger.error( - f"Error processing audio frame for {peer.peer_name}: {e}" - ) - continue - - except Exception as e: - logger.error( - f"Unexpected error in audio processing for {peer.peer_name}: {e}", - exc_info=True, - ) - finally: - cleanup_peer_processor(peer.peer_name) - - -def _process_audio_frame( - audio_data: npt.NDArray[Any], frame: AudioFrame -) -> npt.NDArray[np.float32]: - """Process audio frame format conversion.""" - # Handle stereo to mono conversion - if audio_data.ndim == 2: - if audio_data.shape[0] == 1: - audio_data = audio_data.squeeze(0) - else: - audio_data = np.mean( - audio_data, axis=0 if audio_data.shape[0] > audio_data.shape[1] else 1 - ) - - # Normalize based on data type - if audio_data.dtype == np.int16: - audio_data = audio_data.astype(np.float32) / 32768.0 - elif audio_data.dtype == np.int32: - audio_data = audio_data.astype(np.float32) / 2147483648.0 - - return audio_data.astype(np.float32) - - -def _resample_audio( - audio_data: npt.NDArray[np.float32], orig_sr: int, target_sr: int -) -> npt.NDArray[np.float32]: - """Resample audio efficiently.""" - try: - # Handle stereo audio by converting to mono if necessary - if audio_data.ndim > 1: - audio_data = np.mean(audio_data, axis=1) - - # Use high-quality resampling - resampled = librosa.resample( # type: ignore - audio_data.astype(np.float64), - orig_sr=orig_sr, - target_sr=target_sr, - res_type="kaiser_fast", # Good balance of quality and speed - ) - return resampled.astype(np.float32) # type: ignore - except Exception as e: - logger.error(f"Resampling failed: {e}") - raise ValueError( - f"Failed to resample audio from {orig_sr} Hz to {target_sr} Hz: {e}" - ) - - -# Public API functions -def agent_info() -> Dict[str, str]: - return { - "name": AGENT_NAME, - "description": AGENT_DESCRIPTION, - "has_media": "true", - "configurable": "true", - } - - -def get_config_schema() -> Dict[str, Any]: - """Get the configuration schema for the Whisper bot""" - return { - "bot_name": AGENT_NAME, - "version": "1.0", - "parameters": [ - ], - "categories": [ - ], - } - - -def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool: - """Handle configuration update for a specific lobby""" - global _model_id, _device, _ov_config - - try: - logger.info(f"Updating TTS config for lobby {lobby_id}: {config_values}") - - config_applied = False - - config_applied = True - - return config_applied - - except Exception as e: - logger.error(f"Failed to apply Whisper config update: {e}") - return False - - -def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]: - """Create agent tracks. Provides a synthetic video waveform track and a silent audio track for compatibility.""" - - class SilentAudioTrack(MediaStreamTrack): - kind = "audio" - - def __init__( - self, sample_rate: int = SAMPLE_RATE, channels: int = 1, fps: int = 50 - ): - super().__init__() - self.sample_rate = sample_rate - self.channels = channels - self.fps = fps - self.samples_per_frame = int(self.sample_rate / self.fps) - self._timestamp = 0 - # Per-track playback buffer (float32, mono, -1..1) - self._play_buffer = np.array([], dtype=np.float32) - # Source sample rate of the buffered audio (if any) - self._play_src_sr = None - # Phase for synthetic sine tone (radians) - self._sine_phase = 0.0 - - async def recv(self) -> AudioFrame: - global _audio_data - # If new global TTS audio was produced, grab and prepare it for playback - logger.info( - "recv called with _audio_data: %s and size: %s", - "set" if _audio_data is not None else "unset", - getattr(_audio_data, "size", 0) if _audio_data is not None else "N/A", - ) - try: - if _audio_data is not None and getattr(_audio_data, "size", 0) > 0: - # Expect background worker to provide audio already resampled to - # the track sample rate (SAMPLE_RATE). Keep recv fast and - # non-blocking by avoiding any heavy DSP here. - try: - src_audio = np.asarray(_audio_data, dtype=np.float32) - # Clear global buffer (we consumed it into this track) - _audio_data = None - # Use the provided audio buffer directly as play buffer. - self._play_buffer = src_audio.astype(np.float32) - self._play_src_sr = self.sample_rate - except Exception: - logger.exception( - "Failed to prepare TTS audio buffer in recv; falling back to silence" - ) - - # (Intentionally avoid expensive I/O/DSP here - background worker - # will perform resampling and any debug file writes.) - - # Prepare output samples for this frame - if self._play_buffer.size > 0: - take = min(self.samples_per_frame, int(self._play_buffer.size)) - out = self._play_buffer[:take] - # Advance buffer - if take >= self._play_buffer.size: - self._play_buffer = np.array([], dtype=np.float32) - else: - self._play_buffer = self._play_buffer[take:] - else: - # No TTS audio buffered; output a 110 Hz sine tone at 0.1 volume - freq = 110.0 - volume = 0.1 - n = self.samples_per_frame - # incremental phase per sample - phase_inc = 2.0 * np.pi * freq / float(self.sample_rate) - samples_idx = np.arange(n, dtype=np.float32) - out = volume * np.sin(self._sine_phase + phase_inc * samples_idx) - # advance and wrap phase - self._sine_phase = (self._sine_phase + phase_inc * n) % (2.0 * np.pi) - logger.debug( - "No TTS audio: emitting 110Hz test tone (%d samples at %d Hz)", - n, - self.sample_rate, - ) - - # Convert float32 [-1.0,1.0] to int16 PCM - pcm = np.clip(out, -1.0, 1.0) - pcm_int16 = (pcm * 32767.0).astype(np.int16) - - # aiortc AudioFrame.from_ndarray expects shape (channels, samples) - if self.channels == 1: - data = np.expand_dims(pcm_int16, axis=0) - layout = "mono" - else: - # Duplicate mono into stereo - data = np.vstack([pcm_int16, pcm_int16]) - layout = "stereo" - - frame = AudioFrame.from_ndarray(data, layout=layout) - frame.sample_rate = self.sample_rate - frame.pts = self._timestamp - frame.time_base = fractions.Fraction(1, self.sample_rate) - self._timestamp += self.samples_per_frame - - # Pace the frame rate to avoid busy-looping - await asyncio.sleep(1 / self.fps) - - return frame - except Exception as e: - logger.exception(f"Error in SilentAudioTrack.recv: {e}") - # On error, fall back to silent frame - data = np.zeros((self.channels, self.samples_per_frame), dtype=np.int16) - frame = AudioFrame.from_ndarray( - data, layout="mono" if self.channels == 1 else "stereo" - ) - frame.sample_rate = self.sample_rate - frame.pts = self._timestamp - frame.time_base = fractions.Fraction(1, self.sample_rate) - self._timestamp += self.samples_per_frame - await asyncio.sleep(1 / self.fps) - return frame - - try: - video_track = WaveformVideoTrack( - session_name=session_name, width=640, height=480, fps=15 - ) - audio_track = SilentAudioTrack() - return {"video": video_track, "audio": audio_track} - except Exception as e: - logger.error(f"Failed to create agent tracks: {e}") - return {} - - -async def handle_chat_message( - chat_message: ChatMessageModel, - send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]], -) -> Optional[str]: - """Handle incoming chat messages.""" - global _audio_data, _lobby_threads - - logger.info(f"Received chat message: {chat_message.message}") - - lobby_id = getattr(chat_message, "lobby_id", None) - if not lobby_id: - # If no lobby id present, just spawn a worker without dedup logic - lobby_id = "__no_lobby__" - - # Ensure we only run one background worker per lobby - loop = asyncio.get_running_loop() - - with _lobby_threads_lock: - existing = _lobby_threads.get(lobby_id) - # existing may be a Thread or an asyncio Future-like object - already_running = False - try: - if existing is not None: - if hasattr(existing, "is_alive") and existing.is_alive(): - already_running = True - elif hasattr(existing, "done") and not existing.done(): - already_running = True - except Exception: - # Conservative: assume running if we can't determine - already_running = True - - if already_running: - logger.info("Chat processing already active for lobby %s", lobby_id) - try: - # Prefer using the bound send/create chat helpers if available - if _create_chat_message_func is not None and _send_chat_func is not None: - cm = _create_chat_message_func("Already processing", lobby_id) - await _send_chat_func(cm) - else: - # Fallback to provided send_message_func. Some callers expect ChatMessageModel, - # but send_message_func also supports plain string in some integrations. - await send_message_func("Already processing") - except Exception: - logger.exception("Failed to send 'Already processing' reply") - return None - - # Create background worker function (runs in threadpool via run_in_executor) - def _background_worker(lobby: str, msg: ChatMessageModel, _loop: asyncio.AbstractEventLoop) -> None: - global _audio_data - logger.info("TTS background worker starting for lobby %s (msg id=%s)", lobby, getattr(msg, "id", "no-id")) - start_time = time.time() - try: - try: - # Prefer an already-bound engine getter if present (useful in tests) - if globals().get("get_tts_engine"): - engine_getter = globals().get("get_tts_engine") - else: - from .vibevoicetts import get_tts_engine as engine_getter - - logger.info("TTS engine getter resolved: %s", getattr(engine_getter, "__name__", str(engine_getter))) - - engine = engine_getter() - logger.info("TTS engine instance created: %s", type(engine)) - # Blocking TTS call moved to background thread - logger.info("Invoking engine.text_to_speech for lobby %s", lobby) - raw_audio = engine.text_to_speech(text=f"Speaker 1: {msg.message}", verbose=True) - logger.info("TTS generation completed for lobby %s in %.2fs", lobby, time.time() - start_time) - - # Determine source sample rate if available - try: - src_sr = engine.get_sample_rate() - except Exception: - src_sr = 24000 - - # Ensure numpy array and float32 - try: - raw_audio_arr = np.asarray(raw_audio, dtype=np.float32) - except Exception: - raw_audio_arr = np.array([], dtype=np.float32) - - # Resample to track sample rate if needed (do this in bg thread) - if raw_audio_arr.size > 0 and src_sr != SAMPLE_RATE: - try: - resampled = librosa.resample( - raw_audio_arr.astype(np.float64), - orig_sr=src_sr, - target_sr=SAMPLE_RATE, - res_type="kaiser_fast", - ) - _audio_data = resampled.astype(np.float32) - logger.info("Background worker resampled audio from %d Hz to %d Hz (samples=%d)", src_sr, SAMPLE_RATE, getattr(_audio_data, "size", 0)) - except Exception: - logger.exception("Background resampling failed; using raw audio as-is") - _audio_data = raw_audio_arr.astype(np.float32) - else: - _audio_data = raw_audio_arr.astype(np.float32) - logger.info("Background worker assigned raw audio buffer (samples=%d, src_sr=%s)", getattr(_audio_data, "size", 0), src_sr) - except Exception: - # Fallback: try module-level convenience function if monkeypatched - try: - tts_fn = globals().get("text_to_speech") - if tts_fn: - logger.info("Using monkeypatched text_to_speech convenience function") - raw_audio = tts_fn(text=f"Speaker 1: {msg.message}") - else: - # Last resort: import the convenience function lazily - logger.info("Importing text_to_speech fallback from vibevoicetts") - from .vibevoicetts import text_to_speech as tts_fn - raw_audio = tts_fn(text=f"Speaker 1: {msg.message}") - logger.info("Fallback TTS invocation completed for lobby %s", lobby) - # normalize into numpy here as well - try: - raw_audio_arr = np.asarray(raw_audio, dtype=np.float32) - except Exception: - raw_audio_arr = np.array([], dtype=np.float32) - - _audio_data = raw_audio_arr.astype(np.float32) - logger.info("Background worker assigned raw audio buffer (samples=%d)", getattr(_audio_data, 'size', 0)) - except Exception: - logger.exception("Failed to perform TTS in background worker") - - except Exception: - logger.exception("Unhandled error in background TTS worker for lobby %s", lobby) - finally: - # Update job metadata and cleanup thread entry for this lobby - try: - with _lobby_threads_lock: - # Update job metadata if present - job = _lobby_jobs.get(lobby) - if job is not None: - job["end_time"] = time.time() - job["status"] = "finished" - try: - job["audio_samples"] = int(getattr(_audio_data, "size", 0)) if _audio_data is not None else 0 - except Exception: - job["audio_samples"] = None - - # If the stored worker is the current thread, remove it. For futures, - # they will be removed by the done callback registered below. - th = _lobby_threads.get(lobby) - if th is threading.current_thread(): - del _lobby_threads[lobby] - except Exception: - logger.exception("Error cleaning up background thread record for lobby %s", lobby) - - # Schedule the worker in the event loop's default threadpool. This avoids - # raw Thread.start() races and integrates better with asyncio-based hosts. - logger.info("Scheduling background TTS worker for lobby %s via run_in_executor", lobby_id) - worker_obj = None - try: - # Use a small wrapper so the very first thing the thread does is emit a log - def _bg_wrapper(lobby_w: str, msg_w: ChatMessageModel, loop_w: asyncio.AbstractEventLoop) -> None: - logger.info("Background worker wrapper started for lobby %s (thread=%s)", lobby_w, threading.current_thread()) - try: - _background_worker(lobby_w, msg_w, loop_w) - finally: - logger.info("Background worker wrapper exiting for lobby %s (thread=%s)", lobby_w, threading.current_thread()) - - fut = loop.run_in_executor(None, _bg_wrapper, lobby_id, chat_message, loop) - worker_obj = fut - _lobby_threads[lobby_id] = worker_obj - logger.info("Scheduled background TTS worker (wrapper) for lobby %s: %s", lobby_id, fut) - - # Attach a done callback to clean up registry and update job metadata when finished - def _on_worker_done(fut_obj: "asyncio.Future[Any]") -> None: - try: - # Log exception or result for easier debugging - try: - exc = fut_obj.exception() - except Exception: - exc = None - - if exc: - logger.exception("Background TTS worker for lobby %s raised exception", lobby_id, exc_info=exc) - else: - try: - res = fut_obj.result() - logger.info("Background TTS worker for lobby %s completed successfully; result=%s", lobby_id, str(res)) - except Exception: - # Some futures may not have a result or may raise when retrieving - logger.info("Background TTS worker for lobby %s completed (no result).", lobby_id) - - with _lobby_threads_lock: - # mark job finished and attach exception info if present - job = _lobby_jobs.get(lobby_id) - if job is not None: - job["end_time"] = time.time() - job["status"] = "finished" - try: - job["audio_samples"] = int(getattr(_audio_data, "size", 0)) if _audio_data is not None else 0 - except Exception: - job["audio_samples"] = None - - # remove only if the stored object is this future - stored = _lobby_threads.get(lobby_id) - if stored is fut_obj: - del _lobby_threads[lobby_id] - except Exception: - logger.exception("Error in worker done callback for lobby %s", lobby_id) - - try: - fut.add_done_callback(_on_worker_done) - except Exception: - # ignore if the future does not support callbacks - pass - - except Exception: - # Fallback to raw Thread in the unlikely case run_in_executor fails - thread = threading.Thread(target=_background_worker, args=(lobby_id, chat_message, loop), daemon=True) - worker_obj = thread - _lobby_threads[lobby_id] = thread - logger.info("Created background TTS thread for lobby %s (fallback): %s", lobby_id, thread) - - # Record job metadata for status display - try: - with _lobby_threads_lock: - _lobby_jobs[lobby_id] = { - "status": "running", - "start_time": time.time(), - "message": getattr(chat_message, "message", ""), - "worker": worker_obj, - "error": None, - "end_time": None, - "audio_samples": None, - } - except Exception: - logger.exception("Failed to record lobby job metadata for %s", lobby_id) - - # If we fell back to a raw Thread, start it now; otherwise the future is already scheduled. - try: - stored = _lobby_threads.get(lobby_id) - if stored is not None and hasattr(stored, "start"): - logger.info("Starting fallback background TTS thread for lobby %s", lobby_id) - stored.start() - logger.info("Background TTS thread started for lobby %s", lobby_id) - except Exception: - logger.exception("Failed to start background TTS worker for %s", lobby_id) - - return None - - -async def on_track_received(peer: Peer, track: MediaStreamTrack) -> None: - """Callback when a new track is received from a peer.""" - await handle_track_received(peer, track) - - -def get_track_handler() -> Callable[[Peer, MediaStreamTrack], Awaitable[None]]: - """Return the track handler function.""" - return on_track_received - - -def bind_send_chat_function( - send_chat_func: Callable[[ChatMessageModel], Awaitable[None]], - create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel], -) -> None: - """Bind the send chat function.""" - global _send_chat_func, _create_chat_message_func, _audio_processors - - logger.info("Binding send chat function to OpenVINO whisper agent") - _send_chat_func = send_chat_func - _create_chat_message_func = create_chat_message_func - - # Update existing processors - # for peer_name, processor in _audio_processors.items(): - # processor.send_chat_func = send_chat_func - # processor.create_chat_message_func = create_chat_message_func - # logger.debug(f"Updated processor for {peer_name} with new send chat function") - - -def cleanup_peer_processor(peer_name: str) -> None: - """Clean up processor for disconnected peer.""" - global _audio_processors - - if peer_name in WaveformVideoTrack.buffer: - del WaveformVideoTrack.buffer[peer_name] - - - diff --git a/voicebot/bots/vibevoicetts.py b/voicebot/bots/vibevoicetts.py index 8727ac1..a05bea4 100644 --- a/voicebot/bots/vibevoicetts.py +++ b/voicebot/bots/vibevoicetts.py @@ -5,9 +5,47 @@ from typing import Any, List, Tuple, Union, Optional import time import torch import numpy as np +import sys -from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference -from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor +# Defer importing the external `vibevoice` package until we actually need it. +# In some environments the `vibevoice` package isn't installed into site-packages +# but the repo contains a local copy under `voicebot/VibeVoice`. Attempt a lazy +# import and, if that fails, add the local path(s) to sys.path and retry. + +def _import_vibevoice_symbols(): + try: + from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference + from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor + return VibeVoiceForConditionalGenerationInference, VibeVoiceProcessor + except Exception: + # If a required package (like `diffusers`) is missing inside the + # container or venv, importing deeper VibeVoice modules will raise + # ModuleNotFoundError. Detect that and raise a clearer error that + # includes install instructions. + import traceback as _tb + exc_type, exc_val, exc_tb = _tb.sys.exc_info() + if isinstance(exc_val, ModuleNotFoundError): + missing = str(exc_val).split("'")[1] if "'" in str(exc_val) else str(exc_val) + raise ModuleNotFoundError( + f"Missing dependency when importing VibeVoice: {missing}.\n" + "Install required packages inside the voicebot container.\n" + "Example (inside container):\n" + " PYTHONPATH=/shared:/voicebot uv run python3 -m pip install diffusers accelerate safetensors\n" + "Or add the packages to the voicebot service environment / pyproject and rebuild." + ) from exc_val + # Try adding likely repository-local paths where VibeVoice lives + base = os.path.dirname(__file__) # voicebot/bots + candidates = [ + os.path.abspath(os.path.join(base, "..", "VibeVoice")), + os.path.abspath(os.path.join("/", "voicebot", "VibeVoice")), + ] + for p in candidates: + if os.path.isdir(p) and p not in sys.path: + sys.path.insert(0, p) + # Retry import + from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference + from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor + return VibeVoiceForConditionalGenerationInference, VibeVoiceProcessor from shared.logger import logger @@ -155,6 +193,8 @@ class VibeVoiceTTS: def _load_model(self): """Load the model and processor with device-specific configuration.""" logger.info(f"Loading processor & model from {self.model_path}") + # Ensure external vibevoice symbols are available (lazy import) + VibeVoiceForConditionalGenerationInference, VibeVoiceProcessor = _import_vibevoice_symbols() self.processor = VibeVoiceProcessor.from_pretrained(self.model_path) # Decide dtype & attention implementation @@ -401,15 +441,16 @@ class VibeVoiceTTS: logger.info(f"Generated tokens: {generated_tokens}") logger.info(f"Total tokens: {output_tokens}") + # Return audio data as numpy array # Return audio data as numpy array if outputs.speech_outputs and outputs.speech_outputs[0] is not None: audio_tensor = outputs.speech_outputs[0] - # Convert to numpy array on CPU + # Convert to numpy array on CPU, ensuring compatible dtype if hasattr(audio_tensor, 'cpu'): - audio_data = audio_tensor.cpu().numpy() + audio_data = audio_tensor.cpu().float().numpy() # Convert to float32 first else: - audio_data = np.array(audio_tensor) + audio_data = np.array(audio_tensor, dtype=np.float32) # Ensure it's a 1D array if audio_data.ndim > 1: diff --git a/voicebot/requirements.txt b/voicebot/requirements.txt index f754094..b523b25 100644 --- a/voicebot/requirements.txt +++ b/voicebot/requirements.txt @@ -1,3 +1,4 @@ +--extra-index-url https://download.pytorch.org/whl/xpu about-time==4.2.1 aiofiles==24.1.0 aiohappyeyeballs==2.6.1 @@ -74,20 +75,6 @@ ninja==1.13.0 nncf==2.18.0 numba==0.61.2 numpy==2.2.6 -nvidia-cublas-cu12==12.8.4.1 -nvidia-cuda-cupti-cu12==12.8.90 -nvidia-cuda-nvrtc-cu12==12.8.93 -nvidia-cuda-runtime-cu12==12.8.90 -nvidia-cudnn-cu12==9.10.2.21 -nvidia-cufft-cu12==11.3.3.83 -nvidia-cufile-cu12==1.13.1.3 -nvidia-curand-cu12==10.3.9.90 -nvidia-cusolver-cu12==11.7.3.90 -nvidia-cusparse-cu12==12.5.8.93 -nvidia-cusparselt-cu12==0.7.1 -nvidia-nccl-cu12==2.27.3 -nvidia-nvjitlink-cu12==12.8.93 -nvidia-nvtx-cu12==12.8.90 onnx==1.19.0 openai==1.107.2 openai-whisper @ git+https://github.com/openai/whisper.git@c0d2f624c09dc18e709e37c2ad90c039a4eb72a2 @@ -157,11 +144,13 @@ threadpoolctl==3.6.0 tiktoken==0.11.0 tokenizers==0.21.4 tomlkit==0.13.3 -torch==2.8.0 -torchvision==0.23.0 tqdm==4.67.1 +torch==2.8.0+xpu +torchvision==0.23.0+xpu +torchaudio==2.8.0+xpu transformers==4.53.3 -triton==3.4.0 +diffusers +accelerate typer==0.17.4 typing-extensions==4.15.0 typing-inspection==0.4.1 @@ -172,4 +161,4 @@ watchdog==6.0.0 websockets==15.0.1 wrapt==1.17.3 xxhash==3.5.0 -yarl==1.20.1 +yarl==1.20.1 \ No newline at end of file