1123 lines
44 KiB
Python
1123 lines
44 KiB
Python
import asyncio
|
|
import threading
|
|
import numpy as np
|
|
import time
|
|
import os
|
|
from typing import Dict, Optional, Callable, Awaitable, Any, Union
|
|
import numpy.typing as npt
|
|
|
|
|
|
# Core dependencies
|
|
import librosa
|
|
from shared.logger import logger
|
|
from aiortc import MediaStreamTrack
|
|
from aiortc.mediastreams import MediaStreamError
|
|
from av import AudioFrame, VideoFrame
|
|
import cv2
|
|
import fractions
|
|
from time import perf_counter
|
|
|
|
# Import shared models for chat functionality
|
|
import sys
|
|
|
|
sys.path.append(
|
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
)
|
|
from shared.models import ChatMessageModel
|
|
from voicebot.models import Peer
|
|
|
|
# Defer importing the heavy TTS module until needed (lazy import).
|
|
# Tests and other code can monkeypatch these names on this module.
|
|
get_tts_engine = None
|
|
text_to_speech = None
|
|
|
|
# Global configuration and constants
|
|
AGENT_NAME = "Text To Speech Bot"
|
|
AGENT_DESCRIPTION = (
|
|
"Real-time speech generation- converts text to speech on Intel Arc B580"
|
|
)
|
|
SAMPLE_RATE = 16000 # Whisper expects 16kHz
|
|
|
|
|
|
_audio_data = None
|
|
|
|
# Track background processing threads per lobby to ensure only one worker runs
|
|
# per lobby session at a time.
|
|
_lobby_threads: Dict[str, Any] = {}
|
|
_lobby_threads_lock = threading.Lock()
|
|
# Per-lobby job metadata for status display
|
|
_lobby_jobs: Dict[str, Dict[str, Any]] = {}
|
|
|
|
class TTSModel:
|
|
"""TTS via MSFT VibeVoice"""
|
|
|
|
def __init__(
|
|
self,
|
|
peer_name: str,
|
|
send_chat_func: Callable[[ChatMessageModel], Awaitable[None]],
|
|
create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel],
|
|
):
|
|
self.peer_name = peer_name
|
|
self.send_chat_func = send_chat_func
|
|
self.create_chat_message_func = create_chat_message_func
|
|
|
|
|
|
_send_chat_func: Optional[Callable[[ChatMessageModel], Awaitable[None]]] = None
|
|
_create_chat_message_func: Optional[
|
|
Callable[[str, Optional[str]], ChatMessageModel]
|
|
] = None
|
|
|
|
class MediaClock:
|
|
"""Simple monotonic clock for media tracks."""
|
|
|
|
def __init__(self) -> None:
|
|
self.t0 = perf_counter()
|
|
|
|
def now(self) -> float:
|
|
return perf_counter() - self.t0
|
|
|
|
|
|
class WaveformVideoTrack(MediaStreamTrack):
|
|
"""Video track that renders a live waveform of the incoming audio.
|
|
|
|
The track reads the most-active `OptimizedAudioProcessor` in
|
|
`_audio_processors` and renders the last ~2s of its `current_phrase_audio`.
|
|
If no audio is available, the track will display a "No audio" message.
|
|
"""
|
|
|
|
kind = "video"
|
|
|
|
# Shared buffer for audio data
|
|
buffer: Dict[str, npt.NDArray[np.float32]] = {}
|
|
speech_status: Dict[str, Dict[str, Any]] = {}
|
|
|
|
def __init__(
|
|
self, session_name: str, width: int = 640, height: int = 480, fps: int = 15
|
|
) -> None:
|
|
super().__init__()
|
|
self.session_name = session_name
|
|
self.width = int(width)
|
|
self.height = int(height)
|
|
self.fps = int(fps)
|
|
self.clock = MediaClock()
|
|
self._next_frame_index = 0
|
|
|
|
async def next_timestamp(self) -> tuple[int, float]:
|
|
pts = int(self._next_frame_index * (1 / self.fps) * 90000)
|
|
time_base = 1 / 90000
|
|
return pts, time_base
|
|
|
|
async def recv(self) -> VideoFrame:
|
|
pts, _ = await self.next_timestamp()
|
|
|
|
# schedule frame according to clock
|
|
target_t = self._next_frame_index / self.fps
|
|
now = self.clock.now()
|
|
if target_t > now:
|
|
await asyncio.sleep(target_t - now)
|
|
|
|
self._next_frame_index += 1
|
|
|
|
frame_array: npt.NDArray[np.uint8] = np.zeros(
|
|
(self.height, self.width, 3), dtype=np.uint8
|
|
)
|
|
|
|
# Display model / TTS job status prominently
|
|
status_text = "Idle"
|
|
progress = 0.0
|
|
|
|
try:
|
|
# Show number of active lobby TTS jobs
|
|
with _lobby_threads_lock:
|
|
active_jobs = {k: v for k, v in _lobby_jobs.items() if v.get("status") == "running"}
|
|
|
|
if active_jobs:
|
|
# Pick the most recent running job (sort by start_time desc)
|
|
recent_lobby, recent_job = sorted(
|
|
active_jobs.items(), key=lambda kv: kv[1].get("start_time", 0), reverse=True
|
|
)[0]
|
|
# elapsed time available if needed for overlays
|
|
_ = time.time() - recent_job.get("start_time", time.time())
|
|
snippet = (recent_job.get("message") or "")[:80]
|
|
status_text = f"TTS running ({len(active_jobs)}): {snippet}"
|
|
# Progress as a heuristic (running jobs -> 0.5)
|
|
progress = 0.5
|
|
else:
|
|
# If there is produced audio waiting in _audio_data, show ready-to-play
|
|
if _audio_data is not None and getattr(_audio_data, "size", 0) > 0:
|
|
status_text = "TTS: Ready (audio buffered)"
|
|
progress = 1.0
|
|
else:
|
|
status_text = "Idle"
|
|
progress = 0.0
|
|
except Exception:
|
|
status_text = "Status: error"
|
|
progress = 0.0
|
|
|
|
# Draw status background (increased height for larger text)
|
|
cv2.rectangle(frame_array, (0, 0), (self.width, 80), (0, 0, 0), -1)
|
|
|
|
# Draw progress bar if loading
|
|
if progress < 1.0 and "Ready" not in status_text:
|
|
bar_width = int(progress * (self.width - 40))
|
|
cv2.rectangle(frame_array, (20, 55), (20 + bar_width, 70), (0, 255, 0), -1)
|
|
cv2.rectangle(
|
|
frame_array, (20, 55), (self.width - 20, 70), (255, 255, 255), 2
|
|
)
|
|
|
|
# Draw status text (larger font)
|
|
cv2.putText(
|
|
frame_array,
|
|
f"Status: {status_text}",
|
|
(10, 35),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1.2,
|
|
(255, 255, 255),
|
|
3,
|
|
)
|
|
|
|
# Draw clock in lower right corner, right justified
|
|
current_time = time.strftime("%H:%M:%S")
|
|
(text_width, _), _ = cv2.getTextSize(
|
|
current_time, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2
|
|
)
|
|
clock_x = self.width - text_width - 10 # 10px margin from right edge
|
|
clock_y = self.height - 30 # Move to 450 for height=480
|
|
cv2.putText(
|
|
frame_array,
|
|
current_time,
|
|
(clock_x, clock_y),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1.0,
|
|
(255, 255, 255),
|
|
2,
|
|
)
|
|
|
|
# Draw TTS job status (small text under header)
|
|
try:
|
|
with _lobby_threads_lock:
|
|
active_jobs_count = sum(1 for v in _lobby_jobs.values() if v.get("status") == "running")
|
|
# Get most recent job info if present
|
|
recent_job = None
|
|
if active_jobs_count > 0:
|
|
running = [v for v in _lobby_jobs.values() if v.get("status") == "running"]
|
|
recent_job = max(running, key=lambda j: j.get("start_time", 0))
|
|
|
|
audio_samples = getattr(_audio_data, "size", 0) if _audio_data is not None else 0
|
|
|
|
job_snippet = ""
|
|
job_elapsed = None
|
|
if recent_job:
|
|
job_snippet = (recent_job.get("message") or "")[:80]
|
|
job_elapsed = time.time() - recent_job.get("start_time", time.time())
|
|
|
|
info_x = 10
|
|
info_y = 95
|
|
|
|
if active_jobs_count > 0:
|
|
cv2.putText(
|
|
frame_array,
|
|
f"TTS jobs: {active_jobs_count} | {job_snippet}",
|
|
(info_x, info_y),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.5,
|
|
(0, 255, 255),
|
|
1,
|
|
)
|
|
if job_elapsed is not None:
|
|
cv2.putText(
|
|
frame_array,
|
|
f"Elapsed: {job_elapsed:.1f}s",
|
|
(info_x + 420, info_y),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.5,
|
|
(200, 200, 255),
|
|
1,
|
|
)
|
|
else:
|
|
# show audio buffer status when idle
|
|
cv2.putText(
|
|
frame_array,
|
|
f"Audio buffered: {audio_samples} samples",
|
|
(info_x, info_y),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.5,
|
|
(200, 200, 200),
|
|
1,
|
|
)
|
|
except Exception:
|
|
# Non-critical overlay; ignore failures
|
|
pass
|
|
|
|
# Select the most active audio buffer and get its speech status
|
|
best_proc = None
|
|
best_rms = 0.0
|
|
speech_info = None
|
|
|
|
try:
|
|
for pname, arr in self.__class__.buffer.items():
|
|
try:
|
|
if len(arr) == 0:
|
|
rms = 0.0
|
|
else:
|
|
rms = float(np.sqrt(np.mean(arr**2)))
|
|
if rms > best_rms:
|
|
best_rms = rms
|
|
best_proc = (pname, arr.copy())
|
|
speech_info = self.__class__.speech_status.get(pname, {})
|
|
except Exception:
|
|
continue
|
|
except Exception:
|
|
best_proc = None
|
|
|
|
if best_proc is not None:
|
|
pname, arr = best_proc
|
|
|
|
# Use the last 2 second of audio data, padded with zeros if less
|
|
samples_needed = SAMPLE_RATE * 2 # 2 second(s)
|
|
if len(arr) <= 0:
|
|
arr_segment = np.zeros(samples_needed, dtype=np.float32)
|
|
elif len(arr) >= samples_needed:
|
|
arr_segment = arr[-samples_needed:].copy()
|
|
else:
|
|
# Pad with zeros at the beginning
|
|
arr_segment = np.concatenate(
|
|
[np.zeros(samples_needed - len(arr), dtype=np.float32), arr]
|
|
)
|
|
|
|
# Single normalization code path: normalize based on the historical
|
|
# peak observed for this stream (proc.max_observed_amplitude). This
|
|
# ensures the waveform display is consistent over time and avoids
|
|
# using the instantaneous buffer peak.
|
|
proc = None
|
|
norm = arr_segment.astype(np.float32)
|
|
|
|
# Map audio samples to pixels across the width
|
|
if norm.size < self.width:
|
|
padded = np.zeros(self.width, dtype=np.float32)
|
|
if norm.size > 0:
|
|
padded[-norm.size :] = norm
|
|
norm = padded
|
|
else:
|
|
block = int(np.ceil(norm.size / self.width))
|
|
norm = np.array(
|
|
[
|
|
np.mean(norm[i * block : min((i + 1) * block, norm.size)])
|
|
for i in range(self.width)
|
|
],
|
|
dtype=np.float32,
|
|
)
|
|
|
|
# For display we use the same `norm` computed above (single code
|
|
# path). Use `display_norm` alias to avoid confusion later in the
|
|
# code but don't recompute normalization.
|
|
display_norm = norm
|
|
|
|
# Draw waveform with color coding for speech detection
|
|
points: list[tuple[int, int]] = []
|
|
colors: list[tuple[int, int, int]] = [] # Color for each point
|
|
|
|
for x in range(self.width):
|
|
v = (
|
|
float(display_norm[x])
|
|
if x < display_norm.size and not np.isnan(display_norm[x])
|
|
else 0.0
|
|
)
|
|
y = int((1.0 - ((v + 1.0) / 2.0)) * (self.height - 120)) + 100
|
|
points.append((x, y))
|
|
|
|
# Color based on speech detection status
|
|
is_speech = (
|
|
speech_info.get("is_speech", False) if speech_info else False
|
|
)
|
|
energy_check = (
|
|
speech_info.get("energy_check", False) if speech_info else False
|
|
)
|
|
|
|
if is_speech:
|
|
colors.append((0, 255, 0)) # Green for detected speech
|
|
elif energy_check:
|
|
colors.append((255, 255, 0)) # Yellow for energy but not speech
|
|
else:
|
|
colors.append((128, 128, 128)) # Gray for background noise
|
|
|
|
# Draw colored waveform
|
|
if len(points) > 1:
|
|
for i in range(len(points) - 1):
|
|
cv2.line(frame_array, points[i], points[i + 1], colors[i], 1)
|
|
|
|
# Draw historical peak indicator (horizontal lines at +/-(target_peak))
|
|
try:
|
|
if proc is not None and getattr(proc, "normalization_enabled", False):
|
|
target_peak = float(getattr(proc, "normalization_target_peak", 0.0))
|
|
# Ensure target_peak is within [0, 1]
|
|
target_peak = max(0.0, min(1.0, target_peak))
|
|
|
|
def _amp_to_y(a: float) -> int:
|
|
return (
|
|
int((1.0 - ((a + 1.0) / 2.0)) * (self.height - 120)) + 100
|
|
)
|
|
|
|
top_y = _amp_to_y(target_peak)
|
|
bot_y = _amp_to_y(-target_peak)
|
|
|
|
# Draw thin magenta lines across the waveform area
|
|
cv2.line(
|
|
frame_array,
|
|
(0, top_y),
|
|
(self.width - 1, top_y),
|
|
(255, 0, 255),
|
|
1,
|
|
)
|
|
cv2.line(
|
|
frame_array,
|
|
(0, bot_y),
|
|
(self.width - 1, bot_y),
|
|
(255, 0, 255),
|
|
1,
|
|
)
|
|
|
|
# Label the peak with small text near the right edge
|
|
label = f"Peak:{target_peak:.2f}"
|
|
(tw, _), _ = cv2.getTextSize(
|
|
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
|
)
|
|
lx = max(10, self.width - tw - 12)
|
|
ly = max(12, top_y - 6)
|
|
cv2.putText(
|
|
frame_array,
|
|
label,
|
|
(lx, ly),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.5,
|
|
(255, 0, 255),
|
|
1,
|
|
)
|
|
except Exception:
|
|
# Non-critical: ignore any drawing errors
|
|
pass
|
|
|
|
# Add speech detection status overlay
|
|
if speech_info:
|
|
self._draw_speech_status(frame_array, speech_info, pname)
|
|
|
|
cv2.putText(
|
|
frame_array,
|
|
f"Waveform: {pname}",
|
|
(10, self.height - 15),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1.0,
|
|
(255, 255, 255),
|
|
2,
|
|
)
|
|
|
|
frame = VideoFrame.from_ndarray(frame_array, format="bgr24")
|
|
frame.pts = pts
|
|
frame.time_base = fractions.Fraction(1 / 90000).limit_denominator(1000000)
|
|
return frame
|
|
|
|
def _draw_speech_status(
|
|
self,
|
|
frame_array: npt.NDArray[np.uint8],
|
|
speech_info: Dict[str, Any],
|
|
pname: str,
|
|
):
|
|
"""Draw speech detection status information."""
|
|
|
|
y_offset = 100
|
|
|
|
# Main status
|
|
is_speech = speech_info.get("is_speech", False)
|
|
status_text = "SPEECH" if is_speech else "NOISE"
|
|
status_color = (0, 255, 0) if is_speech else (128, 128, 128)
|
|
|
|
adaptive_thresh = speech_info.get("adaptive_threshold", 0)
|
|
cv2.putText(
|
|
frame_array,
|
|
f"{pname}: {status_text} (thresh: {adaptive_thresh:.4f})",
|
|
(10, y_offset),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.7,
|
|
status_color,
|
|
2,
|
|
)
|
|
|
|
# Detailed metrics (smaller text)
|
|
metrics = [
|
|
f"Energy: {speech_info.get('energy', 0):.3f} ({'Y' if speech_info.get('energy_check', False) else 'N'})",
|
|
f"ZCR: {speech_info.get('zcr', 0):.3f} ({'Y' if speech_info.get('zcr_check', False) else 'N'})",
|
|
f"Spectral: {'Y' if 300 < speech_info.get('centroid', 0) < 3400 else 'N'}/{'Y' if speech_info.get('rolloff', 0) < 2000 else 'N'}/{'Y' if speech_info.get('flux', 0) > 0.01 else 'N'} ({'Y' if speech_info.get('spectral_check', False) else 'N'})",
|
|
f"Harmonic: {speech_info.get('hamonicity', 0):.3f} ({'Y' if speech_info.get('harmonic_check', False) else 'N'})",
|
|
f"Temporal: ({'Y' if speech_info.get('temporal_consistency', False) else 'N'})",
|
|
]
|
|
|
|
for _, metric in enumerate(metrics):
|
|
cv2.putText(
|
|
frame_array,
|
|
metric,
|
|
(320, y_offset),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.4,
|
|
(255, 255, 255),
|
|
1,
|
|
)
|
|
y_offset += 15
|
|
|
|
logic_result = "E:" + ("Y" if speech_info.get("energy_check", False) else "N")
|
|
logic_result += " Z:" + ("Y" if speech_info.get("zcr_check", False) else "N")
|
|
logic_result += " S:" + (
|
|
"Y" if speech_info.get("spectral_check", False) else "N"
|
|
)
|
|
logic_result += " H:" + (
|
|
"Y" if speech_info.get("harmonic_check", False) else "N"
|
|
)
|
|
logic_result += " T:" + (
|
|
"Y" if speech_info.get("temporal_consistency", False) else "N"
|
|
)
|
|
|
|
cv2.putText(
|
|
frame_array,
|
|
logic_result,
|
|
(320, y_offset + 5),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.6,
|
|
(255, 255, 255),
|
|
1,
|
|
)
|
|
|
|
# Noise floor indicator
|
|
noise_floor = speech_info.get("noise_floor_energy", 0)
|
|
cv2.putText(
|
|
frame_array,
|
|
f"Noise Floor: {noise_floor:.4f}",
|
|
(10, y_offset + 30),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.6,
|
|
(200, 200, 200),
|
|
1,
|
|
)
|
|
|
|
|
|
async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None:
|
|
"""Handle incoming audio tracks from WebRTC peers."""
|
|
logger.info(
|
|
f"handle_track_received called for {peer.peer_name} with track kind: {track.kind}"
|
|
)
|
|
global _audio_processors, _send_chat_func
|
|
|
|
if track.kind != "audio":
|
|
logger.info(f"Ignoring non-audio track from {peer.peer_name}: {track.kind}")
|
|
return
|
|
|
|
# Initialize raw audio buffer for immediate graphing
|
|
if peer.peer_name not in WaveformVideoTrack.buffer:
|
|
WaveformVideoTrack.buffer[peer.peer_name] = np.array([], dtype=np.float32)
|
|
|
|
# Start processing frames immediately (before processor is ready)
|
|
logger.info(f"Starting frame processing loop for {peer.peer_name}")
|
|
frame_count = 0
|
|
while True:
|
|
try:
|
|
frame = await track.recv()
|
|
frame_count += 1
|
|
|
|
if frame_count % 100 == 0:
|
|
logger.debug(f"Received {frame_count} frames from {peer.peer_name}")
|
|
|
|
except MediaStreamError as e:
|
|
logger.info(f"Audio stream ended for {peer.peer_name}: {e}")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error receiving frame from {peer.peer_name}: {e}")
|
|
break
|
|
|
|
if isinstance(frame, AudioFrame):
|
|
try:
|
|
# Convert frame to numpy array
|
|
audio_data = frame.to_ndarray()
|
|
|
|
# Handle audio format conversion
|
|
audio_data = _process_audio_frame(audio_data, frame)
|
|
|
|
# Resample if needed
|
|
if frame.sample_rate != SAMPLE_RATE:
|
|
audio_data = _resample_audio(
|
|
audio_data, frame.sample_rate, SAMPLE_RATE
|
|
)
|
|
|
|
# Convert to float32
|
|
audio_data_float32 = audio_data.astype(np.float32)
|
|
|
|
# Update visualization buffer immediately
|
|
WaveformVideoTrack.buffer[peer.peer_name] = np.concatenate(
|
|
[WaveformVideoTrack.buffer[peer.peer_name], audio_data_float32]
|
|
)
|
|
# Limit buffer size to last 10 seconds
|
|
max_samples = SAMPLE_RATE * 10
|
|
if len(WaveformVideoTrack.buffer[peer.peer_name]) > max_samples:
|
|
WaveformVideoTrack.buffer[peer.peer_name] = (
|
|
WaveformVideoTrack.buffer[peer.peer_name][-max_samples:]
|
|
)
|
|
|
|
# # Process with optimized processor if available
|
|
# if peer.peer_name in _audio_processors:
|
|
# audio_processor = _audio_processors[peer.peer_name]
|
|
# audio_processor.add_audio_data(audio_data_float32)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing audio frame for {peer.peer_name}: {e}")
|
|
continue
|
|
|
|
# If processor already exists, just continue processing
|
|
# audio_processor = _audio_processors[peer.peer_name]
|
|
logger.info(f"Continuing OpenVINO audio processing for {peer.peer_name}")
|
|
|
|
try:
|
|
frame_count = 0
|
|
logger.info(f"Entering frame processing loop for {peer.peer_name}")
|
|
while True:
|
|
try:
|
|
logger.debug(f"Waiting for frame from {peer.peer_name}")
|
|
frame = await track.recv()
|
|
frame_count += 1
|
|
|
|
if frame_count == 1:
|
|
logger.info(f"Received first frame from {peer.peer_name}")
|
|
elif frame_count % 50 == 0:
|
|
logger.info(f"Received {frame_count} frames from {peer.peer_name}")
|
|
|
|
except MediaStreamError as e:
|
|
logger.info(f"Audio stream ended for {peer.peer_name}: {e}")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error receiving frame from {peer.peer_name}: {e}")
|
|
break
|
|
|
|
if isinstance(frame, AudioFrame):
|
|
try:
|
|
# Convert frame to numpy array
|
|
audio_data = frame.to_ndarray()
|
|
|
|
# Handle audio format conversion
|
|
audio_data = _process_audio_frame(audio_data, frame)
|
|
|
|
# Resample if needed
|
|
if frame.sample_rate != SAMPLE_RATE:
|
|
audio_data = _resample_audio(
|
|
audio_data, frame.sample_rate, SAMPLE_RATE
|
|
)
|
|
|
|
# Convert to float32
|
|
audio_data_float32 = audio_data.astype(np.float32)
|
|
|
|
logger.debug(
|
|
f"Processed audio frame {frame_count} from {peer.peer_name}: {len(audio_data_float32)} samples"
|
|
)
|
|
|
|
# Process with optimized processor if available
|
|
# audio_processor.add_audio_data(audio_data_float32)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error processing audio frame for {peer.peer_name}: {e}"
|
|
)
|
|
continue
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Unexpected error in audio processing for {peer.peer_name}: {e}",
|
|
exc_info=True,
|
|
)
|
|
finally:
|
|
cleanup_peer_processor(peer.peer_name)
|
|
|
|
|
|
def _process_audio_frame(
|
|
audio_data: npt.NDArray[Any], frame: AudioFrame
|
|
) -> npt.NDArray[np.float32]:
|
|
"""Process audio frame format conversion."""
|
|
# Handle stereo to mono conversion
|
|
if audio_data.ndim == 2:
|
|
if audio_data.shape[0] == 1:
|
|
audio_data = audio_data.squeeze(0)
|
|
else:
|
|
audio_data = np.mean(
|
|
audio_data, axis=0 if audio_data.shape[0] > audio_data.shape[1] else 1
|
|
)
|
|
|
|
# Normalize based on data type
|
|
if audio_data.dtype == np.int16:
|
|
audio_data = audio_data.astype(np.float32) / 32768.0
|
|
elif audio_data.dtype == np.int32:
|
|
audio_data = audio_data.astype(np.float32) / 2147483648.0
|
|
|
|
return audio_data.astype(np.float32)
|
|
|
|
|
|
def _resample_audio(
|
|
audio_data: npt.NDArray[np.float32], orig_sr: int, target_sr: int
|
|
) -> npt.NDArray[np.float32]:
|
|
"""Resample audio efficiently."""
|
|
try:
|
|
# Handle stereo audio by converting to mono if necessary
|
|
if audio_data.ndim > 1:
|
|
audio_data = np.mean(audio_data, axis=1)
|
|
|
|
# Use high-quality resampling
|
|
resampled = librosa.resample( # type: ignore
|
|
audio_data.astype(np.float64),
|
|
orig_sr=orig_sr,
|
|
target_sr=target_sr,
|
|
res_type="kaiser_fast", # Good balance of quality and speed
|
|
)
|
|
return resampled.astype(np.float32) # type: ignore
|
|
except Exception as e:
|
|
logger.error(f"Resampling failed: {e}")
|
|
raise ValueError(
|
|
f"Failed to resample audio from {orig_sr} Hz to {target_sr} Hz: {e}"
|
|
)
|
|
|
|
|
|
# Public API functions
|
|
def agent_info() -> Dict[str, str]:
|
|
return {
|
|
"name": AGENT_NAME,
|
|
"description": AGENT_DESCRIPTION,
|
|
"has_media": "true",
|
|
"configurable": "true",
|
|
}
|
|
|
|
|
|
def get_config_schema() -> Dict[str, Any]:
|
|
"""Get the configuration schema for the Whisper bot"""
|
|
return {
|
|
"bot_name": AGENT_NAME,
|
|
"version": "1.0",
|
|
"parameters": [
|
|
],
|
|
"categories": [
|
|
],
|
|
}
|
|
|
|
|
|
def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
|
"""Handle configuration update for a specific lobby"""
|
|
global _model_id, _device, _ov_config
|
|
|
|
try:
|
|
logger.info(f"Updating TTS config for lobby {lobby_id}: {config_values}")
|
|
|
|
config_applied = False
|
|
|
|
config_applied = True
|
|
|
|
return config_applied
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to apply Whisper config update: {e}")
|
|
return False
|
|
|
|
|
|
def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
|
|
"""Create agent tracks. Provides a synthetic video waveform track and a silent audio track for compatibility."""
|
|
|
|
class SilentAudioTrack(MediaStreamTrack):
|
|
kind = "audio"
|
|
|
|
def __init__(
|
|
self, sample_rate: int = SAMPLE_RATE, channels: int = 1, fps: int = 50
|
|
):
|
|
super().__init__()
|
|
self.sample_rate = sample_rate
|
|
self.channels = channels
|
|
self.fps = fps
|
|
self.samples_per_frame = int(self.sample_rate / self.fps)
|
|
self._timestamp = 0
|
|
# Per-track playback buffer (float32, mono, -1..1)
|
|
self._play_buffer = np.array([], dtype=np.float32)
|
|
# Source sample rate of the buffered audio (if any)
|
|
self._play_src_sr = None
|
|
# Phase for synthetic sine tone (radians)
|
|
self._sine_phase = 0.0
|
|
|
|
async def recv(self) -> AudioFrame:
|
|
global _audio_data
|
|
# If new global TTS audio was produced, grab and prepare it for playback
|
|
logger.info(
|
|
"recv called with _audio_data: %s and size: %s",
|
|
"set" if _audio_data is not None else "unset",
|
|
getattr(_audio_data, "size", 0) if _audio_data is not None else "N/A",
|
|
)
|
|
try:
|
|
if _audio_data is not None and getattr(_audio_data, "size", 0) > 0:
|
|
# Expect background worker to provide audio already resampled to
|
|
# the track sample rate (SAMPLE_RATE). Keep recv fast and
|
|
# non-blocking by avoiding any heavy DSP here.
|
|
try:
|
|
src_audio = np.asarray(_audio_data, dtype=np.float32)
|
|
# Clear global buffer (we consumed it into this track)
|
|
_audio_data = None
|
|
# Use the provided audio buffer directly as play buffer.
|
|
self._play_buffer = src_audio.astype(np.float32)
|
|
self._play_src_sr = self.sample_rate
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to prepare TTS audio buffer in recv; falling back to silence"
|
|
)
|
|
|
|
# (Intentionally avoid expensive I/O/DSP here - background worker
|
|
# will perform resampling and any debug file writes.)
|
|
|
|
# Prepare output samples for this frame
|
|
if self._play_buffer.size > 0:
|
|
take = min(self.samples_per_frame, int(self._play_buffer.size))
|
|
out = self._play_buffer[:take]
|
|
# Advance buffer
|
|
if take >= self._play_buffer.size:
|
|
self._play_buffer = np.array([], dtype=np.float32)
|
|
else:
|
|
self._play_buffer = self._play_buffer[take:]
|
|
else:
|
|
# No TTS audio buffered; output a 110 Hz sine tone at 0.1 volume
|
|
freq = 110.0
|
|
volume = 0.1
|
|
n = self.samples_per_frame
|
|
# incremental phase per sample
|
|
phase_inc = 2.0 * np.pi * freq / float(self.sample_rate)
|
|
samples_idx = np.arange(n, dtype=np.float32)
|
|
out = volume * np.sin(self._sine_phase + phase_inc * samples_idx)
|
|
# advance and wrap phase
|
|
self._sine_phase = (self._sine_phase + phase_inc * n) % (2.0 * np.pi)
|
|
logger.debug(
|
|
"No TTS audio: emitting 110Hz test tone (%d samples at %d Hz)",
|
|
n,
|
|
self.sample_rate,
|
|
)
|
|
|
|
# Convert float32 [-1.0,1.0] to int16 PCM
|
|
pcm = np.clip(out, -1.0, 1.0)
|
|
pcm_int16 = (pcm * 32767.0).astype(np.int16)
|
|
|
|
# aiortc AudioFrame.from_ndarray expects shape (channels, samples)
|
|
if self.channels == 1:
|
|
data = np.expand_dims(pcm_int16, axis=0)
|
|
layout = "mono"
|
|
else:
|
|
# Duplicate mono into stereo
|
|
data = np.vstack([pcm_int16, pcm_int16])
|
|
layout = "stereo"
|
|
|
|
frame = AudioFrame.from_ndarray(data, layout=layout)
|
|
frame.sample_rate = self.sample_rate
|
|
frame.pts = self._timestamp
|
|
frame.time_base = fractions.Fraction(1, self.sample_rate)
|
|
self._timestamp += self.samples_per_frame
|
|
|
|
# Pace the frame rate to avoid busy-looping
|
|
await asyncio.sleep(1 / self.fps)
|
|
|
|
return frame
|
|
except Exception as e:
|
|
logger.exception(f"Error in SilentAudioTrack.recv: {e}")
|
|
# On error, fall back to silent frame
|
|
data = np.zeros((self.channels, self.samples_per_frame), dtype=np.int16)
|
|
frame = AudioFrame.from_ndarray(
|
|
data, layout="mono" if self.channels == 1 else "stereo"
|
|
)
|
|
frame.sample_rate = self.sample_rate
|
|
frame.pts = self._timestamp
|
|
frame.time_base = fractions.Fraction(1, self.sample_rate)
|
|
self._timestamp += self.samples_per_frame
|
|
await asyncio.sleep(1 / self.fps)
|
|
return frame
|
|
|
|
try:
|
|
video_track = WaveformVideoTrack(
|
|
session_name=session_name, width=640, height=480, fps=15
|
|
)
|
|
audio_track = SilentAudioTrack()
|
|
return {"video": video_track, "audio": audio_track}
|
|
except Exception as e:
|
|
logger.error(f"Failed to create agent tracks: {e}")
|
|
return {}
|
|
|
|
|
|
async def handle_chat_message(
|
|
chat_message: ChatMessageModel,
|
|
send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]],
|
|
) -> Optional[str]:
|
|
"""Handle incoming chat messages."""
|
|
global _audio_data, _lobby_threads
|
|
|
|
logger.info(f"Received chat message: {chat_message.message}")
|
|
|
|
lobby_id = getattr(chat_message, "lobby_id", None)
|
|
if not lobby_id:
|
|
# If no lobby id present, just spawn a worker without dedup logic
|
|
lobby_id = "__no_lobby__"
|
|
|
|
# Ensure we only run one background worker per lobby
|
|
loop = asyncio.get_running_loop()
|
|
|
|
with _lobby_threads_lock:
|
|
existing = _lobby_threads.get(lobby_id)
|
|
# existing may be a Thread or an asyncio Future-like object
|
|
already_running = False
|
|
try:
|
|
if existing is not None:
|
|
if hasattr(existing, "is_alive") and existing.is_alive():
|
|
already_running = True
|
|
elif hasattr(existing, "done") and not existing.done():
|
|
already_running = True
|
|
except Exception:
|
|
# Conservative: assume running if we can't determine
|
|
already_running = True
|
|
|
|
if already_running:
|
|
logger.info("Chat processing already active for lobby %s", lobby_id)
|
|
try:
|
|
# Prefer using the bound send/create chat helpers if available
|
|
if _create_chat_message_func is not None and _send_chat_func is not None:
|
|
cm = _create_chat_message_func("Already processing", lobby_id)
|
|
await _send_chat_func(cm)
|
|
else:
|
|
# Fallback to provided send_message_func. Some callers expect ChatMessageModel,
|
|
# but send_message_func also supports plain string in some integrations.
|
|
await send_message_func("Already processing")
|
|
except Exception:
|
|
logger.exception("Failed to send 'Already processing' reply")
|
|
return None
|
|
|
|
# Create background worker function (runs in threadpool via run_in_executor)
|
|
def _background_worker(lobby: str, msg: ChatMessageModel, _loop: asyncio.AbstractEventLoop) -> None:
|
|
global _audio_data
|
|
logger.info("TTS background worker starting for lobby %s (msg id=%s)", lobby, getattr(msg, "id", "no-id"))
|
|
start_time = time.time()
|
|
try:
|
|
try:
|
|
# Prefer an already-bound engine getter if present (useful in tests)
|
|
if globals().get("get_tts_engine"):
|
|
engine_getter = globals().get("get_tts_engine")
|
|
else:
|
|
from .vibevoicetts import get_tts_engine as engine_getter
|
|
|
|
logger.info("TTS engine getter resolved: %s", getattr(engine_getter, "__name__", str(engine_getter)))
|
|
|
|
engine = engine_getter()
|
|
logger.info("TTS engine instance created: %s", type(engine))
|
|
# Blocking TTS call moved to background thread
|
|
logger.info("Invoking engine.text_to_speech for lobby %s", lobby)
|
|
raw_audio = engine.text_to_speech(text=f"Speaker 1: {msg.message}", verbose=True)
|
|
logger.info("TTS generation completed for lobby %s in %.2fs", lobby, time.time() - start_time)
|
|
|
|
# Determine source sample rate if available
|
|
try:
|
|
src_sr = engine.get_sample_rate()
|
|
except Exception:
|
|
src_sr = 24000
|
|
|
|
# Ensure numpy array and float32
|
|
try:
|
|
raw_audio_arr = np.asarray(raw_audio, dtype=np.float32)
|
|
except Exception:
|
|
raw_audio_arr = np.array([], dtype=np.float32)
|
|
|
|
# Resample to track sample rate if needed (do this in bg thread)
|
|
if raw_audio_arr.size > 0 and src_sr != SAMPLE_RATE:
|
|
try:
|
|
resampled = librosa.resample(
|
|
raw_audio_arr.astype(np.float64),
|
|
orig_sr=src_sr,
|
|
target_sr=SAMPLE_RATE,
|
|
res_type="kaiser_fast",
|
|
)
|
|
_audio_data = resampled.astype(np.float32)
|
|
logger.info("Background worker resampled audio from %d Hz to %d Hz (samples=%d)", src_sr, SAMPLE_RATE, getattr(_audio_data, "size", 0))
|
|
except Exception:
|
|
logger.exception("Background resampling failed; using raw audio as-is")
|
|
_audio_data = raw_audio_arr.astype(np.float32)
|
|
else:
|
|
_audio_data = raw_audio_arr.astype(np.float32)
|
|
logger.info("Background worker assigned raw audio buffer (samples=%d, src_sr=%s)", getattr(_audio_data, "size", 0), src_sr)
|
|
except Exception:
|
|
# Fallback: try module-level convenience function if monkeypatched
|
|
try:
|
|
tts_fn = globals().get("text_to_speech")
|
|
if tts_fn:
|
|
logger.info("Using monkeypatched text_to_speech convenience function")
|
|
raw_audio = tts_fn(text=f"Speaker 1: {msg.message}")
|
|
else:
|
|
# Last resort: import the convenience function lazily
|
|
logger.info("Importing text_to_speech fallback from vibevoicetts")
|
|
from .vibevoicetts import text_to_speech as tts_fn
|
|
raw_audio = tts_fn(text=f"Speaker 1: {msg.message}")
|
|
logger.info("Fallback TTS invocation completed for lobby %s", lobby)
|
|
# normalize into numpy here as well
|
|
try:
|
|
raw_audio_arr = np.asarray(raw_audio, dtype=np.float32)
|
|
except Exception:
|
|
raw_audio_arr = np.array([], dtype=np.float32)
|
|
|
|
_audio_data = raw_audio_arr.astype(np.float32)
|
|
logger.info("Background worker assigned raw audio buffer (samples=%d)", getattr(_audio_data, 'size', 0))
|
|
except Exception:
|
|
logger.exception("Failed to perform TTS in background worker")
|
|
|
|
except Exception:
|
|
logger.exception("Unhandled error in background TTS worker for lobby %s", lobby)
|
|
finally:
|
|
# Update job metadata and cleanup thread entry for this lobby
|
|
try:
|
|
with _lobby_threads_lock:
|
|
# Update job metadata if present
|
|
job = _lobby_jobs.get(lobby)
|
|
if job is not None:
|
|
job["end_time"] = time.time()
|
|
job["status"] = "finished"
|
|
try:
|
|
job["audio_samples"] = int(getattr(_audio_data, "size", 0)) if _audio_data is not None else 0
|
|
except Exception:
|
|
job["audio_samples"] = None
|
|
|
|
# If the stored worker is the current thread, remove it. For futures,
|
|
# they will be removed by the done callback registered below.
|
|
th = _lobby_threads.get(lobby)
|
|
if th is threading.current_thread():
|
|
del _lobby_threads[lobby]
|
|
except Exception:
|
|
logger.exception("Error cleaning up background thread record for lobby %s", lobby)
|
|
|
|
# Schedule the worker in the event loop's default threadpool. This avoids
|
|
# raw Thread.start() races and integrates better with asyncio-based hosts.
|
|
logger.info("Scheduling background TTS worker for lobby %s via run_in_executor", lobby_id)
|
|
worker_obj = None
|
|
try:
|
|
# Use a small wrapper so the very first thing the thread does is emit a log
|
|
def _bg_wrapper(lobby_w: str, msg_w: ChatMessageModel, loop_w: asyncio.AbstractEventLoop) -> None:
|
|
logger.info("Background worker wrapper started for lobby %s (thread=%s)", lobby_w, threading.current_thread())
|
|
try:
|
|
_background_worker(lobby_w, msg_w, loop_w)
|
|
finally:
|
|
logger.info("Background worker wrapper exiting for lobby %s (thread=%s)", lobby_w, threading.current_thread())
|
|
|
|
fut = loop.run_in_executor(None, _bg_wrapper, lobby_id, chat_message, loop)
|
|
worker_obj = fut
|
|
_lobby_threads[lobby_id] = worker_obj
|
|
logger.info("Scheduled background TTS worker (wrapper) for lobby %s: %s", lobby_id, fut)
|
|
|
|
# Attach a done callback to clean up registry and update job metadata when finished
|
|
def _on_worker_done(fut_obj: "asyncio.Future[Any]") -> None:
|
|
try:
|
|
# Log exception or result for easier debugging
|
|
try:
|
|
exc = fut_obj.exception()
|
|
except Exception:
|
|
exc = None
|
|
|
|
if exc:
|
|
logger.exception("Background TTS worker for lobby %s raised exception", lobby_id, exc_info=exc)
|
|
else:
|
|
try:
|
|
res = fut_obj.result()
|
|
logger.info("Background TTS worker for lobby %s completed successfully; result=%s", lobby_id, str(res))
|
|
except Exception:
|
|
# Some futures may not have a result or may raise when retrieving
|
|
logger.info("Background TTS worker for lobby %s completed (no result).", lobby_id)
|
|
|
|
with _lobby_threads_lock:
|
|
# mark job finished and attach exception info if present
|
|
job = _lobby_jobs.get(lobby_id)
|
|
if job is not None:
|
|
job["end_time"] = time.time()
|
|
job["status"] = "finished"
|
|
try:
|
|
job["audio_samples"] = int(getattr(_audio_data, "size", 0)) if _audio_data is not None else 0
|
|
except Exception:
|
|
job["audio_samples"] = None
|
|
|
|
# remove only if the stored object is this future
|
|
stored = _lobby_threads.get(lobby_id)
|
|
if stored is fut_obj:
|
|
del _lobby_threads[lobby_id]
|
|
except Exception:
|
|
logger.exception("Error in worker done callback for lobby %s", lobby_id)
|
|
|
|
try:
|
|
fut.add_done_callback(_on_worker_done)
|
|
except Exception:
|
|
# ignore if the future does not support callbacks
|
|
pass
|
|
|
|
except Exception:
|
|
# Fallback to raw Thread in the unlikely case run_in_executor fails
|
|
thread = threading.Thread(target=_background_worker, args=(lobby_id, chat_message, loop), daemon=True)
|
|
worker_obj = thread
|
|
_lobby_threads[lobby_id] = thread
|
|
logger.info("Created background TTS thread for lobby %s (fallback): %s", lobby_id, thread)
|
|
|
|
# Record job metadata for status display
|
|
try:
|
|
with _lobby_threads_lock:
|
|
_lobby_jobs[lobby_id] = {
|
|
"status": "running",
|
|
"start_time": time.time(),
|
|
"message": getattr(chat_message, "message", ""),
|
|
"worker": worker_obj,
|
|
"error": None,
|
|
"end_time": None,
|
|
"audio_samples": None,
|
|
}
|
|
except Exception:
|
|
logger.exception("Failed to record lobby job metadata for %s", lobby_id)
|
|
|
|
# If we fell back to a raw Thread, start it now; otherwise the future is already scheduled.
|
|
try:
|
|
stored = _lobby_threads.get(lobby_id)
|
|
if stored is not None and hasattr(stored, "start"):
|
|
logger.info("Starting fallback background TTS thread for lobby %s", lobby_id)
|
|
stored.start()
|
|
logger.info("Background TTS thread started for lobby %s", lobby_id)
|
|
except Exception:
|
|
logger.exception("Failed to start background TTS worker for %s", lobby_id)
|
|
|
|
return None
|
|
|
|
|
|
async def on_track_received(peer: Peer, track: MediaStreamTrack) -> None:
|
|
"""Callback when a new track is received from a peer."""
|
|
await handle_track_received(peer, track)
|
|
|
|
|
|
def get_track_handler() -> Callable[[Peer, MediaStreamTrack], Awaitable[None]]:
|
|
"""Return the track handler function."""
|
|
return on_track_received
|
|
|
|
|
|
def bind_send_chat_function(
|
|
send_chat_func: Callable[[ChatMessageModel], Awaitable[None]],
|
|
create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel],
|
|
) -> None:
|
|
"""Bind the send chat function."""
|
|
global _send_chat_func, _create_chat_message_func, _audio_processors
|
|
|
|
logger.info("Binding send chat function to OpenVINO whisper agent")
|
|
_send_chat_func = send_chat_func
|
|
_create_chat_message_func = create_chat_message_func
|
|
|
|
# Update existing processors
|
|
# for peer_name, processor in _audio_processors.items():
|
|
# processor.send_chat_func = send_chat_func
|
|
# processor.create_chat_message_func = create_chat_message_func
|
|
# logger.debug(f"Updated processor for {peer_name} with new send chat function")
|
|
|
|
|
|
def cleanup_peer_processor(peer_name: str) -> None:
|
|
"""Clean up processor for disconnected peer."""
|
|
global _audio_processors
|
|
|
|
if peer_name in WaveformVideoTrack.buffer:
|
|
del WaveformVideoTrack.buffer[peer_name]
|
|
|
|
|
|
|