Fixed pydantic models

This commit is contained in:
James Ketr 2025-09-01 13:24:09 -07:00
parent 282c0ffa9c
commit d4fdd917d1

View File

@ -250,6 +250,38 @@ class IceCandidateModel(BaseModel):
candidate: ICECandidateDictModel
class JoinStatusModel(BaseModel):
status: str
message: str = ""
class UserJoinedModel(BaseModel):
name: str
session_id: str
class ParticipantModel(BaseModel):
name: str
session_id: str
# Add other participant fields as needed based on actual data structure
class LobbyStateModel(BaseModel):
participants: List[ParticipantModel] = []
class UpdateModel(BaseModel):
# This can be extended based on the actual update message structure
# For now, we'll make it flexible to handle various update types
class Config:
extra = "allow"
class WebSocketMessageModel(BaseModel):
type: str
data: Dict[str, object] = {}
class WebSocketProtocol(Protocol):
def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ...
def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ...
@ -620,8 +652,14 @@ class WebRTCSignalingClient:
async def _process_message(self, message: MessageData):
"""Process incoming signaling messages"""
msg_type = message.get("type")
data = message.get("data", {})
try:
# Validate the base message structure first
validated_message = WebSocketMessageModel.model_validate(message)
msg_type = validated_message.type
data = validated_message.data
except ValidationError as e:
logger.error(f"Invalid message structure: {e}", exc_info=True)
return
logger.info(
f"_process_message: Received message type: {msg_type} with data: {data}"
@ -656,23 +694,36 @@ class WebRTCSignalingClient:
return
await self._handle_ice_candidate(validated)
elif msg_type == "join_status":
dd = cast(MessageData, data)
logger.info(f"Join status: {dd.get('status')} - {dd.get('message', '')}")
try:
validated = JoinStatusModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid join_status payload: {e}", exc_info=True)
return
logger.info(f"Join status: {validated.status} - {validated.message}")
elif msg_type == "user_joined":
dd = cast(MessageData, data)
try:
validated = UserJoinedModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid user_joined payload: {e}", exc_info=True)
return
logger.info(
f"User joined: {dd.get('name')} (session: {dd.get('session_id')})"
f"User joined: {validated.name} (session: {validated.session_id})"
)
elif msg_type == "lobby_state":
dd = cast(MessageData, data)
participants = dd.get("participants", [])
if isinstance(participants, list):
logger.info(f"Lobby state updated: {len(participants)} participants")
else:
logger.info("Lobby state updated: participants data received")
try:
validated = LobbyStateModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid lobby_state payload: {e}", exc_info=True)
return
participants = validated.participants
logger.info(f"Lobby state updated: {len(participants)} participants")
elif msg_type == "update":
dd = cast(MessageData, data)
logger.info(f"Received update message: {dd}")
try:
validated = UpdateModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid update payload: {e}", exc_info=True)
return
logger.info(f"Received update message: {validated}")
else:
logger.info(f"Unhandled message type: {msg_type} with data: {data}")