1191 lines
46 KiB
Python

from __future__ import annotations
from typing import Any, Optional, List
from fastapi import (
FastAPI,
HTTPException,
Path,
WebSocket,
Request,
Response,
WebSocketDisconnect,
)
import secrets
import os
import json
import hashlib
import binascii
import sys
import asyncio
import threading
import time
from contextlib import asynccontextmanager
from fastapi.staticfiles import StaticFiles
import httpx
from pydantic import ValidationError
from logger import logger
# Import shared models
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.models import (
NamePasswordRecord,
LobbySaved,
SessionSaved,
SessionsPayload,
AdminMetricsResponse,
AdminMetricsConfig,
JoinStatusModel,
ChatMessageModel,
ParticipantModel,
# Bot provider models
BotProviderModel,
BotProviderRegisterRequest,
BotProviderRegisterResponse,
BotProviderListResponse,
BotListResponse,
BotInfoModel,
BotJoinLobbyRequest,
BotJoinLobbyResponse,
BotJoinPayload,
BotLeaveLobbyRequest,
BotLeaveLobbyResponse,
BotProviderBotsResponse,
BotProviderJoinResponse,
)
# Import modular components
from core.session_manager import SessionManager
from core.lobby_manager import LobbyManager
from core.auth_manager import AuthManager
from websocket.connection import WebSocketConnectionManager
from api.admin import AdminAPI
from api.sessions import SessionAPI
from api.lobbies import LobbyAPI
class SessionConfig:
"""Configuration class for session management"""
ANONYMOUS_SESSION_TIMEOUT = int(
os.getenv("ANONYMOUS_SESSION_TIMEOUT", "60")
) # 1 minute
DISPLACED_SESSION_TIMEOUT = int(
os.getenv("DISPLACED_SESSION_TIMEOUT", "10800")
) # 3 hours
CLEANUP_INTERVAL = int(os.getenv("CLEANUP_INTERVAL", "300")) # 5 minutes
MAX_SESSIONS_PER_CLEANUP = int(
os.getenv("MAX_SESSIONS_PER_CLEANUP", "100")
) # Circuit breaker
MAX_CHAT_MESSAGES_PER_LOBBY = int(os.getenv("MAX_CHAT_MESSAGES_PER_LOBBY", "100"))
SESSION_VALIDATION_INTERVAL = int(
os.getenv("SESSION_VALIDATION_INTERVAL", "1800")
) # 30 minutes
class BotProviderConfig:
"""Configuration class for bot provider management"""
# Comma-separated list of allowed provider keys
# Format: "key1:name1,key2:name2" or just "key1,key2" (names default to keys)
ALLOWED_PROVIDERS = os.getenv("BOT_PROVIDER_KEYS", "")
@classmethod
def get_allowed_providers(cls) -> dict[str, str]:
"""Parse allowed providers from environment variable
Returns:
dict mapping provider_key -> provider_name
"""
if not cls.ALLOWED_PROVIDERS.strip():
return {}
providers: dict[str, str] = {}
for entry in cls.ALLOWED_PROVIDERS.split(","):
entry = entry.strip()
if not entry:
continue
if ":" in entry:
key, name = entry.split(":", 1)
providers[key.strip()] = name.strip()
else:
providers[entry] = entry
return providers
# Thread lock for session operations
session_lock = threading.RLock()
# Mapping of reserved names to password records (lowercased name -> {salt:..., hash:...})
name_passwords: dict[str, dict[str, str]] = {}
# Bot provider registry: provider_id -> BotProviderModel
bot_providers: dict[str, BotProviderModel] = {}
all_label = "[ all ]"
info_label = "[ info ]"
todo_label = "[ todo ]"
unset_label = "[ ---- ]"
def _hash_password(password: str, salt_hex: str | None = None) -> tuple[str, str]:
"""Return (salt_hex, hash_hex) for the given password. If salt_hex is provided
it is used; otherwise a new salt is generated."""
if salt_hex:
salt = binascii.unhexlify(salt_hex)
else:
salt = secrets.token_bytes(16)
salt_hex = binascii.hexlify(salt).decode()
dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100000)
hash_hex = binascii.hexlify(dk).decode()
return salt_hex, hash_hex
public_url = os.getenv("PUBLIC_URL", "/")
if not public_url.endswith("/"):
public_url += "/"
# Global variables to control background tasks
cleanup_task_running = False
cleanup_task = None
validation_task_running = False
validation_task = None
# Global modular managers
session_manager: SessionManager = None
lobby_manager: LobbyManager = None
auth_manager: AuthManager = None
websocket_manager: WebSocketConnectionManager = None
async def periodic_cleanup():
"""Background task to periodically clean up old sessions - DEPRECATED: Now handled by SessionManager"""
# This function is kept for compatibility but no longer used
# The actual cleanup is now handled by SessionManager
pass
async def periodic_validation():
"""Background task to periodically validate session integrity - DEPRECATED: Now handled by SessionManager"""
# This function is kept for compatibility but no longer used
# The actual validation is now handled by SessionManager
pass
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
global cleanup_task_running, cleanup_task, validation_task_running, validation_task
global session_manager, lobby_manager, auth_manager, websocket_manager
# Startup
logger.info("Initializing modular architecture...")
# Initialize core managers
session_manager = SessionManager()
lobby_manager = LobbyManager()
auth_manager = AuthManager()
# Set up cross-manager dependencies
lobby_manager.set_name_protection_checker(auth_manager.is_name_protected)
# Initialize WebSocket manager
websocket_manager = WebSocketConnectionManager(
session_manager=session_manager,
lobby_manager=lobby_manager,
auth_manager=auth_manager
)
# Register API routes using modular components
admin_api = AdminAPI(
session_manager=session_manager,
lobby_manager=lobby_manager,
auth_manager=auth_manager,
admin_token=ADMIN_TOKEN or "",
public_url=public_url
)
session_api = SessionAPI(
session_manager=session_manager,
public_url=public_url
)
lobby_api = LobbyAPI(
session_manager=session_manager,
lobby_manager=lobby_manager,
public_url=public_url
)
# Include the modular API routes
app.include_router(admin_api.router)
app.include_router(session_api.router)
app.include_router(lobby_api.router)
logger.info("Starting background tasks...")
cleanup_task_running = True
validation_task_running = True
# Start the new session manager background tasks
await session_manager.start_background_tasks()
# Keep the original background tasks for compatibility
cleanup_task = asyncio.create_task(periodic_cleanup())
validation_task = asyncio.create_task(periodic_validation())
logger.info("Session cleanup and validation tasks started")
yield
# Shutdown
logger.info("Shutting down background tasks...")
cleanup_task_running = False
validation_task_running = False
# Stop modular manager tasks
if session_manager:
await session_manager.stop_background_tasks()
# Cancel tasks
for task in [cleanup_task, validation_task]:
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Clean up sessions gracefully using the session manager
if session_manager:
await session_manager.cleanup_all()
logger.info("All background tasks stopped and sessions cleaned up")
app = FastAPI(lifespan=lifespan)
logger.info(f"Starting server with public URL: {public_url}")
logger.info(
f"Session config - Anonymous timeout: {SessionConfig.ANONYMOUS_SESSION_TIMEOUT}s, "
f"Displaced timeout: {SessionConfig.DISPLACED_SESSION_TIMEOUT}s, "
f"Cleanup interval: {SessionConfig.CLEANUP_INTERVAL}s"
)
# Log bot provider configuration
allowed_providers = BotProviderConfig.get_allowed_providers()
if allowed_providers:
logger.info(
f"Bot provider authentication enabled. Allowed providers: {list(allowed_providers.keys())}"
)
else:
logger.warning("Bot provider authentication disabled. Any provider can register.")
# Optional admin token to protect admin endpoints
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", None)
def _require_admin(request: Request) -> bool:
if not ADMIN_TOKEN:
return True
token = request.headers.get("X-Admin-Token")
return token == ADMIN_TOKEN
# =============================================================================
# Bot Provider API Endpoints - DEPRECATED
# =============================================================================
# NOTE: Bot API endpoints should be moved to api/bots.py using modular architecture
# These endpoints are currently disabled due to dependency on removed Session/Lobby classes
# =============================================================================
# Bot Provider API Endpoints - DEPRECATED
# =============================================================================
# NOTE: Bot API endpoints should be moved to api/bots.py using modular architecture
# These endpoints are currently disabled due to dependency on removed Session/Lobby classes
# Register websocket endpoint directly on app with full public_url path
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}")
async def lobby_join(
@app.get(public_url + "api/bots/providers", response_model=BotProviderListResponse)
async def list_bot_providers() -> BotProviderListResponse:
"""List all registered bot providers"""
return BotProviderListResponse(providers=list(bot_providers.values()))
@app.get(public_url + "api/bots", response_model=BotListResponse)
async def list_available_bots() -> BotListResponse:
"""List all available bots from all registered providers"""
bots: List[BotInfoModel] = []
providers: dict[str, str] = {}
# Update last_seen timestamps and fetch bots from each provider
for provider_id, provider in bot_providers.items():
try:
provider.last_seen = time.time()
# Make HTTP request to provider's /bots endpoint
async with httpx.AsyncClient() as client:
response = await client.get(f"{provider.base_url}/bots", timeout=5.0)
if response.status_code == 200:
# Use Pydantic model to validate the response
bots_response = BotProviderBotsResponse.model_validate(
response.json()
)
# Add each bot to the consolidated list
for bot_info in bots_response.bots:
bots.append(bot_info)
providers[bot_info.name] = provider_id
else:
logger.warning(
f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}"
)
except Exception as e:
logger.error(f"Error fetching bots from provider {provider.name}: {e}")
continue
return BotListResponse(bots=bots, providers=providers)
@app.post(public_url + "api/bots/{bot_name}/join", response_model=BotJoinLobbyResponse)
async def request_bot_join_lobby(
bot_name: str, request: BotJoinLobbyRequest
) -> BotJoinLobbyResponse:
"""Request a bot to join a specific lobby"""
# Find which provider has this bot and determine its media capability
target_provider_id = request.provider_id
bot_has_media = False
if not target_provider_id:
# Auto-discover provider for this bot
for provider_id, provider in bot_providers.items():
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{provider.base_url}/bots", timeout=5.0
)
if response.status_code == 200:
# Use Pydantic model to validate the response
bots_response = BotProviderBotsResponse.model_validate(
response.json()
)
# Look for the bot by name
for bot_info in bots_response.bots:
if bot_info.name == bot_name:
target_provider_id = provider_id
bot_has_media = bot_info.has_media
break
if target_provider_id:
break
except Exception:
continue
else:
# Query the specified provider for bot media capability
if target_provider_id in bot_providers:
provider = bot_providers[target_provider_id]
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{provider.base_url}/bots", timeout=5.0
)
if response.status_code == 200:
# Use Pydantic model to validate the response
bots_response = BotProviderBotsResponse.model_validate(
response.json()
)
# Look for the bot by name
for bot_info in bots_response.bots:
if bot_info.name == bot_name:
bot_has_media = bot_info.has_media
break
except Exception:
# Default to no media if we can't query
pass
if not target_provider_id or target_provider_id not in bot_providers:
raise HTTPException(status_code=404, detail="Bot or provider not found")
provider = bot_providers[target_provider_id]
# Get the lobby to validate it exists
try:
getLobby(request.lobby_id) # Just validate it exists
except Exception:
raise HTTPException(status_code=404, detail="Lobby not found")
# Create a session for the bot
bot_session_id = secrets.token_hex(16)
# Create the Session object for the bot
bot_session = Session(bot_session_id, is_bot=True, has_media=bot_has_media)
logger.info(
f"Created bot session for: {bot_session.getName()} (has_media={bot_has_media})"
)
# Determine server URL for the bot to connect back to
# Use the server's public URL or construct from request
server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000")
if server_base_url.endswith("/"):
server_base_url = server_base_url[:-1]
bot_nick = request.nick or f"{bot_name}-bot-{bot_session_id[:8]}"
# Prepare the join request for the bot provider
bot_join_payload = BotJoinPayload(
lobby_id=request.lobby_id,
session_id=bot_session_id,
nick=bot_nick,
server_url=f"{server_base_url}{public_url}".rstrip("/"),
insecure=True, # Accept self-signed certificates in development
)
try:
# Make request to bot provider
async with httpx.AsyncClient() as client:
response = await client.post(
f"{provider.base_url}/bots/{bot_name}/join",
json=bot_join_payload.model_dump(),
timeout=10.0,
)
if response.status_code == 200:
# Use Pydantic model to parse and validate response
try:
join_response = BotProviderJoinResponse.model_validate(
response.json()
)
run_id = join_response.run_id
# Update bot session with run and provider information
with bot_session.session_lock:
bot_session.bot_run_id = run_id
bot_session.bot_provider_id = target_provider_id
bot_session.setName(bot_nick)
logger.info(
f"Bot {bot_name} requested to join lobby {request.lobby_id}"
)
return BotJoinLobbyResponse(
status="requested",
bot_name=bot_name,
run_id=run_id,
provider_id=target_provider_id,
)
except ValidationError as e:
logger.error(f"Invalid response from bot provider: {e}")
raise HTTPException(
status_code=502,
detail=f"Bot provider returned invalid response: {str(e)}",
)
else:
logger.error(
f"Bot provider returned error: HTTP {response.status_code}: {response.text}"
)
raise HTTPException(
status_code=502,
detail=f"Bot provider error: {response.status_code}",
)
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Bot provider timeout")
except Exception as e:
logger.error(f"Error requesting bot join: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post(public_url + "api/bots/leave", response_model=BotLeaveLobbyResponse)
async def request_bot_leave_lobby(
request: BotLeaveLobbyRequest,
) -> BotLeaveLobbyResponse:
"""Request a bot to leave from all lobbies and disconnect"""
# Find the bot session
bot_session = getSession(request.session_id)
if not bot_session:
raise HTTPException(status_code=404, detail="Bot session not found")
if not bot_session.is_bot:
raise HTTPException(status_code=400, detail="Session is not a bot")
run_id = bot_session.bot_run_id
provider_id = bot_session.bot_provider_id
logger.info(f"Requesting bot {bot_session.getName()} to leave all lobbies")
# Try to stop the bot at the provider level if we have the information
if provider_id and run_id and provider_id in bot_providers:
provider = bot_providers[provider_id]
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{provider.base_url}/bots/runs/{run_id}/stop",
timeout=5.0,
)
if response.status_code == 200:
logger.info(
f"Successfully requested bot provider to stop run {run_id}"
)
else:
logger.warning(
f"Bot provider returned error when stopping: HTTP {response.status_code}"
)
except Exception as e:
logger.warning(f"Failed to request bot stop from provider: {e}")
# Force disconnect the bot session from all lobbies
with bot_session.session_lock:
lobbies_to_part = bot_session.lobbies[:]
for lobby in lobbies_to_part:
try:
await bot_session.part(lobby)
except Exception as e:
logger.warning(f"Error parting bot from lobby {lobby.getName()}: {e}")
# Close WebSocket connection if it exists
if bot_session.ws:
try:
await bot_session.ws.close()
except Exception as e:
logger.warning(f"Error closing bot WebSocket: {e}")
bot_session.ws = None
return BotLeaveLobbyResponse(
status="disconnected",
session_id=request.session_id,
run_id=run_id,
)
# Register websocket endpoint directly on app with full public_url path
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}")
async def lobby_join(
websocket: WebSocket,
lobby_id: str = Path(...),
session_id: str = Path(...),
):
"""WebSocket endpoint for lobby connections - now uses modular WebSocketConnectionManager"""
if websocket_manager:
await websocket_manager.handle_connection(websocket, lobby_id, session_id)
else:
# Fallback if manager not initialized
await websocket.accept()
await websocket.send_json(
{"type": "error", "data": {"error": "Server not fully initialized"}}
)
await websocket.close()
# Serve static files or proxy to frontend development server
await websocket.send_json(
{"type": "error", "data": {"error": "Invalid or missing lobby"}}
)
await websocket.close()
return
if session_id is None:
await websocket.send_json(
{"type": "error", "data": {"error": "Invalid or missing session"}}
)
await websocket.close()
return
session = getSession(session_id)
if not session:
# logger.error(f"Invalid session ID {session_id}")
await websocket.send_json(
{"type": "error", "data": {"error": f"Invalid session ID {session_id}"}}
)
await websocket.close()
return
lobby = None
try:
lobby = getLobby(lobby_id)
except Exception as e:
await websocket.send_json({"type": "error", "data": {"error": str(e)}})
await websocket.close()
return
logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})")
session.ws = websocket
session.update_last_used() # Update activity timestamp
# Check if session is already in lobby and clean up if needed
with lobby.lock:
if session.id in lobby.sessions:
logger.info(
f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining."
)
try:
await session.part(lobby)
await lobby.removeSession(session)
except Exception as e:
logger.warning(f"Error cleaning up stale session: {e}")
# Notify existing peers about new user
failed_peers: list[str] = []
with lobby.lock:
peer_sessions = list(lobby.sessions.values())
for peer_session in peer_sessions:
if not peer_session.ws:
logger.warning(
f"{session.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Marking for removal."
)
failed_peers.append(peer_session.id)
continue
logger.info(f"{session.getName()} -> user_joined({peer_session.getName()})")
try:
await peer_session.ws.send_json(
{
"type": "user_joined",
"data": {
"session_id": session.id,
"name": session.name,
},
}
)
except Exception as e:
logger.warning(
f"Failed to notify {peer_session.getName()} of user join: {e}"
)
failed_peers.append(peer_session.id)
# Clean up failed peers
with lobby.lock:
for failed_peer_id in failed_peers:
if failed_peer_id in lobby.sessions:
del lobby.sessions[failed_peer_id]
try:
while True:
packet = await websocket.receive_json()
session.update_last_used() # Update activity on each message
type = packet.get("type", None)
data: dict[str, Any] | None = packet.get("data", None)
if not type:
logger.error(f"{session.getName()} - Invalid request: {packet}")
await websocket.send_json(
{"type": "error", "data": {"error": "Invalid request"}}
)
continue
# logger.info(f"{session.getName()} <- RAW Rx: {data}")
match type:
case "set_name":
if not data:
logger.error(f"{session.getName()} - set_name missing data")
await websocket.send_json(
{
"type": "error",
"data": {"error": "set_name missing data"},
}
)
continue
name = data.get("name")
password = data.get("password")
logger.info(f"{session.getName()} <- set_name({name}, {password})")
if not name:
logger.error(f"{session.getName()} - Name required")
await websocket.send_json(
{"type": "error", "data": {"error": "Name required"}}
)
continue
# Name takeover / password logic
lname = name.lower()
# If name is unused, allow and optionally save password
if Session.isUniqueName(name):
# If a password was provided, save it (hash+salt) for this name
if password:
salt, hash_hex = _hash_password(password)
name_passwords[lname] = {"salt": salt, "hash": hash_hex}
session.setName(name)
logger.info(f"{session.getName()}: -> update('name', {name})")
await websocket.send_json(
{
"type": "update_name",
"data": {
"name": name,
"protected": True
if name.lower() in name_passwords
else False,
},
}
)
# For any clients in any lobby with this session, update their user lists
await lobby.update_state()
continue
# Name is taken. Check if a password exists for the name and matches.
saved_pw = name_passwords.get(lname)
if not saved_pw and not password:
logger.warning(
f"{session.getName()} - Name already taken (no password set)"
)
await websocket.send_json(
{"type": "error", "data": {"error": "Name already taken"}}
)
continue
if saved_pw and password:
# Expect structured record with salt+hash only
match_password = False
# saved_pw should be a dict[str,str] with 'salt' and 'hash'
salt = saved_pw.get("salt")
_, candidate_hash = _hash_password(
password if password else "", salt_hex=salt
)
if candidate_hash == saved_pw.get("hash"):
match_password = True
else:
# No structured password record available
match_password = False
else:
match_password = True # No password set, but name taken and new password - allow takeover
if not match_password:
logger.warning(
f"{session.getName()} - Name takeover attempted with wrong or missing password"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": "Invalid password for name takeover",
},
}
)
continue
# Password matches: perform takeover. Find the current session holding the name.
# Find the currently existing session (if any) with that name
displaced = Session.getSessionByName(name)
if displaced and displaced.id == session.id:
displaced = None
# If found, change displaced session to a unique fallback name and notify peers
if displaced:
# Create a unique fallback name
fallback = f"{displaced.name}-{displaced.short}"
# Ensure uniqueness
if not Session.isUniqueName(fallback):
# append random suffix until unique
while not Session.isUniqueName(fallback):
fallback = f"{displaced.name}-{secrets.token_hex(3)}"
displaced.setName(fallback)
displaced.mark_displaced()
logger.info(
f"{displaced.getName()} <- displaced by takeover, new name {fallback}"
)
# Notify displaced session (if connected)
if displaced.ws:
try:
await displaced.ws.send_json(
{
"type": "update_name",
"data": {
"name": fallback,
"protected": False,
},
}
)
except Exception:
logger.exception(
"Failed to notify displaced session websocket"
)
# Update all lobbies the displaced session was in
with displaced.session_lock:
displaced_lobbies = displaced.lobbies[:]
for d_lobby in displaced_lobbies:
try:
await d_lobby.update_state()
except Exception:
logger.exception(
"Failed to update lobby state for displaced session"
)
# Now assign the requested name to the current session
session.setName(name)
logger.info(
f"{session.getName()}: -> update('name', {name}) (takeover)"
)
await websocket.send_json(
{
"type": "update_name",
"data": {
"name": name,
"protected": True
if name.lower() in name_passwords
else False,
},
}
)
# Notify lobbies for this session
await lobby.update_state()
case "list_users":
await lobby.update_state(session)
case "get_chat_messages":
# Send recent chat messages to the requesting client
messages = lobby.get_chat_messages(50)
await websocket.send_json(
{
"type": "chat_messages",
"data": {
"messages": [msg.model_dump() for msg in messages]
},
}
)
case "send_chat_message":
if not data or "message" not in data:
logger.error(
f"{session.getName()} - send_chat_message missing message"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": "send_chat_message missing message",
},
}
)
continue
if not session.name:
logger.error(
f"{session.getName()} - Cannot send chat message without name"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": "Must set name before sending chat messages",
},
}
)
continue
message_text = str(data["message"]).strip()
if not message_text:
continue
# Add the message to the lobby and broadcast it
chat_message = lobby.add_chat_message(session, message_text)
logger.info(
f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)"
)
await lobby.broadcast_chat_message(chat_message)
case "join":
logger.info(f"{session.getName()} <- join({lobby.getName()})")
await session.join(lobby=lobby)
case "part":
logger.info(f"{session.getName()} <- part {lobby.getName()}")
await session.part(lobby=lobby)
case "relayICECandidate":
logger.info(f"{session.getName()} <- relayICECandidate")
if not data:
logger.error(
f"{session.getName()} - relayICECandidate missing data"
)
await websocket.send_json(
{
"type": "error",
"data": {"error": "relayICECandidate missing data"},
}
)
continue
with session.session_lock:
if (
lobby.id not in session.lobby_peers
or session.id not in lobby.sessions
):
logger.error(
f"{session.short}:{session.name} <- relayICECandidate - Not an RTC peer ({session.id})"
)
await websocket.send_json(
{
"type": "error",
"data": {"error": "Not joined to lobby"},
}
)
continue
session_peers = session.lobby_peers[lobby.id]
peer_id = data.get("peer_id")
if peer_id not in session_peers:
logger.error(
f"{session.getName()} <- relayICECandidate - Not an RTC peer({peer_id}) in {session_peers}"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": f"Target peer {peer_id} not found",
},
}
)
continue
candidate = data.get("candidate")
message: dict[str, Any] = {
"type": "iceCandidate",
"data": {
"peer_id": session.id,
"peer_name": session.name,
"candidate": candidate,
},
}
peer_session = lobby.getSession(peer_id)
if not peer_session or not peer_session.ws:
logger.warning(
f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}."
)
continue
logger.info(
f"{session.getName()} -> iceCandidate({peer_session.getName()})"
)
try:
await peer_session.ws.send_json(message)
except Exception as e:
logger.warning(f"Failed to relay ICE candidate: {e}")
case "relaySessionDescription":
logger.info(f"{session.getName()} <- relaySessionDescription")
if not data:
logger.error(
f"{session.getName()} - relaySessionDescription missing data"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": "relaySessionDescription missing data",
},
}
)
continue
with session.session_lock:
if (
lobby.id not in session.lobby_peers
or session.id not in lobby.sessions
):
logger.error(
f"{session.short}:{session.name} <- relaySessionDescription - Not an RTC peer ({session.id})"
)
await websocket.send_json(
{
"type": "error",
"data": {"error": "Not joined to lobby"},
}
)
continue
lobby_peers = session.lobby_peers[lobby.id]
peer_id = data.get("peer_id")
if peer_id not in lobby_peers:
logger.error(
f"{session.getName()} <- relaySessionDescription - Not an RTC peer({peer_id}) in {lobby_peers}"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": f"Target peer {peer_id} not found",
},
}
)
continue
if not peer_id:
logger.error(
f"{session.getName()} - relaySessionDescription missing peer_id"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": "relaySessionDescription missing peer_id",
},
}
)
continue
peer_session = lobby.getSession(peer_id)
if not peer_session or not peer_session.ws:
logger.warning(
f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}."
)
continue
session_description = data.get("session_description")
message = {
"type": "sessionDescription",
"data": {
"peer_id": session.id,
"peer_name": session.name,
"session_description": session_description,
},
}
logger.info(
f"{session.getName()} -> sessionDescription({peer_session.getName()})"
)
try:
await peer_session.ws.send_json(message)
except Exception as e:
logger.warning(f"Failed to relay session description: {e}")
case "status_check":
# Simple status check - just respond with success to keep connection alive
logger.debug(f"{session.getName()} <- status_check")
await websocket.send_json(
{"type": "status_ok", "data": {"timestamp": time.time()}}
)
case _:
await websocket.send_json(
{
"type": "error",
"data": {
"error": f"Unknown request type: {type}",
},
}
)
except WebSocketDisconnect:
logger.info(f"{session.getName()} <- WebSocket disconnected for user.")
# Cleanup: remove session from lobby and sessions dict
session.ws = None
if session.id in lobby.sessions:
try:
await session.part(lobby)
except Exception as e:
logger.warning(f"Error during websocket disconnect cleanup: {e}")
try:
await lobby.update_state()
except Exception as e:
logger.warning(f"Error updating lobby state after disconnect: {e}")
# Clean up empty lobbies
with lobby.lock:
if not lobby.sessions:
if lobby.id in lobbies:
del lobbies[lobby.id]
logger.info(f"Cleaned up empty lobby {lobby.getName()}")
except Exception as e:
logger.error(
f"Unexpected error in websocket handler for {session.getName()}: {e}"
)
try:
await websocket.close()
except Exception as e:
pass
# Serve static files or proxy to frontend development server
PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true"
client_build_path = os.path.join(os.path.dirname(__file__), "/client/build")
if PRODUCTION:
logger.info(f"Serving static files from: {client_build_path} at {public_url}")
app.mount(
public_url, StaticFiles(directory=client_build_path, html=True), name="static"
)
else:
logger.info(f"Proxying static files to http://client:3000 at {public_url}")
import ssl
@app.api_route(
f"{public_url}{{path:path}}",
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
)
async def proxy_static(request: Request, path: str):
# Do not proxy API or websocket paths
if path.startswith("api/") or path.startswith("ws/"):
return Response(status_code=404)
url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}/{path}"
if not path:
url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}"
headers = dict(request.headers)
try:
# Accept self-signed certs in dev
async with httpx.AsyncClient(verify=False) as client:
proxy_req = client.build_request(
request.method, url, headers=headers, content=await request.body()
)
proxy_resp = await client.send(proxy_req, stream=True)
content = await proxy_resp.aread()
# Remove problematic headers for browser decoding
filtered_headers = {
k: v
for k, v in proxy_resp.headers.items()
if k.lower()
not in ["content-encoding", "transfer-encoding", "content-length"]
}
return Response(
content=content,
status_code=proxy_resp.status_code,
headers=filtered_headers,
)
except Exception as e:
logger.error(f"Proxy error for {url}: {e}")
return Response("Proxy error", status_code=502)
# WebSocket proxy for /ws (for React DevTools, etc.)
import websockets
@app.websocket("/ws")
async def websocket_proxy(websocket: WebSocket):
logger.info("REACT: WebSocket proxy connection established.")
# Get scheme from websocket.url (should be 'ws' or 'wss')
scheme = websocket.url.scheme if hasattr(websocket, "url") else "ws"
target_url = "wss://client:3000/ws" # Use WSS since client uses HTTPS
await websocket.accept()
try:
# Accept self-signed certs in dev for WSS
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws:
async def client_to_server():
while True:
msg = await websocket.receive_text()
await target_ws.send(msg)
async def server_to_client():
while True:
msg = await target_ws.recv()
if isinstance(msg, str):
await websocket.send_text(msg)
else:
await websocket.send_bytes(msg)
try:
await asyncio.gather(client_to_server(), server_to_client())
except (WebSocketDisconnect, websockets.ConnectionClosed):
logger.info("REACT: WebSocket proxy connection closed.")
except Exception as e:
logger.error(f"REACT: WebSocket proxy error: {e}")
await websocket.close()