713 lines
30 KiB
Python

"""Streaming Whisper agent (bots/whisper)
Real-time speech transcription agent that processes incoming audio streams
and sends transcriptions as chat messages to the lobby.
"""
import asyncio
import numpy as np
import time
import threading
from collections import deque
from queue import Queue, Empty
from typing import Dict, Optional, Callable, Awaitable, Deque, Any, cast
import numpy.typing as npt
from pydantic import BaseModel
# Core dependencies
import librosa
from logger import logger
from aiortc import MediaStreamTrack
from aiortc.mediastreams import MediaStreamError
from av import AudioFrame
# Import shared models for chat functionality
import sys
import os
from voicebot.models import Peer
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
from shared.models import ChatMessageModel
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
# Type definitions
AudioArray = npt.NDArray[np.float32]
class AudioQueueItem(BaseModel):
"""Audio data with timestamp for processing queue."""
audio: AudioArray
timestamp: float
class Config:
arbitrary_types_allowed = True
class TranscriptionHistoryItem(BaseModel):
"""Transcription history item with metadata."""
message: str
timestamp: float
is_final: bool
AGENT_NAME = "whisper"
AGENT_DESCRIPTION = "Real-time speech transcription (Whisper) - converts speech to text"
sample_rate = 16000 # Whisper expects 16kHz
model_ids = {
"Distil-Whisper": [
"distil-whisper/distil-large-v2",
"distil-whisper/distil-medium.en",
"distil-whisper/distil-small.en",
],
"Whisper": [
"openai/whisper-large-v3",
"openai/whisper-large-v2",
"openai/whisper-large",
"openai/whisper-medium",
"openai/whisper-small",
"openai/whisper-base",
"openai/whisper-tiny",
"openai/whisper-medium.en",
"openai/whisper-small.en",
"openai/whisper-base.en",
"openai/whisper-tiny.en",
],
}
# Global whisper model and transcription handler
_model_type = model_ids["Distil-Whisper"]
_model_id = _model_type[0]
logger.info(f"Loading Whisper model: {_model_id}")
_processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=_model_id) # type: ignore
logger.info("Whisper processor loaded successfully")
_pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(
pretrained_model_name_or_path=_model_id
) # type: ignore
_pt_model.eval() # type: ignore
logger.info("Whisper model loaded and set to evaluation mode")
_audio_processors: Dict[str, "AudioProcessor"] = {} # Per-peer audio processors
_send_chat_func: Optional[Callable[[str], Awaitable[None]]] = None
def extract_input_features(audio_array: Any, sampling_rate: int) -> Any:
"""Extract input features from audio array and sampling rate."""
processor_output = _processor( # type: ignore
audio_array,
sampling_rate=sampling_rate,
return_tensors="pt",
)
input_features: Any = processor_output.input_features # type: ignore
return input_features # type: ignore
class AudioProcessor:
"""Handles audio stream processing and transcription with sentence chunking for a specific peer."""
main_loop: Optional[asyncio.AbstractEventLoop]
def __init__(
self, peer_name: str, send_chat_func: Callable[[str], Awaitable[None]]
):
self.peer_name = peer_name
self.send_chat_func = send_chat_func
self.sample_rate = 16000 # Whisper expects 16kHz
self.samples_per_frame = 480 # Common WebRTC frame size at 16kHz (30ms)
# Audio buffering
self.audio_buffer: Deque[AudioArray] = deque(
maxlen=1000
) # ~30 seconds at 30ms frames
self.phrase_timeout = (
3.0 # seconds of silence before considering phrase complete
)
self.last_activity_time = time.time()
# Transcription state
self.current_phrase_audio: AudioArray = np.array([], dtype=np.float32)
self.transcription_history: list[TranscriptionHistoryItem] = []
# Capture the main thread's event loop for background processing
try:
self.main_loop = asyncio.get_running_loop()
logger.debug(f"Captured main event loop for {self.peer_name}")
except RuntimeError:
# No event loop running, we'll need to create one
self.main_loop = None
logger.warning(f"No event loop running when initializing AudioProcessor for {self.peer_name}")
# Background processing
self.processing_queue: Queue[AudioQueueItem] = Queue()
self.is_running = True
self.processor_thread = threading.Thread(
target=self._processing_loop, daemon=True
)
self.processor_thread.start()
logger.info(
f"AudioProcessor initialized for {self.peer_name} - sample_rate: {self.sample_rate}Hz, frame_size: {self.samples_per_frame}, phrase_timeout: {self.phrase_timeout}s"
)
def add_audio_data(self, audio_data: AudioArray):
"""Add new audio data to the processing buffer."""
if not self.is_running:
logger.debug("AudioProcessor not running, ignoring audio data")
return
# Resample if needed (WebRTC might provide different sample rates)
if len(audio_data) > 0:
audio_received_time = time.time()
self.audio_buffer.append(audio_data)
self.last_activity_time = audio_received_time
# Calculate audio metrics to detect silence
audio_rms = np.sqrt(np.mean(audio_data**2))
audio_peak = np.max(np.abs(audio_data))
# Log audio buffer status (reduced verbosity)
buffer_duration_ms = len(self.audio_buffer) * 30 # assuming 30ms frames
# Only log if we have meaningful audio or every 50 frames
if audio_rms > 0.001 or len(self.audio_buffer) % 50 == 0:
logger.info(
f"📥 AUDIO BUFFER ADD at {audio_received_time:.3f}: {len(audio_data)} samples, buffer size: {len(self.audio_buffer)} frames ({buffer_duration_ms}ms), RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f} (peer: {self.peer_name})"
)
else:
logger.debug(
f"Added silent audio chunk: {len(audio_data)} samples, buffer size: {len(self.audio_buffer)} frames"
)
# Check if we should process accumulated audio
if len(self.audio_buffer) >= 10: # Process every ~300ms (10 * 30ms frames)
# Check if we have any meaningful audio in the buffer
combined_audio = np.concatenate(list(self.audio_buffer))
combined_rms = np.sqrt(np.mean(combined_audio**2))
if combined_rms > 0.001: # Only process if not silence
buffer_queue_time = time.time()
logger.info(
f"🚀 BUFFER QUEUING at {buffer_queue_time:.3f}: Buffer threshold reached with meaningful audio (RMS: {combined_rms:.4f}), queuing for processing (peer: {self.peer_name})"
)
self._queue_for_processing()
else:
logger.debug(
f"Buffer threshold reached but audio is silent (RMS: {combined_rms:.4f}), clearing buffer"
)
self.audio_buffer.clear() # Clear silent audio
def _queue_for_processing(self):
"""Queue current audio buffer for transcription processing."""
if not self.audio_buffer:
logger.debug("No audio in buffer to queue for processing")
return
# Combine recent audio frames
combined_audio = np.concatenate(list(self.audio_buffer))
self.audio_buffer.clear()
# Calculate audio metrics
audio_duration_sec = len(combined_audio) / self.sample_rate
audio_rms = np.sqrt(np.mean(combined_audio**2))
audio_peak = np.max(np.abs(combined_audio))
# Skip completely silent audio
if audio_rms < 0.001 and audio_peak < 0.001:
logger.debug(
f"Skipping silent audio chunk: RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f}"
)
return
logger.info(
f"📦 AUDIO CHUNK QUEUED at {time.time():.3f}: {len(combined_audio)} samples, {audio_duration_sec:.2f}s duration, RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f} (peer: {self.peer_name})"
)
# Add to processing queue
try:
queue_item = AudioQueueItem(audio=combined_audio, timestamp=time.time())
self.processing_queue.put_nowait(queue_item)
logger.info(
f"📋 PROCESSING QUEUE ADD at {time.time():.3f}: Added to processing queue, queue size: {self.processing_queue.qsize()} (peer: {self.peer_name})"
)
except Exception as e:
# Queue full, skip this chunk
logger.warning(f"Audio processing queue full, dropping audio chunk: {e}")
def _processing_loop(self):
"""Background thread that processes audio chunks for transcription."""
global _whisper_model
logger.info("ASR processing loop started")
while self.is_running:
try:
# Get audio chunk to process (blocking with timeout)
try:
audio_data = self.processing_queue.get(timeout=1.0)
processing_start_time = time.time()
logger.info(
f"🔄 PROCESSING STARTED at {processing_start_time:.3f}: Retrieved audio chunk from queue, remaining queue size: {self.processing_queue.qsize()} (peer: {self.peer_name})"
)
except Empty:
logger.debug("Processing queue timeout, checking for more audio...")
continue
audio_array = audio_data.audio
chunk_timestamp = audio_data.timestamp
# Check if this is a new phrase (gap in audio)
time_since_last = chunk_timestamp - self.last_activity_time
phrase_complete = time_since_last > self.phrase_timeout
logger.debug(
f"Processing audio chunk: {len(audio_array)} samples, time since last: {time_since_last:.2f}s, phrase_complete: {phrase_complete}"
)
if phrase_complete and len(self.current_phrase_audio) > 0:
# Process the completed phrase
phrase_duration = len(self.current_phrase_audio) / self.sample_rate
phrase_rms = np.sqrt(np.mean(self.current_phrase_audio**2))
logger.info(
f"Processing completed phrase: {phrase_duration:.2f}s duration, {len(self.current_phrase_audio)} samples, RMS: {phrase_rms:.4f}"
)
if self.main_loop and not self.main_loop.is_closed():
asyncio.run_coroutine_threadsafe(
self._transcribe_and_send(
self.current_phrase_audio.copy(), is_final=True
),
self.main_loop,
)
else:
logger.warning(
f"No event loop available for final transcription (peer: {self.peer_name})"
)
pass
self.current_phrase_audio = np.array([], dtype=np.float32)
# Add new audio to current phrase
old_phrase_length = len(self.current_phrase_audio)
self.current_phrase_audio = np.concatenate(
[self.current_phrase_audio, audio_array]
)
current_phrase_duration = (
len(self.current_phrase_audio) / self.sample_rate
)
logger.debug(
f"Updated current phrase: {old_phrase_length} -> {len(self.current_phrase_audio)} samples ({current_phrase_duration:.2f}s)"
)
# Lower the threshold for streaming transcription to catch shorter phrases
min_transcription_duration = 1.0 # Reduced from 2.0 seconds
if (
len(self.current_phrase_audio)
> self.sample_rate * min_transcription_duration
): # At least 1 second
phrase_rms = np.sqrt(np.mean(self.current_phrase_audio**2))
logger.info(
f"Current phrase >= {min_transcription_duration}s (RMS: {phrase_rms:.4f}), attempting streaming transcription"
)
if self.main_loop and not self.main_loop.is_closed():
asyncio.run_coroutine_threadsafe(
self._transcribe_and_send(
self.current_phrase_audio.copy(), is_final=False
),
self.main_loop,
)
else:
logger.warning(
f"No event loop available for streaming transcription (peer: {self.peer_name})"
)
except Exception as e:
logger.error(f"Error in audio processing loop: {e}", exc_info=True)
logger.info("ASR processing loop ended")
async def _transcribe_and_send(self, audio_array: AudioArray, is_final: bool):
"""Transcribe audio and send result as chat message."""
global sample_rate
transcription_start_time = time.time()
transcription_type = "final" if is_final else "streaming"
try:
audio_duration_sec = len(audio_array) / self.sample_rate
# Reduce minimum audio duration threshold
min_duration = 0.3 # Reduced from 0.5 seconds
if len(audio_array) < self.sample_rate * min_duration:
logger.debug(
f"Skipping {transcription_type} transcription: audio too short ({audio_duration_sec:.2f}s < {min_duration}s)"
)
return
# Calculate audio quality metrics
audio_rms = np.sqrt(np.mean(audio_array**2))
audio_peak = np.max(np.abs(audio_array))
# More lenient silence detection
if audio_rms < 0.0005: # Very quiet threshold
logger.debug(
f"Skipping {transcription_type} transcription: audio too quiet (RMS: {audio_rms:.6f})"
)
return
logger.info(
f"🎬 TRANSCRIPTION STARTED ({transcription_type}) at {time.time():.3f}: {audio_duration_sec:.2f}s audio, RMS: {audio_rms:.4f}, Peak: {audio_peak:.4f} (peer: {self.peer_name})"
)
# Ensure audio is in the right format for Whisper
audio_array = audio_array.astype(np.float32)
# Transcribe with Whisper
feature_extraction_start = time.time()
input_features = extract_input_features(audio_array, sample_rate)
feature_extraction_time = time.time() - feature_extraction_start
model_inference_start = time.time()
predicted_ids = _pt_model.generate(input_features) # type: ignore
model_inference_time = time.time() - model_inference_start
decoding_start = time.time()
transcription = _processor.batch_decode(
predicted_ids, skip_special_tokens=True
) # type: ignore
decoding_time = time.time() - decoding_start
total_transcription_time = time.time() - transcription_start_time
logger.debug(
f"ASR timing - Feature extraction: {feature_extraction_time:.3f}s, Model inference: {model_inference_time:.3f}s, Decoding: {decoding_time:.3f}s, Total: {total_transcription_time:.3f}s"
)
text = (
transcription[0].strip()
if transcription and len(transcription) > 0
else ""
)
if text and len(text) > 0: # Accept any non-empty text
# Calculate timing information for the message
chat_send_start = time.time()
total_transcription_time = chat_send_start - transcription_start_time
# Create message with timing information included
status_marker = "🎤" if is_final else "🎤"
type_marker = "" if is_final else " [partial]"
timing_info = f" (⏱️ {total_transcription_time:.2f}s from start: {transcription_start_time:.3f})"
prefix = f"{status_marker} {self.peer_name}{type_marker}: "
message = f"{prefix}{text}{timing_info}"
# Avoid sending duplicate messages (check text only, not timing)
text_only_message = f"{prefix}{text}"
if is_final or text_only_message not in [
h.message.split(' (⏱️')[0] for h in self.transcription_history[-3:]
]:
await self.send_chat_func(message)
chat_send_time = time.time() - chat_send_start
message_sent_time = time.time()
logger.info(
f"💬 CHAT MESSAGE SENT at {message_sent_time:.3f}: '{text}' (transcription started: {transcription_start_time:.3f}, chat send took: {chat_send_time:.3f}s, peer: {self.peer_name})"
)
# Keep history for deduplication
history_item = TranscriptionHistoryItem(
message=message, timestamp=time.time(), is_final=is_final
)
self.transcription_history.append(history_item)
# Limit history size
if len(self.transcription_history) > 10:
self.transcription_history.pop(0)
logger.info(
f"✅ Transcribed ({transcription_type}) for {self.peer_name}: '{text}' (processing time: {total_transcription_time:.3f}s, audio duration: {audio_duration_sec:.2f}s)"
)
# Log end-to-end pipeline timing
total_pipeline_time = message_sent_time - transcription_start_time
logger.info(
f"⏱️ PIPELINE TIMING ({transcription_type}): Total={total_pipeline_time:.3f}s (Transcription={total_transcription_time:.3f}s, Chat Send={chat_send_time:.3f}s, peer: {self.peer_name}) | 🕐 Start: {transcription_start_time:.3f}, End: {message_sent_time:.3f}"
)
else:
logger.debug(
f"Skipping duplicate {transcription_type} transcription: '{text}'"
)
else:
logger.info(
f"❌ No text from {transcription_type} transcription for {self.peer_name} (empty result from model)"
)
except Exception as e:
logger.error(
f"Error in {transcription_type} transcription: {e}", exc_info=True
)
def shutdown(self):
"""Shutdown the audio processor."""
logger.info(f"Shutting down AudioProcessor for {self.peer_name}...")
self.is_running = False
if self.processor_thread.is_alive():
logger.debug(
f"Waiting for processor thread for {self.peer_name} to finish..."
)
self.processor_thread.join(timeout=2.0)
if self.processor_thread.is_alive():
logger.warning(
f"Processor thread for {self.peer_name} did not shut down cleanly within timeout"
)
else:
logger.info(
f"Processor thread for {self.peer_name} shut down successfully"
)
logger.info(f"AudioProcessor for {self.peer_name} shutdown complete")
async def handle_track_received(peer: Peer, track: MediaStreamTrack):
"""Handle incoming audio tracks from WebRTC peers."""
global _audio_processors, _send_chat_func
if track.kind != "audio":
logger.info(f"Ignoring non-audio track from {peer.peer_name}: {track.kind}")
return
# Create or get audio processor for this peer
if peer.peer_name not in _audio_processors:
if _send_chat_func is None:
logger.error(
f"Cannot create AudioProcessor for {peer.peer_name}: no send_chat_func available"
)
return
logger.info(f"Creating new AudioProcessor for {peer.peer_name}")
_audio_processors[peer.peer_name] = AudioProcessor(
peer_name=peer.peer_name, send_chat_func=_send_chat_func
)
audio_processor = _audio_processors[peer.peer_name]
logger.info(
f"Received audio track from {peer.peer_name}, starting transcription"
)
# Start the frame reception loop
try:
frame_count = 0
while True:
try:
# Receive audio frame
frame = await track.recv()
frame_count += 1
# Log less frequently now that we know frames are being received
if frame_count % 100 == 0:
logger.info(f"Received {frame_count} frames from {peer.peer_name}")
except MediaStreamError as e:
# Connection was closed or media stream ended - this is normal
logger.info(
f"Audio stream ended for {peer.peer_name} (MediaStreamError: {e})"
)
break
except Exception as e:
# Other errors during frame reception
logger.error(
f"Error receiving audio frame from {peer.peer_name}: {e}", exc_info=True
)
break
# Check if this is an audio frame and convert to numpy array for processing
if isinstance(frame, AudioFrame):
# Convert AudioFrame to numpy array
try:
audio_data = frame.to_ndarray()
except Exception as e:
logger.error(f"Error converting frame to ndarray for {peer.peer_name}: {e}")
continue
original_shape = audio_data.shape
original_dtype = audio_data.dtype
logger.debug(
f"Audio frame data: shape={original_shape}, dtype={original_dtype}, samples={frame.samples if hasattr(frame, 'samples') else 'unknown'}"
)
# Handle different audio formats - convert stereo to mono if needed
if audio_data.ndim == 2: # Stereo -> mono
if audio_data.shape[0] == 1: # Shape is (1, samples) - just squeeze the first dimension
audio_data = audio_data.squeeze(0)
logger.debug(f"Squeezed single-channel audio: {original_shape} -> {audio_data.shape}")
else: # True stereo (2, samples) or (samples, 2) - average channels
audio_data = np.mean(audio_data, axis=0 if audio_data.shape[0] > audio_data.shape[1] else 1)
logger.debug(f"Converted stereo to mono: {original_shape} -> {audio_data.shape}")
# Convert to float32 and normalize based on data type
if audio_data.dtype == np.int16:
audio_data = audio_data.astype(np.float32) / 32768.0
logger.debug("Normalized int16 audio to float32")
elif audio_data.dtype == np.int32:
audio_data = audio_data.astype(np.float32) / 2147483648.0
logger.debug("Normalized int32 audio to float32")
# Resample to 16kHz if needed for Whisper model
if frame.sample_rate != sample_rate:
original_length = len(audio_data)
# Use librosa to resample with explicit float64 conversion for better precision
try:
audio_float64 = audio_data.astype(np.float64)
audio_data = librosa.resample( # type: ignore
audio_float64, orig_sr=frame.sample_rate, target_sr=sample_rate
)
except Exception as e:
logger.error(f"Resampling failed for {peer.peer_name}: {str(e)}")
# Fall back to original data
audio_data = audio_data
logger.debug(
f"Resampled audio: {frame.sample_rate}Hz -> {sample_rate}Hz, {original_length} -> {len(audio_data)} samples"
)
else:
# No resampling needed
pass
# Ensure audio_data is properly typed as float32 and calculate frame metrics
audio_data_float32 = cast(AudioArray, audio_data.astype(np.float32))
frame_rms = np.sqrt(np.mean(audio_data_float32**2))
frame_peak = np.max(np.abs(audio_data_float32))
# Track frame count and audio state
frame_count = getattr(peer, "_whisper_frame_count", 0) + 1
setattr(peer, "_whisper_frame_count", frame_count)
# Track if we've seen audio before (to detect start of speech)
had_audio = getattr(peer, "_whisper_had_audio", False)
# Define thresholds for "real audio" detection
audio_threshold = 0.001 # RMS threshold for detecting speech
has_audio = frame_rms > audio_threshold
# Log important audio events
if has_audio and not had_audio:
# Started receiving audio
frame_info = f"{frame.sample_rate}Hz, {frame.format.name}, {frame.layout.name}"
logger.info(
f"🎤 AUDIO DETECTED from {peer.peer_name}! Frame #{frame_count}: {frame_info}, RMS: {frame_rms:.4f}, Peak: {frame_peak:.4f}"
)
setattr(peer, "_whisper_had_audio", True)
setattr(peer, "_whisper_last_audio_frame", frame_count)
elif not has_audio and had_audio:
# Stopped receiving audio
last_audio_frame = getattr(peer, "_whisper_last_audio_frame", 0)
logger.info(
f"🔇 Audio stopped from {peer.peer_name} at frame #{frame_count} (last audio was frame #{last_audio_frame})"
)
setattr(peer, "_whisper_had_audio", False)
elif has_audio:
# Continue receiving audio - update last audio frame but don't spam logs
setattr(peer, "_whisper_last_audio_frame", frame_count)
# Only log every 100 frames when continuously receiving audio
if frame_count % 100 == 0:
logger.info(
f"🎤 Audio continuing from {peer.peer_name}: Frame #{frame_count}, RMS: {frame_rms:.4f}"
)
# Log connection info much less frequently (every 200 frames when silent)
if not has_audio and frame_count % 200 == 0:
logger.debug(
f"Connection active from {peer.peer_name}: Frame #{frame_count} (silent, RMS: {frame_rms:.6f})"
)
# Send processed audio to the audio processor for transcription
if audio_processor:
audio_processor.add_audio_data(audio_data_float32)
else:
logger.warning(
f"No audio processor available to handle audio data for {peer.peer_name}"
)
else:
logger.warning(
f"Received non-audio frame on audio track from {peer.peer_name}: type={type(frame)}"
)
except Exception as e:
logger.error(
f"Unexpected error processing audio track from {peer.peer_name}: {e}", exc_info=True
)
finally:
# Clean up the audio processor when the stream ends
cleanup_peer_processor(peer.peer_name)
def agent_info() -> Dict[str, str]:
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "false"}
def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
"""Whisper is not a media source - return no local tracks."""
return {}
async def handle_chat_message(
chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]]
) -> Optional[str]:
"""Handle incoming chat messages and optionally return a response."""
pass
async def on_track_received(peer: Peer, track: MediaStreamTrack):
"""Callback when a new track is received from a peer."""
await handle_track_received(peer, track)
# Export functions for the orchestrator to discover
def get_track_handler():
"""Return the track handler function for the orchestrator to use."""
return on_track_received
def bind_send_chat_function(send_chat_func: Callable[[str], Awaitable[None]]):
"""Bind the send chat function to be used for all audio processors."""
global _send_chat_func, _audio_processors
logger.info("Binding send chat function to whisper agent")
_send_chat_func = send_chat_func
# Update existing audio processors
for peer_name, processor in _audio_processors.items():
logger.debug(
f"Updating AudioProcessor for {peer_name} with new send chat function"
)
processor.send_chat_func = send_chat_func
def cleanup_peer_processor(peer_name: str):
"""Clean up audio processor for a disconnected peer."""
global _audio_processors
if peer_name in _audio_processors:
logger.info(f"Cleaning up AudioProcessor for disconnected peer: {peer_name}")
processor = _audio_processors[peer_name]
processor.shutdown()
del _audio_processors[peer_name]
logger.info(f"AudioProcessor for {peer_name} cleaned up successfully")
else:
logger.debug(f"No AudioProcessor found for peer {peer_name} during cleanup")
def get_active_processors() -> Dict[str, "AudioProcessor"]:
"""Get currently active audio processors (for debugging)."""
return _audio_processors.copy()