- Created shared/models.py with all common Pydantic models - Updated voicebot/main.py to use shared models and remove dd.get() calls - Updated server/main.py to use shared models - Fixed lobby_state message handling with proper validation - Updated Dockerfiles to include shared/ directory - Added comprehensive documentation and migration guide Benefits: - Type safety across both components - Consistent data validation - Eliminated unsafe dictionary access patterns - Centralized model definitions for easier maintenance
1365 lines
50 KiB
Python
1365 lines
50 KiB
Python
"""
|
|
WebRTC Media Agent for Python
|
|
|
|
This module provides synthetic audio/video track creation and WebRTC signaling
|
|
server communication, ported from the JavaScript MediaControl implementation.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import websockets
|
|
import numpy as np
|
|
import cv2
|
|
import fractions
|
|
from typing import (
|
|
Dict,
|
|
Optional,
|
|
Callable,
|
|
Awaitable,
|
|
TypedDict,
|
|
Protocol,
|
|
AsyncIterator,
|
|
cast,
|
|
)
|
|
|
|
from dataclasses import dataclass, field
|
|
from pydantic import ValidationError
|
|
|
|
# types.SimpleNamespace removed — not used anymore after parsing candidates via aiortc.sdp
|
|
import argparse
|
|
import urllib.request
|
|
import urllib.error
|
|
import urllib.parse
|
|
import ssl
|
|
import sys
|
|
import os
|
|
|
|
# Import shared models
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
from shared.models import (
|
|
SessionModel,
|
|
LobbyCreateResponse,
|
|
WebSocketMessageModel,
|
|
JoinStatusModel,
|
|
UserJoinedModel,
|
|
ParticipantModel,
|
|
LobbyStateModel,
|
|
UpdateModel,
|
|
ICECandidateDictModel,
|
|
AddPeerModel,
|
|
RemovePeerModel,
|
|
SessionDescriptionTypedModel,
|
|
SessionDescriptionModel,
|
|
IceCandidateModel,
|
|
)
|
|
from aiortc import (
|
|
RTCPeerConnection,
|
|
RTCSessionDescription,
|
|
RTCIceCandidate,
|
|
MediaStreamTrack,
|
|
)
|
|
from av import VideoFrame, AudioFrame
|
|
import time
|
|
from logger import logger
|
|
|
|
# Defensive monkeypatch: aioice Transaction.__retry may run after the
|
|
# underlying datagram transport or loop was torn down which results in
|
|
# AttributeError being raised and flooding logs. Wrap the original
|
|
# implementation to catch and suppress AttributeError while preserving
|
|
# other exceptions. This is a temporary mitigation to keep logs readable
|
|
# while we investigate/upstream a proper fix or upgrade aioice.
|
|
try:
|
|
import aioice.stun as _aioice_stun # type: ignore
|
|
|
|
# The method is defined with a double-underscore name (__retry) which
|
|
# gets name-mangled. Detect the actual attribute name robustly.
|
|
retry_attr_name = None
|
|
for name in dir(_aioice_stun.Transaction):
|
|
if name.endswith("retry"):
|
|
obj = getattr(_aioice_stun.Transaction, name)
|
|
if callable(obj):
|
|
retry_attr_name = name
|
|
_orig_retry = obj
|
|
break
|
|
|
|
if retry_attr_name is not None:
|
|
# Simple in-process dedupe cache so we only log the same AttributeError
|
|
# once per interval. This prevents flooding the logs when many
|
|
# transactions race to run after shutdown.
|
|
_MONKEYPATCH_LOG_CACHE: dict[str, float] = {}
|
|
_MONKEYPATCH_LOG_SUPPRESSION_INTERVAL = 5.0
|
|
|
|
def _should_log_once(key: str) -> bool:
|
|
now = time.time()
|
|
last = _MONKEYPATCH_LOG_CACHE.get(key)
|
|
if last is None or (now - last) > _MONKEYPATCH_LOG_SUPPRESSION_INTERVAL:
|
|
_MONKEYPATCH_LOG_CACHE[key] = now
|
|
return True
|
|
return False
|
|
|
|
def _safe_transaction_retry(self, *args, **kwargs): # type: ignore
|
|
try:
|
|
return _orig_retry(self, *args, **kwargs) # type: ignore
|
|
except AttributeError as e: # type: ignore
|
|
# Transport or event-loop already closed; log once per key
|
|
key = f"Transaction.{retry_attr_name}:{e}"
|
|
if _should_log_once(key):
|
|
logger.warning(
|
|
"aioice Transaction.%s AttributeError suppressed: %s",
|
|
retry_attr_name,
|
|
e,
|
|
)
|
|
except Exception: # type: ignore
|
|
# Preserve visibility for other unexpected exceptions
|
|
logger.exception(
|
|
"aioice Transaction.%s raised an unexpected exception",
|
|
retry_attr_name,
|
|
)
|
|
|
|
setattr(_aioice_stun.Transaction, retry_attr_name, _safe_transaction_retry) # type: ignore
|
|
logger.info("Applied safe aioice Transaction.%s monkeypatch", retry_attr_name)
|
|
else:
|
|
logger.warning("aioice Transaction.__retry not found; skipping monkeypatch")
|
|
except Exception as e:
|
|
logger.exception("Failed to apply aioice Transaction.__retry monkeypatch: %s", e)
|
|
|
|
# Additional defensive patch: wrap the protocol-level send_stun implementation
|
|
# (e.g. StunProtocol.send_stun) which ultimately calls the datagram transport's
|
|
# sendto. If the transport or its loop is already torn down, sendto can raise
|
|
# AttributeError which then triggers asyncio's fatal error path (calling a None
|
|
# loop). Wrapping here prevents the flood of selector_events/_fatal_error
|
|
# AttributeError traces.
|
|
try:
|
|
import aioice.ice as _aioice_ice # type: ignore
|
|
|
|
# Prefer to patch StunProtocol.send_stun which is used by the ICE code.
|
|
send_attr_name = None
|
|
if hasattr(_aioice_ice, "StunProtocol"):
|
|
proto_cls = getattr(_aioice_ice, "StunProtocol")
|
|
for name in dir(proto_cls):
|
|
if name.endswith("send_stun"):
|
|
attr = getattr(proto_cls, name)
|
|
if callable(attr):
|
|
send_attr_name = name
|
|
_orig_send_stun = attr
|
|
break
|
|
|
|
if send_attr_name is not None:
|
|
|
|
def _safe_send_stun(self, message, addr): # type: ignore
|
|
try:
|
|
return _orig_send_stun(self, message, addr) # type: ignore
|
|
except AttributeError as e: # type: ignore
|
|
# Likely transport._sock or transport._loop is None; log once
|
|
key = f"StunProtocol.{send_attr_name}:{e}"
|
|
if _should_log_once(key):
|
|
logger.warning(
|
|
"aioice StunProtocol.%s AttributeError suppressed: %s",
|
|
send_attr_name,
|
|
e,
|
|
)
|
|
except Exception: # type: ignore
|
|
logger.exception(
|
|
"aioice StunProtocol.%s raised unexpected exception", send_attr_name
|
|
)
|
|
|
|
setattr(proto_cls, send_attr_name, _safe_send_stun) # type: ignore
|
|
logger.info("Applied safe aioice StunProtocol.%s monkeypatch", send_attr_name)
|
|
else:
|
|
logger.warning("aioice StunProtocol.send_stun not found; skipping monkeypatch")
|
|
except Exception as e:
|
|
logger.exception("Failed to apply aioice StunProtocol.send_stun monkeypatch: %s", e)
|
|
|
|
# TypedDict for ICE candidate payloads received from signalling
|
|
class ICECandidateDict(TypedDict, total=False):
|
|
candidate: str
|
|
sdpMid: Optional[str]
|
|
sdpMLineIndex: Optional[int]
|
|
|
|
|
|
# Generic message payload type
|
|
MessageData = dict[str, object]
|
|
|
|
|
|
# Message TypedDicts for signaling payloads
|
|
class BaseMessage(TypedDict, total=False):
|
|
type: str
|
|
data: object
|
|
|
|
|
|
class AddPeerPayload(TypedDict):
|
|
peer_id: str
|
|
peer_name: str
|
|
should_create_offer: bool
|
|
|
|
|
|
class RemovePeerPayload(TypedDict):
|
|
peer_id: str
|
|
peer_name: str
|
|
|
|
|
|
class SessionDescriptionTyped(TypedDict):
|
|
type: str
|
|
sdp: str
|
|
|
|
|
|
class SessionDescriptionPayload(TypedDict):
|
|
peer_id: str
|
|
peer_name: str
|
|
session_description: SessionDescriptionTyped
|
|
|
|
|
|
class IceCandidatePayload(TypedDict):
|
|
peer_id: str
|
|
peer_name: str
|
|
candidate: ICECandidateDict
|
|
|
|
|
|
|
|
class WebSocketProtocol(Protocol):
|
|
def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ...
|
|
def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ...
|
|
def __aiter__(self) -> AsyncIterator[str]: ...
|
|
|
|
|
|
def _default_attributes() -> Dict[str, object]:
|
|
return {}
|
|
|
|
|
|
@dataclass
|
|
class Peer:
|
|
"""Represents a WebRTC peer in the session"""
|
|
|
|
session_id: str
|
|
peer_name: str
|
|
# Generic attributes bag. Values can be tracks or simple metadata.
|
|
attributes: Dict[str, object] = field(default_factory=_default_attributes)
|
|
muted: bool = False
|
|
video_on: bool = True
|
|
local: bool = False
|
|
dead: bool = False
|
|
connection: Optional[RTCPeerConnection] = None
|
|
|
|
|
|
class AnimatedVideoTrack(MediaStreamTrack):
|
|
async def next_timestamp(self):
|
|
# Returns (pts, time_base) for 15 FPS video
|
|
pts = int(self.frame_count * (1 / 15) * 90000)
|
|
time_base = 1 / 90000
|
|
return pts, time_base
|
|
|
|
"""
|
|
Synthetic video track that generates animated content with a bouncing ball.
|
|
Ported from JavaScript createAnimatedVideoTrack function.
|
|
"""
|
|
kind = "video"
|
|
|
|
def __init__(self, width: int = 320, height: int = 240, name: str = ""):
|
|
super().__init__()
|
|
self.width = width
|
|
self.height = height
|
|
self.name = name
|
|
|
|
# Generate color from name hash (similar to JavaScript nameToColor)
|
|
self.ball_color = (
|
|
self._name_to_color(name) if name else (0, 255, 136)
|
|
) # Default green
|
|
|
|
# Ball properties
|
|
self.ball = {
|
|
"x": width / 2,
|
|
"y": height / 2,
|
|
"radius": min(width, height) * 0.06,
|
|
"dx": 3.0,
|
|
"dy": 2.0,
|
|
}
|
|
|
|
self.frame_count = 0
|
|
self._start_time = time.time()
|
|
|
|
def _name_to_color(self, name: str) -> tuple[int, int, int]:
|
|
"""Convert name to HSL color, then to RGB tuple"""
|
|
# Simple hash function (djb2)
|
|
hash_value = 5381
|
|
for char in name:
|
|
hash_value = ((hash_value << 5) + hash_value + ord(char)) & 0xFFFFFFFF
|
|
|
|
# Generate HSL color from hash
|
|
hue = abs(hash_value) % 360
|
|
sat = 60 + (abs(hash_value) % 30) # 60-89%
|
|
light = 45 + (abs(hash_value) % 30) # 45-74%
|
|
|
|
# Convert HSL to RGB
|
|
h = hue / 360.0
|
|
s = sat / 100.0
|
|
lightness = light / 100.0
|
|
|
|
c = (1 - abs(2 * lightness - 1)) * s
|
|
x = c * (1 - abs((h * 6) % 2 - 1))
|
|
m = lightness - c / 2
|
|
|
|
if h < 1 / 6:
|
|
r, g, b = c, x, 0
|
|
elif h < 2 / 6:
|
|
r, g, b = x, c, 0
|
|
elif h < 3 / 6:
|
|
r, g, b = 0, c, x
|
|
elif h < 4 / 6:
|
|
r, g, b = 0, x, c
|
|
elif h < 5 / 6:
|
|
r, g, b = x, 0, c
|
|
else:
|
|
r, g, b = c, 0, x
|
|
|
|
return (
|
|
int((b + m) * 255),
|
|
int((g + m) * 255),
|
|
int((r + m) * 255),
|
|
) # BGR for OpenCV
|
|
|
|
async def recv(self):
|
|
"""Generate video frames at 15 FPS"""
|
|
pts, time_base = await self.next_timestamp()
|
|
|
|
# Create black background
|
|
frame_array = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
|
|
|
# Update ball position
|
|
ball = self.ball
|
|
ball["x"] += ball["dx"]
|
|
ball["y"] += ball["dy"]
|
|
|
|
# Bounce off walls
|
|
if ball["x"] + ball["radius"] >= self.width or ball["x"] - ball["radius"] <= 0:
|
|
ball["dx"] = -ball["dx"]
|
|
if ball["y"] + ball["radius"] >= self.height or ball["y"] - ball["radius"] <= 0:
|
|
ball["dy"] = -ball["dy"]
|
|
|
|
# Keep ball in bounds
|
|
ball["x"] = max(ball["radius"], min(self.width - ball["radius"], ball["x"]))
|
|
ball["y"] = max(ball["radius"], min(self.height - ball["radius"], ball["y"]))
|
|
|
|
# Draw ball
|
|
cv2.circle(
|
|
frame_array,
|
|
(int(ball["x"]), int(ball["y"])),
|
|
int(ball["radius"]),
|
|
self.ball_color,
|
|
-1,
|
|
)
|
|
|
|
# Add frame counter text
|
|
frame_text = f"Frame: {int(time.time() * 1000) % 10000}"
|
|
logger.info(frame_text)
|
|
cv2.putText(
|
|
frame_array,
|
|
frame_text,
|
|
(10, 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.5,
|
|
(255, 255, 255),
|
|
1,
|
|
)
|
|
|
|
# Convert to VideoFrame
|
|
frame = VideoFrame.from_ndarray(frame_array, format="bgr24")
|
|
frame.pts = pts
|
|
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
|
|
|
|
self.frame_count += 1
|
|
return frame
|
|
|
|
|
|
class SilentAudioTrack(MediaStreamTrack):
|
|
async def next_timestamp(self):
|
|
# Returns (pts, time_base) for 20ms audio frames at 48kHz
|
|
pts = int(time.time() * self.sample_rate)
|
|
time_base = 1 / self.sample_rate
|
|
return pts, time_base
|
|
|
|
"""
|
|
Synthetic audio track that generates silence.
|
|
Ported from JavaScript createSilentAudioTrack function.
|
|
"""
|
|
kind = "audio"
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.sample_rate = 48000
|
|
self.samples_per_frame = 960 # 20ms at 48kHz
|
|
|
|
async def recv(self):
|
|
"""Generate silent audio frames"""
|
|
pts, time_base = await self.next_timestamp()
|
|
|
|
# Create silent audio data
|
|
samples = np.zeros((self.samples_per_frame,), dtype=np.float32)
|
|
|
|
# Convert to AudioFrame
|
|
frame = AudioFrame.from_ndarray(
|
|
samples.reshape(1, -1), format="flt", layout="mono"
|
|
)
|
|
frame.sample_rate = self.sample_rate
|
|
frame.pts = pts
|
|
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
|
|
|
|
return frame
|
|
|
|
|
|
class WebRTCSignalingClient:
|
|
"""
|
|
WebRTC signaling client that communicates with the FastAPI signaling server.
|
|
Handles peer-to-peer connection establishment and media streaming.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
server_url: str,
|
|
lobby_id: str,
|
|
session_id: str,
|
|
session_name: str,
|
|
insecure: bool = False,
|
|
):
|
|
self.server_url = server_url
|
|
self.lobby_id = lobby_id
|
|
self.session_id = session_id
|
|
self.session_name = session_name
|
|
self.insecure = insecure
|
|
|
|
# WebSocket client protocol instance (typed as object to avoid Any)
|
|
self.websocket: Optional[object] = None
|
|
|
|
# Optional password to register or takeover a name
|
|
self.name_password: Optional[str] = None
|
|
|
|
self.peers: dict[str, Peer] = {}
|
|
self.peer_connections: dict[str, RTCPeerConnection] = {}
|
|
self.local_tracks: dict[str, MediaStreamTrack] = {}
|
|
|
|
# State management
|
|
self.is_negotiating: dict[str, bool] = {}
|
|
self.making_offer: dict[str, bool] = {}
|
|
self.initiated_offer: set[str] = set()
|
|
self.pending_ice_candidates: dict[str, list[ICECandidateDict]] = {}
|
|
|
|
# Event callbacks
|
|
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
|
|
self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None
|
|
self.on_track_received: Optional[
|
|
Callable[[Peer, MediaStreamTrack], Awaitable[None]]
|
|
] = None
|
|
|
|
async def connect(self):
|
|
"""Connect to the signaling server"""
|
|
ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
|
|
logger.info(f"Connecting to signaling server: {ws_url}")
|
|
|
|
# Log network information for debugging
|
|
try:
|
|
import socket
|
|
|
|
hostname = socket.gethostname()
|
|
local_ip = socket.gethostbyname(hostname)
|
|
logger.info(f"Container hostname: {hostname}, local IP: {local_ip}")
|
|
|
|
# Get all network interfaces
|
|
import subprocess
|
|
|
|
result = subprocess.run(
|
|
["ip", "addr", "show"], capture_output=True, text=True
|
|
)
|
|
logger.info(f"Network interfaces:\n{result.stdout}")
|
|
except Exception as e:
|
|
logger.warning(f"Could not get network info: {e}")
|
|
|
|
try:
|
|
# If insecure (self-signed certs), create an SSL context for the websocket
|
|
ws_ssl = None
|
|
if self.insecure:
|
|
ws_ssl = ssl.create_default_context()
|
|
ws_ssl.check_hostname = False
|
|
ws_ssl.verify_mode = ssl.CERT_NONE
|
|
|
|
logger.info(
|
|
f"Attempting websocket connection to {ws_url} with ssl={bool(ws_ssl)}"
|
|
)
|
|
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
|
|
logger.info("Connected to signaling server")
|
|
|
|
# Set up local media
|
|
await self._setup_local_media()
|
|
|
|
# Set name and join lobby
|
|
name_payload: MessageData = {"name": self.session_name}
|
|
if self.name_password:
|
|
name_payload["password"] = self.name_password
|
|
logger.info(f"Sending set_name: {name_payload}")
|
|
await self._send_message("set_name", name_payload)
|
|
logger.info("Sending join message")
|
|
await self._send_message("join", {})
|
|
|
|
# Start message handling
|
|
logger.info("Starting message handler loop")
|
|
await self._handle_messages()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to signaling server: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def disconnect(self):
|
|
"""Disconnect from signaling server and cleanup"""
|
|
if self.websocket:
|
|
ws = cast(WebSocketProtocol, self.websocket)
|
|
await ws.close()
|
|
|
|
# Close all peer connections
|
|
for pc in self.peer_connections.values():
|
|
await pc.close()
|
|
|
|
# Stop local tracks
|
|
for track in self.local_tracks.values():
|
|
track.stop()
|
|
|
|
logger.info("Disconnected from signaling server")
|
|
|
|
async def _setup_local_media(self):
|
|
"""Create local synthetic media tracks"""
|
|
# Create synthetic video track
|
|
video_track = AnimatedVideoTrack(name=self.session_name)
|
|
self.local_tracks["video"] = video_track
|
|
|
|
# Create synthetic audio track
|
|
audio_track = SilentAudioTrack()
|
|
self.local_tracks["audio"] = audio_track
|
|
|
|
# Add local peer to peers dict
|
|
local_peer = Peer(
|
|
session_id=self.session_id,
|
|
peer_name=self.session_name,
|
|
local=True,
|
|
attributes={"tracks": self.local_tracks},
|
|
)
|
|
self.peers[self.session_id] = local_peer
|
|
|
|
logger.info("Local synthetic media tracks created")
|
|
|
|
async def _send_message(
|
|
self, message_type: str, data: Optional[MessageData] = None
|
|
):
|
|
"""Send message to signaling server"""
|
|
if not self.websocket:
|
|
logger.error("No websocket connection")
|
|
return
|
|
# Build message with explicit type to avoid type narrowing
|
|
message: dict[str, object] = {"type": message_type}
|
|
if data is not None:
|
|
message["data"] = data
|
|
|
|
ws = cast(WebSocketProtocol, self.websocket)
|
|
try:
|
|
logger.info(f"_send_message: Sending {message_type} with data: {data}")
|
|
await ws.send(json.dumps(message))
|
|
logger.info(f"_send_message: Sent message: {message_type}")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"_send_message: Failed to send {message_type}: {e}", exc_info=True
|
|
)
|
|
|
|
async def _handle_messages(self):
|
|
"""Handle incoming messages from signaling server"""
|
|
try:
|
|
ws = cast(WebSocketProtocol, self.websocket)
|
|
async for message in ws:
|
|
logger.info(f"_handle_messages: Received raw message: {message}")
|
|
try:
|
|
data = cast(MessageData, json.loads(message))
|
|
except Exception as e:
|
|
logger.error(
|
|
f"_handle_messages: Failed to parse message: {e}", exc_info=True
|
|
)
|
|
continue
|
|
await self._process_message(data)
|
|
except websockets.exceptions.ConnectionClosed as e:
|
|
logger.info(f"WebSocket connection closed: {e}")
|
|
except Exception as e:
|
|
logger.error(f"Error handling messages: {e}", exc_info=True)
|
|
|
|
async def _process_message(self, message: MessageData):
|
|
"""Process incoming signaling messages"""
|
|
try:
|
|
# Validate the base message structure first
|
|
validated_message = WebSocketMessageModel.model_validate(message)
|
|
msg_type = validated_message.type
|
|
data = validated_message.data
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid message structure: {e}", exc_info=True)
|
|
return
|
|
|
|
logger.info(
|
|
f"_process_message: Received message type: {msg_type} with data: {data}"
|
|
)
|
|
|
|
if msg_type == "addPeer":
|
|
try:
|
|
validated = AddPeerModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid addPeer payload: {e}", exc_info=True)
|
|
return
|
|
await self._handle_add_peer(validated)
|
|
elif msg_type == "removePeer":
|
|
try:
|
|
validated = RemovePeerModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid removePeer payload: {e}", exc_info=True)
|
|
return
|
|
await self._handle_remove_peer(validated)
|
|
elif msg_type == "sessionDescription":
|
|
try:
|
|
validated = SessionDescriptionModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid sessionDescription payload: {e}", exc_info=True)
|
|
return
|
|
await self._handle_session_description(validated)
|
|
elif msg_type == "iceCandidate":
|
|
try:
|
|
validated = IceCandidateModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid iceCandidate payload: {e}", exc_info=True)
|
|
return
|
|
await self._handle_ice_candidate(validated)
|
|
elif msg_type == "join_status":
|
|
try:
|
|
validated = JoinStatusModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid join_status payload: {e}", exc_info=True)
|
|
return
|
|
logger.info(f"Join status: {validated.status} - {validated.message}")
|
|
elif msg_type == "user_joined":
|
|
try:
|
|
validated = UserJoinedModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid user_joined payload: {e}", exc_info=True)
|
|
return
|
|
logger.info(
|
|
f"User joined: {validated.name} (session: {validated.session_id})"
|
|
)
|
|
elif msg_type == "lobby_state":
|
|
try:
|
|
validated = LobbyStateModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid lobby_state payload: {e}", exc_info=True)
|
|
return
|
|
participants = validated.participants
|
|
logger.info(f"Lobby state updated: {len(participants)} participants")
|
|
elif msg_type == "update":
|
|
try:
|
|
validated = UpdateModel.model_validate(data)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid update payload: {e}", exc_info=True)
|
|
return
|
|
logger.info(f"Received update message: {validated}")
|
|
else:
|
|
logger.info(f"Unhandled message type: {msg_type} with data: {data}")
|
|
|
|
async def _handle_add_peer(self, data: AddPeerModel):
|
|
"""Handle addPeer message - create new peer connection"""
|
|
peer_id = data.peer_id
|
|
peer_name = data.peer_name
|
|
should_create_offer = data.should_create_offer
|
|
|
|
logger.info(
|
|
f"Adding peer: {peer_name} (should_create_offer: {should_create_offer})"
|
|
)
|
|
logger.info(
|
|
f"_handle_add_peer: peer_id={peer_id}, peer_name={peer_name}, should_create_offer={should_create_offer}"
|
|
)
|
|
|
|
# Check if peer already exists
|
|
if peer_id in self.peer_connections:
|
|
pc = self.peer_connections[peer_id]
|
|
logger.info(
|
|
f"_handle_add_peer: Existing connection state: {pc.connectionState}"
|
|
)
|
|
if pc.connectionState in ["new", "connected", "connecting"]:
|
|
logger.info(f"Peer connection already exists for {peer_name}")
|
|
return
|
|
else:
|
|
# Clean up stale connection
|
|
logger.info(
|
|
f"_handle_add_peer: Closing stale connection for {peer_name}"
|
|
)
|
|
await pc.close()
|
|
del self.peer_connections[peer_id]
|
|
|
|
# Create new peer
|
|
peer = Peer(session_id=peer_id, peer_name=peer_name, local=False)
|
|
self.peers[peer_id] = peer
|
|
|
|
# Create RTCPeerConnection
|
|
from aiortc.rtcconfiguration import RTCConfiguration, RTCIceServer
|
|
|
|
config = RTCConfiguration(
|
|
iceServers=[
|
|
RTCIceServer(urls="stun:ketrenos.com:3478"),
|
|
RTCIceServer(
|
|
urls="turns:ketrenos.com:5349",
|
|
username="ketra",
|
|
credential="ketran",
|
|
),
|
|
# Add Google's public STUN server as fallback
|
|
RTCIceServer(urls="stun:stun.l.google.com:19302"),
|
|
],
|
|
)
|
|
logger.info(
|
|
f"_handle_add_peer: Creating RTCPeerConnection for {peer_name} with config: {config}"
|
|
)
|
|
pc = RTCPeerConnection(configuration=config)
|
|
|
|
# Add ICE gathering state change handler (explicit registration to satisfy static analyzers)
|
|
def on_ice_gathering_state_change() -> None:
|
|
logger.info(f"ICE gathering state: {pc.iceGatheringState}")
|
|
# Debug: Check if we have any local candidates when gathering is complete
|
|
if pc.iceGatheringState == "complete":
|
|
logger.info(f"ICE gathering complete for {peer_name} - checking if candidates were generated...")
|
|
|
|
pc.on("icegatheringstatechange")(on_ice_gathering_state_change)
|
|
|
|
# Add connection state change handler (explicit registration to satisfy static analyzers)
|
|
def on_connection_state_change() -> None:
|
|
logger.info(f"Connection state: {pc.connectionState}")
|
|
|
|
pc.on("connectionstatechange")(on_connection_state_change)
|
|
|
|
self.peer_connections[peer_id] = pc
|
|
peer.connection = pc
|
|
|
|
# Set up event handlers
|
|
def on_track(track: MediaStreamTrack) -> None:
|
|
logger.info(f"Received {track.kind} track from {peer_name}")
|
|
logger.info(f"on_track: {track.kind} from {peer_name}, track={track}")
|
|
peer.attributes[f"{track.kind}_track"] = track
|
|
if self.on_track_received:
|
|
asyncio.ensure_future(self.on_track_received(peer, track))
|
|
|
|
pc.on("track")(on_track)
|
|
|
|
def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None:
|
|
logger.info(f"on_ice_candidate: {candidate}")
|
|
logger.info(f"on_ice_candidate CALLED for {peer_name}: candidate={candidate}")
|
|
if not candidate:
|
|
logger.info(f"on_ice_candidate: End of candidates signal for {peer_name}")
|
|
return
|
|
|
|
# Raw SDP fragment for the candidate
|
|
raw = getattr(candidate, "candidate", None)
|
|
|
|
# Try to infer candidate type from the SDP string (host/srflx/relay/prflx)
|
|
def _parse_type(s: Optional[str]) -> str:
|
|
if not s:
|
|
return "eoc"
|
|
import re
|
|
|
|
m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", s)
|
|
return m.group(1) if m else "unknown"
|
|
|
|
cand_type = _parse_type(raw)
|
|
protocol = getattr(candidate, "protocol", "unknown")
|
|
logger.info(
|
|
f"ICE candidate outgoing for {peer_name}: type={cand_type} protocol={protocol} sdp={raw}"
|
|
)
|
|
|
|
candidate_dict: MessageData = {
|
|
"candidate": raw,
|
|
"sdpMid": getattr(candidate, "sdpMid", None),
|
|
"sdpMLineIndex": getattr(candidate, "sdpMLineIndex", None),
|
|
}
|
|
payload: MessageData = {"peer_id": peer_id, "candidate": candidate_dict}
|
|
logger.info(
|
|
f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_dict}"
|
|
)
|
|
asyncio.ensure_future(self._send_message("relayICECandidate", payload))
|
|
|
|
pc.on("icecandidate")(on_ice_candidate)
|
|
|
|
# Add local tracks
|
|
for track in self.local_tracks.values():
|
|
logger.info(
|
|
f"_handle_add_peer: Adding local track {track.kind} to {peer_name}"
|
|
)
|
|
pc.addTrack(track)
|
|
|
|
# Create offer if needed
|
|
if should_create_offer:
|
|
self.initiated_offer.add(peer_id)
|
|
self.making_offer[peer_id] = True
|
|
self.is_negotiating[peer_id] = True
|
|
|
|
try:
|
|
logger.info(f"_handle_add_peer: Creating offer for {peer_name}")
|
|
offer = await pc.createOffer()
|
|
logger.info(f"_handle_add_peer: Offer created for {peer_name}: {offer}")
|
|
await pc.setLocalDescription(offer)
|
|
logger.info(f"_handle_add_peer: Local description set for {peer_name}")
|
|
|
|
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
|
|
# Use Method 2: Complete SDP approach to extract ICE candidates
|
|
logger.info(f"_handle_add_peer: Waiting for ICE gathering to complete for {peer_name}")
|
|
while pc.iceGatheringState != "complete":
|
|
await asyncio.sleep(0.1)
|
|
|
|
logger.info(f"_handle_add_peer: ICE gathering complete, extracting candidates from SDP for {peer_name}")
|
|
|
|
# Parse ICE candidates from the local SDP
|
|
sdp_lines = pc.localDescription.sdp.split('\n')
|
|
candidate_lines = [line for line in sdp_lines if line.startswith('a=candidate:')]
|
|
|
|
# Track which media section we're in to determine sdpMid and sdpMLineIndex
|
|
current_media_index = -1
|
|
current_mid = None
|
|
|
|
for line in sdp_lines:
|
|
if line.startswith('m='): # Media section
|
|
current_media_index += 1
|
|
elif line.startswith('a=mid:'): # Media ID
|
|
current_mid = line.split(':', 1)[1].strip()
|
|
elif line.startswith('a=candidate:'):
|
|
candidate_sdp = line[2:] # Remove 'a=' prefix
|
|
|
|
candidate_dict: MessageData = {
|
|
"candidate": candidate_sdp,
|
|
"sdpMid": current_mid,
|
|
"sdpMLineIndex": current_media_index,
|
|
}
|
|
payload_candidate: MessageData = {
|
|
"peer_id": peer_id,
|
|
"candidate": candidate_dict
|
|
}
|
|
|
|
logger.info(f"_handle_add_peer: Sending extracted ICE candidate for {peer_name}: {candidate_sdp[:60]}...")
|
|
await self._send_message("relayICECandidate", payload_candidate)
|
|
|
|
# Send end-of-candidates signal (empty candidate)
|
|
end_candidate_dict: MessageData = {
|
|
"candidate": "",
|
|
"sdpMid": None,
|
|
"sdpMLineIndex": None,
|
|
}
|
|
payload_end: MessageData = {
|
|
"peer_id": peer_id,
|
|
"candidate": end_candidate_dict
|
|
}
|
|
logger.info(f"_handle_add_peer: Sending end-of-candidates signal for {peer_name}")
|
|
await self._send_message("relayICECandidate", payload_end)
|
|
|
|
logger.info(f"_handle_add_peer: Sent {len(candidate_lines)} ICE candidates to {peer_name}")
|
|
|
|
await self._send_message(
|
|
"relaySessionDescription",
|
|
{
|
|
"peer_id": peer_id,
|
|
"session_description": {"type": offer.type, "sdp": offer.sdp},
|
|
},
|
|
)
|
|
|
|
logger.info(f"Offer sent to {peer_name}")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to create/send offer to {peer_name}: {e}", exc_info=True
|
|
)
|
|
finally:
|
|
self.making_offer[peer_id] = False
|
|
|
|
if self.on_peer_added:
|
|
await self.on_peer_added(peer)
|
|
|
|
async def _handle_remove_peer(self, data: RemovePeerModel):
|
|
"""Handle removePeer message"""
|
|
peer_id = data.peer_id
|
|
peer_name = data.peer_name
|
|
|
|
logger.info(f"Removing peer: {peer_name}")
|
|
|
|
# Close peer connection
|
|
if peer_id in self.peer_connections:
|
|
pc = self.peer_connections[peer_id]
|
|
await pc.close()
|
|
del self.peer_connections[peer_id]
|
|
|
|
# Clean up state
|
|
self.is_negotiating.pop(peer_id, None)
|
|
self.making_offer.pop(peer_id, None)
|
|
self.initiated_offer.discard(peer_id)
|
|
self.pending_ice_candidates.pop(peer_id, None)
|
|
|
|
# Remove peer
|
|
peer = self.peers.pop(peer_id, None)
|
|
if peer and self.on_peer_removed:
|
|
await self.on_peer_removed(peer)
|
|
|
|
async def _handle_session_description(self, data: SessionDescriptionModel):
|
|
"""Handle sessionDescription message"""
|
|
peer_id = data.peer_id
|
|
peer_name = data.peer_name
|
|
session_description = data.session_description.model_dump()
|
|
|
|
logger.info(f"Received {session_description['type']} from {peer_name}")
|
|
|
|
pc = self.peer_connections.get(peer_id)
|
|
if not pc:
|
|
logger.error(f"No peer connection for {peer_name}")
|
|
return
|
|
|
|
desc = RTCSessionDescription(
|
|
sdp=session_description["sdp"], type=session_description["type"]
|
|
)
|
|
|
|
# Handle offer collision (polite peer pattern)
|
|
making_offer = self.making_offer.get(peer_id, False)
|
|
offer_collision = desc.type == "offer" and (
|
|
making_offer or pc.signalingState != "stable"
|
|
)
|
|
we_initiated = peer_id in self.initiated_offer
|
|
ignore_offer = we_initiated and offer_collision
|
|
|
|
if ignore_offer:
|
|
logger.info(f"Ignoring offer from {peer_name} due to collision")
|
|
return
|
|
|
|
try:
|
|
await pc.setRemoteDescription(desc)
|
|
self.is_negotiating[peer_id] = False
|
|
logger.info(f"Remote description set for {peer_name}")
|
|
|
|
# Process queued ICE candidates
|
|
pending_candidates = self.pending_ice_candidates.pop(peer_id, [])
|
|
from aiortc.sdp import candidate_from_sdp
|
|
|
|
for candidate_data in pending_candidates:
|
|
# candidate_data is a dict-like ICECandidateDict; convert SDP string
|
|
cand = candidate_data.get("candidate")
|
|
# handle end-of-candidates marker
|
|
if not cand:
|
|
await pc.addIceCandidate(None)
|
|
logger.info(f"Added queued end-of-candidates for {peer_name}")
|
|
continue
|
|
|
|
# cand may be the full "candidate:..." string or the inner SDP part
|
|
if cand and cand.startswith("candidate:"):
|
|
sdp_part = cand.split(":", 1)[1]
|
|
else:
|
|
sdp_part = cand
|
|
|
|
try:
|
|
rtc_candidate = candidate_from_sdp(sdp_part)
|
|
rtc_candidate.sdpMid = candidate_data.get("sdpMid")
|
|
rtc_candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex")
|
|
await pc.addIceCandidate(rtc_candidate)
|
|
logger.info(f"Added queued ICE candidate for {peer_name}")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to add queued ICE candidate for {peer_name}: {e}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to set remote description for {peer_name}: {e}")
|
|
return
|
|
|
|
# Create answer if this was an offer
|
|
if session_description["type"] == "offer":
|
|
try:
|
|
answer = await pc.createAnswer()
|
|
await pc.setLocalDescription(answer)
|
|
|
|
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
|
|
# Use Method 2: Complete SDP approach to extract ICE candidates
|
|
logger.info(f"_handle_session_description: Waiting for ICE gathering to complete for {peer_name} (answer)")
|
|
while pc.iceGatheringState != "complete":
|
|
await asyncio.sleep(0.1)
|
|
|
|
logger.info(f"_handle_session_description: ICE gathering complete, extracting candidates from SDP for {peer_name} (answer)")
|
|
|
|
# Parse ICE candidates from the local SDP
|
|
sdp_lines = pc.localDescription.sdp.split('\n')
|
|
candidate_lines = [line for line in sdp_lines if line.startswith('a=candidate:')]
|
|
|
|
# Track which media section we're in to determine sdpMid and sdpMLineIndex
|
|
current_media_index = -1
|
|
current_mid = None
|
|
|
|
for line in sdp_lines:
|
|
if line.startswith('m='): # Media section
|
|
current_media_index += 1
|
|
elif line.startswith('a=mid:'): # Media ID
|
|
current_mid = line.split(':', 1)[1].strip()
|
|
elif line.startswith('a=candidate:'):
|
|
candidate_sdp = line[2:] # Remove 'a=' prefix
|
|
|
|
candidate_dict: MessageData = {
|
|
"candidate": candidate_sdp,
|
|
"sdpMid": current_mid,
|
|
"sdpMLineIndex": current_media_index,
|
|
}
|
|
payload_candidate: MessageData = {
|
|
"peer_id": peer_id,
|
|
"candidate": candidate_dict
|
|
}
|
|
|
|
logger.info(f"_handle_session_description: Sending extracted ICE candidate for {peer_name} (answer): {candidate_sdp[:60]}...")
|
|
await self._send_message("relayICECandidate", payload_candidate)
|
|
|
|
# Send end-of-candidates signal (empty candidate)
|
|
end_candidate_dict: MessageData = {
|
|
"candidate": "",
|
|
"sdpMid": None,
|
|
"sdpMLineIndex": None,
|
|
}
|
|
payload_end: MessageData = {
|
|
"peer_id": peer_id,
|
|
"candidate": end_candidate_dict
|
|
}
|
|
logger.info(f"_handle_session_description: Sending end-of-candidates signal for {peer_name} (answer)")
|
|
await self._send_message("relayICECandidate", payload_end)
|
|
|
|
logger.info(f"_handle_session_description: Sent {len(candidate_lines)} ICE candidates to {peer_name} (answer)")
|
|
|
|
await self._send_message(
|
|
"relaySessionDescription",
|
|
{
|
|
"peer_id": peer_id,
|
|
"session_description": {"type": answer.type, "sdp": answer.sdp},
|
|
},
|
|
)
|
|
|
|
logger.info(f"Answer sent to {peer_name}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to create/send answer to {peer_name}: {e}")
|
|
|
|
async def _handle_ice_candidate(self, data: IceCandidateModel):
|
|
"""Handle iceCandidate message"""
|
|
peer_id = data.peer_id
|
|
peer_name = data.peer_name
|
|
candidate_data = data.candidate.model_dump()
|
|
|
|
logger.info(f"Received ICE candidate from {peer_name}")
|
|
|
|
pc = self.peer_connections.get(peer_id)
|
|
if not pc:
|
|
logger.error(f"No peer connection for {peer_name}")
|
|
return
|
|
|
|
# Queue candidate if remote description not set
|
|
if not pc.remoteDescription:
|
|
logger.info(
|
|
f"Remote description not set, queuing ICE candidate for {peer_name}"
|
|
)
|
|
if peer_id not in self.pending_ice_candidates:
|
|
self.pending_ice_candidates[peer_id] = []
|
|
# candidate_data is a dict from Pydantic model; cast to the TypedDict
|
|
self.pending_ice_candidates[peer_id].append(
|
|
cast(ICECandidateDict, candidate_data)
|
|
)
|
|
return
|
|
|
|
try:
|
|
from aiortc.sdp import candidate_from_sdp
|
|
|
|
cand = candidate_data.get("candidate")
|
|
if not cand:
|
|
# end-of-candidates
|
|
await pc.addIceCandidate(None)
|
|
logger.info(f"End-of-candidates added for {peer_name}")
|
|
return
|
|
|
|
if cand and cand.startswith("candidate:"):
|
|
sdp_part = cand.split(":", 1)[1]
|
|
else:
|
|
sdp_part = cand
|
|
|
|
# Detect type for logging
|
|
try:
|
|
import re
|
|
|
|
m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", sdp_part)
|
|
cand_type = m.group(1) if m else "unknown"
|
|
except Exception:
|
|
cand_type = "unknown"
|
|
|
|
try:
|
|
rtc_candidate = candidate_from_sdp(sdp_part)
|
|
rtc_candidate.sdpMid = candidate_data.get("sdpMid")
|
|
rtc_candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex")
|
|
|
|
# aiortc expects an object with attributes (RTCIceCandidate)
|
|
await pc.addIceCandidate(rtc_candidate)
|
|
logger.info(f"ICE candidate added for {peer_name}: type={cand_type}")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to add ICE candidate for {peer_name}: type={cand_type} error={e} sdp='{sdp_part}'",
|
|
exc_info=True,
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Unexpected error handling ICE candidate for {peer_name}: {e}",
|
|
exc_info=True,
|
|
)
|
|
|
|
async def _handle_ice_connection_failure(self, peer_id: str, peer_name: str):
|
|
"""Handle ICE connection failure by logging details"""
|
|
logger.info(f"ICE connection failure detected for {peer_name}")
|
|
|
|
pc = self.peer_connections.get(peer_id)
|
|
if not pc:
|
|
logger.error(
|
|
f"No peer connection found for {peer_name} during ICE failure recovery"
|
|
)
|
|
return
|
|
|
|
logger.error(
|
|
f"ICE connection failed for {peer_name}. Connection state: {pc.connectionState}, ICE state: {pc.iceConnectionState}"
|
|
)
|
|
# In a real implementation, you might want to notify the user or attempt reconnection
|
|
|
|
async def _schedule_ice_timeout(self, peer_id: str, peer_name: str):
|
|
"""Schedule a timeout for ICE connection checking"""
|
|
await asyncio.sleep(30) # Wait 30 seconds
|
|
|
|
pc = self.peer_connections.get(peer_id)
|
|
if not pc:
|
|
return
|
|
|
|
if pc.iceConnectionState == "checking":
|
|
logger.warning(
|
|
f"ICE connection timeout for {peer_name} - still in checking state after 30 seconds"
|
|
)
|
|
logger.warning(
|
|
f"Final connection state: {pc.connectionState}, ICE state: {pc.iceConnectionState}"
|
|
)
|
|
logger.warning(
|
|
"This might be due to network connectivity issues between the browser and Docker container"
|
|
)
|
|
logger.warning(
|
|
"Consider checking: 1) Port forwarding 2) TURN server config 3) Docker network mode"
|
|
)
|
|
elif pc.iceConnectionState in ["failed", "closed"]:
|
|
logger.info(
|
|
f"ICE connection for {peer_name} resolved to: {pc.iceConnectionState}"
|
|
)
|
|
else:
|
|
logger.info(
|
|
f"ICE connection for {peer_name} established: {pc.iceConnectionState}"
|
|
)
|
|
|
|
|
|
# Example usage
|
|
def _http_base_url(server_url: str) -> str:
|
|
# Convert ws:// or wss:// to http(s) and ensure no trailing slash
|
|
if server_url.startswith("ws://"):
|
|
return "http://" + server_url[len("ws://") :].rstrip("/")
|
|
if server_url.startswith("wss://"):
|
|
return "https://" + server_url[len("wss://") :].rstrip("/")
|
|
return server_url.rstrip("/")
|
|
|
|
|
|
def _ws_url(server_url: str) -> str:
|
|
# Convert http(s) to ws(s) if needed
|
|
if server_url.startswith("http://"):
|
|
return "ws://" + server_url[len("http://") :].rstrip("/")
|
|
if server_url.startswith("https://"):
|
|
return "wss://" + server_url[len("https://") :].rstrip("/")
|
|
return server_url.rstrip("/")
|
|
|
|
|
|
def create_or_get_session(
|
|
server_url: str, session_id: str | None = None, insecure: bool = False
|
|
) -> str:
|
|
"""Call GET /api/session to obtain a session_id (unless one was provided).
|
|
|
|
Uses urllib so no extra runtime deps are required.
|
|
"""
|
|
if session_id:
|
|
return session_id
|
|
|
|
http_base = _http_base_url(server_url)
|
|
url = f"{http_base}/api/session"
|
|
req = urllib.request.Request(url, method="GET")
|
|
# Prepare SSL context if requested (accept self-signed certs)
|
|
ssl_ctx = None
|
|
if insecure:
|
|
ssl_ctx = ssl.create_default_context()
|
|
ssl_ctx.check_hostname = False
|
|
ssl_ctx.verify_mode = ssl.CERT_NONE
|
|
|
|
try:
|
|
with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp:
|
|
body = resp.read()
|
|
data = json.loads(body)
|
|
|
|
# Validate response shape using Pydantic
|
|
try:
|
|
session = SessionModel.model_validate(data)
|
|
except ValidationError as e:
|
|
raise RuntimeError(f"Invalid session response from {url}: {e}")
|
|
|
|
sid = session.id
|
|
if not sid:
|
|
raise RuntimeError(f"No session id returned from {url}: {data}")
|
|
return sid
|
|
except urllib.error.HTTPError as e:
|
|
raise RuntimeError(f"HTTP error getting session: {e}")
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error getting session: {e}")
|
|
|
|
|
|
def create_or_get_lobby(
|
|
server_url: str,
|
|
session_id: str,
|
|
lobby_name: str,
|
|
private: bool = False,
|
|
insecure: bool = False,
|
|
) -> str:
|
|
"""Call POST /api/lobby/{session_id} to create or lookup a lobby by name.
|
|
|
|
Returns the lobby id.
|
|
"""
|
|
http_base = _http_base_url(server_url)
|
|
url = f"{http_base}/api/lobby/{urllib.parse.quote(session_id)}"
|
|
payload = json.dumps(
|
|
{
|
|
"type": "lobby_create",
|
|
"data": {"name": lobby_name, "private": private},
|
|
}
|
|
).encode("utf-8")
|
|
req = urllib.request.Request(
|
|
url, data=payload, headers={"Content-Type": "application/json"}, method="POST"
|
|
)
|
|
# Prepare SSL context if requested (accept self-signed certs)
|
|
ssl_ctx = None
|
|
if insecure:
|
|
ssl_ctx = ssl.create_default_context()
|
|
ssl_ctx.check_hostname = False
|
|
ssl_ctx.verify_mode = ssl.CERT_NONE
|
|
|
|
try:
|
|
with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp:
|
|
body = resp.read()
|
|
data = json.loads(body)
|
|
# Expect shape: { "type": "lobby_created", "data": {"id":..., ...}}
|
|
try:
|
|
lobby_resp = LobbyCreateResponse.model_validate(data)
|
|
except ValidationError as e:
|
|
raise RuntimeError(f"Invalid lobby response from {url}: {e}")
|
|
|
|
lobby_id = lobby_resp.data.id
|
|
if not lobby_id:
|
|
raise RuntimeError(f"No lobby id returned from {url}: {data}")
|
|
return lobby_id
|
|
except urllib.error.HTTPError as e:
|
|
# Try to include response body for.infoging
|
|
try:
|
|
body = e.read()
|
|
msg = body.decode("utf-8", errors="ignore")
|
|
except Exception:
|
|
msg = str(e)
|
|
raise RuntimeError(f"HTTP error creating lobby: {msg}")
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error creating lobby: {e}")
|
|
|
|
|
|
async def main():
|
|
"""Example usage of the WebRTC signaling client with CLI options to create session and lobby."""
|
|
parser = argparse.ArgumentParser(description="Python WebRTC voicebot client")
|
|
parser.add_argument(
|
|
"--server-url",
|
|
default="http://localhost:8000/ai-voicebot",
|
|
help="AI-Voicebot lobby and signaling server base URL (http:// or https://)",
|
|
)
|
|
parser.add_argument(
|
|
"--lobby", default="default", help="Lobby name to create or join"
|
|
)
|
|
parser.add_argument(
|
|
"--session-name", default="Python Bot", help="Session (user) display name"
|
|
)
|
|
parser.add_argument(
|
|
"--session-id", default=None, help="Optional existing session id to reuse"
|
|
)
|
|
parser.add_argument(
|
|
"--password",
|
|
default=None,
|
|
help="Optional password to register or takeover a name",
|
|
)
|
|
parser.add_argument(
|
|
"--private", action="store_true", help="Create the lobby as private"
|
|
)
|
|
parser.add_argument(
|
|
"--insecure",
|
|
action="store_true",
|
|
help="Allow insecure server connections when using SSL (accept self-signed certs)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Resolve session id (create if needed)
|
|
try:
|
|
session_id = create_or_get_session(
|
|
args.server_url, args.session_id, insecure=args.insecure
|
|
)
|
|
print(f"Using session id: {session_id}")
|
|
except Exception as e:
|
|
print(f"Failed to get/create session: {e}")
|
|
return
|
|
|
|
# Create or get lobby id
|
|
try:
|
|
lobby_id = create_or_get_lobby(
|
|
args.server_url,
|
|
session_id,
|
|
args.lobby,
|
|
args.private,
|
|
insecure=args.insecure,
|
|
)
|
|
print(f"Using lobby id: {lobby_id} (name={args.lobby})")
|
|
except Exception as e:
|
|
print(f"Failed to create/get lobby: {e}")
|
|
return
|
|
|
|
# Build websocket base URL (ws:// or wss://) from server_url and pass to client so
|
|
# it constructs the final websocket path (/ws/lobby/{lobby}/{session}) itself.
|
|
ws_base = _ws_url(args.server_url)
|
|
|
|
client = WebRTCSignalingClient(
|
|
ws_base, lobby_id, session_id, args.session_name, insecure=args.insecure
|
|
)
|
|
|
|
# Set up event handlers
|
|
async def on_peer_added(peer: Peer):
|
|
print(f"Peer added: {peer.peer_name}")
|
|
|
|
async def on_peer_removed(peer: Peer):
|
|
print(f"Peer removed: {peer.peer_name}")
|
|
|
|
async def on_track_received(peer: Peer, track: MediaStreamTrack):
|
|
print(f"Received {track.kind} track from {peer.peer_name}")
|
|
|
|
client.on_peer_added = on_peer_added
|
|
client.on_peer_removed = on_peer_removed
|
|
client.on_track_received = on_track_received
|
|
|
|
try:
|
|
# Connect and run
|
|
# If a password was provided on the CLI, store it on the client for use when setting name
|
|
if args.password:
|
|
client.name_password = args.password
|
|
await client.connect()
|
|
except KeyboardInterrupt:
|
|
print("Shutting down...")
|
|
finally:
|
|
await client.disconnect()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Install required packages:
|
|
# pip install aiortc websockets opencv-python numpy
|
|
|
|
asyncio.run(main())
|