Improved transcription to current message
This commit is contained in:
parent
50f290ac22
commit
110430d22a
@ -40,7 +40,18 @@ const LobbyChat: React.FC<LobbyChatProps> = ({ socketUrl, session, lobbyId }) =>
|
|||||||
switch (message.type) {
|
switch (message.type) {
|
||||||
case "chat_message":
|
case "chat_message":
|
||||||
const chatMessage = data as ChatMessage;
|
const chatMessage = data as ChatMessage;
|
||||||
setMessages((prev) => [...prev, chatMessage]);
|
setMessages((prev) => {
|
||||||
|
const existingIndex = prev.findIndex((msg) => msg.id === chatMessage.id);
|
||||||
|
if (existingIndex >= 0) {
|
||||||
|
// Update existing message
|
||||||
|
const updatedMessages = [...prev];
|
||||||
|
updatedMessages[existingIndex] = chatMessage;
|
||||||
|
return updatedMessages;
|
||||||
|
} else {
|
||||||
|
// Add new message
|
||||||
|
return [...prev, chatMessage];
|
||||||
|
}
|
||||||
|
});
|
||||||
break;
|
break;
|
||||||
case "chat_messages":
|
case "chat_messages":
|
||||||
const chatMessages = data.messages as ChatMessage[];
|
const chatMessages = data.messages as ChatMessage[];
|
||||||
|
@ -44,7 +44,7 @@
|
|||||||
height: 3.75rem;
|
height: 3.75rem;
|
||||||
min-width: 5rem;
|
min-width: 5rem;
|
||||||
min-height: 3.75rem;
|
min-height: 3.75rem;
|
||||||
z-index: 50000;
|
z-index: 1200;
|
||||||
border-radius: 0.25rem;
|
border-radius: 0.25rem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1573,6 +1573,7 @@ const MediaControl: React.FC<MediaControlProps> = ({ isSelf, peer, className })
|
|||||||
hideDefaultLines={false}
|
hideDefaultLines={false}
|
||||||
snappable={true}
|
snappable={true}
|
||||||
snapThreshold={5}
|
snapThreshold={5}
|
||||||
|
origin={false}
|
||||||
edge
|
edge
|
||||||
onDragStart={(e) => {
|
onDragStart={(e) => {
|
||||||
setIsDragging(true);
|
setIsDragging(true);
|
||||||
|
@ -191,6 +191,33 @@ class Lobby:
|
|||||||
]
|
]
|
||||||
return chat_message
|
return chat_message
|
||||||
|
|
||||||
|
def update_chat_message(self, chat_message: ChatMessageModel) -> ChatMessageModel:
|
||||||
|
"""Update an existing chat message in the lobby and return the updated message"""
|
||||||
|
with self.lock:
|
||||||
|
# Find the existing message by ID
|
||||||
|
for i, existing_msg in enumerate(self.chat_messages):
|
||||||
|
if existing_msg.id == chat_message.id:
|
||||||
|
# Update the message content and timestamp
|
||||||
|
updated_msg = ChatMessageModel(
|
||||||
|
id=chat_message.id,
|
||||||
|
message=chat_message.message,
|
||||||
|
sender_name=existing_msg.sender_name, # Keep original sender
|
||||||
|
sender_session_id=existing_msg.sender_session_id, # Keep original session
|
||||||
|
timestamp=time.time(), # Update timestamp
|
||||||
|
lobby_id=existing_msg.lobby_id, # Keep original lobby
|
||||||
|
)
|
||||||
|
self.chat_messages[i] = updated_msg
|
||||||
|
return updated_msg
|
||||||
|
|
||||||
|
# If message not found, add it as new
|
||||||
|
self.chat_messages.append(chat_message)
|
||||||
|
# Keep only the latest messages per lobby
|
||||||
|
if len(self.chat_messages) > LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY:
|
||||||
|
self.chat_messages = self.chat_messages[
|
||||||
|
-LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY :
|
||||||
|
]
|
||||||
|
return chat_message
|
||||||
|
|
||||||
def get_chat_messages(self, limit: int = 50) -> List[ChatMessageModel]:
|
def get_chat_messages(self, limit: int = 50) -> List[ChatMessageModel]:
|
||||||
"""Get the most recent chat messages from the lobby"""
|
"""Get the most recent chat messages from the lobby"""
|
||||||
with self.lock:
|
with self.lock:
|
||||||
|
@ -10,6 +10,7 @@ from typing import Dict, Any, TYPE_CHECKING
|
|||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
|
|
||||||
from shared.logger import logger
|
from shared.logger import logger
|
||||||
|
from shared.models import ChatMessageModel
|
||||||
from .webrtc_signaling import WebRTCSignalingHandlers
|
from .webrtc_signaling import WebRTCSignalingHandlers
|
||||||
from core.error_handling import (
|
from core.error_handling import (
|
||||||
error_handler,
|
error_handler,
|
||||||
@ -252,13 +253,51 @@ class SendChatMessageHandler(MessageHandler):
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|
||||||
message_text = str(data["message"]).strip()
|
message_data = data["message"]
|
||||||
if not message_text:
|
|
||||||
return
|
if isinstance(message_data, dict):
|
||||||
|
# Handle ChatMessageModel object
|
||||||
|
try:
|
||||||
|
chat_message = ChatMessageModel.model_validate(message_data)
|
||||||
|
# Validate that the sender matches the session
|
||||||
|
if chat_message.sender_session_id != session.id:
|
||||||
|
logger.error(
|
||||||
|
f"{session.getName()} - ChatMessageModel sender_session_id mismatch"
|
||||||
|
)
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"error": "ChatMessageModel sender_session_id does not match session"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return
|
||||||
|
# Update or add the message
|
||||||
|
chat_message = lobby.update_chat_message(chat_message)
|
||||||
|
logger.info(
|
||||||
|
f"{session.getName()} -> update_chat_message({lobby.getName()}, {chat_message.message[:50]}...)"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{session.getName()} - Invalid ChatMessageModel: {e}")
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {"error": "Invalid ChatMessageModel format"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Handle string message (legacy support)
|
||||||
|
message_text = str(message_data).strip()
|
||||||
|
if not message_text:
|
||||||
|
return
|
||||||
|
# Add the message to the lobby and broadcast it
|
||||||
|
chat_message = lobby.add_chat_message(session, message_text)
|
||||||
|
logger.info(
|
||||||
|
f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)"
|
||||||
|
)
|
||||||
|
|
||||||
# Add the message to the lobby and broadcast it
|
|
||||||
chat_message = lobby.add_chat_message(session, message_text)
|
|
||||||
logger.info(f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)")
|
|
||||||
await lobby.broadcast_chat_message(chat_message)
|
await lobby.broadcast_chat_message(chat_message)
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ This bot demonstrates the advanced capabilities including:
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Optional, Callable, Awaitable, Any
|
from typing import Dict, Optional, Callable, Awaitable, Any, Union
|
||||||
from aiortc import MediaStreamTrack
|
from aiortc import MediaStreamTrack
|
||||||
|
|
||||||
# Import system modules
|
# Import system modules
|
||||||
@ -257,7 +257,7 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
|
|||||||
|
|
||||||
async def handle_chat_message(
|
async def handle_chat_message(
|
||||||
chat_message: ChatMessageModel,
|
chat_message: ChatMessageModel,
|
||||||
send_message_func: Callable[[str], Awaitable[None]]
|
send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Handle incoming chat messages and provide AI-powered responses."""
|
"""Handle incoming chat messages and provide AI-powered responses."""
|
||||||
global _bot_instance
|
global _bot_instance
|
||||||
|
@ -4,7 +4,7 @@ This bot shows how to create an agent that primarily uses chat functionality
|
|||||||
rather than media streams.
|
rather than media streams.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Optional, Callable, Awaitable
|
from typing import Dict, Optional, Callable, Awaitable, Union
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
from shared.logger import logger
|
from shared.logger import logger
|
||||||
@ -45,7 +45,7 @@ def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
async def handle_chat_message(chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]]) -> Optional[str]:
|
async def handle_chat_message(chat_message: ChatMessageModel, send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]) -> Optional[str]:
|
||||||
"""Handle incoming chat messages and provide responses.
|
"""Handle incoming chat messages and provide responses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -253,23 +253,29 @@ class AdvancedVAD:
|
|||||||
|
|
||||||
def __init__(self, sample_rate: int = SAMPLE_RATE):
|
def __init__(self, sample_rate: int = SAMPLE_RATE):
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
# More conservative thresholds
|
# More permissive thresholds based on research
|
||||||
self.energy_threshold = 0.02 # Increased from 0.01
|
self.energy_threshold = 0.005 # Reduced from 0.02
|
||||||
self.zcr_min = 0.1
|
self.zcr_min = 0.02 # Reduced from 0.1 (voiced speech < 0.1)
|
||||||
self.zcr_max = 0.5 # Reject very high ZCR (digital noise)
|
self.zcr_max = 0.8 # Increased from 0.5 (unvoiced speech ~0.3-0.8)
|
||||||
self.spectral_centroid_min = 300 # Human speech starts around 300Hz
|
|
||||||
self.spectral_centroid_max = 3400 # Human speech typically under 3400Hz
|
|
||||||
self.spectral_rolloff_threshold = 2000 # Speech energy concentrated lower
|
|
||||||
self.minimum_duration = 0.3 # Require 300ms of consistent speech
|
|
||||||
|
|
||||||
# Temporal consistency tracking
|
# Spectral thresholds (keep existing - these work well)
|
||||||
|
self.spectral_centroid_min = 200 # Slightly lower
|
||||||
|
self.spectral_centroid_max = 4000 # Slightly higher
|
||||||
|
self.spectral_rolloff_threshold = 3000 # More permissive
|
||||||
|
|
||||||
|
# Relaxed temporal consistency
|
||||||
|
self.minimum_duration = 0.2 # Reduced from 0.3s
|
||||||
self.speech_history = []
|
self.speech_history = []
|
||||||
self.max_history = 10 # Track last 10 frames (1 second at 100ms chunks)
|
self.max_history = 8 # Reduced from 10
|
||||||
|
|
||||||
# Adaptive noise floor
|
# Adaptive noise floor
|
||||||
self.noise_floor_energy = 0.001
|
self.noise_floor_energy = 0.001
|
||||||
self.noise_floor_centroid = 1000
|
self.noise_floor_centroid = 1000
|
||||||
self.adaptation_rate = 0.02
|
self.adaptation_rate = 0.05
|
||||||
|
|
||||||
|
# Harmonicity improvements
|
||||||
|
self.prev_magnitude = None
|
||||||
|
self.harmonic_threshold = 0.15 # Reduced from 0.3
|
||||||
|
|
||||||
def analyze_frame(self, audio_data: AudioArray) -> Tuple[bool, dict]:
|
def analyze_frame(self, audio_data: AudioArray) -> Tuple[bool, dict]:
|
||||||
"""Analyze audio frame for speech vs noise."""
|
"""Analyze audio frame for speech vs noise."""
|
||||||
@ -286,43 +292,66 @@ class AdvancedVAD:
|
|||||||
spectral_features = self._compute_spectral_features(audio_data)
|
spectral_features = self._compute_spectral_features(audio_data)
|
||||||
|
|
||||||
# Individual feature checks
|
# Individual feature checks
|
||||||
energy_check = energy > max(self.energy_threshold, self.noise_floor_energy * 3)
|
|
||||||
zcr_check = self.zcr_min < zcr < self.zcr_max
|
# Use adaptive energy threshold (reduced multiplier)
|
||||||
|
adaptive_energy_threshold = max(
|
||||||
|
self.energy_threshold,
|
||||||
|
self.noise_floor_energy * 2.0 # Reduced from 3.0
|
||||||
|
)
|
||||||
|
energy_check = energy > adaptive_energy_threshold
|
||||||
|
|
||||||
|
# More permissive ZCR check (allow voiced OR unvoiced speech)
|
||||||
|
zcr_check = (
|
||||||
|
self.zcr_min < zcr < self.zcr_max or # General range
|
||||||
|
zcr < 0.1 or # Definitely voiced
|
||||||
|
(0.2 < zcr < 0.6 and energy > self.energy_threshold * 2) # Unvoiced with energy
|
||||||
|
)
|
||||||
|
|
||||||
|
# Spectral check (more permissive)
|
||||||
spectral_check = (
|
spectral_check = (
|
||||||
self.spectral_centroid_min < spectral_features['centroid'] < self.spectral_centroid_max
|
self.spectral_centroid_min < spectral_features['centroid'] < self.spectral_centroid_max and
|
||||||
and spectral_features['rolloff'] < self.spectral_rolloff_threshold
|
spectral_features['rolloff'] < self.spectral_rolloff_threshold and
|
||||||
and spectral_features['flux'] > 0.01 # Spectral change indicates speech
|
spectral_features['flux'] > 0.005 # Reduced threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
# Harmonicity check (speech has harmonic structure)
|
# Improved harmonicity check
|
||||||
harmonic_check = spectral_features['harmonicity'] > 0.3
|
harmonic_check = spectral_features['harmonicity'] > self.harmonic_threshold
|
||||||
|
|
||||||
# Combined decision with temporal consistency
|
# More permissive combined decision (OR logic for some conditions)
|
||||||
frame_has_speech = (
|
frame_has_speech = (
|
||||||
energy_check and
|
energy_check and (
|
||||||
zcr_check and
|
zcr_check or # ZCR is good, OR
|
||||||
spectral_check and
|
spectral_check or # Spectral features are good, OR
|
||||||
harmonic_check
|
harmonic_check # Harmonicity is good
|
||||||
|
)
|
||||||
|
) or (
|
||||||
|
# Alternative path: strong energy + reasonable spectral
|
||||||
|
energy > adaptive_energy_threshold * 1.5 and spectral_check
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update history
|
# Update history
|
||||||
self.speech_history.append(frame_has_speech)
|
self.speech_history.append(frame_has_speech)
|
||||||
if len(self.speech_history) > self.max_history:
|
if len(self.speech_history) > self.max_history:
|
||||||
self.speech_history.pop(0)
|
self.speech_history.pop(0)
|
||||||
|
|
||||||
# Require temporal consistency (at least 3 of last 5 frames)
|
# More permissive temporal consistency (2 of last 4, or 1 of last 2 if strong)
|
||||||
recent_speech = sum(self.speech_history[-5:]) >= 3 if len(self.speech_history) >= 5 else frame_has_speech
|
if len(self.speech_history) >= 4:
|
||||||
|
recent_speech = sum(self.speech_history[-4:]) >= 2
|
||||||
|
elif len(self.speech_history) >= 2:
|
||||||
|
# For shorter history, be more permissive if recent frame is strong
|
||||||
|
recent_speech = (
|
||||||
|
sum(self.speech_history[-2:]) >= 1 and
|
||||||
|
(energy > adaptive_energy_threshold * 1.2 or frame_has_speech)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
recent_speech = frame_has_speech
|
||||||
|
|
||||||
# Update noise floor during silence
|
# Faster noise floor adaptation during silence
|
||||||
if not frame_has_speech:
|
if not frame_has_speech:
|
||||||
self.noise_floor_energy = (
|
self.noise_floor_energy = (
|
||||||
self.adaptation_rate * energy +
|
self.adaptation_rate * energy +
|
||||||
(1 - self.adaptation_rate) * self.noise_floor_energy
|
(1 - self.adaptation_rate) * self.noise_floor_energy
|
||||||
)
|
)
|
||||||
self.noise_floor_centroid = (
|
|
||||||
self.adaptation_rate * spectral_features['centroid'] +
|
|
||||||
(1 - self.adaptation_rate) * self.noise_floor_centroid
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = {
|
metrics = {
|
||||||
'energy': energy,
|
'energy': energy,
|
||||||
@ -332,6 +361,7 @@ class AdvancedVAD:
|
|||||||
'flux': spectral_features['flux'],
|
'flux': spectral_features['flux'],
|
||||||
'harmonicity': spectral_features['harmonicity'],
|
'harmonicity': spectral_features['harmonicity'],
|
||||||
'noise_floor_energy': self.noise_floor_energy,
|
'noise_floor_energy': self.noise_floor_energy,
|
||||||
|
'adaptive_threshold': adaptive_energy_threshold,
|
||||||
'energy_check': energy_check,
|
'energy_check': energy_check,
|
||||||
'zcr_check': zcr_check,
|
'zcr_check': zcr_check,
|
||||||
'spectral_check': spectral_check,
|
'spectral_check': spectral_check,
|
||||||
@ -386,44 +416,77 @@ class AdvancedVAD:
|
|||||||
|
|
||||||
def _compute_harmonicity(self, magnitude: np.ndarray, freqs: np.ndarray) -> float:
|
def _compute_harmonicity(self, magnitude: np.ndarray, freqs: np.ndarray) -> float:
|
||||||
"""Compute harmonicity score (0-1, higher = more harmonic/speech-like)."""
|
"""Compute harmonicity score (0-1, higher = more harmonic/speech-like)."""
|
||||||
|
|
||||||
# Find fundamental frequency candidate (peak in 80-400Hz range for speech)
|
# Find fundamental frequency candidate (peak in 80-400Hz range for speech)
|
||||||
speech_range = (freqs >= 80) & (freqs <= 400)
|
# Expanded F0 range for better detection
|
||||||
|
speech_range = (freqs >= 60) & (freqs <= 500) # Expanded from 80-400Hz
|
||||||
if not np.any(speech_range):
|
if not np.any(speech_range):
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
speech_magnitude = magnitude[speech_range]
|
speech_magnitude = magnitude[speech_range]
|
||||||
speech_freqs = freqs[speech_range]
|
speech_freqs = freqs[speech_range]
|
||||||
|
|
||||||
if len(speech_magnitude) == 0:
|
if len(speech_magnitude) == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# Find strongest peak in speech range
|
# Find strongest peak in speech range
|
||||||
f0_idx = np.argmax(speech_magnitude)
|
# More robust F0 detection - find peaks instead of just max
|
||||||
|
try:
|
||||||
|
# Import scipy here to handle missing dependency gracefully
|
||||||
|
from scipy.signal import find_peaks
|
||||||
|
|
||||||
|
# Ensure distance is at least 1
|
||||||
|
min_distance = max(1, int(len(speech_magnitude) * 0.05))
|
||||||
|
|
||||||
|
peaks, properties = find_peaks(
|
||||||
|
speech_magnitude,
|
||||||
|
height=np.max(speech_magnitude) * 0.05, # Lowered from 0.1
|
||||||
|
distance=min_distance, # Minimum peak separation
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(peaks) == 0:
|
||||||
|
# Fallback to simple max if no peaks found
|
||||||
|
f0_idx = np.argmax(speech_magnitude)
|
||||||
|
else:
|
||||||
|
# Use the strongest peak
|
||||||
|
strongest_peak_idx = np.argmax(speech_magnitude[peaks])
|
||||||
|
f0_idx = peaks[strongest_peak_idx]
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# scipy not available, use simple max
|
||||||
|
f0_idx = np.argmax(speech_magnitude)
|
||||||
|
|
||||||
f0 = speech_freqs[f0_idx]
|
f0 = speech_freqs[f0_idx]
|
||||||
f0_strength = speech_magnitude[f0_idx]
|
f0_strength = speech_magnitude[f0_idx]
|
||||||
|
|
||||||
if f0_strength < np.max(magnitude) * 0.1: # F0 should be reasonably strong
|
# More lenient F0 strength requirement
|
||||||
|
if f0_strength < np.max(magnitude) * 0.03: # Reduced from 0.1
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# Check for harmonics (2*f0, 3*f0, etc.)
|
# Check for harmonics (2*f0, 3*f0, etc.)
|
||||||
harmonic_strength = 0.0
|
harmonic_strength = 0.0
|
||||||
total_harmonics = 0
|
total_harmonics = 0
|
||||||
|
|
||||||
for harmonic in range(2, 6): # Check 2nd through 5th harmonics
|
for harmonic in range(2, 5): # Check 2nd through 4th harmonics
|
||||||
harmonic_freq = f0 * harmonic
|
harmonic_freq = f0 * harmonic
|
||||||
if harmonic_freq > freqs[-1]:
|
if harmonic_freq > freqs[-1]:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Find closest frequency bin
|
# Find closest frequency bins (check neighboring bins too)
|
||||||
harmonic_idx = np.argmin(np.abs(freqs - harmonic_freq))
|
harmonic_idx = np.argmin(np.abs(freqs - harmonic_freq))
|
||||||
harmonic_strength += magnitude[harmonic_idx]
|
|
||||||
|
# Check a small neighborhood around the harmonic frequency
|
||||||
|
start_idx = max(0, harmonic_idx - 2)
|
||||||
|
end_idx = min(len(magnitude), harmonic_idx + 3)
|
||||||
|
local_max = np.max(magnitude[start_idx:end_idx])
|
||||||
|
|
||||||
|
harmonic_strength += local_max
|
||||||
total_harmonics += 1
|
total_harmonics += 1
|
||||||
|
|
||||||
if total_harmonics == 0:
|
if total_harmonics == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# Normalize by fundamental strength and number of harmonics
|
# Normalize and return
|
||||||
harmonicity = (harmonic_strength / total_harmonics) / f0_strength
|
harmonicity = (harmonic_strength / total_harmonics) / f0_strength
|
||||||
return min(harmonicity, 1.0)
|
return min(harmonicity, 1.0)
|
||||||
|
|
||||||
@ -899,7 +962,8 @@ class OpenVINOWhisperModel:
|
|||||||
# Global model instance with deferred loading
|
# Global model instance with deferred loading
|
||||||
_whisper_model: Optional[OpenVINOWhisperModel] = None
|
_whisper_model: Optional[OpenVINOWhisperModel] = None
|
||||||
_audio_processors: Dict[str, "OptimizedAudioProcessor"] = {}
|
_audio_processors: Dict[str, "OptimizedAudioProcessor"] = {}
|
||||||
_send_chat_func: Optional[Callable[[str], Awaitable[None]]] = None
|
_send_chat_func: Optional[Callable[[ChatMessageModel], Awaitable[None]]] = None
|
||||||
|
_create_chat_message_func: Optional[Callable[[str, Optional[str]], ChatMessageModel]] = None
|
||||||
|
|
||||||
# Model loading status for video display
|
# Model loading status for video display
|
||||||
_model_loading_status: str = "Not loaded"
|
_model_loading_status: str = "Not loaded"
|
||||||
@ -1018,10 +1082,11 @@ class OptimizedAudioProcessor:
|
|||||||
"""Optimized audio processor for Intel Arc B580 with reduced latency."""
|
"""Optimized audio processor for Intel Arc B580 with reduced latency."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, peer_name: str, send_chat_func: Callable[[str], Awaitable[None]]
|
self, peer_name: str, send_chat_func: Callable[[ChatMessageModel], Awaitable[None]], create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel]
|
||||||
):
|
):
|
||||||
self.peer_name = peer_name
|
self.peer_name = peer_name
|
||||||
self.send_chat_func = send_chat_func
|
self.send_chat_func = send_chat_func
|
||||||
|
self.create_chat_message_func = create_chat_message_func
|
||||||
self.sample_rate = SAMPLE_RATE
|
self.sample_rate = SAMPLE_RATE
|
||||||
|
|
||||||
# Initialize visualization buffer if not already done
|
# Initialize visualization buffer if not already done
|
||||||
@ -1049,6 +1114,10 @@ class OptimizedAudioProcessor:
|
|||||||
self.current_phrase_audio = np.array([], dtype=np.float32)
|
self.current_phrase_audio = np.array([], dtype=np.float32)
|
||||||
self.transcription_history: List[TranscriptionHistoryItem] = []
|
self.transcription_history: List[TranscriptionHistoryItem] = []
|
||||||
self.last_activity_time = time.time()
|
self.last_activity_time = time.time()
|
||||||
|
self.last_audio_time = time.time() # Track when any audio chunk is received
|
||||||
|
|
||||||
|
# Current transcription message for refinements
|
||||||
|
self.current_message: Optional[ChatMessageModel] = None
|
||||||
|
|
||||||
# Async processing
|
# Async processing
|
||||||
self.processing_queue: asyncio.Queue[AudioQueueItem] = asyncio.Queue(maxsize=10)
|
self.processing_queue: asyncio.Queue[AudioQueueItem] = asyncio.Queue(maxsize=10)
|
||||||
@ -1076,6 +1145,9 @@ class OptimizedAudioProcessor:
|
|||||||
logger.error("Processor not running or empty audio data")
|
logger.error("Processor not running or empty audio data")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Update last audio time whenever any audio is received
|
||||||
|
self.last_audio_time = time.time()
|
||||||
|
|
||||||
is_speech, vad_metrics = self.advanced_vad.analyze_frame(audio_data)
|
is_speech, vad_metrics = self.advanced_vad.analyze_frame(audio_data)
|
||||||
|
|
||||||
# Update visualization status
|
# Update visualization status
|
||||||
@ -1227,7 +1299,7 @@ class OptimizedAudioProcessor:
|
|||||||
# Check for final transcription on timeout
|
# Check for final transcription on timeout
|
||||||
if (
|
if (
|
||||||
len(self.current_phrase_audio) > 0
|
len(self.current_phrase_audio) > 0
|
||||||
and time.time() - self.last_activity_time > 2.0
|
and time.time() - self.last_audio_time > 2.0
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Final transcription timeout for {self.peer_name} (asyncio.TimeoutError)"
|
f"Final transcription timeout for {self.peer_name} (asyncio.TimeoutError)"
|
||||||
@ -1276,7 +1348,7 @@ class OptimizedAudioProcessor:
|
|||||||
# Check for final transcription
|
# Check for final transcription
|
||||||
if (
|
if (
|
||||||
len(self.current_phrase_audio) > 0
|
len(self.current_phrase_audio) > 0
|
||||||
and time.time() - self.last_activity_time > 2.0
|
and time.time() - self.last_audio_time > 2.0
|
||||||
):
|
):
|
||||||
if self.main_loop:
|
if self.main_loop:
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -1409,29 +1481,43 @@ class OptimizedAudioProcessor:
|
|||||||
# return # Skip very low confidence
|
# return # Skip very low confidence
|
||||||
|
|
||||||
# # Include confidence in message
|
# # Include confidence in message
|
||||||
# confidence_indicator = (
|
# Create ChatMessageModel for transcription
|
||||||
# "✓" if avg_confidence > 0.8 else "?" if avg_confidence < 0.6 else ""
|
|
||||||
# )
|
|
||||||
message = f"{self.peer_name}: {transcription}" # {confidence_indicator}[{avg_confidence:.1%}]"
|
|
||||||
|
|
||||||
await self.send_chat_func(message)
|
|
||||||
|
|
||||||
if transcription:
|
if transcription:
|
||||||
# Create message with timing
|
# Create message with timing
|
||||||
status_marker = "⚡" if is_final else "🎤"
|
status_marker = "⚡" if is_final else "🎤"
|
||||||
type_marker = "" if is_final else " [streaming]"
|
type_marker = "" if is_final else " [streaming]"
|
||||||
timing_info = f" (🚀 {transcription_time:.2f}s)"
|
timing_info = f" (🚀 {transcription_time:.2f}s)"
|
||||||
|
|
||||||
message = f"{status_marker} {self.peer_name}{type_marker}: {transcription}{timing_info}"
|
message_text = f"{status_marker} {self.peer_name}{type_marker}: {transcription}{timing_info}"
|
||||||
|
|
||||||
# Avoid duplicates
|
# Avoid duplicates
|
||||||
if not self._is_duplicate(transcription):
|
if not self._is_duplicate(transcription):
|
||||||
await self.send_chat_func(message)
|
# For streaming transcriptions, reuse the current message ID if it exists
|
||||||
|
# For final transcriptions, always create a new message
|
||||||
|
if is_final:
|
||||||
|
# Final transcription - reset current message
|
||||||
|
self.current_message = None
|
||||||
|
message_id = None
|
||||||
|
else:
|
||||||
|
# Streaming transcription - reuse current message ID or create new
|
||||||
|
if self.current_message is None:
|
||||||
|
message_id = None # Will create new ID
|
||||||
|
else:
|
||||||
|
message_id = self.current_message.id
|
||||||
|
|
||||||
|
# Create ChatMessageModel
|
||||||
|
chat_message = self.create_chat_message_func(message_text, message_id)
|
||||||
|
|
||||||
|
# Update current message for streaming
|
||||||
|
if not is_final:
|
||||||
|
self.current_message = chat_message
|
||||||
|
|
||||||
|
await self.send_chat_func(chat_message)
|
||||||
|
|
||||||
# Update history
|
# Update history
|
||||||
self.transcription_history.append(
|
self.transcription_history.append(
|
||||||
TranscriptionHistoryItem(
|
TranscriptionHistoryItem(
|
||||||
message=message, timestamp=time.time(), is_final=is_final
|
message=message_text, timestamp=time.time(), is_final=is_final
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1676,7 +1762,7 @@ class WaveformVideoTrack(MediaStreamTrack):
|
|||||||
# Draw colored waveform
|
# Draw colored waveform
|
||||||
if len(points) > 1:
|
if len(points) > 1:
|
||||||
for i in range(len(points) - 1):
|
for i in range(len(points) - 1):
|
||||||
cv2.line(frame_array, points[i], points[i+1], colors[i], 2)
|
cv2.line(frame_array, points[i], points[i+1], colors[i], 1)
|
||||||
|
|
||||||
# Add speech detection status overlay
|
# Add speech detection status overlay
|
||||||
if speech_info:
|
if speech_info:
|
||||||
@ -1701,29 +1787,41 @@ class WaveformVideoTrack(MediaStreamTrack):
|
|||||||
"""Draw speech detection status information."""
|
"""Draw speech detection status information."""
|
||||||
|
|
||||||
y_offset = 100
|
y_offset = 100
|
||||||
line_height = 25
|
|
||||||
|
|
||||||
# Main status
|
# Main status
|
||||||
is_speech = speech_info.get('is_speech', False)
|
is_speech = speech_info.get('is_speech', False)
|
||||||
status_text = "🎤 SPEECH" if is_speech else "🔇 NOISE"
|
status_text = "SPEECH" if is_speech else "NOISE"
|
||||||
status_color = (0, 255, 0) if is_speech else (128, 128, 128)
|
status_color = (0, 255, 0) if is_speech else (128, 128, 128)
|
||||||
|
|
||||||
cv2.putText(frame_array, f"{pname}: {status_text}",
|
adaptive_thresh = speech_info.get('adaptive_threshold', 0)
|
||||||
(10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.8, status_color, 2)
|
cv2.putText(frame_array, f"{pname}: {status_text} (thresh: {adaptive_thresh:.4f})",
|
||||||
|
(10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2)
|
||||||
|
|
||||||
# Detailed metrics (smaller text)
|
# Detailed metrics (smaller text)
|
||||||
metrics = [
|
metrics = [
|
||||||
f"Energy: {speech_info.get('energy', 0):.3f} ({'✓' if speech_info.get('energy_check', False) else '✗'})",
|
f"Energy: {speech_info.get('energy', 0):.3f} ({'Y' if speech_info.get('energy_check', False) else 'N'})",
|
||||||
f"ZCR: {speech_info.get('zcr', 0):.3f} ({'✓' if speech_info.get('zcr_check', False) else '✗'})",
|
f"ZCR: {speech_info.get('zcr', 0):.3f} ({'Y' if speech_info.get('zcr_check', False) else 'N'})",
|
||||||
f"Spectral: ({'✓' if speech_info.get('spectral_check', False) else '✗'})",
|
f"Spectral: {'Y' if 300 < speech_info.get('centroid', 0) < 3400 else 'N'}/{'Y' if speech_info.get('rolloff', 0) < 2000 else 'N'}/{'Y' if speech_info.get('flux', 0) > 0.01 else 'N'} ({'Y' if speech_info.get('spectral_check', False) else 'N'})",
|
||||||
f"Harmonic: ({'✓' if speech_info.get('harmonic_check', False) else '✗'})",
|
f"Harmonic: {speech_info.get('hamonicity', 0):.3f} ({'Y' if speech_info.get('harmonic_check', False) else 'N'})",
|
||||||
f"Temporal: ({'✓' if speech_info.get('temporal_consistency', False) else '✗'})"
|
f"Temporal: ({'Y' if speech_info.get('temporal_consistency', False) else 'N'})"
|
||||||
]
|
]
|
||||||
|
|
||||||
for i, metric in enumerate(metrics):
|
for i, metric in enumerate(metrics):
|
||||||
cv2.putText(frame_array, metric,
|
cv2.putText(frame_array, metric,
|
||||||
(320, y_offset + i * 15), cv2.FONT_HERSHEY_SIMPLEX, 0.4,
|
(320, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.4,
|
||||||
(255, 255, 255), 1)
|
(255, 255, 255), 1)
|
||||||
|
y_offset += 15
|
||||||
|
|
||||||
|
logic_result = "E:" + ("Y" if speech_info.get('energy_check', False) else "N")
|
||||||
|
logic_result += " Z:" + ("Y" if speech_info.get('zcr_check', False) else "N")
|
||||||
|
logic_result += " S:" + ("Y" if speech_info.get('spectral_check', False) else "N")
|
||||||
|
logic_result += " H:" + ("Y" if speech_info.get('harmonic_check', False) else "N")
|
||||||
|
logic_result += " T:" + ("Y" if speech_info.get('temporal_consistency', False) else "N")
|
||||||
|
|
||||||
|
cv2.putText(frame_array, logic_result,
|
||||||
|
(320, y_offset + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
|
||||||
|
(255, 255, 255), 1)
|
||||||
|
|
||||||
|
|
||||||
# Noise floor indicator
|
# Noise floor indicator
|
||||||
noise_floor = speech_info.get('noise_floor_energy', 0)
|
noise_floor = speech_info.get('noise_floor_energy', 0)
|
||||||
@ -1760,12 +1858,12 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None:
|
|||||||
_model_loading_progress = 0.8
|
_model_loading_progress = 0.8
|
||||||
|
|
||||||
logger.info(f"Creating OptimizedAudioProcessor for {peer.peer_name}")
|
logger.info(f"Creating OptimizedAudioProcessor for {peer.peer_name}")
|
||||||
if _send_chat_func is None:
|
if _send_chat_func is None or _create_chat_message_func is None:
|
||||||
logger.error(f"No send_chat_func available for {peer.peer_name}")
|
logger.error(f"No send function available for {peer.peer_name}")
|
||||||
_model_loading_status = "Error: No send function available"
|
_model_loading_status = "Error: No send function available"
|
||||||
return
|
return
|
||||||
_audio_processors[peer.peer_name] = OptimizedAudioProcessor(
|
_audio_processors[peer.peer_name] = OptimizedAudioProcessor(
|
||||||
peer_name=peer.peer_name, send_chat_func=_send_chat_func
|
peer_name=peer.peer_name, send_chat_func=_send_chat_func, create_chat_message_func=_create_chat_message_func
|
||||||
)
|
)
|
||||||
|
|
||||||
_model_loading_status = "Ready for transcription"
|
_model_loading_status = "Ready for transcription"
|
||||||
@ -1773,9 +1871,9 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None:
|
|||||||
logger.info(f"OptimizedAudioProcessor ready for {peer.peer_name}")
|
logger.info(f"OptimizedAudioProcessor ready for {peer.peer_name}")
|
||||||
|
|
||||||
if peer.peer_name not in _audio_processors:
|
if peer.peer_name not in _audio_processors:
|
||||||
if _send_chat_func is None:
|
if _send_chat_func is None or _create_chat_message_func is None:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Cannot create processor for {peer.peer_name}: no send_chat_func"
|
f"Cannot create processor for {peer.peer_name}: no send_chat_func or create_chat_message_func"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -1993,7 +2091,7 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
|
|||||||
|
|
||||||
|
|
||||||
async def handle_chat_message(
|
async def handle_chat_message(
|
||||||
chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]]
|
chat_message: ChatMessageModel, send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Handle incoming chat messages."""
|
"""Handle incoming chat messages."""
|
||||||
return None
|
return None
|
||||||
@ -2009,16 +2107,18 @@ def get_track_handler() -> Callable[[Peer, MediaStreamTrack], Awaitable[None]]:
|
|||||||
return on_track_received
|
return on_track_received
|
||||||
|
|
||||||
|
|
||||||
def bind_send_chat_function(send_chat_func: Callable[[str], Awaitable[None]]) -> None:
|
def bind_send_chat_function(send_chat_func: Callable[[ChatMessageModel], Awaitable[None]], create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel]) -> None:
|
||||||
"""Bind the send chat function."""
|
"""Bind the send chat function."""
|
||||||
global _send_chat_func, _audio_processors
|
global _send_chat_func, _create_chat_message_func, _audio_processors
|
||||||
|
|
||||||
logger.info("Binding send chat function to OpenVINO whisper agent")
|
logger.info("Binding send chat function to OpenVINO whisper agent")
|
||||||
_send_chat_func = send_chat_func
|
_send_chat_func = send_chat_func
|
||||||
|
_create_chat_message_func = create_chat_message_func
|
||||||
|
|
||||||
# Update existing processors
|
# Update existing processors
|
||||||
for peer_name, processor in _audio_processors.items():
|
for peer_name, processor in _audio_processors.items():
|
||||||
processor.send_chat_func = send_chat_func
|
processor.send_chat_func = send_chat_func
|
||||||
|
processor.create_chat_message_func = create_chat_message_func
|
||||||
logger.debug(f"Updated processor for {peer_name} with new send chat function")
|
logger.debug(f"Updated processor for {peer_name} with new send chat function")
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ import json
|
|||||||
import websockets
|
import websockets
|
||||||
import time
|
import time
|
||||||
import re
|
import re
|
||||||
|
import secrets
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
Optional,
|
Optional,
|
||||||
@ -20,6 +21,7 @@ from typing import (
|
|||||||
Protocol,
|
Protocol,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
cast,
|
cast,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the parent directory to sys.path to allow absolute imports
|
# Add the parent directory to sys.path to allow absolute imports
|
||||||
@ -90,7 +92,7 @@ class WebRTCSignalingClient:
|
|||||||
session_name: str,
|
session_name: str,
|
||||||
insecure: bool = False,
|
insecure: bool = False,
|
||||||
create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None,
|
create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None,
|
||||||
bind_send_chat_function: Optional[Callable[[Callable[[str], Awaitable[None]]], None]] = None,
|
bind_send_chat_function: Optional[Callable[[Callable[[Union[str, ChatMessageModel]], Awaitable[None]], Callable[[str, Optional[str]], ChatMessageModel]], None]] = None,
|
||||||
registration_check_interval: float = 30.0,
|
registration_check_interval: float = 30.0,
|
||||||
):
|
):
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
@ -107,7 +109,7 @@ class WebRTCSignalingClient:
|
|||||||
|
|
||||||
if self.bind_send_chat_function:
|
if self.bind_send_chat_function:
|
||||||
# Bind the send_chat_message method to the bot's send function
|
# Bind the send_chat_message method to the bot's send function
|
||||||
self.bind_send_chat_function(self.send_chat_message)
|
self.bind_send_chat_function(self.send_chat_message, self.create_chat_message)
|
||||||
|
|
||||||
# WebSocket client protocol instance (typed as object to avoid Any)
|
# WebSocket client protocol instance (typed as object to avoid Any)
|
||||||
self.websocket: Optional[object] = None
|
self.websocket: Optional[object] = None
|
||||||
@ -424,22 +426,47 @@ class WebRTCSignalingClient:
|
|||||||
self.shutdown_requested = True
|
self.shutdown_requested = True
|
||||||
logger.info("Shutdown requested for WebRTC signaling client")
|
logger.info("Shutdown requested for WebRTC signaling client")
|
||||||
|
|
||||||
async def send_chat_message(self, message: str):
|
async def send_chat_message(self, message: Union[str, ChatMessageModel]):
|
||||||
"""Send a chat message to the lobby"""
|
"""Send a chat message to the lobby"""
|
||||||
if not self.is_registered:
|
if not self.is_registered:
|
||||||
logger.warning("Cannot send chat message: not registered")
|
logger.warning("Cannot send chat message: not registered")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not message.strip():
|
if isinstance(message, str):
|
||||||
logger.warning("Cannot send empty chat message")
|
if not message.strip():
|
||||||
return
|
logger.warning("Cannot send empty chat message")
|
||||||
|
return
|
||||||
|
# Create ChatMessageModel from string
|
||||||
|
chat_message = ChatMessageModel(
|
||||||
|
id=secrets.token_hex(8),
|
||||||
|
message=message.strip(),
|
||||||
|
sender_name=self.session_name,
|
||||||
|
sender_session_id=self.session_id,
|
||||||
|
timestamp=time.time(),
|
||||||
|
lobby_id=self.lobby_id,
|
||||||
|
)
|
||||||
|
message_data = chat_message.model_dump()
|
||||||
|
else:
|
||||||
|
# ChatMessageModel
|
||||||
|
message_data = message.model_dump()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._send_message("send_chat_message", {"message": message.strip()})
|
await self._send_message("send_chat_message", {"message": message_data})
|
||||||
logger.info(f"Sent chat message: {message[:50]}...")
|
logger.info(f"Sent chat message: {str(message)[:50]}...")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send chat message: {e}", exc_info=True)
|
logger.error(f"Failed to send chat message: {e}", exc_info=True)
|
||||||
|
|
||||||
|
def create_chat_message(self, message: str, message_id: Optional[str] = None) -> ChatMessageModel:
|
||||||
|
"""Create a ChatMessageModel with the correct session and lobby IDs"""
|
||||||
|
return ChatMessageModel(
|
||||||
|
id=message_id or secrets.token_hex(8),
|
||||||
|
message=message,
|
||||||
|
sender_name=self.session_name,
|
||||||
|
sender_session_id=self.session_id,
|
||||||
|
timestamp=time.time(),
|
||||||
|
lobby_id=self.lobby_id,
|
||||||
|
)
|
||||||
|
|
||||||
async def _setup_local_media(self):
|
async def _setup_local_media(self):
|
||||||
"""Create local media tracks"""
|
"""Create local media tracks"""
|
||||||
# If a bot provided a create_tracks callable, use it to create tracks.
|
# If a bot provided a create_tracks callable, use it to create tracks.
|
||||||
@ -478,6 +505,7 @@ class WebRTCSignalingClient:
|
|||||||
|
|
||||||
# Build message with explicit type to avoid type narrowing
|
# Build message with explicit type to avoid type narrowing
|
||||||
# Always include data field to match WebSocketMessageModel
|
# Always include data field to match WebSocketMessageModel
|
||||||
|
# Should be a ChatMessageModel, not custom dict
|
||||||
message: dict[str, object] = {
|
message: dict[str, object] = {
|
||||||
"type": message_type,
|
"type": message_type,
|
||||||
"data": data if data is not None else {}
|
"data": data if data is not None else {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user