""" WebRTC Media Agent for Python This module provides synthetic audio/video track creation and WebRTC signaling server communication, ported from the JavaScript MediaControl implementation. """ from __future__ import annotations import asyncio import json import logging import websockets import numpy as np import cv2 import fractions from typing import Dict, Optional, Callable, Awaitable, TypedDict, Protocol, AsyncIterator, cast from dataclasses import dataclass, field from aiortc import RTCPeerConnection, RTCSessionDescription, RTCIceCandidate, MediaStreamTrack from av import VideoFrame, AudioFrame import time # TypedDict for ICE candidate payloads received from signalling class ICECandidateDict(TypedDict, total=False): candidate: str sdpMid: Optional[str] sdpMLineIndex: Optional[int] # Generic message payload type MessageData = dict[str, object] # Message TypedDicts for signaling payloads class BaseMessage(TypedDict, total=False): type: str data: object class AddPeerPayload(TypedDict): peer_id: str peer_name: str should_create_offer: bool class RemovePeerPayload(TypedDict): peer_id: str peer_name: str class SessionDescriptionTyped(TypedDict): type: str sdp: str class SessionDescriptionPayload(TypedDict): peer_id: str peer_name: str session_description: SessionDescriptionTyped class IceCandidatePayload(TypedDict): peer_id: str peer_name: str candidate: ICECandidateDict class WebSocketProtocol(Protocol): def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ... def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ... def __aiter__(self) -> AsyncIterator[str]: ... # (imports moved to top) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def _default_attributes() -> Dict[str, object]: return {} @dataclass class Peer: """Represents a WebRTC peer in the session""" session_id: str peer_name: str # Generic attributes bag. Values can be tracks or simple metadata. attributes: Dict[str, object] = field(default_factory=_default_attributes) muted: bool = False video_on: bool = True local: bool = False dead: bool = False connection: Optional[RTCPeerConnection] = None class AnimatedVideoTrack(MediaStreamTrack): async def next_timestamp(self): # Returns (pts, time_base) for 15 FPS video pts = int(self.frame_count * (1/15) * 90000) time_base = 1/90000 return pts, time_base """ Synthetic video track that generates animated content with a bouncing ball. Ported from JavaScript createAnimatedVideoTrack function. """ kind = "video" def __init__(self, width: int = 320, height: int = 240, name: str = ""): super().__init__() self.width = width self.height = height self.name = name # Generate color from name hash (similar to JavaScript nameToColor) self.ball_color = self._name_to_color(name) if name else (0, 255, 136) # Default green # Ball properties self.ball = { 'x': width / 2, 'y': height / 2, 'radius': min(width, height) * 0.06, 'dx': 3.0, 'dy': 2.0 } self.frame_count = 0 self._start_time = time.time() def _name_to_color(self, name: str) -> tuple[int, int, int]: """Convert name to HSL color, then to RGB tuple""" # Simple hash function (djb2) hash_value = 5381 for char in name: hash_value = ((hash_value << 5) + hash_value + ord(char)) & 0xFFFFFFFF # Generate HSL color from hash hue = abs(hash_value) % 360 sat = 60 + (abs(hash_value) % 30) # 60-89% light = 45 + (abs(hash_value) % 30) # 45-74% # Convert HSL to RGB h = hue / 360.0 s = sat / 100.0 lightness = light / 100.0 c = (1 - abs(2 * lightness - 1)) * s x = c * (1 - abs((h * 6) % 2 - 1)) m = lightness - c / 2 if h < 1/6: r, g, b = c, x, 0 elif h < 2/6: r, g, b = x, c, 0 elif h < 3/6: r, g, b = 0, c, x elif h < 4/6: r, g, b = 0, x, c elif h < 5/6: r, g, b = x, 0, c else: r, g, b = c, 0, x return (int((b + m) * 255), int((g + m) * 255), int((r + m) * 255)) # BGR for OpenCV async def recv(self): """Generate video frames at 15 FPS""" pts, time_base = await self.next_timestamp() # Create black background frame_array = np.zeros((self.height, self.width, 3), dtype=np.uint8) # Update ball position ball = self.ball ball['x'] += ball['dx'] ball['y'] += ball['dy'] # Bounce off walls if ball['x'] + ball['radius'] >= self.width or ball['x'] - ball['radius'] <= 0: ball['dx'] = -ball['dx'] if ball['y'] + ball['radius'] >= self.height or ball['y'] - ball['radius'] <= 0: ball['dy'] = -ball['dy'] # Keep ball in bounds ball['x'] = max(ball['radius'], min(self.width - ball['radius'], ball['x'])) ball['y'] = max(ball['radius'], min(self.height - ball['radius'], ball['y'])) # Draw ball cv2.circle(frame_array, (int(ball['x']), int(ball['y'])), int(ball['radius']), self.ball_color, -1) # Add frame counter text frame_text = f"Frame: {int(time.time() * 1000) % 10000}" cv2.putText(frame_array, frame_text, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) # Convert to VideoFrame frame = VideoFrame.from_ndarray(frame_array, format="bgr24") frame.pts = pts frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000) self.frame_count += 1 return frame class SilentAudioTrack(MediaStreamTrack): async def next_timestamp(self): # Returns (pts, time_base) for 20ms audio frames at 48kHz pts = int(time.time() * self.sample_rate) time_base = 1/self.sample_rate return pts, time_base """ Synthetic audio track that generates silence. Ported from JavaScript createSilentAudioTrack function. """ kind = "audio" def __init__(self): super().__init__() self.sample_rate = 48000 self.samples_per_frame = 960 # 20ms at 48kHz async def recv(self): """Generate silent audio frames""" pts, time_base = await self.next_timestamp() # Create silent audio data samples = np.zeros((self.samples_per_frame,), dtype=np.float32) # Convert to AudioFrame frame = AudioFrame.from_ndarray(samples.reshape(1, -1), format="flt", layout="mono") frame.sample_rate = self.sample_rate frame.pts = pts frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000) return frame class WebRTCSignalingClient: """ WebRTC signaling client that communicates with the FastAPI signaling server. Handles peer-to-peer connection establishment and media streaming. """ def __init__(self, server_url: str, lobby_id: str, session_id: str, session_name: str): self.server_url = server_url self.lobby_id = lobby_id self.session_id = session_id self.session_name = session_name # WebSocket client protocol instance (typed as object to avoid Any) self.websocket: Optional[object] = None self.peers: dict[str, Peer] = {} self.peer_connections: dict[str, RTCPeerConnection] = {} self.local_tracks: dict[str, MediaStreamTrack] = {} # State management self.is_negotiating: dict[str, bool] = {} self.making_offer: dict[str, bool] = {} self.initiated_offer: set[str] = set() self.pending_ice_candidates: dict[str, list[ICECandidateDict]] = {} # Event callbacks self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None self.on_track_received: Optional[Callable[[Peer, MediaStreamTrack], Awaitable[None]]] = None async def connect(self): """Connect to the signaling server""" ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}" logger.info(f"Connecting to signaling server: {ws_url}") try: self.websocket = await websockets.connect(ws_url) logger.info("Connected to signaling server") # Set up local media await self._setup_local_media() # Set name and join lobby await self._send_message("set_name", {"name": self.session_name}) await self._send_message("join", {}) # Start message handling await self._handle_messages() except Exception as e: logger.error(f"Failed to connect to signaling server: {e}") raise async def disconnect(self): """Disconnect from signaling server and cleanup""" if self.websocket: ws = cast(WebSocketProtocol, self.websocket) await ws.close() # Close all peer connections for pc in self.peer_connections.values(): await pc.close() # Stop local tracks for track in self.local_tracks.values(): track.stop() logger.info("Disconnected from signaling server") async def _setup_local_media(self): """Create local synthetic media tracks""" # Create synthetic video track video_track = AnimatedVideoTrack(name=self.session_name) self.local_tracks["video"] = video_track # Create synthetic audio track audio_track = SilentAudioTrack() self.local_tracks["audio"] = audio_track # Add local peer to peers dict local_peer = Peer( session_id=self.session_id, peer_name=self.session_name, local=True, attributes={"tracks": self.local_tracks} ) self.peers[self.session_id] = local_peer logger.info("Local synthetic media tracks created") async def _send_message(self, message_type: str, data: Optional[MessageData] = None): """Send message to signaling server""" if not self.websocket: logger.error("No websocket connection") return # Build message with explicit type to avoid type narrowing message: dict[str, object] = {"type": message_type} if data is not None: message["data"] = data ws = cast(WebSocketProtocol, self.websocket) await ws.send(json.dumps(message)) logger.debug(f"Sent message: {message_type}") async def _handle_messages(self): """Handle incoming messages from signaling server""" try: ws = cast(WebSocketProtocol, self.websocket) async for message in ws: data = cast(MessageData, json.loads(message)) await self._process_message(data) except websockets.exceptions.ConnectionClosed: logger.info("WebSocket connection closed") except Exception as e: logger.error(f"Error handling messages: {e}") async def _process_message(self, message: MessageData): """Process incoming signaling messages""" msg_type = message.get("type") data = message.get("data", {}) logger.debug(f"Received message: {msg_type}") if msg_type == "addPeer": await self._handle_add_peer(cast(AddPeerPayload, data)) elif msg_type == "removePeer": await self._handle_remove_peer(cast(RemovePeerPayload, data)) elif msg_type == "sessionDescription": await self._handle_session_description(cast(SessionDescriptionPayload, data)) elif msg_type == "iceCandidate": await self._handle_ice_candidate(cast(IceCandidatePayload, data)) elif msg_type == "join_status": dd = cast(MessageData, data) logger.info(f"Join status: {dd.get('status')} - {dd.get('message', '')}") else: logger.debug(f"Unhandled message type: {msg_type}") async def _handle_add_peer(self, data: AddPeerPayload): """Handle addPeer message - create new peer connection""" peer_id = data["peer_id"] peer_name = data["peer_name"] should_create_offer = data.get("should_create_offer", False) logger.info(f"Adding peer: {peer_name} (should_create_offer: {should_create_offer})") # Check if peer already exists if peer_id in self.peer_connections: pc = self.peer_connections[peer_id] if pc.connectionState in ["new", "connected", "connecting"]: logger.info(f"Peer connection already exists for {peer_name}") return else: # Clean up stale connection await pc.close() del self.peer_connections[peer_id] # Create new peer peer = Peer( session_id=peer_id, peer_name=peer_name, local=False ) self.peers[peer_id] = peer # Create RTCPeerConnection from aiortc.rtcconfiguration import RTCConfiguration, RTCIceServer config = RTCConfiguration(iceServers=[ RTCIceServer(urls="stun:stun.l.google.com:19302"), RTCIceServer(urls="stun:stun1.l.google.com:19302"), RTCIceServer(urls="turns:ketrenos.com:5349", username="ketra", credential="ketran") ]) pc = RTCPeerConnection(configuration=config) self.peer_connections[peer_id] = pc peer.connection = pc # Set up event handlers def on_track(track: MediaStreamTrack) -> None: logger.info(f"Received {track.kind} track from {peer_name}") peer.attributes[f"{track.kind}_track"] = track if self.on_track_received: asyncio.ensure_future(self.on_track_received(peer, track)) pc.on("track")(on_track) def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: if candidate: candidate_dict: MessageData = { "candidate": getattr(candidate, "candidate", None), "sdpMid": getattr(candidate, "sdpMid", None), "sdpMLineIndex": getattr(candidate, "sdpMLineIndex", None), } payload: MessageData = {"peer_id": peer_id, "candidate": candidate_dict} asyncio.ensure_future(self._send_message("relayICECandidate", payload)) pc.on("icecandidate")(on_ice_candidate) # Add local tracks for track in self.local_tracks.values(): pc.addTrack(track) # Create offer if needed if should_create_offer: self.initiated_offer.add(peer_id) self.making_offer[peer_id] = True self.is_negotiating[peer_id] = True try: offer = await pc.createOffer() await pc.setLocalDescription(offer) await self._send_message("relaySessionDescription", { "peer_id": peer_id, "session_description": { "type": offer.type, "sdp": offer.sdp } }) logger.info(f"Offer sent to {peer_name}") except Exception as e: logger.error(f"Failed to create/send offer to {peer_name}: {e}") finally: self.making_offer[peer_id] = False if self.on_peer_added: await self.on_peer_added(peer) async def _handle_remove_peer(self, data: RemovePeerPayload): """Handle removePeer message""" peer_id = data["peer_id"] peer_name = data["peer_name"] logger.info(f"Removing peer: {peer_name}") # Close peer connection if peer_id in self.peer_connections: pc = self.peer_connections[peer_id] await pc.close() del self.peer_connections[peer_id] # Clean up state self.is_negotiating.pop(peer_id, None) self.making_offer.pop(peer_id, None) self.initiated_offer.discard(peer_id) self.pending_ice_candidates.pop(peer_id, None) # Remove peer peer = self.peers.pop(peer_id, None) if peer and self.on_peer_removed: await self.on_peer_removed(peer) async def _handle_session_description(self, data: SessionDescriptionPayload): """Handle sessionDescription message""" peer_id = data["peer_id"] peer_name = data["peer_name"] session_description = data["session_description"] logger.info(f"Received {session_description['type']} from {peer_name}") pc = self.peer_connections.get(peer_id) if not pc: logger.error(f"No peer connection for {peer_name}") return desc = RTCSessionDescription(sdp=session_description["sdp"], type=session_description["type"]) # Handle offer collision (polite peer pattern) making_offer = self.making_offer.get(peer_id, False) offer_collision = desc.type == "offer" and (making_offer or pc.signalingState != "stable") we_initiated = peer_id in self.initiated_offer ignore_offer = we_initiated and offer_collision if ignore_offer: logger.info(f"Ignoring offer from {peer_name} due to collision") return try: await pc.setRemoteDescription(desc) self.is_negotiating[peer_id] = False logger.info(f"Remote description set for {peer_name}") # Process queued ICE candidates pending_candidates = self.pending_ice_candidates.pop(peer_id, []) for candidate_data in pending_candidates: # pass through the dict received from signaling server await pc.addIceCandidate(cast(RTCIceCandidate, candidate_data)) logger.info(f"Added queued ICE candidate for {peer_name}") except Exception as e: logger.error(f"Failed to set remote description for {peer_name}: {e}") return # Create answer if this was an offer if session_description["type"] == "offer": try: answer = await pc.createAnswer() await pc.setLocalDescription(answer) await self._send_message("relaySessionDescription", { "peer_id": peer_id, "session_description": { "type": answer.type, "sdp": answer.sdp } }) logger.info(f"Answer sent to {peer_name}") except Exception as e: logger.error(f"Failed to create/send answer to {peer_name}: {e}") async def _handle_ice_candidate(self, data: IceCandidatePayload): """Handle iceCandidate message""" peer_id = data["peer_id"] peer_name = data["peer_name"] candidate_data = data["candidate"] logger.debug(f"Received ICE candidate from {peer_name}") pc = self.peer_connections.get(peer_id) if not pc: logger.error(f"No peer connection for {peer_name}") return # Queue candidate if remote description not set if not pc.remoteDescription: logger.info(f"Remote description not set, queuing ICE candidate for {peer_name}") if peer_id not in self.pending_ice_candidates: self.pending_ice_candidates[peer_id] = [] self.pending_ice_candidates[peer_id].append(candidate_data) return try: await pc.addIceCandidate(cast(RTCIceCandidate, candidate_data)) logger.debug(f"ICE candidate added for {peer_name}") except Exception as e: logger.error(f"Failed to add ICE candidate for {peer_name}: {e}") # Example usage async def main(): """Example usage of the WebRTC signaling client""" # Configuration SERVER_URL = "ws://localhost:8000" # Adjust to your server LOBBY_ID = "test-lobby" SESSION_ID = "python-client-001" SESSION_NAME = "Python Bot" client = WebRTCSignalingClient(SERVER_URL, LOBBY_ID, SESSION_ID, SESSION_NAME) # Set up event handlers async def on_peer_added(peer: Peer): print(f"Peer added: {peer.peer_name}") async def on_peer_removed(peer: Peer): print(f"Peer removed: {peer.peer_name}") async def on_track_received(peer: Peer, track: MediaStreamTrack): print(f"Received {track.kind} track from {peer.peer_name}") # You could save/process the received media here client.on_peer_added = on_peer_added client.on_peer_removed = on_peer_removed client.on_track_received = on_track_received try: # Connect and run await client.connect() except KeyboardInterrupt: print("Shutting down...") finally: await client.disconnect() if __name__ == "__main__": # Install required packages: # pip install aiortc websockets opencv-python numpy asyncio.run(main())