From 6b4770472375e909c4a47776a9b4fca5f34d445e Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Thu, 4 Sep 2025 16:54:38 -0700 Subject: [PATCH] WebRTC now working --- server/api/admin.py | 28 +- server/core/session_manager.py | 160 +- server/main_backup_working.py | 2338 -------------------------- server/main_clean.py | 293 ---- server/main_original.py | 2338 -------------------------- server/main_refactored.py | 213 --- server/main_working.py | 2338 -------------------------- server/websocket/message_handlers.py | 37 + server/websocket/webrtc_signaling.py | 199 +++ tests/test-webrtc-signaling.py | 71 + tests/verify-webrtc-handlers.py | 41 + 11 files changed, 508 insertions(+), 7548 deletions(-) delete mode 100644 server/main_backup_working.py delete mode 100644 server/main_clean.py delete mode 100644 server/main_original.py delete mode 100644 server/main_refactored.py delete mode 100644 server/main_working.py create mode 100644 server/websocket/webrtc_signaling.py create mode 100644 tests/test-webrtc-signaling.py create mode 100644 tests/verify-webrtc-handlers.py diff --git a/server/api/admin.py b/server/api/admin.py index 513c907..e2f476f 100644 --- a/server/api/admin.py +++ b/server/api/admin.py @@ -5,13 +5,17 @@ This module contains admin-only endpoints for managing users, sessions, and syst Extracted from main.py to improve maintainability and separation of concerns. """ -from typing import TYPE_CHECKING -from fastapi import APIRouter, Request, Response, Body - -# Import shared models import sys import os -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) +from typing import TYPE_CHECKING + +# Add the parent directory of server to the path to access shared +current_dir = os.path.dirname(os.path.abspath(__file__)) +server_dir = os.path.dirname(current_dir) +project_root = os.path.dirname(server_dir) +sys.path.insert(0, project_root) + +from fastapi import APIRouter, Request, Response, Body from shared.models import ( AdminNamesResponse, AdminActionResponse, @@ -21,13 +25,15 @@ from shared.models import ( AdminMetricsResponse, AdminMetricsConfig, ) - from logger import logger if TYPE_CHECKING: - from ..core.session_manager import SessionManager - from ..core.lobby_manager import LobbyManager - from ..core.auth_manager import AuthManager + from core.session_manager import SessionManager, SessionConfig + from core.lobby_manager import LobbyManager + from core.auth_manager import AuthManager +else: + # Import for runtime + from core.session_manager import SessionConfig class AdminAPI: @@ -161,10 +167,8 @@ class AdminAPI: cleanup_candidates = 0 old_anonymous = 0 old_displaced = 0 - + for session in all_sessions: - from ..core.session_manager import SessionConfig - # Anonymous sessions if (not session.ws and not session.name and current_time - session.created_at > SessionConfig.ANONYMOUS_SESSION_TIMEOUT): diff --git a/server/core/session_manager.py b/server/core/session_manager.py index 0d93f00..21527e0 100644 --- a/server/core/session_manager.py +++ b/server/core/session_manager.py @@ -2,17 +2,6 @@ Session management for the AI Voice Bot server. This module handles session lifecycle, persistence, and cleanup operations. -Extracted from m await lobby.add await lobby.removeSession(self) - - # Publish event - await event_bus.publish(SessionLeftLobby( - session_id=self.id, - lobby_id=lobby.id,(self) - - # Publish event - await event_bus.publish(SessionJoinedLobby( - session_id=self.id, - lobby_id=lobby.id,to improve maintainability and separation of concerns. """ from __future__ import annotations @@ -148,13 +137,93 @@ class Session: self.displaced_at = time.time() async def join_lobby(self, lobby): - """Join a lobby and update peers""" + """Join a lobby and establish WebRTC peer connections""" with self.session_lock: if lobby not in self.lobbies: self.lobbies.append(lobby) - + + # Initialize lobby_peers for this lobby if not exists + if lobby.id not in self.lobby_peers: + self.lobby_peers[lobby.id] = [] + + # Add to lobby first await lobby.addSession(self) - + + # Get existing peer sessions in this lobby for WebRTC setup + peer_sessions = [] + for session in lobby.sessions.values(): + if ( + session.id != self.id and session.ws + ): # Don't include self and only connected sessions + peer_sessions.append(session) + + # Establish WebRTC peer connections with existing sessions + for peer_session in peer_sessions: + # Only establish connections if at least one session has media + if self.has_media or peer_session.has_media: + logger.info( + f"{self.getName()} <-> {peer_session.getName()} - Establishing WebRTC peer connection" + ) + + # Add peer to our lobby_peers list + with self.session_lock: + if peer_session.id not in self.lobby_peers[lobby.id]: + self.lobby_peers[lobby.id].append(peer_session.id) + + # Add this session to peer's lobby_peers list + with peer_session.session_lock: + if lobby.id not in peer_session.lobby_peers: + peer_session.lobby_peers[lobby.id] = [] + if self.id not in peer_session.lobby_peers[lobby.id]: + peer_session.lobby_peers[lobby.id].append(self.id) + + # Send addPeer to existing peer (they should not create offer) + logger.info( + f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName()}, {lobby.getName()}, should_create_offer=False, has_media={self.has_media})" + ) + try: + await peer_session.ws.send_json( + { + "type": "addPeer", + "data": { + "peer_id": self.id, + "peer_name": self.name, + "has_media": self.has_media, + "should_create_offer": False, + }, + } + ) + except Exception as e: + logger.warning( + f"Failed to send addPeer to {peer_session.getName()}: {e}" + ) + + # Send addPeer to this session (they should create offer) + if self.ws: + logger.info( + f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName()}, {lobby.getName()}, should_create_offer=True, has_media={peer_session.has_media})" + ) + try: + await self.ws.send_json( + { + "type": "addPeer", + "data": { + "peer_id": peer_session.id, + "peer_name": peer_session.name, + "has_media": peer_session.has_media, + "should_create_offer": True, + }, + } + ) + except Exception as e: + logger.warning( + f"Failed to send addPeer to {self.getName()}: {e}" + ) + else: + logger.info( + f"{self.getName()} - Skipping WebRTC connection with {peer_session.getName()} (neither has media: self={self.has_media}, peer={peer_session.has_media})" + ) + # Publish join event await event_bus.publish(SessionJoinedLobby( session_id=self.id, @@ -163,13 +232,72 @@ class Session: )) async def leave_lobby(self, lobby): - """Leave a lobby and clean up peers""" + """Leave a lobby and clean up WebRTC peer connections""" + # Get peer sessions before removing from lobby + peer_sessions = [] + if lobby.id in self.lobby_peers: + for peer_id in self.lobby_peers[lobby.id]: + peer_session = None + # Find peer session in lobby + for session in lobby.sessions.values(): + if session.id == peer_id: + peer_session = session + break + + if peer_session and peer_session.ws: + peer_sessions.append(peer_session) + + # Send removePeer messages to all peers + for peer_session in peer_sessions: + logger.info(f"{peer_session.getName()} <- remove_peer({self.getName()})") + try: + await peer_session.ws.send_json( + { + "type": "removePeer", + "data": {"peer_name": self.name, "peer_id": self.id}, + } + ) + except Exception as e: + logger.warning( + f"Failed to send removePeer to {peer_session.getName()}: {e}" + ) + + # Remove from peer's lobby_peers + with peer_session.session_lock: + if ( + lobby.id in peer_session.lobby_peers + and self.id in peer_session.lobby_peers[lobby.id] + ): + peer_session.lobby_peers[lobby.id].remove(self.id) + + # Send removePeer to this session + if self.ws: + logger.info( + f"{self.getName()} <- remove_peer({peer_session.getName()})" + ) + try: + await self.ws.send_json( + { + "type": "removePeer", + "data": { + "peer_name": peer_session.name, + "peer_id": peer_session.id, + }, + } + ) + except Exception as e: + logger.warning( + f"Failed to send removePeer to {self.getName()}: {e}" + ) + + # Clean up our lobby_peers and lobbies with self.session_lock: if lobby in self.lobbies: self.lobbies.remove(lobby) if lobby.id in self.lobby_peers: del self.lobby_peers[lobby.id] - + + # Remove from lobby await lobby.removeSession(self) # Publish leave event diff --git a/server/main_backup_working.py b/server/main_backup_working.py deleted file mode 100644 index d5e0118..0000000 --- a/server/main_backup_working.py +++ /dev/null @@ -1,2338 +0,0 @@ -from __future__ import annotations -from typing import Any, Optional, List -from fastapi import ( - Body, - Cookie, - FastAPI, - HTTPException, - Path, - WebSocket, - Request, - Response, - WebSocketDisconnect, -) -import secrets -import os -import json -import hashlib -import binascii -import sys -import asyncio -import threading -import time -from contextlib import asynccontextmanager - -from fastapi.staticfiles import StaticFiles -import httpx -from pydantic import ValidationError -from logger import logger - -# Import shared models -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from shared.models import ( - HealthResponse, - LobbiesResponse, - LobbyCreateRequest, - LobbyCreateResponse, - LobbyListItem, - LobbyModel, - NamePasswordRecord, - LobbySaved, - SessionResponse, - SessionSaved, - SessionsPayload, - AdminNamesResponse, - AdminActionResponse, - AdminSetPassword, - AdminClearPassword, - AdminValidationResponse, - AdminMetricsResponse, - AdminMetricsConfig, - JoinStatusModel, - ChatMessageModel, - ChatMessagesResponse, - ParticipantModel, - # Bot provider models - BotProviderModel, - BotProviderRegisterRequest, - BotProviderRegisterResponse, - BotProviderListResponse, - BotListResponse, - BotInfoModel, - BotJoinLobbyRequest, - BotJoinLobbyResponse, - BotJoinPayload, - BotLeaveLobbyRequest, - BotLeaveLobbyResponse, - BotProviderBotsResponse, - BotProviderJoinResponse, -) - - -class SessionConfig: - """Configuration class for session management""" - - ANONYMOUS_SESSION_TIMEOUT = int( - os.getenv("ANONYMOUS_SESSION_TIMEOUT", "60") - ) # 1 minute - DISPLACED_SESSION_TIMEOUT = int( - os.getenv("DISPLACED_SESSION_TIMEOUT", "10800") - ) # 3 hours - CLEANUP_INTERVAL = int(os.getenv("CLEANUP_INTERVAL", "300")) # 5 minutes - MAX_SESSIONS_PER_CLEANUP = int( - os.getenv("MAX_SESSIONS_PER_CLEANUP", "100") - ) # Circuit breaker - MAX_CHAT_MESSAGES_PER_LOBBY = int(os.getenv("MAX_CHAT_MESSAGES_PER_LOBBY", "100")) - SESSION_VALIDATION_INTERVAL = int( - os.getenv("SESSION_VALIDATION_INTERVAL", "1800") - ) # 30 minutes - - -class BotProviderConfig: - """Configuration class for bot provider management""" - - # Comma-separated list of allowed provider keys - # Format: "key1:name1,key2:name2" or just "key1,key2" (names default to keys) - ALLOWED_PROVIDERS = os.getenv("BOT_PROVIDER_KEYS", "") - - @classmethod - def get_allowed_providers(cls) -> dict[str, str]: - """Parse allowed providers from environment variable - - Returns: - dict mapping provider_key -> provider_name - """ - if not cls.ALLOWED_PROVIDERS.strip(): - return {} - - providers: dict[str, str] = {} - for entry in cls.ALLOWED_PROVIDERS.split(","): - entry = entry.strip() - if not entry: - continue - - if ":" in entry: - key, name = entry.split(":", 1) - providers[key.strip()] = name.strip() - else: - providers[entry] = entry - - return providers - - -# Thread lock for session operations -session_lock = threading.RLock() - -# Mapping of reserved names to password records (lowercased name -> {salt:..., hash:...}) -name_passwords: dict[str, dict[str, str]] = {} - -# Bot provider registry: provider_id -> BotProviderModel -bot_providers: dict[str, BotProviderModel] = {} - -all_label = "[ all ]" -info_label = "[ info ]" -todo_label = "[ todo ]" -unset_label = "[ ---- ]" - - -def _hash_password(password: str, salt_hex: str | None = None) -> tuple[str, str]: - """Return (salt_hex, hash_hex) for the given password. If salt_hex is provided - it is used; otherwise a new salt is generated.""" - if salt_hex: - salt = binascii.unhexlify(salt_hex) - else: - salt = secrets.token_bytes(16) - salt_hex = binascii.hexlify(salt).decode() - dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100000) - hash_hex = binascii.hexlify(dk).decode() - return salt_hex, hash_hex - - -public_url = os.getenv("PUBLIC_URL", "/") -if not public_url.endswith("/"): - public_url += "/" - -# Global variables to control background tasks -cleanup_task_running = False -cleanup_task = None -validation_task_running = False -validation_task = None - - -async def periodic_cleanup(): - """Background task to periodically clean up old sessions""" - global cleanup_task_running - cleanup_errors = 0 - max_consecutive_errors = 5 - - while cleanup_task_running: - try: - removed_count = Session.cleanup_old_sessions() - if removed_count > 0: - logger.info(f"Periodic cleanup removed {removed_count} old sessions") - cleanup_errors = 0 # Reset error counter on success - - # Run cleanup at configured interval - await asyncio.sleep(SessionConfig.CLEANUP_INTERVAL) - except Exception as e: - cleanup_errors += 1 - logger.error( - f"Error in session cleanup task (attempt {cleanup_errors}): {e}" - ) - - if cleanup_errors >= max_consecutive_errors: - logger.error( - f"Too many consecutive cleanup errors ({cleanup_errors}), stopping cleanup task" - ) - break - - # Exponential backoff on errors - await asyncio.sleep(min(60 * cleanup_errors, 300)) - - -async def periodic_validation(): - """Background task to periodically validate session integrity""" - global validation_task_running - - while validation_task_running: - try: - issues = Session.validate_session_integrity() - if issues: - logger.warning(f"Session integrity issues found: {len(issues)} issues") - for issue in issues[:10]: # Log first 10 issues - logger.warning(f"Integrity issue: {issue}") - - await asyncio.sleep(SessionConfig.SESSION_VALIDATION_INTERVAL) - except Exception as e: - logger.error(f"Error in session validation task: {e}") - await asyncio.sleep(300) # Wait 5 minutes before retrying on error - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Lifespan context manager for startup and shutdown events""" - global cleanup_task_running, cleanup_task, validation_task_running, validation_task - - # Startup - logger.info("Starting background tasks...") - cleanup_task_running = True - validation_task_running = True - cleanup_task = asyncio.create_task(periodic_cleanup()) - validation_task = asyncio.create_task(periodic_validation()) - logger.info("Session cleanup and validation tasks started") - - yield - - # Shutdown - logger.info("Shutting down background tasks...") - cleanup_task_running = False - validation_task_running = False - - # Cancel tasks - for task in [cleanup_task, validation_task]: - if task: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - # Clean up all sessions gracefully - await Session.cleanup_all_sessions() - logger.info("All background tasks stopped and sessions cleaned up") - - -app = FastAPI(lifespan=lifespan) - -logger.info(f"Starting server with public URL: {public_url}") -logger.info( - f"Session config - Anonymous timeout: {SessionConfig.ANONYMOUS_SESSION_TIMEOUT}s, " - f"Displaced timeout: {SessionConfig.DISPLACED_SESSION_TIMEOUT}s, " - f"Cleanup interval: {SessionConfig.CLEANUP_INTERVAL}s" -) - -# Log bot provider configuration -allowed_providers = BotProviderConfig.get_allowed_providers() -if allowed_providers: - logger.info( - f"Bot provider authentication enabled. Allowed providers: {list(allowed_providers.keys())}" - ) -else: - logger.warning("Bot provider authentication disabled. Any provider can register.") - -# Optional admin token to protect admin endpoints -ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", None) - - -def _require_admin(request: Request) -> bool: - if not ADMIN_TOKEN: - return True - token = request.headers.get("X-Admin-Token") - return token == ADMIN_TOKEN - - -@app.get(public_url + "api/admin/names", response_model=AdminNamesResponse) -def admin_list_names(request: Request): - if not _require_admin(request): - return Response(status_code=403) - # Convert dict format to Pydantic models - name_passwords_models = { - name: NamePasswordRecord(**record) for name, record in name_passwords.items() - } - return AdminNamesResponse(name_passwords=name_passwords_models) - - -@app.post(public_url + "api/admin/set_password", response_model=AdminActionResponse) -def admin_set_password(request: Request, payload: AdminSetPassword = Body(...)): - if not _require_admin(request): - return Response(status_code=403) - lname = payload.name.lower() - salt, hash_hex = _hash_password(payload.password) - name_passwords[lname] = {"salt": salt, "hash": hash_hex} - Session.save() - return AdminActionResponse(status="ok", name=payload.name) - - -@app.post(public_url + "api/admin/clear_password", response_model=AdminActionResponse) -def admin_clear_password(request: Request, payload: AdminClearPassword = Body(...)): - if not _require_admin(request): - return Response(status_code=403) - lname = payload.name.lower() - if lname in name_passwords: - del name_passwords[lname] - Session.save() - return AdminActionResponse(status="ok", name=payload.name) - return AdminActionResponse(status="not_found", name=payload.name) - - -@app.post(public_url + "api/admin/cleanup_sessions", response_model=AdminActionResponse) -def admin_cleanup_sessions(request: Request): - if not _require_admin(request): - return Response(status_code=403) - try: - removed_count = Session.cleanup_old_sessions() - return AdminActionResponse( - status="ok", name=f"Removed {removed_count} sessions" - ) - except Exception as e: - logger.error(f"Error during manual session cleanup: {e}") - return AdminActionResponse(status="error", name=f"Error: {str(e)}") - - -@app.get(public_url + "api/admin/session_metrics", response_model=AdminMetricsResponse) -def admin_session_metrics(request: Request): - if not _require_admin(request): - return Response(status_code=403) - try: - return Session.get_cleanup_metrics() - except Exception as e: - logger.error(f"Error getting session metrics: {e}") - return Response(status_code=500) - - -@app.get( - public_url + "api/admin/validate_sessions", response_model=AdminValidationResponse -) -def admin_validate_sessions(request: Request): - if not _require_admin(request): - return Response(status_code=403) - try: - issues = Session.validate_session_integrity() - return AdminValidationResponse( - status="ok", issues=issues, issue_count=len(issues) - ) - except Exception as e: - logger.error(f"Error validating sessions: {e}") - return AdminValidationResponse(status="error", error=str(e)) - - -lobbies: dict[str, Lobby] = {} - - -class Lobby: - def __init__(self, name: str, id: str | None = None, private: bool = False): - self.id = secrets.token_hex(16) if id is None else id - self.short = self.id[:8] - self.name = name - self.sessions: dict[str, Session] = {} # All lobby members - self.private = private - self.chat_messages: list[ChatMessageModel] = [] # Store chat messages - self.lock = threading.RLock() # Thread safety for lobby operations - - def getName(self) -> str: - return f"{self.short}:{self.name}" - - async def update_state(self, requesting_session: Session | None = None): - with self.lock: - users: list[ParticipantModel] = [ - ParticipantModel( - name=s.name, - live=True if s.ws else False, - session_id=s.id, - protected=True - if s.name and s.name.lower() in name_passwords - else False, - is_bot=s.is_bot, - has_media=s.has_media, - bot_run_id=s.bot_run_id, - bot_provider_id=s.bot_provider_id, - ) - for s in self.sessions.values() - if s.name - ] - - if requesting_session: - logger.info( - f"{requesting_session.getName()} -> lobby_state({self.getName()})" - ) - if requesting_session.ws: - try: - await requesting_session.ws.send_json( - { - "type": "lobby_state", - "data": { - "participants": [user.model_dump() for user in users] - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to send lobby state to {requesting_session.getName()}: {e}" - ) - else: - logger.warning( - f"{requesting_session.getName()} - No WebSocket connection." - ) - else: - # Send to all sessions in lobby - failed_sessions: list[Session] = [] - for s in self.sessions.values(): - logger.info(f"{s.getName()} -> lobby_state({self.getName()})") - if s.ws: - try: - await s.ws.send_json( - { - "type": "lobby_state", - "data": { - "participants": [ - user.model_dump() for user in users - ] - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to send lobby state to {s.getName()}: {e}" - ) - failed_sessions.append(s) - - # Clean up failed sessions - for failed_session in failed_sessions: - failed_session.ws = None - - def getSession(self, id: str) -> Session | None: - with self.lock: - return self.sessions.get(id, None) - - async def addSession(self, session: Session) -> None: - with self.lock: - if session.id in self.sessions: - logger.warning( - f"{session.getName()} - Already in lobby {self.getName()}." - ) - return None - self.sessions[session.id] = session - await self.update_state() - - async def removeSession(self, session: Session) -> None: - with self.lock: - if session.id not in self.sessions: - logger.warning(f"{session.getName()} - Not in lobby {self.getName()}.") - return None - del self.sessions[session.id] - await self.update_state() - - def add_chat_message(self, session: Session, message: str) -> ChatMessageModel: - """Add a chat message to the lobby and return the message data""" - with self.lock: - chat_message = ChatMessageModel( - id=secrets.token_hex(8), - message=message, - sender_name=session.name or session.short, - sender_session_id=session.id, - timestamp=time.time(), - lobby_id=self.id, - ) - self.chat_messages.append(chat_message) - # Keep only the latest messages per lobby - if len(self.chat_messages) > SessionConfig.MAX_CHAT_MESSAGES_PER_LOBBY: - self.chat_messages = self.chat_messages[ - -SessionConfig.MAX_CHAT_MESSAGES_PER_LOBBY : - ] - return chat_message - - def get_chat_messages(self, limit: int = 50) -> list[ChatMessageModel]: - """Get the most recent chat messages from the lobby""" - with self.lock: - return self.chat_messages[-limit:] if self.chat_messages else [] - - async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None: - """Broadcast a chat message to all connected sessions in the lobby""" - failed_sessions: list[Session] = [] - for peer in self.sessions.values(): - if peer.ws: - try: - logger.info(f"{self.getName()} -> chat_message({peer.getName()})") - await peer.ws.send_json( - {"type": "chat_message", "data": chat_message.model_dump()} - ) - except Exception as e: - logger.warning( - f"Failed to send chat message to {peer.getName()}: {e}" - ) - failed_sessions.append(peer) - - # Clean up failed sessions - for failed_session in failed_sessions: - failed_session.ws = None - - -class Session: - _instances: list[Session] = [] - _save_file = "sessions.json" - _loaded = False - lock = threading.RLock() # Thread safety for class-level operations - - def __init__(self, id: str, is_bot: bool = False, has_media: bool = True): - logger.info( - f"Instantiating new session {id} (bot: {is_bot}, media: {has_media})" - ) - with Session.lock: - self._instances.append(self) - self.id = id - self.short = id[:8] - self.name = "" - self.lobbies: list[Lobby] = [] # List of lobby IDs this session is in - self.lobby_peers: dict[ - str, list[str] - ] = {} # lobby ID -> list of peer session IDs - self.ws: WebSocket | None = None - self.created_at = time.time() - self.last_used = time.time() - self.displaced_at: float | None = None # When name was taken over - self.is_bot = is_bot # Whether this session represents a bot - self.has_media = has_media # Whether this session provides audio/video streams - self.bot_run_id: str | None = None # Bot run ID for tracking - self.bot_provider_id: str | None = None # Bot provider ID - self.session_lock = threading.RLock() # Instance-level lock - self.save() - - @classmethod - def save(cls): - try: - with cls.lock: - sessions_list: list[SessionSaved] = [] - for s in cls._instances: - with s.session_lock: - lobbies_list: list[LobbySaved] = [ - LobbySaved( - id=lobby.id, name=lobby.name, private=lobby.private - ) - for lobby in s.lobbies - ] - sessions_list.append( - SessionSaved( - id=s.id, - name=s.name or "", - lobbies=lobbies_list, - created_at=s.created_at, - last_used=s.last_used, - displaced_at=s.displaced_at, - is_bot=s.is_bot, - has_media=s.has_media, - bot_run_id=s.bot_run_id, - bot_provider_id=s.bot_provider_id, - ) - ) - - # Prepare name password store for persistence (salt+hash). Only structured records are supported. - saved_pw: dict[str, NamePasswordRecord] = { - name: NamePasswordRecord(**record) - for name, record in name_passwords.items() - } - - payload_model = SessionsPayload( - sessions=sessions_list, name_passwords=saved_pw - ) - payload = payload_model.model_dump() - - # Atomic write using temp file - temp_file = cls._save_file + ".tmp" - with open(temp_file, "w") as f: - json.dump(payload, f, indent=2) - - # Atomic rename - os.rename(temp_file, cls._save_file) - - logger.info( - f"Saved {len(sessions_list)} sessions and {len(saved_pw)} name passwords to {cls._save_file}" - ) - except Exception as e: - logger.error(f"Failed to save sessions: {e}") - # Clean up temp file if it exists - try: - if os.path.exists(cls._save_file + ".tmp"): - os.remove(cls._save_file + ".tmp") - except Exception as e: - pass - - @classmethod - def load(cls): - if not os.path.exists(cls._save_file): - logger.info(f"No session save file found: {cls._save_file}") - return - - try: - with open(cls._save_file, "r") as f: - raw = json.load(f) - except Exception as e: - logger.error(f"Failed to read session save file: {e}") - return - - try: - payload = SessionsPayload.model_validate(raw) - except ValidationError as e: - logger.exception(f"Failed to validate sessions payload: {e}") - return - - # Populate in-memory structures from payload (no backwards compatibility code) - name_passwords.clear() - for name, rec in payload.name_passwords.items(): - # rec is a NamePasswordRecord - name_passwords[name] = {"salt": rec.salt, "hash": rec.hash} - - current_time = time.time() - sessions_loaded = 0 - sessions_expired = 0 - - with cls.lock: - for s_saved in payload.sessions: - # Check if this session should be expired during loading - created_at = getattr(s_saved, "created_at", time.time()) - last_used = getattr(s_saved, "last_used", time.time()) - displaced_at = getattr(s_saved, "displaced_at", None) - name = s_saved.name or "" - - # Apply same removal criteria as cleanup_old_sessions - should_expire = cls._should_remove_session_static( - name, None, created_at, last_used, displaced_at, current_time - ) - - if should_expire: - sessions_expired += 1 - logger.info(f"Expiring session {s_saved.id[:8]}:{name} during load") - continue # Skip loading this expired session - - session = Session( - s_saved.id, - is_bot=getattr(s_saved, "is_bot", False), - has_media=getattr(s_saved, "has_media", True), - ) - session.name = name - # Load timestamps, with defaults for backward compatibility - session.created_at = created_at - session.last_used = last_used - session.displaced_at = displaced_at - # Load bot information with defaults for backward compatibility - session.is_bot = getattr(s_saved, "is_bot", False) - session.has_media = getattr(s_saved, "has_media", True) - session.bot_run_id = getattr(s_saved, "bot_run_id", None) - session.bot_provider_id = getattr(s_saved, "bot_provider_id", None) - for lobby_saved in s_saved.lobbies: - session.lobbies.append( - Lobby( - name=lobby_saved.name, - id=lobby_saved.id, - private=lobby_saved.private, - ) - ) - logger.info( - f"Loaded session {session.getName()} with {len(session.lobbies)} lobbies" - ) - for lobby in session.lobbies: - lobbies[lobby.id] = Lobby( - name=lobby.name, id=lobby.id, private=lobby.private - ) # Ensure lobby exists - sessions_loaded += 1 - - logger.info( - f"Loaded {sessions_loaded} sessions and {len(name_passwords)} name passwords from {cls._save_file}" - ) - if sessions_expired > 0: - logger.info(f"Expired {sessions_expired} old sessions during load") - # Save immediately to persist the cleanup - cls.save() - - @classmethod - def getSession(cls, id: str) -> Session | None: - if not cls._loaded: - cls.load() - logger.info(f"Loaded {len(cls._instances)} sessions from disk...") - cls._loaded = True - - with cls.lock: - for s in cls._instances: - if s.id == id: - return s - return None - - @classmethod - def isUniqueName(cls, name: str) -> bool: - if not name: - return False - with cls.lock: - for s in cls._instances: - with s.session_lock: - if s.name.lower() == name.lower(): - return False - return True - - @classmethod - def getSessionByName(cls, name: str) -> Optional["Session"]: - if not name: - return None - lname = name.lower() - with cls.lock: - for s in cls._instances: - with s.session_lock: - if s.name and s.name.lower() == lname: - return s - return None - - def getName(self) -> str: - with self.session_lock: - return f"{self.short}:{self.name if self.name else unset_label}" - - def setName(self, name: str): - with self.session_lock: - self.name = name - self.update_last_used() - self.save() - - def update_last_used(self): - """Update the last_used timestamp""" - with self.session_lock: - self.last_used = time.time() - - def mark_displaced(self): - """Mark this session as having its name taken over""" - with self.session_lock: - self.displaced_at = time.time() - - @staticmethod - def _should_remove_session_static( - name: str, - ws: WebSocket | None, - created_at: float, - last_used: float, - displaced_at: float | None, - current_time: float, - ) -> bool: - """Static method to determine if a session should be removed""" - # Rule 1: Delete sessions with no active connection and no name that are older than threshold - if ( - not ws - and not name - and current_time - created_at > SessionConfig.ANONYMOUS_SESSION_TIMEOUT - ): - return True - - # Rule 2: Delete inactive sessions that had their nick taken over and haven't been used recently - if ( - not ws - and displaced_at is not None - and current_time - last_used > SessionConfig.DISPLACED_SESSION_TIMEOUT - ): - return True - - return False - - def _should_remove(self, current_time: float) -> bool: - """Check if this session should be removed""" - with self.session_lock: - return self._should_remove_session_static( - self.name, - self.ws, - self.created_at, - self.last_used, - self.displaced_at, - current_time, - ) - - @classmethod - def _remove_session_safely(cls, session: Session, empty_lobbies: set[str]) -> None: - """Safely remove a session and track affected lobbies""" - try: - with session.session_lock: - # Remove from lobbies first - for lobby in session.lobbies[ - : - ]: # Copy list to avoid modification during iteration - try: - with lobby.lock: - if session.id in lobby.sessions: - del lobby.sessions[session.id] - if len(lobby.sessions) == 0: - empty_lobbies.add(lobby.id) - - if lobby.id in session.lobby_peers: - del session.lobby_peers[lobby.id] - except Exception as e: - logger.warning( - f"Error removing session {session.getName()} from lobby {lobby.getName()}: {e}" - ) - - # Close WebSocket if open - if session.ws: - try: - asyncio.create_task(session.ws.close()) - except Exception as e: - logger.warning( - f"Error closing WebSocket for {session.getName()}: {e}" - ) - session.ws = None - - # Remove from instances list - with cls.lock: - if session in cls._instances: - cls._instances.remove(session) - - except Exception as e: - logger.error( - f"Error during safe session removal for {session.getName()}: {e}" - ) - - @classmethod - def _cleanup_empty_lobbies(cls, empty_lobbies: set[str]) -> int: - """Clean up empty lobbies from global lobbies dict""" - removed_count = 0 - for lobby_id in empty_lobbies: - if lobby_id in lobbies: - lobby_name = lobbies[lobby_id].getName() - del lobbies[lobby_id] - logger.info(f"Removed empty lobby {lobby_name}") - removed_count += 1 - return removed_count - - @classmethod - def cleanup_old_sessions(cls) -> int: - """Clean up old sessions based on the specified criteria with improved safety""" - current_time = time.time() - sessions_removed = 0 - - try: - # Circuit breaker - don't remove too many sessions at once - sessions_to_remove: list[Session] = [] - empty_lobbies: set[str] = set() - - with cls.lock: - # Identify sessions to remove (up to max limit) - for session in cls._instances[:]: - if ( - len(sessions_to_remove) - >= SessionConfig.MAX_SESSIONS_PER_CLEANUP - ): - logger.warning( - f"Hit session cleanup limit ({SessionConfig.MAX_SESSIONS_PER_CLEANUP}), " - f"stopping cleanup. Remaining sessions will be cleaned up in next cycle." - ) - break - - if session._should_remove(current_time): - sessions_to_remove.append(session) - logger.info( - f"Marking session {session.getName()} for removal - " - f"criteria: no_ws={session.ws is None}, no_name={not session.name}, " - f"age={current_time - session.created_at:.0f}s, " - f"displaced={session.displaced_at is not None}, " - f"unused={current_time - session.last_used:.0f}s" - ) - - # Remove the identified sessions - for session in sessions_to_remove: - cls._remove_session_safely(session, empty_lobbies) - sessions_removed += 1 - - # Clean up empty lobbies - empty_lobbies_removed = cls._cleanup_empty_lobbies(empty_lobbies) - - # Save state if we made changes - if sessions_removed > 0: - cls.save() - logger.info( - f"Session cleanup completed: removed {sessions_removed} sessions, " - f"{empty_lobbies_removed} empty lobbies" - ) - - except Exception as e: - logger.error(f"Error during session cleanup: {e}") - # Don't re-raise - cleanup should be resilient - - return sessions_removed - - @classmethod - def get_cleanup_metrics(cls) -> AdminMetricsResponse: - """Return cleanup metrics for monitoring""" - current_time = time.time() - - with cls.lock: - total_sessions = len(cls._instances) - active_sessions = 0 - named_sessions = 0 - displaced_sessions = 0 - old_anonymous = 0 - old_displaced = 0 - - for s in cls._instances: - with s.session_lock: - if s.ws: - active_sessions += 1 - if s.name: - named_sessions += 1 - if s.displaced_at is not None: - displaced_sessions += 1 - if ( - not s.ws - and current_time - s.last_used - > SessionConfig.DISPLACED_SESSION_TIMEOUT - ): - old_displaced += 1 - if ( - not s.ws - and not s.name - and current_time - s.created_at - > SessionConfig.ANONYMOUS_SESSION_TIMEOUT - ): - old_anonymous += 1 - - config = AdminMetricsConfig( - anonymous_timeout=SessionConfig.ANONYMOUS_SESSION_TIMEOUT, - displaced_timeout=SessionConfig.DISPLACED_SESSION_TIMEOUT, - cleanup_interval=SessionConfig.CLEANUP_INTERVAL, - max_cleanup_per_cycle=SessionConfig.MAX_SESSIONS_PER_CLEANUP, - ) - - return AdminMetricsResponse( - total_sessions=total_sessions, - active_sessions=active_sessions, - named_sessions=named_sessions, - displaced_sessions=displaced_sessions, - old_anonymous_sessions=old_anonymous, - old_displaced_sessions=old_displaced, - total_lobbies=len(lobbies), - cleanup_candidates=old_anonymous + old_displaced, - config=config, - ) - - @classmethod - def validate_session_integrity(cls) -> list[str]: - """Validate session data integrity""" - issues: list[str] = [] - - try: - with cls.lock: - for session in cls._instances: - with session.session_lock: - # Check for orphaned lobby references - for lobby in session.lobbies: - if lobby.id not in lobbies: - issues.append( - f"Session {session.id[:8]}:{session.name} references missing lobby {lobby.id}" - ) - - # Check for inconsistent peer relationships - for lobby_id, peer_ids in session.lobby_peers.items(): - lobby = lobbies.get(lobby_id) - if lobby: - with lobby.lock: - if session.id not in lobby.sessions: - issues.append( - f"Session {session.id[:8]}:{session.name} has peers in lobby {lobby_id} but not in lobby.sessions" - ) - - # Check if peer sessions actually exist - for peer_id in peer_ids: - if peer_id not in lobby.sessions: - issues.append( - f"Session {session.id[:8]}:{session.name} references non-existent peer {peer_id} in lobby {lobby_id}" - ) - else: - issues.append( - f"Session {session.id[:8]}:{session.name} has peer list for non-existent lobby {lobby_id}" - ) - - # Check lobbies for consistency - for lobby_id, lobby in lobbies.items(): - with lobby.lock: - for session_id in lobby.sessions: - found_session = None - for s in cls._instances: - if s.id == session_id: - found_session = s - break - - if not found_session: - issues.append( - f"Lobby {lobby_id} references non-existent session {session_id}" - ) - else: - with found_session.session_lock: - if lobby not in found_session.lobbies: - issues.append( - f"Lobby {lobby_id} contains session {session_id} but session doesn't reference lobby" - ) - - except Exception as e: - logger.error(f"Error during session validation: {e}") - issues.append(f"Validation error: {str(e)}") - - return issues - - @classmethod - async def cleanup_all_sessions(cls): - """Clean up all sessions during shutdown""" - logger.info("Starting graceful session cleanup...") - - try: - with cls.lock: - sessions_to_cleanup = cls._instances[:] - - for session in sessions_to_cleanup: - try: - with session.session_lock: - # Close WebSocket connections - if session.ws: - try: - await session.ws.close() - except Exception as e: - logger.warning( - f"Error closing WebSocket for {session.getName()}: {e}" - ) - session.ws = None - - # Remove from lobbies - for lobby in session.lobbies[:]: - try: - await session.part(lobby) - except Exception as e: - logger.warning( - f"Error removing {session.getName()} from lobby: {e}" - ) - - except Exception as e: - logger.error(f"Error cleaning up session {session.getName()}: {e}") - - # Clear all data structures - with cls.lock: - cls._instances.clear() - lobbies.clear() - - logger.info( - f"Graceful session cleanup completed for {len(sessions_to_cleanup)} sessions" - ) - - except Exception as e: - logger.error(f"Error during graceful session cleanup: {e}") - - async def join(self, lobby: Lobby): - if not self.ws: - logger.error( - f"{self.getName()} - No WebSocket connection. Lobby not available." - ) - return - - with self.session_lock: - if lobby.id in self.lobby_peers or self.id in lobby.sessions: - logger.info(f"{self.getName()} - Already joined to {lobby.getName()}.") - data = JoinStatusModel( - status="Joined", - message=f"Already joined to lobby {lobby.getName()}", - ) - try: - await self.ws.send_json( - {"type": "join_status", "data": data.model_dump()} - ) - except Exception as e: - logger.warning( - f"Failed to send join status to {self.getName()}: {e}" - ) - return - - # Initialize the peer list for this lobby - with self.session_lock: - self.lobbies.append(lobby) - self.lobby_peers[lobby.id] = [] - - with lobby.lock: - peer_sessions = list(lobby.sessions.values()) - - for peer_session in peer_sessions: - if peer_session.id == self.id: - logger.error( - "Should not happen: self in lobby.sessions while not in lobby." - ) - continue - - if not peer_session.ws: - logger.warning( - f"{self.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Removing." - ) - with lobby.lock: - if peer_session.id in lobby.sessions: - del lobby.sessions[peer_session.id] - continue - - # Only create WebRTC peer connections if at least one participant has media - should_create_rtc_connection = self.has_media or peer_session.has_media - - if should_create_rtc_connection: - # Add the peer to session's RTC peer list - with self.session_lock: - self.lobby_peers[lobby.id].append(peer_session.id) - - # Add this user as an RTC peer to each existing peer - with peer_session.session_lock: - if lobby.id not in peer_session.lobby_peers: - peer_session.lobby_peers[lobby.id] = [] - peer_session.lobby_peers[lobby.id].append(self.id) - - logger.info( - f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName()}, {lobby.getName()}, should_create_offer=False, has_media={self.has_media})" - ) - try: - await peer_session.ws.send_json( - { - "type": "addPeer", - "data": { - "peer_id": self.id, - "peer_name": self.name, - "has_media": self.has_media, - "should_create_offer": False, - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to send addPeer to {peer_session.getName()}: {e}" - ) - - # Add each other peer to the caller - logger.info( - f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName()}, {lobby.getName()}, should_create_offer=True, has_media={peer_session.has_media})" - ) - try: - await self.ws.send_json( - { - "type": "addPeer", - "data": { - "peer_id": peer_session.id, - "peer_name": peer_session.name, - "has_media": peer_session.has_media, - "should_create_offer": True, - }, - } - ) - except Exception as e: - logger.warning(f"Failed to send addPeer to {self.getName()}: {e}") - else: - logger.info( - f"{self.getName()} - Skipping WebRTC connection with {peer_session.getName()} (neither has media: self={self.has_media}, peer={peer_session.has_media})" - ) - - # Add this user as an RTC peer - await lobby.addSession(self) - Session.save() - - try: - await self.ws.send_json( - {"type": "join_status", "data": {"status": "Joined"}} - ) - except Exception as e: - logger.warning(f"Failed to send join confirmation to {self.getName()}: {e}") - - async def part(self, lobby: Lobby): - with self.session_lock: - if lobby.id not in self.lobby_peers or self.id not in lobby.sessions: - logger.info( - f"{self.getName()} - Attempt to part non-joined lobby {lobby.getName()}." - ) - if self.ws: - try: - await self.ws.send_json( - { - "type": "error", - "data": { - "error": "Attempt to part non-joined lobby", - }, - } - ) - except Exception: - pass - return - - logger.info(f"{self.getName()} <- part({lobby.getName()}) - Lobby part.") - - lobby_peers = self.lobby_peers[lobby.id][:] # Copy the list - del self.lobby_peers[lobby.id] - if lobby in self.lobbies: - self.lobbies.remove(lobby) - - # Remove this peer from all other RTC peers, and remove each peer from this peer - for peer_session_id in lobby_peers: - peer_session = getSession(peer_session_id) - if not peer_session: - logger.warning( - f"{self.getName()} <- part({lobby.getName()}) - Peer session {peer_session_id} not found. Skipping." - ) - continue - - if peer_session.ws: - logger.info( - f"{peer_session.getName()} <- remove_peer({self.getName()})" - ) - try: - await peer_session.ws.send_json( - { - "type": "removePeer", - "data": {"peer_name": self.name, "peer_id": self.id}, - } - ) - except Exception as e: - logger.warning( - f"Failed to send removePeer to {peer_session.getName()}: {e}" - ) - else: - logger.warning( - f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection for {peer_session.getName()}. Skipping." - ) - - # Remove from peer's lobby_peers - with peer_session.session_lock: - if ( - lobby.id in peer_session.lobby_peers - and self.id in peer_session.lobby_peers[lobby.id] - ): - peer_session.lobby_peers[lobby.id].remove(self.id) - - if self.ws: - logger.info( - f"{self.getName()} <- remove_peer({peer_session.getName()})" - ) - try: - await self.ws.send_json( - { - "type": "removePeer", - "data": { - "peer_name": peer_session.name, - "peer_id": peer_session.id, - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to send removePeer to {self.getName()}: {e}" - ) - else: - logger.error( - f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection." - ) - - await lobby.removeSession(self) - Session.save() - - -def getName(session: Session | None) -> str | None: - if session and session.name: - return session.name - return None - - -def getSession(session_id: str) -> Session | None: - return Session.getSession(session_id) - - -def getLobby(lobby_id: str) -> Lobby: - lobby = lobbies.get(lobby_id, None) - if not lobby: - # Check if this might be a stale reference after cleanup - logger.warning(f"Lobby not found: {lobby_id} (may have been cleaned up)") - raise Exception(f"Lobby not found: {lobby_id}") - return lobby - - -def getLobbyByName(lobby_name: str) -> Lobby | None: - for lobby in lobbies.values(): - if lobby.name == lobby_name: - return lobby - return None - - -# API endpoints -@app.get(f"{public_url}api/health", response_model=HealthResponse) -def health(): - logger.info("Health check endpoint called.") - return HealthResponse(status="ok") - - -# A session (cookie) is bound to a single user (name). -# A user can be in multiple lobbies, but a session is unique to a single user. -# A user can change their name, but the session ID remains the same and the name -# updates for all lobbies. -@app.get(f"{public_url}api/session", response_model=SessionResponse) -async def session( - request: Request, response: Response, session_id: str | None = Cookie(default=None) -) -> Response | SessionResponse: - if session_id is None: - session_id = secrets.token_hex(16) - response.set_cookie(key="session_id", value=session_id) - # Validate that session_id is a hex string of length 32 - elif len(session_id) != 32 or not all(c in "0123456789abcdef" for c in session_id): - return Response( - content=json.dumps({"error": "Invalid session_id"}), - status_code=400, - media_type="application/json", - ) - - print(f"[{session_id[:8]}]: Browser hand-shake achieved.") - - session = getSession(session_id) - if not session: - session = Session(session_id) - logger.info(f"{session.getName()}: New session created.") - else: - session.update_last_used() # Update activity on session resumption - logger.info(f"{session.getName()}: Existing session resumed.") - # Part all lobbies for this session that have no active websocket - with session.session_lock: - lobbies_to_part = session.lobbies[:] - for lobby in lobbies_to_part: - try: - await session.part(lobby) - except Exception as e: - logger.error( - f"{session.getName()} - Error parting lobby {lobby.getName()}: {e}" - ) - - with session.session_lock: - return SessionResponse( - id=session_id, - name=session.name if session.name else "", - lobbies=[ - LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private) - for lobby in session.lobbies - ], - ) - - -@app.get(public_url + "api/lobby", response_model=LobbiesResponse) -async def get_lobbies(request: Request, response: Response) -> LobbiesResponse: - return LobbiesResponse( - lobbies=[ - LobbyListItem(id=lobby.id, name=lobby.name) - for lobby in lobbies.values() - if not lobby.private - ] - ) - - -@app.post(public_url + "api/lobby/{session_id}", response_model=LobbyCreateResponse) -async def lobby_create( - request: Request, - response: Response, - session_id: str = Path(...), - create_request: LobbyCreateRequest = Body(...), -) -> Response | LobbyCreateResponse: - if create_request.type != "lobby_create": - return Response( - content=json.dumps({"error": "Invalid request type"}), - status_code=400, - media_type="application/json", - ) - - data = create_request.data - session = getSession(session_id) - if not session: - return Response( - content=json.dumps({"error": f"Session not found ({session_id})"}), - status_code=404, - media_type="application/json", - ) - logger.info( - f"{session.getName()} lobby_create: {data.name} (private={data.private})" - ) - - lobby = getLobbyByName(data.name) - if not lobby: - lobby = Lobby( - data.name, - private=data.private, - ) - lobbies[lobby.id] = lobby - logger.info(f"{session.getName()} <- lobby_create({lobby.short}:{lobby.name})") - - return LobbyCreateResponse( - type="lobby_created", - data=LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private), - ) - - -@app.get(public_url + "api/lobby/{lobby_id}/chat", response_model=ChatMessagesResponse) -async def get_chat_messages( - request: Request, - lobby_id: str = Path(...), - limit: int = 50, -) -> Response | ChatMessagesResponse: - """Get chat messages for a lobby""" - try: - lobby = getLobby(lobby_id) - except Exception as e: - return Response( - content=json.dumps({"error": str(e)}), - status_code=404, - media_type="application/json", - ) - - messages = lobby.get_chat_messages(limit) - - return ChatMessagesResponse(messages=messages) - - -# ============================================================================= -# Bot Provider API Endpoints -# ============================================================================= - - -@app.post( - public_url + "api/bots/providers/register", - response_model=BotProviderRegisterResponse, -) -async def register_bot_provider( - request: BotProviderRegisterRequest, -) -> BotProviderRegisterResponse: - """Register a new bot provider with authentication""" - import uuid - - # Check if provider authentication is enabled - allowed_providers = BotProviderConfig.get_allowed_providers() - if allowed_providers: - # Authentication is enabled - validate provider key - if request.provider_key not in allowed_providers: - logger.warning( - f"Rejected bot provider registration with invalid key: {request.provider_key}" - ) - raise HTTPException( - status_code=403, - detail="Invalid provider key. Bot provider is not authorized to register.", - ) - - # Check if there's already an active provider with this key and remove it - providers_to_remove: list[str] = [] - for existing_provider_id, existing_provider in bot_providers.items(): - if existing_provider.provider_key == request.provider_key: - providers_to_remove.append(existing_provider_id) - logger.info( - f"Removing stale bot provider: {existing_provider.name} (ID: {existing_provider_id})" - ) - - # Remove stale providers - for provider_id_to_remove in providers_to_remove: - del bot_providers[provider_id_to_remove] - - provider_id = str(uuid.uuid4()) - now = time.time() - - provider = BotProviderModel( - provider_id=provider_id, - base_url=request.base_url.rstrip("/"), - name=request.name, - description=request.description, - provider_key=request.provider_key, - registered_at=now, - last_seen=now, - ) - - bot_providers[provider_id] = provider - logger.info( - f"Registered bot provider: {request.name} at {request.base_url} with key: {request.provider_key}" - ) - - return BotProviderRegisterResponse(provider_id=provider_id) - - -@app.get(public_url + "api/bots/providers", response_model=BotProviderListResponse) -async def list_bot_providers() -> BotProviderListResponse: - """List all registered bot providers""" - return BotProviderListResponse(providers=list(bot_providers.values())) - - -@app.get(public_url + "api/bots", response_model=BotListResponse) -async def list_available_bots() -> BotListResponse: - """List all available bots from all registered providers""" - bots: List[BotInfoModel] = [] - providers: dict[str, str] = {} - - # Update last_seen timestamps and fetch bots from each provider - for provider_id, provider in bot_providers.items(): - try: - provider.last_seen = time.time() - - # Make HTTP request to provider's /bots endpoint - async with httpx.AsyncClient() as client: - response = await client.get(f"{provider.base_url}/bots", timeout=5.0) - if response.status_code == 200: - # Use Pydantic model to validate the response - bots_response = BotProviderBotsResponse.model_validate( - response.json() - ) - # Add each bot to the consolidated list - for bot_info in bots_response.bots: - bots.append(bot_info) - providers[bot_info.name] = provider_id - else: - logger.warning( - f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}" - ) - except Exception as e: - logger.error(f"Error fetching bots from provider {provider.name}: {e}") - continue - - return BotListResponse(bots=bots, providers=providers) - - -@app.post(public_url + "api/bots/{bot_name}/join", response_model=BotJoinLobbyResponse) -async def request_bot_join_lobby( - bot_name: str, request: BotJoinLobbyRequest -) -> BotJoinLobbyResponse: - """Request a bot to join a specific lobby""" - - # Find which provider has this bot and determine its media capability - target_provider_id = request.provider_id - bot_has_media = False - if not target_provider_id: - # Auto-discover provider for this bot - for provider_id, provider in bot_providers.items(): - try: - async with httpx.AsyncClient() as client: - response = await client.get( - f"{provider.base_url}/bots", timeout=5.0 - ) - if response.status_code == 200: - # Use Pydantic model to validate the response - bots_response = BotProviderBotsResponse.model_validate( - response.json() - ) - # Look for the bot by name - for bot_info in bots_response.bots: - if bot_info.name == bot_name: - target_provider_id = provider_id - bot_has_media = bot_info.has_media - break - if target_provider_id: - break - except Exception: - continue - else: - # Query the specified provider for bot media capability - if target_provider_id in bot_providers: - provider = bot_providers[target_provider_id] - try: - async with httpx.AsyncClient() as client: - response = await client.get( - f"{provider.base_url}/bots", timeout=5.0 - ) - if response.status_code == 200: - # Use Pydantic model to validate the response - bots_response = BotProviderBotsResponse.model_validate( - response.json() - ) - # Look for the bot by name - for bot_info in bots_response.bots: - if bot_info.name == bot_name: - bot_has_media = bot_info.has_media - break - except Exception: - # Default to no media if we can't query - pass - - if not target_provider_id or target_provider_id not in bot_providers: - raise HTTPException(status_code=404, detail="Bot or provider not found") - - provider = bot_providers[target_provider_id] - - # Get the lobby to validate it exists - try: - getLobby(request.lobby_id) # Just validate it exists - except Exception: - raise HTTPException(status_code=404, detail="Lobby not found") - - # Create a session for the bot - bot_session_id = secrets.token_hex(16) - - # Create the Session object for the bot - bot_session = Session(bot_session_id, is_bot=True, has_media=bot_has_media) - logger.info( - f"Created bot session for: {bot_session.getName()} (has_media={bot_has_media})" - ) - - # Determine server URL for the bot to connect back to - # Use the server's public URL or construct from request - server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000") - if server_base_url.endswith("/"): - server_base_url = server_base_url[:-1] - - bot_nick = request.nick or f"{bot_name}-bot-{bot_session_id[:8]}" - - # Prepare the join request for the bot provider - bot_join_payload = BotJoinPayload( - lobby_id=request.lobby_id, - session_id=bot_session_id, - nick=bot_nick, - server_url=f"{server_base_url}{public_url}".rstrip("/"), - insecure=True, # Accept self-signed certificates in development - ) - - try: - # Make request to bot provider - async with httpx.AsyncClient() as client: - response = await client.post( - f"{provider.base_url}/bots/{bot_name}/join", - json=bot_join_payload.model_dump(), - timeout=10.0, - ) - - if response.status_code == 200: - # Use Pydantic model to parse and validate response - try: - join_response = BotProviderJoinResponse.model_validate( - response.json() - ) - run_id = join_response.run_id - - # Update bot session with run and provider information - with bot_session.session_lock: - bot_session.bot_run_id = run_id - bot_session.bot_provider_id = target_provider_id - bot_session.setName(bot_nick) - - logger.info( - f"Bot {bot_name} requested to join lobby {request.lobby_id}" - ) - - return BotJoinLobbyResponse( - status="requested", - bot_name=bot_name, - run_id=run_id, - provider_id=target_provider_id, - ) - except ValidationError as e: - logger.error(f"Invalid response from bot provider: {e}") - raise HTTPException( - status_code=502, - detail=f"Bot provider returned invalid response: {str(e)}", - ) - else: - logger.error( - f"Bot provider returned error: HTTP {response.status_code}: {response.text}" - ) - raise HTTPException( - status_code=502, - detail=f"Bot provider error: {response.status_code}", - ) - - except httpx.TimeoutException: - raise HTTPException(status_code=504, detail="Bot provider timeout") - except Exception as e: - logger.error(f"Error requesting bot join: {e}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@app.post(public_url + "api/bots/leave", response_model=BotLeaveLobbyResponse) -async def request_bot_leave_lobby( - request: BotLeaveLobbyRequest, -) -> BotLeaveLobbyResponse: - """Request a bot to leave from all lobbies and disconnect""" - - # Find the bot session - bot_session = getSession(request.session_id) - if not bot_session: - raise HTTPException(status_code=404, detail="Bot session not found") - - if not bot_session.is_bot: - raise HTTPException(status_code=400, detail="Session is not a bot") - - run_id = bot_session.bot_run_id - provider_id = bot_session.bot_provider_id - - logger.info(f"Requesting bot {bot_session.getName()} to leave all lobbies") - - # Try to stop the bot at the provider level if we have the information - if provider_id and run_id and provider_id in bot_providers: - provider = bot_providers[provider_id] - try: - async with httpx.AsyncClient() as client: - response = await client.post( - f"{provider.base_url}/bots/runs/{run_id}/stop", - timeout=5.0, - ) - if response.status_code == 200: - logger.info( - f"Successfully requested bot provider to stop run {run_id}" - ) - else: - logger.warning( - f"Bot provider returned error when stopping: HTTP {response.status_code}" - ) - except Exception as e: - logger.warning(f"Failed to request bot stop from provider: {e}") - - # Force disconnect the bot session from all lobbies - with bot_session.session_lock: - lobbies_to_part = bot_session.lobbies[:] - - for lobby in lobbies_to_part: - try: - await bot_session.part(lobby) - except Exception as e: - logger.warning(f"Error parting bot from lobby {lobby.getName()}: {e}") - - # Close WebSocket connection if it exists - if bot_session.ws: - try: - await bot_session.ws.close() - except Exception as e: - logger.warning(f"Error closing bot WebSocket: {e}") - bot_session.ws = None - - return BotLeaveLobbyResponse( - status="disconnected", - session_id=request.session_id, - run_id=run_id, - ) - - -# Register websocket endpoint directly on app with full public_url path -@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}") -async def lobby_join( - websocket: WebSocket, - lobby_id: str | None = Path(...), - session_id: str | None = Path(...), -): - await websocket.accept() - if lobby_id is None: - await websocket.send_json( - {"type": "error", "data": {"error": "Invalid or missing lobby"}} - ) - await websocket.close() - return - if session_id is None: - await websocket.send_json( - {"type": "error", "data": {"error": "Invalid or missing session"}} - ) - await websocket.close() - return - session = getSession(session_id) - if not session: - # logger.error(f"Invalid session ID {session_id}") - await websocket.send_json( - {"type": "error", "data": {"error": f"Invalid session ID {session_id}"}} - ) - await websocket.close() - return - - lobby = None - try: - lobby = getLobby(lobby_id) - except Exception as e: - await websocket.send_json({"type": "error", "data": {"error": str(e)}}) - await websocket.close() - return - - logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})") - - session.ws = websocket - session.update_last_used() # Update activity timestamp - - # Check if session is already in lobby and clean up if needed - with lobby.lock: - if session.id in lobby.sessions: - logger.info( - f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining." - ) - try: - await session.part(lobby) - await lobby.removeSession(session) - except Exception as e: - logger.warning(f"Error cleaning up stale session: {e}") - - # Notify existing peers about new user - failed_peers: list[str] = [] - with lobby.lock: - peer_sessions = list(lobby.sessions.values()) - - for peer_session in peer_sessions: - if not peer_session.ws: - logger.warning( - f"{session.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Marking for removal." - ) - failed_peers.append(peer_session.id) - continue - - logger.info(f"{session.getName()} -> user_joined({peer_session.getName()})") - try: - await peer_session.ws.send_json( - { - "type": "user_joined", - "data": { - "session_id": session.id, - "name": session.name, - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to notify {peer_session.getName()} of user join: {e}" - ) - failed_peers.append(peer_session.id) - - # Clean up failed peers - with lobby.lock: - for failed_peer_id in failed_peers: - if failed_peer_id in lobby.sessions: - del lobby.sessions[failed_peer_id] - - try: - while True: - packet = await websocket.receive_json() - session.update_last_used() # Update activity on each message - type = packet.get("type", None) - data: dict[str, Any] | None = packet.get("data", None) - if not type: - logger.error(f"{session.getName()} - Invalid request: {packet}") - await websocket.send_json( - {"type": "error", "data": {"error": "Invalid request"}} - ) - continue - # logger.info(f"{session.getName()} <- RAW Rx: {data}") - match type: - case "set_name": - if not data: - logger.error(f"{session.getName()} - set_name missing data") - await websocket.send_json( - { - "type": "error", - "data": {"error": "set_name missing data"}, - } - ) - continue - name = data.get("name") - password = data.get("password") - logger.info(f"{session.getName()} <- set_name({name}, {password})") - if not name: - logger.error(f"{session.getName()} - Name required") - await websocket.send_json( - {"type": "error", "data": {"error": "Name required"}} - ) - continue - # Name takeover / password logic - lname = name.lower() - - # If name is unused, allow and optionally save password - if Session.isUniqueName(name): - # If a password was provided, save it (hash+salt) for this name - if password: - salt, hash_hex = _hash_password(password) - name_passwords[lname] = {"salt": salt, "hash": hash_hex} - session.setName(name) - logger.info(f"{session.getName()}: -> update('name', {name})") - await websocket.send_json( - { - "type": "update_name", - "data": { - "name": name, - "protected": True - if name.lower() in name_passwords - else False, - }, - } - ) - # For any clients in any lobby with this session, update their user lists - await lobby.update_state() - continue - - # Name is taken. Check if a password exists for the name and matches. - saved_pw = name_passwords.get(lname) - if not saved_pw and not password: - logger.warning( - f"{session.getName()} - Name already taken (no password set)" - ) - await websocket.send_json( - {"type": "error", "data": {"error": "Name already taken"}} - ) - continue - - if saved_pw and password: - # Expect structured record with salt+hash only - match_password = False - # saved_pw should be a dict[str,str] with 'salt' and 'hash' - salt = saved_pw.get("salt") - _, candidate_hash = _hash_password( - password if password else "", salt_hex=salt - ) - if candidate_hash == saved_pw.get("hash"): - match_password = True - else: - # No structured password record available - match_password = False - else: - match_password = True # No password set, but name taken and new password - allow takeover - - if not match_password: - logger.warning( - f"{session.getName()} - Name takeover attempted with wrong or missing password" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "Invalid password for name takeover", - }, - } - ) - continue - - # Password matches: perform takeover. Find the current session holding the name. - # Find the currently existing session (if any) with that name - displaced = Session.getSessionByName(name) - if displaced and displaced.id == session.id: - displaced = None - - # If found, change displaced session to a unique fallback name and notify peers - if displaced: - # Create a unique fallback name - fallback = f"{displaced.name}-{displaced.short}" - # Ensure uniqueness - if not Session.isUniqueName(fallback): - # append random suffix until unique - while not Session.isUniqueName(fallback): - fallback = f"{displaced.name}-{secrets.token_hex(3)}" - - displaced.setName(fallback) - displaced.mark_displaced() - logger.info( - f"{displaced.getName()} <- displaced by takeover, new name {fallback}" - ) - # Notify displaced session (if connected) - if displaced.ws: - try: - await displaced.ws.send_json( - { - "type": "update_name", - "data": { - "name": fallback, - "protected": False, - }, - } - ) - except Exception: - logger.exception( - "Failed to notify displaced session websocket" - ) - # Update all lobbies the displaced session was in - with displaced.session_lock: - displaced_lobbies = displaced.lobbies[:] - for d_lobby in displaced_lobbies: - try: - await d_lobby.update_state() - except Exception: - logger.exception( - "Failed to update lobby state for displaced session" - ) - - # Now assign the requested name to the current session - session.setName(name) - logger.info( - f"{session.getName()}: -> update('name', {name}) (takeover)" - ) - await websocket.send_json( - { - "type": "update_name", - "data": { - "name": name, - "protected": True - if name.lower() in name_passwords - else False, - }, - } - ) - # Notify lobbies for this session - await lobby.update_state() - - case "list_users": - await lobby.update_state(session) - - case "get_chat_messages": - # Send recent chat messages to the requesting client - messages = lobby.get_chat_messages(50) - await websocket.send_json( - { - "type": "chat_messages", - "data": { - "messages": [msg.model_dump() for msg in messages] - }, - } - ) - - case "send_chat_message": - if not data or "message" not in data: - logger.error( - f"{session.getName()} - send_chat_message missing message" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "send_chat_message missing message", - }, - } - ) - continue - - if not session.name: - logger.error( - f"{session.getName()} - Cannot send chat message without name" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "Must set name before sending chat messages", - }, - } - ) - continue - - message_text = str(data["message"]).strip() - if not message_text: - continue - - # Add the message to the lobby and broadcast it - chat_message = lobby.add_chat_message(session, message_text) - logger.info( - f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)" - ) - await lobby.broadcast_chat_message(chat_message) - - case "join": - logger.info(f"{session.getName()} <- join({lobby.getName()})") - await session.join(lobby=lobby) - - case "part": - logger.info(f"{session.getName()} <- part {lobby.getName()}") - await session.part(lobby=lobby) - - case "relayICECandidate": - logger.info(f"{session.getName()} <- relayICECandidate") - if not data: - logger.error( - f"{session.getName()} - relayICECandidate missing data" - ) - await websocket.send_json( - { - "type": "error", - "data": {"error": "relayICECandidate missing data"}, - } - ) - continue - - with session.session_lock: - if ( - lobby.id not in session.lobby_peers - or session.id not in lobby.sessions - ): - logger.error( - f"{session.short}:{session.name} <- relayICECandidate - Not an RTC peer ({session.id})" - ) - await websocket.send_json( - { - "type": "error", - "data": {"error": "Not joined to lobby"}, - } - ) - continue - session_peers = session.lobby_peers[lobby.id] - - peer_id = data.get("peer_id") - if peer_id not in session_peers: - logger.error( - f"{session.getName()} <- relayICECandidate - Not an RTC peer({peer_id}) in {session_peers}" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": f"Target peer {peer_id} not found", - }, - } - ) - continue - - candidate = data.get("candidate") - - message: dict[str, Any] = { - "type": "iceCandidate", - "data": { - "peer_id": session.id, - "peer_name": session.name, - "candidate": candidate, - }, - } - - peer_session = lobby.getSession(peer_id) - if not peer_session or not peer_session.ws: - logger.warning( - f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}." - ) - continue - logger.info( - f"{session.getName()} -> iceCandidate({peer_session.getName()})" - ) - try: - await peer_session.ws.send_json(message) - except Exception as e: - logger.warning(f"Failed to relay ICE candidate: {e}") - - case "relaySessionDescription": - logger.info(f"{session.getName()} <- relaySessionDescription") - if not data: - logger.error( - f"{session.getName()} - relaySessionDescription missing data" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "relaySessionDescription missing data", - }, - } - ) - continue - - with session.session_lock: - if ( - lobby.id not in session.lobby_peers - or session.id not in lobby.sessions - ): - logger.error( - f"{session.short}:{session.name} <- relaySessionDescription - Not an RTC peer ({session.id})" - ) - await websocket.send_json( - { - "type": "error", - "data": {"error": "Not joined to lobby"}, - } - ) - continue - - lobby_peers = session.lobby_peers[lobby.id] - - peer_id = data.get("peer_id") - if peer_id not in lobby_peers: - logger.error( - f"{session.getName()} <- relaySessionDescription - Not an RTC peer({peer_id}) in {lobby_peers}" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": f"Target peer {peer_id} not found", - }, - } - ) - continue - - if not peer_id: - logger.error( - f"{session.getName()} - relaySessionDescription missing peer_id" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "relaySessionDescription missing peer_id", - }, - } - ) - continue - peer_session = lobby.getSession(peer_id) - if not peer_session or not peer_session.ws: - logger.warning( - f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}." - ) - continue - - session_description = data.get("session_description") - message = { - "type": "sessionDescription", - "data": { - "peer_id": session.id, - "peer_name": session.name, - "session_description": session_description, - }, - } - - logger.info( - f"{session.getName()} -> sessionDescription({peer_session.getName()})" - ) - try: - await peer_session.ws.send_json(message) - except Exception as e: - logger.warning(f"Failed to relay session description: {e}") - - case "status_check": - # Simple status check - just respond with success to keep connection alive - logger.debug(f"{session.getName()} <- status_check") - await websocket.send_json( - {"type": "status_ok", "data": {"timestamp": time.time()}} - ) - - case _: - await websocket.send_json( - { - "type": "error", - "data": { - "error": f"Unknown request type: {type}", - }, - } - ) - - except WebSocketDisconnect: - logger.info(f"{session.getName()} <- WebSocket disconnected for user.") - # Cleanup: remove session from lobby and sessions dict - session.ws = None - if session.id in lobby.sessions: - try: - await session.part(lobby) - except Exception as e: - logger.warning(f"Error during websocket disconnect cleanup: {e}") - - try: - await lobby.update_state() - except Exception as e: - logger.warning(f"Error updating lobby state after disconnect: {e}") - - # Clean up empty lobbies - with lobby.lock: - if not lobby.sessions: - if lobby.id in lobbies: - del lobbies[lobby.id] - logger.info(f"Cleaned up empty lobby {lobby.getName()}") - except Exception as e: - logger.error( - f"Unexpected error in websocket handler for {session.getName()}: {e}" - ) - try: - await websocket.close() - except Exception as e: - pass - - -# Serve static files or proxy to frontend development server -PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true" -client_build_path = os.path.join(os.path.dirname(__file__), "/client/build") - -if PRODUCTION: - logger.info(f"Serving static files from: {client_build_path} at {public_url}") - app.mount( - public_url, StaticFiles(directory=client_build_path, html=True), name="static" - ) - - -else: - logger.info(f"Proxying static files to http://client:3000 at {public_url}") - - import ssl - - @app.api_route( - f"{public_url}{{path:path}}", - methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"], - ) - async def proxy_static(request: Request, path: str): - # Do not proxy API or websocket paths - if path.startswith("api/") or path.startswith("ws/"): - return Response(status_code=404) - url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}/{path}" - if not path: - url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}" - headers = dict(request.headers) - try: - # Accept self-signed certs in dev - async with httpx.AsyncClient(verify=False) as client: - proxy_req = client.build_request( - request.method, url, headers=headers, content=await request.body() - ) - proxy_resp = await client.send(proxy_req, stream=True) - content = await proxy_resp.aread() - - # Remove problematic headers for browser decoding - filtered_headers = { - k: v - for k, v in proxy_resp.headers.items() - if k.lower() - not in ["content-encoding", "transfer-encoding", "content-length"] - } - return Response( - content=content, - status_code=proxy_resp.status_code, - headers=filtered_headers, - ) - except Exception as e: - logger.error(f"Proxy error for {url}: {e}") - return Response("Proxy error", status_code=502) - - # WebSocket proxy for /ws (for React DevTools, etc.) - import websockets - - @app.websocket("/ws") - async def websocket_proxy(websocket: WebSocket): - logger.info("REACT: WebSocket proxy connection established.") - # Get scheme from websocket.url (should be 'ws' or 'wss') - scheme = websocket.url.scheme if hasattr(websocket, "url") else "ws" - target_url = f"{scheme}://client:3000/ws" - await websocket.accept() - try: - # Accept self-signed certs in dev for WSS - ssl_ctx = ssl.create_default_context() - ssl_ctx.check_hostname = False - ssl_ctx.verify_mode = ssl.CERT_NONE - async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws: - - async def client_to_server(): - while True: - msg = await websocket.receive_text() - await target_ws.send(msg) - - async def server_to_client(): - while True: - msg = await target_ws.recv() - if isinstance(msg, str): - await websocket.send_text(msg) - else: - await websocket.send_bytes(msg) - - try: - await asyncio.gather(client_to_server(), server_to_client()) - except (WebSocketDisconnect, websockets.ConnectionClosed): - logger.info("REACT: WebSocket proxy connection closed.") - except Exception as e: - logger.error(f"REACT: WebSocket proxy error: {e}") - await websocket.close() diff --git a/server/main_clean.py b/server/main_clean.py deleted file mode 100644 index 2a96347..0000000 --- a/server/main_clean.py +++ /dev/null @@ -1,293 +0,0 @@ -""" -Refactored main.py - Step 1 of Server Architecture Improvement - -This is a refactored version of the original main.py that demonstrates the new -modular architecture with separated concerns: - -- SessionManager: Handles session lifecycle and persistence -- LobbyManager: Handles lobby management and chat -- AuthManager: Handles authentication and name protection -- WebSocket message routing: Clean message handling -- Separated API modules: Admin, session, and lobby endpoints - -This maintains backward compatibility while providing a foundation for -further improvements. -""" - -from __future__ import annotations -import os -from contextlib import asynccontextmanager - -from fastapi import FastAPI, WebSocket, Path, Request, Response -from fastapi.staticfiles import StaticFiles -import httpx -import ssl -import websockets - -# Import our new modular components -try: - from core.session_manager import SessionManager - from core.lobby_manager import LobbyManager - from core.auth_manager import AuthManager - from websocket.connection import WebSocketConnectionManager - from api.admin import AdminAPI - from api.sessions import SessionAPI - from api.lobbies import LobbyAPI -except ImportError: - # Handle relative imports when running as module - import sys - import os - sys.path.append(os.path.dirname(os.path.abspath(__file__))) - - from core.session_manager import SessionManager - from core.lobby_manager import LobbyManager - from core.auth_manager import AuthManager - from websocket.connection import WebSocketConnectionManager - from api.admin import AdminAPI - from api.sessions import SessionAPI - from api.lobbies import LobbyAPI - -from logger import logger - - -# Configuration -ADMIN_TOKEN = os.getenv("ADMIN_TOKEN") -public_url = os.getenv("PUBLIC_URL", "/") -if not public_url.endswith("/"): - public_url += "/" - -# Global managers - these replace the global variables from original main.py -session_manager: SessionManager = None -lobby_manager: LobbyManager = None -auth_manager: AuthManager = None -websocket_manager: WebSocketConnectionManager = None - -# API instances -admin_api: AdminAPI = None -session_api: SessionAPI = None -lobby_api: LobbyAPI = None - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Lifespan context manager for startup and shutdown""" - global session_manager, lobby_manager, auth_manager, websocket_manager - global admin_api, session_api, lobby_api - - logger.info("Starting AI Voice Bot server with modular architecture...") - - # Initialize core managers - session_manager = SessionManager() - lobby_manager = LobbyManager(session_manager=session_manager) - auth_manager = AuthManager() - - # Set up cross-manager dependencies - session_manager.set_lobby_manager(lobby_manager) - lobby_manager.set_name_protection_checker(auth_manager.is_name_protected) - - # Initialize WebSocket manager - websocket_manager = WebSocketConnectionManager( - session_manager=session_manager, - lobby_manager=lobby_manager - ) - - # Initialize API routers - admin_api = AdminAPI( - session_manager=session_manager, - lobby_manager=lobby_manager, - auth_manager=auth_manager, - admin_token=ADMIN_TOKEN, - public_url=public_url - ) - - session_api = SessionAPI( - session_manager=session_manager, - public_url=public_url - ) - - lobby_api = LobbyAPI( - session_manager=session_manager, - lobby_manager=lobby_manager, - public_url=public_url - ) - - # Register API routes - app.include_router(admin_api.router) - app.include_router(session_api.router) - app.include_router(lobby_api.router) - - # Start background tasks - await session_manager.start_background_tasks() - - logger.info("AI Voice Bot server started successfully!") - logger.info(f"Server URL: {public_url}") - logger.info(f"Sessions loaded: {session_manager.get_session_count()}") - logger.info(f"Lobbies available: {lobby_manager.get_lobby_count()}") - logger.info(f"Protected names: {auth_manager.get_protection_count()}") - - if ADMIN_TOKEN: - logger.info("Admin endpoints protected with token") - else: - logger.warning("Admin endpoints are unprotected") - - yield - - # Shutdown - logger.info("Shutting down AI Voice Bot server...") - if session_manager: - await session_manager.stop_background_tasks() - await session_manager.cleanup_all_sessions() - logger.info("Server shutdown complete") - - -# Create FastAPI app with the new architecture -app = FastAPI( - title="AI Voice Bot Server", - description="Modular AI Voice Bot Server with WebRTC support", - version="2.0.0", - lifespan=lifespan -) - -logger.info(f"Starting server with public URL: {public_url}") - - -@app.websocket(f"{public_url}" + "ws/lobby/{{lobby_id}}/{{session_id}}") -async def lobby_websocket( - websocket: WebSocket, - lobby_id: str = Path(...), - session_id: str = Path(...) -): - """WebSocket endpoint for lobby connections - now uses WebSocketConnectionManager""" - await websocket_manager.handle_connection(websocket, lobby_id, session_id) - - -# WebSocket proxy for React dev server (development mode) -PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true" - -if not PRODUCTION: - @app.websocket("/ws") - async def websocket_proxy(websocket: WebSocket): - """Proxy WebSocket connections to React dev server""" - logger.info("REACT: WebSocket proxy connection established.") - target_url = "wss://client:3000/ws" - await websocket.accept() - try: - # Accept self-signed certs in dev for WSS - ssl_ctx = ssl.create_default_context() - ssl_ctx.check_hostname = False - ssl_ctx.verify_mode = ssl.CERT_NONE - - async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws: - async def client_to_server(): - try: - while True: - data = await websocket.receive_text() - await target_ws.send(data) - except Exception as e: - logger.debug(f"Client to server error: {e}") - - async def server_to_client(): - try: - while True: - data = await target_ws.recv() - await websocket.send_text(data) - except Exception as e: - logger.debug(f"Server to client error: {e}") - - # Run both directions concurrently - import asyncio - await asyncio.gather( - client_to_server(), - server_to_client(), - return_exceptions=True - ) - except Exception as e: - logger.warning(f"WebSocket proxy error: {e}") - finally: - try: - await websocket.close() - except: - pass - - -# Serve static files or proxy to frontend development server -client_build_path = "/client/build" - -if PRODUCTION: - # In production, serve static files from the client build directory - if os.path.exists(client_build_path): - logger.info(f"Serving static files from: {client_build_path} at {public_url}") - app.mount( - public_url, StaticFiles(directory=client_build_path, html=True), name="static" - ) - else: - logger.warning(f"Client build directory not found: {client_build_path}") -else: - # In development, proxy to the React dev server - logger.info(f"Proxying static files to http://client:3000 at {public_url}") - - @app.api_route( - f"{public_url}{{path:path}}", - methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"], - ) - async def proxy_static(request: Request, path: str): - # Do not proxy API or websocket paths - if path.startswith("api/") or path.startswith("ws/"): - return Response(status_code=404) - - url = f"https://client:3000/{public_url.strip('/')}/{path}" - if not path: - url = f"https://client:3000/{public_url.strip('/')}" - - # Prepare headers but remove problematic ones for proxying - headers = dict(request.headers) - # Remove host header to avoid conflicts - headers.pop("host", None) - # Remove accept-encoding to prevent compression issues - headers.pop("accept-encoding", None) - - try: - # Use HTTP instead of HTTPS for internal container communication - async with httpx.AsyncClient(verify=False) as client: - proxy_req = client.build_request( - request.method, url, headers=headers, content=await request.body() - ) - proxy_resp = await client.send(proxy_req, stream=False) - - # Get response headers but filter out problematic encoding headers - response_headers = dict(proxy_resp.headers) - # Remove content-encoding and transfer-encoding to prevent conflicts - response_headers.pop("content-encoding", None) - response_headers.pop("transfer-encoding", None) - response_headers.pop("content-length", None) # Let FastAPI calculate this - - return Response( - content=proxy_resp.content, - status_code=proxy_resp.status_code, - headers=response_headers, - media_type=proxy_resp.headers.get("content-type") - ) - except Exception as e: - logger.warning(f"Proxy error for {path}: {e}") - return Response(status_code=404) - - -# Health check for the new architecture -@app.get(f"{public_url}api/system/health") -def system_health(): - return { - "status": "ok", - "architecture": "modular", - "version": "2.0.0", - "managers": { - "session_manager": "active" if session_manager else "inactive", - "lobby_manager": "active" if lobby_manager else "inactive", - "auth_manager": "active" if auth_manager else "inactive", - "websocket_manager": "active" if websocket_manager else "inactive", - } - } - - -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/server/main_original.py b/server/main_original.py deleted file mode 100644 index d5e0118..0000000 --- a/server/main_original.py +++ /dev/null @@ -1,2338 +0,0 @@ -from __future__ import annotations -from typing import Any, Optional, List -from fastapi import ( - Body, - Cookie, - FastAPI, - HTTPException, - Path, - WebSocket, - Request, - Response, - WebSocketDisconnect, -) -import secrets -import os -import json -import hashlib -import binascii -import sys -import asyncio -import threading -import time -from contextlib import asynccontextmanager - -from fastapi.staticfiles import StaticFiles -import httpx -from pydantic import ValidationError -from logger import logger - -# Import shared models -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from shared.models import ( - HealthResponse, - LobbiesResponse, - LobbyCreateRequest, - LobbyCreateResponse, - LobbyListItem, - LobbyModel, - NamePasswordRecord, - LobbySaved, - SessionResponse, - SessionSaved, - SessionsPayload, - AdminNamesResponse, - AdminActionResponse, - AdminSetPassword, - AdminClearPassword, - AdminValidationResponse, - AdminMetricsResponse, - AdminMetricsConfig, - JoinStatusModel, - ChatMessageModel, - ChatMessagesResponse, - ParticipantModel, - # Bot provider models - BotProviderModel, - BotProviderRegisterRequest, - BotProviderRegisterResponse, - BotProviderListResponse, - BotListResponse, - BotInfoModel, - BotJoinLobbyRequest, - BotJoinLobbyResponse, - BotJoinPayload, - BotLeaveLobbyRequest, - BotLeaveLobbyResponse, - BotProviderBotsResponse, - BotProviderJoinResponse, -) - - -class SessionConfig: - """Configuration class for session management""" - - ANONYMOUS_SESSION_TIMEOUT = int( - os.getenv("ANONYMOUS_SESSION_TIMEOUT", "60") - ) # 1 minute - DISPLACED_SESSION_TIMEOUT = int( - os.getenv("DISPLACED_SESSION_TIMEOUT", "10800") - ) # 3 hours - CLEANUP_INTERVAL = int(os.getenv("CLEANUP_INTERVAL", "300")) # 5 minutes - MAX_SESSIONS_PER_CLEANUP = int( - os.getenv("MAX_SESSIONS_PER_CLEANUP", "100") - ) # Circuit breaker - MAX_CHAT_MESSAGES_PER_LOBBY = int(os.getenv("MAX_CHAT_MESSAGES_PER_LOBBY", "100")) - SESSION_VALIDATION_INTERVAL = int( - os.getenv("SESSION_VALIDATION_INTERVAL", "1800") - ) # 30 minutes - - -class BotProviderConfig: - """Configuration class for bot provider management""" - - # Comma-separated list of allowed provider keys - # Format: "key1:name1,key2:name2" or just "key1,key2" (names default to keys) - ALLOWED_PROVIDERS = os.getenv("BOT_PROVIDER_KEYS", "") - - @classmethod - def get_allowed_providers(cls) -> dict[str, str]: - """Parse allowed providers from environment variable - - Returns: - dict mapping provider_key -> provider_name - """ - if not cls.ALLOWED_PROVIDERS.strip(): - return {} - - providers: dict[str, str] = {} - for entry in cls.ALLOWED_PROVIDERS.split(","): - entry = entry.strip() - if not entry: - continue - - if ":" in entry: - key, name = entry.split(":", 1) - providers[key.strip()] = name.strip() - else: - providers[entry] = entry - - return providers - - -# Thread lock for session operations -session_lock = threading.RLock() - -# Mapping of reserved names to password records (lowercased name -> {salt:..., hash:...}) -name_passwords: dict[str, dict[str, str]] = {} - -# Bot provider registry: provider_id -> BotProviderModel -bot_providers: dict[str, BotProviderModel] = {} - -all_label = "[ all ]" -info_label = "[ info ]" -todo_label = "[ todo ]" -unset_label = "[ ---- ]" - - -def _hash_password(password: str, salt_hex: str | None = None) -> tuple[str, str]: - """Return (salt_hex, hash_hex) for the given password. If salt_hex is provided - it is used; otherwise a new salt is generated.""" - if salt_hex: - salt = binascii.unhexlify(salt_hex) - else: - salt = secrets.token_bytes(16) - salt_hex = binascii.hexlify(salt).decode() - dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100000) - hash_hex = binascii.hexlify(dk).decode() - return salt_hex, hash_hex - - -public_url = os.getenv("PUBLIC_URL", "/") -if not public_url.endswith("/"): - public_url += "/" - -# Global variables to control background tasks -cleanup_task_running = False -cleanup_task = None -validation_task_running = False -validation_task = None - - -async def periodic_cleanup(): - """Background task to periodically clean up old sessions""" - global cleanup_task_running - cleanup_errors = 0 - max_consecutive_errors = 5 - - while cleanup_task_running: - try: - removed_count = Session.cleanup_old_sessions() - if removed_count > 0: - logger.info(f"Periodic cleanup removed {removed_count} old sessions") - cleanup_errors = 0 # Reset error counter on success - - # Run cleanup at configured interval - await asyncio.sleep(SessionConfig.CLEANUP_INTERVAL) - except Exception as e: - cleanup_errors += 1 - logger.error( - f"Error in session cleanup task (attempt {cleanup_errors}): {e}" - ) - - if cleanup_errors >= max_consecutive_errors: - logger.error( - f"Too many consecutive cleanup errors ({cleanup_errors}), stopping cleanup task" - ) - break - - # Exponential backoff on errors - await asyncio.sleep(min(60 * cleanup_errors, 300)) - - -async def periodic_validation(): - """Background task to periodically validate session integrity""" - global validation_task_running - - while validation_task_running: - try: - issues = Session.validate_session_integrity() - if issues: - logger.warning(f"Session integrity issues found: {len(issues)} issues") - for issue in issues[:10]: # Log first 10 issues - logger.warning(f"Integrity issue: {issue}") - - await asyncio.sleep(SessionConfig.SESSION_VALIDATION_INTERVAL) - except Exception as e: - logger.error(f"Error in session validation task: {e}") - await asyncio.sleep(300) # Wait 5 minutes before retrying on error - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Lifespan context manager for startup and shutdown events""" - global cleanup_task_running, cleanup_task, validation_task_running, validation_task - - # Startup - logger.info("Starting background tasks...") - cleanup_task_running = True - validation_task_running = True - cleanup_task = asyncio.create_task(periodic_cleanup()) - validation_task = asyncio.create_task(periodic_validation()) - logger.info("Session cleanup and validation tasks started") - - yield - - # Shutdown - logger.info("Shutting down background tasks...") - cleanup_task_running = False - validation_task_running = False - - # Cancel tasks - for task in [cleanup_task, validation_task]: - if task: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - # Clean up all sessions gracefully - await Session.cleanup_all_sessions() - logger.info("All background tasks stopped and sessions cleaned up") - - -app = FastAPI(lifespan=lifespan) - -logger.info(f"Starting server with public URL: {public_url}") -logger.info( - f"Session config - Anonymous timeout: {SessionConfig.ANONYMOUS_SESSION_TIMEOUT}s, " - f"Displaced timeout: {SessionConfig.DISPLACED_SESSION_TIMEOUT}s, " - f"Cleanup interval: {SessionConfig.CLEANUP_INTERVAL}s" -) - -# Log bot provider configuration -allowed_providers = BotProviderConfig.get_allowed_providers() -if allowed_providers: - logger.info( - f"Bot provider authentication enabled. Allowed providers: {list(allowed_providers.keys())}" - ) -else: - logger.warning("Bot provider authentication disabled. Any provider can register.") - -# Optional admin token to protect admin endpoints -ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", None) - - -def _require_admin(request: Request) -> bool: - if not ADMIN_TOKEN: - return True - token = request.headers.get("X-Admin-Token") - return token == ADMIN_TOKEN - - -@app.get(public_url + "api/admin/names", response_model=AdminNamesResponse) -def admin_list_names(request: Request): - if not _require_admin(request): - return Response(status_code=403) - # Convert dict format to Pydantic models - name_passwords_models = { - name: NamePasswordRecord(**record) for name, record in name_passwords.items() - } - return AdminNamesResponse(name_passwords=name_passwords_models) - - -@app.post(public_url + "api/admin/set_password", response_model=AdminActionResponse) -def admin_set_password(request: Request, payload: AdminSetPassword = Body(...)): - if not _require_admin(request): - return Response(status_code=403) - lname = payload.name.lower() - salt, hash_hex = _hash_password(payload.password) - name_passwords[lname] = {"salt": salt, "hash": hash_hex} - Session.save() - return AdminActionResponse(status="ok", name=payload.name) - - -@app.post(public_url + "api/admin/clear_password", response_model=AdminActionResponse) -def admin_clear_password(request: Request, payload: AdminClearPassword = Body(...)): - if not _require_admin(request): - return Response(status_code=403) - lname = payload.name.lower() - if lname in name_passwords: - del name_passwords[lname] - Session.save() - return AdminActionResponse(status="ok", name=payload.name) - return AdminActionResponse(status="not_found", name=payload.name) - - -@app.post(public_url + "api/admin/cleanup_sessions", response_model=AdminActionResponse) -def admin_cleanup_sessions(request: Request): - if not _require_admin(request): - return Response(status_code=403) - try: - removed_count = Session.cleanup_old_sessions() - return AdminActionResponse( - status="ok", name=f"Removed {removed_count} sessions" - ) - except Exception as e: - logger.error(f"Error during manual session cleanup: {e}") - return AdminActionResponse(status="error", name=f"Error: {str(e)}") - - -@app.get(public_url + "api/admin/session_metrics", response_model=AdminMetricsResponse) -def admin_session_metrics(request: Request): - if not _require_admin(request): - return Response(status_code=403) - try: - return Session.get_cleanup_metrics() - except Exception as e: - logger.error(f"Error getting session metrics: {e}") - return Response(status_code=500) - - -@app.get( - public_url + "api/admin/validate_sessions", response_model=AdminValidationResponse -) -def admin_validate_sessions(request: Request): - if not _require_admin(request): - return Response(status_code=403) - try: - issues = Session.validate_session_integrity() - return AdminValidationResponse( - status="ok", issues=issues, issue_count=len(issues) - ) - except Exception as e: - logger.error(f"Error validating sessions: {e}") - return AdminValidationResponse(status="error", error=str(e)) - - -lobbies: dict[str, Lobby] = {} - - -class Lobby: - def __init__(self, name: str, id: str | None = None, private: bool = False): - self.id = secrets.token_hex(16) if id is None else id - self.short = self.id[:8] - self.name = name - self.sessions: dict[str, Session] = {} # All lobby members - self.private = private - self.chat_messages: list[ChatMessageModel] = [] # Store chat messages - self.lock = threading.RLock() # Thread safety for lobby operations - - def getName(self) -> str: - return f"{self.short}:{self.name}" - - async def update_state(self, requesting_session: Session | None = None): - with self.lock: - users: list[ParticipantModel] = [ - ParticipantModel( - name=s.name, - live=True if s.ws else False, - session_id=s.id, - protected=True - if s.name and s.name.lower() in name_passwords - else False, - is_bot=s.is_bot, - has_media=s.has_media, - bot_run_id=s.bot_run_id, - bot_provider_id=s.bot_provider_id, - ) - for s in self.sessions.values() - if s.name - ] - - if requesting_session: - logger.info( - f"{requesting_session.getName()} -> lobby_state({self.getName()})" - ) - if requesting_session.ws: - try: - await requesting_session.ws.send_json( - { - "type": "lobby_state", - "data": { - "participants": [user.model_dump() for user in users] - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to send lobby state to {requesting_session.getName()}: {e}" - ) - else: - logger.warning( - f"{requesting_session.getName()} - No WebSocket connection." - ) - else: - # Send to all sessions in lobby - failed_sessions: list[Session] = [] - for s in self.sessions.values(): - logger.info(f"{s.getName()} -> lobby_state({self.getName()})") - if s.ws: - try: - await s.ws.send_json( - { - "type": "lobby_state", - "data": { - "participants": [ - user.model_dump() for user in users - ] - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to send lobby state to {s.getName()}: {e}" - ) - failed_sessions.append(s) - - # Clean up failed sessions - for failed_session in failed_sessions: - failed_session.ws = None - - def getSession(self, id: str) -> Session | None: - with self.lock: - return self.sessions.get(id, None) - - async def addSession(self, session: Session) -> None: - with self.lock: - if session.id in self.sessions: - logger.warning( - f"{session.getName()} - Already in lobby {self.getName()}." - ) - return None - self.sessions[session.id] = session - await self.update_state() - - async def removeSession(self, session: Session) -> None: - with self.lock: - if session.id not in self.sessions: - logger.warning(f"{session.getName()} - Not in lobby {self.getName()}.") - return None - del self.sessions[session.id] - await self.update_state() - - def add_chat_message(self, session: Session, message: str) -> ChatMessageModel: - """Add a chat message to the lobby and return the message data""" - with self.lock: - chat_message = ChatMessageModel( - id=secrets.token_hex(8), - message=message, - sender_name=session.name or session.short, - sender_session_id=session.id, - timestamp=time.time(), - lobby_id=self.id, - ) - self.chat_messages.append(chat_message) - # Keep only the latest messages per lobby - if len(self.chat_messages) > SessionConfig.MAX_CHAT_MESSAGES_PER_LOBBY: - self.chat_messages = self.chat_messages[ - -SessionConfig.MAX_CHAT_MESSAGES_PER_LOBBY : - ] - return chat_message - - def get_chat_messages(self, limit: int = 50) -> list[ChatMessageModel]: - """Get the most recent chat messages from the lobby""" - with self.lock: - return self.chat_messages[-limit:] if self.chat_messages else [] - - async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None: - """Broadcast a chat message to all connected sessions in the lobby""" - failed_sessions: list[Session] = [] - for peer in self.sessions.values(): - if peer.ws: - try: - logger.info(f"{self.getName()} -> chat_message({peer.getName()})") - await peer.ws.send_json( - {"type": "chat_message", "data": chat_message.model_dump()} - ) - except Exception as e: - logger.warning( - f"Failed to send chat message to {peer.getName()}: {e}" - ) - failed_sessions.append(peer) - - # Clean up failed sessions - for failed_session in failed_sessions: - failed_session.ws = None - - -class Session: - _instances: list[Session] = [] - _save_file = "sessions.json" - _loaded = False - lock = threading.RLock() # Thread safety for class-level operations - - def __init__(self, id: str, is_bot: bool = False, has_media: bool = True): - logger.info( - f"Instantiating new session {id} (bot: {is_bot}, media: {has_media})" - ) - with Session.lock: - self._instances.append(self) - self.id = id - self.short = id[:8] - self.name = "" - self.lobbies: list[Lobby] = [] # List of lobby IDs this session is in - self.lobby_peers: dict[ - str, list[str] - ] = {} # lobby ID -> list of peer session IDs - self.ws: WebSocket | None = None - self.created_at = time.time() - self.last_used = time.time() - self.displaced_at: float | None = None # When name was taken over - self.is_bot = is_bot # Whether this session represents a bot - self.has_media = has_media # Whether this session provides audio/video streams - self.bot_run_id: str | None = None # Bot run ID for tracking - self.bot_provider_id: str | None = None # Bot provider ID - self.session_lock = threading.RLock() # Instance-level lock - self.save() - - @classmethod - def save(cls): - try: - with cls.lock: - sessions_list: list[SessionSaved] = [] - for s in cls._instances: - with s.session_lock: - lobbies_list: list[LobbySaved] = [ - LobbySaved( - id=lobby.id, name=lobby.name, private=lobby.private - ) - for lobby in s.lobbies - ] - sessions_list.append( - SessionSaved( - id=s.id, - name=s.name or "", - lobbies=lobbies_list, - created_at=s.created_at, - last_used=s.last_used, - displaced_at=s.displaced_at, - is_bot=s.is_bot, - has_media=s.has_media, - bot_run_id=s.bot_run_id, - bot_provider_id=s.bot_provider_id, - ) - ) - - # Prepare name password store for persistence (salt+hash). Only structured records are supported. - saved_pw: dict[str, NamePasswordRecord] = { - name: NamePasswordRecord(**record) - for name, record in name_passwords.items() - } - - payload_model = SessionsPayload( - sessions=sessions_list, name_passwords=saved_pw - ) - payload = payload_model.model_dump() - - # Atomic write using temp file - temp_file = cls._save_file + ".tmp" - with open(temp_file, "w") as f: - json.dump(payload, f, indent=2) - - # Atomic rename - os.rename(temp_file, cls._save_file) - - logger.info( - f"Saved {len(sessions_list)} sessions and {len(saved_pw)} name passwords to {cls._save_file}" - ) - except Exception as e: - logger.error(f"Failed to save sessions: {e}") - # Clean up temp file if it exists - try: - if os.path.exists(cls._save_file + ".tmp"): - os.remove(cls._save_file + ".tmp") - except Exception as e: - pass - - @classmethod - def load(cls): - if not os.path.exists(cls._save_file): - logger.info(f"No session save file found: {cls._save_file}") - return - - try: - with open(cls._save_file, "r") as f: - raw = json.load(f) - except Exception as e: - logger.error(f"Failed to read session save file: {e}") - return - - try: - payload = SessionsPayload.model_validate(raw) - except ValidationError as e: - logger.exception(f"Failed to validate sessions payload: {e}") - return - - # Populate in-memory structures from payload (no backwards compatibility code) - name_passwords.clear() - for name, rec in payload.name_passwords.items(): - # rec is a NamePasswordRecord - name_passwords[name] = {"salt": rec.salt, "hash": rec.hash} - - current_time = time.time() - sessions_loaded = 0 - sessions_expired = 0 - - with cls.lock: - for s_saved in payload.sessions: - # Check if this session should be expired during loading - created_at = getattr(s_saved, "created_at", time.time()) - last_used = getattr(s_saved, "last_used", time.time()) - displaced_at = getattr(s_saved, "displaced_at", None) - name = s_saved.name or "" - - # Apply same removal criteria as cleanup_old_sessions - should_expire = cls._should_remove_session_static( - name, None, created_at, last_used, displaced_at, current_time - ) - - if should_expire: - sessions_expired += 1 - logger.info(f"Expiring session {s_saved.id[:8]}:{name} during load") - continue # Skip loading this expired session - - session = Session( - s_saved.id, - is_bot=getattr(s_saved, "is_bot", False), - has_media=getattr(s_saved, "has_media", True), - ) - session.name = name - # Load timestamps, with defaults for backward compatibility - session.created_at = created_at - session.last_used = last_used - session.displaced_at = displaced_at - # Load bot information with defaults for backward compatibility - session.is_bot = getattr(s_saved, "is_bot", False) - session.has_media = getattr(s_saved, "has_media", True) - session.bot_run_id = getattr(s_saved, "bot_run_id", None) - session.bot_provider_id = getattr(s_saved, "bot_provider_id", None) - for lobby_saved in s_saved.lobbies: - session.lobbies.append( - Lobby( - name=lobby_saved.name, - id=lobby_saved.id, - private=lobby_saved.private, - ) - ) - logger.info( - f"Loaded session {session.getName()} with {len(session.lobbies)} lobbies" - ) - for lobby in session.lobbies: - lobbies[lobby.id] = Lobby( - name=lobby.name, id=lobby.id, private=lobby.private - ) # Ensure lobby exists - sessions_loaded += 1 - - logger.info( - f"Loaded {sessions_loaded} sessions and {len(name_passwords)} name passwords from {cls._save_file}" - ) - if sessions_expired > 0: - logger.info(f"Expired {sessions_expired} old sessions during load") - # Save immediately to persist the cleanup - cls.save() - - @classmethod - def getSession(cls, id: str) -> Session | None: - if not cls._loaded: - cls.load() - logger.info(f"Loaded {len(cls._instances)} sessions from disk...") - cls._loaded = True - - with cls.lock: - for s in cls._instances: - if s.id == id: - return s - return None - - @classmethod - def isUniqueName(cls, name: str) -> bool: - if not name: - return False - with cls.lock: - for s in cls._instances: - with s.session_lock: - if s.name.lower() == name.lower(): - return False - return True - - @classmethod - def getSessionByName(cls, name: str) -> Optional["Session"]: - if not name: - return None - lname = name.lower() - with cls.lock: - for s in cls._instances: - with s.session_lock: - if s.name and s.name.lower() == lname: - return s - return None - - def getName(self) -> str: - with self.session_lock: - return f"{self.short}:{self.name if self.name else unset_label}" - - def setName(self, name: str): - with self.session_lock: - self.name = name - self.update_last_used() - self.save() - - def update_last_used(self): - """Update the last_used timestamp""" - with self.session_lock: - self.last_used = time.time() - - def mark_displaced(self): - """Mark this session as having its name taken over""" - with self.session_lock: - self.displaced_at = time.time() - - @staticmethod - def _should_remove_session_static( - name: str, - ws: WebSocket | None, - created_at: float, - last_used: float, - displaced_at: float | None, - current_time: float, - ) -> bool: - """Static method to determine if a session should be removed""" - # Rule 1: Delete sessions with no active connection and no name that are older than threshold - if ( - not ws - and not name - and current_time - created_at > SessionConfig.ANONYMOUS_SESSION_TIMEOUT - ): - return True - - # Rule 2: Delete inactive sessions that had their nick taken over and haven't been used recently - if ( - not ws - and displaced_at is not None - and current_time - last_used > SessionConfig.DISPLACED_SESSION_TIMEOUT - ): - return True - - return False - - def _should_remove(self, current_time: float) -> bool: - """Check if this session should be removed""" - with self.session_lock: - return self._should_remove_session_static( - self.name, - self.ws, - self.created_at, - self.last_used, - self.displaced_at, - current_time, - ) - - @classmethod - def _remove_session_safely(cls, session: Session, empty_lobbies: set[str]) -> None: - """Safely remove a session and track affected lobbies""" - try: - with session.session_lock: - # Remove from lobbies first - for lobby in session.lobbies[ - : - ]: # Copy list to avoid modification during iteration - try: - with lobby.lock: - if session.id in lobby.sessions: - del lobby.sessions[session.id] - if len(lobby.sessions) == 0: - empty_lobbies.add(lobby.id) - - if lobby.id in session.lobby_peers: - del session.lobby_peers[lobby.id] - except Exception as e: - logger.warning( - f"Error removing session {session.getName()} from lobby {lobby.getName()}: {e}" - ) - - # Close WebSocket if open - if session.ws: - try: - asyncio.create_task(session.ws.close()) - except Exception as e: - logger.warning( - f"Error closing WebSocket for {session.getName()}: {e}" - ) - session.ws = None - - # Remove from instances list - with cls.lock: - if session in cls._instances: - cls._instances.remove(session) - - except Exception as e: - logger.error( - f"Error during safe session removal for {session.getName()}: {e}" - ) - - @classmethod - def _cleanup_empty_lobbies(cls, empty_lobbies: set[str]) -> int: - """Clean up empty lobbies from global lobbies dict""" - removed_count = 0 - for lobby_id in empty_lobbies: - if lobby_id in lobbies: - lobby_name = lobbies[lobby_id].getName() - del lobbies[lobby_id] - logger.info(f"Removed empty lobby {lobby_name}") - removed_count += 1 - return removed_count - - @classmethod - def cleanup_old_sessions(cls) -> int: - """Clean up old sessions based on the specified criteria with improved safety""" - current_time = time.time() - sessions_removed = 0 - - try: - # Circuit breaker - don't remove too many sessions at once - sessions_to_remove: list[Session] = [] - empty_lobbies: set[str] = set() - - with cls.lock: - # Identify sessions to remove (up to max limit) - for session in cls._instances[:]: - if ( - len(sessions_to_remove) - >= SessionConfig.MAX_SESSIONS_PER_CLEANUP - ): - logger.warning( - f"Hit session cleanup limit ({SessionConfig.MAX_SESSIONS_PER_CLEANUP}), " - f"stopping cleanup. Remaining sessions will be cleaned up in next cycle." - ) - break - - if session._should_remove(current_time): - sessions_to_remove.append(session) - logger.info( - f"Marking session {session.getName()} for removal - " - f"criteria: no_ws={session.ws is None}, no_name={not session.name}, " - f"age={current_time - session.created_at:.0f}s, " - f"displaced={session.displaced_at is not None}, " - f"unused={current_time - session.last_used:.0f}s" - ) - - # Remove the identified sessions - for session in sessions_to_remove: - cls._remove_session_safely(session, empty_lobbies) - sessions_removed += 1 - - # Clean up empty lobbies - empty_lobbies_removed = cls._cleanup_empty_lobbies(empty_lobbies) - - # Save state if we made changes - if sessions_removed > 0: - cls.save() - logger.info( - f"Session cleanup completed: removed {sessions_removed} sessions, " - f"{empty_lobbies_removed} empty lobbies" - ) - - except Exception as e: - logger.error(f"Error during session cleanup: {e}") - # Don't re-raise - cleanup should be resilient - - return sessions_removed - - @classmethod - def get_cleanup_metrics(cls) -> AdminMetricsResponse: - """Return cleanup metrics for monitoring""" - current_time = time.time() - - with cls.lock: - total_sessions = len(cls._instances) - active_sessions = 0 - named_sessions = 0 - displaced_sessions = 0 - old_anonymous = 0 - old_displaced = 0 - - for s in cls._instances: - with s.session_lock: - if s.ws: - active_sessions += 1 - if s.name: - named_sessions += 1 - if s.displaced_at is not None: - displaced_sessions += 1 - if ( - not s.ws - and current_time - s.last_used - > SessionConfig.DISPLACED_SESSION_TIMEOUT - ): - old_displaced += 1 - if ( - not s.ws - and not s.name - and current_time - s.created_at - > SessionConfig.ANONYMOUS_SESSION_TIMEOUT - ): - old_anonymous += 1 - - config = AdminMetricsConfig( - anonymous_timeout=SessionConfig.ANONYMOUS_SESSION_TIMEOUT, - displaced_timeout=SessionConfig.DISPLACED_SESSION_TIMEOUT, - cleanup_interval=SessionConfig.CLEANUP_INTERVAL, - max_cleanup_per_cycle=SessionConfig.MAX_SESSIONS_PER_CLEANUP, - ) - - return AdminMetricsResponse( - total_sessions=total_sessions, - active_sessions=active_sessions, - named_sessions=named_sessions, - displaced_sessions=displaced_sessions, - old_anonymous_sessions=old_anonymous, - old_displaced_sessions=old_displaced, - total_lobbies=len(lobbies), - cleanup_candidates=old_anonymous + old_displaced, - config=config, - ) - - @classmethod - def validate_session_integrity(cls) -> list[str]: - """Validate session data integrity""" - issues: list[str] = [] - - try: - with cls.lock: - for session in cls._instances: - with session.session_lock: - # Check for orphaned lobby references - for lobby in session.lobbies: - if lobby.id not in lobbies: - issues.append( - f"Session {session.id[:8]}:{session.name} references missing lobby {lobby.id}" - ) - - # Check for inconsistent peer relationships - for lobby_id, peer_ids in session.lobby_peers.items(): - lobby = lobbies.get(lobby_id) - if lobby: - with lobby.lock: - if session.id not in lobby.sessions: - issues.append( - f"Session {session.id[:8]}:{session.name} has peers in lobby {lobby_id} but not in lobby.sessions" - ) - - # Check if peer sessions actually exist - for peer_id in peer_ids: - if peer_id not in lobby.sessions: - issues.append( - f"Session {session.id[:8]}:{session.name} references non-existent peer {peer_id} in lobby {lobby_id}" - ) - else: - issues.append( - f"Session {session.id[:8]}:{session.name} has peer list for non-existent lobby {lobby_id}" - ) - - # Check lobbies for consistency - for lobby_id, lobby in lobbies.items(): - with lobby.lock: - for session_id in lobby.sessions: - found_session = None - for s in cls._instances: - if s.id == session_id: - found_session = s - break - - if not found_session: - issues.append( - f"Lobby {lobby_id} references non-existent session {session_id}" - ) - else: - with found_session.session_lock: - if lobby not in found_session.lobbies: - issues.append( - f"Lobby {lobby_id} contains session {session_id} but session doesn't reference lobby" - ) - - except Exception as e: - logger.error(f"Error during session validation: {e}") - issues.append(f"Validation error: {str(e)}") - - return issues - - @classmethod - async def cleanup_all_sessions(cls): - """Clean up all sessions during shutdown""" - logger.info("Starting graceful session cleanup...") - - try: - with cls.lock: - sessions_to_cleanup = cls._instances[:] - - for session in sessions_to_cleanup: - try: - with session.session_lock: - # Close WebSocket connections - if session.ws: - try: - await session.ws.close() - except Exception as e: - logger.warning( - f"Error closing WebSocket for {session.getName()}: {e}" - ) - session.ws = None - - # Remove from lobbies - for lobby in session.lobbies[:]: - try: - await session.part(lobby) - except Exception as e: - logger.warning( - f"Error removing {session.getName()} from lobby: {e}" - ) - - except Exception as e: - logger.error(f"Error cleaning up session {session.getName()}: {e}") - - # Clear all data structures - with cls.lock: - cls._instances.clear() - lobbies.clear() - - logger.info( - f"Graceful session cleanup completed for {len(sessions_to_cleanup)} sessions" - ) - - except Exception as e: - logger.error(f"Error during graceful session cleanup: {e}") - - async def join(self, lobby: Lobby): - if not self.ws: - logger.error( - f"{self.getName()} - No WebSocket connection. Lobby not available." - ) - return - - with self.session_lock: - if lobby.id in self.lobby_peers or self.id in lobby.sessions: - logger.info(f"{self.getName()} - Already joined to {lobby.getName()}.") - data = JoinStatusModel( - status="Joined", - message=f"Already joined to lobby {lobby.getName()}", - ) - try: - await self.ws.send_json( - {"type": "join_status", "data": data.model_dump()} - ) - except Exception as e: - logger.warning( - f"Failed to send join status to {self.getName()}: {e}" - ) - return - - # Initialize the peer list for this lobby - with self.session_lock: - self.lobbies.append(lobby) - self.lobby_peers[lobby.id] = [] - - with lobby.lock: - peer_sessions = list(lobby.sessions.values()) - - for peer_session in peer_sessions: - if peer_session.id == self.id: - logger.error( - "Should not happen: self in lobby.sessions while not in lobby." - ) - continue - - if not peer_session.ws: - logger.warning( - f"{self.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Removing." - ) - with lobby.lock: - if peer_session.id in lobby.sessions: - del lobby.sessions[peer_session.id] - continue - - # Only create WebRTC peer connections if at least one participant has media - should_create_rtc_connection = self.has_media or peer_session.has_media - - if should_create_rtc_connection: - # Add the peer to session's RTC peer list - with self.session_lock: - self.lobby_peers[lobby.id].append(peer_session.id) - - # Add this user as an RTC peer to each existing peer - with peer_session.session_lock: - if lobby.id not in peer_session.lobby_peers: - peer_session.lobby_peers[lobby.id] = [] - peer_session.lobby_peers[lobby.id].append(self.id) - - logger.info( - f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName()}, {lobby.getName()}, should_create_offer=False, has_media={self.has_media})" - ) - try: - await peer_session.ws.send_json( - { - "type": "addPeer", - "data": { - "peer_id": self.id, - "peer_name": self.name, - "has_media": self.has_media, - "should_create_offer": False, - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to send addPeer to {peer_session.getName()}: {e}" - ) - - # Add each other peer to the caller - logger.info( - f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName()}, {lobby.getName()}, should_create_offer=True, has_media={peer_session.has_media})" - ) - try: - await self.ws.send_json( - { - "type": "addPeer", - "data": { - "peer_id": peer_session.id, - "peer_name": peer_session.name, - "has_media": peer_session.has_media, - "should_create_offer": True, - }, - } - ) - except Exception as e: - logger.warning(f"Failed to send addPeer to {self.getName()}: {e}") - else: - logger.info( - f"{self.getName()} - Skipping WebRTC connection with {peer_session.getName()} (neither has media: self={self.has_media}, peer={peer_session.has_media})" - ) - - # Add this user as an RTC peer - await lobby.addSession(self) - Session.save() - - try: - await self.ws.send_json( - {"type": "join_status", "data": {"status": "Joined"}} - ) - except Exception as e: - logger.warning(f"Failed to send join confirmation to {self.getName()}: {e}") - - async def part(self, lobby: Lobby): - with self.session_lock: - if lobby.id not in self.lobby_peers or self.id not in lobby.sessions: - logger.info( - f"{self.getName()} - Attempt to part non-joined lobby {lobby.getName()}." - ) - if self.ws: - try: - await self.ws.send_json( - { - "type": "error", - "data": { - "error": "Attempt to part non-joined lobby", - }, - } - ) - except Exception: - pass - return - - logger.info(f"{self.getName()} <- part({lobby.getName()}) - Lobby part.") - - lobby_peers = self.lobby_peers[lobby.id][:] # Copy the list - del self.lobby_peers[lobby.id] - if lobby in self.lobbies: - self.lobbies.remove(lobby) - - # Remove this peer from all other RTC peers, and remove each peer from this peer - for peer_session_id in lobby_peers: - peer_session = getSession(peer_session_id) - if not peer_session: - logger.warning( - f"{self.getName()} <- part({lobby.getName()}) - Peer session {peer_session_id} not found. Skipping." - ) - continue - - if peer_session.ws: - logger.info( - f"{peer_session.getName()} <- remove_peer({self.getName()})" - ) - try: - await peer_session.ws.send_json( - { - "type": "removePeer", - "data": {"peer_name": self.name, "peer_id": self.id}, - } - ) - except Exception as e: - logger.warning( - f"Failed to send removePeer to {peer_session.getName()}: {e}" - ) - else: - logger.warning( - f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection for {peer_session.getName()}. Skipping." - ) - - # Remove from peer's lobby_peers - with peer_session.session_lock: - if ( - lobby.id in peer_session.lobby_peers - and self.id in peer_session.lobby_peers[lobby.id] - ): - peer_session.lobby_peers[lobby.id].remove(self.id) - - if self.ws: - logger.info( - f"{self.getName()} <- remove_peer({peer_session.getName()})" - ) - try: - await self.ws.send_json( - { - "type": "removePeer", - "data": { - "peer_name": peer_session.name, - "peer_id": peer_session.id, - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to send removePeer to {self.getName()}: {e}" - ) - else: - logger.error( - f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection." - ) - - await lobby.removeSession(self) - Session.save() - - -def getName(session: Session | None) -> str | None: - if session and session.name: - return session.name - return None - - -def getSession(session_id: str) -> Session | None: - return Session.getSession(session_id) - - -def getLobby(lobby_id: str) -> Lobby: - lobby = lobbies.get(lobby_id, None) - if not lobby: - # Check if this might be a stale reference after cleanup - logger.warning(f"Lobby not found: {lobby_id} (may have been cleaned up)") - raise Exception(f"Lobby not found: {lobby_id}") - return lobby - - -def getLobbyByName(lobby_name: str) -> Lobby | None: - for lobby in lobbies.values(): - if lobby.name == lobby_name: - return lobby - return None - - -# API endpoints -@app.get(f"{public_url}api/health", response_model=HealthResponse) -def health(): - logger.info("Health check endpoint called.") - return HealthResponse(status="ok") - - -# A session (cookie) is bound to a single user (name). -# A user can be in multiple lobbies, but a session is unique to a single user. -# A user can change their name, but the session ID remains the same and the name -# updates for all lobbies. -@app.get(f"{public_url}api/session", response_model=SessionResponse) -async def session( - request: Request, response: Response, session_id: str | None = Cookie(default=None) -) -> Response | SessionResponse: - if session_id is None: - session_id = secrets.token_hex(16) - response.set_cookie(key="session_id", value=session_id) - # Validate that session_id is a hex string of length 32 - elif len(session_id) != 32 or not all(c in "0123456789abcdef" for c in session_id): - return Response( - content=json.dumps({"error": "Invalid session_id"}), - status_code=400, - media_type="application/json", - ) - - print(f"[{session_id[:8]}]: Browser hand-shake achieved.") - - session = getSession(session_id) - if not session: - session = Session(session_id) - logger.info(f"{session.getName()}: New session created.") - else: - session.update_last_used() # Update activity on session resumption - logger.info(f"{session.getName()}: Existing session resumed.") - # Part all lobbies for this session that have no active websocket - with session.session_lock: - lobbies_to_part = session.lobbies[:] - for lobby in lobbies_to_part: - try: - await session.part(lobby) - except Exception as e: - logger.error( - f"{session.getName()} - Error parting lobby {lobby.getName()}: {e}" - ) - - with session.session_lock: - return SessionResponse( - id=session_id, - name=session.name if session.name else "", - lobbies=[ - LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private) - for lobby in session.lobbies - ], - ) - - -@app.get(public_url + "api/lobby", response_model=LobbiesResponse) -async def get_lobbies(request: Request, response: Response) -> LobbiesResponse: - return LobbiesResponse( - lobbies=[ - LobbyListItem(id=lobby.id, name=lobby.name) - for lobby in lobbies.values() - if not lobby.private - ] - ) - - -@app.post(public_url + "api/lobby/{session_id}", response_model=LobbyCreateResponse) -async def lobby_create( - request: Request, - response: Response, - session_id: str = Path(...), - create_request: LobbyCreateRequest = Body(...), -) -> Response | LobbyCreateResponse: - if create_request.type != "lobby_create": - return Response( - content=json.dumps({"error": "Invalid request type"}), - status_code=400, - media_type="application/json", - ) - - data = create_request.data - session = getSession(session_id) - if not session: - return Response( - content=json.dumps({"error": f"Session not found ({session_id})"}), - status_code=404, - media_type="application/json", - ) - logger.info( - f"{session.getName()} lobby_create: {data.name} (private={data.private})" - ) - - lobby = getLobbyByName(data.name) - if not lobby: - lobby = Lobby( - data.name, - private=data.private, - ) - lobbies[lobby.id] = lobby - logger.info(f"{session.getName()} <- lobby_create({lobby.short}:{lobby.name})") - - return LobbyCreateResponse( - type="lobby_created", - data=LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private), - ) - - -@app.get(public_url + "api/lobby/{lobby_id}/chat", response_model=ChatMessagesResponse) -async def get_chat_messages( - request: Request, - lobby_id: str = Path(...), - limit: int = 50, -) -> Response | ChatMessagesResponse: - """Get chat messages for a lobby""" - try: - lobby = getLobby(lobby_id) - except Exception as e: - return Response( - content=json.dumps({"error": str(e)}), - status_code=404, - media_type="application/json", - ) - - messages = lobby.get_chat_messages(limit) - - return ChatMessagesResponse(messages=messages) - - -# ============================================================================= -# Bot Provider API Endpoints -# ============================================================================= - - -@app.post( - public_url + "api/bots/providers/register", - response_model=BotProviderRegisterResponse, -) -async def register_bot_provider( - request: BotProviderRegisterRequest, -) -> BotProviderRegisterResponse: - """Register a new bot provider with authentication""" - import uuid - - # Check if provider authentication is enabled - allowed_providers = BotProviderConfig.get_allowed_providers() - if allowed_providers: - # Authentication is enabled - validate provider key - if request.provider_key not in allowed_providers: - logger.warning( - f"Rejected bot provider registration with invalid key: {request.provider_key}" - ) - raise HTTPException( - status_code=403, - detail="Invalid provider key. Bot provider is not authorized to register.", - ) - - # Check if there's already an active provider with this key and remove it - providers_to_remove: list[str] = [] - for existing_provider_id, existing_provider in bot_providers.items(): - if existing_provider.provider_key == request.provider_key: - providers_to_remove.append(existing_provider_id) - logger.info( - f"Removing stale bot provider: {existing_provider.name} (ID: {existing_provider_id})" - ) - - # Remove stale providers - for provider_id_to_remove in providers_to_remove: - del bot_providers[provider_id_to_remove] - - provider_id = str(uuid.uuid4()) - now = time.time() - - provider = BotProviderModel( - provider_id=provider_id, - base_url=request.base_url.rstrip("/"), - name=request.name, - description=request.description, - provider_key=request.provider_key, - registered_at=now, - last_seen=now, - ) - - bot_providers[provider_id] = provider - logger.info( - f"Registered bot provider: {request.name} at {request.base_url} with key: {request.provider_key}" - ) - - return BotProviderRegisterResponse(provider_id=provider_id) - - -@app.get(public_url + "api/bots/providers", response_model=BotProviderListResponse) -async def list_bot_providers() -> BotProviderListResponse: - """List all registered bot providers""" - return BotProviderListResponse(providers=list(bot_providers.values())) - - -@app.get(public_url + "api/bots", response_model=BotListResponse) -async def list_available_bots() -> BotListResponse: - """List all available bots from all registered providers""" - bots: List[BotInfoModel] = [] - providers: dict[str, str] = {} - - # Update last_seen timestamps and fetch bots from each provider - for provider_id, provider in bot_providers.items(): - try: - provider.last_seen = time.time() - - # Make HTTP request to provider's /bots endpoint - async with httpx.AsyncClient() as client: - response = await client.get(f"{provider.base_url}/bots", timeout=5.0) - if response.status_code == 200: - # Use Pydantic model to validate the response - bots_response = BotProviderBotsResponse.model_validate( - response.json() - ) - # Add each bot to the consolidated list - for bot_info in bots_response.bots: - bots.append(bot_info) - providers[bot_info.name] = provider_id - else: - logger.warning( - f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}" - ) - except Exception as e: - logger.error(f"Error fetching bots from provider {provider.name}: {e}") - continue - - return BotListResponse(bots=bots, providers=providers) - - -@app.post(public_url + "api/bots/{bot_name}/join", response_model=BotJoinLobbyResponse) -async def request_bot_join_lobby( - bot_name: str, request: BotJoinLobbyRequest -) -> BotJoinLobbyResponse: - """Request a bot to join a specific lobby""" - - # Find which provider has this bot and determine its media capability - target_provider_id = request.provider_id - bot_has_media = False - if not target_provider_id: - # Auto-discover provider for this bot - for provider_id, provider in bot_providers.items(): - try: - async with httpx.AsyncClient() as client: - response = await client.get( - f"{provider.base_url}/bots", timeout=5.0 - ) - if response.status_code == 200: - # Use Pydantic model to validate the response - bots_response = BotProviderBotsResponse.model_validate( - response.json() - ) - # Look for the bot by name - for bot_info in bots_response.bots: - if bot_info.name == bot_name: - target_provider_id = provider_id - bot_has_media = bot_info.has_media - break - if target_provider_id: - break - except Exception: - continue - else: - # Query the specified provider for bot media capability - if target_provider_id in bot_providers: - provider = bot_providers[target_provider_id] - try: - async with httpx.AsyncClient() as client: - response = await client.get( - f"{provider.base_url}/bots", timeout=5.0 - ) - if response.status_code == 200: - # Use Pydantic model to validate the response - bots_response = BotProviderBotsResponse.model_validate( - response.json() - ) - # Look for the bot by name - for bot_info in bots_response.bots: - if bot_info.name == bot_name: - bot_has_media = bot_info.has_media - break - except Exception: - # Default to no media if we can't query - pass - - if not target_provider_id or target_provider_id not in bot_providers: - raise HTTPException(status_code=404, detail="Bot or provider not found") - - provider = bot_providers[target_provider_id] - - # Get the lobby to validate it exists - try: - getLobby(request.lobby_id) # Just validate it exists - except Exception: - raise HTTPException(status_code=404, detail="Lobby not found") - - # Create a session for the bot - bot_session_id = secrets.token_hex(16) - - # Create the Session object for the bot - bot_session = Session(bot_session_id, is_bot=True, has_media=bot_has_media) - logger.info( - f"Created bot session for: {bot_session.getName()} (has_media={bot_has_media})" - ) - - # Determine server URL for the bot to connect back to - # Use the server's public URL or construct from request - server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000") - if server_base_url.endswith("/"): - server_base_url = server_base_url[:-1] - - bot_nick = request.nick or f"{bot_name}-bot-{bot_session_id[:8]}" - - # Prepare the join request for the bot provider - bot_join_payload = BotJoinPayload( - lobby_id=request.lobby_id, - session_id=bot_session_id, - nick=bot_nick, - server_url=f"{server_base_url}{public_url}".rstrip("/"), - insecure=True, # Accept self-signed certificates in development - ) - - try: - # Make request to bot provider - async with httpx.AsyncClient() as client: - response = await client.post( - f"{provider.base_url}/bots/{bot_name}/join", - json=bot_join_payload.model_dump(), - timeout=10.0, - ) - - if response.status_code == 200: - # Use Pydantic model to parse and validate response - try: - join_response = BotProviderJoinResponse.model_validate( - response.json() - ) - run_id = join_response.run_id - - # Update bot session with run and provider information - with bot_session.session_lock: - bot_session.bot_run_id = run_id - bot_session.bot_provider_id = target_provider_id - bot_session.setName(bot_nick) - - logger.info( - f"Bot {bot_name} requested to join lobby {request.lobby_id}" - ) - - return BotJoinLobbyResponse( - status="requested", - bot_name=bot_name, - run_id=run_id, - provider_id=target_provider_id, - ) - except ValidationError as e: - logger.error(f"Invalid response from bot provider: {e}") - raise HTTPException( - status_code=502, - detail=f"Bot provider returned invalid response: {str(e)}", - ) - else: - logger.error( - f"Bot provider returned error: HTTP {response.status_code}: {response.text}" - ) - raise HTTPException( - status_code=502, - detail=f"Bot provider error: {response.status_code}", - ) - - except httpx.TimeoutException: - raise HTTPException(status_code=504, detail="Bot provider timeout") - except Exception as e: - logger.error(f"Error requesting bot join: {e}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@app.post(public_url + "api/bots/leave", response_model=BotLeaveLobbyResponse) -async def request_bot_leave_lobby( - request: BotLeaveLobbyRequest, -) -> BotLeaveLobbyResponse: - """Request a bot to leave from all lobbies and disconnect""" - - # Find the bot session - bot_session = getSession(request.session_id) - if not bot_session: - raise HTTPException(status_code=404, detail="Bot session not found") - - if not bot_session.is_bot: - raise HTTPException(status_code=400, detail="Session is not a bot") - - run_id = bot_session.bot_run_id - provider_id = bot_session.bot_provider_id - - logger.info(f"Requesting bot {bot_session.getName()} to leave all lobbies") - - # Try to stop the bot at the provider level if we have the information - if provider_id and run_id and provider_id in bot_providers: - provider = bot_providers[provider_id] - try: - async with httpx.AsyncClient() as client: - response = await client.post( - f"{provider.base_url}/bots/runs/{run_id}/stop", - timeout=5.0, - ) - if response.status_code == 200: - logger.info( - f"Successfully requested bot provider to stop run {run_id}" - ) - else: - logger.warning( - f"Bot provider returned error when stopping: HTTP {response.status_code}" - ) - except Exception as e: - logger.warning(f"Failed to request bot stop from provider: {e}") - - # Force disconnect the bot session from all lobbies - with bot_session.session_lock: - lobbies_to_part = bot_session.lobbies[:] - - for lobby in lobbies_to_part: - try: - await bot_session.part(lobby) - except Exception as e: - logger.warning(f"Error parting bot from lobby {lobby.getName()}: {e}") - - # Close WebSocket connection if it exists - if bot_session.ws: - try: - await bot_session.ws.close() - except Exception as e: - logger.warning(f"Error closing bot WebSocket: {e}") - bot_session.ws = None - - return BotLeaveLobbyResponse( - status="disconnected", - session_id=request.session_id, - run_id=run_id, - ) - - -# Register websocket endpoint directly on app with full public_url path -@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}") -async def lobby_join( - websocket: WebSocket, - lobby_id: str | None = Path(...), - session_id: str | None = Path(...), -): - await websocket.accept() - if lobby_id is None: - await websocket.send_json( - {"type": "error", "data": {"error": "Invalid or missing lobby"}} - ) - await websocket.close() - return - if session_id is None: - await websocket.send_json( - {"type": "error", "data": {"error": "Invalid or missing session"}} - ) - await websocket.close() - return - session = getSession(session_id) - if not session: - # logger.error(f"Invalid session ID {session_id}") - await websocket.send_json( - {"type": "error", "data": {"error": f"Invalid session ID {session_id}"}} - ) - await websocket.close() - return - - lobby = None - try: - lobby = getLobby(lobby_id) - except Exception as e: - await websocket.send_json({"type": "error", "data": {"error": str(e)}}) - await websocket.close() - return - - logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})") - - session.ws = websocket - session.update_last_used() # Update activity timestamp - - # Check if session is already in lobby and clean up if needed - with lobby.lock: - if session.id in lobby.sessions: - logger.info( - f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining." - ) - try: - await session.part(lobby) - await lobby.removeSession(session) - except Exception as e: - logger.warning(f"Error cleaning up stale session: {e}") - - # Notify existing peers about new user - failed_peers: list[str] = [] - with lobby.lock: - peer_sessions = list(lobby.sessions.values()) - - for peer_session in peer_sessions: - if not peer_session.ws: - logger.warning( - f"{session.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Marking for removal." - ) - failed_peers.append(peer_session.id) - continue - - logger.info(f"{session.getName()} -> user_joined({peer_session.getName()})") - try: - await peer_session.ws.send_json( - { - "type": "user_joined", - "data": { - "session_id": session.id, - "name": session.name, - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to notify {peer_session.getName()} of user join: {e}" - ) - failed_peers.append(peer_session.id) - - # Clean up failed peers - with lobby.lock: - for failed_peer_id in failed_peers: - if failed_peer_id in lobby.sessions: - del lobby.sessions[failed_peer_id] - - try: - while True: - packet = await websocket.receive_json() - session.update_last_used() # Update activity on each message - type = packet.get("type", None) - data: dict[str, Any] | None = packet.get("data", None) - if not type: - logger.error(f"{session.getName()} - Invalid request: {packet}") - await websocket.send_json( - {"type": "error", "data": {"error": "Invalid request"}} - ) - continue - # logger.info(f"{session.getName()} <- RAW Rx: {data}") - match type: - case "set_name": - if not data: - logger.error(f"{session.getName()} - set_name missing data") - await websocket.send_json( - { - "type": "error", - "data": {"error": "set_name missing data"}, - } - ) - continue - name = data.get("name") - password = data.get("password") - logger.info(f"{session.getName()} <- set_name({name}, {password})") - if not name: - logger.error(f"{session.getName()} - Name required") - await websocket.send_json( - {"type": "error", "data": {"error": "Name required"}} - ) - continue - # Name takeover / password logic - lname = name.lower() - - # If name is unused, allow and optionally save password - if Session.isUniqueName(name): - # If a password was provided, save it (hash+salt) for this name - if password: - salt, hash_hex = _hash_password(password) - name_passwords[lname] = {"salt": salt, "hash": hash_hex} - session.setName(name) - logger.info(f"{session.getName()}: -> update('name', {name})") - await websocket.send_json( - { - "type": "update_name", - "data": { - "name": name, - "protected": True - if name.lower() in name_passwords - else False, - }, - } - ) - # For any clients in any lobby with this session, update their user lists - await lobby.update_state() - continue - - # Name is taken. Check if a password exists for the name and matches. - saved_pw = name_passwords.get(lname) - if not saved_pw and not password: - logger.warning( - f"{session.getName()} - Name already taken (no password set)" - ) - await websocket.send_json( - {"type": "error", "data": {"error": "Name already taken"}} - ) - continue - - if saved_pw and password: - # Expect structured record with salt+hash only - match_password = False - # saved_pw should be a dict[str,str] with 'salt' and 'hash' - salt = saved_pw.get("salt") - _, candidate_hash = _hash_password( - password if password else "", salt_hex=salt - ) - if candidate_hash == saved_pw.get("hash"): - match_password = True - else: - # No structured password record available - match_password = False - else: - match_password = True # No password set, but name taken and new password - allow takeover - - if not match_password: - logger.warning( - f"{session.getName()} - Name takeover attempted with wrong or missing password" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "Invalid password for name takeover", - }, - } - ) - continue - - # Password matches: perform takeover. Find the current session holding the name. - # Find the currently existing session (if any) with that name - displaced = Session.getSessionByName(name) - if displaced and displaced.id == session.id: - displaced = None - - # If found, change displaced session to a unique fallback name and notify peers - if displaced: - # Create a unique fallback name - fallback = f"{displaced.name}-{displaced.short}" - # Ensure uniqueness - if not Session.isUniqueName(fallback): - # append random suffix until unique - while not Session.isUniqueName(fallback): - fallback = f"{displaced.name}-{secrets.token_hex(3)}" - - displaced.setName(fallback) - displaced.mark_displaced() - logger.info( - f"{displaced.getName()} <- displaced by takeover, new name {fallback}" - ) - # Notify displaced session (if connected) - if displaced.ws: - try: - await displaced.ws.send_json( - { - "type": "update_name", - "data": { - "name": fallback, - "protected": False, - }, - } - ) - except Exception: - logger.exception( - "Failed to notify displaced session websocket" - ) - # Update all lobbies the displaced session was in - with displaced.session_lock: - displaced_lobbies = displaced.lobbies[:] - for d_lobby in displaced_lobbies: - try: - await d_lobby.update_state() - except Exception: - logger.exception( - "Failed to update lobby state for displaced session" - ) - - # Now assign the requested name to the current session - session.setName(name) - logger.info( - f"{session.getName()}: -> update('name', {name}) (takeover)" - ) - await websocket.send_json( - { - "type": "update_name", - "data": { - "name": name, - "protected": True - if name.lower() in name_passwords - else False, - }, - } - ) - # Notify lobbies for this session - await lobby.update_state() - - case "list_users": - await lobby.update_state(session) - - case "get_chat_messages": - # Send recent chat messages to the requesting client - messages = lobby.get_chat_messages(50) - await websocket.send_json( - { - "type": "chat_messages", - "data": { - "messages": [msg.model_dump() for msg in messages] - }, - } - ) - - case "send_chat_message": - if not data or "message" not in data: - logger.error( - f"{session.getName()} - send_chat_message missing message" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "send_chat_message missing message", - }, - } - ) - continue - - if not session.name: - logger.error( - f"{session.getName()} - Cannot send chat message without name" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "Must set name before sending chat messages", - }, - } - ) - continue - - message_text = str(data["message"]).strip() - if not message_text: - continue - - # Add the message to the lobby and broadcast it - chat_message = lobby.add_chat_message(session, message_text) - logger.info( - f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)" - ) - await lobby.broadcast_chat_message(chat_message) - - case "join": - logger.info(f"{session.getName()} <- join({lobby.getName()})") - await session.join(lobby=lobby) - - case "part": - logger.info(f"{session.getName()} <- part {lobby.getName()}") - await session.part(lobby=lobby) - - case "relayICECandidate": - logger.info(f"{session.getName()} <- relayICECandidate") - if not data: - logger.error( - f"{session.getName()} - relayICECandidate missing data" - ) - await websocket.send_json( - { - "type": "error", - "data": {"error": "relayICECandidate missing data"}, - } - ) - continue - - with session.session_lock: - if ( - lobby.id not in session.lobby_peers - or session.id not in lobby.sessions - ): - logger.error( - f"{session.short}:{session.name} <- relayICECandidate - Not an RTC peer ({session.id})" - ) - await websocket.send_json( - { - "type": "error", - "data": {"error": "Not joined to lobby"}, - } - ) - continue - session_peers = session.lobby_peers[lobby.id] - - peer_id = data.get("peer_id") - if peer_id not in session_peers: - logger.error( - f"{session.getName()} <- relayICECandidate - Not an RTC peer({peer_id}) in {session_peers}" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": f"Target peer {peer_id} not found", - }, - } - ) - continue - - candidate = data.get("candidate") - - message: dict[str, Any] = { - "type": "iceCandidate", - "data": { - "peer_id": session.id, - "peer_name": session.name, - "candidate": candidate, - }, - } - - peer_session = lobby.getSession(peer_id) - if not peer_session or not peer_session.ws: - logger.warning( - f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}." - ) - continue - logger.info( - f"{session.getName()} -> iceCandidate({peer_session.getName()})" - ) - try: - await peer_session.ws.send_json(message) - except Exception as e: - logger.warning(f"Failed to relay ICE candidate: {e}") - - case "relaySessionDescription": - logger.info(f"{session.getName()} <- relaySessionDescription") - if not data: - logger.error( - f"{session.getName()} - relaySessionDescription missing data" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "relaySessionDescription missing data", - }, - } - ) - continue - - with session.session_lock: - if ( - lobby.id not in session.lobby_peers - or session.id not in lobby.sessions - ): - logger.error( - f"{session.short}:{session.name} <- relaySessionDescription - Not an RTC peer ({session.id})" - ) - await websocket.send_json( - { - "type": "error", - "data": {"error": "Not joined to lobby"}, - } - ) - continue - - lobby_peers = session.lobby_peers[lobby.id] - - peer_id = data.get("peer_id") - if peer_id not in lobby_peers: - logger.error( - f"{session.getName()} <- relaySessionDescription - Not an RTC peer({peer_id}) in {lobby_peers}" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": f"Target peer {peer_id} not found", - }, - } - ) - continue - - if not peer_id: - logger.error( - f"{session.getName()} - relaySessionDescription missing peer_id" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "relaySessionDescription missing peer_id", - }, - } - ) - continue - peer_session = lobby.getSession(peer_id) - if not peer_session or not peer_session.ws: - logger.warning( - f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}." - ) - continue - - session_description = data.get("session_description") - message = { - "type": "sessionDescription", - "data": { - "peer_id": session.id, - "peer_name": session.name, - "session_description": session_description, - }, - } - - logger.info( - f"{session.getName()} -> sessionDescription({peer_session.getName()})" - ) - try: - await peer_session.ws.send_json(message) - except Exception as e: - logger.warning(f"Failed to relay session description: {e}") - - case "status_check": - # Simple status check - just respond with success to keep connection alive - logger.debug(f"{session.getName()} <- status_check") - await websocket.send_json( - {"type": "status_ok", "data": {"timestamp": time.time()}} - ) - - case _: - await websocket.send_json( - { - "type": "error", - "data": { - "error": f"Unknown request type: {type}", - }, - } - ) - - except WebSocketDisconnect: - logger.info(f"{session.getName()} <- WebSocket disconnected for user.") - # Cleanup: remove session from lobby and sessions dict - session.ws = None - if session.id in lobby.sessions: - try: - await session.part(lobby) - except Exception as e: - logger.warning(f"Error during websocket disconnect cleanup: {e}") - - try: - await lobby.update_state() - except Exception as e: - logger.warning(f"Error updating lobby state after disconnect: {e}") - - # Clean up empty lobbies - with lobby.lock: - if not lobby.sessions: - if lobby.id in lobbies: - del lobbies[lobby.id] - logger.info(f"Cleaned up empty lobby {lobby.getName()}") - except Exception as e: - logger.error( - f"Unexpected error in websocket handler for {session.getName()}: {e}" - ) - try: - await websocket.close() - except Exception as e: - pass - - -# Serve static files or proxy to frontend development server -PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true" -client_build_path = os.path.join(os.path.dirname(__file__), "/client/build") - -if PRODUCTION: - logger.info(f"Serving static files from: {client_build_path} at {public_url}") - app.mount( - public_url, StaticFiles(directory=client_build_path, html=True), name="static" - ) - - -else: - logger.info(f"Proxying static files to http://client:3000 at {public_url}") - - import ssl - - @app.api_route( - f"{public_url}{{path:path}}", - methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"], - ) - async def proxy_static(request: Request, path: str): - # Do not proxy API or websocket paths - if path.startswith("api/") or path.startswith("ws/"): - return Response(status_code=404) - url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}/{path}" - if not path: - url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}" - headers = dict(request.headers) - try: - # Accept self-signed certs in dev - async with httpx.AsyncClient(verify=False) as client: - proxy_req = client.build_request( - request.method, url, headers=headers, content=await request.body() - ) - proxy_resp = await client.send(proxy_req, stream=True) - content = await proxy_resp.aread() - - # Remove problematic headers for browser decoding - filtered_headers = { - k: v - for k, v in proxy_resp.headers.items() - if k.lower() - not in ["content-encoding", "transfer-encoding", "content-length"] - } - return Response( - content=content, - status_code=proxy_resp.status_code, - headers=filtered_headers, - ) - except Exception as e: - logger.error(f"Proxy error for {url}: {e}") - return Response("Proxy error", status_code=502) - - # WebSocket proxy for /ws (for React DevTools, etc.) - import websockets - - @app.websocket("/ws") - async def websocket_proxy(websocket: WebSocket): - logger.info("REACT: WebSocket proxy connection established.") - # Get scheme from websocket.url (should be 'ws' or 'wss') - scheme = websocket.url.scheme if hasattr(websocket, "url") else "ws" - target_url = f"{scheme}://client:3000/ws" - await websocket.accept() - try: - # Accept self-signed certs in dev for WSS - ssl_ctx = ssl.create_default_context() - ssl_ctx.check_hostname = False - ssl_ctx.verify_mode = ssl.CERT_NONE - async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws: - - async def client_to_server(): - while True: - msg = await websocket.receive_text() - await target_ws.send(msg) - - async def server_to_client(): - while True: - msg = await target_ws.recv() - if isinstance(msg, str): - await websocket.send_text(msg) - else: - await websocket.send_bytes(msg) - - try: - await asyncio.gather(client_to_server(), server_to_client()) - except (WebSocketDisconnect, websockets.ConnectionClosed): - logger.info("REACT: WebSocket proxy connection closed.") - except Exception as e: - logger.error(f"REACT: WebSocket proxy error: {e}") - await websocket.close() diff --git a/server/main_refactored.py b/server/main_refactored.py deleted file mode 100644 index 545ffc9..0000000 --- a/server/main_refactored.py +++ /dev/null @@ -1,213 +0,0 @@ -""" -Refactored main.py - Step 1 of Server Architecture Improvement - -This is a refactored version of the original main.py that demonstrates the new -modular architecture with separated concerns: - -- SessionManager: Handles session lifecycle and persistence -- LobbyManager: Handles lobby management and chat -- AuthManager: Handles authentication and name protection -- WebSocket message routing: Clean message handling -- Separated API modules: Admin, session, and lobby endpoints - -This maintains backward compatibility while providing a foundation for -further improvements. -""" - -from __future__ import annotations -import os -from contextlib import asynccontextmanager - -from fastapi import FastAPI, WebSocket, Path -from fastapi.staticfiles import StaticFiles - -# Import our new modular components -try: - from core.session_manager import SessionManager - from core.lobby_manager import LobbyManager - from core.auth_manager import AuthManager - from websocket.connection import WebSocketConnectionManager - from api.admin import AdminAPI - from api.sessions import SessionAPI - from api.lobbies import LobbyAPI -except ImportError: - # Handle relative imports when running as module - import sys - import os - - sys.path.append(os.path.dirname(os.path.abspath(__file__))) - - from core.session_manager import SessionManager - from core.lobby_manager import LobbyManager - from core.auth_manager import AuthManager - from websocket.connection import WebSocketConnectionManager - from api.admin import AdminAPI - from api.sessions import SessionAPI - from api.lobbies import LobbyAPI - -from logger import logger - - -# Configuration -public_url = os.getenv("PUBLIC_URL", "/") -if not public_url.endswith("/"): - public_url += "/" - -ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", None) - -# Global managers - these replace the global variables from original main.py -session_manager: SessionManager = None -lobby_manager: LobbyManager = None -auth_manager: AuthManager = None -websocket_manager: WebSocketConnectionManager = None - -# API routers -admin_api: AdminAPI = None -session_api: SessionAPI = None -lobby_api: LobbyAPI = None - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Lifespan context manager for startup and shutdown events""" - global session_manager, lobby_manager, auth_manager, websocket_manager - global admin_api, session_api, lobby_api - - # Startup - logger.info("Starting AI Voice Bot server with modular architecture...") - - # Initialize managers - session_manager = SessionManager("sessions.json") - lobby_manager = LobbyManager() - auth_manager = AuthManager("sessions.json") - - # Load existing data - session_manager.load() - - # Restore lobbies for existing sessions - # Note: This is a simplified version - full lobby restoration would be more complex - for session in session_manager.get_all_sessions(): - for lobby_info in session.lobbies: - # Create lobby if it doesn't exist - lobby = lobby_manager.create_or_get_lobby( - name=lobby_info.name, private=lobby_info.private - ) - # Add session to lobby (but don't trigger events during startup) - with lobby.lock: - lobby.sessions[session.id] = session - - # Set up dependency injection for name protection - lobby_manager.set_name_protection_checker(auth_manager.is_name_protected) - - # Initialize WebSocket manager - websocket_manager = WebSocketConnectionManager( - session_manager=session_manager, - lobby_manager=lobby_manager, - auth_manager=auth_manager, - ) - - # Initialize API routers - admin_api = AdminAPI( - session_manager=session_manager, - lobby_manager=lobby_manager, - auth_manager=auth_manager, - admin_token=ADMIN_TOKEN, - public_url=public_url, - ) - - session_api = SessionAPI(session_manager=session_manager, public_url=public_url) - - lobby_api = LobbyAPI( - session_manager=session_manager, - lobby_manager=lobby_manager, - public_url=public_url, - ) - - # Register API routes - app.include_router(admin_api.router) - app.include_router(session_api.router) - app.include_router(lobby_api.router) - - # Start background tasks - await session_manager.start_background_tasks() - - logger.info("AI Voice Bot server started successfully!") - logger.info(f"Server URL: {public_url}") - logger.info(f"Sessions loaded: {session_manager.get_session_count()}") - logger.info(f"Lobbies available: {lobby_manager.get_lobby_count()}") - logger.info(f"Protected names: {auth_manager.get_protection_count()}") - - if ADMIN_TOKEN: - logger.info("Admin endpoints protected with token") - else: - logger.warning("Admin endpoints are unprotected") - - yield - - # Shutdown - logger.info("Shutting down AI Voice Bot server...") - - # Stop background tasks - if session_manager: - await session_manager.stop_background_tasks() - - logger.info("Server shutdown complete") - - -# Create FastAPI app -app = FastAPI( - title="AI Voice Bot Server (Refactored)", - description="WebRTC voice chat server with modular architecture", - version="2.0.0", - lifespan=lifespan, -) - -logger.info(f"Starting server with public URL: {public_url}") - - -@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}") -async def lobby_websocket( - websocket: WebSocket, - lobby_id: str | None = Path(...), - session_id: str | None = Path(...), -): - """WebSocket endpoint for lobby connections - now uses WebSocketConnectionManager""" - await websocket_manager.handle_connection(websocket, lobby_id, session_id) - - -# Serve static files if available (for client) -try: - app.mount(public_url + "static", StaticFiles(directory="static"), name="static") - logger.info("Static files mounted at /static") -except Exception: - logger.info("No static directory found, skipping static file serving") - - -# Health check for the new architecture -@app.get(f"{public_url}api/system/health") -def system_health(): - """System health check showing manager status""" - return { - "status": "ok", - "architecture": "modular", - "version": "2.0.0", - "managers": { - "session_manager": "active" if session_manager else "inactive", - "lobby_manager": "active" if lobby_manager else "inactive", - "auth_manager": "active" if auth_manager else "inactive", - "websocket_manager": "active" if websocket_manager else "inactive", - }, - "statistics": { - "sessions": session_manager.get_session_count() if session_manager else 0, - "lobbies": lobby_manager.get_lobby_count() if lobby_manager else 0, - "protected_names": auth_manager.get_protection_count() - if auth_manager - else 0, - }, - } - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/server/main_working.py b/server/main_working.py deleted file mode 100644 index b9fe291..0000000 --- a/server/main_working.py +++ /dev/null @@ -1,2338 +0,0 @@ -from __future__ import annotations -from typing import Any, Optional, List -from fastapi import ( - Body, - Cookie, - FastAPI, - HTTPException, - Path, - WebSocket, - Request, - Response, - WebSocketDisconnect, -) -import secrets -import os -import json -import hashlib -import binascii -import sys -import asyncio -import threading -import time -from contextlib import asynccontextmanager - -from fastapi.staticfiles import StaticFiles -import httpx -from pydantic import ValidationError -from logger import logger - -# Import shared models -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from shared.models import ( - HealthResponse, - LobbiesResponse, - LobbyCreateRequest, - LobbyCreateResponse, - LobbyListItem, - LobbyModel, - NamePasswordRecord, - LobbySaved, - SessionResponse, - SessionSaved, - SessionsPayload, - AdminNamesResponse, - AdminActionResponse, - AdminSetPassword, - AdminClearPassword, - AdminValidationResponse, - AdminMetricsResponse, - AdminMetricsConfig, - JoinStatusModel, - ChatMessageModel, - ChatMessagesResponse, - ParticipantModel, - # Bot provider models - BotProviderModel, - BotProviderRegisterRequest, - BotProviderRegisterResponse, - BotProviderListResponse, - BotListResponse, - BotInfoModel, - BotJoinLobbyRequest, - BotJoinLobbyResponse, - BotJoinPayload, - BotLeaveLobbyRequest, - BotLeaveLobbyResponse, - BotProviderBotsResponse, - BotProviderJoinResponse, -) - - -class SessionConfig: - """Configuration class for session management""" - - ANONYMOUS_SESSION_TIMEOUT = int( - os.getenv("ANONYMOUS_SESSION_TIMEOUT", "60") - ) # 1 minute - DISPLACED_SESSION_TIMEOUT = int( - os.getenv("DISPLACED_SESSION_TIMEOUT", "10800") - ) # 3 hours - CLEANUP_INTERVAL = int(os.getenv("CLEANUP_INTERVAL", "300")) # 5 minutes - MAX_SESSIONS_PER_CLEANUP = int( - os.getenv("MAX_SESSIONS_PER_CLEANUP", "100") - ) # Circuit breaker - MAX_CHAT_MESSAGES_PER_LOBBY = int(os.getenv("MAX_CHAT_MESSAGES_PER_LOBBY", "100")) - SESSION_VALIDATION_INTERVAL = int( - os.getenv("SESSION_VALIDATION_INTERVAL", "1800") - ) # 30 minutes - - -class BotProviderConfig: - """Configuration class for bot provider management""" - - # Comma-separated list of allowed provider keys - # Format: "key1:name1,key2:name2" or just "key1,key2" (names default to keys) - ALLOWED_PROVIDERS = os.getenv("BOT_PROVIDER_KEYS", "") - - @classmethod - def get_allowed_providers(cls) -> dict[str, str]: - """Parse allowed providers from environment variable - - Returns: - dict mapping provider_key -> provider_name - """ - if not cls.ALLOWED_PROVIDERS.strip(): - return {} - - providers: dict[str, str] = {} - for entry in cls.ALLOWED_PROVIDERS.split(","): - entry = entry.strip() - if not entry: - continue - - if ":" in entry: - key, name = entry.split(":", 1) - providers[key.strip()] = name.strip() - else: - providers[entry] = entry - - return providers - - -# Thread lock for session operations -session_lock = threading.RLock() - -# Mapping of reserved names to password records (lowercased name -> {salt:..., hash:...}) -name_passwords: dict[str, dict[str, str]] = {} - -# Bot provider registry: provider_id -> BotProviderModel -bot_providers: dict[str, BotProviderModel] = {} - -all_label = "[ all ]" -info_label = "[ info ]" -todo_label = "[ todo ]" -unset_label = "[ ---- ]" - - -def _hash_password(password: str, salt_hex: str | None = None) -> tuple[str, str]: - """Return (salt_hex, hash_hex) for the given password. If salt_hex is provided - it is used; otherwise a new salt is generated.""" - if salt_hex: - salt = binascii.unhexlify(salt_hex) - else: - salt = secrets.token_bytes(16) - salt_hex = binascii.hexlify(salt).decode() - dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100000) - hash_hex = binascii.hexlify(dk).decode() - return salt_hex, hash_hex - - -public_url = os.getenv("PUBLIC_URL", "/") -if not public_url.endswith("/"): - public_url += "/" - -# Global variables to control background tasks -cleanup_task_running = False -cleanup_task = None -validation_task_running = False -validation_task = None - - -async def periodic_cleanup(): - """Background task to periodically clean up old sessions""" - global cleanup_task_running - cleanup_errors = 0 - max_consecutive_errors = 5 - - while cleanup_task_running: - try: - removed_count = Session.cleanup_old_sessions() - if removed_count > 0: - logger.info(f"Periodic cleanup removed {removed_count} old sessions") - cleanup_errors = 0 # Reset error counter on success - - # Run cleanup at configured interval - await asyncio.sleep(SessionConfig.CLEANUP_INTERVAL) - except Exception as e: - cleanup_errors += 1 - logger.error( - f"Error in session cleanup task (attempt {cleanup_errors}): {e}" - ) - - if cleanup_errors >= max_consecutive_errors: - logger.error( - f"Too many consecutive cleanup errors ({cleanup_errors}), stopping cleanup task" - ) - break - - # Exponential backoff on errors - await asyncio.sleep(min(60 * cleanup_errors, 300)) - - -async def periodic_validation(): - """Background task to periodically validate session integrity""" - global validation_task_running - - while validation_task_running: - try: - issues = Session.validate_session_integrity() - if issues: - logger.warning(f"Session integrity issues found: {len(issues)} issues") - for issue in issues[:10]: # Log first 10 issues - logger.warning(f"Integrity issue: {issue}") - - await asyncio.sleep(SessionConfig.SESSION_VALIDATION_INTERVAL) - except Exception as e: - logger.error(f"Error in session validation task: {e}") - await asyncio.sleep(300) # Wait 5 minutes before retrying on error - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Lifespan context manager for startup and shutdown events""" - global cleanup_task_running, cleanup_task, validation_task_running, validation_task - - # Startup - logger.info("Starting background tasks...") - cleanup_task_running = True - validation_task_running = True - cleanup_task = asyncio.create_task(periodic_cleanup()) - validation_task = asyncio.create_task(periodic_validation()) - logger.info("Session cleanup and validation tasks started") - - yield - - # Shutdown - logger.info("Shutting down background tasks...") - cleanup_task_running = False - validation_task_running = False - - # Cancel tasks - for task in [cleanup_task, validation_task]: - if task: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - # Clean up all sessions gracefully - await Session.cleanup_all_sessions() - logger.info("All background tasks stopped and sessions cleaned up") - - -app = FastAPI(lifespan=lifespan) - -logger.info(f"Starting server with public URL: {public_url}") -logger.info( - f"Session config - Anonymous timeout: {SessionConfig.ANONYMOUS_SESSION_TIMEOUT}s, " - f"Displaced timeout: {SessionConfig.DISPLACED_SESSION_TIMEOUT}s, " - f"Cleanup interval: {SessionConfig.CLEANUP_INTERVAL}s" -) - -# Log bot provider configuration -allowed_providers = BotProviderConfig.get_allowed_providers() -if allowed_providers: - logger.info( - f"Bot provider authentication enabled. Allowed providers: {list(allowed_providers.keys())}" - ) -else: - logger.warning("Bot provider authentication disabled. Any provider can register.") - -# Optional admin token to protect admin endpoints -ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", None) - - -def _require_admin(request: Request) -> bool: - if not ADMIN_TOKEN: - return True - token = request.headers.get("X-Admin-Token") - return token == ADMIN_TOKEN - - -@app.get(public_url + "api/admin/names", response_model=AdminNamesResponse) -def admin_list_names(request: Request): - if not _require_admin(request): - return Response(status_code=403) - # Convert dict format to Pydantic models - name_passwords_models = { - name: NamePasswordRecord(**record) for name, record in name_passwords.items() - } - return AdminNamesResponse(name_passwords=name_passwords_models) - - -@app.post(public_url + "api/admin/set_password", response_model=AdminActionResponse) -def admin_set_password(request: Request, payload: AdminSetPassword = Body(...)): - if not _require_admin(request): - return Response(status_code=403) - lname = payload.name.lower() - salt, hash_hex = _hash_password(payload.password) - name_passwords[lname] = {"salt": salt, "hash": hash_hex} - Session.save() - return AdminActionResponse(status="ok", name=payload.name) - - -@app.post(public_url + "api/admin/clear_password", response_model=AdminActionResponse) -def admin_clear_password(request: Request, payload: AdminClearPassword = Body(...)): - if not _require_admin(request): - return Response(status_code=403) - lname = payload.name.lower() - if lname in name_passwords: - del name_passwords[lname] - Session.save() - return AdminActionResponse(status="ok", name=payload.name) - return AdminActionResponse(status="not_found", name=payload.name) - - -@app.post(public_url + "api/admin/cleanup_sessions", response_model=AdminActionResponse) -def admin_cleanup_sessions(request: Request): - if not _require_admin(request): - return Response(status_code=403) - try: - removed_count = Session.cleanup_old_sessions() - return AdminActionResponse( - status="ok", name=f"Removed {removed_count} sessions" - ) - except Exception as e: - logger.error(f"Error during manual session cleanup: {e}") - return AdminActionResponse(status="error", name=f"Error: {str(e)}") - - -@app.get(public_url + "api/admin/session_metrics", response_model=AdminMetricsResponse) -def admin_session_metrics(request: Request): - if not _require_admin(request): - return Response(status_code=403) - try: - return Session.get_cleanup_metrics() - except Exception as e: - logger.error(f"Error getting session metrics: {e}") - return Response(status_code=500) - - -@app.get( - public_url + "api/admin/validate_sessions", response_model=AdminValidationResponse -) -def admin_validate_sessions(request: Request): - if not _require_admin(request): - return Response(status_code=403) - try: - issues = Session.validate_session_integrity() - return AdminValidationResponse( - status="ok", issues=issues, issue_count=len(issues) - ) - except Exception as e: - logger.error(f"Error validating sessions: {e}") - return AdminValidationResponse(status="error", error=str(e)) - - -lobbies: dict[str, Lobby] = {} - - -class Lobby: - def __init__(self, name: str, id: str | None = None, private: bool = False): - self.id = secrets.token_hex(16) if id is None else id - self.short = self.id[:8] - self.name = name - self.sessions: dict[str, Session] = {} # All lobby members - self.private = private - self.chat_messages: list[ChatMessageModel] = [] # Store chat messages - self.lock = threading.RLock() # Thread safety for lobby operations - - def getName(self) -> str: - return f"{self.short}:{self.name}" - - async def update_state(self, requesting_session: Session | None = None): - with self.lock: - users: list[ParticipantModel] = [ - ParticipantModel( - name=s.name, - live=True if s.ws else False, - session_id=s.id, - protected=True - if s.name and s.name.lower() in name_passwords - else False, - is_bot=s.is_bot, - has_media=s.has_media, - bot_run_id=s.bot_run_id, - bot_provider_id=s.bot_provider_id, - ) - for s in self.sessions.values() - if s.name - ] - - if requesting_session: - logger.info( - f"{requesting_session.getName()} -> lobby_state({self.getName()})" - ) - if requesting_session.ws: - try: - await requesting_session.ws.send_json( - { - "type": "lobby_state", - "data": { - "participants": [user.model_dump() for user in users] - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to send lobby state to {requesting_session.getName()}: {e}" - ) - else: - logger.warning( - f"{requesting_session.getName()} - No WebSocket connection." - ) - else: - # Send to all sessions in lobby - failed_sessions: list[Session] = [] - for s in self.sessions.values(): - logger.info(f"{s.getName()} -> lobby_state({self.getName()})") - if s.ws: - try: - await s.ws.send_json( - { - "type": "lobby_state", - "data": { - "participants": [ - user.model_dump() for user in users - ] - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to send lobby state to {s.getName()}: {e}" - ) - failed_sessions.append(s) - - # Clean up failed sessions - for failed_session in failed_sessions: - failed_session.ws = None - - def getSession(self, id: str) -> Session | None: - with self.lock: - return self.sessions.get(id, None) - - async def addSession(self, session: Session) -> None: - with self.lock: - if session.id in self.sessions: - logger.warning( - f"{session.getName()} - Already in lobby {self.getName()}." - ) - return None - self.sessions[session.id] = session - await self.update_state() - - async def removeSession(self, session: Session) -> None: - with self.lock: - if session.id not in self.sessions: - logger.warning(f"{session.getName()} - Not in lobby {self.getName()}.") - return None - del self.sessions[session.id] - await self.update_state() - - def add_chat_message(self, session: Session, message: str) -> ChatMessageModel: - """Add a chat message to the lobby and return the message data""" - with self.lock: - chat_message = ChatMessageModel( - id=secrets.token_hex(8), - message=message, - sender_name=session.name or session.short, - sender_session_id=session.id, - timestamp=time.time(), - lobby_id=self.id, - ) - self.chat_messages.append(chat_message) - # Keep only the latest messages per lobby - if len(self.chat_messages) > SessionConfig.MAX_CHAT_MESSAGES_PER_LOBBY: - self.chat_messages = self.chat_messages[ - -SessionConfig.MAX_CHAT_MESSAGES_PER_LOBBY : - ] - return chat_message - - def get_chat_messages(self, limit: int = 50) -> list[ChatMessageModel]: - """Get the most recent chat messages from the lobby""" - with self.lock: - return self.chat_messages[-limit:] if self.chat_messages else [] - - async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None: - """Broadcast a chat message to all connected sessions in the lobby""" - failed_sessions: list[Session] = [] - for peer in self.sessions.values(): - if peer.ws: - try: - logger.info(f"{self.getName()} -> chat_message({peer.getName()})") - await peer.ws.send_json( - {"type": "chat_message", "data": chat_message.model_dump()} - ) - except Exception as e: - logger.warning( - f"Failed to send chat message to {peer.getName()}: {e}" - ) - failed_sessions.append(peer) - - # Clean up failed sessions - for failed_session in failed_sessions: - failed_session.ws = None - - -class Session: - _instances: list[Session] = [] - _save_file = "sessions.json" - _loaded = False - lock = threading.RLock() # Thread safety for class-level operations - - def __init__(self, id: str, is_bot: bool = False, has_media: bool = True): - logger.info( - f"Instantiating new session {id} (bot: {is_bot}, media: {has_media})" - ) - with Session.lock: - self._instances.append(self) - self.id = id - self.short = id[:8] - self.name = "" - self.lobbies: list[Lobby] = [] # List of lobby IDs this session is in - self.lobby_peers: dict[ - str, list[str] - ] = {} # lobby ID -> list of peer session IDs - self.ws: WebSocket | None = None - self.created_at = time.time() - self.last_used = time.time() - self.displaced_at: float | None = None # When name was taken over - self.is_bot = is_bot # Whether this session represents a bot - self.has_media = has_media # Whether this session provides audio/video streams - self.bot_run_id: str | None = None # Bot run ID for tracking - self.bot_provider_id: str | None = None # Bot provider ID - self.session_lock = threading.RLock() # Instance-level lock - self.save() - - @classmethod - def save(cls): - try: - with cls.lock: - sessions_list: list[SessionSaved] = [] - for s in cls._instances: - with s.session_lock: - lobbies_list: list[LobbySaved] = [ - LobbySaved( - id=lobby.id, name=lobby.name, private=lobby.private - ) - for lobby in s.lobbies - ] - sessions_list.append( - SessionSaved( - id=s.id, - name=s.name or "", - lobbies=lobbies_list, - created_at=s.created_at, - last_used=s.last_used, - displaced_at=s.displaced_at, - is_bot=s.is_bot, - has_media=s.has_media, - bot_run_id=s.bot_run_id, - bot_provider_id=s.bot_provider_id, - ) - ) - - # Prepare name password store for persistence (salt+hash). Only structured records are supported. - saved_pw: dict[str, NamePasswordRecord] = { - name: NamePasswordRecord(**record) - for name, record in name_passwords.items() - } - - payload_model = SessionsPayload( - sessions=sessions_list, name_passwords=saved_pw - ) - payload = payload_model.model_dump() - - # Atomic write using temp file - temp_file = cls._save_file + ".tmp" - with open(temp_file, "w") as f: - json.dump(payload, f, indent=2) - - # Atomic rename - os.rename(temp_file, cls._save_file) - - logger.info( - f"Saved {len(sessions_list)} sessions and {len(saved_pw)} name passwords to {cls._save_file}" - ) - except Exception as e: - logger.error(f"Failed to save sessions: {e}") - # Clean up temp file if it exists - try: - if os.path.exists(cls._save_file + ".tmp"): - os.remove(cls._save_file + ".tmp") - except Exception as e: - pass - - @classmethod - def load(cls): - if not os.path.exists(cls._save_file): - logger.info(f"No session save file found: {cls._save_file}") - return - - try: - with open(cls._save_file, "r") as f: - raw = json.load(f) - except Exception as e: - logger.error(f"Failed to read session save file: {e}") - return - - try: - payload = SessionsPayload.model_validate(raw) - except ValidationError as e: - logger.exception(f"Failed to validate sessions payload: {e}") - return - - # Populate in-memory structures from payload (no backwards compatibility code) - name_passwords.clear() - for name, rec in payload.name_passwords.items(): - # rec is a NamePasswordRecord - name_passwords[name] = {"salt": rec.salt, "hash": rec.hash} - - current_time = time.time() - sessions_loaded = 0 - sessions_expired = 0 - - with cls.lock: - for s_saved in payload.sessions: - # Check if this session should be expired during loading - created_at = getattr(s_saved, "created_at", time.time()) - last_used = getattr(s_saved, "last_used", time.time()) - displaced_at = getattr(s_saved, "displaced_at", None) - name = s_saved.name or "" - - # Apply same removal criteria as cleanup_old_sessions - should_expire = cls._should_remove_session_static( - name, None, created_at, last_used, displaced_at, current_time - ) - - if should_expire: - sessions_expired += 1 - logger.info(f"Expiring session {s_saved.id[:8]}:{name} during load") - continue # Skip loading this expired session - - session = Session( - s_saved.id, - is_bot=getattr(s_saved, "is_bot", False), - has_media=getattr(s_saved, "has_media", True), - ) - session.name = name - # Load timestamps, with defaults for backward compatibility - session.created_at = created_at - session.last_used = last_used - session.displaced_at = displaced_at - # Load bot information with defaults for backward compatibility - session.is_bot = getattr(s_saved, "is_bot", False) - session.has_media = getattr(s_saved, "has_media", True) - session.bot_run_id = getattr(s_saved, "bot_run_id", None) - session.bot_provider_id = getattr(s_saved, "bot_provider_id", None) - for lobby_saved in s_saved.lobbies: - session.lobbies.append( - Lobby( - name=lobby_saved.name, - id=lobby_saved.id, - private=lobby_saved.private, - ) - ) - logger.info( - f"Loaded session {session.getName()} with {len(session.lobbies)} lobbies" - ) - for lobby in session.lobbies: - lobbies[lobby.id] = Lobby( - name=lobby.name, id=lobby.id, private=lobby.private - ) # Ensure lobby exists - sessions_loaded += 1 - - logger.info( - f"Loaded {sessions_loaded} sessions and {len(name_passwords)} name passwords from {cls._save_file}" - ) - if sessions_expired > 0: - logger.info(f"Expired {sessions_expired} old sessions during load") - # Save immediately to persist the cleanup - cls.save() - - @classmethod - def getSession(cls, id: str) -> Session | None: - if not cls._loaded: - cls.load() - logger.info(f"Loaded {len(cls._instances)} sessions from disk...") - cls._loaded = True - - with cls.lock: - for s in cls._instances: - if s.id == id: - return s - return None - - @classmethod - def isUniqueName(cls, name: str) -> bool: - if not name: - return False - with cls.lock: - for s in cls._instances: - with s.session_lock: - if s.name.lower() == name.lower(): - return False - return True - - @classmethod - def getSessionByName(cls, name: str) -> Optional["Session"]: - if not name: - return None - lname = name.lower() - with cls.lock: - for s in cls._instances: - with s.session_lock: - if s.name and s.name.lower() == lname: - return s - return None - - def getName(self) -> str: - with self.session_lock: - return f"{self.short}:{self.name if self.name else unset_label}" - - def setName(self, name: str): - with self.session_lock: - self.name = name - self.update_last_used() - self.save() - - def update_last_used(self): - """Update the last_used timestamp""" - with self.session_lock: - self.last_used = time.time() - - def mark_displaced(self): - """Mark this session as having its name taken over""" - with self.session_lock: - self.displaced_at = time.time() - - @staticmethod - def _should_remove_session_static( - name: str, - ws: WebSocket | None, - created_at: float, - last_used: float, - displaced_at: float | None, - current_time: float, - ) -> bool: - """Static method to determine if a session should be removed""" - # Rule 1: Delete sessions with no active connection and no name that are older than threshold - if ( - not ws - and not name - and current_time - created_at > SessionConfig.ANONYMOUS_SESSION_TIMEOUT - ): - return True - - # Rule 2: Delete inactive sessions that had their nick taken over and haven't been used recently - if ( - not ws - and displaced_at is not None - and current_time - last_used > SessionConfig.DISPLACED_SESSION_TIMEOUT - ): - return True - - return False - - def _should_remove(self, current_time: float) -> bool: - """Check if this session should be removed""" - with self.session_lock: - return self._should_remove_session_static( - self.name, - self.ws, - self.created_at, - self.last_used, - self.displaced_at, - current_time, - ) - - @classmethod - def _remove_session_safely(cls, session: Session, empty_lobbies: set[str]) -> None: - """Safely remove a session and track affected lobbies""" - try: - with session.session_lock: - # Remove from lobbies first - for lobby in session.lobbies[ - : - ]: # Copy list to avoid modification during iteration - try: - with lobby.lock: - if session.id in lobby.sessions: - del lobby.sessions[session.id] - if len(lobby.sessions) == 0: - empty_lobbies.add(lobby.id) - - if lobby.id in session.lobby_peers: - del session.lobby_peers[lobby.id] - except Exception as e: - logger.warning( - f"Error removing session {session.getName()} from lobby {lobby.getName()}: {e}" - ) - - # Close WebSocket if open - if session.ws: - try: - asyncio.create_task(session.ws.close()) - except Exception as e: - logger.warning( - f"Error closing WebSocket for {session.getName()}: {e}" - ) - session.ws = None - - # Remove from instances list - with cls.lock: - if session in cls._instances: - cls._instances.remove(session) - - except Exception as e: - logger.error( - f"Error during safe session removal for {session.getName()}: {e}" - ) - - @classmethod - def _cleanup_empty_lobbies(cls, empty_lobbies: set[str]) -> int: - """Clean up empty lobbies from global lobbies dict""" - removed_count = 0 - for lobby_id in empty_lobbies: - if lobby_id in lobbies: - lobby_name = lobbies[lobby_id].getName() - del lobbies[lobby_id] - logger.info(f"Removed empty lobby {lobby_name}") - removed_count += 1 - return removed_count - - @classmethod - def cleanup_old_sessions(cls) -> int: - """Clean up old sessions based on the specified criteria with improved safety""" - current_time = time.time() - sessions_removed = 0 - - try: - # Circuit breaker - don't remove too many sessions at once - sessions_to_remove: list[Session] = [] - empty_lobbies: set[str] = set() - - with cls.lock: - # Identify sessions to remove (up to max limit) - for session in cls._instances[:]: - if ( - len(sessions_to_remove) - >= SessionConfig.MAX_SESSIONS_PER_CLEANUP - ): - logger.warning( - f"Hit session cleanup limit ({SessionConfig.MAX_SESSIONS_PER_CLEANUP}), " - f"stopping cleanup. Remaining sessions will be cleaned up in next cycle." - ) - break - - if session._should_remove(current_time): - sessions_to_remove.append(session) - logger.info( - f"Marking session {session.getName()} for removal - " - f"criteria: no_ws={session.ws is None}, no_name={not session.name}, " - f"age={current_time - session.created_at:.0f}s, " - f"displaced={session.displaced_at is not None}, " - f"unused={current_time - session.last_used:.0f}s" - ) - - # Remove the identified sessions - for session in sessions_to_remove: - cls._remove_session_safely(session, empty_lobbies) - sessions_removed += 1 - - # Clean up empty lobbies - empty_lobbies_removed = cls._cleanup_empty_lobbies(empty_lobbies) - - # Save state if we made changes - if sessions_removed > 0: - cls.save() - logger.info( - f"Session cleanup completed: removed {sessions_removed} sessions, " - f"{empty_lobbies_removed} empty lobbies" - ) - - except Exception as e: - logger.error(f"Error during session cleanup: {e}") - # Don't re-raise - cleanup should be resilient - - return sessions_removed - - @classmethod - def get_cleanup_metrics(cls) -> AdminMetricsResponse: - """Return cleanup metrics for monitoring""" - current_time = time.time() - - with cls.lock: - total_sessions = len(cls._instances) - active_sessions = 0 - named_sessions = 0 - displaced_sessions = 0 - old_anonymous = 0 - old_displaced = 0 - - for s in cls._instances: - with s.session_lock: - if s.ws: - active_sessions += 1 - if s.name: - named_sessions += 1 - if s.displaced_at is not None: - displaced_sessions += 1 - if ( - not s.ws - and current_time - s.last_used - > SessionConfig.DISPLACED_SESSION_TIMEOUT - ): - old_displaced += 1 - if ( - not s.ws - and not s.name - and current_time - s.created_at - > SessionConfig.ANONYMOUS_SESSION_TIMEOUT - ): - old_anonymous += 1 - - config = AdminMetricsConfig( - anonymous_timeout=SessionConfig.ANONYMOUS_SESSION_TIMEOUT, - displaced_timeout=SessionConfig.DISPLACED_SESSION_TIMEOUT, - cleanup_interval=SessionConfig.CLEANUP_INTERVAL, - max_cleanup_per_cycle=SessionConfig.MAX_SESSIONS_PER_CLEANUP, - ) - - return AdminMetricsResponse( - total_sessions=total_sessions, - active_sessions=active_sessions, - named_sessions=named_sessions, - displaced_sessions=displaced_sessions, - old_anonymous_sessions=old_anonymous, - old_displaced_sessions=old_displaced, - total_lobbies=len(lobbies), - cleanup_candidates=old_anonymous + old_displaced, - config=config, - ) - - @classmethod - def validate_session_integrity(cls) -> list[str]: - """Validate session data integrity""" - issues: list[str] = [] - - try: - with cls.lock: - for session in cls._instances: - with session.session_lock: - # Check for orphaned lobby references - for lobby in session.lobbies: - if lobby.id not in lobbies: - issues.append( - f"Session {session.id[:8]}:{session.name} references missing lobby {lobby.id}" - ) - - # Check for inconsistent peer relationships - for lobby_id, peer_ids in session.lobby_peers.items(): - lobby = lobbies.get(lobby_id) - if lobby: - with lobby.lock: - if session.id not in lobby.sessions: - issues.append( - f"Session {session.id[:8]}:{session.name} has peers in lobby {lobby_id} but not in lobby.sessions" - ) - - # Check if peer sessions actually exist - for peer_id in peer_ids: - if peer_id not in lobby.sessions: - issues.append( - f"Session {session.id[:8]}:{session.name} references non-existent peer {peer_id} in lobby {lobby_id}" - ) - else: - issues.append( - f"Session {session.id[:8]}:{session.name} has peer list for non-existent lobby {lobby_id}" - ) - - # Check lobbies for consistency - for lobby_id, lobby in lobbies.items(): - with lobby.lock: - for session_id in lobby.sessions: - found_session = None - for s in cls._instances: - if s.id == session_id: - found_session = s - break - - if not found_session: - issues.append( - f"Lobby {lobby_id} references non-existent session {session_id}" - ) - else: - with found_session.session_lock: - if lobby not in found_session.lobbies: - issues.append( - f"Lobby {lobby_id} contains session {session_id} but session doesn't reference lobby" - ) - - except Exception as e: - logger.error(f"Error during session validation: {e}") - issues.append(f"Validation error: {str(e)}") - - return issues - - @classmethod - async def cleanup_all_sessions(cls): - """Clean up all sessions during shutdown""" - logger.info("Starting graceful session cleanup...") - - try: - with cls.lock: - sessions_to_cleanup = cls._instances[:] - - for session in sessions_to_cleanup: - try: - with session.session_lock: - # Close WebSocket connections - if session.ws: - try: - await session.ws.close() - except Exception as e: - logger.warning( - f"Error closing WebSocket for {session.getName()}: {e}" - ) - session.ws = None - - # Remove from lobbies - for lobby in session.lobbies[:]: - try: - await session.part(lobby) - except Exception as e: - logger.warning( - f"Error removing {session.getName()} from lobby: {e}" - ) - - except Exception as e: - logger.error(f"Error cleaning up session {session.getName()}: {e}") - - # Clear all data structures - with cls.lock: - cls._instances.clear() - lobbies.clear() - - logger.info( - f"Graceful session cleanup completed for {len(sessions_to_cleanup)} sessions" - ) - - except Exception as e: - logger.error(f"Error during graceful session cleanup: {e}") - - async def join(self, lobby: Lobby): - if not self.ws: - logger.error( - f"{self.getName()} - No WebSocket connection. Lobby not available." - ) - return - - with self.session_lock: - if lobby.id in self.lobby_peers or self.id in lobby.sessions: - logger.info(f"{self.getName()} - Already joined to {lobby.getName()}.") - data = JoinStatusModel( - status="Joined", - message=f"Already joined to lobby {lobby.getName()}", - ) - try: - await self.ws.send_json( - {"type": "join_status", "data": data.model_dump()} - ) - except Exception as e: - logger.warning( - f"Failed to send join status to {self.getName()}: {e}" - ) - return - - # Initialize the peer list for this lobby - with self.session_lock: - self.lobbies.append(lobby) - self.lobby_peers[lobby.id] = [] - - with lobby.lock: - peer_sessions = list(lobby.sessions.values()) - - for peer_session in peer_sessions: - if peer_session.id == self.id: - logger.error( - "Should not happen: self in lobby.sessions while not in lobby." - ) - continue - - if not peer_session.ws: - logger.warning( - f"{self.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Removing." - ) - with lobby.lock: - if peer_session.id in lobby.sessions: - del lobby.sessions[peer_session.id] - continue - - # Only create WebRTC peer connections if at least one participant has media - should_create_rtc_connection = self.has_media or peer_session.has_media - - if should_create_rtc_connection: - # Add the peer to session's RTC peer list - with self.session_lock: - self.lobby_peers[lobby.id].append(peer_session.id) - - # Add this user as an RTC peer to each existing peer - with peer_session.session_lock: - if lobby.id not in peer_session.lobby_peers: - peer_session.lobby_peers[lobby.id] = [] - peer_session.lobby_peers[lobby.id].append(self.id) - - logger.info( - f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName()}, {lobby.getName()}, should_create_offer=False, has_media={self.has_media})" - ) - try: - await peer_session.ws.send_json( - { - "type": "addPeer", - "data": { - "peer_id": self.id, - "peer_name": self.name, - "has_media": self.has_media, - "should_create_offer": False, - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to send addPeer to {peer_session.getName()}: {e}" - ) - - # Add each other peer to the caller - logger.info( - f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName()}, {lobby.getName()}, should_create_offer=True, has_media={peer_session.has_media})" - ) - try: - await self.ws.send_json( - { - "type": "addPeer", - "data": { - "peer_id": peer_session.id, - "peer_name": peer_session.name, - "has_media": peer_session.has_media, - "should_create_offer": True, - }, - } - ) - except Exception as e: - logger.warning(f"Failed to send addPeer to {self.getName()}: {e}") - else: - logger.info( - f"{self.getName()} - Skipping WebRTC connection with {peer_session.getName()} (neither has media: self={self.has_media}, peer={peer_session.has_media})" - ) - - # Add this user as an RTC peer - await lobby.addSession(self) - Session.save() - - try: - await self.ws.send_json( - {"type": "join_status", "data": {"status": "Joined"}} - ) - except Exception as e: - logger.warning(f"Failed to send join confirmation to {self.getName()}: {e}") - - async def part(self, lobby: Lobby): - with self.session_lock: - if lobby.id not in self.lobby_peers or self.id not in lobby.sessions: - logger.info( - f"{self.getName()} - Attempt to part non-joined lobby {lobby.getName()}." - ) - if self.ws: - try: - await self.ws.send_json( - { - "type": "error", - "data": { - "error": "Attempt to part non-joined lobby", - }, - } - ) - except Exception: - pass - return - - logger.info(f"{self.getName()} <- part({lobby.getName()}) - Lobby part.") - - lobby_peers = self.lobby_peers[lobby.id][:] # Copy the list - del self.lobby_peers[lobby.id] - if lobby in self.lobbies: - self.lobbies.remove(lobby) - - # Remove this peer from all other RTC peers, and remove each peer from this peer - for peer_session_id in lobby_peers: - peer_session = getSession(peer_session_id) - if not peer_session: - logger.warning( - f"{self.getName()} <- part({lobby.getName()}) - Peer session {peer_session_id} not found. Skipping." - ) - continue - - if peer_session.ws: - logger.info( - f"{peer_session.getName()} <- remove_peer({self.getName()})" - ) - try: - await peer_session.ws.send_json( - { - "type": "removePeer", - "data": {"peer_name": self.name, "peer_id": self.id}, - } - ) - except Exception as e: - logger.warning( - f"Failed to send removePeer to {peer_session.getName()}: {e}" - ) - else: - logger.warning( - f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection for {peer_session.getName()}. Skipping." - ) - - # Remove from peer's lobby_peers - with peer_session.session_lock: - if ( - lobby.id in peer_session.lobby_peers - and self.id in peer_session.lobby_peers[lobby.id] - ): - peer_session.lobby_peers[lobby.id].remove(self.id) - - if self.ws: - logger.info( - f"{self.getName()} <- remove_peer({peer_session.getName()})" - ) - try: - await self.ws.send_json( - { - "type": "removePeer", - "data": { - "peer_name": peer_session.name, - "peer_id": peer_session.id, - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to send removePeer to {self.getName()}: {e}" - ) - else: - logger.error( - f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection." - ) - - await lobby.removeSession(self) - Session.save() - - -def getName(session: Session | None) -> str | None: - if session and session.name: - return session.name - return None - - -def getSession(session_id: str) -> Session | None: - return Session.getSession(session_id) - - -def getLobby(lobby_id: str) -> Lobby: - lobby = lobbies.get(lobby_id, None) - if not lobby: - # Check if this might be a stale reference after cleanup - logger.warning(f"Lobby not found: {lobby_id} (may have been cleaned up)") - raise Exception(f"Lobby not found: {lobby_id}") - return lobby - - -def getLobbyByName(lobby_name: str) -> Lobby | None: - for lobby in lobbies.values(): - if lobby.name == lobby_name: - return lobby - return None - - -# API endpoints -@app.get(f"{public_url}api/health", response_model=HealthResponse) -def health(): - logger.info("Health check endpoint called.") - return HealthResponse(status="ok") - - -# A session (cookie) is bound to a single user (name). -# A user can be in multiple lobbies, but a session is unique to a single user. -# A user can change their name, but the session ID remains the same and the name -# updates for all lobbies. -@app.get(f"{public_url}api/session", response_model=SessionResponse) -async def session( - request: Request, response: Response, session_id: str | None = Cookie(default=None) -) -> Response | SessionResponse: - if session_id is None: - session_id = secrets.token_hex(16) - response.set_cookie(key="session_id", value=session_id) - # Validate that session_id is a hex string of length 32 - elif len(session_id) != 32 or not all(c in "0123456789abcdef" for c in session_id): - return Response( - content=json.dumps({"error": "Invalid session_id"}), - status_code=400, - media_type="application/json", - ) - - print(f"[{session_id[:8]}]: Browser hand-shake achieved.") - - session = getSession(session_id) - if not session: - session = Session(session_id) - logger.info(f"{session.getName()}: New session created.") - else: - session.update_last_used() # Update activity on session resumption - logger.info(f"{session.getName()}: Existing session resumed.") - # Part all lobbies for this session that have no active websocket - with session.session_lock: - lobbies_to_part = session.lobbies[:] - for lobby in lobbies_to_part: - try: - await session.part(lobby) - except Exception as e: - logger.error( - f"{session.getName()} - Error parting lobby {lobby.getName()}: {e}" - ) - - with session.session_lock: - return SessionResponse( - id=session_id, - name=session.name if session.name else "", - lobbies=[ - LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private) - for lobby in session.lobbies - ], - ) - - -@app.get(public_url + "api/lobby", response_model=LobbiesResponse) -async def get_lobbies(request: Request, response: Response) -> LobbiesResponse: - return LobbiesResponse( - lobbies=[ - LobbyListItem(id=lobby.id, name=lobby.name) - for lobby in lobbies.values() - if not lobby.private - ] - ) - - -@app.post(public_url + "api/lobby/{session_id}", response_model=LobbyCreateResponse) -async def lobby_create( - request: Request, - response: Response, - session_id: str = Path(...), - create_request: LobbyCreateRequest = Body(...), -) -> Response | LobbyCreateResponse: - if create_request.type != "lobby_create": - return Response( - content=json.dumps({"error": "Invalid request type"}), - status_code=400, - media_type="application/json", - ) - - data = create_request.data - session = getSession(session_id) - if not session: - return Response( - content=json.dumps({"error": f"Session not found ({session_id})"}), - status_code=404, - media_type="application/json", - ) - logger.info( - f"{session.getName()} lobby_create: {data.name} (private={data.private})" - ) - - lobby = getLobbyByName(data.name) - if not lobby: - lobby = Lobby( - data.name, - private=data.private, - ) - lobbies[lobby.id] = lobby - logger.info(f"{session.getName()} <- lobby_create({lobby.short}:{lobby.name})") - - return LobbyCreateResponse( - type="lobby_created", - data=LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private), - ) - - -@app.get(public_url + "api/lobby/{lobby_id}/chat", response_model=ChatMessagesResponse) -async def get_chat_messages( - request: Request, - lobby_id: str = Path(...), - limit: int = 50, -) -> Response | ChatMessagesResponse: - """Get chat messages for a lobby""" - try: - lobby = getLobby(lobby_id) - except Exception as e: - return Response( - content=json.dumps({"error": str(e)}), - status_code=404, - media_type="application/json", - ) - - messages = lobby.get_chat_messages(limit) - - return ChatMessagesResponse(messages=messages) - - -# ============================================================================= -# Bot Provider API Endpoints -# ============================================================================= - - -@app.post( - public_url + "api/bots/providers/register", - response_model=BotProviderRegisterResponse, -) -async def register_bot_provider( - request: BotProviderRegisterRequest, -) -> BotProviderRegisterResponse: - """Register a new bot provider with authentication""" - import uuid - - # Check if provider authentication is enabled - allowed_providers = BotProviderConfig.get_allowed_providers() - if allowed_providers: - # Authentication is enabled - validate provider key - if request.provider_key not in allowed_providers: - logger.warning( - f"Rejected bot provider registration with invalid key: {request.provider_key}" - ) - raise HTTPException( - status_code=403, - detail="Invalid provider key. Bot provider is not authorized to register.", - ) - - # Check if there's already an active provider with this key and remove it - providers_to_remove: list[str] = [] - for existing_provider_id, existing_provider in bot_providers.items(): - if existing_provider.provider_key == request.provider_key: - providers_to_remove.append(existing_provider_id) - logger.info( - f"Removing stale bot provider: {existing_provider.name} (ID: {existing_provider_id})" - ) - - # Remove stale providers - for provider_id_to_remove in providers_to_remove: - del bot_providers[provider_id_to_remove] - - provider_id = str(uuid.uuid4()) - now = time.time() - - provider = BotProviderModel( - provider_id=provider_id, - base_url=request.base_url.rstrip("/"), - name=request.name, - description=request.description, - provider_key=request.provider_key, - registered_at=now, - last_seen=now, - ) - - bot_providers[provider_id] = provider - logger.info( - f"Registered bot provider: {request.name} at {request.base_url} with key: {request.provider_key}" - ) - - return BotProviderRegisterResponse(provider_id=provider_id) - - -@app.get(public_url + "api/bots/providers", response_model=BotProviderListResponse) -async def list_bot_providers() -> BotProviderListResponse: - """List all registered bot providers""" - return BotProviderListResponse(providers=list(bot_providers.values())) - - -@app.get(public_url + "api/bots", response_model=BotListResponse) -async def list_available_bots() -> BotListResponse: - """List all available bots from all registered providers""" - bots: List[BotInfoModel] = [] - providers: dict[str, str] = {} - - # Update last_seen timestamps and fetch bots from each provider - for provider_id, provider in bot_providers.items(): - try: - provider.last_seen = time.time() - - # Make HTTP request to provider's /bots endpoint - async with httpx.AsyncClient() as client: - response = await client.get(f"{provider.base_url}/bots", timeout=5.0) - if response.status_code == 200: - # Use Pydantic model to validate the response - bots_response = BotProviderBotsResponse.model_validate( - response.json() - ) - # Add each bot to the consolidated list - for bot_info in bots_response.bots: - bots.append(bot_info) - providers[bot_info.name] = provider_id - else: - logger.warning( - f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}" - ) - except Exception as e: - logger.error(f"Error fetching bots from provider {provider.name}: {e}") - continue - - return BotListResponse(bots=bots, providers=providers) - - -@app.post(public_url + "api/bots/{bot_name}/join", response_model=BotJoinLobbyResponse) -async def request_bot_join_lobby( - bot_name: str, request: BotJoinLobbyRequest -) -> BotJoinLobbyResponse: - """Request a bot to join a specific lobby""" - - # Find which provider has this bot and determine its media capability - target_provider_id = request.provider_id - bot_has_media = False - if not target_provider_id: - # Auto-discover provider for this bot - for provider_id, provider in bot_providers.items(): - try: - async with httpx.AsyncClient() as client: - response = await client.get( - f"{provider.base_url}/bots", timeout=5.0 - ) - if response.status_code == 200: - # Use Pydantic model to validate the response - bots_response = BotProviderBotsResponse.model_validate( - response.json() - ) - # Look for the bot by name - for bot_info in bots_response.bots: - if bot_info.name == bot_name: - target_provider_id = provider_id - bot_has_media = bot_info.has_media - break - if target_provider_id: - break - except Exception: - continue - else: - # Query the specified provider for bot media capability - if target_provider_id in bot_providers: - provider = bot_providers[target_provider_id] - try: - async with httpx.AsyncClient() as client: - response = await client.get( - f"{provider.base_url}/bots", timeout=5.0 - ) - if response.status_code == 200: - # Use Pydantic model to validate the response - bots_response = BotProviderBotsResponse.model_validate( - response.json() - ) - # Look for the bot by name - for bot_info in bots_response.bots: - if bot_info.name == bot_name: - bot_has_media = bot_info.has_media - break - except Exception: - # Default to no media if we can't query - pass - - if not target_provider_id or target_provider_id not in bot_providers: - raise HTTPException(status_code=404, detail="Bot or provider not found") - - provider = bot_providers[target_provider_id] - - # Get the lobby to validate it exists - try: - getLobby(request.lobby_id) # Just validate it exists - except Exception: - raise HTTPException(status_code=404, detail="Lobby not found") - - # Create a session for the bot - bot_session_id = secrets.token_hex(16) - - # Create the Session object for the bot - bot_session = Session(bot_session_id, is_bot=True, has_media=bot_has_media) - logger.info( - f"Created bot session for: {bot_session.getName()} (has_media={bot_has_media})" - ) - - # Determine server URL for the bot to connect back to - # Use the server's public URL or construct from request - server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000") - if server_base_url.endswith("/"): - server_base_url = server_base_url[:-1] - - bot_nick = request.nick or f"{bot_name}-bot-{bot_session_id[:8]}" - - # Prepare the join request for the bot provider - bot_join_payload = BotJoinPayload( - lobby_id=request.lobby_id, - session_id=bot_session_id, - nick=bot_nick, - server_url=f"{server_base_url}{public_url}".rstrip("/"), - insecure=True, # Accept self-signed certificates in development - ) - - try: - # Make request to bot provider - async with httpx.AsyncClient() as client: - response = await client.post( - f"{provider.base_url}/bots/{bot_name}/join", - json=bot_join_payload.model_dump(), - timeout=10.0, - ) - - if response.status_code == 200: - # Use Pydantic model to parse and validate response - try: - join_response = BotProviderJoinResponse.model_validate( - response.json() - ) - run_id = join_response.run_id - - # Update bot session with run and provider information - with bot_session.session_lock: - bot_session.bot_run_id = run_id - bot_session.bot_provider_id = target_provider_id - bot_session.setName(bot_nick) - - logger.info( - f"Bot {bot_name} requested to join lobby {request.lobby_id}" - ) - - return BotJoinLobbyResponse( - status="requested", - bot_name=bot_name, - run_id=run_id, - provider_id=target_provider_id, - ) - except ValidationError as e: - logger.error(f"Invalid response from bot provider: {e}") - raise HTTPException( - status_code=502, - detail=f"Bot provider returned invalid response: {str(e)}", - ) - else: - logger.error( - f"Bot provider returned error: HTTP {response.status_code}: {response.text}" - ) - raise HTTPException( - status_code=502, - detail=f"Bot provider error: {response.status_code}", - ) - - except httpx.TimeoutException: - raise HTTPException(status_code=504, detail="Bot provider timeout") - except Exception as e: - logger.error(f"Error requesting bot join: {e}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - - -@app.post(public_url + "api/bots/leave", response_model=BotLeaveLobbyResponse) -async def request_bot_leave_lobby( - request: BotLeaveLobbyRequest, -) -> BotLeaveLobbyResponse: - """Request a bot to leave from all lobbies and disconnect""" - - # Find the bot session - bot_session = getSession(request.session_id) - if not bot_session: - raise HTTPException(status_code=404, detail="Bot session not found") - - if not bot_session.is_bot: - raise HTTPException(status_code=400, detail="Session is not a bot") - - run_id = bot_session.bot_run_id - provider_id = bot_session.bot_provider_id - - logger.info(f"Requesting bot {bot_session.getName()} to leave all lobbies") - - # Try to stop the bot at the provider level if we have the information - if provider_id and run_id and provider_id in bot_providers: - provider = bot_providers[provider_id] - try: - async with httpx.AsyncClient() as client: - response = await client.post( - f"{provider.base_url}/bots/runs/{run_id}/stop", - timeout=5.0, - ) - if response.status_code == 200: - logger.info( - f"Successfully requested bot provider to stop run {run_id}" - ) - else: - logger.warning( - f"Bot provider returned error when stopping: HTTP {response.status_code}" - ) - except Exception as e: - logger.warning(f"Failed to request bot stop from provider: {e}") - - # Force disconnect the bot session from all lobbies - with bot_session.session_lock: - lobbies_to_part = bot_session.lobbies[:] - - for lobby in lobbies_to_part: - try: - await bot_session.part(lobby) - except Exception as e: - logger.warning(f"Error parting bot from lobby {lobby.getName()}: {e}") - - # Close WebSocket connection if it exists - if bot_session.ws: - try: - await bot_session.ws.close() - except Exception as e: - logger.warning(f"Error closing bot WebSocket: {e}") - bot_session.ws = None - - return BotLeaveLobbyResponse( - status="disconnected", - session_id=request.session_id, - run_id=run_id, - ) - - -# Register websocket endpoint directly on app with full public_url path -@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}") -async def lobby_join( - websocket: WebSocket, - lobby_id: str | None = Path(...), - session_id: str | None = Path(...), -): - await websocket.accept() - if lobby_id is None: - await websocket.send_json( - {"type": "error", "data": {"error": "Invalid or missing lobby"}} - ) - await websocket.close() - return - if session_id is None: - await websocket.send_json( - {"type": "error", "data": {"error": "Invalid or missing session"}} - ) - await websocket.close() - return - session = getSession(session_id) - if not session: - # logger.error(f"Invalid session ID {session_id}") - await websocket.send_json( - {"type": "error", "data": {"error": f"Invalid session ID {session_id}"}} - ) - await websocket.close() - return - - lobby = None - try: - lobby = getLobby(lobby_id) - except Exception as e: - await websocket.send_json({"type": "error", "data": {"error": str(e)}}) - await websocket.close() - return - - logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})") - - session.ws = websocket - session.update_last_used() # Update activity timestamp - - # Check if session is already in lobby and clean up if needed - with lobby.lock: - if session.id in lobby.sessions: - logger.info( - f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining." - ) - try: - await session.part(lobby) - await lobby.removeSession(session) - except Exception as e: - logger.warning(f"Error cleaning up stale session: {e}") - - # Notify existing peers about new user - failed_peers: list[str] = [] - with lobby.lock: - peer_sessions = list(lobby.sessions.values()) - - for peer_session in peer_sessions: - if not peer_session.ws: - logger.warning( - f"{session.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Marking for removal." - ) - failed_peers.append(peer_session.id) - continue - - logger.info(f"{session.getName()} -> user_joined({peer_session.getName()})") - try: - await peer_session.ws.send_json( - { - "type": "user_joined", - "data": { - "session_id": session.id, - "name": session.name, - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to notify {peer_session.getName()} of user join: {e}" - ) - failed_peers.append(peer_session.id) - - # Clean up failed peers - with lobby.lock: - for failed_peer_id in failed_peers: - if failed_peer_id in lobby.sessions: - del lobby.sessions[failed_peer_id] - - try: - while True: - packet = await websocket.receive_json() - session.update_last_used() # Update activity on each message - type = packet.get("type", None) - data: dict[str, Any] | None = packet.get("data", None) - if not type: - logger.error(f"{session.getName()} - Invalid request: {packet}") - await websocket.send_json( - {"type": "error", "data": {"error": "Invalid request"}} - ) - continue - # logger.info(f"{session.getName()} <- RAW Rx: {data}") - match type: - case "set_name": - if not data: - logger.error(f"{session.getName()} - set_name missing data") - await websocket.send_json( - { - "type": "error", - "data": {"error": "set_name missing data"}, - } - ) - continue - name = data.get("name") - password = data.get("password") - logger.info(f"{session.getName()} <- set_name({name}, {password})") - if not name: - logger.error(f"{session.getName()} - Name required") - await websocket.send_json( - {"type": "error", "data": {"error": "Name required"}} - ) - continue - # Name takeover / password logic - lname = name.lower() - - # If name is unused, allow and optionally save password - if Session.isUniqueName(name): - # If a password was provided, save it (hash+salt) for this name - if password: - salt, hash_hex = _hash_password(password) - name_passwords[lname] = {"salt": salt, "hash": hash_hex} - session.setName(name) - logger.info(f"{session.getName()}: -> update('name', {name})") - await websocket.send_json( - { - "type": "update_name", - "data": { - "name": name, - "protected": True - if name.lower() in name_passwords - else False, - }, - } - ) - # For any clients in any lobby with this session, update their user lists - await lobby.update_state() - continue - - # Name is taken. Check if a password exists for the name and matches. - saved_pw = name_passwords.get(lname) - if not saved_pw and not password: - logger.warning( - f"{session.getName()} - Name already taken (no password set)" - ) - await websocket.send_json( - {"type": "error", "data": {"error": "Name already taken"}} - ) - continue - - if saved_pw and password: - # Expect structured record with salt+hash only - match_password = False - # saved_pw should be a dict[str,str] with 'salt' and 'hash' - salt = saved_pw.get("salt") - _, candidate_hash = _hash_password( - password if password else "", salt_hex=salt - ) - if candidate_hash == saved_pw.get("hash"): - match_password = True - else: - # No structured password record available - match_password = False - else: - match_password = True # No password set, but name taken and new password - allow takeover - - if not match_password: - logger.warning( - f"{session.getName()} - Name takeover attempted with wrong or missing password" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "Invalid password for name takeover", - }, - } - ) - continue - - # Password matches: perform takeover. Find the current session holding the name. - # Find the currently existing session (if any) with that name - displaced = Session.getSessionByName(name) - if displaced and displaced.id == session.id: - displaced = None - - # If found, change displaced session to a unique fallback name and notify peers - if displaced: - # Create a unique fallback name - fallback = f"{displaced.name}-{displaced.short}" - # Ensure uniqueness - if not Session.isUniqueName(fallback): - # append random suffix until unique - while not Session.isUniqueName(fallback): - fallback = f"{displaced.name}-{secrets.token_hex(3)}" - - displaced.setName(fallback) - displaced.mark_displaced() - logger.info( - f"{displaced.getName()} <- displaced by takeover, new name {fallback}" - ) - # Notify displaced session (if connected) - if displaced.ws: - try: - await displaced.ws.send_json( - { - "type": "update_name", - "data": { - "name": fallback, - "protected": False, - }, - } - ) - except Exception: - logger.exception( - "Failed to notify displaced session websocket" - ) - # Update all lobbies the displaced session was in - with displaced.session_lock: - displaced_lobbies = displaced.lobbies[:] - for d_lobby in displaced_lobbies: - try: - await d_lobby.update_state() - except Exception: - logger.exception( - "Failed to update lobby state for displaced session" - ) - - # Now assign the requested name to the current session - session.setName(name) - logger.info( - f"{session.getName()}: -> update('name', {name}) (takeover)" - ) - await websocket.send_json( - { - "type": "update_name", - "data": { - "name": name, - "protected": True - if name.lower() in name_passwords - else False, - }, - } - ) - # Notify lobbies for this session - await lobby.update_state() - - case "list_users": - await lobby.update_state(session) - - case "get_chat_messages": - # Send recent chat messages to the requesting client - messages = lobby.get_chat_messages(50) - await websocket.send_json( - { - "type": "chat_messages", - "data": { - "messages": [msg.model_dump() for msg in messages] - }, - } - ) - - case "send_chat_message": - if not data or "message" not in data: - logger.error( - f"{session.getName()} - send_chat_message missing message" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "send_chat_message missing message", - }, - } - ) - continue - - if not session.name: - logger.error( - f"{session.getName()} - Cannot send chat message without name" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "Must set name before sending chat messages", - }, - } - ) - continue - - message_text = str(data["message"]).strip() - if not message_text: - continue - - # Add the message to the lobby and broadcast it - chat_message = lobby.add_chat_message(session, message_text) - logger.info( - f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)" - ) - await lobby.broadcast_chat_message(chat_message) - - case "join": - logger.info(f"{session.getName()} <- join({lobby.getName()})") - await session.join(lobby=lobby) - - case "part": - logger.info(f"{session.getName()} <- part {lobby.getName()}") - await session.part(lobby=lobby) - - case "relayICECandidate": - logger.info(f"{session.getName()} <- relayICECandidate") - if not data: - logger.error( - f"{session.getName()} - relayICECandidate missing data" - ) - await websocket.send_json( - { - "type": "error", - "data": {"error": "relayICECandidate missing data"}, - } - ) - continue - - with session.session_lock: - if ( - lobby.id not in session.lobby_peers - or session.id not in lobby.sessions - ): - logger.error( - f"{session.short}:{session.name} <- relayICECandidate - Not an RTC peer ({session.id})" - ) - await websocket.send_json( - { - "type": "error", - "data": {"error": "Not joined to lobby"}, - } - ) - continue - session_peers = session.lobby_peers[lobby.id] - - peer_id = data.get("peer_id") - if peer_id not in session_peers: - logger.error( - f"{session.getName()} <- relayICECandidate - Not an RTC peer({peer_id}) in {session_peers}" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": f"Target peer {peer_id} not found", - }, - } - ) - continue - - candidate = data.get("candidate") - - message: dict[str, Any] = { - "type": "iceCandidate", - "data": { - "peer_id": session.id, - "peer_name": session.name, - "candidate": candidate, - }, - } - - peer_session = lobby.getSession(peer_id) - if not peer_session or not peer_session.ws: - logger.warning( - f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}." - ) - continue - logger.info( - f"{session.getName()} -> iceCandidate({peer_session.getName()})" - ) - try: - await peer_session.ws.send_json(message) - except Exception as e: - logger.warning(f"Failed to relay ICE candidate: {e}") - - case "relaySessionDescription": - logger.info(f"{session.getName()} <- relaySessionDescription") - if not data: - logger.error( - f"{session.getName()} - relaySessionDescription missing data" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "relaySessionDescription missing data", - }, - } - ) - continue - - with session.session_lock: - if ( - lobby.id not in session.lobby_peers - or session.id not in lobby.sessions - ): - logger.error( - f"{session.short}:{session.name} <- relaySessionDescription - Not an RTC peer ({session.id})" - ) - await websocket.send_json( - { - "type": "error", - "data": {"error": "Not joined to lobby"}, - } - ) - continue - - lobby_peers = session.lobby_peers[lobby.id] - - peer_id = data.get("peer_id") - if peer_id not in lobby_peers: - logger.error( - f"{session.getName()} <- relaySessionDescription - Not an RTC peer({peer_id}) in {lobby_peers}" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": f"Target peer {peer_id} not found", - }, - } - ) - continue - - if not peer_id: - logger.error( - f"{session.getName()} - relaySessionDescription missing peer_id" - ) - await websocket.send_json( - { - "type": "error", - "data": { - "error": "relaySessionDescription missing peer_id", - }, - } - ) - continue - peer_session = lobby.getSession(peer_id) - if not peer_session or not peer_session.ws: - logger.warning( - f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}." - ) - continue - - session_description = data.get("session_description") - message = { - "type": "sessionDescription", - "data": { - "peer_id": session.id, - "peer_name": session.name, - "session_description": session_description, - }, - } - - logger.info( - f"{session.getName()} -> sessionDescription({peer_session.getName()})" - ) - try: - await peer_session.ws.send_json(message) - except Exception as e: - logger.warning(f"Failed to relay session description: {e}") - - case "status_check": - # Simple status check - just respond with success to keep connection alive - logger.debug(f"{session.getName()} <- status_check") - await websocket.send_json( - {"type": "status_ok", "data": {"timestamp": time.time()}} - ) - - case _: - await websocket.send_json( - { - "type": "error", - "data": { - "error": f"Unknown request type: {type}", - }, - } - ) - - except WebSocketDisconnect: - logger.info(f"{session.getName()} <- WebSocket disconnected for user.") - # Cleanup: remove session from lobby and sessions dict - session.ws = None - if session.id in lobby.sessions: - try: - await session.part(lobby) - except Exception as e: - logger.warning(f"Error during websocket disconnect cleanup: {e}") - - try: - await lobby.update_state() - except Exception as e: - logger.warning(f"Error updating lobby state after disconnect: {e}") - - # Clean up empty lobbies - with lobby.lock: - if not lobby.sessions: - if lobby.id in lobbies: - del lobbies[lobby.id] - logger.info(f"Cleaned up empty lobby {lobby.getName()}") - except Exception as e: - logger.error( - f"Unexpected error in websocket handler for {session.getName()}: {e}" - ) - try: - await websocket.close() - except Exception as e: - pass - - -# Serve static files or proxy to frontend development server -PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true" -client_build_path = os.path.join(os.path.dirname(__file__), "/client/build") - -if PRODUCTION: - logger.info(f"Serving static files from: {client_build_path} at {public_url}") - app.mount( - public_url, StaticFiles(directory=client_build_path, html=True), name="static" - ) - - -else: - logger.info(f"Proxying static files to http://client:3000 at {public_url}") - - import ssl - - @app.api_route( - f"{public_url}{{path:path}}", - methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"], - ) - async def proxy_static(request: Request, path: str): - # Do not proxy API or websocket paths - if path.startswith("api/") or path.startswith("ws/"): - return Response(status_code=404) - url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}/{path}" - if not path: - url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}" - headers = dict(request.headers) - try: - # Accept self-signed certs in dev - async with httpx.AsyncClient(verify=False) as client: - proxy_req = client.build_request( - request.method, url, headers=headers, content=await request.body() - ) - proxy_resp = await client.send(proxy_req, stream=True) - content = await proxy_resp.aread() - - # Remove problematic headers for browser decoding - filtered_headers = { - k: v - for k, v in proxy_resp.headers.items() - if k.lower() - not in ["content-encoding", "transfer-encoding", "content-length"] - } - return Response( - content=content, - status_code=proxy_resp.status_code, - headers=filtered_headers, - ) - except Exception as e: - logger.error(f"Proxy error for {url}: {e}") - return Response("Proxy error", status_code=502) - - # WebSocket proxy for /ws (for React DevTools, etc.) - import websockets - - @app.websocket("/ws") - async def websocket_proxy(websocket: WebSocket): - logger.info("REACT: WebSocket proxy connection established.") - # Get scheme from websocket.url (should be 'ws' or 'wss') - scheme = websocket.url.scheme if hasattr(websocket, "url") else "ws" - target_url = "wss://client:3000/ws" # Use WSS since client uses HTTPS - await websocket.accept() - try: - # Accept self-signed certs in dev for WSS - ssl_ctx = ssl.create_default_context() - ssl_ctx.check_hostname = False - ssl_ctx.verify_mode = ssl.CERT_NONE - async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws: - - async def client_to_server(): - while True: - msg = await websocket.receive_text() - await target_ws.send(msg) - - async def server_to_client(): - while True: - msg = await target_ws.recv() - if isinstance(msg, str): - await websocket.send_text(msg) - else: - await websocket.send_bytes(msg) - - try: - await asyncio.gather(client_to_server(), server_to_client()) - except (WebSocketDisconnect, websockets.ConnectionClosed): - logger.info("REACT: WebSocket proxy connection closed.") - except Exception as e: - logger.error(f"REACT: WebSocket proxy error: {e}") - await websocket.close() diff --git a/server/websocket/message_handlers.py b/server/websocket/message_handlers.py index 39f37e1..0030e6b 100644 --- a/server/websocket/message_handlers.py +++ b/server/websocket/message_handlers.py @@ -10,6 +10,7 @@ from typing import Dict, Any, TYPE_CHECKING from fastapi import WebSocket from logger import logger +from .webrtc_signaling import WebRTCSignalingHandlers if TYPE_CHECKING: from ..core.session_manager import Session @@ -255,6 +256,38 @@ class SendChatMessageHandler(MessageHandler): await lobby.broadcast_chat_message(chat_message) +class RelayICECandidateHandler(MessageHandler): + """Handler for relayICECandidate messages - WebRTC signaling""" + + async def handle( + self, + session: "Session", + lobby: "Lobby", + data: Dict[str, Any], + websocket: WebSocket, + managers: Dict[str, Any], + ) -> None: + await WebRTCSignalingHandlers.handle_relay_ice_candidate( + websocket, session, lobby, data + ) + + +class RelaySessionDescriptionHandler(MessageHandler): + """Handler for relaySessionDescription messages - WebRTC signaling""" + + async def handle( + self, + session: "Session", + lobby: "Lobby", + data: Dict[str, Any], + websocket: WebSocket, + managers: Dict[str, Any], + ) -> None: + await WebRTCSignalingHandlers.handle_relay_session_description( + websocket, session, lobby, data + ) + + class MessageRouter: """Routes WebSocket messages to appropriate handlers""" @@ -270,6 +303,10 @@ class MessageRouter: self.register("list_users", ListUsersHandler()) self.register("get_chat_messages", GetChatMessagesHandler()) self.register("send_chat_message", SendChatMessageHandler()) + + # WebRTC signaling handlers + self.register("relayICECandidate", RelayICECandidateHandler()) + self.register("relaySessionDescription", RelaySessionDescriptionHandler()) def register(self, message_type: str, handler: MessageHandler): """Register a handler for a message type""" diff --git a/server/websocket/webrtc_signaling.py b/server/websocket/webrtc_signaling.py new file mode 100644 index 0000000..cf027da --- /dev/null +++ b/server/websocket/webrtc_signaling.py @@ -0,0 +1,199 @@ +""" +WebRTC Signaling Handlers + +This module contains WebRTC signaling message handlers for peer-to-peer communication. +Handles ICE candidate relay and session description exchange between peers. +""" + +from typing import Any, Dict, TYPE_CHECKING +from fastapi import WebSocket + +from logger import logger + +if TYPE_CHECKING: + from core.session_manager import Session + from core.lobby_manager import Lobby + + +class WebRTCSignalingHandlers: + """WebRTC signaling message handlers for peer-to-peer communication.""" + + @staticmethod + async def handle_relay_ice_candidate( + websocket: WebSocket, + session: "Session", + lobby: "Lobby", + data: Dict[str, Any] + ) -> None: + """ + Handle ICE candidate relay between peers. + + Args: + websocket: The WebSocket connection + session: The sender session + lobby: The lobby context + data: Message data containing peer_id and candidate + """ + logger.info(f"{session.getName()} <- relayICECandidate") + + if not data: + logger.error(f"{session.getName()} - relayICECandidate missing data") + await websocket.send_json({ + "type": "error", + "data": {"error": "relayICECandidate missing data"} + }) + return + + # Check if session is properly joined to lobby with RTC peers + with session.session_lock: + if (lobby.id not in session.lobby_peers or + session.id not in lobby.sessions): + logger.error( + f"{session.short}:{session.name} <- relayICECandidate - " + f"Not an RTC peer ({session.id})" + ) + await websocket.send_json({ + "type": "error", + "data": {"error": "Not joined to lobby"} + }) + return + + session_peers = session.lobby_peers[lobby.id] + + # Validate peer_id + peer_id = data.get("peer_id") + if peer_id not in session_peers: + logger.error( + f"{session.getName()} <- relayICECandidate - " + f"Not an RTC peer({peer_id}) in {session_peers}" + ) + await websocket.send_json({ + "type": "error", + "data": {"error": f"Target peer {peer_id} not found"} + }) + return + + # Get candidate data + candidate = data.get("candidate") + + # Prepare message for target peer + message: Dict[str, Any] = { + "type": "iceCandidate", + "data": { + "peer_id": session.id, + "peer_name": session.name, + "candidate": candidate, + }, + } + + # Find target peer session and relay the message + peer_session = lobby.getSession(peer_id) + if not peer_session or not peer_session.ws: + logger.warning( + f"{session.getName()} - Live peer session {peer_id} " + f"not found in lobby {lobby.getName()}." + ) + return + + logger.info( + f"{session.getName()} -> iceCandidate({peer_session.getName()})" + ) + + try: + await peer_session.ws.send_json(message) + except Exception as e: + logger.warning(f"Failed to relay ICE candidate: {e}") + + @staticmethod + async def handle_relay_session_description( + websocket: WebSocket, + session: "Session", + lobby: "Lobby", + data: Dict[str, Any] + ) -> None: + """ + Handle session description relay between peers. + + Args: + websocket: The WebSocket connection + session: The sender session + lobby: The lobby context + data: Message data containing peer_id and session_description + """ + logger.info(f"{session.getName()} <- relaySessionDescription") + + if not data: + logger.error(f"{session.getName()} - relaySessionDescription missing data") + await websocket.send_json({ + "type": "error", + "data": {"error": "relaySessionDescription missing data"} + }) + return + + # Check if session is properly joined to lobby with RTC peers + with session.session_lock: + if (lobby.id not in session.lobby_peers or + session.id not in lobby.sessions): + logger.error( + f"{session.short}:{session.name} <- relaySessionDescription - " + f"Not an RTC peer ({session.id})" + ) + await websocket.send_json({ + "type": "error", + "data": {"error": "Not joined to lobby"} + }) + return + + lobby_peers = session.lobby_peers[lobby.id] + + # Validate peer_id + peer_id = data.get("peer_id") + if not peer_id: + logger.error(f"{session.getName()} - relaySessionDescription missing peer_id") + await websocket.send_json({ + "type": "error", + "data": {"error": "relaySessionDescription missing peer_id"} + }) + return + + if peer_id not in lobby_peers: + logger.error( + f"{session.getName()} <- relaySessionDescription - " + f"Not an RTC peer({peer_id}) in {lobby_peers}" + ) + await websocket.send_json({ + "type": "error", + "data": {"error": f"Target peer {peer_id} not found"} + }) + return + + # Find target peer session + peer_session = lobby.getSession(peer_id) + if not peer_session or not peer_session.ws: + logger.warning( + f"{session.getName()} - Live peer session {peer_id} " + f"not found in lobby {lobby.getName()}." + ) + return + + # Get session description data + session_description = data.get("session_description") + + # Prepare message for target peer + message: Dict[str, Any] = { + "type": "sessionDescription", + "data": { + "peer_id": session.id, + "peer_name": session.name, + "session_description": session_description, + }, + } + + logger.info( + f"{session.getName()} -> sessionDescription({peer_session.getName()})" + ) + + try: + await peer_session.ws.send_json(message) + except Exception as e: + logger.warning(f"Failed to relay session description: {e}") diff --git a/tests/test-webrtc-signaling.py b/tests/test-webrtc-signaling.py new file mode 100644 index 0000000..1fabf97 --- /dev/null +++ b/tests/test-webrtc-signaling.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +""" +Test WebRTC signaling handlers registration +""" + +import asyncio +import json +import websockets +import sys + +async def test_webrtc_handlers(): + """Test that WebRTC signaling handlers are properly registered""" + try: + # Connect to the WebSocket endpoint + uri = "ws://localhost:8000/ai-voicebot/ws" + + async with websockets.connect(uri) as websocket: + print("Connected to WebSocket") + + # Send a set_name message first + await websocket.send(json.dumps({ + "type": "set_name", + "data": {"name": "test_user"} + })) + + response = await websocket.recv() + print(f"Set name response: {response}") + + # Test relayICECandidate handler + test_message = { + "type": "relayICECandidate", + "data": { + "peer_id": "nonexistent_peer", + "candidate": {"candidate": "test"} + } + } + + await websocket.send(json.dumps(test_message)) + print("Sent relayICECandidate message") + + # Expect an error response since we're not in a lobby + response = await websocket.recv() + print(f"ICE candidate response: {response}") + + # Test relaySessionDescription handler + test_message = { + "type": "relaySessionDescription", + "data": { + "peer_id": "nonexistent_peer", + "session_description": {"type": "offer"} + } + } + + await websocket.send(json.dumps(test_message)) + print("Sent relaySessionDescription message") + + # Expect an error response since we're not in a lobby + response = await websocket.recv() + print(f"Session description response: {response}") + + print("WebRTC signaling handlers are working!") + + except Exception as e: + print(f"Error testing WebRTC handlers: {e}") + return False + + return True + +if __name__ == "__main__": + success = asyncio.run(test_webrtc_handlers()) + sys.exit(0 if success else 1) diff --git a/tests/verify-webrtc-handlers.py b/tests/verify-webrtc-handlers.py new file mode 100644 index 0000000..f1a297f --- /dev/null +++ b/tests/verify-webrtc-handlers.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +""" +Test script to verify WebRTC signaling handlers are registered +""" + +import sys + +# Add the server directory to Python path +sys.path.insert(0, '/home/jketreno/docker/ai-voicebot/server') + +from websocket.message_handlers import MessageRouter + +def test_webrtc_handlers(): + """Test that WebRTC signaling handlers are registered""" + router = MessageRouter() + supported_types = router.get_supported_types() + + print("Supported message types:") + for msg_type in sorted(supported_types): + print(f" - {msg_type}") + + # Check for WebRTC handlers + webrtc_handlers = [ + "relayICECandidate", + "relaySessionDescription" + ] + + print("\nWebRTC signaling handlers:") + for handler in webrtc_handlers: + if handler in supported_types: + print(f" āœ“ {handler} - REGISTERED") + else: + print(f" āœ— {handler} - MISSING") + return False + + print("\nāœ… All WebRTC signaling handlers are properly registered!") + return True + +if __name__ == "__main__": + success = test_webrtc_handlers() + sys.exit(0 if success else 1)