Improved transcription to current message

This commit is contained in:
James Ketr 2025-09-14 15:48:17 -07:00
parent 50f290ac22
commit 110430d22a
9 changed files with 310 additions and 104 deletions

View File

@ -40,7 +40,18 @@ const LobbyChat: React.FC<LobbyChatProps> = ({ socketUrl, session, lobbyId }) =>
switch (message.type) { switch (message.type) {
case "chat_message": case "chat_message":
const chatMessage = data as ChatMessage; const chatMessage = data as ChatMessage;
setMessages((prev) => [...prev, chatMessage]); setMessages((prev) => {
const existingIndex = prev.findIndex((msg) => msg.id === chatMessage.id);
if (existingIndex >= 0) {
// Update existing message
const updatedMessages = [...prev];
updatedMessages[existingIndex] = chatMessage;
return updatedMessages;
} else {
// Add new message
return [...prev, chatMessage];
}
});
break; break;
case "chat_messages": case "chat_messages":
const chatMessages = data.messages as ChatMessage[]; const chatMessages = data.messages as ChatMessage[];

View File

@ -44,7 +44,7 @@
height: 3.75rem; height: 3.75rem;
min-width: 5rem; min-width: 5rem;
min-height: 3.75rem; min-height: 3.75rem;
z-index: 50000; z-index: 1200;
border-radius: 0.25rem; border-radius: 0.25rem;
} }

View File

@ -1573,6 +1573,7 @@ const MediaControl: React.FC<MediaControlProps> = ({ isSelf, peer, className })
hideDefaultLines={false} hideDefaultLines={false}
snappable={true} snappable={true}
snapThreshold={5} snapThreshold={5}
origin={false}
edge edge
onDragStart={(e) => { onDragStart={(e) => {
setIsDragging(true); setIsDragging(true);

View File

@ -191,6 +191,33 @@ class Lobby:
] ]
return chat_message return chat_message
def update_chat_message(self, chat_message: ChatMessageModel) -> ChatMessageModel:
"""Update an existing chat message in the lobby and return the updated message"""
with self.lock:
# Find the existing message by ID
for i, existing_msg in enumerate(self.chat_messages):
if existing_msg.id == chat_message.id:
# Update the message content and timestamp
updated_msg = ChatMessageModel(
id=chat_message.id,
message=chat_message.message,
sender_name=existing_msg.sender_name, # Keep original sender
sender_session_id=existing_msg.sender_session_id, # Keep original session
timestamp=time.time(), # Update timestamp
lobby_id=existing_msg.lobby_id, # Keep original lobby
)
self.chat_messages[i] = updated_msg
return updated_msg
# If message not found, add it as new
self.chat_messages.append(chat_message)
# Keep only the latest messages per lobby
if len(self.chat_messages) > LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY:
self.chat_messages = self.chat_messages[
-LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY :
]
return chat_message
def get_chat_messages(self, limit: int = 50) -> List[ChatMessageModel]: def get_chat_messages(self, limit: int = 50) -> List[ChatMessageModel]:
"""Get the most recent chat messages from the lobby""" """Get the most recent chat messages from the lobby"""
with self.lock: with self.lock:

View File

@ -10,6 +10,7 @@ from typing import Dict, Any, TYPE_CHECKING
from fastapi import WebSocket from fastapi import WebSocket
from shared.logger import logger from shared.logger import logger
from shared.models import ChatMessageModel
from .webrtc_signaling import WebRTCSignalingHandlers from .webrtc_signaling import WebRTCSignalingHandlers
from core.error_handling import ( from core.error_handling import (
error_handler, error_handler,
@ -252,13 +253,51 @@ class SendChatMessageHandler(MessageHandler):
}) })
return return
message_text = str(data["message"]).strip() message_data = data["message"]
if not message_text:
return if isinstance(message_data, dict):
# Handle ChatMessageModel object
try:
chat_message = ChatMessageModel.model_validate(message_data)
# Validate that the sender matches the session
if chat_message.sender_session_id != session.id:
logger.error(
f"{session.getName()} - ChatMessageModel sender_session_id mismatch"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": "ChatMessageModel sender_session_id does not match session"
},
}
)
return
# Update or add the message
chat_message = lobby.update_chat_message(chat_message)
logger.info(
f"{session.getName()} -> update_chat_message({lobby.getName()}, {chat_message.message[:50]}...)"
)
except Exception as e:
logger.error(f"{session.getName()} - Invalid ChatMessageModel: {e}")
await websocket.send_json(
{
"type": "error",
"data": {"error": "Invalid ChatMessageModel format"},
}
)
return
else:
# Handle string message (legacy support)
message_text = str(message_data).strip()
if not message_text:
return
# Add the message to the lobby and broadcast it
chat_message = lobby.add_chat_message(session, message_text)
logger.info(
f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)"
)
# Add the message to the lobby and broadcast it
chat_message = lobby.add_chat_message(session, message_text)
logger.info(f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)")
await lobby.broadcast_chat_message(chat_message) await lobby.broadcast_chat_message(chat_message)

View File

@ -10,7 +10,7 @@ This bot demonstrates the advanced capabilities including:
import os import os
import time import time
import uuid import uuid
from typing import Dict, Optional, Callable, Awaitable, Any from typing import Dict, Optional, Callable, Awaitable, Any, Union
from aiortc import MediaStreamTrack from aiortc import MediaStreamTrack
# Import system modules # Import system modules
@ -257,7 +257,7 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
async def handle_chat_message( async def handle_chat_message(
chat_message: ChatMessageModel, chat_message: ChatMessageModel,
send_message_func: Callable[[str], Awaitable[None]] send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]
) -> Optional[str]: ) -> Optional[str]:
"""Handle incoming chat messages and provide AI-powered responses.""" """Handle incoming chat messages and provide AI-powered responses."""
global _bot_instance global _bot_instance

View File

@ -4,7 +4,7 @@ This bot shows how to create an agent that primarily uses chat functionality
rather than media streams. rather than media streams.
""" """
from typing import Dict, Optional, Callable, Awaitable from typing import Dict, Optional, Callable, Awaitable, Union
import time import time
import random import random
from shared.logger import logger from shared.logger import logger
@ -45,7 +45,7 @@ def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
return {} return {}
async def handle_chat_message(chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]]) -> Optional[str]: async def handle_chat_message(chat_message: ChatMessageModel, send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]) -> Optional[str]:
"""Handle incoming chat messages and provide responses. """Handle incoming chat messages and provide responses.
Args: Args:

View File

@ -253,23 +253,29 @@ class AdvancedVAD:
def __init__(self, sample_rate: int = SAMPLE_RATE): def __init__(self, sample_rate: int = SAMPLE_RATE):
self.sample_rate = sample_rate self.sample_rate = sample_rate
# More conservative thresholds # More permissive thresholds based on research
self.energy_threshold = 0.02 # Increased from 0.01 self.energy_threshold = 0.005 # Reduced from 0.02
self.zcr_min = 0.1 self.zcr_min = 0.02 # Reduced from 0.1 (voiced speech < 0.1)
self.zcr_max = 0.5 # Reject very high ZCR (digital noise) self.zcr_max = 0.8 # Increased from 0.5 (unvoiced speech ~0.3-0.8)
self.spectral_centroid_min = 300 # Human speech starts around 300Hz
self.spectral_centroid_max = 3400 # Human speech typically under 3400Hz
self.spectral_rolloff_threshold = 2000 # Speech energy concentrated lower
self.minimum_duration = 0.3 # Require 300ms of consistent speech
# Temporal consistency tracking # Spectral thresholds (keep existing - these work well)
self.spectral_centroid_min = 200 # Slightly lower
self.spectral_centroid_max = 4000 # Slightly higher
self.spectral_rolloff_threshold = 3000 # More permissive
# Relaxed temporal consistency
self.minimum_duration = 0.2 # Reduced from 0.3s
self.speech_history = [] self.speech_history = []
self.max_history = 10 # Track last 10 frames (1 second at 100ms chunks) self.max_history = 8 # Reduced from 10
# Adaptive noise floor # Adaptive noise floor
self.noise_floor_energy = 0.001 self.noise_floor_energy = 0.001
self.noise_floor_centroid = 1000 self.noise_floor_centroid = 1000
self.adaptation_rate = 0.02 self.adaptation_rate = 0.05
# Harmonicity improvements
self.prev_magnitude = None
self.harmonic_threshold = 0.15 # Reduced from 0.3
def analyze_frame(self, audio_data: AudioArray) -> Tuple[bool, dict]: def analyze_frame(self, audio_data: AudioArray) -> Tuple[bool, dict]:
"""Analyze audio frame for speech vs noise.""" """Analyze audio frame for speech vs noise."""
@ -286,43 +292,66 @@ class AdvancedVAD:
spectral_features = self._compute_spectral_features(audio_data) spectral_features = self._compute_spectral_features(audio_data)
# Individual feature checks # Individual feature checks
energy_check = energy > max(self.energy_threshold, self.noise_floor_energy * 3)
zcr_check = self.zcr_min < zcr < self.zcr_max # Use adaptive energy threshold (reduced multiplier)
adaptive_energy_threshold = max(
self.energy_threshold,
self.noise_floor_energy * 2.0 # Reduced from 3.0
)
energy_check = energy > adaptive_energy_threshold
# More permissive ZCR check (allow voiced OR unvoiced speech)
zcr_check = (
self.zcr_min < zcr < self.zcr_max or # General range
zcr < 0.1 or # Definitely voiced
(0.2 < zcr < 0.6 and energy > self.energy_threshold * 2) # Unvoiced with energy
)
# Spectral check (more permissive)
spectral_check = ( spectral_check = (
self.spectral_centroid_min < spectral_features['centroid'] < self.spectral_centroid_max self.spectral_centroid_min < spectral_features['centroid'] < self.spectral_centroid_max and
and spectral_features['rolloff'] < self.spectral_rolloff_threshold spectral_features['rolloff'] < self.spectral_rolloff_threshold and
and spectral_features['flux'] > 0.01 # Spectral change indicates speech spectral_features['flux'] > 0.005 # Reduced threshold
) )
# Harmonicity check (speech has harmonic structure) # Improved harmonicity check
harmonic_check = spectral_features['harmonicity'] > 0.3 harmonic_check = spectral_features['harmonicity'] > self.harmonic_threshold
# Combined decision with temporal consistency # More permissive combined decision (OR logic for some conditions)
frame_has_speech = ( frame_has_speech = (
energy_check and energy_check and (
zcr_check and zcr_check or # ZCR is good, OR
spectral_check and spectral_check or # Spectral features are good, OR
harmonic_check harmonic_check # Harmonicity is good
)
) or (
# Alternative path: strong energy + reasonable spectral
energy > adaptive_energy_threshold * 1.5 and spectral_check
) )
# Update history # Update history
self.speech_history.append(frame_has_speech) self.speech_history.append(frame_has_speech)
if len(self.speech_history) > self.max_history: if len(self.speech_history) > self.max_history:
self.speech_history.pop(0) self.speech_history.pop(0)
# Require temporal consistency (at least 3 of last 5 frames) # More permissive temporal consistency (2 of last 4, or 1 of last 2 if strong)
recent_speech = sum(self.speech_history[-5:]) >= 3 if len(self.speech_history) >= 5 else frame_has_speech if len(self.speech_history) >= 4:
recent_speech = sum(self.speech_history[-4:]) >= 2
elif len(self.speech_history) >= 2:
# For shorter history, be more permissive if recent frame is strong
recent_speech = (
sum(self.speech_history[-2:]) >= 1 and
(energy > adaptive_energy_threshold * 1.2 or frame_has_speech)
)
else:
recent_speech = frame_has_speech
# Update noise floor during silence # Faster noise floor adaptation during silence
if not frame_has_speech: if not frame_has_speech:
self.noise_floor_energy = ( self.noise_floor_energy = (
self.adaptation_rate * energy + self.adaptation_rate * energy +
(1 - self.adaptation_rate) * self.noise_floor_energy (1 - self.adaptation_rate) * self.noise_floor_energy
) )
self.noise_floor_centroid = (
self.adaptation_rate * spectral_features['centroid'] +
(1 - self.adaptation_rate) * self.noise_floor_centroid
)
metrics = { metrics = {
'energy': energy, 'energy': energy,
@ -332,6 +361,7 @@ class AdvancedVAD:
'flux': spectral_features['flux'], 'flux': spectral_features['flux'],
'harmonicity': spectral_features['harmonicity'], 'harmonicity': spectral_features['harmonicity'],
'noise_floor_energy': self.noise_floor_energy, 'noise_floor_energy': self.noise_floor_energy,
'adaptive_threshold': adaptive_energy_threshold,
'energy_check': energy_check, 'energy_check': energy_check,
'zcr_check': zcr_check, 'zcr_check': zcr_check,
'spectral_check': spectral_check, 'spectral_check': spectral_check,
@ -386,44 +416,77 @@ class AdvancedVAD:
def _compute_harmonicity(self, magnitude: np.ndarray, freqs: np.ndarray) -> float: def _compute_harmonicity(self, magnitude: np.ndarray, freqs: np.ndarray) -> float:
"""Compute harmonicity score (0-1, higher = more harmonic/speech-like).""" """Compute harmonicity score (0-1, higher = more harmonic/speech-like)."""
# Find fundamental frequency candidate (peak in 80-400Hz range for speech) # Find fundamental frequency candidate (peak in 80-400Hz range for speech)
speech_range = (freqs >= 80) & (freqs <= 400) # Expanded F0 range for better detection
speech_range = (freqs >= 60) & (freqs <= 500) # Expanded from 80-400Hz
if not np.any(speech_range): if not np.any(speech_range):
return 0.0 return 0.0
speech_magnitude = magnitude[speech_range] speech_magnitude = magnitude[speech_range]
speech_freqs = freqs[speech_range] speech_freqs = freqs[speech_range]
if len(speech_magnitude) == 0: if len(speech_magnitude) == 0:
return 0.0 return 0.0
# Find strongest peak in speech range # Find strongest peak in speech range
f0_idx = np.argmax(speech_magnitude) # More robust F0 detection - find peaks instead of just max
try:
# Import scipy here to handle missing dependency gracefully
from scipy.signal import find_peaks
# Ensure distance is at least 1
min_distance = max(1, int(len(speech_magnitude) * 0.05))
peaks, properties = find_peaks(
speech_magnitude,
height=np.max(speech_magnitude) * 0.05, # Lowered from 0.1
distance=min_distance, # Minimum peak separation
)
if len(peaks) == 0:
# Fallback to simple max if no peaks found
f0_idx = np.argmax(speech_magnitude)
else:
# Use the strongest peak
strongest_peak_idx = np.argmax(speech_magnitude[peaks])
f0_idx = peaks[strongest_peak_idx]
except ImportError:
# scipy not available, use simple max
f0_idx = np.argmax(speech_magnitude)
f0 = speech_freqs[f0_idx] f0 = speech_freqs[f0_idx]
f0_strength = speech_magnitude[f0_idx] f0_strength = speech_magnitude[f0_idx]
if f0_strength < np.max(magnitude) * 0.1: # F0 should be reasonably strong # More lenient F0 strength requirement
if f0_strength < np.max(magnitude) * 0.03: # Reduced from 0.1
return 0.0 return 0.0
# Check for harmonics (2*f0, 3*f0, etc.) # Check for harmonics (2*f0, 3*f0, etc.)
harmonic_strength = 0.0 harmonic_strength = 0.0
total_harmonics = 0 total_harmonics = 0
for harmonic in range(2, 6): # Check 2nd through 5th harmonics for harmonic in range(2, 5): # Check 2nd through 4th harmonics
harmonic_freq = f0 * harmonic harmonic_freq = f0 * harmonic
if harmonic_freq > freqs[-1]: if harmonic_freq > freqs[-1]:
break break
# Find closest frequency bin # Find closest frequency bins (check neighboring bins too)
harmonic_idx = np.argmin(np.abs(freqs - harmonic_freq)) harmonic_idx = np.argmin(np.abs(freqs - harmonic_freq))
harmonic_strength += magnitude[harmonic_idx]
# Check a small neighborhood around the harmonic frequency
start_idx = max(0, harmonic_idx - 2)
end_idx = min(len(magnitude), harmonic_idx + 3)
local_max = np.max(magnitude[start_idx:end_idx])
harmonic_strength += local_max
total_harmonics += 1 total_harmonics += 1
if total_harmonics == 0: if total_harmonics == 0:
return 0.0 return 0.0
# Normalize by fundamental strength and number of harmonics # Normalize and return
harmonicity = (harmonic_strength / total_harmonics) / f0_strength harmonicity = (harmonic_strength / total_harmonics) / f0_strength
return min(harmonicity, 1.0) return min(harmonicity, 1.0)
@ -899,7 +962,8 @@ class OpenVINOWhisperModel:
# Global model instance with deferred loading # Global model instance with deferred loading
_whisper_model: Optional[OpenVINOWhisperModel] = None _whisper_model: Optional[OpenVINOWhisperModel] = None
_audio_processors: Dict[str, "OptimizedAudioProcessor"] = {} _audio_processors: Dict[str, "OptimizedAudioProcessor"] = {}
_send_chat_func: Optional[Callable[[str], Awaitable[None]]] = None _send_chat_func: Optional[Callable[[ChatMessageModel], Awaitable[None]]] = None
_create_chat_message_func: Optional[Callable[[str, Optional[str]], ChatMessageModel]] = None
# Model loading status for video display # Model loading status for video display
_model_loading_status: str = "Not loaded" _model_loading_status: str = "Not loaded"
@ -1018,10 +1082,11 @@ class OptimizedAudioProcessor:
"""Optimized audio processor for Intel Arc B580 with reduced latency.""" """Optimized audio processor for Intel Arc B580 with reduced latency."""
def __init__( def __init__(
self, peer_name: str, send_chat_func: Callable[[str], Awaitable[None]] self, peer_name: str, send_chat_func: Callable[[ChatMessageModel], Awaitable[None]], create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel]
): ):
self.peer_name = peer_name self.peer_name = peer_name
self.send_chat_func = send_chat_func self.send_chat_func = send_chat_func
self.create_chat_message_func = create_chat_message_func
self.sample_rate = SAMPLE_RATE self.sample_rate = SAMPLE_RATE
# Initialize visualization buffer if not already done # Initialize visualization buffer if not already done
@ -1049,6 +1114,10 @@ class OptimizedAudioProcessor:
self.current_phrase_audio = np.array([], dtype=np.float32) self.current_phrase_audio = np.array([], dtype=np.float32)
self.transcription_history: List[TranscriptionHistoryItem] = [] self.transcription_history: List[TranscriptionHistoryItem] = []
self.last_activity_time = time.time() self.last_activity_time = time.time()
self.last_audio_time = time.time() # Track when any audio chunk is received
# Current transcription message for refinements
self.current_message: Optional[ChatMessageModel] = None
# Async processing # Async processing
self.processing_queue: asyncio.Queue[AudioQueueItem] = asyncio.Queue(maxsize=10) self.processing_queue: asyncio.Queue[AudioQueueItem] = asyncio.Queue(maxsize=10)
@ -1076,6 +1145,9 @@ class OptimizedAudioProcessor:
logger.error("Processor not running or empty audio data") logger.error("Processor not running or empty audio data")
return return
# Update last audio time whenever any audio is received
self.last_audio_time = time.time()
is_speech, vad_metrics = self.advanced_vad.analyze_frame(audio_data) is_speech, vad_metrics = self.advanced_vad.analyze_frame(audio_data)
# Update visualization status # Update visualization status
@ -1227,7 +1299,7 @@ class OptimizedAudioProcessor:
# Check for final transcription on timeout # Check for final transcription on timeout
if ( if (
len(self.current_phrase_audio) > 0 len(self.current_phrase_audio) > 0
and time.time() - self.last_activity_time > 2.0 and time.time() - self.last_audio_time > 2.0
): ):
logger.info( logger.info(
f"Final transcription timeout for {self.peer_name} (asyncio.TimeoutError)" f"Final transcription timeout for {self.peer_name} (asyncio.TimeoutError)"
@ -1276,7 +1348,7 @@ class OptimizedAudioProcessor:
# Check for final transcription # Check for final transcription
if ( if (
len(self.current_phrase_audio) > 0 len(self.current_phrase_audio) > 0
and time.time() - self.last_activity_time > 2.0 and time.time() - self.last_audio_time > 2.0
): ):
if self.main_loop: if self.main_loop:
logger.info( logger.info(
@ -1409,29 +1481,43 @@ class OptimizedAudioProcessor:
# return # Skip very low confidence # return # Skip very low confidence
# # Include confidence in message # # Include confidence in message
# confidence_indicator = ( # Create ChatMessageModel for transcription
# "✓" if avg_confidence > 0.8 else "?" if avg_confidence < 0.6 else ""
# )
message = f"{self.peer_name}: {transcription}" # {confidence_indicator}[{avg_confidence:.1%}]"
await self.send_chat_func(message)
if transcription: if transcription:
# Create message with timing # Create message with timing
status_marker = "" if is_final else "🎤" status_marker = "" if is_final else "🎤"
type_marker = "" if is_final else " [streaming]" type_marker = "" if is_final else " [streaming]"
timing_info = f" (🚀 {transcription_time:.2f}s)" timing_info = f" (🚀 {transcription_time:.2f}s)"
message = f"{status_marker} {self.peer_name}{type_marker}: {transcription}{timing_info}" message_text = f"{status_marker} {self.peer_name}{type_marker}: {transcription}{timing_info}"
# Avoid duplicates # Avoid duplicates
if not self._is_duplicate(transcription): if not self._is_duplicate(transcription):
await self.send_chat_func(message) # For streaming transcriptions, reuse the current message ID if it exists
# For final transcriptions, always create a new message
if is_final:
# Final transcription - reset current message
self.current_message = None
message_id = None
else:
# Streaming transcription - reuse current message ID or create new
if self.current_message is None:
message_id = None # Will create new ID
else:
message_id = self.current_message.id
# Create ChatMessageModel
chat_message = self.create_chat_message_func(message_text, message_id)
# Update current message for streaming
if not is_final:
self.current_message = chat_message
await self.send_chat_func(chat_message)
# Update history # Update history
self.transcription_history.append( self.transcription_history.append(
TranscriptionHistoryItem( TranscriptionHistoryItem(
message=message, timestamp=time.time(), is_final=is_final message=message_text, timestamp=time.time(), is_final=is_final
) )
) )
@ -1676,7 +1762,7 @@ class WaveformVideoTrack(MediaStreamTrack):
# Draw colored waveform # Draw colored waveform
if len(points) > 1: if len(points) > 1:
for i in range(len(points) - 1): for i in range(len(points) - 1):
cv2.line(frame_array, points[i], points[i+1], colors[i], 2) cv2.line(frame_array, points[i], points[i+1], colors[i], 1)
# Add speech detection status overlay # Add speech detection status overlay
if speech_info: if speech_info:
@ -1701,29 +1787,41 @@ class WaveformVideoTrack(MediaStreamTrack):
"""Draw speech detection status information.""" """Draw speech detection status information."""
y_offset = 100 y_offset = 100
line_height = 25
# Main status # Main status
is_speech = speech_info.get('is_speech', False) is_speech = speech_info.get('is_speech', False)
status_text = "🎤 SPEECH" if is_speech else "🔇 NOISE" status_text = "SPEECH" if is_speech else "NOISE"
status_color = (0, 255, 0) if is_speech else (128, 128, 128) status_color = (0, 255, 0) if is_speech else (128, 128, 128)
cv2.putText(frame_array, f"{pname}: {status_text}", adaptive_thresh = speech_info.get('adaptive_threshold', 0)
(10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.8, status_color, 2) cv2.putText(frame_array, f"{pname}: {status_text} (thresh: {adaptive_thresh:.4f})",
(10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2)
# Detailed metrics (smaller text) # Detailed metrics (smaller text)
metrics = [ metrics = [
f"Energy: {speech_info.get('energy', 0):.3f} ({'' if speech_info.get('energy_check', False) else ''})", f"Energy: {speech_info.get('energy', 0):.3f} ({'Y' if speech_info.get('energy_check', False) else 'N'})",
f"ZCR: {speech_info.get('zcr', 0):.3f} ({'' if speech_info.get('zcr_check', False) else ''})", f"ZCR: {speech_info.get('zcr', 0):.3f} ({'Y' if speech_info.get('zcr_check', False) else 'N'})",
f"Spectral: ({'' if speech_info.get('spectral_check', False) else ''})", f"Spectral: {'Y' if 300 < speech_info.get('centroid', 0) < 3400 else 'N'}/{'Y' if speech_info.get('rolloff', 0) < 2000 else 'N'}/{'Y' if speech_info.get('flux', 0) > 0.01 else 'N'} ({'Y' if speech_info.get('spectral_check', False) else 'N'})",
f"Harmonic: ({'' if speech_info.get('harmonic_check', False) else ''})", f"Harmonic: {speech_info.get('hamonicity', 0):.3f} ({'Y' if speech_info.get('harmonic_check', False) else 'N'})",
f"Temporal: ({'' if speech_info.get('temporal_consistency', False) else ''})" f"Temporal: ({'Y' if speech_info.get('temporal_consistency', False) else 'N'})"
] ]
for i, metric in enumerate(metrics): for i, metric in enumerate(metrics):
cv2.putText(frame_array, metric, cv2.putText(frame_array, metric,
(320, y_offset + i * 15), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (320, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.4,
(255, 255, 255), 1) (255, 255, 255), 1)
y_offset += 15
logic_result = "E:" + ("Y" if speech_info.get('energy_check', False) else "N")
logic_result += " Z:" + ("Y" if speech_info.get('zcr_check', False) else "N")
logic_result += " S:" + ("Y" if speech_info.get('spectral_check', False) else "N")
logic_result += " H:" + ("Y" if speech_info.get('harmonic_check', False) else "N")
logic_result += " T:" + ("Y" if speech_info.get('temporal_consistency', False) else "N")
cv2.putText(frame_array, logic_result,
(320, y_offset + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
(255, 255, 255), 1)
# Noise floor indicator # Noise floor indicator
noise_floor = speech_info.get('noise_floor_energy', 0) noise_floor = speech_info.get('noise_floor_energy', 0)
@ -1760,12 +1858,12 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None:
_model_loading_progress = 0.8 _model_loading_progress = 0.8
logger.info(f"Creating OptimizedAudioProcessor for {peer.peer_name}") logger.info(f"Creating OptimizedAudioProcessor for {peer.peer_name}")
if _send_chat_func is None: if _send_chat_func is None or _create_chat_message_func is None:
logger.error(f"No send_chat_func available for {peer.peer_name}") logger.error(f"No send function available for {peer.peer_name}")
_model_loading_status = "Error: No send function available" _model_loading_status = "Error: No send function available"
return return
_audio_processors[peer.peer_name] = OptimizedAudioProcessor( _audio_processors[peer.peer_name] = OptimizedAudioProcessor(
peer_name=peer.peer_name, send_chat_func=_send_chat_func peer_name=peer.peer_name, send_chat_func=_send_chat_func, create_chat_message_func=_create_chat_message_func
) )
_model_loading_status = "Ready for transcription" _model_loading_status = "Ready for transcription"
@ -1773,9 +1871,9 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None:
logger.info(f"OptimizedAudioProcessor ready for {peer.peer_name}") logger.info(f"OptimizedAudioProcessor ready for {peer.peer_name}")
if peer.peer_name not in _audio_processors: if peer.peer_name not in _audio_processors:
if _send_chat_func is None: if _send_chat_func is None or _create_chat_message_func is None:
logger.error( logger.error(
f"Cannot create processor for {peer.peer_name}: no send_chat_func" f"Cannot create processor for {peer.peer_name}: no send_chat_func or create_chat_message_func"
) )
return return
@ -1993,7 +2091,7 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
async def handle_chat_message( async def handle_chat_message(
chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]] chat_message: ChatMessageModel, send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]
) -> Optional[str]: ) -> Optional[str]:
"""Handle incoming chat messages.""" """Handle incoming chat messages."""
return None return None
@ -2009,16 +2107,18 @@ def get_track_handler() -> Callable[[Peer, MediaStreamTrack], Awaitable[None]]:
return on_track_received return on_track_received
def bind_send_chat_function(send_chat_func: Callable[[str], Awaitable[None]]) -> None: def bind_send_chat_function(send_chat_func: Callable[[ChatMessageModel], Awaitable[None]], create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel]) -> None:
"""Bind the send chat function.""" """Bind the send chat function."""
global _send_chat_func, _audio_processors global _send_chat_func, _create_chat_message_func, _audio_processors
logger.info("Binding send chat function to OpenVINO whisper agent") logger.info("Binding send chat function to OpenVINO whisper agent")
_send_chat_func = send_chat_func _send_chat_func = send_chat_func
_create_chat_message_func = create_chat_message_func
# Update existing processors # Update existing processors
for peer_name, processor in _audio_processors.items(): for peer_name, processor in _audio_processors.items():
processor.send_chat_func = send_chat_func processor.send_chat_func = send_chat_func
processor.create_chat_message_func = create_chat_message_func
logger.debug(f"Updated processor for {peer_name} with new send chat function") logger.debug(f"Updated processor for {peer_name} with new send chat function")

View File

@ -12,6 +12,7 @@ import json
import websockets import websockets
import time import time
import re import re
import secrets
from typing import ( from typing import (
Dict, Dict,
Optional, Optional,
@ -20,6 +21,7 @@ from typing import (
Protocol, Protocol,
AsyncIterator, AsyncIterator,
cast, cast,
Union,
) )
# Add the parent directory to sys.path to allow absolute imports # Add the parent directory to sys.path to allow absolute imports
@ -90,7 +92,7 @@ class WebRTCSignalingClient:
session_name: str, session_name: str,
insecure: bool = False, insecure: bool = False,
create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None, create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None,
bind_send_chat_function: Optional[Callable[[Callable[[str], Awaitable[None]]], None]] = None, bind_send_chat_function: Optional[Callable[[Callable[[Union[str, ChatMessageModel]], Awaitable[None]], Callable[[str, Optional[str]], ChatMessageModel]], None]] = None,
registration_check_interval: float = 30.0, registration_check_interval: float = 30.0,
): ):
self.server_url = server_url self.server_url = server_url
@ -107,7 +109,7 @@ class WebRTCSignalingClient:
if self.bind_send_chat_function: if self.bind_send_chat_function:
# Bind the send_chat_message method to the bot's send function # Bind the send_chat_message method to the bot's send function
self.bind_send_chat_function(self.send_chat_message) self.bind_send_chat_function(self.send_chat_message, self.create_chat_message)
# WebSocket client protocol instance (typed as object to avoid Any) # WebSocket client protocol instance (typed as object to avoid Any)
self.websocket: Optional[object] = None self.websocket: Optional[object] = None
@ -424,22 +426,47 @@ class WebRTCSignalingClient:
self.shutdown_requested = True self.shutdown_requested = True
logger.info("Shutdown requested for WebRTC signaling client") logger.info("Shutdown requested for WebRTC signaling client")
async def send_chat_message(self, message: str): async def send_chat_message(self, message: Union[str, ChatMessageModel]):
"""Send a chat message to the lobby""" """Send a chat message to the lobby"""
if not self.is_registered: if not self.is_registered:
logger.warning("Cannot send chat message: not registered") logger.warning("Cannot send chat message: not registered")
return return
if not message.strip(): if isinstance(message, str):
logger.warning("Cannot send empty chat message") if not message.strip():
return logger.warning("Cannot send empty chat message")
return
# Create ChatMessageModel from string
chat_message = ChatMessageModel(
id=secrets.token_hex(8),
message=message.strip(),
sender_name=self.session_name,
sender_session_id=self.session_id,
timestamp=time.time(),
lobby_id=self.lobby_id,
)
message_data = chat_message.model_dump()
else:
# ChatMessageModel
message_data = message.model_dump()
try: try:
await self._send_message("send_chat_message", {"message": message.strip()}) await self._send_message("send_chat_message", {"message": message_data})
logger.info(f"Sent chat message: {message[:50]}...") logger.info(f"Sent chat message: {str(message)[:50]}...")
except Exception as e: except Exception as e:
logger.error(f"Failed to send chat message: {e}", exc_info=True) logger.error(f"Failed to send chat message: {e}", exc_info=True)
def create_chat_message(self, message: str, message_id: Optional[str] = None) -> ChatMessageModel:
"""Create a ChatMessageModel with the correct session and lobby IDs"""
return ChatMessageModel(
id=message_id or secrets.token_hex(8),
message=message,
sender_name=self.session_name,
sender_session_id=self.session_id,
timestamp=time.time(),
lobby_id=self.lobby_id,
)
async def _setup_local_media(self): async def _setup_local_media(self):
"""Create local media tracks""" """Create local media tracks"""
# If a bot provided a create_tracks callable, use it to create tracks. # If a bot provided a create_tracks callable, use it to create tracks.
@ -478,6 +505,7 @@ class WebRTCSignalingClient:
# Build message with explicit type to avoid type narrowing # Build message with explicit type to avoid type narrowing
# Always include data field to match WebSocketMessageModel # Always include data field to match WebSocketMessageModel
# Should be a ChatMessageModel, not custom dict
message: dict[str, object] = { message: dict[str, object] = {
"type": message_type, "type": message_type,
"data": data if data is not None else {} "data": data if data is not None else {}