From a313209768b9dea7eb35cce1be2848c930ea2ee4 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Mon, 1 Sep 2025 15:20:21 -0700 Subject: [PATCH] Refactored to use shared/models --- voicebot/main.py | 204 +++++++++++++++++++++-------------------------- 1 file changed, 90 insertions(+), 114 deletions(-) diff --git a/voicebot/main.py b/voicebot/main.py index 4cd4580..2f878b7 100644 --- a/voicebot/main.py +++ b/voicebot/main.py @@ -15,7 +15,6 @@ from typing import ( Optional, Callable, Awaitable, - TypedDict, Protocol, AsyncIterator, cast, @@ -47,6 +46,8 @@ from shared.models import ( RemovePeerModel, SessionDescriptionModel, IceCandidateModel, + ICECandidateDictModel, + SessionDescriptionTypedModel, ) from aiortc import ( RTCPeerConnection, @@ -60,51 +61,10 @@ from synthetic_media import create_synthetic_tracks # import debug_aioice -# TypedDict for ICE candidate payloads received from signalling -class ICECandidateDict(TypedDict, total=False): - candidate: str - sdpMid: Optional[str] - sdpMLineIndex: Optional[int] - - # Generic message payload type MessageData = dict[str, object] -# Message TypedDicts for signaling payloads -class BaseMessage(TypedDict, total=False): - type: str - data: object - - -class AddPeerPayload(TypedDict): - peer_id: str - peer_name: str - should_create_offer: bool - - -class RemovePeerPayload(TypedDict): - peer_id: str - peer_name: str - - -class SessionDescriptionTyped(TypedDict): - type: str - sdp: str - - -class SessionDescriptionPayload(TypedDict): - peer_id: str - peer_name: str - session_description: SessionDescriptionTyped - - -class IceCandidatePayload(TypedDict): - peer_id: str - peer_name: str - candidate: ICECandidateDict - - class WebSocketProtocol(Protocol): def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ... def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ... @@ -164,7 +124,7 @@ class WebRTCSignalingClient: self.is_negotiating: dict[str, bool] = {} self.making_offer: dict[str, bool] = {} self.initiated_offer: set[str] = set() - self.pending_ice_candidates: dict[str, list[ICECandidateDict]] = {} + self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {} # Event callbacks self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None @@ -488,16 +448,20 @@ class WebRTCSignalingClient: f"ICE candidate outgoing for {peer_name}: type={cand_type} protocol={protocol} sdp={raw}" ) - candidate_dict: MessageData = { - "candidate": raw, - "sdpMid": getattr(candidate, "sdpMid", None), - "sdpMLineIndex": getattr(candidate, "sdpMLineIndex", None), - } - payload: MessageData = {"peer_id": peer_id, "candidate": candidate_dict} - logger.info( - f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_dict}" + candidate_model = ICECandidateDictModel( + candidate=raw, + sdpMid=getattr(candidate, "sdpMid", None), + sdpMLineIndex=getattr(candidate, "sdpMLineIndex", None), + ) + payload_model = IceCandidateModel( + peer_id=peer_id, peer_name=peer_name, candidate=candidate_model + ) + logger.info( + f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_model}" + ) + asyncio.ensure_future( + self._send_message("relayICECandidate", payload_model.model_dump()) ) - asyncio.ensure_future(self._send_message("relayICECandidate", payload)) pc.on("icecandidate")(on_ice_candidate) @@ -544,41 +508,48 @@ class WebRTCSignalingClient: current_mid = line.split(':', 1)[1].strip() elif line.startswith('a=candidate:'): candidate_sdp = line[2:] # Remove 'a=' prefix - - candidate_dict: MessageData = { - "candidate": candidate_sdp, - "sdpMid": current_mid, - "sdpMLineIndex": current_media_index, - } - payload_candidate: MessageData = { - "peer_id": peer_id, - "candidate": candidate_dict - } + + candidate_model = ICECandidateDictModel( + candidate=candidate_sdp, + sdpMid=current_mid, + sdpMLineIndex=current_media_index, + ) + payload_candidate = IceCandidateModel( + peer_id=peer_id, + peer_name=peer_name, + candidate=candidate_model, + ) logger.info(f"_handle_add_peer: Sending extracted ICE candidate for {peer_name}: {candidate_sdp[:60]}...") - await self._send_message("relayICECandidate", payload_candidate) + await self._send_message( + "relayICECandidate", payload_candidate.model_dump() + ) # Send end-of-candidates signal (empty candidate) - end_candidate_dict: MessageData = { - "candidate": "", - "sdpMid": None, - "sdpMLineIndex": None, - } - payload_end: MessageData = { - "peer_id": peer_id, - "candidate": end_candidate_dict - } + end_candidate_model = ICECandidateDictModel( + candidate="", + sdpMid=None, + sdpMLineIndex=None, + ) + payload_end = IceCandidateModel( + peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model + ) logger.info(f"_handle_add_peer: Sending end-of-candidates signal for {peer_name}") - await self._send_message("relayICECandidate", payload_end) + await self._send_message("relayICECandidate", payload_end.model_dump()) logger.info(f"_handle_add_peer: Sent {len(candidate_lines)} ICE candidates to {peer_name}") + session_desc_typed = SessionDescriptionTypedModel( + type=offer.type, sdp=offer.sdp + ) + session_desc_model = SessionDescriptionModel( + peer_id=peer_id, + peer_name=peer_name, + session_description=session_desc_typed, + ) await self._send_message( "relaySessionDescription", - { - "peer_id": peer_id, - "session_description": {"type": offer.type, "sdp": offer.sdp}, - }, + session_desc_model.model_dump(), ) logger.info(f"Offer sent to {peer_name}") @@ -655,8 +626,8 @@ class WebRTCSignalingClient: from aiortc.sdp import candidate_from_sdp for candidate_data in pending_candidates: - # candidate_data is a dict-like ICECandidateDict; convert SDP string - cand = candidate_data.get("candidate") + # candidate_data is an ICECandidateDictModel Pydantic model + cand = candidate_data.candidate # handle end-of-candidates marker if not cand: await pc.addIceCandidate(None) @@ -671,8 +642,8 @@ class WebRTCSignalingClient: try: rtc_candidate = candidate_from_sdp(sdp_part) - rtc_candidate.sdpMid = candidate_data.get("sdpMid") - rtc_candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex") + rtc_candidate.sdpMid = candidate_data.sdpMid + rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex await pc.addIceCandidate(rtc_candidate) logger.info(f"Added queued ICE candidate for {peer_name}") except Exception as e: @@ -713,41 +684,48 @@ class WebRTCSignalingClient: current_mid = line.split(':', 1)[1].strip() elif line.startswith('a=candidate:'): candidate_sdp = line[2:] # Remove 'a=' prefix - - candidate_dict: MessageData = { - "candidate": candidate_sdp, - "sdpMid": current_mid, - "sdpMLineIndex": current_media_index, - } - payload_candidate: MessageData = { - "peer_id": peer_id, - "candidate": candidate_dict - } + + candidate_model = ICECandidateDictModel( + candidate=candidate_sdp, + sdpMid=current_mid, + sdpMLineIndex=current_media_index, + ) + payload_candidate = IceCandidateModel( + peer_id=peer_id, + peer_name=peer_name, + candidate=candidate_model, + ) logger.info(f"_handle_session_description: Sending extracted ICE candidate for {peer_name} (answer): {candidate_sdp[:60]}...") - await self._send_message("relayICECandidate", payload_candidate) + await self._send_message( + "relayICECandidate", payload_candidate.model_dump() + ) # Send end-of-candidates signal (empty candidate) - end_candidate_dict: MessageData = { - "candidate": "", - "sdpMid": None, - "sdpMLineIndex": None, - } - payload_end: MessageData = { - "peer_id": peer_id, - "candidate": end_candidate_dict - } + end_candidate_model = ICECandidateDictModel( + candidate="", + sdpMid=None, + sdpMLineIndex=None, + ) + payload_end = IceCandidateModel( + peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model + ) logger.info(f"_handle_session_description: Sending end-of-candidates signal for {peer_name} (answer)") - await self._send_message("relayICECandidate", payload_end) + await self._send_message("relayICECandidate", payload_end.model_dump()) logger.info(f"_handle_session_description: Sent {len(candidate_lines)} ICE candidates to {peer_name} (answer)") + session_desc_typed = SessionDescriptionTypedModel( + type=answer.type, sdp=answer.sdp + ) + session_desc_model = SessionDescriptionModel( + peer_id=peer_id, + peer_name=peer_name, + session_description=session_desc_typed, + ) await self._send_message( "relaySessionDescription", - { - "peer_id": peer_id, - "session_description": {"type": answer.type, "sdp": answer.sdp}, - }, + session_desc_model.model_dump(), ) logger.info(f"Answer sent to {peer_name}") @@ -758,7 +736,7 @@ class WebRTCSignalingClient: """Handle iceCandidate message""" peer_id = data.peer_id peer_name = data.peer_name - candidate_data = data.candidate.model_dump() + candidate_data = data.candidate logger.info(f"Received ICE candidate from {peer_name}") @@ -774,16 +752,14 @@ class WebRTCSignalingClient: ) if peer_id not in self.pending_ice_candidates: self.pending_ice_candidates[peer_id] = [] - # candidate_data is a dict from Pydantic model; cast to the TypedDict - self.pending_ice_candidates[peer_id].append( - cast(ICECandidateDict, candidate_data) - ) + # candidate_data is an ICECandidateDictModel Pydantic model + self.pending_ice_candidates[peer_id].append(candidate_data) return try: from aiortc.sdp import candidate_from_sdp - cand = candidate_data.get("candidate") + cand = candidate_data.candidate if not cand: # end-of-candidates await pc.addIceCandidate(None) @@ -806,8 +782,8 @@ class WebRTCSignalingClient: try: rtc_candidate = candidate_from_sdp(sdp_part) - rtc_candidate.sdpMid = candidate_data.get("sdpMid") - rtc_candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex") + rtc_candidate.sdpMid = candidate_data.sdpMid + rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex # aiortc expects an object with attributes (RTCIceCandidate) await pc.addIceCandidate(rtc_candidate)