ai-voicebot/voicebot/bots/vibevoice.py
2025-09-18 15:19:40 -07:00

828 lines
31 KiB
Python

import asyncio
import numpy as np
import time
import os
from typing import Dict, Optional, Callable, Awaitable, Any, Union
import numpy.typing as npt
# Core dependencies
import librosa
from shared.logger import logger
from aiortc import MediaStreamTrack
from aiortc.mediastreams import MediaStreamError
from av import AudioFrame, VideoFrame
import cv2
import fractions
from time import perf_counter
import wave
# Import shared models for chat functionality
import sys
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
from shared.models import ChatMessageModel
from voicebot.models import Peer
from .vibevoicetts import text_to_speech
# Global configuration and constants
AGENT_NAME = "Text To Speech Bot"
AGENT_DESCRIPTION = (
"Real-time speech generation- converts text to speech on Intel Arc B580"
)
SAMPLE_RATE = 16000 # Whisper expects 16kHz
_audio_data = None
class TTSModel:
"""TTS via MSFT VibeVoice"""
def __init__(
self,
peer_name: str,
send_chat_func: Callable[[ChatMessageModel], Awaitable[None]],
create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel],
):
self.peer_name = peer_name
self.send_chat_func = send_chat_func
self.create_chat_message_func = create_chat_message_func
_send_chat_func: Optional[Callable[[ChatMessageModel], Awaitable[None]]] = None
_create_chat_message_func: Optional[
Callable[[str, Optional[str]], ChatMessageModel]
] = None
class MediaClock:
"""Simple monotonic clock for media tracks."""
def __init__(self) -> None:
self.t0 = perf_counter()
def now(self) -> float:
return perf_counter() - self.t0
class WaveformVideoTrack(MediaStreamTrack):
"""Video track that renders a live waveform of the incoming audio.
The track reads the most-active `OptimizedAudioProcessor` in
`_audio_processors` and renders the last ~2s of its `current_phrase_audio`.
If no audio is available, the track will display a "No audio" message.
"""
kind = "video"
# Shared buffer for audio data
buffer: Dict[str, npt.NDArray[np.float32]] = {}
speech_status: Dict[str, Dict[str, Any]] = {}
def __init__(
self, session_name: str, width: int = 640, height: int = 480, fps: int = 15
) -> None:
super().__init__()
self.session_name = session_name
self.width = int(width)
self.height = int(height)
self.fps = int(fps)
self.clock = MediaClock()
self._next_frame_index = 0
async def next_timestamp(self) -> tuple[int, float]:
pts = int(self._next_frame_index * (1 / self.fps) * 90000)
time_base = 1 / 90000
return pts, time_base
async def recv(self) -> VideoFrame:
pts, _ = await self.next_timestamp()
# schedule frame according to clock
target_t = self._next_frame_index / self.fps
now = self.clock.now()
if target_t > now:
await asyncio.sleep(target_t - now)
self._next_frame_index += 1
frame_array: npt.NDArray[np.uint8] = np.zeros(
(self.height, self.width, 3), dtype=np.uint8
)
# Display model loading status prominently
status_text = "Initializing..."
progress = 0.1
# Draw status background (increased height for larger text)
cv2.rectangle(frame_array, (0, 0), (self.width, 80), (0, 0, 0), -1)
# Draw progress bar if loading
if progress < 1.0 and "Ready" not in status_text:
bar_width = int(progress * (self.width - 40))
cv2.rectangle(frame_array, (20, 55), (20 + bar_width, 70), (0, 255, 0), -1)
cv2.rectangle(
frame_array, (20, 55), (self.width - 20, 70), (255, 255, 255), 2
)
# Draw status text (larger font)
cv2.putText(
frame_array,
f"Status: {status_text}",
(10, 35),
cv2.FONT_HERSHEY_SIMPLEX,
1.2,
(255, 255, 255),
3,
)
# Draw clock in lower right corner, right justified
current_time = time.strftime("%H:%M:%S")
(text_width, _), _ = cv2.getTextSize(
current_time, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2
)
clock_x = self.width - text_width - 10 # 10px margin from right edge
clock_y = self.height - 30 # Move to 450 for height=480
cv2.putText(
frame_array,
current_time,
(clock_x, clock_y),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
(255, 255, 255),
2,
)
# Select the most active audio buffer and get its speech status
best_proc = None
best_rms = 0.0
speech_info = None
try:
for pname, arr in self.__class__.buffer.items():
try:
if len(arr) == 0:
rms = 0.0
else:
rms = float(np.sqrt(np.mean(arr**2)))
if rms > best_rms:
best_rms = rms
best_proc = (pname, arr.copy())
speech_info = self.__class__.speech_status.get(pname, {})
except Exception:
continue
except Exception:
best_proc = None
if best_proc is not None:
pname, arr = best_proc
# Use the last 2 second of audio data, padded with zeros if less
samples_needed = SAMPLE_RATE * 2 # 2 second(s)
if len(arr) <= 0:
arr_segment = np.zeros(samples_needed, dtype=np.float32)
elif len(arr) >= samples_needed:
arr_segment = arr[-samples_needed:].copy()
else:
# Pad with zeros at the beginning
arr_segment = np.concatenate(
[np.zeros(samples_needed - len(arr), dtype=np.float32), arr]
)
# Single normalization code path: normalize based on the historical
# peak observed for this stream (proc.max_observed_amplitude). This
# ensures the waveform display is consistent over time and avoids
# using the instantaneous buffer peak.
proc = None
norm = arr_segment.astype(np.float32)
# Map audio samples to pixels across the width
if norm.size < self.width:
padded = np.zeros(self.width, dtype=np.float32)
if norm.size > 0:
padded[-norm.size :] = norm
norm = padded
else:
block = int(np.ceil(norm.size / self.width))
norm = np.array(
[
np.mean(norm[i * block : min((i + 1) * block, norm.size)])
for i in range(self.width)
],
dtype=np.float32,
)
# For display we use the same `norm` computed above (single code
# path). Use `display_norm` alias to avoid confusion later in the
# code but don't recompute normalization.
display_norm = norm
# Draw waveform with color coding for speech detection
points: list[tuple[int, int]] = []
colors: list[tuple[int, int, int]] = [] # Color for each point
for x in range(self.width):
v = (
float(display_norm[x])
if x < display_norm.size and not np.isnan(display_norm[x])
else 0.0
)
y = int((1.0 - ((v + 1.0) / 2.0)) * (self.height - 120)) + 100
points.append((x, y))
# Color based on speech detection status
is_speech = (
speech_info.get("is_speech", False) if speech_info else False
)
energy_check = (
speech_info.get("energy_check", False) if speech_info else False
)
if is_speech:
colors.append((0, 255, 0)) # Green for detected speech
elif energy_check:
colors.append((255, 255, 0)) # Yellow for energy but not speech
else:
colors.append((128, 128, 128)) # Gray for background noise
# Draw colored waveform
if len(points) > 1:
for i in range(len(points) - 1):
cv2.line(frame_array, points[i], points[i + 1], colors[i], 1)
# Draw historical peak indicator (horizontal lines at +/-(target_peak))
try:
if proc is not None and getattr(proc, "normalization_enabled", False):
target_peak = float(getattr(proc, "normalization_target_peak", 0.0))
# Ensure target_peak is within [0, 1]
target_peak = max(0.0, min(1.0, target_peak))
def _amp_to_y(a: float) -> int:
return (
int((1.0 - ((a + 1.0) / 2.0)) * (self.height - 120)) + 100
)
top_y = _amp_to_y(target_peak)
bot_y = _amp_to_y(-target_peak)
# Draw thin magenta lines across the waveform area
cv2.line(
frame_array,
(0, top_y),
(self.width - 1, top_y),
(255, 0, 255),
1,
)
cv2.line(
frame_array,
(0, bot_y),
(self.width - 1, bot_y),
(255, 0, 255),
1,
)
# Label the peak with small text near the right edge
label = f"Peak:{target_peak:.2f}"
(tw, _), _ = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
lx = max(10, self.width - tw - 12)
ly = max(12, top_y - 6)
cv2.putText(
frame_array,
label,
(lx, ly),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(255, 0, 255),
1,
)
except Exception:
# Non-critical: ignore any drawing errors
pass
# Add speech detection status overlay
if speech_info:
self._draw_speech_status(frame_array, speech_info, pname)
cv2.putText(
frame_array,
f"Waveform: {pname}",
(10, self.height - 15),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
(255, 255, 255),
2,
)
frame = VideoFrame.from_ndarray(frame_array, format="bgr24")
frame.pts = pts
frame.time_base = fractions.Fraction(1 / 90000).limit_denominator(1000000)
return frame
def _draw_speech_status(
self,
frame_array: npt.NDArray[np.uint8],
speech_info: Dict[str, Any],
pname: str,
):
"""Draw speech detection status information."""
y_offset = 100
# Main status
is_speech = speech_info.get("is_speech", False)
status_text = "SPEECH" if is_speech else "NOISE"
status_color = (0, 255, 0) if is_speech else (128, 128, 128)
adaptive_thresh = speech_info.get("adaptive_threshold", 0)
cv2.putText(
frame_array,
f"{pname}: {status_text} (thresh: {adaptive_thresh:.4f})",
(10, y_offset),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
status_color,
2,
)
# Detailed metrics (smaller text)
metrics = [
f"Energy: {speech_info.get('energy', 0):.3f} ({'Y' if speech_info.get('energy_check', False) else 'N'})",
f"ZCR: {speech_info.get('zcr', 0):.3f} ({'Y' if speech_info.get('zcr_check', False) else 'N'})",
f"Spectral: {'Y' if 300 < speech_info.get('centroid', 0) < 3400 else 'N'}/{'Y' if speech_info.get('rolloff', 0) < 2000 else 'N'}/{'Y' if speech_info.get('flux', 0) > 0.01 else 'N'} ({'Y' if speech_info.get('spectral_check', False) else 'N'})",
f"Harmonic: {speech_info.get('hamonicity', 0):.3f} ({'Y' if speech_info.get('harmonic_check', False) else 'N'})",
f"Temporal: ({'Y' if speech_info.get('temporal_consistency', False) else 'N'})",
]
for _, metric in enumerate(metrics):
cv2.putText(
frame_array,
metric,
(320, y_offset),
cv2.FONT_HERSHEY_SIMPLEX,
0.4,
(255, 255, 255),
1,
)
y_offset += 15
logic_result = "E:" + ("Y" if speech_info.get("energy_check", False) else "N")
logic_result += " Z:" + ("Y" if speech_info.get("zcr_check", False) else "N")
logic_result += " S:" + (
"Y" if speech_info.get("spectral_check", False) else "N"
)
logic_result += " H:" + (
"Y" if speech_info.get("harmonic_check", False) else "N"
)
logic_result += " T:" + (
"Y" if speech_info.get("temporal_consistency", False) else "N"
)
cv2.putText(
frame_array,
logic_result,
(320, y_offset + 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
1,
)
# Noise floor indicator
noise_floor = speech_info.get("noise_floor_energy", 0)
cv2.putText(
frame_array,
f"Noise Floor: {noise_floor:.4f}",
(10, y_offset + 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(200, 200, 200),
1,
)
async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None:
"""Handle incoming audio tracks from WebRTC peers."""
logger.info(
f"handle_track_received called for {peer.peer_name} with track kind: {track.kind}"
)
global _audio_processors, _send_chat_func
if track.kind != "audio":
logger.info(f"Ignoring non-audio track from {peer.peer_name}: {track.kind}")
return
# Initialize raw audio buffer for immediate graphing
if peer.peer_name not in WaveformVideoTrack.buffer:
WaveformVideoTrack.buffer[peer.peer_name] = np.array([], dtype=np.float32)
# Start processing frames immediately (before processor is ready)
logger.info(f"Starting frame processing loop for {peer.peer_name}")
frame_count = 0
while True:
try:
frame = await track.recv()
frame_count += 1
if frame_count % 100 == 0:
logger.debug(f"Received {frame_count} frames from {peer.peer_name}")
except MediaStreamError as e:
logger.info(f"Audio stream ended for {peer.peer_name}: {e}")
break
except Exception as e:
logger.error(f"Error receiving frame from {peer.peer_name}: {e}")
break
if isinstance(frame, AudioFrame):
try:
# Convert frame to numpy array
audio_data = frame.to_ndarray()
# Handle audio format conversion
audio_data = _process_audio_frame(audio_data, frame)
# Resample if needed
if frame.sample_rate != SAMPLE_RATE:
audio_data = _resample_audio(
audio_data, frame.sample_rate, SAMPLE_RATE
)
# Convert to float32
audio_data_float32 = audio_data.astype(np.float32)
# Update visualization buffer immediately
WaveformVideoTrack.buffer[peer.peer_name] = np.concatenate(
[WaveformVideoTrack.buffer[peer.peer_name], audio_data_float32]
)
# Limit buffer size to last 10 seconds
max_samples = SAMPLE_RATE * 10
if len(WaveformVideoTrack.buffer[peer.peer_name]) > max_samples:
WaveformVideoTrack.buffer[peer.peer_name] = (
WaveformVideoTrack.buffer[peer.peer_name][-max_samples:]
)
# # Process with optimized processor if available
# if peer.peer_name in _audio_processors:
# audio_processor = _audio_processors[peer.peer_name]
# audio_processor.add_audio_data(audio_data_float32)
except Exception as e:
logger.error(f"Error processing audio frame for {peer.peer_name}: {e}")
continue
# If processor already exists, just continue processing
# audio_processor = _audio_processors[peer.peer_name]
logger.info(f"Continuing OpenVINO audio processing for {peer.peer_name}")
try:
frame_count = 0
logger.info(f"Entering frame processing loop for {peer.peer_name}")
while True:
try:
logger.debug(f"Waiting for frame from {peer.peer_name}")
frame = await track.recv()
frame_count += 1
if frame_count == 1:
logger.info(f"Received first frame from {peer.peer_name}")
elif frame_count % 50 == 0:
logger.info(f"Received {frame_count} frames from {peer.peer_name}")
except MediaStreamError as e:
logger.info(f"Audio stream ended for {peer.peer_name}: {e}")
break
except Exception as e:
logger.error(f"Error receiving frame from {peer.peer_name}: {e}")
break
if isinstance(frame, AudioFrame):
try:
# Convert frame to numpy array
audio_data = frame.to_ndarray()
# Handle audio format conversion
audio_data = _process_audio_frame(audio_data, frame)
# Resample if needed
if frame.sample_rate != SAMPLE_RATE:
audio_data = _resample_audio(
audio_data, frame.sample_rate, SAMPLE_RATE
)
# Convert to float32
audio_data_float32 = audio_data.astype(np.float32)
logger.debug(
f"Processed audio frame {frame_count} from {peer.peer_name}: {len(audio_data_float32)} samples"
)
# Process with optimized processor if available
# audio_processor.add_audio_data(audio_data_float32)
except Exception as e:
logger.error(
f"Error processing audio frame for {peer.peer_name}: {e}"
)
continue
except Exception as e:
logger.error(
f"Unexpected error in audio processing for {peer.peer_name}: {e}",
exc_info=True,
)
finally:
cleanup_peer_processor(peer.peer_name)
def _process_audio_frame(
audio_data: npt.NDArray[Any], frame: AudioFrame
) -> npt.NDArray[np.float32]:
"""Process audio frame format conversion."""
# Handle stereo to mono conversion
if audio_data.ndim == 2:
if audio_data.shape[0] == 1:
audio_data = audio_data.squeeze(0)
else:
audio_data = np.mean(
audio_data, axis=0 if audio_data.shape[0] > audio_data.shape[1] else 1
)
# Normalize based on data type
if audio_data.dtype == np.int16:
audio_data = audio_data.astype(np.float32) / 32768.0
elif audio_data.dtype == np.int32:
audio_data = audio_data.astype(np.float32) / 2147483648.0
return audio_data.astype(np.float32)
def _resample_audio(
audio_data: npt.NDArray[np.float32], orig_sr: int, target_sr: int
) -> npt.NDArray[np.float32]:
"""Resample audio efficiently."""
try:
# Handle stereo audio by converting to mono if necessary
if audio_data.ndim > 1:
audio_data = np.mean(audio_data, axis=1)
# Use high-quality resampling
resampled = librosa.resample( # type: ignore
audio_data.astype(np.float64),
orig_sr=orig_sr,
target_sr=target_sr,
res_type="kaiser_fast", # Good balance of quality and speed
)
return resampled.astype(np.float32) # type: ignore
except Exception as e:
logger.error(f"Resampling failed: {e}")
raise ValueError(
f"Failed to resample audio from {orig_sr} Hz to {target_sr} Hz: {e}"
)
# Public API functions
def agent_info() -> Dict[str, str]:
return {
"name": AGENT_NAME,
"description": AGENT_DESCRIPTION,
"has_media": "true",
"configurable": "true",
}
def get_config_schema() -> Dict[str, Any]:
"""Get the configuration schema for the Whisper bot"""
return {
"bot_name": AGENT_NAME,
"version": "1.0",
"parameters": [
],
"categories": [
],
}
def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
"""Handle configuration update for a specific lobby"""
global _model_id, _device, _ov_config
try:
logger.info(f"Updating TTS config for lobby {lobby_id}: {config_values}")
config_applied = False
config_applied = True
return config_applied
except Exception as e:
logger.error(f"Failed to apply Whisper config update: {e}")
return False
def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
"""Create agent tracks. Provides a synthetic video waveform track and a silent audio track for compatibility."""
class SilentAudioTrack(MediaStreamTrack):
kind = "audio"
def __init__(
self, sample_rate: int = SAMPLE_RATE, channels: int = 1, fps: int = 50
):
super().__init__()
self.sample_rate = sample_rate
self.channels = channels
self.fps = fps
self.samples_per_frame = int(self.sample_rate / self.fps)
self._timestamp = 0
# Per-track playback buffer (float32, mono, -1..1)
self._play_buffer = np.array([], dtype=np.float32)
# Source sample rate of the buffered audio (if any)
self._play_src_sr = None
# Phase for synthetic sine tone (radians)
self._sine_phase = 0.0
async def recv(self) -> AudioFrame:
global _audio_data
# If new global TTS audio was produced, grab and prepare it for playback
logger.info("recv called with _audio_data: %s and size: %s", "set" if _audio_data is not None else "unset", getattr(_audio_data, 'size', 0) if _audio_data is not None else 'N/A')
try:
if _audio_data is not None and getattr(_audio_data, 'size', 0) > 0:
# Get source sample rate from TTS engine if available
try:
from .vibevoicetts import get_tts_engine
src_sr = get_tts_engine().get_sample_rate()
except Exception:
# Fallback: assume 24000 (VibeVoice default)
logger.info("Falling back to default source sample rate of 24000 Hz")
src_sr = 24000
# Ensure numpy float32
src_audio = np.asarray(_audio_data, dtype=np.float32)
# Clear global buffer (we consumed it into this track)
_audio_data = None
# If source sr differs from track sr, resample
if src_sr != self.sample_rate:
try:
resampled = librosa.resample(
src_audio.astype(np.float64),
orig_sr=src_sr,
target_sr=self.sample_rate,
res_type="kaiser_fast",
)
self._play_buffer = resampled.astype(np.float32)
except Exception:
# On failure, fallback to nearest simple scaling (not ideal)
logger.exception("Failed to resample TTS audio; using source as-is")
self._play_buffer = src_audio.astype(np.float32)
else:
self._play_buffer = src_audio.astype(np.float32)
self._play_src_sr = self.sample_rate
# Save a copy of the consumed audio to disk for inspection
try:
pcm_save = np.clip(self._play_buffer, -1.0, 1.0)
pcm_int16_save = (pcm_save * 32767.0).astype(np.int16)
sample_path = "./sample.wav"
with wave.open(sample_path, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(self.sample_rate)
wf.writeframes(pcm_int16_save.tobytes())
logger.info("Wrote TTS sample to %s (%d samples, %d Hz)", sample_path, len(pcm_int16_save), self.sample_rate)
except Exception:
logger.exception("Failed to write sample.wav")
# Prepare output samples for this frame
if self._play_buffer.size > 0:
take = min(self.samples_per_frame, int(self._play_buffer.size))
out = self._play_buffer[:take]
# Advance buffer
if take >= self._play_buffer.size:
self._play_buffer = np.array([], dtype=np.float32)
else:
self._play_buffer = self._play_buffer[take:]
else:
# No TTS audio buffered; output a 110 Hz sine tone at 0.1 volume
freq = 110.0
volume = 0.1
n = self.samples_per_frame
# incremental phase per sample
phase_inc = 2.0 * np.pi * freq / float(self.sample_rate)
samples_idx = np.arange(n, dtype=np.float32)
out = volume * np.sin(self._sine_phase + phase_inc * samples_idx)
# advance and wrap phase
self._sine_phase = (self._sine_phase + phase_inc * n) % (2.0 * np.pi)
logger.debug("No TTS audio: emitting 110Hz test tone (%d samples at %d Hz)", n, self.sample_rate)
# Convert float32 [-1.0,1.0] to int16 PCM
pcm = np.clip(out, -1.0, 1.0)
pcm_int16 = (pcm * 32767.0).astype(np.int16)
# aiortc AudioFrame.from_ndarray expects shape (channels, samples)
if self.channels == 1:
data = np.expand_dims(pcm_int16, axis=0)
layout = "mono"
else:
# Duplicate mono into stereo
data = np.vstack([pcm_int16, pcm_int16])
layout = "stereo"
frame = AudioFrame.from_ndarray(data, layout=layout)
frame.sample_rate = self.sample_rate
frame.pts = self._timestamp
frame.time_base = fractions.Fraction(1, self.sample_rate)
self._timestamp += self.samples_per_frame
# Pace the frame rate to avoid busy-looping
await asyncio.sleep(1 / self.fps)
return frame
except Exception as e:
logger.exception(f"Error in SilentAudioTrack.recv: {e}")
# On error, fall back to silent frame
data = np.zeros((self.channels, self.samples_per_frame), dtype=np.int16)
frame = AudioFrame.from_ndarray(data, layout="mono" if self.channels == 1 else "stereo")
frame.sample_rate = self.sample_rate
frame.pts = self._timestamp
frame.time_base = fractions.Fraction(1, self.sample_rate)
self._timestamp += self.samples_per_frame
await asyncio.sleep(1 / self.fps)
return frame
try:
video_track = WaveformVideoTrack(
session_name=session_name, width=640, height=480, fps=15
)
audio_track = SilentAudioTrack()
return {"video": video_track, "audio": audio_track}
except Exception as e:
logger.error(f"Failed to create agent tracks: {e}")
return {}
async def handle_chat_message(
chat_message: ChatMessageModel,
send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]],
) -> Optional[str]:
"""Handle incoming chat messages."""
global _audio_data
logger.info(f"Received chat message: {chat_message.message}")
# This ends up blocking; spin it off into a thread
try:
from .vibevoicetts import get_tts_engine
engine = get_tts_engine()
_audio_data = engine.text_to_speech(text=f"Speaker 1: {chat_message.message}", speaker_names=None, verbose=True)
except Exception:
# Fallback to convenience function if direct engine call fails
_audio_data = text_to_speech(text=f"Speaker 1: {chat_message.message}", speaker_names=None)
return None
async def on_track_received(peer: Peer, track: MediaStreamTrack) -> None:
"""Callback when a new track is received from a peer."""
await handle_track_received(peer, track)
def get_track_handler() -> Callable[[Peer, MediaStreamTrack], Awaitable[None]]:
"""Return the track handler function."""
return on_track_received
def bind_send_chat_function(
send_chat_func: Callable[[ChatMessageModel], Awaitable[None]],
create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel],
) -> None:
"""Bind the send chat function."""
global _send_chat_func, _create_chat_message_func, _audio_processors
logger.info("Binding send chat function to OpenVINO whisper agent")
_send_chat_func = send_chat_func
_create_chat_message_func = create_chat_message_func
# Update existing processors
# for peer_name, processor in _audio_processors.items():
# processor.send_chat_func = send_chat_func
# processor.create_chat_message_func = create_chat_message_func
# logger.debug(f"Updated processor for {peer_name} with new send chat function")
def cleanup_peer_processor(peer_name: str) -> None:
"""Clean up processor for disconnected peer."""
global _audio_processors
if peer_name in WaveformVideoTrack.buffer:
del WaveformVideoTrack.buffer[peer_name]