Refactored to use shared/models

This commit is contained in:
James Ketr 2025-09-01 15:20:21 -07:00
parent 7c5616fbd9
commit a313209768

View File

@ -15,7 +15,6 @@ from typing import (
Optional,
Callable,
Awaitable,
TypedDict,
Protocol,
AsyncIterator,
cast,
@ -47,6 +46,8 @@ from shared.models import (
RemovePeerModel,
SessionDescriptionModel,
IceCandidateModel,
ICECandidateDictModel,
SessionDescriptionTypedModel,
)
from aiortc import (
RTCPeerConnection,
@ -60,51 +61,10 @@ from synthetic_media import create_synthetic_tracks
# import debug_aioice
# TypedDict for ICE candidate payloads received from signalling
class ICECandidateDict(TypedDict, total=False):
candidate: str
sdpMid: Optional[str]
sdpMLineIndex: Optional[int]
# Generic message payload type
MessageData = dict[str, object]
# Message TypedDicts for signaling payloads
class BaseMessage(TypedDict, total=False):
type: str
data: object
class AddPeerPayload(TypedDict):
peer_id: str
peer_name: str
should_create_offer: bool
class RemovePeerPayload(TypedDict):
peer_id: str
peer_name: str
class SessionDescriptionTyped(TypedDict):
type: str
sdp: str
class SessionDescriptionPayload(TypedDict):
peer_id: str
peer_name: str
session_description: SessionDescriptionTyped
class IceCandidatePayload(TypedDict):
peer_id: str
peer_name: str
candidate: ICECandidateDict
class WebSocketProtocol(Protocol):
def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ...
def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ...
@ -164,7 +124,7 @@ class WebRTCSignalingClient:
self.is_negotiating: dict[str, bool] = {}
self.making_offer: dict[str, bool] = {}
self.initiated_offer: set[str] = set()
self.pending_ice_candidates: dict[str, list[ICECandidateDict]] = {}
self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {}
# Event callbacks
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
@ -488,16 +448,20 @@ class WebRTCSignalingClient:
f"ICE candidate outgoing for {peer_name}: type={cand_type} protocol={protocol} sdp={raw}"
)
candidate_dict: MessageData = {
"candidate": raw,
"sdpMid": getattr(candidate, "sdpMid", None),
"sdpMLineIndex": getattr(candidate, "sdpMLineIndex", None),
}
payload: MessageData = {"peer_id": peer_id, "candidate": candidate_dict}
logger.info(
f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_dict}"
candidate_model = ICECandidateDictModel(
candidate=raw,
sdpMid=getattr(candidate, "sdpMid", None),
sdpMLineIndex=getattr(candidate, "sdpMLineIndex", None),
)
payload_model = IceCandidateModel(
peer_id=peer_id, peer_name=peer_name, candidate=candidate_model
)
logger.info(
f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_model}"
)
asyncio.ensure_future(
self._send_message("relayICECandidate", payload_model.model_dump())
)
asyncio.ensure_future(self._send_message("relayICECandidate", payload))
pc.on("icecandidate")(on_ice_candidate)
@ -544,41 +508,48 @@ class WebRTCSignalingClient:
current_mid = line.split(':', 1)[1].strip()
elif line.startswith('a=candidate:'):
candidate_sdp = line[2:] # Remove 'a=' prefix
candidate_dict: MessageData = {
"candidate": candidate_sdp,
"sdpMid": current_mid,
"sdpMLineIndex": current_media_index,
}
payload_candidate: MessageData = {
"peer_id": peer_id,
"candidate": candidate_dict
}
candidate_model = ICECandidateDictModel(
candidate=candidate_sdp,
sdpMid=current_mid,
sdpMLineIndex=current_media_index,
)
payload_candidate = IceCandidateModel(
peer_id=peer_id,
peer_name=peer_name,
candidate=candidate_model,
)
logger.info(f"_handle_add_peer: Sending extracted ICE candidate for {peer_name}: {candidate_sdp[:60]}...")
await self._send_message("relayICECandidate", payload_candidate)
await self._send_message(
"relayICECandidate", payload_candidate.model_dump()
)
# Send end-of-candidates signal (empty candidate)
end_candidate_dict: MessageData = {
"candidate": "",
"sdpMid": None,
"sdpMLineIndex": None,
}
payload_end: MessageData = {
"peer_id": peer_id,
"candidate": end_candidate_dict
}
end_candidate_model = ICECandidateDictModel(
candidate="",
sdpMid=None,
sdpMLineIndex=None,
)
payload_end = IceCandidateModel(
peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model
)
logger.info(f"_handle_add_peer: Sending end-of-candidates signal for {peer_name}")
await self._send_message("relayICECandidate", payload_end)
await self._send_message("relayICECandidate", payload_end.model_dump())
logger.info(f"_handle_add_peer: Sent {len(candidate_lines)} ICE candidates to {peer_name}")
session_desc_typed = SessionDescriptionTypedModel(
type=offer.type, sdp=offer.sdp
)
session_desc_model = SessionDescriptionModel(
peer_id=peer_id,
peer_name=peer_name,
session_description=session_desc_typed,
)
await self._send_message(
"relaySessionDescription",
{
"peer_id": peer_id,
"session_description": {"type": offer.type, "sdp": offer.sdp},
},
session_desc_model.model_dump(),
)
logger.info(f"Offer sent to {peer_name}")
@ -655,8 +626,8 @@ class WebRTCSignalingClient:
from aiortc.sdp import candidate_from_sdp
for candidate_data in pending_candidates:
# candidate_data is a dict-like ICECandidateDict; convert SDP string
cand = candidate_data.get("candidate")
# candidate_data is an ICECandidateDictModel Pydantic model
cand = candidate_data.candidate
# handle end-of-candidates marker
if not cand:
await pc.addIceCandidate(None)
@ -671,8 +642,8 @@ class WebRTCSignalingClient:
try:
rtc_candidate = candidate_from_sdp(sdp_part)
rtc_candidate.sdpMid = candidate_data.get("sdpMid")
rtc_candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex")
rtc_candidate.sdpMid = candidate_data.sdpMid
rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex
await pc.addIceCandidate(rtc_candidate)
logger.info(f"Added queued ICE candidate for {peer_name}")
except Exception as e:
@ -713,41 +684,48 @@ class WebRTCSignalingClient:
current_mid = line.split(':', 1)[1].strip()
elif line.startswith('a=candidate:'):
candidate_sdp = line[2:] # Remove 'a=' prefix
candidate_dict: MessageData = {
"candidate": candidate_sdp,
"sdpMid": current_mid,
"sdpMLineIndex": current_media_index,
}
payload_candidate: MessageData = {
"peer_id": peer_id,
"candidate": candidate_dict
}
candidate_model = ICECandidateDictModel(
candidate=candidate_sdp,
sdpMid=current_mid,
sdpMLineIndex=current_media_index,
)
payload_candidate = IceCandidateModel(
peer_id=peer_id,
peer_name=peer_name,
candidate=candidate_model,
)
logger.info(f"_handle_session_description: Sending extracted ICE candidate for {peer_name} (answer): {candidate_sdp[:60]}...")
await self._send_message("relayICECandidate", payload_candidate)
await self._send_message(
"relayICECandidate", payload_candidate.model_dump()
)
# Send end-of-candidates signal (empty candidate)
end_candidate_dict: MessageData = {
"candidate": "",
"sdpMid": None,
"sdpMLineIndex": None,
}
payload_end: MessageData = {
"peer_id": peer_id,
"candidate": end_candidate_dict
}
end_candidate_model = ICECandidateDictModel(
candidate="",
sdpMid=None,
sdpMLineIndex=None,
)
payload_end = IceCandidateModel(
peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model
)
logger.info(f"_handle_session_description: Sending end-of-candidates signal for {peer_name} (answer)")
await self._send_message("relayICECandidate", payload_end)
await self._send_message("relayICECandidate", payload_end.model_dump())
logger.info(f"_handle_session_description: Sent {len(candidate_lines)} ICE candidates to {peer_name} (answer)")
session_desc_typed = SessionDescriptionTypedModel(
type=answer.type, sdp=answer.sdp
)
session_desc_model = SessionDescriptionModel(
peer_id=peer_id,
peer_name=peer_name,
session_description=session_desc_typed,
)
await self._send_message(
"relaySessionDescription",
{
"peer_id": peer_id,
"session_description": {"type": answer.type, "sdp": answer.sdp},
},
session_desc_model.model_dump(),
)
logger.info(f"Answer sent to {peer_name}")
@ -758,7 +736,7 @@ class WebRTCSignalingClient:
"""Handle iceCandidate message"""
peer_id = data.peer_id
peer_name = data.peer_name
candidate_data = data.candidate.model_dump()
candidate_data = data.candidate
logger.info(f"Received ICE candidate from {peer_name}")
@ -774,16 +752,14 @@ class WebRTCSignalingClient:
)
if peer_id not in self.pending_ice_candidates:
self.pending_ice_candidates[peer_id] = []
# candidate_data is a dict from Pydantic model; cast to the TypedDict
self.pending_ice_candidates[peer_id].append(
cast(ICECandidateDict, candidate_data)
)
# candidate_data is an ICECandidateDictModel Pydantic model
self.pending_ice_candidates[peer_id].append(candidate_data)
return
try:
from aiortc.sdp import candidate_from_sdp
cand = candidate_data.get("candidate")
cand = candidate_data.candidate
if not cand:
# end-of-candidates
await pc.addIceCandidate(None)
@ -806,8 +782,8 @@ class WebRTCSignalingClient:
try:
rtc_candidate = candidate_from_sdp(sdp_part)
rtc_candidate.sdpMid = candidate_data.get("sdpMid")
rtc_candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex")
rtc_candidate.sdpMid = candidate_data.sdpMid
rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex
# aiortc expects an object with attributes (RTCIceCandidate)
await pc.addIceCandidate(rtc_candidate)