620 lines
24 KiB
Python

"""Streaming Whisper agent (bots/whisper)
Real-time speech transcription agent that processes incoming audio streams
and sends transcriptions as chat messages to the lobby.
"""
import asyncio
import numpy as np
import time
import threading
from collections import deque
from queue import Queue, Empty
from typing import Dict, Optional, Callable, Awaitable, Deque, Any, cast
import numpy.typing as npt
from pydantic import BaseModel
# Core dependencies
import librosa
from logger import logger
from aiortc import MediaStreamTrack
from av import AudioFrame
# Import shared models for chat functionality
import sys
import os
from voicebot.models import Peer
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
from shared.models import ChatMessageModel
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
# Type definitions
AudioArray = npt.NDArray[np.float32]
class AudioQueueItem(BaseModel):
"""Audio data with timestamp for processing queue."""
audio: AudioArray
timestamp: float
class Config:
arbitrary_types_allowed = True
class TranscriptionHistoryItem(BaseModel):
"""Transcription history item with metadata."""
message: str
timestamp: float
is_final: bool
AGENT_NAME = "whisper"
AGENT_DESCRIPTION = "Real-time speech transcription (Whisper) - converts speech to text"
sample_rate = 16000 # Whisper expects 16kHz
model_ids = {
"Distil-Whisper": [
"distil-whisper/distil-large-v2",
"distil-whisper/distil-medium.en",
"distil-whisper/distil-small.en",
],
"Whisper": [
"openai/whisper-large-v3",
"openai/whisper-large-v2",
"openai/whisper-large",
"openai/whisper-medium",
"openai/whisper-small",
"openai/whisper-base",
"openai/whisper-tiny",
"openai/whisper-medium.en",
"openai/whisper-small.en",
"openai/whisper-base.en",
"openai/whisper-tiny.en",
],
}
# Global whisper model and transcription handler
_model_type = model_ids["Distil-Whisper"]
_model_id = _model_type[0]
logger.info(f"Loading Whisper model: {_model_id}")
_processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=_model_id) # type: ignore
logger.info("Whisper processor loaded successfully")
_pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(
pretrained_model_name_or_path=_model_id
) # type: ignore
_pt_model.eval() # type: ignore
logger.info("Whisper model loaded and set to evaluation mode")
_audio_processors: Dict[str, "AudioProcessor"] = {} # Per-peer audio processors
_send_chat_func: Optional[Callable[[str], Awaitable[None]]] = None
def extract_input_features(audio_array: Any, sampling_rate: int) -> Any:
"""Extract input features from audio array and sampling rate."""
processor_output = _processor( # type: ignore
audio_array,
sampling_rate=sampling_rate,
return_tensors="pt",
)
input_features: Any = processor_output.input_features # type: ignore
return input_features # type: ignore
class AudioProcessor:
"""Handles audio stream processing and transcription with sentence chunking for a specific peer."""
def __init__(
self, peer_name: str, send_chat_func: Callable[[str], Awaitable[None]]
):
self.peer_name = peer_name
self.send_chat_func = send_chat_func
self.sample_rate = 16000 # Whisper expects 16kHz
self.samples_per_frame = 480 # Common WebRTC frame size at 16kHz (30ms)
# Audio buffering
self.audio_buffer: Deque[AudioArray] = deque(
maxlen=1000
) # ~30 seconds at 30ms frames
self.phrase_timeout = (
3.0 # seconds of silence before considering phrase complete
)
self.last_activity_time = time.time()
# Transcription state
self.current_phrase_audio: AudioArray = np.array([], dtype=np.float32)
self.transcription_history: list[TranscriptionHistoryItem] = []
# Background processing
self.processing_queue: Queue[AudioQueueItem] = Queue()
self.is_running = True
self.processor_thread = threading.Thread(
target=self._processing_loop, daemon=True
)
self.processor_thread.start()
logger.info(
f"AudioProcessor initialized for {self.peer_name} - sample_rate: {self.sample_rate}Hz, frame_size: {self.samples_per_frame}, phrase_timeout: {self.phrase_timeout}s"
)
def add_audio_data(self, audio_data: AudioArray):
"""Add new audio data to the processing buffer."""
if not self.is_running:
logger.debug("AudioProcessor not running, ignoring audio data")
return
# Resample if needed (WebRTC might provide different sample rates)
if len(audio_data) > 0:
self.audio_buffer.append(audio_data)
self.last_activity_time = time.time()
# Calculate audio metrics to detect silence
audio_rms = np.sqrt(np.mean(audio_data**2))
audio_peak = np.max(np.abs(audio_data))
# Log audio buffer status (reduced verbosity)
buffer_duration_ms = len(self.audio_buffer) * 30 # assuming 30ms frames
# Only log if we have meaningful audio or every 50 frames
if audio_rms > 0.001 or len(self.audio_buffer) % 50 == 0:
logger.info(
f"Added audio chunk: {len(audio_data)} samples, buffer size: {len(self.audio_buffer)} frames ({buffer_duration_ms}ms), RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f}"
)
else:
logger.debug(
f"Added silent audio chunk: {len(audio_data)} samples, buffer size: {len(self.audio_buffer)} frames"
)
# Check if we should process accumulated audio
if len(self.audio_buffer) >= 10: # Process every ~300ms (10 * 30ms frames)
# Check if we have any meaningful audio in the buffer
combined_audio = np.concatenate(list(self.audio_buffer))
combined_rms = np.sqrt(np.mean(combined_audio**2))
if combined_rms > 0.001: # Only process if not silence
logger.info(
f"Buffer threshold reached with meaningful audio (RMS: {combined_rms:.4f}), queuing for processing"
)
self._queue_for_processing()
else:
logger.debug(
f"Buffer threshold reached but audio is silent (RMS: {combined_rms:.4f}), clearing buffer"
)
self.audio_buffer.clear() # Clear silent audio
def _queue_for_processing(self):
"""Queue current audio buffer for transcription processing."""
if not self.audio_buffer:
logger.debug("No audio in buffer to queue for processing")
return
# Combine recent audio frames
combined_audio = np.concatenate(list(self.audio_buffer))
self.audio_buffer.clear()
# Calculate audio metrics
audio_duration_sec = len(combined_audio) / self.sample_rate
audio_rms = np.sqrt(np.mean(combined_audio**2))
audio_peak = np.max(np.abs(combined_audio))
# Skip completely silent audio
if audio_rms < 0.001 and audio_peak < 0.001:
logger.debug(
f"Skipping silent audio chunk: RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f}"
)
return
logger.info(
f"Queuing audio chunk: {len(combined_audio)} samples, {audio_duration_sec:.2f}s duration, RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f}"
)
# Add to processing queue
try:
queue_item = AudioQueueItem(audio=combined_audio, timestamp=time.time())
self.processing_queue.put_nowait(queue_item)
logger.info(
f"Added to processing queue, queue size: {self.processing_queue.qsize()}"
)
except Exception as e:
# Queue full, skip this chunk
logger.warning(f"Audio processing queue full, dropping audio chunk: {e}")
def _processing_loop(self):
"""Background thread that processes audio chunks for transcription."""
global _whisper_model
logger.info("ASR processing loop started")
while self.is_running:
try:
# Get audio chunk to process (blocking with timeout)
try:
audio_data = self.processing_queue.get(timeout=1.0)
logger.debug(
f"Retrieved audio chunk from queue, remaining queue size: {self.processing_queue.qsize()}"
)
except Empty:
logger.debug("Processing queue timeout, checking for more audio...")
continue
audio_array = audio_data.audio
chunk_timestamp = audio_data.timestamp
# Check if this is a new phrase (gap in audio)
time_since_last = chunk_timestamp - self.last_activity_time
phrase_complete = time_since_last > self.phrase_timeout
logger.debug(
f"Processing audio chunk: {len(audio_array)} samples, time since last: {time_since_last:.2f}s, phrase_complete: {phrase_complete}"
)
if phrase_complete and len(self.current_phrase_audio) > 0:
# Process the completed phrase
phrase_duration = len(self.current_phrase_audio) / self.sample_rate
phrase_rms = np.sqrt(np.mean(self.current_phrase_audio**2))
logger.info(
f"Processing completed phrase: {phrase_duration:.2f}s duration, {len(self.current_phrase_audio)} samples, RMS: {phrase_rms:.4f}"
)
try:
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(
self._transcribe_and_send(
self.current_phrase_audio.copy(), is_final=True
),
loop,
)
except RuntimeError as e:
# No event loop running, skip this transcription
logger.warning(
f"No event loop available for final transcription: {e}"
)
pass
self.current_phrase_audio = np.array([], dtype=np.float32)
# Add new audio to current phrase
old_phrase_length = len(self.current_phrase_audio)
self.current_phrase_audio = np.concatenate(
[self.current_phrase_audio, audio_array]
)
current_phrase_duration = (
len(self.current_phrase_audio) / self.sample_rate
)
logger.debug(
f"Updated current phrase: {old_phrase_length} -> {len(self.current_phrase_audio)} samples ({current_phrase_duration:.2f}s)"
)
# Lower the threshold for streaming transcription to catch shorter phrases
min_transcription_duration = 1.0 # Reduced from 2.0 seconds
if (
len(self.current_phrase_audio)
> self.sample_rate * min_transcription_duration
): # At least 1 second
phrase_rms = np.sqrt(np.mean(self.current_phrase_audio**2))
logger.info(
f"Current phrase >= {min_transcription_duration}s (RMS: {phrase_rms:.4f}), attempting streaming transcription"
)
try:
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(
self._transcribe_and_send(
self.current_phrase_audio.copy(), is_final=False
),
loop,
)
except RuntimeError as e:
# No event loop running, skip this transcription
logger.warning(
f"No event loop available for streaming transcription: {e}"
)
pass
except Exception as e:
logger.error(f"Error in audio processing loop: {e}", exc_info=True)
logger.info("ASR processing loop ended")
async def _transcribe_and_send(self, audio_array: AudioArray, is_final: bool):
"""Transcribe audio and send result as chat message."""
global sample_rate
transcription_start_time = time.time()
transcription_type = "final" if is_final else "streaming"
try:
audio_duration_sec = len(audio_array) / self.sample_rate
# Reduce minimum audio duration threshold
min_duration = 0.3 # Reduced from 0.5 seconds
if len(audio_array) < self.sample_rate * min_duration:
logger.debug(
f"Skipping {transcription_type} transcription: audio too short ({audio_duration_sec:.2f}s < {min_duration}s)"
)
return
# Calculate audio quality metrics
audio_rms = np.sqrt(np.mean(audio_array**2))
audio_peak = np.max(np.abs(audio_array))
# More lenient silence detection
if audio_rms < 0.0005: # Very quiet threshold
logger.debug(
f"Skipping {transcription_type} transcription: audio too quiet (RMS: {audio_rms:.6f})"
)
return
logger.info(
f"Starting {transcription_type} transcription: {audio_duration_sec:.2f}s audio, RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f}"
)
# Ensure audio is in the right format for Whisper
audio_array = audio_array.astype(np.float32)
# Transcribe with Whisper
feature_extraction_start = time.time()
input_features = extract_input_features(audio_array, sample_rate)
feature_extraction_time = time.time() - feature_extraction_start
model_inference_start = time.time()
predicted_ids = _pt_model.generate(input_features) # type: ignore
model_inference_time = time.time() - model_inference_start
decoding_start = time.time()
transcription = _processor.batch_decode(
predicted_ids, skip_special_tokens=True
) # type: ignore
decoding_time = time.time() - decoding_start
total_transcription_time = time.time() - transcription_start_time
logger.debug(
f"ASR timing - Feature extraction: {feature_extraction_time:.3f}s, Model inference: {model_inference_time:.3f}s, Decoding: {decoding_time:.3f}s, Total: {total_transcription_time:.3f}s"
)
text = (
transcription[0].strip()
if transcription and len(transcription) > 0
else ""
)
if text and len(text) > 0: # Accept any non-empty text
prefix = (
f"🎤 {self.peer_name}: "
if is_final
else f"🎤 {self.peer_name} [partial]: "
)
message = f"{prefix}{text}"
# Avoid sending duplicate messages
if is_final or message not in [
h.message for h in self.transcription_history[-3:]
]:
await self.send_chat_func(message)
# Keep history for deduplication
history_item = TranscriptionHistoryItem(
message=message, timestamp=time.time(), is_final=is_final
)
self.transcription_history.append(history_item)
# Limit history size
if len(self.transcription_history) > 10:
self.transcription_history.pop(0)
logger.info(
f"✅ Transcribed ({transcription_type}) for {self.peer_name}: '{text}' (processing time: {total_transcription_time:.3f}s, audio duration: {audio_duration_sec:.2f}s)"
)
else:
logger.debug(
f"Skipping duplicate {transcription_type} transcription: '{text}'"
)
else:
logger.info(
f"❌ No text from {transcription_type} transcription for {self.peer_name} (empty result from model)"
)
except Exception as e:
logger.error(
f"Error in {transcription_type} transcription: {e}", exc_info=True
)
def shutdown(self):
"""Shutdown the audio processor."""
logger.info(f"Shutting down AudioProcessor for {self.peer_name}...")
self.is_running = False
if self.processor_thread.is_alive():
logger.debug(
f"Waiting for processor thread for {self.peer_name} to finish..."
)
self.processor_thread.join(timeout=2.0)
if self.processor_thread.is_alive():
logger.warning(
f"Processor thread for {self.peer_name} did not shut down cleanly within timeout"
)
else:
logger.info(
f"Processor thread for {self.peer_name} shut down successfully"
)
logger.info(f"AudioProcessor for {self.peer_name} shutdown complete")
async def handle_track_received(peer: Peer, track: MediaStreamTrack):
"""Handle incoming audio tracks from WebRTC peers."""
global _audio_processors, _send_chat_func
if track.kind != "audio":
logger.info(f"Ignoring non-audio track from {peer.peer_name}: {track.kind}")
return
# Create or get audio processor for this peer
if peer.peer_name not in _audio_processors:
if _send_chat_func is None:
logger.error(
f"Cannot create AudioProcessor for {peer.peer_name}: no send_chat_func available"
)
return
logger.info(f"Creating new AudioProcessor for {peer.peer_name}")
_audio_processors[peer.peer_name] = AudioProcessor(
peer_name=peer.peer_name, send_chat_func=_send_chat_func
)
audio_processor = _audio_processors[peer.peer_name]
logger.info(
f"Received audio track from {peer.peer_name}, starting transcription (processor available: {audio_processor is not None})"
)
try:
while True:
# Receive audio frame
frame = await track.recv()
if isinstance(frame, AudioFrame):
frame_info = (
f"{frame.sample_rate}Hz, {frame.format.name}, {frame.layout.name}"
)
logger.debug(
f"Received audio frame from {peer.peer_name}: {frame_info}"
)
# Convert AudioFrame to numpy array
audio_data = frame.to_ndarray()
original_shape = audio_data.shape
original_dtype = audio_data.dtype
logger.debug(
f"Audio frame data: shape={original_shape}, dtype={original_dtype}"
)
# Handle different audio formats
if audio_data.ndim == 2: # Stereo -> mono
audio_data = np.mean(audio_data, axis=1)
logger.debug(
f"Converted stereo to mono: {original_shape} -> {audio_data.shape}"
)
# Convert to float32 and normalize
if audio_data.dtype == np.int16:
audio_data = audio_data.astype(np.float32) / 32768.0
logger.debug("Normalized int16 audio to float32")
elif audio_data.dtype == np.int32:
audio_data = audio_data.astype(np.float32) / 2147483648.0
logger.debug("Normalized int32 audio to float32")
# Resample to 16kHz if needed
if frame.sample_rate != sample_rate:
original_length = len(audio_data)
audio_data = librosa.resample( # type: ignore
audio_data, orig_sr=frame.sample_rate, target_sr=sample_rate
)
logger.debug(
f"Resampled audio: {frame.sample_rate}Hz -> {sample_rate}Hz, {original_length} -> {len(audio_data)} samples"
)
# Ensure audio_data is AudioArray (float32)
audio_data_float32 = cast(AudioArray, audio_data.astype(np.float32))
# Calculate audio quality metrics for this frame
frame_rms = np.sqrt(np.mean(audio_data_float32**2))
frame_peak = np.max(np.abs(audio_data_float32))
# Only log full frame details every 20 frames to reduce noise
frame_count = getattr(peer, "_whisper_frame_count", 0) + 1
setattr(peer, "_whisper_frame_count", frame_count)
if frame_count % 20 == 0:
logger.info(
f"Audio frame #{frame_count} from {peer.peer_name}: {frame_info}, {len(audio_data_float32)} samples, RMS: {frame_rms:.4f}, Peak: {frame_peak:.4f}"
)
else:
logger.debug(
f"Audio frame #{frame_count}: RMS: {frame_rms:.4f}, Peak: {frame_peak:.4f}"
)
# Send to audio processor
if audio_processor:
audio_processor.add_audio_data(audio_data_float32)
else:
logger.warning(
f"No audio processor available to handle audio data for {peer.peer_name}"
)
else:
logger.warning(
f"Received non-audio frame on audio track from {peer.peer_name}: type={type(frame)}"
)
except Exception as e:
logger.error(
f"Error processing audio track from {peer.peer_name}: {e}", exc_info=True
)
def agent_info() -> Dict[str, str]:
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "false"}
def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
"""Whisper is not a media source - return no local tracks."""
return {}
async def handle_chat_message(
chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]]
) -> Optional[str]:
"""Handle incoming chat messages and optionally return a response."""
pass
async def on_track_received(peer: Peer, track: MediaStreamTrack):
"""Callback when a new track is received from a peer."""
await handle_track_received(peer, track)
# Export functions for the orchestrator to discover
def get_track_handler():
"""Return the track handler function for the orchestrator to use."""
return on_track_received
def bind_send_chat_function(send_chat_func: Callable[[str], Awaitable[None]]):
"""Bind the send chat function to be used for all audio processors."""
global _send_chat_func, _audio_processors
logger.info("Binding send chat function to whisper agent")
_send_chat_func = send_chat_func
# Update existing audio processors
for peer_name, processor in _audio_processors.items():
logger.debug(
f"Updating AudioProcessor for {peer_name} with new send chat function"
)
processor.send_chat_func = send_chat_func
def cleanup_peer_processor(peer_name: str):
"""Clean up audio processor for a disconnected peer."""
global _audio_processors
if peer_name in _audio_processors:
logger.info(f"Cleaning up AudioProcessor for disconnected peer: {peer_name}")
processor = _audio_processors[peer_name]
processor.shutdown()
del _audio_processors[peer_name]
logger.info(f"AudioProcessor for {peer_name} cleaned up successfully")
else:
logger.debug(f"No AudioProcessor found for peer {peer_name} during cleanup")
def get_active_processors() -> Dict[str, "AudioProcessor"]:
"""Get currently active audio processors (for debugging)."""
return _audio_processors.copy()