ai-voicebot/server/websocket/connection.py
2025-09-08 16:03:53 -07:00

224 lines
8.5 KiB
Python

"""
WebSocket connection management.
This module handles WebSocket connections and integrates with the message router.
"""
import asyncio
from typing import Dict, Any, Optional, TYPE_CHECKING
from fastapi import WebSocket, WebSocketDisconnect
from shared.logger import logger
from .message_handlers import MessageRouter
if TYPE_CHECKING:
# Use absolute imports to avoid relative import issues
try:
from core.session_manager import Session, SessionManager
from core.lobby_manager import Lobby, LobbyManager
from core.auth_manager import AuthManager
except ImportError:
# Fallback for when running from different directory structure
from ..core.session_manager import Session, SessionManager
from ..core.lobby_manager import Lobby, LobbyManager
from ..core.auth_manager import AuthManager
class WebSocketConnectionManager:
"""Manages WebSocket connections and message processing"""
def __init__(
self,
session_manager: "SessionManager",
lobby_manager: "LobbyManager",
auth_manager: "AuthManager"
):
self.session_manager = session_manager
self.lobby_manager = lobby_manager
self.auth_manager = auth_manager
self.message_router = MessageRouter()
# Managers dict for injection into handlers
self.managers = {
"session_manager": session_manager,
"lobby_manager": lobby_manager,
"auth_manager": auth_manager,
}
async def handle_connection(
self,
websocket: WebSocket,
lobby_id: str,
session_id: str
):
"""Handle a WebSocket connection for a session in a lobby"""
await websocket.accept()
# Validate inputs
if not lobby_id:
await websocket.send_json({
"type": "error",
"data": {"error": "Invalid or missing lobby"}
})
await websocket.close()
return
if not session_id:
await websocket.send_json({
"type": "error",
"data": {"error": "Invalid or missing session"}
})
await websocket.close()
return
# Get session
session = self.session_manager.get_session(session_id)
if not session:
await websocket.send_json({
"type": "error",
"data": {"error": f"Invalid session ID {session_id}"}
})
await websocket.close()
return
# Get lobby
lobby = self.lobby_manager.get_lobby(lobby_id)
if not lobby:
await websocket.send_json({
"type": "error",
"data": {"error": f"Lobby not found: {lobby_id}"}
})
await websocket.close()
return
logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})")
# Set up connection
session.ws = websocket
session.update_last_used()
# Clean up stale session in lobby if needed. If a stale session
# was present, perform a full re-join so WebRTC peer setup runs
# (this ensures bots and other peers receive addPeer signaling).
stale_rejoin = False
if session.id in lobby.sessions:
stale_rejoin = True
logger.info(f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining.")
try:
# Leave the lobby to clean up peer connections
await session.leave_lobby(lobby)
# Verify the session has been properly removed before proceeding
while session.id in lobby.sessions:
logger.debug(
f"Waiting for session {session.getName()} to be fully removed from lobby"
)
# Brief yield to allow cleanup to complete
await asyncio.sleep(0.01)
# Safety check to prevent infinite loop
if session.id in lobby.sessions:
logger.warning(
f"Force removing stale session {session.getName()} from lobby"
)
with lobby.lock:
lobby.sessions.pop(session.id, None)
break
except Exception as e:
logger.warning(f"Error cleaning up stale session: {e}")
# If this was a stale rejoin, re-add the session to the lobby which will
# trigger WebRTC signaling (addPeer) with existing peers (including bots).
if stale_rejoin:
try:
await session.join_lobby(lobby)
except Exception as e:
logger.warning(f"Error re-joining stale session to lobby: {e}")
else:
# Notify existing peers about a fresh new connection (no automatic join yet)
await self._notify_peers_of_join(session, lobby)
try:
# Message processing loop
while True:
packet = await websocket.receive_json()
session.update_last_used()
message_type = packet.get("type", None)
data: Optional[Dict[str, Any]] = packet.get("data", None)
if not message_type:
logger.error(f"{session.getName()} - Invalid request: {packet}")
await websocket.send_json({
"type": "error",
"data": {"error": "Invalid request"}
})
continue
# Route message to appropriate handler
await self.message_router.route(
message_type, session, lobby, data or {}, websocket, self.managers
)
except WebSocketDisconnect:
logger.info(f"{session.getName()} <- WebSocket disconnected")
except Exception as e:
logger.error(f"Error in WebSocket connection for {session.getName()}: {e}")
finally:
# Clean up connection
await self._cleanup_connection(session, lobby)
async def _notify_peers_of_join(self, session: "Session", lobby: "Lobby"):
"""Notify existing peers about a new user joining"""
failed_peers = []
with lobby.lock:
peer_sessions = list(lobby.sessions.values())
for peer_session in peer_sessions:
if not peer_session.ws:
logger.warning(
f"{session.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Marking for removal."
)
failed_peers.append(peer_session.id)
continue
logger.info(f"{session.getName()} -> user_joined({peer_session.getName()})")
try:
await peer_session.ws.send_json({
"type": "user_joined",
"data": {
"session_id": session.id,
"name": session.name,
},
})
except Exception as e:
logger.warning(f"Failed to notify {peer_session.getName()} of user join: {e}")
failed_peers.append(peer_session.id)
# Clean up failed peers
with lobby.lock:
for failed_peer_id in failed_peers:
if failed_peer_id in lobby.sessions:
del lobby.sessions[failed_peer_id]
async def _cleanup_connection(self, session: "Session", lobby: "Lobby"):
"""Clean up when connection is closed"""
try:
# Clear WebSocket reference
session.ws = None
# Remove from lobby if present
if session.id in lobby.sessions:
await session.leave_lobby(lobby)
logger.info(f"Removed {session.getName()} from lobby {lobby.getName()} on disconnect")
except Exception as e:
logger.error(f"Error during connection cleanup for {session.getName()}: {e}")
def add_message_handler(self, message_type: str, handler):
"""Add a custom message handler"""
self.message_router.register(message_type, handler)
def get_supported_message_types(self) -> list[str]:
"""Get list of supported message types"""
return self.message_router.get_supported_types()