""" WebSocket message routing and handling. This module provides a clean way to route WebSocket messages to appropriate handlers, replacing the massive switch statement from main.py. """ from abc import ABC, abstractmethod from typing import Dict, Any, TYPE_CHECKING from fastapi import WebSocket from logger import logger if TYPE_CHECKING: from ..core.session_manager import Session from ..core.lobby_manager import Lobby from ..core.auth_manager import AuthManager class MessageHandler(ABC): """Base class for WebSocket message handlers""" @abstractmethod async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: """Handle a WebSocket message""" pass class SetNameHandler(MessageHandler): """Handler for set_name messages""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: auth_manager: "AuthManager" = managers["auth_manager"] session_manager = managers["session_manager"] if not data: logger.error(f"{session.getName()} - set_name missing data") await websocket.send_json({ "type": "error", "data": {"error": "set_name missing data"}, }) return name = data.get("name") password = data.get("password") logger.info(f"{session.getName()} <- set_name({name}, {'***' if password else None})") if not name: logger.error(f"{session.getName()} - Name required") await websocket.send_json({ "type": "error", "data": {"error": "Name required"} }) return # Check if name is unique if session_manager.is_unique_name(name): # If a password was provided, save it for this name if password: auth_manager.set_password(name, password) session.setName(name) logger.info(f"{session.getName()}: -> update('name', {name})") await websocket.send_json({ "type": "update_name", "data": { "name": name, "protected": auth_manager.is_name_protected(name), }, }) # Update lobby state await lobby.update_state() return # Name is taken - check takeover allowed, reason = auth_manager.check_name_takeover(name, password) if not allowed: logger.warning(f"{session.getName()} - {reason}") await websocket.send_json({ "type": "error", "data": {"error": reason} }) return # Takeover allowed - handle displacement displaced = session_manager.get_session_by_name(name) if displaced and displaced.id != session.id: # Create unique fallback name fallback = f"{displaced.name}-{displaced.short}" counter = 1 while not session_manager.is_unique_name(fallback): fallback = f"{displaced.name}-{displaced.short}-{counter}" counter += 1 displaced.setName(fallback) displaced.mark_displaced() logger.info(f"{displaced.getName()} <- displaced by takeover, new name {fallback}") # Notify displaced session if displaced.ws: try: await displaced.ws.send_json({ "type": "update_name", "data": { "name": fallback, "protected": False, }, }) except Exception: logger.exception("Failed to notify displaced session websocket") # Update lobbies for displaced session for d_lobby in displaced.lobbies[:]: try: await d_lobby.update_state() except Exception: logger.exception("Failed to update lobby state for displaced session") # Set new password if provided if password: auth_manager.set_password(name, password) # Assign name to current session session.setName(name) logger.info(f"{session.getName()}: -> update('name', {name}) (takeover)") await websocket.send_json({ "type": "update_name", "data": { "name": name, "protected": auth_manager.is_name_protected(name), }, }) # Update lobby state await lobby.update_state() class JoinHandler(MessageHandler): """Handler for join messages""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: logger.info(f"{session.getName()} <- join({lobby.getName()})") await session.join_lobby(lobby) class PartHandler(MessageHandler): """Handler for part messages""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: logger.info(f"{session.getName()} <- part {lobby.getName()}") await session.leave_lobby(lobby) class ListUsersHandler(MessageHandler): """Handler for list_users messages""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: await lobby.update_state(session) class GetChatMessagesHandler(MessageHandler): """Handler for get_chat_messages messages""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: messages = lobby.get_chat_messages(50) await websocket.send_json({ "type": "chat_messages", "data": { "messages": [msg.model_dump() for msg in messages] }, }) class SendChatMessageHandler(MessageHandler): """Handler for send_chat_message messages""" async def handle( self, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ) -> None: if not data or "message" not in data: logger.error(f"{session.getName()} - send_chat_message missing message") await websocket.send_json({ "type": "error", "data": {"error": "send_chat_message missing message"}, }) return if not session.name: logger.error(f"{session.getName()} - Cannot send chat message without name") await websocket.send_json({ "type": "error", "data": {"error": "Must set name before sending chat messages"}, }) return message_text = str(data["message"]).strip() if not message_text: return # Add the message to the lobby and broadcast it chat_message = lobby.add_chat_message(session, message_text) logger.info(f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)") await lobby.broadcast_chat_message(chat_message) class MessageRouter: """Routes WebSocket messages to appropriate handlers""" def __init__(self): self._handlers: Dict[str, MessageHandler] = {} self._register_default_handlers() def _register_default_handlers(self): """Register default message handlers""" self.register("set_name", SetNameHandler()) self.register("join", JoinHandler()) self.register("part", PartHandler()) self.register("list_users", ListUsersHandler()) self.register("get_chat_messages", GetChatMessagesHandler()) self.register("send_chat_message", SendChatMessageHandler()) def register(self, message_type: str, handler: MessageHandler): """Register a handler for a message type""" self._handlers[message_type] = handler logger.debug(f"Registered handler for message type: {message_type}") async def route( self, message_type: str, session: "Session", lobby: "Lobby", data: Dict[str, Any], websocket: WebSocket, managers: Dict[str, Any] ): """Route a message to the appropriate handler""" if message_type in self._handlers: try: await self._handlers[message_type].handle(session, lobby, data, websocket, managers) except Exception as e: import traceback logger.error(f"Error handling message type {message_type}: {e}") logger.error(f"Full traceback: {traceback.format_exc()}") await websocket.send_json({ "type": "error", "data": {"error": f"Internal error handling {message_type}"} }) else: logger.warning(f"Unknown message type: {message_type}") await websocket.send_json({ "type": "error", "data": {"error": f"Unknown message type: {message_type}"} }) def get_supported_types(self) -> list[str]: """Get list of supported message types""" return list(self._handlers.keys())