TTS is working on CPU
This commit is contained in:
parent
a0aa65ec1c
commit
4058f729e2
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import numpy as np
|
||||
import time
|
||||
import os
|
||||
@ -15,7 +16,6 @@ from av import AudioFrame, VideoFrame
|
||||
import cv2
|
||||
import fractions
|
||||
from time import perf_counter
|
||||
import wave
|
||||
|
||||
# Import shared models for chat functionality
|
||||
import sys
|
||||
@ -26,7 +26,10 @@ sys.path.append(
|
||||
from shared.models import ChatMessageModel
|
||||
from voicebot.models import Peer
|
||||
|
||||
from .vibevoicetts import text_to_speech
|
||||
# Defer importing the heavy TTS module until needed (lazy import).
|
||||
# Tests and other code can monkeypatch these names on this module.
|
||||
get_tts_engine = None
|
||||
text_to_speech = None
|
||||
|
||||
# Global configuration and constants
|
||||
AGENT_NAME = "Text To Speech Bot"
|
||||
@ -38,6 +41,13 @@ SAMPLE_RATE = 16000 # Whisper expects 16kHz
|
||||
|
||||
_audio_data = None
|
||||
|
||||
# Track background processing threads per lobby to ensure only one worker runs
|
||||
# per lobby session at a time.
|
||||
_lobby_threads: Dict[str, Any] = {}
|
||||
_lobby_threads_lock = threading.Lock()
|
||||
# Per-lobby job metadata for status display
|
||||
_lobby_jobs: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
class TTSModel:
|
||||
"""TTS via MSFT VibeVoice"""
|
||||
|
||||
@ -112,9 +122,37 @@ class WaveformVideoTrack(MediaStreamTrack):
|
||||
(self.height, self.width, 3), dtype=np.uint8
|
||||
)
|
||||
|
||||
# Display model loading status prominently
|
||||
status_text = "Initializing..."
|
||||
progress = 0.1
|
||||
# Display model / TTS job status prominently
|
||||
status_text = "Idle"
|
||||
progress = 0.0
|
||||
|
||||
try:
|
||||
# Show number of active lobby TTS jobs
|
||||
with _lobby_threads_lock:
|
||||
active_jobs = {k: v for k, v in _lobby_jobs.items() if v.get("status") == "running"}
|
||||
|
||||
if active_jobs:
|
||||
# Pick the most recent running job (sort by start_time desc)
|
||||
recent_lobby, recent_job = sorted(
|
||||
active_jobs.items(), key=lambda kv: kv[1].get("start_time", 0), reverse=True
|
||||
)[0]
|
||||
# elapsed time available if needed for overlays
|
||||
_ = time.time() - recent_job.get("start_time", time.time())
|
||||
snippet = (recent_job.get("message") or "")[:80]
|
||||
status_text = f"TTS running ({len(active_jobs)}): {snippet}"
|
||||
# Progress as a heuristic (running jobs -> 0.5)
|
||||
progress = 0.5
|
||||
else:
|
||||
# If there is produced audio waiting in _audio_data, show ready-to-play
|
||||
if _audio_data is not None and getattr(_audio_data, "size", 0) > 0:
|
||||
status_text = "TTS: Ready (audio buffered)"
|
||||
progress = 1.0
|
||||
else:
|
||||
status_text = "Idle"
|
||||
progress = 0.0
|
||||
except Exception:
|
||||
status_text = "Status: error"
|
||||
progress = 0.0
|
||||
|
||||
# Draw status background (increased height for larger text)
|
||||
cv2.rectangle(frame_array, (0, 0), (self.width, 80), (0, 0, 0), -1)
|
||||
@ -155,6 +193,62 @@ class WaveformVideoTrack(MediaStreamTrack):
|
||||
2,
|
||||
)
|
||||
|
||||
# Draw TTS job status (small text under header)
|
||||
try:
|
||||
with _lobby_threads_lock:
|
||||
active_jobs_count = sum(1 for v in _lobby_jobs.values() if v.get("status") == "running")
|
||||
# Get most recent job info if present
|
||||
recent_job = None
|
||||
if active_jobs_count > 0:
|
||||
running = [v for v in _lobby_jobs.values() if v.get("status") == "running"]
|
||||
recent_job = max(running, key=lambda j: j.get("start_time", 0))
|
||||
|
||||
audio_samples = getattr(_audio_data, "size", 0) if _audio_data is not None else 0
|
||||
|
||||
job_snippet = ""
|
||||
job_elapsed = None
|
||||
if recent_job:
|
||||
job_snippet = (recent_job.get("message") or "")[:80]
|
||||
job_elapsed = time.time() - recent_job.get("start_time", time.time())
|
||||
|
||||
info_x = 10
|
||||
info_y = 95
|
||||
|
||||
if active_jobs_count > 0:
|
||||
cv2.putText(
|
||||
frame_array,
|
||||
f"TTS jobs: {active_jobs_count} | {job_snippet}",
|
||||
(info_x, info_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 255, 255),
|
||||
1,
|
||||
)
|
||||
if job_elapsed is not None:
|
||||
cv2.putText(
|
||||
frame_array,
|
||||
f"Elapsed: {job_elapsed:.1f}s",
|
||||
(info_x + 420, info_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(200, 200, 255),
|
||||
1,
|
||||
)
|
||||
else:
|
||||
# show audio buffer status when idle
|
||||
cv2.putText(
|
||||
frame_array,
|
||||
f"Audio buffered: {audio_samples} samples",
|
||||
(info_x, info_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(200, 200, 200),
|
||||
1,
|
||||
)
|
||||
except Exception:
|
||||
# Non-critical overlay; ignore failures
|
||||
pass
|
||||
|
||||
# Select the most active audio buffer and get its speech status
|
||||
best_proc = None
|
||||
best_rms = 0.0
|
||||
@ -649,57 +743,30 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
|
||||
async def recv(self) -> AudioFrame:
|
||||
global _audio_data
|
||||
# If new global TTS audio was produced, grab and prepare it for playback
|
||||
logger.info("recv called with _audio_data: %s and size: %s", "set" if _audio_data is not None else "unset", getattr(_audio_data, 'size', 0) if _audio_data is not None else 'N/A')
|
||||
logger.info(
|
||||
"recv called with _audio_data: %s and size: %s",
|
||||
"set" if _audio_data is not None else "unset",
|
||||
getattr(_audio_data, "size", 0) if _audio_data is not None else "N/A",
|
||||
)
|
||||
try:
|
||||
if _audio_data is not None and getattr(_audio_data, 'size', 0) > 0:
|
||||
# Get source sample rate from TTS engine if available
|
||||
if _audio_data is not None and getattr(_audio_data, "size", 0) > 0:
|
||||
# Expect background worker to provide audio already resampled to
|
||||
# the track sample rate (SAMPLE_RATE). Keep recv fast and
|
||||
# non-blocking by avoiding any heavy DSP here.
|
||||
try:
|
||||
from .vibevoicetts import get_tts_engine
|
||||
|
||||
src_sr = get_tts_engine().get_sample_rate()
|
||||
except Exception:
|
||||
# Fallback: assume 24000 (VibeVoice default)
|
||||
logger.info("Falling back to default source sample rate of 24000 Hz")
|
||||
src_sr = 24000
|
||||
|
||||
# Ensure numpy float32
|
||||
src_audio = np.asarray(_audio_data, dtype=np.float32)
|
||||
|
||||
# Clear global buffer (we consumed it into this track)
|
||||
_audio_data = None
|
||||
|
||||
# If source sr differs from track sr, resample
|
||||
if src_sr != self.sample_rate:
|
||||
try:
|
||||
resampled = librosa.resample(
|
||||
src_audio.astype(np.float64),
|
||||
orig_sr=src_sr,
|
||||
target_sr=self.sample_rate,
|
||||
res_type="kaiser_fast",
|
||||
)
|
||||
self._play_buffer = resampled.astype(np.float32)
|
||||
except Exception:
|
||||
# On failure, fallback to nearest simple scaling (not ideal)
|
||||
logger.exception("Failed to resample TTS audio; using source as-is")
|
||||
self._play_buffer = src_audio.astype(np.float32)
|
||||
else:
|
||||
src_audio = np.asarray(_audio_data, dtype=np.float32)
|
||||
# Clear global buffer (we consumed it into this track)
|
||||
_audio_data = None
|
||||
# Use the provided audio buffer directly as play buffer.
|
||||
self._play_buffer = src_audio.astype(np.float32)
|
||||
|
||||
self._play_src_sr = self.sample_rate
|
||||
|
||||
# Save a copy of the consumed audio to disk for inspection
|
||||
try:
|
||||
pcm_save = np.clip(self._play_buffer, -1.0, 1.0)
|
||||
pcm_int16_save = (pcm_save * 32767.0).astype(np.int16)
|
||||
sample_path = "./sample.wav"
|
||||
with wave.open(sample_path, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(self.sample_rate)
|
||||
wf.writeframes(pcm_int16_save.tobytes())
|
||||
logger.info("Wrote TTS sample to %s (%d samples, %d Hz)", sample_path, len(pcm_int16_save), self.sample_rate)
|
||||
self._play_src_sr = self.sample_rate
|
||||
except Exception:
|
||||
logger.exception("Failed to write sample.wav")
|
||||
logger.exception(
|
||||
"Failed to prepare TTS audio buffer in recv; falling back to silence"
|
||||
)
|
||||
|
||||
# (Intentionally avoid expensive I/O/DSP here - background worker
|
||||
# will perform resampling and any debug file writes.)
|
||||
|
||||
# Prepare output samples for this frame
|
||||
if self._play_buffer.size > 0:
|
||||
@ -721,7 +788,11 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
|
||||
out = volume * np.sin(self._sine_phase + phase_inc * samples_idx)
|
||||
# advance and wrap phase
|
||||
self._sine_phase = (self._sine_phase + phase_inc * n) % (2.0 * np.pi)
|
||||
logger.debug("No TTS audio: emitting 110Hz test tone (%d samples at %d Hz)", n, self.sample_rate)
|
||||
logger.debug(
|
||||
"No TTS audio: emitting 110Hz test tone (%d samples at %d Hz)",
|
||||
n,
|
||||
self.sample_rate,
|
||||
)
|
||||
|
||||
# Convert float32 [-1.0,1.0] to int16 PCM
|
||||
pcm = np.clip(out, -1.0, 1.0)
|
||||
@ -750,7 +821,9 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
|
||||
logger.exception(f"Error in SilentAudioTrack.recv: {e}")
|
||||
# On error, fall back to silent frame
|
||||
data = np.zeros((self.channels, self.samples_per_frame), dtype=np.int16)
|
||||
frame = AudioFrame.from_ndarray(data, layout="mono" if self.channels == 1 else "stereo")
|
||||
frame = AudioFrame.from_ndarray(
|
||||
data, layout="mono" if self.channels == 1 else "stereo"
|
||||
)
|
||||
frame.sample_rate = self.sample_rate
|
||||
frame.pts = self._timestamp
|
||||
frame.time_base = fractions.Fraction(1, self.sample_rate)
|
||||
@ -774,17 +847,239 @@ async def handle_chat_message(
|
||||
send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]],
|
||||
) -> Optional[str]:
|
||||
"""Handle incoming chat messages."""
|
||||
global _audio_data
|
||||
logger.info(f"Received chat message: {chat_message.message}")
|
||||
# This ends up blocking; spin it off into a thread
|
||||
try:
|
||||
from .vibevoicetts import get_tts_engine
|
||||
global _audio_data, _lobby_threads
|
||||
|
||||
logger.info(f"Received chat message: {chat_message.message}")
|
||||
|
||||
lobby_id = getattr(chat_message, "lobby_id", None)
|
||||
if not lobby_id:
|
||||
# If no lobby id present, just spawn a worker without dedup logic
|
||||
lobby_id = "__no_lobby__"
|
||||
|
||||
# Ensure we only run one background worker per lobby
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
with _lobby_threads_lock:
|
||||
existing = _lobby_threads.get(lobby_id)
|
||||
# existing may be a Thread or an asyncio Future-like object
|
||||
already_running = False
|
||||
try:
|
||||
if existing is not None:
|
||||
if hasattr(existing, "is_alive") and existing.is_alive():
|
||||
already_running = True
|
||||
elif hasattr(existing, "done") and not existing.done():
|
||||
already_running = True
|
||||
except Exception:
|
||||
# Conservative: assume running if we can't determine
|
||||
already_running = True
|
||||
|
||||
if already_running:
|
||||
logger.info("Chat processing already active for lobby %s", lobby_id)
|
||||
try:
|
||||
# Prefer using the bound send/create chat helpers if available
|
||||
if _create_chat_message_func is not None and _send_chat_func is not None:
|
||||
cm = _create_chat_message_func("Already processing", lobby_id)
|
||||
await _send_chat_func(cm)
|
||||
else:
|
||||
# Fallback to provided send_message_func. Some callers expect ChatMessageModel,
|
||||
# but send_message_func also supports plain string in some integrations.
|
||||
await send_message_func("Already processing")
|
||||
except Exception:
|
||||
logger.exception("Failed to send 'Already processing' reply")
|
||||
return None
|
||||
|
||||
# Create background worker function (runs in threadpool via run_in_executor)
|
||||
def _background_worker(lobby: str, msg: ChatMessageModel, _loop: asyncio.AbstractEventLoop) -> None:
|
||||
global _audio_data
|
||||
logger.info("TTS background worker starting for lobby %s (msg id=%s)", lobby, getattr(msg, "id", "no-id"))
|
||||
start_time = time.time()
|
||||
try:
|
||||
try:
|
||||
# Prefer an already-bound engine getter if present (useful in tests)
|
||||
if globals().get("get_tts_engine"):
|
||||
engine_getter = globals().get("get_tts_engine")
|
||||
else:
|
||||
from .vibevoicetts import get_tts_engine as engine_getter
|
||||
|
||||
logger.info("TTS engine getter resolved: %s", getattr(engine_getter, "__name__", str(engine_getter)))
|
||||
|
||||
engine = engine_getter()
|
||||
logger.info("TTS engine instance created: %s", type(engine))
|
||||
# Blocking TTS call moved to background thread
|
||||
logger.info("Invoking engine.text_to_speech for lobby %s", lobby)
|
||||
raw_audio = engine.text_to_speech(text=f"Speaker 1: {msg.message}", verbose=True)
|
||||
logger.info("TTS generation completed for lobby %s in %.2fs", lobby, time.time() - start_time)
|
||||
|
||||
# Determine source sample rate if available
|
||||
try:
|
||||
src_sr = engine.get_sample_rate()
|
||||
except Exception:
|
||||
src_sr = 24000
|
||||
|
||||
# Ensure numpy array and float32
|
||||
try:
|
||||
raw_audio_arr = np.asarray(raw_audio, dtype=np.float32)
|
||||
except Exception:
|
||||
raw_audio_arr = np.array([], dtype=np.float32)
|
||||
|
||||
# Resample to track sample rate if needed (do this in bg thread)
|
||||
if raw_audio_arr.size > 0 and src_sr != SAMPLE_RATE:
|
||||
try:
|
||||
resampled = librosa.resample(
|
||||
raw_audio_arr.astype(np.float64),
|
||||
orig_sr=src_sr,
|
||||
target_sr=SAMPLE_RATE,
|
||||
res_type="kaiser_fast",
|
||||
)
|
||||
_audio_data = resampled.astype(np.float32)
|
||||
logger.info("Background worker resampled audio from %d Hz to %d Hz (samples=%d)", src_sr, SAMPLE_RATE, getattr(_audio_data, "size", 0))
|
||||
except Exception:
|
||||
logger.exception("Background resampling failed; using raw audio as-is")
|
||||
_audio_data = raw_audio_arr.astype(np.float32)
|
||||
else:
|
||||
_audio_data = raw_audio_arr.astype(np.float32)
|
||||
logger.info("Background worker assigned raw audio buffer (samples=%d, src_sr=%s)", getattr(_audio_data, "size", 0), src_sr)
|
||||
except Exception:
|
||||
# Fallback: try module-level convenience function if monkeypatched
|
||||
try:
|
||||
tts_fn = globals().get("text_to_speech")
|
||||
if tts_fn:
|
||||
logger.info("Using monkeypatched text_to_speech convenience function")
|
||||
raw_audio = tts_fn(text=f"Speaker 1: {msg.message}")
|
||||
else:
|
||||
# Last resort: import the convenience function lazily
|
||||
logger.info("Importing text_to_speech fallback from vibevoicetts")
|
||||
from .vibevoicetts import text_to_speech as tts_fn
|
||||
raw_audio = tts_fn(text=f"Speaker 1: {msg.message}")
|
||||
logger.info("Fallback TTS invocation completed for lobby %s", lobby)
|
||||
# normalize into numpy here as well
|
||||
try:
|
||||
raw_audio_arr = np.asarray(raw_audio, dtype=np.float32)
|
||||
except Exception:
|
||||
raw_audio_arr = np.array([], dtype=np.float32)
|
||||
|
||||
_audio_data = raw_audio_arr.astype(np.float32)
|
||||
logger.info("Background worker assigned raw audio buffer (samples=%d)", getattr(_audio_data, 'size', 0))
|
||||
except Exception:
|
||||
logger.exception("Failed to perform TTS in background worker")
|
||||
|
||||
except Exception:
|
||||
logger.exception("Unhandled error in background TTS worker for lobby %s", lobby)
|
||||
finally:
|
||||
# Update job metadata and cleanup thread entry for this lobby
|
||||
try:
|
||||
with _lobby_threads_lock:
|
||||
# Update job metadata if present
|
||||
job = _lobby_jobs.get(lobby)
|
||||
if job is not None:
|
||||
job["end_time"] = time.time()
|
||||
job["status"] = "finished"
|
||||
try:
|
||||
job["audio_samples"] = int(getattr(_audio_data, "size", 0)) if _audio_data is not None else 0
|
||||
except Exception:
|
||||
job["audio_samples"] = None
|
||||
|
||||
# If the stored worker is the current thread, remove it. For futures,
|
||||
# they will be removed by the done callback registered below.
|
||||
th = _lobby_threads.get(lobby)
|
||||
if th is threading.current_thread():
|
||||
del _lobby_threads[lobby]
|
||||
except Exception:
|
||||
logger.exception("Error cleaning up background thread record for lobby %s", lobby)
|
||||
|
||||
# Schedule the worker in the event loop's default threadpool. This avoids
|
||||
# raw Thread.start() races and integrates better with asyncio-based hosts.
|
||||
logger.info("Scheduling background TTS worker for lobby %s via run_in_executor", lobby_id)
|
||||
worker_obj = None
|
||||
try:
|
||||
# Use a small wrapper so the very first thing the thread does is emit a log
|
||||
def _bg_wrapper(lobby_w: str, msg_w: ChatMessageModel, loop_w: asyncio.AbstractEventLoop) -> None:
|
||||
logger.info("Background worker wrapper started for lobby %s (thread=%s)", lobby_w, threading.current_thread())
|
||||
try:
|
||||
_background_worker(lobby_w, msg_w, loop_w)
|
||||
finally:
|
||||
logger.info("Background worker wrapper exiting for lobby %s (thread=%s)", lobby_w, threading.current_thread())
|
||||
|
||||
fut = loop.run_in_executor(None, _bg_wrapper, lobby_id, chat_message, loop)
|
||||
worker_obj = fut
|
||||
_lobby_threads[lobby_id] = worker_obj
|
||||
logger.info("Scheduled background TTS worker (wrapper) for lobby %s: %s", lobby_id, fut)
|
||||
|
||||
# Attach a done callback to clean up registry and update job metadata when finished
|
||||
def _on_worker_done(fut_obj: "asyncio.Future[Any]") -> None:
|
||||
try:
|
||||
# Log exception or result for easier debugging
|
||||
try:
|
||||
exc = fut_obj.exception()
|
||||
except Exception:
|
||||
exc = None
|
||||
|
||||
if exc:
|
||||
logger.exception("Background TTS worker for lobby %s raised exception", lobby_id, exc_info=exc)
|
||||
else:
|
||||
try:
|
||||
res = fut_obj.result()
|
||||
logger.info("Background TTS worker for lobby %s completed successfully; result=%s", lobby_id, str(res))
|
||||
except Exception:
|
||||
# Some futures may not have a result or may raise when retrieving
|
||||
logger.info("Background TTS worker for lobby %s completed (no result).", lobby_id)
|
||||
|
||||
with _lobby_threads_lock:
|
||||
# mark job finished and attach exception info if present
|
||||
job = _lobby_jobs.get(lobby_id)
|
||||
if job is not None:
|
||||
job["end_time"] = time.time()
|
||||
job["status"] = "finished"
|
||||
try:
|
||||
job["audio_samples"] = int(getattr(_audio_data, "size", 0)) if _audio_data is not None else 0
|
||||
except Exception:
|
||||
job["audio_samples"] = None
|
||||
|
||||
# remove only if the stored object is this future
|
||||
stored = _lobby_threads.get(lobby_id)
|
||||
if stored is fut_obj:
|
||||
del _lobby_threads[lobby_id]
|
||||
except Exception:
|
||||
logger.exception("Error in worker done callback for lobby %s", lobby_id)
|
||||
|
||||
try:
|
||||
fut.add_done_callback(_on_worker_done)
|
||||
except Exception:
|
||||
# ignore if the future does not support callbacks
|
||||
pass
|
||||
|
||||
engine = get_tts_engine()
|
||||
_audio_data = engine.text_to_speech(text=f"Speaker 1: {chat_message.message}", speaker_names=None, verbose=True)
|
||||
except Exception:
|
||||
# Fallback to convenience function if direct engine call fails
|
||||
_audio_data = text_to_speech(text=f"Speaker 1: {chat_message.message}", speaker_names=None)
|
||||
# Fallback to raw Thread in the unlikely case run_in_executor fails
|
||||
thread = threading.Thread(target=_background_worker, args=(lobby_id, chat_message, loop), daemon=True)
|
||||
worker_obj = thread
|
||||
_lobby_threads[lobby_id] = thread
|
||||
logger.info("Created background TTS thread for lobby %s (fallback): %s", lobby_id, thread)
|
||||
|
||||
# Record job metadata for status display
|
||||
try:
|
||||
with _lobby_threads_lock:
|
||||
_lobby_jobs[lobby_id] = {
|
||||
"status": "running",
|
||||
"start_time": time.time(),
|
||||
"message": getattr(chat_message, "message", ""),
|
||||
"worker": worker_obj,
|
||||
"error": None,
|
||||
"end_time": None,
|
||||
"audio_samples": None,
|
||||
}
|
||||
except Exception:
|
||||
logger.exception("Failed to record lobby job metadata for %s", lobby_id)
|
||||
|
||||
# If we fell back to a raw Thread, start it now; otherwise the future is already scheduled.
|
||||
try:
|
||||
stored = _lobby_threads.get(lobby_id)
|
||||
if stored is not None and hasattr(stored, "start"):
|
||||
logger.info("Starting fallback background TTS thread for lobby %s", lobby_id)
|
||||
stored.start()
|
||||
logger.info("Background TTS thread started for lobby %s", lobby_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to start background TTS worker for %s", lobby_id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user