"""Streaming Whisper agent (bots/whisper) Real-time speech transcription agent that processes incoming audio streams and sends transcriptions as chat messages to the lobby. """ import asyncio import numpy as np import time import threading from collections import deque from queue import Queue, Empty from typing import Dict, Optional, Callable, Awaitable, Deque, Any, cast import numpy.typing as npt from pydantic import BaseModel # Core dependencies import librosa from logger import logger from aiortc import MediaStreamTrack from aiortc.mediastreams import MediaStreamError from av import AudioFrame # Import shared models for chat functionality import sys import os from voicebot.models import Peer sys.path.append( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) from shared.models import ChatMessageModel from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq # Type definitions AudioArray = npt.NDArray[np.float32] class AudioQueueItem(BaseModel): """Audio data with timestamp for processing queue.""" audio: AudioArray timestamp: float class Config: arbitrary_types_allowed = True class TranscriptionHistoryItem(BaseModel): """Transcription history item with metadata.""" message: str timestamp: float is_final: bool AGENT_NAME = "whisper" AGENT_DESCRIPTION = "Real-time speech transcription (Whisper) - converts speech to text" sample_rate = 16000 # Whisper expects 16kHz model_ids = { "Distil-Whisper": [ "distil-whisper/distil-large-v2", "distil-whisper/distil-medium.en", "distil-whisper/distil-small.en", ], "Whisper": [ "openai/whisper-large-v3", "openai/whisper-large-v2", "openai/whisper-large", "openai/whisper-medium", "openai/whisper-small", "openai/whisper-base", "openai/whisper-tiny", "openai/whisper-medium.en", "openai/whisper-small.en", "openai/whisper-base.en", "openai/whisper-tiny.en", ], } # Global whisper model and transcription handler _model_type = model_ids["Distil-Whisper"] _model_id = _model_type[0] logger.info(f"Loading Whisper model: {_model_id}") _processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=_model_id) # type: ignore logger.info("Whisper processor loaded successfully") _pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained( pretrained_model_name_or_path=_model_id ) # type: ignore _pt_model.eval() # type: ignore logger.info("Whisper model loaded and set to evaluation mode") _audio_processors: Dict[str, "AudioProcessor"] = {} # Per-peer audio processors _send_chat_func: Optional[Callable[[str], Awaitable[None]]] = None def extract_input_features(audio_array: Any, sampling_rate: int) -> Any: """Extract input features from audio array and sampling rate.""" processor_output = _processor( # type: ignore audio_array, sampling_rate=sampling_rate, return_tensors="pt", ) input_features: Any = processor_output.input_features # type: ignore return input_features # type: ignore class AudioProcessor: """Handles audio stream processing and transcription with sentence chunking for a specific peer.""" def __init__( self, peer_name: str, send_chat_func: Callable[[str], Awaitable[None]] ): self.peer_name = peer_name self.send_chat_func = send_chat_func self.sample_rate = 16000 # Whisper expects 16kHz self.samples_per_frame = 480 # Common WebRTC frame size at 16kHz (30ms) # Audio buffering self.audio_buffer: Deque[AudioArray] = deque( maxlen=1000 ) # ~30 seconds at 30ms frames self.phrase_timeout = ( 3.0 # seconds of silence before considering phrase complete ) self.last_activity_time = time.time() # Transcription state self.current_phrase_audio: AudioArray = np.array([], dtype=np.float32) self.transcription_history: list[TranscriptionHistoryItem] = [] # Background processing self.processing_queue: Queue[AudioQueueItem] = Queue() self.is_running = True self.processor_thread = threading.Thread( target=self._processing_loop, daemon=True ) self.processor_thread.start() logger.info( f"AudioProcessor initialized for {self.peer_name} - sample_rate: {self.sample_rate}Hz, frame_size: {self.samples_per_frame}, phrase_timeout: {self.phrase_timeout}s" ) def add_audio_data(self, audio_data: AudioArray): """Add new audio data to the processing buffer.""" if not self.is_running: logger.debug("AudioProcessor not running, ignoring audio data") return # Resample if needed (WebRTC might provide different sample rates) if len(audio_data) > 0: self.audio_buffer.append(audio_data) self.last_activity_time = time.time() # Calculate audio metrics to detect silence audio_rms = np.sqrt(np.mean(audio_data**2)) audio_peak = np.max(np.abs(audio_data)) # Log audio buffer status (reduced verbosity) buffer_duration_ms = len(self.audio_buffer) * 30 # assuming 30ms frames # Only log if we have meaningful audio or every 50 frames if audio_rms > 0.001 or len(self.audio_buffer) % 50 == 0: logger.info( f"Added audio chunk: {len(audio_data)} samples, buffer size: {len(self.audio_buffer)} frames ({buffer_duration_ms}ms), RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f}" ) else: logger.debug( f"Added silent audio chunk: {len(audio_data)} samples, buffer size: {len(self.audio_buffer)} frames" ) # Check if we should process accumulated audio if len(self.audio_buffer) >= 10: # Process every ~300ms (10 * 30ms frames) # Check if we have any meaningful audio in the buffer combined_audio = np.concatenate(list(self.audio_buffer)) combined_rms = np.sqrt(np.mean(combined_audio**2)) if combined_rms > 0.001: # Only process if not silence logger.info( f"Buffer threshold reached with meaningful audio (RMS: {combined_rms:.4f}), queuing for processing" ) self._queue_for_processing() else: logger.debug( f"Buffer threshold reached but audio is silent (RMS: {combined_rms:.4f}), clearing buffer" ) self.audio_buffer.clear() # Clear silent audio def _queue_for_processing(self): """Queue current audio buffer for transcription processing.""" if not self.audio_buffer: logger.debug("No audio in buffer to queue for processing") return # Combine recent audio frames combined_audio = np.concatenate(list(self.audio_buffer)) self.audio_buffer.clear() # Calculate audio metrics audio_duration_sec = len(combined_audio) / self.sample_rate audio_rms = np.sqrt(np.mean(combined_audio**2)) audio_peak = np.max(np.abs(combined_audio)) # Skip completely silent audio if audio_rms < 0.001 and audio_peak < 0.001: logger.debug( f"Skipping silent audio chunk: RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f}" ) return logger.info( f"Queuing audio chunk: {len(combined_audio)} samples, {audio_duration_sec:.2f}s duration, RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f}" ) # Add to processing queue try: queue_item = AudioQueueItem(audio=combined_audio, timestamp=time.time()) self.processing_queue.put_nowait(queue_item) logger.info( f"Added to processing queue, queue size: {self.processing_queue.qsize()}" ) except Exception as e: # Queue full, skip this chunk logger.warning(f"Audio processing queue full, dropping audio chunk: {e}") def _processing_loop(self): """Background thread that processes audio chunks for transcription.""" global _whisper_model logger.info("ASR processing loop started") while self.is_running: try: # Get audio chunk to process (blocking with timeout) try: audio_data = self.processing_queue.get(timeout=1.0) logger.debug( f"Retrieved audio chunk from queue, remaining queue size: {self.processing_queue.qsize()}" ) except Empty: logger.debug("Processing queue timeout, checking for more audio...") continue audio_array = audio_data.audio chunk_timestamp = audio_data.timestamp # Check if this is a new phrase (gap in audio) time_since_last = chunk_timestamp - self.last_activity_time phrase_complete = time_since_last > self.phrase_timeout logger.debug( f"Processing audio chunk: {len(audio_array)} samples, time since last: {time_since_last:.2f}s, phrase_complete: {phrase_complete}" ) if phrase_complete and len(self.current_phrase_audio) > 0: # Process the completed phrase phrase_duration = len(self.current_phrase_audio) / self.sample_rate phrase_rms = np.sqrt(np.mean(self.current_phrase_audio**2)) logger.info( f"Processing completed phrase: {phrase_duration:.2f}s duration, {len(self.current_phrase_audio)} samples, RMS: {phrase_rms:.4f}" ) try: loop = asyncio.get_event_loop() asyncio.run_coroutine_threadsafe( self._transcribe_and_send( self.current_phrase_audio.copy(), is_final=True ), loop, ) except RuntimeError as e: # No event loop running, skip this transcription logger.warning( f"No event loop available for final transcription: {e}" ) pass self.current_phrase_audio = np.array([], dtype=np.float32) # Add new audio to current phrase old_phrase_length = len(self.current_phrase_audio) self.current_phrase_audio = np.concatenate( [self.current_phrase_audio, audio_array] ) current_phrase_duration = ( len(self.current_phrase_audio) / self.sample_rate ) logger.debug( f"Updated current phrase: {old_phrase_length} -> {len(self.current_phrase_audio)} samples ({current_phrase_duration:.2f}s)" ) # Lower the threshold for streaming transcription to catch shorter phrases min_transcription_duration = 1.0 # Reduced from 2.0 seconds if ( len(self.current_phrase_audio) > self.sample_rate * min_transcription_duration ): # At least 1 second phrase_rms = np.sqrt(np.mean(self.current_phrase_audio**2)) logger.info( f"Current phrase >= {min_transcription_duration}s (RMS: {phrase_rms:.4f}), attempting streaming transcription" ) try: loop = asyncio.get_event_loop() asyncio.run_coroutine_threadsafe( self._transcribe_and_send( self.current_phrase_audio.copy(), is_final=False ), loop, ) except RuntimeError as e: # No event loop running, skip this transcription logger.warning( f"No event loop available for streaming transcription: {e}" ) pass except Exception as e: logger.error(f"Error in audio processing loop: {e}", exc_info=True) logger.info("ASR processing loop ended") async def _transcribe_and_send(self, audio_array: AudioArray, is_final: bool): """Transcribe audio and send result as chat message.""" global sample_rate transcription_start_time = time.time() transcription_type = "final" if is_final else "streaming" try: audio_duration_sec = len(audio_array) / self.sample_rate # Reduce minimum audio duration threshold min_duration = 0.3 # Reduced from 0.5 seconds if len(audio_array) < self.sample_rate * min_duration: logger.debug( f"Skipping {transcription_type} transcription: audio too short ({audio_duration_sec:.2f}s < {min_duration}s)" ) return # Calculate audio quality metrics audio_rms = np.sqrt(np.mean(audio_array**2)) audio_peak = np.max(np.abs(audio_array)) # More lenient silence detection if audio_rms < 0.0005: # Very quiet threshold logger.debug( f"Skipping {transcription_type} transcription: audio too quiet (RMS: {audio_rms:.6f})" ) return logger.info( f"Starting {transcription_type} transcription: {audio_duration_sec:.2f}s audio, RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f}" ) # Ensure audio is in the right format for Whisper audio_array = audio_array.astype(np.float32) # Transcribe with Whisper feature_extraction_start = time.time() input_features = extract_input_features(audio_array, sample_rate) feature_extraction_time = time.time() - feature_extraction_start model_inference_start = time.time() predicted_ids = _pt_model.generate(input_features) # type: ignore model_inference_time = time.time() - model_inference_start decoding_start = time.time() transcription = _processor.batch_decode( predicted_ids, skip_special_tokens=True ) # type: ignore decoding_time = time.time() - decoding_start total_transcription_time = time.time() - transcription_start_time logger.debug( f"ASR timing - Feature extraction: {feature_extraction_time:.3f}s, Model inference: {model_inference_time:.3f}s, Decoding: {decoding_time:.3f}s, Total: {total_transcription_time:.3f}s" ) text = ( transcription[0].strip() if transcription and len(transcription) > 0 else "" ) if text and len(text) > 0: # Accept any non-empty text prefix = ( f"🎤 {self.peer_name}: " if is_final else f"🎤 {self.peer_name} [partial]: " ) message = f"{prefix}{text}" # Avoid sending duplicate messages if is_final or message not in [ h.message for h in self.transcription_history[-3:] ]: await self.send_chat_func(message) # Keep history for deduplication history_item = TranscriptionHistoryItem( message=message, timestamp=time.time(), is_final=is_final ) self.transcription_history.append(history_item) # Limit history size if len(self.transcription_history) > 10: self.transcription_history.pop(0) logger.info( f"✅ Transcribed ({transcription_type}) for {self.peer_name}: '{text}' (processing time: {total_transcription_time:.3f}s, audio duration: {audio_duration_sec:.2f}s)" ) else: logger.debug( f"Skipping duplicate {transcription_type} transcription: '{text}'" ) else: logger.info( f"❌ No text from {transcription_type} transcription for {self.peer_name} (empty result from model)" ) except Exception as e: logger.error( f"Error in {transcription_type} transcription: {e}", exc_info=True ) def shutdown(self): """Shutdown the audio processor.""" logger.info(f"Shutting down AudioProcessor for {self.peer_name}...") self.is_running = False if self.processor_thread.is_alive(): logger.debug( f"Waiting for processor thread for {self.peer_name} to finish..." ) self.processor_thread.join(timeout=2.0) if self.processor_thread.is_alive(): logger.warning( f"Processor thread for {self.peer_name} did not shut down cleanly within timeout" ) else: logger.info( f"Processor thread for {self.peer_name} shut down successfully" ) logger.info(f"AudioProcessor for {self.peer_name} shutdown complete") async def handle_track_received(peer: Peer, track: MediaStreamTrack): """Handle incoming audio tracks from WebRTC peers.""" global _audio_processors, _send_chat_func if track.kind != "audio": logger.info(f"Ignoring non-audio track from {peer.peer_name}: {track.kind}") return # Create or get audio processor for this peer if peer.peer_name not in _audio_processors: if _send_chat_func is None: logger.error( f"Cannot create AudioProcessor for {peer.peer_name}: no send_chat_func available" ) return logger.info(f"Creating new AudioProcessor for {peer.peer_name}") _audio_processors[peer.peer_name] = AudioProcessor( peer_name=peer.peer_name, send_chat_func=_send_chat_func ) audio_processor = _audio_processors[peer.peer_name] logger.info( f"Received audio track from {peer.peer_name}, starting transcription" ) # Start the frame reception loop try: frame_count = 0 while True: try: # Receive audio frame frame = await track.recv() frame_count += 1 # Log less frequently now that we know frames are being received if frame_count % 100 == 0: logger.info(f"Received {frame_count} frames from {peer.peer_name}") except MediaStreamError as e: # Connection was closed or media stream ended - this is normal logger.info( f"Audio stream ended for {peer.peer_name} (MediaStreamError: {e})" ) break except Exception as e: # Other errors during frame reception logger.error( f"Error receiving audio frame from {peer.peer_name}: {e}", exc_info=True ) break # Check if this is an audio frame and convert to numpy array for processing if isinstance(frame, AudioFrame): # Convert AudioFrame to numpy array try: audio_data = frame.to_ndarray() except Exception as e: logger.error(f"Error converting frame to ndarray for {peer.peer_name}: {e}") continue original_shape = audio_data.shape original_dtype = audio_data.dtype logger.debug( f"Audio frame data: shape={original_shape}, dtype={original_dtype}, samples={frame.samples if hasattr(frame, 'samples') else 'unknown'}" ) # Handle different audio formats - convert stereo to mono if needed if audio_data.ndim == 2: # Stereo -> mono if audio_data.shape[0] == 1: # Shape is (1, samples) - just squeeze the first dimension audio_data = audio_data.squeeze(0) logger.debug(f"Squeezed single-channel audio: {original_shape} -> {audio_data.shape}") else: # True stereo (2, samples) or (samples, 2) - average channels audio_data = np.mean(audio_data, axis=0 if audio_data.shape[0] > audio_data.shape[1] else 1) logger.debug(f"Converted stereo to mono: {original_shape} -> {audio_data.shape}") # Convert to float32 and normalize based on data type if audio_data.dtype == np.int16: audio_data = audio_data.astype(np.float32) / 32768.0 logger.debug("Normalized int16 audio to float32") elif audio_data.dtype == np.int32: audio_data = audio_data.astype(np.float32) / 2147483648.0 logger.debug("Normalized int32 audio to float32") # Resample to 16kHz if needed for Whisper model if frame.sample_rate != sample_rate: original_length = len(audio_data) # Use librosa to resample with explicit float64 conversion for better precision try: audio_float64 = audio_data.astype(np.float64) audio_data = librosa.resample( # type: ignore audio_float64, orig_sr=frame.sample_rate, target_sr=sample_rate ) except Exception as e: logger.error(f"Resampling failed for {peer.peer_name}: {str(e)}") # Fall back to original data audio_data = audio_data logger.debug( f"Resampled audio: {frame.sample_rate}Hz -> {sample_rate}Hz, {original_length} -> {len(audio_data)} samples" ) else: # No resampling needed pass # Ensure audio_data is properly typed as float32 and calculate frame metrics audio_data_float32 = cast(AudioArray, audio_data.astype(np.float32)) frame_rms = np.sqrt(np.mean(audio_data_float32**2)) frame_peak = np.max(np.abs(audio_data_float32)) # Track frame count and audio state frame_count = getattr(peer, "_whisper_frame_count", 0) + 1 setattr(peer, "_whisper_frame_count", frame_count) # Track if we've seen audio before (to detect start of speech) had_audio = getattr(peer, "_whisper_had_audio", False) # Define thresholds for "real audio" detection audio_threshold = 0.001 # RMS threshold for detecting speech has_audio = frame_rms > audio_threshold # Log important audio events if has_audio and not had_audio: # Started receiving audio frame_info = f"{frame.sample_rate}Hz, {frame.format.name}, {frame.layout.name}" logger.info( f"🎤 AUDIO DETECTED from {peer.peer_name}! Frame #{frame_count}: {frame_info}, RMS: {frame_rms:.4f}, Peak: {frame_peak:.4f}" ) setattr(peer, "_whisper_had_audio", True) setattr(peer, "_whisper_last_audio_frame", frame_count) elif not has_audio and had_audio: # Stopped receiving audio last_audio_frame = getattr(peer, "_whisper_last_audio_frame", 0) logger.info( f"🔇 Audio stopped from {peer.peer_name} at frame #{frame_count} (last audio was frame #{last_audio_frame})" ) setattr(peer, "_whisper_had_audio", False) elif has_audio: # Continue receiving audio - update last audio frame but don't spam logs setattr(peer, "_whisper_last_audio_frame", frame_count) # Only log every 100 frames when continuously receiving audio if frame_count % 100 == 0: logger.info( f"🎤 Audio continuing from {peer.peer_name}: Frame #{frame_count}, RMS: {frame_rms:.4f}" ) # Log connection info much less frequently (every 200 frames when silent) if not has_audio and frame_count % 200 == 0: logger.debug( f"Connection active from {peer.peer_name}: Frame #{frame_count} (silent, RMS: {frame_rms:.6f})" ) # Send processed audio to the audio processor for transcription if audio_processor: audio_processor.add_audio_data(audio_data_float32) else: logger.warning( f"No audio processor available to handle audio data for {peer.peer_name}" ) else: logger.warning( f"Received non-audio frame on audio track from {peer.peer_name}: type={type(frame)}" ) except Exception as e: logger.error( f"Unexpected error processing audio track from {peer.peer_name}: {e}", exc_info=True ) finally: # Clean up the audio processor when the stream ends cleanup_peer_processor(peer.peer_name) def agent_info() -> Dict[str, str]: return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "false"} def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]: """Whisper is not a media source - return no local tracks.""" return {} async def handle_chat_message( chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]] ) -> Optional[str]: """Handle incoming chat messages and optionally return a response.""" pass async def on_track_received(peer: Peer, track: MediaStreamTrack): """Callback when a new track is received from a peer.""" await handle_track_received(peer, track) # Export functions for the orchestrator to discover def get_track_handler(): """Return the track handler function for the orchestrator to use.""" return on_track_received def bind_send_chat_function(send_chat_func: Callable[[str], Awaitable[None]]): """Bind the send chat function to be used for all audio processors.""" global _send_chat_func, _audio_processors logger.info("Binding send chat function to whisper agent") _send_chat_func = send_chat_func # Update existing audio processors for peer_name, processor in _audio_processors.items(): logger.debug( f"Updating AudioProcessor for {peer_name} with new send chat function" ) processor.send_chat_func = send_chat_func def cleanup_peer_processor(peer_name: str): """Clean up audio processor for a disconnected peer.""" global _audio_processors if peer_name in _audio_processors: logger.info(f"Cleaning up AudioProcessor for disconnected peer: {peer_name}") processor = _audio_processors[peer_name] processor.shutdown() del _audio_processors[peer_name] logger.info(f"AudioProcessor for {peer_name} cleaned up successfully") else: logger.debug(f"No AudioProcessor found for peer {peer_name} during cleanup") def get_active_processors() -> Dict[str, "AudioProcessor"]: """Get currently active audio processors (for debugging).""" return _audio_processors.copy()