1191 lines
46 KiB
Python
1191 lines
46 KiB
Python
from __future__ import annotations
|
|
from typing import Any, Optional, List
|
|
from fastapi import (
|
|
FastAPI,
|
|
HTTPException,
|
|
Path,
|
|
WebSocket,
|
|
Request,
|
|
Response,
|
|
WebSocketDisconnect,
|
|
)
|
|
import secrets
|
|
import os
|
|
import json
|
|
import hashlib
|
|
import binascii
|
|
import sys
|
|
import asyncio
|
|
import threading
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
import httpx
|
|
from pydantic import ValidationError
|
|
from logger import logger
|
|
|
|
# Import shared models
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
from shared.models import (
|
|
NamePasswordRecord,
|
|
LobbySaved,
|
|
SessionSaved,
|
|
SessionsPayload,
|
|
AdminMetricsResponse,
|
|
AdminMetricsConfig,
|
|
JoinStatusModel,
|
|
ChatMessageModel,
|
|
ParticipantModel,
|
|
# Bot provider models
|
|
BotProviderModel,
|
|
BotProviderRegisterRequest,
|
|
BotProviderRegisterResponse,
|
|
BotProviderListResponse,
|
|
BotListResponse,
|
|
BotInfoModel,
|
|
BotJoinLobbyRequest,
|
|
BotJoinLobbyResponse,
|
|
BotJoinPayload,
|
|
BotLeaveLobbyRequest,
|
|
BotLeaveLobbyResponse,
|
|
BotProviderBotsResponse,
|
|
BotProviderJoinResponse,
|
|
)
|
|
|
|
# Import modular components
|
|
from core.session_manager import SessionManager
|
|
from core.lobby_manager import LobbyManager
|
|
from core.auth_manager import AuthManager
|
|
from websocket.connection import WebSocketConnectionManager
|
|
from api.admin import AdminAPI
|
|
from api.sessions import SessionAPI
|
|
from api.lobbies import LobbyAPI
|
|
|
|
|
|
class SessionConfig:
|
|
"""Configuration class for session management"""
|
|
|
|
ANONYMOUS_SESSION_TIMEOUT = int(
|
|
os.getenv("ANONYMOUS_SESSION_TIMEOUT", "60")
|
|
) # 1 minute
|
|
DISPLACED_SESSION_TIMEOUT = int(
|
|
os.getenv("DISPLACED_SESSION_TIMEOUT", "10800")
|
|
) # 3 hours
|
|
CLEANUP_INTERVAL = int(os.getenv("CLEANUP_INTERVAL", "300")) # 5 minutes
|
|
MAX_SESSIONS_PER_CLEANUP = int(
|
|
os.getenv("MAX_SESSIONS_PER_CLEANUP", "100")
|
|
) # Circuit breaker
|
|
MAX_CHAT_MESSAGES_PER_LOBBY = int(os.getenv("MAX_CHAT_MESSAGES_PER_LOBBY", "100"))
|
|
SESSION_VALIDATION_INTERVAL = int(
|
|
os.getenv("SESSION_VALIDATION_INTERVAL", "1800")
|
|
) # 30 minutes
|
|
|
|
|
|
class BotProviderConfig:
|
|
"""Configuration class for bot provider management"""
|
|
|
|
# Comma-separated list of allowed provider keys
|
|
# Format: "key1:name1,key2:name2" or just "key1,key2" (names default to keys)
|
|
ALLOWED_PROVIDERS = os.getenv("BOT_PROVIDER_KEYS", "")
|
|
|
|
@classmethod
|
|
def get_allowed_providers(cls) -> dict[str, str]:
|
|
"""Parse allowed providers from environment variable
|
|
|
|
Returns:
|
|
dict mapping provider_key -> provider_name
|
|
"""
|
|
if not cls.ALLOWED_PROVIDERS.strip():
|
|
return {}
|
|
|
|
providers: dict[str, str] = {}
|
|
for entry in cls.ALLOWED_PROVIDERS.split(","):
|
|
entry = entry.strip()
|
|
if not entry:
|
|
continue
|
|
|
|
if ":" in entry:
|
|
key, name = entry.split(":", 1)
|
|
providers[key.strip()] = name.strip()
|
|
else:
|
|
providers[entry] = entry
|
|
|
|
return providers
|
|
|
|
|
|
# Thread lock for session operations
|
|
session_lock = threading.RLock()
|
|
|
|
# Mapping of reserved names to password records (lowercased name -> {salt:..., hash:...})
|
|
name_passwords: dict[str, dict[str, str]] = {}
|
|
|
|
# Bot provider registry: provider_id -> BotProviderModel
|
|
bot_providers: dict[str, BotProviderModel] = {}
|
|
|
|
all_label = "[ all ]"
|
|
info_label = "[ info ]"
|
|
todo_label = "[ todo ]"
|
|
unset_label = "[ ---- ]"
|
|
|
|
|
|
def _hash_password(password: str, salt_hex: str | None = None) -> tuple[str, str]:
|
|
"""Return (salt_hex, hash_hex) for the given password. If salt_hex is provided
|
|
it is used; otherwise a new salt is generated."""
|
|
if salt_hex:
|
|
salt = binascii.unhexlify(salt_hex)
|
|
else:
|
|
salt = secrets.token_bytes(16)
|
|
salt_hex = binascii.hexlify(salt).decode()
|
|
dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100000)
|
|
hash_hex = binascii.hexlify(dk).decode()
|
|
return salt_hex, hash_hex
|
|
|
|
|
|
public_url = os.getenv("PUBLIC_URL", "/")
|
|
if not public_url.endswith("/"):
|
|
public_url += "/"
|
|
|
|
# Global variables to control background tasks
|
|
cleanup_task_running = False
|
|
cleanup_task = None
|
|
validation_task_running = False
|
|
validation_task = None
|
|
|
|
# Global modular managers
|
|
session_manager: SessionManager = None
|
|
lobby_manager: LobbyManager = None
|
|
auth_manager: AuthManager = None
|
|
websocket_manager: WebSocketConnectionManager = None
|
|
|
|
|
|
async def periodic_cleanup():
|
|
"""Background task to periodically clean up old sessions - DEPRECATED: Now handled by SessionManager"""
|
|
# This function is kept for compatibility but no longer used
|
|
# The actual cleanup is now handled by SessionManager
|
|
pass
|
|
|
|
|
|
async def periodic_validation():
|
|
"""Background task to periodically validate session integrity - DEPRECATED: Now handled by SessionManager"""
|
|
# This function is kept for compatibility but no longer used
|
|
# The actual validation is now handled by SessionManager
|
|
pass
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Lifespan context manager for startup and shutdown events"""
|
|
global cleanup_task_running, cleanup_task, validation_task_running, validation_task
|
|
global session_manager, lobby_manager, auth_manager, websocket_manager
|
|
|
|
# Startup
|
|
logger.info("Initializing modular architecture...")
|
|
|
|
# Initialize core managers
|
|
session_manager = SessionManager()
|
|
lobby_manager = LobbyManager()
|
|
auth_manager = AuthManager()
|
|
|
|
# Set up cross-manager dependencies
|
|
lobby_manager.set_name_protection_checker(auth_manager.is_name_protected)
|
|
|
|
# Initialize WebSocket manager
|
|
websocket_manager = WebSocketConnectionManager(
|
|
session_manager=session_manager,
|
|
lobby_manager=lobby_manager,
|
|
auth_manager=auth_manager
|
|
)
|
|
|
|
# Register API routes using modular components
|
|
admin_api = AdminAPI(
|
|
session_manager=session_manager,
|
|
lobby_manager=lobby_manager,
|
|
auth_manager=auth_manager,
|
|
admin_token=ADMIN_TOKEN or "",
|
|
public_url=public_url
|
|
)
|
|
|
|
session_api = SessionAPI(
|
|
session_manager=session_manager,
|
|
public_url=public_url
|
|
)
|
|
|
|
lobby_api = LobbyAPI(
|
|
session_manager=session_manager,
|
|
lobby_manager=lobby_manager,
|
|
public_url=public_url
|
|
)
|
|
|
|
# Include the modular API routes
|
|
app.include_router(admin_api.router)
|
|
app.include_router(session_api.router)
|
|
app.include_router(lobby_api.router)
|
|
|
|
logger.info("Starting background tasks...")
|
|
cleanup_task_running = True
|
|
validation_task_running = True
|
|
|
|
# Start the new session manager background tasks
|
|
await session_manager.start_background_tasks()
|
|
|
|
# Keep the original background tasks for compatibility
|
|
cleanup_task = asyncio.create_task(periodic_cleanup())
|
|
validation_task = asyncio.create_task(periodic_validation())
|
|
logger.info("Session cleanup and validation tasks started")
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
logger.info("Shutting down background tasks...")
|
|
cleanup_task_running = False
|
|
validation_task_running = False
|
|
|
|
# Stop modular manager tasks
|
|
if session_manager:
|
|
await session_manager.stop_background_tasks()
|
|
|
|
# Cancel tasks
|
|
for task in [cleanup_task, validation_task]:
|
|
if task:
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Clean up sessions gracefully using the session manager
|
|
if session_manager:
|
|
await session_manager.cleanup_all()
|
|
logger.info("All background tasks stopped and sessions cleaned up")
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
logger.info(f"Starting server with public URL: {public_url}")
|
|
logger.info(
|
|
f"Session config - Anonymous timeout: {SessionConfig.ANONYMOUS_SESSION_TIMEOUT}s, "
|
|
f"Displaced timeout: {SessionConfig.DISPLACED_SESSION_TIMEOUT}s, "
|
|
f"Cleanup interval: {SessionConfig.CLEANUP_INTERVAL}s"
|
|
)
|
|
|
|
# Log bot provider configuration
|
|
allowed_providers = BotProviderConfig.get_allowed_providers()
|
|
if allowed_providers:
|
|
logger.info(
|
|
f"Bot provider authentication enabled. Allowed providers: {list(allowed_providers.keys())}"
|
|
)
|
|
else:
|
|
logger.warning("Bot provider authentication disabled. Any provider can register.")
|
|
|
|
# Optional admin token to protect admin endpoints
|
|
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", None)
|
|
|
|
|
|
def _require_admin(request: Request) -> bool:
|
|
if not ADMIN_TOKEN:
|
|
return True
|
|
token = request.headers.get("X-Admin-Token")
|
|
return token == ADMIN_TOKEN
|
|
|
|
|
|
# =============================================================================
|
|
# Bot Provider API Endpoints - DEPRECATED
|
|
# =============================================================================
|
|
|
|
# NOTE: Bot API endpoints should be moved to api/bots.py using modular architecture
|
|
# These endpoints are currently disabled due to dependency on removed Session/Lobby classes
|
|
|
|
|
|
# =============================================================================
|
|
# Bot Provider API Endpoints - DEPRECATED
|
|
# =============================================================================
|
|
|
|
# NOTE: Bot API endpoints should be moved to api/bots.py using modular architecture
|
|
# These endpoints are currently disabled due to dependency on removed Session/Lobby classes
|
|
|
|
|
|
# Register websocket endpoint directly on app with full public_url path
|
|
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}")
|
|
async def lobby_join(
|
|
|
|
|
|
@app.get(public_url + "api/bots/providers", response_model=BotProviderListResponse)
|
|
async def list_bot_providers() -> BotProviderListResponse:
|
|
"""List all registered bot providers"""
|
|
return BotProviderListResponse(providers=list(bot_providers.values()))
|
|
|
|
|
|
@app.get(public_url + "api/bots", response_model=BotListResponse)
|
|
async def list_available_bots() -> BotListResponse:
|
|
"""List all available bots from all registered providers"""
|
|
bots: List[BotInfoModel] = []
|
|
providers: dict[str, str] = {}
|
|
|
|
# Update last_seen timestamps and fetch bots from each provider
|
|
for provider_id, provider in bot_providers.items():
|
|
try:
|
|
provider.last_seen = time.time()
|
|
|
|
# Make HTTP request to provider's /bots endpoint
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(f"{provider.base_url}/bots", timeout=5.0)
|
|
if response.status_code == 200:
|
|
# Use Pydantic model to validate the response
|
|
bots_response = BotProviderBotsResponse.model_validate(
|
|
response.json()
|
|
)
|
|
# Add each bot to the consolidated list
|
|
for bot_info in bots_response.bots:
|
|
bots.append(bot_info)
|
|
providers[bot_info.name] = provider_id
|
|
else:
|
|
logger.warning(
|
|
f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error fetching bots from provider {provider.name}: {e}")
|
|
continue
|
|
|
|
return BotListResponse(bots=bots, providers=providers)
|
|
|
|
|
|
@app.post(public_url + "api/bots/{bot_name}/join", response_model=BotJoinLobbyResponse)
|
|
async def request_bot_join_lobby(
|
|
bot_name: str, request: BotJoinLobbyRequest
|
|
) -> BotJoinLobbyResponse:
|
|
"""Request a bot to join a specific lobby"""
|
|
|
|
# Find which provider has this bot and determine its media capability
|
|
target_provider_id = request.provider_id
|
|
bot_has_media = False
|
|
if not target_provider_id:
|
|
# Auto-discover provider for this bot
|
|
for provider_id, provider in bot_providers.items():
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
f"{provider.base_url}/bots", timeout=5.0
|
|
)
|
|
if response.status_code == 200:
|
|
# Use Pydantic model to validate the response
|
|
bots_response = BotProviderBotsResponse.model_validate(
|
|
response.json()
|
|
)
|
|
# Look for the bot by name
|
|
for bot_info in bots_response.bots:
|
|
if bot_info.name == bot_name:
|
|
target_provider_id = provider_id
|
|
bot_has_media = bot_info.has_media
|
|
break
|
|
if target_provider_id:
|
|
break
|
|
except Exception:
|
|
continue
|
|
else:
|
|
# Query the specified provider for bot media capability
|
|
if target_provider_id in bot_providers:
|
|
provider = bot_providers[target_provider_id]
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
f"{provider.base_url}/bots", timeout=5.0
|
|
)
|
|
if response.status_code == 200:
|
|
# Use Pydantic model to validate the response
|
|
bots_response = BotProviderBotsResponse.model_validate(
|
|
response.json()
|
|
)
|
|
# Look for the bot by name
|
|
for bot_info in bots_response.bots:
|
|
if bot_info.name == bot_name:
|
|
bot_has_media = bot_info.has_media
|
|
break
|
|
except Exception:
|
|
# Default to no media if we can't query
|
|
pass
|
|
|
|
if not target_provider_id or target_provider_id not in bot_providers:
|
|
raise HTTPException(status_code=404, detail="Bot or provider not found")
|
|
|
|
provider = bot_providers[target_provider_id]
|
|
|
|
# Get the lobby to validate it exists
|
|
try:
|
|
getLobby(request.lobby_id) # Just validate it exists
|
|
except Exception:
|
|
raise HTTPException(status_code=404, detail="Lobby not found")
|
|
|
|
# Create a session for the bot
|
|
bot_session_id = secrets.token_hex(16)
|
|
|
|
# Create the Session object for the bot
|
|
bot_session = Session(bot_session_id, is_bot=True, has_media=bot_has_media)
|
|
logger.info(
|
|
f"Created bot session for: {bot_session.getName()} (has_media={bot_has_media})"
|
|
)
|
|
|
|
# Determine server URL for the bot to connect back to
|
|
# Use the server's public URL or construct from request
|
|
server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000")
|
|
if server_base_url.endswith("/"):
|
|
server_base_url = server_base_url[:-1]
|
|
|
|
bot_nick = request.nick or f"{bot_name}-bot-{bot_session_id[:8]}"
|
|
|
|
# Prepare the join request for the bot provider
|
|
bot_join_payload = BotJoinPayload(
|
|
lobby_id=request.lobby_id,
|
|
session_id=bot_session_id,
|
|
nick=bot_nick,
|
|
server_url=f"{server_base_url}{public_url}".rstrip("/"),
|
|
insecure=True, # Accept self-signed certificates in development
|
|
)
|
|
|
|
try:
|
|
# Make request to bot provider
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
f"{provider.base_url}/bots/{bot_name}/join",
|
|
json=bot_join_payload.model_dump(),
|
|
timeout=10.0,
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
# Use Pydantic model to parse and validate response
|
|
try:
|
|
join_response = BotProviderJoinResponse.model_validate(
|
|
response.json()
|
|
)
|
|
run_id = join_response.run_id
|
|
|
|
# Update bot session with run and provider information
|
|
with bot_session.session_lock:
|
|
bot_session.bot_run_id = run_id
|
|
bot_session.bot_provider_id = target_provider_id
|
|
bot_session.setName(bot_nick)
|
|
|
|
logger.info(
|
|
f"Bot {bot_name} requested to join lobby {request.lobby_id}"
|
|
)
|
|
|
|
return BotJoinLobbyResponse(
|
|
status="requested",
|
|
bot_name=bot_name,
|
|
run_id=run_id,
|
|
provider_id=target_provider_id,
|
|
)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid response from bot provider: {e}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=f"Bot provider returned invalid response: {str(e)}",
|
|
)
|
|
else:
|
|
logger.error(
|
|
f"Bot provider returned error: HTTP {response.status_code}: {response.text}"
|
|
)
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=f"Bot provider error: {response.status_code}",
|
|
)
|
|
|
|
except httpx.TimeoutException:
|
|
raise HTTPException(status_code=504, detail="Bot provider timeout")
|
|
except Exception as e:
|
|
logger.error(f"Error requesting bot join: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
|
|
|
|
|
@app.post(public_url + "api/bots/leave", response_model=BotLeaveLobbyResponse)
|
|
async def request_bot_leave_lobby(
|
|
request: BotLeaveLobbyRequest,
|
|
) -> BotLeaveLobbyResponse:
|
|
"""Request a bot to leave from all lobbies and disconnect"""
|
|
|
|
# Find the bot session
|
|
bot_session = getSession(request.session_id)
|
|
if not bot_session:
|
|
raise HTTPException(status_code=404, detail="Bot session not found")
|
|
|
|
if not bot_session.is_bot:
|
|
raise HTTPException(status_code=400, detail="Session is not a bot")
|
|
|
|
run_id = bot_session.bot_run_id
|
|
provider_id = bot_session.bot_provider_id
|
|
|
|
logger.info(f"Requesting bot {bot_session.getName()} to leave all lobbies")
|
|
|
|
# Try to stop the bot at the provider level if we have the information
|
|
if provider_id and run_id and provider_id in bot_providers:
|
|
provider = bot_providers[provider_id]
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
f"{provider.base_url}/bots/runs/{run_id}/stop",
|
|
timeout=5.0,
|
|
)
|
|
if response.status_code == 200:
|
|
logger.info(
|
|
f"Successfully requested bot provider to stop run {run_id}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Bot provider returned error when stopping: HTTP {response.status_code}"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to request bot stop from provider: {e}")
|
|
|
|
# Force disconnect the bot session from all lobbies
|
|
with bot_session.session_lock:
|
|
lobbies_to_part = bot_session.lobbies[:]
|
|
|
|
for lobby in lobbies_to_part:
|
|
try:
|
|
await bot_session.part(lobby)
|
|
except Exception as e:
|
|
logger.warning(f"Error parting bot from lobby {lobby.getName()}: {e}")
|
|
|
|
# Close WebSocket connection if it exists
|
|
if bot_session.ws:
|
|
try:
|
|
await bot_session.ws.close()
|
|
except Exception as e:
|
|
logger.warning(f"Error closing bot WebSocket: {e}")
|
|
bot_session.ws = None
|
|
|
|
return BotLeaveLobbyResponse(
|
|
status="disconnected",
|
|
session_id=request.session_id,
|
|
run_id=run_id,
|
|
)
|
|
|
|
|
|
# Register websocket endpoint directly on app with full public_url path
|
|
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}")
|
|
async def lobby_join(
|
|
websocket: WebSocket,
|
|
lobby_id: str = Path(...),
|
|
session_id: str = Path(...),
|
|
):
|
|
"""WebSocket endpoint for lobby connections - now uses modular WebSocketConnectionManager"""
|
|
if websocket_manager:
|
|
await websocket_manager.handle_connection(websocket, lobby_id, session_id)
|
|
else:
|
|
# Fallback if manager not initialized
|
|
await websocket.accept()
|
|
await websocket.send_json(
|
|
{"type": "error", "data": {"error": "Server not fully initialized"}}
|
|
)
|
|
await websocket.close()
|
|
|
|
|
|
# Serve static files or proxy to frontend development server
|
|
await websocket.send_json(
|
|
{"type": "error", "data": {"error": "Invalid or missing lobby"}}
|
|
)
|
|
await websocket.close()
|
|
return
|
|
if session_id is None:
|
|
await websocket.send_json(
|
|
{"type": "error", "data": {"error": "Invalid or missing session"}}
|
|
)
|
|
await websocket.close()
|
|
return
|
|
session = getSession(session_id)
|
|
if not session:
|
|
# logger.error(f"Invalid session ID {session_id}")
|
|
await websocket.send_json(
|
|
{"type": "error", "data": {"error": f"Invalid session ID {session_id}"}}
|
|
)
|
|
await websocket.close()
|
|
return
|
|
|
|
lobby = None
|
|
try:
|
|
lobby = getLobby(lobby_id)
|
|
except Exception as e:
|
|
await websocket.send_json({"type": "error", "data": {"error": str(e)}})
|
|
await websocket.close()
|
|
return
|
|
|
|
logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})")
|
|
|
|
session.ws = websocket
|
|
session.update_last_used() # Update activity timestamp
|
|
|
|
# Check if session is already in lobby and clean up if needed
|
|
with lobby.lock:
|
|
if session.id in lobby.sessions:
|
|
logger.info(
|
|
f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining."
|
|
)
|
|
try:
|
|
await session.part(lobby)
|
|
await lobby.removeSession(session)
|
|
except Exception as e:
|
|
logger.warning(f"Error cleaning up stale session: {e}")
|
|
|
|
# Notify existing peers about new user
|
|
failed_peers: list[str] = []
|
|
with lobby.lock:
|
|
peer_sessions = list(lobby.sessions.values())
|
|
|
|
for peer_session in peer_sessions:
|
|
if not peer_session.ws:
|
|
logger.warning(
|
|
f"{session.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Marking for removal."
|
|
)
|
|
failed_peers.append(peer_session.id)
|
|
continue
|
|
|
|
logger.info(f"{session.getName()} -> user_joined({peer_session.getName()})")
|
|
try:
|
|
await peer_session.ws.send_json(
|
|
{
|
|
"type": "user_joined",
|
|
"data": {
|
|
"session_id": session.id,
|
|
"name": session.name,
|
|
},
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to notify {peer_session.getName()} of user join: {e}"
|
|
)
|
|
failed_peers.append(peer_session.id)
|
|
|
|
# Clean up failed peers
|
|
with lobby.lock:
|
|
for failed_peer_id in failed_peers:
|
|
if failed_peer_id in lobby.sessions:
|
|
del lobby.sessions[failed_peer_id]
|
|
|
|
try:
|
|
while True:
|
|
packet = await websocket.receive_json()
|
|
session.update_last_used() # Update activity on each message
|
|
type = packet.get("type", None)
|
|
data: dict[str, Any] | None = packet.get("data", None)
|
|
if not type:
|
|
logger.error(f"{session.getName()} - Invalid request: {packet}")
|
|
await websocket.send_json(
|
|
{"type": "error", "data": {"error": "Invalid request"}}
|
|
)
|
|
continue
|
|
# logger.info(f"{session.getName()} <- RAW Rx: {data}")
|
|
match type:
|
|
case "set_name":
|
|
if not data:
|
|
logger.error(f"{session.getName()} - set_name missing data")
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {"error": "set_name missing data"},
|
|
}
|
|
)
|
|
continue
|
|
name = data.get("name")
|
|
password = data.get("password")
|
|
logger.info(f"{session.getName()} <- set_name({name}, {password})")
|
|
if not name:
|
|
logger.error(f"{session.getName()} - Name required")
|
|
await websocket.send_json(
|
|
{"type": "error", "data": {"error": "Name required"}}
|
|
)
|
|
continue
|
|
# Name takeover / password logic
|
|
lname = name.lower()
|
|
|
|
# If name is unused, allow and optionally save password
|
|
if Session.isUniqueName(name):
|
|
# If a password was provided, save it (hash+salt) for this name
|
|
if password:
|
|
salt, hash_hex = _hash_password(password)
|
|
name_passwords[lname] = {"salt": salt, "hash": hash_hex}
|
|
session.setName(name)
|
|
logger.info(f"{session.getName()}: -> update('name', {name})")
|
|
await websocket.send_json(
|
|
{
|
|
"type": "update_name",
|
|
"data": {
|
|
"name": name,
|
|
"protected": True
|
|
if name.lower() in name_passwords
|
|
else False,
|
|
},
|
|
}
|
|
)
|
|
# For any clients in any lobby with this session, update their user lists
|
|
await lobby.update_state()
|
|
continue
|
|
|
|
# Name is taken. Check if a password exists for the name and matches.
|
|
saved_pw = name_passwords.get(lname)
|
|
if not saved_pw and not password:
|
|
logger.warning(
|
|
f"{session.getName()} - Name already taken (no password set)"
|
|
)
|
|
await websocket.send_json(
|
|
{"type": "error", "data": {"error": "Name already taken"}}
|
|
)
|
|
continue
|
|
|
|
if saved_pw and password:
|
|
# Expect structured record with salt+hash only
|
|
match_password = False
|
|
# saved_pw should be a dict[str,str] with 'salt' and 'hash'
|
|
salt = saved_pw.get("salt")
|
|
_, candidate_hash = _hash_password(
|
|
password if password else "", salt_hex=salt
|
|
)
|
|
if candidate_hash == saved_pw.get("hash"):
|
|
match_password = True
|
|
else:
|
|
# No structured password record available
|
|
match_password = False
|
|
else:
|
|
match_password = True # No password set, but name taken and new password - allow takeover
|
|
|
|
if not match_password:
|
|
logger.warning(
|
|
f"{session.getName()} - Name takeover attempted with wrong or missing password"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {
|
|
"error": "Invalid password for name takeover",
|
|
},
|
|
}
|
|
)
|
|
continue
|
|
|
|
# Password matches: perform takeover. Find the current session holding the name.
|
|
# Find the currently existing session (if any) with that name
|
|
displaced = Session.getSessionByName(name)
|
|
if displaced and displaced.id == session.id:
|
|
displaced = None
|
|
|
|
# If found, change displaced session to a unique fallback name and notify peers
|
|
if displaced:
|
|
# Create a unique fallback name
|
|
fallback = f"{displaced.name}-{displaced.short}"
|
|
# Ensure uniqueness
|
|
if not Session.isUniqueName(fallback):
|
|
# append random suffix until unique
|
|
while not Session.isUniqueName(fallback):
|
|
fallback = f"{displaced.name}-{secrets.token_hex(3)}"
|
|
|
|
displaced.setName(fallback)
|
|
displaced.mark_displaced()
|
|
logger.info(
|
|
f"{displaced.getName()} <- displaced by takeover, new name {fallback}"
|
|
)
|
|
# Notify displaced session (if connected)
|
|
if displaced.ws:
|
|
try:
|
|
await displaced.ws.send_json(
|
|
{
|
|
"type": "update_name",
|
|
"data": {
|
|
"name": fallback,
|
|
"protected": False,
|
|
},
|
|
}
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to notify displaced session websocket"
|
|
)
|
|
# Update all lobbies the displaced session was in
|
|
with displaced.session_lock:
|
|
displaced_lobbies = displaced.lobbies[:]
|
|
for d_lobby in displaced_lobbies:
|
|
try:
|
|
await d_lobby.update_state()
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to update lobby state for displaced session"
|
|
)
|
|
|
|
# Now assign the requested name to the current session
|
|
session.setName(name)
|
|
logger.info(
|
|
f"{session.getName()}: -> update('name', {name}) (takeover)"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "update_name",
|
|
"data": {
|
|
"name": name,
|
|
"protected": True
|
|
if name.lower() in name_passwords
|
|
else False,
|
|
},
|
|
}
|
|
)
|
|
# Notify lobbies for this session
|
|
await lobby.update_state()
|
|
|
|
case "list_users":
|
|
await lobby.update_state(session)
|
|
|
|
case "get_chat_messages":
|
|
# Send recent chat messages to the requesting client
|
|
messages = lobby.get_chat_messages(50)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "chat_messages",
|
|
"data": {
|
|
"messages": [msg.model_dump() for msg in messages]
|
|
},
|
|
}
|
|
)
|
|
|
|
case "send_chat_message":
|
|
if not data or "message" not in data:
|
|
logger.error(
|
|
f"{session.getName()} - send_chat_message missing message"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {
|
|
"error": "send_chat_message missing message",
|
|
},
|
|
}
|
|
)
|
|
continue
|
|
|
|
if not session.name:
|
|
logger.error(
|
|
f"{session.getName()} - Cannot send chat message without name"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {
|
|
"error": "Must set name before sending chat messages",
|
|
},
|
|
}
|
|
)
|
|
continue
|
|
|
|
message_text = str(data["message"]).strip()
|
|
if not message_text:
|
|
continue
|
|
|
|
# Add the message to the lobby and broadcast it
|
|
chat_message = lobby.add_chat_message(session, message_text)
|
|
logger.info(
|
|
f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)"
|
|
)
|
|
await lobby.broadcast_chat_message(chat_message)
|
|
|
|
case "join":
|
|
logger.info(f"{session.getName()} <- join({lobby.getName()})")
|
|
await session.join(lobby=lobby)
|
|
|
|
case "part":
|
|
logger.info(f"{session.getName()} <- part {lobby.getName()}")
|
|
await session.part(lobby=lobby)
|
|
|
|
case "relayICECandidate":
|
|
logger.info(f"{session.getName()} <- relayICECandidate")
|
|
if not data:
|
|
logger.error(
|
|
f"{session.getName()} - relayICECandidate missing data"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {"error": "relayICECandidate missing data"},
|
|
}
|
|
)
|
|
continue
|
|
|
|
with session.session_lock:
|
|
if (
|
|
lobby.id not in session.lobby_peers
|
|
or session.id not in lobby.sessions
|
|
):
|
|
logger.error(
|
|
f"{session.short}:{session.name} <- relayICECandidate - Not an RTC peer ({session.id})"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {"error": "Not joined to lobby"},
|
|
}
|
|
)
|
|
continue
|
|
session_peers = session.lobby_peers[lobby.id]
|
|
|
|
peer_id = data.get("peer_id")
|
|
if peer_id not in session_peers:
|
|
logger.error(
|
|
f"{session.getName()} <- relayICECandidate - Not an RTC peer({peer_id}) in {session_peers}"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {
|
|
"error": f"Target peer {peer_id} not found",
|
|
},
|
|
}
|
|
)
|
|
continue
|
|
|
|
candidate = data.get("candidate")
|
|
|
|
message: dict[str, Any] = {
|
|
"type": "iceCandidate",
|
|
"data": {
|
|
"peer_id": session.id,
|
|
"peer_name": session.name,
|
|
"candidate": candidate,
|
|
},
|
|
}
|
|
|
|
peer_session = lobby.getSession(peer_id)
|
|
if not peer_session or not peer_session.ws:
|
|
logger.warning(
|
|
f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}."
|
|
)
|
|
continue
|
|
logger.info(
|
|
f"{session.getName()} -> iceCandidate({peer_session.getName()})"
|
|
)
|
|
try:
|
|
await peer_session.ws.send_json(message)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to relay ICE candidate: {e}")
|
|
|
|
case "relaySessionDescription":
|
|
logger.info(f"{session.getName()} <- relaySessionDescription")
|
|
if not data:
|
|
logger.error(
|
|
f"{session.getName()} - relaySessionDescription missing data"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {
|
|
"error": "relaySessionDescription missing data",
|
|
},
|
|
}
|
|
)
|
|
continue
|
|
|
|
with session.session_lock:
|
|
if (
|
|
lobby.id not in session.lobby_peers
|
|
or session.id not in lobby.sessions
|
|
):
|
|
logger.error(
|
|
f"{session.short}:{session.name} <- relaySessionDescription - Not an RTC peer ({session.id})"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {"error": "Not joined to lobby"},
|
|
}
|
|
)
|
|
continue
|
|
|
|
lobby_peers = session.lobby_peers[lobby.id]
|
|
|
|
peer_id = data.get("peer_id")
|
|
if peer_id not in lobby_peers:
|
|
logger.error(
|
|
f"{session.getName()} <- relaySessionDescription - Not an RTC peer({peer_id}) in {lobby_peers}"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {
|
|
"error": f"Target peer {peer_id} not found",
|
|
},
|
|
}
|
|
)
|
|
continue
|
|
|
|
if not peer_id:
|
|
logger.error(
|
|
f"{session.getName()} - relaySessionDescription missing peer_id"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {
|
|
"error": "relaySessionDescription missing peer_id",
|
|
},
|
|
}
|
|
)
|
|
continue
|
|
peer_session = lobby.getSession(peer_id)
|
|
if not peer_session or not peer_session.ws:
|
|
logger.warning(
|
|
f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}."
|
|
)
|
|
continue
|
|
|
|
session_description = data.get("session_description")
|
|
message = {
|
|
"type": "sessionDescription",
|
|
"data": {
|
|
"peer_id": session.id,
|
|
"peer_name": session.name,
|
|
"session_description": session_description,
|
|
},
|
|
}
|
|
|
|
logger.info(
|
|
f"{session.getName()} -> sessionDescription({peer_session.getName()})"
|
|
)
|
|
try:
|
|
await peer_session.ws.send_json(message)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to relay session description: {e}")
|
|
|
|
case "status_check":
|
|
# Simple status check - just respond with success to keep connection alive
|
|
logger.debug(f"{session.getName()} <- status_check")
|
|
await websocket.send_json(
|
|
{"type": "status_ok", "data": {"timestamp": time.time()}}
|
|
)
|
|
|
|
case _:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {
|
|
"error": f"Unknown request type: {type}",
|
|
},
|
|
}
|
|
)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"{session.getName()} <- WebSocket disconnected for user.")
|
|
# Cleanup: remove session from lobby and sessions dict
|
|
session.ws = None
|
|
if session.id in lobby.sessions:
|
|
try:
|
|
await session.part(lobby)
|
|
except Exception as e:
|
|
logger.warning(f"Error during websocket disconnect cleanup: {e}")
|
|
|
|
try:
|
|
await lobby.update_state()
|
|
except Exception as e:
|
|
logger.warning(f"Error updating lobby state after disconnect: {e}")
|
|
|
|
# Clean up empty lobbies
|
|
with lobby.lock:
|
|
if not lobby.sessions:
|
|
if lobby.id in lobbies:
|
|
del lobbies[lobby.id]
|
|
logger.info(f"Cleaned up empty lobby {lobby.getName()}")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Unexpected error in websocket handler for {session.getName()}: {e}"
|
|
)
|
|
try:
|
|
await websocket.close()
|
|
except Exception as e:
|
|
pass
|
|
|
|
|
|
# Serve static files or proxy to frontend development server
|
|
PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true"
|
|
client_build_path = os.path.join(os.path.dirname(__file__), "/client/build")
|
|
|
|
if PRODUCTION:
|
|
logger.info(f"Serving static files from: {client_build_path} at {public_url}")
|
|
app.mount(
|
|
public_url, StaticFiles(directory=client_build_path, html=True), name="static"
|
|
)
|
|
|
|
|
|
else:
|
|
logger.info(f"Proxying static files to http://client:3000 at {public_url}")
|
|
|
|
import ssl
|
|
|
|
@app.api_route(
|
|
f"{public_url}{{path:path}}",
|
|
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
|
|
)
|
|
async def proxy_static(request: Request, path: str):
|
|
# Do not proxy API or websocket paths
|
|
if path.startswith("api/") or path.startswith("ws/"):
|
|
return Response(status_code=404)
|
|
url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}/{path}"
|
|
if not path:
|
|
url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}"
|
|
headers = dict(request.headers)
|
|
try:
|
|
# Accept self-signed certs in dev
|
|
async with httpx.AsyncClient(verify=False) as client:
|
|
proxy_req = client.build_request(
|
|
request.method, url, headers=headers, content=await request.body()
|
|
)
|
|
proxy_resp = await client.send(proxy_req, stream=True)
|
|
content = await proxy_resp.aread()
|
|
|
|
# Remove problematic headers for browser decoding
|
|
filtered_headers = {
|
|
k: v
|
|
for k, v in proxy_resp.headers.items()
|
|
if k.lower()
|
|
not in ["content-encoding", "transfer-encoding", "content-length"]
|
|
}
|
|
return Response(
|
|
content=content,
|
|
status_code=proxy_resp.status_code,
|
|
headers=filtered_headers,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Proxy error for {url}: {e}")
|
|
return Response("Proxy error", status_code=502)
|
|
|
|
# WebSocket proxy for /ws (for React DevTools, etc.)
|
|
import websockets
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_proxy(websocket: WebSocket):
|
|
logger.info("REACT: WebSocket proxy connection established.")
|
|
# Get scheme from websocket.url (should be 'ws' or 'wss')
|
|
scheme = websocket.url.scheme if hasattr(websocket, "url") else "ws"
|
|
target_url = "wss://client:3000/ws" # Use WSS since client uses HTTPS
|
|
await websocket.accept()
|
|
try:
|
|
# Accept self-signed certs in dev for WSS
|
|
ssl_ctx = ssl.create_default_context()
|
|
ssl_ctx.check_hostname = False
|
|
ssl_ctx.verify_mode = ssl.CERT_NONE
|
|
async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws:
|
|
|
|
async def client_to_server():
|
|
while True:
|
|
msg = await websocket.receive_text()
|
|
await target_ws.send(msg)
|
|
|
|
async def server_to_client():
|
|
while True:
|
|
msg = await target_ws.recv()
|
|
if isinstance(msg, str):
|
|
await websocket.send_text(msg)
|
|
else:
|
|
await websocket.send_bytes(msg)
|
|
|
|
try:
|
|
await asyncio.gather(client_to_server(), server_to_client())
|
|
except (WebSocketDisconnect, websockets.ConnectionClosed):
|
|
logger.info("REACT: WebSocket proxy connection closed.")
|
|
except Exception as e:
|
|
logger.error(f"REACT: WebSocket proxy error: {e}")
|
|
await websocket.close()
|