431 lines
14 KiB
Python
431 lines
14 KiB
Python
"""
|
|
WebSocket message routing and handling.
|
|
|
|
This module provides a clean way to route WebSocket messages to appropriate handlers,
|
|
replacing the massive switch statement from main.py.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Any, TYPE_CHECKING
|
|
from fastapi import WebSocket
|
|
|
|
from logger import logger
|
|
from .webrtc_signaling import WebRTCSignalingHandlers
|
|
from core.error_handling import (
|
|
error_handler,
|
|
WebSocketError,
|
|
ValidationError,
|
|
with_websocket_error_handling,
|
|
ErrorSeverity
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from ..core.session_manager import Session
|
|
from ..core.lobby_manager import Lobby
|
|
from ..core.auth_manager import AuthManager
|
|
|
|
|
|
class MessageHandler(ABC):
|
|
"""Base class for WebSocket message handlers"""
|
|
|
|
@abstractmethod
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
"""Handle a WebSocket message"""
|
|
pass
|
|
|
|
|
|
class SetNameHandler(MessageHandler):
|
|
"""Handler for set_name messages"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
auth_manager: "AuthManager" = managers["auth_manager"]
|
|
session_manager = managers["session_manager"]
|
|
|
|
if not data:
|
|
logger.error(f"{session.getName()} - set_name missing data")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {"error": "set_name missing data"},
|
|
})
|
|
return
|
|
|
|
name = data.get("name")
|
|
password = data.get("password")
|
|
|
|
logger.info(f"{session.getName()} <- set_name({name}, {'***' if password else None})")
|
|
|
|
if not name:
|
|
logger.error(f"{session.getName()} - Name required")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {"error": "Name required"}
|
|
})
|
|
return
|
|
|
|
# Check if name is unique
|
|
if session_manager.is_unique_name(name):
|
|
# If a password was provided, save it for this name
|
|
if password:
|
|
auth_manager.set_password(name, password)
|
|
|
|
session.setName(name)
|
|
logger.info(f"{session.getName()}: -> update('name', {name})")
|
|
|
|
await websocket.send_json({
|
|
"type": "update_name",
|
|
"data": {
|
|
"name": name,
|
|
"protected": auth_manager.is_name_protected(name),
|
|
},
|
|
})
|
|
|
|
# Update lobby state
|
|
await lobby.update_state()
|
|
return
|
|
|
|
# Name is taken - check takeover
|
|
allowed, reason = auth_manager.check_name_takeover(name, password)
|
|
|
|
if not allowed:
|
|
logger.warning(f"{session.getName()} - {reason}")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {"error": reason}
|
|
})
|
|
return
|
|
|
|
# Takeover allowed - handle displacement
|
|
displaced = session_manager.get_session_by_name(name)
|
|
if displaced and displaced.id != session.id:
|
|
# Create unique fallback name
|
|
fallback = f"{displaced.name}-{displaced.short}"
|
|
counter = 1
|
|
while not session_manager.is_unique_name(fallback):
|
|
fallback = f"{displaced.name}-{displaced.short}-{counter}"
|
|
counter += 1
|
|
|
|
displaced.setName(fallback)
|
|
displaced.mark_displaced()
|
|
logger.info(f"{displaced.getName()} <- displaced by takeover, new name {fallback}")
|
|
|
|
# Notify displaced session
|
|
if displaced.ws:
|
|
try:
|
|
await displaced.ws.send_json({
|
|
"type": "update_name",
|
|
"data": {
|
|
"name": fallback,
|
|
"protected": False,
|
|
},
|
|
})
|
|
except Exception:
|
|
logger.exception("Failed to notify displaced session websocket")
|
|
|
|
# Update lobbies for displaced session
|
|
for d_lobby in displaced.lobbies[:]:
|
|
try:
|
|
await d_lobby.update_state()
|
|
except Exception:
|
|
logger.exception("Failed to update lobby state for displaced session")
|
|
|
|
# Set new password if provided
|
|
if password:
|
|
auth_manager.set_password(name, password)
|
|
|
|
# Assign name to current session
|
|
session.setName(name)
|
|
logger.info(f"{session.getName()}: -> update('name', {name}) (takeover)")
|
|
|
|
await websocket.send_json({
|
|
"type": "update_name",
|
|
"data": {
|
|
"name": name,
|
|
"protected": auth_manager.is_name_protected(name),
|
|
},
|
|
})
|
|
|
|
# Update lobby state
|
|
await lobby.update_state()
|
|
|
|
|
|
class JoinHandler(MessageHandler):
|
|
"""Handler for join messages"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
logger.info(f"{session.getName()} <- join({lobby.getName()})")
|
|
await session.join_lobby(lobby)
|
|
|
|
|
|
class PartHandler(MessageHandler):
|
|
"""Handler for part messages"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
logger.info(f"{session.getName()} <- part {lobby.getName()}")
|
|
await session.leave_lobby(lobby)
|
|
|
|
|
|
class ListUsersHandler(MessageHandler):
|
|
"""Handler for list_users messages"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
await lobby.update_state(session)
|
|
|
|
|
|
class GetChatMessagesHandler(MessageHandler):
|
|
"""Handler for get_chat_messages messages"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
messages = lobby.get_chat_messages(50)
|
|
await websocket.send_json({
|
|
"type": "chat_messages",
|
|
"data": {
|
|
"messages": [msg.model_dump() for msg in messages]
|
|
},
|
|
})
|
|
|
|
|
|
class SendChatMessageHandler(MessageHandler):
|
|
"""Handler for send_chat_message messages"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
if not data or "message" not in data:
|
|
logger.error(f"{session.getName()} - send_chat_message missing message")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {"error": "send_chat_message missing message"},
|
|
})
|
|
return
|
|
|
|
if not session.name:
|
|
logger.error(f"{session.getName()} - Cannot send chat message without name")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {"error": "Must set name before sending chat messages"},
|
|
})
|
|
return
|
|
|
|
message_text = str(data["message"]).strip()
|
|
if not message_text:
|
|
return
|
|
|
|
# Add the message to the lobby and broadcast it
|
|
chat_message = lobby.add_chat_message(session, message_text)
|
|
logger.info(f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)")
|
|
await lobby.broadcast_chat_message(chat_message)
|
|
|
|
|
|
class RelayICECandidateHandler(MessageHandler):
|
|
"""Handler for relayICECandidate messages - WebRTC signaling"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any],
|
|
) -> None:
|
|
await WebRTCSignalingHandlers.handle_relay_ice_candidate(
|
|
websocket, session, lobby, data
|
|
)
|
|
|
|
|
|
class RelaySessionDescriptionHandler(MessageHandler):
|
|
"""Handler for relaySessionDescription messages - WebRTC signaling"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any],
|
|
) -> None:
|
|
await WebRTCSignalingHandlers.handle_relay_session_description(
|
|
websocket, session, lobby, data
|
|
)
|
|
|
|
|
|
class MessageRouter:
|
|
"""Routes WebSocket messages to appropriate handlers"""
|
|
|
|
def __init__(self):
|
|
self._handlers: Dict[str, MessageHandler] = {}
|
|
self._register_default_handlers()
|
|
|
|
def _register_default_handlers(self):
|
|
"""Register default message handlers"""
|
|
self.register("set_name", SetNameHandler())
|
|
self.register("join", JoinHandler())
|
|
self.register("part", PartHandler())
|
|
self.register("list_users", ListUsersHandler())
|
|
self.register("get_chat_messages", GetChatMessagesHandler())
|
|
self.register("send_chat_message", SendChatMessageHandler())
|
|
|
|
# WebRTC signaling handlers
|
|
self.register("relayICECandidate", RelayICECandidateHandler())
|
|
self.register("relaySessionDescription", RelaySessionDescriptionHandler())
|
|
|
|
def register(self, message_type: str, handler: MessageHandler):
|
|
"""Register a handler for a message type"""
|
|
self._handlers[message_type] = handler
|
|
logger.debug(f"Registered handler for message type: {message_type}")
|
|
|
|
async def route(
|
|
self,
|
|
message_type: str,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
):
|
|
"""Route a message to the appropriate handler with enhanced error handling"""
|
|
if message_type not in self._handlers:
|
|
await error_handler.handle_error(
|
|
ValidationError(f"Unknown message type: {message_type}"),
|
|
context={
|
|
"message_type": message_type,
|
|
"session_id": session.id if session else "unknown",
|
|
"data_keys": list(data.keys()) if data else []
|
|
},
|
|
websocket=websocket,
|
|
session_id=session.id if session else None
|
|
)
|
|
return
|
|
|
|
try:
|
|
# Execute handler with context tracking
|
|
await self._handlers[message_type].handle(
|
|
session, lobby, data, websocket, managers
|
|
)
|
|
|
|
except WebSocketError as e:
|
|
# WebSocket specific errors - attempt recovery
|
|
await error_handler.handle_error(
|
|
e,
|
|
context={
|
|
"message_type": message_type,
|
|
"session_id": session.id if session else "unknown",
|
|
"handler": type(self._handlers[message_type]).__name__
|
|
},
|
|
websocket=websocket,
|
|
session_id=session.id if session else None,
|
|
recovery_action=lambda: self._websocket_recovery(websocket, session)
|
|
)
|
|
|
|
except ValidationError as e:
|
|
# Validation errors - usually client-side issues
|
|
await error_handler.handle_error(
|
|
e,
|
|
context={
|
|
"message_type": message_type,
|
|
"session_id": session.id if session else "unknown",
|
|
"data": str(data)[:500] # Truncate large data
|
|
},
|
|
websocket=websocket,
|
|
session_id=session.id if session else None
|
|
)
|
|
|
|
except Exception as e:
|
|
# Unexpected errors - enhanced logging and fallback
|
|
await error_handler.handle_error(
|
|
WebSocketError(
|
|
f"Unexpected error in {message_type} handler: {e}",
|
|
severity=ErrorSeverity.HIGH
|
|
),
|
|
context={
|
|
"message_type": message_type,
|
|
"session_id": session.id if session else "unknown",
|
|
"handler": type(self._handlers[message_type]).__name__,
|
|
"exception_type": type(e).__name__,
|
|
"traceback": str(e)
|
|
},
|
|
websocket=websocket,
|
|
session_id=session.id if session else None,
|
|
recovery_action=lambda: self._generic_recovery(message_type, session, lobby)
|
|
)
|
|
|
|
async def _websocket_recovery(self, websocket: WebSocket, session: "Session"):
|
|
"""WebSocket recovery action"""
|
|
if websocket and session:
|
|
# Send a connection status update
|
|
await websocket.send_json({
|
|
"type": "connection_status",
|
|
"data": {
|
|
"status": "recovered",
|
|
"session_id": session.id,
|
|
"message": "Connection recovered from error"
|
|
}
|
|
})
|
|
|
|
async def _generic_recovery(self, message_type: str, session: "Session", lobby: "Lobby"):
|
|
"""Generic recovery action"""
|
|
# Log recovery attempt
|
|
logger.info(f"Attempting recovery for {message_type} error")
|
|
|
|
# Depending on message type, perform specific recovery
|
|
if message_type in ["join", "part"]:
|
|
# For lobby operations, ensure session state consistency
|
|
if session and lobby:
|
|
# Refresh lobby state
|
|
await lobby.update_state()
|
|
|
|
elif message_type == "set_name":
|
|
# For name operations, validate session state
|
|
if session:
|
|
logger.info(f"Validating session state for {session.id}")
|
|
|
|
def get_supported_types(self) -> list[str]:
|
|
"""Get list of supported message types"""
|
|
return list(self._handlers.keys())
|