2025-09-01 20:34:01 -07:00

1316 lines
49 KiB
Python

from __future__ import annotations
from typing import Any, Optional, TypedDict
from fastapi import (
Body,
Cookie,
FastAPI,
Path,
WebSocket,
Request,
Response,
WebSocketDisconnect,
)
import secrets
import os
import json
import hashlib
import binascii
import sys
import asyncio
from contextlib import asynccontextmanager
from fastapi.staticfiles import StaticFiles
import httpx
from pydantic import ValidationError
from logger import logger
# Import shared models
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.models import (
HealthResponse,
LobbiesResponse,
LobbyCreateRequest,
LobbyCreateResponse,
LobbyListItem,
LobbyModel,
NamePasswordRecord,
LobbySaved,
SessionResponse,
SessionSaved,
SessionsPayload,
AdminNamesResponse,
AdminActionResponse,
AdminSetPassword,
AdminClearPassword,
JoinStatusModel,
ChatMessageModel,
ChatMessagesResponse,
)
# Mapping of reserved names to password records (lowercased name -> {salt:..., hash:...})
name_passwords: dict[str, dict[str, str]] = {}
all_label = "[ all ]"
info_label = "[ info ]"
todo_label = "[ todo ]"
unset_label = "[ ---- ]"
def _hash_password(password: str, salt_hex: str | None = None) -> tuple[str, str]:
"""Return (salt_hex, hash_hex) for the given password. If salt_hex is provided
it is used; otherwise a new salt is generated."""
if salt_hex:
salt = binascii.unhexlify(salt_hex)
else:
salt = secrets.token_bytes(16)
salt_hex = binascii.hexlify(salt).decode()
dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100000)
hash_hex = binascii.hexlify(dk).decode()
return salt_hex, hash_hex
public_url = os.getenv("PUBLIC_URL", "/")
if not public_url.endswith("/"):
public_url += "/"
# Global variable to control the cleanup task
cleanup_task_running = False
cleanup_task = None
async def periodic_cleanup():
"""Background task to periodically clean up old sessions"""
global cleanup_task_running
while cleanup_task_running:
try:
Session.cleanup_old_sessions()
# Run cleanup every 5 minutes
await asyncio.sleep(300)
except Exception as e:
logger.error(f"Error in session cleanup task: {e}")
await asyncio.sleep(60) # Wait 1 minute before retrying on error
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
global cleanup_task_running, cleanup_task
# Startup
cleanup_task_running = True
cleanup_task = asyncio.create_task(periodic_cleanup())
logger.info("Session cleanup task started")
yield
# Shutdown
cleanup_task_running = False
if cleanup_task:
cleanup_task.cancel()
try:
await cleanup_task
except asyncio.CancelledError:
pass
logger.info("Session cleanup task stopped")
app = FastAPI(lifespan=lifespan)
logger.info(f"Starting server with public URL: {public_url}")
# Optional admin token to protect admin endpoints
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", None)
def _require_admin(request: Request) -> bool:
if not ADMIN_TOKEN:
return True
token = request.headers.get("X-Admin-Token")
return token == ADMIN_TOKEN
@app.get(public_url + "api/admin/names", response_model=AdminNamesResponse)
def admin_list_names(request: Request):
if not _require_admin(request):
return Response(status_code=403)
return {"name_passwords": name_passwords}
@app.post(public_url + "api/admin/set_password", response_model=AdminActionResponse)
def admin_set_password(request: Request, payload: AdminSetPassword = Body(...)):
if not _require_admin(request):
return Response(status_code=403)
lname = payload.name.lower()
salt, hash_hex = _hash_password(payload.password)
name_passwords[lname] = {"salt": salt, "hash": hash_hex}
Session.save()
return {"status": "ok", "name": payload.name}
@app.post(public_url + "api/admin/clear_password", response_model=AdminActionResponse)
def admin_clear_password(request: Request, payload: AdminClearPassword = Body(...)):
if not _require_admin(request):
return Response(status_code=403)
lname = payload.name.lower()
if lname in name_passwords:
del name_passwords[lname]
Session.save()
return {"status": "ok", "name": payload.name}
return {"status": "not_found", "name": payload.name}
@app.post(public_url + "api/admin/cleanup_sessions", response_model=AdminActionResponse)
def admin_cleanup_sessions(request: Request):
if not _require_admin(request):
return Response(status_code=403)
try:
removed_count = Session.cleanup_old_sessions()
return {"status": "ok", "name": f"Removed {removed_count} sessions"}
except Exception as e:
logger.error(f"Error during manual session cleanup: {e}")
return {"status": "not_found", "name": f"Error: {str(e)}"}
lobbies: dict[str, Lobby] = {}
class LobbyResponse(TypedDict):
id: str
name: str
private: bool
class Lobby:
def __init__(self, name: str, id: str | None = None, private: bool = False):
self.id = secrets.token_hex(16) if id is None else id
self.short = self.id[:8]
self.name = name
self.sessions: dict[str, Session] = {} # All lobby members
self.private = private
self.chat_messages: list[ChatMessageModel] = [] # Store chat messages
def getName(self) -> str:
return f"{self.short}:{self.name}"
async def update_state(self, requesting_session: Session | None = None):
users: list[dict[str, str | bool]] = [
{
"name": s.name,
"live": True if s.ws else False,
"session_id": s.id,
"protected": True
if s.name and s.name.lower() in name_passwords
else False,
}
for s in self.sessions.values()
if s.name
]
if requesting_session:
logger.info(
f"{requesting_session.getName()} -> lobby_state({self.getName()})"
)
if requesting_session.ws:
await requesting_session.ws.send_json(
{"type": "lobby_state", "data": {"participants": users}}
)
else:
logger.warning(
f"{requesting_session.getName()} - No WebSocket connection."
)
else:
for s in self.sessions.values():
logger.info(f"{s.getName()} -> lobby_state({self.getName()})")
if s.ws:
await s.ws.send_json(
{"type": "lobby_state", "data": {"participants": users}}
)
def getSession(self, id: str) -> Session | None:
return self.sessions.get(id, None)
async def addSession(self, session: Session) -> None:
if session.id in self.sessions:
logger.warning(f"{session.getName()} - Already in lobby {self.getName()}.")
return None
self.sessions[session.id] = session
await self.update_state()
async def removeSession(self, session: Session) -> None:
if session.id not in self.sessions:
logger.warning(f"{session.getName()} - Not in lobby {self.getName()}.")
return None
del self.sessions[session.id]
await self.update_state()
def add_chat_message(self, session: Session, message: str) -> ChatMessageModel:
"""Add a chat message to the lobby and return the message data"""
import time
chat_message = ChatMessageModel(
id=secrets.token_hex(8),
message=message,
sender_name=session.name or session.short,
sender_session_id=session.id,
timestamp=time.time(),
lobby_id=self.id,
)
self.chat_messages.append(chat_message)
# Keep only the latest 100 messages per lobby
if len(self.chat_messages) > 100:
self.chat_messages = self.chat_messages[-100:]
return chat_message
def get_chat_messages(self, limit: int = 50) -> list[ChatMessageModel]:
"""Get the most recent chat messages from the lobby"""
return self.chat_messages[-limit:] if self.chat_messages else []
async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None:
"""Broadcast a chat message to all connected sessions in the lobby"""
for session in self.sessions.values():
if session.ws:
try:
await session.ws.send_json(
{"type": "chat_message", "data": chat_message.model_dump()}
)
except Exception as e:
logger.warning(
f"Failed to send chat message to {session.getName()}: {e}"
)
class Session:
_instances: list[Session] = []
_save_file = "sessions.json"
_loaded = False
def __init__(self, id: str):
import time
logger.info(f"Instantiating new session {id}")
self._instances.append(self)
self.id = id
self.short = id[:8]
self.name = ""
self.lobbies: list[Lobby] = [] # List of lobby IDs this session is in
self.lobby_peers: dict[
str, list[str]
] = {} # lobby ID -> list of peer session IDs
self.ws: WebSocket | None = None
self.created_at = time.time()
self.last_used = time.time()
self.displaced_at: float | None = None # When name was taken over
self.save()
@classmethod
def save(cls):
sessions_list: list[SessionSaved] = []
for s in cls._instances:
lobbies_list: list[LobbySaved] = [
LobbySaved(id=lobby.id, name=lobby.name, private=lobby.private)
for lobby in s.lobbies
]
sessions_list.append(
SessionSaved(
id=s.id,
name=s.name or "",
lobbies=lobbies_list,
created_at=s.created_at,
last_used=s.last_used,
displaced_at=s.displaced_at,
)
)
# Prepare name password store for persistence (salt+hash). Only structured records are supported.
saved_pw: dict[str, NamePasswordRecord] = {
name: NamePasswordRecord(**record)
for name, record in name_passwords.items()
}
payload_model = SessionsPayload(sessions=sessions_list, name_passwords=saved_pw)
payload = payload_model.model_dump()
with open(cls._save_file, "w") as f:
json.dump(payload, f, indent=2)
logger.info(
f"Saved {len(sessions_list)} sessions and {len(saved_pw)} name passwords to {cls._save_file}"
)
@classmethod
def load(cls):
import time
if not os.path.exists(cls._save_file):
logger.info(f"No session save file found: {cls._save_file}")
return
with open(cls._save_file, "r") as f:
raw = json.load(f)
try:
payload = SessionsPayload.model_validate(raw)
except ValidationError as e:
logger.exception(f"Failed to validate sessions payload: {e}")
return
# Populate in-memory structures from payload (no backwards compatibility code)
name_passwords.clear()
for name, rec in payload.name_passwords.items():
# rec is a NamePasswordRecord
name_passwords[name] = {"salt": rec.salt, "hash": rec.hash}
for s_saved in payload.sessions:
session = Session(s_saved.id)
session.name = s_saved.name or ""
# Load timestamps, with defaults for backward compatibility
session.created_at = getattr(s_saved, "created_at", time.time())
session.last_used = getattr(s_saved, "last_used", time.time())
session.displaced_at = getattr(s_saved, "displaced_at", None)
for lobby_saved in s_saved.lobbies:
session.lobbies.append(
Lobby(
name=lobby_saved.name,
id=lobby_saved.id,
private=lobby_saved.private,
)
)
logger.info(
f"Loaded session {session.getName()} with {len(session.lobbies)} lobbies"
)
for lobby in session.lobbies:
lobbies[lobby.id] = Lobby(
name=lobby.name, id=lobby.id
) # Ensure lobby exists
logger.info(
f"Loaded {len(payload.sessions)} sessions and {len(name_passwords)} name passwords from {cls._save_file}"
)
@classmethod
def getSession(cls, id: str) -> Session | None:
if not cls._loaded:
cls.load()
logger.info(f"Loaded {len(cls._instances)} sessions from disk...")
cls._loaded = True
for s in cls._instances:
if s.id == id:
return s
return None
@classmethod
def isUniqueName(cls, name: str) -> bool:
if not name:
return False
for s in cls._instances:
if s.name.lower() == name.lower():
return False
return True
@classmethod
def getSessionByName(cls, name: str) -> Optional["Session"]:
if not name:
return None
lname = name.lower()
for s in cls._instances:
if s.name and s.name.lower() == lname:
return s
return None
def getName(self) -> str:
return f"{self.short}:{self.name if self.name else unset_label}"
def setName(self, name: str):
self.name = name
self.update_last_used()
self.save()
def update_last_used(self):
"""Update the last_used timestamp"""
import time
self.last_used = time.time()
def mark_displaced(self):
"""Mark this session as having its name taken over"""
import time
self.displaced_at = time.time()
@classmethod
def cleanup_old_sessions(cls) -> int:
"""Clean up old sessions based on the specified criteria"""
import time
current_time = time.time()
one_minute = 60.0
twenty_four_hours = 24 * 60 * 60.0
sessions_removed = 0
# Make a copy of the list to avoid modifying it while iterating
sessions_to_remove: list[Session] = []
for session in cls._instances[:]:
# Rule 1: Delete sessions with no active connection and no name that are older than 1 minute
if (
not session.ws
and not session.name
and current_time - session.created_at > one_minute
):
logger.info(
f"Removing session {session.getName()} - no connection, no name, older than 1 minute"
)
sessions_to_remove.append(session)
continue
# Rule 2: Delete inactive sessions that had their nick taken over and haven't been used in 24 hours
if (
not session.ws
and session.displaced_at is not None
and current_time - session.last_used > twenty_four_hours
):
logger.info(
f"Removing session {session.getName()} - displaced and unused for 24+ hours"
)
sessions_to_remove.append(session)
continue
# Remove the sessions
for session in sessions_to_remove:
# Remove from lobbies first
for lobby in session.lobbies[
:
]: # Copy list to avoid modification during iteration
try:
# Use async cleanup if needed, but for cleanup we'll just remove from data structures
if session.id in lobby.sessions:
del lobby.sessions[session.id]
if lobby.id in session.lobby_peers:
del session.lobby_peers[lobby.id]
except Exception as e:
logger.warning(
f"Error removing session {session.getName()} from lobby {lobby.getName()}: {e}"
)
# Remove from instances list
if session in cls._instances:
cls._instances.remove(session)
sessions_removed += 1
if sessions_removed > 0:
cls.save()
logger.info(f"Session cleanup: removed {sessions_removed} old sessions")
return sessions_removed
async def join(self, lobby: Lobby):
if not self.ws:
logger.error(
f"{self.getName()} - No WebSocket connection. Lobby not available."
)
return
if lobby.id in self.lobby_peers or self.id in lobby.sessions:
logger.info(f"{self.getName()} - Already joined to {lobby.getName()}.")
data = JoinStatusModel(
status="Joined", message=f"Already joined to lobby {lobby.getName()}"
)
await self.ws.send_json({"type": "join_status", "data": data.model_dump()})
return
# Initialize the peer list for this lobby
self.lobbies.append(lobby)
self.lobby_peers[lobby.id] = []
for peer_id in lobby.sessions:
if peer_id == self.id:
raise Exception(
"Should not happen: self in lobby.sessions while not in lobby."
)
peer_session = lobby.getSession(peer_id)
if not peer_session or not peer_session.ws:
logger.warning(
f"{self.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}. Removing."
)
del lobby.sessions[peer_id]
continue
# Add the peer to session's RTC peer list
self.lobby_peers[lobby.id].append(peer_id)
# Add this user as an RTC peer to each existing peer
peer_session.lobby_peers[lobby.id].append(self.id)
logger.info(
f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName(), lobby.getName()}, should_create_offer=False)"
)
await peer_session.ws.send_json(
{
"type": "addPeer",
"data": {
"peer_id": self.id,
"peer_name": self.name,
"should_create_offer": False,
},
}
)
# Add each other peer to the caller
logger.info(
f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName(), lobby.getName()}, should_create_offer=True)"
)
await self.ws.send_json(
{
"type": "addPeer",
"data": {
"peer_id": peer_session.id,
"peer_name": peer_session.name,
"should_create_offer": True,
},
}
)
# Add this user as an RTC peer
await lobby.addSession(self)
Session.save()
await self.ws.send_json({"type": "join_status", "data": {"status": "Joined"}})
async def part(self, lobby: Lobby):
if lobby.id not in self.lobby_peers or self.id not in lobby.sessions:
logger.info(
f"{self.getName()} - Attempt to part non-joined lobby {lobby.getName()}."
)
if self.ws:
await self.ws.send_json(
{"type": "error", "error": "Attempt to part non-joined lobby"}
)
return
logger.info(f"{self.getName()} <- part({lobby.getName()}) - Lobby part.")
lobby_peers = self.lobby_peers[lobby.id]
del self.lobby_peers[lobby.id]
self.lobbies.remove(lobby)
# Remove this peer from all other RTC peers, and remove each peer from this peer
for peer_session_id in lobby_peers:
peer_session = getSession(peer_session_id)
if not peer_session:
logger.warning(
f"{self.getName()} <- part({lobby.getName()}) - Peer session {peer_session_id} not found. Skipping."
)
continue
if not peer_session.ws:
logger.warning(
f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection for {peer_session.getName()}. Skipping."
)
continue
logger.info(f"{peer_session.getName()} <- remove_peer({self.getName()})")
await peer_session.ws.send_json(
{
"type": "removePeer",
"data": {"peer_name": self.name, "peer_id": self.id},
}
)
if not self.ws:
logger.error(
f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection."
)
continue
logger.info(f"{self.getName()} <- remove_peer({peer_session.getName()})")
await self.ws.send_json(
{
"type": "removePeer",
"data": {
"peer_name": peer_session.name,
"peer_id": peer_session.id,
},
}
)
await lobby.removeSession(self)
Session.save()
def getName(session: Session | None) -> str | None:
if session and session.name:
return session.name
return None
def getSession(session_id: str) -> Session | None:
return Session.getSession(session_id)
def getLobby(lobby_id: str) -> Lobby:
lobby = lobbies.get(lobby_id, None)
if not lobby:
logger.error(f"Lobby not found: {lobby_id}")
raise Exception(f"Lobby not found: {lobby_id}")
return lobby
def getLobbyByName(lobby_name: str) -> Lobby | None:
for lobby in lobbies.values():
if lobby.name == lobby_name:
return lobby
return None
# API endpoints
@app.get(f"{public_url}api/health", response_model=HealthResponse)
def health():
logger.info("Health check endpoint called.")
return {
"status": "ok",
}
# A session (cookie) is bound to a single user (name).
# A user can be in multiple lobbies, but a session is unique to a single user.
# A user can change their name, but the session ID remains the same and the name
# updates for all lobbies.
@app.get(f"{public_url}api/session", response_model=SessionResponse)
async def session(
request: Request, response: Response, session_id: str | None = Cookie(default=None)
) -> Response | SessionResponse:
if session_id is None:
session_id = secrets.token_hex(16)
response.set_cookie(key="session_id", value=session_id)
# Validate that session_id is a hex string of length 32
elif len(session_id) != 32 or not all(c in "0123456789abcdef" for c in session_id):
return Response(
content=json.dumps({"error": "Invalid session_id"}),
status_code=400,
media_type="application/json",
)
print(f"[{session_id[:8]}]: Browser hand-shake achieved.")
session = getSession(session_id)
if not session:
session = Session(session_id)
logger.info(f"{session.getName()}: New session created.")
else:
session.update_last_used() # Update activity on session resumption
logger.info(f"{session.getName()}: Existing session resumed.")
# Part all lobbies for this session that have no active websocket
for lobby_id in list(session.lobby_peers.keys()):
lobby = None
try:
lobby = getLobby(lobby_id)
except Exception as e:
logger.error(
f"{session.getName()} - Error getting lobby {lobby_id}: {e}"
)
continue
await session.part(lobby)
return SessionResponse(
id=session_id,
name=session.name if session.name else "",
lobbies=[
LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private)
for lobby in session.lobbies
],
)
@app.get(public_url + "api/lobby", response_model=LobbiesResponse)
async def get_lobbies(request: Request, response: Response) -> LobbiesResponse:
return LobbiesResponse(
lobbies=[
LobbyListItem(id=lobby.id, name=lobby.name)
for lobby in lobbies.values()
if not lobby.private
]
)
@app.post(public_url + "api/lobby/{session_id}", response_model=LobbyCreateResponse)
async def lobby_create(
request: Request,
response: Response,
session_id: str = Path(...),
create_request: LobbyCreateRequest = Body(...),
) -> Response | LobbyCreateResponse:
if create_request.type != "lobby_create":
return {"error": "Invalid request type"}
data = create_request.data
session = getSession(session_id)
if not session:
return Response(
content=json.dumps({"error": f"Session not found ({session_id})"}),
status_code=404,
media_type="application/json",
)
logger.info(
f"{session.getName()} lobby_create: {data.name} (private={data.private})"
)
lobby = getLobbyByName(data.name)
if not lobby:
lobby = Lobby(
data.name,
private=data.private,
)
lobbies[lobby.id] = lobby
logger.info(f"{session.getName()} <- lobby_create({lobby.short}:{lobby.name})")
return LobbyCreateResponse(
type="lobby_created",
data=LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private),
)
@app.get(public_url + "api/lobby/{lobby_id}/chat", response_model=ChatMessagesResponse)
async def get_chat_messages(
request: Request,
lobby_id: str = Path(...),
limit: int = 50,
) -> Response | ChatMessagesResponse:
"""Get chat messages for a lobby"""
try:
lobby = getLobby(lobby_id)
except Exception as e:
return Response(
content=json.dumps({"error": str(e)}),
status_code=404,
media_type="application/json",
)
messages = lobby.get_chat_messages(limit)
return ChatMessagesResponse(messages=messages)
# Register websocket endpoint directly on app with full public_url path
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}")
async def lobby_join(
websocket: WebSocket,
lobby_id: str | None = Path(...),
session_id: str | None = Path(...),
):
await websocket.accept()
if lobby_id is None:
await websocket.send_json(
{"type": "error", "error": "Invalid or missing lobby"}
)
await websocket.close()
return
if session_id is None:
await websocket.send_json(
{"type": "error", "error": "Invalid or missing session"}
)
await websocket.close()
return
session = getSession(session_id)
if not session:
# logger.error(f"Invalid session ID {session_id}")
await websocket.send_json(
{"type": "error", "error": f"Invalid session ID {session_id}"}
)
await websocket.close()
return
lobby = None
try:
lobby = getLobby(lobby_id)
except Exception as e:
await websocket.send_json({"type": "error", "error": str(e)})
await websocket.close()
return
logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})")
session.ws = websocket
session.update_last_used() # Update activity timestamp
if session.id in lobby.sessions:
logger.info(
f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining."
)
await session.part(lobby)
await lobby.removeSession(session)
for peer_id in lobby.sessions:
peer_session = lobby.getSession(peer_id)
if not peer_session or not peer_session.ws:
logger.warning(
f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}. Removing."
)
del lobby.sessions[peer_id]
continue
logger.info(f"{session.getName()} -> user_joined({peer_session.getName()})")
await peer_session.ws.send_json(
{
"type": "user_joined",
"data": {
"session_id": session.id,
"name": session.name,
},
}
)
try:
while True:
packet = await websocket.receive_json()
session.update_last_used() # Update activity on each message
type = packet.get("type", None)
data: dict[str, Any] | None = packet.get("data", None)
if not type:
logger.error(f"{session.getName()} - Invalid request: {packet}")
await websocket.send_json({"type": "error", "error": "Invalid request"})
continue
# logger.info(f"{session.getName()} <- RAW Rx: {data}")
match type:
case "set_name":
if not data:
logger.error(f"{session.getName()} - set_name missing data")
await websocket.send_json(
{"type": "error", "error": "set_name missing data"}
)
continue
name = data.get("name")
password = data.get("password")
logger.info(f"{session.getName()} <- set_name({name})")
if not name:
logger.error(f"{session.getName()} - Name required")
await websocket.send_json(
{"type": "error", "error": "Name required"}
)
continue
# Name takeover / password logic
lname = name.lower()
# If name is unused, allow and optionally save password
if Session.isUniqueName(name):
# If a password was provided, save it (hash+salt) for this name
if password:
salt, hash_hex = _hash_password(password)
name_passwords[lname] = {"salt": salt, "hash": hash_hex}
session.setName(name)
logger.info(f"{session.getName()}: -> update('name', {name})")
await websocket.send_json(
{
"type": "update_name",
"data": {
"name": name,
"protected": True
if name.lower() in name_passwords
else False,
},
}
)
# For any clients in any lobby with this session, update their user lists
await lobby.update_state()
continue
# Name is taken. Check if a password exists for the name and matches.
saved_pw = name_passwords.get(lname)
if not saved_pw:
logger.warning(
f"{session.getName()} - Name already taken (no password set)"
)
await websocket.send_json(
{"type": "error", "error": "Name already taken"}
)
continue
# Expect structured record with salt+hash only
match_password = False
# saved_pw should be a dict[str,str] with 'salt' and 'hash'
salt = saved_pw.get("salt")
_, candidate_hash = _hash_password(
password if password else "", salt_hex=salt
)
if candidate_hash == saved_pw.get("hash"):
match_password = True
else:
# No structured password record available
match_password = False
if not match_password:
logger.warning(
f"{session.getName()} - Name takeover attempted with wrong or missing password"
)
await websocket.send_json(
{
"type": "error",
"error": "Invalid password for name takeover",
}
)
continue
# Password matches: perform takeover. Find the current session holding the name.
# Find the currently existing session (if any) with that name
displaced = Session.getSessionByName(name)
if displaced and displaced.id == session.id:
displaced = None
# If found, change displaced session to a unique fallback name and notify peers
if displaced:
# Create a unique fallback name
fallback = f"{displaced.name}-{displaced.short}"
# Ensure uniqueness
if not Session.isUniqueName(fallback):
# append random suffix until unique
while not Session.isUniqueName(fallback):
fallback = f"{displaced.name}-{secrets.token_hex(3)}"
displaced.setName(fallback)
displaced.mark_displaced()
logger.info(
f"{displaced.getName()} <- displaced by takeover, new name {fallback}"
)
# Notify displaced session (if connected)
if displaced.ws:
try:
await displaced.ws.send_json(
{
"type": "update_name",
"data": {
"name": fallback,
"protected": False,
},
}
)
except Exception:
logger.exception(
"Failed to notify displaced session websocket"
)
# Update all lobbies the displaced session was in
for d_lobby in list(displaced.lobbies):
try:
await d_lobby.update_state()
except Exception:
logger.exception(
"Failed to update lobby state for displaced session"
)
# Now assign the requested name to the current session
session.setName(name)
logger.info(
f"{session.getName()}: -> update('name', {name}) (takeover)"
)
await websocket.send_json(
{
"type": "update_name",
"data": {
"name": name,
"protected": True
if name.lower() in name_passwords
else False,
},
}
)
# Notify lobbies for this session
await lobby.update_state()
case "list_users":
await lobby.update_state(session)
case "get_chat_messages":
# Send recent chat messages to the requesting client
messages = lobby.get_chat_messages(50)
await websocket.send_json(
{
"type": "chat_messages",
"data": {
"messages": [msg.model_dump() for msg in messages]
},
}
)
case "send_chat_message":
if not data or "message" not in data:
logger.error(
f"{session.getName()} - send_chat_message missing message"
)
await websocket.send_json(
{
"type": "error",
"error": "send_chat_message missing message",
}
)
continue
if not session.name:
logger.error(
f"{session.getName()} - Cannot send chat message without name"
)
await websocket.send_json(
{
"type": "error",
"error": "Must set name before sending chat messages",
}
)
continue
message_text = str(data["message"]).strip()
if not message_text:
continue
# Add the message to the lobby and broadcast it
chat_message = lobby.add_chat_message(session, message_text)
await lobby.broadcast_chat_message(chat_message)
logger.info(
f"{session.getName()} sent chat message to {lobby.getName()}: {message_text[:50]}..."
)
case "join":
logger.info(f"{session.getName()} <- join({lobby.getName()})")
await session.join(lobby=lobby)
case "part":
logger.info(f"{session.getName()} <- part {lobby.getName()}")
await session.part(lobby=lobby)
case "relayICECandidate":
logger.info(f"{session.getName()} <- relayICECandidate")
if not data:
logger.error(
f"{session.getName()} - relayICECandidate missing data"
)
await websocket.send_json(
{"type": "error", "error": "relayICECandidate missing data"}
)
continue
if (
lobby.id not in session.lobby_peers
or session.id not in lobby.sessions
):
logger.error(
f"{session.short}:{session.name} <- relayICECandidate - Not an RTC peer ({session.id})"
)
await websocket.send_json(
{"type": "error", "error": "Not joined to lobby"}
)
continue
session_peers = session.lobby_peers[lobby.id]
peer_id = data.get("peer_id")
if peer_id not in session_peers:
logger.error(
f"{session.getName()} <- relayICECandidate - Not an RTC peer({peer_id}) in {session_peers}"
)
await websocket.send_json(
{
"type": "error",
"error": f"Target peer {peer_id} not found",
}
)
continue
candidate = data.get("candidate")
message: dict[str, Any] = {
"type": "iceCandidate",
"data": {
"peer_id": session.id,
"peer_name": session.name,
"candidate": candidate,
},
}
peer_session = lobby.getSession(peer_id)
if not peer_session or not peer_session.ws:
logger.warning(
f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}."
)
break
logger.info(
f"{session.getName()} -> iceCandidate({peer_session.getName()})"
)
await peer_session.ws.send_json(message)
case "relaySessionDescription":
logger.info(f"{session.getName()} <- relaySessionDescription")
if not data:
logger.error(
f"{session.getName()} - relaySessionDescription missing data"
)
await websocket.send_json(
{
"type": "error",
"error": "relaySessionDescription missing data",
}
)
continue
if (
lobby.id not in session.lobby_peers
or session.id not in lobby.sessions
):
logger.error(
f"{session.short}:{session.name} <- relaySessionDescription - Not an RTC peer ({session.id})"
)
await websocket.send_json(
{"type": "error", "error": "Not joined to lobby"}
)
continue
lobby_peers = session.lobby_peers[lobby.id]
peer_id = data.get("peer_id")
if peer_id not in lobby_peers:
logger.error(
f"{session.getName()} <- relaySessionDescription - Not an RTC peer({peer_id}) in {lobby_peers}"
)
await websocket.send_json(
{
"type": "error",
"error": f"Target peer {peer_id} not found",
}
)
continue
peer_id = data.get("peer_id", None)
if not peer_id:
logger.error(
f"{session.getName()} - relaySessionDescription missing peer_id"
)
await websocket.send_json(
{
"type": "error",
"error": "relaySessionDescription missing peer_id",
}
)
continue
peer_session = lobby.getSession(peer_id)
if not peer_session or not peer_session.ws:
logger.warning(
f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}."
)
break
session_description = data.get("session_description")
message = {
"type": "sessionDescription",
"data": {
"peer_id": session.id,
"peer_name": session.name,
"session_description": session_description,
},
}
logger.info(
f"{session.getName()} -> sessionDescription({peer_session.getName()})"
)
await peer_session.ws.send_json(message)
case _:
await websocket.send_json(
{
"type": "error",
"error": f"Unknown request type: {type}",
}
)
except WebSocketDisconnect:
logger.info(f"{session.getName()} <- WebSocket disconnected for user.")
# Cleanup: remove session from lobby and sessions dict
session.ws = None
if session.id in lobby.sessions:
await session.part(lobby)
await lobby.update_state()
# Clean up empty lobbies
if not lobby.sessions:
if lobby.id in lobbies:
del lobbies[lobby.id]
logger.info(f"Cleaned up empty lobby {lobby.getName()}")
# Serve static files or proxy to frontend development server
PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true"
client_build_path = os.path.join(os.path.dirname(__file__), "/client/build")
if PRODUCTION:
logger.info(f"Serving static files from: {client_build_path} at {public_url}")
app.mount(
public_url, StaticFiles(directory=client_build_path, html=True), name="static"
)
else:
logger.info(f"Proxying static files to http://client:3000 at {public_url}")
import ssl
@app.api_route(
f"{public_url}{{path:path}}",
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
)
async def proxy_static(request: Request, path: str):
# Do not proxy API or websocket paths
if path.startswith("api/") or path.startswith("ws/"):
return Response(status_code=404)
url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}/{path}"
if not path:
url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}"
headers = dict(request.headers)
try:
# Accept self-signed certs in dev
async with httpx.AsyncClient(verify=False) as client:
proxy_req = client.build_request(
request.method, url, headers=headers, content=await request.body()
)
proxy_resp = await client.send(proxy_req, stream=True)
content = await proxy_resp.aread()
# Remove problematic headers for browser decoding
filtered_headers = {
k: v
for k, v in proxy_resp.headers.items()
if k.lower()
not in ["content-encoding", "transfer-encoding", "content-length"]
}
return Response(
content=content,
status_code=proxy_resp.status_code,
headers=filtered_headers,
)
except Exception as e:
logger.error(f"Proxy error for {url}: {e}")
return Response("Proxy error", status_code=502)
# WebSocket proxy for /ws (for React DevTools, etc.)
import websockets
import asyncio
from starlette.websockets import WebSocket as StarletteWebSocket
@app.websocket("/ws")
async def websocket_proxy(websocket: StarletteWebSocket):
logger.info("REACT: WebSocket proxy connection established.")
# Get scheme from websocket.url (should be 'ws' or 'wss')
scheme = websocket.url.scheme if hasattr(websocket, "url") else "ws"
target_url = f"{scheme}://client:3000/ws"
await websocket.accept()
try:
# Accept self-signed certs in dev for WSS
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws:
async def client_to_server():
while True:
msg = await websocket.receive_text()
await target_ws.send(msg)
async def server_to_client():
while True:
msg = await target_ws.recv()
if isinstance(msg, str):
await websocket.send_text(msg)
else:
await websocket.send_bytes(msg)
try:
await asyncio.gather(client_to_server(), server_to_client())
except (WebSocketDisconnect, websockets.ConnectionClosed):
logger.info("REACT: WebSocket proxy connection closed.")
except Exception as e:
logger.error(f"REACT: WebSocket proxy error: {e}")
await websocket.close()