616 lines
22 KiB
Python
616 lines
22 KiB
Python
"""
|
|
WebRTC Media Agent for Python
|
|
|
|
This module provides synthetic audio/video track creation and WebRTC signaling
|
|
server communication, ported from the JavaScript MediaControl implementation.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import websockets
|
|
import numpy as np
|
|
import cv2
|
|
import fractions
|
|
from typing import Dict, Optional, Callable, Awaitable, TypedDict, Protocol, AsyncIterator, cast
|
|
|
|
from dataclasses import dataclass, field
|
|
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCIceCandidate, MediaStreamTrack
|
|
from av import VideoFrame, AudioFrame
|
|
import time
|
|
|
|
# TypedDict for ICE candidate payloads received from signalling
|
|
class ICECandidateDict(TypedDict, total=False):
|
|
candidate: str
|
|
sdpMid: Optional[str]
|
|
sdpMLineIndex: Optional[int]
|
|
|
|
# Generic message payload type
|
|
MessageData = dict[str, object]
|
|
|
|
|
|
# Message TypedDicts for signaling payloads
|
|
class BaseMessage(TypedDict, total=False):
|
|
type: str
|
|
data: object
|
|
|
|
|
|
class AddPeerPayload(TypedDict):
|
|
peer_id: str
|
|
peer_name: str
|
|
should_create_offer: bool
|
|
|
|
|
|
class RemovePeerPayload(TypedDict):
|
|
peer_id: str
|
|
peer_name: str
|
|
|
|
|
|
class SessionDescriptionTyped(TypedDict):
|
|
type: str
|
|
sdp: str
|
|
|
|
|
|
class SessionDescriptionPayload(TypedDict):
|
|
peer_id: str
|
|
peer_name: str
|
|
session_description: SessionDescriptionTyped
|
|
|
|
|
|
class IceCandidatePayload(TypedDict):
|
|
peer_id: str
|
|
peer_name: str
|
|
candidate: ICECandidateDict
|
|
|
|
|
|
class WebSocketProtocol(Protocol):
|
|
def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ...
|
|
def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ...
|
|
def __aiter__(self) -> AsyncIterator[str]: ...
|
|
|
|
# (imports moved to top)
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _default_attributes() -> Dict[str, object]:
|
|
return {}
|
|
|
|
@dataclass
|
|
class Peer:
|
|
"""Represents a WebRTC peer in the session"""
|
|
session_id: str
|
|
peer_name: str
|
|
# Generic attributes bag. Values can be tracks or simple metadata.
|
|
attributes: Dict[str, object] = field(default_factory=_default_attributes)
|
|
muted: bool = False
|
|
video_on: bool = True
|
|
local: bool = False
|
|
dead: bool = False
|
|
connection: Optional[RTCPeerConnection] = None
|
|
|
|
class AnimatedVideoTrack(MediaStreamTrack):
|
|
async def next_timestamp(self):
|
|
# Returns (pts, time_base) for 15 FPS video
|
|
pts = int(self.frame_count * (1/15) * 90000)
|
|
time_base = 1/90000
|
|
return pts, time_base
|
|
"""
|
|
Synthetic video track that generates animated content with a bouncing ball.
|
|
Ported from JavaScript createAnimatedVideoTrack function.
|
|
"""
|
|
kind = "video"
|
|
|
|
def __init__(self, width: int = 320, height: int = 240, name: str = ""):
|
|
super().__init__()
|
|
self.width = width
|
|
self.height = height
|
|
self.name = name
|
|
|
|
# Generate color from name hash (similar to JavaScript nameToColor)
|
|
self.ball_color = self._name_to_color(name) if name else (0, 255, 136) # Default green
|
|
|
|
# Ball properties
|
|
self.ball = {
|
|
'x': width / 2,
|
|
'y': height / 2,
|
|
'radius': min(width, height) * 0.06,
|
|
'dx': 3.0,
|
|
'dy': 2.0
|
|
}
|
|
|
|
self.frame_count = 0
|
|
self._start_time = time.time()
|
|
|
|
def _name_to_color(self, name: str) -> tuple[int, int, int]:
|
|
"""Convert name to HSL color, then to RGB tuple"""
|
|
# Simple hash function (djb2)
|
|
hash_value = 5381
|
|
for char in name:
|
|
hash_value = ((hash_value << 5) + hash_value + ord(char)) & 0xFFFFFFFF
|
|
|
|
# Generate HSL color from hash
|
|
hue = abs(hash_value) % 360
|
|
sat = 60 + (abs(hash_value) % 30) # 60-89%
|
|
light = 45 + (abs(hash_value) % 30) # 45-74%
|
|
|
|
# Convert HSL to RGB
|
|
h = hue / 360.0
|
|
s = sat / 100.0
|
|
lightness = light / 100.0
|
|
|
|
c = (1 - abs(2 * lightness - 1)) * s
|
|
x = c * (1 - abs((h * 6) % 2 - 1))
|
|
m = lightness - c / 2
|
|
|
|
if h < 1/6:
|
|
r, g, b = c, x, 0
|
|
elif h < 2/6:
|
|
r, g, b = x, c, 0
|
|
elif h < 3/6:
|
|
r, g, b = 0, c, x
|
|
elif h < 4/6:
|
|
r, g, b = 0, x, c
|
|
elif h < 5/6:
|
|
r, g, b = x, 0, c
|
|
else:
|
|
r, g, b = c, 0, x
|
|
|
|
return (int((b + m) * 255), int((g + m) * 255), int((r + m) * 255)) # BGR for OpenCV
|
|
|
|
async def recv(self):
|
|
"""Generate video frames at 15 FPS"""
|
|
pts, time_base = await self.next_timestamp()
|
|
|
|
# Create black background
|
|
frame_array = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
|
|
|
# Update ball position
|
|
ball = self.ball
|
|
ball['x'] += ball['dx']
|
|
ball['y'] += ball['dy']
|
|
|
|
# Bounce off walls
|
|
if ball['x'] + ball['radius'] >= self.width or ball['x'] - ball['radius'] <= 0:
|
|
ball['dx'] = -ball['dx']
|
|
if ball['y'] + ball['radius'] >= self.height or ball['y'] - ball['radius'] <= 0:
|
|
ball['dy'] = -ball['dy']
|
|
|
|
# Keep ball in bounds
|
|
ball['x'] = max(ball['radius'], min(self.width - ball['radius'], ball['x']))
|
|
ball['y'] = max(ball['radius'], min(self.height - ball['radius'], ball['y']))
|
|
|
|
# Draw ball
|
|
cv2.circle(frame_array,
|
|
(int(ball['x']), int(ball['y'])),
|
|
int(ball['radius']),
|
|
self.ball_color,
|
|
-1)
|
|
|
|
# Add frame counter text
|
|
frame_text = f"Frame: {int(time.time() * 1000) % 10000}"
|
|
cv2.putText(frame_array, frame_text, (10, 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
|
|
|
# Convert to VideoFrame
|
|
frame = VideoFrame.from_ndarray(frame_array, format="bgr24")
|
|
frame.pts = pts
|
|
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
|
|
|
|
self.frame_count += 1
|
|
return frame
|
|
|
|
class SilentAudioTrack(MediaStreamTrack):
|
|
async def next_timestamp(self):
|
|
# Returns (pts, time_base) for 20ms audio frames at 48kHz
|
|
pts = int(time.time() * self.sample_rate)
|
|
time_base = 1/self.sample_rate
|
|
return pts, time_base
|
|
"""
|
|
Synthetic audio track that generates silence.
|
|
Ported from JavaScript createSilentAudioTrack function.
|
|
"""
|
|
kind = "audio"
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.sample_rate = 48000
|
|
self.samples_per_frame = 960 # 20ms at 48kHz
|
|
|
|
async def recv(self):
|
|
"""Generate silent audio frames"""
|
|
pts, time_base = await self.next_timestamp()
|
|
|
|
# Create silent audio data
|
|
samples = np.zeros((self.samples_per_frame,), dtype=np.float32)
|
|
|
|
# Convert to AudioFrame
|
|
frame = AudioFrame.from_ndarray(samples.reshape(1, -1), format="flt", layout="mono")
|
|
frame.sample_rate = self.sample_rate
|
|
frame.pts = pts
|
|
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
|
|
|
|
return frame
|
|
|
|
class WebRTCSignalingClient:
|
|
"""
|
|
WebRTC signaling client that communicates with the FastAPI signaling server.
|
|
Handles peer-to-peer connection establishment and media streaming.
|
|
"""
|
|
|
|
def __init__(self, server_url: str, lobby_id: str, session_id: str, session_name: str):
|
|
self.server_url = server_url
|
|
self.lobby_id = lobby_id
|
|
self.session_id = session_id
|
|
self.session_name = session_name
|
|
|
|
# WebSocket client protocol instance (typed as object to avoid Any)
|
|
self.websocket: Optional[object] = None
|
|
|
|
self.peers: dict[str, Peer] = {}
|
|
self.peer_connections: dict[str, RTCPeerConnection] = {}
|
|
self.local_tracks: dict[str, MediaStreamTrack] = {}
|
|
|
|
# State management
|
|
self.is_negotiating: dict[str, bool] = {}
|
|
self.making_offer: dict[str, bool] = {}
|
|
self.initiated_offer: set[str] = set()
|
|
self.pending_ice_candidates: dict[str, list[ICECandidateDict]] = {}
|
|
|
|
# Event callbacks
|
|
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
|
|
self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None
|
|
self.on_track_received: Optional[Callable[[Peer, MediaStreamTrack], Awaitable[None]]] = None
|
|
|
|
async def connect(self):
|
|
"""Connect to the signaling server"""
|
|
ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
|
|
logger.info(f"Connecting to signaling server: {ws_url}")
|
|
|
|
try:
|
|
self.websocket = await websockets.connect(ws_url)
|
|
logger.info("Connected to signaling server")
|
|
|
|
# Set up local media
|
|
await self._setup_local_media()
|
|
|
|
# Set name and join lobby
|
|
await self._send_message("set_name", {"name": self.session_name})
|
|
await self._send_message("join", {})
|
|
|
|
# Start message handling
|
|
await self._handle_messages()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to signaling server: {e}")
|
|
raise
|
|
|
|
async def disconnect(self):
|
|
"""Disconnect from signaling server and cleanup"""
|
|
if self.websocket:
|
|
ws = cast(WebSocketProtocol, self.websocket)
|
|
await ws.close()
|
|
|
|
# Close all peer connections
|
|
for pc in self.peer_connections.values():
|
|
await pc.close()
|
|
|
|
# Stop local tracks
|
|
for track in self.local_tracks.values():
|
|
track.stop()
|
|
|
|
logger.info("Disconnected from signaling server")
|
|
|
|
async def _setup_local_media(self):
|
|
"""Create local synthetic media tracks"""
|
|
# Create synthetic video track
|
|
video_track = AnimatedVideoTrack(name=self.session_name)
|
|
self.local_tracks["video"] = video_track
|
|
|
|
# Create synthetic audio track
|
|
audio_track = SilentAudioTrack()
|
|
self.local_tracks["audio"] = audio_track
|
|
|
|
# Add local peer to peers dict
|
|
local_peer = Peer(
|
|
session_id=self.session_id,
|
|
peer_name=self.session_name,
|
|
local=True,
|
|
attributes={"tracks": self.local_tracks}
|
|
)
|
|
self.peers[self.session_id] = local_peer
|
|
|
|
logger.info("Local synthetic media tracks created")
|
|
|
|
async def _send_message(self, message_type: str, data: Optional[MessageData] = None):
|
|
"""Send message to signaling server"""
|
|
if not self.websocket:
|
|
logger.error("No websocket connection")
|
|
return
|
|
# Build message with explicit type to avoid type narrowing
|
|
message: dict[str, object] = {"type": message_type}
|
|
if data is not None:
|
|
message["data"] = data
|
|
|
|
ws = cast(WebSocketProtocol, self.websocket)
|
|
await ws.send(json.dumps(message))
|
|
logger.debug(f"Sent message: {message_type}")
|
|
|
|
async def _handle_messages(self):
|
|
"""Handle incoming messages from signaling server"""
|
|
try:
|
|
ws = cast(WebSocketProtocol, self.websocket)
|
|
async for message in ws:
|
|
data = cast(MessageData, json.loads(message))
|
|
await self._process_message(data)
|
|
except websockets.exceptions.ConnectionClosed:
|
|
logger.info("WebSocket connection closed")
|
|
except Exception as e:
|
|
logger.error(f"Error handling messages: {e}")
|
|
|
|
async def _process_message(self, message: MessageData):
|
|
"""Process incoming signaling messages"""
|
|
msg_type = message.get("type")
|
|
data = message.get("data", {})
|
|
|
|
logger.debug(f"Received message: {msg_type}")
|
|
|
|
if msg_type == "addPeer":
|
|
await self._handle_add_peer(cast(AddPeerPayload, data))
|
|
elif msg_type == "removePeer":
|
|
await self._handle_remove_peer(cast(RemovePeerPayload, data))
|
|
elif msg_type == "sessionDescription":
|
|
await self._handle_session_description(cast(SessionDescriptionPayload, data))
|
|
elif msg_type == "iceCandidate":
|
|
await self._handle_ice_candidate(cast(IceCandidatePayload, data))
|
|
elif msg_type == "join_status":
|
|
dd = cast(MessageData, data)
|
|
logger.info(f"Join status: {dd.get('status')} - {dd.get('message', '')}")
|
|
else:
|
|
logger.debug(f"Unhandled message type: {msg_type}")
|
|
|
|
async def _handle_add_peer(self, data: AddPeerPayload):
|
|
"""Handle addPeer message - create new peer connection"""
|
|
peer_id = data["peer_id"]
|
|
peer_name = data["peer_name"]
|
|
should_create_offer = data.get("should_create_offer", False)
|
|
|
|
logger.info(f"Adding peer: {peer_name} (should_create_offer: {should_create_offer})")
|
|
|
|
# Check if peer already exists
|
|
if peer_id in self.peer_connections:
|
|
pc = self.peer_connections[peer_id]
|
|
if pc.connectionState in ["new", "connected", "connecting"]:
|
|
logger.info(f"Peer connection already exists for {peer_name}")
|
|
return
|
|
else:
|
|
# Clean up stale connection
|
|
await pc.close()
|
|
del self.peer_connections[peer_id]
|
|
|
|
# Create new peer
|
|
peer = Peer(
|
|
session_id=peer_id,
|
|
peer_name=peer_name,
|
|
local=False
|
|
)
|
|
self.peers[peer_id] = peer
|
|
|
|
# Create RTCPeerConnection
|
|
from aiortc.rtcconfiguration import RTCConfiguration, RTCIceServer
|
|
config = RTCConfiguration(iceServers=[
|
|
RTCIceServer(urls="stun:stun.l.google.com:19302"),
|
|
RTCIceServer(urls="stun:stun1.l.google.com:19302"),
|
|
RTCIceServer(urls="turns:ketrenos.com:5349", username="ketra", credential="ketran")
|
|
])
|
|
pc = RTCPeerConnection(configuration=config)
|
|
|
|
self.peer_connections[peer_id] = pc
|
|
peer.connection = pc
|
|
|
|
# Set up event handlers
|
|
def on_track(track: MediaStreamTrack) -> None:
|
|
logger.info(f"Received {track.kind} track from {peer_name}")
|
|
peer.attributes[f"{track.kind}_track"] = track
|
|
if self.on_track_received:
|
|
asyncio.ensure_future(self.on_track_received(peer, track))
|
|
|
|
pc.on("track")(on_track)
|
|
|
|
def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None:
|
|
if candidate:
|
|
candidate_dict: MessageData = {
|
|
"candidate": getattr(candidate, "candidate", None),
|
|
"sdpMid": getattr(candidate, "sdpMid", None),
|
|
"sdpMLineIndex": getattr(candidate, "sdpMLineIndex", None),
|
|
}
|
|
payload: MessageData = {"peer_id": peer_id, "candidate": candidate_dict}
|
|
asyncio.ensure_future(self._send_message("relayICECandidate", payload))
|
|
|
|
pc.on("icecandidate")(on_ice_candidate)
|
|
|
|
# Add local tracks
|
|
for track in self.local_tracks.values():
|
|
pc.addTrack(track)
|
|
|
|
# Create offer if needed
|
|
if should_create_offer:
|
|
self.initiated_offer.add(peer_id)
|
|
self.making_offer[peer_id] = True
|
|
self.is_negotiating[peer_id] = True
|
|
|
|
try:
|
|
offer = await pc.createOffer()
|
|
await pc.setLocalDescription(offer)
|
|
|
|
await self._send_message("relaySessionDescription", {
|
|
"peer_id": peer_id,
|
|
"session_description": {
|
|
"type": offer.type,
|
|
"sdp": offer.sdp
|
|
}
|
|
})
|
|
|
|
logger.info(f"Offer sent to {peer_name}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to create/send offer to {peer_name}: {e}")
|
|
finally:
|
|
self.making_offer[peer_id] = False
|
|
|
|
if self.on_peer_added:
|
|
await self.on_peer_added(peer)
|
|
|
|
async def _handle_remove_peer(self, data: RemovePeerPayload):
|
|
"""Handle removePeer message"""
|
|
peer_id = data["peer_id"]
|
|
peer_name = data["peer_name"]
|
|
|
|
logger.info(f"Removing peer: {peer_name}")
|
|
|
|
# Close peer connection
|
|
if peer_id in self.peer_connections:
|
|
pc = self.peer_connections[peer_id]
|
|
await pc.close()
|
|
del self.peer_connections[peer_id]
|
|
|
|
# Clean up state
|
|
self.is_negotiating.pop(peer_id, None)
|
|
self.making_offer.pop(peer_id, None)
|
|
self.initiated_offer.discard(peer_id)
|
|
self.pending_ice_candidates.pop(peer_id, None)
|
|
|
|
# Remove peer
|
|
peer = self.peers.pop(peer_id, None)
|
|
if peer and self.on_peer_removed:
|
|
await self.on_peer_removed(peer)
|
|
|
|
async def _handle_session_description(self, data: SessionDescriptionPayload):
|
|
"""Handle sessionDescription message"""
|
|
peer_id = data["peer_id"]
|
|
peer_name = data["peer_name"]
|
|
session_description = data["session_description"]
|
|
|
|
logger.info(f"Received {session_description['type']} from {peer_name}")
|
|
|
|
pc = self.peer_connections.get(peer_id)
|
|
if not pc:
|
|
logger.error(f"No peer connection for {peer_name}")
|
|
return
|
|
|
|
desc = RTCSessionDescription(sdp=session_description["sdp"], type=session_description["type"])
|
|
|
|
# Handle offer collision (polite peer pattern)
|
|
making_offer = self.making_offer.get(peer_id, False)
|
|
offer_collision = desc.type == "offer" and (making_offer or pc.signalingState != "stable")
|
|
we_initiated = peer_id in self.initiated_offer
|
|
ignore_offer = we_initiated and offer_collision
|
|
|
|
if ignore_offer:
|
|
logger.info(f"Ignoring offer from {peer_name} due to collision")
|
|
return
|
|
|
|
try:
|
|
await pc.setRemoteDescription(desc)
|
|
self.is_negotiating[peer_id] = False
|
|
logger.info(f"Remote description set for {peer_name}")
|
|
|
|
# Process queued ICE candidates
|
|
pending_candidates = self.pending_ice_candidates.pop(peer_id, [])
|
|
for candidate_data in pending_candidates:
|
|
# pass through the dict received from signaling server
|
|
await pc.addIceCandidate(cast(RTCIceCandidate, candidate_data))
|
|
logger.info(f"Added queued ICE candidate for {peer_name}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to set remote description for {peer_name}: {e}")
|
|
return
|
|
|
|
# Create answer if this was an offer
|
|
if session_description["type"] == "offer":
|
|
try:
|
|
answer = await pc.createAnswer()
|
|
await pc.setLocalDescription(answer)
|
|
|
|
await self._send_message("relaySessionDescription", {
|
|
"peer_id": peer_id,
|
|
"session_description": {
|
|
"type": answer.type,
|
|
"sdp": answer.sdp
|
|
}
|
|
})
|
|
|
|
logger.info(f"Answer sent to {peer_name}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to create/send answer to {peer_name}: {e}")
|
|
|
|
async def _handle_ice_candidate(self, data: IceCandidatePayload):
|
|
"""Handle iceCandidate message"""
|
|
peer_id = data["peer_id"]
|
|
peer_name = data["peer_name"]
|
|
candidate_data = data["candidate"]
|
|
|
|
logger.debug(f"Received ICE candidate from {peer_name}")
|
|
|
|
pc = self.peer_connections.get(peer_id)
|
|
if not pc:
|
|
logger.error(f"No peer connection for {peer_name}")
|
|
return
|
|
|
|
# Queue candidate if remote description not set
|
|
if not pc.remoteDescription:
|
|
logger.info(f"Remote description not set, queuing ICE candidate for {peer_name}")
|
|
if peer_id not in self.pending_ice_candidates:
|
|
self.pending_ice_candidates[peer_id] = []
|
|
self.pending_ice_candidates[peer_id].append(candidate_data)
|
|
return
|
|
|
|
try:
|
|
await pc.addIceCandidate(cast(RTCIceCandidate, candidate_data))
|
|
logger.debug(f"ICE candidate added for {peer_name}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to add ICE candidate for {peer_name}: {e}")
|
|
|
|
# Example usage
|
|
async def main():
|
|
"""Example usage of the WebRTC signaling client"""
|
|
|
|
# Configuration
|
|
SERVER_URL = "ws://localhost:8000" # Adjust to your server
|
|
LOBBY_ID = "test-lobby"
|
|
SESSION_ID = "python-client-001"
|
|
SESSION_NAME = "Python Bot"
|
|
|
|
client = WebRTCSignalingClient(SERVER_URL, LOBBY_ID, SESSION_ID, SESSION_NAME)
|
|
|
|
# Set up event handlers
|
|
async def on_peer_added(peer: Peer):
|
|
print(f"Peer added: {peer.peer_name}")
|
|
|
|
async def on_peer_removed(peer: Peer):
|
|
print(f"Peer removed: {peer.peer_name}")
|
|
|
|
async def on_track_received(peer: Peer, track: MediaStreamTrack):
|
|
print(f"Received {track.kind} track from {peer.peer_name}")
|
|
# You could save/process the received media here
|
|
|
|
client.on_peer_added = on_peer_added
|
|
client.on_peer_removed = on_peer_removed
|
|
client.on_track_received = on_track_received
|
|
|
|
try:
|
|
# Connect and run
|
|
await client.connect()
|
|
except KeyboardInterrupt:
|
|
print("Shutting down...")
|
|
finally:
|
|
await client.disconnect()
|
|
|
|
if __name__ == "__main__":
|
|
# Install required packages:
|
|
# pip install aiortc websockets opencv-python numpy
|
|
|
|
asyncio.run(main()) |