ai-voicebot/voicebot/webrtc_signaling.py

1330 lines
56 KiB
Python

"""
WebRTC signaling client for voicebot.
This module provides WebRTC signaling server communication and peer connection management.
Synthetic audio/video track creation is handled by the bots.synthetic_media module.
"""
from __future__ import annotations
import asyncio
import json
import websockets
import time
import re
import secrets
from typing import (
Dict,
Optional,
Callable,
Awaitable,
Protocol,
AsyncIterator,
cast,
Union,
)
# Add the parent directory to sys.path to allow absolute imports
from pydantic import ValidationError
from aiortc import (
RTCPeerConnection,
RTCSessionDescription,
RTCIceCandidate,
MediaStreamTrack,
)
from aiortc.rtcconfiguration import RTCConfiguration, RTCIceServer
from aiortc.sdp import candidate_from_sdp
# Import shared models
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.models import (
WebSocketMessageModel,
JoinStatusModel,
UserJoinedModel,
LobbyStateModel,
UpdateNameModel,
AddPeerModel,
RemovePeerModel,
SessionDescriptionModel,
IceCandidateModel,
ICECandidateDictModel,
SessionDescriptionTypedModel,
ChatMessageModel,
)
from shared.logger import logger
from voicebot.bots.synthetic_media import create_synthetic_tracks
from voicebot.models import Peer, MessageData
from voicebot.utils import create_ssl_context, log_network_info
def _convert_http_to_ws_url(url: str) -> str:
"""Convert HTTP/HTTPS URL to WebSocket URL by replacing scheme."""
if url.startswith("https://"):
return url.replace("https://", "wss://", 1)
elif url.startswith("http://"):
return url.replace("http://", "ws://", 1)
else:
# Assume it's already a WebSocket URL
return url
class WebSocketProtocol(Protocol):
def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ...
def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ...
def __aiter__(self) -> AsyncIterator[str]: ...
class WebRTCSignalingClient:
"""
WebRTC signaling client that communicates with the FastAPI signaling server.
Handles peer-to-peer connection establishment and media streaming.
"""
def __init__(
self,
server_url: str,
lobby_id: str,
session_id: str,
session_name: str,
insecure: bool = False,
create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None,
bind_send_chat_function: Optional[Callable[[Callable[[Union[str, ChatMessageModel]], Awaitable[None]], Callable[[str, Optional[str]], ChatMessageModel]], None]] = None,
registration_check_interval: float = 30.0,
):
self.server_url = server_url
self.lobby_id = lobby_id
self.session_id = session_id
self.session_name = session_name
self.insecure = insecure
# Optional factory to create local media tracks for this client (bot provided)
self.create_tracks = create_tracks
# Optional functions to bind/unbind chat message sending (bot provided)
self.bind_send_chat_function = bind_send_chat_function
if self.bind_send_chat_function:
# Bind the send_chat_message method to the bot's send function
self.bind_send_chat_function(self.send_chat_message, self.create_chat_message)
# WebSocket client protocol instance (typed as object to avoid Any)
self.websocket: Optional[object] = None
# Optional password to register or takeover a name
self.name_password: Optional[str] = session_name
self.peers: dict[str, Peer] = {}
self.peer_connections: dict[str, RTCPeerConnection] = {}
self.local_tracks: dict[str, MediaStreamTrack] = {}
# State management
self.is_negotiating: dict[str, bool] = {}
self.making_offer: dict[str, bool] = {}
self.initiated_offer: set[str] = set()
self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {}
# Registration status tracking
self.is_registered: bool = False
self.last_registration_check: float = 0
self.registration_check_interval: float = registration_check_interval
self.registration_check_task: Optional[asyncio.Task[None]] = None
# Shutdown flag for graceful termination
self.shutdown_requested: bool = False
# Event callbacks
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None
self.on_track_received: Optional[
Callable[[Peer, MediaStreamTrack], Awaitable[None]]
] = None
self.on_chat_message_received: Optional[
Callable[[ChatMessageModel], Awaitable[None]]
] = None
async def connect(self):
"""Connect to the signaling server"""
base_ws_url = _convert_http_to_ws_url(self.server_url)
ws_url = f"{base_ws_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
logger.info(f"Connecting to signaling server: {ws_url}")
# Log network information for debugging
log_network_info()
try:
# Create SSL context based on URL scheme and insecure setting
if ws_url.startswith("wss://"):
# For wss://, we need an SSL context
if self.insecure:
# Accept self-signed certificates
ws_ssl = create_ssl_context(insecure=True)
else:
# Use default SSL context for secure connections
ws_ssl = True
else:
# For ws://, no SSL context needed
ws_ssl = None
logger.info(
f"Attempting websocket connection to {ws_url} with ssl={ws_ssl}"
)
self.websocket = await websockets.connect(
ws_url,
ssl=ws_ssl,
ping_interval=30, # Send ping every 30 seconds
ping_timeout=10, # Wait up to 10 seconds for pong response
close_timeout=10 # Wait up to 10 seconds for close handshake
)
logger.info("Connected to signaling server")
# Set up local media
await self._setup_local_media()
# Set name and join lobby
name_payload: MessageData = {"name": self.session_name}
if self.name_password:
name_payload["password"] = self.name_password
logger.info(f"Sending set_name: {name_payload}")
await self._send_message("set_name", name_payload)
logger.info("Sending join message")
await self._send_message("join", {})
# Mark as registered after successful join
self.is_registered = True
self.last_registration_check = time.time()
# Start periodic registration check
self.registration_check_task = asyncio.create_task(self._periodic_registration_check())
# Start message handling
logger.info("Starting message handler loop")
try:
await self._handle_messages()
except Exception as e:
logger.error(f"Message handling stopped: {e}")
self.is_registered = False
raise
finally:
# Clean disconnect when exiting
await self.disconnect()
except Exception as e:
logger.error(f"Failed to connect to signaling server: {e}", exc_info=True)
raise
async def _periodic_registration_check(self):
"""Periodically check registration status and re-register if needed"""
# Add a small random delay to prevent all bots from checking at the same time
import random
initial_delay = random.uniform(0, 5)
await asyncio.sleep(initial_delay)
while not self.shutdown_requested:
try:
await asyncio.sleep(self.registration_check_interval)
# Check shutdown flag again after sleep
if self.shutdown_requested:
break
current_time = time.time()
if current_time - self.last_registration_check < self.registration_check_interval:
continue
# Check if we're still connected and registered
if not await self._check_registration_status():
logger.warning("Registration check failed, attempting to re-register")
await self._re_register()
self.last_registration_check = current_time
except asyncio.CancelledError:
logger.info("Registration check task cancelled")
break
except Exception as e:
logger.error(f"Error in periodic registration check: {e}", exc_info=True)
# Continue checking even if one iteration fails
continue
logger.info("Registration check loop ended")
async def _check_registration_status(self) -> bool:
"""Check if the voicebot is still registered with the server"""
try:
# First check if websocket is still connected
if not self.websocket:
logger.warning("WebSocket connection lost")
return False
# Try to send a ping/status check message to verify connection
# We'll use a simple status message to check connectivity
try:
await self._send_message("status_check", {"timestamp": time.time()})
logger.debug("Registration status check sent")
return True
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"WebSocket connection closed during status check: {e}")
# Mark websocket as None to ensure reconnection attempts
self.websocket = None
return False
except Exception as e:
logger.warning(f"Failed to send status check: {e}")
return False
except Exception as e:
logger.error(f"Error checking registration status: {e}")
return False
async def _re_register(self):
"""Attempt to re-register with the server"""
try:
logger.info("Attempting to re-register with server")
# Mark as not registered during re-registration attempt
self.is_registered = False
# Try to reconnect the websocket if it's lost
if not self.websocket:
logger.info("WebSocket lost, attempting to reconnect")
await self._reconnect_websocket()
# Double-check that websocket is connected after reconnect attempt
if not self.websocket:
logger.error("Failed to establish WebSocket connection for re-registration")
return
# Re-send name and join messages with retry logic
name_payload: MessageData = {"name": self.session_name}
if self.name_password:
name_payload["password"] = self.name_password
logger.info("Re-sending set_name message")
try:
await self._send_message("set_name", name_payload)
except websockets.exceptions.ConnectionClosed:
logger.warning("Connection closed while sending set_name, aborting re-registration")
return
except Exception as e:
logger.error(f"Failed to send set_name during re-registration: {e}")
return
logger.info("Re-sending join message")
try:
await self._send_message("join", {})
except websockets.exceptions.ConnectionClosed:
logger.warning("Connection closed while sending join, aborting re-registration")
return
except Exception as e:
logger.error(f"Failed to send join during re-registration: {e}")
return
# Mark as registered after successful re-join
self.is_registered = True
self.last_registration_check = time.time()
logger.info("Successfully re-registered with server")
except Exception as e:
logger.error(f"Failed to re-register with server: {e}", exc_info=True)
# Will try again on next check interval
async def _reconnect_websocket(self):
"""Reconnect the WebSocket connection"""
try:
# Close existing connection if any
if self.websocket:
try:
ws = cast(WebSocketProtocol, self.websocket)
await ws.close()
except Exception:
pass
self.websocket = None
# Reconnect
base_ws_url = _convert_http_to_ws_url(self.server_url)
ws_url = f"{base_ws_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
# Create SSL context based on URL scheme and insecure setting
if ws_url.startswith("wss://"):
# For wss://, we need an SSL context
if self.insecure:
# Accept self-signed certificates
ws_ssl = create_ssl_context(insecure=True)
else:
# Use default SSL context for secure connections
ws_ssl = True
else:
# For ws://, no SSL context needed
ws_ssl = None
logger.info(f"Reconnecting to signaling server: {ws_url}")
self.websocket = await websockets.connect(
ws_url,
ssl=ws_ssl,
ping_interval=30, # Send ping every 30 seconds
ping_timeout=10, # Wait up to 10 seconds for pong response
close_timeout=10 # Wait up to 10 seconds for close handshake
)
logger.info("Successfully reconnected to signaling server")
except Exception as e:
logger.error(f"Failed to reconnect websocket: {e}", exc_info=True)
raise
async def disconnect(self):
"""Disconnect from signaling server and cleanup"""
# Cancel the registration check task
if self.registration_check_task and not self.registration_check_task.done():
self.registration_check_task.cancel()
try:
# Only await if we're in the same event loop
current_loop = asyncio.get_running_loop()
task_loop = self.registration_check_task.get_loop()
if current_loop == task_loop:
await self.registration_check_task
else:
logger.warning("Registration check task in different event loop, skipping await")
except asyncio.CancelledError:
pass
except Exception as e:
logger.warning(f"Error cancelling registration check task: {e}")
self.registration_check_task = None
if self.websocket:
ws = cast(WebSocketProtocol, self.websocket)
try:
await ws.close()
except Exception as e:
logger.warning(f"Error closing websocket: {e}")
# Close all peer connections
for pc in self.peer_connections.values():
try:
await pc.close()
except Exception as e:
logger.warning(f"Error closing peer connection: {e}")
# Stop local tracks
for track in self.local_tracks.values():
try:
track.stop()
except Exception as e:
logger.warning(f"Error stopping track: {e}")
# Reset registration status
self.is_registered = False
logger.info("Disconnected from signaling server")
def request_shutdown(self):
"""Request graceful shutdown - can be called from any thread"""
self.shutdown_requested = True
logger.info("Shutdown requested for WebRTC signaling client")
async def send_chat_message(self, message: Union[str, ChatMessageModel]):
"""Send a chat message to the lobby"""
if not self.is_registered:
logger.warning("Cannot send chat message: not registered")
return
if isinstance(message, str):
if not message.strip():
logger.warning("Cannot send empty chat message")
return
# Create ChatMessageModel from string
chat_message = ChatMessageModel(
id=secrets.token_hex(8),
message=message.strip(),
sender_name=self.session_name,
sender_session_id=self.session_id,
timestamp=time.time(),
lobby_id=self.lobby_id,
)
message_data = chat_message.model_dump()
else:
# ChatMessageModel
message_data = message.model_dump()
try:
await self._send_message("send_chat_message", {"message": message_data})
logger.info(f"Sent chat message: {str(message)[:50]}...")
except Exception as e:
logger.error(f"Failed to send chat message: {e}", exc_info=True)
def create_chat_message(self, message: str, message_id: Optional[str] = None) -> ChatMessageModel:
"""Create a ChatMessageModel with the correct session and lobby IDs"""
return ChatMessageModel(
id=message_id or secrets.token_hex(8),
message=message,
sender_name=self.session_name,
sender_session_id=self.session_id,
timestamp=time.time(),
lobby_id=self.lobby_id,
)
async def _setup_local_media(self):
"""Create local media tracks"""
# If a bot provided a create_tracks callable, use it to create tracks.
# Otherwise, use default synthetic tracks.
try:
if self.create_tracks:
tracks = self.create_tracks(self.session_name)
self.local_tracks.update(tracks)
else:
# Default fallback to synthetic tracks
tracks = create_synthetic_tracks(self.session_name)
self.local_tracks.update(tracks)
except Exception:
logger.exception("Failed to create local tracks using bot factory")
# Add local peer to peers dict
local_peer = Peer(
session_id=self.session_id,
peer_name=self.session_name,
local=True,
attributes={"tracks": self.local_tracks},
)
self.peers[self.session_id] = local_peer
logger.info("Local media tracks created")
async def _send_message(
self, message_type: str, data: Optional[MessageData] = None
):
"""Send message to signaling server"""
if not self.websocket:
logger.error("No websocket connection")
return
ws = cast(WebSocketProtocol, self.websocket)
# Build message with explicit type to avoid type narrowing
# Always include data field to match WebSocketMessageModel
# Should be a ChatMessageModel, not custom dict
message: dict[str, object] = {
"type": message_type,
"data": data if data is not None else {}
}
try:
logger.debug(f"_send_message: Sending {message_type} with data: {data}")
await ws.send(json.dumps(message))
logger.debug(f"_send_message: Sent message: {message_type}")
except websockets.exceptions.ConnectionClosed as e:
logger.error(
f"_send_message: WebSocket connection closed while sending {message_type}: {e}"
)
# Mark websocket as None to trigger reconnection
self.websocket = None
self.is_registered = False
raise
except Exception as e:
logger.error(
f"_send_message: Failed to send {message_type}: {e}", exc_info=True
)
raise
async def _handle_messages(self):
"""Handle incoming messages from signaling server"""
logger.info("_handle_messages: Starting message handling loop")
message_count = 0
try:
ws = cast(WebSocketProtocol, self.websocket)
logger.info("_handle_messages: WebSocket cast successful, entering message loop")
async for message in ws:
message_count += 1
logger.debug(f"_handle_messages: [#{message_count}] Entering message processing")
# Check for shutdown request
if self.shutdown_requested:
logger.info("Shutdown requested, breaking message loop")
break
logger.debug(f"_handle_messages: [#{message_count}] Received raw message: {message}")
try:
data = cast(MessageData, json.loads(message))
logger.debug(f"_handle_messages: [#{message_count}] Successfully parsed message data: {data}")
except Exception as e:
logger.error(
f"_handle_messages: [#{message_count}] Failed to parse message: {e}", exc_info=True
)
continue
logger.debug(f"_handle_messages: [#{message_count}] About to call _process_message")
# Add timeout to message processing to prevent blocking
processing_timeout = 30.0 # 30 seconds timeout
try:
await asyncio.wait_for(self._process_message(data), timeout=processing_timeout)
logger.debug(f"_handle_messages: [#{message_count}] Completed _process_message")
except asyncio.TimeoutError:
logger.error(f"_handle_messages: [#{message_count}] Message processing timeout ({processing_timeout}s) for message type: {data.get('type', 'unknown')}")
except Exception as e:
logger.error(f"_handle_messages: [#{message_count}] Error in _process_message: {e}", exc_info=True)
logger.info(f"_handle_messages: Exited message loop (async for ended) after processing {message_count} messages")
logger.info("_handle_messages: Exited message loop (async for ended)")
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"WebSocket connection closed: {e}")
self.is_registered = False
# The periodic registration check will detect this and attempt reconnection
except Exception as e:
logger.error(f"Error handling messages: {e}", exc_info=True)
self.is_registered = False
async def _process_message(self, message: MessageData):
"""Process incoming signaling messages"""
logger.debug(f"_process_message: ENTRY - Processing message: {message}")
try:
# Handle error messages specially since they have a different structure
if message.get("type") == "error" and "error" in message:
error_msg = message.get("error", "Unknown error")
logger.error(f"Received error from signaling server: {error_msg}")
return
# Validate the base message structure for non-error messages
validated_message = WebSocketMessageModel.model_validate(message)
msg_type = validated_message.type
data = validated_message.data
logger.debug(f"_process_message: Validated message - type: {msg_type}, data: {data}")
except ValidationError as e:
logger.error(f"Invalid message structure: {e}", exc_info=True)
return
logger.debug(
f"_process_message: Received message type: {msg_type} with data: {data}"
)
if msg_type == "addPeer":
if data is None:
logger.error("addPeer message missing required data")
return
try:
validated = AddPeerModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid addPeer payload: {e}", exc_info=True)
return
await self._handle_add_peer(validated)
elif msg_type == "removePeer":
try:
validated = RemovePeerModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid removePeer payload: {e}", exc_info=True)
return
await self._handle_remove_peer(validated)
elif msg_type == "sessionDescription":
try:
validated = SessionDescriptionModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid sessionDescription payload: {e}", exc_info=True)
return
await self._handle_session_description(validated)
elif msg_type == "iceCandidate":
try:
validated = IceCandidateModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid iceCandidate payload: {e}", exc_info=True)
return
await self._handle_ice_candidate(validated)
elif msg_type == "join_status":
try:
validated = JoinStatusModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid join_status payload: {e}", exc_info=True)
return
logger.info(f"Join status: {validated.status} - {validated.message}")
elif msg_type == "user_joined":
try:
validated = UserJoinedModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid user_joined payload: {e}", exc_info=True)
return
logger.info(
f"User joined: {validated.name} (session: {validated.session_id})"
)
logger.debug(f"user_joined payload: {validated}")
elif msg_type == "lobby_state":
try:
validated = LobbyStateModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid lobby_state payload: {e}", exc_info=True)
return
participants = validated.participants
logger.info(f"Lobby state updated: {len(participants)} participants")
elif msg_type == "update_name":
try:
validated = UpdateNameModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid update payload: {e}", exc_info=True)
return
logger.info(f"Received update message: {validated}")
elif msg_type == "status_check":
# Handle status check messages - these are used to verify connection
logger.debug(f"Received status check message with data: {data}")
# No special processing needed for status checks, just acknowledge receipt
elif msg_type == "status_ok":
# Handle status_ok response from server
logger.debug(f"Received status_ok from server with data: {data}")
# This confirms the connection is healthy
elif msg_type == "status_response":
# Handle status_response from server (response to status_check)
logger.debug(f"Received status_response from server with data: {data}")
# This confirms the connection is healthy and provides session info
elif msg_type == "chat_message":
try:
validated = ChatMessageModel.model_validate(data)
if validated.sender_session_id == self.session_id:
logger.debug("Ignoring chat message from ourselves")
return
except ValidationError as e:
logger.error(f"Invalid chat_message payload: {e}", exc_info=True)
return
if validated.sender_session_id == self.session_id:
logger.debug("Ignoring chat message from ourselves")
return
logger.info(f"Received chat message from {validated.sender_name}: {validated.message[:50]}...")
# Call the callback if it's set
if self.on_chat_message_received:
try:
await self.on_chat_message_received(validated)
except Exception as e:
logger.error(f"Error in chat message callback: {e}", exc_info=True)
elif msg_type == "error":
logger.error(f"Received error from signaling server: {data}")
elif msg_type == "addPeer":
# Log raw addPeer receipt for debugging
try:
validated = AddPeerModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid addPeer payload: {e}", exc_info=True)
return
logger.info(f"Received addPeer for peer {validated.peer_name} (peer_id={validated.peer_id}) should_create_offer={validated.should_create_offer}")
logger.debug(f"addPeer payload: {validated}")
await self._handle_add_peer(validated)
elif msg_type == "sessionDescription":
try:
validated = SessionDescriptionModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid sessionDescription payload: {e}", exc_info=True)
return
logger.info(f"Received sessionDescription from {validated.peer_name} (peer_id={validated.peer_id})")
logger.debug(f"sessionDescription payload keys: {list(data.keys())}")
await self._handle_session_description(validated)
elif msg_type == "iceCandidate":
try:
validated = IceCandidateModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid iceCandidate payload: {e}", exc_info=True)
return
logger.info(f"Received iceCandidate from {validated.peer_name} (peer_id={validated.peer_id})")
logger.debug(f"iceCandidate payload: {validated}")
await self._handle_ice_candidate(validated)
else:
logger.info(f"Unhandled message type: {msg_type} with data: {data}")
# Continue with more methods in the next part...
async def _handle_add_peer(self, data: AddPeerModel):
"""Handle addPeer message - create new peer connection"""
peer_id = data.peer_id
peer_name = data.peer_name
should_create_offer = data.should_create_offer
logger.info(
f"Adding peer: {peer_name} (should_create_offer: {should_create_offer})"
)
logger.debug(
f"_handle_add_peer: peer_id={peer_id}, peer_name={peer_name}, should_create_offer={should_create_offer}"
)
# Check if peer already exists
if peer_id in self.peer_connections:
pc = self.peer_connections[peer_id]
logger.debug(
f"_handle_add_peer: Existing connection state: {pc.connectionState}"
)
if pc.connectionState in ["new", "connected", "connecting"]:
logger.info(f"Peer connection already exists for {peer_name}")
return
else:
# Clean up stale connection
logger.debug(
f"_handle_add_peer: Closing stale connection for {peer_name}"
)
await pc.close()
del self.peer_connections[peer_id]
# Create new peer
peer = Peer(session_id=peer_id, peer_name=peer_name, local=False)
self.peers[peer_id] = peer
# Create RTCPeerConnection
config = RTCConfiguration(
iceServers=[
RTCIceServer(urls="stun:ketrenos.com:3478"),
RTCIceServer(
urls="turns:ketrenos.com:5349",
username="ketra",
credential="ketran",
),
# Add Google's public STUN server as fallback
RTCIceServer(urls="stun:stun.l.google.com:19302"),
],
)
logger.debug(
f"_handle_add_peer: Creating RTCPeerConnection for {peer_name} with config: {config}"
)
pc = RTCPeerConnection(configuration=config)
# Add ICE gathering state change handler
def on_ice_gathering_state_change() -> None:
logger.info(f"ICE gathering state: {pc.iceGatheringState}")
if pc.iceGatheringState == "complete":
logger.info(
f"ICE gathering complete for {peer_name} - checking if candidates were generated..."
)
pc.on("icegatheringstatechange")(on_ice_gathering_state_change)
# Add connection state change handler
def on_connection_state_change() -> None:
logger.info(f"Connection state: {pc.connectionState}")
pc.on("connectionstatechange")(on_connection_state_change)
self.peer_connections[peer_id] = pc
peer.connection = pc
# Set up event handlers
def on_track(track: MediaStreamTrack) -> None:
logger.info(f"Received {track.kind} track from {peer_name}")
logger.info(f"on_track: {track.kind} from {peer_name}, track={track}")
peer.attributes[f"{track.kind}_track"] = track
if self.on_track_received:
asyncio.ensure_future(self.on_track_received(peer, track))
pc.on("track")(on_track)
def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None:
logger.info(f"on_ice_candidate: {candidate}")
logger.info(
f"on_ice_candidate CALLED for {peer_name}: candidate={candidate}"
)
if not candidate:
logger.info(
f"on_ice_candidate: End of candidates signal for {peer_name}"
)
return
# Raw SDP fragment for the candidate
raw = getattr(candidate, "candidate", None)
# Try to infer candidate type from the SDP string (host/srflx/relay/prflx)
def _parse_type(s: Optional[str]) -> str:
if not s:
return "eoc"
m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", s)
return m.group(1) if m else "unknown"
cand_type = _parse_type(raw)
protocol = getattr(candidate, "protocol", "unknown")
logger.info(
f"ICE candidate outgoing for {peer_name}: type={cand_type} protocol={protocol} sdp={raw}"
)
# Ensure candidate has the proper SDP format
if raw and not raw.startswith("candidate:"):
raw = f"candidate:{raw}"
# Clean up any extra spaces
if raw:
raw = raw.replace("candidate: ", "candidate:")
candidate_model = ICECandidateDictModel(
candidate=raw,
sdpMid=getattr(candidate, "sdpMid", None),
sdpMLineIndex=getattr(candidate, "sdpMLineIndex", None),
)
payload_model = IceCandidateModel(
peer_id=peer_id, peer_name=peer_name, candidate=candidate_model
)
logger.info(
f"on_ice_candidate: Sending relayICECandidate for {peer_name}: candidate='{candidate_model.candidate}' sdpMid={candidate_model.sdpMid} sdpMLineIndex={candidate_model.sdpMLineIndex}"
)
asyncio.ensure_future(
self._send_message("relayICECandidate", payload_model.model_dump())
)
pc.on("icecandidate")(on_ice_candidate)
# Add local tracks with proper transceiver configuration
for track in self.local_tracks.values():
logger.debug(
f"_handle_add_peer: Adding local track {track.kind} to {peer_name}"
)
# Add track with explicit transceiver direction to ensure proper SDP generation
transceiver = pc.addTransceiver(track, direction="sendrecv")
logger.debug(f"_handle_add_peer: Added transceiver for {track.kind}, mid: {transceiver.mid}")
# For bots with no local tracks (like whisper), add receive-only transceivers
# to ensure proper SDP negotiation for incoming audio/video
if not self.local_tracks:
logger.info(f"_handle_add_peer: No local tracks for {peer_name}, adding receive-only transceivers")
# Add receive-only audio transceiver for bots that need to receive audio
audio_transceiver = pc.addTransceiver("audio", direction="recvonly")
logger.debug(f"_handle_add_peer: Added receive-only audio transceiver, mid: {audio_transceiver.mid}")
# Add receive-only video transceiver to handle any incoming video
video_transceiver = pc.addTransceiver("video", direction="recvonly")
logger.debug(f"_handle_add_peer: Added receive-only video transceiver, mid: {video_transceiver.mid}")
# Create offer if instructed by server coordination
if should_create_offer:
logger.info(f"Creating offer for {peer_name} as instructed by server")
await self._create_and_send_offer(peer_id, peer_name, pc)
if self.on_peer_added:
await self.on_peer_added(peer)
def _clean_sdp(self, sdp: str) -> str:
"""Clean and validate SDP to ensure proper BUNDLE groups"""
lines = sdp.split('\n')
cleaned_lines: list[str] = []
bundle_mids: list[str] = []
current_mid: str | None = None
for line in lines:
if line.startswith('a=mid:'):
current_mid = line.split(':', 1)[1].strip()
if current_mid: # Only add non-empty MIDs
bundle_mids.append(current_mid)
cleaned_lines.append(line)
elif line.startswith('a=group:BUNDLE'):
# Rebuild BUNDLE group with valid MIDs only
if bundle_mids:
cleaned_lines.append(f'a=group:BUNDLE {" ".join(bundle_mids)}')
# Skip original BUNDLE line to avoid duplicates
else:
cleaned_lines.append(line)
return '\n'.join(cleaned_lines)
async def _create_and_send_offer(self, peer_id: str, peer_name: str, pc: RTCPeerConnection):
"""Create and send an offer to a peer"""
self.initiated_offer.add(peer_id)
self.making_offer[peer_id] = True
self.is_negotiating[peer_id] = True
try:
logger.debug(f"_handle_add_peer: Creating offer for {peer_name}")
offer = await pc.createOffer()
logger.debug(
f"_handle_add_peer: Offer created for {peer_name}: {offer}"
)
await pc.setLocalDescription(offer)
logger.debug(f"_handle_add_peer: Local description set for {peer_name}")
# Clean the SDP to ensure proper BUNDLE groups
cleaned_sdp = self._clean_sdp(offer.sdp)
if cleaned_sdp != offer.sdp:
logger.debug(f"_create_and_send_offer: Cleaned SDP for {peer_name}")
# Update the offer with cleaned SDP
offer = RTCSessionDescription(sdp=cleaned_sdp, type=offer.type)
await pc.setLocalDescription(offer)
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
# Use Method 2: Complete SDP approach to extract ICE candidates
logger.debug(
f"_handle_add_peer: Waiting for ICE gathering to complete for {peer_name}"
)
# Add timeout to prevent blocking message processing indefinitely
ice_gathering_timeout = 10.0 # 10 seconds timeout
start_time = asyncio.get_event_loop().time()
while pc.iceGatheringState != "complete":
current_time = asyncio.get_event_loop().time()
if current_time - start_time > ice_gathering_timeout:
logger.warning(f"ICE gathering timeout ({ice_gathering_timeout}s) for {peer_name}, proceeding anyway")
break
await asyncio.sleep(0.1)
logger.debug(
f"_handle_add_peer: ICE gathering complete, extracting candidates from SDP for {peer_name}"
)
await self._extract_and_send_candidates(peer_id, peer_name, pc)
session_desc_typed = SessionDescriptionTypedModel(
type=offer.type, sdp=offer.sdp
)
session_desc_model = SessionDescriptionModel(
peer_id=peer_id,
peer_name=peer_name,
session_description=session_desc_typed,
)
# Debug log the SDP to help diagnose issues
logger.debug(f"_create_and_send_offer: SDP for {peer_name}:\n{offer.sdp[:500]}...")
await self._send_message(
"relaySessionDescription",
session_desc_model.model_dump(),
)
logger.info(f"Offer sent to {peer_name}")
except Exception as e:
logger.error(
f"Failed to create/send offer to {peer_name}: {e}", exc_info=True
)
finally:
self.making_offer[peer_id] = False
async def _extract_and_send_candidates(self, peer_id: str, peer_name: str, pc: RTCPeerConnection):
"""Extract ICE candidates from SDP and send them"""
if not pc.localDescription or not pc.localDescription.sdp:
logger.warning(f"_extract_and_send_candidates: No local description for {peer_name}")
return
sdp_lines = pc.localDescription.sdp.split("\n")
# Parse SDP structure to map media sections to their MIDs
media_sections: list[tuple[int, str]] = [] # List of (media_index, mid)
current_media_index = -1
current_mid: str | None = None
# First pass: identify media sections and their MIDs
for line in sdp_lines:
line = line.strip()
if line.startswith("m="): # Media section start
current_media_index += 1
current_mid = None # Reset MID for this section
elif line.startswith("a=mid:"): # Media ID
current_mid = line.split(":", 1)[1].strip()
if current_mid:
media_sections.append((current_media_index, current_mid))
logger.debug(f"_extract_and_send_candidates: Found media sections for {peer_name}: {media_sections}")
# If no MIDs found, fall back to using the transceivers
if not media_sections:
transceivers = pc.getTransceivers()
for i, transceiver in enumerate(transceivers):
if hasattr(transceiver, 'mid') and transceiver.mid:
media_sections.append((i, str(transceiver.mid)))
else:
# Fallback to default MID pattern
media_sections.append((i, str(i)))
logger.debug(f"_extract_and_send_candidates: Using transceiver MIDs for {peer_name}: {media_sections}")
# Second pass: extract candidates and assign them to media sections
current_media_index = -1
current_section_mid: str | None = None
candidates_sent = 0
for line in sdp_lines:
line = line.strip()
if line.startswith("m="): # Media section start
current_media_index += 1
# Find the MID for this media section
current_section_mid = None
for media_idx, mid in media_sections:
if media_idx == current_media_index:
current_section_mid = mid
break
# If no MID found, use the media index as string
if current_section_mid is None:
current_section_mid = str(current_media_index)
elif line.startswith("a=candidate:"):
candidate_sdp = line[2:] # Remove 'a=' prefix, keeping "candidate:..."
# Clean up any extra spaces
if candidate_sdp:
candidate_sdp = candidate_sdp.replace("candidate: ", "candidate:")
# Only send if we have valid MID and media index
if current_section_mid is not None and current_media_index >= 0:
candidate_model = ICECandidateDictModel(
candidate=candidate_sdp,
sdpMid=current_section_mid,
sdpMLineIndex=current_media_index,
)
payload_candidate = IceCandidateModel(
peer_id=peer_id,
peer_name=peer_name,
candidate=candidate_model,
)
logger.debug(
f"_extract_and_send_candidates: Sending ICE candidate for {peer_name} (mid={current_section_mid}, idx={current_media_index}): candidate='{candidate_sdp}'"
)
await self._send_message(
"relayICECandidate", payload_candidate.model_dump()
)
candidates_sent += 1
else:
logger.warning(f"_extract_and_send_candidates: Skipping candidate with invalid MID/index for {peer_name}")
# Send end-of-candidates signal only if we have valid media sections
if media_sections:
# Use the first media section for end-of-candidates
first_media_idx, first_mid = media_sections[0]
end_candidate_model = ICECandidateDictModel(
candidate="",
sdpMid=first_mid,
sdpMLineIndex=first_media_idx,
)
payload_end = IceCandidateModel(
peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model
)
logger.debug(
f"_extract_and_send_candidates: Sending end-of-candidates signal for {peer_name} (mid={first_mid}, idx={first_media_idx})"
)
await self._send_message("relayICECandidate", payload_end.model_dump())
logger.debug(
f"_extract_and_send_candidates: Sent {candidates_sent} ICE candidates to {peer_name}"
)
async def _handle_remove_peer(self, data: RemovePeerModel):
"""Handle removePeer message"""
peer_id = data.peer_id
peer_name = data.peer_name
logger.info(f"Removing peer: {peer_name}")
# Close peer connection
if peer_id in self.peer_connections:
pc = self.peer_connections[peer_id]
await pc.close()
del self.peer_connections[peer_id]
# Clean up state
self.is_negotiating.pop(peer_id, None)
self.making_offer.pop(peer_id, None)
self.initiated_offer.discard(peer_id)
self.pending_ice_candidates.pop(peer_id, None)
# Remove peer
peer = self.peers.pop(peer_id, None)
if peer and self.on_peer_removed:
await self.on_peer_removed(peer)
async def _handle_session_description(self, data: SessionDescriptionModel):
"""Handle sessionDescription message"""
peer_id = data.peer_id
peer_name = data.peer_name
session_description = data.session_description.model_dump()
logger.info(f"Received {session_description['type']} from {peer_name}")
# Debug log the received SDP to help diagnose issues
logger.debug(f"Received SDP from {peer_name}:\n{session_description['sdp'][:500]}...")
pc = self.peer_connections.get(peer_id)
if not pc:
logger.error(f"No peer connection for {peer_name}")
return
desc = RTCSessionDescription(
sdp=session_description["sdp"], type=session_description["type"]
)
# Handle offer collision (polite peer pattern)
making_offer = self.making_offer.get(peer_id, False)
offer_collision = desc.type == "offer" and (
making_offer or pc.signalingState != "stable"
)
we_initiated = peer_id in self.initiated_offer
# For bots, be more polite - always yield to human users in collision
# Bots should generally be the polite peer
ignore_offer = offer_collision and we_initiated
if ignore_offer:
logger.info(f"Ignoring offer from {peer_name} due to collision (bot being polite)")
# Reset our offer state to allow the remote offer to proceed
if peer_id in self.initiated_offer:
self.initiated_offer.remove(peer_id)
self.making_offer[peer_id] = False
self.is_negotiating[peer_id] = False
# Wait a bit and then retry if connection isn't established
async def retry_connection():
await asyncio.sleep(1.0) # Wait 1 second
if pc.connectionState not in ["connected", "closed", "failed"]:
logger.info(f"Retrying connection setup for {peer_name} after collision")
# Don't create offer, let the remote peer drive
asyncio.create_task(retry_connection())
return
try:
await pc.setRemoteDescription(desc)
self.is_negotiating[peer_id] = False
logger.info(f"Remote description set for {peer_name}")
# Process queued ICE candidates
pending_candidates = self.pending_ice_candidates.pop(peer_id, [])
for candidate_data in pending_candidates:
# candidate_data is an ICECandidateDictModel Pydantic model
cand = candidate_data.candidate
# handle end-of-candidates marker
if not cand:
await pc.addIceCandidate(None)
logger.info(f"Added queued end-of-candidates for {peer_name}")
continue
# cand may be the full "candidate:..." string or the inner SDP part
if cand and cand.startswith("candidate:"):
sdp_part = cand.split(":", 1)[1]
else:
sdp_part = cand
try:
rtc_candidate = candidate_from_sdp(sdp_part)
rtc_candidate.sdpMid = candidate_data.sdpMid
rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex
await pc.addIceCandidate(rtc_candidate)
logger.info(f"Added queued ICE candidate for {peer_name}")
except Exception as e:
logger.error(
f"Failed to add queued ICE candidate for {peer_name}: {e}"
)
except Exception as e:
logger.error(f"Failed to set remote description for {peer_name}: {e}")
return
# Create answer if this was an offer
if session_description["type"] == "offer":
await self._create_and_send_answer(peer_id, peer_name, pc)
async def _create_and_send_answer(self, peer_id: str, peer_name: str, pc: RTCPeerConnection):
"""Create and send an answer to a peer"""
try:
answer = await pc.createAnswer()
# Clean the SDP to ensure proper BUNDLE groups
cleaned_sdp = self._clean_sdp(answer.sdp)
if cleaned_sdp != answer.sdp:
logger.debug(f"_create_and_send_answer: Cleaned SDP for {peer_name}")
# Update the answer with cleaned SDP
answer = RTCSessionDescription(sdp=cleaned_sdp, type=answer.type)
await pc.setLocalDescription(answer)
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
# Use Method 2: Complete SDP approach to extract ICE candidates
logger.debug(
f"_create_and_send_answer: Waiting for ICE gathering to complete for {peer_name} (answer)"
)
# Add timeout to prevent blocking message processing indefinitely
ice_gathering_timeout = 10.0 # 10 seconds timeout
start_time = asyncio.get_event_loop().time()
while pc.iceGatheringState != "complete":
current_time = asyncio.get_event_loop().time()
if current_time - start_time > ice_gathering_timeout:
logger.warning(f"ICE gathering timeout ({ice_gathering_timeout}s) for {peer_name} (answer), proceeding anyway")
break
await asyncio.sleep(0.1)
logger.debug(
f"_create_and_send_answer: ICE gathering complete, extracting candidates from SDP for {peer_name} (answer)"
)
await self._extract_and_send_candidates(peer_id, peer_name, pc)
session_desc_typed = SessionDescriptionTypedModel(
type=answer.type, sdp=answer.sdp
)
session_desc_model = SessionDescriptionModel(
peer_id=peer_id,
peer_name=peer_name,
session_description=session_desc_typed,
)
# Debug log the SDP to help diagnose issues
logger.debug(f"_create_and_send_answer: SDP for {peer_name}:\n{answer.sdp[:500]}...")
await self._send_message(
"relaySessionDescription",
session_desc_model.model_dump(),
)
logger.info(f"Answer sent to {peer_name}")
except Exception as e:
logger.error(f"Failed to create/send answer to {peer_name}: {e}")
async def _handle_ice_candidate(self, data: IceCandidateModel):
"""Handle iceCandidate message"""
peer_id = data.peer_id
peer_name = data.peer_name
candidate_data = data.candidate
logger.info(f"Received ICE candidate from {peer_name}")
pc = self.peer_connections.get(peer_id)
if not pc:
logger.error(f"No peer connection for {peer_name}")
return
# Queue candidate if remote description not set
if not pc.remoteDescription:
logger.info(
f"Remote description not set, queuing ICE candidate for {peer_name}"
)
if peer_id not in self.pending_ice_candidates:
self.pending_ice_candidates[peer_id] = []
# candidate_data is an ICECandidateDictModel Pydantic model
self.pending_ice_candidates[peer_id].append(candidate_data)
return
try:
cand = candidate_data.candidate
if not cand:
# end-of-candidates
await pc.addIceCandidate(None)
logger.info(f"End-of-candidates added for {peer_name}")
return
if cand and cand.startswith("candidate:"):
sdp_part = cand.split(":", 1)[1]
else:
sdp_part = cand
# Detect type for logging
try:
m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", sdp_part)
cand_type = m.group(1) if m else "unknown"
except Exception:
cand_type = "unknown"
try:
rtc_candidate = candidate_from_sdp(sdp_part)
rtc_candidate.sdpMid = candidate_data.sdpMid
rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex
# aiortc expects an object with attributes (RTCIceCandidate)
await pc.addIceCandidate(rtc_candidate)
logger.info(f"ICE candidate added for {peer_name}: type={cand_type}")
except Exception as e:
logger.error(
f"Failed to add ICE candidate for {peer_name}: type={cand_type} error={e} sdp='{sdp_part}'",
exc_info=True,
)
except Exception as e:
logger.error(
f"Unexpected error handling ICE candidate for {peer_name}: {e}",
exc_info=True,
)