from __future__ import annotations from typing import Any, Optional from fastapi import ( Body, Cookie, FastAPI, HTTPException, Path, WebSocket, Request, Response, WebSocketDisconnect, ) import secrets import os import json import hashlib import binascii import sys import asyncio import threading import time from contextlib import asynccontextmanager from fastapi.staticfiles import StaticFiles import httpx from pydantic import ValidationError from logger import logger # Import shared models sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from shared.models import ( HealthResponse, LobbiesResponse, LobbyCreateRequest, LobbyCreateResponse, LobbyListItem, LobbyModel, NamePasswordRecord, LobbySaved, SessionResponse, SessionSaved, SessionsPayload, AdminNamesResponse, AdminActionResponse, AdminSetPassword, AdminClearPassword, AdminValidationResponse, AdminMetricsResponse, AdminMetricsConfig, JoinStatusModel, ChatMessageModel, ChatMessagesResponse, ParticipantModel, # Bot provider models BotProviderModel, BotProviderRegisterRequest, BotProviderRegisterResponse, BotProviderListResponse, BotListResponse, BotInfoModel, BotJoinLobbyRequest, BotJoinLobbyResponse, BotJoinPayload, BotLeaveLobbyRequest, BotLeaveLobbyResponse, ) class SessionConfig: """Configuration class for session management""" ANONYMOUS_SESSION_TIMEOUT = int( os.getenv("ANONYMOUS_SESSION_TIMEOUT", "60") ) # 1 minute DISPLACED_SESSION_TIMEOUT = int( os.getenv("DISPLACED_SESSION_TIMEOUT", "10800") ) # 3 hours CLEANUP_INTERVAL = int(os.getenv("CLEANUP_INTERVAL", "300")) # 5 minutes MAX_SESSIONS_PER_CLEANUP = int( os.getenv("MAX_SESSIONS_PER_CLEANUP", "100") ) # Circuit breaker MAX_CHAT_MESSAGES_PER_LOBBY = int(os.getenv("MAX_CHAT_MESSAGES_PER_LOBBY", "100")) SESSION_VALIDATION_INTERVAL = int( os.getenv("SESSION_VALIDATION_INTERVAL", "1800") ) # 30 minutes class BotProviderConfig: """Configuration class for bot provider management""" # Comma-separated list of allowed provider keys # Format: "key1:name1,key2:name2" or just "key1,key2" (names default to keys) ALLOWED_PROVIDERS = os.getenv("BOT_PROVIDER_KEYS", "") @classmethod def get_allowed_providers(cls) -> dict[str, str]: """Parse allowed providers from environment variable Returns: dict mapping provider_key -> provider_name """ if not cls.ALLOWED_PROVIDERS.strip(): return {} providers: dict[str, str] = {} for entry in cls.ALLOWED_PROVIDERS.split(","): entry = entry.strip() if not entry: continue if ":" in entry: key, name = entry.split(":", 1) providers[key.strip()] = name.strip() else: providers[entry] = entry return providers # Thread lock for session operations session_lock = threading.RLock() # Mapping of reserved names to password records (lowercased name -> {salt:..., hash:...}) name_passwords: dict[str, dict[str, str]] = {} # Bot provider registry: provider_id -> BotProviderModel bot_providers: dict[str, BotProviderModel] = {} all_label = "[ all ]" info_label = "[ info ]" todo_label = "[ todo ]" unset_label = "[ ---- ]" def _hash_password(password: str, salt_hex: str | None = None) -> tuple[str, str]: """Return (salt_hex, hash_hex) for the given password. If salt_hex is provided it is used; otherwise a new salt is generated.""" if salt_hex: salt = binascii.unhexlify(salt_hex) else: salt = secrets.token_bytes(16) salt_hex = binascii.hexlify(salt).decode() dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100000) hash_hex = binascii.hexlify(dk).decode() return salt_hex, hash_hex public_url = os.getenv("PUBLIC_URL", "/") if not public_url.endswith("/"): public_url += "/" # Global variables to control background tasks cleanup_task_running = False cleanup_task = None validation_task_running = False validation_task = None async def periodic_cleanup(): """Background task to periodically clean up old sessions""" global cleanup_task_running cleanup_errors = 0 max_consecutive_errors = 5 while cleanup_task_running: try: removed_count = Session.cleanup_old_sessions() if removed_count > 0: logger.info(f"Periodic cleanup removed {removed_count} old sessions") cleanup_errors = 0 # Reset error counter on success # Run cleanup at configured interval await asyncio.sleep(SessionConfig.CLEANUP_INTERVAL) except Exception as e: cleanup_errors += 1 logger.error( f"Error in session cleanup task (attempt {cleanup_errors}): {e}" ) if cleanup_errors >= max_consecutive_errors: logger.error( f"Too many consecutive cleanup errors ({cleanup_errors}), stopping cleanup task" ) break # Exponential backoff on errors await asyncio.sleep(min(60 * cleanup_errors, 300)) async def periodic_validation(): """Background task to periodically validate session integrity""" global validation_task_running while validation_task_running: try: issues = Session.validate_session_integrity() if issues: logger.warning(f"Session integrity issues found: {len(issues)} issues") for issue in issues[:10]: # Log first 10 issues logger.warning(f"Integrity issue: {issue}") await asyncio.sleep(SessionConfig.SESSION_VALIDATION_INTERVAL) except Exception as e: logger.error(f"Error in session validation task: {e}") await asyncio.sleep(300) # Wait 5 minutes before retrying on error @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for startup and shutdown events""" global cleanup_task_running, cleanup_task, validation_task_running, validation_task # Startup logger.info("Starting background tasks...") cleanup_task_running = True validation_task_running = True cleanup_task = asyncio.create_task(periodic_cleanup()) validation_task = asyncio.create_task(periodic_validation()) logger.info("Session cleanup and validation tasks started") yield # Shutdown logger.info("Shutting down background tasks...") cleanup_task_running = False validation_task_running = False # Cancel tasks for task in [cleanup_task, validation_task]: if task: task.cancel() try: await task except asyncio.CancelledError: pass # Clean up all sessions gracefully await Session.cleanup_all_sessions() logger.info("All background tasks stopped and sessions cleaned up") app = FastAPI(lifespan=lifespan) logger.info(f"Starting server with public URL: {public_url}") logger.info( f"Session config - Anonymous timeout: {SessionConfig.ANONYMOUS_SESSION_TIMEOUT}s, " f"Displaced timeout: {SessionConfig.DISPLACED_SESSION_TIMEOUT}s, " f"Cleanup interval: {SessionConfig.CLEANUP_INTERVAL}s" ) # Log bot provider configuration allowed_providers = BotProviderConfig.get_allowed_providers() if allowed_providers: logger.info( f"Bot provider authentication enabled. Allowed providers: {list(allowed_providers.keys())}" ) else: logger.warning("Bot provider authentication disabled. Any provider can register.") # Optional admin token to protect admin endpoints ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", None) def _require_admin(request: Request) -> bool: if not ADMIN_TOKEN: return True token = request.headers.get("X-Admin-Token") return token == ADMIN_TOKEN @app.get(public_url + "api/admin/names", response_model=AdminNamesResponse) def admin_list_names(request: Request): if not _require_admin(request): return Response(status_code=403) # Convert dict format to Pydantic models name_passwords_models = { name: NamePasswordRecord(**record) for name, record in name_passwords.items() } return AdminNamesResponse(name_passwords=name_passwords_models) @app.post(public_url + "api/admin/set_password", response_model=AdminActionResponse) def admin_set_password(request: Request, payload: AdminSetPassword = Body(...)): if not _require_admin(request): return Response(status_code=403) lname = payload.name.lower() salt, hash_hex = _hash_password(payload.password) name_passwords[lname] = {"salt": salt, "hash": hash_hex} Session.save() return AdminActionResponse(status="ok", name=payload.name) @app.post(public_url + "api/admin/clear_password", response_model=AdminActionResponse) def admin_clear_password(request: Request, payload: AdminClearPassword = Body(...)): if not _require_admin(request): return Response(status_code=403) lname = payload.name.lower() if lname in name_passwords: del name_passwords[lname] Session.save() return AdminActionResponse(status="ok", name=payload.name) return AdminActionResponse(status="not_found", name=payload.name) @app.post(public_url + "api/admin/cleanup_sessions", response_model=AdminActionResponse) def admin_cleanup_sessions(request: Request): if not _require_admin(request): return Response(status_code=403) try: removed_count = Session.cleanup_old_sessions() return AdminActionResponse( status="ok", name=f"Removed {removed_count} sessions" ) except Exception as e: logger.error(f"Error during manual session cleanup: {e}") return AdminActionResponse(status="error", name=f"Error: {str(e)}") @app.get(public_url + "api/admin/session_metrics", response_model=AdminMetricsResponse) def admin_session_metrics(request: Request): if not _require_admin(request): return Response(status_code=403) try: return Session.get_cleanup_metrics() except Exception as e: logger.error(f"Error getting session metrics: {e}") return Response(status_code=500) @app.get( public_url + "api/admin/validate_sessions", response_model=AdminValidationResponse ) def admin_validate_sessions(request: Request): if not _require_admin(request): return Response(status_code=403) try: issues = Session.validate_session_integrity() return AdminValidationResponse( status="ok", issues=issues, issue_count=len(issues) ) except Exception as e: logger.error(f"Error validating sessions: {e}") return AdminValidationResponse(status="error", error=str(e)) lobbies: dict[str, Lobby] = {} class Lobby: def __init__(self, name: str, id: str | None = None, private: bool = False): self.id = secrets.token_hex(16) if id is None else id self.short = self.id[:8] self.name = name self.sessions: dict[str, Session] = {} # All lobby members self.private = private self.chat_messages: list[ChatMessageModel] = [] # Store chat messages self.lock = threading.RLock() # Thread safety for lobby operations def getName(self) -> str: return f"{self.short}:{self.name}" async def update_state(self, requesting_session: Session | None = None): with self.lock: users: list[ParticipantModel] = [ ParticipantModel( name=s.name, live=True if s.ws else False, session_id=s.id, protected=True if s.name and s.name.lower() in name_passwords else False, is_bot=s.is_bot, bot_run_id=s.bot_run_id, bot_provider_id=s.bot_provider_id, ) for s in self.sessions.values() if s.name ] if requesting_session: logger.info( f"{requesting_session.getName()} -> lobby_state({self.getName()})" ) if requesting_session.ws: try: await requesting_session.ws.send_json( { "type": "lobby_state", "data": { "participants": [user.model_dump() for user in users] }, } ) except Exception as e: logger.warning( f"Failed to send lobby state to {requesting_session.getName()}: {e}" ) else: logger.warning( f"{requesting_session.getName()} - No WebSocket connection." ) else: # Send to all sessions in lobby failed_sessions: list[Session] = [] for s in self.sessions.values(): logger.info(f"{s.getName()} -> lobby_state({self.getName()})") if s.ws: try: await s.ws.send_json( { "type": "lobby_state", "data": { "participants": [ user.model_dump() for user in users ] }, } ) except Exception as e: logger.warning( f"Failed to send lobby state to {s.getName()}: {e}" ) failed_sessions.append(s) # Clean up failed sessions for failed_session in failed_sessions: failed_session.ws = None def getSession(self, id: str) -> Session | None: with self.lock: return self.sessions.get(id, None) async def addSession(self, session: Session) -> None: with self.lock: if session.id in self.sessions: logger.warning( f"{session.getName()} - Already in lobby {self.getName()}." ) return None self.sessions[session.id] = session await self.update_state() async def removeSession(self, session: Session) -> None: with self.lock: if session.id not in self.sessions: logger.warning(f"{session.getName()} - Not in lobby {self.getName()}.") return None del self.sessions[session.id] await self.update_state() def add_chat_message(self, session: Session, message: str) -> ChatMessageModel: """Add a chat message to the lobby and return the message data""" with self.lock: chat_message = ChatMessageModel( id=secrets.token_hex(8), message=message, sender_name=session.name or session.short, sender_session_id=session.id, timestamp=time.time(), lobby_id=self.id, ) self.chat_messages.append(chat_message) # Keep only the latest messages per lobby if len(self.chat_messages) > SessionConfig.MAX_CHAT_MESSAGES_PER_LOBBY: self.chat_messages = self.chat_messages[ -SessionConfig.MAX_CHAT_MESSAGES_PER_LOBBY : ] return chat_message def get_chat_messages(self, limit: int = 50) -> list[ChatMessageModel]: """Get the most recent chat messages from the lobby""" with self.lock: return self.chat_messages[-limit:] if self.chat_messages else [] async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None: """Broadcast a chat message to all connected sessions in the lobby""" failed_sessions: list[Session] = [] for session in self.sessions.values(): if session.ws: try: await session.ws.send_json( {"type": "chat_message", "data": chat_message.model_dump()} ) except Exception as e: logger.warning( f"Failed to send chat message to {session.getName()}: {e}" ) failed_sessions.append(session) # Clean up failed sessions for failed_session in failed_sessions: failed_session.ws = None class Session: _instances: list[Session] = [] _save_file = "sessions.json" _loaded = False lock = threading.RLock() # Thread safety for class-level operations def __init__(self, id: str, is_bot: bool = False): logger.info(f"Instantiating new session {id} (bot: {is_bot})") with Session.lock: self._instances.append(self) self.id = id self.short = id[:8] self.name = "" self.lobbies: list[Lobby] = [] # List of lobby IDs this session is in self.lobby_peers: dict[ str, list[str] ] = {} # lobby ID -> list of peer session IDs self.ws: WebSocket | None = None self.created_at = time.time() self.last_used = time.time() self.displaced_at: float | None = None # When name was taken over self.is_bot = is_bot # Whether this session represents a bot self.bot_run_id: str | None = None # Bot run ID for tracking self.bot_provider_id: str | None = None # Bot provider ID self.session_lock = threading.RLock() # Instance-level lock self.save() @classmethod def save(cls): try: with cls.lock: sessions_list: list[SessionSaved] = [] for s in cls._instances: with s.session_lock: lobbies_list: list[LobbySaved] = [ LobbySaved( id=lobby.id, name=lobby.name, private=lobby.private ) for lobby in s.lobbies ] sessions_list.append( SessionSaved( id=s.id, name=s.name or "", lobbies=lobbies_list, created_at=s.created_at, last_used=s.last_used, displaced_at=s.displaced_at, is_bot=s.is_bot, bot_run_id=s.bot_run_id, bot_provider_id=s.bot_provider_id, ) ) # Prepare name password store for persistence (salt+hash). Only structured records are supported. saved_pw: dict[str, NamePasswordRecord] = { name: NamePasswordRecord(**record) for name, record in name_passwords.items() } payload_model = SessionsPayload( sessions=sessions_list, name_passwords=saved_pw ) payload = payload_model.model_dump() # Atomic write using temp file temp_file = cls._save_file + ".tmp" with open(temp_file, "w") as f: json.dump(payload, f, indent=2) # Atomic rename os.rename(temp_file, cls._save_file) logger.info( f"Saved {len(sessions_list)} sessions and {len(saved_pw)} name passwords to {cls._save_file}" ) except Exception as e: logger.error(f"Failed to save sessions: {e}") # Clean up temp file if it exists try: if os.path.exists(cls._save_file + ".tmp"): os.remove(cls._save_file + ".tmp") except Exception as e: pass @classmethod def load(cls): if not os.path.exists(cls._save_file): logger.info(f"No session save file found: {cls._save_file}") return try: with open(cls._save_file, "r") as f: raw = json.load(f) except Exception as e: logger.error(f"Failed to read session save file: {e}") return try: payload = SessionsPayload.model_validate(raw) except ValidationError as e: logger.exception(f"Failed to validate sessions payload: {e}") return # Populate in-memory structures from payload (no backwards compatibility code) name_passwords.clear() for name, rec in payload.name_passwords.items(): # rec is a NamePasswordRecord name_passwords[name] = {"salt": rec.salt, "hash": rec.hash} current_time = time.time() sessions_loaded = 0 sessions_expired = 0 with cls.lock: for s_saved in payload.sessions: # Check if this session should be expired during loading created_at = getattr(s_saved, "created_at", time.time()) last_used = getattr(s_saved, "last_used", time.time()) displaced_at = getattr(s_saved, "displaced_at", None) name = s_saved.name or "" # Apply same removal criteria as cleanup_old_sessions should_expire = cls._should_remove_session_static( name, None, created_at, last_used, displaced_at, current_time ) if should_expire: sessions_expired += 1 logger.info(f"Expiring session {s_saved.id[:8]}:{name} during load") continue # Skip loading this expired session session = Session(s_saved.id, is_bot=getattr(s_saved, "is_bot", False)) session.name = name # Load timestamps, with defaults for backward compatibility session.created_at = created_at session.last_used = last_used session.displaced_at = displaced_at # Load bot information with defaults for backward compatibility session.is_bot = getattr(s_saved, "is_bot", False) session.bot_run_id = getattr(s_saved, "bot_run_id", None) session.bot_provider_id = getattr(s_saved, "bot_provider_id", None) for lobby_saved in s_saved.lobbies: session.lobbies.append( Lobby( name=lobby_saved.name, id=lobby_saved.id, private=lobby_saved.private, ) ) logger.info( f"Loaded session {session.getName()} with {len(session.lobbies)} lobbies" ) for lobby in session.lobbies: lobbies[lobby.id] = Lobby( name=lobby.name, id=lobby.id, private=lobby.private ) # Ensure lobby exists sessions_loaded += 1 logger.info( f"Loaded {sessions_loaded} sessions and {len(name_passwords)} name passwords from {cls._save_file}" ) if sessions_expired > 0: logger.info(f"Expired {sessions_expired} old sessions during load") # Save immediately to persist the cleanup cls.save() @classmethod def getSession(cls, id: str) -> Session | None: if not cls._loaded: cls.load() logger.info(f"Loaded {len(cls._instances)} sessions from disk...") cls._loaded = True with cls.lock: for s in cls._instances: if s.id == id: return s return None @classmethod def isUniqueName(cls, name: str) -> bool: if not name: return False with cls.lock: for s in cls._instances: with s.session_lock: if s.name.lower() == name.lower(): return False return True @classmethod def getSessionByName(cls, name: str) -> Optional["Session"]: if not name: return None lname = name.lower() with cls.lock: for s in cls._instances: with s.session_lock: if s.name and s.name.lower() == lname: return s return None def getName(self) -> str: with self.session_lock: return f"{self.short}:{self.name if self.name else unset_label}" def setName(self, name: str): with self.session_lock: self.name = name self.update_last_used() self.save() def update_last_used(self): """Update the last_used timestamp""" with self.session_lock: self.last_used = time.time() def mark_displaced(self): """Mark this session as having its name taken over""" with self.session_lock: self.displaced_at = time.time() @staticmethod def _should_remove_session_static( name: str, ws: WebSocket | None, created_at: float, last_used: float, displaced_at: float | None, current_time: float, ) -> bool: """Static method to determine if a session should be removed""" # Rule 1: Delete sessions with no active connection and no name that are older than threshold if ( not ws and not name and current_time - created_at > SessionConfig.ANONYMOUS_SESSION_TIMEOUT ): return True # Rule 2: Delete inactive sessions that had their nick taken over and haven't been used recently if ( not ws and displaced_at is not None and current_time - last_used > SessionConfig.DISPLACED_SESSION_TIMEOUT ): return True return False def _should_remove(self, current_time: float) -> bool: """Check if this session should be removed""" with self.session_lock: return self._should_remove_session_static( self.name, self.ws, self.created_at, self.last_used, self.displaced_at, current_time, ) @classmethod def _remove_session_safely(cls, session: Session, empty_lobbies: set[str]) -> None: """Safely remove a session and track affected lobbies""" try: with session.session_lock: # Remove from lobbies first for lobby in session.lobbies[ : ]: # Copy list to avoid modification during iteration try: with lobby.lock: if session.id in lobby.sessions: del lobby.sessions[session.id] if len(lobby.sessions) == 0: empty_lobbies.add(lobby.id) if lobby.id in session.lobby_peers: del session.lobby_peers[lobby.id] except Exception as e: logger.warning( f"Error removing session {session.getName()} from lobby {lobby.getName()}: {e}" ) # Close WebSocket if open if session.ws: try: asyncio.create_task(session.ws.close()) except Exception as e: logger.warning( f"Error closing WebSocket for {session.getName()}: {e}" ) session.ws = None # Remove from instances list with cls.lock: if session in cls._instances: cls._instances.remove(session) except Exception as e: logger.error( f"Error during safe session removal for {session.getName()}: {e}" ) @classmethod def _cleanup_empty_lobbies(cls, empty_lobbies: set[str]) -> int: """Clean up empty lobbies from global lobbies dict""" removed_count = 0 for lobby_id in empty_lobbies: if lobby_id in lobbies: lobby_name = lobbies[lobby_id].getName() del lobbies[lobby_id] logger.info(f"Removed empty lobby {lobby_name}") removed_count += 1 return removed_count @classmethod def cleanup_old_sessions(cls) -> int: """Clean up old sessions based on the specified criteria with improved safety""" current_time = time.time() sessions_removed = 0 try: # Circuit breaker - don't remove too many sessions at once sessions_to_remove: list[Session] = [] empty_lobbies: set[str] = set() with cls.lock: # Identify sessions to remove (up to max limit) for session in cls._instances[:]: if ( len(sessions_to_remove) >= SessionConfig.MAX_SESSIONS_PER_CLEANUP ): logger.warning( f"Hit session cleanup limit ({SessionConfig.MAX_SESSIONS_PER_CLEANUP}), " f"stopping cleanup. Remaining sessions will be cleaned up in next cycle." ) break if session._should_remove(current_time): sessions_to_remove.append(session) logger.info( f"Marking session {session.getName()} for removal - " f"criteria: no_ws={session.ws is None}, no_name={not session.name}, " f"age={current_time - session.created_at:.0f}s, " f"displaced={session.displaced_at is not None}, " f"unused={current_time - session.last_used:.0f}s" ) # Remove the identified sessions for session in sessions_to_remove: cls._remove_session_safely(session, empty_lobbies) sessions_removed += 1 # Clean up empty lobbies empty_lobbies_removed = cls._cleanup_empty_lobbies(empty_lobbies) # Save state if we made changes if sessions_removed > 0: cls.save() logger.info( f"Session cleanup completed: removed {sessions_removed} sessions, " f"{empty_lobbies_removed} empty lobbies" ) except Exception as e: logger.error(f"Error during session cleanup: {e}") # Don't re-raise - cleanup should be resilient return sessions_removed @classmethod def get_cleanup_metrics(cls) -> AdminMetricsResponse: """Return cleanup metrics for monitoring""" current_time = time.time() with cls.lock: total_sessions = len(cls._instances) active_sessions = 0 named_sessions = 0 displaced_sessions = 0 old_anonymous = 0 old_displaced = 0 for s in cls._instances: with s.session_lock: if s.ws: active_sessions += 1 if s.name: named_sessions += 1 if s.displaced_at is not None: displaced_sessions += 1 if ( not s.ws and current_time - s.last_used > SessionConfig.DISPLACED_SESSION_TIMEOUT ): old_displaced += 1 if ( not s.ws and not s.name and current_time - s.created_at > SessionConfig.ANONYMOUS_SESSION_TIMEOUT ): old_anonymous += 1 config = AdminMetricsConfig( anonymous_timeout=SessionConfig.ANONYMOUS_SESSION_TIMEOUT, displaced_timeout=SessionConfig.DISPLACED_SESSION_TIMEOUT, cleanup_interval=SessionConfig.CLEANUP_INTERVAL, max_cleanup_per_cycle=SessionConfig.MAX_SESSIONS_PER_CLEANUP, ) return AdminMetricsResponse( total_sessions=total_sessions, active_sessions=active_sessions, named_sessions=named_sessions, displaced_sessions=displaced_sessions, old_anonymous_sessions=old_anonymous, old_displaced_sessions=old_displaced, total_lobbies=len(lobbies), cleanup_candidates=old_anonymous + old_displaced, config=config, ) @classmethod def validate_session_integrity(cls) -> list[str]: """Validate session data integrity""" issues: list[str] = [] try: with cls.lock: for session in cls._instances: with session.session_lock: # Check for orphaned lobby references for lobby in session.lobbies: if lobby.id not in lobbies: issues.append( f"Session {session.id[:8]}:{session.name} references missing lobby {lobby.id}" ) # Check for inconsistent peer relationships for lobby_id, peer_ids in session.lobby_peers.items(): lobby = lobbies.get(lobby_id) if lobby: with lobby.lock: if session.id not in lobby.sessions: issues.append( f"Session {session.id[:8]}:{session.name} has peers in lobby {lobby_id} but not in lobby.sessions" ) # Check if peer sessions actually exist for peer_id in peer_ids: if peer_id not in lobby.sessions: issues.append( f"Session {session.id[:8]}:{session.name} references non-existent peer {peer_id} in lobby {lobby_id}" ) else: issues.append( f"Session {session.id[:8]}:{session.name} has peer list for non-existent lobby {lobby_id}" ) # Check lobbies for consistency for lobby_id, lobby in lobbies.items(): with lobby.lock: for session_id in lobby.sessions: found_session = None for s in cls._instances: if s.id == session_id: found_session = s break if not found_session: issues.append( f"Lobby {lobby_id} references non-existent session {session_id}" ) else: with found_session.session_lock: if lobby not in found_session.lobbies: issues.append( f"Lobby {lobby_id} contains session {session_id} but session doesn't reference lobby" ) except Exception as e: logger.error(f"Error during session validation: {e}") issues.append(f"Validation error: {str(e)}") return issues @classmethod async def cleanup_all_sessions(cls): """Clean up all sessions during shutdown""" logger.info("Starting graceful session cleanup...") try: with cls.lock: sessions_to_cleanup = cls._instances[:] for session in sessions_to_cleanup: try: with session.session_lock: # Close WebSocket connections if session.ws: try: await session.ws.close() except Exception as e: logger.warning( f"Error closing WebSocket for {session.getName()}: {e}" ) session.ws = None # Remove from lobbies for lobby in session.lobbies[:]: try: await session.part(lobby) except Exception as e: logger.warning( f"Error removing {session.getName()} from lobby: {e}" ) except Exception as e: logger.error(f"Error cleaning up session {session.getName()}: {e}") # Clear all data structures with cls.lock: cls._instances.clear() lobbies.clear() logger.info( f"Graceful session cleanup completed for {len(sessions_to_cleanup)} sessions" ) except Exception as e: logger.error(f"Error during graceful session cleanup: {e}") async def join(self, lobby: Lobby): if not self.ws: logger.error( f"{self.getName()} - No WebSocket connection. Lobby not available." ) return with self.session_lock: if lobby.id in self.lobby_peers or self.id in lobby.sessions: logger.info(f"{self.getName()} - Already joined to {lobby.getName()}.") data = JoinStatusModel( status="Joined", message=f"Already joined to lobby {lobby.getName()}", ) try: await self.ws.send_json( {"type": "join_status", "data": data.model_dump()} ) except Exception as e: logger.warning( f"Failed to send join status to {self.getName()}: {e}" ) return # Initialize the peer list for this lobby with self.session_lock: self.lobbies.append(lobby) self.lobby_peers[lobby.id] = [] with lobby.lock: peer_sessions = list(lobby.sessions.values()) for peer_session in peer_sessions: if peer_session.id == self.id: logger.error( "Should not happen: self in lobby.sessions while not in lobby." ) continue if not peer_session.ws: logger.warning( f"{self.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Removing." ) with lobby.lock: if peer_session.id in lobby.sessions: del lobby.sessions[peer_session.id] continue # Add the peer to session's RTC peer list with self.session_lock: self.lobby_peers[lobby.id].append(peer_session.id) # Add this user as an RTC peer to each existing peer with peer_session.session_lock: if lobby.id not in peer_session.lobby_peers: peer_session.lobby_peers[lobby.id] = [] peer_session.lobby_peers[lobby.id].append(self.id) logger.info( f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName()}, {lobby.getName()}, should_create_offer=False)" ) try: await peer_session.ws.send_json( { "type": "addPeer", "data": { "peer_id": self.id, "peer_name": self.name, "should_create_offer": False, }, } ) except Exception as e: logger.warning( f"Failed to send addPeer to {peer_session.getName()}: {e}" ) # Add each other peer to the caller logger.info( f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName()}, {lobby.getName()}, should_create_offer=True)" ) try: await self.ws.send_json( { "type": "addPeer", "data": { "peer_id": peer_session.id, "peer_name": peer_session.name, "should_create_offer": True, }, } ) except Exception as e: logger.warning(f"Failed to send addPeer to {self.getName()}: {e}") # Add this user as an RTC peer await lobby.addSession(self) Session.save() try: await self.ws.send_json( {"type": "join_status", "data": {"status": "Joined"}} ) except Exception as e: logger.warning(f"Failed to send join confirmation to {self.getName()}: {e}") async def part(self, lobby: Lobby): with self.session_lock: if lobby.id not in self.lobby_peers or self.id not in lobby.sessions: logger.info( f"{self.getName()} - Attempt to part non-joined lobby {lobby.getName()}." ) if self.ws: try: await self.ws.send_json( { "type": "error", "error": "Attempt to part non-joined lobby", } ) except Exception: pass return logger.info(f"{self.getName()} <- part({lobby.getName()}) - Lobby part.") lobby_peers = self.lobby_peers[lobby.id][:] # Copy the list del self.lobby_peers[lobby.id] if lobby in self.lobbies: self.lobbies.remove(lobby) # Remove this peer from all other RTC peers, and remove each peer from this peer for peer_session_id in lobby_peers: peer_session = getSession(peer_session_id) if not peer_session: logger.warning( f"{self.getName()} <- part({lobby.getName()}) - Peer session {peer_session_id} not found. Skipping." ) continue if peer_session.ws: logger.info( f"{peer_session.getName()} <- remove_peer({self.getName()})" ) try: await peer_session.ws.send_json( { "type": "removePeer", "data": {"peer_name": self.name, "peer_id": self.id}, } ) except Exception as e: logger.warning( f"Failed to send removePeer to {peer_session.getName()}: {e}" ) else: logger.warning( f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection for {peer_session.getName()}. Skipping." ) # Remove from peer's lobby_peers with peer_session.session_lock: if ( lobby.id in peer_session.lobby_peers and self.id in peer_session.lobby_peers[lobby.id] ): peer_session.lobby_peers[lobby.id].remove(self.id) if self.ws: logger.info( f"{self.getName()} <- remove_peer({peer_session.getName()})" ) try: await self.ws.send_json( { "type": "removePeer", "data": { "peer_name": peer_session.name, "peer_id": peer_session.id, }, } ) except Exception as e: logger.warning( f"Failed to send removePeer to {self.getName()}: {e}" ) else: logger.error( f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection." ) await lobby.removeSession(self) Session.save() def getName(session: Session | None) -> str | None: if session and session.name: return session.name return None def getSession(session_id: str) -> Session | None: return Session.getSession(session_id) def getLobby(lobby_id: str) -> Lobby: lobby = lobbies.get(lobby_id, None) if not lobby: logger.error(f"Lobby not found: {lobby_id}") raise Exception(f"Lobby not found: {lobby_id}") return lobby def getLobbyByName(lobby_name: str) -> Lobby | None: for lobby in lobbies.values(): if lobby.name == lobby_name: return lobby return None # API endpoints @app.get(f"{public_url}api/health", response_model=HealthResponse) def health(): logger.info("Health check endpoint called.") return HealthResponse(status="ok") # A session (cookie) is bound to a single user (name). # A user can be in multiple lobbies, but a session is unique to a single user. # A user can change their name, but the session ID remains the same and the name # updates for all lobbies. @app.get(f"{public_url}api/session", response_model=SessionResponse) async def session( request: Request, response: Response, session_id: str | None = Cookie(default=None) ) -> Response | SessionResponse: if session_id is None: session_id = secrets.token_hex(16) response.set_cookie(key="session_id", value=session_id) # Validate that session_id is a hex string of length 32 elif len(session_id) != 32 or not all(c in "0123456789abcdef" for c in session_id): return Response( content=json.dumps({"error": "Invalid session_id"}), status_code=400, media_type="application/json", ) print(f"[{session_id[:8]}]: Browser hand-shake achieved.") session = getSession(session_id) if not session: session = Session(session_id) logger.info(f"{session.getName()}: New session created.") else: session.update_last_used() # Update activity on session resumption logger.info(f"{session.getName()}: Existing session resumed.") # Part all lobbies for this session that have no active websocket with session.session_lock: lobbies_to_part = session.lobbies[:] for lobby in lobbies_to_part: try: await session.part(lobby) except Exception as e: logger.error( f"{session.getName()} - Error parting lobby {lobby.getName()}: {e}" ) with session.session_lock: return SessionResponse( id=session_id, name=session.name if session.name else "", lobbies=[ LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private) for lobby in session.lobbies ], ) @app.get(public_url + "api/lobby", response_model=LobbiesResponse) async def get_lobbies(request: Request, response: Response) -> LobbiesResponse: return LobbiesResponse( lobbies=[ LobbyListItem(id=lobby.id, name=lobby.name) for lobby in lobbies.values() if not lobby.private ] ) @app.post(public_url + "api/lobby/{session_id}", response_model=LobbyCreateResponse) async def lobby_create( request: Request, response: Response, session_id: str = Path(...), create_request: LobbyCreateRequest = Body(...), ) -> Response | LobbyCreateResponse: if create_request.type != "lobby_create": return Response( content=json.dumps({"error": "Invalid request type"}), status_code=400, media_type="application/json", ) data = create_request.data session = getSession(session_id) if not session: return Response( content=json.dumps({"error": f"Session not found ({session_id})"}), status_code=404, media_type="application/json", ) logger.info( f"{session.getName()} lobby_create: {data.name} (private={data.private})" ) lobby = getLobbyByName(data.name) if not lobby: lobby = Lobby( data.name, private=data.private, ) lobbies[lobby.id] = lobby logger.info(f"{session.getName()} <- lobby_create({lobby.short}:{lobby.name})") return LobbyCreateResponse( type="lobby_created", data=LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private), ) @app.get(public_url + "api/lobby/{lobby_id}/chat", response_model=ChatMessagesResponse) async def get_chat_messages( request: Request, lobby_id: str = Path(...), limit: int = 50, ) -> Response | ChatMessagesResponse: """Get chat messages for a lobby""" try: lobby = getLobby(lobby_id) except Exception as e: return Response( content=json.dumps({"error": str(e)}), status_code=404, media_type="application/json", ) messages = lobby.get_chat_messages(limit) return ChatMessagesResponse(messages=messages) # ============================================================================= # Bot Provider API Endpoints # ============================================================================= @app.post( public_url + "api/bots/providers/register", response_model=BotProviderRegisterResponse, ) async def register_bot_provider( request: BotProviderRegisterRequest, ) -> BotProviderRegisterResponse: """Register a new bot provider with authentication""" import uuid # Check if provider authentication is enabled allowed_providers = BotProviderConfig.get_allowed_providers() if allowed_providers: # Authentication is enabled - validate provider key if request.provider_key not in allowed_providers: logger.warning( f"Rejected bot provider registration with invalid key: {request.provider_key}" ) raise HTTPException( status_code=403, detail="Invalid provider key. Bot provider is not authorized to register.", ) # Check if there's already an active provider with this key and remove it providers_to_remove: list[str] = [] for existing_provider_id, existing_provider in bot_providers.items(): if existing_provider.provider_key == request.provider_key: providers_to_remove.append(existing_provider_id) logger.info( f"Removing stale bot provider: {existing_provider.name} (ID: {existing_provider_id})" ) # Remove stale providers for provider_id_to_remove in providers_to_remove: del bot_providers[provider_id_to_remove] provider_id = str(uuid.uuid4()) now = time.time() provider = BotProviderModel( provider_id=provider_id, base_url=request.base_url.rstrip("/"), name=request.name, description=request.description, provider_key=request.provider_key, registered_at=now, last_seen=now, ) bot_providers[provider_id] = provider logger.info( f"Registered bot provider: {request.name} at {request.base_url} with key: {request.provider_key}" ) return BotProviderRegisterResponse(provider_id=provider_id) @app.get(public_url + "api/bots/providers", response_model=BotProviderListResponse) async def list_bot_providers() -> BotProviderListResponse: """List all registered bot providers""" return BotProviderListResponse(providers=list(bot_providers.values())) @app.get(public_url + "api/bots", response_model=BotListResponse) async def list_available_bots() -> BotListResponse: """List all available bots from all registered providers""" bots: dict[str, BotInfoModel] = {} providers: dict[str, str] = {} # Update last_seen timestamps and fetch bots from each provider for provider_id, provider in bot_providers.items(): try: provider.last_seen = time.time() # Make HTTP request to provider's /bots endpoint async with httpx.AsyncClient() as client: response = await client.get(f"{provider.base_url}/bots", timeout=5.0) if response.status_code == 200: provider_bots = response.json() # provider_bots should be a dict of bot_name -> bot_info for bot_name, bot_info in provider_bots.items(): bots[bot_name] = BotInfoModel(**bot_info) providers[bot_name] = provider_id else: logger.warning( f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}" ) except Exception as e: logger.error(f"Error fetching bots from provider {provider.name}: {e}") continue return BotListResponse(bots=bots, providers=providers) @app.post(public_url + "api/bots/{bot_name}/join", response_model=BotJoinLobbyResponse) async def request_bot_join_lobby( bot_name: str, request: BotJoinLobbyRequest ) -> BotJoinLobbyResponse: """Request a bot to join a specific lobby""" # Find which provider has this bot target_provider_id = request.provider_id if not target_provider_id: # Auto-discover provider for this bot for provider_id, provider in bot_providers.items(): try: async with httpx.AsyncClient() as client: response = await client.get( f"{provider.base_url}/bots", timeout=5.0 ) if response.status_code == 200: provider_bots = response.json() if bot_name in provider_bots: target_provider_id = provider_id break except Exception: continue if not target_provider_id or target_provider_id not in bot_providers: raise HTTPException(status_code=404, detail="Bot or provider not found") provider = bot_providers[target_provider_id] # Get the lobby to validate it exists try: getLobby(request.lobby_id) # Just validate it exists except Exception: raise HTTPException(status_code=404, detail="Lobby not found") # Create a session for the bot bot_session_id = secrets.token_hex(16) # Create the Session object for the bot bot_session = Session(bot_session_id, is_bot=True) logger.info(f"Created bot session for: {bot_session.getName()}") # Determine server URL for the bot to connect back to # Use the server's public URL or construct from request server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000") if server_base_url.endswith("/"): server_base_url = server_base_url[:-1] bot_nick = request.nick or f"{bot_name}-bot-{bot_session_id[:8]}" # Prepare the join request for the bot provider bot_join_payload = BotJoinPayload( lobby_id=request.lobby_id, session_id=bot_session_id, nick=bot_nick, server_url=f"{server_base_url}{public_url}".rstrip("/"), insecure=True, # Accept self-signed certificates in development ) try: # Make request to bot provider async with httpx.AsyncClient() as client: response = await client.post( f"{provider.base_url}/bots/{bot_name}/join", json=bot_join_payload.model_dump(), timeout=10.0, ) if response.status_code == 200: result = response.json() run_id = result.get("run_id", "unknown") # Update bot session with run and provider information with bot_session.session_lock: bot_session.bot_run_id = run_id bot_session.bot_provider_id = target_provider_id bot_session.setName(bot_nick) logger.info( f"Bot {bot_name} requested to join lobby {request.lobby_id}" ) return BotJoinLobbyResponse( status="requested", bot_name=bot_name, run_id=run_id, provider_id=target_provider_id, ) else: logger.error( f"Bot provider returned error: HTTP {response.status_code}: {response.text}" ) raise HTTPException( status_code=502, detail=f"Bot provider error: {response.status_code}", ) except httpx.TimeoutException: raise HTTPException(status_code=504, detail="Bot provider timeout") except Exception as e: logger.error(f"Error requesting bot join: {e}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.post(public_url + "api/bots/leave", response_model=BotLeaveLobbyResponse) async def request_bot_leave_lobby( request: BotLeaveLobbyRequest, ) -> BotLeaveLobbyResponse: """Request a bot to leave from all lobbies and disconnect""" # Find the bot session bot_session = getSession(request.session_id) if not bot_session: raise HTTPException(status_code=404, detail="Bot session not found") if not bot_session.is_bot: raise HTTPException(status_code=400, detail="Session is not a bot") run_id = bot_session.bot_run_id provider_id = bot_session.bot_provider_id logger.info(f"Requesting bot {bot_session.getName()} to leave all lobbies") # Try to stop the bot at the provider level if we have the information if provider_id and run_id and provider_id in bot_providers: provider = bot_providers[provider_id] try: async with httpx.AsyncClient() as client: response = await client.post( f"{provider.base_url}/bots/runs/{run_id}/stop", timeout=5.0, ) if response.status_code == 200: logger.info( f"Successfully requested bot provider to stop run {run_id}" ) else: logger.warning( f"Bot provider returned error when stopping: HTTP {response.status_code}" ) except Exception as e: logger.warning(f"Failed to request bot stop from provider: {e}") # Force disconnect the bot session from all lobbies with bot_session.session_lock: lobbies_to_part = bot_session.lobbies[:] for lobby in lobbies_to_part: try: await bot_session.part(lobby) except Exception as e: logger.warning(f"Error parting bot from lobby {lobby.getName()}: {e}") # Close WebSocket connection if it exists if bot_session.ws: try: await bot_session.ws.close() except Exception as e: logger.warning(f"Error closing bot WebSocket: {e}") bot_session.ws = None return BotLeaveLobbyResponse( status="disconnected", session_id=request.session_id, run_id=run_id, ) # Register websocket endpoint directly on app with full public_url path @app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}") async def lobby_join( websocket: WebSocket, lobby_id: str | None = Path(...), session_id: str | None = Path(...), ): await websocket.accept() if lobby_id is None: await websocket.send_json( {"type": "error", "error": "Invalid or missing lobby"} ) await websocket.close() return if session_id is None: await websocket.send_json( {"type": "error", "error": "Invalid or missing session"} ) await websocket.close() return session = getSession(session_id) if not session: # logger.error(f"Invalid session ID {session_id}") await websocket.send_json( {"type": "error", "error": f"Invalid session ID {session_id}"} ) await websocket.close() return lobby = None try: lobby = getLobby(lobby_id) except Exception as e: await websocket.send_json({"type": "error", "error": str(e)}) await websocket.close() return logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})") session.ws = websocket session.update_last_used() # Update activity timestamp # Check if session is already in lobby and clean up if needed with lobby.lock: if session.id in lobby.sessions: logger.info( f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining." ) try: await session.part(lobby) await lobby.removeSession(session) except Exception as e: logger.warning(f"Error cleaning up stale session: {e}") # Notify existing peers about new user failed_peers: list[str] = [] with lobby.lock: peer_sessions = list(lobby.sessions.values()) for peer_session in peer_sessions: if not peer_session.ws: logger.warning( f"{session.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Marking for removal." ) failed_peers.append(peer_session.id) continue logger.info(f"{session.getName()} -> user_joined({peer_session.getName()})") try: await peer_session.ws.send_json( { "type": "user_joined", "data": { "session_id": session.id, "name": session.name, }, } ) except Exception as e: logger.warning( f"Failed to notify {peer_session.getName()} of user join: {e}" ) failed_peers.append(peer_session.id) # Clean up failed peers with lobby.lock: for failed_peer_id in failed_peers: if failed_peer_id in lobby.sessions: del lobby.sessions[failed_peer_id] try: while True: packet = await websocket.receive_json() session.update_last_used() # Update activity on each message type = packet.get("type", None) data: dict[str, Any] | None = packet.get("data", None) if not type: logger.error(f"{session.getName()} - Invalid request: {packet}") await websocket.send_json({"type": "error", "error": "Invalid request"}) continue # logger.info(f"{session.getName()} <- RAW Rx: {data}") match type: case "set_name": if not data: logger.error(f"{session.getName()} - set_name missing data") await websocket.send_json( {"type": "error", "error": "set_name missing data"} ) continue name = data.get("name") password = data.get("password") logger.info(f"{session.getName()} <- set_name({name})") if not name: logger.error(f"{session.getName()} - Name required") await websocket.send_json( {"type": "error", "error": "Name required"} ) continue # Name takeover / password logic lname = name.lower() # If name is unused, allow and optionally save password if Session.isUniqueName(name): # If a password was provided, save it (hash+salt) for this name if password: salt, hash_hex = _hash_password(password) name_passwords[lname] = {"salt": salt, "hash": hash_hex} session.setName(name) logger.info(f"{session.getName()}: -> update('name', {name})") await websocket.send_json( { "type": "update_name", "data": { "name": name, "protected": True if name.lower() in name_passwords else False, }, } ) # For any clients in any lobby with this session, update their user lists await lobby.update_state() continue # Name is taken. Check if a password exists for the name and matches. saved_pw = name_passwords.get(lname) if not saved_pw: logger.warning( f"{session.getName()} - Name already taken (no password set)" ) await websocket.send_json( {"type": "error", "error": "Name already taken"} ) continue # Expect structured record with salt+hash only match_password = False # saved_pw should be a dict[str,str] with 'salt' and 'hash' salt = saved_pw.get("salt") _, candidate_hash = _hash_password( password if password else "", salt_hex=salt ) if candidate_hash == saved_pw.get("hash"): match_password = True else: # No structured password record available match_password = False if not match_password: logger.warning( f"{session.getName()} - Name takeover attempted with wrong or missing password" ) await websocket.send_json( { "type": "error", "error": "Invalid password for name takeover", } ) continue # Password matches: perform takeover. Find the current session holding the name. # Find the currently existing session (if any) with that name displaced = Session.getSessionByName(name) if displaced and displaced.id == session.id: displaced = None # If found, change displaced session to a unique fallback name and notify peers if displaced: # Create a unique fallback name fallback = f"{displaced.name}-{displaced.short}" # Ensure uniqueness if not Session.isUniqueName(fallback): # append random suffix until unique while not Session.isUniqueName(fallback): fallback = f"{displaced.name}-{secrets.token_hex(3)}" displaced.setName(fallback) displaced.mark_displaced() logger.info( f"{displaced.getName()} <- displaced by takeover, new name {fallback}" ) # Notify displaced session (if connected) if displaced.ws: try: await displaced.ws.send_json( { "type": "update_name", "data": { "name": fallback, "protected": False, }, } ) except Exception: logger.exception( "Failed to notify displaced session websocket" ) # Update all lobbies the displaced session was in with displaced.session_lock: displaced_lobbies = displaced.lobbies[:] for d_lobby in displaced_lobbies: try: await d_lobby.update_state() except Exception: logger.exception( "Failed to update lobby state for displaced session" ) # Now assign the requested name to the current session session.setName(name) logger.info( f"{session.getName()}: -> update('name', {name}) (takeover)" ) await websocket.send_json( { "type": "update_name", "data": { "name": name, "protected": True if name.lower() in name_passwords else False, }, } ) # Notify lobbies for this session await lobby.update_state() case "list_users": await lobby.update_state(session) case "get_chat_messages": # Send recent chat messages to the requesting client messages = lobby.get_chat_messages(50) await websocket.send_json( { "type": "chat_messages", "data": { "messages": [msg.model_dump() for msg in messages] }, } ) case "send_chat_message": if not data or "message" not in data: logger.error( f"{session.getName()} - send_chat_message missing message" ) await websocket.send_json( { "type": "error", "error": "send_chat_message missing message", } ) continue if not session.name: logger.error( f"{session.getName()} - Cannot send chat message without name" ) await websocket.send_json( { "type": "error", "error": "Must set name before sending chat messages", } ) continue message_text = str(data["message"]).strip() if not message_text: continue # Add the message to the lobby and broadcast it chat_message = lobby.add_chat_message(session, message_text) await lobby.broadcast_chat_message(chat_message) logger.info( f"{session.getName()} sent chat message to {lobby.getName()}: {message_text[:50]}..." ) case "join": logger.info(f"{session.getName()} <- join({lobby.getName()})") await session.join(lobby=lobby) case "part": logger.info(f"{session.getName()} <- part {lobby.getName()}") await session.part(lobby=lobby) case "relayICECandidate": logger.info(f"{session.getName()} <- relayICECandidate") if not data: logger.error( f"{session.getName()} - relayICECandidate missing data" ) await websocket.send_json( {"type": "error", "error": "relayICECandidate missing data"} ) continue with session.session_lock: if ( lobby.id not in session.lobby_peers or session.id not in lobby.sessions ): logger.error( f"{session.short}:{session.name} <- relayICECandidate - Not an RTC peer ({session.id})" ) await websocket.send_json( {"type": "error", "error": "Not joined to lobby"} ) continue session_peers = session.lobby_peers[lobby.id] peer_id = data.get("peer_id") if peer_id not in session_peers: logger.error( f"{session.getName()} <- relayICECandidate - Not an RTC peer({peer_id}) in {session_peers}" ) await websocket.send_json( { "type": "error", "error": f"Target peer {peer_id} not found", } ) continue candidate = data.get("candidate") message: dict[str, Any] = { "type": "iceCandidate", "data": { "peer_id": session.id, "peer_name": session.name, "candidate": candidate, }, } peer_session = lobby.getSession(peer_id) if not peer_session or not peer_session.ws: logger.warning( f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}." ) continue logger.info( f"{session.getName()} -> iceCandidate({peer_session.getName()})" ) try: await peer_session.ws.send_json(message) except Exception as e: logger.warning(f"Failed to relay ICE candidate: {e}") case "relaySessionDescription": logger.info(f"{session.getName()} <- relaySessionDescription") if not data: logger.error( f"{session.getName()} - relaySessionDescription missing data" ) await websocket.send_json( { "type": "error", "error": "relaySessionDescription missing data", } ) continue with session.session_lock: if ( lobby.id not in session.lobby_peers or session.id not in lobby.sessions ): logger.error( f"{session.short}:{session.name} <- relaySessionDescription - Not an RTC peer ({session.id})" ) await websocket.send_json( {"type": "error", "error": "Not joined to lobby"} ) continue lobby_peers = session.lobby_peers[lobby.id] peer_id = data.get("peer_id") if peer_id not in lobby_peers: logger.error( f"{session.getName()} <- relaySessionDescription - Not an RTC peer({peer_id}) in {lobby_peers}" ) await websocket.send_json( { "type": "error", "error": f"Target peer {peer_id} not found", } ) continue if not peer_id: logger.error( f"{session.getName()} - relaySessionDescription missing peer_id" ) await websocket.send_json( { "type": "error", "error": "relaySessionDescription missing peer_id", } ) continue peer_session = lobby.getSession(peer_id) if not peer_session or not peer_session.ws: logger.warning( f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}." ) continue session_description = data.get("session_description") message = { "type": "sessionDescription", "data": { "peer_id": session.id, "peer_name": session.name, "session_description": session_description, }, } logger.info( f"{session.getName()} -> sessionDescription({peer_session.getName()})" ) try: await peer_session.ws.send_json(message) except Exception as e: logger.warning(f"Failed to relay session description: {e}") case _: await websocket.send_json( { "type": "error", "error": f"Unknown request type: {type}", } ) except WebSocketDisconnect: logger.info(f"{session.getName()} <- WebSocket disconnected for user.") # Cleanup: remove session from lobby and sessions dict session.ws = None if session.id in lobby.sessions: try: await session.part(lobby) except Exception as e: logger.warning(f"Error during websocket disconnect cleanup: {e}") try: await lobby.update_state() except Exception as e: logger.warning(f"Error updating lobby state after disconnect: {e}") # Clean up empty lobbies with lobby.lock: if not lobby.sessions: if lobby.id in lobbies: del lobbies[lobby.id] logger.info(f"Cleaned up empty lobby {lobby.getName()}") except Exception as e: logger.error( f"Unexpected error in websocket handler for {session.getName()}: {e}" ) try: await websocket.close() except Exception as e: pass # Serve static files or proxy to frontend development server PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true" client_build_path = os.path.join(os.path.dirname(__file__), "/client/build") if PRODUCTION: logger.info(f"Serving static files from: {client_build_path} at {public_url}") app.mount( public_url, StaticFiles(directory=client_build_path, html=True), name="static" ) else: logger.info(f"Proxying static files to http://client:3000 at {public_url}") import ssl @app.api_route( f"{public_url}{{path:path}}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"], ) async def proxy_static(request: Request, path: str): # Do not proxy API or websocket paths if path.startswith("api/") or path.startswith("ws/"): return Response(status_code=404) url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}/{path}" if not path: url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}" headers = dict(request.headers) try: # Accept self-signed certs in dev async with httpx.AsyncClient(verify=False) as client: proxy_req = client.build_request( request.method, url, headers=headers, content=await request.body() ) proxy_resp = await client.send(proxy_req, stream=True) content = await proxy_resp.aread() # Remove problematic headers for browser decoding filtered_headers = { k: v for k, v in proxy_resp.headers.items() if k.lower() not in ["content-encoding", "transfer-encoding", "content-length"] } return Response( content=content, status_code=proxy_resp.status_code, headers=filtered_headers, ) except Exception as e: logger.error(f"Proxy error for {url}: {e}") return Response("Proxy error", status_code=502) # WebSocket proxy for /ws (for React DevTools, etc.) import websockets @app.websocket("/ws") async def websocket_proxy(websocket: WebSocket): logger.info("REACT: WebSocket proxy connection established.") # Get scheme from websocket.url (should be 'ws' or 'wss') scheme = websocket.url.scheme if hasattr(websocket, "url") else "ws" target_url = f"{scheme}://client:3000/ws" await websocket.accept() try: # Accept self-signed certs in dev for WSS ssl_ctx = ssl.create_default_context() ssl_ctx.check_hostname = False ssl_ctx.verify_mode = ssl.CERT_NONE async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws: async def client_to_server(): while True: msg = await websocket.receive_text() await target_ws.send(msg) async def server_to_client(): while True: msg = await target_ws.recv() if isinstance(msg, str): await websocket.send_text(msg) else: await websocket.send_bytes(msg) try: await asyncio.gather(client_to_server(), server_to_client()) except (WebSocketDisconnect, websockets.ConnectionClosed): logger.info("REACT: WebSocket proxy connection closed.") except Exception as e: logger.error(f"REACT: WebSocket proxy error: {e}") await websocket.close()