Working
This commit is contained in:
parent
e96bd887ab
commit
90c3c6e19b
@ -4,7 +4,6 @@
|
|||||||
!client
|
!client
|
||||||
!shared
|
!shared
|
||||||
**/node_modules
|
**/node_modules
|
||||||
**/build
|
|
||||||
**/dist
|
**/dist
|
||||||
**/__pycache__
|
**/__pycache__
|
||||||
**/.venv
|
**/.venv
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Lobby management for the AI Voice Bot server.
|
Lobby management for the AI Voice Bot server.handles lobby lifecycle, participants, and chat functionality.
|
||||||
|
|
||||||
This module handles lobby lifecycle, participants, and chat functionality.
|
|
||||||
Extracted from main.py to improve maintainability and separation of concerns.
|
Extracted from main.py to improve maintainability and separation of concerns.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -9,21 +7,30 @@ from __future__ import annotations
|
|||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
from typing import Dict, List, Optional, TYPE_CHECKING, Callable
|
||||||
|
import os
|
||||||
|
|
||||||
# Import shared models
|
# Import shared models
|
||||||
# Import shared models
|
# Import shared models
|
||||||
try:
|
try:
|
||||||
# Try relative import first (when running as part of the package)
|
# Try relative import first (when running as part of the package)
|
||||||
from ...shared.models import ChatMessageModel, ParticipantModel
|
from ...shared.models import (
|
||||||
|
ChatMessageModel,
|
||||||
|
ParticipantModel,
|
||||||
|
WebSocketMessageModel,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
try:
|
||||||
# Try absolute import (when running directly)
|
# Try absolute import (when running directly)
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
from shared.models import ChatMessageModel, ParticipantModel
|
from shared.models import (
|
||||||
except ImportError:
|
ChatMessageModel,
|
||||||
|
ParticipantModel,
|
||||||
|
WebSocketMessageModel,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Failed to import shared models: {e}. Ensure shared/models.py is accessible and PYTHONPATH is correctly set."
|
f"Failed to import shared models: {e}. Ensure shared/models.py is accessible and PYTHONPATH is correctly set."
|
||||||
)
|
)
|
||||||
@ -32,25 +39,21 @@ from shared.logger import logger
|
|||||||
|
|
||||||
# Use try/except for importing events to handle both relative and absolute imports
|
# Use try/except for importing events to handle both relative and absolute imports
|
||||||
try:
|
try:
|
||||||
from ..models.events import event_bus, ChatMessageSent, SessionDisconnected, SessionLeftLobby
|
from ..models.events import (
|
||||||
|
event_bus,
|
||||||
|
ChatMessageSent,
|
||||||
|
SessionDisconnected,
|
||||||
|
SessionLeftLobby,
|
||||||
|
Event,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
from models.events import (
|
||||||
from models.events import event_bus, ChatMessageSent, SessionDisconnected, SessionLeftLobby
|
event_bus,
|
||||||
except ImportError:
|
ChatMessageSent,
|
||||||
# Create dummy event system for standalone testing
|
SessionDisconnected,
|
||||||
class DummyEventBus:
|
SessionLeftLobby,
|
||||||
async def publish(self, event):
|
Event,
|
||||||
pass
|
)
|
||||||
event_bus = DummyEventBus()
|
|
||||||
|
|
||||||
class ChatMessageSent:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class SessionDisconnected:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class SessionLeftLobby:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .session_manager import Session
|
from .session_manager import Session
|
||||||
@ -62,34 +65,6 @@ class LobbyConfig:
|
|||||||
|
|
||||||
|
|
||||||
class Lobby:
|
class Lobby:
|
||||||
async def broadcast_json(self, message: dict) -> None:
|
|
||||||
"""Broadcast an arbitrary JSON message to all connected sessions in the lobby"""
|
|
||||||
failed_sessions: List[Session] = []
|
|
||||||
for peer in self.sessions.values():
|
|
||||||
if peer.ws:
|
|
||||||
try:
|
|
||||||
await peer.ws.send_json(message)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to send broadcast_json message to {peer.getName()}: {e}"
|
|
||||||
)
|
|
||||||
failed_sessions.append(peer)
|
|
||||||
for failed_session in failed_sessions:
|
|
||||||
failed_session.ws = None
|
|
||||||
async def broadcast_peer_state_update(self, update: dict) -> None:
|
|
||||||
"""Broadcast a peer state update to all connected sessions in the lobby"""
|
|
||||||
failed_sessions: List[Session] = []
|
|
||||||
for peer in self.sessions.values():
|
|
||||||
if peer.ws:
|
|
||||||
try:
|
|
||||||
await peer.ws.send_json(update)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to send peer state update to {peer.getName()}: {e}"
|
|
||||||
)
|
|
||||||
failed_sessions.append(peer)
|
|
||||||
for failed_session in failed_sessions:
|
|
||||||
failed_session.ws = None
|
|
||||||
"""Individual lobby representing a chat/voice room"""
|
"""Individual lobby representing a chat/voice room"""
|
||||||
|
|
||||||
def __init__(self, name: str, id: Optional[str] = None, private: bool = False):
|
def __init__(self, name: str, id: Optional[str] = None, private: bool = False):
|
||||||
@ -104,6 +79,36 @@ class Lobby:
|
|||||||
def getName(self) -> str:
|
def getName(self) -> str:
|
||||||
return f"{self.short}:{self.name}"
|
return f"{self.short}:{self.name}"
|
||||||
|
|
||||||
|
async def broadcast_json(self, message: WebSocketMessageModel) -> None:
|
||||||
|
"""Broadcast an arbitrary JSON message to all connected sessions in the lobby"""
|
||||||
|
failed_sessions: List[Session] = []
|
||||||
|
for peer in self.sessions.values():
|
||||||
|
if peer.ws:
|
||||||
|
try:
|
||||||
|
await peer.ws.send_json(message.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to send broadcast_json message to {peer.getName()}: {e}"
|
||||||
|
)
|
||||||
|
failed_sessions.append(peer)
|
||||||
|
for failed_session in failed_sessions:
|
||||||
|
failed_session.ws = None
|
||||||
|
|
||||||
|
async def broadcast_peer_state_update(self, update: WebSocketMessageModel) -> None:
|
||||||
|
"""Broadcast a peer state update to all connected sessions in the lobby"""
|
||||||
|
failed_sessions: List[Session] = []
|
||||||
|
for peer in self.sessions.values():
|
||||||
|
if peer.ws:
|
||||||
|
try:
|
||||||
|
await peer.ws.send_json(update.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to send peer state update to {peer.getName()}: {e}"
|
||||||
|
)
|
||||||
|
failed_sessions.append(peer)
|
||||||
|
for failed_session in failed_sessions:
|
||||||
|
failed_session.ws = None
|
||||||
|
|
||||||
async def update_state(self, requesting_session: Optional[Session] = None):
|
async def update_state(self, requesting_session: Optional[Session] = None):
|
||||||
"""Update lobby state and notify participants"""
|
"""Update lobby state and notify participants"""
|
||||||
with self.lock:
|
with self.lock:
|
||||||
@ -344,7 +349,7 @@ class LobbyManager:
|
|||||||
# Event system not available, skip subscriptions
|
# Event system not available, skip subscriptions
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def handle(self, event):
|
async def handle(self, event: Event) -> None:
|
||||||
"""Handle events from the event bus"""
|
"""Handle events from the event bus"""
|
||||||
|
|
||||||
if isinstance(event, SessionDisconnected):
|
if isinstance(event, SessionDisconnected):
|
||||||
@ -352,7 +357,7 @@ class LobbyManager:
|
|||||||
elif isinstance(event, SessionLeftLobby):
|
elif isinstance(event, SessionLeftLobby):
|
||||||
await self._handle_session_left_lobby(event)
|
await self._handle_session_left_lobby(event)
|
||||||
|
|
||||||
async def _handle_session_disconnected(self, event):
|
async def _handle_session_disconnected(self, event: SessionDisconnected) -> None:
|
||||||
"""Handle session disconnection by removing from all lobbies"""
|
"""Handle session disconnection by removing from all lobbies"""
|
||||||
session_id = event.session_id
|
session_id = event.session_id
|
||||||
|
|
||||||
@ -372,7 +377,7 @@ class LobbyManager:
|
|||||||
if lobby.is_empty() and not lobby.private:
|
if lobby.is_empty() and not lobby.private:
|
||||||
await self._cleanup_empty_lobby(lobby)
|
await self._cleanup_empty_lobby(lobby)
|
||||||
|
|
||||||
async def _handle_session_left_lobby(self, event):
|
async def _handle_session_left_lobby(self, event: SessionLeftLobby) -> None:
|
||||||
"""Handle explicit session leave"""
|
"""Handle explicit session leave"""
|
||||||
# This is already handled by the session's leave_lobby method
|
# This is already handled by the session's leave_lobby method
|
||||||
# but we could add additional cleanup logic here if needed
|
# but we could add additional cleanup logic here if needed
|
||||||
@ -447,8 +452,8 @@ class LobbyManager:
|
|||||||
|
|
||||||
return removed_count
|
return removed_count
|
||||||
|
|
||||||
def set_name_protection_checker(self, checker_func):
|
def set_name_protection_checker(self, checker_func: Callable[[str], bool]) -> None:
|
||||||
"""Inject name protection checker from AuthManager"""
|
"""Inject name protection checker from AuthManager"""
|
||||||
# This allows us to inject the name protection logic without tight coupling
|
# This allows us to inject the name protection logic without tight coupling
|
||||||
for lobby in self.lobbies.values():
|
for lobby in self.lobbies.values():
|
||||||
lobby._is_name_protected = checker_func
|
lobby._is_name_protected = checker_func # type: ignore
|
||||||
|
@ -10,7 +10,12 @@ from typing import Dict, Any, TYPE_CHECKING
|
|||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
|
|
||||||
from shared.logger import logger
|
from shared.logger import logger
|
||||||
from shared.models import ChatMessageModel
|
from shared.models import (
|
||||||
|
ChatMessageModel,
|
||||||
|
WebSocketMessageModel,
|
||||||
|
WebSocketErrorModel,
|
||||||
|
UpdateNameModel,
|
||||||
|
)
|
||||||
from .webrtc_signaling import WebRTCSignalingHandlers
|
from .webrtc_signaling import WebRTCSignalingHandlers
|
||||||
from core.error_handling import (
|
from core.error_handling import (
|
||||||
error_handler,
|
error_handler,
|
||||||
@ -45,11 +50,11 @@ class PeerStateUpdateHandler(MessageHandler):
|
|||||||
|
|
||||||
async def handle(
|
async def handle(
|
||||||
self,
|
self,
|
||||||
session: Any,
|
session: "Session",
|
||||||
lobby: Any,
|
lobby: "Lobby",
|
||||||
data: dict,
|
data: Dict[str, Any],
|
||||||
websocket: Any,
|
websocket: WebSocket,
|
||||||
managers: dict,
|
managers: Dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
# Only allow a user to update their own state
|
# Only allow a user to update their own state
|
||||||
if not lobby or not session:
|
if not lobby or not session:
|
||||||
@ -59,14 +64,14 @@ class PeerStateUpdateHandler(MessageHandler):
|
|||||||
# Ignore attempts to update other users' state
|
# Ignore attempts to update other users' state
|
||||||
# Optionally log or send error to client
|
# Optionally log or send error to client
|
||||||
return
|
return
|
||||||
update = {
|
update = WebSocketMessageModel(
|
||||||
"type": "peer_state_update",
|
type="peer_state_update",
|
||||||
"data": {
|
data={
|
||||||
"peer_id": peer_id,
|
"peer_id": peer_id,
|
||||||
"muted": data.get("muted"),
|
"muted": data.get("muted"),
|
||||||
"video_on": data.get("video_on"),
|
"video_on": data.get("video_on"),
|
||||||
},
|
}, # type: ignore
|
||||||
}
|
)
|
||||||
await lobby.broadcast_peer_state_update(update)
|
await lobby.broadcast_peer_state_update(update)
|
||||||
|
|
||||||
|
|
||||||
@ -86,10 +91,12 @@ class SetNameHandler(MessageHandler):
|
|||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
logger.error(f"{session.getName()} - set_name missing data")
|
logger.error(f"{session.getName()} - set_name missing data")
|
||||||
await websocket.send_json({
|
await websocket.send_json(
|
||||||
"type": "error",
|
WebSocketMessageModel(
|
||||||
"data": {"error": "set_name missing data"},
|
type="error",
|
||||||
})
|
data=WebSocketErrorModel(error="set_name missing data"),
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
name = data.get("name")
|
name = data.get("name")
|
||||||
@ -99,10 +106,11 @@ class SetNameHandler(MessageHandler):
|
|||||||
|
|
||||||
if not name:
|
if not name:
|
||||||
logger.error(f"{session.getName()} - Name required")
|
logger.error(f"{session.getName()} - Name required")
|
||||||
await websocket.send_json({
|
await websocket.send_json(
|
||||||
"type": "error",
|
WebSocketMessageModel(
|
||||||
"data": {"error": "Name required"}
|
type="error", data=WebSocketErrorModel(error="Name required")
|
||||||
})
|
).model_dump()
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if name is unique
|
# Check if name is unique
|
||||||
@ -114,13 +122,14 @@ class SetNameHandler(MessageHandler):
|
|||||||
session.setName(name)
|
session.setName(name)
|
||||||
logger.info(f"{session.getName()}: -> update('name', {name})")
|
logger.info(f"{session.getName()}: -> update('name', {name})")
|
||||||
|
|
||||||
await websocket.send_json({
|
await websocket.send_json(
|
||||||
"type": "update_name",
|
WebSocketMessageModel(
|
||||||
"data": {
|
type="update_name",
|
||||||
"name": name,
|
data=UpdateNameModel(
|
||||||
"protected": auth_manager.is_name_protected(name),
|
name=name, protected=auth_manager.is_name_protected(name)
|
||||||
},
|
),
|
||||||
})
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
# Update lobby state
|
# Update lobby state
|
||||||
await lobby.update_state()
|
await lobby.update_state()
|
||||||
@ -131,10 +140,11 @@ class SetNameHandler(MessageHandler):
|
|||||||
|
|
||||||
if not allowed:
|
if not allowed:
|
||||||
logger.warning(f"{session.getName()} - {reason}")
|
logger.warning(f"{session.getName()} - {reason}")
|
||||||
await websocket.send_json({
|
await websocket.send_json(
|
||||||
"type": "error",
|
WebSocketMessageModel(
|
||||||
"data": {"error": reason}
|
type="error", data=WebSocketErrorModel(error=reason)
|
||||||
})
|
).model_dump()
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Takeover allowed - handle displacement
|
# Takeover allowed - handle displacement
|
||||||
@ -179,13 +189,14 @@ class SetNameHandler(MessageHandler):
|
|||||||
session.setName(name)
|
session.setName(name)
|
||||||
logger.info(f"{session.getName()}: -> update('name', {name}) (takeover)")
|
logger.info(f"{session.getName()}: -> update('name', {name}) (takeover)")
|
||||||
|
|
||||||
await websocket.send_json({
|
await websocket.send_json(
|
||||||
"type": "update_name",
|
WebSocketMessageModel(
|
||||||
"data": {
|
type="update_name",
|
||||||
"name": name,
|
data=UpdateNameModel(
|
||||||
"protected": auth_manager.is_name_protected(name),
|
name=name, protected=auth_manager.is_name_protected(name)
|
||||||
},
|
),
|
||||||
})
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
# Update lobby state
|
# Update lobby state
|
||||||
await lobby.update_state()
|
await lobby.update_state()
|
||||||
@ -460,15 +471,15 @@ class MessageRouter:
|
|||||||
):
|
):
|
||||||
"""Route a message to the appropriate handler with enhanced error handling"""
|
"""Route a message to the appropriate handler with enhanced error handling"""
|
||||||
if message_type not in self._handlers:
|
if message_type not in self._handlers:
|
||||||
await error_handler.handle_error(
|
await error_handler.handle_error( # type: ignore
|
||||||
ValidationError(f"Unknown message type: {message_type}"),
|
ValidationError(f"Unknown message type: {message_type}"),
|
||||||
context={
|
context={
|
||||||
"message_type": message_type,
|
"message_type": message_type,
|
||||||
"session_id": session.id if session else "unknown",
|
"session_id": session.id if session else "unknown",
|
||||||
"data_keys": list(data.keys()) if data else []
|
"data_keys": list(data.keys()) if data else [],
|
||||||
},
|
},
|
||||||
websocket=websocket,
|
websocket=websocket,
|
||||||
session_id=session.id if session else None
|
session_id=session.id if session else None,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -480,48 +491,50 @@ class MessageRouter:
|
|||||||
|
|
||||||
except WebSocketError as e:
|
except WebSocketError as e:
|
||||||
# WebSocket specific errors - attempt recovery
|
# WebSocket specific errors - attempt recovery
|
||||||
await error_handler.handle_error(
|
await error_handler.handle_error( # type: ignore
|
||||||
e,
|
e,
|
||||||
context={
|
context={
|
||||||
"message_type": message_type,
|
"message_type": message_type,
|
||||||
"session_id": session.id if session else "unknown",
|
"session_id": session.id if session else "unknown",
|
||||||
"handler": type(self._handlers[message_type]).__name__
|
"handler": type(self._handlers[message_type]).__name__,
|
||||||
},
|
},
|
||||||
websocket=websocket,
|
websocket=websocket,
|
||||||
session_id=session.id if session else None,
|
session_id=session.id if session else None,
|
||||||
recovery_action=lambda: self._websocket_recovery(websocket, session)
|
recovery_action=lambda: self._websocket_recovery(websocket, session),
|
||||||
)
|
)
|
||||||
|
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
# Validation errors - usually client-side issues
|
# Validation errors - usually client-side issues
|
||||||
await error_handler.handle_error(
|
await error_handler.handle_error( # type: ignore
|
||||||
e,
|
e,
|
||||||
context={
|
context={
|
||||||
"message_type": message_type,
|
"message_type": message_type,
|
||||||
"session_id": session.id if session else "unknown",
|
"session_id": session.id if session else "unknown",
|
||||||
"data": str(data)[:500] # Truncate large data
|
"data": str(data)[:500], # Truncate large data
|
||||||
},
|
},
|
||||||
websocket=websocket,
|
websocket=websocket,
|
||||||
session_id=session.id if session else None
|
session_id=session.id if session else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Unexpected errors - enhanced logging and fallback
|
# Unexpected errors - enhanced logging and fallback
|
||||||
await error_handler.handle_error(
|
await error_handler.handle_error( # type: ignore
|
||||||
WebSocketError(
|
WebSocketError(
|
||||||
f"Unexpected error in {message_type} handler: {e}",
|
f"Unexpected error in {message_type} handler: {e}",
|
||||||
severity=ErrorSeverity.HIGH
|
severity=ErrorSeverity.HIGH,
|
||||||
),
|
),
|
||||||
context={
|
context={
|
||||||
"message_type": message_type,
|
"message_type": message_type,
|
||||||
"session_id": session.id if session else "unknown",
|
"session_id": session.id if session else "unknown",
|
||||||
"handler": type(self._handlers[message_type]).__name__,
|
"handler": type(self._handlers[message_type]).__name__,
|
||||||
"exception_type": type(e).__name__,
|
"exception_type": type(e).__name__,
|
||||||
"traceback": str(e)
|
"traceback": str(e),
|
||||||
},
|
},
|
||||||
websocket=websocket,
|
websocket=websocket,
|
||||||
session_id=session.id if session else None,
|
session_id=session.id if session else None,
|
||||||
recovery_action=lambda: self._generic_recovery(message_type, session, lobby)
|
recovery_action=lambda: self._generic_recovery(
|
||||||
|
message_type, session, lobby
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _websocket_recovery(self, websocket: WebSocket, session: "Session"):
|
async def _websocket_recovery(self, websocket: WebSocket, session: "Session"):
|
||||||
|
@ -65,43 +65,6 @@ _device = "GPU.1" # Default to Intel Arc B580 GPU
|
|||||||
_generate_global_lock = threading.Lock()
|
_generate_global_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def _blocking_generate_decode(audio_array: AudioArray, sample_rate: int, generation_config: GenerationConfig | None = None) -> str:
|
|
||||||
"""Blocking helper to run processor -> model.generate -> decode while
|
|
||||||
holding a global lock to serialize OpenVINO access.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with _generate_global_lock:
|
|
||||||
ov_model = _ensure_model_loaded()
|
|
||||||
if ov_model.processor is None:
|
|
||||||
raise RuntimeError("Processor not initialized for OpenVINO model")
|
|
||||||
|
|
||||||
# Extract features
|
|
||||||
inputs = ov_model.processor(audio_array, sampling_rate=sample_rate, return_tensors="pt")
|
|
||||||
input_features = inputs.input_features
|
|
||||||
|
|
||||||
# Use a basic generation config if none provided
|
|
||||||
gen_cfg = generation_config or GenerationConfig(max_new_tokens=128)
|
|
||||||
|
|
||||||
gen_out = ov_model.ov_model.generate(input_features, generation_config=gen_cfg) # type: ignore
|
|
||||||
|
|
||||||
# Prefer .sequences if available
|
|
||||||
if hasattr(gen_out, "sequences"):
|
|
||||||
ids = gen_out.sequences
|
|
||||||
else:
|
|
||||||
ids = gen_out
|
|
||||||
|
|
||||||
# Decode
|
|
||||||
try:
|
|
||||||
transcription = ov_model.processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
|
|
||||||
except Exception:
|
|
||||||
transcription = ""
|
|
||||||
|
|
||||||
return transcription
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"blocking_generate_decode failed: {e}", exc_info=True)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_devices() -> list[dict[str, Any]]:
|
def get_available_devices() -> list[dict[str, Any]]:
|
||||||
"""List available OpenVINO devices with their properties."""
|
"""List available OpenVINO devices with their properties."""
|
||||||
try:
|
try:
|
||||||
@ -230,7 +193,7 @@ class OpenVINOConfig(BaseModel):
|
|||||||
cfg.update(
|
cfg.update(
|
||||||
{
|
{
|
||||||
"CPU_THROUGHPUT_NUM_THREADS": str(self.max_threads),
|
"CPU_THROUGHPUT_NUM_THREADS": str(self.max_threads),
|
||||||
"CPU_BIND_THREAD": "YES",
|
# "CPU_BIND_THREAD": "YES", # Removed: not supported by CPU plugin
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -245,7 +208,7 @@ CHUNK_DURATION_MS = 100 # Reduced latency - 100ms chunks
|
|||||||
VAD_THRESHOLD = 0.01 # Initial voice activity detection threshold
|
VAD_THRESHOLD = 0.01 # Initial voice activity detection threshold
|
||||||
MAX_SILENCE_FRAMES = 30 # 3 seconds of silence before stopping (for overall silence)
|
MAX_SILENCE_FRAMES = 30 # 3 seconds of silence before stopping (for overall silence)
|
||||||
MAX_TRAILING_SILENCE_FRAMES = 5 # 0.5 seconds of trailing silence
|
MAX_TRAILING_SILENCE_FRAMES = 5 # 0.5 seconds of trailing silence
|
||||||
VAD_CONFIG = {
|
VAD_CONFIG: Dict[str, Any] = {
|
||||||
"energy_threshold": 0.01,
|
"energy_threshold": 0.01,
|
||||||
"zcr_threshold": 0.1,
|
"zcr_threshold": 0.1,
|
||||||
"adapt_thresholds": True,
|
"adapt_thresholds": True,
|
||||||
@ -301,7 +264,7 @@ def setup_intel_arc_environment() -> None:
|
|||||||
class AdvancedVAD:
|
class AdvancedVAD:
|
||||||
"""Advanced Voice Activity Detection with noise rejection."""
|
"""Advanced Voice Activity Detection with noise rejection."""
|
||||||
|
|
||||||
def __init__(self, sample_rate: int = SAMPLE_RATE):
|
def __init__(self, sample_rate: int = 16000):
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
# More permissive thresholds based on research
|
# More permissive thresholds based on research
|
||||||
self.energy_threshold = 0.005 # Reduced from 0.02
|
self.energy_threshold = 0.005 # Reduced from 0.02
|
||||||
@ -315,7 +278,7 @@ class AdvancedVAD:
|
|||||||
|
|
||||||
# Relaxed temporal consistency
|
# Relaxed temporal consistency
|
||||||
self.minimum_duration = 0.2 # Reduced from 0.3s
|
self.minimum_duration = 0.2 # Reduced from 0.3s
|
||||||
self.speech_history = []
|
self.speech_history: List[bool] = []
|
||||||
self.max_history = 8 # Reduced from 10
|
self.max_history = 8 # Reduced from 10
|
||||||
|
|
||||||
# Adaptive noise floor
|
# Adaptive noise floor
|
||||||
@ -327,7 +290,7 @@ class AdvancedVAD:
|
|||||||
self.prev_magnitude = None
|
self.prev_magnitude = None
|
||||||
self.harmonic_threshold = 0.15 # Reduced from 0.3
|
self.harmonic_threshold = 0.15 # Reduced from 0.3
|
||||||
|
|
||||||
def analyze_frame(self, audio_data: AudioArray) -> Tuple[bool, dict]:
|
def analyze_frame(self, audio_data: AudioArray) -> Tuple[bool, Dict[str, Any]]:
|
||||||
"""Analyze audio frame for speech vs noise."""
|
"""Analyze audio frame for speech vs noise."""
|
||||||
|
|
||||||
# Basic energy features
|
# Basic energy features
|
||||||
@ -403,7 +366,7 @@ class AdvancedVAD:
|
|||||||
(1 - self.adaptation_rate) * self.noise_floor_energy
|
(1 - self.adaptation_rate) * self.noise_floor_energy
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics = {
|
metrics: Dict[str, Any] = {
|
||||||
'energy': energy,
|
'energy': energy,
|
||||||
'zcr': zcr,
|
'zcr': zcr,
|
||||||
'centroid': spectral_features['centroid'],
|
'centroid': spectral_features['centroid'],
|
||||||
@ -419,9 +382,9 @@ class AdvancedVAD:
|
|||||||
'temporal_consistency': recent_speech
|
'temporal_consistency': recent_speech
|
||||||
}
|
}
|
||||||
|
|
||||||
return recent_speech, metrics
|
return recent_speech, metrics # type: ignore
|
||||||
|
|
||||||
def _compute_spectral_features(self, audio_data: AudioArray) -> dict:
|
def _compute_spectral_features(self, audio_data: AudioArray) -> Dict[str, Any]:
|
||||||
"""Compute spectral features for speech detection."""
|
"""Compute spectral features for speech detection."""
|
||||||
|
|
||||||
# Apply window to reduce spectral leakage
|
# Apply window to reduce spectral leakage
|
||||||
@ -464,7 +427,7 @@ class AdvancedVAD:
|
|||||||
'harmonicity': harmonicity
|
'harmonicity': harmonicity
|
||||||
}
|
}
|
||||||
|
|
||||||
def _compute_harmonicity(self, magnitude: np.ndarray, freqs: np.ndarray) -> float:
|
def _compute_harmonicity(self, magnitude: npt.NDArray[np.float32], freqs: npt.NDArray[np.float32]) -> float:
|
||||||
"""Compute harmonicity score (0-1, higher = more harmonic/speech-like)."""
|
"""Compute harmonicity score (0-1, higher = more harmonic/speech-like)."""
|
||||||
|
|
||||||
# Find fundamental frequency candidate (peak in 80-400Hz range for speech)
|
# Find fundamental frequency candidate (peak in 80-400Hz range for speech)
|
||||||
@ -483,24 +446,24 @@ class AdvancedVAD:
|
|||||||
# More robust F0 detection - find peaks instead of just max
|
# More robust F0 detection - find peaks instead of just max
|
||||||
try:
|
try:
|
||||||
# Import scipy here to handle missing dependency gracefully
|
# Import scipy here to handle missing dependency gracefully
|
||||||
from scipy.signal import find_peaks
|
from scipy.signal import find_peaks # type: ignore
|
||||||
|
|
||||||
# Ensure distance is at least 1
|
# Ensure distance is at least 1
|
||||||
min_distance = max(1, int(len(speech_magnitude) * 0.05))
|
min_distance = max(1, int(len(speech_magnitude) * 0.05))
|
||||||
|
|
||||||
peaks, properties = find_peaks(
|
peaks, properties = find_peaks( # type: ignore
|
||||||
speech_magnitude,
|
speech_magnitude,
|
||||||
height=np.max(speech_magnitude) * 0.05, # Lowered from 0.1
|
height=np.max(speech_magnitude) * 0.05, # Lowered from 0.1
|
||||||
distance=min_distance, # Minimum peak separation
|
distance=min_distance, # Minimum peak separation
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(peaks) == 0:
|
if len(peaks) == 0: # type: ignore
|
||||||
# Fallback to simple max if no peaks found
|
# Fallback to simple max if no peaks found
|
||||||
f0_idx = np.argmax(speech_magnitude)
|
f0_idx = np.argmax(speech_magnitude)
|
||||||
else:
|
else:
|
||||||
# Use the strongest peak
|
# Use the strongest peak
|
||||||
strongest_peak_idx = np.argmax(speech_magnitude[peaks])
|
strongest_peak_idx = np.argmax(speech_magnitude[peaks])
|
||||||
f0_idx = peaks[strongest_peak_idx]
|
f0_idx = int(peaks[strongest_peak_idx]) # type: ignore
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# scipy not available, use simple max
|
# scipy not available, use simple max
|
||||||
@ -526,8 +489,8 @@ class AdvancedVAD:
|
|||||||
harmonic_idx = np.argmin(np.abs(freqs - harmonic_freq))
|
harmonic_idx = np.argmin(np.abs(freqs - harmonic_freq))
|
||||||
|
|
||||||
# Check a small neighborhood around the harmonic frequency
|
# Check a small neighborhood around the harmonic frequency
|
||||||
start_idx = max(0, harmonic_idx - 2)
|
start_idx = max(0, int(harmonic_idx) - 2)
|
||||||
end_idx = min(len(magnitude), harmonic_idx + 3)
|
end_idx = min(len(magnitude), int(harmonic_idx) + 3)
|
||||||
local_max = np.max(magnitude[start_idx:end_idx])
|
local_max = np.max(magnitude[start_idx:end_idx])
|
||||||
|
|
||||||
harmonic_strength += local_max
|
harmonic_strength += local_max
|
||||||
@ -565,13 +528,13 @@ class OpenVINOWhisperModel:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Loading Whisper model '{self.model_id}' on device: {self.device}"
|
f"Loading Whisper model '{self.model_id}' on device: {self.device}"
|
||||||
)
|
)
|
||||||
self.processor = WhisperProcessor.from_pretrained(
|
self.processor = WhisperProcessor.from_pretrained( # type: ignore
|
||||||
self.model_id, use_fast=True
|
self.model_id, use_fast=True
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
logger.info("Whisper processor loaded successfully")
|
logger.info("Whisper processor loaded successfully")
|
||||||
|
|
||||||
# Export the model to OpenVINO IR if not already converted
|
# Export the model to OpenVINO IR if not already converted
|
||||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained( # type: ignore
|
||||||
self.model_id, export=True, device=self.device
|
self.model_id, export=True, device=self.device
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
@ -614,7 +577,7 @@ class OpenVINOWhisperModel:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert to OpenVINO with FP16 for Arc GPU
|
# Convert to OpenVINO with FP16 for Arc GPU
|
||||||
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
ov_model = OVModelForSpeechSeq2Seq.from_pretrained( # type: ignore
|
||||||
self.model_id,
|
self.model_id,
|
||||||
ov_config=self.config.to_ov_config(),
|
ov_config=self.config.to_ov_config(),
|
||||||
export=True,
|
export=True,
|
||||||
@ -623,12 +586,13 @@ class OpenVINOWhisperModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Enable FP16 for Intel Arc performance
|
# Enable FP16 for Intel Arc performance
|
||||||
ov_model.half()
|
if hasattr(ov_model, 'half'):
|
||||||
ov_model.save_pretrained(self.model_path)
|
ov_model.half() # type: ignore
|
||||||
|
ov_model.save_pretrained(self.model_path) # type: ignore
|
||||||
logger.info("Model converted and saved in FP16 format")
|
logger.info("Model converted and saved in FP16 format")
|
||||||
|
|
||||||
# Load the converted model
|
# Load the converted model
|
||||||
self.ov_model = ov_model
|
self.ov_model = ov_model # type: ignore
|
||||||
self._compile_model()
|
self._compile_model()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -639,38 +603,38 @@ class OpenVINOWhisperModel:
|
|||||||
"""Basic model conversion without advanced features."""
|
"""Basic model conversion without advanced features."""
|
||||||
logger.info(f"Basic conversion of {self.model_id} to OpenVINO format...")
|
logger.info(f"Basic conversion of {self.model_id} to OpenVINO format...")
|
||||||
|
|
||||||
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
|
||||||
self.model_id, export=True, compile=False
|
self.model_id, export=True, compile=False
|
||||||
)
|
)
|
||||||
|
|
||||||
ov_model.save_pretrained(self.model_path)
|
ov_model.save_pretrained(self.model_path)# type: ignore
|
||||||
logger.info("Basic model conversion completed")
|
logger.info("Basic model conversion completed")
|
||||||
|
|
||||||
def _load_fp16_model(self) -> None:
|
def _load_fp16_model(self) -> None:
|
||||||
"""Load existing FP16 OpenVINO model."""
|
"""Load existing FP16 OpenVINO model."""
|
||||||
logger.info("Loading existing FP16 OpenVINO model...")
|
logger.info("Loading existing FP16 OpenVINO model...")
|
||||||
try:
|
try:
|
||||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
|
||||||
self.model_path, ov_config=self.config.to_ov_config(), compile=False
|
self.model_path, ov_config=self.config.to_ov_config(), compile=False
|
||||||
)
|
) # type: ignore
|
||||||
self._compile_model()
|
self._compile_model()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load FP16 model: {e}")
|
logger.error(f"Failed to load FP16 model: {e}")
|
||||||
# Try basic loading
|
# Try basic loading
|
||||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
|
||||||
self.model_path, compile=False
|
self.model_path, compile=False
|
||||||
)
|
) # type: ignore
|
||||||
self._compile_model()
|
self._compile_model()
|
||||||
|
|
||||||
def _try_load_quantized_model(self) -> bool:
|
def _try_load_quantized_model(self) -> bool:
|
||||||
"""Try to load existing quantized model."""
|
"""Try to load existing quantized model."""
|
||||||
try:
|
try:
|
||||||
logger.info("Loading existing INT8 quantized model...")
|
logger.info("Loading existing INT8 quantized model...")
|
||||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
|
||||||
self.quantized_model_path,
|
self.quantized_model_path,
|
||||||
ov_config=self.config.to_ov_config(),
|
ov_config=self.config.to_ov_config(),
|
||||||
compile=False,
|
compile=False,
|
||||||
)
|
) # type: ignore
|
||||||
self._compile_model()
|
self._compile_model()
|
||||||
self.is_quantized = True
|
self.is_quantized = True
|
||||||
logger.info("Quantized model loaded successfully")
|
logger.info("Quantized model loaded successfully")
|
||||||
@ -690,13 +654,12 @@ class OpenVINOWhisperModel:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Check if model components are available
|
# Check if model components are available
|
||||||
if not hasattr(self.ov_model, "encoder") or self.ov_model.encoder is None:
|
if not hasattr(self.ov_model, "encoder"):
|
||||||
logger.warning("Model encoder not available, skipping quantization")
|
logger.warning("Model encoder not available, skipping quantization")
|
||||||
return
|
return
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not hasattr(self.ov_model, "decoder_with_past")
|
not hasattr(self.ov_model, "decoder_with_past")
|
||||||
or self.ov_model.decoder_with_past is None
|
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Model decoder_with_past not available, skipping quantization"
|
"Model decoder_with_past not available, skipping quantization"
|
||||||
@ -761,14 +724,14 @@ class OpenVINOWhisperModel:
|
|||||||
|
|
||||||
# Save quantized models
|
# Save quantized models
|
||||||
self.quantized_model_path.mkdir(parents=True, exist_ok=True)
|
self.quantized_model_path.mkdir(parents=True, exist_ok=True)
|
||||||
ov.save_model(
|
ov.save_model(# type: ignore
|
||||||
quantized_encoder,
|
quantized_encoder,
|
||||||
self.quantized_model_path / "openvino_encoder_model.xml",
|
self.quantized_model_path / "openvino_encoder_model.xml",
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
ov.save_model(
|
ov.save_model(# type: ignore
|
||||||
quantized_decoder,
|
quantized_decoder,
|
||||||
self.quantized_model_path / "openvino_decoder_with_past_model.xml",
|
self.quantized_model_path / "openvino_decoder_with_past_model.xml",
|
||||||
) # type: ignore
|
) # type: ignore # type: ignore
|
||||||
|
|
||||||
# Copy remaining files
|
# Copy remaining files
|
||||||
self._copy_model_files()
|
self._copy_model_files()
|
||||||
@ -828,12 +791,12 @@ class OpenVINOWhisperModel:
|
|||||||
decoder_data: CalibrationData = []
|
decoder_data: CalibrationData = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.ov_model.encoder.request = InferRequestWrapper(
|
self.ov_model.encoder.request = InferRequestWrapper(# type: ignore
|
||||||
original_encoder_request, encoder_data
|
original_encoder_request, encoder_data# type: ignore
|
||||||
)
|
) # type: ignore
|
||||||
self.ov_model.decoder_with_past.request = InferRequestWrapper(
|
self.ov_model.decoder_with_past.request = InferRequestWrapper(
|
||||||
original_decoder_request, decoder_data
|
original_decoder_request, decoder_data
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
# Generate synthetic calibration data instead of loading dataset
|
# Generate synthetic calibration data instead of loading dataset
|
||||||
logger.info("Generating synthetic calibration data...")
|
logger.info("Generating synthetic calibration data...")
|
||||||
@ -842,17 +805,17 @@ class OpenVINOWhisperModel:
|
|||||||
# Generate random audio similar to speech
|
# Generate random audio similar to speech
|
||||||
duration = 2.0 + np.random.random() * 3.0 # 2-5 seconds
|
duration = 2.0 + np.random.random() * 3.0 # 2-5 seconds
|
||||||
synthetic_audio = (
|
synthetic_audio = (
|
||||||
np.random.randn(int(SAMPLE_RATE * duration)).astype(np.float32)
|
np.random.randn(int(16000 * duration)).astype(np.float32)
|
||||||
* 0.1
|
* 0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs: Any = self.processor(
|
inputs: Any = self.processor(# type: ignore
|
||||||
synthetic_audio, sampling_rate=SAMPLE_RATE, return_tensors="pt"
|
synthetic_audio, sampling_rate=16000, return_tensors="pt"
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
# Run inference to collect calibration data
|
# Run inference to collect calibration data
|
||||||
generated_ids = self.ov_model.generate(
|
_ = self.ov_model.generate( # type: ignore
|
||||||
inputs.input_features, max_new_tokens=10
|
inputs.input_features, max_new_tokens=10 # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
if i % 5 == 0:
|
if i % 5 == 0:
|
||||||
@ -882,7 +845,7 @@ class OpenVINOWhisperModel:
|
|||||||
result["decoder"] = decoder_data
|
result["decoder"] = decoder_data
|
||||||
logger.info(f"Collected {len(decoder_data)} decoder calibration samples")
|
logger.info(f"Collected {len(decoder_data)} decoder calibration samples")
|
||||||
|
|
||||||
return result
|
return result # type: ignore
|
||||||
|
|
||||||
def _copy_model_files(self) -> None:
|
def _copy_model_files(self) -> None:
|
||||||
"""Copy necessary model files for quantized model."""
|
"""Copy necessary model files for quantized model."""
|
||||||
@ -951,28 +914,29 @@ class OpenVINOWhisperModel:
|
|||||||
)
|
)
|
||||||
# Try to reload using the existing saved model path if possible
|
# Try to reload using the existing saved model path if possible
|
||||||
try:
|
try:
|
||||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
|
||||||
self.model_path, ov_config=cpu_cfg.to_ov_config(), compile=False
|
self.model_path, ov_config=cpu_cfg.to_ov_config(), compile=False
|
||||||
)
|
) # type: ignore
|
||||||
except Exception:
|
except Exception:
|
||||||
# If loading the saved model failed, try loading without ov_config
|
# If loading the saved model failed, try loading without ov_config
|
||||||
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
|
self.ov_model = OVModelForSpeechSeq2Seq.from_pretrained(# type: ignore
|
||||||
self.model_path, compile=False
|
self.model_path, compile=False
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
# Compile on CPU
|
# Compile on CPU
|
||||||
self.ov_model.to("CPU")
|
if self.ov_model is not None:
|
||||||
# Provide CPU-only ov_config if supported
|
self.ov_model.to("CPU") # type: ignore
|
||||||
try:
|
# Provide CPU-only ov_config if supported
|
||||||
self.ov_model.compile()
|
try:
|
||||||
except Exception as compile_cpu_e:
|
self.ov_model.compile() # type: ignore
|
||||||
logger.warning(
|
except Exception as compile_cpu_e:
|
||||||
f"CPU compile with CPU ov_config failed, retrying default compile: {compile_cpu_e}"
|
logger.warning(
|
||||||
)
|
f"CPU compile with CPU ov_config failed, retrying default compile: {compile_cpu_e}"
|
||||||
self.ov_model.compile()
|
)
|
||||||
|
self.ov_model.compile() # type: ignore
|
||||||
|
|
||||||
self._warmup_model()
|
self._warmup_model()
|
||||||
logger.info("Model compiled for CPU successfully")
|
logger.info("Model compiled for CPU successfully")
|
||||||
except Exception as cpu_e:
|
except Exception as cpu_e:
|
||||||
logger.error(f"Failed to compile for CPU as well: {cpu_e}")
|
logger.error(f"Failed to compile for CPU as well: {cpu_e}")
|
||||||
raise
|
raise
|
||||||
@ -984,14 +948,14 @@ class OpenVINOWhisperModel:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("Warming up model...")
|
logger.info("Warming up model...")
|
||||||
dummy_audio = np.random.randn(SAMPLE_RATE).astype(np.float32) # 1 second
|
dummy_audio = np.random.randn(16000).astype(np.float32) # 1 second
|
||||||
dummy_features = self.processor(
|
dummy_features = self.processor(# type: ignore
|
||||||
dummy_audio, sampling_rate=SAMPLE_RATE, return_tensors="pt"
|
dummy_audio, sampling_rate=16000, return_tensors="pt"
|
||||||
).input_features
|
).input_features
|
||||||
|
|
||||||
# Run warmup iterations
|
# Run warmup iterations
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
_ = self.ov_model.generate(dummy_features, max_new_tokens=10)
|
_ = self.ov_model.generate(dummy_features, max_new_tokens=10)# type: ignore
|
||||||
if i == 0:
|
if i == 0:
|
||||||
logger.debug("First warmup iteration completed")
|
logger.debug("First warmup iteration completed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1004,9 +968,9 @@ class OpenVINOWhisperModel:
|
|||||||
if self.processor is None:
|
if self.processor is None:
|
||||||
raise RuntimeError("Processor not initialized")
|
raise RuntimeError("Processor not initialized")
|
||||||
|
|
||||||
return self.processor.batch_decode(
|
return self.processor.batch_decode(# type: ignore
|
||||||
token_ids, skip_special_tokens=skip_special_tokens
|
token_ids, skip_special_tokens=skip_special_tokens
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# Global model instance with deferred loading
|
# Global model instance with deferred loading
|
||||||
@ -1046,12 +1010,12 @@ def extract_input_features(audio_array: AudioArray, sampling_rate: int) -> torch
|
|||||||
if ov_model.processor is None:
|
if ov_model.processor is None:
|
||||||
raise RuntimeError("Processor not initialized")
|
raise RuntimeError("Processor not initialized")
|
||||||
|
|
||||||
inputs = ov_model.processor(
|
inputs = ov_model.processor(# type: ignore
|
||||||
audio_array,
|
audio_array,
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
) # type: ignore
|
||||||
return inputs.input_features
|
return inputs.input_features # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class VoiceActivityDetector(BaseModel):
|
class VoiceActivityDetector(BaseModel):
|
||||||
@ -1064,7 +1028,7 @@ class VoiceActivityDetector(BaseModel):
|
|||||||
def simple_robust_vad(
|
def simple_robust_vad(
|
||||||
audio_data: AudioArray,
|
audio_data: AudioArray,
|
||||||
energy_threshold: float = 0.01,
|
energy_threshold: float = 0.01,
|
||||||
sample_rate: int = SAMPLE_RATE,
|
sample_rate: int = 16000,
|
||||||
) -> VoiceActivityDetector:
|
) -> VoiceActivityDetector:
|
||||||
"""Simplified robust VAD."""
|
"""Simplified robust VAD."""
|
||||||
|
|
||||||
@ -1091,7 +1055,7 @@ def enhanced_vad(
|
|||||||
audio_data: AudioArray,
|
audio_data: AudioArray,
|
||||||
energy_threshold: float = 0.01,
|
energy_threshold: float = 0.01,
|
||||||
zcr_threshold: float = 0.1,
|
zcr_threshold: float = 0.1,
|
||||||
sample_rate: int = SAMPLE_RATE,
|
sample_rate: int = 16000,
|
||||||
) -> VoiceActivityDetector:
|
) -> VoiceActivityDetector:
|
||||||
"""Enhanced VAD using multiple features.
|
"""Enhanced VAD using multiple features.
|
||||||
|
|
||||||
@ -1137,14 +1101,39 @@ class OptimizedAudioProcessor:
|
|||||||
self.peer_name = peer_name
|
self.peer_name = peer_name
|
||||||
self.send_chat_func = send_chat_func
|
self.send_chat_func = send_chat_func
|
||||||
self.create_chat_message_func = create_chat_message_func
|
self.create_chat_message_func = create_chat_message_func
|
||||||
self.sample_rate = SAMPLE_RATE
|
|
||||||
|
# Audio processing settings (use defaults, can be overridden per instance)
|
||||||
|
self.sample_rate = 16000 # Default Whisper sample rate
|
||||||
|
self.chunk_duration_ms = 100 # Default chunk duration
|
||||||
|
self.chunk_size = int(self.sample_rate * self.chunk_duration_ms / 1000)
|
||||||
|
|
||||||
|
# Silence handling parameters
|
||||||
|
self.max_silence_frames = 30 # Default max silence frames
|
||||||
|
self.max_trailing_silence_frames = 5 # Default trailing silence frames
|
||||||
|
|
||||||
|
# VAD settings (use defaults, can be overridden per instance)
|
||||||
|
self.vad_energy_threshold = 0.005
|
||||||
|
self.vad_zcr_min = 0.02
|
||||||
|
self.vad_zcr_max = 0.8
|
||||||
|
self.vad_spectral_centroid_min = 200
|
||||||
|
self.vad_spectral_centroid_max = 4000
|
||||||
|
self.vad_spectral_rolloff_threshold = 3000
|
||||||
|
self.vad_minimum_duration = 0.2
|
||||||
|
self.vad_max_history = 8
|
||||||
|
self.vad_noise_floor_energy = 0.001
|
||||||
|
self.vad_adaptation_rate = 0.05
|
||||||
|
self.vad_harmonic_threshold = 0.15
|
||||||
|
|
||||||
|
# Normalization settings
|
||||||
|
self.normalization_enabled = True # Default normalization enabled
|
||||||
|
self.normalization_target_peak = 0.7 # Default target peak
|
||||||
|
self.max_normalization_gain = 10.0 # Default max gain
|
||||||
|
|
||||||
# Initialize visualization buffer if not already done
|
# Initialize visualization buffer if not already done
|
||||||
if self.peer_name not in WaveformVideoTrack.buffer:
|
if self.peer_name not in WaveformVideoTrack.buffer:
|
||||||
WaveformVideoTrack.buffer[self.peer_name] = np.array([], dtype=np.float32)
|
WaveformVideoTrack.buffer[self.peer_name] = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
# Optimized buffering parameters
|
# Optimized buffering parameters
|
||||||
self.chunk_size = int(self.sample_rate * CHUNK_DURATION_MS / 1000)
|
|
||||||
self.buffer_size = self.chunk_size * 50
|
self.buffer_size = self.chunk_size * 50
|
||||||
|
|
||||||
# Circular buffer for zero-copy operations
|
# Circular buffer for zero-copy operations
|
||||||
@ -1154,8 +1143,6 @@ class OptimizedAudioProcessor:
|
|||||||
|
|
||||||
# Silence handling parameters
|
# Silence handling parameters
|
||||||
self.silence_frames: int = 0
|
self.silence_frames: int = 0
|
||||||
self.max_silence_frames: int = MAX_SILENCE_FRAMES
|
|
||||||
self.max_trailing_silence_frames: int = MAX_TRAILING_SILENCE_FRAMES
|
|
||||||
|
|
||||||
# Enhanced VAD parameters with EMA for noise adaptation
|
# Enhanced VAD parameters with EMA for noise adaptation
|
||||||
self.advanced_vad = AdvancedVAD(sample_rate=self.sample_rate)
|
self.advanced_vad = AdvancedVAD(sample_rate=self.sample_rate)
|
||||||
@ -1165,9 +1152,6 @@ class OptimizedAudioProcessor:
|
|||||||
# maximum which helps models expect a consistent level across peers.
|
# maximum which helps models expect a consistent level across peers.
|
||||||
# It's intentionally permissive and capped to avoid amplifying noise.
|
# It's intentionally permissive and capped to avoid amplifying noise.
|
||||||
self.max_observed_amplitude: float = 1e-6
|
self.max_observed_amplitude: float = 1e-6
|
||||||
self.normalization_enabled: bool = True
|
|
||||||
self.normalization_target_peak: float = 0.95
|
|
||||||
self.max_normalization_gain: float = 3.0 # avoid amplifying tiny noise too much
|
|
||||||
|
|
||||||
# Processing state
|
# Processing state
|
||||||
self.current_phrase_audio = np.array([], dtype=np.float32)
|
self.current_phrase_audio = np.array([], dtype=np.float32)
|
||||||
@ -1476,9 +1460,9 @@ class OptimizedAudioProcessor:
|
|||||||
ov_model = _ensure_model_loaded()
|
ov_model = _ensure_model_loaded()
|
||||||
|
|
||||||
# Extract features (this is relatively cheap but keep on thread)
|
# Extract features (this is relatively cheap but keep on thread)
|
||||||
input_features = ov_model.processor(
|
input_features = ov_model.processor(# type: ignore
|
||||||
audio_in, sampling_rate=self.sample_rate, return_tensors="pt"
|
audio_in, sampling_rate=self.sample_rate, return_tensors="pt"
|
||||||
).input_features
|
).input_features # type: ignore
|
||||||
|
|
||||||
# Perform generation (blocking)
|
# Perform generation (blocking)
|
||||||
# Use the same generation configuration as the async path
|
# Use the same generation configuration as the async path
|
||||||
@ -1496,23 +1480,24 @@ class OptimizedAudioProcessor:
|
|||||||
# Serialize access to the underlying OpenVINO generation call
|
# Serialize access to the underlying OpenVINO generation call
|
||||||
# to avoid concurrency problems with the OpenVINO runtime.
|
# to avoid concurrency problems with the OpenVINO runtime.
|
||||||
with _generate_global_lock:
|
with _generate_global_lock:
|
||||||
gen_out = ov_model.ov_model.generate(
|
gen_out = ov_model.ov_model.generate(# type: ignore
|
||||||
input_features, generation_config=gen_cfg
|
input_features, generation_config=gen_cfg# type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to extract sequences if present
|
# Try to extract sequences if present
|
||||||
if hasattr(gen_out, "sequences"):
|
if hasattr(gen_out, "sequences"): # type: ignore
|
||||||
ids = gen_out.sequences
|
ids = gen_out.sequences # type: ignore
|
||||||
else:
|
else:
|
||||||
ids = gen_out
|
ids = gen_out # type: ignore
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
|
text: str = ""
|
||||||
try:
|
try:
|
||||||
text = ov_model.processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
|
text = ov_model.processor.batch_decode(ids, skip_special_tokens=True)[0].strip() # type: ignore
|
||||||
except Exception:
|
except Exception:
|
||||||
text = ""
|
text = ""
|
||||||
|
|
||||||
return text, 0.0
|
return text, 0.0 # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Blocking transcription failed for {self.peer_name}: {e}", exc_info=True)
|
logger.error(f"Blocking transcription failed for {self.peer_name}: {e}", exc_info=True)
|
||||||
return "", 0.0
|
return "", 0.0
|
||||||
@ -1933,7 +1918,7 @@ class OptimizedAudioProcessor:
|
|||||||
# Many generate implementations return an object with a
|
# Many generate implementations return an object with a
|
||||||
# `.sequences` attribute, so prefer that when available.
|
# `.sequences` attribute, so prefer that when available.
|
||||||
if hasattr(generation_output, "sequences"):
|
if hasattr(generation_output, "sequences"):
|
||||||
generated_ids = generation_output.sequences
|
generated_ids = generation_output.sequences # type: ignore
|
||||||
else:
|
else:
|
||||||
generated_ids = generation_output
|
generated_ids = generation_output
|
||||||
|
|
||||||
@ -1958,9 +1943,9 @@ class OptimizedAudioProcessor:
|
|||||||
# Primary decode attempt
|
# Primary decode attempt
|
||||||
transcription: str = ""
|
transcription: str = ""
|
||||||
try:
|
try:
|
||||||
transcription = ov_model.processor.batch_decode(
|
transcription = ov_model.processor.batch_decode(# type: ignore
|
||||||
generated_ids, skip_special_tokens=True
|
generated_ids, skip_special_tokens=True
|
||||||
)[0].strip()
|
)[0].strip() # type: ignore
|
||||||
except Exception as decode_e:
|
except Exception as decode_e:
|
||||||
logger.warning(f"{self.peer_name}: primary decode failed: {decode_e}")
|
logger.warning(f"{self.peer_name}: primary decode failed: {decode_e}")
|
||||||
|
|
||||||
@ -1969,11 +1954,11 @@ class OptimizedAudioProcessor:
|
|||||||
if not transcription:
|
if not transcription:
|
||||||
try:
|
try:
|
||||||
if hasattr(generation_output, "sequences") and (
|
if hasattr(generation_output, "sequences") and (
|
||||||
generated_ids is not generation_output.sequences
|
generated_ids is not generation_output.sequences # type: ignore
|
||||||
):
|
):
|
||||||
transcription = ov_model.processor.batch_decode(
|
transcription = ov_model.processor.batch_decode(# type: ignore
|
||||||
generation_output.sequences, skip_special_tokens=True
|
generation_output.sequences, skip_special_tokens=True # type: ignore
|
||||||
)[0].strip()
|
)[0].strip() # type: ignore
|
||||||
except Exception as fallback_e:
|
except Exception as fallback_e:
|
||||||
logger.warning(f"{self.peer_name}: fallback decode failed: {fallback_e}")
|
logger.warning(f"{self.peer_name}: fallback decode failed: {fallback_e}")
|
||||||
|
|
||||||
@ -1982,11 +1967,11 @@ class OptimizedAudioProcessor:
|
|||||||
try:
|
try:
|
||||||
if is_final:
|
if is_final:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.peer_name}: final transcription empty after decode; generated_ids repr/shape: {repr(generated_ids)[:200]}"
|
f"{self.peer_name}: final transcription empty after decode"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{self.peer_name}: streaming transcription empty after decode; generated_ids repr/shape: {repr(generated_ids)[:200]}"
|
f"{self.peer_name}: streaming transcription empty after decode"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug(f"{self.peer_name}: generated_ids unavailable for diagnostics")
|
logger.debug(f"{self.peer_name}: generated_ids unavailable for diagnostics")
|
||||||
@ -2020,7 +2005,7 @@ class OptimizedAudioProcessor:
|
|||||||
# Avoid duplicates for streaming updates, but always send final
|
# Avoid duplicates for streaming updates, but always send final
|
||||||
# transcriptions so the UI/clients receive the final marker even
|
# transcriptions so the UI/clients receive the final marker even
|
||||||
# if the text matches a recent interim result.
|
# if the text matches a recent interim result.
|
||||||
if is_final or not self._is_duplicate(transcription):
|
if is_final or not self._is_duplicate(transcription): # type: ignore
|
||||||
# Reuse the existing message ID when possible so the frontend
|
# Reuse the existing message ID when possible so the frontend
|
||||||
# updates the streaming message into a final message instead
|
# updates the streaming message into a final message instead
|
||||||
# of creating a new one. If there is no current_message, a
|
# of creating a new one. If there is no current_message, a
|
||||||
@ -2163,7 +2148,7 @@ class WaveformVideoTrack(MediaStreamTrack):
|
|||||||
|
|
||||||
# Shared buffer for audio data
|
# Shared buffer for audio data
|
||||||
buffer: Dict[str, npt.NDArray[np.float32]] = {}
|
buffer: Dict[str, npt.NDArray[np.float32]] = {}
|
||||||
speech_status: Dict[str, dict] = {}
|
speech_status: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, session_name: str, width: int = 640, height: int = 480, fps: int = 15
|
self, session_name: str, width: int = 640, height: int = 480, fps: int = 15
|
||||||
@ -2182,7 +2167,7 @@ class WaveformVideoTrack(MediaStreamTrack):
|
|||||||
return pts, time_base
|
return pts, time_base
|
||||||
|
|
||||||
async def recv(self) -> VideoFrame:
|
async def recv(self) -> VideoFrame:
|
||||||
pts, time_base = await self.next_timestamp()
|
pts, _ = await self.next_timestamp()
|
||||||
|
|
||||||
# schedule frame according to clock
|
# schedule frame according to clock
|
||||||
target_t = self._next_frame_index / self.fps
|
target_t = self._next_frame_index / self.fps
|
||||||
@ -2224,7 +2209,7 @@ class WaveformVideoTrack(MediaStreamTrack):
|
|||||||
|
|
||||||
# Draw clock in lower right corner, right justified
|
# Draw clock in lower right corner, right justified
|
||||||
current_time = time.strftime("%H:%M:%S")
|
current_time = time.strftime("%H:%M:%S")
|
||||||
(text_width, text_height), _ = cv2.getTextSize(
|
(text_width, _), _ = cv2.getTextSize(
|
||||||
current_time, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2
|
current_time, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 2
|
||||||
)
|
)
|
||||||
clock_x = self.width - text_width - 10 # 10px margin from right edge
|
clock_x = self.width - text_width - 10 # 10px margin from right edge
|
||||||
@ -2364,7 +2349,7 @@ class WaveformVideoTrack(MediaStreamTrack):
|
|||||||
|
|
||||||
# Label the peak with small text near the right edge
|
# Label the peak with small text near the right edge
|
||||||
label = f"Peak:{target_peak:.2f}"
|
label = f"Peak:{target_peak:.2f}"
|
||||||
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
(tw, _), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||||
lx = max(10, self.width - tw - 12)
|
lx = max(10, self.width - tw - 12)
|
||||||
ly = max(12, top_y - 6)
|
ly = max(12, top_y - 6)
|
||||||
cv2.putText(frame_array, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 1)
|
cv2.putText(frame_array, label, (lx, ly), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 1)
|
||||||
@ -2391,7 +2376,7 @@ class WaveformVideoTrack(MediaStreamTrack):
|
|||||||
frame.time_base = fractions.Fraction(1 / 90000).limit_denominator(1000000)
|
frame.time_base = fractions.Fraction(1 / 90000).limit_denominator(1000000)
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
def _draw_speech_status(self, frame_array: np.ndarray, speech_info: dict, pname: str):
|
def _draw_speech_status(self, frame_array: npt.NDArray[np.uint8], speech_info: Dict[str, Any], pname: str):
|
||||||
"""Draw speech detection status information."""
|
"""Draw speech detection status information."""
|
||||||
|
|
||||||
y_offset = 100
|
y_offset = 100
|
||||||
@ -2414,7 +2399,7 @@ class WaveformVideoTrack(MediaStreamTrack):
|
|||||||
f"Temporal: ({'Y' if speech_info.get('temporal_consistency', False) else 'N'})"
|
f"Temporal: ({'Y' if speech_info.get('temporal_consistency', False) else 'N'})"
|
||||||
]
|
]
|
||||||
|
|
||||||
for i, metric in enumerate(metrics):
|
for _, metric in enumerate(metrics):
|
||||||
cv2.putText(frame_array, metric,
|
cv2.putText(frame_array, metric,
|
||||||
(320, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.4,
|
(320, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.4,
|
||||||
(255, 255, 255), 1)
|
(255, 255, 255), 1)
|
||||||
@ -2646,13 +2631,13 @@ def _resample_audio(
|
|||||||
audio_data = np.mean(audio_data, axis=1)
|
audio_data = np.mean(audio_data, axis=1)
|
||||||
|
|
||||||
# Use high-quality resampling
|
# Use high-quality resampling
|
||||||
resampled = librosa.resample(
|
resampled = librosa.resample( # type: ignore
|
||||||
audio_data.astype(np.float64),
|
audio_data.astype(np.float64),
|
||||||
orig_sr=orig_sr,
|
orig_sr=orig_sr,
|
||||||
target_sr=target_sr,
|
target_sr=target_sr,
|
||||||
res_type="kaiser_fast", # Good balance of quality and speed
|
res_type="kaiser_fast", # Good balance of quality and speed
|
||||||
)
|
)
|
||||||
return resampled.astype(np.float32)
|
return resampled.astype(np.float32) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Resampling failed: {e}")
|
logger.error(f"Resampling failed: {e}")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -2760,7 +2745,7 @@ def get_config_schema() -> Dict[str, Any]:
|
|||||||
"type": "range",
|
"type": "range",
|
||||||
"label": "VAD Threshold",
|
"label": "VAD Threshold",
|
||||||
"description": "Voice activity detection threshold",
|
"description": "Voice activity detection threshold",
|
||||||
"default_value": VAD_THRESHOLD,
|
"default_value": 0.01,
|
||||||
"required": False,
|
"required": False,
|
||||||
"min_value": 0.001,
|
"min_value": 0.001,
|
||||||
"max_value": 0.1,
|
"max_value": 0.1,
|
||||||
@ -2915,7 +2900,7 @@ def get_config_schema() -> Dict[str, Any]:
|
|||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"label": "Enable Normalization",
|
"label": "Enable Normalization",
|
||||||
"description": "Normalize incoming audio based on observed peak amplitude before transcription and visualization",
|
"description": "Normalize incoming audio based on observed peak amplitude before transcription and visualization",
|
||||||
"default_value": NORMALIZATION_ENABLED,
|
"default_value": True,
|
||||||
"required": False
|
"required": False
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -2923,7 +2908,7 @@ def get_config_schema() -> Dict[str, Any]:
|
|||||||
"type": "number",
|
"type": "number",
|
||||||
"label": "Normalization Target Peak",
|
"label": "Normalization Target Peak",
|
||||||
"description": "Target peak (0-1) used when normalizing audio",
|
"description": "Target peak (0-1) used when normalizing audio",
|
||||||
"default_value": NORMALIZATION_TARGET_PEAK,
|
"default_value": 0.7,
|
||||||
"required": False,
|
"required": False,
|
||||||
"min_value": 0.5,
|
"min_value": 0.5,
|
||||||
"max_value": 1.0
|
"max_value": 1.0
|
||||||
@ -2933,7 +2918,7 @@ def get_config_schema() -> Dict[str, Any]:
|
|||||||
"type": "range",
|
"type": "range",
|
||||||
"label": "Max Normalization Gain",
|
"label": "Max Normalization Gain",
|
||||||
"description": "Maximum allowed gain applied during normalization",
|
"description": "Maximum allowed gain applied during normalization",
|
||||||
"default_value": MAX_NORMALIZATION_GAIN,
|
"default_value": 10.0,
|
||||||
"required": False,
|
"required": False,
|
||||||
"min_value": 1.0,
|
"min_value": 1.0,
|
||||||
"max_value": 10.0,
|
"max_value": 10.0,
|
||||||
@ -2951,15 +2936,14 @@ def get_config_schema() -> Dict[str, Any]:
|
|||||||
|
|
||||||
def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
||||||
"""Handle configuration update for a specific lobby"""
|
"""Handle configuration update for a specific lobby"""
|
||||||
global _model_id, _device, _ov_config, SAMPLE_RATE, CHUNK_DURATION_MS, VAD_THRESHOLD
|
global _model_id, _device, _ov_config
|
||||||
global MAX_SILENCE_FRAMES, MAX_TRAILING_SILENCE_FRAMES
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Updating Whisper config for lobby {lobby_id}: {config_values}")
|
logger.info(f"Updating Whisper config for lobby {lobby_id}: {config_values}")
|
||||||
|
|
||||||
config_applied = False
|
config_applied = False
|
||||||
|
|
||||||
# Update model configuration
|
# Update model configuration (global - affects all instances)
|
||||||
if "model_id" in config_values:
|
if "model_id" in config_values:
|
||||||
new_model_id = config_values["model_id"]
|
new_model_id = config_values["model_id"]
|
||||||
if new_model_id in [model for models in model_ids.values() for model in models]:
|
if new_model_id in [model for models in model_ids.values() for model in models]:
|
||||||
@ -2969,9 +2953,9 @@ def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"Invalid model_id: {new_model_id}")
|
logger.warning(f"Invalid model_id: {new_model_id}")
|
||||||
|
|
||||||
# Update device configuration
|
# Update device configuration (global - affects all instances)
|
||||||
if "device" in config_values:
|
if "device" in config_values:
|
||||||
new_device = config_values["device"]
|
new_device = config_values["device"] # type: ignore
|
||||||
available_devices = [d["name"] for d in get_available_devices()]
|
available_devices = [d["name"] for d in get_available_devices()]
|
||||||
if new_device in available_devices or new_device in ["CPU", "GPU", "GPU.1"]:
|
if new_device in available_devices or new_device in ["CPU", "GPU", "GPU.1"]:
|
||||||
_device = new_device
|
_device = new_device
|
||||||
@ -2981,7 +2965,7 @@ def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"Invalid device: {new_device}, available: {available_devices}")
|
logger.warning(f"Invalid device: {new_device}, available: {available_devices}")
|
||||||
|
|
||||||
# Update OpenVINO configuration
|
# Update OpenVINO configuration (global - affects all instances)
|
||||||
if "enable_quantization" in config_values:
|
if "enable_quantization" in config_values:
|
||||||
_ov_config.enable_quantization = bool(config_values["enable_quantization"])
|
_ov_config.enable_quantization = bool(config_values["enable_quantization"])
|
||||||
config_applied = True
|
config_applied = True
|
||||||
@ -3001,106 +2985,212 @@ def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
|||||||
config_applied = True
|
config_applied = True
|
||||||
logger.info(f"Updated max_threads to: {_ov_config.max_threads}")
|
logger.info(f"Updated max_threads to: {_ov_config.max_threads}")
|
||||||
|
|
||||||
# Update audio processing parameters
|
# Update audio processing parameters for existing processors
|
||||||
if "sample_rate" in config_values:
|
if "sample_rate" in config_values:
|
||||||
rate = int(config_values["sample_rate"])
|
rate = int(config_values["sample_rate"])
|
||||||
if 8000 <= rate <= 48000:
|
if 8000 <= rate <= 48000:
|
||||||
SAMPLE_RATE = rate
|
# Update existing processors
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.sample_rate = rate
|
||||||
|
proc.chunk_size = int(proc.sample_rate * proc.chunk_duration_ms / 1000)
|
||||||
|
logger.info(f"Updated sample_rate to {rate} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update sample_rate for processor: {pname}")
|
||||||
config_applied = True
|
config_applied = True
|
||||||
logger.info(f"Updated sample_rate to: {SAMPLE_RATE}")
|
logger.info(f"Updated sample_rate to: {rate}")
|
||||||
|
|
||||||
if "chunk_duration_ms" in config_values:
|
if "chunk_duration_ms" in config_values:
|
||||||
duration = int(config_values["chunk_duration_ms"])
|
duration = int(config_values["chunk_duration_ms"])
|
||||||
if 50 <= duration <= 500:
|
if 50 <= duration <= 500:
|
||||||
CHUNK_DURATION_MS = duration
|
# Update existing processors
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.chunk_duration_ms = duration
|
||||||
|
proc.chunk_size = int(proc.sample_rate * proc.chunk_duration_ms / 1000)
|
||||||
|
logger.info(f"Updated chunk_duration_ms to {duration} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update chunk_duration_ms for processor: {pname}")
|
||||||
config_applied = True
|
config_applied = True
|
||||||
logger.info(f"Updated chunk_duration_ms to: {CHUNK_DURATION_MS}")
|
logger.info(f"Updated chunk_duration_ms to: {duration}")
|
||||||
|
|
||||||
if "vad_threshold" in config_values:
|
|
||||||
threshold = float(config_values["vad_threshold"])
|
|
||||||
if 0.001 <= threshold <= 0.1:
|
|
||||||
VAD_THRESHOLD = threshold
|
|
||||||
config_applied = True
|
|
||||||
logger.info(f"Updated vad_threshold to: {VAD_THRESHOLD}")
|
|
||||||
|
|
||||||
if "max_silence_frames" in config_values:
|
if "max_silence_frames" in config_values:
|
||||||
frames = int(config_values["max_silence_frames"])
|
frames = int(config_values["max_silence_frames"])
|
||||||
if 10 <= frames <= 100:
|
if 10 <= frames <= 100:
|
||||||
MAX_SILENCE_FRAMES = frames
|
# Update existing processors
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.max_silence_frames = frames
|
||||||
|
logger.info(f"Updated max_silence_frames to {frames} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update max_silence_frames for processor: {pname}")
|
||||||
config_applied = True
|
config_applied = True
|
||||||
logger.info(f"Updated max_silence_frames to: {MAX_SILENCE_FRAMES}")
|
logger.info(f"Updated max_silence_frames to: {frames}")
|
||||||
|
|
||||||
if "max_trailing_silence_frames" in config_values:
|
if "max_trailing_silence_frames" in config_values:
|
||||||
frames = int(config_values["max_trailing_silence_frames"])
|
frames = int(config_values["max_trailing_silence_frames"])
|
||||||
if 1 <= frames <= 20:
|
if 1 <= frames <= 20:
|
||||||
MAX_TRAILING_SILENCE_FRAMES = frames
|
# Update existing processors
|
||||||
config_applied = True
|
|
||||||
logger.info(f"Updated max_trailing_silence_frames to: {MAX_TRAILING_SILENCE_FRAMES}")
|
|
||||||
|
|
||||||
# Update VAD configuration (this would require updating existing processors)
|
|
||||||
vad_updates = {}
|
|
||||||
if "vad_energy_threshold" in config_values:
|
|
||||||
vad_updates["energy_threshold"] = float(config_values["vad_energy_threshold"])
|
|
||||||
if "vad_zcr_min" in config_values:
|
|
||||||
vad_updates["zcr_min"] = float(config_values["vad_zcr_min"])
|
|
||||||
if "vad_zcr_max" in config_values:
|
|
||||||
vad_updates["zcr_max"] = float(config_values["vad_zcr_max"])
|
|
||||||
if "vad_spectral_centroid_min" in config_values:
|
|
||||||
vad_updates["spectral_centroid_min"] = float(config_values["vad_spectral_centroid_min"])
|
|
||||||
if "vad_spectral_centroid_max" in config_values:
|
|
||||||
vad_updates["spectral_centroid_max"] = float(config_values["vad_spectral_centroid_max"])
|
|
||||||
if "vad_spectral_rolloff_threshold" in config_values:
|
|
||||||
vad_updates["spectral_rolloff_threshold"] = float(config_values["vad_spectral_rolloff_threshold"])
|
|
||||||
if "vad_minimum_duration" in config_values:
|
|
||||||
vad_updates["minimum_duration"] = float(config_values["vad_minimum_duration"])
|
|
||||||
if "vad_max_history" in config_values:
|
|
||||||
vad_updates["max_history"] = int(config_values["vad_max_history"])
|
|
||||||
if "vad_noise_floor_energy" in config_values:
|
|
||||||
vad_updates["noise_floor_energy"] = float(config_values["vad_noise_floor_energy"])
|
|
||||||
if "vad_adaptation_rate" in config_values:
|
|
||||||
vad_updates["adaptation_rate"] = float(config_values["vad_adaptation_rate"])
|
|
||||||
if "vad_harmonic_threshold" in config_values:
|
|
||||||
vad_updates["harmonic_threshold"] = float(config_values["vad_harmonic_threshold"])
|
|
||||||
|
|
||||||
if vad_updates:
|
|
||||||
# Update VAD_CONFIG global
|
|
||||||
VAD_CONFIG.update(vad_updates)
|
|
||||||
config_applied = True
|
|
||||||
logger.info(f"Updated VAD config: {vad_updates}")
|
|
||||||
|
|
||||||
# Note: Existing processors would need to be recreated to pick up VAD changes
|
|
||||||
# For now, we'll log that a restart may be needed
|
|
||||||
logger.info("VAD configuration updated - existing processors may need restart to take effect")
|
|
||||||
|
|
||||||
# Normalization updates: apply to global defaults and active processors
|
|
||||||
norm_updates = False
|
|
||||||
if "normalization_enabled" in config_values:
|
|
||||||
NORMALIZATION_ENABLED = bool(config_values["normalization_enabled"])
|
|
||||||
norm_updates = True
|
|
||||||
logger.info(f"Updated NORMALIZATION_ENABLED to: {NORMALIZATION_ENABLED}")
|
|
||||||
if "normalization_target_peak" in config_values:
|
|
||||||
NORMALIZATION_TARGET_PEAK = float(config_values["normalization_target_peak"])
|
|
||||||
norm_updates = True
|
|
||||||
logger.info(f"Updated NORMALIZATION_TARGET_PEAK to: {NORMALIZATION_TARGET_PEAK}")
|
|
||||||
if "max_normalization_gain" in config_values:
|
|
||||||
MAX_NORMALIZATION_GAIN = float(config_values["max_normalization_gain"])
|
|
||||||
norm_updates = True
|
|
||||||
logger.info(f"Updated MAX_NORMALIZATION_GAIN to: {MAX_NORMALIZATION_GAIN}")
|
|
||||||
|
|
||||||
if norm_updates:
|
|
||||||
# Propagate changes to existing processors
|
|
||||||
try:
|
|
||||||
for pname, proc in list(_audio_processors.items()):
|
for pname, proc in list(_audio_processors.items()):
|
||||||
try:
|
try:
|
||||||
proc.normalization_enabled = NORMALIZATION_ENABLED
|
proc.max_trailing_silence_frames = frames
|
||||||
proc.normalization_target_peak = NORMALIZATION_TARGET_PEAK
|
logger.info(f"Updated max_trailing_silence_frames to {frames} for processor: {pname}")
|
||||||
proc.max_normalization_gain = MAX_NORMALIZATION_GAIN
|
|
||||||
logger.info(f"Applied normalization config to processor: {pname}")
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug(f"Failed to apply normalization config to processor: {pname}")
|
logger.debug(f"Failed to update max_trailing_silence_frames for processor: {pname}")
|
||||||
config_applied = True
|
config_applied = True
|
||||||
except Exception:
|
logger.info(f"Updated max_trailing_silence_frames to: {frames}")
|
||||||
logger.debug("Failed to propagate normalization settings to processors")
|
|
||||||
|
# Update VAD configuration for existing processors
|
||||||
|
vad_updates = False
|
||||||
|
if "vad_energy_threshold" in config_values:
|
||||||
|
threshold = float(config_values["vad_energy_threshold"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.vad_energy_threshold = threshold
|
||||||
|
logger.info(f"Updated vad_energy_threshold to {threshold} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update vad_energy_threshold for processor: {pname}")
|
||||||
|
vad_updates = True
|
||||||
|
|
||||||
|
if "vad_zcr_min" in config_values:
|
||||||
|
zcr_min = float(config_values["vad_zcr_min"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.vad_zcr_min = zcr_min
|
||||||
|
logger.info(f"Updated vad_zcr_min to {zcr_min} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update vad_zcr_min for processor: {pname}")
|
||||||
|
vad_updates = True
|
||||||
|
|
||||||
|
if "vad_zcr_max" in config_values:
|
||||||
|
zcr_max = float(config_values["vad_zcr_max"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.vad_zcr_max = zcr_max
|
||||||
|
logger.info(f"Updated vad_zcr_max to {zcr_max} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update vad_zcr_max for processor: {pname}")
|
||||||
|
vad_updates = True
|
||||||
|
|
||||||
|
if "vad_spectral_centroid_min" in config_values:
|
||||||
|
centroid_min = int(config_values["vad_spectral_centroid_min"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.vad_spectral_centroid_min = centroid_min
|
||||||
|
logger.info(f"Updated vad_spectral_centroid_min to {centroid_min} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update vad_spectral_centroid_min for processor: {pname}")
|
||||||
|
vad_updates = True
|
||||||
|
|
||||||
|
if "vad_spectral_centroid_max" in config_values:
|
||||||
|
centroid_max = int(config_values["vad_spectral_centroid_max"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.vad_spectral_centroid_max = centroid_max
|
||||||
|
logger.info(f"Updated vad_spectral_centroid_max to {centroid_max} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update vad_spectral_centroid_max for processor: {pname}")
|
||||||
|
vad_updates = True
|
||||||
|
|
||||||
|
if "vad_spectral_rolloff_threshold" in config_values:
|
||||||
|
rolloff = int(config_values["vad_spectral_rolloff_threshold"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.vad_spectral_rolloff_threshold = rolloff
|
||||||
|
logger.info(f"Updated vad_spectral_rolloff_threshold to {rolloff} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update vad_spectral_rolloff_threshold for processor: {pname}")
|
||||||
|
vad_updates = True
|
||||||
|
|
||||||
|
if "vad_minimum_duration" in config_values:
|
||||||
|
duration = float(config_values["vad_minimum_duration"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.vad_minimum_duration = duration
|
||||||
|
logger.info(f"Updated vad_minimum_duration to {duration} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update vad_minimum_duration for processor: {pname}")
|
||||||
|
vad_updates = True
|
||||||
|
|
||||||
|
if "vad_max_history" in config_values:
|
||||||
|
history = int(config_values["vad_max_history"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.vad_max_history = history
|
||||||
|
logger.info(f"Updated vad_max_history to {history} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update vad_max_history for processor: {pname}")
|
||||||
|
vad_updates = True
|
||||||
|
|
||||||
|
if "vad_noise_floor_energy" in config_values:
|
||||||
|
noise_floor = float(config_values["vad_noise_floor_energy"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.vad_noise_floor_energy = noise_floor
|
||||||
|
logger.info(f"Updated vad_noise_floor_energy to {noise_floor} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update vad_noise_floor_energy for processor: {pname}")
|
||||||
|
vad_updates = True
|
||||||
|
|
||||||
|
if "vad_adaptation_rate" in config_values:
|
||||||
|
adaptation_rate = float(config_values["vad_adaptation_rate"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.vad_adaptation_rate = adaptation_rate
|
||||||
|
logger.info(f"Updated vad_adaptation_rate to {adaptation_rate} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update vad_adaptation_rate for processor: {pname}")
|
||||||
|
vad_updates = True
|
||||||
|
|
||||||
|
if "vad_harmonic_threshold" in config_values:
|
||||||
|
harmonic_threshold = float(config_values["vad_harmonic_threshold"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.vad_harmonic_threshold = harmonic_threshold
|
||||||
|
logger.info(f"Updated vad_harmonic_threshold to {harmonic_threshold} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update vad_harmonic_threshold for processor: {pname}")
|
||||||
|
vad_updates = True
|
||||||
|
|
||||||
|
if vad_updates:
|
||||||
|
config_applied = True
|
||||||
|
logger.info("VAD configuration updated for existing processors")
|
||||||
|
|
||||||
|
# Normalization updates: apply to existing processors
|
||||||
|
norm_updates = False
|
||||||
|
if "normalization_enabled" in config_values:
|
||||||
|
enabled = bool(config_values["normalization_enabled"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.normalization_enabled = enabled
|
||||||
|
logger.info(f"Updated normalization_enabled to {enabled} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update normalization_enabled for processor: {pname}")
|
||||||
|
norm_updates = True
|
||||||
|
|
||||||
|
if "normalization_target_peak" in config_values:
|
||||||
|
target_peak = float(config_values["normalization_target_peak"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.normalization_target_peak = target_peak
|
||||||
|
logger.info(f"Updated normalization_target_peak to {target_peak} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update normalization_target_peak for processor: {pname}")
|
||||||
|
norm_updates = True
|
||||||
|
|
||||||
|
if "max_normalization_gain" in config_values:
|
||||||
|
max_gain = float(config_values["max_normalization_gain"])
|
||||||
|
for pname, proc in list(_audio_processors.items()):
|
||||||
|
try:
|
||||||
|
proc.max_normalization_gain = max_gain
|
||||||
|
logger.info(f"Updated max_normalization_gain to {max_gain} for processor: {pname}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Failed to update max_normalization_gain for processor: {pname}")
|
||||||
|
norm_updates = True
|
||||||
|
|
||||||
|
if norm_updates:
|
||||||
|
config_applied = True
|
||||||
|
logger.info("Normalization configuration updated for existing processors")
|
||||||
|
|
||||||
if config_applied:
|
if config_applied:
|
||||||
logger.info(f"Configuration update completed for lobby {lobby_id}")
|
logger.info(f"Configuration update completed for lobby {lobby_id}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user