Improved transcription to current message

This commit is contained in:
James Ketr 2025-09-14 15:48:17 -07:00
parent 50f290ac22
commit 110430d22a
9 changed files with 310 additions and 104 deletions

View File

@ -40,7 +40,18 @@ const LobbyChat: React.FC<LobbyChatProps> = ({ socketUrl, session, lobbyId }) =>
switch (message.type) {
case "chat_message":
const chatMessage = data as ChatMessage;
setMessages((prev) => [...prev, chatMessage]);
setMessages((prev) => {
const existingIndex = prev.findIndex((msg) => msg.id === chatMessage.id);
if (existingIndex >= 0) {
// Update existing message
const updatedMessages = [...prev];
updatedMessages[existingIndex] = chatMessage;
return updatedMessages;
} else {
// Add new message
return [...prev, chatMessage];
}
});
break;
case "chat_messages":
const chatMessages = data.messages as ChatMessage[];

View File

@ -44,7 +44,7 @@
height: 3.75rem;
min-width: 5rem;
min-height: 3.75rem;
z-index: 50000;
z-index: 1200;
border-radius: 0.25rem;
}

View File

@ -1573,6 +1573,7 @@ const MediaControl: React.FC<MediaControlProps> = ({ isSelf, peer, className })
hideDefaultLines={false}
snappable={true}
snapThreshold={5}
origin={false}
edge
onDragStart={(e) => {
setIsDragging(true);

View File

@ -191,6 +191,33 @@ class Lobby:
]
return chat_message
def update_chat_message(self, chat_message: ChatMessageModel) -> ChatMessageModel:
"""Update an existing chat message in the lobby and return the updated message"""
with self.lock:
# Find the existing message by ID
for i, existing_msg in enumerate(self.chat_messages):
if existing_msg.id == chat_message.id:
# Update the message content and timestamp
updated_msg = ChatMessageModel(
id=chat_message.id,
message=chat_message.message,
sender_name=existing_msg.sender_name, # Keep original sender
sender_session_id=existing_msg.sender_session_id, # Keep original session
timestamp=time.time(), # Update timestamp
lobby_id=existing_msg.lobby_id, # Keep original lobby
)
self.chat_messages[i] = updated_msg
return updated_msg
# If message not found, add it as new
self.chat_messages.append(chat_message)
# Keep only the latest messages per lobby
if len(self.chat_messages) > LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY:
self.chat_messages = self.chat_messages[
-LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY :
]
return chat_message
def get_chat_messages(self, limit: int = 50) -> List[ChatMessageModel]:
"""Get the most recent chat messages from the lobby"""
with self.lock:

View File

@ -10,6 +10,7 @@ from typing import Dict, Any, TYPE_CHECKING
from fastapi import WebSocket
from shared.logger import logger
from shared.models import ChatMessageModel
from .webrtc_signaling import WebRTCSignalingHandlers
from core.error_handling import (
error_handler,
@ -252,13 +253,51 @@ class SendChatMessageHandler(MessageHandler):
})
return
message_text = str(data["message"]).strip()
if not message_text:
return
message_data = data["message"]
if isinstance(message_data, dict):
# Handle ChatMessageModel object
try:
chat_message = ChatMessageModel.model_validate(message_data)
# Validate that the sender matches the session
if chat_message.sender_session_id != session.id:
logger.error(
f"{session.getName()} - ChatMessageModel sender_session_id mismatch"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": "ChatMessageModel sender_session_id does not match session"
},
}
)
return
# Update or add the message
chat_message = lobby.update_chat_message(chat_message)
logger.info(
f"{session.getName()} -> update_chat_message({lobby.getName()}, {chat_message.message[:50]}...)"
)
except Exception as e:
logger.error(f"{session.getName()} - Invalid ChatMessageModel: {e}")
await websocket.send_json(
{
"type": "error",
"data": {"error": "Invalid ChatMessageModel format"},
}
)
return
else:
# Handle string message (legacy support)
message_text = str(message_data).strip()
if not message_text:
return
# Add the message to the lobby and broadcast it
chat_message = lobby.add_chat_message(session, message_text)
logger.info(
f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)"
)
# Add the message to the lobby and broadcast it
chat_message = lobby.add_chat_message(session, message_text)
logger.info(f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)")
await lobby.broadcast_chat_message(chat_message)

View File

@ -10,7 +10,7 @@ This bot demonstrates the advanced capabilities including:
import os
import time
import uuid
from typing import Dict, Optional, Callable, Awaitable, Any
from typing import Dict, Optional, Callable, Awaitable, Any, Union
from aiortc import MediaStreamTrack
# Import system modules
@ -257,7 +257,7 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
async def handle_chat_message(
chat_message: ChatMessageModel,
send_message_func: Callable[[str], Awaitable[None]]
send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]
) -> Optional[str]:
"""Handle incoming chat messages and provide AI-powered responses."""
global _bot_instance

View File

@ -4,7 +4,7 @@ This bot shows how to create an agent that primarily uses chat functionality
rather than media streams.
"""
from typing import Dict, Optional, Callable, Awaitable
from typing import Dict, Optional, Callable, Awaitable, Union
import time
import random
from shared.logger import logger
@ -45,7 +45,7 @@ def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
return {}
async def handle_chat_message(chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]]) -> Optional[str]:
async def handle_chat_message(chat_message: ChatMessageModel, send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]) -> Optional[str]:
"""Handle incoming chat messages and provide responses.
Args:

View File

@ -253,23 +253,29 @@ class AdvancedVAD:
def __init__(self, sample_rate: int = SAMPLE_RATE):
self.sample_rate = sample_rate
# More conservative thresholds
self.energy_threshold = 0.02 # Increased from 0.01
self.zcr_min = 0.1
self.zcr_max = 0.5 # Reject very high ZCR (digital noise)
self.spectral_centroid_min = 300 # Human speech starts around 300Hz
self.spectral_centroid_max = 3400 # Human speech typically under 3400Hz
self.spectral_rolloff_threshold = 2000 # Speech energy concentrated lower
self.minimum_duration = 0.3 # Require 300ms of consistent speech
# More permissive thresholds based on research
self.energy_threshold = 0.005 # Reduced from 0.02
self.zcr_min = 0.02 # Reduced from 0.1 (voiced speech < 0.1)
self.zcr_max = 0.8 # Increased from 0.5 (unvoiced speech ~0.3-0.8)
# Temporal consistency tracking
# Spectral thresholds (keep existing - these work well)
self.spectral_centroid_min = 200 # Slightly lower
self.spectral_centroid_max = 4000 # Slightly higher
self.spectral_rolloff_threshold = 3000 # More permissive
# Relaxed temporal consistency
self.minimum_duration = 0.2 # Reduced from 0.3s
self.speech_history = []
self.max_history = 10 # Track last 10 frames (1 second at 100ms chunks)
self.max_history = 8 # Reduced from 10
# Adaptive noise floor
self.noise_floor_energy = 0.001
self.noise_floor_centroid = 1000
self.adaptation_rate = 0.02
self.adaptation_rate = 0.05
# Harmonicity improvements
self.prev_magnitude = None
self.harmonic_threshold = 0.15 # Reduced from 0.3
def analyze_frame(self, audio_data: AudioArray) -> Tuple[bool, dict]:
"""Analyze audio frame for speech vs noise."""
@ -286,43 +292,66 @@ class AdvancedVAD:
spectral_features = self._compute_spectral_features(audio_data)
# Individual feature checks
energy_check = energy > max(self.energy_threshold, self.noise_floor_energy * 3)
zcr_check = self.zcr_min < zcr < self.zcr_max
# Use adaptive energy threshold (reduced multiplier)
adaptive_energy_threshold = max(
self.energy_threshold,
self.noise_floor_energy * 2.0 # Reduced from 3.0
)
energy_check = energy > adaptive_energy_threshold
# More permissive ZCR check (allow voiced OR unvoiced speech)
zcr_check = (
self.zcr_min < zcr < self.zcr_max or # General range
zcr < 0.1 or # Definitely voiced
(0.2 < zcr < 0.6 and energy > self.energy_threshold * 2) # Unvoiced with energy
)
# Spectral check (more permissive)
spectral_check = (
self.spectral_centroid_min < spectral_features['centroid'] < self.spectral_centroid_max
and spectral_features['rolloff'] < self.spectral_rolloff_threshold
and spectral_features['flux'] > 0.01 # Spectral change indicates speech
self.spectral_centroid_min < spectral_features['centroid'] < self.spectral_centroid_max and
spectral_features['rolloff'] < self.spectral_rolloff_threshold and
spectral_features['flux'] > 0.005 # Reduced threshold
)
# Harmonicity check (speech has harmonic structure)
harmonic_check = spectral_features['harmonicity'] > 0.3
# Combined decision with temporal consistency
# Improved harmonicity check
harmonic_check = spectral_features['harmonicity'] > self.harmonic_threshold
# More permissive combined decision (OR logic for some conditions)
frame_has_speech = (
energy_check and
zcr_check and
spectral_check and
harmonic_check
energy_check and (
zcr_check or # ZCR is good, OR
spectral_check or # Spectral features are good, OR
harmonic_check # Harmonicity is good
)
) or (
# Alternative path: strong energy + reasonable spectral
energy > adaptive_energy_threshold * 1.5 and spectral_check
)
# Update history
self.speech_history.append(frame_has_speech)
if len(self.speech_history) > self.max_history:
self.speech_history.pop(0)
# Require temporal consistency (at least 3 of last 5 frames)
recent_speech = sum(self.speech_history[-5:]) >= 3 if len(self.speech_history) >= 5 else frame_has_speech
# More permissive temporal consistency (2 of last 4, or 1 of last 2 if strong)
if len(self.speech_history) >= 4:
recent_speech = sum(self.speech_history[-4:]) >= 2
elif len(self.speech_history) >= 2:
# For shorter history, be more permissive if recent frame is strong
recent_speech = (
sum(self.speech_history[-2:]) >= 1 and
(energy > adaptive_energy_threshold * 1.2 or frame_has_speech)
)
else:
recent_speech = frame_has_speech
# Update noise floor during silence
# Faster noise floor adaptation during silence
if not frame_has_speech:
self.noise_floor_energy = (
self.adaptation_rate * energy +
(1 - self.adaptation_rate) * self.noise_floor_energy
)
self.noise_floor_centroid = (
self.adaptation_rate * spectral_features['centroid'] +
(1 - self.adaptation_rate) * self.noise_floor_centroid
)
metrics = {
'energy': energy,
@ -332,6 +361,7 @@ class AdvancedVAD:
'flux': spectral_features['flux'],
'harmonicity': spectral_features['harmonicity'],
'noise_floor_energy': self.noise_floor_energy,
'adaptive_threshold': adaptive_energy_threshold,
'energy_check': energy_check,
'zcr_check': zcr_check,
'spectral_check': spectral_check,
@ -386,44 +416,77 @@ class AdvancedVAD:
def _compute_harmonicity(self, magnitude: np.ndarray, freqs: np.ndarray) -> float:
"""Compute harmonicity score (0-1, higher = more harmonic/speech-like)."""
# Find fundamental frequency candidate (peak in 80-400Hz range for speech)
speech_range = (freqs >= 80) & (freqs <= 400)
# Expanded F0 range for better detection
speech_range = (freqs >= 60) & (freqs <= 500) # Expanded from 80-400Hz
if not np.any(speech_range):
return 0.0
speech_magnitude = magnitude[speech_range]
speech_freqs = freqs[speech_range]
if len(speech_magnitude) == 0:
return 0.0
# Find strongest peak in speech range
f0_idx = np.argmax(speech_magnitude)
# More robust F0 detection - find peaks instead of just max
try:
# Import scipy here to handle missing dependency gracefully
from scipy.signal import find_peaks
# Ensure distance is at least 1
min_distance = max(1, int(len(speech_magnitude) * 0.05))
peaks, properties = find_peaks(
speech_magnitude,
height=np.max(speech_magnitude) * 0.05, # Lowered from 0.1
distance=min_distance, # Minimum peak separation
)
if len(peaks) == 0:
# Fallback to simple max if no peaks found
f0_idx = np.argmax(speech_magnitude)
else:
# Use the strongest peak
strongest_peak_idx = np.argmax(speech_magnitude[peaks])
f0_idx = peaks[strongest_peak_idx]
except ImportError:
# scipy not available, use simple max
f0_idx = np.argmax(speech_magnitude)
f0 = speech_freqs[f0_idx]
f0_strength = speech_magnitude[f0_idx]
if f0_strength < np.max(magnitude) * 0.1: # F0 should be reasonably strong
# More lenient F0 strength requirement
if f0_strength < np.max(magnitude) * 0.03: # Reduced from 0.1
return 0.0
# Check for harmonics (2*f0, 3*f0, etc.)
harmonic_strength = 0.0
total_harmonics = 0
for harmonic in range(2, 6): # Check 2nd through 5th harmonics
for harmonic in range(2, 5): # Check 2nd through 4th harmonics
harmonic_freq = f0 * harmonic
if harmonic_freq > freqs[-1]:
break
# Find closest frequency bin
# Find closest frequency bins (check neighboring bins too)
harmonic_idx = np.argmin(np.abs(freqs - harmonic_freq))
harmonic_strength += magnitude[harmonic_idx]
# Check a small neighborhood around the harmonic frequency
start_idx = max(0, harmonic_idx - 2)
end_idx = min(len(magnitude), harmonic_idx + 3)
local_max = np.max(magnitude[start_idx:end_idx])
harmonic_strength += local_max
total_harmonics += 1
if total_harmonics == 0:
return 0.0
# Normalize by fundamental strength and number of harmonics
# Normalize and return
harmonicity = (harmonic_strength / total_harmonics) / f0_strength
return min(harmonicity, 1.0)
@ -899,7 +962,8 @@ class OpenVINOWhisperModel:
# Global model instance with deferred loading
_whisper_model: Optional[OpenVINOWhisperModel] = None
_audio_processors: Dict[str, "OptimizedAudioProcessor"] = {}
_send_chat_func: Optional[Callable[[str], Awaitable[None]]] = None
_send_chat_func: Optional[Callable[[ChatMessageModel], Awaitable[None]]] = None
_create_chat_message_func: Optional[Callable[[str, Optional[str]], ChatMessageModel]] = None
# Model loading status for video display
_model_loading_status: str = "Not loaded"
@ -1018,10 +1082,11 @@ class OptimizedAudioProcessor:
"""Optimized audio processor for Intel Arc B580 with reduced latency."""
def __init__(
self, peer_name: str, send_chat_func: Callable[[str], Awaitable[None]]
self, peer_name: str, send_chat_func: Callable[[ChatMessageModel], Awaitable[None]], create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel]
):
self.peer_name = peer_name
self.send_chat_func = send_chat_func
self.create_chat_message_func = create_chat_message_func
self.sample_rate = SAMPLE_RATE
# Initialize visualization buffer if not already done
@ -1049,6 +1114,10 @@ class OptimizedAudioProcessor:
self.current_phrase_audio = np.array([], dtype=np.float32)
self.transcription_history: List[TranscriptionHistoryItem] = []
self.last_activity_time = time.time()
self.last_audio_time = time.time() # Track when any audio chunk is received
# Current transcription message for refinements
self.current_message: Optional[ChatMessageModel] = None
# Async processing
self.processing_queue: asyncio.Queue[AudioQueueItem] = asyncio.Queue(maxsize=10)
@ -1076,6 +1145,9 @@ class OptimizedAudioProcessor:
logger.error("Processor not running or empty audio data")
return
# Update last audio time whenever any audio is received
self.last_audio_time = time.time()
is_speech, vad_metrics = self.advanced_vad.analyze_frame(audio_data)
# Update visualization status
@ -1227,7 +1299,7 @@ class OptimizedAudioProcessor:
# Check for final transcription on timeout
if (
len(self.current_phrase_audio) > 0
and time.time() - self.last_activity_time > 2.0
and time.time() - self.last_audio_time > 2.0
):
logger.info(
f"Final transcription timeout for {self.peer_name} (asyncio.TimeoutError)"
@ -1276,7 +1348,7 @@ class OptimizedAudioProcessor:
# Check for final transcription
if (
len(self.current_phrase_audio) > 0
and time.time() - self.last_activity_time > 2.0
and time.time() - self.last_audio_time > 2.0
):
if self.main_loop:
logger.info(
@ -1409,29 +1481,43 @@ class OptimizedAudioProcessor:
# return # Skip very low confidence
# # Include confidence in message
# confidence_indicator = (
# "✓" if avg_confidence > 0.8 else "?" if avg_confidence < 0.6 else ""
# )
message = f"{self.peer_name}: {transcription}" # {confidence_indicator}[{avg_confidence:.1%}]"
await self.send_chat_func(message)
# Create ChatMessageModel for transcription
if transcription:
# Create message with timing
status_marker = "" if is_final else "🎤"
type_marker = "" if is_final else " [streaming]"
timing_info = f" (🚀 {transcription_time:.2f}s)"
message = f"{status_marker} {self.peer_name}{type_marker}: {transcription}{timing_info}"
message_text = f"{status_marker} {self.peer_name}{type_marker}: {transcription}{timing_info}"
# Avoid duplicates
if not self._is_duplicate(transcription):
await self.send_chat_func(message)
# For streaming transcriptions, reuse the current message ID if it exists
# For final transcriptions, always create a new message
if is_final:
# Final transcription - reset current message
self.current_message = None
message_id = None
else:
# Streaming transcription - reuse current message ID or create new
if self.current_message is None:
message_id = None # Will create new ID
else:
message_id = self.current_message.id
# Create ChatMessageModel
chat_message = self.create_chat_message_func(message_text, message_id)
# Update current message for streaming
if not is_final:
self.current_message = chat_message
await self.send_chat_func(chat_message)
# Update history
self.transcription_history.append(
TranscriptionHistoryItem(
message=message, timestamp=time.time(), is_final=is_final
message=message_text, timestamp=time.time(), is_final=is_final
)
)
@ -1676,7 +1762,7 @@ class WaveformVideoTrack(MediaStreamTrack):
# Draw colored waveform
if len(points) > 1:
for i in range(len(points) - 1):
cv2.line(frame_array, points[i], points[i+1], colors[i], 2)
cv2.line(frame_array, points[i], points[i+1], colors[i], 1)
# Add speech detection status overlay
if speech_info:
@ -1701,29 +1787,41 @@ class WaveformVideoTrack(MediaStreamTrack):
"""Draw speech detection status information."""
y_offset = 100
line_height = 25
# Main status
is_speech = speech_info.get('is_speech', False)
status_text = "🎤 SPEECH" if is_speech else "🔇 NOISE"
status_text = "SPEECH" if is_speech else "NOISE"
status_color = (0, 255, 0) if is_speech else (128, 128, 128)
cv2.putText(frame_array, f"{pname}: {status_text}",
(10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.8, status_color, 2)
adaptive_thresh = speech_info.get('adaptive_threshold', 0)
cv2.putText(frame_array, f"{pname}: {status_text} (thresh: {adaptive_thresh:.4f})",
(10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2)
# Detailed metrics (smaller text)
metrics = [
f"Energy: {speech_info.get('energy', 0):.3f} ({'' if speech_info.get('energy_check', False) else ''})",
f"ZCR: {speech_info.get('zcr', 0):.3f} ({'' if speech_info.get('zcr_check', False) else ''})",
f"Spectral: ({'' if speech_info.get('spectral_check', False) else ''})",
f"Harmonic: ({'' if speech_info.get('harmonic_check', False) else ''})",
f"Temporal: ({'' if speech_info.get('temporal_consistency', False) else ''})"
f"Energy: {speech_info.get('energy', 0):.3f} ({'Y' if speech_info.get('energy_check', False) else 'N'})",
f"ZCR: {speech_info.get('zcr', 0):.3f} ({'Y' if speech_info.get('zcr_check', False) else 'N'})",
f"Spectral: {'Y' if 300 < speech_info.get('centroid', 0) < 3400 else 'N'}/{'Y' if speech_info.get('rolloff', 0) < 2000 else 'N'}/{'Y' if speech_info.get('flux', 0) > 0.01 else 'N'} ({'Y' if speech_info.get('spectral_check', False) else 'N'})",
f"Harmonic: {speech_info.get('hamonicity', 0):.3f} ({'Y' if speech_info.get('harmonic_check', False) else 'N'})",
f"Temporal: ({'Y' if speech_info.get('temporal_consistency', False) else 'N'})"
]
for i, metric in enumerate(metrics):
cv2.putText(frame_array, metric,
(320, y_offset + i * 15), cv2.FONT_HERSHEY_SIMPLEX, 0.4,
(320, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.4,
(255, 255, 255), 1)
y_offset += 15
logic_result = "E:" + ("Y" if speech_info.get('energy_check', False) else "N")
logic_result += " Z:" + ("Y" if speech_info.get('zcr_check', False) else "N")
logic_result += " S:" + ("Y" if speech_info.get('spectral_check', False) else "N")
logic_result += " H:" + ("Y" if speech_info.get('harmonic_check', False) else "N")
logic_result += " T:" + ("Y" if speech_info.get('temporal_consistency', False) else "N")
cv2.putText(frame_array, logic_result,
(320, y_offset + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
(255, 255, 255), 1)
# Noise floor indicator
noise_floor = speech_info.get('noise_floor_energy', 0)
@ -1760,12 +1858,12 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None:
_model_loading_progress = 0.8
logger.info(f"Creating OptimizedAudioProcessor for {peer.peer_name}")
if _send_chat_func is None:
logger.error(f"No send_chat_func available for {peer.peer_name}")
if _send_chat_func is None or _create_chat_message_func is None:
logger.error(f"No send function available for {peer.peer_name}")
_model_loading_status = "Error: No send function available"
return
_audio_processors[peer.peer_name] = OptimizedAudioProcessor(
peer_name=peer.peer_name, send_chat_func=_send_chat_func
peer_name=peer.peer_name, send_chat_func=_send_chat_func, create_chat_message_func=_create_chat_message_func
)
_model_loading_status = "Ready for transcription"
@ -1773,9 +1871,9 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None:
logger.info(f"OptimizedAudioProcessor ready for {peer.peer_name}")
if peer.peer_name not in _audio_processors:
if _send_chat_func is None:
if _send_chat_func is None or _create_chat_message_func is None:
logger.error(
f"Cannot create processor for {peer.peer_name}: no send_chat_func"
f"Cannot create processor for {peer.peer_name}: no send_chat_func or create_chat_message_func"
)
return
@ -1993,7 +2091,7 @@ def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
async def handle_chat_message(
chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]]
chat_message: ChatMessageModel, send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]
) -> Optional[str]:
"""Handle incoming chat messages."""
return None
@ -2009,16 +2107,18 @@ def get_track_handler() -> Callable[[Peer, MediaStreamTrack], Awaitable[None]]:
return on_track_received
def bind_send_chat_function(send_chat_func: Callable[[str], Awaitable[None]]) -> None:
def bind_send_chat_function(send_chat_func: Callable[[ChatMessageModel], Awaitable[None]], create_chat_message_func: Callable[[str, Optional[str]], ChatMessageModel]) -> None:
"""Bind the send chat function."""
global _send_chat_func, _audio_processors
global _send_chat_func, _create_chat_message_func, _audio_processors
logger.info("Binding send chat function to OpenVINO whisper agent")
_send_chat_func = send_chat_func
_create_chat_message_func = create_chat_message_func
# Update existing processors
for peer_name, processor in _audio_processors.items():
processor.send_chat_func = send_chat_func
processor.create_chat_message_func = create_chat_message_func
logger.debug(f"Updated processor for {peer_name} with new send chat function")

View File

@ -12,6 +12,7 @@ import json
import websockets
import time
import re
import secrets
from typing import (
Dict,
Optional,
@ -20,6 +21,7 @@ from typing import (
Protocol,
AsyncIterator,
cast,
Union,
)
# Add the parent directory to sys.path to allow absolute imports
@ -90,7 +92,7 @@ class WebRTCSignalingClient:
session_name: str,
insecure: bool = False,
create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None,
bind_send_chat_function: Optional[Callable[[Callable[[str], Awaitable[None]]], None]] = None,
bind_send_chat_function: Optional[Callable[[Callable[[Union[str, ChatMessageModel]], Awaitable[None]], Callable[[str, Optional[str]], ChatMessageModel]], None]] = None,
registration_check_interval: float = 30.0,
):
self.server_url = server_url
@ -107,7 +109,7 @@ class WebRTCSignalingClient:
if self.bind_send_chat_function:
# Bind the send_chat_message method to the bot's send function
self.bind_send_chat_function(self.send_chat_message)
self.bind_send_chat_function(self.send_chat_message, self.create_chat_message)
# WebSocket client protocol instance (typed as object to avoid Any)
self.websocket: Optional[object] = None
@ -424,22 +426,47 @@ class WebRTCSignalingClient:
self.shutdown_requested = True
logger.info("Shutdown requested for WebRTC signaling client")
async def send_chat_message(self, message: str):
async def send_chat_message(self, message: Union[str, ChatMessageModel]):
"""Send a chat message to the lobby"""
if not self.is_registered:
logger.warning("Cannot send chat message: not registered")
return
if not message.strip():
logger.warning("Cannot send empty chat message")
return
if isinstance(message, str):
if not message.strip():
logger.warning("Cannot send empty chat message")
return
# Create ChatMessageModel from string
chat_message = ChatMessageModel(
id=secrets.token_hex(8),
message=message.strip(),
sender_name=self.session_name,
sender_session_id=self.session_id,
timestamp=time.time(),
lobby_id=self.lobby_id,
)
message_data = chat_message.model_dump()
else:
# ChatMessageModel
message_data = message.model_dump()
try:
await self._send_message("send_chat_message", {"message": message.strip()})
logger.info(f"Sent chat message: {message[:50]}...")
await self._send_message("send_chat_message", {"message": message_data})
logger.info(f"Sent chat message: {str(message)[:50]}...")
except Exception as e:
logger.error(f"Failed to send chat message: {e}", exc_info=True)
def create_chat_message(self, message: str, message_id: Optional[str] = None) -> ChatMessageModel:
"""Create a ChatMessageModel with the correct session and lobby IDs"""
return ChatMessageModel(
id=message_id or secrets.token_hex(8),
message=message,
sender_name=self.session_name,
sender_session_id=self.session_id,
timestamp=time.time(),
lobby_id=self.lobby_id,
)
async def _setup_local_media(self):
"""Create local media tracks"""
# If a bot provided a create_tracks callable, use it to create tracks.
@ -478,6 +505,7 @@ class WebRTCSignalingClient:
# Build message with explicit type to avoid type narrowing
# Always include data field to match WebSocketMessageModel
# Should be a ChatMessageModel, not custom dict
message: dict[str, object] = {
"type": message_type,
"data": data if data is not None else {}