347 lines
11 KiB
Python
347 lines
11 KiB
Python
"""
|
|
WebSocket message routing and handling.
|
|
|
|
This module provides a clean way to route WebSocket messages to appropriate handlers,
|
|
replacing the massive switch statement from main.py.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Any, TYPE_CHECKING
|
|
from fastapi import WebSocket
|
|
|
|
from logger import logger
|
|
from .webrtc_signaling import WebRTCSignalingHandlers
|
|
|
|
if TYPE_CHECKING:
|
|
from ..core.session_manager import Session
|
|
from ..core.lobby_manager import Lobby
|
|
from ..core.auth_manager import AuthManager
|
|
|
|
|
|
class MessageHandler(ABC):
|
|
"""Base class for WebSocket message handlers"""
|
|
|
|
@abstractmethod
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
"""Handle a WebSocket message"""
|
|
pass
|
|
|
|
|
|
class SetNameHandler(MessageHandler):
|
|
"""Handler for set_name messages"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
auth_manager: "AuthManager" = managers["auth_manager"]
|
|
session_manager = managers["session_manager"]
|
|
|
|
if not data:
|
|
logger.error(f"{session.getName()} - set_name missing data")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {"error": "set_name missing data"},
|
|
})
|
|
return
|
|
|
|
name = data.get("name")
|
|
password = data.get("password")
|
|
|
|
logger.info(f"{session.getName()} <- set_name({name}, {'***' if password else None})")
|
|
|
|
if not name:
|
|
logger.error(f"{session.getName()} - Name required")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {"error": "Name required"}
|
|
})
|
|
return
|
|
|
|
# Check if name is unique
|
|
if session_manager.is_unique_name(name):
|
|
# If a password was provided, save it for this name
|
|
if password:
|
|
auth_manager.set_password(name, password)
|
|
|
|
session.setName(name)
|
|
logger.info(f"{session.getName()}: -> update('name', {name})")
|
|
|
|
await websocket.send_json({
|
|
"type": "update_name",
|
|
"data": {
|
|
"name": name,
|
|
"protected": auth_manager.is_name_protected(name),
|
|
},
|
|
})
|
|
|
|
# Update lobby state
|
|
await lobby.update_state()
|
|
return
|
|
|
|
# Name is taken - check takeover
|
|
allowed, reason = auth_manager.check_name_takeover(name, password)
|
|
|
|
if not allowed:
|
|
logger.warning(f"{session.getName()} - {reason}")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {"error": reason}
|
|
})
|
|
return
|
|
|
|
# Takeover allowed - handle displacement
|
|
displaced = session_manager.get_session_by_name(name)
|
|
if displaced and displaced.id != session.id:
|
|
# Create unique fallback name
|
|
fallback = f"{displaced.name}-{displaced.short}"
|
|
counter = 1
|
|
while not session_manager.is_unique_name(fallback):
|
|
fallback = f"{displaced.name}-{displaced.short}-{counter}"
|
|
counter += 1
|
|
|
|
displaced.setName(fallback)
|
|
displaced.mark_displaced()
|
|
logger.info(f"{displaced.getName()} <- displaced by takeover, new name {fallback}")
|
|
|
|
# Notify displaced session
|
|
if displaced.ws:
|
|
try:
|
|
await displaced.ws.send_json({
|
|
"type": "update_name",
|
|
"data": {
|
|
"name": fallback,
|
|
"protected": False,
|
|
},
|
|
})
|
|
except Exception:
|
|
logger.exception("Failed to notify displaced session websocket")
|
|
|
|
# Update lobbies for displaced session
|
|
for d_lobby in displaced.lobbies[:]:
|
|
try:
|
|
await d_lobby.update_state()
|
|
except Exception:
|
|
logger.exception("Failed to update lobby state for displaced session")
|
|
|
|
# Set new password if provided
|
|
if password:
|
|
auth_manager.set_password(name, password)
|
|
|
|
# Assign name to current session
|
|
session.setName(name)
|
|
logger.info(f"{session.getName()}: -> update('name', {name}) (takeover)")
|
|
|
|
await websocket.send_json({
|
|
"type": "update_name",
|
|
"data": {
|
|
"name": name,
|
|
"protected": auth_manager.is_name_protected(name),
|
|
},
|
|
})
|
|
|
|
# Update lobby state
|
|
await lobby.update_state()
|
|
|
|
|
|
class JoinHandler(MessageHandler):
|
|
"""Handler for join messages"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
logger.info(f"{session.getName()} <- join({lobby.getName()})")
|
|
await session.join_lobby(lobby)
|
|
|
|
|
|
class PartHandler(MessageHandler):
|
|
"""Handler for part messages"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
logger.info(f"{session.getName()} <- part {lobby.getName()}")
|
|
await session.leave_lobby(lobby)
|
|
|
|
|
|
class ListUsersHandler(MessageHandler):
|
|
"""Handler for list_users messages"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
await lobby.update_state(session)
|
|
|
|
|
|
class GetChatMessagesHandler(MessageHandler):
|
|
"""Handler for get_chat_messages messages"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
messages = lobby.get_chat_messages(50)
|
|
await websocket.send_json({
|
|
"type": "chat_messages",
|
|
"data": {
|
|
"messages": [msg.model_dump() for msg in messages]
|
|
},
|
|
})
|
|
|
|
|
|
class SendChatMessageHandler(MessageHandler):
|
|
"""Handler for send_chat_message messages"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
) -> None:
|
|
if not data or "message" not in data:
|
|
logger.error(f"{session.getName()} - send_chat_message missing message")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {"error": "send_chat_message missing message"},
|
|
})
|
|
return
|
|
|
|
if not session.name:
|
|
logger.error(f"{session.getName()} - Cannot send chat message without name")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {"error": "Must set name before sending chat messages"},
|
|
})
|
|
return
|
|
|
|
message_text = str(data["message"]).strip()
|
|
if not message_text:
|
|
return
|
|
|
|
# Add the message to the lobby and broadcast it
|
|
chat_message = lobby.add_chat_message(session, message_text)
|
|
logger.info(f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)")
|
|
await lobby.broadcast_chat_message(chat_message)
|
|
|
|
|
|
class RelayICECandidateHandler(MessageHandler):
|
|
"""Handler for relayICECandidate messages - WebRTC signaling"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any],
|
|
) -> None:
|
|
await WebRTCSignalingHandlers.handle_relay_ice_candidate(
|
|
websocket, session, lobby, data
|
|
)
|
|
|
|
|
|
class RelaySessionDescriptionHandler(MessageHandler):
|
|
"""Handler for relaySessionDescription messages - WebRTC signaling"""
|
|
|
|
async def handle(
|
|
self,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any],
|
|
) -> None:
|
|
await WebRTCSignalingHandlers.handle_relay_session_description(
|
|
websocket, session, lobby, data
|
|
)
|
|
|
|
|
|
class MessageRouter:
|
|
"""Routes WebSocket messages to appropriate handlers"""
|
|
|
|
def __init__(self):
|
|
self._handlers: Dict[str, MessageHandler] = {}
|
|
self._register_default_handlers()
|
|
|
|
def _register_default_handlers(self):
|
|
"""Register default message handlers"""
|
|
self.register("set_name", SetNameHandler())
|
|
self.register("join", JoinHandler())
|
|
self.register("part", PartHandler())
|
|
self.register("list_users", ListUsersHandler())
|
|
self.register("get_chat_messages", GetChatMessagesHandler())
|
|
self.register("send_chat_message", SendChatMessageHandler())
|
|
|
|
# WebRTC signaling handlers
|
|
self.register("relayICECandidate", RelayICECandidateHandler())
|
|
self.register("relaySessionDescription", RelaySessionDescriptionHandler())
|
|
|
|
def register(self, message_type: str, handler: MessageHandler):
|
|
"""Register a handler for a message type"""
|
|
self._handlers[message_type] = handler
|
|
logger.debug(f"Registered handler for message type: {message_type}")
|
|
|
|
async def route(
|
|
self,
|
|
message_type: str,
|
|
session: "Session",
|
|
lobby: "Lobby",
|
|
data: Dict[str, Any],
|
|
websocket: WebSocket,
|
|
managers: Dict[str, Any]
|
|
):
|
|
"""Route a message to the appropriate handler"""
|
|
if message_type in self._handlers:
|
|
try:
|
|
await self._handlers[message_type].handle(session, lobby, data, websocket, managers)
|
|
except Exception as e:
|
|
import traceback
|
|
logger.error(f"Error handling message type {message_type}: {e}")
|
|
logger.error(f"Full traceback: {traceback.format_exc()}")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {"error": f"Internal error handling {message_type}"}
|
|
})
|
|
else:
|
|
logger.warning(f"Unknown message type: {message_type}")
|
|
await websocket.send_json({
|
|
"type": "error",
|
|
"data": {"error": f"Unknown message type: {message_type}"}
|
|
})
|
|
|
|
def get_supported_types(self) -> list[str]:
|
|
"""Get list of supported message types"""
|
|
return list(self._handlers.keys())
|