ai-voicebot/server/main_working.py

2339 lines
90 KiB
Python

from __future__ import annotations
from typing import Any, Optional, List
from fastapi import (
Body,
Cookie,
FastAPI,
HTTPException,
Path,
WebSocket,
Request,
Response,
WebSocketDisconnect,
)
import secrets
import os
import json
import hashlib
import binascii
import sys
import asyncio
import threading
import time
from contextlib import asynccontextmanager
from fastapi.staticfiles import StaticFiles
import httpx
from pydantic import ValidationError
from logger import logger
# Import shared models
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.models import (
HealthResponse,
LobbiesResponse,
LobbyCreateRequest,
LobbyCreateResponse,
LobbyListItem,
LobbyModel,
NamePasswordRecord,
LobbySaved,
SessionResponse,
SessionSaved,
SessionsPayload,
AdminNamesResponse,
AdminActionResponse,
AdminSetPassword,
AdminClearPassword,
AdminValidationResponse,
AdminMetricsResponse,
AdminMetricsConfig,
JoinStatusModel,
ChatMessageModel,
ChatMessagesResponse,
ParticipantModel,
# Bot provider models
BotProviderModel,
BotProviderRegisterRequest,
BotProviderRegisterResponse,
BotProviderListResponse,
BotListResponse,
BotInfoModel,
BotJoinLobbyRequest,
BotJoinLobbyResponse,
BotJoinPayload,
BotLeaveLobbyRequest,
BotLeaveLobbyResponse,
BotProviderBotsResponse,
BotProviderJoinResponse,
)
class SessionConfig:
"""Configuration class for session management"""
ANONYMOUS_SESSION_TIMEOUT = int(
os.getenv("ANONYMOUS_SESSION_TIMEOUT", "60")
) # 1 minute
DISPLACED_SESSION_TIMEOUT = int(
os.getenv("DISPLACED_SESSION_TIMEOUT", "10800")
) # 3 hours
CLEANUP_INTERVAL = int(os.getenv("CLEANUP_INTERVAL", "300")) # 5 minutes
MAX_SESSIONS_PER_CLEANUP = int(
os.getenv("MAX_SESSIONS_PER_CLEANUP", "100")
) # Circuit breaker
MAX_CHAT_MESSAGES_PER_LOBBY = int(os.getenv("MAX_CHAT_MESSAGES_PER_LOBBY", "100"))
SESSION_VALIDATION_INTERVAL = int(
os.getenv("SESSION_VALIDATION_INTERVAL", "1800")
) # 30 minutes
class BotProviderConfig:
"""Configuration class for bot provider management"""
# Comma-separated list of allowed provider keys
# Format: "key1:name1,key2:name2" or just "key1,key2" (names default to keys)
ALLOWED_PROVIDERS = os.getenv("BOT_PROVIDER_KEYS", "")
@classmethod
def get_allowed_providers(cls) -> dict[str, str]:
"""Parse allowed providers from environment variable
Returns:
dict mapping provider_key -> provider_name
"""
if not cls.ALLOWED_PROVIDERS.strip():
return {}
providers: dict[str, str] = {}
for entry in cls.ALLOWED_PROVIDERS.split(","):
entry = entry.strip()
if not entry:
continue
if ":" in entry:
key, name = entry.split(":", 1)
providers[key.strip()] = name.strip()
else:
providers[entry] = entry
return providers
# Thread lock for session operations
session_lock = threading.RLock()
# Mapping of reserved names to password records (lowercased name -> {salt:..., hash:...})
name_passwords: dict[str, dict[str, str]] = {}
# Bot provider registry: provider_id -> BotProviderModel
bot_providers: dict[str, BotProviderModel] = {}
all_label = "[ all ]"
info_label = "[ info ]"
todo_label = "[ todo ]"
unset_label = "[ ---- ]"
def _hash_password(password: str, salt_hex: str | None = None) -> tuple[str, str]:
"""Return (salt_hex, hash_hex) for the given password. If salt_hex is provided
it is used; otherwise a new salt is generated."""
if salt_hex:
salt = binascii.unhexlify(salt_hex)
else:
salt = secrets.token_bytes(16)
salt_hex = binascii.hexlify(salt).decode()
dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100000)
hash_hex = binascii.hexlify(dk).decode()
return salt_hex, hash_hex
public_url = os.getenv("PUBLIC_URL", "/")
if not public_url.endswith("/"):
public_url += "/"
# Global variables to control background tasks
cleanup_task_running = False
cleanup_task = None
validation_task_running = False
validation_task = None
async def periodic_cleanup():
"""Background task to periodically clean up old sessions"""
global cleanup_task_running
cleanup_errors = 0
max_consecutive_errors = 5
while cleanup_task_running:
try:
removed_count = Session.cleanup_old_sessions()
if removed_count > 0:
logger.info(f"Periodic cleanup removed {removed_count} old sessions")
cleanup_errors = 0 # Reset error counter on success
# Run cleanup at configured interval
await asyncio.sleep(SessionConfig.CLEANUP_INTERVAL)
except Exception as e:
cleanup_errors += 1
logger.error(
f"Error in session cleanup task (attempt {cleanup_errors}): {e}"
)
if cleanup_errors >= max_consecutive_errors:
logger.error(
f"Too many consecutive cleanup errors ({cleanup_errors}), stopping cleanup task"
)
break
# Exponential backoff on errors
await asyncio.sleep(min(60 * cleanup_errors, 300))
async def periodic_validation():
"""Background task to periodically validate session integrity"""
global validation_task_running
while validation_task_running:
try:
issues = Session.validate_session_integrity()
if issues:
logger.warning(f"Session integrity issues found: {len(issues)} issues")
for issue in issues[:10]: # Log first 10 issues
logger.warning(f"Integrity issue: {issue}")
await asyncio.sleep(SessionConfig.SESSION_VALIDATION_INTERVAL)
except Exception as e:
logger.error(f"Error in session validation task: {e}")
await asyncio.sleep(300) # Wait 5 minutes before retrying on error
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
global cleanup_task_running, cleanup_task, validation_task_running, validation_task
# Startup
logger.info("Starting background tasks...")
cleanup_task_running = True
validation_task_running = True
cleanup_task = asyncio.create_task(periodic_cleanup())
validation_task = asyncio.create_task(periodic_validation())
logger.info("Session cleanup and validation tasks started")
yield
# Shutdown
logger.info("Shutting down background tasks...")
cleanup_task_running = False
validation_task_running = False
# Cancel tasks
for task in [cleanup_task, validation_task]:
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Clean up all sessions gracefully
await Session.cleanup_all_sessions()
logger.info("All background tasks stopped and sessions cleaned up")
app = FastAPI(lifespan=lifespan)
logger.info(f"Starting server with public URL: {public_url}")
logger.info(
f"Session config - Anonymous timeout: {SessionConfig.ANONYMOUS_SESSION_TIMEOUT}s, "
f"Displaced timeout: {SessionConfig.DISPLACED_SESSION_TIMEOUT}s, "
f"Cleanup interval: {SessionConfig.CLEANUP_INTERVAL}s"
)
# Log bot provider configuration
allowed_providers = BotProviderConfig.get_allowed_providers()
if allowed_providers:
logger.info(
f"Bot provider authentication enabled. Allowed providers: {list(allowed_providers.keys())}"
)
else:
logger.warning("Bot provider authentication disabled. Any provider can register.")
# Optional admin token to protect admin endpoints
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", None)
def _require_admin(request: Request) -> bool:
if not ADMIN_TOKEN:
return True
token = request.headers.get("X-Admin-Token")
return token == ADMIN_TOKEN
@app.get(public_url + "api/admin/names", response_model=AdminNamesResponse)
def admin_list_names(request: Request):
if not _require_admin(request):
return Response(status_code=403)
# Convert dict format to Pydantic models
name_passwords_models = {
name: NamePasswordRecord(**record) for name, record in name_passwords.items()
}
return AdminNamesResponse(name_passwords=name_passwords_models)
@app.post(public_url + "api/admin/set_password", response_model=AdminActionResponse)
def admin_set_password(request: Request, payload: AdminSetPassword = Body(...)):
if not _require_admin(request):
return Response(status_code=403)
lname = payload.name.lower()
salt, hash_hex = _hash_password(payload.password)
name_passwords[lname] = {"salt": salt, "hash": hash_hex}
Session.save()
return AdminActionResponse(status="ok", name=payload.name)
@app.post(public_url + "api/admin/clear_password", response_model=AdminActionResponse)
def admin_clear_password(request: Request, payload: AdminClearPassword = Body(...)):
if not _require_admin(request):
return Response(status_code=403)
lname = payload.name.lower()
if lname in name_passwords:
del name_passwords[lname]
Session.save()
return AdminActionResponse(status="ok", name=payload.name)
return AdminActionResponse(status="not_found", name=payload.name)
@app.post(public_url + "api/admin/cleanup_sessions", response_model=AdminActionResponse)
def admin_cleanup_sessions(request: Request):
if not _require_admin(request):
return Response(status_code=403)
try:
removed_count = Session.cleanup_old_sessions()
return AdminActionResponse(
status="ok", name=f"Removed {removed_count} sessions"
)
except Exception as e:
logger.error(f"Error during manual session cleanup: {e}")
return AdminActionResponse(status="error", name=f"Error: {str(e)}")
@app.get(public_url + "api/admin/session_metrics", response_model=AdminMetricsResponse)
def admin_session_metrics(request: Request):
if not _require_admin(request):
return Response(status_code=403)
try:
return Session.get_cleanup_metrics()
except Exception as e:
logger.error(f"Error getting session metrics: {e}")
return Response(status_code=500)
@app.get(
public_url + "api/admin/validate_sessions", response_model=AdminValidationResponse
)
def admin_validate_sessions(request: Request):
if not _require_admin(request):
return Response(status_code=403)
try:
issues = Session.validate_session_integrity()
return AdminValidationResponse(
status="ok", issues=issues, issue_count=len(issues)
)
except Exception as e:
logger.error(f"Error validating sessions: {e}")
return AdminValidationResponse(status="error", error=str(e))
lobbies: dict[str, Lobby] = {}
class Lobby:
def __init__(self, name: str, id: str | None = None, private: bool = False):
self.id = secrets.token_hex(16) if id is None else id
self.short = self.id[:8]
self.name = name
self.sessions: dict[str, Session] = {} # All lobby members
self.private = private
self.chat_messages: list[ChatMessageModel] = [] # Store chat messages
self.lock = threading.RLock() # Thread safety for lobby operations
def getName(self) -> str:
return f"{self.short}:{self.name}"
async def update_state(self, requesting_session: Session | None = None):
with self.lock:
users: list[ParticipantModel] = [
ParticipantModel(
name=s.name,
live=True if s.ws else False,
session_id=s.id,
protected=True
if s.name and s.name.lower() in name_passwords
else False,
is_bot=s.is_bot,
has_media=s.has_media,
bot_run_id=s.bot_run_id,
bot_provider_id=s.bot_provider_id,
)
for s in self.sessions.values()
if s.name
]
if requesting_session:
logger.info(
f"{requesting_session.getName()} -> lobby_state({self.getName()})"
)
if requesting_session.ws:
try:
await requesting_session.ws.send_json(
{
"type": "lobby_state",
"data": {
"participants": [user.model_dump() for user in users]
},
}
)
except Exception as e:
logger.warning(
f"Failed to send lobby state to {requesting_session.getName()}: {e}"
)
else:
logger.warning(
f"{requesting_session.getName()} - No WebSocket connection."
)
else:
# Send to all sessions in lobby
failed_sessions: list[Session] = []
for s in self.sessions.values():
logger.info(f"{s.getName()} -> lobby_state({self.getName()})")
if s.ws:
try:
await s.ws.send_json(
{
"type": "lobby_state",
"data": {
"participants": [
user.model_dump() for user in users
]
},
}
)
except Exception as e:
logger.warning(
f"Failed to send lobby state to {s.getName()}: {e}"
)
failed_sessions.append(s)
# Clean up failed sessions
for failed_session in failed_sessions:
failed_session.ws = None
def getSession(self, id: str) -> Session | None:
with self.lock:
return self.sessions.get(id, None)
async def addSession(self, session: Session) -> None:
with self.lock:
if session.id in self.sessions:
logger.warning(
f"{session.getName()} - Already in lobby {self.getName()}."
)
return None
self.sessions[session.id] = session
await self.update_state()
async def removeSession(self, session: Session) -> None:
with self.lock:
if session.id not in self.sessions:
logger.warning(f"{session.getName()} - Not in lobby {self.getName()}.")
return None
del self.sessions[session.id]
await self.update_state()
def add_chat_message(self, session: Session, message: str) -> ChatMessageModel:
"""Add a chat message to the lobby and return the message data"""
with self.lock:
chat_message = ChatMessageModel(
id=secrets.token_hex(8),
message=message,
sender_name=session.name or session.short,
sender_session_id=session.id,
timestamp=time.time(),
lobby_id=self.id,
)
self.chat_messages.append(chat_message)
# Keep only the latest messages per lobby
if len(self.chat_messages) > SessionConfig.MAX_CHAT_MESSAGES_PER_LOBBY:
self.chat_messages = self.chat_messages[
-SessionConfig.MAX_CHAT_MESSAGES_PER_LOBBY :
]
return chat_message
def get_chat_messages(self, limit: int = 50) -> list[ChatMessageModel]:
"""Get the most recent chat messages from the lobby"""
with self.lock:
return self.chat_messages[-limit:] if self.chat_messages else []
async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None:
"""Broadcast a chat message to all connected sessions in the lobby"""
failed_sessions: list[Session] = []
for peer in self.sessions.values():
if peer.ws:
try:
logger.info(f"{self.getName()} -> chat_message({peer.getName()})")
await peer.ws.send_json(
{"type": "chat_message", "data": chat_message.model_dump()}
)
except Exception as e:
logger.warning(
f"Failed to send chat message to {peer.getName()}: {e}"
)
failed_sessions.append(peer)
# Clean up failed sessions
for failed_session in failed_sessions:
failed_session.ws = None
class Session:
_instances: list[Session] = []
_save_file = "sessions.json"
_loaded = False
lock = threading.RLock() # Thread safety for class-level operations
def __init__(self, id: str, is_bot: bool = False, has_media: bool = True):
logger.info(
f"Instantiating new session {id} (bot: {is_bot}, media: {has_media})"
)
with Session.lock:
self._instances.append(self)
self.id = id
self.short = id[:8]
self.name = ""
self.lobbies: list[Lobby] = [] # List of lobby IDs this session is in
self.lobby_peers: dict[
str, list[str]
] = {} # lobby ID -> list of peer session IDs
self.ws: WebSocket | None = None
self.created_at = time.time()
self.last_used = time.time()
self.displaced_at: float | None = None # When name was taken over
self.is_bot = is_bot # Whether this session represents a bot
self.has_media = has_media # Whether this session provides audio/video streams
self.bot_run_id: str | None = None # Bot run ID for tracking
self.bot_provider_id: str | None = None # Bot provider ID
self.session_lock = threading.RLock() # Instance-level lock
self.save()
@classmethod
def save(cls):
try:
with cls.lock:
sessions_list: list[SessionSaved] = []
for s in cls._instances:
with s.session_lock:
lobbies_list: list[LobbySaved] = [
LobbySaved(
id=lobby.id, name=lobby.name, private=lobby.private
)
for lobby in s.lobbies
]
sessions_list.append(
SessionSaved(
id=s.id,
name=s.name or "",
lobbies=lobbies_list,
created_at=s.created_at,
last_used=s.last_used,
displaced_at=s.displaced_at,
is_bot=s.is_bot,
has_media=s.has_media,
bot_run_id=s.bot_run_id,
bot_provider_id=s.bot_provider_id,
)
)
# Prepare name password store for persistence (salt+hash). Only structured records are supported.
saved_pw: dict[str, NamePasswordRecord] = {
name: NamePasswordRecord(**record)
for name, record in name_passwords.items()
}
payload_model = SessionsPayload(
sessions=sessions_list, name_passwords=saved_pw
)
payload = payload_model.model_dump()
# Atomic write using temp file
temp_file = cls._save_file + ".tmp"
with open(temp_file, "w") as f:
json.dump(payload, f, indent=2)
# Atomic rename
os.rename(temp_file, cls._save_file)
logger.info(
f"Saved {len(sessions_list)} sessions and {len(saved_pw)} name passwords to {cls._save_file}"
)
except Exception as e:
logger.error(f"Failed to save sessions: {e}")
# Clean up temp file if it exists
try:
if os.path.exists(cls._save_file + ".tmp"):
os.remove(cls._save_file + ".tmp")
except Exception as e:
pass
@classmethod
def load(cls):
if not os.path.exists(cls._save_file):
logger.info(f"No session save file found: {cls._save_file}")
return
try:
with open(cls._save_file, "r") as f:
raw = json.load(f)
except Exception as e:
logger.error(f"Failed to read session save file: {e}")
return
try:
payload = SessionsPayload.model_validate(raw)
except ValidationError as e:
logger.exception(f"Failed to validate sessions payload: {e}")
return
# Populate in-memory structures from payload (no backwards compatibility code)
name_passwords.clear()
for name, rec in payload.name_passwords.items():
# rec is a NamePasswordRecord
name_passwords[name] = {"salt": rec.salt, "hash": rec.hash}
current_time = time.time()
sessions_loaded = 0
sessions_expired = 0
with cls.lock:
for s_saved in payload.sessions:
# Check if this session should be expired during loading
created_at = getattr(s_saved, "created_at", time.time())
last_used = getattr(s_saved, "last_used", time.time())
displaced_at = getattr(s_saved, "displaced_at", None)
name = s_saved.name or ""
# Apply same removal criteria as cleanup_old_sessions
should_expire = cls._should_remove_session_static(
name, None, created_at, last_used, displaced_at, current_time
)
if should_expire:
sessions_expired += 1
logger.info(f"Expiring session {s_saved.id[:8]}:{name} during load")
continue # Skip loading this expired session
session = Session(
s_saved.id,
is_bot=getattr(s_saved, "is_bot", False),
has_media=getattr(s_saved, "has_media", True),
)
session.name = name
# Load timestamps, with defaults for backward compatibility
session.created_at = created_at
session.last_used = last_used
session.displaced_at = displaced_at
# Load bot information with defaults for backward compatibility
session.is_bot = getattr(s_saved, "is_bot", False)
session.has_media = getattr(s_saved, "has_media", True)
session.bot_run_id = getattr(s_saved, "bot_run_id", None)
session.bot_provider_id = getattr(s_saved, "bot_provider_id", None)
for lobby_saved in s_saved.lobbies:
session.lobbies.append(
Lobby(
name=lobby_saved.name,
id=lobby_saved.id,
private=lobby_saved.private,
)
)
logger.info(
f"Loaded session {session.getName()} with {len(session.lobbies)} lobbies"
)
for lobby in session.lobbies:
lobbies[lobby.id] = Lobby(
name=lobby.name, id=lobby.id, private=lobby.private
) # Ensure lobby exists
sessions_loaded += 1
logger.info(
f"Loaded {sessions_loaded} sessions and {len(name_passwords)} name passwords from {cls._save_file}"
)
if sessions_expired > 0:
logger.info(f"Expired {sessions_expired} old sessions during load")
# Save immediately to persist the cleanup
cls.save()
@classmethod
def getSession(cls, id: str) -> Session | None:
if not cls._loaded:
cls.load()
logger.info(f"Loaded {len(cls._instances)} sessions from disk...")
cls._loaded = True
with cls.lock:
for s in cls._instances:
if s.id == id:
return s
return None
@classmethod
def isUniqueName(cls, name: str) -> bool:
if not name:
return False
with cls.lock:
for s in cls._instances:
with s.session_lock:
if s.name.lower() == name.lower():
return False
return True
@classmethod
def getSessionByName(cls, name: str) -> Optional["Session"]:
if not name:
return None
lname = name.lower()
with cls.lock:
for s in cls._instances:
with s.session_lock:
if s.name and s.name.lower() == lname:
return s
return None
def getName(self) -> str:
with self.session_lock:
return f"{self.short}:{self.name if self.name else unset_label}"
def setName(self, name: str):
with self.session_lock:
self.name = name
self.update_last_used()
self.save()
def update_last_used(self):
"""Update the last_used timestamp"""
with self.session_lock:
self.last_used = time.time()
def mark_displaced(self):
"""Mark this session as having its name taken over"""
with self.session_lock:
self.displaced_at = time.time()
@staticmethod
def _should_remove_session_static(
name: str,
ws: WebSocket | None,
created_at: float,
last_used: float,
displaced_at: float | None,
current_time: float,
) -> bool:
"""Static method to determine if a session should be removed"""
# Rule 1: Delete sessions with no active connection and no name that are older than threshold
if (
not ws
and not name
and current_time - created_at > SessionConfig.ANONYMOUS_SESSION_TIMEOUT
):
return True
# Rule 2: Delete inactive sessions that had their nick taken over and haven't been used recently
if (
not ws
and displaced_at is not None
and current_time - last_used > SessionConfig.DISPLACED_SESSION_TIMEOUT
):
return True
return False
def _should_remove(self, current_time: float) -> bool:
"""Check if this session should be removed"""
with self.session_lock:
return self._should_remove_session_static(
self.name,
self.ws,
self.created_at,
self.last_used,
self.displaced_at,
current_time,
)
@classmethod
def _remove_session_safely(cls, session: Session, empty_lobbies: set[str]) -> None:
"""Safely remove a session and track affected lobbies"""
try:
with session.session_lock:
# Remove from lobbies first
for lobby in session.lobbies[
:
]: # Copy list to avoid modification during iteration
try:
with lobby.lock:
if session.id in lobby.sessions:
del lobby.sessions[session.id]
if len(lobby.sessions) == 0:
empty_lobbies.add(lobby.id)
if lobby.id in session.lobby_peers:
del session.lobby_peers[lobby.id]
except Exception as e:
logger.warning(
f"Error removing session {session.getName()} from lobby {lobby.getName()}: {e}"
)
# Close WebSocket if open
if session.ws:
try:
asyncio.create_task(session.ws.close())
except Exception as e:
logger.warning(
f"Error closing WebSocket for {session.getName()}: {e}"
)
session.ws = None
# Remove from instances list
with cls.lock:
if session in cls._instances:
cls._instances.remove(session)
except Exception as e:
logger.error(
f"Error during safe session removal for {session.getName()}: {e}"
)
@classmethod
def _cleanup_empty_lobbies(cls, empty_lobbies: set[str]) -> int:
"""Clean up empty lobbies from global lobbies dict"""
removed_count = 0
for lobby_id in empty_lobbies:
if lobby_id in lobbies:
lobby_name = lobbies[lobby_id].getName()
del lobbies[lobby_id]
logger.info(f"Removed empty lobby {lobby_name}")
removed_count += 1
return removed_count
@classmethod
def cleanup_old_sessions(cls) -> int:
"""Clean up old sessions based on the specified criteria with improved safety"""
current_time = time.time()
sessions_removed = 0
try:
# Circuit breaker - don't remove too many sessions at once
sessions_to_remove: list[Session] = []
empty_lobbies: set[str] = set()
with cls.lock:
# Identify sessions to remove (up to max limit)
for session in cls._instances[:]:
if (
len(sessions_to_remove)
>= SessionConfig.MAX_SESSIONS_PER_CLEANUP
):
logger.warning(
f"Hit session cleanup limit ({SessionConfig.MAX_SESSIONS_PER_CLEANUP}), "
f"stopping cleanup. Remaining sessions will be cleaned up in next cycle."
)
break
if session._should_remove(current_time):
sessions_to_remove.append(session)
logger.info(
f"Marking session {session.getName()} for removal - "
f"criteria: no_ws={session.ws is None}, no_name={not session.name}, "
f"age={current_time - session.created_at:.0f}s, "
f"displaced={session.displaced_at is not None}, "
f"unused={current_time - session.last_used:.0f}s"
)
# Remove the identified sessions
for session in sessions_to_remove:
cls._remove_session_safely(session, empty_lobbies)
sessions_removed += 1
# Clean up empty lobbies
empty_lobbies_removed = cls._cleanup_empty_lobbies(empty_lobbies)
# Save state if we made changes
if sessions_removed > 0:
cls.save()
logger.info(
f"Session cleanup completed: removed {sessions_removed} sessions, "
f"{empty_lobbies_removed} empty lobbies"
)
except Exception as e:
logger.error(f"Error during session cleanup: {e}")
# Don't re-raise - cleanup should be resilient
return sessions_removed
@classmethod
def get_cleanup_metrics(cls) -> AdminMetricsResponse:
"""Return cleanup metrics for monitoring"""
current_time = time.time()
with cls.lock:
total_sessions = len(cls._instances)
active_sessions = 0
named_sessions = 0
displaced_sessions = 0
old_anonymous = 0
old_displaced = 0
for s in cls._instances:
with s.session_lock:
if s.ws:
active_sessions += 1
if s.name:
named_sessions += 1
if s.displaced_at is not None:
displaced_sessions += 1
if (
not s.ws
and current_time - s.last_used
> SessionConfig.DISPLACED_SESSION_TIMEOUT
):
old_displaced += 1
if (
not s.ws
and not s.name
and current_time - s.created_at
> SessionConfig.ANONYMOUS_SESSION_TIMEOUT
):
old_anonymous += 1
config = AdminMetricsConfig(
anonymous_timeout=SessionConfig.ANONYMOUS_SESSION_TIMEOUT,
displaced_timeout=SessionConfig.DISPLACED_SESSION_TIMEOUT,
cleanup_interval=SessionConfig.CLEANUP_INTERVAL,
max_cleanup_per_cycle=SessionConfig.MAX_SESSIONS_PER_CLEANUP,
)
return AdminMetricsResponse(
total_sessions=total_sessions,
active_sessions=active_sessions,
named_sessions=named_sessions,
displaced_sessions=displaced_sessions,
old_anonymous_sessions=old_anonymous,
old_displaced_sessions=old_displaced,
total_lobbies=len(lobbies),
cleanup_candidates=old_anonymous + old_displaced,
config=config,
)
@classmethod
def validate_session_integrity(cls) -> list[str]:
"""Validate session data integrity"""
issues: list[str] = []
try:
with cls.lock:
for session in cls._instances:
with session.session_lock:
# Check for orphaned lobby references
for lobby in session.lobbies:
if lobby.id not in lobbies:
issues.append(
f"Session {session.id[:8]}:{session.name} references missing lobby {lobby.id}"
)
# Check for inconsistent peer relationships
for lobby_id, peer_ids in session.lobby_peers.items():
lobby = lobbies.get(lobby_id)
if lobby:
with lobby.lock:
if session.id not in lobby.sessions:
issues.append(
f"Session {session.id[:8]}:{session.name} has peers in lobby {lobby_id} but not in lobby.sessions"
)
# Check if peer sessions actually exist
for peer_id in peer_ids:
if peer_id not in lobby.sessions:
issues.append(
f"Session {session.id[:8]}:{session.name} references non-existent peer {peer_id} in lobby {lobby_id}"
)
else:
issues.append(
f"Session {session.id[:8]}:{session.name} has peer list for non-existent lobby {lobby_id}"
)
# Check lobbies for consistency
for lobby_id, lobby in lobbies.items():
with lobby.lock:
for session_id in lobby.sessions:
found_session = None
for s in cls._instances:
if s.id == session_id:
found_session = s
break
if not found_session:
issues.append(
f"Lobby {lobby_id} references non-existent session {session_id}"
)
else:
with found_session.session_lock:
if lobby not in found_session.lobbies:
issues.append(
f"Lobby {lobby_id} contains session {session_id} but session doesn't reference lobby"
)
except Exception as e:
logger.error(f"Error during session validation: {e}")
issues.append(f"Validation error: {str(e)}")
return issues
@classmethod
async def cleanup_all_sessions(cls):
"""Clean up all sessions during shutdown"""
logger.info("Starting graceful session cleanup...")
try:
with cls.lock:
sessions_to_cleanup = cls._instances[:]
for session in sessions_to_cleanup:
try:
with session.session_lock:
# Close WebSocket connections
if session.ws:
try:
await session.ws.close()
except Exception as e:
logger.warning(
f"Error closing WebSocket for {session.getName()}: {e}"
)
session.ws = None
# Remove from lobbies
for lobby in session.lobbies[:]:
try:
await session.part(lobby)
except Exception as e:
logger.warning(
f"Error removing {session.getName()} from lobby: {e}"
)
except Exception as e:
logger.error(f"Error cleaning up session {session.getName()}: {e}")
# Clear all data structures
with cls.lock:
cls._instances.clear()
lobbies.clear()
logger.info(
f"Graceful session cleanup completed for {len(sessions_to_cleanup)} sessions"
)
except Exception as e:
logger.error(f"Error during graceful session cleanup: {e}")
async def join(self, lobby: Lobby):
if not self.ws:
logger.error(
f"{self.getName()} - No WebSocket connection. Lobby not available."
)
return
with self.session_lock:
if lobby.id in self.lobby_peers or self.id in lobby.sessions:
logger.info(f"{self.getName()} - Already joined to {lobby.getName()}.")
data = JoinStatusModel(
status="Joined",
message=f"Already joined to lobby {lobby.getName()}",
)
try:
await self.ws.send_json(
{"type": "join_status", "data": data.model_dump()}
)
except Exception as e:
logger.warning(
f"Failed to send join status to {self.getName()}: {e}"
)
return
# Initialize the peer list for this lobby
with self.session_lock:
self.lobbies.append(lobby)
self.lobby_peers[lobby.id] = []
with lobby.lock:
peer_sessions = list(lobby.sessions.values())
for peer_session in peer_sessions:
if peer_session.id == self.id:
logger.error(
"Should not happen: self in lobby.sessions while not in lobby."
)
continue
if not peer_session.ws:
logger.warning(
f"{self.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Removing."
)
with lobby.lock:
if peer_session.id in lobby.sessions:
del lobby.sessions[peer_session.id]
continue
# Only create WebRTC peer connections if at least one participant has media
should_create_rtc_connection = self.has_media or peer_session.has_media
if should_create_rtc_connection:
# Add the peer to session's RTC peer list
with self.session_lock:
self.lobby_peers[lobby.id].append(peer_session.id)
# Add this user as an RTC peer to each existing peer
with peer_session.session_lock:
if lobby.id not in peer_session.lobby_peers:
peer_session.lobby_peers[lobby.id] = []
peer_session.lobby_peers[lobby.id].append(self.id)
logger.info(
f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName()}, {lobby.getName()}, should_create_offer=False, has_media={self.has_media})"
)
try:
await peer_session.ws.send_json(
{
"type": "addPeer",
"data": {
"peer_id": self.id,
"peer_name": self.name,
"has_media": self.has_media,
"should_create_offer": False,
},
}
)
except Exception as e:
logger.warning(
f"Failed to send addPeer to {peer_session.getName()}: {e}"
)
# Add each other peer to the caller
logger.info(
f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName()}, {lobby.getName()}, should_create_offer=True, has_media={peer_session.has_media})"
)
try:
await self.ws.send_json(
{
"type": "addPeer",
"data": {
"peer_id": peer_session.id,
"peer_name": peer_session.name,
"has_media": peer_session.has_media,
"should_create_offer": True,
},
}
)
except Exception as e:
logger.warning(f"Failed to send addPeer to {self.getName()}: {e}")
else:
logger.info(
f"{self.getName()} - Skipping WebRTC connection with {peer_session.getName()} (neither has media: self={self.has_media}, peer={peer_session.has_media})"
)
# Add this user as an RTC peer
await lobby.addSession(self)
Session.save()
try:
await self.ws.send_json(
{"type": "join_status", "data": {"status": "Joined"}}
)
except Exception as e:
logger.warning(f"Failed to send join confirmation to {self.getName()}: {e}")
async def part(self, lobby: Lobby):
with self.session_lock:
if lobby.id not in self.lobby_peers or self.id not in lobby.sessions:
logger.info(
f"{self.getName()} - Attempt to part non-joined lobby {lobby.getName()}."
)
if self.ws:
try:
await self.ws.send_json(
{
"type": "error",
"data": {
"error": "Attempt to part non-joined lobby",
},
}
)
except Exception:
pass
return
logger.info(f"{self.getName()} <- part({lobby.getName()}) - Lobby part.")
lobby_peers = self.lobby_peers[lobby.id][:] # Copy the list
del self.lobby_peers[lobby.id]
if lobby in self.lobbies:
self.lobbies.remove(lobby)
# Remove this peer from all other RTC peers, and remove each peer from this peer
for peer_session_id in lobby_peers:
peer_session = getSession(peer_session_id)
if not peer_session:
logger.warning(
f"{self.getName()} <- part({lobby.getName()}) - Peer session {peer_session_id} not found. Skipping."
)
continue
if peer_session.ws:
logger.info(
f"{peer_session.getName()} <- remove_peer({self.getName()})"
)
try:
await peer_session.ws.send_json(
{
"type": "removePeer",
"data": {"peer_name": self.name, "peer_id": self.id},
}
)
except Exception as e:
logger.warning(
f"Failed to send removePeer to {peer_session.getName()}: {e}"
)
else:
logger.warning(
f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection for {peer_session.getName()}. Skipping."
)
# Remove from peer's lobby_peers
with peer_session.session_lock:
if (
lobby.id in peer_session.lobby_peers
and self.id in peer_session.lobby_peers[lobby.id]
):
peer_session.lobby_peers[lobby.id].remove(self.id)
if self.ws:
logger.info(
f"{self.getName()} <- remove_peer({peer_session.getName()})"
)
try:
await self.ws.send_json(
{
"type": "removePeer",
"data": {
"peer_name": peer_session.name,
"peer_id": peer_session.id,
},
}
)
except Exception as e:
logger.warning(
f"Failed to send removePeer to {self.getName()}: {e}"
)
else:
logger.error(
f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection."
)
await lobby.removeSession(self)
Session.save()
def getName(session: Session | None) -> str | None:
if session and session.name:
return session.name
return None
def getSession(session_id: str) -> Session | None:
return Session.getSession(session_id)
def getLobby(lobby_id: str) -> Lobby:
lobby = lobbies.get(lobby_id, None)
if not lobby:
# Check if this might be a stale reference after cleanup
logger.warning(f"Lobby not found: {lobby_id} (may have been cleaned up)")
raise Exception(f"Lobby not found: {lobby_id}")
return lobby
def getLobbyByName(lobby_name: str) -> Lobby | None:
for lobby in lobbies.values():
if lobby.name == lobby_name:
return lobby
return None
# API endpoints
@app.get(f"{public_url}api/health", response_model=HealthResponse)
def health():
logger.info("Health check endpoint called.")
return HealthResponse(status="ok")
# A session (cookie) is bound to a single user (name).
# A user can be in multiple lobbies, but a session is unique to a single user.
# A user can change their name, but the session ID remains the same and the name
# updates for all lobbies.
@app.get(f"{public_url}api/session", response_model=SessionResponse)
async def session(
request: Request, response: Response, session_id: str | None = Cookie(default=None)
) -> Response | SessionResponse:
if session_id is None:
session_id = secrets.token_hex(16)
response.set_cookie(key="session_id", value=session_id)
# Validate that session_id is a hex string of length 32
elif len(session_id) != 32 or not all(c in "0123456789abcdef" for c in session_id):
return Response(
content=json.dumps({"error": "Invalid session_id"}),
status_code=400,
media_type="application/json",
)
print(f"[{session_id[:8]}]: Browser hand-shake achieved.")
session = getSession(session_id)
if not session:
session = Session(session_id)
logger.info(f"{session.getName()}: New session created.")
else:
session.update_last_used() # Update activity on session resumption
logger.info(f"{session.getName()}: Existing session resumed.")
# Part all lobbies for this session that have no active websocket
with session.session_lock:
lobbies_to_part = session.lobbies[:]
for lobby in lobbies_to_part:
try:
await session.part(lobby)
except Exception as e:
logger.error(
f"{session.getName()} - Error parting lobby {lobby.getName()}: {e}"
)
with session.session_lock:
return SessionResponse(
id=session_id,
name=session.name if session.name else "",
lobbies=[
LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private)
for lobby in session.lobbies
],
)
@app.get(public_url + "api/lobby", response_model=LobbiesResponse)
async def get_lobbies(request: Request, response: Response) -> LobbiesResponse:
return LobbiesResponse(
lobbies=[
LobbyListItem(id=lobby.id, name=lobby.name)
for lobby in lobbies.values()
if not lobby.private
]
)
@app.post(public_url + "api/lobby/{session_id}", response_model=LobbyCreateResponse)
async def lobby_create(
request: Request,
response: Response,
session_id: str = Path(...),
create_request: LobbyCreateRequest = Body(...),
) -> Response | LobbyCreateResponse:
if create_request.type != "lobby_create":
return Response(
content=json.dumps({"error": "Invalid request type"}),
status_code=400,
media_type="application/json",
)
data = create_request.data
session = getSession(session_id)
if not session:
return Response(
content=json.dumps({"error": f"Session not found ({session_id})"}),
status_code=404,
media_type="application/json",
)
logger.info(
f"{session.getName()} lobby_create: {data.name} (private={data.private})"
)
lobby = getLobbyByName(data.name)
if not lobby:
lobby = Lobby(
data.name,
private=data.private,
)
lobbies[lobby.id] = lobby
logger.info(f"{session.getName()} <- lobby_create({lobby.short}:{lobby.name})")
return LobbyCreateResponse(
type="lobby_created",
data=LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private),
)
@app.get(public_url + "api/lobby/{lobby_id}/chat", response_model=ChatMessagesResponse)
async def get_chat_messages(
request: Request,
lobby_id: str = Path(...),
limit: int = 50,
) -> Response | ChatMessagesResponse:
"""Get chat messages for a lobby"""
try:
lobby = getLobby(lobby_id)
except Exception as e:
return Response(
content=json.dumps({"error": str(e)}),
status_code=404,
media_type="application/json",
)
messages = lobby.get_chat_messages(limit)
return ChatMessagesResponse(messages=messages)
# =============================================================================
# Bot Provider API Endpoints
# =============================================================================
@app.post(
public_url + "api/bots/providers/register",
response_model=BotProviderRegisterResponse,
)
async def register_bot_provider(
request: BotProviderRegisterRequest,
) -> BotProviderRegisterResponse:
"""Register a new bot provider with authentication"""
import uuid
# Check if provider authentication is enabled
allowed_providers = BotProviderConfig.get_allowed_providers()
if allowed_providers:
# Authentication is enabled - validate provider key
if request.provider_key not in allowed_providers:
logger.warning(
f"Rejected bot provider registration with invalid key: {request.provider_key}"
)
raise HTTPException(
status_code=403,
detail="Invalid provider key. Bot provider is not authorized to register.",
)
# Check if there's already an active provider with this key and remove it
providers_to_remove: list[str] = []
for existing_provider_id, existing_provider in bot_providers.items():
if existing_provider.provider_key == request.provider_key:
providers_to_remove.append(existing_provider_id)
logger.info(
f"Removing stale bot provider: {existing_provider.name} (ID: {existing_provider_id})"
)
# Remove stale providers
for provider_id_to_remove in providers_to_remove:
del bot_providers[provider_id_to_remove]
provider_id = str(uuid.uuid4())
now = time.time()
provider = BotProviderModel(
provider_id=provider_id,
base_url=request.base_url.rstrip("/"),
name=request.name,
description=request.description,
provider_key=request.provider_key,
registered_at=now,
last_seen=now,
)
bot_providers[provider_id] = provider
logger.info(
f"Registered bot provider: {request.name} at {request.base_url} with key: {request.provider_key}"
)
return BotProviderRegisterResponse(provider_id=provider_id)
@app.get(public_url + "api/bots/providers", response_model=BotProviderListResponse)
async def list_bot_providers() -> BotProviderListResponse:
"""List all registered bot providers"""
return BotProviderListResponse(providers=list(bot_providers.values()))
@app.get(public_url + "api/bots", response_model=BotListResponse)
async def list_available_bots() -> BotListResponse:
"""List all available bots from all registered providers"""
bots: List[BotInfoModel] = []
providers: dict[str, str] = {}
# Update last_seen timestamps and fetch bots from each provider
for provider_id, provider in bot_providers.items():
try:
provider.last_seen = time.time()
# Make HTTP request to provider's /bots endpoint
async with httpx.AsyncClient() as client:
response = await client.get(f"{provider.base_url}/bots", timeout=5.0)
if response.status_code == 200:
# Use Pydantic model to validate the response
bots_response = BotProviderBotsResponse.model_validate(
response.json()
)
# Add each bot to the consolidated list
for bot_info in bots_response.bots:
bots.append(bot_info)
providers[bot_info.name] = provider_id
else:
logger.warning(
f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}"
)
except Exception as e:
logger.error(f"Error fetching bots from provider {provider.name}: {e}")
continue
return BotListResponse(bots=bots, providers=providers)
@app.post(public_url + "api/bots/{bot_name}/join", response_model=BotJoinLobbyResponse)
async def request_bot_join_lobby(
bot_name: str, request: BotJoinLobbyRequest
) -> BotJoinLobbyResponse:
"""Request a bot to join a specific lobby"""
# Find which provider has this bot and determine its media capability
target_provider_id = request.provider_id
bot_has_media = False
if not target_provider_id:
# Auto-discover provider for this bot
for provider_id, provider in bot_providers.items():
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{provider.base_url}/bots", timeout=5.0
)
if response.status_code == 200:
# Use Pydantic model to validate the response
bots_response = BotProviderBotsResponse.model_validate(
response.json()
)
# Look for the bot by name
for bot_info in bots_response.bots:
if bot_info.name == bot_name:
target_provider_id = provider_id
bot_has_media = bot_info.has_media
break
if target_provider_id:
break
except Exception:
continue
else:
# Query the specified provider for bot media capability
if target_provider_id in bot_providers:
provider = bot_providers[target_provider_id]
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{provider.base_url}/bots", timeout=5.0
)
if response.status_code == 200:
# Use Pydantic model to validate the response
bots_response = BotProviderBotsResponse.model_validate(
response.json()
)
# Look for the bot by name
for bot_info in bots_response.bots:
if bot_info.name == bot_name:
bot_has_media = bot_info.has_media
break
except Exception:
# Default to no media if we can't query
pass
if not target_provider_id or target_provider_id not in bot_providers:
raise HTTPException(status_code=404, detail="Bot or provider not found")
provider = bot_providers[target_provider_id]
# Get the lobby to validate it exists
try:
getLobby(request.lobby_id) # Just validate it exists
except Exception:
raise HTTPException(status_code=404, detail="Lobby not found")
# Create a session for the bot
bot_session_id = secrets.token_hex(16)
# Create the Session object for the bot
bot_session = Session(bot_session_id, is_bot=True, has_media=bot_has_media)
logger.info(
f"Created bot session for: {bot_session.getName()} (has_media={bot_has_media})"
)
# Determine server URL for the bot to connect back to
# Use the server's public URL or construct from request
server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000")
if server_base_url.endswith("/"):
server_base_url = server_base_url[:-1]
bot_nick = request.nick or f"{bot_name}-bot-{bot_session_id[:8]}"
# Prepare the join request for the bot provider
bot_join_payload = BotJoinPayload(
lobby_id=request.lobby_id,
session_id=bot_session_id,
nick=bot_nick,
server_url=f"{server_base_url}{public_url}".rstrip("/"),
insecure=True, # Accept self-signed certificates in development
)
try:
# Make request to bot provider
async with httpx.AsyncClient() as client:
response = await client.post(
f"{provider.base_url}/bots/{bot_name}/join",
json=bot_join_payload.model_dump(),
timeout=10.0,
)
if response.status_code == 200:
# Use Pydantic model to parse and validate response
try:
join_response = BotProviderJoinResponse.model_validate(
response.json()
)
run_id = join_response.run_id
# Update bot session with run and provider information
with bot_session.session_lock:
bot_session.bot_run_id = run_id
bot_session.bot_provider_id = target_provider_id
bot_session.setName(bot_nick)
logger.info(
f"Bot {bot_name} requested to join lobby {request.lobby_id}"
)
return BotJoinLobbyResponse(
status="requested",
bot_name=bot_name,
run_id=run_id,
provider_id=target_provider_id,
)
except ValidationError as e:
logger.error(f"Invalid response from bot provider: {e}")
raise HTTPException(
status_code=502,
detail=f"Bot provider returned invalid response: {str(e)}",
)
else:
logger.error(
f"Bot provider returned error: HTTP {response.status_code}: {response.text}"
)
raise HTTPException(
status_code=502,
detail=f"Bot provider error: {response.status_code}",
)
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Bot provider timeout")
except Exception as e:
logger.error(f"Error requesting bot join: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post(public_url + "api/bots/leave", response_model=BotLeaveLobbyResponse)
async def request_bot_leave_lobby(
request: BotLeaveLobbyRequest,
) -> BotLeaveLobbyResponse:
"""Request a bot to leave from all lobbies and disconnect"""
# Find the bot session
bot_session = getSession(request.session_id)
if not bot_session:
raise HTTPException(status_code=404, detail="Bot session not found")
if not bot_session.is_bot:
raise HTTPException(status_code=400, detail="Session is not a bot")
run_id = bot_session.bot_run_id
provider_id = bot_session.bot_provider_id
logger.info(f"Requesting bot {bot_session.getName()} to leave all lobbies")
# Try to stop the bot at the provider level if we have the information
if provider_id and run_id and provider_id in bot_providers:
provider = bot_providers[provider_id]
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{provider.base_url}/bots/runs/{run_id}/stop",
timeout=5.0,
)
if response.status_code == 200:
logger.info(
f"Successfully requested bot provider to stop run {run_id}"
)
else:
logger.warning(
f"Bot provider returned error when stopping: HTTP {response.status_code}"
)
except Exception as e:
logger.warning(f"Failed to request bot stop from provider: {e}")
# Force disconnect the bot session from all lobbies
with bot_session.session_lock:
lobbies_to_part = bot_session.lobbies[:]
for lobby in lobbies_to_part:
try:
await bot_session.part(lobby)
except Exception as e:
logger.warning(f"Error parting bot from lobby {lobby.getName()}: {e}")
# Close WebSocket connection if it exists
if bot_session.ws:
try:
await bot_session.ws.close()
except Exception as e:
logger.warning(f"Error closing bot WebSocket: {e}")
bot_session.ws = None
return BotLeaveLobbyResponse(
status="disconnected",
session_id=request.session_id,
run_id=run_id,
)
# Register websocket endpoint directly on app with full public_url path
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}")
async def lobby_join(
websocket: WebSocket,
lobby_id: str | None = Path(...),
session_id: str | None = Path(...),
):
await websocket.accept()
if lobby_id is None:
await websocket.send_json(
{"type": "error", "data": {"error": "Invalid or missing lobby"}}
)
await websocket.close()
return
if session_id is None:
await websocket.send_json(
{"type": "error", "data": {"error": "Invalid or missing session"}}
)
await websocket.close()
return
session = getSession(session_id)
if not session:
# logger.error(f"Invalid session ID {session_id}")
await websocket.send_json(
{"type": "error", "data": {"error": f"Invalid session ID {session_id}"}}
)
await websocket.close()
return
lobby = None
try:
lobby = getLobby(lobby_id)
except Exception as e:
await websocket.send_json({"type": "error", "data": {"error": str(e)}})
await websocket.close()
return
logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})")
session.ws = websocket
session.update_last_used() # Update activity timestamp
# Check if session is already in lobby and clean up if needed
with lobby.lock:
if session.id in lobby.sessions:
logger.info(
f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining."
)
try:
await session.part(lobby)
await lobby.removeSession(session)
except Exception as e:
logger.warning(f"Error cleaning up stale session: {e}")
# Notify existing peers about new user
failed_peers: list[str] = []
with lobby.lock:
peer_sessions = list(lobby.sessions.values())
for peer_session in peer_sessions:
if not peer_session.ws:
logger.warning(
f"{session.getName()} - Live peer session {peer_session.id} not found in lobby {lobby.getName()}. Marking for removal."
)
failed_peers.append(peer_session.id)
continue
logger.info(f"{session.getName()} -> user_joined({peer_session.getName()})")
try:
await peer_session.ws.send_json(
{
"type": "user_joined",
"data": {
"session_id": session.id,
"name": session.name,
},
}
)
except Exception as e:
logger.warning(
f"Failed to notify {peer_session.getName()} of user join: {e}"
)
failed_peers.append(peer_session.id)
# Clean up failed peers
with lobby.lock:
for failed_peer_id in failed_peers:
if failed_peer_id in lobby.sessions:
del lobby.sessions[failed_peer_id]
try:
while True:
packet = await websocket.receive_json()
session.update_last_used() # Update activity on each message
type = packet.get("type", None)
data: dict[str, Any] | None = packet.get("data", None)
if not type:
logger.error(f"{session.getName()} - Invalid request: {packet}")
await websocket.send_json(
{"type": "error", "data": {"error": "Invalid request"}}
)
continue
# logger.info(f"{session.getName()} <- RAW Rx: {data}")
match type:
case "set_name":
if not data:
logger.error(f"{session.getName()} - set_name missing data")
await websocket.send_json(
{
"type": "error",
"data": {"error": "set_name missing data"},
}
)
continue
name = data.get("name")
password = data.get("password")
logger.info(f"{session.getName()} <- set_name({name}, {password})")
if not name:
logger.error(f"{session.getName()} - Name required")
await websocket.send_json(
{"type": "error", "data": {"error": "Name required"}}
)
continue
# Name takeover / password logic
lname = name.lower()
# If name is unused, allow and optionally save password
if Session.isUniqueName(name):
# If a password was provided, save it (hash+salt) for this name
if password:
salt, hash_hex = _hash_password(password)
name_passwords[lname] = {"salt": salt, "hash": hash_hex}
session.setName(name)
logger.info(f"{session.getName()}: -> update('name', {name})")
await websocket.send_json(
{
"type": "update_name",
"data": {
"name": name,
"protected": True
if name.lower() in name_passwords
else False,
},
}
)
# For any clients in any lobby with this session, update their user lists
await lobby.update_state()
continue
# Name is taken. Check if a password exists for the name and matches.
saved_pw = name_passwords.get(lname)
if not saved_pw and not password:
logger.warning(
f"{session.getName()} - Name already taken (no password set)"
)
await websocket.send_json(
{"type": "error", "data": {"error": "Name already taken"}}
)
continue
if saved_pw and password:
# Expect structured record with salt+hash only
match_password = False
# saved_pw should be a dict[str,str] with 'salt' and 'hash'
salt = saved_pw.get("salt")
_, candidate_hash = _hash_password(
password if password else "", salt_hex=salt
)
if candidate_hash == saved_pw.get("hash"):
match_password = True
else:
# No structured password record available
match_password = False
else:
match_password = True # No password set, but name taken and new password - allow takeover
if not match_password:
logger.warning(
f"{session.getName()} - Name takeover attempted with wrong or missing password"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": "Invalid password for name takeover",
},
}
)
continue
# Password matches: perform takeover. Find the current session holding the name.
# Find the currently existing session (if any) with that name
displaced = Session.getSessionByName(name)
if displaced and displaced.id == session.id:
displaced = None
# If found, change displaced session to a unique fallback name and notify peers
if displaced:
# Create a unique fallback name
fallback = f"{displaced.name}-{displaced.short}"
# Ensure uniqueness
if not Session.isUniqueName(fallback):
# append random suffix until unique
while not Session.isUniqueName(fallback):
fallback = f"{displaced.name}-{secrets.token_hex(3)}"
displaced.setName(fallback)
displaced.mark_displaced()
logger.info(
f"{displaced.getName()} <- displaced by takeover, new name {fallback}"
)
# Notify displaced session (if connected)
if displaced.ws:
try:
await displaced.ws.send_json(
{
"type": "update_name",
"data": {
"name": fallback,
"protected": False,
},
}
)
except Exception:
logger.exception(
"Failed to notify displaced session websocket"
)
# Update all lobbies the displaced session was in
with displaced.session_lock:
displaced_lobbies = displaced.lobbies[:]
for d_lobby in displaced_lobbies:
try:
await d_lobby.update_state()
except Exception:
logger.exception(
"Failed to update lobby state for displaced session"
)
# Now assign the requested name to the current session
session.setName(name)
logger.info(
f"{session.getName()}: -> update('name', {name}) (takeover)"
)
await websocket.send_json(
{
"type": "update_name",
"data": {
"name": name,
"protected": True
if name.lower() in name_passwords
else False,
},
}
)
# Notify lobbies for this session
await lobby.update_state()
case "list_users":
await lobby.update_state(session)
case "get_chat_messages":
# Send recent chat messages to the requesting client
messages = lobby.get_chat_messages(50)
await websocket.send_json(
{
"type": "chat_messages",
"data": {
"messages": [msg.model_dump() for msg in messages]
},
}
)
case "send_chat_message":
if not data or "message" not in data:
logger.error(
f"{session.getName()} - send_chat_message missing message"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": "send_chat_message missing message",
},
}
)
continue
if not session.name:
logger.error(
f"{session.getName()} - Cannot send chat message without name"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": "Must set name before sending chat messages",
},
}
)
continue
message_text = str(data["message"]).strip()
if not message_text:
continue
# Add the message to the lobby and broadcast it
chat_message = lobby.add_chat_message(session, message_text)
logger.info(
f"{session.getName()} -> broadcast_chat_message({lobby.getName()}, {message_text[:50]}...)"
)
await lobby.broadcast_chat_message(chat_message)
case "join":
logger.info(f"{session.getName()} <- join({lobby.getName()})")
await session.join(lobby=lobby)
case "part":
logger.info(f"{session.getName()} <- part {lobby.getName()}")
await session.part(lobby=lobby)
case "relayICECandidate":
logger.info(f"{session.getName()} <- relayICECandidate")
if not data:
logger.error(
f"{session.getName()} - relayICECandidate missing data"
)
await websocket.send_json(
{
"type": "error",
"data": {"error": "relayICECandidate missing data"},
}
)
continue
with session.session_lock:
if (
lobby.id not in session.lobby_peers
or session.id not in lobby.sessions
):
logger.error(
f"{session.short}:{session.name} <- relayICECandidate - Not an RTC peer ({session.id})"
)
await websocket.send_json(
{
"type": "error",
"data": {"error": "Not joined to lobby"},
}
)
continue
session_peers = session.lobby_peers[lobby.id]
peer_id = data.get("peer_id")
if peer_id not in session_peers:
logger.error(
f"{session.getName()} <- relayICECandidate - Not an RTC peer({peer_id}) in {session_peers}"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": f"Target peer {peer_id} not found",
},
}
)
continue
candidate = data.get("candidate")
message: dict[str, Any] = {
"type": "iceCandidate",
"data": {
"peer_id": session.id,
"peer_name": session.name,
"candidate": candidate,
},
}
peer_session = lobby.getSession(peer_id)
if not peer_session or not peer_session.ws:
logger.warning(
f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}."
)
continue
logger.info(
f"{session.getName()} -> iceCandidate({peer_session.getName()})"
)
try:
await peer_session.ws.send_json(message)
except Exception as e:
logger.warning(f"Failed to relay ICE candidate: {e}")
case "relaySessionDescription":
logger.info(f"{session.getName()} <- relaySessionDescription")
if not data:
logger.error(
f"{session.getName()} - relaySessionDescription missing data"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": "relaySessionDescription missing data",
},
}
)
continue
with session.session_lock:
if (
lobby.id not in session.lobby_peers
or session.id not in lobby.sessions
):
logger.error(
f"{session.short}:{session.name} <- relaySessionDescription - Not an RTC peer ({session.id})"
)
await websocket.send_json(
{
"type": "error",
"data": {"error": "Not joined to lobby"},
}
)
continue
lobby_peers = session.lobby_peers[lobby.id]
peer_id = data.get("peer_id")
if peer_id not in lobby_peers:
logger.error(
f"{session.getName()} <- relaySessionDescription - Not an RTC peer({peer_id}) in {lobby_peers}"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": f"Target peer {peer_id} not found",
},
}
)
continue
if not peer_id:
logger.error(
f"{session.getName()} - relaySessionDescription missing peer_id"
)
await websocket.send_json(
{
"type": "error",
"data": {
"error": "relaySessionDescription missing peer_id",
},
}
)
continue
peer_session = lobby.getSession(peer_id)
if not peer_session or not peer_session.ws:
logger.warning(
f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}."
)
continue
session_description = data.get("session_description")
message = {
"type": "sessionDescription",
"data": {
"peer_id": session.id,
"peer_name": session.name,
"session_description": session_description,
},
}
logger.info(
f"{session.getName()} -> sessionDescription({peer_session.getName()})"
)
try:
await peer_session.ws.send_json(message)
except Exception as e:
logger.warning(f"Failed to relay session description: {e}")
case "status_check":
# Simple status check - just respond with success to keep connection alive
logger.debug(f"{session.getName()} <- status_check")
await websocket.send_json(
{"type": "status_ok", "data": {"timestamp": time.time()}}
)
case _:
await websocket.send_json(
{
"type": "error",
"data": {
"error": f"Unknown request type: {type}",
},
}
)
except WebSocketDisconnect:
logger.info(f"{session.getName()} <- WebSocket disconnected for user.")
# Cleanup: remove session from lobby and sessions dict
session.ws = None
if session.id in lobby.sessions:
try:
await session.part(lobby)
except Exception as e:
logger.warning(f"Error during websocket disconnect cleanup: {e}")
try:
await lobby.update_state()
except Exception as e:
logger.warning(f"Error updating lobby state after disconnect: {e}")
# Clean up empty lobbies
with lobby.lock:
if not lobby.sessions:
if lobby.id in lobbies:
del lobbies[lobby.id]
logger.info(f"Cleaned up empty lobby {lobby.getName()}")
except Exception as e:
logger.error(
f"Unexpected error in websocket handler for {session.getName()}: {e}"
)
try:
await websocket.close()
except Exception as e:
pass
# Serve static files or proxy to frontend development server
PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true"
client_build_path = os.path.join(os.path.dirname(__file__), "/client/build")
if PRODUCTION:
logger.info(f"Serving static files from: {client_build_path} at {public_url}")
app.mount(
public_url, StaticFiles(directory=client_build_path, html=True), name="static"
)
else:
logger.info(f"Proxying static files to http://client:3000 at {public_url}")
import ssl
@app.api_route(
f"{public_url}{{path:path}}",
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
)
async def proxy_static(request: Request, path: str):
# Do not proxy API or websocket paths
if path.startswith("api/") or path.startswith("ws/"):
return Response(status_code=404)
url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}/{path}"
if not path:
url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}"
headers = dict(request.headers)
try:
# Accept self-signed certs in dev
async with httpx.AsyncClient(verify=False) as client:
proxy_req = client.build_request(
request.method, url, headers=headers, content=await request.body()
)
proxy_resp = await client.send(proxy_req, stream=True)
content = await proxy_resp.aread()
# Remove problematic headers for browser decoding
filtered_headers = {
k: v
for k, v in proxy_resp.headers.items()
if k.lower()
not in ["content-encoding", "transfer-encoding", "content-length"]
}
return Response(
content=content,
status_code=proxy_resp.status_code,
headers=filtered_headers,
)
except Exception as e:
logger.error(f"Proxy error for {url}: {e}")
return Response("Proxy error", status_code=502)
# WebSocket proxy for /ws (for React DevTools, etc.)
import websockets
@app.websocket("/ws")
async def websocket_proxy(websocket: WebSocket):
logger.info("REACT: WebSocket proxy connection established.")
# Get scheme from websocket.url (should be 'ws' or 'wss')
scheme = websocket.url.scheme if hasattr(websocket, "url") else "ws"
target_url = "wss://client:3000/ws" # Use WSS since client uses HTTPS
await websocket.accept()
try:
# Accept self-signed certs in dev for WSS
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws:
async def client_to_server():
while True:
msg = await websocket.receive_text()
await target_ws.send(msg)
async def server_to_client():
while True:
msg = await target_ws.recv()
if isinstance(msg, str):
await websocket.send_text(msg)
else:
await websocket.send_bytes(msg)
try:
await asyncio.gather(client_to_server(), server_to_client())
except (WebSocketDisconnect, websockets.ConnectionClosed):
logger.info("REACT: WebSocket proxy connection closed.")
except Exception as e:
logger.error(f"REACT: WebSocket proxy error: {e}")
await websocket.close()