ai-voicebot/server/core/lobby_manager.py
2025-09-16 12:25:04 -07:00

455 lines
18 KiB
Python

"""
Lobby management for the AI Voice Bot server.
This module handles lobby lifecycle, participants, and chat functionality.
Extracted from main.py to improve maintainability and separation of concerns.
"""
from __future__ import annotations
import secrets
import time
import threading
from typing import Dict, List, Optional, TYPE_CHECKING
# Import shared models
# Import shared models
try:
# Try relative import first (when running as part of the package)
from ...shared.models import ChatMessageModel, ParticipantModel
except ImportError:
try:
# Try absolute import (when running directly)
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from shared.models import ChatMessageModel, ParticipantModel
except ImportError:
raise ImportError(
f"Failed to import shared models: {e}. Ensure shared/models.py is accessible and PYTHONPATH is correctly set."
)
from shared.logger import logger
# Use try/except for importing events to handle both relative and absolute imports
try:
from ..models.events import event_bus, ChatMessageSent, SessionDisconnected, SessionLeftLobby
except ImportError:
try:
from models.events import event_bus, ChatMessageSent, SessionDisconnected, SessionLeftLobby
except ImportError:
# Create dummy event system for standalone testing
class DummyEventBus:
async def publish(self, event):
pass
event_bus = DummyEventBus()
class ChatMessageSent:
pass
class SessionDisconnected:
pass
class SessionLeftLobby:
pass
if TYPE_CHECKING:
from .session_manager import Session
class LobbyConfig:
"""Configuration for lobby management"""
MAX_CHAT_MESSAGES_PER_LOBBY = int(os.getenv("MAX_CHAT_MESSAGES_PER_LOBBY", "100"))
class Lobby:
async def broadcast_json(self, message: dict) -> None:
"""Broadcast an arbitrary JSON message to all connected sessions in the lobby"""
failed_sessions: List[Session] = []
for peer in self.sessions.values():
if peer.ws:
try:
await peer.ws.send_json(message)
except Exception as e:
logger.warning(
f"Failed to send broadcast_json message to {peer.getName()}: {e}"
)
failed_sessions.append(peer)
for failed_session in failed_sessions:
failed_session.ws = None
async def broadcast_peer_state_update(self, update: dict) -> None:
"""Broadcast a peer state update to all connected sessions in the lobby"""
failed_sessions: List[Session] = []
for peer in self.sessions.values():
if peer.ws:
try:
await peer.ws.send_json(update)
except Exception as e:
logger.warning(
f"Failed to send peer state update to {peer.getName()}: {e}"
)
failed_sessions.append(peer)
for failed_session in failed_sessions:
failed_session.ws = None
"""Individual lobby representing a chat/voice room"""
def __init__(self, name: str, id: Optional[str] = None, private: bool = False):
self.id = secrets.token_hex(16) if id is None else id
self.short = self.id[:8]
self.name = name
self.sessions: Dict[str, Session] = {} # All lobby members
self.private = private
self.chat_messages: List[ChatMessageModel] = [] # Store chat messages
self.lock = threading.RLock() # Thread safety for lobby operations
def getName(self) -> str:
return f"{self.short}:{self.name}"
async def update_state(self, requesting_session: Optional[Session] = None):
"""Update lobby state and notify participants"""
with self.lock:
users: List[ParticipantModel] = [
ParticipantModel(
name=s.name,
live=True if s.ws else False,
session_id=s.id,
protected=True
if s.name and self._is_name_protected(s.name)
else False,
has_media=s.has_media,
bot_run_id=s.bot_run_id,
bot_provider_id=s.bot_provider_id,
bot_instance_id=s.bot_instance_id,
)
for s in self.sessions.values()
if s.name
]
if requesting_session:
logger.info(
f"{requesting_session.getName()} -> lobby_state({self.getName()})"
)
if requesting_session.ws:
try:
await requesting_session.ws.send_json(
{
"type": "lobby_state",
"data": {
"participants": [
{
**user.model_dump(),
"muted": getattr(s, "muted", False),
"video_on": getattr(s, "video_on", True),
}
for user, s in zip(users, self.sessions.values())
]
},
}
)
except Exception as e:
logger.warning(
f"Failed to send lobby state to {requesting_session.getName()}: {e}"
)
else:
logger.warning(
f"{requesting_session.getName()} - No WebSocket connection."
)
else:
# Send to all sessions in lobby
failed_sessions: List[Session] = []
for s in self.sessions.values():
logger.info(f"{s.getName()} -> lobby_state({self.getName()})")
if s.ws:
try:
await s.ws.send_json(
{
"type": "lobby_state",
"data": {
"participants": [
{
**user.model_dump(),
"muted": getattr(s, "muted", False),
"video_on": getattr(s, "video_on", True),
}
for user, s in zip(users, self.sessions.values())
]
},
}
)
except Exception as e:
logger.warning(
f"Failed to send lobby state to {s.getName()}: {e}"
)
failed_sessions.append(s)
# Clean up failed sessions
for failed_session in failed_sessions:
failed_session.ws = None
def _is_name_protected(self, name: str) -> bool:
"""Check if a name is protected (has password) - to be injected by AuthManager"""
# TODO: This will be handled by dependency injection from AuthManager
return False
def getSession(self, id: str) -> Optional[Session]:
with self.lock:
return self.sessions.get(id, None)
async def addSession(self, session: Session) -> None:
with self.lock:
if session.id in self.sessions:
logger.warning(
f"{session.getName()} - Already in lobby {self.getName()}."
)
return
self.sessions[session.id] = session
await self.update_state()
async def removeSession(self, session: Session) -> None:
with self.lock:
if session.id not in self.sessions:
logger.warning(f"{session.getName()} - Not in lobby {self.getName()}.")
return
del self.sessions[session.id]
await self.update_state()
def add_chat_message(self, session: Session, message: str) -> ChatMessageModel:
"""Add a chat message to the lobby and return the message data"""
with self.lock:
chat_message = ChatMessageModel(
id=secrets.token_hex(8),
message=message,
sender_name=session.name or session.short,
sender_session_id=session.id,
timestamp=time.time(),
lobby_id=self.id,
)
self.chat_messages.append(chat_message)
# Keep only the latest messages per lobby
if len(self.chat_messages) > LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY:
self.chat_messages = self.chat_messages[
-LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY :
]
return chat_message
def update_chat_message(self, chat_message: ChatMessageModel) -> ChatMessageModel:
"""Update an existing chat message in the lobby and return the updated message"""
with self.lock:
# Find the existing message by ID
for i, existing_msg in enumerate(self.chat_messages):
if existing_msg.id == chat_message.id:
# Update the message content and timestamp
updated_msg = ChatMessageModel(
id=chat_message.id,
message=chat_message.message,
sender_name=existing_msg.sender_name, # Keep original sender
sender_session_id=existing_msg.sender_session_id, # Keep original session
timestamp=time.time(), # Update timestamp
lobby_id=existing_msg.lobby_id, # Keep original lobby
)
self.chat_messages[i] = updated_msg
return updated_msg
# If message not found, add it as new
self.chat_messages.append(chat_message)
# Keep only the latest messages per lobby
if len(self.chat_messages) > LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY:
self.chat_messages = self.chat_messages[
-LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY :
]
return chat_message
def get_chat_messages(self, limit: int = 50) -> List[ChatMessageModel]:
"""Get the most recent chat messages from the lobby"""
with self.lock:
return self.chat_messages[-limit:] if self.chat_messages else []
def clear_chat_messages(self) -> None:
"""Clear all chat messages from the lobby"""
with self.lock:
self.chat_messages.clear()
async def broadcast_chat_clear(self) -> None:
"""Broadcast a chat clear event to all connected sessions in the lobby"""
failed_sessions: List[Session] = []
for peer in self.sessions.values():
if peer.ws:
try:
logger.info(f"{self.getName()} -> chat_cleared({peer.getName()})")
await peer.ws.send_json({"type": "chat_cleared", "data": {}})
except Exception as e:
logger.warning(
f"Failed to send chat clear message to {peer.getName()}: {e}"
)
failed_sessions.append(peer)
# Clean up failed sessions
for failed_session in failed_sessions:
failed_session.ws = None
async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None:
"""Broadcast a chat message to all connected sessions in the lobby"""
failed_sessions: List[Session] = []
for peer in self.sessions.values():
if peer.ws:
try:
logger.info(f"{self.getName()} -> chat_message({peer.getName()})")
await peer.ws.send_json(
{"type": "chat_message", "data": chat_message.model_dump()}
)
except Exception as e:
logger.warning(
f"Failed to send chat message to {peer.getName()}: {e}"
)
failed_sessions.append(peer)
# Clean up failed sessions
for failed_session in failed_sessions:
failed_session.ws = None
# Publish chat event
await event_bus.publish(ChatMessageSent(
session_id=chat_message.sender_session_id,
lobby_id=chat_message.lobby_id,
message=chat_message.message,
sender_name=chat_message.sender_name
))
def get_participant_count(self) -> int:
"""Get number of participants in lobby"""
with self.lock:
return len(self.sessions)
def is_empty(self) -> bool:
"""Check if lobby is empty"""
with self.lock:
return len(self.sessions) == 0
class LobbyManager:
"""Manages all lobbies and their lifecycle"""
def __init__(self):
self.lobbies: Dict[str, Lobby] = {}
self.lock = threading.RLock()
# Subscribe to session events - handle import errors gracefully
try:
event_bus.subscribe(SessionDisconnected, self)
event_bus.subscribe(SessionLeftLobby, self)
except ImportError:
try:
event_bus.subscribe(SessionDisconnected, self)
event_bus.subscribe(SessionLeftLobby, self)
except (ImportError, AttributeError):
# Event system not available, skip subscriptions
pass
async def handle(self, event):
"""Handle events from the event bus"""
if isinstance(event, SessionDisconnected):
await self._handle_session_disconnected(event)
elif isinstance(event, SessionLeftLobby):
await self._handle_session_left_lobby(event)
async def _handle_session_disconnected(self, event):
"""Handle session disconnection by removing from all lobbies"""
session_id = event.session_id
with self.lock:
lobbies_to_check = list(self.lobbies.values())
for lobby in lobbies_to_check:
with lobby.lock:
if session_id in lobby.sessions:
del lobby.sessions[session_id]
logger.info(f"Removed disconnected session {session_id} from lobby {lobby.getName()}")
# Update lobby state
await lobby.update_state()
# Check if lobby is now empty and should be cleaned up
if lobby.is_empty() and not lobby.private:
await self._cleanup_empty_lobby(lobby)
async def _handle_session_left_lobby(self, event):
"""Handle explicit session leave"""
# This is already handled by the session's leave_lobby method
# but we could add additional cleanup logic here if needed
pass
def create_or_get_lobby(self, name: str, private: bool = False) -> Lobby:
"""Create a new lobby or get existing one by name"""
with self.lock:
# Look for existing lobby with same name
for lobby in self.lobbies.values():
if lobby.name == name and lobby.private == private:
return lobby
# Create new lobby
lobby = Lobby(name=name, private=private)
self.lobbies[lobby.id] = lobby
logger.info(f"Created new lobby: {lobby.getName()}")
return lobby
def get_lobby(self, lobby_id: str) -> Optional[Lobby]:
"""Get lobby by ID"""
with self.lock:
return self.lobbies.get(lobby_id)
def get_lobby_by_name(self, name: str) -> Optional[Lobby]:
"""Get lobby by name"""
with self.lock:
for lobby in self.lobbies.values():
if lobby.name == name:
return lobby
return None
def list_lobbies(self, include_private: bool = False) -> List[Lobby]:
"""List all lobbies, optionally including private ones"""
with self.lock:
if include_private:
return list(self.lobbies.values())
else:
return [lobby for lobby in self.lobbies.values() if not lobby.private]
async def _cleanup_empty_lobby(self, lobby: Lobby):
"""Clean up an empty lobby"""
with self.lock:
if lobby.id in self.lobbies and lobby.is_empty():
del self.lobbies[lobby.id]
logger.info(f"Cleaned up empty lobby: {lobby.getName()}")
def get_lobby_count(self) -> int:
"""Get total lobby count"""
with self.lock:
return len(self.lobbies)
def get_total_participants(self) -> int:
"""Get total participants across all lobbies"""
with self.lock:
return sum(lobby.get_participant_count() for lobby in self.lobbies.values())
async def cleanup_empty_lobbies(self) -> int:
"""Clean up all empty non-private lobbies"""
removed_count = 0
with self.lock:
lobbies_to_remove: list[Lobby] = []
for lobby in self.lobbies.values():
if lobby.is_empty() and not lobby.private:
lobbies_to_remove.append(lobby)
for lobby in lobbies_to_remove:
del self.lobbies[lobby.id]
removed_count += 1
logger.info(f"Cleaned up empty lobby: {lobby.getName()}")
return removed_count
def set_name_protection_checker(self, checker_func):
"""Inject name protection checker from AuthManager"""
# This allows us to inject the name protection logic without tight coupling
for lobby in self.lobbies.values():
lobby._is_name_protected = checker_func