334 lines
13 KiB
Python

"""Streaming Whisper agent (bots/whisper)
Real-time speech transcription agent that processes incoming audio streams
and sends transcriptions as chat messages to the lobby.
"""
import asyncio
import numpy as np
import time
import threading
from collections import deque
from queue import Queue, Empty
from typing import Dict, Optional, Callable, Awaitable, Deque, Any, cast
import numpy.typing as npt
from pydantic import BaseModel
# Core dependencies
import librosa
from logger import logger
from aiortc import MediaStreamTrack
from av import AudioFrame
# Import shared models for chat functionality
import sys
import os
from voicebot.models import Peer
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from shared.models import ChatMessageModel
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
# Type definitions
AudioArray = npt.NDArray[np.float32]
class AudioQueueItem(BaseModel):
"""Audio data with timestamp for processing queue."""
audio: AudioArray
timestamp: float
class Config:
arbitrary_types_allowed = True
class TranscriptionHistoryItem(BaseModel):
"""Transcription history item with metadata."""
message: str
timestamp: float
is_final: bool
AGENT_NAME = "whisper"
AGENT_DESCRIPTION = "Real-time speech transcription (Whisper) - converts speech to text"
sample_rate = 16000 # Whisper expects 16kHz
model_ids = {
"Distil-Whisper": [
"distil-whisper/distil-large-v2",
"distil-whisper/distil-medium.en",
"distil-whisper/distil-small.en"
],
"Whisper": [
"openai/whisper-large-v3",
"openai/whisper-large-v2",
"openai/whisper-large",
"openai/whisper-medium",
"openai/whisper-small",
"openai/whisper-base",
"openai/whisper-tiny",
"openai/whisper-medium.en",
"openai/whisper-small.en",
"openai/whisper-base.en",
"openai/whisper-tiny.en",
]
}
# Global whisper model and transcription handler
_model_type = model_ids["Distil-Whisper"]
_model_id = _model_type[0]
_processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=_model_id) # type: ignore
_pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=_model_id) # type: ignore
_pt_model.eval() # type: ignore
_audio_processor: Optional['AudioProcessor'] = None
def extract_input_features(audio_array: Any, sampling_rate: int) -> Any:
"""Extract input features from audio array and sampling rate."""
processor_output = _processor( # type: ignore
audio_array,
sampling_rate=sampling_rate,
return_tensors="pt",
)
input_features: Any = processor_output.input_features # type: ignore
return input_features # type: ignore
class AudioProcessor:
"""Handles audio stream processing and transcription with sentence chunking."""
def __init__(self, send_chat_func: Callable[[str], Awaitable[None]]):
self.send_chat_func = send_chat_func
self.sample_rate = 16000 # Whisper expects 16kHz
self.samples_per_frame = 480 # Common WebRTC frame size at 16kHz (30ms)
# Audio buffering
self.audio_buffer: Deque[AudioArray] = deque(maxlen=1000) # ~30 seconds at 30ms frames
self.phrase_timeout = 3.0 # seconds of silence before considering phrase complete
self.last_activity_time = time.time()
# Transcription state
self.current_phrase_audio: AudioArray = np.array([], dtype=np.float32)
self.transcription_history: list[TranscriptionHistoryItem] = []
# Background processing
self.processing_queue: Queue[AudioQueueItem] = Queue()
self.is_running = True
self.processor_thread = threading.Thread(target=self._processing_loop, daemon=True)
self.processor_thread.start()
logger.info("AudioProcessor initialized for real-time transcription")
def add_audio_data(self, audio_data: AudioArray):
"""Add new audio data to the processing buffer."""
if not self.is_running:
return
# Resample if needed (WebRTC might provide different sample rates)
if len(audio_data) > 0:
self.audio_buffer.append(audio_data)
self.last_activity_time = time.time()
# Check if we should process accumulated audio
if len(self.audio_buffer) >= 10: # Process every ~300ms (10 * 30ms frames)
self._queue_for_processing()
def _queue_for_processing(self):
"""Queue current audio buffer for transcription processing."""
if not self.audio_buffer:
return
# Combine recent audio frames
combined_audio = np.concatenate(list(self.audio_buffer))
self.audio_buffer.clear()
# Add to processing queue
try:
queue_item = AudioQueueItem(
audio=combined_audio,
timestamp=time.time()
)
self.processing_queue.put_nowait(queue_item)
except Exception:
# Queue full, skip this chunk
logger.debug("Audio processing queue full, dropping audio chunk")
def _processing_loop(self):
"""Background thread that processes audio chunks for transcription."""
global _whisper_model
while self.is_running:
try:
# Get audio chunk to process (blocking with timeout)
try:
audio_data = self.processing_queue.get(timeout=1.0)
except Empty:
continue
audio_array = audio_data.audio
chunk_timestamp = audio_data.timestamp
# Check if this is a new phrase (gap in audio)
time_since_last = chunk_timestamp - self.last_activity_time
phrase_complete = time_since_last > self.phrase_timeout
if phrase_complete and len(self.current_phrase_audio) > 0:
# Process the completed phrase
try:
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(
self._transcribe_and_send(self.current_phrase_audio.copy(), is_final=True),
loop
)
except RuntimeError:
# No event loop running, skip this transcription
pass
self.current_phrase_audio = np.array([], dtype=np.float32)
# Add new audio to current phrase
self.current_phrase_audio = np.concatenate([self.current_phrase_audio, audio_array])
# Also do streaming transcription for immediate feedback
if len(self.current_phrase_audio) > self.sample_rate * 2: # At least 2 seconds
try:
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(
self._transcribe_and_send(self.current_phrase_audio.copy(), is_final=False),
loop
)
except RuntimeError:
# No event loop running, skip this transcription
pass
except Exception as e:
logger.error(f"Error in audio processing loop: {e}", exc_info=True)
async def _transcribe_and_send(self, audio_array: AudioArray, is_final: bool):
"""Transcribe audio and send result as chat message."""
global sample_rate
try:
if len(audio_array) < self.sample_rate * 0.5: # Skip very short audio
return
# Ensure audio is in the right format for Whisper
audio_array = audio_array.astype(np.float32)
# Transcribe with Whisper
input_features = extract_input_features(audio_array, sample_rate)
predicted_ids = _pt_model.generate(input_features) # type: ignore
transcription = _processor.batch_decode(predicted_ids, skip_special_tokens=True) # type: ignore
text = transcription.strip()
if text and len(text) > 1: # Only send meaningful transcriptions
prefix = "🎤 " if is_final else "🎤 [partial] "
message = f"{prefix}{text}"
# Avoid sending duplicate messages
if is_final or message not in [h.message for h in self.transcription_history[-3:]]:
await self.send_chat_func(message)
# Keep history for deduplication
history_item = TranscriptionHistoryItem(
message=message,
timestamp=time.time(),
is_final=is_final
)
self.transcription_history.append(history_item)
# Limit history size
if len(self.transcription_history) > 10:
self.transcription_history.pop(0)
logger.info(f"Transcribed ({'final' if is_final else 'partial'}): {text}")
except Exception as e:
logger.error(f"Error in transcription: {e}", exc_info=True)
def shutdown(self):
"""Shutdown the audio processor."""
self.is_running = False
if self.processor_thread.is_alive():
self.processor_thread.join(timeout=2.0)
async def handle_track_received(peer: Peer, track: MediaStreamTrack):
"""Handle incoming audio tracks from WebRTC peers."""
global _audio_processor
if track.kind != "audio":
logger.info(f"Ignoring non-audio track: {track.kind}")
return
logger.info(f"Received audio track from {peer.peer_name}, starting transcription")
try:
while True:
# Receive audio frame
frame = await track.recv()
if isinstance(frame, AudioFrame):
logger.info(f"Received audio frame: {frame.sample_rate}Hz, {frame.format.name}, {frame.layout.name}")
# Convert AudioFrame to numpy array
audio_data = frame.to_ndarray()
# Handle different audio formats
if audio_data.ndim == 2: # Stereo -> mono
audio_data = np.mean(audio_data, axis=1)
# Convert to float32 and normalize
if audio_data.dtype == np.int16:
audio_data = audio_data.astype(np.float32) / 32768.0
elif audio_data.dtype == np.int32:
audio_data = audio_data.astype(np.float32) / 2147483648.0
# Resample to 16kHz if needed
if frame.sample_rate != sample_rate:
audio_data = librosa.resample( # type: ignore
audio_data,
orig_sr=frame.sample_rate,
target_sr=sample_rate
)
# Ensure audio_data is AudioArray (float32)
audio_data_float32 = cast(AudioArray, audio_data.astype(np.float32))
# Send to audio processor
if _audio_processor:
_audio_processor.add_audio_data(audio_data_float32)
else:
logger.warning(f"Received non-audio frame on audio track from {peer.peer_name}")
except Exception as e:
logger.error(f"Error processing audio track from {peer.peer_name}: {e}", exc_info=True)
def agent_info() -> Dict[str, str]:
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "false"}
def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
"""Whisper is not a media source - return no local tracks."""
return {}
async def handle_chat_message(chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]]) -> Optional[str]:
"""Handle incoming chat messages and optionally return a response."""
pass
async def on_track_received(peer: Peer, track: MediaStreamTrack):
"""Callback when a new track is received from a peer."""
await handle_track_received(peer, track)
# Export functions for the orchestrator to discover
def get_track_handler():
"""Return the track handler function for the orchestrator to use."""
return on_track_received
def bind_send_chat_function(send_chat_func: Callable[[str], Awaitable[None]]):
"""Bind the send chat function to the audio processor."""
global _send_chat_func, _audio_processor
_send_chat_func = send_chat_func
if _audio_processor:
_audio_processor.send_chat_func = send_chat_func
else:
_audio_processor = AudioProcessor(send_chat_func=send_chat_func)