"""Streaming Whisper agent (bots/whisper) Real-time speech transcription agent that processes incoming audio streams and sends transcriptions as chat messages to the lobby. """ import asyncio import numpy as np import time import threading from collections import deque from queue import Queue, Empty from typing import Dict, Optional, Callable, Awaitable, Deque, Any, cast import numpy.typing as npt from pydantic import BaseModel # Core dependencies import librosa from logger import logger from aiortc import MediaStreamTrack from av import AudioFrame # Import shared models for chat functionality import sys import os from voicebot.models import Peer sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from shared.models import ChatMessageModel from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq # Type definitions AudioArray = npt.NDArray[np.float32] class AudioQueueItem(BaseModel): """Audio data with timestamp for processing queue.""" audio: AudioArray timestamp: float class Config: arbitrary_types_allowed = True class TranscriptionHistoryItem(BaseModel): """Transcription history item with metadata.""" message: str timestamp: float is_final: bool AGENT_NAME = "whisper" AGENT_DESCRIPTION = "Real-time speech transcription (Whisper) - converts speech to text" sample_rate = 16000 # Whisper expects 16kHz model_ids = { "Distil-Whisper": [ "distil-whisper/distil-large-v2", "distil-whisper/distil-medium.en", "distil-whisper/distil-small.en" ], "Whisper": [ "openai/whisper-large-v3", "openai/whisper-large-v2", "openai/whisper-large", "openai/whisper-medium", "openai/whisper-small", "openai/whisper-base", "openai/whisper-tiny", "openai/whisper-medium.en", "openai/whisper-small.en", "openai/whisper-base.en", "openai/whisper-tiny.en", ] } # Global whisper model and transcription handler _model_type = model_ids["Distil-Whisper"] _model_id = _model_type[0] _processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=_model_id) # type: ignore _pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=_model_id) # type: ignore _pt_model.eval() # type: ignore _audio_processor: Optional['AudioProcessor'] = None def extract_input_features(audio_array: Any, sampling_rate: int) -> Any: """Extract input features from audio array and sampling rate.""" processor_output = _processor( # type: ignore audio_array, sampling_rate=sampling_rate, return_tensors="pt", ) input_features: Any = processor_output.input_features # type: ignore return input_features # type: ignore class AudioProcessor: """Handles audio stream processing and transcription with sentence chunking.""" def __init__(self, send_chat_func: Callable[[str], Awaitable[None]]): self.send_chat_func = send_chat_func self.sample_rate = 16000 # Whisper expects 16kHz self.samples_per_frame = 480 # Common WebRTC frame size at 16kHz (30ms) # Audio buffering self.audio_buffer: Deque[AudioArray] = deque(maxlen=1000) # ~30 seconds at 30ms frames self.phrase_timeout = 3.0 # seconds of silence before considering phrase complete self.last_activity_time = time.time() # Transcription state self.current_phrase_audio: AudioArray = np.array([], dtype=np.float32) self.transcription_history: list[TranscriptionHistoryItem] = [] # Background processing self.processing_queue: Queue[AudioQueueItem] = Queue() self.is_running = True self.processor_thread = threading.Thread(target=self._processing_loop, daemon=True) self.processor_thread.start() logger.info("AudioProcessor initialized for real-time transcription") def add_audio_data(self, audio_data: AudioArray): """Add new audio data to the processing buffer.""" if not self.is_running: return # Resample if needed (WebRTC might provide different sample rates) if len(audio_data) > 0: self.audio_buffer.append(audio_data) self.last_activity_time = time.time() # Check if we should process accumulated audio if len(self.audio_buffer) >= 10: # Process every ~300ms (10 * 30ms frames) self._queue_for_processing() def _queue_for_processing(self): """Queue current audio buffer for transcription processing.""" if not self.audio_buffer: return # Combine recent audio frames combined_audio = np.concatenate(list(self.audio_buffer)) self.audio_buffer.clear() # Add to processing queue try: queue_item = AudioQueueItem( audio=combined_audio, timestamp=time.time() ) self.processing_queue.put_nowait(queue_item) except Exception: # Queue full, skip this chunk logger.debug("Audio processing queue full, dropping audio chunk") def _processing_loop(self): """Background thread that processes audio chunks for transcription.""" global _whisper_model while self.is_running: try: # Get audio chunk to process (blocking with timeout) try: audio_data = self.processing_queue.get(timeout=1.0) except Empty: continue audio_array = audio_data.audio chunk_timestamp = audio_data.timestamp # Check if this is a new phrase (gap in audio) time_since_last = chunk_timestamp - self.last_activity_time phrase_complete = time_since_last > self.phrase_timeout if phrase_complete and len(self.current_phrase_audio) > 0: # Process the completed phrase try: loop = asyncio.get_event_loop() asyncio.run_coroutine_threadsafe( self._transcribe_and_send(self.current_phrase_audio.copy(), is_final=True), loop ) except RuntimeError: # No event loop running, skip this transcription pass self.current_phrase_audio = np.array([], dtype=np.float32) # Add new audio to current phrase self.current_phrase_audio = np.concatenate([self.current_phrase_audio, audio_array]) # Also do streaming transcription for immediate feedback if len(self.current_phrase_audio) > self.sample_rate * 2: # At least 2 seconds try: loop = asyncio.get_event_loop() asyncio.run_coroutine_threadsafe( self._transcribe_and_send(self.current_phrase_audio.copy(), is_final=False), loop ) except RuntimeError: # No event loop running, skip this transcription pass except Exception as e: logger.error(f"Error in audio processing loop: {e}", exc_info=True) async def _transcribe_and_send(self, audio_array: AudioArray, is_final: bool): """Transcribe audio and send result as chat message.""" global sample_rate try: if len(audio_array) < self.sample_rate * 0.5: # Skip very short audio return # Ensure audio is in the right format for Whisper audio_array = audio_array.astype(np.float32) # Transcribe with Whisper input_features = extract_input_features(audio_array, sample_rate) predicted_ids = _pt_model.generate(input_features) # type: ignore transcription = _processor.batch_decode(predicted_ids, skip_special_tokens=True) # type: ignore text = transcription.strip() if text and len(text) > 1: # Only send meaningful transcriptions prefix = "🎤 " if is_final else "🎤 [partial] " message = f"{prefix}{text}" # Avoid sending duplicate messages if is_final or message not in [h.message for h in self.transcription_history[-3:]]: await self.send_chat_func(message) # Keep history for deduplication history_item = TranscriptionHistoryItem( message=message, timestamp=time.time(), is_final=is_final ) self.transcription_history.append(history_item) # Limit history size if len(self.transcription_history) > 10: self.transcription_history.pop(0) logger.info(f"Transcribed ({'final' if is_final else 'partial'}): {text}") except Exception as e: logger.error(f"Error in transcription: {e}", exc_info=True) def shutdown(self): """Shutdown the audio processor.""" self.is_running = False if self.processor_thread.is_alive(): self.processor_thread.join(timeout=2.0) async def handle_track_received(peer: Peer, track: MediaStreamTrack): """Handle incoming audio tracks from WebRTC peers.""" global _audio_processor if track.kind != "audio": logger.info(f"Ignoring non-audio track: {track.kind}") return logger.info(f"Received audio track from {peer.peer_name}, starting transcription") try: while True: # Receive audio frame frame = await track.recv() if isinstance(frame, AudioFrame): # Convert AudioFrame to numpy array audio_data = frame.to_ndarray() # Handle different audio formats if audio_data.ndim == 2: # Stereo -> mono audio_data = np.mean(audio_data, axis=1) # Convert to float32 and normalize if audio_data.dtype == np.int16: audio_data = audio_data.astype(np.float32) / 32768.0 elif audio_data.dtype == np.int32: audio_data = audio_data.astype(np.float32) / 2147483648.0 # Resample to 16kHz if needed if frame.sample_rate != sample_rate: audio_data = librosa.resample( # type: ignore audio_data, orig_sr=frame.sample_rate, target_sr=sample_rate ) # Ensure audio_data is AudioArray (float32) audio_data_float32 = cast(AudioArray, audio_data.astype(np.float32)) # Send to audio processor if _audio_processor: _audio_processor.add_audio_data(audio_data_float32) except Exception as e: logger.error(f"Error processing audio track from {peer.peer_name}: {e}", exc_info=True) def agent_info() -> Dict[str, str]: return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "false"} def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]: """Whisper is not a media source - return no local tracks.""" return {} async def handle_chat_message(chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]]) -> Optional[str]: """Handle incoming chat messages and optionally return a response.""" pass async def on_track_received(peer: Peer, track: MediaStreamTrack): """Callback when a new track is received from a peer.""" await handle_track_received(peer, track) # Export functions for the orchestrator to discover def get_track_handler(): """Return the track handler function for the orchestrator to use.""" return on_track_received def bind_send_chat_function(send_chat_func: Callable[[str], Awaitable[None]]): """Bind the send chat function to the audio processor.""" global _send_chat_func, _audio_processor _send_chat_func = send_chat_func if _audio_processor: _audio_processor.send_chat_func = send_chat_func else: _audio_processor = AudioProcessor(send_chat_func=send_chat_func)