diff --git a/voicebot/main.py b/voicebot/main.py index ee75d4a..ae5f476 100644 --- a/voicebot/main.py +++ b/voicebot/main.py @@ -250,6 +250,38 @@ class IceCandidateModel(BaseModel): candidate: ICECandidateDictModel +class JoinStatusModel(BaseModel): + status: str + message: str = "" + + +class UserJoinedModel(BaseModel): + name: str + session_id: str + + +class ParticipantModel(BaseModel): + name: str + session_id: str + # Add other participant fields as needed based on actual data structure + + +class LobbyStateModel(BaseModel): + participants: List[ParticipantModel] = [] + + +class UpdateModel(BaseModel): + # This can be extended based on the actual update message structure + # For now, we'll make it flexible to handle various update types + class Config: + extra = "allow" + + +class WebSocketMessageModel(BaseModel): + type: str + data: Dict[str, object] = {} + + class WebSocketProtocol(Protocol): def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ... def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ... @@ -620,8 +652,14 @@ class WebRTCSignalingClient: async def _process_message(self, message: MessageData): """Process incoming signaling messages""" - msg_type = message.get("type") - data = message.get("data", {}) + try: + # Validate the base message structure first + validated_message = WebSocketMessageModel.model_validate(message) + msg_type = validated_message.type + data = validated_message.data + except ValidationError as e: + logger.error(f"Invalid message structure: {e}", exc_info=True) + return logger.info( f"_process_message: Received message type: {msg_type} with data: {data}" @@ -656,23 +694,36 @@ class WebRTCSignalingClient: return await self._handle_ice_candidate(validated) elif msg_type == "join_status": - dd = cast(MessageData, data) - logger.info(f"Join status: {dd.get('status')} - {dd.get('message', '')}") + try: + validated = JoinStatusModel.model_validate(data) + except ValidationError as e: + logger.error(f"Invalid join_status payload: {e}", exc_info=True) + return + logger.info(f"Join status: {validated.status} - {validated.message}") elif msg_type == "user_joined": - dd = cast(MessageData, data) + try: + validated = UserJoinedModel.model_validate(data) + except ValidationError as e: + logger.error(f"Invalid user_joined payload: {e}", exc_info=True) + return logger.info( - f"User joined: {dd.get('name')} (session: {dd.get('session_id')})" + f"User joined: {validated.name} (session: {validated.session_id})" ) elif msg_type == "lobby_state": - dd = cast(MessageData, data) - participants = dd.get("participants", []) - if isinstance(participants, list): - logger.info(f"Lobby state updated: {len(participants)} participants") - else: - logger.info("Lobby state updated: participants data received") + try: + validated = LobbyStateModel.model_validate(data) + except ValidationError as e: + logger.error(f"Invalid lobby_state payload: {e}", exc_info=True) + return + participants = validated.participants + logger.info(f"Lobby state updated: {len(participants)} participants") elif msg_type == "update": - dd = cast(MessageData, data) - logger.info(f"Received update message: {dd}") + try: + validated = UpdateModel.model_validate(data) + except ValidationError as e: + logger.error(f"Invalid update payload: {e}", exc_info=True) + return + logger.info(f"Received update message: {validated}") else: logger.info(f"Unhandled message type: {msg_type} with data: {data}")