Refactored to use shared/models
This commit is contained in:
parent
7c5616fbd9
commit
a313209768
204
voicebot/main.py
204
voicebot/main.py
@ -15,7 +15,6 @@ from typing import (
|
||||
Optional,
|
||||
Callable,
|
||||
Awaitable,
|
||||
TypedDict,
|
||||
Protocol,
|
||||
AsyncIterator,
|
||||
cast,
|
||||
@ -47,6 +46,8 @@ from shared.models import (
|
||||
RemovePeerModel,
|
||||
SessionDescriptionModel,
|
||||
IceCandidateModel,
|
||||
ICECandidateDictModel,
|
||||
SessionDescriptionTypedModel,
|
||||
)
|
||||
from aiortc import (
|
||||
RTCPeerConnection,
|
||||
@ -60,51 +61,10 @@ from synthetic_media import create_synthetic_tracks
|
||||
# import debug_aioice
|
||||
|
||||
|
||||
# TypedDict for ICE candidate payloads received from signalling
|
||||
class ICECandidateDict(TypedDict, total=False):
|
||||
candidate: str
|
||||
sdpMid: Optional[str]
|
||||
sdpMLineIndex: Optional[int]
|
||||
|
||||
|
||||
# Generic message payload type
|
||||
MessageData = dict[str, object]
|
||||
|
||||
|
||||
# Message TypedDicts for signaling payloads
|
||||
class BaseMessage(TypedDict, total=False):
|
||||
type: str
|
||||
data: object
|
||||
|
||||
|
||||
class AddPeerPayload(TypedDict):
|
||||
peer_id: str
|
||||
peer_name: str
|
||||
should_create_offer: bool
|
||||
|
||||
|
||||
class RemovePeerPayload(TypedDict):
|
||||
peer_id: str
|
||||
peer_name: str
|
||||
|
||||
|
||||
class SessionDescriptionTyped(TypedDict):
|
||||
type: str
|
||||
sdp: str
|
||||
|
||||
|
||||
class SessionDescriptionPayload(TypedDict):
|
||||
peer_id: str
|
||||
peer_name: str
|
||||
session_description: SessionDescriptionTyped
|
||||
|
||||
|
||||
class IceCandidatePayload(TypedDict):
|
||||
peer_id: str
|
||||
peer_name: str
|
||||
candidate: ICECandidateDict
|
||||
|
||||
|
||||
class WebSocketProtocol(Protocol):
|
||||
def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ...
|
||||
def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ...
|
||||
@ -164,7 +124,7 @@ class WebRTCSignalingClient:
|
||||
self.is_negotiating: dict[str, bool] = {}
|
||||
self.making_offer: dict[str, bool] = {}
|
||||
self.initiated_offer: set[str] = set()
|
||||
self.pending_ice_candidates: dict[str, list[ICECandidateDict]] = {}
|
||||
self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {}
|
||||
|
||||
# Event callbacks
|
||||
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
|
||||
@ -488,16 +448,20 @@ class WebRTCSignalingClient:
|
||||
f"ICE candidate outgoing for {peer_name}: type={cand_type} protocol={protocol} sdp={raw}"
|
||||
)
|
||||
|
||||
candidate_dict: MessageData = {
|
||||
"candidate": raw,
|
||||
"sdpMid": getattr(candidate, "sdpMid", None),
|
||||
"sdpMLineIndex": getattr(candidate, "sdpMLineIndex", None),
|
||||
}
|
||||
payload: MessageData = {"peer_id": peer_id, "candidate": candidate_dict}
|
||||
logger.info(
|
||||
f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_dict}"
|
||||
candidate_model = ICECandidateDictModel(
|
||||
candidate=raw,
|
||||
sdpMid=getattr(candidate, "sdpMid", None),
|
||||
sdpMLineIndex=getattr(candidate, "sdpMLineIndex", None),
|
||||
)
|
||||
payload_model = IceCandidateModel(
|
||||
peer_id=peer_id, peer_name=peer_name, candidate=candidate_model
|
||||
)
|
||||
logger.info(
|
||||
f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_model}"
|
||||
)
|
||||
asyncio.ensure_future(
|
||||
self._send_message("relayICECandidate", payload_model.model_dump())
|
||||
)
|
||||
asyncio.ensure_future(self._send_message("relayICECandidate", payload))
|
||||
|
||||
pc.on("icecandidate")(on_ice_candidate)
|
||||
|
||||
@ -544,41 +508,48 @@ class WebRTCSignalingClient:
|
||||
current_mid = line.split(':', 1)[1].strip()
|
||||
elif line.startswith('a=candidate:'):
|
||||
candidate_sdp = line[2:] # Remove 'a=' prefix
|
||||
|
||||
candidate_dict: MessageData = {
|
||||
"candidate": candidate_sdp,
|
||||
"sdpMid": current_mid,
|
||||
"sdpMLineIndex": current_media_index,
|
||||
}
|
||||
payload_candidate: MessageData = {
|
||||
"peer_id": peer_id,
|
||||
"candidate": candidate_dict
|
||||
}
|
||||
|
||||
candidate_model = ICECandidateDictModel(
|
||||
candidate=candidate_sdp,
|
||||
sdpMid=current_mid,
|
||||
sdpMLineIndex=current_media_index,
|
||||
)
|
||||
payload_candidate = IceCandidateModel(
|
||||
peer_id=peer_id,
|
||||
peer_name=peer_name,
|
||||
candidate=candidate_model,
|
||||
)
|
||||
|
||||
logger.info(f"_handle_add_peer: Sending extracted ICE candidate for {peer_name}: {candidate_sdp[:60]}...")
|
||||
await self._send_message("relayICECandidate", payload_candidate)
|
||||
await self._send_message(
|
||||
"relayICECandidate", payload_candidate.model_dump()
|
||||
)
|
||||
|
||||
# Send end-of-candidates signal (empty candidate)
|
||||
end_candidate_dict: MessageData = {
|
||||
"candidate": "",
|
||||
"sdpMid": None,
|
||||
"sdpMLineIndex": None,
|
||||
}
|
||||
payload_end: MessageData = {
|
||||
"peer_id": peer_id,
|
||||
"candidate": end_candidate_dict
|
||||
}
|
||||
end_candidate_model = ICECandidateDictModel(
|
||||
candidate="",
|
||||
sdpMid=None,
|
||||
sdpMLineIndex=None,
|
||||
)
|
||||
payload_end = IceCandidateModel(
|
||||
peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model
|
||||
)
|
||||
logger.info(f"_handle_add_peer: Sending end-of-candidates signal for {peer_name}")
|
||||
await self._send_message("relayICECandidate", payload_end)
|
||||
await self._send_message("relayICECandidate", payload_end.model_dump())
|
||||
|
||||
logger.info(f"_handle_add_peer: Sent {len(candidate_lines)} ICE candidates to {peer_name}")
|
||||
|
||||
session_desc_typed = SessionDescriptionTypedModel(
|
||||
type=offer.type, sdp=offer.sdp
|
||||
)
|
||||
session_desc_model = SessionDescriptionModel(
|
||||
peer_id=peer_id,
|
||||
peer_name=peer_name,
|
||||
session_description=session_desc_typed,
|
||||
)
|
||||
await self._send_message(
|
||||
"relaySessionDescription",
|
||||
{
|
||||
"peer_id": peer_id,
|
||||
"session_description": {"type": offer.type, "sdp": offer.sdp},
|
||||
},
|
||||
session_desc_model.model_dump(),
|
||||
)
|
||||
|
||||
logger.info(f"Offer sent to {peer_name}")
|
||||
@ -655,8 +626,8 @@ class WebRTCSignalingClient:
|
||||
from aiortc.sdp import candidate_from_sdp
|
||||
|
||||
for candidate_data in pending_candidates:
|
||||
# candidate_data is a dict-like ICECandidateDict; convert SDP string
|
||||
cand = candidate_data.get("candidate")
|
||||
# candidate_data is an ICECandidateDictModel Pydantic model
|
||||
cand = candidate_data.candidate
|
||||
# handle end-of-candidates marker
|
||||
if not cand:
|
||||
await pc.addIceCandidate(None)
|
||||
@ -671,8 +642,8 @@ class WebRTCSignalingClient:
|
||||
|
||||
try:
|
||||
rtc_candidate = candidate_from_sdp(sdp_part)
|
||||
rtc_candidate.sdpMid = candidate_data.get("sdpMid")
|
||||
rtc_candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex")
|
||||
rtc_candidate.sdpMid = candidate_data.sdpMid
|
||||
rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex
|
||||
await pc.addIceCandidate(rtc_candidate)
|
||||
logger.info(f"Added queued ICE candidate for {peer_name}")
|
||||
except Exception as e:
|
||||
@ -713,41 +684,48 @@ class WebRTCSignalingClient:
|
||||
current_mid = line.split(':', 1)[1].strip()
|
||||
elif line.startswith('a=candidate:'):
|
||||
candidate_sdp = line[2:] # Remove 'a=' prefix
|
||||
|
||||
candidate_dict: MessageData = {
|
||||
"candidate": candidate_sdp,
|
||||
"sdpMid": current_mid,
|
||||
"sdpMLineIndex": current_media_index,
|
||||
}
|
||||
payload_candidate: MessageData = {
|
||||
"peer_id": peer_id,
|
||||
"candidate": candidate_dict
|
||||
}
|
||||
|
||||
candidate_model = ICECandidateDictModel(
|
||||
candidate=candidate_sdp,
|
||||
sdpMid=current_mid,
|
||||
sdpMLineIndex=current_media_index,
|
||||
)
|
||||
payload_candidate = IceCandidateModel(
|
||||
peer_id=peer_id,
|
||||
peer_name=peer_name,
|
||||
candidate=candidate_model,
|
||||
)
|
||||
|
||||
logger.info(f"_handle_session_description: Sending extracted ICE candidate for {peer_name} (answer): {candidate_sdp[:60]}...")
|
||||
await self._send_message("relayICECandidate", payload_candidate)
|
||||
await self._send_message(
|
||||
"relayICECandidate", payload_candidate.model_dump()
|
||||
)
|
||||
|
||||
# Send end-of-candidates signal (empty candidate)
|
||||
end_candidate_dict: MessageData = {
|
||||
"candidate": "",
|
||||
"sdpMid": None,
|
||||
"sdpMLineIndex": None,
|
||||
}
|
||||
payload_end: MessageData = {
|
||||
"peer_id": peer_id,
|
||||
"candidate": end_candidate_dict
|
||||
}
|
||||
end_candidate_model = ICECandidateDictModel(
|
||||
candidate="",
|
||||
sdpMid=None,
|
||||
sdpMLineIndex=None,
|
||||
)
|
||||
payload_end = IceCandidateModel(
|
||||
peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model
|
||||
)
|
||||
logger.info(f"_handle_session_description: Sending end-of-candidates signal for {peer_name} (answer)")
|
||||
await self._send_message("relayICECandidate", payload_end)
|
||||
await self._send_message("relayICECandidate", payload_end.model_dump())
|
||||
|
||||
logger.info(f"_handle_session_description: Sent {len(candidate_lines)} ICE candidates to {peer_name} (answer)")
|
||||
|
||||
session_desc_typed = SessionDescriptionTypedModel(
|
||||
type=answer.type, sdp=answer.sdp
|
||||
)
|
||||
session_desc_model = SessionDescriptionModel(
|
||||
peer_id=peer_id,
|
||||
peer_name=peer_name,
|
||||
session_description=session_desc_typed,
|
||||
)
|
||||
await self._send_message(
|
||||
"relaySessionDescription",
|
||||
{
|
||||
"peer_id": peer_id,
|
||||
"session_description": {"type": answer.type, "sdp": answer.sdp},
|
||||
},
|
||||
session_desc_model.model_dump(),
|
||||
)
|
||||
|
||||
logger.info(f"Answer sent to {peer_name}")
|
||||
@ -758,7 +736,7 @@ class WebRTCSignalingClient:
|
||||
"""Handle iceCandidate message"""
|
||||
peer_id = data.peer_id
|
||||
peer_name = data.peer_name
|
||||
candidate_data = data.candidate.model_dump()
|
||||
candidate_data = data.candidate
|
||||
|
||||
logger.info(f"Received ICE candidate from {peer_name}")
|
||||
|
||||
@ -774,16 +752,14 @@ class WebRTCSignalingClient:
|
||||
)
|
||||
if peer_id not in self.pending_ice_candidates:
|
||||
self.pending_ice_candidates[peer_id] = []
|
||||
# candidate_data is a dict from Pydantic model; cast to the TypedDict
|
||||
self.pending_ice_candidates[peer_id].append(
|
||||
cast(ICECandidateDict, candidate_data)
|
||||
)
|
||||
# candidate_data is an ICECandidateDictModel Pydantic model
|
||||
self.pending_ice_candidates[peer_id].append(candidate_data)
|
||||
return
|
||||
|
||||
try:
|
||||
from aiortc.sdp import candidate_from_sdp
|
||||
|
||||
cand = candidate_data.get("candidate")
|
||||
cand = candidate_data.candidate
|
||||
if not cand:
|
||||
# end-of-candidates
|
||||
await pc.addIceCandidate(None)
|
||||
@ -806,8 +782,8 @@ class WebRTCSignalingClient:
|
||||
|
||||
try:
|
||||
rtc_candidate = candidate_from_sdp(sdp_part)
|
||||
rtc_candidate.sdpMid = candidate_data.get("sdpMid")
|
||||
rtc_candidate.sdpMLineIndex = candidate_data.get("sdpMLineIndex")
|
||||
rtc_candidate.sdpMid = candidate_data.sdpMid
|
||||
rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex
|
||||
|
||||
# aiortc expects an object with attributes (RTCIceCandidate)
|
||||
await pc.addIceCandidate(rtc_candidate)
|
||||
|
Loading…
x
Reference in New Issue
Block a user