""" Lobby management for the AI Voice Bot server. This module handles lobby lifecycle, participants, and chat functionality. Extracted from main.py to improve maintainability and separation of concerns. """ from __future__ import annotations import secrets import time import threading from typing import Dict, List, Optional, TYPE_CHECKING # Import shared models # Import shared models try: # Try relative import first (when running as part of the package) from ...shared.models import ChatMessageModel, ParticipantModel except ImportError: try: # Try absolute import (when running directly) import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from shared.models import ChatMessageModel, ParticipantModel except ImportError: raise ImportError( f"Failed to import shared models: {e}. Ensure shared/models.py is accessible and PYTHONPATH is correctly set." ) from shared.logger import logger # Use try/except for importing events to handle both relative and absolute imports try: from ..models.events import event_bus, ChatMessageSent, SessionDisconnected, SessionLeftLobby except ImportError: try: from models.events import event_bus, ChatMessageSent, SessionDisconnected, SessionLeftLobby except ImportError: # Create dummy event system for standalone testing class DummyEventBus: async def publish(self, event): pass event_bus = DummyEventBus() class ChatMessageSent: pass class SessionDisconnected: pass class SessionLeftLobby: pass if TYPE_CHECKING: from .session_manager import Session class LobbyConfig: """Configuration for lobby management""" MAX_CHAT_MESSAGES_PER_LOBBY = int(os.getenv("MAX_CHAT_MESSAGES_PER_LOBBY", "100")) class Lobby: async def broadcast_json(self, message: dict) -> None: """Broadcast an arbitrary JSON message to all connected sessions in the lobby""" failed_sessions: List[Session] = [] for peer in self.sessions.values(): if peer.ws: try: await peer.ws.send_json(message) except Exception as e: logger.warning( f"Failed to send broadcast_json message to {peer.getName()}: {e}" ) failed_sessions.append(peer) for failed_session in failed_sessions: failed_session.ws = None async def broadcast_peer_state_update(self, update: dict) -> None: """Broadcast a peer state update to all connected sessions in the lobby""" failed_sessions: List[Session] = [] for peer in self.sessions.values(): if peer.ws: try: await peer.ws.send_json(update) except Exception as e: logger.warning( f"Failed to send peer state update to {peer.getName()}: {e}" ) failed_sessions.append(peer) for failed_session in failed_sessions: failed_session.ws = None """Individual lobby representing a chat/voice room""" def __init__(self, name: str, id: Optional[str] = None, private: bool = False): self.id = secrets.token_hex(16) if id is None else id self.short = self.id[:8] self.name = name self.sessions: Dict[str, Session] = {} # All lobby members self.private = private self.chat_messages: List[ChatMessageModel] = [] # Store chat messages self.lock = threading.RLock() # Thread safety for lobby operations def getName(self) -> str: return f"{self.short}:{self.name}" async def update_state(self, requesting_session: Optional[Session] = None): """Update lobby state and notify participants""" with self.lock: users: List[ParticipantModel] = [ ParticipantModel( name=s.name, live=True if s.ws else False, session_id=s.id, protected=True if s.name and self._is_name_protected(s.name) else False, has_media=s.has_media, bot_run_id=s.bot_run_id, bot_provider_id=s.bot_provider_id, bot_instance_id=s.bot_instance_id, ) for s in self.sessions.values() if s.name ] if requesting_session: logger.info( f"{requesting_session.getName()} -> lobby_state({self.getName()})" ) if requesting_session.ws: try: await requesting_session.ws.send_json( { "type": "lobby_state", "data": { "participants": [ { **user.model_dump(), "muted": getattr(s, "muted", False), "video_on": getattr(s, "video_on", True), } for user, s in zip(users, self.sessions.values()) ] }, } ) except Exception as e: logger.warning( f"Failed to send lobby state to {requesting_session.getName()}: {e}" ) else: logger.warning( f"{requesting_session.getName()} - No WebSocket connection." ) else: # Send to all sessions in lobby failed_sessions: List[Session] = [] for s in self.sessions.values(): logger.info(f"{s.getName()} -> lobby_state({self.getName()})") if s.ws: try: await s.ws.send_json( { "type": "lobby_state", "data": { "participants": [ { **user.model_dump(), "muted": getattr(s, "muted", False), "video_on": getattr(s, "video_on", True), } for user, s in zip(users, self.sessions.values()) ] }, } ) except Exception as e: logger.warning( f"Failed to send lobby state to {s.getName()}: {e}" ) failed_sessions.append(s) # Clean up failed sessions for failed_session in failed_sessions: failed_session.ws = None def _is_name_protected(self, name: str) -> bool: """Check if a name is protected (has password) - to be injected by AuthManager""" # TODO: This will be handled by dependency injection from AuthManager return False def getSession(self, id: str) -> Optional[Session]: with self.lock: return self.sessions.get(id, None) async def addSession(self, session: Session) -> None: with self.lock: if session.id in self.sessions: logger.warning( f"{session.getName()} - Already in lobby {self.getName()}." ) return self.sessions[session.id] = session await self.update_state() async def removeSession(self, session: Session) -> None: with self.lock: if session.id not in self.sessions: logger.warning(f"{session.getName()} - Not in lobby {self.getName()}.") return del self.sessions[session.id] await self.update_state() def add_chat_message(self, session: Session, message: str) -> ChatMessageModel: """Add a chat message to the lobby and return the message data""" with self.lock: chat_message = ChatMessageModel( id=secrets.token_hex(8), message=message, sender_name=session.name or session.short, sender_session_id=session.id, timestamp=time.time(), lobby_id=self.id, ) self.chat_messages.append(chat_message) # Keep only the latest messages per lobby if len(self.chat_messages) > LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY: self.chat_messages = self.chat_messages[ -LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY : ] return chat_message def update_chat_message(self, chat_message: ChatMessageModel) -> ChatMessageModel: """Update an existing chat message in the lobby and return the updated message""" with self.lock: # Find the existing message by ID for i, existing_msg in enumerate(self.chat_messages): if existing_msg.id == chat_message.id: # Update the message content and timestamp updated_msg = ChatMessageModel( id=chat_message.id, message=chat_message.message, sender_name=existing_msg.sender_name, # Keep original sender sender_session_id=existing_msg.sender_session_id, # Keep original session timestamp=time.time(), # Update timestamp lobby_id=existing_msg.lobby_id, # Keep original lobby ) self.chat_messages[i] = updated_msg return updated_msg # If message not found, add it as new self.chat_messages.append(chat_message) # Keep only the latest messages per lobby if len(self.chat_messages) > LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY: self.chat_messages = self.chat_messages[ -LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY : ] return chat_message def get_chat_messages(self, limit: int = 50) -> List[ChatMessageModel]: """Get the most recent chat messages from the lobby""" with self.lock: return self.chat_messages[-limit:] if self.chat_messages else [] def clear_chat_messages(self) -> None: """Clear all chat messages from the lobby""" with self.lock: self.chat_messages.clear() async def broadcast_chat_clear(self) -> None: """Broadcast a chat clear event to all connected sessions in the lobby""" failed_sessions: List[Session] = [] for peer in self.sessions.values(): if peer.ws: try: logger.info(f"{self.getName()} -> chat_cleared({peer.getName()})") await peer.ws.send_json({"type": "chat_cleared", "data": {}}) except Exception as e: logger.warning( f"Failed to send chat clear message to {peer.getName()}: {e}" ) failed_sessions.append(peer) # Clean up failed sessions for failed_session in failed_sessions: failed_session.ws = None async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None: """Broadcast a chat message to all connected sessions in the lobby""" failed_sessions: List[Session] = [] for peer in self.sessions.values(): if peer.ws: try: logger.info(f"{self.getName()} -> chat_message({peer.getName()})") await peer.ws.send_json( {"type": "chat_message", "data": chat_message.model_dump()} ) except Exception as e: logger.warning( f"Failed to send chat message to {peer.getName()}: {e}" ) failed_sessions.append(peer) # Clean up failed sessions for failed_session in failed_sessions: failed_session.ws = None # Publish chat event await event_bus.publish(ChatMessageSent( session_id=chat_message.sender_session_id, lobby_id=chat_message.lobby_id, message=chat_message.message, sender_name=chat_message.sender_name )) def get_participant_count(self) -> int: """Get number of participants in lobby""" with self.lock: return len(self.sessions) def is_empty(self) -> bool: """Check if lobby is empty""" with self.lock: return len(self.sessions) == 0 class LobbyManager: """Manages all lobbies and their lifecycle""" def __init__(self): self.lobbies: Dict[str, Lobby] = {} self.lock = threading.RLock() # Subscribe to session events - handle import errors gracefully try: event_bus.subscribe(SessionDisconnected, self) event_bus.subscribe(SessionLeftLobby, self) except ImportError: try: event_bus.subscribe(SessionDisconnected, self) event_bus.subscribe(SessionLeftLobby, self) except (ImportError, AttributeError): # Event system not available, skip subscriptions pass async def handle(self, event): """Handle events from the event bus""" if isinstance(event, SessionDisconnected): await self._handle_session_disconnected(event) elif isinstance(event, SessionLeftLobby): await self._handle_session_left_lobby(event) async def _handle_session_disconnected(self, event): """Handle session disconnection by removing from all lobbies""" session_id = event.session_id with self.lock: lobbies_to_check = list(self.lobbies.values()) for lobby in lobbies_to_check: with lobby.lock: if session_id in lobby.sessions: del lobby.sessions[session_id] logger.info(f"Removed disconnected session {session_id} from lobby {lobby.getName()}") # Update lobby state await lobby.update_state() # Check if lobby is now empty and should be cleaned up if lobby.is_empty() and not lobby.private: await self._cleanup_empty_lobby(lobby) async def _handle_session_left_lobby(self, event): """Handle explicit session leave""" # This is already handled by the session's leave_lobby method # but we could add additional cleanup logic here if needed pass def create_or_get_lobby(self, name: str, private: bool = False) -> Lobby: """Create a new lobby or get existing one by name""" with self.lock: # Look for existing lobby with same name for lobby in self.lobbies.values(): if lobby.name == name and lobby.private == private: return lobby # Create new lobby lobby = Lobby(name=name, private=private) self.lobbies[lobby.id] = lobby logger.info(f"Created new lobby: {lobby.getName()}") return lobby def get_lobby(self, lobby_id: str) -> Optional[Lobby]: """Get lobby by ID""" with self.lock: return self.lobbies.get(lobby_id) def get_lobby_by_name(self, name: str) -> Optional[Lobby]: """Get lobby by name""" with self.lock: for lobby in self.lobbies.values(): if lobby.name == name: return lobby return None def list_lobbies(self, include_private: bool = False) -> List[Lobby]: """List all lobbies, optionally including private ones""" with self.lock: if include_private: return list(self.lobbies.values()) else: return [lobby for lobby in self.lobbies.values() if not lobby.private] async def _cleanup_empty_lobby(self, lobby: Lobby): """Clean up an empty lobby""" with self.lock: if lobby.id in self.lobbies and lobby.is_empty(): del self.lobbies[lobby.id] logger.info(f"Cleaned up empty lobby: {lobby.getName()}") def get_lobby_count(self) -> int: """Get total lobby count""" with self.lock: return len(self.lobbies) def get_total_participants(self) -> int: """Get total participants across all lobbies""" with self.lock: return sum(lobby.get_participant_count() for lobby in self.lobbies.values()) async def cleanup_empty_lobbies(self) -> int: """Clean up all empty non-private lobbies""" removed_count = 0 with self.lock: lobbies_to_remove: list[Lobby] = [] for lobby in self.lobbies.values(): if lobby.is_empty() and not lobby.private: lobbies_to_remove.append(lobby) for lobby in lobbies_to_remove: del self.lobbies[lobby.id] removed_count += 1 logger.info(f"Cleaned up empty lobby: {lobby.getName()}") return removed_count def set_name_protection_checker(self, checker_func): """Inject name protection checker from AuthManager""" # This allows us to inject the name protection logic without tight coupling for lobby in self.lobbies.values(): lobby._is_name_protected = checker_func