""" WebSocket connection management. This module handles WebSocket connections and integrates with the message router. """ import asyncio from typing import Dict, Any, Optional, TYPE_CHECKING from fastapi import WebSocket, WebSocketDisconnect from shared.logger import logger from .message_handlers import MessageRouter if TYPE_CHECKING: # Use absolute imports to avoid relative import issues try: from core.session_manager import Session, SessionManager from core.lobby_manager import Lobby, LobbyManager from core.auth_manager import AuthManager except ImportError: # Fallback for when running from different directory structure from ..core.session_manager import Session, SessionManager from ..core.lobby_manager import Lobby, LobbyManager from ..core.auth_manager import AuthManager class WebSocketConnectionManager: """Manages WebSocket connections and message processing""" def __init__( self, session_manager: "SessionManager", lobby_manager: "LobbyManager", auth_manager: "AuthManager" ): self.session_manager = session_manager self.lobby_manager = lobby_manager self.auth_manager = auth_manager self.message_router = MessageRouter() # Managers dict for injection into handlers self.managers = { "session_manager": session_manager, "lobby_manager": lobby_manager, "auth_manager": auth_manager, } async def handle_connection( self, websocket: WebSocket, lobby_id: str, session_id: str ): """Handle a WebSocket connection for a session in a lobby""" await websocket.accept() # Validate inputs if not lobby_id: await websocket.send_json({ "type": "error", "data": {"error": "Invalid or missing lobby"} }) await websocket.close() return if not session_id: await websocket.send_json({ "type": "error", "data": {"error": "Invalid or missing session"} }) await websocket.close() return # Get session session = self.session_manager.get_session(session_id) if not session: await websocket.send_json({ "type": "error", "data": {"error": f"Invalid session ID {session_id}"} }) await websocket.close() return # Get lobby lobby = self.lobby_manager.get_lobby(lobby_id) if not lobby: await websocket.send_json({ "type": "error", "data": {"error": f"Lobby not found: {lobby_id}"} }) await websocket.close() return logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})") # Set up connection session.ws = websocket session.update_last_used() # Clean up stale session in lobby if needed. If a stale session # was present, perform a full re-join so WebRTC peer setup runs # (this ensures bots and other peers receive addPeer signaling). stale_rejoin = False if session.id in lobby.sessions: stale_rejoin = True logger.info(f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining.") try: # Leave the lobby to clean up peer connections await session.leave_lobby(lobby) # Verify the session has been properly removed before proceeding while session.id in lobby.sessions: logger.debug( f"Waiting for session {session.getName()} to be fully removed from lobby" ) # Brief yield to allow cleanup to complete await asyncio.sleep(0.01) # Safety check to prevent infinite loop if session.id in lobby.sessions: logger.warning( f"Force removing stale session {session.getName()} from lobby" ) with lobby.lock: lobby.sessions.pop(session.id, None) break except Exception as e: logger.warning(f"Error cleaning up stale session: {e}") # If this was a stale rejoin, re-add the session to the lobby which will # trigger WebRTC signaling (addPeer) with existing peers (including bots). if stale_rejoin: try: await session.join_lobby(lobby) except Exception as e: logger.warning(f"Error re-joining stale session to lobby: {e}") else: # Notify existing peers about a fresh new connection (no automatic join yet) await self._notify_peers_of_join(session, lobby) try: # Message processing loop while True: packet = await websocket.receive_json() session.update_last_used() message_type = packet.get("type", None) data: Optional[Dict[str, Any]] = packet.get("data", None) if not message_type: logger.error(f"{session.getName()} - Invalid request: {packet}") await websocket.send_json({ "type": "error", "data": {"error": "Invalid request"} }) continue # Route message to appropriate handler await self.message_router.route( message_type, session, lobby, data or {}, websocket, self.managers ) except WebSocketDisconnect: logger.info(f"{session.getName()} <- WebSocket disconnected") except Exception as e: logger.error(f"Error in WebSocket connection for {session.getName()}: {e}") finally: # Clean up connection await self._cleanup_connection(session, lobby) async def _notify_peers_of_join(self, session: "Session", lobby: "Lobby"): """Notify existing peers about a new user joining""" failed_peers = [] with lobby.lock: peer_sessions = list(lobby.sessions.values()) for peer_session in peer_sessions: if not peer_session.ws: logger.warning( f"{session.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Marking for removal." ) failed_peers.append(peer_session.id) continue logger.info(f"{session.getName()} -> user_joined({peer_session.getName()})") try: await peer_session.ws.send_json({ "type": "user_joined", "data": { "session_id": session.id, "name": session.name, }, }) except Exception as e: logger.warning(f"Failed to notify {peer_session.getName()} of user join: {e}") failed_peers.append(peer_session.id) # Clean up failed peers with lobby.lock: for failed_peer_id in failed_peers: if failed_peer_id in lobby.sessions: del lobby.sessions[failed_peer_id] async def _cleanup_connection(self, session: "Session", lobby: "Lobby"): """Clean up when connection is closed""" try: # Clear WebSocket reference session.ws = None # Remove from lobby if present if session.id in lobby.sessions: await session.leave_lobby(lobby) logger.info(f"Removed {session.getName()} from lobby {lobby.getName()} on disconnect") except Exception as e: logger.error(f"Error during connection cleanup for {session.getName()}: {e}") def add_message_handler(self, message_type: str, handler): """Add a custom message handler""" self.message_router.register(message_type, handler) def get_supported_message_types(self) -> list[str]: """Get list of supported message types""" return self.message_router.get_supported_types()