331 lines
13 KiB
Python
331 lines
13 KiB
Python
"""Streaming Whisper agent (bots/whisper)
|
|
|
|
Real-time speech transcription agent that processes incoming audio streams
|
|
and sends transcriptions as chat messages to the lobby.
|
|
"""
|
|
|
|
import asyncio
|
|
import numpy as np
|
|
import time
|
|
import threading
|
|
from collections import deque
|
|
from queue import Queue, Empty
|
|
from typing import Dict, Optional, Callable, Awaitable, Deque, Any, cast
|
|
import numpy.typing as npt
|
|
from pydantic import BaseModel
|
|
|
|
# Core dependencies
|
|
import librosa
|
|
from logger import logger
|
|
from aiortc import MediaStreamTrack
|
|
from av import AudioFrame
|
|
|
|
# Import shared models for chat functionality
|
|
import sys
|
|
import os
|
|
|
|
from voicebot.models import Peer
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
from shared.models import ChatMessageModel
|
|
|
|
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
|
|
|
# Type definitions
|
|
AudioArray = npt.NDArray[np.float32]
|
|
|
|
class AudioQueueItem(BaseModel):
|
|
"""Audio data with timestamp for processing queue."""
|
|
audio: AudioArray
|
|
timestamp: float
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
class TranscriptionHistoryItem(BaseModel):
|
|
"""Transcription history item with metadata."""
|
|
message: str
|
|
timestamp: float
|
|
is_final: bool
|
|
|
|
AGENT_NAME = "whisper"
|
|
AGENT_DESCRIPTION = "Real-time speech transcription (Whisper) - converts speech to text"
|
|
sample_rate = 16000 # Whisper expects 16kHz
|
|
|
|
model_ids = {
|
|
"Distil-Whisper": [
|
|
"distil-whisper/distil-large-v2",
|
|
"distil-whisper/distil-medium.en",
|
|
"distil-whisper/distil-small.en"
|
|
],
|
|
"Whisper": [
|
|
"openai/whisper-large-v3",
|
|
"openai/whisper-large-v2",
|
|
"openai/whisper-large",
|
|
"openai/whisper-medium",
|
|
"openai/whisper-small",
|
|
"openai/whisper-base",
|
|
"openai/whisper-tiny",
|
|
"openai/whisper-medium.en",
|
|
"openai/whisper-small.en",
|
|
"openai/whisper-base.en",
|
|
"openai/whisper-tiny.en",
|
|
]
|
|
}
|
|
# Global whisper model and transcription handler
|
|
_model_type = model_ids["Distil-Whisper"]
|
|
_model_id = _model_type[0]
|
|
_processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=_model_id) # type: ignore
|
|
_pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=_model_id) # type: ignore
|
|
_pt_model.eval() # type: ignore
|
|
_audio_processor: Optional['AudioProcessor'] = None
|
|
|
|
|
|
def extract_input_features(audio_array: Any, sampling_rate: int) -> Any:
|
|
"""Extract input features from audio array and sampling rate."""
|
|
processor_output = _processor( # type: ignore
|
|
audio_array,
|
|
sampling_rate=sampling_rate,
|
|
return_tensors="pt",
|
|
)
|
|
input_features: Any = processor_output.input_features # type: ignore
|
|
return input_features # type: ignore
|
|
|
|
class AudioProcessor:
|
|
"""Handles audio stream processing and transcription with sentence chunking."""
|
|
|
|
def __init__(self, send_chat_func: Callable[[str], Awaitable[None]]):
|
|
self.send_chat_func = send_chat_func
|
|
self.sample_rate = 16000 # Whisper expects 16kHz
|
|
self.samples_per_frame = 480 # Common WebRTC frame size at 16kHz (30ms)
|
|
|
|
# Audio buffering
|
|
self.audio_buffer: Deque[AudioArray] = deque(maxlen=1000) # ~30 seconds at 30ms frames
|
|
self.phrase_timeout = 3.0 # seconds of silence before considering phrase complete
|
|
self.last_activity_time = time.time()
|
|
|
|
# Transcription state
|
|
self.current_phrase_audio: AudioArray = np.array([], dtype=np.float32)
|
|
self.transcription_history: list[TranscriptionHistoryItem] = []
|
|
|
|
# Background processing
|
|
self.processing_queue: Queue[AudioQueueItem] = Queue()
|
|
self.is_running = True
|
|
self.processor_thread = threading.Thread(target=self._processing_loop, daemon=True)
|
|
self.processor_thread.start()
|
|
|
|
logger.info("AudioProcessor initialized for real-time transcription")
|
|
|
|
def add_audio_data(self, audio_data: AudioArray):
|
|
"""Add new audio data to the processing buffer."""
|
|
if not self.is_running:
|
|
return
|
|
|
|
# Resample if needed (WebRTC might provide different sample rates)
|
|
if len(audio_data) > 0:
|
|
self.audio_buffer.append(audio_data)
|
|
self.last_activity_time = time.time()
|
|
|
|
# Check if we should process accumulated audio
|
|
if len(self.audio_buffer) >= 10: # Process every ~300ms (10 * 30ms frames)
|
|
self._queue_for_processing()
|
|
|
|
def _queue_for_processing(self):
|
|
"""Queue current audio buffer for transcription processing."""
|
|
if not self.audio_buffer:
|
|
return
|
|
|
|
# Combine recent audio frames
|
|
combined_audio = np.concatenate(list(self.audio_buffer))
|
|
self.audio_buffer.clear()
|
|
|
|
# Add to processing queue
|
|
try:
|
|
queue_item = AudioQueueItem(
|
|
audio=combined_audio,
|
|
timestamp=time.time()
|
|
)
|
|
self.processing_queue.put_nowait(queue_item)
|
|
except Exception:
|
|
# Queue full, skip this chunk
|
|
logger.debug("Audio processing queue full, dropping audio chunk")
|
|
|
|
def _processing_loop(self):
|
|
"""Background thread that processes audio chunks for transcription."""
|
|
global _whisper_model
|
|
|
|
while self.is_running:
|
|
try:
|
|
# Get audio chunk to process (blocking with timeout)
|
|
try:
|
|
audio_data = self.processing_queue.get(timeout=1.0)
|
|
except Empty:
|
|
continue
|
|
|
|
audio_array = audio_data.audio
|
|
chunk_timestamp = audio_data.timestamp
|
|
|
|
# Check if this is a new phrase (gap in audio)
|
|
time_since_last = chunk_timestamp - self.last_activity_time
|
|
phrase_complete = time_since_last > self.phrase_timeout
|
|
|
|
if phrase_complete and len(self.current_phrase_audio) > 0:
|
|
# Process the completed phrase
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._transcribe_and_send(self.current_phrase_audio.copy(), is_final=True),
|
|
loop
|
|
)
|
|
except RuntimeError:
|
|
# No event loop running, skip this transcription
|
|
pass
|
|
self.current_phrase_audio = np.array([], dtype=np.float32)
|
|
|
|
# Add new audio to current phrase
|
|
self.current_phrase_audio = np.concatenate([self.current_phrase_audio, audio_array])
|
|
|
|
# Also do streaming transcription for immediate feedback
|
|
if len(self.current_phrase_audio) > self.sample_rate * 2: # At least 2 seconds
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._transcribe_and_send(self.current_phrase_audio.copy(), is_final=False),
|
|
loop
|
|
)
|
|
except RuntimeError:
|
|
# No event loop running, skip this transcription
|
|
pass
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in audio processing loop: {e}", exc_info=True)
|
|
|
|
async def _transcribe_and_send(self, audio_array: AudioArray, is_final: bool):
|
|
"""Transcribe audio and send result as chat message."""
|
|
global sample_rate
|
|
|
|
try:
|
|
if len(audio_array) < self.sample_rate * 0.5: # Skip very short audio
|
|
return
|
|
|
|
# Ensure audio is in the right format for Whisper
|
|
audio_array = audio_array.astype(np.float32)
|
|
|
|
# Transcribe with Whisper
|
|
input_features = extract_input_features(audio_array, sample_rate)
|
|
predicted_ids = _pt_model.generate(input_features) # type: ignore
|
|
transcription = _processor.batch_decode(predicted_ids, skip_special_tokens=True) # type: ignore
|
|
|
|
|
|
text = transcription.strip()
|
|
|
|
if text and len(text) > 1: # Only send meaningful transcriptions
|
|
prefix = "🎤 " if is_final else "🎤 [partial] "
|
|
message = f"{prefix}{text}"
|
|
|
|
# Avoid sending duplicate messages
|
|
if is_final or message not in [h.message for h in self.transcription_history[-3:]]:
|
|
await self.send_chat_func(message)
|
|
|
|
# Keep history for deduplication
|
|
history_item = TranscriptionHistoryItem(
|
|
message=message,
|
|
timestamp=time.time(),
|
|
is_final=is_final
|
|
)
|
|
self.transcription_history.append(history_item)
|
|
|
|
# Limit history size
|
|
if len(self.transcription_history) > 10:
|
|
self.transcription_history.pop(0)
|
|
|
|
logger.info(f"Transcribed ({'final' if is_final else 'partial'}): {text}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in transcription: {e}", exc_info=True)
|
|
|
|
def shutdown(self):
|
|
"""Shutdown the audio processor."""
|
|
self.is_running = False
|
|
if self.processor_thread.is_alive():
|
|
self.processor_thread.join(timeout=2.0)
|
|
|
|
|
|
async def handle_track_received(peer: Peer, track: MediaStreamTrack):
|
|
"""Handle incoming audio tracks from WebRTC peers."""
|
|
global _audio_processor
|
|
|
|
if track.kind != "audio":
|
|
logger.info(f"Ignoring non-audio track: {track.kind}")
|
|
return
|
|
|
|
logger.info(f"Received audio track from {peer.peer_name}, starting transcription")
|
|
|
|
try:
|
|
while True:
|
|
# Receive audio frame
|
|
frame = await track.recv()
|
|
if isinstance(frame, AudioFrame):
|
|
# Convert AudioFrame to numpy array
|
|
audio_data = frame.to_ndarray()
|
|
|
|
# Handle different audio formats
|
|
if audio_data.ndim == 2: # Stereo -> mono
|
|
audio_data = np.mean(audio_data, axis=1)
|
|
|
|
# Convert to float32 and normalize
|
|
if audio_data.dtype == np.int16:
|
|
audio_data = audio_data.astype(np.float32) / 32768.0
|
|
elif audio_data.dtype == np.int32:
|
|
audio_data = audio_data.astype(np.float32) / 2147483648.0
|
|
|
|
# Resample to 16kHz if needed
|
|
if frame.sample_rate != sample_rate:
|
|
audio_data = librosa.resample( # type: ignore
|
|
audio_data,
|
|
orig_sr=frame.sample_rate,
|
|
target_sr=sample_rate
|
|
)
|
|
|
|
# Ensure audio_data is AudioArray (float32)
|
|
audio_data_float32 = cast(AudioArray, audio_data.astype(np.float32))
|
|
|
|
# Send to audio processor
|
|
if _audio_processor:
|
|
_audio_processor.add_audio_data(audio_data_float32)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing audio track from {peer.peer_name}: {e}", exc_info=True)
|
|
|
|
|
|
def agent_info() -> Dict[str, str]:
|
|
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "false"}
|
|
|
|
|
|
def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
|
|
"""Whisper is not a media source - return no local tracks."""
|
|
return {}
|
|
|
|
|
|
async def handle_chat_message(chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]]) -> Optional[str]:
|
|
"""Handle incoming chat messages and optionally return a response."""
|
|
pass
|
|
|
|
|
|
async def on_track_received(peer: Peer, track: MediaStreamTrack):
|
|
"""Callback when a new track is received from a peer."""
|
|
await handle_track_received(peer, track)
|
|
|
|
# Export functions for the orchestrator to discover
|
|
def get_track_handler():
|
|
"""Return the track handler function for the orchestrator to use."""
|
|
return on_track_received
|
|
|
|
def bind_send_chat_function(send_chat_func: Callable[[str], Awaitable[None]]):
|
|
"""Bind the send chat function to the audio processor."""
|
|
global _send_chat_func, _audio_processor
|
|
_send_chat_func = send_chat_func
|
|
if _audio_processor:
|
|
_audio_processor.send_chat_func = send_chat_func
|
|
else:
|
|
_audio_processor = AudioProcessor(send_chat_func=send_chat_func)
|