1308 lines
54 KiB
Python
1308 lines
54 KiB
Python
"""Streaming Whisper agent (bots/whisper) - OpenVINO Optimized for Intel Arc B580
|
|
|
|
Real-time speech transcription agent that processes incoming audio streams
|
|
and sends transcriptions as chat messages to the lobby.
|
|
Optimized for Intel Arc B580 GPU using OpenVINO inference engine.
|
|
"""
|
|
|
|
import asyncio
|
|
import numpy as np
|
|
import time
|
|
import threading
|
|
import os
|
|
import gc
|
|
import shutil
|
|
from queue import Queue, Empty
|
|
from typing import Dict, Optional, Callable, Awaitable, Any, List, Union
|
|
from pathlib import Path
|
|
import numpy.typing as npt
|
|
from pydantic import BaseModel, Field, ConfigDict
|
|
|
|
# Core dependencies
|
|
import librosa
|
|
from shared.logger import logger
|
|
from aiortc import MediaStreamTrack
|
|
from aiortc.mediastreams import MediaStreamError
|
|
from av import AudioFrame, VideoFrame
|
|
import cv2
|
|
import fractions
|
|
from time import perf_counter
|
|
|
|
# Import shared models for chat functionality
|
|
import sys
|
|
sys.path.append(
|
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
)
|
|
from shared.models import ChatMessageModel
|
|
from voicebot.models import Peer
|
|
|
|
# OpenVINO optimized imports
|
|
import openvino as ov
|
|
from optimum.intel.openvino import OVModelForSpeechSeq2Seq # type: ignore
|
|
from transformers import WhisperProcessor
|
|
from openvino.runtime import Core # Part of optimum.intel.openvino # type: ignore
|
|
import torch
|
|
|
|
# Import quantization dependencies with error handling
|
|
import nncf # type: ignore
|
|
from optimum.intel.openvino.quantization import InferRequestWrapper # type: ignore
|
|
QUANTIZATION_AVAILABLE = True
|
|
|
|
# Type definitions
|
|
AudioArray = npt.NDArray[np.float32]
|
|
ModelConfig = Dict[str, Union[str, int, bool]]
|
|
CalibrationData = List[Dict[str, Any]]
|
|
|
|
_device = "GPU.1" # Default to Intel Arc B580 GPU
|
|
|
|
def get_available_devices() -> list[dict[str, Any]]:
|
|
"""List available OpenVINO devices with their properties."""
|
|
try:
|
|
core = Core()
|
|
devices = core.available_devices
|
|
device_info : list[dict[str, Any]] = []
|
|
for device in devices:
|
|
try:
|
|
# Get device properties
|
|
properties = core.get_property(device, "FULL_DEVICE_NAME")
|
|
# Attempt to get additional properties if available
|
|
try:
|
|
device_type = core.get_property(device, "DEVICE_TYPE")
|
|
except Exception:
|
|
device_type = "N/A"
|
|
try:
|
|
capabilities : Any = core.get_property(device, "SUPPORTED_PROPERTIES")
|
|
except Exception:
|
|
capabilities = "N/A"
|
|
device_info.append({
|
|
"name": device,
|
|
"full_name": properties,
|
|
"type": device_type,
|
|
"capabilities": capabilities
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Failed to retrieve properties for device {device}: {e}")
|
|
device_info.append({
|
|
"name": device,
|
|
"full_name": "Unknown",
|
|
"type": "N/A",
|
|
"capabilities": "N/A"
|
|
})
|
|
return device_info
|
|
except Exception as e:
|
|
logger.error(f"Failed to retrieve available devices: {e}")
|
|
return []
|
|
|
|
def print_available_devices(device: str | None = None):
|
|
"""Print available OpenVINO devices in a formatted manner."""
|
|
devices = get_available_devices()
|
|
if not devices:
|
|
logger.info("No OpenVINO devices detected.")
|
|
return
|
|
logger.info("Available OpenVINO Devices:")
|
|
for d in devices:
|
|
logger.info(f"- Device: {d.get('name')} {'*' if d.get('name') == device else ''}")
|
|
logger.info(f" Full Name: {d.get('full_name')}")
|
|
logger.info(f" Type: {d.get('type')}")
|
|
|
|
|
|
print_available_devices(_device)
|
|
|
|
class AudioQueueItem(BaseModel):
|
|
"""Audio data with timestamp for processing queue."""
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
audio: AudioArray = Field(..., description="Audio data as numpy array")
|
|
timestamp: float = Field(..., description="Timestamp when audio was captured")
|
|
|
|
|
|
class TranscriptionHistoryItem(BaseModel):
|
|
"""Transcription history item with metadata."""
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
message: str = Field(..., description="Transcribed text message")
|
|
timestamp: float = Field(..., description="When transcription was completed")
|
|
is_final: bool = Field(..., description="Whether this is final or streaming transcription")
|
|
|
|
|
|
class OpenVINOConfig(BaseModel):
|
|
"""OpenVINO configuration for Intel Arc B580 optimization."""
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
device: str = Field(default=_device, description="Target device for inference")
|
|
cache_dir: str = Field(default="./ov_cache", description="Cache directory for compiled models")
|
|
enable_quantization: bool = Field(default=True, description="Enable INT8 quantization")
|
|
throughput_streams: int = Field(default=2, description="Number of inference streams")
|
|
max_threads: int = Field(default=8, description="Maximum number of threads")
|
|
|
|
def to_ov_config(self) -> ModelConfig:
|
|
"""Convert to OpenVINO configuration dictionary."""
|
|
cfg: ModelConfig = {"CACHE_DIR": self.cache_dir}
|
|
|
|
# Only include GPU-specific tuning options when the target device is GPU.
|
|
# Some OpenVINO plugins (notably the CPU plugin) will raise NotFound
|
|
# errors for GPU_* properties, so avoid passing them unless applicable.
|
|
device = (self.device or "").upper()
|
|
if device == "GPU":
|
|
cfg.update(
|
|
{
|
|
# Throughput / stream tuning
|
|
"GPU_THROUGHPUT_STREAMS": str(self.throughput_streams),
|
|
# Threading controls may be driver/plugin-specific; keep minimal
|
|
# NOTE: We intentionally do NOT set GPU_MAX_NUM_THREADS here
|
|
# because some OpenVINO plugins / builds (and the CPU plugin
|
|
# during a fallback) do not recognize the property and will
|
|
# raise NotFound/UnsupportedProperty errors. If you need to
|
|
# tune GPU threads for a specific driver, set that externally
|
|
# or via vendor-specific tools.
|
|
}
|
|
)
|
|
else:
|
|
# Safe CPU-side defaults
|
|
cfg.update(
|
|
{
|
|
"CPU_THROUGHPUT_NUM_THREADS": str(self.max_threads),
|
|
"CPU_BIND_THREAD": "YES",
|
|
}
|
|
)
|
|
|
|
return cfg
|
|
|
|
|
|
# Global configuration and constants
|
|
AGENT_NAME = "whisper"
|
|
AGENT_DESCRIPTION = "Real-time speech transcription (OpenVINO Whisper) - converts speech to text on Intel Arc B580"
|
|
SAMPLE_RATE = 16000 # Whisper expects 16kHz
|
|
CHUNK_DURATION_MS = 100 # Reduced latency - 100ms chunks
|
|
VAD_THRESHOLD = 0.01 # Voice activity detection threshold
|
|
MAX_SILENCE_FRAMES = 30 # 3 seconds of silence before stopping
|
|
|
|
model_ids = {
|
|
"Distil-Whisper": [
|
|
"distil-whisper/distil-large-v2",
|
|
"distil-whisper/distil-medium.en",
|
|
"distil-whisper/distil-small.en",
|
|
],
|
|
"Whisper": [
|
|
"openai/whisper-large-v3",
|
|
"openai/whisper-large-v2",
|
|
"openai/whisper-large",
|
|
"openai/whisper-medium",
|
|
"openai/whisper-small",
|
|
"openai/whisper-base",
|
|
"openai/whisper-tiny",
|
|
"openai/whisper-medium.en",
|
|
"openai/whisper-small.en",
|
|
"openai/whisper-base.en",
|
|
"openai/whisper-tiny.en",
|
|
],
|
|
}
|
|
|
|
# Global model configuration
|
|
_model_type = model_ids["Distil-Whisper"]
|
|
_model_id = _model_type[0] # Use distil-large-v2 for best quality
|
|
_ov_config = OpenVINOConfig()
|
|
|
|
|
|
def setup_intel_arc_environment() -> None:
|
|
"""Configure environment variables for optimal Intel Arc B580 performance."""
|
|
os.environ["OV_GPU_CACHE_MODEL"] = "1"
|
|
os.environ["OV_GPU_ENABLE_OPENCL_THROTTLING"] = "0"
|
|
os.environ["OV_GPU_DISABLE_WINOGRAD"] = "1"
|
|
logger.info("Intel Arc B580 environment variables configured")
|
|
|
|
|
|
class OpenVINOWhisperModel:
|
|
"""OpenVINO optimized Whisper model for Intel Arc B580."""
|
|
|
|
def __init__(self, model_id: str, config: OpenVINOConfig, device: str):
|
|
self.model_id = model_id
|
|
self.config = config
|
|
self.device = device
|
|
self.model_path = Path(model_id.replace('/', '_'))
|
|
self.quantized_model_path = Path(f"{self.model_path}_quantized")
|
|
|
|
self.processor: Optional[WhisperProcessor] = None
|
|
self.ov_model: Optional[OVModelForSpeechSeq2Seq] = None
|
|
self.is_quantized = False
|
|
|
|
self._initialize_model()
|
|
|
|
def _initialize_model(self) -> None:
|
|
"""Initialize processor and OpenVINO model with robust error handling."""
|
|
logger.info(f"Initializing OpenVINO Whisper model: {self.model_id}")
|
|
|
|
try:
|
|
# Initialize processor
|
|
logger.info(f"Loading Whisper model '{self.model_id}' on device: {self.device}")
|
|
self.processor = WhisperProcessor.from_pretrained(self.model_id, use_fast=True) # type: ignore
|
|
logger.info("Whisper processor loaded successfully")
|
|
|
|
# Export the model to OpenVINO IR if not already converted
|
|
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(self.model_id, export=True, device=self.device) # type: ignore
|
|
|
|
logger.info("Whisper model exported as OpenVINO IR")
|
|
|
|
# # Try to load quantized model first if it exists
|
|
# if self.config.enable_quantization and self.quantized_model_path.exists():
|
|
# if self._try_load_quantized_model():
|
|
# return
|
|
|
|
# # Load or create FP16 model
|
|
# if self.model_path.exists():
|
|
# self._load_fp16_model()
|
|
# else:
|
|
# self._convert_model()
|
|
|
|
# # Try quantization after model is loaded and compiled
|
|
# if self.config.enable_quantization and not self.is_quantized:
|
|
# self._try_quantize_existing_model()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing model: {e}")
|
|
# Fallback to basic conversion without quantization
|
|
self._fallback_initialization()
|
|
|
|
def _fallback_initialization(self) -> None:
|
|
"""Fallback initialization without quantization."""
|
|
logger.warning("Falling back to basic OpenVINO conversion without quantization")
|
|
try:
|
|
if not self.model_path.exists():
|
|
self._convert_model_basic()
|
|
self._load_fp16_model()
|
|
except Exception as e:
|
|
logger.error(f"Fallback initialization failed: {e}")
|
|
raise RuntimeError("Failed to initialize OpenVINO model") from e
|
|
|
|
def _convert_model(self) -> None:
|
|
"""Convert PyTorch model to OpenVINO format."""
|
|
logger.info(f"Converting {self.model_id} to OpenVINO format...")
|
|
|
|
try:
|
|
# Convert to OpenVINO with FP16 for Arc GPU
|
|
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
|
self.model_id,
|
|
ov_config=self.config.to_ov_config(),
|
|
export=True,
|
|
compile=False,
|
|
load_in_8bit=False
|
|
)
|
|
|
|
# Enable FP16 for Intel Arc performance
|
|
ov_model.half()
|
|
ov_model.save_pretrained(self.model_path)
|
|
logger.info("Model converted and saved in FP16 format")
|
|
|
|
# Load the converted model
|
|
self.ov_model = ov_model
|
|
self._compile_model()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Model conversion failed: {e}")
|
|
raise
|
|
|
|
def _convert_model_basic(self) -> None:
|
|
"""Basic model conversion without advanced features."""
|
|
logger.info(f"Basic conversion of {self.model_id} to OpenVINO format...")
|
|
|
|
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
|
self.model_id,
|
|
export=True,
|
|
compile=False
|
|
)
|
|
|
|
ov_model.save_pretrained(self.model_path)
|
|
logger.info("Basic model conversion completed")
|
|
|
|
def _load_fp16_model(self) -> None:
|
|
"""Load existing FP16 OpenVINO model."""
|
|
logger.info("Loading existing FP16 OpenVINO model...")
|
|
try:
|
|
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
|
self.model_path,
|
|
ov_config=self.config.to_ov_config(),
|
|
compile=False
|
|
)
|
|
self._compile_model()
|
|
except Exception as e:
|
|
logger.error(f"Failed to load FP16 model: {e}")
|
|
# Try basic loading
|
|
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
|
self.model_path,
|
|
compile=False
|
|
)
|
|
self._compile_model()
|
|
|
|
def _try_load_quantized_model(self) -> bool:
|
|
"""Try to load existing quantized model."""
|
|
try:
|
|
logger.info("Loading existing INT8 quantized model...")
|
|
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
|
self.quantized_model_path,
|
|
ov_config=self.config.to_ov_config(),
|
|
compile=False
|
|
)
|
|
self._compile_model()
|
|
self.is_quantized = True
|
|
logger.info("Quantized model loaded successfully")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load quantized model: {e}")
|
|
return False
|
|
|
|
def _try_quantize_existing_model(self) -> None:
|
|
"""Try to quantize the existing model after it's loaded."""
|
|
if not QUANTIZATION_AVAILABLE:
|
|
logger.info("Quantization libraries not available, skipping quantization")
|
|
return
|
|
|
|
if self.ov_model is None:
|
|
logger.warning("No model loaded, cannot quantize")
|
|
return
|
|
|
|
# Check if model components are available
|
|
if not hasattr(self.ov_model, 'encoder') or self.ov_model.encoder is None:
|
|
logger.warning("Model encoder not available, skipping quantization")
|
|
return
|
|
|
|
if not hasattr(self.ov_model, 'decoder_with_past') or self.ov_model.decoder_with_past is None:
|
|
logger.warning("Model decoder_with_past not available, skipping quantization")
|
|
return
|
|
|
|
try:
|
|
logger.info("Attempting to quantize compiled model...")
|
|
self._quantize_model_safe()
|
|
except Exception as e:
|
|
logger.warning(f"Quantization failed, continuing with FP16 model: {e}")
|
|
|
|
def _quantize_model_safe(self) -> None:
|
|
"""Safely quantize the model with extensive error handling."""
|
|
if not nncf:
|
|
logger.info("Quantization libraries not available, skipping quantization")
|
|
return
|
|
if self.quantized_model_path.exists():
|
|
logger.info("Quantized model already exists")
|
|
return
|
|
|
|
if self.ov_model is None:
|
|
raise RuntimeError("No model to quantize")
|
|
|
|
if not self.ov_model.decoder_with_past:
|
|
raise RuntimeError("Model decoder_with_past not available")
|
|
|
|
logger.info("Creating INT8 quantized model for Intel Arc B580...")
|
|
|
|
try:
|
|
# Collect calibration data with error handling
|
|
calibration_data = self._collect_calibration_data_safe()
|
|
if not calibration_data:
|
|
logger.warning("No calibration data collected, skipping quantization")
|
|
return
|
|
|
|
# Quantize encoder
|
|
if calibration_data.get('encoder'):
|
|
logger.info("Quantizing encoder...")
|
|
quantized_encoder = nncf.quantize(
|
|
self.ov_model.encoder.model,
|
|
nncf.Dataset(calibration_data['encoder']),
|
|
model_type=nncf.ModelType.TRANSFORMER,
|
|
subset_size=min(len(calibration_data['encoder']), 50)
|
|
)
|
|
else:
|
|
logger.warning("No encoder calibration data, copying original encoder")
|
|
quantized_encoder = self.ov_model.encoder.model
|
|
|
|
# Quantize decoder
|
|
if calibration_data.get('decoder'):
|
|
logger.info("Quantizing decoder with past...")
|
|
quantized_decoder = nncf.quantize(
|
|
self.ov_model.decoder_with_past.model,
|
|
nncf.Dataset(calibration_data['decoder']),
|
|
model_type=nncf.ModelType.TRANSFORMER,
|
|
subset_size=min(len(calibration_data['decoder']), 50)
|
|
)
|
|
else:
|
|
logger.warning("No decoder calibration data, copying original decoder")
|
|
quantized_decoder = self.ov_model.decoder_with_past.model
|
|
|
|
# Save quantized models
|
|
self.quantized_model_path.mkdir(parents=True, exist_ok=True)
|
|
ov.save_model(quantized_encoder, self.quantized_model_path / "openvino_encoder_model.xml") # type: ignore
|
|
ov.save_model(quantized_decoder, self.quantized_model_path / "openvino_decoder_with_past_model.xml") # type: ignore
|
|
|
|
# Copy remaining files
|
|
self._copy_model_files()
|
|
|
|
# Clean up
|
|
del quantized_encoder, quantized_decoder, calibration_data
|
|
gc.collect()
|
|
|
|
# Load quantized model
|
|
if self._try_load_quantized_model():
|
|
logger.info("Quantization completed successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Quantization failed: {e}")
|
|
# Clean up partial quantization
|
|
if self.quantized_model_path.exists():
|
|
shutil.rmtree(self.quantized_model_path, ignore_errors=True)
|
|
|
|
def _collect_calibration_data_safe(self, dataset_size: int = 20) -> Dict[str, CalibrationData]:
|
|
"""Safely collect calibration data with extensive error handling."""
|
|
if self.ov_model is None or self.processor is None:
|
|
return {}
|
|
|
|
logger.info(f"Collecting calibration data ({dataset_size} samples)...")
|
|
|
|
# Check model components
|
|
if not self.ov_model.encoder:
|
|
logger.warning("Encoder not available for calibration")
|
|
return {}
|
|
|
|
if not self.ov_model.decoder_with_past:
|
|
logger.warning("Decoder with past not available for calibration")
|
|
return {}
|
|
|
|
# Check if requests are available
|
|
if not hasattr(self.ov_model.encoder, 'request') or self.ov_model.encoder.request is None:
|
|
logger.warning("Encoder request not available for calibration")
|
|
return {}
|
|
|
|
if not hasattr(self.ov_model.decoder_with_past, 'request') or self.ov_model.decoder_with_past.request is None:
|
|
logger.warning("Decoder request not available for calibration")
|
|
return {}
|
|
|
|
# Setup data collection
|
|
original_encoder_request = self.ov_model.encoder.request
|
|
original_decoder_request = self.ov_model.decoder_with_past.request
|
|
|
|
encoder_data: CalibrationData = []
|
|
decoder_data: CalibrationData = []
|
|
|
|
try:
|
|
self.ov_model.encoder.request = InferRequestWrapper(original_encoder_request, encoder_data)
|
|
self.ov_model.decoder_with_past.request = InferRequestWrapper(original_decoder_request, decoder_data)
|
|
|
|
# Generate synthetic calibration data instead of loading dataset
|
|
logger.info("Generating synthetic calibration data...")
|
|
for i in range(dataset_size):
|
|
try:
|
|
# Generate random audio similar to speech
|
|
duration = 2.0 + np.random.random() * 3.0 # 2-5 seconds
|
|
synthetic_audio = np.random.randn(int(SAMPLE_RATE * duration)).astype(np.float32) * 0.1
|
|
|
|
inputs : Any = self.processor(
|
|
synthetic_audio,
|
|
sampling_rate=SAMPLE_RATE,
|
|
return_tensors="pt"
|
|
)
|
|
|
|
# Run inference to collect calibration data
|
|
generated_ids = self.ov_model.generate(inputs.input_features, max_new_tokens=10)
|
|
|
|
if i % 5 == 0:
|
|
logger.debug(f"Generated calibration sample {i+1}/{dataset_size}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to generate calibration sample {i}: {e}")
|
|
continue
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during calibration data collection: {e}")
|
|
finally:
|
|
# Restore original requests
|
|
try:
|
|
self.ov_model.encoder.request = original_encoder_request
|
|
self.ov_model.decoder_with_past.request = original_decoder_request
|
|
except Exception as e:
|
|
logger.warning(f"Failed to restore original requests: {e}")
|
|
|
|
result = {}
|
|
if encoder_data:
|
|
result['encoder'] = encoder_data
|
|
logger.info(f"Collected {len(encoder_data)} encoder calibration samples")
|
|
if decoder_data:
|
|
result['decoder'] = decoder_data
|
|
logger.info(f"Collected {len(decoder_data)} decoder calibration samples")
|
|
|
|
return result
|
|
|
|
def _copy_model_files(self) -> None:
|
|
"""Copy necessary model files for quantized model."""
|
|
try:
|
|
# Copy config and first-step decoder
|
|
if (self.model_path / "config.json").exists():
|
|
shutil.copy(self.model_path / "config.json", self.quantized_model_path / "config.json")
|
|
|
|
decoder_xml = self.model_path / "openvino_decoder_model.xml"
|
|
decoder_bin = self.model_path / "openvino_decoder_model.bin"
|
|
|
|
if decoder_xml.exists():
|
|
shutil.copy(decoder_xml, self.quantized_model_path / "openvino_decoder_model.xml")
|
|
if decoder_bin.exists():
|
|
shutil.copy(decoder_bin, self.quantized_model_path / "openvino_decoder_model.bin")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to copy some model files: {e}")
|
|
|
|
def _compile_model(self) -> None:
|
|
"""Compile model for Intel Arc B580."""
|
|
if self.ov_model is None:
|
|
raise RuntimeError("Model not loaded")
|
|
|
|
logger.info("Compiling model for Intel Arc B580...")
|
|
try:
|
|
self.ov_model.to(self.config.device)
|
|
self.ov_model.compile()
|
|
|
|
# Warmup for optimal performance
|
|
self._warmup_model()
|
|
logger.info("Model compiled and warmed up successfully")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to compile for {self.config.device}, attempting safe CPU fallback: {e}")
|
|
# Fallback: reload/compile model with a CPU-only ov_config to avoid
|
|
# passing GPU-specific properties to the CPU plugin which can raise
|
|
# NotFound/UnsupportedProperty exceptions.
|
|
try:
|
|
cpu_cfg = OpenVINOConfig(**{**self.config.model_dump()}) if hasattr(self.config, 'model_dump') else self.config
|
|
# Ensure device is CPU and use conservative CPU threading options
|
|
cpu_cfg = OpenVINOConfig(device='CPU', cache_dir=self.config.cache_dir, enable_quantization=self.config.enable_quantization, throughput_streams=1, max_threads=self.config.max_threads)
|
|
|
|
logger.info("Reloading model with CPU-only OpenVINO config for safe compilation")
|
|
# Try to reload using the existing saved model path if possible
|
|
try:
|
|
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
|
self.model_path,
|
|
ov_config=cpu_cfg.to_ov_config(),
|
|
compile=False
|
|
)
|
|
except Exception:
|
|
# If loading the saved model failed, try loading without ov_config
|
|
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(self.model_path, compile=False)
|
|
|
|
# Compile on CPU
|
|
self.ov_model.to('CPU')
|
|
# Provide CPU-only ov_config if supported
|
|
try:
|
|
self.ov_model.compile()
|
|
except Exception as compile_cpu_e:
|
|
logger.warning(f"CPU compile with CPU ov_config failed, retrying default compile: {compile_cpu_e}")
|
|
self.ov_model.compile()
|
|
|
|
self._warmup_model()
|
|
logger.info("Model compiled for CPU successfully")
|
|
except Exception as cpu_e:
|
|
logger.error(f"Failed to compile for CPU as well: {cpu_e}")
|
|
raise
|
|
|
|
def _warmup_model(self) -> None:
|
|
"""Warmup model for consistent GPU performance."""
|
|
if self.ov_model is None or self.processor is None:
|
|
return
|
|
|
|
try:
|
|
logger.info("Warming up model...")
|
|
dummy_audio = np.random.randn(SAMPLE_RATE).astype(np.float32) # 1 second
|
|
dummy_features = self.processor(
|
|
dummy_audio,
|
|
sampling_rate=SAMPLE_RATE,
|
|
return_tensors="pt"
|
|
).input_features
|
|
|
|
# Run warmup iterations
|
|
for i in range(3):
|
|
_ = self.ov_model.generate(dummy_features, max_new_tokens=10)
|
|
if i == 0:
|
|
logger.debug("First warmup iteration completed")
|
|
except Exception as e:
|
|
logger.warning(f"Model warmup failed: {e}")
|
|
|
|
def generate(self, input_features: torch.Tensor, language: str = "en") -> torch.Tensor:
|
|
"""Generate transcription from input features."""
|
|
if self.ov_model is None:
|
|
raise RuntimeError("Model not initialized")
|
|
|
|
generation_config : dict[str, Any]= {
|
|
"max_length": 448,
|
|
"num_beams": 4, # Use beam search for better results
|
|
# "num_beams": 1, # Greedy decoding for speed
|
|
"no_repeat_ngram_size": 3, # Prevent repetitive phrases
|
|
"language": language, # Explicitly set language to English
|
|
"task": "transcribe", # Ensure transcription, not translation
|
|
"suppress_tokens": None, # Disable default suppress_tokens to avoid conflicts
|
|
"begin_suppress_tokens": None, # Disable default begin_suppress_tokens
|
|
"max_new_tokens": 128,
|
|
"do_sample": False
|
|
}
|
|
try:
|
|
return self.ov_model.generate( # type: ignore
|
|
input_features,
|
|
**generation_config
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Model generation failed: {e}")
|
|
raise RuntimeError(f"Failed to generate transcription: {e}")
|
|
|
|
def decode(self, token_ids: torch.Tensor, skip_special_tokens: bool = True) -> List[str]:
|
|
"""Decode token IDs to text."""
|
|
if self.processor is None:
|
|
raise RuntimeError("Processor not initialized")
|
|
|
|
return self.processor.batch_decode(token_ids, skip_special_tokens=skip_special_tokens)
|
|
|
|
|
|
# Global model instance with deferred loading
|
|
_whisper_model: Optional[OpenVINOWhisperModel] = None
|
|
_audio_processors: Dict[str, "OptimizedAudioProcessor"] = {}
|
|
_send_chat_func: Optional[Callable[[str], Awaitable[None]]] = None
|
|
|
|
def _ensure_model_loaded(device: str = _device) -> OpenVINOWhisperModel:
|
|
"""Ensure the global model is loaded."""
|
|
global _whisper_model
|
|
if _whisper_model is None:
|
|
setup_intel_arc_environment()
|
|
logger.info(f"Loading OpenVINO Whisper model: {_model_id}")
|
|
_whisper_model = OpenVINOWhisperModel(model_id=_model_id, config=_ov_config, device=device)
|
|
logger.info("OpenVINO Whisper model loaded successfully")
|
|
return _whisper_model
|
|
|
|
|
|
def extract_input_features(audio_array: AudioArray, sampling_rate: int) -> torch.Tensor:
|
|
"""Extract input features from audio array optimized for OpenVINO."""
|
|
ov_model = _ensure_model_loaded()
|
|
if ov_model.processor is None:
|
|
raise RuntimeError("Processor not initialized")
|
|
|
|
inputs = ov_model.processor(
|
|
audio_array,
|
|
sampling_rate=sampling_rate,
|
|
return_tensors="pt",
|
|
)
|
|
return inputs.input_features
|
|
|
|
|
|
class OptimizedAudioProcessor:
|
|
"""Optimized audio processor for Intel Arc B580 with reduced latency."""
|
|
|
|
def __init__(self, peer_name: str, send_chat_func: Callable[[str], Awaitable[None]]):
|
|
self.peer_name = peer_name
|
|
self.send_chat_func = send_chat_func
|
|
self.sample_rate = SAMPLE_RATE
|
|
|
|
# Optimized buffering parameters
|
|
self.chunk_size = int(self.sample_rate * CHUNK_DURATION_MS / 1000) # 100ms chunks
|
|
self.buffer_size = self.chunk_size * 50 # 5 seconds max
|
|
|
|
# Circular buffer for zero-copy operations
|
|
self.audio_buffer = np.zeros(self.buffer_size, dtype=np.float32)
|
|
self.write_ptr = 0
|
|
self.read_ptr = 0
|
|
|
|
# Voice Activity Detection
|
|
self.vad_threshold = VAD_THRESHOLD
|
|
self.silence_frames = 0
|
|
self.max_silence_frames = MAX_SILENCE_FRAMES
|
|
|
|
# Processing state
|
|
self.current_phrase_audio = np.array([], dtype=np.float32)
|
|
self.transcription_history: List[TranscriptionHistoryItem] = []
|
|
self.last_activity_time = time.time()
|
|
|
|
# Async processing
|
|
self.processing_queue: asyncio.Queue[AudioQueueItem] = asyncio.Queue(maxsize=10)
|
|
self.is_running = True
|
|
|
|
# Start async processing task
|
|
try:
|
|
self.main_loop = asyncio.get_running_loop()
|
|
asyncio.create_task(self._async_processing_loop())
|
|
logger.info(f"Started async processing for {self.peer_name}")
|
|
except RuntimeError:
|
|
# Fallback to thread-based processing
|
|
self.main_loop = None
|
|
self.processor_thread = threading.Thread(target=self._thread_processing_loop, daemon=True)
|
|
self.processor_thread.start()
|
|
logger.warning(f"Using thread-based processing for {self.peer_name}")
|
|
|
|
logger.info(f"OptimizedAudioProcessor initialized for {self.peer_name}")
|
|
|
|
def add_audio_data(self, audio_data: AudioArray) -> None:
|
|
"""Add audio data with Voice Activity Detection and circular buffering."""
|
|
if not self.is_running or len(audio_data) == 0:
|
|
return
|
|
|
|
# Voice Activity Detection
|
|
energy = np.sqrt(np.mean(audio_data**2))
|
|
has_speech = energy > self.vad_threshold
|
|
|
|
if not has_speech:
|
|
self.silence_frames += 1
|
|
if self.silence_frames > self.max_silence_frames:
|
|
# Clear current phrase on long silence
|
|
if len(self.current_phrase_audio) > 0:
|
|
self._queue_final_transcription()
|
|
return
|
|
else:
|
|
self.silence_frames = 0
|
|
self.last_activity_time = time.time()
|
|
|
|
# Add to circular buffer (zero-copy when possible)
|
|
self._add_to_circular_buffer(audio_data)
|
|
|
|
# Check if we should process
|
|
if self._available_samples() >= self.chunk_size:
|
|
self._queue_for_processing()
|
|
|
|
def _add_to_circular_buffer(self, audio_data: AudioArray) -> None:
|
|
"""Add data to circular buffer efficiently."""
|
|
chunk_len = len(audio_data)
|
|
|
|
if self.write_ptr + chunk_len <= self.buffer_size:
|
|
# Simple case - no wraparound
|
|
self.audio_buffer[self.write_ptr:self.write_ptr + chunk_len] = audio_data
|
|
else:
|
|
# Wraparound case
|
|
first_part = self.buffer_size - self.write_ptr
|
|
self.audio_buffer[self.write_ptr:] = audio_data[:first_part]
|
|
self.audio_buffer[:chunk_len - first_part] = audio_data[first_part:]
|
|
|
|
self.write_ptr = (self.write_ptr + chunk_len) % self.buffer_size
|
|
|
|
def _available_samples(self) -> int:
|
|
"""Calculate available samples in circular buffer."""
|
|
if self.write_ptr >= self.read_ptr:
|
|
return self.write_ptr - self.read_ptr
|
|
else:
|
|
return self.buffer_size - self.read_ptr + self.write_ptr
|
|
|
|
def _extract_chunk(self, size: int) -> AudioArray:
|
|
"""Extract chunk from circular buffer."""
|
|
if self.read_ptr + size <= self.buffer_size:
|
|
chunk = self.audio_buffer[self.read_ptr:self.read_ptr + size].copy()
|
|
else:
|
|
first_part = self.buffer_size - self.read_ptr
|
|
chunk = np.concatenate([
|
|
self.audio_buffer[self.read_ptr:],
|
|
self.audio_buffer[:size - first_part]
|
|
])
|
|
|
|
self.read_ptr = (self.read_ptr + size) % self.buffer_size
|
|
return chunk.astype(np.float32)
|
|
|
|
def _queue_for_processing(self) -> None:
|
|
"""Queue audio chunk for processing."""
|
|
available = self._available_samples()
|
|
if available < self.chunk_size:
|
|
return
|
|
|
|
# Extract chunk for processing
|
|
chunk = self._extract_chunk(self.chunk_size)
|
|
|
|
# Create queue item
|
|
queue_item = AudioQueueItem(audio=chunk, timestamp=time.time())
|
|
|
|
# Queue for processing
|
|
if self.main_loop:
|
|
try:
|
|
self.processing_queue.put_nowait(queue_item)
|
|
except asyncio.QueueFull:
|
|
logger.warning(f"Processing queue full for {self.peer_name}, dropping chunk")
|
|
else:
|
|
# Thread-based fallback
|
|
try:
|
|
threading_queue = getattr(self, '_threading_queue', None)
|
|
if threading_queue:
|
|
threading_queue.put_nowait(queue_item)
|
|
except Exception as e:
|
|
logger.warning(f"Threading queue issue for {self.peer_name}: {e}")
|
|
|
|
def _queue_final_transcription(self) -> None:
|
|
"""Queue final transcription of current phrase."""
|
|
if len(self.current_phrase_audio) > self.sample_rate * 0.5: # At least 0.5 seconds
|
|
if self.main_loop:
|
|
asyncio.create_task(self._transcribe_and_send(self.current_phrase_audio.copy(), is_final=True))
|
|
|
|
self.current_phrase_audio = np.array([], dtype=np.float32)
|
|
|
|
async def _async_processing_loop(self) -> None:
|
|
"""Async processing loop for audio chunks."""
|
|
logger.info(f"Started async processing loop for {self.peer_name}")
|
|
|
|
while self.is_running:
|
|
try:
|
|
# Get audio chunk
|
|
audio_item = await asyncio.wait_for(self.processing_queue.get(), timeout=1.0)
|
|
|
|
# Add to current phrase
|
|
self.current_phrase_audio = np.concatenate([self.current_phrase_audio, audio_item.audio])
|
|
|
|
# Check if we should transcribe
|
|
phrase_duration = len(self.current_phrase_audio) / self.sample_rate
|
|
|
|
if phrase_duration >= 1.0: # Transcribe every 1 second
|
|
await self._transcribe_and_send(self.current_phrase_audio.copy(), is_final=False)
|
|
|
|
except asyncio.TimeoutError:
|
|
# Check for final transcription on timeout
|
|
if len(self.current_phrase_audio) > 0 and time.time() - self.last_activity_time > 2.0:
|
|
await self._transcribe_and_send(self.current_phrase_audio.copy(), is_final=True)
|
|
self.current_phrase_audio = np.array([], dtype=np.float32)
|
|
except Exception as e:
|
|
logger.error(f"Error in async processing loop for {self.peer_name}: {e}")
|
|
|
|
logger.info(f"Async processing loop ended for {self.peer_name}")
|
|
|
|
def _thread_processing_loop(self) -> None:
|
|
"""Thread-based processing loop fallback."""
|
|
self._threading_queue: Queue[AudioQueueItem] = Queue(maxsize=10)
|
|
logger.info(f"Started thread processing loop for {self.peer_name}")
|
|
|
|
while self.is_running:
|
|
try:
|
|
audio_item = self._threading_queue.get(timeout=1.0)
|
|
|
|
# Add to current phrase
|
|
self.current_phrase_audio = np.concatenate([self.current_phrase_audio, audio_item.audio])
|
|
|
|
# Check if we should transcribe
|
|
phrase_duration = len(self.current_phrase_audio) / self.sample_rate
|
|
|
|
if phrase_duration >= 1.0:
|
|
if self.main_loop:
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._transcribe_and_send(self.current_phrase_audio.copy(), is_final=False),
|
|
self.main_loop
|
|
)
|
|
|
|
except Empty:
|
|
# Check for final transcription
|
|
if len(self.current_phrase_audio) > 0 and time.time() - self.last_activity_time > 2.0:
|
|
if self.main_loop:
|
|
asyncio.run_coroutine_threadsafe(
|
|
self._transcribe_and_send(self.current_phrase_audio.copy(), is_final=True),
|
|
self.main_loop
|
|
)
|
|
self.current_phrase_audio = np.array([], dtype=np.float32)
|
|
except Exception as e:
|
|
logger.error(f"Error in thread processing loop for {self.peer_name}: {e}")
|
|
|
|
async def _transcribe_and_send(self, audio_array: AudioArray, is_final: bool, language: str="en") -> None:
|
|
"""
|
|
Transcribe raw numpy audio data using OpenVINO Whisper.
|
|
|
|
Parameters:
|
|
- audio_array: normalized 1D numpy array containing mono PCM data at 16 kHz.
|
|
- is_final: whether this is a final transcription (True) or interim (False)
|
|
- language: language code for transcription (default 'en' for English)
|
|
"""
|
|
if audio_array.ndim != 1:
|
|
raise ValueError("Expected mono audio as a 1D numpy array.")
|
|
|
|
transcription_start = time.time()
|
|
transcription_type = "final" if is_final else "streaming"
|
|
|
|
try:
|
|
audio_duration = len(audio_array) / self.sample_rate
|
|
|
|
# Skip very short audio
|
|
if audio_duration < 0.3:
|
|
logger.debug(f"Skipping {transcription_type} transcription: too short ({audio_duration:.2f}s)")
|
|
return
|
|
|
|
# Audio quality check
|
|
audio_rms = np.sqrt(np.mean(audio_array**2))
|
|
if audio_rms < 0.001:
|
|
logger.debug(f"Skipping {transcription_type} transcription: too quiet (RMS: {audio_rms:.6f})")
|
|
return
|
|
|
|
logger.info(f"🎬 OpenVINO transcription ({transcription_type}) started: {audio_duration:.2f}s, RMS: {audio_rms:.4f}")
|
|
|
|
# Extract features for OpenVINO
|
|
input_features = extract_input_features(audio_array, self.sample_rate)
|
|
# logger.info(f"Features extracted for OpenVINO: {input_features.shape}")
|
|
# GPU inference with OpenVINO
|
|
ov_model = _ensure_model_loaded()
|
|
generated_ids = ov_model.generate(input_features)
|
|
|
|
# Decode tokens into text
|
|
transcription = ov_model.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
text = transcription.strip() if transcription else ""
|
|
logger.info(f"Transcription text: {text}")
|
|
transcription_time = time.time() - transcription_start
|
|
|
|
if text and len(text.split()) >= 2:
|
|
# Create message with timing
|
|
status_marker = "⚡" if is_final else "🎤"
|
|
type_marker = "" if is_final else " [streaming]"
|
|
timing_info = f" (🚀 {transcription_time:.2f}s)"
|
|
|
|
message = f"{status_marker} {self.peer_name}{type_marker}: {text}{timing_info}"
|
|
|
|
# Avoid duplicates
|
|
if not self._is_duplicate(text):
|
|
await self.send_chat_func(message)
|
|
|
|
# Update history
|
|
self.transcription_history.append(TranscriptionHistoryItem(
|
|
message=message,
|
|
timestamp=time.time(),
|
|
is_final=is_final
|
|
))
|
|
|
|
# Limit history
|
|
if len(self.transcription_history) > 10:
|
|
self.transcription_history.pop(0)
|
|
|
|
logger.info(f"✅ OpenVINO transcription ({transcription_type}): '{text}' ({transcription_time:.3f}s)")
|
|
else:
|
|
logger.debug(f"Skipping duplicate {transcription_type} transcription: '{text}'")
|
|
else:
|
|
logger.debug(f"Empty or too short transcription result: '{text}'")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in OpenVINO {transcription_type} transcription: {e}", exc_info=True)
|
|
|
|
def _is_duplicate(self, text: str) -> bool:
|
|
"""Check if transcription is duplicate of recent ones."""
|
|
recent_texts = [h.message.split(': ', 1)[-1].split(' (🚀')[0]
|
|
for h in self.transcription_history[-3:]]
|
|
return text in recent_texts
|
|
|
|
def shutdown(self) -> None:
|
|
"""Shutdown the audio processor."""
|
|
logger.info(f"Shutting down OptimizedAudioProcessor for {self.peer_name}...")
|
|
self.is_running = False
|
|
|
|
# Final transcription if needed
|
|
if len(self.current_phrase_audio) > 0:
|
|
if self.main_loop:
|
|
asyncio.create_task(self._transcribe_and_send(self.current_phrase_audio.copy(), is_final=True))
|
|
|
|
# Cleanup thread if exists
|
|
if hasattr(self, 'processor_thread') and self.processor_thread.is_alive():
|
|
self.processor_thread.join(timeout=2.0)
|
|
|
|
logger.info(f"OptimizedAudioProcessor shutdown complete for {self.peer_name}")
|
|
|
|
def normalize_audio(audio_data: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
|
|
"""Normalize audio to have maximum amplitude of 1.0."""
|
|
max_amplitude = np.max(np.abs(audio_data))
|
|
if max_amplitude > 0:
|
|
audio_data = audio_data / max_amplitude
|
|
return audio_data
|
|
|
|
|
|
class MediaClock:
|
|
"""Simple monotonic clock for media tracks."""
|
|
|
|
def __init__(self) -> None:
|
|
self.t0 = perf_counter()
|
|
|
|
def now(self) -> float:
|
|
return perf_counter() - self.t0
|
|
|
|
|
|
class WaveformVideoTrack(MediaStreamTrack):
|
|
"""Video track that renders a live waveform of the incoming audio.
|
|
|
|
The track reads the most-active `OptimizedAudioProcessor` in
|
|
`_audio_processors` and renders the last ~2s of its `current_phrase_audio`.
|
|
If no audio is available, the track will display a "No audio" message.
|
|
"""
|
|
|
|
kind = "video"
|
|
|
|
def __init__(self, session_name: str, width: int = 640, height: int = 240, fps: int = 15) -> None:
|
|
super().__init__()
|
|
self.session_name = session_name
|
|
self.width = int(width)
|
|
self.height = int(height)
|
|
self.fps = int(fps)
|
|
self.clock = MediaClock()
|
|
self._next_frame_index = 0
|
|
|
|
async def next_timestamp(self) -> tuple[int, float]:
|
|
pts = int(self._next_frame_index * (1 / self.fps) * 90000)
|
|
time_base = 1 / 90000
|
|
return pts, time_base
|
|
|
|
async def recv(self) -> VideoFrame:
|
|
pts, time_base = await self.next_timestamp()
|
|
|
|
# schedule frame according to clock
|
|
target_t = self._next_frame_index / self.fps
|
|
now = self.clock.now()
|
|
if target_t > now:
|
|
await asyncio.sleep(target_t - now)
|
|
|
|
self._next_frame_index += 1
|
|
|
|
frame_array: npt.NDArray[np.uint8] = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
|
|
|
# Select the most active processor (highest RMS) and draw its waveform
|
|
best_proc = None
|
|
best_rms = 0.0
|
|
try:
|
|
for pname, proc in _audio_processors.items():
|
|
try:
|
|
arr = getattr(proc, 'current_phrase_audio', None)
|
|
if arr is None or len(arr) == 0:
|
|
continue
|
|
rms = float(np.sqrt(np.mean(arr**2)))
|
|
if rms > best_rms:
|
|
best_rms = rms
|
|
best_proc = (pname, arr.copy())
|
|
except Exception:
|
|
continue
|
|
except Exception:
|
|
best_proc = None
|
|
|
|
if best_proc is not None:
|
|
pname, arr = best_proc
|
|
|
|
# Use up to 2 seconds of audio for the waveform
|
|
window_samples = min(len(arr), SAMPLE_RATE * 2)
|
|
if window_samples <= 0:
|
|
arr_segment = np.zeros(1, dtype=np.float32)
|
|
else:
|
|
arr_segment = arr[-window_samples:]
|
|
|
|
# Normalize segment to -1..1 safely
|
|
maxv = float(np.max(np.abs(arr_segment))) if arr_segment.size > 0 else 0.0
|
|
if maxv > 0:
|
|
norm = arr_segment / maxv
|
|
else:
|
|
norm = np.zeros_like(arr_segment)
|
|
|
|
# Map audio samples to pixels across the width
|
|
if norm.size < self.width:
|
|
padded = np.zeros(self.width, dtype=np.float32)
|
|
if norm.size > 0:
|
|
padded[-norm.size:] = norm
|
|
norm = padded
|
|
else:
|
|
block = int(np.ceil(norm.size / self.width))
|
|
norm = np.array([np.mean(norm[i * block : min((i + 1) * block, norm.size)]) for i in range(self.width)], dtype=np.float32)
|
|
|
|
# Create polyline points, avoid NaN
|
|
points: list[tuple[int, int]] = []
|
|
for x in range(self.width):
|
|
v = float(norm[x]) if x < norm.size and not np.isnan(norm[x]) else 0.0
|
|
y = int((1.0 - ((v + 1.0) / 2.0)) * (self.height - 1))
|
|
points.append((x, max(0, min(self.height - 1, y))))
|
|
|
|
if len(points) > 1:
|
|
pts_np = np.array(points, dtype=np.int32)
|
|
cv2.polylines(frame_array, [pts_np], isClosed=False, color=(0, 200, 80), thickness=2)
|
|
|
|
cv2.putText(frame_array, f"Waveform: {pname}", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
|
else:
|
|
cv2.putText(frame_array, "No audio", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (200, 200, 200), 1)
|
|
|
|
frame = VideoFrame.from_ndarray(frame_array, format="bgr24")
|
|
frame.pts = pts
|
|
frame.time_base = fractions.Fraction(1 / 90000).limit_denominator(1000000)
|
|
return frame
|
|
|
|
async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None:
|
|
"""Handle incoming audio tracks from WebRTC peers."""
|
|
global _audio_processors, _send_chat_func
|
|
|
|
if track.kind != "audio":
|
|
logger.info(f"Ignoring non-audio track from {peer.peer_name}: {track.kind}")
|
|
return
|
|
|
|
# Create audio processor
|
|
if peer.peer_name not in _audio_processors:
|
|
if _send_chat_func is None:
|
|
logger.error(f"Cannot create processor for {peer.peer_name}: no send_chat_func")
|
|
return
|
|
|
|
logger.info(f"Creating OptimizedAudioProcessor for {peer.peer_name}")
|
|
_audio_processors[peer.peer_name] = OptimizedAudioProcessor(
|
|
peer_name=peer.peer_name,
|
|
send_chat_func=_send_chat_func
|
|
)
|
|
|
|
audio_processor = _audio_processors[peer.peer_name]
|
|
logger.info(f"Starting OpenVINO audio processing for {peer.peer_name}")
|
|
|
|
try:
|
|
frame_count = 0
|
|
while True:
|
|
try:
|
|
frame = await track.recv()
|
|
frame_count += 1
|
|
|
|
if frame_count % 100 == 0:
|
|
logger.debug(f"Received {frame_count} frames from {peer.peer_name}")
|
|
|
|
except MediaStreamError as e:
|
|
logger.info(f"Audio stream ended for {peer.peer_name}: {e}")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error receiving frame from {peer.peer_name}: {e}")
|
|
break
|
|
|
|
if isinstance(frame, AudioFrame):
|
|
try:
|
|
# Convert frame to numpy array
|
|
audio_data = frame.to_ndarray()
|
|
|
|
# Handle audio format conversion
|
|
audio_data = _process_audio_frame(audio_data, frame)
|
|
|
|
# Resample if needed
|
|
if frame.sample_rate != SAMPLE_RATE:
|
|
audio_data = _resample_audio(audio_data, frame.sample_rate, SAMPLE_RATE)
|
|
|
|
# Convert to float32
|
|
audio_data_float32 = audio_data.astype(np.float32)
|
|
audio_data = normalize_audio(audio_data)
|
|
|
|
# Process with optimized processor
|
|
audio_processor.add_audio_data(audio_data_float32)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing audio frame for {peer.peer_name}: {e}")
|
|
continue
|
|
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in audio processing for {peer.peer_name}: {e}", exc_info=True)
|
|
finally:
|
|
cleanup_peer_processor(peer.peer_name)
|
|
|
|
|
|
def _process_audio_frame(audio_data: npt.NDArray[Any], frame: AudioFrame) -> npt.NDArray[np.float32]:
|
|
"""Process audio frame format conversion."""
|
|
# Handle stereo to mono conversion
|
|
if audio_data.ndim == 2:
|
|
if audio_data.shape[0] == 1:
|
|
audio_data = audio_data.squeeze(0)
|
|
else:
|
|
audio_data = np.mean(audio_data, axis=0 if audio_data.shape[0] > audio_data.shape[1] else 1)
|
|
|
|
# Normalize based on data type
|
|
if audio_data.dtype == np.int16:
|
|
audio_data = audio_data.astype(np.float32) / 32768.0
|
|
elif audio_data.dtype == np.int32:
|
|
audio_data = audio_data.astype(np.float32) / 2147483648.0
|
|
|
|
return audio_data.astype(np.float32)
|
|
|
|
|
|
def _resample_audio(audio_data: npt.NDArray[np.float32], orig_sr: int, target_sr: int) -> npt.NDArray[np.float32]:
|
|
"""Resample audio efficiently."""
|
|
try:
|
|
# Handle stereo audio by converting to mono if necessary
|
|
if audio_data.ndim > 1:
|
|
audio_data = np.mean(audio_data, axis=1)
|
|
|
|
# Use high-quality resampling
|
|
resampled = librosa.resample(
|
|
audio_data.astype(np.float64),
|
|
orig_sr=orig_sr,
|
|
target_sr=target_sr,
|
|
res_type='kaiser_fast' # Good balance of quality and speed
|
|
)
|
|
return resampled.astype(np.float32)
|
|
except Exception as e:
|
|
logger.error(f"Resampling failed: {e}")
|
|
raise ValueError(f"Failed to resample audio from {orig_sr} Hz to {target_sr} Hz: {e}")
|
|
|
|
|
|
|
|
# Public API functions
|
|
def agent_info() -> Dict[str, str]:
|
|
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "true"}
|
|
|
|
|
|
def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
|
|
"""Create agent tracks. Provides a synthetic video waveform track and a silent audio track for compatibility."""
|
|
class SilentAudioTrack(MediaStreamTrack):
|
|
kind = "audio"
|
|
def __init__(self, sample_rate: int = SAMPLE_RATE, channels: int = 1, fps: int = 50):
|
|
super().__init__()
|
|
self.sample_rate = sample_rate
|
|
self.channels = channels
|
|
self.fps = fps
|
|
self.samples_per_frame = int(self.sample_rate / self.fps)
|
|
self._timestamp = 0
|
|
async def recv(self) -> AudioFrame:
|
|
# Generate silent audio as int16 (required by aiortc)
|
|
data = np.zeros((self.channels, self.samples_per_frame), dtype=np.int16)
|
|
frame = AudioFrame.from_ndarray(data, layout="mono" if self.channels == 1 else "stereo")
|
|
frame.sample_rate = self.sample_rate
|
|
frame.pts = self._timestamp
|
|
frame.time_base = fractions.Fraction(1, self.sample_rate)
|
|
self._timestamp += self.samples_per_frame
|
|
await asyncio.sleep(1 / self.fps)
|
|
return frame
|
|
try:
|
|
video_track = WaveformVideoTrack(session_name=session_name, width=640, height=240, fps=15)
|
|
audio_track = SilentAudioTrack()
|
|
return {"video": video_track, "audio": audio_track}
|
|
except Exception as e:
|
|
logger.error(f"Failed to create agent tracks: {e}")
|
|
return {}
|
|
|
|
|
|
async def handle_chat_message(
|
|
chat_message: ChatMessageModel,
|
|
send_message_func: Callable[[str], Awaitable[None]]
|
|
) -> Optional[str]:
|
|
"""Handle incoming chat messages."""
|
|
return None
|
|
|
|
|
|
async def on_track_received(peer: Peer, track: MediaStreamTrack) -> None:
|
|
"""Callback when a new track is received from a peer."""
|
|
await handle_track_received(peer, track)
|
|
|
|
|
|
def get_track_handler() -> Callable[[Peer, MediaStreamTrack], Awaitable[None]]:
|
|
"""Return the track handler function."""
|
|
return on_track_received
|
|
|
|
|
|
def bind_send_chat_function(send_chat_func: Callable[[str], Awaitable[None]]) -> None:
|
|
"""Bind the send chat function."""
|
|
global _send_chat_func, _audio_processors
|
|
|
|
logger.info("Binding send chat function to OpenVINO whisper agent")
|
|
_send_chat_func = send_chat_func
|
|
|
|
# Update existing processors
|
|
for peer_name, processor in _audio_processors.items():
|
|
processor.send_chat_func = send_chat_func
|
|
logger.debug(f"Updated processor for {peer_name} with new send chat function")
|
|
|
|
|
|
def cleanup_peer_processor(peer_name: str) -> None:
|
|
"""Clean up processor for disconnected peer."""
|
|
global _audio_processors
|
|
|
|
if peer_name in _audio_processors:
|
|
logger.info(f"Cleaning up processor for {peer_name}")
|
|
processor = _audio_processors[peer_name]
|
|
processor.shutdown()
|
|
del _audio_processors[peer_name]
|
|
logger.info(f"Processor cleanup complete for {peer_name}")
|
|
|
|
|
|
def get_active_processors() -> Dict[str, OptimizedAudioProcessor]:
|
|
"""Get active processors for debugging."""
|
|
return _audio_processors.copy()
|
|
|
|
|
|
def get_model_info() -> Dict[str, Any]:
|
|
"""Get information about the loaded model."""
|
|
ov_model = _ensure_model_loaded()
|
|
return {
|
|
"model_id": _model_id,
|
|
"device": _ov_config.device,
|
|
"quantization_enabled": _ov_config.enable_quantization,
|
|
"is_quantized": ov_model.is_quantized,
|
|
"sample_rate": SAMPLE_RATE,
|
|
"chunk_duration_ms": CHUNK_DURATION_MS
|
|
} |