Features added: - WebSocket chat message handling in WebRTC signaling client - Bot chat handler discovery and automatic setup - Chat message sending/receiving capabilities - Example chatbot with conversation features - Enhanced whisper bot with chat commands - Comprehensive error handling and logging - Full integration with existing WebRTC infrastructure Bots can now: - Receive chat messages from lobby participants - Send responses back through WebSocket - Process commands and keywords - Integrate seamlessly with voice/video functionality Files modified: - voicebot/webrtc_signaling.py: Added chat message handling - voicebot/bot_orchestrator.py: Enhanced bot discovery for chat - voicebot/bots/whisper.py: Added chat command processing - voicebot/bots/chatbot.py: New conversational bot - voicebot/bots/__init__.py: Added chatbot module - CHAT_INTEGRATION.md: Comprehensive documentation - README.md: Updated with chat functionality info
1010 lines
39 KiB
Python
1010 lines
39 KiB
Python
"""
|
|
WebRTC signaling client for voicebot.
|
|
|
|
This module provides WebRTC signaling server communication and peer connection management.
|
|
Synthetic audio/video track creation is handled by the bots.synthetic_media module.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import websockets
|
|
import time
|
|
import re
|
|
from typing import (
|
|
Dict,
|
|
Optional,
|
|
Callable,
|
|
Awaitable,
|
|
Protocol,
|
|
AsyncIterator,
|
|
cast,
|
|
)
|
|
|
|
# Add the parent directory to sys.path to allow absolute imports
|
|
|
|
from pydantic import ValidationError
|
|
from aiortc import (
|
|
RTCPeerConnection,
|
|
RTCSessionDescription,
|
|
RTCIceCandidate,
|
|
MediaStreamTrack,
|
|
)
|
|
from aiortc.rtcconfiguration import RTCConfiguration, RTCIceServer
|
|
from aiortc.sdp import candidate_from_sdp
|
|
|
|
# Import shared models
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
from shared.models import (
|
|
WebSocketMessageModel,
|
|
JoinStatusModel,
|
|
UserJoinedModel,
|
|
LobbyStateModel,
|
|
UpdateNameModel,
|
|
AddPeerModel,
|
|
RemovePeerModel,
|
|
SessionDescriptionModel,
|
|
IceCandidateModel,
|
|
ICECandidateDictModel,
|
|
SessionDescriptionTypedModel,
|
|
ChatMessageModel,
|
|
)
|
|
|
|
from logger import logger
|
|
from voicebot.bots.synthetic_media import create_synthetic_tracks
|
|
from voicebot.models import Peer, MessageData
|
|
from voicebot.utils import create_ssl_context, log_network_info
|
|
|
|
|
|
def _convert_http_to_ws_url(url: str) -> str:
|
|
"""Convert HTTP/HTTPS URL to WebSocket URL by replacing scheme."""
|
|
if url.startswith("https://"):
|
|
return url.replace("https://", "wss://", 1)
|
|
elif url.startswith("http://"):
|
|
return url.replace("http://", "ws://", 1)
|
|
else:
|
|
# Assume it's already a WebSocket URL
|
|
return url
|
|
|
|
|
|
class WebSocketProtocol(Protocol):
|
|
def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ...
|
|
def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ...
|
|
def __aiter__(self) -> AsyncIterator[str]: ...
|
|
|
|
|
|
class WebRTCSignalingClient:
|
|
"""
|
|
WebRTC signaling client that communicates with the FastAPI signaling server.
|
|
Handles peer-to-peer connection establishment and media streaming.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
server_url: str,
|
|
lobby_id: str,
|
|
session_id: str,
|
|
session_name: str,
|
|
insecure: bool = False,
|
|
create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None,
|
|
registration_check_interval: float = 30.0,
|
|
):
|
|
self.server_url = server_url
|
|
self.lobby_id = lobby_id
|
|
self.session_id = session_id
|
|
self.session_name = session_name
|
|
self.insecure = insecure
|
|
|
|
# Optional factory to create local media tracks for this client (bot provided)
|
|
self.create_tracks = create_tracks
|
|
|
|
# WebSocket client protocol instance (typed as object to avoid Any)
|
|
self.websocket: Optional[object] = None
|
|
|
|
# Optional password to register or takeover a name
|
|
self.name_password: Optional[str] = session_name
|
|
|
|
self.peers: dict[str, Peer] = {}
|
|
self.peer_connections: dict[str, RTCPeerConnection] = {}
|
|
self.local_tracks: dict[str, MediaStreamTrack] = {}
|
|
|
|
# State management
|
|
self.is_negotiating: dict[str, bool] = {}
|
|
self.making_offer: dict[str, bool] = {}
|
|
self.initiated_offer: set[str] = set()
|
|
self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {}
|
|
|
|
# Registration status tracking
|
|
self.is_registered: bool = False
|
|
self.last_registration_check: float = 0
|
|
self.registration_check_interval: float = registration_check_interval
|
|
self.registration_check_task: Optional[asyncio.Task[None]] = None
|
|
|
|
# Shutdown flag for graceful termination
|
|
self.shutdown_requested: bool = False
|
|
|
|
# Event callbacks
|
|
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
|
|
self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None
|
|
self.on_track_received: Optional[
|
|
Callable[[Peer, MediaStreamTrack], Awaitable[None]]
|
|
] = None
|
|
self.on_chat_message_received: Optional[
|
|
Callable[[ChatMessageModel], Awaitable[None]]
|
|
] = None
|
|
|
|
async def connect(self):
|
|
"""Connect to the signaling server"""
|
|
base_ws_url = _convert_http_to_ws_url(self.server_url)
|
|
ws_url = f"{base_ws_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
|
|
logger.info(f"Connecting to signaling server: {ws_url}")
|
|
|
|
# Log network information for debugging
|
|
log_network_info()
|
|
|
|
try:
|
|
# Create SSL context based on URL scheme and insecure setting
|
|
if ws_url.startswith("wss://"):
|
|
# For wss://, we need an SSL context
|
|
if self.insecure:
|
|
# Accept self-signed certificates
|
|
ws_ssl = create_ssl_context(insecure=True)
|
|
else:
|
|
# Use default SSL context for secure connections
|
|
ws_ssl = True
|
|
else:
|
|
# For ws://, no SSL context needed
|
|
ws_ssl = None
|
|
|
|
logger.info(
|
|
f"Attempting websocket connection to {ws_url} with ssl={ws_ssl}"
|
|
)
|
|
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
|
|
logger.info("Connected to signaling server")
|
|
|
|
# Set up local media
|
|
await self._setup_local_media()
|
|
|
|
# Set name and join lobby
|
|
name_payload: MessageData = {"name": self.session_name}
|
|
if self.name_password:
|
|
name_payload["password"] = self.name_password
|
|
logger.info(f"Sending set_name: {name_payload}")
|
|
await self._send_message("set_name", name_payload)
|
|
logger.info("Sending join message")
|
|
await self._send_message("join", {})
|
|
|
|
# Mark as registered after successful join
|
|
self.is_registered = True
|
|
self.last_registration_check = time.time()
|
|
|
|
# Start periodic registration check
|
|
self.registration_check_task = asyncio.create_task(self._periodic_registration_check())
|
|
|
|
# Start message handling
|
|
logger.info("Starting message handler loop")
|
|
try:
|
|
await self._handle_messages()
|
|
except Exception as e:
|
|
logger.error(f"Message handling stopped: {e}")
|
|
self.is_registered = False
|
|
raise
|
|
finally:
|
|
# Clean disconnect when exiting
|
|
await self.disconnect()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to signaling server: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def _periodic_registration_check(self):
|
|
"""Periodically check registration status and re-register if needed"""
|
|
|
|
while not self.shutdown_requested:
|
|
try:
|
|
await asyncio.sleep(self.registration_check_interval)
|
|
|
|
# Check shutdown flag again after sleep
|
|
if self.shutdown_requested:
|
|
break
|
|
|
|
current_time = time.time()
|
|
if current_time - self.last_registration_check < self.registration_check_interval:
|
|
continue
|
|
|
|
# Check if we're still connected and registered
|
|
if not await self._check_registration_status():
|
|
logger.warning("Registration check failed, attempting to re-register")
|
|
await self._re_register()
|
|
|
|
self.last_registration_check = current_time
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("Registration check task cancelled")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in periodic registration check: {e}", exc_info=True)
|
|
# Continue checking even if one iteration fails
|
|
continue
|
|
|
|
logger.info("Registration check loop ended")
|
|
|
|
async def _check_registration_status(self) -> bool:
|
|
"""Check if the voicebot is still registered with the server"""
|
|
try:
|
|
# First check if websocket is still connected
|
|
if not self.websocket:
|
|
logger.warning("WebSocket connection lost")
|
|
return False
|
|
|
|
# Try to send a ping/status check message to verify connection
|
|
# We'll use a simple status message to check connectivity
|
|
try:
|
|
await self._send_message("status_check", {"timestamp": time.time()})
|
|
logger.debug("Registration status check sent")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to send status check: {e}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking registration status: {e}")
|
|
return False
|
|
|
|
async def _re_register(self):
|
|
"""Attempt to re-register with the server"""
|
|
try:
|
|
logger.info("Attempting to re-register with server")
|
|
|
|
# Mark as not registered during re-registration attempt
|
|
self.is_registered = False
|
|
|
|
# Try to reconnect the websocket if it's lost
|
|
if not self.websocket:
|
|
logger.info("WebSocket lost, attempting to reconnect")
|
|
await self._reconnect_websocket()
|
|
|
|
# Re-send name and join messages
|
|
name_payload: MessageData = {"name": self.session_name}
|
|
if self.name_password:
|
|
name_payload["password"] = self.name_password
|
|
|
|
logger.info("Re-sending set_name message")
|
|
await self._send_message("set_name", name_payload)
|
|
|
|
logger.info("Re-sending join message")
|
|
await self._send_message("join", {})
|
|
|
|
# Mark as registered after successful re-join
|
|
self.is_registered = True
|
|
self.last_registration_check = time.time()
|
|
|
|
logger.info("Successfully re-registered with server")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to re-register with server: {e}", exc_info=True)
|
|
# Will try again on next check interval
|
|
|
|
async def _reconnect_websocket(self):
|
|
"""Reconnect the WebSocket connection"""
|
|
try:
|
|
# Close existing connection if any
|
|
if self.websocket:
|
|
try:
|
|
ws = cast(WebSocketProtocol, self.websocket)
|
|
await ws.close()
|
|
except Exception:
|
|
pass
|
|
self.websocket = None
|
|
|
|
# Reconnect
|
|
base_ws_url = _convert_http_to_ws_url(self.server_url)
|
|
ws_url = f"{base_ws_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
|
|
|
|
# Create SSL context based on URL scheme and insecure setting
|
|
if ws_url.startswith("wss://"):
|
|
# For wss://, we need an SSL context
|
|
if self.insecure:
|
|
# Accept self-signed certificates
|
|
ws_ssl = create_ssl_context(insecure=True)
|
|
else:
|
|
# Use default SSL context for secure connections
|
|
ws_ssl = True
|
|
else:
|
|
# For ws://, no SSL context needed
|
|
ws_ssl = None
|
|
|
|
logger.info(f"Reconnecting to signaling server: {ws_url}")
|
|
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
|
|
logger.info("Successfully reconnected to signaling server")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to reconnect websocket: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def disconnect(self):
|
|
"""Disconnect from signaling server and cleanup"""
|
|
# Cancel the registration check task
|
|
if self.registration_check_task and not self.registration_check_task.done():
|
|
self.registration_check_task.cancel()
|
|
try:
|
|
# Only await if we're in the same event loop
|
|
current_loop = asyncio.get_running_loop()
|
|
task_loop = self.registration_check_task.get_loop()
|
|
if current_loop == task_loop:
|
|
await self.registration_check_task
|
|
else:
|
|
logger.warning("Registration check task in different event loop, skipping await")
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
logger.warning(f"Error cancelling registration check task: {e}")
|
|
self.registration_check_task = None
|
|
|
|
if self.websocket:
|
|
ws = cast(WebSocketProtocol, self.websocket)
|
|
try:
|
|
await ws.close()
|
|
except Exception as e:
|
|
logger.warning(f"Error closing websocket: {e}")
|
|
|
|
# Close all peer connections
|
|
for pc in self.peer_connections.values():
|
|
try:
|
|
await pc.close()
|
|
except Exception as e:
|
|
logger.warning(f"Error closing peer connection: {e}")
|
|
|
|
# Stop local tracks
|
|
for track in self.local_tracks.values():
|
|
try:
|
|
track.stop()
|
|
except Exception as e:
|
|
logger.warning(f"Error stopping track: {e}")
|
|
|
|
# Reset registration status
|
|
self.is_registered = False
|
|
|
|
logger.info("Disconnected from signaling server")
|
|
|
|
def request_shutdown(self):
|
|
"""Request graceful shutdown - can be called from any thread"""
|
|
self.shutdown_requested = True
|
|
logger.info("Shutdown requested for WebRTC signaling client")
|
|
|
|
async def send_chat_message(self, message: str):
|
|
"""Send a chat message to the lobby"""
|
|
if not self.is_registered:
|
|
logger.warning("Cannot send chat message: not registered")
|
|
return
|
|
|
|
if not message.strip():
|
|
logger.warning("Cannot send empty chat message")
|
|
return
|
|
|
|
try:
|
|
await self._send_message("send_chat_message", {"message": message.strip()})
|
|
logger.info(f"Sent chat message: {message[:50]}...")
|
|
except Exception as e:
|
|
logger.error(f"Failed to send chat message: {e}", exc_info=True)
|
|
|
|
async def _setup_local_media(self):
|
|
"""Create local media tracks"""
|
|
# If a bot provided a create_tracks callable, use it to create tracks.
|
|
# Otherwise, use default synthetic tracks.
|
|
try:
|
|
if self.create_tracks:
|
|
tracks = self.create_tracks(self.session_name)
|
|
self.local_tracks.update(tracks)
|
|
else:
|
|
# Default fallback to synthetic tracks
|
|
tracks = create_synthetic_tracks(self.session_name)
|
|
self.local_tracks.update(tracks)
|
|
except Exception:
|
|
logger.exception("Failed to create local tracks using bot factory")
|
|
|
|
# Add local peer to peers dict
|
|
local_peer = Peer(
|
|
session_id=self.session_id,
|
|
peer_name=self.session_name,
|
|
local=True,
|
|
attributes={"tracks": self.local_tracks},
|
|
)
|
|
self.peers[self.session_id] = local_peer
|
|
|
|
logger.info("Local media tracks created")
|
|
|
|
async def _send_message(
|
|
self, message_type: str, data: Optional[MessageData] = None
|
|
):
|
|
"""Send message to signaling server"""
|
|
if not self.websocket:
|
|
logger.error("No websocket connection")
|
|
return
|
|
# Build message with explicit type to avoid type narrowing
|
|
message: dict[str, object] = {"type": message_type}
|
|
if data is not None:
|
|
message["data"] = data
|
|
|
|
ws = cast(WebSocketProtocol, self.websocket)
|
|
try:
|
|
logger.debug(f"_send_message: Sending {message_type} with data: {data}")
|
|
await ws.send(json.dumps(message))
|
|
logger.debug(f"_send_message: Sent message: {message_type}")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"_send_message: Failed to send {message_type}: {e}", exc_info=True
|
|
)
|
|
|
|
async def _handle_messages(self):
|
|
"""Handle incoming messages from signaling server"""
|
|
try:
|
|
ws = cast(WebSocketProtocol, self.websocket)
|
|
async for message in ws:
|
|
# Check for shutdown request
|
|
if self.shutdown_requested:
|
|
logger.info("Shutdown requested, breaking message loop")
|
|
break
|
|
|
|
logger.debug(f"_handle_messages: Received raw message: {message}")
|
|
try:
|
|
data = cast(MessageData, json.loads(message))
|
|
except Exception as e:
|
|
logger.error(
|
|
f"_handle_messages: Failed to parse message: {e}", exc_info=True
|
|
)
|
|
continue
|
|
await self._process_message(data)
|
|
except websockets.exceptions.ConnectionClosed as e:
|
|
logger.warning(f"WebSocket connection closed: {e}")
|
|
self.is_registered = False
|
|
# The periodic registration check will detect this and attempt reconnection
|
|
except Exception as e:
|
|
logger.error(f"Error handling messages: {e}", exc_info=True)
|
|
self.is_registered = False
|
|
|
|
async def _process_message(self, message: MessageData):
|
|
"""Process incoming signaling messages"""
|
|
try:
|
|
# Handle error messages specially since they have a different structure
|
|
if message.get("type") == "error" and "error" in message:
|
|
error_msg = message.get("error", "Unknown error")
|
|
logger.error(f"Received error from signaling server: {error_msg}")
|
|
return
|
|
|
|
# Validate the base message structure for non-error messages
|
|
validated_message = WebSocketMessageModel.model_validate(message)
|
|
msg_type = validated_message.type
|
|
data = validated_message.data
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid message structure: {e}", exc_info=True)
|
|
return
|
|
|
|
logger.debug(
|
|
f"_process_message: Received message type: {msg_type} with data: {data}"
|
|
)
|
|
|
|
if msg_type == "addPeer":
|
|
try:
|
|
validated = AddPeerModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid addPeer payload: {e}", exc_info=True)
|
|
return
|
|
await self._handle_add_peer(validated)
|
|
elif msg_type == "removePeer":
|
|
try:
|
|
validated = RemovePeerModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid removePeer payload: {e}", exc_info=True)
|
|
return
|
|
await self._handle_remove_peer(validated)
|
|
elif msg_type == "sessionDescription":
|
|
try:
|
|
validated = SessionDescriptionModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid sessionDescription payload: {e}", exc_info=True)
|
|
return
|
|
await self._handle_session_description(validated)
|
|
elif msg_type == "iceCandidate":
|
|
try:
|
|
validated = IceCandidateModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid iceCandidate payload: {e}", exc_info=True)
|
|
return
|
|
await self._handle_ice_candidate(validated)
|
|
elif msg_type == "join_status":
|
|
try:
|
|
validated = JoinStatusModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid join_status payload: {e}", exc_info=True)
|
|
return
|
|
logger.info(f"Join status: {validated.status} - {validated.message}")
|
|
elif msg_type == "user_joined":
|
|
try:
|
|
validated = UserJoinedModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid user_joined payload: {e}", exc_info=True)
|
|
return
|
|
logger.info(
|
|
f"User joined: {validated.name} (session: {validated.session_id})"
|
|
)
|
|
elif msg_type == "lobby_state":
|
|
try:
|
|
validated = LobbyStateModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid lobby_state payload: {e}", exc_info=True)
|
|
return
|
|
participants = validated.participants
|
|
logger.info(f"Lobby state updated: {len(participants)} participants")
|
|
elif msg_type == "update_name":
|
|
try:
|
|
validated = UpdateNameModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid update payload: {e}", exc_info=True)
|
|
return
|
|
logger.info(f"Received update message: {validated}")
|
|
elif msg_type == "status_check":
|
|
# Handle status check messages - these are used to verify connection
|
|
logger.debug(f"Received status check message: {data}")
|
|
# No special processing needed for status checks, just acknowledge receipt
|
|
elif msg_type == "chat_message":
|
|
try:
|
|
validated = ChatMessageModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid chat_message payload: {e}", exc_info=True)
|
|
return
|
|
logger.info(f"Received chat message from {validated.sender_name}: {validated.message[:50]}...")
|
|
# Call the callback if it's set
|
|
if self.on_chat_message_received:
|
|
try:
|
|
await self.on_chat_message_received(validated)
|
|
except Exception as e:
|
|
logger.error(f"Error in chat message callback: {e}", exc_info=True)
|
|
else:
|
|
logger.info(f"Unhandled message type: {msg_type} with data: {data}")
|
|
|
|
# Continue with more methods in the next part...
|
|
|
|
async def _handle_add_peer(self, data: AddPeerModel):
|
|
"""Handle addPeer message - create new peer connection"""
|
|
peer_id = data.peer_id
|
|
peer_name = data.peer_name
|
|
should_create_offer = data.should_create_offer
|
|
|
|
logger.info(
|
|
f"Adding peer: {peer_name} (should_create_offer: {should_create_offer})"
|
|
)
|
|
logger.debug(
|
|
f"_handle_add_peer: peer_id={peer_id}, peer_name={peer_name}, should_create_offer={should_create_offer}"
|
|
)
|
|
|
|
# Check if peer already exists
|
|
if peer_id in self.peer_connections:
|
|
pc = self.peer_connections[peer_id]
|
|
logger.debug(
|
|
f"_handle_add_peer: Existing connection state: {pc.connectionState}"
|
|
)
|
|
if pc.connectionState in ["new", "connected", "connecting"]:
|
|
logger.info(f"Peer connection already exists for {peer_name}")
|
|
return
|
|
else:
|
|
# Clean up stale connection
|
|
logger.debug(
|
|
f"_handle_add_peer: Closing stale connection for {peer_name}"
|
|
)
|
|
await pc.close()
|
|
del self.peer_connections[peer_id]
|
|
|
|
# Create new peer
|
|
peer = Peer(session_id=peer_id, peer_name=peer_name, local=False)
|
|
self.peers[peer_id] = peer
|
|
|
|
# Create RTCPeerConnection
|
|
config = RTCConfiguration(
|
|
iceServers=[
|
|
RTCIceServer(urls="stun:ketrenos.com:3478"),
|
|
RTCIceServer(
|
|
urls="turns:ketrenos.com:5349",
|
|
username="ketra",
|
|
credential="ketran",
|
|
),
|
|
# Add Google's public STUN server as fallback
|
|
RTCIceServer(urls="stun:stun.l.google.com:19302"),
|
|
],
|
|
)
|
|
logger.debug(
|
|
f"_handle_add_peer: Creating RTCPeerConnection for {peer_name} with config: {config}"
|
|
)
|
|
pc = RTCPeerConnection(configuration=config)
|
|
|
|
# Add ICE gathering state change handler
|
|
def on_ice_gathering_state_change() -> None:
|
|
logger.info(f"ICE gathering state: {pc.iceGatheringState}")
|
|
if pc.iceGatheringState == "complete":
|
|
logger.info(
|
|
f"ICE gathering complete for {peer_name} - checking if candidates were generated..."
|
|
)
|
|
|
|
pc.on("icegatheringstatechange")(on_ice_gathering_state_change)
|
|
|
|
# Add connection state change handler
|
|
def on_connection_state_change() -> None:
|
|
logger.info(f"Connection state: {pc.connectionState}")
|
|
|
|
pc.on("connectionstatechange")(on_connection_state_change)
|
|
|
|
self.peer_connections[peer_id] = pc
|
|
peer.connection = pc
|
|
|
|
# Set up event handlers
|
|
def on_track(track: MediaStreamTrack) -> None:
|
|
logger.info(f"Received {track.kind} track from {peer_name}")
|
|
logger.info(f"on_track: {track.kind} from {peer_name}, track={track}")
|
|
peer.attributes[f"{track.kind}_track"] = track
|
|
if self.on_track_received:
|
|
asyncio.ensure_future(self.on_track_received(peer, track))
|
|
|
|
pc.on("track")(on_track)
|
|
|
|
def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None:
|
|
logger.info(f"on_ice_candidate: {candidate}")
|
|
logger.info(
|
|
f"on_ice_candidate CALLED for {peer_name}: candidate={candidate}"
|
|
)
|
|
if not candidate:
|
|
logger.info(
|
|
f"on_ice_candidate: End of candidates signal for {peer_name}"
|
|
)
|
|
return
|
|
|
|
# Raw SDP fragment for the candidate
|
|
raw = getattr(candidate, "candidate", None)
|
|
|
|
# Try to infer candidate type from the SDP string (host/srflx/relay/prflx)
|
|
def _parse_type(s: Optional[str]) -> str:
|
|
if not s:
|
|
return "eoc"
|
|
m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", s)
|
|
return m.group(1) if m else "unknown"
|
|
|
|
cand_type = _parse_type(raw)
|
|
protocol = getattr(candidate, "protocol", "unknown")
|
|
logger.info(
|
|
f"ICE candidate outgoing for {peer_name}: type={cand_type} protocol={protocol} sdp={raw}"
|
|
)
|
|
|
|
candidate_model = ICECandidateDictModel(
|
|
candidate=raw,
|
|
sdpMid=getattr(candidate, "sdpMid", None),
|
|
sdpMLineIndex=getattr(candidate, "sdpMLineIndex", None),
|
|
)
|
|
payload_model = IceCandidateModel(
|
|
peer_id=peer_id, peer_name=peer_name, candidate=candidate_model
|
|
)
|
|
logger.info(
|
|
f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_model}"
|
|
)
|
|
asyncio.ensure_future(
|
|
self._send_message("relayICECandidate", payload_model.model_dump())
|
|
)
|
|
|
|
pc.on("icecandidate")(on_ice_candidate)
|
|
|
|
# Add local tracks
|
|
for track in self.local_tracks.values():
|
|
logger.debug(
|
|
f"_handle_add_peer: Adding local track {track.kind} to {peer_name}"
|
|
)
|
|
pc.addTrack(track)
|
|
|
|
# Create offer if needed
|
|
if should_create_offer:
|
|
await self._create_and_send_offer(peer_id, peer_name, pc)
|
|
|
|
if self.on_peer_added:
|
|
await self.on_peer_added(peer)
|
|
|
|
async def _create_and_send_offer(self, peer_id: str, peer_name: str, pc: RTCPeerConnection):
|
|
"""Create and send an offer to a peer"""
|
|
self.initiated_offer.add(peer_id)
|
|
self.making_offer[peer_id] = True
|
|
self.is_negotiating[peer_id] = True
|
|
|
|
try:
|
|
logger.debug(f"_handle_add_peer: Creating offer for {peer_name}")
|
|
offer = await pc.createOffer()
|
|
logger.debug(
|
|
f"_handle_add_peer: Offer created for {peer_name}: {offer}"
|
|
)
|
|
await pc.setLocalDescription(offer)
|
|
logger.debug(f"_handle_add_peer: Local description set for {peer_name}")
|
|
|
|
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
|
|
# Use Method 2: Complete SDP approach to extract ICE candidates
|
|
logger.debug(
|
|
f"_handle_add_peer: Waiting for ICE gathering to complete for {peer_name}"
|
|
)
|
|
while pc.iceGatheringState != "complete":
|
|
await asyncio.sleep(0.1)
|
|
|
|
logger.debug(
|
|
f"_handle_add_peer: ICE gathering complete, extracting candidates from SDP for {peer_name}"
|
|
)
|
|
|
|
await self._extract_and_send_candidates(peer_id, peer_name, pc)
|
|
|
|
session_desc_typed = SessionDescriptionTypedModel(
|
|
type=offer.type, sdp=offer.sdp
|
|
)
|
|
session_desc_model = SessionDescriptionModel(
|
|
peer_id=peer_id,
|
|
peer_name=peer_name,
|
|
session_description=session_desc_typed,
|
|
)
|
|
await self._send_message(
|
|
"relaySessionDescription",
|
|
session_desc_model.model_dump(),
|
|
)
|
|
|
|
logger.info(f"Offer sent to {peer_name}")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to create/send offer to {peer_name}: {e}", exc_info=True
|
|
)
|
|
finally:
|
|
self.making_offer[peer_id] = False
|
|
|
|
async def _extract_and_send_candidates(self, peer_id: str, peer_name: str, pc: RTCPeerConnection):
|
|
"""Extract ICE candidates from SDP and send them"""
|
|
# Parse ICE candidates from the local SDP
|
|
sdp_lines = pc.localDescription.sdp.split("\n")
|
|
candidate_lines = [
|
|
line for line in sdp_lines if line.startswith("a=candidate:")
|
|
]
|
|
|
|
# Track which media section we're in to determine sdpMid and sdpMLineIndex
|
|
current_media_index = -1
|
|
current_mid = None
|
|
|
|
for line in sdp_lines:
|
|
if line.startswith("m="): # Media section
|
|
current_media_index += 1
|
|
elif line.startswith("a=mid:"): # Media ID
|
|
current_mid = line.split(":", 1)[1].strip()
|
|
elif line.startswith("a=candidate:"):
|
|
candidate_sdp = line[2:] # Remove 'a=' prefix
|
|
|
|
candidate_model = ICECandidateDictModel(
|
|
candidate=candidate_sdp,
|
|
sdpMid=current_mid,
|
|
sdpMLineIndex=current_media_index,
|
|
)
|
|
payload_candidate = IceCandidateModel(
|
|
peer_id=peer_id,
|
|
peer_name=peer_name,
|
|
candidate=candidate_model,
|
|
)
|
|
|
|
logger.debug(
|
|
f"_extract_and_send_candidates: Sending extracted ICE candidate for {peer_name}: {candidate_sdp[:60]}..."
|
|
)
|
|
await self._send_message(
|
|
"relayICECandidate", payload_candidate.model_dump()
|
|
)
|
|
|
|
# Send end-of-candidates signal (empty candidate)
|
|
end_candidate_model = ICECandidateDictModel(
|
|
candidate="",
|
|
sdpMid=None,
|
|
sdpMLineIndex=None,
|
|
)
|
|
payload_end = IceCandidateModel(
|
|
peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model
|
|
)
|
|
logger.debug(
|
|
f"_extract_and_send_candidates: Sending end-of-candidates signal for {peer_name}"
|
|
)
|
|
await self._send_message("relayICECandidate", payload_end.model_dump())
|
|
|
|
logger.debug(
|
|
f"_extract_and_send_candidates: Sent {len(candidate_lines)} ICE candidates to {peer_name}"
|
|
)
|
|
|
|
async def _handle_remove_peer(self, data: RemovePeerModel):
|
|
"""Handle removePeer message"""
|
|
peer_id = data.peer_id
|
|
peer_name = data.peer_name
|
|
|
|
logger.info(f"Removing peer: {peer_name}")
|
|
|
|
# Close peer connection
|
|
if peer_id in self.peer_connections:
|
|
pc = self.peer_connections[peer_id]
|
|
await pc.close()
|
|
del self.peer_connections[peer_id]
|
|
|
|
# Clean up state
|
|
self.is_negotiating.pop(peer_id, None)
|
|
self.making_offer.pop(peer_id, None)
|
|
self.initiated_offer.discard(peer_id)
|
|
self.pending_ice_candidates.pop(peer_id, None)
|
|
|
|
# Remove peer
|
|
peer = self.peers.pop(peer_id, None)
|
|
if peer and self.on_peer_removed:
|
|
await self.on_peer_removed(peer)
|
|
|
|
async def _handle_session_description(self, data: SessionDescriptionModel):
|
|
"""Handle sessionDescription message"""
|
|
peer_id = data.peer_id
|
|
peer_name = data.peer_name
|
|
session_description = data.session_description.model_dump()
|
|
|
|
logger.info(f"Received {session_description['type']} from {peer_name}")
|
|
|
|
pc = self.peer_connections.get(peer_id)
|
|
if not pc:
|
|
logger.error(f"No peer connection for {peer_name}")
|
|
return
|
|
|
|
desc = RTCSessionDescription(
|
|
sdp=session_description["sdp"], type=session_description["type"]
|
|
)
|
|
|
|
# Handle offer collision (polite peer pattern)
|
|
making_offer = self.making_offer.get(peer_id, False)
|
|
offer_collision = desc.type == "offer" and (
|
|
making_offer or pc.signalingState != "stable"
|
|
)
|
|
we_initiated = peer_id in self.initiated_offer
|
|
ignore_offer = we_initiated and offer_collision
|
|
|
|
if ignore_offer:
|
|
logger.info(f"Ignoring offer from {peer_name} due to collision")
|
|
return
|
|
|
|
try:
|
|
await pc.setRemoteDescription(desc)
|
|
self.is_negotiating[peer_id] = False
|
|
logger.info(f"Remote description set for {peer_name}")
|
|
|
|
# Process queued ICE candidates
|
|
pending_candidates = self.pending_ice_candidates.pop(peer_id, [])
|
|
|
|
for candidate_data in pending_candidates:
|
|
# candidate_data is an ICECandidateDictModel Pydantic model
|
|
cand = candidate_data.candidate
|
|
# handle end-of-candidates marker
|
|
if not cand:
|
|
await pc.addIceCandidate(None)
|
|
logger.info(f"Added queued end-of-candidates for {peer_name}")
|
|
continue
|
|
|
|
# cand may be the full "candidate:..." string or the inner SDP part
|
|
if cand and cand.startswith("candidate:"):
|
|
sdp_part = cand.split(":", 1)[1]
|
|
else:
|
|
sdp_part = cand
|
|
|
|
try:
|
|
rtc_candidate = candidate_from_sdp(sdp_part)
|
|
rtc_candidate.sdpMid = candidate_data.sdpMid
|
|
rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex
|
|
await pc.addIceCandidate(rtc_candidate)
|
|
logger.info(f"Added queued ICE candidate for {peer_name}")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to add queued ICE candidate for {peer_name}: {e}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to set remote description for {peer_name}: {e}")
|
|
return
|
|
|
|
# Create answer if this was an offer
|
|
if session_description["type"] == "offer":
|
|
await self._create_and_send_answer(peer_id, peer_name, pc)
|
|
|
|
async def _create_and_send_answer(self, peer_id: str, peer_name: str, pc: RTCPeerConnection):
|
|
"""Create and send an answer to a peer"""
|
|
try:
|
|
answer = await pc.createAnswer()
|
|
await pc.setLocalDescription(answer)
|
|
|
|
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
|
|
# Use Method 2: Complete SDP approach to extract ICE candidates
|
|
logger.debug(
|
|
f"_create_and_send_answer: Waiting for ICE gathering to complete for {peer_name} (answer)"
|
|
)
|
|
while pc.iceGatheringState != "complete":
|
|
await asyncio.sleep(0.1)
|
|
|
|
logger.debug(
|
|
f"_create_and_send_answer: ICE gathering complete, extracting candidates from SDP for {peer_name} (answer)"
|
|
)
|
|
|
|
await self._extract_and_send_candidates(peer_id, peer_name, pc)
|
|
|
|
session_desc_typed = SessionDescriptionTypedModel(
|
|
type=answer.type, sdp=answer.sdp
|
|
)
|
|
session_desc_model = SessionDescriptionModel(
|
|
peer_id=peer_id,
|
|
peer_name=peer_name,
|
|
session_description=session_desc_typed,
|
|
)
|
|
await self._send_message(
|
|
"relaySessionDescription",
|
|
session_desc_model.model_dump(),
|
|
)
|
|
|
|
logger.info(f"Answer sent to {peer_name}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to create/send answer to {peer_name}: {e}")
|
|
|
|
async def _handle_ice_candidate(self, data: IceCandidateModel):
|
|
"""Handle iceCandidate message"""
|
|
peer_id = data.peer_id
|
|
peer_name = data.peer_name
|
|
candidate_data = data.candidate
|
|
|
|
logger.info(f"Received ICE candidate from {peer_name}")
|
|
|
|
pc = self.peer_connections.get(peer_id)
|
|
if not pc:
|
|
logger.error(f"No peer connection for {peer_name}")
|
|
return
|
|
|
|
# Queue candidate if remote description not set
|
|
if not pc.remoteDescription:
|
|
logger.info(
|
|
f"Remote description not set, queuing ICE candidate for {peer_name}"
|
|
)
|
|
if peer_id not in self.pending_ice_candidates:
|
|
self.pending_ice_candidates[peer_id] = []
|
|
# candidate_data is an ICECandidateDictModel Pydantic model
|
|
self.pending_ice_candidates[peer_id].append(candidate_data)
|
|
return
|
|
|
|
try:
|
|
cand = candidate_data.candidate
|
|
if not cand:
|
|
# end-of-candidates
|
|
await pc.addIceCandidate(None)
|
|
logger.info(f"End-of-candidates added for {peer_name}")
|
|
return
|
|
|
|
if cand and cand.startswith("candidate:"):
|
|
sdp_part = cand.split(":", 1)[1]
|
|
else:
|
|
sdp_part = cand
|
|
|
|
# Detect type for logging
|
|
try:
|
|
m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", sdp_part)
|
|
cand_type = m.group(1) if m else "unknown"
|
|
except Exception:
|
|
cand_type = "unknown"
|
|
|
|
try:
|
|
rtc_candidate = candidate_from_sdp(sdp_part)
|
|
rtc_candidate.sdpMid = candidate_data.sdpMid
|
|
rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex
|
|
|
|
# aiortc expects an object with attributes (RTCIceCandidate)
|
|
await pc.addIceCandidate(rtc_candidate)
|
|
logger.info(f"ICE candidate added for {peer_name}: type={cand_type}")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to add ICE candidate for {peer_name}: type={cand_type} error={e} sdp='{sdp_part}'",
|
|
exc_info=True,
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Unexpected error handling ICE candidate for {peer_name}: {e}",
|
|
exc_info=True,
|
|
)
|