""" WebSocket message routing and handling. This module provides a clean way to route WebSocket messages to appropriate handlers, replacing the massive switch statement from main.py. """ from abc import ABC, abstractmethod from typing import Dict, Any, TYPE_CHECKING from fastapi import WebSocket from logger import logger from .webrtc_signaling import WebRTCSignalingHandlers from core.error_handling import ( error_handler, WebSocketError, ValidationError, ErrorSeverity, ) if TYPE_CHECKING: from ..core.session_manager import Session from ..core.lobby_manager import Lobby from ..core.auth_manager import AuthManager class MessageHandler(ABC): """Base class for WebSocket message handlers""" @abstractmethod async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: """Handle a WebSocket message""" pass class SetNameHandler(MessageHandler): """Handler for set_name messages""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: auth_manager: "AuthManager" = managers["auth_manager"] session_manager = managers["session_manager"] if not data: logger.error(f"{session.getName()} - set_name missing data") await websocket.send_json({ "type": "error", "data": {"error": "set_name missing data"}, }) return name = data.get("name") password = data.get("password") logger.info(f"{session.getName()} <- set_name({name}, {'***' if password else None})") if not name: logger.error(f"{session.getName()} - Name required") await websocket.send_json({ "type": "error", "data": {"error": "Name required"} }) return # Check if name is unique if session_manager.is_unique_name(name): # If a password was provided, save it for this name if password: auth_manager.set_password(name, password) session.setName(name) logger.info(f"{session.getName()}: -> update('name', {name})") await websocket.send_json({ "type": "update_name", "data": { "name": name, "protected": auth_manager.is_name_protected(name), }, }) # Update lobby state await lobby.update_state() return # Name is taken - check takeover allowed, reason = auth_manager.check_name_takeover(name, password) if not allowed: logger.warning(f"{session.getName()} - {reason}") await websocket.send_json({ "type": "error", "data": {"error": reason} }) return # Takeover allowed - handle displacement displaced = session_manager.get_session_by_name(name) if displaced and displaced.id != session.id: # Create unique fallback name fallback = f"{displaced.name}-{displaced.short}" counter = 1 while not session_manager.is_unique_name(fallback): fallback = f"{displaced.name}-{displaced.short}-{counter}" counter += 1 displaced.setName(fallback) displaced.mark_displaced() logger.info(f"{displaced.getName()} <- displaced by takeover, new name {fallback}") # Notify displaced session if displaced.ws: try: await displaced.ws.send_json({ "type": "update_name", "data": { "name": fallback, "protected": False, }, }) except Exception: logger.exception("Failed to notify displaced session websocket") # Update lobbies for displaced session for d_lobby in displaced.lobbies[:]: try: await d_lobby.update_state() except Exception: logger.exception("Failed to update lobby state for displaced session") # Set new password if provided if password: auth_manager.set_password(name, password) # Assign name to current session session.setName(name) logger.info(f"{session.getName()}: -> update('name', {name}) (takeover)") await websocket.send_json({ "type": "update_name", "data": { "name": name, "protected": auth_manager.is_name_protected(name), }, }) # Update lobby state await lobby.update_state() class JoinHandler(MessageHandler): """Handler for join messages""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: logger.info(f"{session.getName()} <- join({lobby.getName()})") await session.join_lobby(lobby) class PartHandler(MessageHandler): """Handler for part messages""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: logger.info(f"{session.getName()} <- part {lobby.getName()}") await session.leave_lobby(lobby) class ListUsersHandler(MessageHandler): """Handler for list_users messages""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: await lobby.update_state(session) class GetChatMessagesHandler(MessageHandler): """Handler for get_chat_messages messages""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: messages = lobby.get_chat_messages(50) await websocket.send_json({ "type": "chat_messages", "data": { "messages": [msg.model_dump() for msg in messages] }, }) class SendChatMessageHandler(MessageHandler): """Handler for send_chat_message messages""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: if not data or "message" not in data: logger.error(f"{session.getName()} - send_chat_message missing message") await websocket.send_json({ "type": "error", "data": {"error": "send_chat_message missing message"}, }) return if not session.name: logger.error(f"{session.getName()} - Cannot send chat message without name") await websocket.send_json({ "type": "error", "data": {"error": "Must set name before sending chat messages"}, }) return message_text = str(data["message"]).strip() if not message_text: return # Add the message to the lobby and broadcast it chat_message = lobby.add_chat_message(session, message_text) logger.info(f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)") await lobby.broadcast_chat_message(chat_message) class RelayICECandidateHandler(MessageHandler): """Handler for relayICECandidate messages - WebRTC signaling""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any], ) -> None: await WebRTCSignalingHandlers.handle_relay_ice_candidate( websocket, session, lobby, data ) class RelaySessionDescriptionHandler(MessageHandler): """Handler for relaySessionDescription messages - WebRTC signaling""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any], ) -> None: await WebRTCSignalingHandlers.handle_relay_session_description( websocket, session, lobby, data ) class StatusCheckHandler(MessageHandler): """Handler for status_check messages - Bot health monitoring""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any], ) -> None: # Simple status check response logger.debug(f"{session.getName()} <- status_check") # Respond with current status await websocket.send_json( { "type": "status_response", "data": { "status": "ok", "timestamp": data.get("timestamp"), "session_id": session.id, "lobby": lobby.getName() if lobby else None, }, } ) class MessageRouter: """Routes WebSocket messages to appropriate handlers""" def __init__(self): self._handlers: Dict[str, MessageHandler] = {} self._register_default_handlers() def _register_default_handlers(self): """Register default message handlers""" self.register("set_name", SetNameHandler()) self.register("join", JoinHandler()) self.register("part", PartHandler()) self.register("list_users", ListUsersHandler()) self.register("get_chat_messages", GetChatMessagesHandler()) self.register("send_chat_message", SendChatMessageHandler()) # WebRTC signaling handlers self.register("relayICECandidate", RelayICECandidateHandler()) self.register("relaySessionDescription", RelaySessionDescriptionHandler()) # Bot monitoring handlers self.register("status_check", StatusCheckHandler()) def register(self, message_type: str, handler: MessageHandler): """Register a handler for a message type""" self._handlers[message_type] = handler logger.debug(f"Registered handler for message type: {message_type}") async def route( self, message_type: str, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ): """Route a message to the appropriate handler with enhanced error handling""" if message_type not in self._handlers: await error_handler.handle_error( ValidationError(f"Unknown message type: {message_type}"), context={ "message_type": message_type, "session_id": session.id if session else "unknown", "data_keys": list(data.keys()) if data else [] }, websocket=websocket, session_id=session.id if session else None ) return try: # Execute handler with context tracking await self._handlers[message_type].handle( session, lobby, data, websocket, managers ) except WebSocketError as e: # WebSocket specific errors - attempt recovery await error_handler.handle_error( e, context={ "message_type": message_type, "session_id": session.id if session else "unknown", "handler": type(self._handlers[message_type]).__name__ }, websocket=websocket, session_id=session.id if session else None, recovery_action=lambda: self._websocket_recovery(websocket, session) ) except ValidationError as e: # Validation errors - usually client-side issues await error_handler.handle_error( e, context={ "message_type": message_type, "session_id": session.id if session else "unknown", "data": str(data)[:500] # Truncate large data }, websocket=websocket, session_id=session.id if session else None ) except Exception as e: # Unexpected errors - enhanced logging and fallback await error_handler.handle_error( WebSocketError( f"Unexpected error in {message_type} handler: {e}", severity=ErrorSeverity.HIGH ), context={ "message_type": message_type, "session_id": session.id if session else "unknown", "handler": type(self._handlers[message_type]).__name__, "exception_type": type(e).__name__, "traceback": str(e) }, websocket=websocket, session_id=session.id if session else None, recovery_action=lambda: self._generic_recovery(message_type, session, lobby) ) async def _websocket_recovery(self, websocket: WebSocket, session: "Session"): """WebSocket recovery action""" if websocket and session: # Send a connection status update await websocket.send_json({ "type": "connection_status", "data": { "status": "recovered", "session_id": session.id, "message": "Connection recovered from error" } }) async def _generic_recovery(self, message_type: str, session: "Session", lobby: "Lobby"): """Generic recovery action""" # Log recovery attempt logger.info(f"Attempting recovery for {message_type} error") # Depending on message type, perform specific recovery if message_type in ["join", "part"]: # For lobby operations, ensure session state consistency if session and lobby: # Refresh lobby state await lobby.update_state() elif message_type == "set_name": # For name operations, validate session state if session: logger.info(f"Validating session state for {session.id}") def get_supported_types(self) -> list[str]: """Get list of supported message types""" return list(self._handlers.keys())