Working
This commit is contained in:
parent
e96bd887ab
commit
90c3c6e19b
@ -4,7 +4,6 @@
|
||||
!client
|
||||
!shared
|
||||
**/node_modules
|
||||
**/build
|
||||
**/dist
|
||||
**/__pycache__
|
||||
**/.venv
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""
|
||||
Lobby management for the AI Voice Bot server.
|
||||
|
||||
This module handles lobby lifecycle, participants, and chat functionality.
|
||||
Lobby management for the AI Voice Bot server.handles lobby lifecycle, participants, and chat functionality.
|
||||
Extracted from main.py to improve maintainability and separation of concerns.
|
||||
"""
|
||||
|
||||
@ -9,21 +7,30 @@ from __future__ import annotations
|
||||
import secrets
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Callable
|
||||
import os
|
||||
|
||||
# Import shared models
|
||||
# Import shared models
|
||||
try:
|
||||
# Try relative import first (when running as part of the package)
|
||||
from ...shared.models import ChatMessageModel, ParticipantModel
|
||||
from ...shared.models import (
|
||||
ChatMessageModel,
|
||||
ParticipantModel,
|
||||
WebSocketMessageModel,
|
||||
)
|
||||
except ImportError:
|
||||
try:
|
||||
# Try absolute import (when running directly)
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
from shared.models import ChatMessageModel, ParticipantModel
|
||||
except ImportError:
|
||||
from shared.models import (
|
||||
ChatMessageModel,
|
||||
ParticipantModel,
|
||||
WebSocketMessageModel,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
f"Failed to import shared models: {e}. Ensure shared/models.py is accessible and PYTHONPATH is correctly set."
|
||||
)
|
||||
@ -32,25 +39,21 @@ from shared.logger import logger
|
||||
|
||||
# Use try/except for importing events to handle both relative and absolute imports
|
||||
try:
|
||||
from ..models.events import event_bus, ChatMessageSent, SessionDisconnected, SessionLeftLobby
|
||||
from ..models.events import (
|
||||
event_bus,
|
||||
ChatMessageSent,
|
||||
SessionDisconnected,
|
||||
SessionLeftLobby,
|
||||
Event,
|
||||
)
|
||||
except ImportError:
|
||||
try:
|
||||
from models.events import event_bus, ChatMessageSent, SessionDisconnected, SessionLeftLobby
|
||||
except ImportError:
|
||||
# Create dummy event system for standalone testing
|
||||
class DummyEventBus:
|
||||
async def publish(self, event):
|
||||
pass
|
||||
event_bus = DummyEventBus()
|
||||
|
||||
class ChatMessageSent:
|
||||
pass
|
||||
|
||||
class SessionDisconnected:
|
||||
pass
|
||||
|
||||
class SessionLeftLobby:
|
||||
pass
|
||||
from models.events import (
|
||||
event_bus,
|
||||
ChatMessageSent,
|
||||
SessionDisconnected,
|
||||
SessionLeftLobby,
|
||||
Event,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .session_manager import Session
|
||||
@ -62,34 +65,6 @@ class LobbyConfig:
|
||||
|
||||
|
||||
class Lobby:
|
||||
async def broadcast_json(self, message: dict) -> None:
|
||||
"""Broadcast an arbitrary JSON message to all connected sessions in the lobby"""
|
||||
failed_sessions: List[Session] = []
|
||||
for peer in self.sessions.values():
|
||||
if peer.ws:
|
||||
try:
|
||||
await peer.ws.send_json(message)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to send broadcast_json message to {peer.getName()}: {e}"
|
||||
)
|
||||
failed_sessions.append(peer)
|
||||
for failed_session in failed_sessions:
|
||||
failed_session.ws = None
|
||||
async def broadcast_peer_state_update(self, update: dict) -> None:
|
||||
"""Broadcast a peer state update to all connected sessions in the lobby"""
|
||||
failed_sessions: List[Session] = []
|
||||
for peer in self.sessions.values():
|
||||
if peer.ws:
|
||||
try:
|
||||
await peer.ws.send_json(update)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to send peer state update to {peer.getName()}: {e}"
|
||||
)
|
||||
failed_sessions.append(peer)
|
||||
for failed_session in failed_sessions:
|
||||
failed_session.ws = None
|
||||
"""Individual lobby representing a chat/voice room"""
|
||||
|
||||
def __init__(self, name: str, id: Optional[str] = None, private: bool = False):
|
||||
@ -104,6 +79,36 @@ class Lobby:
|
||||
def getName(self) -> str:
|
||||
return f"{self.short}:{self.name}"
|
||||
|
||||
async def broadcast_json(self, message: WebSocketMessageModel) -> None:
|
||||
"""Broadcast an arbitrary JSON message to all connected sessions in the lobby"""
|
||||
failed_sessions: List[Session] = []
|
||||
for peer in self.sessions.values():
|
||||
if peer.ws:
|
||||
try:
|
||||
await peer.ws.send_json(message.model_dump())
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to send broadcast_json message to {peer.getName()}: {e}"
|
||||
)
|
||||
failed_sessions.append(peer)
|
||||
for failed_session in failed_sessions:
|
||||
failed_session.ws = None
|
||||
|
||||
async def broadcast_peer_state_update(self, update: WebSocketMessageModel) -> None:
|
||||
"""Broadcast a peer state update to all connected sessions in the lobby"""
|
||||
failed_sessions: List[Session] = []
|
||||
for peer in self.sessions.values():
|
||||
if peer.ws:
|
||||
try:
|
||||
await peer.ws.send_json(update.model_dump())
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to send peer state update to {peer.getName()}: {e}"
|
||||
)
|
||||
failed_sessions.append(peer)
|
||||
for failed_session in failed_sessions:
|
||||
failed_session.ws = None
|
||||
|
||||
async def update_state(self, requesting_session: Optional[Session] = None):
|
||||
"""Update lobby state and notify participants"""
|
||||
with self.lock:
|
||||
@ -344,7 +349,7 @@ class LobbyManager:
|
||||
# Event system not available, skip subscriptions
|
||||
pass
|
||||
|
||||
async def handle(self, event):
|
||||
async def handle(self, event: Event) -> None:
|
||||
"""Handle events from the event bus"""
|
||||
|
||||
if isinstance(event, SessionDisconnected):
|
||||
@ -352,7 +357,7 @@ class LobbyManager:
|
||||
elif isinstance(event, SessionLeftLobby):
|
||||
await self._handle_session_left_lobby(event)
|
||||
|
||||
async def _handle_session_disconnected(self, event):
|
||||
async def _handle_session_disconnected(self, event: SessionDisconnected) -> None:
|
||||
"""Handle session disconnection by removing from all lobbies"""
|
||||
session_id = event.session_id
|
||||
|
||||
@ -372,7 +377,7 @@ class LobbyManager:
|
||||
if lobby.is_empty() and not lobby.private:
|
||||
await self._cleanup_empty_lobby(lobby)
|
||||
|
||||
async def _handle_session_left_lobby(self, event):
|
||||
async def _handle_session_left_lobby(self, event: SessionLeftLobby) -> None:
|
||||
"""Handle explicit session leave"""
|
||||
# This is already handled by the session's leave_lobby method
|
||||
# but we could add additional cleanup logic here if needed
|
||||
@ -447,8 +452,8 @@ class LobbyManager:
|
||||
|
||||
return removed_count
|
||||
|
||||
def set_name_protection_checker(self, checker_func):
|
||||
def set_name_protection_checker(self, checker_func: Callable[[str], bool]) -> None:
|
||||
"""Inject name protection checker from AuthManager"""
|
||||
# This allows us to inject the name protection logic without tight coupling
|
||||
for lobby in self.lobbies.values():
|
||||
lobby._is_name_protected = checker_func
|
||||
lobby._is_name_protected = checker_func # type: ignore
|
||||
|
@ -10,7 +10,12 @@ from typing import Dict, Any, TYPE_CHECKING
|
||||
from fastapi import WebSocket
|
||||
|
||||
from shared.logger import logger
|
||||
from shared.models import ChatMessageModel
|
||||
from shared.models import (
|
||||
ChatMessageModel,
|
||||
WebSocketMessageModel,
|
||||
WebSocketErrorModel,
|
||||
UpdateNameModel,
|
||||
)
|
||||
from .webrtc_signaling import WebRTCSignalingHandlers
|
||||
from core.error_handling import (
|
||||
error_handler,
|
||||
@ -45,11 +50,11 @@ class PeerStateUpdateHandler(MessageHandler):
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
session: Any,
|
||||
lobby: Any,
|
||||
data: dict,
|
||||
websocket: Any,
|
||||
managers: dict,
|
||||
session: "Session",
|
||||
lobby: "Lobby",
|
||||
data: Dict[str, Any],
|
||||
websocket: WebSocket,
|
||||
managers: Dict[str, Any],
|
||||
) -> None:
|
||||
# Only allow a user to update their own state
|
||||
if not lobby or not session:
|
||||
@ -59,14 +64,14 @@ class PeerStateUpdateHandler(MessageHandler):
|
||||
# Ignore attempts to update other users' state
|
||||
# Optionally log or send error to client
|
||||
return
|
||||
update = {
|
||||
"type": "peer_state_update",
|
||||
"data": {
|
||||
update = WebSocketMessageModel(
|
||||
type="peer_state_update",
|
||||
data={
|
||||
"peer_id": peer_id,
|
||||
"muted": data.get("muted"),
|
||||
"video_on": data.get("video_on"),
|
||||
},
|
||||
}
|
||||
}, # type: ignore
|
||||
)
|
||||
await lobby.broadcast_peer_state_update(update)
|
||||
|
||||
|
||||
@ -86,10 +91,12 @@ class SetNameHandler(MessageHandler):
|
||||
|
||||
if not data:
|
||||
logger.error(f"{session.getName()} - set_name missing data")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": {"error": "set_name missing data"},
|
||||
})
|
||||
await websocket.send_json(
|
||||
WebSocketMessageModel(
|
||||
type="error",
|
||||
data=WebSocketErrorModel(error="set_name missing data"),
|
||||
).model_dump()
|
||||
)
|
||||
return
|
||||
|
||||
name = data.get("name")
|
||||
@ -99,10 +106,11 @@ class SetNameHandler(MessageHandler):
|
||||
|
||||
if not name:
|
||||
logger.error(f"{session.getName()} - Name required")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": {"error": "Name required"}
|
||||
})
|
||||
await websocket.send_json(
|
||||
WebSocketMessageModel(
|
||||
type="error", data=WebSocketErrorModel(error="Name required")
|
||||
).model_dump()
|
||||
)
|
||||
return
|
||||
|
||||
# Check if name is unique
|
||||
@ -114,13 +122,14 @@ class SetNameHandler(MessageHandler):
|
||||
session.setName(name)
|
||||
logger.info(f"{session.getName()}: -> update('name', {name})")
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "update_name",
|
||||
"data": {
|
||||
"name": name,
|
||||
"protected": auth_manager.is_name_protected(name),
|
||||
},
|
||||
})
|
||||
await websocket.send_json(
|
||||
WebSocketMessageModel(
|
||||
type="update_name",
|
||||
data=UpdateNameModel(
|
||||
name=name, protected=auth_manager.is_name_protected(name)
|
||||
),
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
# Update lobby state
|
||||
await lobby.update_state()
|
||||
@ -131,10 +140,11 @@ class SetNameHandler(MessageHandler):
|
||||
|
||||
if not allowed:
|
||||
logger.warning(f"{session.getName()} - {reason}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": {"error": reason}
|
||||
})
|
||||
await websocket.send_json(
|
||||
WebSocketMessageModel(
|
||||
type="error", data=WebSocketErrorModel(error=reason)
|
||||
).model_dump()
|
||||
)
|
||||
return
|
||||
|
||||
# Takeover allowed - handle displacement
|
||||
@ -179,13 +189,14 @@ class SetNameHandler(MessageHandler):
|
||||
session.setName(name)
|
||||
logger.info(f"{session.getName()}: -> update('name', {name}) (takeover)")
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "update_name",
|
||||
"data": {
|
||||
"name": name,
|
||||
"protected": auth_manager.is_name_protected(name),
|
||||
},
|
||||
})
|
||||
await websocket.send_json(
|
||||
WebSocketMessageModel(
|
||||
type="update_name",
|
||||
data=UpdateNameModel(
|
||||
name=name, protected=auth_manager.is_name_protected(name)
|
||||
),
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
# Update lobby state
|
||||
await lobby.update_state()
|
||||
@ -460,15 +471,15 @@ class MessageRouter:
|
||||
):
|
||||
"""Route a message to the appropriate handler with enhanced error handling"""
|
||||
if message_type not in self._handlers:
|
||||
await error_handler.handle_error(
|
||||
await error_handler.handle_error( # type: ignore
|
||||
ValidationError(f"Unknown message type: {message_type}"),
|
||||
context={
|
||||
"message_type": message_type,
|
||||
"session_id": session.id if session else "unknown",
|
||||
"data_keys": list(data.keys()) if data else []
|
||||
"data_keys": list(data.keys()) if data else [],
|
||||
},
|
||||
websocket=websocket,
|
||||
session_id=session.id if session else None
|
||||
session_id=session.id if session else None,
|
||||
)
|
||||
return
|
||||
|
||||
@ -480,48 +491,50 @@ class MessageRouter:
|
||||
|
||||
except WebSocketError as e:
|
||||
# WebSocket specific errors - attempt recovery
|
||||
await error_handler.handle_error(
|
||||
await error_handler.handle_error( # type: ignore
|
||||
e,
|
||||
context={
|
||||
"message_type": message_type,
|
||||
"session_id": session.id if session else "unknown",
|
||||
"handler": type(self._handlers[message_type]).__name__
|
||||
"handler": type(self._handlers[message_type]).__name__,
|
||||
},
|
||||
websocket=websocket,
|
||||
session_id=session.id if session else None,
|
||||
recovery_action=lambda: self._websocket_recovery(websocket, session)
|
||||
recovery_action=lambda: self._websocket_recovery(websocket, session),
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
# Validation errors - usually client-side issues
|
||||
await error_handler.handle_error(
|
||||
await error_handler.handle_error( # type: ignore
|
||||
e,
|
||||
context={
|
||||
"message_type": message_type,
|
||||
"session_id": session.id if session else "unknown",
|
||||
"data": str(data)[:500] # Truncate large data
|
||||
"data": str(data)[:500], # Truncate large data
|
||||
},
|
||||
websocket=websocket,
|
||||
session_id=session.id if session else None
|
||||
session_id=session.id if session else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Unexpected errors - enhanced logging and fallback
|
||||
await error_handler.handle_error(
|
||||
await error_handler.handle_error( # type: ignore
|
||||
WebSocketError(
|
||||
f"Unexpected error in {message_type} handler: {e}",
|
||||
severity=ErrorSeverity.HIGH
|
||||
severity=ErrorSeverity.HIGH,
|
||||
),
|
||||
context={
|
||||
"message_type": message_type,
|
||||
"session_id": session.id if session else "unknown",
|
||||
"handler": type(self._handlers[message_type]).__name__,
|
||||
"exception_type": type(e).__name__,
|
||||
"traceback": str(e)
|
||||
"traceback": str(e),
|
||||
},
|
||||
websocket=websocket,
|
||||
session_id=session.id if session else None,
|
||||
recovery_action=lambda: self._generic_recovery(message_type, session, lobby)
|
||||
recovery_action=lambda: self._generic_recovery(
|
||||
message_type, session, lobby
|
||||
),
|
||||
)
|
||||
|
||||
async def _websocket_recovery(self, websocket: WebSocket, session: "Session"):
|
||||
|
@ -65,43 +65,6 @@ _device = "GPU.1" # Default to Intel Arc B580 GPU
|
||||
_generate_global_lock = threading.Lock()
|
||||
|
||||
|
||||
def _blocking_generate_decode(audio_array: AudioArray, sample_rate: int, generation_config: GenerationConfig | None = None) -> str:
|
||||
"""Blocking helper to run processor -> model.generate -> decode while
|
||||
holding a global lock to serialize OpenVINO access.
|
||||
"""
|
||||
try:
|
||||
with _generate_global_lock:
|
||||
ov_model = _ensure_model_loaded()
|
||||
if ov_model.processor is None:
|
||||
raise RuntimeError("Processor not initialized for OpenVINO model")
|
||||
|
||||
# Extract features
|
||||
inputs = ov_model.processor(audio_array, sampling_rate=sample_rate, return_tensors="pt")
|
||||
input_features = inputs.input_features
|
||||
|
||||
# Use a basic generation config if none provided
|
||||
gen_cfg = generation_config or GenerationConfig(max_new_tokens=128)
|
||||
|
||||
gen_out = ov_model.ov_model.generate(input_features, generation_config=gen_cfg) # type: ignore
|
||||
|
||||
# Prefer .sequences if available
|
||||
if hasattr(gen_out, "sequences"):
|
||||
ids = gen_out.sequences
|
||||
else:
|
||||
ids = gen_out
|
||||
|
||||
# Decode
|
||||
try:
|
||||
transcription = ov_model.processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
|
||||
except Exception:
|
||||
transcription = ""
|
||||
|
||||
return transcription
|
||||
except Exception as e:
|
||||
logger.error(f"blocking_generate_decode failed: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
|
||||
def get_available_devices() -> list[dict[str, Any]]:
|
||||
"""List available OpenVINO devices with their properties."""
|
||||
try:
|
||||
@ -230,7 +193,7 @@ class OpenVINOConfig(BaseModel):
|
||||
cfg.update(
|
||||
{
|
||||
"CPU_THROUGHPUT_NUM_THREADS": str(self.max_threads),
|
||||
"CPU_BIND_THREAD": "YES",
|
||||
# "CPU_BIND_THREAD": "YES", # Removed: not supported by CPU plugin
|
||||
}
|
||||
)
|
||||
|
||||
@ -245,7 +208,7 @@ CHUNK_DURATION_MS = 100 # Reduced latency - 100ms chunks
|
||||
VAD_THRESHOLD = 0.01 # Initial voice activity detection threshold
|
||||
MAX_SILENCE_FRAMES = 30 # 3 seconds of silence before stopping (for overall silence)
|
||||
MAX_TRAILING_SILENCE_FRAMES = 5 # 0.5 seconds of trailing silence
|
||||
VAD_CONFIG = {
|
||||
VAD_CONFIG: Dict[str, Any] = {
|
||||
"energy_threshold": 0.01,
|
||||
"zcr_threshold": 0.1,
|
||||
"adapt_thresholds": True,
|
||||
@ -301,7 +264,7 @@ def setup_intel_arc_environment() -> None:
|
||||
class AdvancedVAD:
|
||||
"""Advanced Voice Activity Detection with noise rejection."""
|
||||
|
||||
def __init__(self, sample_rate: int = SAMPLE_RATE):
|
||||
def __init__(self, sample_rate: int = 16000):
|
||||
self.sample_rate = sample_rate
|
||||
# More permissive thresholds based on research
|
||||
self.energy_threshold = 0.005 # Reduced from 0.02
|
||||
@ -315,7 +278,7 @@ class AdvancedVAD:
|
||||
|
||||
# Relaxed temporal consistency
|
||||
self.minimum_duration = 0.2 # Reduced from 0.3s
|
||||
self.speech_history = []
|
||||
self.speech_history: List[bool] = []
|
||||
self.max_history = 8 # Reduced from 10
|
||||
|
||||
# Adaptive noise floor
|
||||
@ -327,7 +290,7 @@ class AdvancedVAD:
|
||||
self.prev_magnitude = None
|
||||
self.harmonic_threshold = 0.15 # Reduced from 0.3
|
||||
|
||||
def analyze_frame(self, audio_data: AudioArray) -> Tuple[bool, dict]:
|
||||
def analyze_frame(self, audio_data: AudioArray) -> Tuple[bool, Dict[str, Any]]:
|
||||
"""Analyze audio frame for speech vs noise."""
|
||||
|
||||
# Basic energy features
|
||||
@ -403,7 +366,7 @@ class AdvancedVAD:
|
||||
(1 - self.adaptation_rate) * self.noise_floor_energy
|
||||
)
|
||||
|
||||
metrics = {
|
||||
metrics: Dict[str, Any] = {
|
||||
'energy': energy,
|
||||
'zcr': zcr,
|
||||
'centroid': spectral_features['centroid'],
|
||||
@ -419,9 +382,9 @@ class AdvancedVAD:
|
||||
'temporal_consistency': recent_speech
|
||||
}
|
||||
|
||||
return recent_speech, metrics
|
||||
return recent_speech, metrics # type: ignore
|
||||
|
||||
def _compute_spectral_features(self, audio_data: AudioArray) -> dict:
|
||||
def _compute_spectral_features(self, audio_data: AudioArray) -> Dict[str, Any]:
|
||||
"""Compute spectral features for speech detection."""
|
||||
|
||||
# Apply window to reduce spectral leakage
|
||||
@ -464,7 +427,7 @@ class AdvancedVAD:
|
||||
'harmonicity': harmonicity
|
||||
}
|
||||
|
||||
def _compute_harmonicity(self, magnitude: np.ndarray, freqs: np.ndarray) -> float:
|
||||
def _compute_harmonicity(self, magnitude: npt.NDArray[np.float32], freqs: npt.NDArray[np.float32]) -> float:
|
||||
"""Compute harmonicity score (0-1, higher = more harmonic/speech-like)."""
|
||||
|
||||
# Find fundamental frequency candidate (peak in 80-400Hz range for speech)
|
||||
@ -483,24 +446,24 @@ class AdvancedVAD:
|
||||
# More robust F0 detection - find peaks instead of just max
|
||||
try:
|
||||
# Import scipy here to handle missing dependency gracefully
|
||||
from scipy.signal import find_peaks
|
||||
from scipy.signal import find_peaks # type: ignore
|
||||
|
||||
# Ensure distance is at least 1
|
||||
min_distance = max(1, int(len(speech_magnitude) * 0.05))
|
||||
|
||||
peaks, properties = find_peaks(
|
||||
peaks, properties = find_peaks( # type: ignore
|
||||
speech_magnitude,
|
||||
height=np.max(speech_magnitude) * 0.05, # Lowered from 0.1
|
||||
distance=min_distance, # Minimum peak separation
|
||||
)
|
||||
|
||||
if len(peaks) == 0:
|
||||
if len(peaks) == 0: # type: ignore
|
||||
# Fallback to simple max if no peaks found
|
||||
f0_idx = np.argmax(speech_magnitude)
|
||||
else:
|
||||
# Use the strongest peak
|
||||
strongest_peak_idx = np.argmax(speech_magnitude[peaks])
|
||||
f0_idx = peaks[strongest_peak_idx]
|
||||
f0_idx = int(peaks[strongest_peak_idx]) # type: ignore
|
||||
|
||||
except ImportError:
|
||||
# scipy not available, use simple max
|
||||
@ -526,8 +489,8 @@ class AdvancedVAD:
|
||||
harmonic_idx = np.argmin(np.abs(freqs - harmonic_freq))
|
||||
|
||||
# Check a small neighborhood around the harmonic frequency
|
||||
start_idx = max(0, harmonic_idx - 2)
|
||||
end_idx = min(len(magnitude), harmonic_idx + 3)
|
||||
start_idx = max(0, int(harmonic_idx) - 2)
|
||||
end_idx = min(len(magnitude), int(harmonic_idx) + 3)
|
||||
local_max = np.max(magnitude[start_idx:end_idx])
|
||||
|
||||
harmonic_strength += local_max
|
||||
@ -565,13 +528,13 @@ class OpenVINOWhisperModel:
|
||||
logger.info(
|
||||
f"Loading Whisper model '{self.model_id}' on device: {self.device}"
|
||||
)
|
||||
self.processor = WhisperProcessor.from_pretrained(
|
||||
self.processor = WhisperProcessor.from_pretrained( # type: ignore
|
||||
self.model_id, use_fast=True
|
||||
) # type: ignore
|
||||
logger.info("Whisper processor loaded successfully")
|
||||
|
||||
# Export the model to OpenVINO IR if not already converted
|
||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained( # type: ignore
|
||||
self.model_id, export=True, device=self.device
|
||||
) # type: ignore
|
||||
|
||||
@ -614,7 +577,7 @@ class OpenVINOWhisperModel:
|
||||
|
||||
try:
|
||||
# Convert to OpenVINO with FP16 for Arc GPU
|
||||
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
||||
ov_model = OVModelForSpeechSeq2Seq.from_pretrained( # type: ignore
|
||||
self.model_id,
|
||||
ov_config=self.config.to_ov_config(),
|
||||
export=True,
|
||||
@ -623,12 +586,13 @@ class OpenVINOWhisperModel:
|
||||
)
|
||||
|
||||
# Enable FP16 for Intel Arc performance
|
||||
ov_model.half()
|
||||
ov_model.save_pretrained(self.model_path)
|
||||
if hasattr(ov_model, 'half'):
|
||||
ov_model.half() # type: ignore
|
||||
ov_model.save_pretrained(self.model_path) # type: ignore
|
||||
logger.info("Model converted and saved in FP16 format")
|
||||
|
||||
# Load the converted model
|
||||
self.ov_model = ov_model
|
||||
self.ov_model = ov_model # type: ignore
|
||||
self._compile_model()
|
||||
|
||||
except Exception as e:
|
||||
@ -639,38 +603,38 @@ class OpenVINOWhisperModel:
|
||||
"""Basic model conversion without advanced features."""
|
||||
logger.info(f"Basic conversion of {self.model_id} to OpenVINO format...")
|
||||
|
||||
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
||||
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
|
||||
self.model_id, export=True, compile=False
|
||||
)
|
||||
|
||||
ov_model.save_pretrained(self.model_path)
|
||||
ov_model.save_pretrained(self.model_path)# type: ignore
|
||||
logger.info("Basic model conversion completed")
|
||||
|
||||
def _load_fp16_model(self) -> None:
|
||||
"""Load existing FP16 OpenVINO model."""
|
||||
logger.info("Loading existing FP16 OpenVINO model...")
|
||||
try:
|
||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
|
||||
self.model_path, ov_config=self.config.to_ov_config(), compile=False
|
||||
)
|
||||
) # type: ignore
|
||||
self._compile_model()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load FP16 model: {e}")
|
||||
# Try basic loading
|
||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
|
||||
self.model_path, compile=False
|
||||
)
|
||||
) # type: ignore
|
||||
self._compile_model()
|
||||
|
||||
def _try_load_quantized_model(self) -> bool:
|
||||
"""Try to load existing quantized model."""
|
||||
try:
|
||||
logger.info("Loading existing INT8 quantized model...")
|
||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
|
||||
self.quantized_model_path,
|
||||
ov_config=self.config.to_ov_config(),
|
||||
compile=False,
|
||||
)
|
||||
) # type: ignore
|
||||
self._compile_model()
|
||||
self.is_quantized = True
|
||||
logger.info("Quantized model loaded successfully")
|
||||
@ -690,13 +654,12 @@ class OpenVINOWhisperModel:
|
||||
return
|
||||
|
||||
# Check if model components are available
|
||||
if not hasattr(self.ov_model, "encoder") or self.ov_model.encoder is None:
|
||||
if not hasattr(self.ov_model, "encoder"):
|
||||
logger.warning("Model encoder not available, skipping quantization")
|
||||
return
|
||||
|
||||
if (
|
||||
not hasattr(self.ov_model, "decoder_with_past")
|
||||
or self.ov_model.decoder_with_past is None
|
||||
):
|
||||
logger.warning(
|
||||
"Model decoder_with_past not available, skipping quantization"
|
||||
@ -761,14 +724,14 @@ class OpenVINOWhisperModel:
|
||||
|
||||
# Save quantized models
|
||||
self.quantized_model_path.mkdir(parents=True, exist_ok=True)
|
||||
ov.save_model(
|
||||
ov.save_model(# type: ignore
|
||||
quantized_encoder,
|
||||
self.quantized_model_path / "openvino_encoder_model.xml",
|
||||
) # type: ignore
|
||||
ov.save_model(
|
||||
ov.save_model(# type: ignore
|
||||
quantized_decoder,
|
||||
self.quantized_model_path / "openvino_decoder_with_past_model.xml",
|
||||
) # type: ignore
|
||||
) # type: ignore # type: ignore
|
||||
|
||||
# Copy remaining files
|
||||
self._copy_model_files()
|
||||
@ -828,12 +791,12 @@ class OpenVINOWhisperModel:
|
||||
decoder_data: CalibrationData = []
|
||||
|
||||
try:
|
||||
self.ov_model.encoder.request = InferRequestWrapper(
|
||||
original_encoder_request, encoder_data
|
||||
)
|
||||
self.ov_model.encoder.request = InferRequestWrapper(# type: ignore
|
||||
original_encoder_request, encoder_data# type: ignore
|
||||
) # type: ignore
|
||||
self.ov_model.decoder_with_past.request = InferRequestWrapper(
|
||||
original_decoder_request, decoder_data
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
# Generate synthetic calibration data instead of loading dataset
|
||||
logger.info("Generating synthetic calibration data...")
|
||||
@ -842,17 +805,17 @@ class OpenVINOWhisperModel:
|
||||
# Generate random audio similar to speech
|
||||
duration = 2.0 + np.random.random() * 3.0 # 2-5 seconds
|
||||
synthetic_audio = (
|
||||
np.random.randn(int(SAMPLE_RATE * duration)).astype(np.float32)
|
||||
np.random.randn(int(16000 * duration)).astype(np.float32)
|
||||
* 0.1
|
||||
)
|
||||
|
||||
inputs: Any = self.processor(
|
||||
synthetic_audio, sampling_rate=SAMPLE_RATE, return_tensors="pt"
|
||||
)
|
||||
inputs: Any = self.processor(# type: ignore
|
||||
synthetic_audio, sampling_rate=16000, return_tensors="pt"
|
||||
) # type: ignore
|
||||
|
||||
# Run inference to collect calibration data
|
||||
generated_ids = self.ov_model.generate(
|
||||
inputs.input_features, max_new_tokens=10
|
||||
_ = self.ov_model.generate( # type: ignore
|
||||
inputs.input_features, max_new_tokens=10 # type: ignore
|
||||
)
|
||||
|
||||
if i % 5 == 0:
|
||||
@ -882,7 +845,7 @@ class OpenVINOWhisperModel:
|
||||
result["decoder"] = decoder_data
|
||||
logger.info(f"Collected {len(decoder_data)} decoder calibration samples")
|
||||
|
||||
return result
|
||||
return result # type: ignore
|
||||
|
||||
def _copy_model_files(self) -> None:
|
||||
"""Copy necessary model files for quantized model."""
|
||||
@ -951,28 +914,29 @@ class OpenVINOWhisperModel:
|
||||
)
|
||||
# Try to reload using the existing saved model path if possible
|
||||
try:
|
||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
|
||||
self.model_path, ov_config=cpu_cfg.to_ov_config(), compile=False
|
||||
)
|
||||
) # type: ignore
|
||||
except Exception:
|
||||
# If loading the saved model failed, try loading without ov_config
|
||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
|
||||
self.model_path, compile=False
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
# Compile on CPU
|
||||
self.ov_model.to("CPU")
|
||||
# Provide CPU-only ov_config if supported
|
||||
try:
|
||||
self.ov_model.compile()
|
||||
except Exception as compile_cpu_e:
|
||||
logger.warning(
|
||||
f"CPU compile with CPU ov_config failed, retrying default compile: {compile_cpu_e}"
|
||||
)
|
||||
self.ov_model.compile()
|
||||
if self.ov_model is not None:
|
||||
self.ov_model.to("CPU") # type: ignore
|
||||
# Provide CPU-only ov_config if supported
|
||||
try:
|
||||
self.ov_model.compile() # type: ignore
|
||||
except Exception as compile_cpu_e:
|
||||
logger.warning(
|
||||
f"CPU compile with CPU ov_config failed, retrying default compile: {compile_cpu_e}"
|
||||
)
|
||||
self.ov_model.compile() # type: ignore
|
||||
|
||||
self._warmup_model()
|
||||
logger.info("Model compiled for CPU successfully")
|
||||
self._warmup_model()
|
||||
logger.info("Model compiled for CPU successfully")
|
||||
except Exception as cpu_e:
|
||||
logger.error(f"Failed to compile for CPU as well: {cpu_e}")
|
||||
raise
|
||||
@ -984,14 +948,14 @@ class OpenVINOWhisperModel:
|
||||
|
||||
try:
|
||||
logger.info("Warming up model...")
|
||||
dummy_audio = np.random.randn(SAMPLE_RATE).astype(np.float32) # 1 second
|
||||
dummy_features = self.processor(
|
||||
dummy_audio, sampling_rate=SAMPLE_RATE, return_tensors="pt"
|
||||
dummy_audio = np.random.randn(16000).astype(np.float32) # 1 second
|
||||
dummy_features = self.processor(# type: ignore
|
||||
dummy_audio, sampling_rate=16000, return_tensors="pt"
|
||||
).input_features
|
||||
|
||||
# Run warmup iterations
|
||||
for i in range(3):
|
||||
_ = self.ov_model.generate(dummy_features, max_new_tokens=10)
|
||||
_ = self.ov_model.generate(dummy_features, max_new_tokens=10)# type: ignore
|
||||
if i == 0:
|
||||
logger.debug("First warmup iteration completed")
|
||||
except Exception as e:
|
||||
@ -1004,9 +968,9 @@ class OpenVINOWhisperModel:
|
||||
if self.processor is None:
|
||||
raise RuntimeError("Processor not initialized")
|
||||
|
||||
return self.processor.batch_decode(
|
||||
return self.processor.batch_decode(# type: ignore
|
||||
token_ids, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
# Global model instance with deferred loading
|
||||
@ -1046,12 +1010,12 @@ def extract_input_features(audio_array: AudioArray, sampling_rate: int) -> torch
|
||||
if ov_model.processor is None:
|
||||
raise RuntimeError("Processor not initialized")
|
||||
|
||||
inputs = ov_model.processor(
|
||||
inputs = ov_model.processor(# type: ignore
|
||||
audio_array,
|
||||
sampling_rate=sampling_rate,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return inputs.input_features
|
||||
) # type: ignore
|
||||
return inputs.input_features # type: ignore
|
||||
|
||||
|
||||
class VoiceActivityDetector(BaseModel):
|
||||
@ -1064,7 +1028,7 @@ class VoiceActivityDetector(BaseModel):
|
||||
def simple_robust_vad(
|
||||
audio_data: AudioArray,
|
||||
energy_threshold: float = 0.01,
|
||||
sample_rate: int = SAMPLE_RATE,
|
||||
sample_rate: int = 16000,
|
||||
) -> VoiceActivityDetector:
|
||||
"""Simplified robust VAD."""
|
||||
|
||||
@ -1091,7 +1055,7 @@ def enhanced_vad(
|
||||
audio_data: AudioArray,
|
||||
energy_threshold: float = 0.01,
|
||||
zcr_threshold: float = 0.1,
|
||||
sample_rate: int = SAMPLE_RATE,
|
||||
sample_rate: int = 16000,
|
||||
) -> VoiceActivityDetector:
|
||||
"""Enhanced VAD using multiple features.
|
||||
|
||||
@ -1137,14 +1101,39 @@ class OptimizedAudioProcessor:
|
||||
self.peer_name = peer_name
|
||||
self.send_chat_func = send_chat_func
|
||||
self.create_chat_message_func = create_chat_message_func
|
||||
self.sample_rate = SAMPLE_RATE
|
||||
|
||||
# Audio processing settings (use defaults, can be overridden per instance)
|
||||
self.sample_rate = 16000 # Default Whisper sample rate
|
||||
self.chunk_duration_ms = 100 # Default chunk duration
|
||||
self.chunk_size = int(self.sample_rate * self.chunk_duration_ms / 1000)
|
||||
|
||||
# Silence handling parameters
|
||||
self.max_silence_frames = 30 # Default max silence frames
|
||||
self.max_trailing_silence_frames = 5 # Default trailing silence frames
|
||||
|
||||
# VAD settings (use defaults, can be overridden per instance)
|
||||
self.vad_energy_threshold = 0.005
|
||||
self.vad_zcr_min = 0.02
|
||||
self.vad_zcr_max = 0.8
|
||||
self.vad_spectral_centroid_min = 200
|
||||
self.vad_spectral_centroid_max = 4000
|
||||
self.vad_spectral_rolloff_threshold = 3000
|
||||
self.vad_minimum_duration = 0.2
|
||||
self.vad_max_history = 8
|
||||
self.vad_noise_floor_energy = 0.001
|
||||
self.vad_adaptation_rate = 0.05
|
||||
self.vad_harmonic_threshold = 0.15
|
||||
|
||||
# Normalization settings
|
||||
self.normalization_enabled = True # Default normalization enabled
|
||||
self.normalization_target_peak = 0.7 # Default target peak
|
||||
self.max_normalization_gain = 10.0 # Default max gain
|
||||
|
||||
# Initialize visualization buffer if not already done
|
||||
if self.peer_name not in WaveformVideoTrack.buffer:
|
||||
WaveformVideoTrack.buffer[self.peer_name] = np.array([], dtype=np.float32)
|
||||
|
||||
# Optimized buffering parameters
|
||||
self.chunk_size = int(self.sample_rate * CHUNK_DURATION_MS / 1000)
|
||||
self.buffer_size = self.chunk_size * 50
|
||||
|
||||
# Circular buffer for zero-copy operations
|
||||
@ -1154,8 +1143,6 @@ class OptimizedAudioProcessor:
|
||||
|
||||
# Silence handling parameters
|
||||
self.silence_frames: int = 0
|
||||
self.max_silence_frames: int = MAX_SILENCE_FRAMES
|
||||
self.max_trailing_silence_frames: int = MAX_TRAILING_SILENCE_FRAMES
|
||||
|
||||
# Enhanced VAD parameters with EMA for noise adaptation
|
||||
self.advanced_vad = AdvancedVAD(sample_rate=self.sample_rate)
|
||||
@ -1165,9 +1152,6 @@ class OptimizedAudioProcessor:
|
||||
# maximum which helps models expect a consistent level across peers.
|
||||
# It's intentionally permissive and capped to avoid amplifying noise.
|
||||
self.max_observed_amplitude: float = 1e-6
|
||||
self.normalization_enabled: bool = True
|
||||
self.normalization_target_peak: float = 0.95
|
||||
self.max_normalization_gain: float = 3.0 # avoid amplifying tiny noise too much
|
||||
|
||||
# Processing state
|
||||
self.current_phrase_audio = np.array([], dtype=np.float32)
|
||||
@ -1476,9 +1460,9 @@ class OptimizedAudioProcessor:
|
||||
ov_model = _ensure_model_loaded()
|
||||
|
||||
# Extract features (this is relatively cheap but keep on thread)
|
||||
input_features = ov_model.processor(
|
||||
input_features = ov_model.processor(# type: ignore
|
||||
audio_in, sampling_rate=self.sample_rate, return_tensors="pt"
|
||||
).input_features
|
||||
).input_features # type: ignore
|
||||
|
||||
# Perform generation (blocking)
|
||||
# Use the same generation configuration as the async path
|
||||
@ -1496,23 +1480,24 @@ class OptimizedAudioProcessor:
|
||||
# Serialize access to the underlying OpenVINO generation call
|
||||
# to avoid concurrency problems with the OpenVINO runtime.
|
||||
with _generate_global_lock:
|
||||
gen_out = ov_model.ov_model.generate(
|
||||
input_features, generation_config=gen_cfg
|
||||
gen_out = ov_model.ov_model.generate(# type: ignore
|
||||
input_features, generation_config=gen_cfg# type: ignore
|
||||
)
|
||||
|
||||
# Try to extract sequences if present
|
||||
if hasattr(gen_out, "sequences"):
|
||||
ids = gen_out.sequences
|
||||
if hasattr(gen_out, "sequences"): # type: ignore
|
||||
ids = gen_out.sequences # type: ignore
|
||||
else:
|
||||
ids = gen_out
|
||||
ids = gen_out # type: ignore
|
||||
|
||||
# Decode
|
||||
text: str = ""
|
||||
try:
|
||||
text = ov_model.processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
|
||||
text = ov_model.processor.batch_decode(ids, skip_special_tokens=True)[0].strip() # type: ignore
|
||||
except Exception:
|
||||
text = ""
|
||||
|
||||
return text, 0.0
|
||||
return text, 0.0 # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"Blocking transcription failed for {self.peer_name}: {e}", exc_info=True)
|
||||
return "", 0.0
|
||||
@ -1933,7 +1918,7 @@ class OptimizedAudioProcessor:
|
||||
# Many generate implementations return an object with a
|
||||
# `.sequences` attribute, so prefer that when available.
|
||||
if hasattr(generation_output, "sequences"):
|
||||
generated_ids = generation_output.sequences
|
||||
generated_ids = generation_output.sequences # type: ignore
|
||||
else:
|
||||
generated_ids = generation_output
|
||||
|
||||
@ -1958,9 +1943,9 @@ class OptimizedAudioProcessor:
|
||||
# Primary decode attempt
|
||||
transcription: str = ""
|
||||
try:
|
||||
transcription = ov_model.processor.batch_decode(
|
||||
transcription = ov_model.processor.batch_decode(# type: ignore
|
||||
generated_ids, skip_special_tokens=True
|
||||
)[0].strip()
|
||||
)[0].strip() # type: ignore
|
||||
except Exception as decode_e:
|
||||
logger.warning(f"{self.peer_name}: primary decode failed: {decode_e}")
|
||||
|
||||
@ -1969,11 +1954,11 @@ class OptimizedAudioProcessor:
|
||||
if not transcription:
|
||||
try:
|
||||
if hasattr(generation_output, "sequences") and (
|
||||
generated_ids is not generation_output.sequences
|
||||
generated_ids is not generation_output.sequences # type: ignore
|
||||
):
|
||||
transcription = ov_model.processor.batch_decode(
|
||||
generation_output.sequences, skip_special_tokens=True
|
||||
)[0].strip()
|
||||
transcription = ov_model.processor.batch_decode(# type: ignore
|
||||
generation_output.sequences, skip_special_tokens=True # type: ignore
|
||||
)[0].strip() # type: ignore
|
||||
except Exception as fallback_e:
|
||||
logger.warning(f"{self.peer_name}: fallback decode failed: {fallback_e}")
|
||||
|
||||
@ -1982,11 +1967,11 @@ class OptimizedAudioProcessor:
|
||||
try:
|
||||
if is_final:
|
||||
logger.info(
|
||||
f"{self.peer_name}: final transcription empty after decode; generated_ids repr/shape: {repr(generated_ids)[:200]}"
|
||||
f"{self.peer_name}: final transcription empty after decode"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"{self.peer_name}: streaming transcription empty after decode; generated_ids repr/shape: {repr(generated_ids)[:200]}"
|
||||
f"{self.peer_name}: streaming transcription empty after decode"
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(f"{self.peer_name}: generated_ids unavailable for diagnostics")
|
||||
@ -2020,7 +2005,7 @@ class OptimizedAudioProcessor:
|
||||
# Avoid duplicates for streaming updates, but always send final
|
||||
# transcriptions so the UI/clients receive the final marker even
|
||||
# if the text matches a recent interim result.
|
||||
if is_final or not self._is_duplicate(transcription):
|
||||
if is_final or not self._is_duplicate(transcription): # type: ignore
|
||||
# Reuse the existing message ID when possible so the frontend
|
||||
# updates the streaming message into a final message instead
|
||||
# of creating a new one. If there is no current_message, a
|
||||
@ -2163,7 +2148,7 @@ class WaveformVideoTrack(MediaStreamTrack):
|
||||
|
||||
# Shared buffer for audio data
|
||||
buffer: Dict[str, npt.NDArray[np.float32]] = {}
|
||||
speech_status: Dict[str, dict] = {}
|
||||
speech_status: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(
|
||||
self, session_name: str, width: int = 640, height: int = 480, fps: int = 15
|
||||
@ -2182,7 +2167,7 @@ class WaveformVideoTrack(MediaStreamTrack):
|
||||
return pts, time_base
|
||||
|
||||
async def recv(self) -> VideoFrame:
|
||||
pts, time_base = await self.next_timestamp()
|
||||
pts, _ = await self.next_timestamp()
|
||||
|
||||
# schedule frame according to clock
|
||||
target_t = self._next_frame_index / self.fps
|
||||
@ -2224,7 +2209,7 @@ class WaveformVideoTrack(MediaStreamTrack):
|
||||
|
||||
# Draw clock in lower right corner, right justified
|
||||
current_time = time.strftime("%H:%M:%S")
|
||||
(text_width, text_height), _ = cv2.getTextSize(
|
||||
(text_width, _), _ = cv2.getTextSize(
|
||||
current_time, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2
|
||||
)
|
||||
clock_x = self.width - text_width - 10 # 10px margin from right edge
|
||||
@ -2364,7 +2349,7 @@ class WaveformVideoTrack(MediaStreamTrack):
|
||||
|
||||
# Label the peak with small text near the right edge
|
||||
label = f"Peak:{target_peak:.2f}"
|
||||
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||
(tw, _), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||
lx = max(10, self.width - tw - 12)
|
||||
ly = max(12, top_y - 6)
|
||||
cv2.putText(frame_array, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 1)
|
||||
@ -2391,7 +2376,7 @@ class WaveformVideoTrack(MediaStreamTrack):
|
||||
frame.time_base = fractions.Fraction(1 / 90000).limit_denominator(1000000)
|
||||
return frame
|
||||
|
||||
def _draw_speech_status(self, frame_array: np.ndarray, speech_info: dict, pname: str):
|
||||
def _draw_speech_status(self, frame_array: npt.NDArray[np.uint8], speech_info: Dict[str, Any], pname: str):
|
||||
"""Draw speech detection status information."""
|
||||
|
||||
y_offset = 100
|
||||
@ -2414,7 +2399,7 @@ class WaveformVideoTrack(MediaStreamTrack):
|
||||
f"Temporal: ({'Y' if speech_info.get('temporal_consistency', False) else 'N'})"
|
||||
]
|
||||
|
||||
for i, metric in enumerate(metrics):
|
||||
for _, metric in enumerate(metrics):
|
||||
cv2.putText(frame_array, metric,
|
||||
(320, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.4,
|
||||
(255, 255, 255), 1)
|
||||
@ -2646,13 +2631,13 @@ def _resample_audio(
|
||||
audio_data = np.mean(audio_data, axis=1)
|
||||
|
||||
# Use high-quality resampling
|
||||
resampled = librosa.resample(
|
||||
resampled = librosa.resample( # type: ignore
|
||||
audio_data.astype(np.float64),
|
||||
orig_sr=orig_sr,
|
||||
target_sr=target_sr,
|
||||
res_type="kaiser_fast", # Good balance of quality and speed
|
||||
)
|
||||
return resampled.astype(np.float32)
|
||||
return resampled.astype(np.float32) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"Resampling failed: {e}")
|
||||
raise ValueError(
|
||||
@ -2760,7 +2745,7 @@ def get_config_schema() -> Dict[str, Any]:
|
||||
"type": "range",
|
||||
"label": "VAD Threshold",
|
||||
"description": "Voice activity detection threshold",
|
||||
"default_value": VAD_THRESHOLD,
|
||||
"default_value": 0.01,
|
||||
"required": False,
|
||||
"min_value": 0.001,
|
||||
"max_value": 0.1,
|
||||
@ -2915,7 +2900,7 @@ def get_config_schema() -> Dict[str, Any]:
|
||||
"type": "boolean",
|
||||
"label": "Enable Normalization",
|
||||
"description": "Normalize incoming audio based on observed peak amplitude before transcription and visualization",
|
||||
"default_value": NORMALIZATION_ENABLED,
|
||||
"default_value": True,
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
@ -2923,7 +2908,7 @@ def get_config_schema() -> Dict[str, Any]:
|
||||
"type": "number",
|
||||
"label": "Normalization Target Peak",
|
||||
"description": "Target peak (0-1) used when normalizing audio",
|
||||
"default_value": NORMALIZATION_TARGET_PEAK,
|
||||
"default_value": 0.7,
|
||||
"required": False,
|
||||
"min_value": 0.5,
|
||||
"max_value": 1.0
|
||||
@ -2933,7 +2918,7 @@ def get_config_schema() -> Dict[str, Any]:
|
||||
"type": "range",
|
||||
"label": "Max Normalization Gain",
|
||||
"description": "Maximum allowed gain applied during normalization",
|
||||
"default_value": MAX_NORMALIZATION_GAIN,
|
||||
"default_value": 10.0,
|
||||
"required": False,
|
||||
"min_value": 1.0,
|
||||
"max_value": 10.0,
|
||||
@ -2951,15 +2936,14 @@ def get_config_schema() -> Dict[str, Any]:
|
||||
|
||||
def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
||||
"""Handle configuration update for a specific lobby"""
|
||||
global _model_id, _device, _ov_config, SAMPLE_RATE, CHUNK_DURATION_MS, VAD_THRESHOLD
|
||||
global MAX_SILENCE_FRAMES, MAX_TRAILING_SILENCE_FRAMES
|
||||
global _model_id, _device, _ov_config
|
||||
|
||||
try:
|
||||
logger.info(f"Updating Whisper config for lobby {lobby_id}: {config_values}")
|
||||
|
||||
config_applied = False
|
||||
|
||||
# Update model configuration
|
||||
# Update model configuration (global - affects all instances)
|
||||
if "model_id" in config_values:
|
||||
new_model_id = config_values["model_id"]
|
||||
if new_model_id in [model for models in model_ids.values() for model in models]:
|
||||
@ -2969,9 +2953,9 @@ def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
||||
else:
|
||||
logger.warning(f"Invalid model_id: {new_model_id}")
|
||||
|
||||
# Update device configuration
|
||||
# Update device configuration (global - affects all instances)
|
||||
if "device" in config_values:
|
||||
new_device = config_values["device"]
|
||||
new_device = config_values["device"] # type: ignore
|
||||
available_devices = [d["name"] for d in get_available_devices()]
|
||||
if new_device in available_devices or new_device in ["CPU", "GPU", "GPU.1"]:
|
||||
_device = new_device
|
||||
@ -2981,7 +2965,7 @@ def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
||||
else:
|
||||
logger.warning(f"Invalid device: {new_device}, available: {available_devices}")
|
||||
|
||||
# Update OpenVINO configuration
|
||||
# Update OpenVINO configuration (global - affects all instances)
|
||||
if "enable_quantization" in config_values:
|
||||
_ov_config.enable_quantization = bool(config_values["enable_quantization"])
|
||||
config_applied = True
|
||||
@ -3001,106 +2985,212 @@ def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
||||
config_applied = True
|
||||
logger.info(f"Updated max_threads to: {_ov_config.max_threads}")
|
||||
|
||||
# Update audio processing parameters
|
||||
# Update audio processing parameters for existing processors
|
||||
if "sample_rate" in config_values:
|
||||
rate = int(config_values["sample_rate"])
|
||||
if 8000 <= rate <= 48000:
|
||||
SAMPLE_RATE = rate
|
||||
# Update existing processors
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.sample_rate = rate
|
||||
proc.chunk_size = int(proc.sample_rate * proc.chunk_duration_ms / 1000)
|
||||
logger.info(f"Updated sample_rate to {rate} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update sample_rate for processor: {pname}")
|
||||
config_applied = True
|
||||
logger.info(f"Updated sample_rate to: {SAMPLE_RATE}")
|
||||
logger.info(f"Updated sample_rate to: {rate}")
|
||||
|
||||
if "chunk_duration_ms" in config_values:
|
||||
duration = int(config_values["chunk_duration_ms"])
|
||||
if 50 <= duration <= 500:
|
||||
CHUNK_DURATION_MS = duration
|
||||
# Update existing processors
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.chunk_duration_ms = duration
|
||||
proc.chunk_size = int(proc.sample_rate * proc.chunk_duration_ms / 1000)
|
||||
logger.info(f"Updated chunk_duration_ms to {duration} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update chunk_duration_ms for processor: {pname}")
|
||||
config_applied = True
|
||||
logger.info(f"Updated chunk_duration_ms to: {CHUNK_DURATION_MS}")
|
||||
|
||||
if "vad_threshold" in config_values:
|
||||
threshold = float(config_values["vad_threshold"])
|
||||
if 0.001 <= threshold <= 0.1:
|
||||
VAD_THRESHOLD = threshold
|
||||
config_applied = True
|
||||
logger.info(f"Updated vad_threshold to: {VAD_THRESHOLD}")
|
||||
logger.info(f"Updated chunk_duration_ms to: {duration}")
|
||||
|
||||
if "max_silence_frames" in config_values:
|
||||
frames = int(config_values["max_silence_frames"])
|
||||
if 10 <= frames <= 100:
|
||||
MAX_SILENCE_FRAMES = frames
|
||||
# Update existing processors
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.max_silence_frames = frames
|
||||
logger.info(f"Updated max_silence_frames to {frames} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update max_silence_frames for processor: {pname}")
|
||||
config_applied = True
|
||||
logger.info(f"Updated max_silence_frames to: {MAX_SILENCE_FRAMES}")
|
||||
logger.info(f"Updated max_silence_frames to: {frames}")
|
||||
|
||||
if "max_trailing_silence_frames" in config_values:
|
||||
frames = int(config_values["max_trailing_silence_frames"])
|
||||
if 1 <= frames <= 20:
|
||||
MAX_TRAILING_SILENCE_FRAMES = frames
|
||||
config_applied = True
|
||||
logger.info(f"Updated max_trailing_silence_frames to: {MAX_TRAILING_SILENCE_FRAMES}")
|
||||
|
||||
# Update VAD configuration (this would require updating existing processors)
|
||||
vad_updates = {}
|
||||
if "vad_energy_threshold" in config_values:
|
||||
vad_updates["energy_threshold"] = float(config_values["vad_energy_threshold"])
|
||||
if "vad_zcr_min" in config_values:
|
||||
vad_updates["zcr_min"] = float(config_values["vad_zcr_min"])
|
||||
if "vad_zcr_max" in config_values:
|
||||
vad_updates["zcr_max"] = float(config_values["vad_zcr_max"])
|
||||
if "vad_spectral_centroid_min" in config_values:
|
||||
vad_updates["spectral_centroid_min"] = float(config_values["vad_spectral_centroid_min"])
|
||||
if "vad_spectral_centroid_max" in config_values:
|
||||
vad_updates["spectral_centroid_max"] = float(config_values["vad_spectral_centroid_max"])
|
||||
if "vad_spectral_rolloff_threshold" in config_values:
|
||||
vad_updates["spectral_rolloff_threshold"] = float(config_values["vad_spectral_rolloff_threshold"])
|
||||
if "vad_minimum_duration" in config_values:
|
||||
vad_updates["minimum_duration"] = float(config_values["vad_minimum_duration"])
|
||||
if "vad_max_history" in config_values:
|
||||
vad_updates["max_history"] = int(config_values["vad_max_history"])
|
||||
if "vad_noise_floor_energy" in config_values:
|
||||
vad_updates["noise_floor_energy"] = float(config_values["vad_noise_floor_energy"])
|
||||
if "vad_adaptation_rate" in config_values:
|
||||
vad_updates["adaptation_rate"] = float(config_values["vad_adaptation_rate"])
|
||||
if "vad_harmonic_threshold" in config_values:
|
||||
vad_updates["harmonic_threshold"] = float(config_values["vad_harmonic_threshold"])
|
||||
|
||||
if vad_updates:
|
||||
# Update VAD_CONFIG global
|
||||
VAD_CONFIG.update(vad_updates)
|
||||
config_applied = True
|
||||
logger.info(f"Updated VAD config: {vad_updates}")
|
||||
|
||||
# Note: Existing processors would need to be recreated to pick up VAD changes
|
||||
# For now, we'll log that a restart may be needed
|
||||
logger.info("VAD configuration updated - existing processors may need restart to take effect")
|
||||
|
||||
# Normalization updates: apply to global defaults and active processors
|
||||
norm_updates = False
|
||||
if "normalization_enabled" in config_values:
|
||||
NORMALIZATION_ENABLED = bool(config_values["normalization_enabled"])
|
||||
norm_updates = True
|
||||
logger.info(f"Updated NORMALIZATION_ENABLED to: {NORMALIZATION_ENABLED}")
|
||||
if "normalization_target_peak" in config_values:
|
||||
NORMALIZATION_TARGET_PEAK = float(config_values["normalization_target_peak"])
|
||||
norm_updates = True
|
||||
logger.info(f"Updated NORMALIZATION_TARGET_PEAK to: {NORMALIZATION_TARGET_PEAK}")
|
||||
if "max_normalization_gain" in config_values:
|
||||
MAX_NORMALIZATION_GAIN = float(config_values["max_normalization_gain"])
|
||||
norm_updates = True
|
||||
logger.info(f"Updated MAX_NORMALIZATION_GAIN to: {MAX_NORMALIZATION_GAIN}")
|
||||
|
||||
if norm_updates:
|
||||
# Propagate changes to existing processors
|
||||
try:
|
||||
# Update existing processors
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.normalization_enabled = NORMALIZATION_ENABLED
|
||||
proc.normalization_target_peak = NORMALIZATION_TARGET_PEAK
|
||||
proc.max_normalization_gain = MAX_NORMALIZATION_GAIN
|
||||
logger.info(f"Applied normalization config to processor: {pname}")
|
||||
proc.max_trailing_silence_frames = frames
|
||||
logger.info(f"Updated max_trailing_silence_frames to {frames} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to apply normalization config to processor: {pname}")
|
||||
logger.debug(f"Failed to update max_trailing_silence_frames for processor: {pname}")
|
||||
config_applied = True
|
||||
except Exception:
|
||||
logger.debug("Failed to propagate normalization settings to processors")
|
||||
logger.info(f"Updated max_trailing_silence_frames to: {frames}")
|
||||
|
||||
# Update VAD configuration for existing processors
|
||||
vad_updates = False
|
||||
if "vad_energy_threshold" in config_values:
|
||||
threshold = float(config_values["vad_energy_threshold"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.vad_energy_threshold = threshold
|
||||
logger.info(f"Updated vad_energy_threshold to {threshold} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update vad_energy_threshold for processor: {pname}")
|
||||
vad_updates = True
|
||||
|
||||
if "vad_zcr_min" in config_values:
|
||||
zcr_min = float(config_values["vad_zcr_min"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.vad_zcr_min = zcr_min
|
||||
logger.info(f"Updated vad_zcr_min to {zcr_min} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update vad_zcr_min for processor: {pname}")
|
||||
vad_updates = True
|
||||
|
||||
if "vad_zcr_max" in config_values:
|
||||
zcr_max = float(config_values["vad_zcr_max"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.vad_zcr_max = zcr_max
|
||||
logger.info(f"Updated vad_zcr_max to {zcr_max} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update vad_zcr_max for processor: {pname}")
|
||||
vad_updates = True
|
||||
|
||||
if "vad_spectral_centroid_min" in config_values:
|
||||
centroid_min = int(config_values["vad_spectral_centroid_min"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.vad_spectral_centroid_min = centroid_min
|
||||
logger.info(f"Updated vad_spectral_centroid_min to {centroid_min} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update vad_spectral_centroid_min for processor: {pname}")
|
||||
vad_updates = True
|
||||
|
||||
if "vad_spectral_centroid_max" in config_values:
|
||||
centroid_max = int(config_values["vad_spectral_centroid_max"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.vad_spectral_centroid_max = centroid_max
|
||||
logger.info(f"Updated vad_spectral_centroid_max to {centroid_max} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update vad_spectral_centroid_max for processor: {pname}")
|
||||
vad_updates = True
|
||||
|
||||
if "vad_spectral_rolloff_threshold" in config_values:
|
||||
rolloff = int(config_values["vad_spectral_rolloff_threshold"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.vad_spectral_rolloff_threshold = rolloff
|
||||
logger.info(f"Updated vad_spectral_rolloff_threshold to {rolloff} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update vad_spectral_rolloff_threshold for processor: {pname}")
|
||||
vad_updates = True
|
||||
|
||||
if "vad_minimum_duration" in config_values:
|
||||
duration = float(config_values["vad_minimum_duration"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.vad_minimum_duration = duration
|
||||
logger.info(f"Updated vad_minimum_duration to {duration} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update vad_minimum_duration for processor: {pname}")
|
||||
vad_updates = True
|
||||
|
||||
if "vad_max_history" in config_values:
|
||||
history = int(config_values["vad_max_history"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.vad_max_history = history
|
||||
logger.info(f"Updated vad_max_history to {history} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update vad_max_history for processor: {pname}")
|
||||
vad_updates = True
|
||||
|
||||
if "vad_noise_floor_energy" in config_values:
|
||||
noise_floor = float(config_values["vad_noise_floor_energy"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.vad_noise_floor_energy = noise_floor
|
||||
logger.info(f"Updated vad_noise_floor_energy to {noise_floor} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update vad_noise_floor_energy for processor: {pname}")
|
||||
vad_updates = True
|
||||
|
||||
if "vad_adaptation_rate" in config_values:
|
||||
adaptation_rate = float(config_values["vad_adaptation_rate"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.vad_adaptation_rate = adaptation_rate
|
||||
logger.info(f"Updated vad_adaptation_rate to {adaptation_rate} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update vad_adaptation_rate for processor: {pname}")
|
||||
vad_updates = True
|
||||
|
||||
if "vad_harmonic_threshold" in config_values:
|
||||
harmonic_threshold = float(config_values["vad_harmonic_threshold"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.vad_harmonic_threshold = harmonic_threshold
|
||||
logger.info(f"Updated vad_harmonic_threshold to {harmonic_threshold} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update vad_harmonic_threshold for processor: {pname}")
|
||||
vad_updates = True
|
||||
|
||||
if vad_updates:
|
||||
config_applied = True
|
||||
logger.info("VAD configuration updated for existing processors")
|
||||
|
||||
# Normalization updates: apply to existing processors
|
||||
norm_updates = False
|
||||
if "normalization_enabled" in config_values:
|
||||
enabled = bool(config_values["normalization_enabled"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.normalization_enabled = enabled
|
||||
logger.info(f"Updated normalization_enabled to {enabled} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update normalization_enabled for processor: {pname}")
|
||||
norm_updates = True
|
||||
|
||||
if "normalization_target_peak" in config_values:
|
||||
target_peak = float(config_values["normalization_target_peak"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.normalization_target_peak = target_peak
|
||||
logger.info(f"Updated normalization_target_peak to {target_peak} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update normalization_target_peak for processor: {pname}")
|
||||
norm_updates = True
|
||||
|
||||
if "max_normalization_gain" in config_values:
|
||||
max_gain = float(config_values["max_normalization_gain"])
|
||||
for pname, proc in list(_audio_processors.items()):
|
||||
try:
|
||||
proc.max_normalization_gain = max_gain
|
||||
logger.info(f"Updated max_normalization_gain to {max_gain} for processor: {pname}")
|
||||
except Exception:
|
||||
logger.debug(f"Failed to update max_normalization_gain for processor: {pname}")
|
||||
norm_updates = True
|
||||
|
||||
if norm_updates:
|
||||
config_applied = True
|
||||
logger.info("Normalization configuration updated for existing processors")
|
||||
|
||||
if config_applied:
|
||||
logger.info(f"Configuration update completed for lobby {lobby_id}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user