1316 lines
49 KiB
Python
1316 lines
49 KiB
Python
from __future__ import annotations
|
|
from typing import Any, Optional, TypedDict
|
|
from fastapi import (
|
|
Body,
|
|
Cookie,
|
|
FastAPI,
|
|
Path,
|
|
WebSocket,
|
|
Request,
|
|
Response,
|
|
WebSocketDisconnect,
|
|
)
|
|
import secrets
|
|
import os
|
|
import json
|
|
import hashlib
|
|
import binascii
|
|
import sys
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
import httpx
|
|
from pydantic import ValidationError
|
|
from logger import logger
|
|
|
|
# Import shared models
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
from shared.models import (
|
|
HealthResponse,
|
|
LobbiesResponse,
|
|
LobbyCreateRequest,
|
|
LobbyCreateResponse,
|
|
LobbyListItem,
|
|
LobbyModel,
|
|
NamePasswordRecord,
|
|
LobbySaved,
|
|
SessionResponse,
|
|
SessionSaved,
|
|
SessionsPayload,
|
|
AdminNamesResponse,
|
|
AdminActionResponse,
|
|
AdminSetPassword,
|
|
AdminClearPassword,
|
|
JoinStatusModel,
|
|
ChatMessageModel,
|
|
ChatMessagesResponse,
|
|
)
|
|
|
|
|
|
# Mapping of reserved names to password records (lowercased name -> {salt:..., hash:...})
|
|
name_passwords: dict[str, dict[str, str]] = {}
|
|
|
|
all_label = "[ all ]"
|
|
info_label = "[ info ]"
|
|
todo_label = "[ todo ]"
|
|
unset_label = "[ ---- ]"
|
|
|
|
|
|
def _hash_password(password: str, salt_hex: str | None = None) -> tuple[str, str]:
|
|
"""Return (salt_hex, hash_hex) for the given password. If salt_hex is provided
|
|
it is used; otherwise a new salt is generated."""
|
|
if salt_hex:
|
|
salt = binascii.unhexlify(salt_hex)
|
|
else:
|
|
salt = secrets.token_bytes(16)
|
|
salt_hex = binascii.hexlify(salt).decode()
|
|
dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100000)
|
|
hash_hex = binascii.hexlify(dk).decode()
|
|
return salt_hex, hash_hex
|
|
|
|
|
|
public_url = os.getenv("PUBLIC_URL", "/")
|
|
if not public_url.endswith("/"):
|
|
public_url += "/"
|
|
|
|
# Global variable to control the cleanup task
|
|
cleanup_task_running = False
|
|
cleanup_task = None
|
|
|
|
|
|
async def periodic_cleanup():
|
|
"""Background task to periodically clean up old sessions"""
|
|
global cleanup_task_running
|
|
while cleanup_task_running:
|
|
try:
|
|
Session.cleanup_old_sessions()
|
|
# Run cleanup every 5 minutes
|
|
await asyncio.sleep(300)
|
|
except Exception as e:
|
|
logger.error(f"Error in session cleanup task: {e}")
|
|
await asyncio.sleep(60) # Wait 1 minute before retrying on error
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Lifespan context manager for startup and shutdown events"""
|
|
global cleanup_task_running, cleanup_task
|
|
|
|
# Startup
|
|
cleanup_task_running = True
|
|
cleanup_task = asyncio.create_task(periodic_cleanup())
|
|
logger.info("Session cleanup task started")
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
cleanup_task_running = False
|
|
if cleanup_task:
|
|
cleanup_task.cancel()
|
|
try:
|
|
await cleanup_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
logger.info("Session cleanup task stopped")
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
logger.info(f"Starting server with public URL: {public_url}")
|
|
|
|
# Optional admin token to protect admin endpoints
|
|
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", None)
|
|
|
|
|
|
def _require_admin(request: Request) -> bool:
|
|
if not ADMIN_TOKEN:
|
|
return True
|
|
token = request.headers.get("X-Admin-Token")
|
|
return token == ADMIN_TOKEN
|
|
|
|
|
|
@app.get(public_url + "api/admin/names", response_model=AdminNamesResponse)
|
|
def admin_list_names(request: Request):
|
|
if not _require_admin(request):
|
|
return Response(status_code=403)
|
|
return {"name_passwords": name_passwords}
|
|
|
|
|
|
@app.post(public_url + "api/admin/set_password", response_model=AdminActionResponse)
|
|
def admin_set_password(request: Request, payload: AdminSetPassword = Body(...)):
|
|
if not _require_admin(request):
|
|
return Response(status_code=403)
|
|
lname = payload.name.lower()
|
|
salt, hash_hex = _hash_password(payload.password)
|
|
name_passwords[lname] = {"salt": salt, "hash": hash_hex}
|
|
Session.save()
|
|
return {"status": "ok", "name": payload.name}
|
|
|
|
|
|
@app.post(public_url + "api/admin/clear_password", response_model=AdminActionResponse)
|
|
def admin_clear_password(request: Request, payload: AdminClearPassword = Body(...)):
|
|
if not _require_admin(request):
|
|
return Response(status_code=403)
|
|
lname = payload.name.lower()
|
|
if lname in name_passwords:
|
|
del name_passwords[lname]
|
|
Session.save()
|
|
return {"status": "ok", "name": payload.name}
|
|
return {"status": "not_found", "name": payload.name}
|
|
|
|
|
|
@app.post(public_url + "api/admin/cleanup_sessions", response_model=AdminActionResponse)
|
|
def admin_cleanup_sessions(request: Request):
|
|
if not _require_admin(request):
|
|
return Response(status_code=403)
|
|
try:
|
|
removed_count = Session.cleanup_old_sessions()
|
|
return {"status": "ok", "name": f"Removed {removed_count} sessions"}
|
|
except Exception as e:
|
|
logger.error(f"Error during manual session cleanup: {e}")
|
|
return {"status": "not_found", "name": f"Error: {str(e)}"}
|
|
|
|
|
|
lobbies: dict[str, Lobby] = {}
|
|
|
|
|
|
class LobbyResponse(TypedDict):
|
|
id: str
|
|
name: str
|
|
private: bool
|
|
|
|
|
|
class Lobby:
|
|
def __init__(self, name: str, id: str | None = None, private: bool = False):
|
|
self.id = secrets.token_hex(16) if id is None else id
|
|
self.short = self.id[:8]
|
|
self.name = name
|
|
self.sessions: dict[str, Session] = {} # All lobby members
|
|
self.private = private
|
|
self.chat_messages: list[ChatMessageModel] = [] # Store chat messages
|
|
|
|
def getName(self) -> str:
|
|
return f"{self.short}:{self.name}"
|
|
|
|
async def update_state(self, requesting_session: Session | None = None):
|
|
users: list[dict[str, str | bool]] = [
|
|
{
|
|
"name": s.name,
|
|
"live": True if s.ws else False,
|
|
"session_id": s.id,
|
|
"protected": True
|
|
if s.name and s.name.lower() in name_passwords
|
|
else False,
|
|
}
|
|
for s in self.sessions.values()
|
|
if s.name
|
|
]
|
|
if requesting_session:
|
|
logger.info(
|
|
f"{requesting_session.getName()} -> lobby_state({self.getName()})"
|
|
)
|
|
if requesting_session.ws:
|
|
await requesting_session.ws.send_json(
|
|
{"type": "lobby_state", "data": {"participants": users}}
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"{requesting_session.getName()} - No WebSocket connection."
|
|
)
|
|
else:
|
|
for s in self.sessions.values():
|
|
logger.info(f"{s.getName()} -> lobby_state({self.getName()})")
|
|
if s.ws:
|
|
await s.ws.send_json(
|
|
{"type": "lobby_state", "data": {"participants": users}}
|
|
)
|
|
|
|
def getSession(self, id: str) -> Session | None:
|
|
return self.sessions.get(id, None)
|
|
|
|
async def addSession(self, session: Session) -> None:
|
|
if session.id in self.sessions:
|
|
logger.warning(f"{session.getName()} - Already in lobby {self.getName()}.")
|
|
return None
|
|
self.sessions[session.id] = session
|
|
await self.update_state()
|
|
|
|
async def removeSession(self, session: Session) -> None:
|
|
if session.id not in self.sessions:
|
|
logger.warning(f"{session.getName()} - Not in lobby {self.getName()}.")
|
|
return None
|
|
del self.sessions[session.id]
|
|
await self.update_state()
|
|
|
|
def add_chat_message(self, session: Session, message: str) -> ChatMessageModel:
|
|
"""Add a chat message to the lobby and return the message data"""
|
|
import time
|
|
|
|
chat_message = ChatMessageModel(
|
|
id=secrets.token_hex(8),
|
|
message=message,
|
|
sender_name=session.name or session.short,
|
|
sender_session_id=session.id,
|
|
timestamp=time.time(),
|
|
lobby_id=self.id,
|
|
)
|
|
self.chat_messages.append(chat_message)
|
|
# Keep only the latest 100 messages per lobby
|
|
if len(self.chat_messages) > 100:
|
|
self.chat_messages = self.chat_messages[-100:]
|
|
return chat_message
|
|
|
|
def get_chat_messages(self, limit: int = 50) -> list[ChatMessageModel]:
|
|
"""Get the most recent chat messages from the lobby"""
|
|
return self.chat_messages[-limit:] if self.chat_messages else []
|
|
|
|
async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None:
|
|
"""Broadcast a chat message to all connected sessions in the lobby"""
|
|
for session in self.sessions.values():
|
|
if session.ws:
|
|
try:
|
|
await session.ws.send_json(
|
|
{"type": "chat_message", "data": chat_message.model_dump()}
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to send chat message to {session.getName()}: {e}"
|
|
)
|
|
|
|
|
|
class Session:
|
|
_instances: list[Session] = []
|
|
_save_file = "sessions.json"
|
|
_loaded = False
|
|
|
|
def __init__(self, id: str):
|
|
import time
|
|
logger.info(f"Instantiating new session {id}")
|
|
self._instances.append(self)
|
|
self.id = id
|
|
self.short = id[:8]
|
|
self.name = ""
|
|
self.lobbies: list[Lobby] = [] # List of lobby IDs this session is in
|
|
self.lobby_peers: dict[
|
|
str, list[str]
|
|
] = {} # lobby ID -> list of peer session IDs
|
|
self.ws: WebSocket | None = None
|
|
self.created_at = time.time()
|
|
self.last_used = time.time()
|
|
self.displaced_at: float | None = None # When name was taken over
|
|
self.save()
|
|
|
|
@classmethod
|
|
def save(cls):
|
|
sessions_list: list[SessionSaved] = []
|
|
for s in cls._instances:
|
|
lobbies_list: list[LobbySaved] = [
|
|
LobbySaved(id=lobby.id, name=lobby.name, private=lobby.private)
|
|
for lobby in s.lobbies
|
|
]
|
|
sessions_list.append(
|
|
SessionSaved(
|
|
id=s.id,
|
|
name=s.name or "",
|
|
lobbies=lobbies_list,
|
|
created_at=s.created_at,
|
|
last_used=s.last_used,
|
|
displaced_at=s.displaced_at,
|
|
)
|
|
)
|
|
# Prepare name password store for persistence (salt+hash). Only structured records are supported.
|
|
saved_pw: dict[str, NamePasswordRecord] = {
|
|
name: NamePasswordRecord(**record)
|
|
for name, record in name_passwords.items()
|
|
}
|
|
|
|
payload_model = SessionsPayload(sessions=sessions_list, name_passwords=saved_pw)
|
|
payload = payload_model.model_dump()
|
|
|
|
with open(cls._save_file, "w") as f:
|
|
json.dump(payload, f, indent=2)
|
|
|
|
logger.info(
|
|
f"Saved {len(sessions_list)} sessions and {len(saved_pw)} name passwords to {cls._save_file}"
|
|
)
|
|
|
|
@classmethod
|
|
def load(cls):
|
|
import time
|
|
if not os.path.exists(cls._save_file):
|
|
logger.info(f"No session save file found: {cls._save_file}")
|
|
return
|
|
with open(cls._save_file, "r") as f:
|
|
raw = json.load(f)
|
|
|
|
try:
|
|
payload = SessionsPayload.model_validate(raw)
|
|
except ValidationError as e:
|
|
logger.exception(f"Failed to validate sessions payload: {e}")
|
|
return
|
|
|
|
# Populate in-memory structures from payload (no backwards compatibility code)
|
|
name_passwords.clear()
|
|
for name, rec in payload.name_passwords.items():
|
|
# rec is a NamePasswordRecord
|
|
name_passwords[name] = {"salt": rec.salt, "hash": rec.hash}
|
|
|
|
for s_saved in payload.sessions:
|
|
session = Session(s_saved.id)
|
|
session.name = s_saved.name or ""
|
|
# Load timestamps, with defaults for backward compatibility
|
|
session.created_at = getattr(s_saved, "created_at", time.time())
|
|
session.last_used = getattr(s_saved, "last_used", time.time())
|
|
session.displaced_at = getattr(s_saved, "displaced_at", None)
|
|
for lobby_saved in s_saved.lobbies:
|
|
session.lobbies.append(
|
|
Lobby(
|
|
name=lobby_saved.name,
|
|
id=lobby_saved.id,
|
|
private=lobby_saved.private,
|
|
)
|
|
)
|
|
logger.info(
|
|
f"Loaded session {session.getName()} with {len(session.lobbies)} lobbies"
|
|
)
|
|
for lobby in session.lobbies:
|
|
lobbies[lobby.id] = Lobby(
|
|
name=lobby.name, id=lobby.id
|
|
) # Ensure lobby exists
|
|
|
|
logger.info(
|
|
f"Loaded {len(payload.sessions)} sessions and {len(name_passwords)} name passwords from {cls._save_file}"
|
|
)
|
|
|
|
@classmethod
|
|
def getSession(cls, id: str) -> Session | None:
|
|
if not cls._loaded:
|
|
cls.load()
|
|
logger.info(f"Loaded {len(cls._instances)} sessions from disk...")
|
|
cls._loaded = True
|
|
for s in cls._instances:
|
|
if s.id == id:
|
|
return s
|
|
return None
|
|
|
|
@classmethod
|
|
def isUniqueName(cls, name: str) -> bool:
|
|
if not name:
|
|
return False
|
|
for s in cls._instances:
|
|
if s.name.lower() == name.lower():
|
|
return False
|
|
return True
|
|
|
|
@classmethod
|
|
def getSessionByName(cls, name: str) -> Optional["Session"]:
|
|
if not name:
|
|
return None
|
|
lname = name.lower()
|
|
for s in cls._instances:
|
|
if s.name and s.name.lower() == lname:
|
|
return s
|
|
return None
|
|
|
|
def getName(self) -> str:
|
|
return f"{self.short}:{self.name if self.name else unset_label}"
|
|
|
|
def setName(self, name: str):
|
|
self.name = name
|
|
self.update_last_used()
|
|
self.save()
|
|
|
|
def update_last_used(self):
|
|
"""Update the last_used timestamp"""
|
|
import time
|
|
|
|
self.last_used = time.time()
|
|
|
|
def mark_displaced(self):
|
|
"""Mark this session as having its name taken over"""
|
|
import time
|
|
|
|
self.displaced_at = time.time()
|
|
|
|
@classmethod
|
|
def cleanup_old_sessions(cls) -> int:
|
|
"""Clean up old sessions based on the specified criteria"""
|
|
import time
|
|
|
|
current_time = time.time()
|
|
one_minute = 60.0
|
|
twenty_four_hours = 24 * 60 * 60.0
|
|
sessions_removed = 0
|
|
|
|
# Make a copy of the list to avoid modifying it while iterating
|
|
sessions_to_remove: list[Session] = []
|
|
|
|
for session in cls._instances[:]:
|
|
# Rule 1: Delete sessions with no active connection and no name that are older than 1 minute
|
|
if (
|
|
not session.ws
|
|
and not session.name
|
|
and current_time - session.created_at > one_minute
|
|
):
|
|
logger.info(
|
|
f"Removing session {session.getName()} - no connection, no name, older than 1 minute"
|
|
)
|
|
sessions_to_remove.append(session)
|
|
continue
|
|
|
|
# Rule 2: Delete inactive sessions that had their nick taken over and haven't been used in 24 hours
|
|
if (
|
|
not session.ws
|
|
and session.displaced_at is not None
|
|
and current_time - session.last_used > twenty_four_hours
|
|
):
|
|
logger.info(
|
|
f"Removing session {session.getName()} - displaced and unused for 24+ hours"
|
|
)
|
|
sessions_to_remove.append(session)
|
|
continue
|
|
|
|
# Remove the sessions
|
|
for session in sessions_to_remove:
|
|
# Remove from lobbies first
|
|
for lobby in session.lobbies[
|
|
:
|
|
]: # Copy list to avoid modification during iteration
|
|
try:
|
|
# Use async cleanup if needed, but for cleanup we'll just remove from data structures
|
|
if session.id in lobby.sessions:
|
|
del lobby.sessions[session.id]
|
|
if lobby.id in session.lobby_peers:
|
|
del session.lobby_peers[lobby.id]
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Error removing session {session.getName()} from lobby {lobby.getName()}: {e}"
|
|
)
|
|
|
|
# Remove from instances list
|
|
if session in cls._instances:
|
|
cls._instances.remove(session)
|
|
sessions_removed += 1
|
|
|
|
if sessions_removed > 0:
|
|
cls.save()
|
|
logger.info(f"Session cleanup: removed {sessions_removed} old sessions")
|
|
|
|
return sessions_removed
|
|
|
|
async def join(self, lobby: Lobby):
|
|
if not self.ws:
|
|
logger.error(
|
|
f"{self.getName()} - No WebSocket connection. Lobby not available."
|
|
)
|
|
return
|
|
|
|
if lobby.id in self.lobby_peers or self.id in lobby.sessions:
|
|
logger.info(f"{self.getName()} - Already joined to {lobby.getName()}.")
|
|
data = JoinStatusModel(
|
|
status="Joined", message=f"Already joined to lobby {lobby.getName()}"
|
|
)
|
|
await self.ws.send_json({"type": "join_status", "data": data.model_dump()})
|
|
return
|
|
|
|
# Initialize the peer list for this lobby
|
|
self.lobbies.append(lobby)
|
|
self.lobby_peers[lobby.id] = []
|
|
|
|
for peer_id in lobby.sessions:
|
|
if peer_id == self.id:
|
|
raise Exception(
|
|
"Should not happen: self in lobby.sessions while not in lobby."
|
|
)
|
|
|
|
peer_session = lobby.getSession(peer_id)
|
|
if not peer_session or not peer_session.ws:
|
|
logger.warning(
|
|
f"{self.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}. Removing."
|
|
)
|
|
del lobby.sessions[peer_id]
|
|
continue
|
|
|
|
# Add the peer to session's RTC peer list
|
|
self.lobby_peers[lobby.id].append(peer_id)
|
|
|
|
# Add this user as an RTC peer to each existing peer
|
|
peer_session.lobby_peers[lobby.id].append(self.id)
|
|
|
|
logger.info(
|
|
f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName(), lobby.getName()}, should_create_offer=False)"
|
|
)
|
|
await peer_session.ws.send_json(
|
|
{
|
|
"type": "addPeer",
|
|
"data": {
|
|
"peer_id": self.id,
|
|
"peer_name": self.name,
|
|
"should_create_offer": False,
|
|
},
|
|
}
|
|
)
|
|
|
|
# Add each other peer to the caller
|
|
logger.info(
|
|
f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName(), lobby.getName()}, should_create_offer=True)"
|
|
)
|
|
await self.ws.send_json(
|
|
{
|
|
"type": "addPeer",
|
|
"data": {
|
|
"peer_id": peer_session.id,
|
|
"peer_name": peer_session.name,
|
|
"should_create_offer": True,
|
|
},
|
|
}
|
|
)
|
|
|
|
# Add this user as an RTC peer
|
|
await lobby.addSession(self)
|
|
Session.save()
|
|
|
|
await self.ws.send_json({"type": "join_status", "data": {"status": "Joined"}})
|
|
|
|
async def part(self, lobby: Lobby):
|
|
if lobby.id not in self.lobby_peers or self.id not in lobby.sessions:
|
|
logger.info(
|
|
f"{self.getName()} - Attempt to part non-joined lobby {lobby.getName()}."
|
|
)
|
|
if self.ws:
|
|
await self.ws.send_json(
|
|
{"type": "error", "error": "Attempt to part non-joined lobby"}
|
|
)
|
|
return
|
|
|
|
logger.info(f"{self.getName()} <- part({lobby.getName()}) - Lobby part.")
|
|
|
|
lobby_peers = self.lobby_peers[lobby.id]
|
|
del self.lobby_peers[lobby.id]
|
|
self.lobbies.remove(lobby)
|
|
|
|
# Remove this peer from all other RTC peers, and remove each peer from this peer
|
|
for peer_session_id in lobby_peers:
|
|
peer_session = getSession(peer_session_id)
|
|
if not peer_session:
|
|
logger.warning(
|
|
f"{self.getName()} <- part({lobby.getName()}) - Peer session {peer_session_id} not found. Skipping."
|
|
)
|
|
continue
|
|
if not peer_session.ws:
|
|
logger.warning(
|
|
f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection for {peer_session.getName()}. Skipping."
|
|
)
|
|
continue
|
|
logger.info(f"{peer_session.getName()} <- remove_peer({self.getName()})")
|
|
await peer_session.ws.send_json(
|
|
{
|
|
"type": "removePeer",
|
|
"data": {"peer_name": self.name, "peer_id": self.id},
|
|
}
|
|
)
|
|
|
|
if not self.ws:
|
|
logger.error(
|
|
f"{self.getName()} <- part({lobby.getName()}) - No WebSocket connection."
|
|
)
|
|
continue
|
|
|
|
logger.info(f"{self.getName()} <- remove_peer({peer_session.getName()})")
|
|
|
|
await self.ws.send_json(
|
|
{
|
|
"type": "removePeer",
|
|
"data": {
|
|
"peer_name": peer_session.name,
|
|
"peer_id": peer_session.id,
|
|
},
|
|
}
|
|
)
|
|
|
|
await lobby.removeSession(self)
|
|
Session.save()
|
|
|
|
|
|
def getName(session: Session | None) -> str | None:
|
|
if session and session.name:
|
|
return session.name
|
|
return None
|
|
|
|
|
|
def getSession(session_id: str) -> Session | None:
|
|
return Session.getSession(session_id)
|
|
|
|
|
|
def getLobby(lobby_id: str) -> Lobby:
|
|
lobby = lobbies.get(lobby_id, None)
|
|
if not lobby:
|
|
logger.error(f"Lobby not found: {lobby_id}")
|
|
raise Exception(f"Lobby not found: {lobby_id}")
|
|
return lobby
|
|
|
|
|
|
def getLobbyByName(lobby_name: str) -> Lobby | None:
|
|
for lobby in lobbies.values():
|
|
if lobby.name == lobby_name:
|
|
return lobby
|
|
return None
|
|
|
|
|
|
# API endpoints
|
|
@app.get(f"{public_url}api/health", response_model=HealthResponse)
|
|
def health():
|
|
logger.info("Health check endpoint called.")
|
|
return {
|
|
"status": "ok",
|
|
}
|
|
|
|
|
|
# A session (cookie) is bound to a single user (name).
|
|
# A user can be in multiple lobbies, but a session is unique to a single user.
|
|
# A user can change their name, but the session ID remains the same and the name
|
|
# updates for all lobbies.
|
|
@app.get(f"{public_url}api/session", response_model=SessionResponse)
|
|
async def session(
|
|
request: Request, response: Response, session_id: str | None = Cookie(default=None)
|
|
) -> Response | SessionResponse:
|
|
if session_id is None:
|
|
session_id = secrets.token_hex(16)
|
|
response.set_cookie(key="session_id", value=session_id)
|
|
# Validate that session_id is a hex string of length 32
|
|
elif len(session_id) != 32 or not all(c in "0123456789abcdef" for c in session_id):
|
|
return Response(
|
|
content=json.dumps({"error": "Invalid session_id"}),
|
|
status_code=400,
|
|
media_type="application/json",
|
|
)
|
|
|
|
print(f"[{session_id[:8]}]: Browser hand-shake achieved.")
|
|
|
|
session = getSession(session_id)
|
|
if not session:
|
|
session = Session(session_id)
|
|
logger.info(f"{session.getName()}: New session created.")
|
|
else:
|
|
session.update_last_used() # Update activity on session resumption
|
|
logger.info(f"{session.getName()}: Existing session resumed.")
|
|
# Part all lobbies for this session that have no active websocket
|
|
for lobby_id in list(session.lobby_peers.keys()):
|
|
lobby = None
|
|
try:
|
|
lobby = getLobby(lobby_id)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"{session.getName()} - Error getting lobby {lobby_id}: {e}"
|
|
)
|
|
continue
|
|
await session.part(lobby)
|
|
|
|
return SessionResponse(
|
|
id=session_id,
|
|
name=session.name if session.name else "",
|
|
lobbies=[
|
|
LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private)
|
|
for lobby in session.lobbies
|
|
],
|
|
)
|
|
|
|
|
|
@app.get(public_url + "api/lobby", response_model=LobbiesResponse)
|
|
async def get_lobbies(request: Request, response: Response) -> LobbiesResponse:
|
|
return LobbiesResponse(
|
|
lobbies=[
|
|
LobbyListItem(id=lobby.id, name=lobby.name)
|
|
for lobby in lobbies.values()
|
|
if not lobby.private
|
|
]
|
|
)
|
|
|
|
|
|
@app.post(public_url + "api/lobby/{session_id}", response_model=LobbyCreateResponse)
|
|
async def lobby_create(
|
|
request: Request,
|
|
response: Response,
|
|
session_id: str = Path(...),
|
|
create_request: LobbyCreateRequest = Body(...),
|
|
) -> Response | LobbyCreateResponse:
|
|
if create_request.type != "lobby_create":
|
|
return {"error": "Invalid request type"}
|
|
|
|
data = create_request.data
|
|
session = getSession(session_id)
|
|
if not session:
|
|
return Response(
|
|
content=json.dumps({"error": f"Session not found ({session_id})"}),
|
|
status_code=404,
|
|
media_type="application/json",
|
|
)
|
|
logger.info(
|
|
f"{session.getName()} lobby_create: {data.name} (private={data.private})"
|
|
)
|
|
|
|
lobby = getLobbyByName(data.name)
|
|
if not lobby:
|
|
lobby = Lobby(
|
|
data.name,
|
|
private=data.private,
|
|
)
|
|
lobbies[lobby.id] = lobby
|
|
logger.info(f"{session.getName()} <- lobby_create({lobby.short}:{lobby.name})")
|
|
|
|
return LobbyCreateResponse(
|
|
type="lobby_created",
|
|
data=LobbyModel(id=lobby.id, name=lobby.name, private=lobby.private),
|
|
)
|
|
|
|
|
|
@app.get(public_url + "api/lobby/{lobby_id}/chat", response_model=ChatMessagesResponse)
|
|
async def get_chat_messages(
|
|
request: Request,
|
|
lobby_id: str = Path(...),
|
|
limit: int = 50,
|
|
) -> Response | ChatMessagesResponse:
|
|
"""Get chat messages for a lobby"""
|
|
try:
|
|
lobby = getLobby(lobby_id)
|
|
except Exception as e:
|
|
return Response(
|
|
content=json.dumps({"error": str(e)}),
|
|
status_code=404,
|
|
media_type="application/json",
|
|
)
|
|
|
|
messages = lobby.get_chat_messages(limit)
|
|
|
|
return ChatMessagesResponse(messages=messages)
|
|
# Register websocket endpoint directly on app with full public_url path
|
|
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}")
|
|
async def lobby_join(
|
|
websocket: WebSocket,
|
|
lobby_id: str | None = Path(...),
|
|
session_id: str | None = Path(...),
|
|
):
|
|
await websocket.accept()
|
|
if lobby_id is None:
|
|
await websocket.send_json(
|
|
{"type": "error", "error": "Invalid or missing lobby"}
|
|
)
|
|
await websocket.close()
|
|
return
|
|
if session_id is None:
|
|
await websocket.send_json(
|
|
{"type": "error", "error": "Invalid or missing session"}
|
|
)
|
|
await websocket.close()
|
|
return
|
|
session = getSession(session_id)
|
|
if not session:
|
|
# logger.error(f"Invalid session ID {session_id}")
|
|
await websocket.send_json(
|
|
{"type": "error", "error": f"Invalid session ID {session_id}"}
|
|
)
|
|
await websocket.close()
|
|
return
|
|
|
|
lobby = None
|
|
try:
|
|
lobby = getLobby(lobby_id)
|
|
except Exception as e:
|
|
await websocket.send_json({"type": "error", "error": str(e)})
|
|
await websocket.close()
|
|
return
|
|
|
|
logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})")
|
|
|
|
session.ws = websocket
|
|
session.update_last_used() # Update activity timestamp
|
|
if session.id in lobby.sessions:
|
|
logger.info(
|
|
f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining."
|
|
)
|
|
await session.part(lobby)
|
|
await lobby.removeSession(session)
|
|
|
|
for peer_id in lobby.sessions:
|
|
peer_session = lobby.getSession(peer_id)
|
|
if not peer_session or not peer_session.ws:
|
|
logger.warning(
|
|
f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}. Removing."
|
|
)
|
|
del lobby.sessions[peer_id]
|
|
continue
|
|
logger.info(f"{session.getName()} -> user_joined({peer_session.getName()})")
|
|
await peer_session.ws.send_json(
|
|
{
|
|
"type": "user_joined",
|
|
"data": {
|
|
"session_id": session.id,
|
|
"name": session.name,
|
|
},
|
|
}
|
|
)
|
|
|
|
try:
|
|
while True:
|
|
packet = await websocket.receive_json()
|
|
session.update_last_used() # Update activity on each message
|
|
type = packet.get("type", None)
|
|
data: dict[str, Any] | None = packet.get("data", None)
|
|
if not type:
|
|
logger.error(f"{session.getName()} - Invalid request: {packet}")
|
|
await websocket.send_json({"type": "error", "error": "Invalid request"})
|
|
continue
|
|
# logger.info(f"{session.getName()} <- RAW Rx: {data}")
|
|
match type:
|
|
case "set_name":
|
|
if not data:
|
|
logger.error(f"{session.getName()} - set_name missing data")
|
|
await websocket.send_json(
|
|
{"type": "error", "error": "set_name missing data"}
|
|
)
|
|
continue
|
|
name = data.get("name")
|
|
password = data.get("password")
|
|
logger.info(f"{session.getName()} <- set_name({name})")
|
|
if not name:
|
|
logger.error(f"{session.getName()} - Name required")
|
|
await websocket.send_json(
|
|
{"type": "error", "error": "Name required"}
|
|
)
|
|
continue
|
|
# Name takeover / password logic
|
|
lname = name.lower()
|
|
|
|
# If name is unused, allow and optionally save password
|
|
if Session.isUniqueName(name):
|
|
# If a password was provided, save it (hash+salt) for this name
|
|
if password:
|
|
salt, hash_hex = _hash_password(password)
|
|
name_passwords[lname] = {"salt": salt, "hash": hash_hex}
|
|
session.setName(name)
|
|
logger.info(f"{session.getName()}: -> update('name', {name})")
|
|
await websocket.send_json(
|
|
{
|
|
"type": "update_name",
|
|
"data": {
|
|
"name": name,
|
|
"protected": True
|
|
if name.lower() in name_passwords
|
|
else False,
|
|
},
|
|
}
|
|
)
|
|
# For any clients in any lobby with this session, update their user lists
|
|
await lobby.update_state()
|
|
continue
|
|
|
|
# Name is taken. Check if a password exists for the name and matches.
|
|
saved_pw = name_passwords.get(lname)
|
|
if not saved_pw:
|
|
logger.warning(
|
|
f"{session.getName()} - Name already taken (no password set)"
|
|
)
|
|
await websocket.send_json(
|
|
{"type": "error", "error": "Name already taken"}
|
|
)
|
|
continue
|
|
|
|
# Expect structured record with salt+hash only
|
|
match_password = False
|
|
# saved_pw should be a dict[str,str] with 'salt' and 'hash'
|
|
salt = saved_pw.get("salt")
|
|
_, candidate_hash = _hash_password(
|
|
password if password else "", salt_hex=salt
|
|
)
|
|
if candidate_hash == saved_pw.get("hash"):
|
|
match_password = True
|
|
else:
|
|
# No structured password record available
|
|
match_password = False
|
|
|
|
if not match_password:
|
|
logger.warning(
|
|
f"{session.getName()} - Name takeover attempted with wrong or missing password"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"error": "Invalid password for name takeover",
|
|
}
|
|
)
|
|
continue
|
|
|
|
# Password matches: perform takeover. Find the current session holding the name.
|
|
# Find the currently existing session (if any) with that name
|
|
displaced = Session.getSessionByName(name)
|
|
if displaced and displaced.id == session.id:
|
|
displaced = None
|
|
|
|
# If found, change displaced session to a unique fallback name and notify peers
|
|
if displaced:
|
|
# Create a unique fallback name
|
|
fallback = f"{displaced.name}-{displaced.short}"
|
|
# Ensure uniqueness
|
|
if not Session.isUniqueName(fallback):
|
|
# append random suffix until unique
|
|
while not Session.isUniqueName(fallback):
|
|
fallback = f"{displaced.name}-{secrets.token_hex(3)}"
|
|
|
|
displaced.setName(fallback)
|
|
displaced.mark_displaced()
|
|
logger.info(
|
|
f"{displaced.getName()} <- displaced by takeover, new name {fallback}"
|
|
)
|
|
# Notify displaced session (if connected)
|
|
if displaced.ws:
|
|
try:
|
|
await displaced.ws.send_json(
|
|
{
|
|
"type": "update_name",
|
|
"data": {
|
|
"name": fallback,
|
|
"protected": False,
|
|
},
|
|
}
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to notify displaced session websocket"
|
|
)
|
|
# Update all lobbies the displaced session was in
|
|
for d_lobby in list(displaced.lobbies):
|
|
try:
|
|
await d_lobby.update_state()
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to update lobby state for displaced session"
|
|
)
|
|
|
|
# Now assign the requested name to the current session
|
|
session.setName(name)
|
|
logger.info(
|
|
f"{session.getName()}: -> update('name', {name}) (takeover)"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "update_name",
|
|
"data": {
|
|
"name": name,
|
|
"protected": True
|
|
if name.lower() in name_passwords
|
|
else False,
|
|
},
|
|
}
|
|
)
|
|
# Notify lobbies for this session
|
|
await lobby.update_state()
|
|
|
|
case "list_users":
|
|
await lobby.update_state(session)
|
|
|
|
case "get_chat_messages":
|
|
# Send recent chat messages to the requesting client
|
|
messages = lobby.get_chat_messages(50)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "chat_messages",
|
|
"data": {
|
|
"messages": [msg.model_dump() for msg in messages]
|
|
},
|
|
}
|
|
)
|
|
|
|
case "send_chat_message":
|
|
if not data or "message" not in data:
|
|
logger.error(
|
|
f"{session.getName()} - send_chat_message missing message"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"error": "send_chat_message missing message",
|
|
}
|
|
)
|
|
continue
|
|
|
|
if not session.name:
|
|
logger.error(
|
|
f"{session.getName()} - Cannot send chat message without name"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"error": "Must set name before sending chat messages",
|
|
}
|
|
)
|
|
continue
|
|
|
|
message_text = str(data["message"]).strip()
|
|
if not message_text:
|
|
continue
|
|
|
|
# Add the message to the lobby and broadcast it
|
|
chat_message = lobby.add_chat_message(session, message_text)
|
|
await lobby.broadcast_chat_message(chat_message)
|
|
logger.info(
|
|
f"{session.getName()} sent chat message to {lobby.getName()}: {message_text[:50]}..."
|
|
)
|
|
|
|
case "join":
|
|
logger.info(f"{session.getName()} <- join({lobby.getName()})")
|
|
await session.join(lobby=lobby)
|
|
|
|
case "part":
|
|
logger.info(f"{session.getName()} <- part {lobby.getName()}")
|
|
await session.part(lobby=lobby)
|
|
|
|
case "relayICECandidate":
|
|
logger.info(f"{session.getName()} <- relayICECandidate")
|
|
if not data:
|
|
logger.error(
|
|
f"{session.getName()} - relayICECandidate missing data"
|
|
)
|
|
await websocket.send_json(
|
|
{"type": "error", "error": "relayICECandidate missing data"}
|
|
)
|
|
continue
|
|
|
|
if (
|
|
lobby.id not in session.lobby_peers
|
|
or session.id not in lobby.sessions
|
|
):
|
|
logger.error(
|
|
f"{session.short}:{session.name} <- relayICECandidate - Not an RTC peer ({session.id})"
|
|
)
|
|
await websocket.send_json(
|
|
{"type": "error", "error": "Not joined to lobby"}
|
|
)
|
|
continue
|
|
session_peers = session.lobby_peers[lobby.id]
|
|
peer_id = data.get("peer_id")
|
|
if peer_id not in session_peers:
|
|
logger.error(
|
|
f"{session.getName()} <- relayICECandidate - Not an RTC peer({peer_id}) in {session_peers}"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"error": f"Target peer {peer_id} not found",
|
|
}
|
|
)
|
|
continue
|
|
|
|
candidate = data.get("candidate")
|
|
|
|
message: dict[str, Any] = {
|
|
"type": "iceCandidate",
|
|
"data": {
|
|
"peer_id": session.id,
|
|
"peer_name": session.name,
|
|
"candidate": candidate,
|
|
},
|
|
}
|
|
|
|
peer_session = lobby.getSession(peer_id)
|
|
if not peer_session or not peer_session.ws:
|
|
logger.warning(
|
|
f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}."
|
|
)
|
|
break
|
|
logger.info(
|
|
f"{session.getName()} -> iceCandidate({peer_session.getName()})"
|
|
)
|
|
await peer_session.ws.send_json(message)
|
|
|
|
case "relaySessionDescription":
|
|
logger.info(f"{session.getName()} <- relaySessionDescription")
|
|
if not data:
|
|
logger.error(
|
|
f"{session.getName()} - relaySessionDescription missing data"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"error": "relaySessionDescription missing data",
|
|
}
|
|
)
|
|
continue
|
|
|
|
if (
|
|
lobby.id not in session.lobby_peers
|
|
or session.id not in lobby.sessions
|
|
):
|
|
logger.error(
|
|
f"{session.short}:{session.name} <- relaySessionDescription - Not an RTC peer ({session.id})"
|
|
)
|
|
await websocket.send_json(
|
|
{"type": "error", "error": "Not joined to lobby"}
|
|
)
|
|
continue
|
|
|
|
lobby_peers = session.lobby_peers[lobby.id]
|
|
peer_id = data.get("peer_id")
|
|
if peer_id not in lobby_peers:
|
|
logger.error(
|
|
f"{session.getName()} <- relaySessionDescription - Not an RTC peer({peer_id}) in {lobby_peers}"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"error": f"Target peer {peer_id} not found",
|
|
}
|
|
)
|
|
continue
|
|
|
|
peer_id = data.get("peer_id", None)
|
|
if not peer_id:
|
|
logger.error(
|
|
f"{session.getName()} - relaySessionDescription missing peer_id"
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"error": "relaySessionDescription missing peer_id",
|
|
}
|
|
)
|
|
continue
|
|
peer_session = lobby.getSession(peer_id)
|
|
if not peer_session or not peer_session.ws:
|
|
logger.warning(
|
|
f"{session.getName()} - Live peer session {peer_id} not found in lobby {lobby.getName()}."
|
|
)
|
|
break
|
|
|
|
session_description = data.get("session_description")
|
|
message = {
|
|
"type": "sessionDescription",
|
|
"data": {
|
|
"peer_id": session.id,
|
|
"peer_name": session.name,
|
|
"session_description": session_description,
|
|
},
|
|
}
|
|
|
|
logger.info(
|
|
f"{session.getName()} -> sessionDescription({peer_session.getName()})"
|
|
)
|
|
await peer_session.ws.send_json(message)
|
|
|
|
case _:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"error": f"Unknown request type: {type}",
|
|
}
|
|
)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"{session.getName()} <- WebSocket disconnected for user.")
|
|
# Cleanup: remove session from lobby and sessions dict
|
|
session.ws = None
|
|
if session.id in lobby.sessions:
|
|
await session.part(lobby)
|
|
|
|
await lobby.update_state()
|
|
|
|
# Clean up empty lobbies
|
|
if not lobby.sessions:
|
|
if lobby.id in lobbies:
|
|
del lobbies[lobby.id]
|
|
logger.info(f"Cleaned up empty lobby {lobby.getName()}")
|
|
|
|
|
|
# Serve static files or proxy to frontend development server
|
|
PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true"
|
|
client_build_path = os.path.join(os.path.dirname(__file__), "/client/build")
|
|
|
|
if PRODUCTION:
|
|
logger.info(f"Serving static files from: {client_build_path} at {public_url}")
|
|
app.mount(
|
|
public_url, StaticFiles(directory=client_build_path, html=True), name="static"
|
|
)
|
|
|
|
|
|
else:
|
|
logger.info(f"Proxying static files to http://client:3000 at {public_url}")
|
|
|
|
import ssl
|
|
|
|
@app.api_route(
|
|
f"{public_url}{{path:path}}",
|
|
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
|
|
)
|
|
async def proxy_static(request: Request, path: str):
|
|
# Do not proxy API or websocket paths
|
|
if path.startswith("api/") or path.startswith("ws/"):
|
|
return Response(status_code=404)
|
|
url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}/{path}"
|
|
if not path:
|
|
url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}"
|
|
headers = dict(request.headers)
|
|
try:
|
|
# Accept self-signed certs in dev
|
|
async with httpx.AsyncClient(verify=False) as client:
|
|
proxy_req = client.build_request(
|
|
request.method, url, headers=headers, content=await request.body()
|
|
)
|
|
proxy_resp = await client.send(proxy_req, stream=True)
|
|
content = await proxy_resp.aread()
|
|
|
|
# Remove problematic headers for browser decoding
|
|
filtered_headers = {
|
|
k: v
|
|
for k, v in proxy_resp.headers.items()
|
|
if k.lower()
|
|
not in ["content-encoding", "transfer-encoding", "content-length"]
|
|
}
|
|
return Response(
|
|
content=content,
|
|
status_code=proxy_resp.status_code,
|
|
headers=filtered_headers,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Proxy error for {url}: {e}")
|
|
return Response("Proxy error", status_code=502)
|
|
|
|
# WebSocket proxy for /ws (for React DevTools, etc.)
|
|
import websockets
|
|
import asyncio
|
|
from starlette.websockets import WebSocket as StarletteWebSocket
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_proxy(websocket: StarletteWebSocket):
|
|
logger.info("REACT: WebSocket proxy connection established.")
|
|
# Get scheme from websocket.url (should be 'ws' or 'wss')
|
|
scheme = websocket.url.scheme if hasattr(websocket, "url") else "ws"
|
|
target_url = f"{scheme}://client:3000/ws"
|
|
await websocket.accept()
|
|
try:
|
|
# Accept self-signed certs in dev for WSS
|
|
ssl_ctx = ssl.create_default_context()
|
|
ssl_ctx.check_hostname = False
|
|
ssl_ctx.verify_mode = ssl.CERT_NONE
|
|
async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws:
|
|
|
|
async def client_to_server():
|
|
while True:
|
|
msg = await websocket.receive_text()
|
|
await target_ws.send(msg)
|
|
|
|
async def server_to_client():
|
|
while True:
|
|
msg = await target_ws.recv()
|
|
if isinstance(msg, str):
|
|
await websocket.send_text(msg)
|
|
else:
|
|
await websocket.send_bytes(msg)
|
|
|
|
try:
|
|
await asyncio.gather(client_to_server(), server_to_client())
|
|
except (WebSocketDisconnect, websockets.ConnectionClosed):
|
|
logger.info("REACT: WebSocket proxy connection closed.")
|
|
except Exception as e:
|
|
logger.error(f"REACT: WebSocket proxy error: {e}")
|
|
await websocket.close()
|