ai-voicebot/server/websocket/message_handlers.py
2025-09-16 12:25:04 -07:00

560 lines
19 KiB
Python

"""
WebSocket message routing and handling.
This module provides a clean way to route WebSocket messages to appropriate handlers,
replacing the massive switch statement from main.py.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, TYPE_CHECKING
from fastapi import WebSocket
from shared.logger import logger
from shared.models import ChatMessageModel
from .webrtc_signaling import WebRTCSignalingHandlers
from core.error_handling import (
error_handler,
WebSocketError,
ValidationError,
ErrorSeverity,
)
if TYPE_CHECKING:
from ..core.session_manager import Session
from ..core.lobby_manager import Lobby
from ..core.auth_manager import AuthManager
class MessageHandler(ABC):
"""Base class for WebSocket message handlers"""
@abstractmethod
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
"""Handle a WebSocket message"""
pass
class PeerStateUpdateHandler(MessageHandler):
"""Handler for peer_state_update messages"""
async def handle(
self,
session: Any,
lobby: Any,
data: dict,
websocket: Any,
managers: dict,
) -> None:
# Only allow a user to update their own state
if not lobby or not session:
return
peer_id = data.get("peer_id", getattr(session, "id", None))
if str(peer_id) != str(getattr(session, "id", None)):
# Ignore attempts to update other users' state
# Optionally log or send error to client
return
update = {
"type": "peer_state_update",
"data": {
"peer_id": peer_id,
"muted": data.get("muted"),
"video_on": data.get("video_on"),
},
}
await lobby.broadcast_peer_state_update(update)
class SetNameHandler(MessageHandler):
"""Handler for set_name messages"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
auth_manager: "AuthManager" = managers["auth_manager"]
session_manager = managers["session_manager"]
if not data:
logger.error(f"{session.getName()} - set_name missing data")
await websocket.send_json({
"type": "error",
"data": {"error": "set_name missing data"},
})
return
name = data.get("name")
password = data.get("password")
logger.info(f"{session.getName()} <- set_name({name}, {'***' if password else None})")
if not name:
logger.error(f"{session.getName()} - Name required")
await websocket.send_json({
"type": "error",
"data": {"error": "Name required"}
})
return
# Check if name is unique
if session_manager.is_unique_name(name):
# If a password was provided, save it for this name
if password:
auth_manager.set_password(name, password)
session.setName(name)
logger.info(f"{session.getName()}: -> update('name', {name})")
await websocket.send_json({
"type": "update_name",
"data": {
"name": name,
"protected": auth_manager.is_name_protected(name),
},
})
# Update lobby state
await lobby.update_state()
return
# Name is taken - check takeover
allowed, reason = auth_manager.check_name_takeover(name, password)
if not allowed:
logger.warning(f"{session.getName()} - {reason}")
await websocket.send_json({
"type": "error",
"data": {"error": reason}
})
return
# Takeover allowed - handle displacement
displaced = session_manager.get_session_by_name(name)
if displaced and displaced.id != session.id:
# Create unique fallback name
fallback = f"{displaced.name}-{displaced.short}"
counter = 1
while not session_manager.is_unique_name(fallback):
fallback = f"{displaced.name}-{displaced.short}-{counter}"
counter += 1
displaced.setName(fallback)
displaced.mark_displaced()
logger.info(f"{displaced.getName()} <- displaced by takeover, new name {fallback}")
# Notify displaced session
if displaced.ws:
try:
await displaced.ws.send_json({
"type": "update_name",
"data": {
"name": fallback,
"protected": False,
},
})
except Exception:
logger.exception("Failed to notify displaced session websocket")
# Update lobbies for displaced session
for d_lobby in displaced.lobbies[:]:
try:
await d_lobby.update_state()
except Exception:
logger.exception("Failed to update lobby state for displaced session")
# Set new password if provided
if password:
auth_manager.set_password(name, password)
# Assign name to current session
session.setName(name)
logger.info(f"{session.getName()}: -> update('name', {name}) (takeover)")
await websocket.send_json({
"type": "update_name",
"data": {
"name": name,
"protected": auth_manager.is_name_protected(name),
},
})
# Update lobby state
await lobby.update_state()
class JoinHandler(MessageHandler):
"""Handler for join messages"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
logger.info(f"{session.getName()} <- join({lobby.getName()})")
await session.join_lobby(lobby)
class PartHandler(MessageHandler):
"""Handler for part messages"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
logger.info(f"{session.getName()} <- part {lobby.getName()}")
await session.leave_lobby(lobby)
class ListUsersHandler(MessageHandler):
"""Handler for list_users messages"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
await lobby.update_state(session)
class GetChatMessagesHandler(MessageHandler):
"""Handler for get_chat_messages messages"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
messages = lobby.get_chat_messages(50)
await websocket.send_json({
"type": "chat_messages",
"data": {
"messages": [msg.model_dump() for msg in messages]
},
})
class SendChatMessageHandler(MessageHandler):
"""Handler for send_chat_message messages"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
) -> None:
if not data or "message" not in data:
logger.error(f"{session.getName()} - send_chat_message missing message")
await websocket.send_json({
"type": "error",
"data": {"error": "send_chat_message missing message"},
})
return
if not session.name:
logger.error(f"{session.getName()} - Cannot send chat message without name")
await websocket.send_json({
"type": "error",
"data": {"error": "Must set name before sending chat messages"},
})
return
message_data = data["message"]
if isinstance(message_data, dict):
# Handle ChatMessageModel object
try:
chat_message = ChatMessageModel.model_validate(message_data)
# Validate that the sender matches the session
if chat_message.sender_session_id != session.id:
logger.error(
f"{session.getName()} - ChatMessageModel sender_session_id mismatch"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": "ChatMessageModel sender_session_id does not match session"
},
}
)
return
# Update or add the message
chat_message = lobby.update_chat_message(chat_message)
logger.info(
f"{session.getName()} -> update_chat_message({lobby.getName()}, {chat_message.message[:50]}...)"
)
except Exception as e:
logger.error(f"{session.getName()} - Invalid ChatMessageModel: {e}")
await websocket.send_json(
{
"type": "error",
"data": {"error": "Invalid ChatMessageModel format"},
}
)
return
else:
# Handle string message (legacy support)
message_text = str(message_data).strip()
if not message_text:
return
# Add the message to the lobby and broadcast it
chat_message = lobby.add_chat_message(session, message_text)
logger.info(
f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)"
)
await lobby.broadcast_chat_message(chat_message)
class ClearChatMessagesHandler(MessageHandler):
"""Handler for clear_chat_messages messages"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any],
) -> None:
if not session.name:
logger.error(
f"{session.getName()} - Cannot clear chat messages without name"
)
await websocket.send_json(
{
"type": "error",
"data": {"error": "Must set name before clearing chat messages"},
}
)
return
# Clear the messages and broadcast the clear event
lobby.clear_chat_messages()
logger.info(f"{session.getName()} -> clear_chat_messages({lobby.getName()})")
await lobby.broadcast_chat_clear()
class RelayICECandidateHandler(MessageHandler):
"""Handler for relayICECandidate messages - WebRTC signaling"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any],
) -> None:
await WebRTCSignalingHandlers.handle_relay_ice_candidate(
websocket, session, lobby, data
)
class RelaySessionDescriptionHandler(MessageHandler):
"""Handler for relaySessionDescription messages - WebRTC signaling"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any],
) -> None:
await WebRTCSignalingHandlers.handle_relay_session_description(
websocket, session, lobby, data
)
class StatusCheckHandler(MessageHandler):
"""Handler for status_check messages - Bot health monitoring"""
async def handle(
self,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any],
) -> None:
# Simple status check response
logger.debug(f"{session.getName()} <- status_check")
# Respond with current status
await websocket.send_json(
{
"type": "status_response",
"data": {
"status": "ok",
"timestamp": data.get("timestamp"),
"session_id": session.id,
"lobby": lobby.getName() if lobby else None,
},
}
)
class MessageRouter:
"""Routes WebSocket messages to appropriate handlers"""
def __init__(self):
self._handlers: Dict[str, MessageHandler] = {}
self._register_default_handlers()
def _register_default_handlers(self):
"""Register default message handlers"""
self.register("set_name", SetNameHandler())
self.register("join", JoinHandler())
self.register("part", PartHandler())
self.register("list_users", ListUsersHandler())
self.register("get_chat_messages", GetChatMessagesHandler())
self.register("send_chat_message", SendChatMessageHandler())
self.register("clear_chat_messages", ClearChatMessagesHandler())
# WebRTC signaling handlers
self.register("relayICECandidate", RelayICECandidateHandler())
self.register("relaySessionDescription", RelaySessionDescriptionHandler())
# Bot monitoring handlers
self.register("status_check", StatusCheckHandler())
self.register("peer_state_update", PeerStateUpdateHandler())
def register(self, message_type: str, handler: MessageHandler):
"""Register a handler for a message type"""
self._handlers[message_type] = handler
logger.debug(f"Registered handler for message type: {message_type}")
async def route(
self,
message_type: str,
session: "Session",
lobby: "Lobby",
data: Dict[str, Any],
websocket: WebSocket,
managers: Dict[str, Any]
):
"""Route a message to the appropriate handler with enhanced error handling"""
if message_type not in self._handlers:
await error_handler.handle_error(
ValidationError(f"Unknown message type: {message_type}"),
context={
"message_type": message_type,
"session_id": session.id if session else "unknown",
"data_keys": list(data.keys()) if data else []
},
websocket=websocket,
session_id=session.id if session else None
)
return
try:
# Execute handler with context tracking
await self._handlers[message_type].handle(
session, lobby, data, websocket, managers
)
except WebSocketError as e:
# WebSocket specific errors - attempt recovery
await error_handler.handle_error(
e,
context={
"message_type": message_type,
"session_id": session.id if session else "unknown",
"handler": type(self._handlers[message_type]).__name__
},
websocket=websocket,
session_id=session.id if session else None,
recovery_action=lambda: self._websocket_recovery(websocket, session)
)
except ValidationError as e:
# Validation errors - usually client-side issues
await error_handler.handle_error(
e,
context={
"message_type": message_type,
"session_id": session.id if session else "unknown",
"data": str(data)[:500] # Truncate large data
},
websocket=websocket,
session_id=session.id if session else None
)
except Exception as e:
# Unexpected errors - enhanced logging and fallback
await error_handler.handle_error(
WebSocketError(
f"Unexpected error in {message_type} handler: {e}",
severity=ErrorSeverity.HIGH
),
context={
"message_type": message_type,
"session_id": session.id if session else "unknown",
"handler": type(self._handlers[message_type]).__name__,
"exception_type": type(e).__name__,
"traceback": str(e)
},
websocket=websocket,
session_id=session.id if session else None,
recovery_action=lambda: self._generic_recovery(message_type, session, lobby)
)
async def _websocket_recovery(self, websocket: WebSocket, session: "Session"):
"""WebSocket recovery action"""
if websocket and session:
# Send a connection status update
await websocket.send_json({
"type": "connection_status",
"data": {
"status": "recovered",
"session_id": session.id,
"message": "Connection recovered from error"
}
})
async def _generic_recovery(self, message_type: str, session: "Session", lobby: "Lobby"):
"""Generic recovery action"""
# Log recovery attempt
logger.info(f"Attempting recovery for {message_type} error")
# Depending on message type, perform specific recovery
if message_type in ["join", "part"]:
# For lobby operations, ensure session state consistency
if session and lobby:
# Refresh lobby state
await lobby.update_state()
elif message_type == "set_name":
# For name operations, validate session state
if session:
logger.info(f"Validating session state for {session.id}")
def get_supported_types(self) -> list[str]:
"""Get list of supported message types"""
return list(self._handlers.keys())