This commit is contained in:
James Ketr 2025-09-16 13:57:58 -07:00
parent e96bd887ab
commit 90c3c6e19b
4 changed files with 451 additions and 344 deletions

View File

@ -4,7 +4,6 @@
!client !client
!shared !shared
**/node_modules **/node_modules
**/build
**/dist **/dist
**/__pycache__ **/__pycache__
**/.venv **/.venv

View File

@ -1,7 +1,5 @@
""" """
Lobby management for the AI Voice Bot server. Lobby management for the AI Voice Bot server.handles lobby lifecycle, participants, and chat functionality.
This module handles lobby lifecycle, participants, and chat functionality.
Extracted from main.py to improve maintainability and separation of concerns. Extracted from main.py to improve maintainability and separation of concerns.
""" """
@ -9,21 +7,30 @@ from __future__ import annotations
import secrets import secrets
import time import time
import threading import threading
from typing import Dict, List, Optional, TYPE_CHECKING from typing import Dict, List, Optional, TYPE_CHECKING, Callable
import os
# Import shared models # Import shared models
# Import shared models # Import shared models
try: try:
# Try relative import first (when running as part of the package) # Try relative import first (when running as part of the package)
from ...shared.models import ChatMessageModel, ParticipantModel from ...shared.models import (
ChatMessageModel,
ParticipantModel,
WebSocketMessageModel,
)
except ImportError: except ImportError:
try: try:
# Try absolute import (when running directly) # Try absolute import (when running directly)
import sys import sys
import os import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from shared.models import ChatMessageModel, ParticipantModel from shared.models import (
except ImportError: ChatMessageModel,
ParticipantModel,
WebSocketMessageModel,
)
except ImportError as e:
raise ImportError( raise ImportError(
f"Failed to import shared models: {e}. Ensure shared/models.py is accessible and PYTHONPATH is correctly set." f"Failed to import shared models: {e}. Ensure shared/models.py is accessible and PYTHONPATH is correctly set."
) )
@ -32,25 +39,21 @@ from shared.logger import logger
# Use try/except for importing events to handle both relative and absolute imports # Use try/except for importing events to handle both relative and absolute imports
try: try:
from ..models.events import event_bus, ChatMessageSent, SessionDisconnected, SessionLeftLobby from ..models.events import (
event_bus,
ChatMessageSent,
SessionDisconnected,
SessionLeftLobby,
Event,
)
except ImportError: except ImportError:
try: from models.events import (
from models.events import event_bus, ChatMessageSent, SessionDisconnected, SessionLeftLobby event_bus,
except ImportError: ChatMessageSent,
# Create dummy event system for standalone testing SessionDisconnected,
class DummyEventBus: SessionLeftLobby,
async def publish(self, event): Event,
pass )
event_bus = DummyEventBus()
class ChatMessageSent:
pass
class SessionDisconnected:
pass
class SessionLeftLobby:
pass
if TYPE_CHECKING: if TYPE_CHECKING:
from .session_manager import Session from .session_manager import Session
@ -62,34 +65,6 @@ class LobbyConfig:
class Lobby: class Lobby:
async def broadcast_json(self, message: dict) -> None:
"""Broadcast an arbitrary JSON message to all connected sessions in the lobby"""
failed_sessions: List[Session] = []
for peer in self.sessions.values():
if peer.ws:
try:
await peer.ws.send_json(message)
except Exception as e:
logger.warning(
f"Failed to send broadcast_json message to {peer.getName()}: {e}"
)
failed_sessions.append(peer)
for failed_session in failed_sessions:
failed_session.ws = None
async def broadcast_peer_state_update(self, update: dict) -> None:
"""Broadcast a peer state update to all connected sessions in the lobby"""
failed_sessions: List[Session] = []
for peer in self.sessions.values():
if peer.ws:
try:
await peer.ws.send_json(update)
except Exception as e:
logger.warning(
f"Failed to send peer state update to {peer.getName()}: {e}"
)
failed_sessions.append(peer)
for failed_session in failed_sessions:
failed_session.ws = None
"""Individual lobby representing a chat/voice room""" """Individual lobby representing a chat/voice room"""
def __init__(self, name: str, id: Optional[str] = None, private: bool = False): def __init__(self, name: str, id: Optional[str] = None, private: bool = False):
@ -104,6 +79,36 @@ class Lobby:
def getName(self) -> str: def getName(self) -> str:
return f"{self.short}:{self.name}" return f"{self.short}:{self.name}"
async def broadcast_json(self, message: WebSocketMessageModel) -> None:
"""Broadcast an arbitrary JSON message to all connected sessions in the lobby"""
failed_sessions: List[Session] = []
for peer in self.sessions.values():
if peer.ws:
try:
await peer.ws.send_json(message.model_dump())
except Exception as e:
logger.warning(
f"Failed to send broadcast_json message to {peer.getName()}: {e}"
)
failed_sessions.append(peer)
for failed_session in failed_sessions:
failed_session.ws = None
async def broadcast_peer_state_update(self, update: WebSocketMessageModel) -> None:
"""Broadcast a peer state update to all connected sessions in the lobby"""
failed_sessions: List[Session] = []
for peer in self.sessions.values():
if peer.ws:
try:
await peer.ws.send_json(update.model_dump())
except Exception as e:
logger.warning(
f"Failed to send peer state update to {peer.getName()}: {e}"
)
failed_sessions.append(peer)
for failed_session in failed_sessions:
failed_session.ws = None
async def update_state(self, requesting_session: Optional[Session] = None): async def update_state(self, requesting_session: Optional[Session] = None):
"""Update lobby state and notify participants""" """Update lobby state and notify participants"""
with self.lock: with self.lock:
@ -344,7 +349,7 @@ class LobbyManager:
# Event system not available, skip subscriptions # Event system not available, skip subscriptions
pass pass
async def handle(self, event): async def handle(self, event: Event) -> None:
"""Handle events from the event bus""" """Handle events from the event bus"""
if isinstance(event, SessionDisconnected): if isinstance(event, SessionDisconnected):
@ -352,7 +357,7 @@ class LobbyManager:
elif isinstance(event, SessionLeftLobby): elif isinstance(event, SessionLeftLobby):
await self._handle_session_left_lobby(event) await self._handle_session_left_lobby(event)
async def _handle_session_disconnected(self, event): async def _handle_session_disconnected(self, event: SessionDisconnected) -> None:
"""Handle session disconnection by removing from all lobbies""" """Handle session disconnection by removing from all lobbies"""
session_id = event.session_id session_id = event.session_id
@ -372,7 +377,7 @@ class LobbyManager:
if lobby.is_empty() and not lobby.private: if lobby.is_empty() and not lobby.private:
await self._cleanup_empty_lobby(lobby) await self._cleanup_empty_lobby(lobby)
async def _handle_session_left_lobby(self, event): async def _handle_session_left_lobby(self, event: SessionLeftLobby) -> None:
"""Handle explicit session leave""" """Handle explicit session leave"""
# This is already handled by the session's leave_lobby method # This is already handled by the session's leave_lobby method
# but we could add additional cleanup logic here if needed # but we could add additional cleanup logic here if needed
@ -447,8 +452,8 @@ class LobbyManager:
return removed_count return removed_count
def set_name_protection_checker(self, checker_func): def set_name_protection_checker(self, checker_func: Callable[[str], bool]) -> None:
"""Inject name protection checker from AuthManager""" """Inject name protection checker from AuthManager"""
# This allows us to inject the name protection logic without tight coupling # This allows us to inject the name protection logic without tight coupling
for lobby in self.lobbies.values(): for lobby in self.lobbies.values():
lobby._is_name_protected = checker_func lobby._is_name_protected = checker_func # type: ignore

View File

@ -10,7 +10,12 @@ from typing import Dict, Any, TYPE_CHECKING
from fastapi import WebSocket from fastapi import WebSocket
from shared.logger import logger from shared.logger import logger
from shared.models import ChatMessageModel from shared.models import (
ChatMessageModel,
WebSocketMessageModel,
WebSocketErrorModel,
UpdateNameModel,
)
from .webrtc_signaling import WebRTCSignalingHandlers from .webrtc_signaling import WebRTCSignalingHandlers
from core.error_handling import ( from core.error_handling import (
error_handler, error_handler,
@ -45,11 +50,11 @@ class PeerStateUpdateHandler(MessageHandler):
async def handle( async def handle(
self, self,
session: Any, session: "Session",
lobby: Any, lobby: "Lobby",
data: dict, data: Dict[str, Any],
websocket: Any, websocket: WebSocket,
managers: dict, managers: Dict[str, Any],
) -> None: ) -> None:
# Only allow a user to update their own state # Only allow a user to update their own state
if not lobby or not session: if not lobby or not session:
@ -59,14 +64,14 @@ class PeerStateUpdateHandler(MessageHandler):
# Ignore attempts to update other users' state # Ignore attempts to update other users' state
# Optionally log or send error to client # Optionally log or send error to client
return return
update = { update = WebSocketMessageModel(
"type": "peer_state_update", type="peer_state_update",
"data": { data={
"peer_id": peer_id, "peer_id": peer_id,
"muted": data.get("muted"), "muted": data.get("muted"),
"video_on": data.get("video_on"), "video_on": data.get("video_on"),
}, }, # type: ignore
} )
await lobby.broadcast_peer_state_update(update) await lobby.broadcast_peer_state_update(update)
@ -86,10 +91,12 @@ class SetNameHandler(MessageHandler):
if not data: if not data:
logger.error(f"{session.getName()} - set_name missing data") logger.error(f"{session.getName()} - set_name missing data")
await websocket.send_json({ await websocket.send_json(
"type": "error", WebSocketMessageModel(
"data": {"error": "set_name missing data"}, type="error",
}) data=WebSocketErrorModel(error="set_name missing data"),
).model_dump()
)
return return
name = data.get("name") name = data.get("name")
@ -99,10 +106,11 @@ class SetNameHandler(MessageHandler):
if not name: if not name:
logger.error(f"{session.getName()} - Name required") logger.error(f"{session.getName()} - Name required")
await websocket.send_json({ await websocket.send_json(
"type": "error", WebSocketMessageModel(
"data": {"error": "Name required"} type="error", data=WebSocketErrorModel(error="Name required")
}) ).model_dump()
)
return return
# Check if name is unique # Check if name is unique
@ -114,13 +122,14 @@ class SetNameHandler(MessageHandler):
session.setName(name) session.setName(name)
logger.info(f"{session.getName()}: -> update('name', {name})") logger.info(f"{session.getName()}: -> update('name', {name})")
await websocket.send_json({ await websocket.send_json(
"type": "update_name", WebSocketMessageModel(
"data": { type="update_name",
"name": name, data=UpdateNameModel(
"protected": auth_manager.is_name_protected(name), name=name, protected=auth_manager.is_name_protected(name)
}, ),
}) ).model_dump()
)
# Update lobby state # Update lobby state
await lobby.update_state() await lobby.update_state()
@ -131,10 +140,11 @@ class SetNameHandler(MessageHandler):
if not allowed: if not allowed:
logger.warning(f"{session.getName()} - {reason}") logger.warning(f"{session.getName()} - {reason}")
await websocket.send_json({ await websocket.send_json(
"type": "error", WebSocketMessageModel(
"data": {"error": reason} type="error", data=WebSocketErrorModel(error=reason)
}) ).model_dump()
)
return return
# Takeover allowed - handle displacement # Takeover allowed - handle displacement
@ -179,13 +189,14 @@ class SetNameHandler(MessageHandler):
session.setName(name) session.setName(name)
logger.info(f"{session.getName()}: -> update('name', {name}) (takeover)") logger.info(f"{session.getName()}: -> update('name', {name}) (takeover)")
await websocket.send_json({ await websocket.send_json(
"type": "update_name", WebSocketMessageModel(
"data": { type="update_name",
"name": name, data=UpdateNameModel(
"protected": auth_manager.is_name_protected(name), name=name, protected=auth_manager.is_name_protected(name)
}, ),
}) ).model_dump()
)
# Update lobby state # Update lobby state
await lobby.update_state() await lobby.update_state()
@ -460,15 +471,15 @@ class MessageRouter:
): ):
"""Route a message to the appropriate handler with enhanced error handling""" """Route a message to the appropriate handler with enhanced error handling"""
if message_type not in self._handlers: if message_type not in self._handlers:
await error_handler.handle_error( await error_handler.handle_error( # type: ignore
ValidationError(f"Unknown message type: {message_type}"), ValidationError(f"Unknown message type: {message_type}"),
context={ context={
"message_type": message_type, "message_type": message_type,
"session_id": session.id if session else "unknown", "session_id": session.id if session else "unknown",
"data_keys": list(data.keys()) if data else [] "data_keys": list(data.keys()) if data else [],
}, },
websocket=websocket, websocket=websocket,
session_id=session.id if session else None session_id=session.id if session else None,
) )
return return
@ -480,48 +491,50 @@ class MessageRouter:
except WebSocketError as e: except WebSocketError as e:
# WebSocket specific errors - attempt recovery # WebSocket specific errors - attempt recovery
await error_handler.handle_error( await error_handler.handle_error( # type: ignore
e, e,
context={ context={
"message_type": message_type, "message_type": message_type,
"session_id": session.id if session else "unknown", "session_id": session.id if session else "unknown",
"handler": type(self._handlers[message_type]).__name__ "handler": type(self._handlers[message_type]).__name__,
}, },
websocket=websocket, websocket=websocket,
session_id=session.id if session else None, session_id=session.id if session else None,
recovery_action=lambda: self._websocket_recovery(websocket, session) recovery_action=lambda: self._websocket_recovery(websocket, session),
) )
except ValidationError as e: except ValidationError as e:
# Validation errors - usually client-side issues # Validation errors - usually client-side issues
await error_handler.handle_error( await error_handler.handle_error( # type: ignore
e, e,
context={ context={
"message_type": message_type, "message_type": message_type,
"session_id": session.id if session else "unknown", "session_id": session.id if session else "unknown",
"data": str(data)[:500] # Truncate large data "data": str(data)[:500], # Truncate large data
}, },
websocket=websocket, websocket=websocket,
session_id=session.id if session else None session_id=session.id if session else None,
) )
except Exception as e: except Exception as e:
# Unexpected errors - enhanced logging and fallback # Unexpected errors - enhanced logging and fallback
await error_handler.handle_error( await error_handler.handle_error( # type: ignore
WebSocketError( WebSocketError(
f"Unexpected error in {message_type} handler: {e}", f"Unexpected error in {message_type} handler: {e}",
severity=ErrorSeverity.HIGH severity=ErrorSeverity.HIGH,
), ),
context={ context={
"message_type": message_type, "message_type": message_type,
"session_id": session.id if session else "unknown", "session_id": session.id if session else "unknown",
"handler": type(self._handlers[message_type]).__name__, "handler": type(self._handlers[message_type]).__name__,
"exception_type": type(e).__name__, "exception_type": type(e).__name__,
"traceback": str(e) "traceback": str(e),
}, },
websocket=websocket, websocket=websocket,
session_id=session.id if session else None, session_id=session.id if session else None,
recovery_action=lambda: self._generic_recovery(message_type, session, lobby) recovery_action=lambda: self._generic_recovery(
message_type, session, lobby
),
) )
async def _websocket_recovery(self, websocket: WebSocket, session: "Session"): async def _websocket_recovery(self, websocket: WebSocket, session: "Session"):

View File

@ -65,43 +65,6 @@ _device = "GPU.1" # Default to Intel Arc B580 GPU
_generate_global_lock = threading.Lock() _generate_global_lock = threading.Lock()
def _blocking_generate_decode(audio_array: AudioArray, sample_rate: int, generation_config: GenerationConfig | None = None) -> str:
"""Blocking helper to run processor -> model.generate -> decode while
holding a global lock to serialize OpenVINO access.
"""
try:
with _generate_global_lock:
ov_model = _ensure_model_loaded()
if ov_model.processor is None:
raise RuntimeError("Processor not initialized for OpenVINO model")
# Extract features
inputs = ov_model.processor(audio_array, sampling_rate=sample_rate, return_tensors="pt")
input_features = inputs.input_features
# Use a basic generation config if none provided
gen_cfg = generation_config or GenerationConfig(max_new_tokens=128)
gen_out = ov_model.ov_model.generate(input_features, generation_config=gen_cfg) # type: ignore
# Prefer .sequences if available
if hasattr(gen_out, "sequences"):
ids = gen_out.sequences
else:
ids = gen_out
# Decode
try:
transcription = ov_model.processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
except Exception:
transcription = ""
return transcription
except Exception as e:
logger.error(f"blocking_generate_decode failed: {e}", exc_info=True)
return ""
def get_available_devices() -> list[dict[str, Any]]: def get_available_devices() -> list[dict[str, Any]]:
"""List available OpenVINO devices with their properties.""" """List available OpenVINO devices with their properties."""
try: try:
@ -230,7 +193,7 @@ class OpenVINOConfig(BaseModel):
cfg.update( cfg.update(
{ {
"CPU_THROUGHPUT_NUM_THREADS": str(self.max_threads), "CPU_THROUGHPUT_NUM_THREADS": str(self.max_threads),
"CPU_BIND_THREAD": "YES", # "CPU_BIND_THREAD": "YES", # Removed: not supported by CPU plugin
} }
) )
@ -245,7 +208,7 @@ CHUNK_DURATION_MS = 100 # Reduced latency - 100ms chunks
VAD_THRESHOLD = 0.01 # Initial voice activity detection threshold VAD_THRESHOLD = 0.01 # Initial voice activity detection threshold
MAX_SILENCE_FRAMES = 30 # 3 seconds of silence before stopping (for overall silence) MAX_SILENCE_FRAMES = 30 # 3 seconds of silence before stopping (for overall silence)
MAX_TRAILING_SILENCE_FRAMES = 5 # 0.5 seconds of trailing silence MAX_TRAILING_SILENCE_FRAMES = 5 # 0.5 seconds of trailing silence
VAD_CONFIG = { VAD_CONFIG: Dict[str, Any] = {
"energy_threshold": 0.01, "energy_threshold": 0.01,
"zcr_threshold": 0.1, "zcr_threshold": 0.1,
"adapt_thresholds": True, "adapt_thresholds": True,
@ -301,7 +264,7 @@ def setup_intel_arc_environment() -> None:
class AdvancedVAD: class AdvancedVAD:
"""Advanced Voice Activity Detection with noise rejection.""" """Advanced Voice Activity Detection with noise rejection."""
def __init__(self, sample_rate: int = SAMPLE_RATE): def __init__(self, sample_rate: int = 16000):
self.sample_rate = sample_rate self.sample_rate = sample_rate
# More permissive thresholds based on research # More permissive thresholds based on research
self.energy_threshold = 0.005 # Reduced from 0.02 self.energy_threshold = 0.005 # Reduced from 0.02
@ -315,7 +278,7 @@ class AdvancedVAD:
# Relaxed temporal consistency # Relaxed temporal consistency
self.minimum_duration = 0.2 # Reduced from 0.3s self.minimum_duration = 0.2 # Reduced from 0.3s
self.speech_history = [] self.speech_history: List[bool] = []
self.max_history = 8 # Reduced from 10 self.max_history = 8 # Reduced from 10
# Adaptive noise floor # Adaptive noise floor
@ -327,7 +290,7 @@ class AdvancedVAD:
self.prev_magnitude = None self.prev_magnitude = None
self.harmonic_threshold = 0.15 # Reduced from 0.3 self.harmonic_threshold = 0.15 # Reduced from 0.3
def analyze_frame(self, audio_data: AudioArray) -> Tuple[bool, dict]: def analyze_frame(self, audio_data: AudioArray) -> Tuple[bool, Dict[str, Any]]:
"""Analyze audio frame for speech vs noise.""" """Analyze audio frame for speech vs noise."""
# Basic energy features # Basic energy features
@ -403,7 +366,7 @@ class AdvancedVAD:
(1 - self.adaptation_rate) * self.noise_floor_energy (1 - self.adaptation_rate) * self.noise_floor_energy
) )
metrics = { metrics: Dict[str, Any] = {
'energy': energy, 'energy': energy,
'zcr': zcr, 'zcr': zcr,
'centroid': spectral_features['centroid'], 'centroid': spectral_features['centroid'],
@ -419,9 +382,9 @@ class AdvancedVAD:
'temporal_consistency': recent_speech 'temporal_consistency': recent_speech
} }
return recent_speech, metrics return recent_speech, metrics # type: ignore
def _compute_spectral_features(self, audio_data: AudioArray) -> dict: def _compute_spectral_features(self, audio_data: AudioArray) -> Dict[str, Any]:
"""Compute spectral features for speech detection.""" """Compute spectral features for speech detection."""
# Apply window to reduce spectral leakage # Apply window to reduce spectral leakage
@ -464,7 +427,7 @@ class AdvancedVAD:
'harmonicity': harmonicity 'harmonicity': harmonicity
} }
def _compute_harmonicity(self, magnitude: np.ndarray, freqs: np.ndarray) -> float: def _compute_harmonicity(self, magnitude: npt.NDArray[np.float32], freqs: npt.NDArray[np.float32]) -> float:
"""Compute harmonicity score (0-1, higher = more harmonic/speech-like).""" """Compute harmonicity score (0-1, higher = more harmonic/speech-like)."""
# Find fundamental frequency candidate (peak in 80-400Hz range for speech) # Find fundamental frequency candidate (peak in 80-400Hz range for speech)
@ -483,24 +446,24 @@ class AdvancedVAD:
# More robust F0 detection - find peaks instead of just max # More robust F0 detection - find peaks instead of just max
try: try:
# Import scipy here to handle missing dependency gracefully # Import scipy here to handle missing dependency gracefully
from scipy.signal import find_peaks from scipy.signal import find_peaks # type: ignore
# Ensure distance is at least 1 # Ensure distance is at least 1
min_distance = max(1, int(len(speech_magnitude) * 0.05)) min_distance = max(1, int(len(speech_magnitude) * 0.05))
peaks, properties = find_peaks( peaks, properties = find_peaks( # type: ignore
speech_magnitude, speech_magnitude,
height=np.max(speech_magnitude) * 0.05, # Lowered from 0.1 height=np.max(speech_magnitude) * 0.05, # Lowered from 0.1
distance=min_distance, # Minimum peak separation distance=min_distance, # Minimum peak separation
) )
if len(peaks) == 0: if len(peaks) == 0: # type: ignore
# Fallback to simple max if no peaks found # Fallback to simple max if no peaks found
f0_idx = np.argmax(speech_magnitude) f0_idx = np.argmax(speech_magnitude)
else: else:
# Use the strongest peak # Use the strongest peak
strongest_peak_idx = np.argmax(speech_magnitude[peaks]) strongest_peak_idx = np.argmax(speech_magnitude[peaks])
f0_idx = peaks[strongest_peak_idx] f0_idx = int(peaks[strongest_peak_idx]) # type: ignore
except ImportError: except ImportError:
# scipy not available, use simple max # scipy not available, use simple max
@ -526,8 +489,8 @@ class AdvancedVAD:
harmonic_idx = np.argmin(np.abs(freqs - harmonic_freq)) harmonic_idx = np.argmin(np.abs(freqs - harmonic_freq))
# Check a small neighborhood around the harmonic frequency # Check a small neighborhood around the harmonic frequency
start_idx = max(0, harmonic_idx - 2) start_idx = max(0, int(harmonic_idx) - 2)
end_idx = min(len(magnitude), harmonic_idx + 3) end_idx = min(len(magnitude), int(harmonic_idx) + 3)
local_max = np.max(magnitude[start_idx:end_idx]) local_max = np.max(magnitude[start_idx:end_idx])
harmonic_strength += local_max harmonic_strength += local_max
@ -565,13 +528,13 @@ class OpenVINOWhisperModel:
logger.info( logger.info(
f"Loading Whisper model '{self.model_id}' on device: {self.device}" f"Loading Whisper model '{self.model_id}' on device: {self.device}"
) )
self.processor = WhisperProcessor.from_pretrained( self.processor = WhisperProcessor.from_pretrained( # type: ignore
self.model_id, use_fast=True self.model_id, use_fast=True
) # type: ignore ) # type: ignore
logger.info("Whisper processor loaded successfully") logger.info("Whisper processor loaded successfully")
# Export the model to OpenVINO IR if not already converted # Export the model to OpenVINO IR if not already converted
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained( self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained( # type: ignore
self.model_id, export=True, device=self.device self.model_id, export=True, device=self.device
) # type: ignore ) # type: ignore
@ -614,7 +577,7 @@ class OpenVINOWhisperModel:
try: try:
# Convert to OpenVINO with FP16 for Arc GPU # Convert to OpenVINO with FP16 for Arc GPU
ov_model = OVModelForSpeechSeq2Seq.from_pretrained( ov_model = OVModelForSpeechSeq2Seq.from_pretrained( # type: ignore
self.model_id, self.model_id,
ov_config=self.config.to_ov_config(), ov_config=self.config.to_ov_config(),
export=True, export=True,
@ -623,12 +586,13 @@ class OpenVINOWhisperModel:
) )
# Enable FP16 for Intel Arc performance # Enable FP16 for Intel Arc performance
ov_model.half() if hasattr(ov_model, 'half'):
ov_model.save_pretrained(self.model_path) ov_model.half() # type: ignore
ov_model.save_pretrained(self.model_path) # type: ignore
logger.info("Model converted and saved in FP16 format") logger.info("Model converted and saved in FP16 format")
# Load the converted model # Load the converted model
self.ov_model = ov_model self.ov_model = ov_model # type: ignore
self._compile_model() self._compile_model()
except Exception as e: except Exception as e:
@ -639,38 +603,38 @@ class OpenVINOWhisperModel:
"""Basic model conversion without advanced features.""" """Basic model conversion without advanced features."""
logger.info(f"Basic conversion of {self.model_id} to OpenVINO format...") logger.info(f"Basic conversion of {self.model_id} to OpenVINO format...")
ov_model = OVModelForSpeechSeq2Seq.from_pretrained( ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
self.model_id, export=True, compile=False self.model_id, export=True, compile=False
) )
ov_model.save_pretrained(self.model_path) ov_model.save_pretrained(self.model_path)# type: ignore
logger.info("Basic model conversion completed") logger.info("Basic model conversion completed")
def _load_fp16_model(self) -> None: def _load_fp16_model(self) -> None:
"""Load existing FP16 OpenVINO model.""" """Load existing FP16 OpenVINO model."""
logger.info("Loading existing FP16 OpenVINO model...") logger.info("Loading existing FP16 OpenVINO model...")
try: try:
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained( self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
self.model_path, ov_config=self.config.to_ov_config(), compile=False self.model_path, ov_config=self.config.to_ov_config(), compile=False
) ) # type: ignore
self._compile_model() self._compile_model()
except Exception as e: except Exception as e:
logger.error(f"Failed to load FP16 model: {e}") logger.error(f"Failed to load FP16 model: {e}")
# Try basic loading # Try basic loading
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained( self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
self.model_path, compile=False self.model_path, compile=False
) ) # type: ignore
self._compile_model() self._compile_model()
def _try_load_quantized_model(self) -> bool: def _try_load_quantized_model(self) -> bool:
"""Try to load existing quantized model.""" """Try to load existing quantized model."""
try: try:
logger.info("Loading existing INT8 quantized model...") logger.info("Loading existing INT8 quantized model...")
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained( self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
self.quantized_model_path, self.quantized_model_path,
ov_config=self.config.to_ov_config(), ov_config=self.config.to_ov_config(),
compile=False, compile=False,
) ) # type: ignore
self._compile_model() self._compile_model()
self.is_quantized = True self.is_quantized = True
logger.info("Quantized model loaded successfully") logger.info("Quantized model loaded successfully")
@ -690,13 +654,12 @@ class OpenVINOWhisperModel:
return return
# Check if model components are available # Check if model components are available
if not hasattr(self.ov_model, "encoder") or self.ov_model.encoder is None: if not hasattr(self.ov_model, "encoder"):
logger.warning("Model encoder not available, skipping quantization") logger.warning("Model encoder not available, skipping quantization")
return return
if ( if (
not hasattr(self.ov_model, "decoder_with_past") not hasattr(self.ov_model, "decoder_with_past")
or self.ov_model.decoder_with_past is None
): ):
logger.warning( logger.warning(
"Model decoder_with_past not available, skipping quantization" "Model decoder_with_past not available, skipping quantization"
@ -761,14 +724,14 @@ class OpenVINOWhisperModel:
# Save quantized models # Save quantized models
self.quantized_model_path.mkdir(parents=True, exist_ok=True) self.quantized_model_path.mkdir(parents=True, exist_ok=True)
ov.save_model( ov.save_model(# type: ignore
quantized_encoder, quantized_encoder,
self.quantized_model_path / "openvino_encoder_model.xml", self.quantized_model_path / "openvino_encoder_model.xml",
) # type: ignore ) # type: ignore
ov.save_model( ov.save_model(# type: ignore
quantized_decoder, quantized_decoder,
self.quantized_model_path / "openvino_decoder_with_past_model.xml", self.quantized_model_path / "openvino_decoder_with_past_model.xml",
) # type: ignore ) # type: ignore # type: ignore
# Copy remaining files # Copy remaining files
self._copy_model_files() self._copy_model_files()
@ -828,12 +791,12 @@ class OpenVINOWhisperModel:
decoder_data: CalibrationData = [] decoder_data: CalibrationData = []
try: try:
self.ov_model.encoder.request = InferRequestWrapper( self.ov_model.encoder.request = InferRequestWrapper(# type: ignore
original_encoder_request, encoder_data original_encoder_request, encoder_data# type: ignore
) ) # type: ignore
self.ov_model.decoder_with_past.request = InferRequestWrapper( self.ov_model.decoder_with_past.request = InferRequestWrapper(
original_decoder_request, decoder_data original_decoder_request, decoder_data
) ) # type: ignore
# Generate synthetic calibration data instead of loading dataset # Generate synthetic calibration data instead of loading dataset
logger.info("Generating synthetic calibration data...") logger.info("Generating synthetic calibration data...")
@ -842,17 +805,17 @@ class OpenVINOWhisperModel:
# Generate random audio similar to speech # Generate random audio similar to speech
duration = 2.0 + np.random.random() * 3.0 # 2-5 seconds duration = 2.0 + np.random.random() * 3.0 # 2-5 seconds
synthetic_audio = ( synthetic_audio = (
np.random.randn(int(SAMPLE_RATE * duration)).astype(np.float32) np.random.randn(int(16000 * duration)).astype(np.float32)
* 0.1 * 0.1
) )
inputs: Any = self.processor( inputs: Any = self.processor(# type: ignore
synthetic_audio, sampling_rate=SAMPLE_RATE, return_tensors="pt" synthetic_audio, sampling_rate=16000, return_tensors="pt"
) ) # type: ignore
# Run inference to collect calibration data # Run inference to collect calibration data
generated_ids = self.ov_model.generate( _ = self.ov_model.generate( # type: ignore
inputs.input_features, max_new_tokens=10 inputs.input_features, max_new_tokens=10 # type: ignore
) )
if i % 5 == 0: if i % 5 == 0:
@ -882,7 +845,7 @@ class OpenVINOWhisperModel:
result["decoder"] = decoder_data result["decoder"] = decoder_data
logger.info(f"Collected {len(decoder_data)} decoder calibration samples") logger.info(f"Collected {len(decoder_data)} decoder calibration samples")
return result return result # type: ignore
def _copy_model_files(self) -> None: def _copy_model_files(self) -> None:
"""Copy necessary model files for quantized model.""" """Copy necessary model files for quantized model."""
@ -951,28 +914,29 @@ class OpenVINOWhisperModel:
) )
# Try to reload using the existing saved model path if possible # Try to reload using the existing saved model path if possible
try: try:
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained( self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
self.model_path, ov_config=cpu_cfg.to_ov_config(), compile=False self.model_path, ov_config=cpu_cfg.to_ov_config(), compile=False
) ) # type: ignore
except Exception: except Exception:
# If loading the saved model failed, try loading without ov_config # If loading the saved model failed, try loading without ov_config
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained( self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
self.model_path, compile=False self.model_path, compile=False
) ) # type: ignore
# Compile on CPU # Compile on CPU
self.ov_model.to("CPU") if self.ov_model is not None:
# Provide CPU-only ov_config if supported self.ov_model.to("CPU") # type: ignore
try: # Provide CPU-only ov_config if supported
self.ov_model.compile() try:
except Exception as compile_cpu_e: self.ov_model.compile() # type: ignore
logger.warning( except Exception as compile_cpu_e:
f"CPU compile with CPU ov_config failed, retrying default compile: {compile_cpu_e}" logger.warning(
) f"CPU compile with CPU ov_config failed, retrying default compile: {compile_cpu_e}"
self.ov_model.compile() )
self.ov_model.compile() # type: ignore
self._warmup_model() self._warmup_model()
logger.info("Model compiled for CPU successfully") logger.info("Model compiled for CPU successfully")
except Exception as cpu_e: except Exception as cpu_e:
logger.error(f"Failed to compile for CPU as well: {cpu_e}") logger.error(f"Failed to compile for CPU as well: {cpu_e}")
raise raise
@ -984,14 +948,14 @@ class OpenVINOWhisperModel:
try: try:
logger.info("Warming up model...") logger.info("Warming up model...")
dummy_audio = np.random.randn(SAMPLE_RATE).astype(np.float32) # 1 second dummy_audio = np.random.randn(16000).astype(np.float32) # 1 second
dummy_features = self.processor( dummy_features = self.processor(# type: ignore
dummy_audio, sampling_rate=SAMPLE_RATE, return_tensors="pt" dummy_audio, sampling_rate=16000, return_tensors="pt"
).input_features ).input_features
# Run warmup iterations # Run warmup iterations
for i in range(3): for i in range(3):
_ = self.ov_model.generate(dummy_features, max_new_tokens=10) _ = self.ov_model.generate(dummy_features, max_new_tokens=10)# type: ignore
if i == 0: if i == 0:
logger.debug("First warmup iteration completed") logger.debug("First warmup iteration completed")
except Exception as e: except Exception as e:
@ -1004,9 +968,9 @@ class OpenVINOWhisperModel:
if self.processor is None: if self.processor is None:
raise RuntimeError("Processor not initialized") raise RuntimeError("Processor not initialized")
return self.processor.batch_decode( return self.processor.batch_decode(# type: ignore
token_ids, skip_special_tokens=skip_special_tokens token_ids, skip_special_tokens=skip_special_tokens
) ) # type: ignore
# Global model instance with deferred loading # Global model instance with deferred loading
@ -1046,12 +1010,12 @@ def extract_input_features(audio_array: AudioArray, sampling_rate: int) -> torch
if ov_model.processor is None: if ov_model.processor is None:
raise RuntimeError("Processor not initialized") raise RuntimeError("Processor not initialized")
inputs = ov_model.processor( inputs = ov_model.processor(# type: ignore
audio_array, audio_array,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
return_tensors="pt", return_tensors="pt",
) ) # type: ignore
return inputs.input_features return inputs.input_features # type: ignore
class VoiceActivityDetector(BaseModel): class VoiceActivityDetector(BaseModel):
@ -1064,7 +1028,7 @@ class VoiceActivityDetector(BaseModel):
def simple_robust_vad( def simple_robust_vad(
audio_data: AudioArray, audio_data: AudioArray,
energy_threshold: float = 0.01, energy_threshold: float = 0.01,
sample_rate: int = SAMPLE_RATE, sample_rate: int = 16000,
) -> VoiceActivityDetector: ) -> VoiceActivityDetector:
"""Simplified robust VAD.""" """Simplified robust VAD."""
@ -1091,7 +1055,7 @@ def enhanced_vad(
audio_data: AudioArray, audio_data: AudioArray,
energy_threshold: float = 0.01, energy_threshold: float = 0.01,
zcr_threshold: float = 0.1, zcr_threshold: float = 0.1,
sample_rate: int = SAMPLE_RATE, sample_rate: int = 16000,
) -> VoiceActivityDetector: ) -> VoiceActivityDetector:
"""Enhanced VAD using multiple features. """Enhanced VAD using multiple features.
@ -1137,14 +1101,39 @@ class OptimizedAudioProcessor:
self.peer_name = peer_name self.peer_name = peer_name
self.send_chat_func = send_chat_func self.send_chat_func = send_chat_func
self.create_chat_message_func = create_chat_message_func self.create_chat_message_func = create_chat_message_func
self.sample_rate = SAMPLE_RATE
# Audio processing settings (use defaults, can be overridden per instance)
self.sample_rate = 16000 # Default Whisper sample rate
self.chunk_duration_ms = 100 # Default chunk duration
self.chunk_size = int(self.sample_rate * self.chunk_duration_ms / 1000)
# Silence handling parameters
self.max_silence_frames = 30 # Default max silence frames
self.max_trailing_silence_frames = 5 # Default trailing silence frames
# VAD settings (use defaults, can be overridden per instance)
self.vad_energy_threshold = 0.005
self.vad_zcr_min = 0.02
self.vad_zcr_max = 0.8
self.vad_spectral_centroid_min = 200
self.vad_spectral_centroid_max = 4000
self.vad_spectral_rolloff_threshold = 3000
self.vad_minimum_duration = 0.2
self.vad_max_history = 8
self.vad_noise_floor_energy = 0.001
self.vad_adaptation_rate = 0.05
self.vad_harmonic_threshold = 0.15
# Normalization settings
self.normalization_enabled = True # Default normalization enabled
self.normalization_target_peak = 0.7 # Default target peak
self.max_normalization_gain = 10.0 # Default max gain
# Initialize visualization buffer if not already done # Initialize visualization buffer if not already done
if self.peer_name not in WaveformVideoTrack.buffer: if self.peer_name not in WaveformVideoTrack.buffer:
WaveformVideoTrack.buffer[self.peer_name] = np.array([], dtype=np.float32) WaveformVideoTrack.buffer[self.peer_name] = np.array([], dtype=np.float32)
# Optimized buffering parameters # Optimized buffering parameters
self.chunk_size = int(self.sample_rate * CHUNK_DURATION_MS / 1000)
self.buffer_size = self.chunk_size * 50 self.buffer_size = self.chunk_size * 50
# Circular buffer for zero-copy operations # Circular buffer for zero-copy operations
@ -1154,8 +1143,6 @@ class OptimizedAudioProcessor:
# Silence handling parameters # Silence handling parameters
self.silence_frames: int = 0 self.silence_frames: int = 0
self.max_silence_frames: int = MAX_SILENCE_FRAMES
self.max_trailing_silence_frames: int = MAX_TRAILING_SILENCE_FRAMES
# Enhanced VAD parameters with EMA for noise adaptation # Enhanced VAD parameters with EMA for noise adaptation
self.advanced_vad = AdvancedVAD(sample_rate=self.sample_rate) self.advanced_vad = AdvancedVAD(sample_rate=self.sample_rate)
@ -1165,9 +1152,6 @@ class OptimizedAudioProcessor:
# maximum which helps models expect a consistent level across peers. # maximum which helps models expect a consistent level across peers.
# It's intentionally permissive and capped to avoid amplifying noise. # It's intentionally permissive and capped to avoid amplifying noise.
self.max_observed_amplitude: float = 1e-6 self.max_observed_amplitude: float = 1e-6
self.normalization_enabled: bool = True
self.normalization_target_peak: float = 0.95
self.max_normalization_gain: float = 3.0 # avoid amplifying tiny noise too much
# Processing state # Processing state
self.current_phrase_audio = np.array([], dtype=np.float32) self.current_phrase_audio = np.array([], dtype=np.float32)
@ -1476,9 +1460,9 @@ class OptimizedAudioProcessor:
ov_model = _ensure_model_loaded() ov_model = _ensure_model_loaded()
# Extract features (this is relatively cheap but keep on thread) # Extract features (this is relatively cheap but keep on thread)
input_features = ov_model.processor( input_features = ov_model.processor(# type: ignore
audio_in, sampling_rate=self.sample_rate, return_tensors="pt" audio_in, sampling_rate=self.sample_rate, return_tensors="pt"
).input_features ).input_features # type: ignore
# Perform generation (blocking) # Perform generation (blocking)
# Use the same generation configuration as the async path # Use the same generation configuration as the async path
@ -1496,23 +1480,24 @@ class OptimizedAudioProcessor:
# Serialize access to the underlying OpenVINO generation call # Serialize access to the underlying OpenVINO generation call
# to avoid concurrency problems with the OpenVINO runtime. # to avoid concurrency problems with the OpenVINO runtime.
with _generate_global_lock: with _generate_global_lock:
gen_out = ov_model.ov_model.generate( gen_out = ov_model.ov_model.generate(# type: ignore
input_features, generation_config=gen_cfg input_features, generation_config=gen_cfg# type: ignore
) )
# Try to extract sequences if present # Try to extract sequences if present
if hasattr(gen_out, "sequences"): if hasattr(gen_out, "sequences"): # type: ignore
ids = gen_out.sequences ids = gen_out.sequences # type: ignore
else: else:
ids = gen_out ids = gen_out # type: ignore
# Decode # Decode
text: str = ""
try: try:
text = ov_model.processor.batch_decode(ids, skip_special_tokens=True)[0].strip() text = ov_model.processor.batch_decode(ids, skip_special_tokens=True)[0].strip() # type: ignore
except Exception: except Exception:
text = "" text = ""
return text, 0.0 return text, 0.0 # type: ignore
except Exception as e: except Exception as e:
logger.error(f"Blocking transcription failed for {self.peer_name}: {e}", exc_info=True) logger.error(f"Blocking transcription failed for {self.peer_name}: {e}", exc_info=True)
return "", 0.0 return "", 0.0
@ -1933,7 +1918,7 @@ class OptimizedAudioProcessor:
# Many generate implementations return an object with a # Many generate implementations return an object with a
# `.sequences` attribute, so prefer that when available. # `.sequences` attribute, so prefer that when available.
if hasattr(generation_output, "sequences"): if hasattr(generation_output, "sequences"):
generated_ids = generation_output.sequences generated_ids = generation_output.sequences # type: ignore
else: else:
generated_ids = generation_output generated_ids = generation_output
@ -1958,9 +1943,9 @@ class OptimizedAudioProcessor:
# Primary decode attempt # Primary decode attempt
transcription: str = "" transcription: str = ""
try: try:
transcription = ov_model.processor.batch_decode( transcription = ov_model.processor.batch_decode(# type: ignore
generated_ids, skip_special_tokens=True generated_ids, skip_special_tokens=True
)[0].strip() )[0].strip() # type: ignore
except Exception as decode_e: except Exception as decode_e:
logger.warning(f"{self.peer_name}: primary decode failed: {decode_e}") logger.warning(f"{self.peer_name}: primary decode failed: {decode_e}")
@ -1969,11 +1954,11 @@ class OptimizedAudioProcessor:
if not transcription: if not transcription:
try: try:
if hasattr(generation_output, "sequences") and ( if hasattr(generation_output, "sequences") and (
generated_ids is not generation_output.sequences generated_ids is not generation_output.sequences # type: ignore
): ):
transcription = ov_model.processor.batch_decode( transcription = ov_model.processor.batch_decode(# type: ignore
generation_output.sequences, skip_special_tokens=True generation_output.sequences, skip_special_tokens=True # type: ignore
)[0].strip() )[0].strip() # type: ignore
except Exception as fallback_e: except Exception as fallback_e:
logger.warning(f"{self.peer_name}: fallback decode failed: {fallback_e}") logger.warning(f"{self.peer_name}: fallback decode failed: {fallback_e}")
@ -1982,11 +1967,11 @@ class OptimizedAudioProcessor:
try: try:
if is_final: if is_final:
logger.info( logger.info(
f"{self.peer_name}: final transcription empty after decode; generated_ids repr/shape: {repr(generated_ids)[:200]}" f"{self.peer_name}: final transcription empty after decode"
) )
else: else:
logger.debug( logger.debug(
f"{self.peer_name}: streaming transcription empty after decode; generated_ids repr/shape: {repr(generated_ids)[:200]}" f"{self.peer_name}: streaming transcription empty after decode"
) )
except Exception: except Exception:
logger.debug(f"{self.peer_name}: generated_ids unavailable for diagnostics") logger.debug(f"{self.peer_name}: generated_ids unavailable for diagnostics")
@ -2020,7 +2005,7 @@ class OptimizedAudioProcessor:
# Avoid duplicates for streaming updates, but always send final # Avoid duplicates for streaming updates, but always send final
# transcriptions so the UI/clients receive the final marker even # transcriptions so the UI/clients receive the final marker even
# if the text matches a recent interim result. # if the text matches a recent interim result.
if is_final or not self._is_duplicate(transcription): if is_final or not self._is_duplicate(transcription): # type: ignore
# Reuse the existing message ID when possible so the frontend # Reuse the existing message ID when possible so the frontend
# updates the streaming message into a final message instead # updates the streaming message into a final message instead
# of creating a new one. If there is no current_message, a # of creating a new one. If there is no current_message, a
@ -2163,7 +2148,7 @@ class WaveformVideoTrack(MediaStreamTrack):
# Shared buffer for audio data # Shared buffer for audio data
buffer: Dict[str, npt.NDArray[np.float32]] = {} buffer: Dict[str, npt.NDArray[np.float32]] = {}
speech_status: Dict[str, dict] = {} speech_status: Dict[str, Dict[str, Any]] = {}
def __init__( def __init__(
self, session_name: str, width: int = 640, height: int = 480, fps: int = 15 self, session_name: str, width: int = 640, height: int = 480, fps: int = 15
@ -2182,7 +2167,7 @@ class WaveformVideoTrack(MediaStreamTrack):
return pts, time_base return pts, time_base
async def recv(self) -> VideoFrame: async def recv(self) -> VideoFrame:
pts, time_base = await self.next_timestamp() pts, _ = await self.next_timestamp()
# schedule frame according to clock # schedule frame according to clock
target_t = self._next_frame_index / self.fps target_t = self._next_frame_index / self.fps
@ -2224,7 +2209,7 @@ class WaveformVideoTrack(MediaStreamTrack):
# Draw clock in lower right corner, right justified # Draw clock in lower right corner, right justified
current_time = time.strftime("%H:%M:%S") current_time = time.strftime("%H:%M:%S")
(text_width, text_height), _ = cv2.getTextSize( (text_width, _), _ = cv2.getTextSize(
current_time, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2 current_time, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2
) )
clock_x = self.width - text_width - 10 # 10px margin from right edge clock_x = self.width - text_width - 10 # 10px margin from right edge
@ -2364,7 +2349,7 @@ class WaveformVideoTrack(MediaStreamTrack):
# Label the peak with small text near the right edge # Label the peak with small text near the right edge
label = f"Peak:{target_peak:.2f}" label = f"Peak:{target_peak:.2f}"
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) (tw, _), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
lx = max(10, self.width - tw - 12) lx = max(10, self.width - tw - 12)
ly = max(12, top_y - 6) ly = max(12, top_y - 6)
cv2.putText(frame_array, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 1) cv2.putText(frame_array, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 1)
@ -2391,7 +2376,7 @@ class WaveformVideoTrack(MediaStreamTrack):
frame.time_base = fractions.Fraction(1 / 90000).limit_denominator(1000000) frame.time_base = fractions.Fraction(1 / 90000).limit_denominator(1000000)
return frame return frame
def _draw_speech_status(self, frame_array: np.ndarray, speech_info: dict, pname: str): def _draw_speech_status(self, frame_array: npt.NDArray[np.uint8], speech_info: Dict[str, Any], pname: str):
"""Draw speech detection status information.""" """Draw speech detection status information."""
y_offset = 100 y_offset = 100
@ -2414,7 +2399,7 @@ class WaveformVideoTrack(MediaStreamTrack):
f"Temporal: ({'Y' if speech_info.get('temporal_consistency', False) else 'N'})" f"Temporal: ({'Y' if speech_info.get('temporal_consistency', False) else 'N'})"
] ]
for i, metric in enumerate(metrics): for _, metric in enumerate(metrics):
cv2.putText(frame_array, metric, cv2.putText(frame_array, metric,
(320, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (320, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.4,
(255, 255, 255), 1) (255, 255, 255), 1)
@ -2646,13 +2631,13 @@ def _resample_audio(
audio_data = np.mean(audio_data, axis=1) audio_data = np.mean(audio_data, axis=1)
# Use high-quality resampling # Use high-quality resampling
resampled = librosa.resample( resampled = librosa.resample( # type: ignore
audio_data.astype(np.float64), audio_data.astype(np.float64),
orig_sr=orig_sr, orig_sr=orig_sr,
target_sr=target_sr, target_sr=target_sr,
res_type="kaiser_fast", # Good balance of quality and speed res_type="kaiser_fast", # Good balance of quality and speed
) )
return resampled.astype(np.float32) return resampled.astype(np.float32) # type: ignore
except Exception as e: except Exception as e:
logger.error(f"Resampling failed: {e}") logger.error(f"Resampling failed: {e}")
raise ValueError( raise ValueError(
@ -2760,7 +2745,7 @@ def get_config_schema() -> Dict[str, Any]:
"type": "range", "type": "range",
"label": "VAD Threshold", "label": "VAD Threshold",
"description": "Voice activity detection threshold", "description": "Voice activity detection threshold",
"default_value": VAD_THRESHOLD, "default_value": 0.01,
"required": False, "required": False,
"min_value": 0.001, "min_value": 0.001,
"max_value": 0.1, "max_value": 0.1,
@ -2915,7 +2900,7 @@ def get_config_schema() -> Dict[str, Any]:
"type": "boolean", "type": "boolean",
"label": "Enable Normalization", "label": "Enable Normalization",
"description": "Normalize incoming audio based on observed peak amplitude before transcription and visualization", "description": "Normalize incoming audio based on observed peak amplitude before transcription and visualization",
"default_value": NORMALIZATION_ENABLED, "default_value": True,
"required": False "required": False
}, },
{ {
@ -2923,7 +2908,7 @@ def get_config_schema() -> Dict[str, Any]:
"type": "number", "type": "number",
"label": "Normalization Target Peak", "label": "Normalization Target Peak",
"description": "Target peak (0-1) used when normalizing audio", "description": "Target peak (0-1) used when normalizing audio",
"default_value": NORMALIZATION_TARGET_PEAK, "default_value": 0.7,
"required": False, "required": False,
"min_value": 0.5, "min_value": 0.5,
"max_value": 1.0 "max_value": 1.0
@ -2933,7 +2918,7 @@ def get_config_schema() -> Dict[str, Any]:
"type": "range", "type": "range",
"label": "Max Normalization Gain", "label": "Max Normalization Gain",
"description": "Maximum allowed gain applied during normalization", "description": "Maximum allowed gain applied during normalization",
"default_value": MAX_NORMALIZATION_GAIN, "default_value": 10.0,
"required": False, "required": False,
"min_value": 1.0, "min_value": 1.0,
"max_value": 10.0, "max_value": 10.0,
@ -2951,15 +2936,14 @@ def get_config_schema() -> Dict[str, Any]:
def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool: def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
"""Handle configuration update for a specific lobby""" """Handle configuration update for a specific lobby"""
global _model_id, _device, _ov_config, SAMPLE_RATE, CHUNK_DURATION_MS, VAD_THRESHOLD global _model_id, _device, _ov_config
global MAX_SILENCE_FRAMES, MAX_TRAILING_SILENCE_FRAMES
try: try:
logger.info(f"Updating Whisper config for lobby {lobby_id}: {config_values}") logger.info(f"Updating Whisper config for lobby {lobby_id}: {config_values}")
config_applied = False config_applied = False
# Update model configuration # Update model configuration (global - affects all instances)
if "model_id" in config_values: if "model_id" in config_values:
new_model_id = config_values["model_id"] new_model_id = config_values["model_id"]
if new_model_id in [model for models in model_ids.values() for model in models]: if new_model_id in [model for models in model_ids.values() for model in models]:
@ -2969,9 +2953,9 @@ def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
else: else:
logger.warning(f"Invalid model_id: {new_model_id}") logger.warning(f"Invalid model_id: {new_model_id}")
# Update device configuration # Update device configuration (global - affects all instances)
if "device" in config_values: if "device" in config_values:
new_device = config_values["device"] new_device = config_values["device"] # type: ignore
available_devices = [d["name"] for d in get_available_devices()] available_devices = [d["name"] for d in get_available_devices()]
if new_device in available_devices or new_device in ["CPU", "GPU", "GPU.1"]: if new_device in available_devices or new_device in ["CPU", "GPU", "GPU.1"]:
_device = new_device _device = new_device
@ -2981,7 +2965,7 @@ def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
else: else:
logger.warning(f"Invalid device: {new_device}, available: {available_devices}") logger.warning(f"Invalid device: {new_device}, available: {available_devices}")
# Update OpenVINO configuration # Update OpenVINO configuration (global - affects all instances)
if "enable_quantization" in config_values: if "enable_quantization" in config_values:
_ov_config.enable_quantization = bool(config_values["enable_quantization"]) _ov_config.enable_quantization = bool(config_values["enable_quantization"])
config_applied = True config_applied = True
@ -3001,106 +2985,212 @@ def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
config_applied = True config_applied = True
logger.info(f"Updated max_threads to: {_ov_config.max_threads}") logger.info(f"Updated max_threads to: {_ov_config.max_threads}")
# Update audio processing parameters # Update audio processing parameters for existing processors
if "sample_rate" in config_values: if "sample_rate" in config_values:
rate = int(config_values["sample_rate"]) rate = int(config_values["sample_rate"])
if 8000 <= rate <= 48000: if 8000 <= rate <= 48000:
SAMPLE_RATE = rate # Update existing processors
for pname, proc in list(_audio_processors.items()):
try:
proc.sample_rate = rate
proc.chunk_size = int(proc.sample_rate * proc.chunk_duration_ms / 1000)
logger.info(f"Updated sample_rate to {rate} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update sample_rate for processor: {pname}")
config_applied = True config_applied = True
logger.info(f"Updated sample_rate to: {SAMPLE_RATE}") logger.info(f"Updated sample_rate to: {rate}")
if "chunk_duration_ms" in config_values: if "chunk_duration_ms" in config_values:
duration = int(config_values["chunk_duration_ms"]) duration = int(config_values["chunk_duration_ms"])
if 50 <= duration <= 500: if 50 <= duration <= 500:
CHUNK_DURATION_MS = duration # Update existing processors
for pname, proc in list(_audio_processors.items()):
try:
proc.chunk_duration_ms = duration
proc.chunk_size = int(proc.sample_rate * proc.chunk_duration_ms / 1000)
logger.info(f"Updated chunk_duration_ms to {duration} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update chunk_duration_ms for processor: {pname}")
config_applied = True config_applied = True
logger.info(f"Updated chunk_duration_ms to: {CHUNK_DURATION_MS}") logger.info(f"Updated chunk_duration_ms to: {duration}")
if "vad_threshold" in config_values:
threshold = float(config_values["vad_threshold"])
if 0.001 <= threshold <= 0.1:
VAD_THRESHOLD = threshold
config_applied = True
logger.info(f"Updated vad_threshold to: {VAD_THRESHOLD}")
if "max_silence_frames" in config_values: if "max_silence_frames" in config_values:
frames = int(config_values["max_silence_frames"]) frames = int(config_values["max_silence_frames"])
if 10 <= frames <= 100: if 10 <= frames <= 100:
MAX_SILENCE_FRAMES = frames # Update existing processors
for pname, proc in list(_audio_processors.items()):
try:
proc.max_silence_frames = frames
logger.info(f"Updated max_silence_frames to {frames} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update max_silence_frames for processor: {pname}")
config_applied = True config_applied = True
logger.info(f"Updated max_silence_frames to: {MAX_SILENCE_FRAMES}") logger.info(f"Updated max_silence_frames to: {frames}")
if "max_trailing_silence_frames" in config_values: if "max_trailing_silence_frames" in config_values:
frames = int(config_values["max_trailing_silence_frames"]) frames = int(config_values["max_trailing_silence_frames"])
if 1 <= frames <= 20: if 1 <= frames <= 20:
MAX_TRAILING_SILENCE_FRAMES = frames # Update existing processors
config_applied = True
logger.info(f"Updated max_trailing_silence_frames to: {MAX_TRAILING_SILENCE_FRAMES}")
# Update VAD configuration (this would require updating existing processors)
vad_updates = {}
if "vad_energy_threshold" in config_values:
vad_updates["energy_threshold"] = float(config_values["vad_energy_threshold"])
if "vad_zcr_min" in config_values:
vad_updates["zcr_min"] = float(config_values["vad_zcr_min"])
if "vad_zcr_max" in config_values:
vad_updates["zcr_max"] = float(config_values["vad_zcr_max"])
if "vad_spectral_centroid_min" in config_values:
vad_updates["spectral_centroid_min"] = float(config_values["vad_spectral_centroid_min"])
if "vad_spectral_centroid_max" in config_values:
vad_updates["spectral_centroid_max"] = float(config_values["vad_spectral_centroid_max"])
if "vad_spectral_rolloff_threshold" in config_values:
vad_updates["spectral_rolloff_threshold"] = float(config_values["vad_spectral_rolloff_threshold"])
if "vad_minimum_duration" in config_values:
vad_updates["minimum_duration"] = float(config_values["vad_minimum_duration"])
if "vad_max_history" in config_values:
vad_updates["max_history"] = int(config_values["vad_max_history"])
if "vad_noise_floor_energy" in config_values:
vad_updates["noise_floor_energy"] = float(config_values["vad_noise_floor_energy"])
if "vad_adaptation_rate" in config_values:
vad_updates["adaptation_rate"] = float(config_values["vad_adaptation_rate"])
if "vad_harmonic_threshold" in config_values:
vad_updates["harmonic_threshold"] = float(config_values["vad_harmonic_threshold"])
if vad_updates:
# Update VAD_CONFIG global
VAD_CONFIG.update(vad_updates)
config_applied = True
logger.info(f"Updated VAD config: {vad_updates}")
# Note: Existing processors would need to be recreated to pick up VAD changes
# For now, we'll log that a restart may be needed
logger.info("VAD configuration updated - existing processors may need restart to take effect")
# Normalization updates: apply to global defaults and active processors
norm_updates = False
if "normalization_enabled" in config_values:
NORMALIZATION_ENABLED = bool(config_values["normalization_enabled"])
norm_updates = True
logger.info(f"Updated NORMALIZATION_ENABLED to: {NORMALIZATION_ENABLED}")
if "normalization_target_peak" in config_values:
NORMALIZATION_TARGET_PEAK = float(config_values["normalization_target_peak"])
norm_updates = True
logger.info(f"Updated NORMALIZATION_TARGET_PEAK to: {NORMALIZATION_TARGET_PEAK}")
if "max_normalization_gain" in config_values:
MAX_NORMALIZATION_GAIN = float(config_values["max_normalization_gain"])
norm_updates = True
logger.info(f"Updated MAX_NORMALIZATION_GAIN to: {MAX_NORMALIZATION_GAIN}")
if norm_updates:
# Propagate changes to existing processors
try:
for pname, proc in list(_audio_processors.items()): for pname, proc in list(_audio_processors.items()):
try: try:
proc.normalization_enabled = NORMALIZATION_ENABLED proc.max_trailing_silence_frames = frames
proc.normalization_target_peak = NORMALIZATION_TARGET_PEAK logger.info(f"Updated max_trailing_silence_frames to {frames} for processor: {pname}")
proc.max_normalization_gain = MAX_NORMALIZATION_GAIN
logger.info(f"Applied normalization config to processor: {pname}")
except Exception: except Exception:
logger.debug(f"Failed to apply normalization config to processor: {pname}") logger.debug(f"Failed to update max_trailing_silence_frames for processor: {pname}")
config_applied = True config_applied = True
except Exception: logger.info(f"Updated max_trailing_silence_frames to: {frames}")
logger.debug("Failed to propagate normalization settings to processors")
# Update VAD configuration for existing processors
vad_updates = False
if "vad_energy_threshold" in config_values:
threshold = float(config_values["vad_energy_threshold"])
for pname, proc in list(_audio_processors.items()):
try:
proc.vad_energy_threshold = threshold
logger.info(f"Updated vad_energy_threshold to {threshold} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update vad_energy_threshold for processor: {pname}")
vad_updates = True
if "vad_zcr_min" in config_values:
zcr_min = float(config_values["vad_zcr_min"])
for pname, proc in list(_audio_processors.items()):
try:
proc.vad_zcr_min = zcr_min
logger.info(f"Updated vad_zcr_min to {zcr_min} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update vad_zcr_min for processor: {pname}")
vad_updates = True
if "vad_zcr_max" in config_values:
zcr_max = float(config_values["vad_zcr_max"])
for pname, proc in list(_audio_processors.items()):
try:
proc.vad_zcr_max = zcr_max
logger.info(f"Updated vad_zcr_max to {zcr_max} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update vad_zcr_max for processor: {pname}")
vad_updates = True
if "vad_spectral_centroid_min" in config_values:
centroid_min = int(config_values["vad_spectral_centroid_min"])
for pname, proc in list(_audio_processors.items()):
try:
proc.vad_spectral_centroid_min = centroid_min
logger.info(f"Updated vad_spectral_centroid_min to {centroid_min} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update vad_spectral_centroid_min for processor: {pname}")
vad_updates = True
if "vad_spectral_centroid_max" in config_values:
centroid_max = int(config_values["vad_spectral_centroid_max"])
for pname, proc in list(_audio_processors.items()):
try:
proc.vad_spectral_centroid_max = centroid_max
logger.info(f"Updated vad_spectral_centroid_max to {centroid_max} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update vad_spectral_centroid_max for processor: {pname}")
vad_updates = True
if "vad_spectral_rolloff_threshold" in config_values:
rolloff = int(config_values["vad_spectral_rolloff_threshold"])
for pname, proc in list(_audio_processors.items()):
try:
proc.vad_spectral_rolloff_threshold = rolloff
logger.info(f"Updated vad_spectral_rolloff_threshold to {rolloff} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update vad_spectral_rolloff_threshold for processor: {pname}")
vad_updates = True
if "vad_minimum_duration" in config_values:
duration = float(config_values["vad_minimum_duration"])
for pname, proc in list(_audio_processors.items()):
try:
proc.vad_minimum_duration = duration
logger.info(f"Updated vad_minimum_duration to {duration} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update vad_minimum_duration for processor: {pname}")
vad_updates = True
if "vad_max_history" in config_values:
history = int(config_values["vad_max_history"])
for pname, proc in list(_audio_processors.items()):
try:
proc.vad_max_history = history
logger.info(f"Updated vad_max_history to {history} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update vad_max_history for processor: {pname}")
vad_updates = True
if "vad_noise_floor_energy" in config_values:
noise_floor = float(config_values["vad_noise_floor_energy"])
for pname, proc in list(_audio_processors.items()):
try:
proc.vad_noise_floor_energy = noise_floor
logger.info(f"Updated vad_noise_floor_energy to {noise_floor} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update vad_noise_floor_energy for processor: {pname}")
vad_updates = True
if "vad_adaptation_rate" in config_values:
adaptation_rate = float(config_values["vad_adaptation_rate"])
for pname, proc in list(_audio_processors.items()):
try:
proc.vad_adaptation_rate = adaptation_rate
logger.info(f"Updated vad_adaptation_rate to {adaptation_rate} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update vad_adaptation_rate for processor: {pname}")
vad_updates = True
if "vad_harmonic_threshold" in config_values:
harmonic_threshold = float(config_values["vad_harmonic_threshold"])
for pname, proc in list(_audio_processors.items()):
try:
proc.vad_harmonic_threshold = harmonic_threshold
logger.info(f"Updated vad_harmonic_threshold to {harmonic_threshold} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update vad_harmonic_threshold for processor: {pname}")
vad_updates = True
if vad_updates:
config_applied = True
logger.info("VAD configuration updated for existing processors")
# Normalization updates: apply to existing processors
norm_updates = False
if "normalization_enabled" in config_values:
enabled = bool(config_values["normalization_enabled"])
for pname, proc in list(_audio_processors.items()):
try:
proc.normalization_enabled = enabled
logger.info(f"Updated normalization_enabled to {enabled} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update normalization_enabled for processor: {pname}")
norm_updates = True
if "normalization_target_peak" in config_values:
target_peak = float(config_values["normalization_target_peak"])
for pname, proc in list(_audio_processors.items()):
try:
proc.normalization_target_peak = target_peak
logger.info(f"Updated normalization_target_peak to {target_peak} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update normalization_target_peak for processor: {pname}")
norm_updates = True
if "max_normalization_gain" in config_values:
max_gain = float(config_values["max_normalization_gain"])
for pname, proc in list(_audio_processors.items()):
try:
proc.max_normalization_gain = max_gain
logger.info(f"Updated max_normalization_gain to {max_gain} for processor: {pname}")
except Exception:
logger.debug(f"Failed to update max_normalization_gain for processor: {pname}")
norm_updates = True
if norm_updates:
config_applied = True
logger.info("Normalization configuration updated for existing processors")
if config_applied: if config_applied:
logger.info(f"Configuration update completed for lobby {lobby_id}") logger.info(f"Configuration update completed for lobby {lobby_id}")