1362 lines
50 KiB
Python

"""
WebRTC Media Agent for Python
This module provides synthetic audio/video track creation and WebRTC signaling
server communication, ported from the JavaScript MediaControl implementation.
"""
from __future__ import annotations
import asyncio
import json
import websockets
import numpy as np
import cv2
import fractions
from typing import (
Dict,
Optional,
Callable,
Awaitable,
TypedDict,
Protocol,
AsyncIterator,
cast,
)
from dataclasses import dataclass, field
from pydantic import ValidationError
# types.SimpleNamespace removed — not used anymore after parsing candidates via aiortc.sdp
import argparse
import urllib.request
import urllib.error
import urllib.parse
import ssl
import sys
import os
# Import shared models
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.models import (
SessionModel,
LobbyCreateResponse,
WebSocketMessageModel,
JoinStatusModel,
UserJoinedModel,
LobbyStateModel,
UpdateNameModel,
AddPeerModel,
RemovePeerModel,
SessionDescriptionModel,
IceCandidateModel,
)
from aiortc import (
RTCPeerConnection,
RTCSessionDescription,
RTCIceCandidate,
MediaStreamTrack,
)
from av import VideoFrame, AudioFrame
import time
from logger import logger
# Defensive monkeypatch: aioice Transaction.__retry may run after the
# underlying datagram transport or loop was torn down which results in
# AttributeError being raised and flooding logs. Wrap the original
# implementation to catch and suppress AttributeError while preserving
# other exceptions. This is a temporary mitigation to keep logs readable
# while we investigate/upstream a proper fix or upgrade aioice.
try:
import aioice.stun as _aioice_stun # type: ignore
# The method is defined with a double-underscore name (__retry) which
# gets name-mangled. Detect the actual attribute name robustly.
retry_attr_name = None
for name in dir(_aioice_stun.Transaction):
if name.endswith("retry"):
obj = getattr(_aioice_stun.Transaction, name)
if callable(obj):
retry_attr_name = name
_orig_retry = obj
break
if retry_attr_name is not None:
# Simple in-process dedupe cache so we only log the same AttributeError
# once per interval. This prevents flooding the logs when many
# transactions race to run after shutdown.
_MONKEYPATCH_LOG_CACHE: dict[str, float] = {}
_MONKEYPATCH_LOG_SUPPRESSION_INTERVAL = 5.0
def _should_log_once(key: str) -> bool:
now = time.time()
last = _MONKEYPATCH_LOG_CACHE.get(key)
if last is None or (now - last) > _MONKEYPATCH_LOG_SUPPRESSION_INTERVAL:
_MONKEYPATCH_LOG_CACHE[key] = now
return True
return False
def _safe_transaction_retry(self, *args, **kwargs): # type: ignore
try:
return _orig_retry(self, *args, **kwargs) # type: ignore
except AttributeError as e: # type: ignore
# Transport or event-loop already closed; log once per key
key = f"Transaction.{retry_attr_name}:{e}"
if _should_log_once(key):
logger.warning(
"aioice Transaction.%s AttributeError suppressed: %s",
retry_attr_name,
e,
)
except Exception: # type: ignore
# Preserve visibility for other unexpected exceptions
logger.exception(
"aioice Transaction.%s raised an unexpected exception",
retry_attr_name,
)
setattr(_aioice_stun.Transaction, retry_attr_name, _safe_transaction_retry) # type: ignore
logger.info("Applied safe aioice Transaction.%s monkeypatch", retry_attr_name)
else:
logger.warning("aioice Transaction.__retry not found; skipping monkeypatch")
except Exception as e:
logger.exception("Failed to apply aioice Transaction.__retry monkeypatch: %s", e)
# Additional defensive patch: wrap the protocol-level send_stun implementation
# (e.g. StunProtocol.send_stun) which ultimately calls the datagram transport's
# sendto. If the transport or its loop is already torn down, sendto can raise
# AttributeError which then triggers asyncio's fatal error path (calling a None
# loop). Wrapping here prevents the flood of selector_events/_fatal_error
# AttributeError traces.
try:
import aioice.ice as _aioice_ice # type: ignore
# Prefer to patch StunProtocol.send_stun which is used by the ICE code.
send_attr_name = None
if hasattr(_aioice_ice, "StunProtocol"):
proto_cls = getattr(_aioice_ice, "StunProtocol")
for name in dir(proto_cls):
if name.endswith("send_stun"):
attr = getattr(proto_cls, name)
if callable(attr):
send_attr_name = name
_orig_send_stun = attr
break
if send_attr_name is not None:
def _safe_send_stun(self, message, addr): # type: ignore
try:
return _orig_send_stun(self, message, addr) # type: ignore
except AttributeError as e: # type: ignore
# Likely transport._sock or transport._loop is None; log once
key = f"StunProtocol.{send_attr_name}:{e}"
if _should_log_once(key):
logger.warning(
"aioice StunProtocol.%s AttributeError suppressed: %s",
send_attr_name,
e,
)
except Exception: # type: ignore
logger.exception(
"aioice StunProtocol.%s raised unexpected exception", send_attr_name
)
setattr(proto_cls, send_attr_name, _safe_send_stun) # type: ignore
logger.info("Applied safe aioice StunProtocol.%s monkeypatch", send_attr_name)
else:
logger.warning("aioice StunProtocol.send_stun not found; skipping monkeypatch")
except Exception as e:
logger.exception("Failed to apply aioice StunProtocol.send_stun monkeypatch: %s", e)
# TypedDict for ICE candidate payloads received from signalling
class ICECandidateDict(TypedDict, total=False):
candidate: str
sdpMid: Optional[str]
sdpMLineIndex: Optional[int]
# Generic message payload type
MessageData = dict[str, object]
# Message TypedDicts for signaling payloads
class BaseMessage(TypedDict, total=False):
type: str
data: object
class AddPeerPayload(TypedDict):
peer_id: str
peer_name: str
should_create_offer: bool
class RemovePeerPayload(TypedDict):
peer_id: str
peer_name: str
class SessionDescriptionTyped(TypedDict):
type: str
sdp: str
class SessionDescriptionPayload(TypedDict):
peer_id: str
peer_name: str
session_description: SessionDescriptionTyped
class IceCandidatePayload(TypedDict):
peer_id: str
peer_name: str
candidate: ICECandidateDict
class WebSocketProtocol(Protocol):
def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ...
def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ...
def __aiter__(self) -> AsyncIterator[str]: ...
def _default_attributes() -> Dict[str, object]:
return {}
@dataclass
class Peer:
"""Represents a WebRTC peer in the session"""
session_id: str
peer_name: str
# Generic attributes bag. Values can be tracks or simple metadata.
attributes: Dict[str, object] = field(default_factory=_default_attributes)
muted: bool = False
video_on: bool = True
local: bool = False
dead: bool = False
connection: Optional[RTCPeerConnection] = None
class AnimatedVideoTrack(MediaStreamTrack):
async def next_timestamp(self):
# Returns (pts, time_base) for 15 FPS video
pts = int(self.frame_count * (1 / 15) * 90000)
time_base = 1 / 90000
return pts, time_base
"""
Synthetic video track that generates animated content with a bouncing ball.
Ported from JavaScript createAnimatedVideoTrack function.
"""
kind = "video"
def __init__(self, width: int = 320, height: int = 240, name: str = ""):
super().__init__()
self.width = width
self.height = height
self.name = name
# Generate color from name hash (similar to JavaScript nameToColor)
self.ball_color = (
self._name_to_color(name) if name else (0, 255, 136)
) # Default green
# Ball properties
self.ball = {
"x": width / 2,
"y": height / 2,
"radius": min(width, height) * 0.06,
"dx": 3.0,
"dy": 2.0,
}
self.frame_count = 0
self._start_time = time.time()
def _name_to_color(self, name: str) -> tuple[int, int, int]:
"""Convert name to HSL color, then to RGB tuple"""
# Simple hash function (djb2)
hash_value = 5381
for char in name:
hash_value = ((hash_value << 5) + hash_value + ord(char)) & 0xFFFFFFFF
# Generate HSL color from hash
hue = abs(hash_value) % 360
sat = 60 + (abs(hash_value) % 30) # 60-89%
light = 45 + (abs(hash_value) % 30) # 45-74%
# Convert HSL to RGB
h = hue / 360.0
s = sat / 100.0
lightness = light / 100.0
c = (1 - abs(2 * lightness - 1)) * s
x = c * (1 - abs((h * 6) % 2 - 1))
m = lightness - c / 2
if h < 1 / 6:
r, g, b = c, x, 0
elif h < 2 / 6:
r, g, b = x, c, 0
elif h < 3 / 6:
r, g, b = 0, c, x
elif h < 4 / 6:
r, g, b = 0, x, c
elif h < 5 / 6:
r, g, b = x, 0, c
else:
r, g, b = c, 0, x
return (
int((b + m) * 255),
int((g + m) * 255),
int((r + m) * 255),
) # BGR for OpenCV
async def recv(self):
"""Generate video frames at 15 FPS"""
pts, time_base = await self.next_timestamp()
# Create black background
frame_array = np.zeros((self.height, self.width, 3), dtype=np.uint8)
# Update ball position
ball = self.ball
ball["x"] += ball["dx"]
ball["y"] += ball["dy"]
# Bounce off walls
if ball["x"] + ball["radius"] >= self.width or ball["x"] - ball["radius"] <= 0:
ball["dx"] = -ball["dx"]
if ball["y"] + ball["radius"] >= self.height or ball["y"] - ball["radius"] <= 0:
ball["dy"] = -ball["dy"]
# Keep ball in bounds
ball["x"] = max(ball["radius"], min(self.width - ball["radius"], ball["x"]))
ball["y"] = max(ball["radius"], min(self.height - ball["radius"], ball["y"]))
# Draw ball
cv2.circle(
frame_array,
(int(ball["x"]), int(ball["y"])),
int(ball["radius"]),
self.ball_color,
-1,
)
# Add frame counter text
frame_text = f"Frame: {int(time.time() * 1000) % 10000}"
# logger.info(frame_text)
cv2.putText(
frame_array,
frame_text,
(10, 20),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(255, 255, 255),
1,
)
# Convert to VideoFrame
frame = VideoFrame.from_ndarray(frame_array, format="bgr24")
frame.pts = pts
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
self.frame_count += 1
return frame
class SilentAudioTrack(MediaStreamTrack):
async def next_timestamp(self):
# Returns (pts, time_base) for 20ms audio frames at 48kHz
pts = int(time.time() * self.sample_rate)
time_base = 1 / self.sample_rate
return pts, time_base
"""
Synthetic audio track that generates silence.
Ported from JavaScript createSilentAudioTrack function.
"""
kind = "audio"
def __init__(self):
super().__init__()
self.sample_rate = 48000
self.samples_per_frame = 960 # 20ms at 48kHz
async def recv(self):
"""Generate silent audio frames"""
pts, time_base = await self.next_timestamp()
# Create silent audio data
samples = np.zeros((self.samples_per_frame,), dtype=np.float32)
# Convert to AudioFrame
frame = AudioFrame.from_ndarray(
samples.reshape(1, -1), format="flt", layout="mono"
)
frame.sample_rate = self.sample_rate
frame.pts = pts
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
return frame
class WebRTCSignalingClient:
"""
WebRTC signaling client that communicates with the FastAPI signaling server.
Handles peer-to-peer connection establishment and media streaming.
"""
def __init__(
self,
server_url: str,
lobby_id: str,
session_id: str,
session_name: str,
insecure: bool = False,
):
self.server_url = server_url
self.lobby_id = lobby_id
self.session_id = session_id
self.session_name = session_name
self.insecure = insecure
# WebSocket client protocol instance (typed as object to avoid Any)
self.websocket: Optional[object] = None
# Optional password to register or takeover a name
self.name_password: Optional[str] = None
self.peers: dict[str, Peer] = {}
self.peer_connections: dict[str, RTCPeerConnection] = {}
self.local_tracks: dict[str, MediaStreamTrack] = {}
# State management
self.is_negotiating: dict[str, bool] = {}
self.making_offer: dict[str, bool] = {}
self.initiated_offer: set[str] = set()
self.pending_ice_candidates: dict[str, list[ICECandidateDict]] = {}
# Event callbacks
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None
self.on_track_received: Optional[
Callable[[Peer, MediaStreamTrack], Awaitable[None]]
] = None
async def connect(self):
"""Connect to the signaling server"""
ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
logger.info(f"Connecting to signaling server: {ws_url}")
# Log network information for debugging
try:
import socket
hostname = socket.gethostname()
local_ip = socket.gethostbyname(hostname)
logger.info(f"Container hostname: {hostname}, local IP: {local_ip}")
# Get all network interfaces
import subprocess
result = subprocess.run(
["ip", "addr", "show"], capture_output=True, text=True
)
logger.info(f"Network interfaces:\n{result.stdout}")
except Exception as e:
logger.warning(f"Could not get network info: {e}")
try:
# If insecure (self-signed certs), create an SSL context for the websocket
ws_ssl = None
if self.insecure:
ws_ssl = ssl.create_default_context()
ws_ssl.check_hostname = False
ws_ssl.verify_mode = ssl.CERT_NONE
logger.info(
f"Attempting websocket connection to {ws_url} with ssl={bool(ws_ssl)}"
)
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
logger.info("Connected to signaling server")
# Set up local media
await self._setup_local_media()
# Set name and join lobby
name_payload: MessageData = {"name": self.session_name}
if self.name_password:
name_payload["password"] = self.name_password
logger.info(f"Sending set_name: {name_payload}")
await self._send_message("set_name", name_payload)
logger.info("Sending join message")
await self._send_message("join", {})
# Start message handling
logger.info("Starting message handler loop")
await self._handle_messages()
except Exception as e:
logger.error(f"Failed to connect to signaling server: {e}", exc_info=True)
raise
async def disconnect(self):
"""Disconnect from signaling server and cleanup"""
if self.websocket:
ws = cast(WebSocketProtocol, self.websocket)
await ws.close()
# Close all peer connections
for pc in self.peer_connections.values():
await pc.close()
# Stop local tracks
for track in self.local_tracks.values():
track.stop()
logger.info("Disconnected from signaling server")
async def _setup_local_media(self):
"""Create local synthetic media tracks"""
# Create synthetic video track
video_track = AnimatedVideoTrack(name=self.session_name)
self.local_tracks["video"] = video_track
# Create synthetic audio track
audio_track = SilentAudioTrack()
self.local_tracks["audio"] = audio_track
# Add local peer to peers dict
local_peer = Peer(
session_id=self.session_id,
peer_name=self.session_name,
local=True,
attributes={"tracks": self.local_tracks},
)
self.peers[self.session_id] = local_peer
logger.info("Local synthetic media tracks created")
async def _send_message(
self, message_type: str, data: Optional[MessageData] = None
):
"""Send message to signaling server"""
if not self.websocket:
logger.error("No websocket connection")
return
# Build message with explicit type to avoid type narrowing
message: dict[str, object] = {"type": message_type}
if data is not None:
message["data"] = data
ws = cast(WebSocketProtocol, self.websocket)
try:
logger.info(f"_send_message: Sending {message_type} with data: {data}")
await ws.send(json.dumps(message))
logger.info(f"_send_message: Sent message: {message_type}")
except Exception as e:
logger.error(
f"_send_message: Failed to send {message_type}: {e}", exc_info=True
)
async def _handle_messages(self):
"""Handle incoming messages from signaling server"""
try:
ws = cast(WebSocketProtocol, self.websocket)
async for message in ws:
logger.info(f"_handle_messages: Received raw message: {message}")
try:
data = cast(MessageData, json.loads(message))
except Exception as e:
logger.error(
f"_handle_messages: Failed to parse message: {e}", exc_info=True
)
continue
await self._process_message(data)
except websockets.exceptions.ConnectionClosed as e:
logger.info(f"WebSocket connection closed: {e}")
except Exception as e:
logger.error(f"Error handling messages: {e}", exc_info=True)
async def _process_message(self, message: MessageData):
"""Process incoming signaling messages"""
try:
# Validate the base message structure first
validated_message = WebSocketMessageModel.model_validate(message)
msg_type = validated_message.type
data = validated_message.data
except ValidationError as e:
logger.error(f"Invalid message structure: {e}", exc_info=True)
return
logger.info(
f"_process_message: Received message type: {msg_type} with data: {data}"
)
if msg_type == "addPeer":
try:
validated = AddPeerModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid addPeer payload: {e}", exc_info=True)
return
await self._handle_add_peer(validated)
elif msg_type == "removePeer":
try:
validated = RemovePeerModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid removePeer payload: {e}", exc_info=True)
return
await self._handle_remove_peer(validated)
elif msg_type == "sessionDescription":
try:
validated = SessionDescriptionModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid sessionDescription payload: {e}", exc_info=True)
return
await self._handle_session_description(validated)
elif msg_type == "iceCandidate":
try:
validated = IceCandidateModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid iceCandidate payload: {e}", exc_info=True)
return
await self._handle_ice_candidate(validated)
elif msg_type == "join_status":
try:
validated = JoinStatusModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid join_status payload: {e}", exc_info=True)
return
logger.info(f"Join status: {validated.status} - {validated.message}")
elif msg_type == "user_joined":
try:
validated = UserJoinedModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid user_joined payload: {e}", exc_info=True)
return
logger.info(
f"User joined: {validated.name} (session: {validated.session_id})"
)
elif msg_type == "lobby_state":
try:
validated = LobbyStateModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid lobby_state payload: {e}", exc_info=True)
return
participants = validated.participants
logger.info(f"Lobby state updated: {len(participants)} participants")
elif msg_type == "update_name":
try:
validated = UpdateNameModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid update payload: {e}", exc_info=True)
return
logger.info(f"Received update message: {validated}")
else:
logger.info(f"Unhandled message type: {msg_type} with data: {data}")
async def _handle_add_peer(self, data: AddPeerModel):
"""Handle addPeer message - create new peer connection"""
peer_id = data.peer_id
peer_name = data.peer_name
should_create_offer = data.should_create_offer
logger.info(
f"Adding peer: {peer_name} (should_create_offer: {should_create_offer})"
)
logger.info(
f"_handle_add_peer: peer_id={peer_id}, peer_name={peer_name}, should_create_offer={should_create_offer}"
)
# Check if peer already exists
if peer_id in self.peer_connections:
pc = self.peer_connections[peer_id]
logger.info(
f"_handle_add_peer: Existing connection state: {pc.connectionState}"
)
if pc.connectionState in ["new", "connected", "connecting"]:
logger.info(f"Peer connection already exists for {peer_name}")
return
else:
# Clean up stale connection
logger.info(
f"_handle_add_peer: Closing stale connection for {peer_name}"
)
await pc.close()
del self.peer_connections[peer_id]
# Create new peer
peer = Peer(session_id=peer_id, peer_name=peer_name, local=False)
self.peers[peer_id] = peer
# Create RTCPeerConnection
from aiortc.rtcconfiguration import RTCConfiguration, RTCIceServer
config = RTCConfiguration(
iceServers=[
RTCIceServer(urls="stun:ketrenos.com:3478"),
RTCIceServer(
urls="turns:ketrenos.com:5349",
username="ketra",
credential="ketran",
),
# Add Google's public STUN server as fallback
RTCIceServer(urls="stun:stun.l.google.com:19302"),
],
)
logger.info(
f"_handle_add_peer: Creating RTCPeerConnection for {peer_name} with config: {config}"
)
pc = RTCPeerConnection(configuration=config)
# Add ICE gathering state change handler (explicit registration to satisfy static analyzers)
def on_ice_gathering_state_change() -> None:
logger.info(f"ICE gathering state: {pc.iceGatheringState}")
# Debug: Check if we have any local candidates when gathering is complete
if pc.iceGatheringState == "complete":
logger.info(f"ICE gathering complete for {peer_name} - checking if candidates were generated...")
pc.on("icegatheringstatechange")(on_ice_gathering_state_change)
# Add connection state change handler (explicit registration to satisfy static analyzers)
def on_connection_state_change() -> None:
logger.info(f"Connection state: {pc.connectionState}")
pc.on("connectionstatechange")(on_connection_state_change)
self.peer_connections[peer_id] = pc
peer.connection = pc
# Set up event handlers
def on_track(track: MediaStreamTrack) -> None:
logger.info(f"Received {track.kind} track from {peer_name}")
logger.info(f"on_track: {track.kind} from {peer_name}, track={track}")
peer.attributes[f"{track.kind}_track"] = track
if self.on_track_received:
asyncio.ensure_future(self.on_track_received(peer, track))
pc.on("track")(on_track)
def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None:
logger.info(f"on_ice_candidate: {candidate}")
logger.info(f"on_ice_candidate CALLED for {peer_name}: candidate={candidate}")
if not candidate:
logger.info(f"on_ice_candidate: End of candidates signal for {peer_name}")
return
# Raw SDP fragment for the candidate
raw = getattr(candidate, "candidate", None)
# Try to infer candidate type from the SDP string (host/srflx/relay/prflx)
def _parse_type(s: Optional[str]) -> str:
if not s:
return "eoc"
import re
m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", s)
return m.group(1) if m else "unknown"
cand_type = _parse_type(raw)
protocol = getattr(candidate, "protocol", "unknown")
logger.info(
f"ICE candidate outgoing for {peer_name}: type={cand_type} protocol={protocol} sdp={raw}"
)
candidate_dict: MessageData = {
"candidate": raw,
"sdpMid": getattr(candidate, "sdpMid", None),
"sdpMLineIndex": getattr(candidate, "sdpMLineIndex", None),
}
payload: MessageData = {"peer_id": peer_id, "candidate": candidate_dict}
logger.info(
f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_dict}"
)
asyncio.ensure_future(self._send_message("relayICECandidate", payload))
pc.on("icecandidate")(on_ice_candidate)
# Add local tracks
for track in self.local_tracks.values():
logger.info(
f"_handle_add_peer: Adding local track {track.kind} to {peer_name}"
)
pc.addTrack(track)
# Create offer if needed
if should_create_offer:
self.initiated_offer.add(peer_id)
self.making_offer[peer_id] = True
self.is_negotiating[peer_id] = True
try:
logger.info(f"_handle_add_peer: Creating offer for {peer_name}")
offer = await pc.createOffer()
logger.info(f"_handle_add_peer: Offer created for {peer_name}: {offer}")
await pc.setLocalDescription(offer)
logger.info(f"_handle_add_peer: Local description set for {peer_name}")
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
# Use Method 2: Complete SDP approach to extract ICE candidates
logger.info(f"_handle_add_peer: Waiting for ICE gathering to complete for {peer_name}")
while pc.iceGatheringState != "complete":
await asyncio.sleep(0.1)
logger.info(f"_handle_add_peer: ICE gathering complete, extracting candidates from SDP for {peer_name}")
# Parse ICE candidates from the local SDP
sdp_lines = pc.localDescription.sdp.split('\n')
candidate_lines = [line for line in sdp_lines if line.startswith('a=candidate:')]
# Track which media section we're in to determine sdpMid and sdpMLineIndex
current_media_index = -1
current_mid = None
for line in sdp_lines:
if line.startswith('m='): # Media section
current_media_index += 1
elif line.startswith('a=mid:'): # Media ID
current_mid = line.split(':', 1)[1].strip()
elif line.startswith('a=candidate:'):
candidate_sdp = line[2:] # Remove 'a=' prefix
candidate_dict: MessageData = {
"candidate": candidate_sdp,
"sdpMid": current_mid,
"sdpMLineIndex": current_media_index,
}
payload_candidate: MessageData = {
"peer_id": peer_id,
"candidate": candidate_dict
}
logger.info(f"_handle_add_peer: Sending extracted ICE candidate for {peer_name}: {candidate_sdp[:60]}...")
await self._send_message("relayICECandidate", payload_candidate)
# Send end-of-candidates signal (empty candidate)
end_candidate_dict: MessageData = {
"candidate": "",
"sdpMid": None,
"sdpMLineIndex": None,
}
payload_end: MessageData = {
"peer_id": peer_id,
"candidate": end_candidate_dict
}
logger.info(f"_handle_add_peer: Sending end-of-candidates signal for {peer_name}")
await self._send_message("relayICECandidate", payload_end)
logger.info(f"_handle_add_peer: Sent {len(candidate_lines)} ICE candidates to {peer_name}")
await self._send_message(
"relaySessionDescription",
{
"peer_id": peer_id,
"session_description": {"type": offer.type, "sdp": offer.sdp},
},
)
logger.info(f"Offer sent to {peer_name}")
except Exception as e:
logger.error(
f"Failed to create/send offer to {peer_name}: {e}", exc_info=True
)
finally:
self.making_offer[peer_id] = False
if self.on_peer_added:
await self.on_peer_added(peer)
async def _handle_remove_peer(self, data: RemovePeerModel):
"""Handle removePeer message"""
peer_id = data.peer_id
peer_name = data.peer_name
logger.info(f"Removing peer: {peer_name}")
# Close peer connection
if peer_id in self.peer_connections:
pc = self.peer_connections[peer_id]
await pc.close()
del self.peer_connections[peer_id]
# Clean up state
self.is_negotiating.pop(peer_id, None)
self.making_offer.pop(peer_id, None)
self.initiated_offer.discard(peer_id)
self.pending_ice_candidates.pop(peer_id, None)
# Remove peer
peer = self.peers.pop(peer_id, None)
if peer and self.on_peer_removed:
await self.on_peer_removed(peer)
async def _handle_session_description(self, data: SessionDescriptionModel):
"""Handle sessionDescription message"""
peer_id = data.peer_id
peer_name = data.peer_name
session_description = data.session_description.model_dump()
logger.info(f"Received {session_description['type']} from {peer_name}")
pc = self.peer_connections.get(peer_id)
if not pc:
logger.error(f"No peer connection for {peer_name}")
return
desc = RTCSessionDescription(
sdp=session_description["sdp"], type=session_description["type"]
)
# Handle offer collision (polite peer pattern)
making_offer = self.making_offer.get(peer_id, False)
offer_collision = desc.type == "offer" and (
making_offer or pc.signalingState != "stable"
)
we_initiated = peer_id in self.initiated_offer
ignore_offer = we_initiated and offer_collision
if ignore_offer:
logger.info(f"Ignoring offer from {peer_name} due to collision")
return
try:
await pc.setRemoteDescription(desc)
self.is_negotiating[peer_id] = False
logger.info(f"Remote description set for {peer_name}")
# Process queued ICE candidates
pending_candidates = self.pending_ice_candidates.pop(peer_id, [])
from aiortc.sdp import candidate_from_sdp
for candidate_data in pending_candidates:
# candidate_data is a dict-like ICECandidateDict; convert SDP string
cand = candidate_data.get("candidate")
# handle end-of-candidates marker
if not cand:
await pc.addIceCandidate(None)
logger.info(f"Added queued end-of-candidates for {peer_name}")
continue
# cand may be the full "candidate:..." string or the inner SDP part
if cand and cand.startswith("candidate:"):
sdp_part = cand.split(":", 1)[1]
else:
sdp_part = cand
try:
rtc_candidate = candidate_from_sdp(sdp_part)
rtc_candidate.sdpMid = candidate_data.get("sdpMid")
rtc_candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex")
await pc.addIceCandidate(rtc_candidate)
logger.info(f"Added queued ICE candidate for {peer_name}")
except Exception as e:
logger.error(
f"Failed to add queued ICE candidate for {peer_name}: {e}"
)
except Exception as e:
logger.error(f"Failed to set remote description for {peer_name}: {e}")
return
# Create answer if this was an offer
if session_description["type"] == "offer":
try:
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
# Use Method 2: Complete SDP approach to extract ICE candidates
logger.info(f"_handle_session_description: Waiting for ICE gathering to complete for {peer_name} (answer)")
while pc.iceGatheringState != "complete":
await asyncio.sleep(0.1)
logger.info(f"_handle_session_description: ICE gathering complete, extracting candidates from SDP for {peer_name} (answer)")
# Parse ICE candidates from the local SDP
sdp_lines = pc.localDescription.sdp.split('\n')
candidate_lines = [line for line in sdp_lines if line.startswith('a=candidate:')]
# Track which media section we're in to determine sdpMid and sdpMLineIndex
current_media_index = -1
current_mid = None
for line in sdp_lines:
if line.startswith('m='): # Media section
current_media_index += 1
elif line.startswith('a=mid:'): # Media ID
current_mid = line.split(':', 1)[1].strip()
elif line.startswith('a=candidate:'):
candidate_sdp = line[2:] # Remove 'a=' prefix
candidate_dict: MessageData = {
"candidate": candidate_sdp,
"sdpMid": current_mid,
"sdpMLineIndex": current_media_index,
}
payload_candidate: MessageData = {
"peer_id": peer_id,
"candidate": candidate_dict
}
logger.info(f"_handle_session_description: Sending extracted ICE candidate for {peer_name} (answer): {candidate_sdp[:60]}...")
await self._send_message("relayICECandidate", payload_candidate)
# Send end-of-candidates signal (empty candidate)
end_candidate_dict: MessageData = {
"candidate": "",
"sdpMid": None,
"sdpMLineIndex": None,
}
payload_end: MessageData = {
"peer_id": peer_id,
"candidate": end_candidate_dict
}
logger.info(f"_handle_session_description: Sending end-of-candidates signal for {peer_name} (answer)")
await self._send_message("relayICECandidate", payload_end)
logger.info(f"_handle_session_description: Sent {len(candidate_lines)} ICE candidates to {peer_name} (answer)")
await self._send_message(
"relaySessionDescription",
{
"peer_id": peer_id,
"session_description": {"type": answer.type, "sdp": answer.sdp},
},
)
logger.info(f"Answer sent to {peer_name}")
except Exception as e:
logger.error(f"Failed to create/send answer to {peer_name}: {e}")
async def _handle_ice_candidate(self, data: IceCandidateModel):
"""Handle iceCandidate message"""
peer_id = data.peer_id
peer_name = data.peer_name
candidate_data = data.candidate.model_dump()
logger.info(f"Received ICE candidate from {peer_name}")
pc = self.peer_connections.get(peer_id)
if not pc:
logger.error(f"No peer connection for {peer_name}")
return
# Queue candidate if remote description not set
if not pc.remoteDescription:
logger.info(
f"Remote description not set, queuing ICE candidate for {peer_name}"
)
if peer_id not in self.pending_ice_candidates:
self.pending_ice_candidates[peer_id] = []
# candidate_data is a dict from Pydantic model; cast to the TypedDict
self.pending_ice_candidates[peer_id].append(
cast(ICECandidateDict, candidate_data)
)
return
try:
from aiortc.sdp import candidate_from_sdp
cand = candidate_data.get("candidate")
if not cand:
# end-of-candidates
await pc.addIceCandidate(None)
logger.info(f"End-of-candidates added for {peer_name}")
return
if cand and cand.startswith("candidate:"):
sdp_part = cand.split(":", 1)[1]
else:
sdp_part = cand
# Detect type for logging
try:
import re
m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", sdp_part)
cand_type = m.group(1) if m else "unknown"
except Exception:
cand_type = "unknown"
try:
rtc_candidate = candidate_from_sdp(sdp_part)
rtc_candidate.sdpMid = candidate_data.get("sdpMid")
rtc_candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex")
# aiortc expects an object with attributes (RTCIceCandidate)
await pc.addIceCandidate(rtc_candidate)
logger.info(f"ICE candidate added for {peer_name}: type={cand_type}")
except Exception as e:
logger.error(
f"Failed to add ICE candidate for {peer_name}: type={cand_type} error={e} sdp='{sdp_part}'",
exc_info=True,
)
except Exception as e:
logger.error(
f"Unexpected error handling ICE candidate for {peer_name}: {e}",
exc_info=True,
)
async def _handle_ice_connection_failure(self, peer_id: str, peer_name: str):
"""Handle ICE connection failure by logging details"""
logger.info(f"ICE connection failure detected for {peer_name}")
pc = self.peer_connections.get(peer_id)
if not pc:
logger.error(
f"No peer connection found for {peer_name} during ICE failure recovery"
)
return
logger.error(
f"ICE connection failed for {peer_name}. Connection state: {pc.connectionState}, ICE state: {pc.iceConnectionState}"
)
# In a real implementation, you might want to notify the user or attempt reconnection
async def _schedule_ice_timeout(self, peer_id: str, peer_name: str):
"""Schedule a timeout for ICE connection checking"""
await asyncio.sleep(30) # Wait 30 seconds
pc = self.peer_connections.get(peer_id)
if not pc:
return
if pc.iceConnectionState == "checking":
logger.warning(
f"ICE connection timeout for {peer_name} - still in checking state after 30 seconds"
)
logger.warning(
f"Final connection state: {pc.connectionState}, ICE state: {pc.iceConnectionState}"
)
logger.warning(
"This might be due to network connectivity issues between the browser and Docker container"
)
logger.warning(
"Consider checking: 1) Port forwarding 2) TURN server config 3) Docker network mode"
)
elif pc.iceConnectionState in ["failed", "closed"]:
logger.info(
f"ICE connection for {peer_name} resolved to: {pc.iceConnectionState}"
)
else:
logger.info(
f"ICE connection for {peer_name} established: {pc.iceConnectionState}"
)
# Example usage
def _http_base_url(server_url: str) -> str:
# Convert ws:// or wss:// to http(s) and ensure no trailing slash
if server_url.startswith("ws://"):
return "http://" + server_url[len("ws://") :].rstrip("/")
if server_url.startswith("wss://"):
return "https://" + server_url[len("wss://") :].rstrip("/")
return server_url.rstrip("/")
def _ws_url(server_url: str) -> str:
# Convert http(s) to ws(s) if needed
if server_url.startswith("http://"):
return "ws://" + server_url[len("http://") :].rstrip("/")
if server_url.startswith("https://"):
return "wss://" + server_url[len("https://") :].rstrip("/")
return server_url.rstrip("/")
def create_or_get_session(
server_url: str, session_id: str | None = None, insecure: bool = False
) -> str:
"""Call GET /api/session to obtain a session_id (unless one was provided).
Uses urllib so no extra runtime deps are required.
"""
if session_id:
return session_id
http_base = _http_base_url(server_url)
url = f"{http_base}/api/session"
req = urllib.request.Request(url, method="GET")
# Prepare SSL context if requested (accept self-signed certs)
ssl_ctx = None
if insecure:
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
try:
with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp:
body = resp.read()
data = json.loads(body)
# Validate response shape using Pydantic
try:
session = SessionModel.model_validate(data)
except ValidationError as e:
raise RuntimeError(f"Invalid session response from {url}: {e}")
sid = session.id
if not sid:
raise RuntimeError(f"No session id returned from {url}: {data}")
return sid
except urllib.error.HTTPError as e:
raise RuntimeError(f"HTTP error getting session: {e}")
except Exception as e:
raise RuntimeError(f"Error getting session: {e}")
def create_or_get_lobby(
server_url: str,
session_id: str,
lobby_name: str,
private: bool = False,
insecure: bool = False,
) -> str:
"""Call POST /api/lobby/{session_id} to create or lookup a lobby by name.
Returns the lobby id.
"""
http_base = _http_base_url(server_url)
url = f"{http_base}/api/lobby/{urllib.parse.quote(session_id)}"
payload = json.dumps(
{
"type": "lobby_create",
"data": {"name": lobby_name, "private": private},
}
).encode("utf-8")
req = urllib.request.Request(
url, data=payload, headers={"Content-Type": "application/json"}, method="POST"
)
# Prepare SSL context if requested (accept self-signed certs)
ssl_ctx = None
if insecure:
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
try:
with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp:
body = resp.read()
data = json.loads(body)
# Expect shape: { "type": "lobby_created", "data": {"id":..., ...}}
try:
lobby_resp = LobbyCreateResponse.model_validate(data)
except ValidationError as e:
raise RuntimeError(f"Invalid lobby response from {url}: {e}")
lobby_id = lobby_resp.data.id
if not lobby_id:
raise RuntimeError(f"No lobby id returned from {url}: {data}")
return lobby_id
except urllib.error.HTTPError as e:
# Try to include response body for.infoging
try:
body = e.read()
msg = body.decode("utf-8", errors="ignore")
except Exception:
msg = str(e)
raise RuntimeError(f"HTTP error creating lobby: {msg}")
except Exception as e:
raise RuntimeError(f"Error creating lobby: {e}")
async def main():
"""Example usage of the WebRTC signaling client with CLI options to create session and lobby."""
parser = argparse.ArgumentParser(description="Python WebRTC voicebot client")
parser.add_argument(
"--server-url",
default="http://localhost:8000/ai-voicebot",
help="AI-Voicebot lobby and signaling server base URL (http:// or https://)",
)
parser.add_argument(
"--lobby", default="default", help="Lobby name to create or join"
)
parser.add_argument(
"--session-name", default="Python Bot", help="Session (user) display name"
)
parser.add_argument(
"--session-id", default=None, help="Optional existing session id to reuse"
)
parser.add_argument(
"--password",
default=None,
help="Optional password to register or takeover a name",
)
parser.add_argument(
"--private", action="store_true", help="Create the lobby as private"
)
parser.add_argument(
"--insecure",
action="store_true",
help="Allow insecure server connections when using SSL (accept self-signed certs)",
)
args = parser.parse_args()
# Resolve session id (create if needed)
try:
session_id = create_or_get_session(
args.server_url, args.session_id, insecure=args.insecure
)
print(f"Using session id: {session_id}")
except Exception as e:
print(f"Failed to get/create session: {e}")
return
# Create or get lobby id
try:
lobby_id = create_or_get_lobby(
args.server_url,
session_id,
args.lobby,
args.private,
insecure=args.insecure,
)
print(f"Using lobby id: {lobby_id} (name={args.lobby})")
except Exception as e:
print(f"Failed to create/get lobby: {e}")
return
# Build websocket base URL (ws:// or wss://) from server_url and pass to client so
# it constructs the final websocket path (/ws/lobby/{lobby}/{session}) itself.
ws_base = _ws_url(args.server_url)
client = WebRTCSignalingClient(
ws_base, lobby_id, session_id, args.session_name, insecure=args.insecure
)
# Set up event handlers
async def on_peer_added(peer: Peer):
print(f"Peer added: {peer.peer_name}")
async def on_peer_removed(peer: Peer):
print(f"Peer removed: {peer.peer_name}")
async def on_track_received(peer: Peer, track: MediaStreamTrack):
print(f"Received {track.kind} track from {peer.peer_name}")
client.on_peer_added = on_peer_added
client.on_peer_removed = on_peer_removed
client.on_track_received = on_track_received
try:
# Connect and run
# If a password was provided on the CLI, store it on the client for use when setting name
if args.password:
client.name_password = args.password
await client.connect()
except KeyboardInterrupt:
print("Shutting down...")
finally:
await client.disconnect()
if __name__ == "__main__":
# Install required packages:
# pip install aiortc websockets opencv-python numpy
asyncio.run(main())