""" WebRTC Media Agent for Python This module provides synthetic audio/video track creation and WebRTC signaling server communication, ported from the JavaScript MediaControl implementation. """ from __future__ import annotations import asyncio import json import websockets import numpy as np import cv2 import fractions from typing import ( Dict, Optional, Callable, Awaitable, TypedDict, Protocol, AsyncIterator, cast, ) from dataclasses import dataclass, field from pydantic import ValidationError # types.SimpleNamespace removed — not used anymore after parsing candidates via aiortc.sdp import argparse import urllib.request import urllib.error import urllib.parse import ssl import sys import os # Import shared models sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from shared.models import ( SessionModel, LobbyCreateResponse, WebSocketMessageModel, JoinStatusModel, UserJoinedModel, ParticipantModel, LobbyStateModel, UpdateModel, ICECandidateDictModel, AddPeerModel, RemovePeerModel, SessionDescriptionTypedModel, SessionDescriptionModel, IceCandidateModel, ) from aiortc import ( RTCPeerConnection, RTCSessionDescription, RTCIceCandidate, MediaStreamTrack, ) from av import VideoFrame, AudioFrame import time from logger import logger # Defensive monkeypatch: aioice Transaction.__retry may run after the # underlying datagram transport or loop was torn down which results in # AttributeError being raised and flooding logs. Wrap the original # implementation to catch and suppress AttributeError while preserving # other exceptions. This is a temporary mitigation to keep logs readable # while we investigate/upstream a proper fix or upgrade aioice. try: import aioice.stun as _aioice_stun # type: ignore # The method is defined with a double-underscore name (__retry) which # gets name-mangled. Detect the actual attribute name robustly. retry_attr_name = None for name in dir(_aioice_stun.Transaction): if name.endswith("retry"): obj = getattr(_aioice_stun.Transaction, name) if callable(obj): retry_attr_name = name _orig_retry = obj break if retry_attr_name is not None: # Simple in-process dedupe cache so we only log the same AttributeError # once per interval. This prevents flooding the logs when many # transactions race to run after shutdown. _MONKEYPATCH_LOG_CACHE: dict[str, float] = {} _MONKEYPATCH_LOG_SUPPRESSION_INTERVAL = 5.0 def _should_log_once(key: str) -> bool: now = time.time() last = _MONKEYPATCH_LOG_CACHE.get(key) if last is None or (now - last) > _MONKEYPATCH_LOG_SUPPRESSION_INTERVAL: _MONKEYPATCH_LOG_CACHE[key] = now return True return False def _safe_transaction_retry(self, *args, **kwargs): # type: ignore try: return _orig_retry(self, *args, **kwargs) # type: ignore except AttributeError as e: # type: ignore # Transport or event-loop already closed; log once per key key = f"Transaction.{retry_attr_name}:{e}" if _should_log_once(key): logger.warning( "aioice Transaction.%s AttributeError suppressed: %s", retry_attr_name, e, ) except Exception: # type: ignore # Preserve visibility for other unexpected exceptions logger.exception( "aioice Transaction.%s raised an unexpected exception", retry_attr_name, ) setattr(_aioice_stun.Transaction, retry_attr_name, _safe_transaction_retry) # type: ignore logger.info("Applied safe aioice Transaction.%s monkeypatch", retry_attr_name) else: logger.warning("aioice Transaction.__retry not found; skipping monkeypatch") except Exception as e: logger.exception("Failed to apply aioice Transaction.__retry monkeypatch: %s", e) # Additional defensive patch: wrap the protocol-level send_stun implementation # (e.g. StunProtocol.send_stun) which ultimately calls the datagram transport's # sendto. If the transport or its loop is already torn down, sendto can raise # AttributeError which then triggers asyncio's fatal error path (calling a None # loop). Wrapping here prevents the flood of selector_events/_fatal_error # AttributeError traces. try: import aioice.ice as _aioice_ice # type: ignore # Prefer to patch StunProtocol.send_stun which is used by the ICE code. send_attr_name = None if hasattr(_aioice_ice, "StunProtocol"): proto_cls = getattr(_aioice_ice, "StunProtocol") for name in dir(proto_cls): if name.endswith("send_stun"): attr = getattr(proto_cls, name) if callable(attr): send_attr_name = name _orig_send_stun = attr break if send_attr_name is not None: def _safe_send_stun(self, message, addr): # type: ignore try: return _orig_send_stun(self, message, addr) # type: ignore except AttributeError as e: # type: ignore # Likely transport._sock or transport._loop is None; log once key = f"StunProtocol.{send_attr_name}:{e}" if _should_log_once(key): logger.warning( "aioice StunProtocol.%s AttributeError suppressed: %s", send_attr_name, e, ) except Exception: # type: ignore logger.exception( "aioice StunProtocol.%s raised unexpected exception", send_attr_name ) setattr(proto_cls, send_attr_name, _safe_send_stun) # type: ignore logger.info("Applied safe aioice StunProtocol.%s monkeypatch", send_attr_name) else: logger.warning("aioice StunProtocol.send_stun not found; skipping monkeypatch") except Exception as e: logger.exception("Failed to apply aioice StunProtocol.send_stun monkeypatch: %s", e) # TypedDict for ICE candidate payloads received from signalling class ICECandidateDict(TypedDict, total=False): candidate: str sdpMid: Optional[str] sdpMLineIndex: Optional[int] # Generic message payload type MessageData = dict[str, object] # Message TypedDicts for signaling payloads class BaseMessage(TypedDict, total=False): type: str data: object class AddPeerPayload(TypedDict): peer_id: str peer_name: str should_create_offer: bool class RemovePeerPayload(TypedDict): peer_id: str peer_name: str class SessionDescriptionTyped(TypedDict): type: str sdp: str class SessionDescriptionPayload(TypedDict): peer_id: str peer_name: str session_description: SessionDescriptionTyped class IceCandidatePayload(TypedDict): peer_id: str peer_name: str candidate: ICECandidateDict class WebSocketProtocol(Protocol): def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ... def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ... def __aiter__(self) -> AsyncIterator[str]: ... def _default_attributes() -> Dict[str, object]: return {} @dataclass class Peer: """Represents a WebRTC peer in the session""" session_id: str peer_name: str # Generic attributes bag. Values can be tracks or simple metadata. attributes: Dict[str, object] = field(default_factory=_default_attributes) muted: bool = False video_on: bool = True local: bool = False dead: bool = False connection: Optional[RTCPeerConnection] = None class AnimatedVideoTrack(MediaStreamTrack): async def next_timestamp(self): # Returns (pts, time_base) for 15 FPS video pts = int(self.frame_count * (1 / 15) * 90000) time_base = 1 / 90000 return pts, time_base """ Synthetic video track that generates animated content with a bouncing ball. Ported from JavaScript createAnimatedVideoTrack function. """ kind = "video" def __init__(self, width: int = 320, height: int = 240, name: str = ""): super().__init__() self.width = width self.height = height self.name = name # Generate color from name hash (similar to JavaScript nameToColor) self.ball_color = ( self._name_to_color(name) if name else (0, 255, 136) ) # Default green # Ball properties self.ball = { "x": width / 2, "y": height / 2, "radius": min(width, height) * 0.06, "dx": 3.0, "dy": 2.0, } self.frame_count = 0 self._start_time = time.time() def _name_to_color(self, name: str) -> tuple[int, int, int]: """Convert name to HSL color, then to RGB tuple""" # Simple hash function (djb2) hash_value = 5381 for char in name: hash_value = ((hash_value << 5) + hash_value + ord(char)) & 0xFFFFFFFF # Generate HSL color from hash hue = abs(hash_value) % 360 sat = 60 + (abs(hash_value) % 30) # 60-89% light = 45 + (abs(hash_value) % 30) # 45-74% # Convert HSL to RGB h = hue / 360.0 s = sat / 100.0 lightness = light / 100.0 c = (1 - abs(2 * lightness - 1)) * s x = c * (1 - abs((h * 6) % 2 - 1)) m = lightness - c / 2 if h < 1 / 6: r, g, b = c, x, 0 elif h < 2 / 6: r, g, b = x, c, 0 elif h < 3 / 6: r, g, b = 0, c, x elif h < 4 / 6: r, g, b = 0, x, c elif h < 5 / 6: r, g, b = x, 0, c else: r, g, b = c, 0, x return ( int((b + m) * 255), int((g + m) * 255), int((r + m) * 255), ) # BGR for OpenCV async def recv(self): """Generate video frames at 15 FPS""" pts, time_base = await self.next_timestamp() # Create black background frame_array = np.zeros((self.height, self.width, 3), dtype=np.uint8) # Update ball position ball = self.ball ball["x"] += ball["dx"] ball["y"] += ball["dy"] # Bounce off walls if ball["x"] + ball["radius"] >= self.width or ball["x"] - ball["radius"] <= 0: ball["dx"] = -ball["dx"] if ball["y"] + ball["radius"] >= self.height or ball["y"] - ball["radius"] <= 0: ball["dy"] = -ball["dy"] # Keep ball in bounds ball["x"] = max(ball["radius"], min(self.width - ball["radius"], ball["x"])) ball["y"] = max(ball["radius"], min(self.height - ball["radius"], ball["y"])) # Draw ball cv2.circle( frame_array, (int(ball["x"]), int(ball["y"])), int(ball["radius"]), self.ball_color, -1, ) # Add frame counter text frame_text = f"Frame: {int(time.time() * 1000) % 10000}" logger.info(frame_text) cv2.putText( frame_array, frame_text, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, ) # Convert to VideoFrame frame = VideoFrame.from_ndarray(frame_array, format="bgr24") frame.pts = pts frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000) self.frame_count += 1 return frame class SilentAudioTrack(MediaStreamTrack): async def next_timestamp(self): # Returns (pts, time_base) for 20ms audio frames at 48kHz pts = int(time.time() * self.sample_rate) time_base = 1 / self.sample_rate return pts, time_base """ Synthetic audio track that generates silence. Ported from JavaScript createSilentAudioTrack function. """ kind = "audio" def __init__(self): super().__init__() self.sample_rate = 48000 self.samples_per_frame = 960 # 20ms at 48kHz async def recv(self): """Generate silent audio frames""" pts, time_base = await self.next_timestamp() # Create silent audio data samples = np.zeros((self.samples_per_frame,), dtype=np.float32) # Convert to AudioFrame frame = AudioFrame.from_ndarray( samples.reshape(1, -1), format="flt", layout="mono" ) frame.sample_rate = self.sample_rate frame.pts = pts frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000) return frame class WebRTCSignalingClient: """ WebRTC signaling client that communicates with the FastAPI signaling server. Handles peer-to-peer connection establishment and media streaming. """ def __init__( self, server_url: str, lobby_id: str, session_id: str, session_name: str, insecure: bool = False, ): self.server_url = server_url self.lobby_id = lobby_id self.session_id = session_id self.session_name = session_name self.insecure = insecure # WebSocket client protocol instance (typed as object to avoid Any) self.websocket: Optional[object] = None # Optional password to register or takeover a name self.name_password: Optional[str] = None self.peers: dict[str, Peer] = {} self.peer_connections: dict[str, RTCPeerConnection] = {} self.local_tracks: dict[str, MediaStreamTrack] = {} # State management self.is_negotiating: dict[str, bool] = {} self.making_offer: dict[str, bool] = {} self.initiated_offer: set[str] = set() self.pending_ice_candidates: dict[str, list[ICECandidateDict]] = {} # Event callbacks self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None self.on_track_received: Optional[ Callable[[Peer, MediaStreamTrack], Awaitable[None]] ] = None async def connect(self): """Connect to the signaling server""" ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}" logger.info(f"Connecting to signaling server: {ws_url}") # Log network information for debugging try: import socket hostname = socket.gethostname() local_ip = socket.gethostbyname(hostname) logger.info(f"Container hostname: {hostname}, local IP: {local_ip}") # Get all network interfaces import subprocess result = subprocess.run( ["ip", "addr", "show"], capture_output=True, text=True ) logger.info(f"Network interfaces:\n{result.stdout}") except Exception as e: logger.warning(f"Could not get network info: {e}") try: # If insecure (self-signed certs), create an SSL context for the websocket ws_ssl = None if self.insecure: ws_ssl = ssl.create_default_context() ws_ssl.check_hostname = False ws_ssl.verify_mode = ssl.CERT_NONE logger.info( f"Attempting websocket connection to {ws_url} with ssl={bool(ws_ssl)}" ) self.websocket = await websockets.connect(ws_url, ssl=ws_ssl) logger.info("Connected to signaling server") # Set up local media await self._setup_local_media() # Set name and join lobby name_payload: MessageData = {"name": self.session_name} if self.name_password: name_payload["password"] = self.name_password logger.info(f"Sending set_name: {name_payload}") await self._send_message("set_name", name_payload) logger.info("Sending join message") await self._send_message("join", {}) # Start message handling logger.info("Starting message handler loop") await self._handle_messages() except Exception as e: logger.error(f"Failed to connect to signaling server: {e}", exc_info=True) raise async def disconnect(self): """Disconnect from signaling server and cleanup""" if self.websocket: ws = cast(WebSocketProtocol, self.websocket) await ws.close() # Close all peer connections for pc in self.peer_connections.values(): await pc.close() # Stop local tracks for track in self.local_tracks.values(): track.stop() logger.info("Disconnected from signaling server") async def _setup_local_media(self): """Create local synthetic media tracks""" # Create synthetic video track video_track = AnimatedVideoTrack(name=self.session_name) self.local_tracks["video"] = video_track # Create synthetic audio track audio_track = SilentAudioTrack() self.local_tracks["audio"] = audio_track # Add local peer to peers dict local_peer = Peer( session_id=self.session_id, peer_name=self.session_name, local=True, attributes={"tracks": self.local_tracks}, ) self.peers[self.session_id] = local_peer logger.info("Local synthetic media tracks created") async def _send_message( self, message_type: str, data: Optional[MessageData] = None ): """Send message to signaling server""" if not self.websocket: logger.error("No websocket connection") return # Build message with explicit type to avoid type narrowing message: dict[str, object] = {"type": message_type} if data is not None: message["data"] = data ws = cast(WebSocketProtocol, self.websocket) try: logger.info(f"_send_message: Sending {message_type} with data: {data}") await ws.send(json.dumps(message)) logger.info(f"_send_message: Sent message: {message_type}") except Exception as e: logger.error( f"_send_message: Failed to send {message_type}: {e}", exc_info=True ) async def _handle_messages(self): """Handle incoming messages from signaling server""" try: ws = cast(WebSocketProtocol, self.websocket) async for message in ws: logger.info(f"_handle_messages: Received raw message: {message}") try: data = cast(MessageData, json.loads(message)) except Exception as e: logger.error( f"_handle_messages: Failed to parse message: {e}", exc_info=True ) continue await self._process_message(data) except websockets.exceptions.ConnectionClosed as e: logger.info(f"WebSocket connection closed: {e}") except Exception as e: logger.error(f"Error handling messages: {e}", exc_info=True) async def _process_message(self, message: MessageData): """Process incoming signaling messages""" try: # Validate the base message structure first validated_message = WebSocketMessageModel.model_validate(message) msg_type = validated_message.type data = validated_message.data except ValidationError as e: logger.error(f"Invalid message structure: {e}", exc_info=True) return logger.info( f"_process_message: Received message type: {msg_type} with data: {data}" ) if msg_type == "addPeer": try: validated = AddPeerModel.model_validate(data) except ValidationError as e: logger.error(f"Invalid addPeer payload: {e}", exc_info=True) return await self._handle_add_peer(validated) elif msg_type == "removePeer": try: validated = RemovePeerModel.model_validate(data) except ValidationError as e: logger.error(f"Invalid removePeer payload: {e}", exc_info=True) return await self._handle_remove_peer(validated) elif msg_type == "sessionDescription": try: validated = SessionDescriptionModel.model_validate(data) except ValidationError as e: logger.error(f"Invalid sessionDescription payload: {e}", exc_info=True) return await self._handle_session_description(validated) elif msg_type == "iceCandidate": try: validated = IceCandidateModel.model_validate(data) except ValidationError as e: logger.error(f"Invalid iceCandidate payload: {e}", exc_info=True) return await self._handle_ice_candidate(validated) elif msg_type == "join_status": try: validated = JoinStatusModel.model_validate(data) except ValidationError as e: logger.error(f"Invalid join_status payload: {e}", exc_info=True) return logger.info(f"Join status: {validated.status} - {validated.message}") elif msg_type == "user_joined": try: validated = UserJoinedModel.model_validate(data) except ValidationError as e: logger.error(f"Invalid user_joined payload: {e}", exc_info=True) return logger.info( f"User joined: {validated.name} (session: {validated.session_id})" ) elif msg_type == "lobby_state": try: validated = LobbyStateModel.model_validate(data) except ValidationError as e: logger.error(f"Invalid lobby_state payload: {e}", exc_info=True) return participants = validated.participants logger.info(f"Lobby state updated: {len(participants)} participants") elif msg_type == "update": try: validated = UpdateModel.model_validate(data) except ValidationError as e: logger.error(f"Invalid update payload: {e}", exc_info=True) return logger.info(f"Received update message: {validated}") else: logger.info(f"Unhandled message type: {msg_type} with data: {data}") async def _handle_add_peer(self, data: AddPeerModel): """Handle addPeer message - create new peer connection""" peer_id = data.peer_id peer_name = data.peer_name should_create_offer = data.should_create_offer logger.info( f"Adding peer: {peer_name} (should_create_offer: {should_create_offer})" ) logger.info( f"_handle_add_peer: peer_id={peer_id}, peer_name={peer_name}, should_create_offer={should_create_offer}" ) # Check if peer already exists if peer_id in self.peer_connections: pc = self.peer_connections[peer_id] logger.info( f"_handle_add_peer: Existing connection state: {pc.connectionState}" ) if pc.connectionState in ["new", "connected", "connecting"]: logger.info(f"Peer connection already exists for {peer_name}") return else: # Clean up stale connection logger.info( f"_handle_add_peer: Closing stale connection for {peer_name}" ) await pc.close() del self.peer_connections[peer_id] # Create new peer peer = Peer(session_id=peer_id, peer_name=peer_name, local=False) self.peers[peer_id] = peer # Create RTCPeerConnection from aiortc.rtcconfiguration import RTCConfiguration, RTCIceServer config = RTCConfiguration( iceServers=[ RTCIceServer(urls="stun:ketrenos.com:3478"), RTCIceServer( urls="turns:ketrenos.com:5349", username="ketra", credential="ketran", ), # Add Google's public STUN server as fallback RTCIceServer(urls="stun:stun.l.google.com:19302"), ], ) logger.info( f"_handle_add_peer: Creating RTCPeerConnection for {peer_name} with config: {config}" ) pc = RTCPeerConnection(configuration=config) # Add ICE gathering state change handler (explicit registration to satisfy static analyzers) def on_ice_gathering_state_change() -> None: logger.info(f"ICE gathering state: {pc.iceGatheringState}") # Debug: Check if we have any local candidates when gathering is complete if pc.iceGatheringState == "complete": logger.info(f"ICE gathering complete for {peer_name} - checking if candidates were generated...") pc.on("icegatheringstatechange")(on_ice_gathering_state_change) # Add connection state change handler (explicit registration to satisfy static analyzers) def on_connection_state_change() -> None: logger.info(f"Connection state: {pc.connectionState}") pc.on("connectionstatechange")(on_connection_state_change) self.peer_connections[peer_id] = pc peer.connection = pc # Set up event handlers def on_track(track: MediaStreamTrack) -> None: logger.info(f"Received {track.kind} track from {peer_name}") logger.info(f"on_track: {track.kind} from {peer_name}, track={track}") peer.attributes[f"{track.kind}_track"] = track if self.on_track_received: asyncio.ensure_future(self.on_track_received(peer, track)) pc.on("track")(on_track) def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: logger.info(f"on_ice_candidate: {candidate}") logger.info(f"on_ice_candidate CALLED for {peer_name}: candidate={candidate}") if not candidate: logger.info(f"on_ice_candidate: End of candidates signal for {peer_name}") return # Raw SDP fragment for the candidate raw = getattr(candidate, "candidate", None) # Try to infer candidate type from the SDP string (host/srflx/relay/prflx) def _parse_type(s: Optional[str]) -> str: if not s: return "eoc" import re m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", s) return m.group(1) if m else "unknown" cand_type = _parse_type(raw) protocol = getattr(candidate, "protocol", "unknown") logger.info( f"ICE candidate outgoing for {peer_name}: type={cand_type} protocol={protocol} sdp={raw}" ) candidate_dict: MessageData = { "candidate": raw, "sdpMid": getattr(candidate, "sdpMid", None), "sdpMLineIndex": getattr(candidate, "sdpMLineIndex", None), } payload: MessageData = {"peer_id": peer_id, "candidate": candidate_dict} logger.info( f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_dict}" ) asyncio.ensure_future(self._send_message("relayICECandidate", payload)) pc.on("icecandidate")(on_ice_candidate) # Add local tracks for track in self.local_tracks.values(): logger.info( f"_handle_add_peer: Adding local track {track.kind} to {peer_name}" ) pc.addTrack(track) # Create offer if needed if should_create_offer: self.initiated_offer.add(peer_id) self.making_offer[peer_id] = True self.is_negotiating[peer_id] = True try: logger.info(f"_handle_add_peer: Creating offer for {peer_name}") offer = await pc.createOffer() logger.info(f"_handle_add_peer: Offer created for {peer_name}: {offer}") await pc.setLocalDescription(offer) logger.info(f"_handle_add_peer: Local description set for {peer_name}") # WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344) # Use Method 2: Complete SDP approach to extract ICE candidates logger.info(f"_handle_add_peer: Waiting for ICE gathering to complete for {peer_name}") while pc.iceGatheringState != "complete": await asyncio.sleep(0.1) logger.info(f"_handle_add_peer: ICE gathering complete, extracting candidates from SDP for {peer_name}") # Parse ICE candidates from the local SDP sdp_lines = pc.localDescription.sdp.split('\n') candidate_lines = [line for line in sdp_lines if line.startswith('a=candidate:')] # Track which media section we're in to determine sdpMid and sdpMLineIndex current_media_index = -1 current_mid = None for line in sdp_lines: if line.startswith('m='): # Media section current_media_index += 1 elif line.startswith('a=mid:'): # Media ID current_mid = line.split(':', 1)[1].strip() elif line.startswith('a=candidate:'): candidate_sdp = line[2:] # Remove 'a=' prefix candidate_dict: MessageData = { "candidate": candidate_sdp, "sdpMid": current_mid, "sdpMLineIndex": current_media_index, } payload_candidate: MessageData = { "peer_id": peer_id, "candidate": candidate_dict } logger.info(f"_handle_add_peer: Sending extracted ICE candidate for {peer_name}: {candidate_sdp[:60]}...") await self._send_message("relayICECandidate", payload_candidate) # Send end-of-candidates signal (empty candidate) end_candidate_dict: MessageData = { "candidate": "", "sdpMid": None, "sdpMLineIndex": None, } payload_end: MessageData = { "peer_id": peer_id, "candidate": end_candidate_dict } logger.info(f"_handle_add_peer: Sending end-of-candidates signal for {peer_name}") await self._send_message("relayICECandidate", payload_end) logger.info(f"_handle_add_peer: Sent {len(candidate_lines)} ICE candidates to {peer_name}") await self._send_message( "relaySessionDescription", { "peer_id": peer_id, "session_description": {"type": offer.type, "sdp": offer.sdp}, }, ) logger.info(f"Offer sent to {peer_name}") except Exception as e: logger.error( f"Failed to create/send offer to {peer_name}: {e}", exc_info=True ) finally: self.making_offer[peer_id] = False if self.on_peer_added: await self.on_peer_added(peer) async def _handle_remove_peer(self, data: RemovePeerModel): """Handle removePeer message""" peer_id = data.peer_id peer_name = data.peer_name logger.info(f"Removing peer: {peer_name}") # Close peer connection if peer_id in self.peer_connections: pc = self.peer_connections[peer_id] await pc.close() del self.peer_connections[peer_id] # Clean up state self.is_negotiating.pop(peer_id, None) self.making_offer.pop(peer_id, None) self.initiated_offer.discard(peer_id) self.pending_ice_candidates.pop(peer_id, None) # Remove peer peer = self.peers.pop(peer_id, None) if peer and self.on_peer_removed: await self.on_peer_removed(peer) async def _handle_session_description(self, data: SessionDescriptionModel): """Handle sessionDescription message""" peer_id = data.peer_id peer_name = data.peer_name session_description = data.session_description.model_dump() logger.info(f"Received {session_description['type']} from {peer_name}") pc = self.peer_connections.get(peer_id) if not pc: logger.error(f"No peer connection for {peer_name}") return desc = RTCSessionDescription( sdp=session_description["sdp"], type=session_description["type"] ) # Handle offer collision (polite peer pattern) making_offer = self.making_offer.get(peer_id, False) offer_collision = desc.type == "offer" and ( making_offer or pc.signalingState != "stable" ) we_initiated = peer_id in self.initiated_offer ignore_offer = we_initiated and offer_collision if ignore_offer: logger.info(f"Ignoring offer from {peer_name} due to collision") return try: await pc.setRemoteDescription(desc) self.is_negotiating[peer_id] = False logger.info(f"Remote description set for {peer_name}") # Process queued ICE candidates pending_candidates = self.pending_ice_candidates.pop(peer_id, []) from aiortc.sdp import candidate_from_sdp for candidate_data in pending_candidates: # candidate_data is a dict-like ICECandidateDict; convert SDP string cand = candidate_data.get("candidate") # handle end-of-candidates marker if not cand: await pc.addIceCandidate(None) logger.info(f"Added queued end-of-candidates for {peer_name}") continue # cand may be the full "candidate:..." string or the inner SDP part if cand and cand.startswith("candidate:"): sdp_part = cand.split(":", 1)[1] else: sdp_part = cand try: rtc_candidate = candidate_from_sdp(sdp_part) rtc_candidate.sdpMid = candidate_data.get("sdpMid") rtc_candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex") await pc.addIceCandidate(rtc_candidate) logger.info(f"Added queued ICE candidate for {peer_name}") except Exception as e: logger.error( f"Failed to add queued ICE candidate for {peer_name}: {e}" ) except Exception as e: logger.error(f"Failed to set remote description for {peer_name}: {e}") return # Create answer if this was an offer if session_description["type"] == "offer": try: answer = await pc.createAnswer() await pc.setLocalDescription(answer) # WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344) # Use Method 2: Complete SDP approach to extract ICE candidates logger.info(f"_handle_session_description: Waiting for ICE gathering to complete for {peer_name} (answer)") while pc.iceGatheringState != "complete": await asyncio.sleep(0.1) logger.info(f"_handle_session_description: ICE gathering complete, extracting candidates from SDP for {peer_name} (answer)") # Parse ICE candidates from the local SDP sdp_lines = pc.localDescription.sdp.split('\n') candidate_lines = [line for line in sdp_lines if line.startswith('a=candidate:')] # Track which media section we're in to determine sdpMid and sdpMLineIndex current_media_index = -1 current_mid = None for line in sdp_lines: if line.startswith('m='): # Media section current_media_index += 1 elif line.startswith('a=mid:'): # Media ID current_mid = line.split(':', 1)[1].strip() elif line.startswith('a=candidate:'): candidate_sdp = line[2:] # Remove 'a=' prefix candidate_dict: MessageData = { "candidate": candidate_sdp, "sdpMid": current_mid, "sdpMLineIndex": current_media_index, } payload_candidate: MessageData = { "peer_id": peer_id, "candidate": candidate_dict } logger.info(f"_handle_session_description: Sending extracted ICE candidate for {peer_name} (answer): {candidate_sdp[:60]}...") await self._send_message("relayICECandidate", payload_candidate) # Send end-of-candidates signal (empty candidate) end_candidate_dict: MessageData = { "candidate": "", "sdpMid": None, "sdpMLineIndex": None, } payload_end: MessageData = { "peer_id": peer_id, "candidate": end_candidate_dict } logger.info(f"_handle_session_description: Sending end-of-candidates signal for {peer_name} (answer)") await self._send_message("relayICECandidate", payload_end) logger.info(f"_handle_session_description: Sent {len(candidate_lines)} ICE candidates to {peer_name} (answer)") await self._send_message( "relaySessionDescription", { "peer_id": peer_id, "session_description": {"type": answer.type, "sdp": answer.sdp}, }, ) logger.info(f"Answer sent to {peer_name}") except Exception as e: logger.error(f"Failed to create/send answer to {peer_name}: {e}") async def _handle_ice_candidate(self, data: IceCandidateModel): """Handle iceCandidate message""" peer_id = data.peer_id peer_name = data.peer_name candidate_data = data.candidate.model_dump() logger.info(f"Received ICE candidate from {peer_name}") pc = self.peer_connections.get(peer_id) if not pc: logger.error(f"No peer connection for {peer_name}") return # Queue candidate if remote description not set if not pc.remoteDescription: logger.info( f"Remote description not set, queuing ICE candidate for {peer_name}" ) if peer_id not in self.pending_ice_candidates: self.pending_ice_candidates[peer_id] = [] # candidate_data is a dict from Pydantic model; cast to the TypedDict self.pending_ice_candidates[peer_id].append( cast(ICECandidateDict, candidate_data) ) return try: from aiortc.sdp import candidate_from_sdp cand = candidate_data.get("candidate") if not cand: # end-of-candidates await pc.addIceCandidate(None) logger.info(f"End-of-candidates added for {peer_name}") return if cand and cand.startswith("candidate:"): sdp_part = cand.split(":", 1)[1] else: sdp_part = cand # Detect type for logging try: import re m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", sdp_part) cand_type = m.group(1) if m else "unknown" except Exception: cand_type = "unknown" try: rtc_candidate = candidate_from_sdp(sdp_part) rtc_candidate.sdpMid = candidate_data.get("sdpMid") rtc_candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex") # aiortc expects an object with attributes (RTCIceCandidate) await pc.addIceCandidate(rtc_candidate) logger.info(f"ICE candidate added for {peer_name}: type={cand_type}") except Exception as e: logger.error( f"Failed to add ICE candidate for {peer_name}: type={cand_type} error={e} sdp='{sdp_part}'", exc_info=True, ) except Exception as e: logger.error( f"Unexpected error handling ICE candidate for {peer_name}: {e}", exc_info=True, ) async def _handle_ice_connection_failure(self, peer_id: str, peer_name: str): """Handle ICE connection failure by logging details""" logger.info(f"ICE connection failure detected for {peer_name}") pc = self.peer_connections.get(peer_id) if not pc: logger.error( f"No peer connection found for {peer_name} during ICE failure recovery" ) return logger.error( f"ICE connection failed for {peer_name}. Connection state: {pc.connectionState}, ICE state: {pc.iceConnectionState}" ) # In a real implementation, you might want to notify the user or attempt reconnection async def _schedule_ice_timeout(self, peer_id: str, peer_name: str): """Schedule a timeout for ICE connection checking""" await asyncio.sleep(30) # Wait 30 seconds pc = self.peer_connections.get(peer_id) if not pc: return if pc.iceConnectionState == "checking": logger.warning( f"ICE connection timeout for {peer_name} - still in checking state after 30 seconds" ) logger.warning( f"Final connection state: {pc.connectionState}, ICE state: {pc.iceConnectionState}" ) logger.warning( "This might be due to network connectivity issues between the browser and Docker container" ) logger.warning( "Consider checking: 1) Port forwarding 2) TURN server config 3) Docker network mode" ) elif pc.iceConnectionState in ["failed", "closed"]: logger.info( f"ICE connection for {peer_name} resolved to: {pc.iceConnectionState}" ) else: logger.info( f"ICE connection for {peer_name} established: {pc.iceConnectionState}" ) # Example usage def _http_base_url(server_url: str) -> str: # Convert ws:// or wss:// to http(s) and ensure no trailing slash if server_url.startswith("ws://"): return "http://" + server_url[len("ws://") :].rstrip("/") if server_url.startswith("wss://"): return "https://" + server_url[len("wss://") :].rstrip("/") return server_url.rstrip("/") def _ws_url(server_url: str) -> str: # Convert http(s) to ws(s) if needed if server_url.startswith("http://"): return "ws://" + server_url[len("http://") :].rstrip("/") if server_url.startswith("https://"): return "wss://" + server_url[len("https://") :].rstrip("/") return server_url.rstrip("/") def create_or_get_session( server_url: str, session_id: str | None = None, insecure: bool = False ) -> str: """Call GET /api/session to obtain a session_id (unless one was provided). Uses urllib so no extra runtime deps are required. """ if session_id: return session_id http_base = _http_base_url(server_url) url = f"{http_base}/api/session" req = urllib.request.Request(url, method="GET") # Prepare SSL context if requested (accept self-signed certs) ssl_ctx = None if insecure: ssl_ctx = ssl.create_default_context() ssl_ctx.check_hostname = False ssl_ctx.verify_mode = ssl.CERT_NONE try: with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp: body = resp.read() data = json.loads(body) # Validate response shape using Pydantic try: session = SessionModel.model_validate(data) except ValidationError as e: raise RuntimeError(f"Invalid session response from {url}: {e}") sid = session.id if not sid: raise RuntimeError(f"No session id returned from {url}: {data}") return sid except urllib.error.HTTPError as e: raise RuntimeError(f"HTTP error getting session: {e}") except Exception as e: raise RuntimeError(f"Error getting session: {e}") def create_or_get_lobby( server_url: str, session_id: str, lobby_name: str, private: bool = False, insecure: bool = False, ) -> str: """Call POST /api/lobby/{session_id} to create or lookup a lobby by name. Returns the lobby id. """ http_base = _http_base_url(server_url) url = f"{http_base}/api/lobby/{urllib.parse.quote(session_id)}" payload = json.dumps( { "type": "lobby_create", "data": {"name": lobby_name, "private": private}, } ).encode("utf-8") req = urllib.request.Request( url, data=payload, headers={"Content-Type": "application/json"}, method="POST" ) # Prepare SSL context if requested (accept self-signed certs) ssl_ctx = None if insecure: ssl_ctx = ssl.create_default_context() ssl_ctx.check_hostname = False ssl_ctx.verify_mode = ssl.CERT_NONE try: with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp: body = resp.read() data = json.loads(body) # Expect shape: { "type": "lobby_created", "data": {"id":..., ...}} try: lobby_resp = LobbyCreateResponse.model_validate(data) except ValidationError as e: raise RuntimeError(f"Invalid lobby response from {url}: {e}") lobby_id = lobby_resp.data.id if not lobby_id: raise RuntimeError(f"No lobby id returned from {url}: {data}") return lobby_id except urllib.error.HTTPError as e: # Try to include response body for.infoging try: body = e.read() msg = body.decode("utf-8", errors="ignore") except Exception: msg = str(e) raise RuntimeError(f"HTTP error creating lobby: {msg}") except Exception as e: raise RuntimeError(f"Error creating lobby: {e}") async def main(): """Example usage of the WebRTC signaling client with CLI options to create session and lobby.""" parser = argparse.ArgumentParser(description="Python WebRTC voicebot client") parser.add_argument( "--server-url", default="http://localhost:8000/ai-voicebot", help="AI-Voicebot lobby and signaling server base URL (http:// or https://)", ) parser.add_argument( "--lobby", default="default", help="Lobby name to create or join" ) parser.add_argument( "--session-name", default="Python Bot", help="Session (user) display name" ) parser.add_argument( "--session-id", default=None, help="Optional existing session id to reuse" ) parser.add_argument( "--password", default=None, help="Optional password to register or takeover a name", ) parser.add_argument( "--private", action="store_true", help="Create the lobby as private" ) parser.add_argument( "--insecure", action="store_true", help="Allow insecure server connections when using SSL (accept self-signed certs)", ) args = parser.parse_args() # Resolve session id (create if needed) try: session_id = create_or_get_session( args.server_url, args.session_id, insecure=args.insecure ) print(f"Using session id: {session_id}") except Exception as e: print(f"Failed to get/create session: {e}") return # Create or get lobby id try: lobby_id = create_or_get_lobby( args.server_url, session_id, args.lobby, args.private, insecure=args.insecure, ) print(f"Using lobby id: {lobby_id} (name={args.lobby})") except Exception as e: print(f"Failed to create/get lobby: {e}") return # Build websocket base URL (ws:// or wss://) from server_url and pass to client so # it constructs the final websocket path (/ws/lobby/{lobby}/{session}) itself. ws_base = _ws_url(args.server_url) client = WebRTCSignalingClient( ws_base, lobby_id, session_id, args.session_name, insecure=args.insecure ) # Set up event handlers async def on_peer_added(peer: Peer): print(f"Peer added: {peer.peer_name}") async def on_peer_removed(peer: Peer): print(f"Peer removed: {peer.peer_name}") async def on_track_received(peer: Peer, track: MediaStreamTrack): print(f"Received {track.kind} track from {peer.peer_name}") client.on_peer_added = on_peer_added client.on_peer_removed = on_peer_removed client.on_track_received = on_track_received try: # Connect and run # If a password was provided on the CLI, store it on the client for use when setting name if args.password: client.name_password = args.password await client.connect() except KeyboardInterrupt: print("Shutting down...") finally: await client.disconnect() if __name__ == "__main__": # Install required packages: # pip install aiortc websockets opencv-python numpy asyncio.run(main())