365 lines
14 KiB
Python
365 lines
14 KiB
Python
"""
|
|
Lobby management for the AI Voice Bot server.
|
|
|
|
This module handles lobby lifecycle, participants, and chat functionality.
|
|
Extracted from main.py to improve maintainability and separation of concerns.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import secrets
|
|
import time
|
|
import threading
|
|
from typing import Dict, List, Optional, TYPE_CHECKING
|
|
|
|
# Import shared models
|
|
# Import shared models
|
|
try:
|
|
# Try relative import first (when running as part of the package)
|
|
from ...shared.models import ChatMessageModel, ParticipantModel
|
|
except ImportError:
|
|
try:
|
|
# Try absolute import (when running directly)
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
from shared.models import ChatMessageModel, ParticipantModel
|
|
except ImportError:
|
|
raise ImportError(
|
|
f"Failed to import shared models: {e}. Ensure shared/models.py is accessible and PYTHONPATH is correctly set."
|
|
)
|
|
|
|
from logger import logger
|
|
|
|
# Use try/except for importing events to handle both relative and absolute imports
|
|
try:
|
|
from ..models.events import event_bus, ChatMessageSent, SessionDisconnected, SessionLeftLobby
|
|
except ImportError:
|
|
try:
|
|
from models.events import event_bus, ChatMessageSent, SessionDisconnected, SessionLeftLobby
|
|
except ImportError:
|
|
# Create dummy event system for standalone testing
|
|
class DummyEventBus:
|
|
async def publish(self, event):
|
|
pass
|
|
event_bus = DummyEventBus()
|
|
|
|
class ChatMessageSent:
|
|
pass
|
|
|
|
class SessionDisconnected:
|
|
pass
|
|
|
|
class SessionLeftLobby:
|
|
pass
|
|
|
|
if TYPE_CHECKING:
|
|
from .session_manager import Session
|
|
|
|
|
|
class LobbyConfig:
|
|
"""Configuration for lobby management"""
|
|
MAX_CHAT_MESSAGES_PER_LOBBY = int(os.getenv("MAX_CHAT_MESSAGES_PER_LOBBY", "100"))
|
|
|
|
|
|
class Lobby:
|
|
"""Individual lobby representing a chat/voice room"""
|
|
|
|
def __init__(self, name: str, id: Optional[str] = None, private: bool = False):
|
|
self.id = secrets.token_hex(16) if id is None else id
|
|
self.short = self.id[:8]
|
|
self.name = name
|
|
self.sessions: Dict[str, Session] = {} # All lobby members
|
|
self.private = private
|
|
self.chat_messages: List[ChatMessageModel] = [] # Store chat messages
|
|
self.lock = threading.RLock() # Thread safety for lobby operations
|
|
|
|
def getName(self) -> str:
|
|
return f"{self.short}:{self.name}"
|
|
|
|
async def update_state(self, requesting_session: Optional[Session] = None):
|
|
"""Update lobby state and notify participants"""
|
|
with self.lock:
|
|
users: List[ParticipantModel] = [
|
|
ParticipantModel(
|
|
name=s.name,
|
|
live=True if s.ws else False,
|
|
session_id=s.id,
|
|
protected=True
|
|
if s.name and self._is_name_protected(s.name)
|
|
else False,
|
|
has_media=s.has_media,
|
|
bot_run_id=s.bot_run_id,
|
|
bot_provider_id=s.bot_provider_id,
|
|
bot_instance_id=s.bot_instance_id,
|
|
)
|
|
for s in self.sessions.values()
|
|
if s.name
|
|
]
|
|
|
|
if requesting_session:
|
|
logger.info(
|
|
f"{requesting_session.getName()} -> lobby_state({self.getName()})"
|
|
)
|
|
if requesting_session.ws:
|
|
try:
|
|
await requesting_session.ws.send_json(
|
|
{
|
|
"type": "lobby_state",
|
|
"data": {
|
|
"participants": [user.model_dump() for user in users]
|
|
},
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to send lobby state to {requesting_session.getName()}: {e}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"{requesting_session.getName()} - No WebSocket connection."
|
|
)
|
|
else:
|
|
# Send to all sessions in lobby
|
|
failed_sessions: List[Session] = []
|
|
for s in self.sessions.values():
|
|
logger.info(f"{s.getName()} -> lobby_state({self.getName()})")
|
|
if s.ws:
|
|
try:
|
|
await s.ws.send_json(
|
|
{
|
|
"type": "lobby_state",
|
|
"data": {
|
|
"participants": [
|
|
user.model_dump() for user in users
|
|
]
|
|
},
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to send lobby state to {s.getName()}: {e}"
|
|
)
|
|
failed_sessions.append(s)
|
|
|
|
# Clean up failed sessions
|
|
for failed_session in failed_sessions:
|
|
failed_session.ws = None
|
|
|
|
def _is_name_protected(self, name: str) -> bool:
|
|
"""Check if a name is protected (has password) - to be injected by AuthManager"""
|
|
# TODO: This will be handled by dependency injection from AuthManager
|
|
return False
|
|
|
|
def getSession(self, id: str) -> Optional[Session]:
|
|
with self.lock:
|
|
return self.sessions.get(id, None)
|
|
|
|
async def addSession(self, session: Session) -> None:
|
|
with self.lock:
|
|
if session.id in self.sessions:
|
|
logger.warning(
|
|
f"{session.getName()} - Already in lobby {self.getName()}."
|
|
)
|
|
return
|
|
self.sessions[session.id] = session
|
|
await self.update_state()
|
|
|
|
async def removeSession(self, session: Session) -> None:
|
|
with self.lock:
|
|
if session.id not in self.sessions:
|
|
logger.warning(f"{session.getName()} - Not in lobby {self.getName()}.")
|
|
return
|
|
del self.sessions[session.id]
|
|
await self.update_state()
|
|
|
|
def add_chat_message(self, session: Session, message: str) -> ChatMessageModel:
|
|
"""Add a chat message to the lobby and return the message data"""
|
|
with self.lock:
|
|
chat_message = ChatMessageModel(
|
|
id=secrets.token_hex(8),
|
|
message=message,
|
|
sender_name=session.name or session.short,
|
|
sender_session_id=session.id,
|
|
timestamp=time.time(),
|
|
lobby_id=self.id,
|
|
)
|
|
self.chat_messages.append(chat_message)
|
|
# Keep only the latest messages per lobby
|
|
if len(self.chat_messages) > LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY:
|
|
self.chat_messages = self.chat_messages[
|
|
-LobbyConfig.MAX_CHAT_MESSAGES_PER_LOBBY :
|
|
]
|
|
return chat_message
|
|
|
|
def get_chat_messages(self, limit: int = 50) -> List[ChatMessageModel]:
|
|
"""Get the most recent chat messages from the lobby"""
|
|
with self.lock:
|
|
return self.chat_messages[-limit:] if self.chat_messages else []
|
|
|
|
async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None:
|
|
"""Broadcast a chat message to all connected sessions in the lobby"""
|
|
failed_sessions: List[Session] = []
|
|
for peer in self.sessions.values():
|
|
if peer.ws:
|
|
try:
|
|
logger.info(f"{self.getName()} -> chat_message({peer.getName()})")
|
|
await peer.ws.send_json(
|
|
{"type": "chat_message", "data": chat_message.model_dump()}
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to send chat message to {peer.getName()}: {e}"
|
|
)
|
|
failed_sessions.append(peer)
|
|
|
|
# Clean up failed sessions
|
|
for failed_session in failed_sessions:
|
|
failed_session.ws = None
|
|
|
|
# Publish chat event
|
|
await event_bus.publish(ChatMessageSent(
|
|
session_id=chat_message.sender_session_id,
|
|
lobby_id=chat_message.lobby_id,
|
|
message=chat_message.message,
|
|
sender_name=chat_message.sender_name
|
|
))
|
|
|
|
def get_participant_count(self) -> int:
|
|
"""Get number of participants in lobby"""
|
|
with self.lock:
|
|
return len(self.sessions)
|
|
|
|
def is_empty(self) -> bool:
|
|
"""Check if lobby is empty"""
|
|
with self.lock:
|
|
return len(self.sessions) == 0
|
|
|
|
|
|
class LobbyManager:
|
|
"""Manages all lobbies and their lifecycle"""
|
|
|
|
def __init__(self):
|
|
self.lobbies: Dict[str, Lobby] = {}
|
|
self.lock = threading.RLock()
|
|
|
|
# Subscribe to session events - handle import errors gracefully
|
|
try:
|
|
event_bus.subscribe(SessionDisconnected, self)
|
|
event_bus.subscribe(SessionLeftLobby, self)
|
|
except ImportError:
|
|
try:
|
|
event_bus.subscribe(SessionDisconnected, self)
|
|
event_bus.subscribe(SessionLeftLobby, self)
|
|
except (ImportError, AttributeError):
|
|
# Event system not available, skip subscriptions
|
|
pass
|
|
|
|
async def handle(self, event):
|
|
"""Handle events from the event bus"""
|
|
|
|
if isinstance(event, SessionDisconnected):
|
|
await self._handle_session_disconnected(event)
|
|
elif isinstance(event, SessionLeftLobby):
|
|
await self._handle_session_left_lobby(event)
|
|
|
|
async def _handle_session_disconnected(self, event):
|
|
"""Handle session disconnection by removing from all lobbies"""
|
|
session_id = event.session_id
|
|
|
|
with self.lock:
|
|
lobbies_to_check = list(self.lobbies.values())
|
|
|
|
for lobby in lobbies_to_check:
|
|
with lobby.lock:
|
|
if session_id in lobby.sessions:
|
|
del lobby.sessions[session_id]
|
|
logger.info(f"Removed disconnected session {session_id} from lobby {lobby.getName()}")
|
|
|
|
# Update lobby state
|
|
await lobby.update_state()
|
|
|
|
# Check if lobby is now empty and should be cleaned up
|
|
if lobby.is_empty() and not lobby.private:
|
|
await self._cleanup_empty_lobby(lobby)
|
|
|
|
async def _handle_session_left_lobby(self, event):
|
|
"""Handle explicit session leave"""
|
|
# This is already handled by the session's leave_lobby method
|
|
# but we could add additional cleanup logic here if needed
|
|
pass
|
|
|
|
def create_or_get_lobby(self, name: str, private: bool = False) -> Lobby:
|
|
"""Create a new lobby or get existing one by name"""
|
|
with self.lock:
|
|
# Look for existing lobby with same name
|
|
for lobby in self.lobbies.values():
|
|
if lobby.name == name and lobby.private == private:
|
|
return lobby
|
|
|
|
# Create new lobby
|
|
lobby = Lobby(name=name, private=private)
|
|
self.lobbies[lobby.id] = lobby
|
|
logger.info(f"Created new lobby: {lobby.getName()}")
|
|
return lobby
|
|
|
|
def get_lobby(self, lobby_id: str) -> Optional[Lobby]:
|
|
"""Get lobby by ID"""
|
|
with self.lock:
|
|
return self.lobbies.get(lobby_id)
|
|
|
|
def get_lobby_by_name(self, name: str) -> Optional[Lobby]:
|
|
"""Get lobby by name"""
|
|
with self.lock:
|
|
for lobby in self.lobbies.values():
|
|
if lobby.name == name:
|
|
return lobby
|
|
return None
|
|
|
|
def list_lobbies(self, include_private: bool = False) -> List[Lobby]:
|
|
"""List all lobbies, optionally including private ones"""
|
|
with self.lock:
|
|
if include_private:
|
|
return list(self.lobbies.values())
|
|
else:
|
|
return [lobby for lobby in self.lobbies.values() if not lobby.private]
|
|
|
|
async def _cleanup_empty_lobby(self, lobby: Lobby):
|
|
"""Clean up an empty lobby"""
|
|
with self.lock:
|
|
if lobby.id in self.lobbies and lobby.is_empty():
|
|
del self.lobbies[lobby.id]
|
|
logger.info(f"Cleaned up empty lobby: {lobby.getName()}")
|
|
|
|
def get_lobby_count(self) -> int:
|
|
"""Get total lobby count"""
|
|
with self.lock:
|
|
return len(self.lobbies)
|
|
|
|
def get_total_participants(self) -> int:
|
|
"""Get total participants across all lobbies"""
|
|
with self.lock:
|
|
return sum(lobby.get_participant_count() for lobby in self.lobbies.values())
|
|
|
|
async def cleanup_empty_lobbies(self) -> int:
|
|
"""Clean up all empty non-private lobbies"""
|
|
removed_count = 0
|
|
|
|
with self.lock:
|
|
lobbies_to_remove: list[Lobby] = []
|
|
for lobby in self.lobbies.values():
|
|
if lobby.is_empty() and not lobby.private:
|
|
lobbies_to_remove.append(lobby)
|
|
|
|
for lobby in lobbies_to_remove:
|
|
del self.lobbies[lobby.id]
|
|
removed_count += 1
|
|
logger.info(f"Cleaned up empty lobby: {lobby.getName()}")
|
|
|
|
return removed_count
|
|
|
|
def set_name_protection_checker(self, checker_func):
|
|
"""Inject name protection checker from AuthManager"""
|
|
# This allows us to inject the name protection logic without tight coupling
|
|
for lobby in self.lobbies.values():
|
|
lobby._is_name_protected = checker_func
|