ai-voicebot/server/websocket/message_handlers.py

310 lines
9.9 KiB
Python

"""
WebSocket message routing and handling.
This module provides a clean way to route WebSocket messages to appropriate handlers,
replacing the massive switch statement from main.py.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, TYPE_CHECKING
from fastapi import WebSocket
from logger import logger
if TYPE_CHECKING:
from ..core.session_manager import Session
from ..core.lobby_manager import Lobby
from ..core.auth_manager import AuthManager
class MessageHandler(ABC):
"""Base class for WebSocket message handlers"""
@abstractmethod
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
"""Handle a WebSocket message"""
pass
class SetNameHandler(MessageHandler):
"""Handler for set_name messages"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
auth_manager: "AuthManager" = managers["auth_manager"]
session_manager = managers["session_manager"]
if not data:
logger.error(f"{session.getName()} - set_name missing data")
await websocket.send_json({
"type": "error",
"data": {"error": "set_name missing data"},
})
return
name = data.get("name")
password = data.get("password")
logger.info(f"{session.getName()} <- set_name({name}, {'***' if password else None})")
if not name:
logger.error(f"{session.getName()} - Name required")
await websocket.send_json({
"type": "error",
"data": {"error": "Name required"}
})
return
# Check if name is unique
if session_manager.is_unique_name(name):
# If a password was provided, save it for this name
if password:
auth_manager.set_password(name, password)
session.setName(name)
logger.info(f"{session.getName()}: -> update('name', {name})")
await websocket.send_json({
"type": "update_name",
"data": {
"name": name,
"protected": auth_manager.is_name_protected(name),
},
})
# Update lobby state
await lobby.update_state()
return
# Name is taken - check takeover
allowed, reason = auth_manager.check_name_takeover(name, password)
if not allowed:
logger.warning(f"{session.getName()} - {reason}")
await websocket.send_json({
"type": "error",
"data": {"error": reason}
})
return
# Takeover allowed - handle displacement
displaced = session_manager.get_session_by_name(name)
if displaced and displaced.id != session.id:
# Create unique fallback name
fallback = f"{displaced.name}-{displaced.short}"
counter = 1
while not session_manager.is_unique_name(fallback):
fallback = f"{displaced.name}-{displaced.short}-{counter}"
counter += 1
displaced.setName(fallback)
displaced.mark_displaced()
logger.info(f"{displaced.getName()} <- displaced by takeover, new name {fallback}")
# Notify displaced session
if displaced.ws:
try:
await displaced.ws.send_json({
"type": "update_name",
"data": {
"name": fallback,
"protected": False,
},
})
except Exception:
logger.exception("Failed to notify displaced session websocket")
# Update lobbies for displaced session
for d_lobby in displaced.lobbies[:]:
try:
await d_lobby.update_state()
except Exception:
logger.exception("Failed to update lobby state for displaced session")
# Set new password if provided
if password:
auth_manager.set_password(name, password)
# Assign name to current session
session.setName(name)
logger.info(f"{session.getName()}: -> update('name', {name}) (takeover)")
await websocket.send_json({
"type": "update_name",
"data": {
"name": name,
"protected": auth_manager.is_name_protected(name),
},
})
# Update lobby state
await lobby.update_state()
class JoinHandler(MessageHandler):
"""Handler for join messages"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
logger.info(f"{session.getName()} <- join({lobby.getName()})")
await session.join_lobby(lobby)
class PartHandler(MessageHandler):
"""Handler for part messages"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
logger.info(f"{session.getName()} <- part {lobby.getName()}")
await session.leave_lobby(lobby)
class ListUsersHandler(MessageHandler):
"""Handler for list_users messages"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
await lobby.update_state(session)
class GetChatMessagesHandler(MessageHandler):
"""Handler for get_chat_messages messages"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
messages = lobby.get_chat_messages(50)
await websocket.send_json({
"type": "chat_messages",
"data": {
"messages": [msg.model_dump() for msg in messages]
},
})
class SendChatMessageHandler(MessageHandler):
"""Handler for send_chat_message messages"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
if not data or "message" not in data:
logger.error(f"{session.getName()} - send_chat_message missing message")
await websocket.send_json({
"type": "error",
"data": {"error": "send_chat_message missing message"},
})
return
if not session.name:
logger.error(f"{session.getName()} - Cannot send chat message without name")
await websocket.send_json({
"type": "error",
"data": {"error": "Must set name before sending chat messages"},
})
return
message_text = str(data["message"]).strip()
if not message_text:
return
# Add the message to the lobby and broadcast it
chat_message = lobby.add_chat_message(session, message_text)
logger.info(f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)")
await lobby.broadcast_chat_message(chat_message)
class MessageRouter:
"""Routes WebSocket messages to appropriate handlers"""
def __init__(self):
self._handlers: Dict[str, MessageHandler] = {}
self._register_default_handlers()
def _register_default_handlers(self):
"""Register default message handlers"""
self.register("set_name", SetNameHandler())
self.register("join", JoinHandler())
self.register("part", PartHandler())
self.register("list_users", ListUsersHandler())
self.register("get_chat_messages", GetChatMessagesHandler())
self.register("send_chat_message", SendChatMessageHandler())
def register(self, message_type: str, handler: MessageHandler):
"""Register a handler for a message type"""
self._handlers[message_type] = handler
logger.debug(f"Registered handler for message type: {message_type}")
async def route(
self,
message_type: str,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
):
"""Route a message to the appropriate handler"""
if message_type in self._handlers:
try:
await self._handlers[message_type].handle(session, lobby, data, websocket, managers)
except Exception as e:
import traceback
logger.error(f"Error handling message type {message_type}: {e}")
logger.error(f"Full traceback: {traceback.format_exc()}")
await websocket.send_json({
"type": "error",
"data": {"error": f"Internal error handling {message_type}"}
})
else:
logger.warning(f"Unknown message type: {message_type}")
await websocket.send_json({
"type": "error",
"data": {"error": f"Unknown message type: {message_type}"}
})
def get_supported_types(self) -> list[str]:
"""Get list of supported message types"""
return list(self._handlers.keys())