"""Streaming Whisper agent (bots/whisper) Real-time speech transcription agent that processes incoming audio streams and sends transcriptions as chat messages to the lobby. """ import asyncio import numpy as np import time import threading from collections import deque from queue import Queue, Empty from typing import Dict, Optional, Callable, Awaitable, Deque, Any, cast import numpy.typing as npt from pydantic import BaseModel # Core dependencies import librosa from logger import logger from aiortc import MediaStreamTrack from av import AudioFrame # Import shared models for chat functionality import sys import os from voicebot.models import Peer sys.path.append( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) from shared.models import ChatMessageModel from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq # Type definitions AudioArray = npt.NDArray[np.float32] class AudioQueueItem(BaseModel): """Audio data with timestamp for processing queue.""" audio: AudioArray timestamp: float class Config: arbitrary_types_allowed = True class TranscriptionHistoryItem(BaseModel): """Transcription history item with metadata.""" message: str timestamp: float is_final: bool AGENT_NAME = "whisper" AGENT_DESCRIPTION = "Real-time speech transcription (Whisper) - converts speech to text" sample_rate = 16000 # Whisper expects 16kHz model_ids = { "Distil-Whisper": [ "distil-whisper/distil-large-v2", "distil-whisper/distil-medium.en", "distil-whisper/distil-small.en", ], "Whisper": [ "openai/whisper-large-v3", "openai/whisper-large-v2", "openai/whisper-large", "openai/whisper-medium", "openai/whisper-small", "openai/whisper-base", "openai/whisper-tiny", "openai/whisper-medium.en", "openai/whisper-small.en", "openai/whisper-base.en", "openai/whisper-tiny.en", ], } # Global whisper model and transcription handler _model_type = model_ids["Distil-Whisper"] _model_id = _model_type[0] logger.info(f"Loading Whisper model: {_model_id}") _processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=_model_id) # type: ignore logger.info("Whisper processor loaded successfully") _pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained( pretrained_model_name_or_path=_model_id ) # type: ignore _pt_model.eval() # type: ignore logger.info("Whisper model loaded and set to evaluation mode") _audio_processors: Dict[str, "AudioProcessor"] = {} # Per-peer audio processors _send_chat_func: Optional[Callable[[str], Awaitable[None]]] = None def extract_input_features(audio_array: Any, sampling_rate: int) -> Any: """Extract input features from audio array and sampling rate.""" processor_output = _processor( # type: ignore audio_array, sampling_rate=sampling_rate, return_tensors="pt", ) input_features: Any = processor_output.input_features # type: ignore return input_features # type: ignore class AudioProcessor: """Handles audio stream processing and transcription with sentence chunking for a specific peer.""" def __init__( self, peer_name: str, send_chat_func: Callable[[str], Awaitable[None]] ): self.peer_name = peer_name self.send_chat_func = send_chat_func self.sample_rate = 16000 # Whisper expects 16kHz self.samples_per_frame = 480 # Common WebRTC frame size at 16kHz (30ms) # Audio buffering self.audio_buffer: Deque[AudioArray] = deque( maxlen=1000 ) # ~30 seconds at 30ms frames self.phrase_timeout = ( 3.0 # seconds of silence before considering phrase complete ) self.last_activity_time = time.time() # Transcription state self.current_phrase_audio: AudioArray = np.array([], dtype=np.float32) self.transcription_history: list[TranscriptionHistoryItem] = [] # Background processing self.processing_queue: Queue[AudioQueueItem] = Queue() self.is_running = True self.processor_thread = threading.Thread( target=self._processing_loop, daemon=True ) self.processor_thread.start() logger.info( f"AudioProcessor initialized for {self.peer_name} - sample_rate: {self.sample_rate}Hz, frame_size: {self.samples_per_frame}, phrase_timeout: {self.phrase_timeout}s" ) def add_audio_data(self, audio_data: AudioArray): """Add new audio data to the processing buffer.""" if not self.is_running: logger.debug("AudioProcessor not running, ignoring audio data") return # Resample if needed (WebRTC might provide different sample rates) if len(audio_data) > 0: self.audio_buffer.append(audio_data) self.last_activity_time = time.time() # Calculate audio metrics to detect silence audio_rms = np.sqrt(np.mean(audio_data**2)) audio_peak = np.max(np.abs(audio_data)) # Log audio buffer status (reduced verbosity) buffer_duration_ms = len(self.audio_buffer) * 30 # assuming 30ms frames # Only log if we have meaningful audio or every 50 frames if audio_rms > 0.001 or len(self.audio_buffer) % 50 == 0: logger.info( f"Added audio chunk: {len(audio_data)} samples, buffer size: {len(self.audio_buffer)} frames ({buffer_duration_ms}ms), RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f}" ) else: logger.debug( f"Added silent audio chunk: {len(audio_data)} samples, buffer size: {len(self.audio_buffer)} frames" ) # Check if we should process accumulated audio if len(self.audio_buffer) >= 10: # Process every ~300ms (10 * 30ms frames) # Check if we have any meaningful audio in the buffer combined_audio = np.concatenate(list(self.audio_buffer)) combined_rms = np.sqrt(np.mean(combined_audio**2)) if combined_rms > 0.001: # Only process if not silence logger.info( f"Buffer threshold reached with meaningful audio (RMS: {combined_rms:.4f}), queuing for processing" ) self._queue_for_processing() else: logger.debug( f"Buffer threshold reached but audio is silent (RMS: {combined_rms:.4f}), clearing buffer" ) self.audio_buffer.clear() # Clear silent audio def _queue_for_processing(self): """Queue current audio buffer for transcription processing.""" if not self.audio_buffer: logger.debug("No audio in buffer to queue for processing") return # Combine recent audio frames combined_audio = np.concatenate(list(self.audio_buffer)) self.audio_buffer.clear() # Calculate audio metrics audio_duration_sec = len(combined_audio) / self.sample_rate audio_rms = np.sqrt(np.mean(combined_audio**2)) audio_peak = np.max(np.abs(combined_audio)) # Skip completely silent audio if audio_rms < 0.001 and audio_peak < 0.001: logger.debug( f"Skipping silent audio chunk: RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f}" ) return logger.info( f"Queuing audio chunk: {len(combined_audio)} samples, {audio_duration_sec:.2f}s duration, RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f}" ) # Add to processing queue try: queue_item = AudioQueueItem(audio=combined_audio, timestamp=time.time()) self.processing_queue.put_nowait(queue_item) logger.info( f"Added to processing queue, queue size: {self.processing_queue.qsize()}" ) except Exception as e: # Queue full, skip this chunk logger.warning(f"Audio processing queue full, dropping audio chunk: {e}") def _processing_loop(self): """Background thread that processes audio chunks for transcription.""" global _whisper_model logger.info("ASR processing loop started") while self.is_running: try: # Get audio chunk to process (blocking with timeout) try: audio_data = self.processing_queue.get(timeout=1.0) logger.debug( f"Retrieved audio chunk from queue, remaining queue size: {self.processing_queue.qsize()}" ) except Empty: logger.debug("Processing queue timeout, checking for more audio...") continue audio_array = audio_data.audio chunk_timestamp = audio_data.timestamp # Check if this is a new phrase (gap in audio) time_since_last = chunk_timestamp - self.last_activity_time phrase_complete = time_since_last > self.phrase_timeout logger.debug( f"Processing audio chunk: {len(audio_array)} samples, time since last: {time_since_last:.2f}s, phrase_complete: {phrase_complete}" ) if phrase_complete and len(self.current_phrase_audio) > 0: # Process the completed phrase phrase_duration = len(self.current_phrase_audio) / self.sample_rate phrase_rms = np.sqrt(np.mean(self.current_phrase_audio**2)) logger.info( f"Processing completed phrase: {phrase_duration:.2f}s duration, {len(self.current_phrase_audio)} samples, RMS: {phrase_rms:.4f}" ) try: loop = asyncio.get_event_loop() asyncio.run_coroutine_threadsafe( self._transcribe_and_send( self.current_phrase_audio.copy(), is_final=True ), loop, ) except RuntimeError as e: # No event loop running, skip this transcription logger.warning( f"No event loop available for final transcription: {e}" ) pass self.current_phrase_audio = np.array([], dtype=np.float32) # Add new audio to current phrase old_phrase_length = len(self.current_phrase_audio) self.current_phrase_audio = np.concatenate( [self.current_phrase_audio, audio_array] ) current_phrase_duration = ( len(self.current_phrase_audio) / self.sample_rate ) logger.debug( f"Updated current phrase: {old_phrase_length} -> {len(self.current_phrase_audio)} samples ({current_phrase_duration:.2f}s)" ) # Lower the threshold for streaming transcription to catch shorter phrases min_transcription_duration = 1.0 # Reduced from 2.0 seconds if ( len(self.current_phrase_audio) > self.sample_rate * min_transcription_duration ): # At least 1 second phrase_rms = np.sqrt(np.mean(self.current_phrase_audio**2)) logger.info( f"Current phrase >= {min_transcription_duration}s (RMS: {phrase_rms:.4f}), attempting streaming transcription" ) try: loop = asyncio.get_event_loop() asyncio.run_coroutine_threadsafe( self._transcribe_and_send( self.current_phrase_audio.copy(), is_final=False ), loop, ) except RuntimeError as e: # No event loop running, skip this transcription logger.warning( f"No event loop available for streaming transcription: {e}" ) pass except Exception as e: logger.error(f"Error in audio processing loop: {e}", exc_info=True) logger.info("ASR processing loop ended") async def _transcribe_and_send(self, audio_array: AudioArray, is_final: bool): """Transcribe audio and send result as chat message.""" global sample_rate transcription_start_time = time.time() transcription_type = "final" if is_final else "streaming" try: audio_duration_sec = len(audio_array) / self.sample_rate # Reduce minimum audio duration threshold min_duration = 0.3 # Reduced from 0.5 seconds if len(audio_array) < self.sample_rate * min_duration: logger.debug( f"Skipping {transcription_type} transcription: audio too short ({audio_duration_sec:.2f}s < {min_duration}s)" ) return # Calculate audio quality metrics audio_rms = np.sqrt(np.mean(audio_array**2)) audio_peak = np.max(np.abs(audio_array)) # More lenient silence detection if audio_rms < 0.0005: # Very quiet threshold logger.debug( f"Skipping {transcription_type} transcription: audio too quiet (RMS: {audio_rms:.6f})" ) return logger.info( f"Starting {transcription_type} transcription: {audio_duration_sec:.2f}s audio, RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f}" ) # Ensure audio is in the right format for Whisper audio_array = audio_array.astype(np.float32) # Transcribe with Whisper feature_extraction_start = time.time() input_features = extract_input_features(audio_array, sample_rate) feature_extraction_time = time.time() - feature_extraction_start model_inference_start = time.time() predicted_ids = _pt_model.generate(input_features) # type: ignore model_inference_time = time.time() - model_inference_start decoding_start = time.time() transcription = _processor.batch_decode( predicted_ids, skip_special_tokens=True ) # type: ignore decoding_time = time.time() - decoding_start total_transcription_time = time.time() - transcription_start_time logger.debug( f"ASR timing - Feature extraction: {feature_extraction_time:.3f}s, Model inference: {model_inference_time:.3f}s, Decoding: {decoding_time:.3f}s, Total: {total_transcription_time:.3f}s" ) text = ( transcription[0].strip() if transcription and len(transcription) > 0 else "" ) if text and len(text) > 0: # Accept any non-empty text prefix = ( f"🎤 {self.peer_name}: " if is_final else f"🎤 {self.peer_name} [partial]: " ) message = f"{prefix}{text}" # Avoid sending duplicate messages if is_final or message not in [ h.message for h in self.transcription_history[-3:] ]: await self.send_chat_func(message) # Keep history for deduplication history_item = TranscriptionHistoryItem( message=message, timestamp=time.time(), is_final=is_final ) self.transcription_history.append(history_item) # Limit history size if len(self.transcription_history) > 10: self.transcription_history.pop(0) logger.info( f"✅ Transcribed ({transcription_type}) for {self.peer_name}: '{text}' (processing time: {total_transcription_time:.3f}s, audio duration: {audio_duration_sec:.2f}s)" ) else: logger.debug( f"Skipping duplicate {transcription_type} transcription: '{text}'" ) else: logger.info( f"❌ No text from {transcription_type} transcription for {self.peer_name} (empty result from model)" ) except Exception as e: logger.error( f"Error in {transcription_type} transcription: {e}", exc_info=True ) def shutdown(self): """Shutdown the audio processor.""" logger.info(f"Shutting down AudioProcessor for {self.peer_name}...") self.is_running = False if self.processor_thread.is_alive(): logger.debug( f"Waiting for processor thread for {self.peer_name} to finish..." ) self.processor_thread.join(timeout=2.0) if self.processor_thread.is_alive(): logger.warning( f"Processor thread for {self.peer_name} did not shut down cleanly within timeout" ) else: logger.info( f"Processor thread for {self.peer_name} shut down successfully" ) logger.info(f"AudioProcessor for {self.peer_name} shutdown complete") async def handle_track_received(peer: Peer, track: MediaStreamTrack): """Handle incoming audio tracks from WebRTC peers.""" global _audio_processors, _send_chat_func if track.kind != "audio": logger.info(f"Ignoring non-audio track from {peer.peer_name}: {track.kind}") return # Create or get audio processor for this peer if peer.peer_name not in _audio_processors: if _send_chat_func is None: logger.error( f"Cannot create AudioProcessor for {peer.peer_name}: no send_chat_func available" ) return logger.info(f"Creating new AudioProcessor for {peer.peer_name}") _audio_processors[peer.peer_name] = AudioProcessor( peer_name=peer.peer_name, send_chat_func=_send_chat_func ) audio_processor = _audio_processors[peer.peer_name] logger.info( f"Received audio track from {peer.peer_name}, starting transcription (processor available: {audio_processor is not None})" ) try: while True: # Receive audio frame frame = await track.recv() if isinstance(frame, AudioFrame): frame_info = ( f"{frame.sample_rate}Hz, {frame.format.name}, {frame.layout.name}" ) logger.debug( f"Received audio frame from {peer.peer_name}: {frame_info}" ) # Convert AudioFrame to numpy array audio_data = frame.to_ndarray() original_shape = audio_data.shape original_dtype = audio_data.dtype logger.debug( f"Audio frame data: shape={original_shape}, dtype={original_dtype}" ) # Handle different audio formats if audio_data.ndim == 2: # Stereo -> mono audio_data = np.mean(audio_data, axis=1) logger.debug( f"Converted stereo to mono: {original_shape} -> {audio_data.shape}" ) # Convert to float32 and normalize if audio_data.dtype == np.int16: audio_data = audio_data.astype(np.float32) / 32768.0 logger.debug("Normalized int16 audio to float32") elif audio_data.dtype == np.int32: audio_data = audio_data.astype(np.float32) / 2147483648.0 logger.debug("Normalized int32 audio to float32") # Resample to 16kHz if needed if frame.sample_rate != sample_rate: original_length = len(audio_data) audio_data = librosa.resample( # type: ignore audio_data, orig_sr=frame.sample_rate, target_sr=sample_rate ) logger.debug( f"Resampled audio: {frame.sample_rate}Hz -> {sample_rate}Hz, {original_length} -> {len(audio_data)} samples" ) # Ensure audio_data is AudioArray (float32) audio_data_float32 = cast(AudioArray, audio_data.astype(np.float32)) # Calculate audio quality metrics for this frame frame_rms = np.sqrt(np.mean(audio_data_float32**2)) frame_peak = np.max(np.abs(audio_data_float32)) # Only log full frame details every 20 frames to reduce noise frame_count = getattr(peer, "_whisper_frame_count", 0) + 1 setattr(peer, "_whisper_frame_count", frame_count) if frame_count % 20 == 0: logger.info( f"Audio frame #{frame_count} from {peer.peer_name}: {frame_info}, {len(audio_data_float32)} samples, RMS: {frame_rms:.4f}, Peak: {frame_peak:.4f}" ) else: logger.debug( f"Audio frame #{frame_count}: RMS: {frame_rms:.4f}, Peak: {frame_peak:.4f}" ) # Send to audio processor if audio_processor: audio_processor.add_audio_data(audio_data_float32) else: logger.warning( f"No audio processor available to handle audio data for {peer.peer_name}" ) else: logger.warning( f"Received non-audio frame on audio track from {peer.peer_name}: type={type(frame)}" ) except Exception as e: logger.error( f"Error processing audio track from {peer.peer_name}: {e}", exc_info=True ) def agent_info() -> Dict[str, str]: return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "false"} def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]: """Whisper is not a media source - return no local tracks.""" return {} async def handle_chat_message( chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]] ) -> Optional[str]: """Handle incoming chat messages and optionally return a response.""" pass async def on_track_received(peer: Peer, track: MediaStreamTrack): """Callback when a new track is received from a peer.""" await handle_track_received(peer, track) # Export functions for the orchestrator to discover def get_track_handler(): """Return the track handler function for the orchestrator to use.""" return on_track_received def bind_send_chat_function(send_chat_func: Callable[[str], Awaitable[None]]): """Bind the send chat function to be used for all audio processors.""" global _send_chat_func, _audio_processors logger.info("Binding send chat function to whisper agent") _send_chat_func = send_chat_func # Update existing audio processors for peer_name, processor in _audio_processors.items(): logger.debug( f"Updating AudioProcessor for {peer_name} with new send chat function" ) processor.send_chat_func = send_chat_func def cleanup_peer_processor(peer_name: str): """Clean up audio processor for a disconnected peer.""" global _audio_processors if peer_name in _audio_processors: logger.info(f"Cleaning up AudioProcessor for disconnected peer: {peer_name}") processor = _audio_processors[peer_name] processor.shutdown() del _audio_processors[peer_name] logger.info(f"AudioProcessor for {peer_name} cleaned up successfully") else: logger.debug(f"No AudioProcessor found for peer {peer_name} during cleanup") def get_active_processors() -> Dict[str, "AudioProcessor"]: """Get currently active audio processors (for debugging).""" return _audio_processors.copy()