TTS is working on CPU

This commit is contained in:
James Ketr 2025-09-19 11:55:16 -07:00
parent a0aa65ec1c
commit 4058f729e2

View File

@ -1,4 +1,5 @@
import asyncio
import threading
import numpy as np
import time
import os
@ -15,7 +16,6 @@ from av import AudioFrame, VideoFrame
import cv2
import fractions
from time import perf_counter
import wave
# Import shared models for chat functionality
import sys
@ -26,7 +26,10 @@ sys.path.append(
from shared.models import ChatMessageModel
from voicebot.models import Peer
from .vibevoicetts import text_to_speech
# Defer importing the heavy TTS module until needed (lazy import).
# Tests and other code can monkeypatch these names on this module.
get_tts_engine = None
text_to_speech = None
# Global configuration and constants
AGENT_NAME = "Text To Speech Bot"
@ -38,6 +41,13 @@ SAMPLE_RATE = 16000 # Whisper expects 16kHz
_audio_data = None
# Track background processing threads per lobby to ensure only one worker runs
# per lobby session at a time.
_lobby_threads: Dict[str, Any] = {}
_lobby_threads_lock = threading.Lock()
# Per-lobby job metadata for status display
_lobby_jobs: Dict[str, Dict[str, Any]] = {}
class TTSModel:
"""TTS via MSFT VibeVoice"""
@ -112,9 +122,37 @@ class WaveformVideoTrack(MediaStreamTrack):
(self.height, self.width, 3), dtype=np.uint8
)
# Display model loading status prominently
status_text = "Initializing..."
progress = 0.1
# Display model / TTS job status prominently
status_text = "Idle"
progress = 0.0
try:
# Show number of active lobby TTS jobs
with _lobby_threads_lock:
active_jobs = {k: v for k, v in _lobby_jobs.items() if v.get("status") == "running"}
if active_jobs:
# Pick the most recent running job (sort by start_time desc)
recent_lobby, recent_job = sorted(
active_jobs.items(), key=lambda kv: kv[1].get("start_time", 0), reverse=True
)[0]
# elapsed time available if needed for overlays
_ = time.time() - recent_job.get("start_time", time.time())
snippet = (recent_job.get("message") or "")[:80]
status_text = f"TTS running ({len(active_jobs)}): {snippet}"
# Progress as a heuristic (running jobs -> 0.5)
progress = 0.5
else:
# If there is produced audio waiting in _audio_data, show ready-to-play
if _audio_data is not None and getattr(_audio_data, "size", 0) > 0:
status_text = "TTS: Ready (audio buffered)"
progress = 1.0
else:
status_text = "Idle"
progress = 0.0
except Exception:
status_text = "Status: error"
progress = 0.0
# Draw status background (increased height for larger text)
cv2.rectangle(frame_array, (0, 0), (self.width, 80), (0, 0, 0), -1)
@ -155,6 +193,62 @@ class WaveformVideoTrack(MediaStreamTrack):
2,
)
# Draw TTS job status (small text under header)
try:
with _lobby_threads_lock:
active_jobs_count = sum(1 for v in _lobby_jobs.values() if v.get("status") == "running")
# Get most recent job info if present
recent_job = None
if active_jobs_count > 0:
running = [v for v in _lobby_jobs.values() if v.get("status") == "running"]
recent_job = max(running, key=lambda j: j.get("start_time", 0))
audio_samples = getattr(_audio_data, "size", 0) if _audio_data is not None else 0
job_snippet = ""
job_elapsed = None
if recent_job:
job_snippet = (recent_job.get("message") or "")[:80]
job_elapsed = time.time() - recent_job.get("start_time", time.time())
info_x = 10
info_y = 95
if active_jobs_count > 0:
cv2.putText(
frame_array,
f"TTS jobs: {active_jobs_count} | {job_snippet}",
(info_x, info_y),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 255, 255),
1,
)
if job_elapsed is not None:
cv2.putText(
frame_array,
f"Elapsed: {job_elapsed:.1f}s",
(info_x + 420, info_y),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(200, 200, 255),
1,
)
else:
# show audio buffer status when idle
cv2.putText(
frame_array,
f"Audio buffered: {audio_samples} samples",
(info_x, info_y),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(200, 200, 200),
1,
)
except Exception:
# Non-critical overlay; ignore failures
pass
# Select the most active audio buffer and get its speech status
best_proc = None
best_rms = 0.0
@ -649,57 +743,30 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
async def recv(self) -> AudioFrame:
global _audio_data
# If new global TTS audio was produced, grab and prepare it for playback
logger.info("recv called with _audio_data: %s and size: %s", "set" if _audio_data is not None else "unset", getattr(_audio_data, 'size', 0) if _audio_data is not None else 'N/A')
logger.info(
"recv called with _audio_data: %s and size: %s",
"set" if _audio_data is not None else "unset",
getattr(_audio_data, "size", 0) if _audio_data is not None else "N/A",
)
try:
if _audio_data is not None and getattr(_audio_data, 'size', 0) > 0:
# Get source sample rate from TTS engine if available
if _audio_data is not None and getattr(_audio_data, "size", 0) > 0:
# Expect background worker to provide audio already resampled to
# the track sample rate (SAMPLE_RATE). Keep recv fast and
# non-blocking by avoiding any heavy DSP here.
try:
from .vibevoicetts import get_tts_engine
src_sr = get_tts_engine().get_sample_rate()
except Exception:
# Fallback: assume 24000 (VibeVoice default)
logger.info("Falling back to default source sample rate of 24000 Hz")
src_sr = 24000
# Ensure numpy float32
src_audio = np.asarray(_audio_data, dtype=np.float32)
# Clear global buffer (we consumed it into this track)
_audio_data = None
# If source sr differs from track sr, resample
if src_sr != self.sample_rate:
try:
resampled = librosa.resample(
src_audio.astype(np.float64),
orig_sr=src_sr,
target_sr=self.sample_rate,
res_type="kaiser_fast",
)
self._play_buffer = resampled.astype(np.float32)
except Exception:
# On failure, fallback to nearest simple scaling (not ideal)
logger.exception("Failed to resample TTS audio; using source as-is")
self._play_buffer = src_audio.astype(np.float32)
else:
src_audio = np.asarray(_audio_data, dtype=np.float32)
# Clear global buffer (we consumed it into this track)
_audio_data = None
# Use the provided audio buffer directly as play buffer.
self._play_buffer = src_audio.astype(np.float32)
self._play_src_sr = self.sample_rate
# Save a copy of the consumed audio to disk for inspection
try:
pcm_save = np.clip(self._play_buffer, -1.0, 1.0)
pcm_int16_save = (pcm_save * 32767.0).astype(np.int16)
sample_path = "./sample.wav"
with wave.open(sample_path, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(self.sample_rate)
wf.writeframes(pcm_int16_save.tobytes())
logger.info("Wrote TTS sample to %s (%d samples, %d Hz)", sample_path, len(pcm_int16_save), self.sample_rate)
self._play_src_sr = self.sample_rate
except Exception:
logger.exception("Failed to write sample.wav")
logger.exception(
"Failed to prepare TTS audio buffer in recv; falling back to silence"
)
# (Intentionally avoid expensive I/O/DSP here - background worker
# will perform resampling and any debug file writes.)
# Prepare output samples for this frame
if self._play_buffer.size > 0:
@ -721,7 +788,11 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
out = volume * np.sin(self._sine_phase + phase_inc * samples_idx)
# advance and wrap phase
self._sine_phase = (self._sine_phase + phase_inc * n) % (2.0 * np.pi)
logger.debug("No TTS audio: emitting 110Hz test tone (%d samples at %d Hz)", n, self.sample_rate)
logger.debug(
"No TTS audio: emitting 110Hz test tone (%d samples at %d Hz)",
n,
self.sample_rate,
)
# Convert float32 [-1.0,1.0] to int16 PCM
pcm = np.clip(out, -1.0, 1.0)
@ -750,7 +821,9 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
logger.exception(f"Error in SilentAudioTrack.recv: {e}")
# On error, fall back to silent frame
data = np.zeros((self.channels, self.samples_per_frame), dtype=np.int16)
frame = AudioFrame.from_ndarray(data, layout="mono" if self.channels == 1 else "stereo")
frame = AudioFrame.from_ndarray(
data, layout="mono" if self.channels == 1 else "stereo"
)
frame.sample_rate = self.sample_rate
frame.pts = self._timestamp
frame.time_base = fractions.Fraction(1, self.sample_rate)
@ -774,17 +847,239 @@ async def handle_chat_message(
send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]],
) -> Optional[str]:
"""Handle incoming chat messages."""
global _audio_data
logger.info(f"Received chat message: {chat_message.message}")
# This ends up blocking; spin it off into a thread
try:
from .vibevoicetts import get_tts_engine
global _audio_data, _lobby_threads
logger.info(f"Received chat message: {chat_message.message}")
lobby_id = getattr(chat_message, "lobby_id", None)
if not lobby_id:
# If no lobby id present, just spawn a worker without dedup logic
lobby_id = "__no_lobby__"
# Ensure we only run one background worker per lobby
loop = asyncio.get_running_loop()
with _lobby_threads_lock:
existing = _lobby_threads.get(lobby_id)
# existing may be a Thread or an asyncio Future-like object
already_running = False
try:
if existing is not None:
if hasattr(existing, "is_alive") and existing.is_alive():
already_running = True
elif hasattr(existing, "done") and not existing.done():
already_running = True
except Exception:
# Conservative: assume running if we can't determine
already_running = True
if already_running:
logger.info("Chat processing already active for lobby %s", lobby_id)
try:
# Prefer using the bound send/create chat helpers if available
if _create_chat_message_func is not None and _send_chat_func is not None:
cm = _create_chat_message_func("Already processing", lobby_id)
await _send_chat_func(cm)
else:
# Fallback to provided send_message_func. Some callers expect ChatMessageModel,
# but send_message_func also supports plain string in some integrations.
await send_message_func("Already processing")
except Exception:
logger.exception("Failed to send 'Already processing' reply")
return None
# Create background worker function (runs in threadpool via run_in_executor)
def _background_worker(lobby: str, msg: ChatMessageModel, _loop: asyncio.AbstractEventLoop) -> None:
global _audio_data
logger.info("TTS background worker starting for lobby %s (msg id=%s)", lobby, getattr(msg, "id", "no-id"))
start_time = time.time()
try:
try:
# Prefer an already-bound engine getter if present (useful in tests)
if globals().get("get_tts_engine"):
engine_getter = globals().get("get_tts_engine")
else:
from .vibevoicetts import get_tts_engine as engine_getter
logger.info("TTS engine getter resolved: %s", getattr(engine_getter, "__name__", str(engine_getter)))
engine = engine_getter()
logger.info("TTS engine instance created: %s", type(engine))
# Blocking TTS call moved to background thread
logger.info("Invoking engine.text_to_speech for lobby %s", lobby)
raw_audio = engine.text_to_speech(text=f"Speaker 1: {msg.message}", verbose=True)
logger.info("TTS generation completed for lobby %s in %.2fs", lobby, time.time() - start_time)
# Determine source sample rate if available
try:
src_sr = engine.get_sample_rate()
except Exception:
src_sr = 24000
# Ensure numpy array and float32
try:
raw_audio_arr = np.asarray(raw_audio, dtype=np.float32)
except Exception:
raw_audio_arr = np.array([], dtype=np.float32)
# Resample to track sample rate if needed (do this in bg thread)
if raw_audio_arr.size > 0 and src_sr != SAMPLE_RATE:
try:
resampled = librosa.resample(
raw_audio_arr.astype(np.float64),
orig_sr=src_sr,
target_sr=SAMPLE_RATE,
res_type="kaiser_fast",
)
_audio_data = resampled.astype(np.float32)
logger.info("Background worker resampled audio from %d Hz to %d Hz (samples=%d)", src_sr, SAMPLE_RATE, getattr(_audio_data, "size", 0))
except Exception:
logger.exception("Background resampling failed; using raw audio as-is")
_audio_data = raw_audio_arr.astype(np.float32)
else:
_audio_data = raw_audio_arr.astype(np.float32)
logger.info("Background worker assigned raw audio buffer (samples=%d, src_sr=%s)", getattr(_audio_data, "size", 0), src_sr)
except Exception:
# Fallback: try module-level convenience function if monkeypatched
try:
tts_fn = globals().get("text_to_speech")
if tts_fn:
logger.info("Using monkeypatched text_to_speech convenience function")
raw_audio = tts_fn(text=f"Speaker 1: {msg.message}")
else:
# Last resort: import the convenience function lazily
logger.info("Importing text_to_speech fallback from vibevoicetts")
from .vibevoicetts import text_to_speech as tts_fn
raw_audio = tts_fn(text=f"Speaker 1: {msg.message}")
logger.info("Fallback TTS invocation completed for lobby %s", lobby)
# normalize into numpy here as well
try:
raw_audio_arr = np.asarray(raw_audio, dtype=np.float32)
except Exception:
raw_audio_arr = np.array([], dtype=np.float32)
_audio_data = raw_audio_arr.astype(np.float32)
logger.info("Background worker assigned raw audio buffer (samples=%d)", getattr(_audio_data, 'size', 0))
except Exception:
logger.exception("Failed to perform TTS in background worker")
except Exception:
logger.exception("Unhandled error in background TTS worker for lobby %s", lobby)
finally:
# Update job metadata and cleanup thread entry for this lobby
try:
with _lobby_threads_lock:
# Update job metadata if present
job = _lobby_jobs.get(lobby)
if job is not None:
job["end_time"] = time.time()
job["status"] = "finished"
try:
job["audio_samples"] = int(getattr(_audio_data, "size", 0)) if _audio_data is not None else 0
except Exception:
job["audio_samples"] = None
# If the stored worker is the current thread, remove it. For futures,
# they will be removed by the done callback registered below.
th = _lobby_threads.get(lobby)
if th is threading.current_thread():
del _lobby_threads[lobby]
except Exception:
logger.exception("Error cleaning up background thread record for lobby %s", lobby)
# Schedule the worker in the event loop's default threadpool. This avoids
# raw Thread.start() races and integrates better with asyncio-based hosts.
logger.info("Scheduling background TTS worker for lobby %s via run_in_executor", lobby_id)
worker_obj = None
try:
# Use a small wrapper so the very first thing the thread does is emit a log
def _bg_wrapper(lobby_w: str, msg_w: ChatMessageModel, loop_w: asyncio.AbstractEventLoop) -> None:
logger.info("Background worker wrapper started for lobby %s (thread=%s)", lobby_w, threading.current_thread())
try:
_background_worker(lobby_w, msg_w, loop_w)
finally:
logger.info("Background worker wrapper exiting for lobby %s (thread=%s)", lobby_w, threading.current_thread())
fut = loop.run_in_executor(None, _bg_wrapper, lobby_id, chat_message, loop)
worker_obj = fut
_lobby_threads[lobby_id] = worker_obj
logger.info("Scheduled background TTS worker (wrapper) for lobby %s: %s", lobby_id, fut)
# Attach a done callback to clean up registry and update job metadata when finished
def _on_worker_done(fut_obj: "asyncio.Future[Any]") -> None:
try:
# Log exception or result for easier debugging
try:
exc = fut_obj.exception()
except Exception:
exc = None
if exc:
logger.exception("Background TTS worker for lobby %s raised exception", lobby_id, exc_info=exc)
else:
try:
res = fut_obj.result()
logger.info("Background TTS worker for lobby %s completed successfully; result=%s", lobby_id, str(res))
except Exception:
# Some futures may not have a result or may raise when retrieving
logger.info("Background TTS worker for lobby %s completed (no result).", lobby_id)
with _lobby_threads_lock:
# mark job finished and attach exception info if present
job = _lobby_jobs.get(lobby_id)
if job is not None:
job["end_time"] = time.time()
job["status"] = "finished"
try:
job["audio_samples"] = int(getattr(_audio_data, "size", 0)) if _audio_data is not None else 0
except Exception:
job["audio_samples"] = None
# remove only if the stored object is this future
stored = _lobby_threads.get(lobby_id)
if stored is fut_obj:
del _lobby_threads[lobby_id]
except Exception:
logger.exception("Error in worker done callback for lobby %s", lobby_id)
try:
fut.add_done_callback(_on_worker_done)
except Exception:
# ignore if the future does not support callbacks
pass
engine = get_tts_engine()
_audio_data = engine.text_to_speech(text=f"Speaker 1: {chat_message.message}", speaker_names=None, verbose=True)
except Exception:
# Fallback to convenience function if direct engine call fails
_audio_data = text_to_speech(text=f"Speaker 1: {chat_message.message}", speaker_names=None)
# Fallback to raw Thread in the unlikely case run_in_executor fails
thread = threading.Thread(target=_background_worker, args=(lobby_id, chat_message, loop), daemon=True)
worker_obj = thread
_lobby_threads[lobby_id] = thread
logger.info("Created background TTS thread for lobby %s (fallback): %s", lobby_id, thread)
# Record job metadata for status display
try:
with _lobby_threads_lock:
_lobby_jobs[lobby_id] = {
"status": "running",
"start_time": time.time(),
"message": getattr(chat_message, "message", ""),
"worker": worker_obj,
"error": None,
"end_time": None,
"audio_samples": None,
}
except Exception:
logger.exception("Failed to record lobby job metadata for %s", lobby_id)
# If we fell back to a raw Thread, start it now; otherwise the future is already scheduled.
try:
stored = _lobby_threads.get(lobby_id)
if stored is not None and hasattr(stored, "start"):
logger.info("Starting fallback background TTS thread for lobby %s", lobby_id)
stored.start()
logger.info("Background TTS thread started for lobby %s", lobby_id)
except Exception:
logger.exception("Failed to start background TTS worker for %s", lobby_id)
return None