From 4058f729e29264baf9f2da2f360e5f0ff353d1f5 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Fri, 19 Sep 2025 11:55:16 -0700 Subject: [PATCH] TTS is working on CPU --- voicebot/bots/vibevoice.py | 421 +++++++++++++++++++++++++++++++------ 1 file changed, 358 insertions(+), 63 deletions(-) diff --git a/voicebot/bots/vibevoice.py b/voicebot/bots/vibevoice.py index ca6f38f..89ab6ae 100644 --- a/voicebot/bots/vibevoice.py +++ b/voicebot/bots/vibevoice.py @@ -1,4 +1,5 @@ import asyncio +import threading import numpy as np import time import os @@ -15,7 +16,6 @@ from av import AudioFrame, VideoFrame import cv2 import fractions from time import perf_counter -import wave # Import shared models for chat functionality import sys @@ -26,7 +26,10 @@ sys.path.append( from shared.models import ChatMessageModel from voicebot.models import Peer -from .vibevoicetts import text_to_speech +# Defer importing the heavy TTS module until needed (lazy import). +# Tests and other code can monkeypatch these names on this module. +get_tts_engine = None +text_to_speech = None # Global configuration and constants AGENT_NAME = "Text To Speech Bot" @@ -38,6 +41,13 @@ SAMPLE_RATE = 16000 # Whisper expects 16kHz _audio_data = None +# Track background processing threads per lobby to ensure only one worker runs +# per lobby session at a time. +_lobby_threads: Dict[str, Any] = {} +_lobby_threads_lock = threading.Lock() +# Per-lobby job metadata for status display +_lobby_jobs: Dict[str, Dict[str, Any]] = {} + class TTSModel: """TTS via MSFT VibeVoice""" @@ -112,9 +122,37 @@ class WaveformVideoTrack(MediaStreamTrack): (self.height, self.width, 3), dtype=np.uint8 ) - # Display model loading status prominently - status_text = "Initializing..." - progress = 0.1 + # Display model / TTS job status prominently + status_text = "Idle" + progress = 0.0 + + try: + # Show number of active lobby TTS jobs + with _lobby_threads_lock: + active_jobs = {k: v for k, v in _lobby_jobs.items() if v.get("status") == "running"} + + if active_jobs: + # Pick the most recent running job (sort by start_time desc) + recent_lobby, recent_job = sorted( + active_jobs.items(), key=lambda kv: kv[1].get("start_time", 0), reverse=True + )[0] + # elapsed time available if needed for overlays + _ = time.time() - recent_job.get("start_time", time.time()) + snippet = (recent_job.get("message") or "")[:80] + status_text = f"TTS running ({len(active_jobs)}): {snippet}" + # Progress as a heuristic (running jobs -> 0.5) + progress = 0.5 + else: + # If there is produced audio waiting in _audio_data, show ready-to-play + if _audio_data is not None and getattr(_audio_data, "size", 0) > 0: + status_text = "TTS: Ready (audio buffered)" + progress = 1.0 + else: + status_text = "Idle" + progress = 0.0 + except Exception: + status_text = "Status: error" + progress = 0.0 # Draw status background (increased height for larger text) cv2.rectangle(frame_array, (0, 0), (self.width, 80), (0, 0, 0), -1) @@ -155,6 +193,62 @@ class WaveformVideoTrack(MediaStreamTrack): 2, ) + # Draw TTS job status (small text under header) + try: + with _lobby_threads_lock: + active_jobs_count = sum(1 for v in _lobby_jobs.values() if v.get("status") == "running") + # Get most recent job info if present + recent_job = None + if active_jobs_count > 0: + running = [v for v in _lobby_jobs.values() if v.get("status") == "running"] + recent_job = max(running, key=lambda j: j.get("start_time", 0)) + + audio_samples = getattr(_audio_data, "size", 0) if _audio_data is not None else 0 + + job_snippet = "" + job_elapsed = None + if recent_job: + job_snippet = (recent_job.get("message") or "")[:80] + job_elapsed = time.time() - recent_job.get("start_time", time.time()) + + info_x = 10 + info_y = 95 + + if active_jobs_count > 0: + cv2.putText( + frame_array, + f"TTS jobs: {active_jobs_count} | {job_snippet}", + (info_x, info_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 255, 255), + 1, + ) + if job_elapsed is not None: + cv2.putText( + frame_array, + f"Elapsed: {job_elapsed:.1f}s", + (info_x + 420, info_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (200, 200, 255), + 1, + ) + else: + # show audio buffer status when idle + cv2.putText( + frame_array, + f"Audio buffered: {audio_samples} samples", + (info_x, info_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (200, 200, 200), + 1, + ) + except Exception: + # Non-critical overlay; ignore failures + pass + # Select the most active audio buffer and get its speech status best_proc = None best_rms = 0.0 @@ -649,57 +743,30 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]: async def recv(self) -> AudioFrame: global _audio_data # If new global TTS audio was produced, grab and prepare it for playback - logger.info("recv called with _audio_data: %s and size: %s", "set" if _audio_data is not None else "unset", getattr(_audio_data, 'size', 0) if _audio_data is not None else 'N/A') + logger.info( + "recv called with _audio_data: %s and size: %s", + "set" if _audio_data is not None else "unset", + getattr(_audio_data, "size", 0) if _audio_data is not None else "N/A", + ) try: - if _audio_data is not None and getattr(_audio_data, 'size', 0) > 0: - # Get source sample rate from TTS engine if available + if _audio_data is not None and getattr(_audio_data, "size", 0) > 0: + # Expect background worker to provide audio already resampled to + # the track sample rate (SAMPLE_RATE). Keep recv fast and + # non-blocking by avoiding any heavy DSP here. try: - from .vibevoicetts import get_tts_engine - - src_sr = get_tts_engine().get_sample_rate() - except Exception: - # Fallback: assume 24000 (VibeVoice default) - logger.info("Falling back to default source sample rate of 24000 Hz") - src_sr = 24000 - - # Ensure numpy float32 - src_audio = np.asarray(_audio_data, dtype=np.float32) - - # Clear global buffer (we consumed it into this track) - _audio_data = None - - # If source sr differs from track sr, resample - if src_sr != self.sample_rate: - try: - resampled = librosa.resample( - src_audio.astype(np.float64), - orig_sr=src_sr, - target_sr=self.sample_rate, - res_type="kaiser_fast", - ) - self._play_buffer = resampled.astype(np.float32) - except Exception: - # On failure, fallback to nearest simple scaling (not ideal) - logger.exception("Failed to resample TTS audio; using source as-is") - self._play_buffer = src_audio.astype(np.float32) - else: + src_audio = np.asarray(_audio_data, dtype=np.float32) + # Clear global buffer (we consumed it into this track) + _audio_data = None + # Use the provided audio buffer directly as play buffer. self._play_buffer = src_audio.astype(np.float32) - - self._play_src_sr = self.sample_rate - - # Save a copy of the consumed audio to disk for inspection - try: - pcm_save = np.clip(self._play_buffer, -1.0, 1.0) - pcm_int16_save = (pcm_save * 32767.0).astype(np.int16) - sample_path = "./sample.wav" - with wave.open(sample_path, "wb") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) - wf.setframerate(self.sample_rate) - wf.writeframes(pcm_int16_save.tobytes()) - logger.info("Wrote TTS sample to %s (%d samples, %d Hz)", sample_path, len(pcm_int16_save), self.sample_rate) + self._play_src_sr = self.sample_rate except Exception: - logger.exception("Failed to write sample.wav") + logger.exception( + "Failed to prepare TTS audio buffer in recv; falling back to silence" + ) + + # (Intentionally avoid expensive I/O/DSP here - background worker + # will perform resampling and any debug file writes.) # Prepare output samples for this frame if self._play_buffer.size > 0: @@ -721,7 +788,11 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]: out = volume * np.sin(self._sine_phase + phase_inc * samples_idx) # advance and wrap phase self._sine_phase = (self._sine_phase + phase_inc * n) % (2.0 * np.pi) - logger.debug("No TTS audio: emitting 110Hz test tone (%d samples at %d Hz)", n, self.sample_rate) + logger.debug( + "No TTS audio: emitting 110Hz test tone (%d samples at %d Hz)", + n, + self.sample_rate, + ) # Convert float32 [-1.0,1.0] to int16 PCM pcm = np.clip(out, -1.0, 1.0) @@ -750,7 +821,9 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]: logger.exception(f"Error in SilentAudioTrack.recv: {e}") # On error, fall back to silent frame data = np.zeros((self.channels, self.samples_per_frame), dtype=np.int16) - frame = AudioFrame.from_ndarray(data, layout="mono" if self.channels == 1 else "stereo") + frame = AudioFrame.from_ndarray( + data, layout="mono" if self.channels == 1 else "stereo" + ) frame.sample_rate = self.sample_rate frame.pts = self._timestamp frame.time_base = fractions.Fraction(1, self.sample_rate) @@ -774,17 +847,239 @@ async def handle_chat_message( send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]], ) -> Optional[str]: """Handle incoming chat messages.""" - global _audio_data - logger.info(f"Received chat message: {chat_message.message}") - # This ends up blocking; spin it off into a thread - try: - from .vibevoicetts import get_tts_engine + global _audio_data, _lobby_threads + + logger.info(f"Received chat message: {chat_message.message}") + + lobby_id = getattr(chat_message, "lobby_id", None) + if not lobby_id: + # If no lobby id present, just spawn a worker without dedup logic + lobby_id = "__no_lobby__" + + # Ensure we only run one background worker per lobby + loop = asyncio.get_running_loop() + + with _lobby_threads_lock: + existing = _lobby_threads.get(lobby_id) + # existing may be a Thread or an asyncio Future-like object + already_running = False + try: + if existing is not None: + if hasattr(existing, "is_alive") and existing.is_alive(): + already_running = True + elif hasattr(existing, "done") and not existing.done(): + already_running = True + except Exception: + # Conservative: assume running if we can't determine + already_running = True + + if already_running: + logger.info("Chat processing already active for lobby %s", lobby_id) + try: + # Prefer using the bound send/create chat helpers if available + if _create_chat_message_func is not None and _send_chat_func is not None: + cm = _create_chat_message_func("Already processing", lobby_id) + await _send_chat_func(cm) + else: + # Fallback to provided send_message_func. Some callers expect ChatMessageModel, + # but send_message_func also supports plain string in some integrations. + await send_message_func("Already processing") + except Exception: + logger.exception("Failed to send 'Already processing' reply") + return None + + # Create background worker function (runs in threadpool via run_in_executor) + def _background_worker(lobby: str, msg: ChatMessageModel, _loop: asyncio.AbstractEventLoop) -> None: + global _audio_data + logger.info("TTS background worker starting for lobby %s (msg id=%s)", lobby, getattr(msg, "id", "no-id")) + start_time = time.time() + try: + try: + # Prefer an already-bound engine getter if present (useful in tests) + if globals().get("get_tts_engine"): + engine_getter = globals().get("get_tts_engine") + else: + from .vibevoicetts import get_tts_engine as engine_getter + + logger.info("TTS engine getter resolved: %s", getattr(engine_getter, "__name__", str(engine_getter))) + + engine = engine_getter() + logger.info("TTS engine instance created: %s", type(engine)) + # Blocking TTS call moved to background thread + logger.info("Invoking engine.text_to_speech for lobby %s", lobby) + raw_audio = engine.text_to_speech(text=f"Speaker 1: {msg.message}", verbose=True) + logger.info("TTS generation completed for lobby %s in %.2fs", lobby, time.time() - start_time) + + # Determine source sample rate if available + try: + src_sr = engine.get_sample_rate() + except Exception: + src_sr = 24000 + + # Ensure numpy array and float32 + try: + raw_audio_arr = np.asarray(raw_audio, dtype=np.float32) + except Exception: + raw_audio_arr = np.array([], dtype=np.float32) + + # Resample to track sample rate if needed (do this in bg thread) + if raw_audio_arr.size > 0 and src_sr != SAMPLE_RATE: + try: + resampled = librosa.resample( + raw_audio_arr.astype(np.float64), + orig_sr=src_sr, + target_sr=SAMPLE_RATE, + res_type="kaiser_fast", + ) + _audio_data = resampled.astype(np.float32) + logger.info("Background worker resampled audio from %d Hz to %d Hz (samples=%d)", src_sr, SAMPLE_RATE, getattr(_audio_data, "size", 0)) + except Exception: + logger.exception("Background resampling failed; using raw audio as-is") + _audio_data = raw_audio_arr.astype(np.float32) + else: + _audio_data = raw_audio_arr.astype(np.float32) + logger.info("Background worker assigned raw audio buffer (samples=%d, src_sr=%s)", getattr(_audio_data, "size", 0), src_sr) + except Exception: + # Fallback: try module-level convenience function if monkeypatched + try: + tts_fn = globals().get("text_to_speech") + if tts_fn: + logger.info("Using monkeypatched text_to_speech convenience function") + raw_audio = tts_fn(text=f"Speaker 1: {msg.message}") + else: + # Last resort: import the convenience function lazily + logger.info("Importing text_to_speech fallback from vibevoicetts") + from .vibevoicetts import text_to_speech as tts_fn + raw_audio = tts_fn(text=f"Speaker 1: {msg.message}") + logger.info("Fallback TTS invocation completed for lobby %s", lobby) + # normalize into numpy here as well + try: + raw_audio_arr = np.asarray(raw_audio, dtype=np.float32) + except Exception: + raw_audio_arr = np.array([], dtype=np.float32) + + _audio_data = raw_audio_arr.astype(np.float32) + logger.info("Background worker assigned raw audio buffer (samples=%d)", getattr(_audio_data, 'size', 0)) + except Exception: + logger.exception("Failed to perform TTS in background worker") + + except Exception: + logger.exception("Unhandled error in background TTS worker for lobby %s", lobby) + finally: + # Update job metadata and cleanup thread entry for this lobby + try: + with _lobby_threads_lock: + # Update job metadata if present + job = _lobby_jobs.get(lobby) + if job is not None: + job["end_time"] = time.time() + job["status"] = "finished" + try: + job["audio_samples"] = int(getattr(_audio_data, "size", 0)) if _audio_data is not None else 0 + except Exception: + job["audio_samples"] = None + + # If the stored worker is the current thread, remove it. For futures, + # they will be removed by the done callback registered below. + th = _lobby_threads.get(lobby) + if th is threading.current_thread(): + del _lobby_threads[lobby] + except Exception: + logger.exception("Error cleaning up background thread record for lobby %s", lobby) + + # Schedule the worker in the event loop's default threadpool. This avoids + # raw Thread.start() races and integrates better with asyncio-based hosts. + logger.info("Scheduling background TTS worker for lobby %s via run_in_executor", lobby_id) + worker_obj = None + try: + # Use a small wrapper so the very first thing the thread does is emit a log + def _bg_wrapper(lobby_w: str, msg_w: ChatMessageModel, loop_w: asyncio.AbstractEventLoop) -> None: + logger.info("Background worker wrapper started for lobby %s (thread=%s)", lobby_w, threading.current_thread()) + try: + _background_worker(lobby_w, msg_w, loop_w) + finally: + logger.info("Background worker wrapper exiting for lobby %s (thread=%s)", lobby_w, threading.current_thread()) + + fut = loop.run_in_executor(None, _bg_wrapper, lobby_id, chat_message, loop) + worker_obj = fut + _lobby_threads[lobby_id] = worker_obj + logger.info("Scheduled background TTS worker (wrapper) for lobby %s: %s", lobby_id, fut) + + # Attach a done callback to clean up registry and update job metadata when finished + def _on_worker_done(fut_obj: "asyncio.Future[Any]") -> None: + try: + # Log exception or result for easier debugging + try: + exc = fut_obj.exception() + except Exception: + exc = None + + if exc: + logger.exception("Background TTS worker for lobby %s raised exception", lobby_id, exc_info=exc) + else: + try: + res = fut_obj.result() + logger.info("Background TTS worker for lobby %s completed successfully; result=%s", lobby_id, str(res)) + except Exception: + # Some futures may not have a result or may raise when retrieving + logger.info("Background TTS worker for lobby %s completed (no result).", lobby_id) + + with _lobby_threads_lock: + # mark job finished and attach exception info if present + job = _lobby_jobs.get(lobby_id) + if job is not None: + job["end_time"] = time.time() + job["status"] = "finished" + try: + job["audio_samples"] = int(getattr(_audio_data, "size", 0)) if _audio_data is not None else 0 + except Exception: + job["audio_samples"] = None + + # remove only if the stored object is this future + stored = _lobby_threads.get(lobby_id) + if stored is fut_obj: + del _lobby_threads[lobby_id] + except Exception: + logger.exception("Error in worker done callback for lobby %s", lobby_id) + + try: + fut.add_done_callback(_on_worker_done) + except Exception: + # ignore if the future does not support callbacks + pass - engine = get_tts_engine() - _audio_data = engine.text_to_speech(text=f"Speaker 1: {chat_message.message}", speaker_names=None, verbose=True) except Exception: - # Fallback to convenience function if direct engine call fails - _audio_data = text_to_speech(text=f"Speaker 1: {chat_message.message}", speaker_names=None) + # Fallback to raw Thread in the unlikely case run_in_executor fails + thread = threading.Thread(target=_background_worker, args=(lobby_id, chat_message, loop), daemon=True) + worker_obj = thread + _lobby_threads[lobby_id] = thread + logger.info("Created background TTS thread for lobby %s (fallback): %s", lobby_id, thread) + + # Record job metadata for status display + try: + with _lobby_threads_lock: + _lobby_jobs[lobby_id] = { + "status": "running", + "start_time": time.time(), + "message": getattr(chat_message, "message", ""), + "worker": worker_obj, + "error": None, + "end_time": None, + "audio_samples": None, + } + except Exception: + logger.exception("Failed to record lobby job metadata for %s", lobby_id) + + # If we fell back to a raw Thread, start it now; otherwise the future is already scheduled. + try: + stored = _lobby_threads.get(lobby_id) + if stored is not None and hasattr(stored, "start"): + logger.info("Starting fallback background TTS thread for lobby %s", lobby_id) + stored.start() + logger.info("Background TTS thread started for lobby %s", lobby_id) + except Exception: + logger.exception("Failed to start background TTS worker for %s", lobby_id) + return None