diff --git a/client/src/LobbyChat.tsx b/client/src/LobbyChat.tsx index 992f4bd..7fea855 100644 --- a/client/src/LobbyChat.tsx +++ b/client/src/LobbyChat.tsx @@ -40,7 +40,18 @@ const LobbyChat: React.FC = ({ socketUrl, session, lobbyId }) => switch (message.type) { case "chat_message": const chatMessage = data as ChatMessage; - setMessages((prev) => [...prev, chatMessage]); + setMessages((prev) => { + const existingIndex = prev.findIndex((msg) => msg.id === chatMessage.id); + if (existingIndex >= 0) { + // Update existing message + const updatedMessages = [...prev]; + updatedMessages[existingIndex] = chatMessage; + return updatedMessages; + } else { + // Add new message + return [...prev, chatMessage]; + } + }); break; case "chat_messages": const chatMessages = data.messages as ChatMessage[]; diff --git a/client/src/MediaControl.css b/client/src/MediaControl.css index 0e5a221..4c8cea7 100644 --- a/client/src/MediaControl.css +++ b/client/src/MediaControl.css @@ -44,7 +44,7 @@ height: 3.75rem; min-width: 5rem; min-height: 3.75rem; - z-index: 50000; + z-index: 1200; border-radius: 0.25rem; } diff --git a/client/src/MediaControl.tsx b/client/src/MediaControl.tsx index c5f4913..20facaa 100644 --- a/client/src/MediaControl.tsx +++ b/client/src/MediaControl.tsx @@ -1573,6 +1573,7 @@ const MediaControl: React.FC = ({ isSelf, peer, className }) hideDefaultLines={false} snappable={true} snapThreshold={5} + origin={false} edge onDragStart={(e) => { setIsDragging(true); diff --git a/server/core/lobby_manager.py b/server/core/lobby_manager.py index 4c56500..c70fc6d 100644 --- a/server/core/lobby_manager.py +++ b/server/core/lobby_manager.py @@ -191,6 +191,33 @@ class Lobby: ] return chat_message + def update_chat_message(self, chat_message: ChatMessageModel) -> ChatMessageModel: + """Update an existing chat message in the lobby and return the updated message""" + with self.lock: + # Find the existing message by ID + for i, existing_msg in enumerate(self.chat_messages): + if existing_msg.id == chat_message.id: + # Update the message content and timestamp + updated_msg = ChatMessageModel( + id=chat_message.id, + message=chat_message.message, + sender_name=existing_msg.sender_name, # Keep original sender + sender_session_id=existing_msg.sender_session_id, # Keep original session + timestamp=time.time(), # Update timestamp + lobby_id=existing_msg.lobby_id, # Keep original lobby + ) + self.chat_messages[i] = updated_msg + return updated_msg + + # If message not found, add it as new + self.chat_messages.append(chat_message) + # Keep only the latest messages per lobby + if len(self.chat_messages) > LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY: + self.chat_messages = self.chat_messages[ + -LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY : + ] + return chat_message + def get_chat_messages(self, limit: int = 50) -> List[ChatMessageModel]: """Get the most recent chat messages from the lobby""" with self.lock: diff --git a/server/websocket/message_handlers.py b/server/websocket/message_handlers.py index aa92bd1..477508a 100644 --- a/server/websocket/message_handlers.py +++ b/server/websocket/message_handlers.py @@ -10,6 +10,7 @@ from typing import Dict, Any, TYPE_CHECKING from fastapi import WebSocket from shared.logger import logger +from shared.models import ChatMessageModel from .webrtc_signaling import WebRTCSignalingHandlers from core.error_handling import ( error_handler, @@ -252,13 +253,51 @@ class SendChatMessageHandler(MessageHandler): }) return - message_text = str(data["message"]).strip() - if not message_text: - return + message_data = data["message"] + + if isinstance(message_data, dict): + # Handle ChatMessageModel object + try: + chat_message = ChatMessageModel.model_validate(message_data) + # Validate that the sender matches the session + if chat_message.sender_session_id != session.id: + logger.error( + f"{session.getName()} - ChatMessageModel sender_session_id mismatch" + ) + await websocket.send_json( + { + "type": "error", + "data": { + "error": "ChatMessageModel sender_session_id does not match session" + }, + } + ) + return + # Update or add the message + chat_message = lobby.update_chat_message(chat_message) + logger.info( + f"{session.getName()} -> update_chat_message({lobby.getName()}, {chat_message.message[:50]}...)" + ) + except Exception as e: + logger.error(f"{session.getName()} - Invalid ChatMessageModel: {e}") + await websocket.send_json( + { + "type": "error", + "data": {"error": "Invalid ChatMessageModel format"}, + } + ) + return + else: + # Handle string message (legacy support) + message_text = str(message_data).strip() + if not message_text: + return + # Add the message to the lobby and broadcast it + chat_message = lobby.add_chat_message(session, message_text) + logger.info( + f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)" + ) - # Add the message to the lobby and broadcast it - chat_message = lobby.add_chat_message(session, message_text) - logger.info(f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)") await lobby.broadcast_chat_message(chat_message) diff --git a/voicebot/bots/ai_chatbot.py b/voicebot/bots/ai_chatbot.py index f02c163..02e7aa6 100644 --- a/voicebot/bots/ai_chatbot.py +++ b/voicebot/bots/ai_chatbot.py @@ -10,7 +10,7 @@ This bot demonstrates the advanced capabilities including: import os import time import uuid -from typing import Dict, Optional, Callable, Awaitable, Any +from typing import Dict, Optional, Callable, Awaitable, Any, Union from aiortc import MediaStreamTrack # Import system modules @@ -257,7 +257,7 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]: async def handle_chat_message( chat_message: ChatMessageModel, - send_message_func: Callable[[str], Awaitable[None]] + send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]] ) -> Optional[str]: """Handle incoming chat messages and provide AI-powered responses.""" global _bot_instance diff --git a/voicebot/bots/chatbot.py b/voicebot/bots/chatbot.py index 80fdb0c..28fc66e 100644 --- a/voicebot/bots/chatbot.py +++ b/voicebot/bots/chatbot.py @@ -4,7 +4,7 @@ This bot shows how to create an agent that primarily uses chat functionality rather than media streams. """ -from typing import Dict, Optional, Callable, Awaitable +from typing import Dict, Optional, Callable, Awaitable, Union import time import random from shared.logger import logger @@ -45,7 +45,7 @@ def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]: return {} -async def handle_chat_message(chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]]) -> Optional[str]: +async def handle_chat_message(chat_message: ChatMessageModel, send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]) -> Optional[str]: """Handle incoming chat messages and provide responses. Args: diff --git a/voicebot/bots/whisper.py b/voicebot/bots/whisper.py index e17c791..f8f60be 100644 --- a/voicebot/bots/whisper.py +++ b/voicebot/bots/whisper.py @@ -253,23 +253,29 @@ class AdvancedVAD: def __init__(self, sample_rate: int = SAMPLE_RATE): self.sample_rate = sample_rate - # More conservative thresholds - self.energy_threshold = 0.02 # Increased from 0.01 - self.zcr_min = 0.1 - self.zcr_max = 0.5 # Reject very high ZCR (digital noise) - self.spectral_centroid_min = 300 # Human speech starts around 300Hz - self.spectral_centroid_max = 3400 # Human speech typically under 3400Hz - self.spectral_rolloff_threshold = 2000 # Speech energy concentrated lower - self.minimum_duration = 0.3 # Require 300ms of consistent speech + # More permissive thresholds based on research + self.energy_threshold = 0.005 # Reduced from 0.02 + self.zcr_min = 0.02 # Reduced from 0.1 (voiced speech < 0.1) + self.zcr_max = 0.8 # Increased from 0.5 (unvoiced speech ~0.3-0.8) - # Temporal consistency tracking + # Spectral thresholds (keep existing - these work well) + self.spectral_centroid_min = 200 # Slightly lower + self.spectral_centroid_max = 4000 # Slightly higher + self.spectral_rolloff_threshold = 3000 # More permissive + + # Relaxed temporal consistency + self.minimum_duration = 0.2 # Reduced from 0.3s self.speech_history = [] - self.max_history = 10 # Track last 10 frames (1 second at 100ms chunks) + self.max_history = 8 # Reduced from 10 # Adaptive noise floor self.noise_floor_energy = 0.001 self.noise_floor_centroid = 1000 - self.adaptation_rate = 0.02 + self.adaptation_rate = 0.05 + + # Harmonicity improvements + self.prev_magnitude = None + self.harmonic_threshold = 0.15 # Reduced from 0.3 def analyze_frame(self, audio_data: AudioArray) -> Tuple[bool, dict]: """Analyze audio frame for speech vs noise.""" @@ -286,43 +292,66 @@ class AdvancedVAD: spectral_features = self._compute_spectral_features(audio_data) # Individual feature checks - energy_check = energy > max(self.energy_threshold, self.noise_floor_energy * 3) - zcr_check = self.zcr_min < zcr < self.zcr_max + + # Use adaptive energy threshold (reduced multiplier) + adaptive_energy_threshold = max( + self.energy_threshold, + self.noise_floor_energy * 2.0 # Reduced from 3.0 + ) + energy_check = energy > adaptive_energy_threshold + + # More permissive ZCR check (allow voiced OR unvoiced speech) + zcr_check = ( + self.zcr_min < zcr < self.zcr_max or # General range + zcr < 0.1 or # Definitely voiced + (0.2 < zcr < 0.6 and energy > self.energy_threshold * 2) # Unvoiced with energy + ) + + # Spectral check (more permissive) spectral_check = ( - self.spectral_centroid_min < spectral_features['centroid'] < self.spectral_centroid_max - and spectral_features['rolloff'] < self.spectral_rolloff_threshold - and spectral_features['flux'] > 0.01 # Spectral change indicates speech + self.spectral_centroid_min < spectral_features['centroid'] < self.spectral_centroid_max and + spectral_features['rolloff'] < self.spectral_rolloff_threshold and + spectral_features['flux'] > 0.005 # Reduced threshold ) - - # Harmonicity check (speech has harmonic structure) - harmonic_check = spectral_features['harmonicity'] > 0.3 - - # Combined decision with temporal consistency + + # Improved harmonicity check + harmonic_check = spectral_features['harmonicity'] > self.harmonic_threshold + + # More permissive combined decision (OR logic for some conditions) frame_has_speech = ( - energy_check and - zcr_check and - spectral_check and - harmonic_check + energy_check and ( + zcr_check or # ZCR is good, OR + spectral_check or # Spectral features are good, OR + harmonic_check # Harmonicity is good + ) + ) or ( + # Alternative path: strong energy + reasonable spectral + energy > adaptive_energy_threshold * 1.5 and spectral_check ) - + # Update history self.speech_history.append(frame_has_speech) if len(self.speech_history) > self.max_history: self.speech_history.pop(0) - # Require temporal consistency (at least 3 of last 5 frames) - recent_speech = sum(self.speech_history[-5:]) >= 3 if len(self.speech_history) >= 5 else frame_has_speech + # More permissive temporal consistency (2 of last 4, or 1 of last 2 if strong) + if len(self.speech_history) >= 4: + recent_speech = sum(self.speech_history[-4:]) >= 2 + elif len(self.speech_history) >= 2: + # For shorter history, be more permissive if recent frame is strong + recent_speech = ( + sum(self.speech_history[-2:]) >= 1 and + (energy > adaptive_energy_threshold * 1.2 or frame_has_speech) + ) + else: + recent_speech = frame_has_speech - # Update noise floor during silence + # Faster noise floor adaptation during silence if not frame_has_speech: self.noise_floor_energy = ( self.adaptation_rate * energy + (1 - self.adaptation_rate) * self.noise_floor_energy ) - self.noise_floor_centroid = ( - self.adaptation_rate * spectral_features['centroid'] + - (1 - self.adaptation_rate) * self.noise_floor_centroid - ) metrics = { 'energy': energy, @@ -332,6 +361,7 @@ class AdvancedVAD: 'flux': spectral_features['flux'], 'harmonicity': spectral_features['harmonicity'], 'noise_floor_energy': self.noise_floor_energy, + 'adaptive_threshold': adaptive_energy_threshold, 'energy_check': energy_check, 'zcr_check': zcr_check, 'spectral_check': spectral_check, @@ -386,44 +416,77 @@ class AdvancedVAD: def _compute_harmonicity(self, magnitude: np.ndarray, freqs: np.ndarray) -> float: """Compute harmonicity score (0-1, higher = more harmonic/speech-like).""" - + # Find fundamental frequency candidate (peak in 80-400Hz range for speech) - speech_range = (freqs >= 80) & (freqs <= 400) + # Expanded F0 range for better detection + speech_range = (freqs >= 60) & (freqs <= 500) # Expanded from 80-400Hz if not np.any(speech_range): return 0.0 - + speech_magnitude = magnitude[speech_range] speech_freqs = freqs[speech_range] - + if len(speech_magnitude) == 0: return 0.0 - + # Find strongest peak in speech range - f0_idx = np.argmax(speech_magnitude) + # More robust F0 detection - find peaks instead of just max + try: + # Import scipy here to handle missing dependency gracefully + from scipy.signal import find_peaks + + # Ensure distance is at least 1 + min_distance = max(1, int(len(speech_magnitude) * 0.05)) + + peaks, properties = find_peaks( + speech_magnitude, + height=np.max(speech_magnitude) * 0.05, # Lowered from 0.1 + distance=min_distance, # Minimum peak separation + ) + + if len(peaks) == 0: + # Fallback to simple max if no peaks found + f0_idx = np.argmax(speech_magnitude) + else: + # Use the strongest peak + strongest_peak_idx = np.argmax(speech_magnitude[peaks]) + f0_idx = peaks[strongest_peak_idx] + + except ImportError: + # scipy not available, use simple max + f0_idx = np.argmax(speech_magnitude) + f0 = speech_freqs[f0_idx] f0_strength = speech_magnitude[f0_idx] - - if f0_strength < np.max(magnitude) * 0.1: # F0 should be reasonably strong + + # More lenient F0 strength requirement + if f0_strength < np.max(magnitude) * 0.03: # Reduced from 0.1 return 0.0 - + # Check for harmonics (2*f0, 3*f0, etc.) harmonic_strength = 0.0 total_harmonics = 0 - - for harmonic in range(2, 6): # Check 2nd through 5th harmonics + + for harmonic in range(2, 5): # Check 2nd through 4th harmonics harmonic_freq = f0 * harmonic if harmonic_freq > freqs[-1]: break - - # Find closest frequency bin + + # Find closest frequency bins (check neighboring bins too) harmonic_idx = np.argmin(np.abs(freqs - harmonic_freq)) - harmonic_strength += magnitude[harmonic_idx] + + # Check a small neighborhood around the harmonic frequency + start_idx = max(0, harmonic_idx - 2) + end_idx = min(len(magnitude), harmonic_idx + 3) + local_max = np.max(magnitude[start_idx:end_idx]) + + harmonic_strength += local_max total_harmonics += 1 - + if total_harmonics == 0: return 0.0 - - # Normalize by fundamental strength and number of harmonics + + # Normalize and return harmonicity = (harmonic_strength / total_harmonics) / f0_strength return min(harmonicity, 1.0) @@ -899,7 +962,8 @@ class OpenVINOWhisperModel: # Global model instance with deferred loading _whisper_model: Optional[OpenVINOWhisperModel] = None _audio_processors: Dict[str, "OptimizedAudioProcessor"] = {} -_send_chat_func: Optional[Callable[[str], Awaitable[None]]] = None +_send_chat_func: Optional[Callable[[ChatMessageModel], Awaitable[None]]] = None +_create_chat_message_func: Optional[Callable[[str, Optional[str]], ChatMessageModel]] = None # Model loading status for video display _model_loading_status: str = "Not loaded" @@ -1018,10 +1082,11 @@ class OptimizedAudioProcessor: """Optimized audio processor for Intel Arc B580 with reduced latency.""" def __init__( - self, peer_name: str, send_chat_func: Callable[[str], Awaitable[None]] + self, peer_name: str, send_chat_func: Callable[[ChatMessageModel], Awaitable[None]], create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel] ): self.peer_name = peer_name self.send_chat_func = send_chat_func + self.create_chat_message_func = create_chat_message_func self.sample_rate = SAMPLE_RATE # Initialize visualization buffer if not already done @@ -1049,6 +1114,10 @@ class OptimizedAudioProcessor: self.current_phrase_audio = np.array([], dtype=np.float32) self.transcription_history: List[TranscriptionHistoryItem] = [] self.last_activity_time = time.time() + self.last_audio_time = time.time() # Track when any audio chunk is received + + # Current transcription message for refinements + self.current_message: Optional[ChatMessageModel] = None # Async processing self.processing_queue: asyncio.Queue[AudioQueueItem] = asyncio.Queue(maxsize=10) @@ -1076,6 +1145,9 @@ class OptimizedAudioProcessor: logger.error("Processor not running or empty audio data") return + # Update last audio time whenever any audio is received + self.last_audio_time = time.time() + is_speech, vad_metrics = self.advanced_vad.analyze_frame(audio_data) # Update visualization status @@ -1227,7 +1299,7 @@ class OptimizedAudioProcessor: # Check for final transcription on timeout if ( len(self.current_phrase_audio) > 0 - and time.time() - self.last_activity_time > 2.0 + and time.time() - self.last_audio_time > 2.0 ): logger.info( f"Final transcription timeout for {self.peer_name} (asyncio.TimeoutError)" @@ -1276,7 +1348,7 @@ class OptimizedAudioProcessor: # Check for final transcription if ( len(self.current_phrase_audio) > 0 - and time.time() - self.last_activity_time > 2.0 + and time.time() - self.last_audio_time > 2.0 ): if self.main_loop: logger.info( @@ -1409,29 +1481,43 @@ class OptimizedAudioProcessor: # return # Skip very low confidence # # Include confidence in message - # confidence_indicator = ( - # "✓" if avg_confidence > 0.8 else "?" if avg_confidence < 0.6 else "" - # ) - message = f"{self.peer_name}: {transcription}" # {confidence_indicator}[{avg_confidence:.1%}]" - - await self.send_chat_func(message) - + # Create ChatMessageModel for transcription if transcription: # Create message with timing status_marker = "⚡" if is_final else "🎤" type_marker = "" if is_final else " [streaming]" timing_info = f" (🚀 {transcription_time:.2f}s)" - message = f"{status_marker} {self.peer_name}{type_marker}: {transcription}{timing_info}" + message_text = f"{status_marker} {self.peer_name}{type_marker}: {transcription}{timing_info}" # Avoid duplicates if not self._is_duplicate(transcription): - await self.send_chat_func(message) + # For streaming transcriptions, reuse the current message ID if it exists + # For final transcriptions, always create a new message + if is_final: + # Final transcription - reset current message + self.current_message = None + message_id = None + else: + # Streaming transcription - reuse current message ID or create new + if self.current_message is None: + message_id = None # Will create new ID + else: + message_id = self.current_message.id + + # Create ChatMessageModel + chat_message = self.create_chat_message_func(message_text, message_id) + + # Update current message for streaming + if not is_final: + self.current_message = chat_message + + await self.send_chat_func(chat_message) # Update history self.transcription_history.append( TranscriptionHistoryItem( - message=message, timestamp=time.time(), is_final=is_final + message=message_text, timestamp=time.time(), is_final=is_final ) ) @@ -1676,7 +1762,7 @@ class WaveformVideoTrack(MediaStreamTrack): # Draw colored waveform if len(points) > 1: for i in range(len(points) - 1): - cv2.line(frame_array, points[i], points[i+1], colors[i], 2) + cv2.line(frame_array, points[i], points[i+1], colors[i], 1) # Add speech detection status overlay if speech_info: @@ -1701,29 +1787,41 @@ class WaveformVideoTrack(MediaStreamTrack): """Draw speech detection status information.""" y_offset = 100 - line_height = 25 # Main status is_speech = speech_info.get('is_speech', False) - status_text = "🎤 SPEECH" if is_speech else "🔇 NOISE" + status_text = "SPEECH" if is_speech else "NOISE" status_color = (0, 255, 0) if is_speech else (128, 128, 128) - cv2.putText(frame_array, f"{pname}: {status_text}", - (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.8, status_color, 2) + adaptive_thresh = speech_info.get('adaptive_threshold', 0) + cv2.putText(frame_array, f"{pname}: {status_text} (thresh: {adaptive_thresh:.4f})", + (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2) # Detailed metrics (smaller text) metrics = [ - f"Energy: {speech_info.get('energy', 0):.3f} ({'✓' if speech_info.get('energy_check', False) else '✗'})", - f"ZCR: {speech_info.get('zcr', 0):.3f} ({'✓' if speech_info.get('zcr_check', False) else '✗'})", - f"Spectral: ({'✓' if speech_info.get('spectral_check', False) else '✗'})", - f"Harmonic: ({'✓' if speech_info.get('harmonic_check', False) else '✗'})", - f"Temporal: ({'✓' if speech_info.get('temporal_consistency', False) else '✗'})" + f"Energy: {speech_info.get('energy', 0):.3f} ({'Y' if speech_info.get('energy_check', False) else 'N'})", + f"ZCR: {speech_info.get('zcr', 0):.3f} ({'Y' if speech_info.get('zcr_check', False) else 'N'})", + f"Spectral: {'Y' if 300 < speech_info.get('centroid', 0) < 3400 else 'N'}/{'Y' if speech_info.get('rolloff', 0) < 2000 else 'N'}/{'Y' if speech_info.get('flux', 0) > 0.01 else 'N'} ({'Y' if speech_info.get('spectral_check', False) else 'N'})", + f"Harmonic: {speech_info.get('hamonicity', 0):.3f} ({'Y' if speech_info.get('harmonic_check', False) else 'N'})", + f"Temporal: ({'Y' if speech_info.get('temporal_consistency', False) else 'N'})" ] - + for i, metric in enumerate(metrics): cv2.putText(frame_array, metric, - (320, y_offset + i * 15), cv2.FONT_HERSHEY_SIMPLEX, 0.4, + (320, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) + y_offset += 15 + + logic_result = "E:" + ("Y" if speech_info.get('energy_check', False) else "N") + logic_result += " Z:" + ("Y" if speech_info.get('zcr_check', False) else "N") + logic_result += " S:" + ("Y" if speech_info.get('spectral_check', False) else "N") + logic_result += " H:" + ("Y" if speech_info.get('harmonic_check', False) else "N") + logic_result += " T:" + ("Y" if speech_info.get('temporal_consistency', False) else "N") + + cv2.putText(frame_array, logic_result, + (320, y_offset + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, + (255, 255, 255), 1) + # Noise floor indicator noise_floor = speech_info.get('noise_floor_energy', 0) @@ -1760,12 +1858,12 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None: _model_loading_progress = 0.8 logger.info(f"Creating OptimizedAudioProcessor for {peer.peer_name}") - if _send_chat_func is None: - logger.error(f"No send_chat_func available for {peer.peer_name}") + if _send_chat_func is None or _create_chat_message_func is None: + logger.error(f"No send function available for {peer.peer_name}") _model_loading_status = "Error: No send function available" return _audio_processors[peer.peer_name] = OptimizedAudioProcessor( - peer_name=peer.peer_name, send_chat_func=_send_chat_func + peer_name=peer.peer_name, send_chat_func=_send_chat_func, create_chat_message_func=_create_chat_message_func ) _model_loading_status = "Ready for transcription" @@ -1773,9 +1871,9 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None: logger.info(f"OptimizedAudioProcessor ready for {peer.peer_name}") if peer.peer_name not in _audio_processors: - if _send_chat_func is None: + if _send_chat_func is None or _create_chat_message_func is None: logger.error( - f"Cannot create processor for {peer.peer_name}: no send_chat_func" + f"Cannot create processor for {peer.peer_name}: no send_chat_func or create_chat_message_func" ) return @@ -1993,7 +2091,7 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]: async def handle_chat_message( - chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]] + chat_message: ChatMessageModel, send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]] ) -> Optional[str]: """Handle incoming chat messages.""" return None @@ -2009,16 +2107,18 @@ def get_track_handler() -> Callable[[Peer, MediaStreamTrack], Awaitable[None]]: return on_track_received -def bind_send_chat_function(send_chat_func: Callable[[str], Awaitable[None]]) -> None: +def bind_send_chat_function(send_chat_func: Callable[[ChatMessageModel], Awaitable[None]], create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel]) -> None: """Bind the send chat function.""" - global _send_chat_func, _audio_processors + global _send_chat_func, _create_chat_message_func, _audio_processors logger.info("Binding send chat function to OpenVINO whisper agent") _send_chat_func = send_chat_func + _create_chat_message_func = create_chat_message_func # Update existing processors for peer_name, processor in _audio_processors.items(): processor.send_chat_func = send_chat_func + processor.create_chat_message_func = create_chat_message_func logger.debug(f"Updated processor for {peer_name} with new send chat function") diff --git a/voicebot/webrtc_signaling.py b/voicebot/webrtc_signaling.py index b58d965..6559cc1 100644 --- a/voicebot/webrtc_signaling.py +++ b/voicebot/webrtc_signaling.py @@ -12,6 +12,7 @@ import json import websockets import time import re +import secrets from typing import ( Dict, Optional, @@ -20,6 +21,7 @@ from typing import ( Protocol, AsyncIterator, cast, + Union, ) # Add the parent directory to sys.path to allow absolute imports @@ -90,7 +92,7 @@ class WebRTCSignalingClient: session_name: str, insecure: bool = False, create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None, - bind_send_chat_function: Optional[Callable[[Callable[[str], Awaitable[None]]], None]] = None, + bind_send_chat_function: Optional[Callable[[Callable[[Union[str, ChatMessageModel]], Awaitable[None]], Callable[[str, Optional[str]], ChatMessageModel]], None]] = None, registration_check_interval: float = 30.0, ): self.server_url = server_url @@ -107,7 +109,7 @@ class WebRTCSignalingClient: if self.bind_send_chat_function: # Bind the send_chat_message method to the bot's send function - self.bind_send_chat_function(self.send_chat_message) + self.bind_send_chat_function(self.send_chat_message, self.create_chat_message) # WebSocket client protocol instance (typed as object to avoid Any) self.websocket: Optional[object] = None @@ -424,22 +426,47 @@ class WebRTCSignalingClient: self.shutdown_requested = True logger.info("Shutdown requested for WebRTC signaling client") - async def send_chat_message(self, message: str): + async def send_chat_message(self, message: Union[str, ChatMessageModel]): """Send a chat message to the lobby""" if not self.is_registered: logger.warning("Cannot send chat message: not registered") return - if not message.strip(): - logger.warning("Cannot send empty chat message") - return + if isinstance(message, str): + if not message.strip(): + logger.warning("Cannot send empty chat message") + return + # Create ChatMessageModel from string + chat_message = ChatMessageModel( + id=secrets.token_hex(8), + message=message.strip(), + sender_name=self.session_name, + sender_session_id=self.session_id, + timestamp=time.time(), + lobby_id=self.lobby_id, + ) + message_data = chat_message.model_dump() + else: + # ChatMessageModel + message_data = message.model_dump() try: - await self._send_message("send_chat_message", {"message": message.strip()}) - logger.info(f"Sent chat message: {message[:50]}...") + await self._send_message("send_chat_message", {"message": message_data}) + logger.info(f"Sent chat message: {str(message)[:50]}...") except Exception as e: logger.error(f"Failed to send chat message: {e}", exc_info=True) + def create_chat_message(self, message: str, message_id: Optional[str] = None) -> ChatMessageModel: + """Create a ChatMessageModel with the correct session and lobby IDs""" + return ChatMessageModel( + id=message_id or secrets.token_hex(8), + message=message, + sender_name=self.session_name, + sender_session_id=self.session_id, + timestamp=time.time(), + lobby_id=self.lobby_id, + ) + async def _setup_local_media(self): """Create local media tracks""" # If a bot provided a create_tracks callable, use it to create tracks. @@ -478,6 +505,7 @@ class WebRTCSignalingClient: # Build message with explicit type to avoid type narrowing # Always include data field to match WebSocketMessageModel + # Should be a ChatMessageModel, not custom dict message: dict[str, object] = { "type": message_type, "data": data if data is not None else {}