from __future__ import annotations from typing import Any, Optional from fastapi import ( Body, Cookie, FastAPI, HTTPException, Path, WebSocket, Request, Response, WebSocketDisconnect, ) import secrets import os import json import hashlib import binascii import sys import asyncio from contextlib import asynccontextmanager from fastapi.staticfiles import StaticFiles import httpx from pydantic import ValidationError from logger import logger # Import shared models sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from shared.models import ( HealthResponse, LobbiesResponse, LobbyCreateRequest, LobbyCreateResponse, LobbyListItem, LobbyModel, NamePasswordRecord, LobbySaved, SessionResponse, SessionSaved, SessionsPayload, AdminNamesResponse, AdminActionResponse, AdminSetPassword, AdminClearPassword, JoinStatusModel, ChatMessageModel, ChatMessagesResponse, ParticipantModel, # Bot provider models BotProviderModel, BotProviderRegisterRequest, BotProviderRegisterResponse, BotProviderListResponse, BotListResponse, BotInfoModel, BotJoinLobbyRequest, BotJoinLobbyResponse, BotJoinPayload, ) # Mapping of reserved names to password records (lowercased name -> {salt:..., hash:...}) name_passwords: dict[str, dict[str, str]] = {} # Bot provider registry: provider_id -> BotProviderModel bot_providers: dict[str, BotProviderModel] = {} all_label = "[ all ]" info_label = "[ info ]" todo_label = "[ todo ]" unset_label = "[ ---- ]" def _hash_password(password: str, salt_hex: str | None = None) -> tuple[str, str]: """Return (salt_hex, hash_hex) for the given password. If salt_hex is provided it is used; otherwise a new salt is generated.""" if salt_hex: salt = binascii.unhexlify(salt_hex) else: salt = secrets.token_bytes(16) salt_hex = binascii.hexlify(salt).decode() dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100000) hash_hex = binascii.hexlify(dk).decode() return salt_hex, hash_hex public_url = os.getenv("PUBLIC_URL", "/") if not public_url.endswith("/"): public_url += "/" # Global variable to control the cleanup task cleanup_task_running = False cleanup_task = None async def periodic_cleanup(): """Background task to periodically clean up old sessions""" global cleanup_task_running while cleanup_task_running: try: Session.cleanup_old_sessions() # Run cleanup every 5 minutes await asyncio.sleep(300) except Exception as e: logger.error(f"Error in session cleanup task: {e}") await asyncio.sleep(60) # Wait 1 minute before retrying on error @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for startup and shutdown events""" global cleanup_task_running, cleanup_task # Startup cleanup_task_running = True cleanup_task = asyncio.create_task(periodic_cleanup()) logger.info("Session cleanup task started") yield # Shutdown cleanup_task_running = False if cleanup_task: cleanup_task.cancel() try: await cleanup_task except asyncio.CancelledError: pass logger.info("Session cleanup task stopped") app = FastAPI(lifespan=lifespan) logger.info(f"Starting server with public URL: {public_url}") # Optional admin token to protect admin endpoints ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", None) def _require_admin(request: Request) -> bool: if not ADMIN_TOKEN: return True token = request.headers.get("X-Admin-Token") return token == ADMIN_TOKEN @app.get(public_url + "api/admin/names", response_model=AdminNamesResponse) def admin_list_names(request: Request): if not _require_admin(request): return Response(status_code=403) return {"name_passwords": name_passwords} @app.post(public_url + "api/admin/set_password", response_model=AdminActionResponse) def admin_set_password(request: Request, payload: AdminSetPassword = Body(...)): if not _require_admin(request): return Response(status_code=403) lname = payload.name.lower() salt, hash_hex = _hash_password(payload.password) name_passwords[lname] = {"salt": salt, "hash": hash_hex} Session.save() return {"status": "ok", "name": payload.name} @app.post(public_url + "api/admin/clear_password", response_model=AdminActionResponse) def admin_clear_password(request: Request, payload: AdminClearPassword = Body(...)): if not _require_admin(request): return Response(status_code=403) lname = payload.name.lower() if lname in name_passwords: del name_passwords[lname] Session.save() return {"status": "ok", "name": payload.name} return {"status": "not_found", "name": payload.name} @app.post(public_url + "api/admin/cleanup_sessions", response_model=AdminActionResponse) def admin_cleanup_sessions(request: Request): if not _require_admin(request): return Response(status_code=403) try: removed_count = Session.cleanup_old_sessions() return {"status": "ok", "name": f"Removed {removed_count} sessions"} except Exception as e: logger.error(f"Error during manual session cleanup: {e}") return {"status": "not_found", "name": f"Error: {str(e)}"} lobbies: dict[str, Lobby] = {} class Lobby: def __init__(self, name: str, id: str | None = None, private: bool = False): self.id = secrets.token_hex(16) if id is None else id self.short = self.id[:8] self.name = name self.sessions: dict[str, Session] = {} # All lobby members self.private = private self.chat_messages: list[ChatMessageModel] = [] # Store chat messages def getName(self) -> str: return f"{self.short}:{self.name}" async def update_state(self, requesting_session: Session | None = None): users: list[ParticipantModel] = [ ParticipantModel( name=s.name, live=True if s.ws else False, session_id=s.id, protected=True if s.name and s.name.lower() in name_passwords else False, ) for s in self.sessions.values() if s.name ] if requesting_session: logger.info( f"{requesting_session.getName()} -> lobby_state({self.getName()})" ) if requesting_session.ws: await requesting_session.ws.send_json( { "type": "lobby_state", "data": {"participants": [user.model_dump() for user in users]}, } ) else: logger.warning( f"{requesting_session.getName()} - No WebSocket connection." ) else: for s in self.sessions.values(): logger.info(f"{s.getName()} -> lobby_state({self.getName()})") if s.ws: await s.ws.send_json( { "type": "lobby_state", "data": { "participants": [user.model_dump() for user in users] }, } ) def getSession(self, id: str) -> Session | None: return self.sessions.get(id, None) async def addSession(self, session: Session) -> None: if session.id in self.sessions: logger.warning(f"{session.getName()} - Already in lobby {self.getName()}.") return None self.sessions[session.id] = session await self.update_state() async def removeSession(self, session: Session) -> None: if session.id not in self.sessions: logger.warning(f"{session.getName()} - Not in lobby {self.getName()}.") return None del self.sessions[session.id] await self.update_state() def add_chat_message(self, session: Session, message: str) -> ChatMessageModel: """Add a chat message to the lobby and return the message data""" import time chat_message = ChatMessageModel( id=secrets.token_hex(8), message=message, sender_name=session.name or session.short, sender_session_id=session.id, timestamp=time.time(), lobby_id=self.id, ) self.chat_messages.append(chat_message) # Keep only the latest 100 messages per lobby if len(self.chat_messages) > 100: self.chat_messages = self.chat_messages[-100:] return chat_message def get_chat_messages(self, limit: int = 50) -> list[ChatMessageModel]: """Get the most recent chat messages from the lobby""" return self.chat_messages[-limit:] if self.chat_messages else [] async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None: """Broadcast a chat message to all connected sessions in the lobby""" for session in self.sessions.values(): if session.ws: try: await session.ws.send_json( {"type": "chat_message", "data": chat_message.model_dump()} ) except Exception as e: logger.warning( f"Failed to send chat message to {session.getName()}: {e}" ) class Session: _instances: list[Session] = [] _save_file = "sessions.json" _loaded = False def __init__(self, id: str): import time logger.info(f"Instantiating new session {id}") self._instances.append(self) self.id = id self.short = id[:8] self.name = "" self.lobbies: list[Lobby] = [] # List of lobby IDs this session is in self.lobby_peers: dict[ str, list[str] ] = {} # lobby ID -> list of peer session IDs self.ws: WebSocket | None = None self.created_at = time.time() self.last_used = time.time() self.displaced_at: float | None = None # When name was taken over self.save() @classmethod def save(cls): sessions_list: list[SessionSaved] = [] for s in cls._instances: lobbies_list: list[LobbySaved] = [ LobbySaved(id=lobby.id, name=lobby.name, private=lobby.private) for lobby in s.lobbies ] sessions_list.append( SessionSaved( id=s.id, name=s.name or "", lobbies=lobbies_list, created_at=s.created_at, last_used=s.last_used, displaced_at=s.displaced_at, ) ) # Prepare name password store for persistence (salt+hash). Only structured records are supported. saved_pw: dict[str, NamePasswordRecord] = { name: NamePasswordRecord(**record) for name, record in name_passwords.items() } payload_model = SessionsPayload(sessions=sessions_list, name_passwords=saved_pw) payload = payload_model.model_dump() with open(cls._save_file, "w") as f: json.dump(payload, f, indent=2) logger.info( f"Saved {len(sessions_list)} sessions and {len(saved_pw)} name passwords to {cls._save_file}" ) @classmethod def load(cls): import time if not os.path.exists(cls._save_file): logger.info(f"No session save file found: {cls._save_file}") return with open(cls._save_file, "r") as f: raw = json.load(f) try: payload = SessionsPayload.model_validate(raw) except ValidationError as e: logger.exception(f"Failed to validate sessions payload: {e}") return # Populate in-memory structures from payload (no backwards compatibility code) name_passwords.clear() for name, rec in payload.name_passwords.items(): # rec is a NamePasswordRecord name_passwords[name] = {"salt": rec.salt, "hash": rec.hash} current_time = time.time() one_minute = 60.0 three_hours = 3 * 60 * 60.0 sessions_loaded = 0 sessions_expired = 0 for s_saved in payload.sessions: # Check if this session should be expired during loading created_at = getattr(s_saved, "created_at", time.time()) last_used = getattr(s_saved, "last_used", time.time()) displaced_at = getattr(s_saved, "displaced_at", None) name = s_saved.name or "" # Apply same removal criteria as cleanup_old_sessions should_expire = False # Rule 1: Sessions with no name that are older than 1 minute (no connection assumed for disk sessions) if not name and current_time - created_at > one_minute: should_expire = True logger.info( f"Expiring session {s_saved.id[:8]} during load - no name, older than 1 minute" ) # Rule 2: Displaced sessions unused for 3+ hours (no connection assumed for disk sessions) elif displaced_at is not None and current_time - last_used > three_hours: should_expire = True logger.info( f"Expiring session {s_saved.id[:8]}:{name} during load - displaced and unused for 3+ hours" ) if should_expire: sessions_expired += 1 continue # Skip loading this expired session session = Session(s_saved.id) session.name = name # Load timestamps, with defaults for backward compatibility session.created_at = created_at session.last_used = last_used session.displaced_at = displaced_at for lobby_saved in s_saved.lobbies: session.lobbies.append( Lobby( name=lobby_saved.name, id=lobby_saved.id, private=lobby_saved.private, ) ) logger.info( f"Loaded session {session.getName()} with {len(session.lobbies)} lobbies" ) for lobby in session.lobbies: lobbies[lobby.id] = Lobby( name=lobby.name, id=lobby.id ) # Ensure lobby exists sessions_loaded += 1 logger.info( f"Loaded {sessions_loaded} sessions and {len(name_passwords)} name passwords from {cls._save_file}" ) if sessions_expired > 0: logger.info(f"Expired {sessions_expired} old sessions during load") # Save immediately to persist the cleanup cls.save() @classmethod def getSession(cls, id: str) -> Session | None: if not cls._loaded: cls.load() logger.info(f"Loaded {len(cls._instances)} sessions from disk...") cls._loaded = True for s in cls._instances: if s.id == id: return s return None @classmethod def isUniqueName(cls, name: str) -> bool: if not name: return False for s in cls._instances: if s.name.lower() == name.lower(): return False return True @classmethod def getSessionByName(cls, name: str) -> Optional["Session"]: if not name: return None lname = name.lower() for s in cls._instances: if s.name and s.name.lower() == lname: return s return None def getName(self) -> str: return f"{self.short}:{self.name if self.name else unset_label}" def setName(self, name: str): self.name = name self.update_last_used() self.save() def update_last_used(self): """Update the last_used timestamp""" import time self.last_used = time.time() def mark_displaced(self): """Mark this session as having its name taken over""" import time self.displaced_at = time.time() @classmethod def cleanup_old_sessions(cls) -> int: """Clean up old sessions based on the specified criteria""" import time current_time = time.time() one_minute = 60.0 three_hours = 3 * 60 * 60.0 sessions_removed = 0 # Make a copy of the list to avoid modifying it while iterating sessions_to_remove: list[Session] = [] for session in cls._instances[:]: # Rule 1: Delete sessions with no active connection and no name that are older than 1 minute if ( not session.ws and not session.name and current_time - session.created_at > one_minute ): logger.info( f"Removing session {session.getName()} - no connection, no name, older than 1 minute" ) sessions_to_remove.append(session) continue # Rule 2: Delete inactive sessions that had their nick taken over and haven't been used in 3 hours if ( not session.ws and session.displaced_at is not None and current_time - session.last_used > three_hours ): logger.info( f"Removing session {session.getName()} - displaced and unused for 3+ hours" ) sessions_to_remove.append(session) continue # Remove the sessions for session in sessions_to_remove: # Remove from lobbies first for lobby in session.lobbies[ : ]: # Copy list to avoid modification during iteration try: # Use async cleanup if needed, but for cleanup we'll just remove from data structures if session.id in lobby.sessions: del lobby.sessions[session.id] if lobby.id in session.lobby_peers: del session.lobby_peers[lobby.id] except Exception as e: logger.warning( f"Error removing session {session.getName()} from lobby {lobby.getName()}: {e}" ) # Remove from instances list if session in cls._instances: cls._instances.remove(session) sessions_removed += 1 # Clean up empty lobbies from global lobbies dict empty_lobbies: list[str] = [] for lobby_id, lobby in lobbies.items(): if len(lobby.sessions) == 0: empty_lobbies.append(lobby_id) for lobby_id in empty_lobbies: del lobbies[lobby_id] logger.info(f"Removed empty lobby {lobby_id}") if sessions_removed > 0: cls.save() logger.info(f"Session cleanup: removed {sessions_removed} old sessions") if empty_lobbies: logger.info(f"Session cleanup: removed {len(empty_lobbies)} empty lobbies") return sessions_removed async def join(self, lobby: Lobby): if not self.ws: logger.error( f"{self.getName()} - No WebSocket connection. Lobby not available." ) return if lobby.id in self.lobby_peers or self.id in lobby.sessions: logger.info(f"{self.getName()} - Already joined to {lobby.getName()}.") data = JoinStatusModel( status="Joined", message=f"Already joined to lobby {lobby.getName()}" ) await self.ws.send_json({"type": "join_status", "data": data.model_dump()}) return # Initialize the peer list for this lobby self.lobbies.append(lobby) self.lobby_peers[lobby.id] = [] for peer_id in lobby.sessions: if peer_id == self.id: raise Exception( "Should not happen: self in lobby.sessions while not in lobby." ) peer_session = lobby.getSession(peer_id) if not peer_session or not peer_session.ws: logger.warning( f"{self.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}. Removing." ) del lobby.sessions[peer_id] continue # Add the peer to session's RTC peer list self.lobby_peers[lobby.id].append(peer_id) # Add this user as an RTC peer to each existing peer peer_session.lobby_peers[lobby.id].append(self.id) logger.info( f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName(), lobby.getName()}, should_create_offer=False)" ) await peer_session.ws.send_json( { "type": "addPeer", "data": { "peer_id": self.id, "peer_name": self.name, "should_create_offer": False, }, } ) # Add each other peer to the caller logger.info( f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName(), lobby.getName()}, should_create_offer=True)" ) await self.ws.send_json( { "type": "addPeer", "data": { "peer_id": peer_session.id, "peer_name": peer_session.name, "should_create_offer": True, }, } ) # Add this user as an RTC peer await lobby.addSession(self) Session.save() await self.ws.send_json({"type": "join_status", "data": {"status": "Joined"}}) async def part(self, lobby: Lobby): if lobby.id not in self.lobby_peers or self.id not in lobby.sessions: logger.info( f"{self.getName()} - Attempt to part non-joined lobby {lobby.getName()}." ) if self.ws: await self.ws.send_json( {"type": "error", "error": "Attempt to part non-joined lobby"} ) return logger.info(f"{self.getName()} <- part({lobby.getName()}) - Lobby part.") lobby_peers = self.lobby_peers[lobby.id] del self.lobby_peers[lobby.id] self.lobbies.remove(lobby) # Remove this peer from all other RTC peers, and remove each peer from this peer for peer_session_id in lobby_peers: peer_session = getSession(peer_session_id) if not peer_session: logger.warning( f"{self.getName()} <- part({lobby.getName()}) - Peer session {peer_session_id} not found. Skipping." ) continue if not peer_session.ws: logger.warning( f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection for {peer_session.getName()}. Skipping." ) continue logger.info(f"{peer_session.getName()} <- remove_peer({self.getName()})") await peer_session.ws.send_json( { "type": "removePeer", "data": {"peer_name": self.name, "peer_id": self.id}, } ) if not self.ws: logger.error( f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection." ) continue logger.info(f"{self.getName()} <- remove_peer({peer_session.getName()})") await self.ws.send_json( { "type": "removePeer", "data": { "peer_name": peer_session.name, "peer_id": peer_session.id, }, } ) await lobby.removeSession(self) Session.save() def getName(session: Session | None) -> str | None: if session and session.name: return session.name return None def getSession(session_id: str) -> Session | None: return Session.getSession(session_id) def getLobby(lobby_id: str) -> Lobby: lobby = lobbies.get(lobby_id, None) if not lobby: logger.error(f"Lobby not found: {lobby_id}") raise Exception(f"Lobby not found: {lobby_id}") return lobby def getLobbyByName(lobby_name: str) -> Lobby | None: for lobby in lobbies.values(): if lobby.name == lobby_name: return lobby return None # API endpoints @app.get(f"{public_url}api/health", response_model=HealthResponse) def health(): logger.info("Health check endpoint called.") return { "status": "ok", } # A session (cookie) is bound to a single user (name). # A user can be in multiple lobbies, but a session is unique to a single user. # A user can change their name, but the session ID remains the same and the name # updates for all lobbies. @app.get(f"{public_url}api/session", response_model=SessionResponse) async def session( request: Request, response: Response, session_id: str | None = Cookie(default=None) ) -> Response | SessionResponse: if session_id is None: session_id = secrets.token_hex(16) response.set_cookie(key="session_id", value=session_id) # Validate that session_id is a hex string of length 32 elif len(session_id) != 32 or not all(c in "0123456789abcdef" for c in session_id): return Response( content=json.dumps({"error": "Invalid session_id"}), status_code=400, media_type="application/json", ) print(f"[{session_id[:8]}]: Browser hand-shake achieved.") session = getSession(session_id) if not session: session = Session(session_id) logger.info(f"{session.getName()}: New session created.") else: session.update_last_used() # Update activity on session resumption logger.info(f"{session.getName()}: Existing session resumed.") # Part all lobbies for this session that have no active websocket for lobby_id in list(session.lobby_peers.keys()): lobby = None try: lobby = getLobby(lobby_id) except Exception as e: logger.error( f"{session.getName()} - Error getting lobby {lobby_id}: {e}" ) continue await session.part(lobby) return SessionResponse( id=session_id, name=session.name if session.name else "", lobbies=[ LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private) for lobby in session.lobbies ], ) @app.get(public_url + "api/lobby", response_model=LobbiesResponse) async def get_lobbies(request: Request, response: Response) -> LobbiesResponse: return LobbiesResponse( lobbies=[ LobbyListItem(id=lobby.id, name=lobby.name) for lobby in lobbies.values() if not lobby.private ] ) @app.post(public_url + "api/lobby/{session_id}", response_model=LobbyCreateResponse) async def lobby_create( request: Request, response: Response, session_id: str = Path(...), create_request: LobbyCreateRequest = Body(...), ) -> Response | LobbyCreateResponse: if create_request.type != "lobby_create": return {"error": "Invalid request type"} data = create_request.data session = getSession(session_id) if not session: return Response( content=json.dumps({"error": f"Session not found ({session_id})"}), status_code=404, media_type="application/json", ) logger.info( f"{session.getName()} lobby_create: {data.name} (private={data.private})" ) lobby = getLobbyByName(data.name) if not lobby: lobby = Lobby( data.name, private=data.private, ) lobbies[lobby.id] = lobby logger.info(f"{session.getName()} <- lobby_create({lobby.short}:{lobby.name})") return LobbyCreateResponse( type="lobby_created", data=LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private), ) @app.get(public_url + "api/lobby/{lobby_id}/chat", response_model=ChatMessagesResponse) async def get_chat_messages( request: Request, lobby_id: str = Path(...), limit: int = 50, ) -> Response | ChatMessagesResponse: """Get chat messages for a lobby""" try: lobby = getLobby(lobby_id) except Exception as e: return Response( content=json.dumps({"error": str(e)}), status_code=404, media_type="application/json", ) messages = lobby.get_chat_messages(limit) return ChatMessagesResponse(messages=messages) # ============================================================================= # Bot Provider API Endpoints # ============================================================================= @app.post( public_url + "api/bots/providers/register", response_model=BotProviderRegisterResponse, ) async def register_bot_provider( request: BotProviderRegisterRequest, ) -> BotProviderRegisterResponse: """Register a new bot provider""" import time import uuid provider_id = str(uuid.uuid4()) now = time.time() provider = BotProviderModel( provider_id=provider_id, base_url=request.base_url.rstrip("/"), name=request.name, description=request.description, registered_at=now, last_seen=now, ) bot_providers[provider_id] = provider logger.info(f"Registered bot provider: {request.name} at {request.base_url}") return BotProviderRegisterResponse(provider_id=provider_id) @app.get(public_url + "api/bots/providers", response_model=BotProviderListResponse) async def list_bot_providers() -> BotProviderListResponse: """List all registered bot providers""" return BotProviderListResponse(providers=list(bot_providers.values())) @app.get(public_url + "api/bots", response_model=BotListResponse) async def list_available_bots() -> BotListResponse: """List all available bots from all registered providers""" bots: dict[str, BotInfoModel] = {} providers: dict[str, str] = {} # Update last_seen timestamps and fetch bots from each provider for provider_id, provider in bot_providers.items(): try: import time provider.last_seen = time.time() # Make HTTP request to provider's /bots endpoint async with httpx.AsyncClient() as client: response = await client.get(f"{provider.base_url}/bots", timeout=5.0) if response.status_code == 200: provider_bots = response.json() # provider_bots should be a dict of bot_name -> bot_info for bot_name, bot_info in provider_bots.items(): bots[bot_name] = BotInfoModel(**bot_info) providers[bot_name] = provider_id else: logger.warning( f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}" ) except Exception as e: logger.error(f"Error fetching bots from provider {provider.name}: {e}") continue return BotListResponse(bots=bots, providers=providers) @app.post(public_url + "api/bots/{bot_name}/join", response_model=BotJoinLobbyResponse) async def request_bot_join_lobby( bot_name: str, request: BotJoinLobbyRequest ) -> BotJoinLobbyResponse: """Request a bot to join a specific lobby""" # Find which provider has this bot target_provider_id = request.provider_id if not target_provider_id: # Auto-discover provider for this bot for provider_id, provider in bot_providers.items(): try: async with httpx.AsyncClient() as client: response = await client.get( f"{provider.base_url}/bots", timeout=5.0 ) if response.status_code == 200: provider_bots = response.json() if bot_name in provider_bots: target_provider_id = provider_id break except Exception: continue if not target_provider_id or target_provider_id not in bot_providers: raise HTTPException(status_code=404, detail="Bot or provider not found") provider = bot_providers[target_provider_id] # Get the lobby to validate it exists try: getLobby(request.lobby_id) # Just validate it exists except Exception: raise HTTPException(status_code=404, detail="Lobby not found") # Create a session for the bot bot_session_id = secrets.token_hex(16) # Determine server URL for the bot to connect back to # Use the server's public URL or construct from request server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000") if server_base_url.endswith("/"): server_base_url = server_base_url[:-1] bot_nick = request.nick or f"{bot_name}-bot" # Prepare the join request for the bot provider bot_join_payload = BotJoinPayload( lobby_id=request.lobby_id, session_id=bot_session_id, nick=bot_nick, server_url=f"{server_base_url}{public_url}".rstrip("/"), insecure=False, # Assume secure by default ) try: # Make request to bot provider async with httpx.AsyncClient() as client: response = await client.post( f"{provider.base_url}/bots/{bot_name}/join", json=bot_join_payload.model_dump(), timeout=10.0, ) if response.status_code == 200: result = response.json() run_id = result.get("run_id", "unknown") logger.info( f"Bot {bot_name} requested to join lobby {request.lobby_id}" ) return BotJoinLobbyResponse( status="requested", bot_name=bot_name, run_id=run_id, provider_id=target_provider_id, ) else: logger.error( f"Bot provider returned error: HTTP {response.status_code}: {response.text}" ) raise HTTPException( status_code=502, detail=f"Bot provider error: {response.status_code}", ) except httpx.TimeoutException: raise HTTPException(status_code=504, detail="Bot provider timeout") except Exception as e: logger.error(f"Error requesting bot join: {e}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") # Register websocket endpoint directly on app with full public_url path @app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}") async def lobby_join( websocket: WebSocket, lobby_id: str | None = Path(...), session_id: str | None = Path(...), ): await websocket.accept() if lobby_id is None: await websocket.send_json( {"type": "error", "error": "Invalid or missing lobby"} ) await websocket.close() return if session_id is None: await websocket.send_json( {"type": "error", "error": "Invalid or missing session"} ) await websocket.close() return session = getSession(session_id) if not session: # logger.error(f"Invalid session ID {session_id}") await websocket.send_json( {"type": "error", "error": f"Invalid session ID {session_id}"} ) await websocket.close() return lobby = None try: lobby = getLobby(lobby_id) except Exception as e: await websocket.send_json({"type": "error", "error": str(e)}) await websocket.close() return logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})") session.ws = websocket session.update_last_used() # Update activity timestamp if session.id in lobby.sessions: logger.info( f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining." ) await session.part(lobby) await lobby.removeSession(session) for peer_id in lobby.sessions: peer_session = lobby.getSession(peer_id) if not peer_session or not peer_session.ws: logger.warning( f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}. Removing." ) del lobby.sessions[peer_id] continue logger.info(f"{session.getName()} -> user_joined({peer_session.getName()})") await peer_session.ws.send_json( { "type": "user_joined", "data": { "session_id": session.id, "name": session.name, }, } ) try: while True: packet = await websocket.receive_json() session.update_last_used() # Update activity on each message type = packet.get("type", None) data: dict[str, Any] | None = packet.get("data", None) if not type: logger.error(f"{session.getName()} - Invalid request: {packet}") await websocket.send_json({"type": "error", "error": "Invalid request"}) continue # logger.info(f"{session.getName()} <- RAW Rx: {data}") match type: case "set_name": if not data: logger.error(f"{session.getName()} - set_name missing data") await websocket.send_json( {"type": "error", "error": "set_name missing data"} ) continue name = data.get("name") password = data.get("password") logger.info(f"{session.getName()} <- set_name({name})") if not name: logger.error(f"{session.getName()} - Name required") await websocket.send_json( {"type": "error", "error": "Name required"} ) continue # Name takeover / password logic lname = name.lower() # If name is unused, allow and optionally save password if Session.isUniqueName(name): # If a password was provided, save it (hash+salt) for this name if password: salt, hash_hex = _hash_password(password) name_passwords[lname] = {"salt": salt, "hash": hash_hex} session.setName(name) logger.info(f"{session.getName()}: -> update('name', {name})") await websocket.send_json( { "type": "update_name", "data": { "name": name, "protected": True if name.lower() in name_passwords else False, }, } ) # For any clients in any lobby with this session, update their user lists await lobby.update_state() continue # Name is taken. Check if a password exists for the name and matches. saved_pw = name_passwords.get(lname) if not saved_pw: logger.warning( f"{session.getName()} - Name already taken (no password set)" ) await websocket.send_json( {"type": "error", "error": "Name already taken"} ) continue # Expect structured record with salt+hash only match_password = False # saved_pw should be a dict[str,str] with 'salt' and 'hash' salt = saved_pw.get("salt") _, candidate_hash = _hash_password( password if password else "", salt_hex=salt ) if candidate_hash == saved_pw.get("hash"): match_password = True else: # No structured password record available match_password = False if not match_password: logger.warning( f"{session.getName()} - Name takeover attempted with wrong or missing password" ) await websocket.send_json( { "type": "error", "error": "Invalid password for name takeover", } ) continue # Password matches: perform takeover. Find the current session holding the name. # Find the currently existing session (if any) with that name displaced = Session.getSessionByName(name) if displaced and displaced.id == session.id: displaced = None # If found, change displaced session to a unique fallback name and notify peers if displaced: # Create a unique fallback name fallback = f"{displaced.name}-{displaced.short}" # Ensure uniqueness if not Session.isUniqueName(fallback): # append random suffix until unique while not Session.isUniqueName(fallback): fallback = f"{displaced.name}-{secrets.token_hex(3)}" displaced.setName(fallback) displaced.mark_displaced() logger.info( f"{displaced.getName()} <- displaced by takeover, new name {fallback}" ) # Notify displaced session (if connected) if displaced.ws: try: await displaced.ws.send_json( { "type": "update_name", "data": { "name": fallback, "protected": False, }, } ) except Exception: logger.exception( "Failed to notify displaced session websocket" ) # Update all lobbies the displaced session was in for d_lobby in list(displaced.lobbies): try: await d_lobby.update_state() except Exception: logger.exception( "Failed to update lobby state for displaced session" ) # Now assign the requested name to the current session session.setName(name) logger.info( f"{session.getName()}: -> update('name', {name}) (takeover)" ) await websocket.send_json( { "type": "update_name", "data": { "name": name, "protected": True if name.lower() in name_passwords else False, }, } ) # Notify lobbies for this session await lobby.update_state() case "list_users": await lobby.update_state(session) case "get_chat_messages": # Send recent chat messages to the requesting client messages = lobby.get_chat_messages(50) await websocket.send_json( { "type": "chat_messages", "data": { "messages": [msg.model_dump() for msg in messages] }, } ) case "send_chat_message": if not data or "message" not in data: logger.error( f"{session.getName()} - send_chat_message missing message" ) await websocket.send_json( { "type": "error", "error": "send_chat_message missing message", } ) continue if not session.name: logger.error( f"{session.getName()} - Cannot send chat message without name" ) await websocket.send_json( { "type": "error", "error": "Must set name before sending chat messages", } ) continue message_text = str(data["message"]).strip() if not message_text: continue # Add the message to the lobby and broadcast it chat_message = lobby.add_chat_message(session, message_text) await lobby.broadcast_chat_message(chat_message) logger.info( f"{session.getName()} sent chat message to {lobby.getName()}: {message_text[:50]}..." ) case "join": logger.info(f"{session.getName()} <- join({lobby.getName()})") await session.join(lobby=lobby) case "part": logger.info(f"{session.getName()} <- part {lobby.getName()}") await session.part(lobby=lobby) case "relayICECandidate": logger.info(f"{session.getName()} <- relayICECandidate") if not data: logger.error( f"{session.getName()} - relayICECandidate missing data" ) await websocket.send_json( {"type": "error", "error": "relayICECandidate missing data"} ) continue if ( lobby.id not in session.lobby_peers or session.id not in lobby.sessions ): logger.error( f"{session.short}:{session.name} <- relayICECandidate - Not an RTC peer ({session.id})" ) await websocket.send_json( {"type": "error", "error": "Not joined to lobby"} ) continue session_peers = session.lobby_peers[lobby.id] peer_id = data.get("peer_id") if peer_id not in session_peers: logger.error( f"{session.getName()} <- relayICECandidate - Not an RTC peer({peer_id}) in {session_peers}" ) await websocket.send_json( { "type": "error", "error": f"Target peer {peer_id} not found", } ) continue candidate = data.get("candidate") message: dict[str, Any] = { "type": "iceCandidate", "data": { "peer_id": session.id, "peer_name": session.name, "candidate": candidate, }, } peer_session = lobby.getSession(peer_id) if not peer_session or not peer_session.ws: logger.warning( f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}." ) break logger.info( f"{session.getName()} -> iceCandidate({peer_session.getName()})" ) await peer_session.ws.send_json(message) case "relaySessionDescription": logger.info(f"{session.getName()} <- relaySessionDescription") if not data: logger.error( f"{session.getName()} - relaySessionDescription missing data" ) await websocket.send_json( { "type": "error", "error": "relaySessionDescription missing data", } ) continue if ( lobby.id not in session.lobby_peers or session.id not in lobby.sessions ): logger.error( f"{session.short}:{session.name} <- relaySessionDescription - Not an RTC peer ({session.id})" ) await websocket.send_json( {"type": "error", "error": "Not joined to lobby"} ) continue lobby_peers = session.lobby_peers[lobby.id] peer_id = data.get("peer_id") if peer_id not in lobby_peers: logger.error( f"{session.getName()} <- relaySessionDescription - Not an RTC peer({peer_id}) in {lobby_peers}" ) await websocket.send_json( { "type": "error", "error": f"Target peer {peer_id} not found", } ) continue peer_id = data.get("peer_id", None) if not peer_id: logger.error( f"{session.getName()} - relaySessionDescription missing peer_id" ) await websocket.send_json( { "type": "error", "error": "relaySessionDescription missing peer_id", } ) continue peer_session = lobby.getSession(peer_id) if not peer_session or not peer_session.ws: logger.warning( f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}." ) break session_description = data.get("session_description") message = { "type": "sessionDescription", "data": { "peer_id": session.id, "peer_name": session.name, "session_description": session_description, }, } logger.info( f"{session.getName()} -> sessionDescription({peer_session.getName()})" ) await peer_session.ws.send_json(message) case _: await websocket.send_json( { "type": "error", "error": f"Unknown request type: {type}", } ) except WebSocketDisconnect: logger.info(f"{session.getName()} <- WebSocket disconnected for user.") # Cleanup: remove session from lobby and sessions dict session.ws = None if session.id in lobby.sessions: await session.part(lobby) await lobby.update_state() # Clean up empty lobbies if not lobby.sessions: if lobby.id in lobbies: del lobbies[lobby.id] logger.info(f"Cleaned up empty lobby {lobby.getName()}") # Serve static files or proxy to frontend development server PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true" client_build_path = os.path.join(os.path.dirname(__file__), "/client/build") if PRODUCTION: logger.info(f"Serving static files from: {client_build_path} at {public_url}") app.mount( public_url, StaticFiles(directory=client_build_path, html=True), name="static" ) else: logger.info(f"Proxying static files to http://client:3000 at {public_url}") import ssl @app.api_route( f"{public_url}{{path:path}}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"], ) async def proxy_static(request: Request, path: str): # Do not proxy API or websocket paths if path.startswith("api/") or path.startswith("ws/"): return Response(status_code=404) url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}/{path}" if not path: url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}" headers = dict(request.headers) try: # Accept self-signed certs in dev async with httpx.AsyncClient(verify=False) as client: proxy_req = client.build_request( request.method, url, headers=headers, content=await request.body() ) proxy_resp = await client.send(proxy_req, stream=True) content = await proxy_resp.aread() # Remove problematic headers for browser decoding filtered_headers = { k: v for k, v in proxy_resp.headers.items() if k.lower() not in ["content-encoding", "transfer-encoding", "content-length"] } return Response( content=content, status_code=proxy_resp.status_code, headers=filtered_headers, ) except Exception as e: logger.error(f"Proxy error for {url}: {e}") return Response("Proxy error", status_code=502) # WebSocket proxy for /ws (for React DevTools, etc.) import websockets import asyncio from starlette.websockets import WebSocket as StarletteWebSocket @app.websocket("/ws") async def websocket_proxy(websocket: StarletteWebSocket): logger.info("REACT: WebSocket proxy connection established.") # Get scheme from websocket.url (should be 'ws' or 'wss') scheme = websocket.url.scheme if hasattr(websocket, "url") else "ws" target_url = f"{scheme}://client:3000/ws" await websocket.accept() try: # Accept self-signed certs in dev for WSS ssl_ctx = ssl.create_default_context() ssl_ctx.check_hostname = False ssl_ctx.verify_mode = ssl.CERT_NONE async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws: async def client_to_server(): while True: msg = await websocket.receive_text() await target_ws.send(msg) async def server_to_client(): while True: msg = await target_ws.recv() if isinstance(msg, str): await websocket.send_text(msg) else: await websocket.send_bytes(msg) try: await asyncio.gather(client_to_server(), server_to_client()) except (WebSocketDisconnect, websockets.ConnectionClosed): logger.info("REACT: WebSocket proxy connection closed.") except Exception as e: logger.error(f"REACT: WebSocket proxy error: {e}") await websocket.close()