2025-08-23 20:33:34 -07:00

438 lines
15 KiB
Python

from fastapi import (
Cookie,
FastAPI,
Path,
WebSocket,
WebSocketDisconnect,
Request,
Response,
)
from fastapi.staticfiles import StaticFiles
import secrets
import os
import httpx
from logger import logger
public_url = os.getenv("PUBLIC_URL", "/")
if not public_url.endswith("/"):
public_url += "/"
app = FastAPI()
logger.info(f"Starting server with public URL: {public_url}")
class Session:
def __init__(self, id):
self.id = id
self.short = id[:8]
self.name = ""
self.lobbies: dict[str, Lobby] = {}
self.ws: WebSocket | None = None
self.has_audio = False
self.has_video = False
def getName(session: Session | None) -> str:
if not session:
return "Admin"
if session.name:
return session.name
return session.id
class Lobby:
def __init__(self, id):
self.id = id
self.short = id[:8]
self.sessions: dict[str, Session] = {}
def addSession(self, session: Session):
if session.id not in self.sessions:
self.sessions[session.id] = session
def removeSession(self, session: Session):
if session.id in self.sessions:
del self.sessions[session.id]
def getSession(self, id) -> Session | None:
return self.sessions.get(id, None)
lobbies: dict[str, Lobby] = {}
sessions: dict[str, Session] = {}
def getSession(session_id) -> Session | None:
return sessions.get(session_id, None)
def getLobby(lobby_id) -> Lobby | None:
return lobbies.get(lobby_id, None)
# API endpoints
@app.get(f"{public_url}api/health")
def health():
logger.info("Health check endpoint called.")
return {"status": "ok", "sessions": len(sessions), "lobbies": len(lobbies)}
# A session (cookie) is bound to a single user (name).
# A user can be in multiple lobbies, but a session is unique to a single user.
# A user can change their name, but the session ID remains the same and the name
# updates for all lobbies.
@app.get(f"{public_url}api/lobby")
async def lobby(
request: Request, response: Response, session_id: str = Cookie(default=None)
):
if session_id is None:
session_id = secrets.token_hex(16)
response.set_cookie(key="session_id", value=session_id)
print(f"[{session_id[:8]}]: Browser hand-shake achieved.")
if session_id not in sessions:
sessions[session_id] = Session(session_id)
logger.info(f"{session_id[:8]}: New session created.")
else:
name = sessions[session_id].name if sessions[session_id].name else "UNSET"
logger.info(f"{session_id[:8]}: Existing session resumed for {name}.")
return {"session": session_id}
all = "[ all ]"
info = "[ info ]"
todo = "[ todo ]"
async def join(
lobby: Lobby,
session: Session,
has_video: bool,
has_audio: bool,
):
if not session.name:
logger.error(
f"{session.short}:[UNSET] <- join - No name set yet. Audio not available."
)
return
if not session.ws:
logger.error(
f"{session.short}:{session.name} - No WebSocket connection. Audio not available."
)
return
logger.info(f"{lobby.short}: <- join - {session.short}:{session.name}")
if session.id in lobby.sessions:
logger.info(f"{session.short}:{session.name} - Already joined to Audio.")
return
logger.info(f"{lobby.short}: -> addPeer - {session.short}:{session.name}")
for peer in lobby.sessions.values():
if not peer.ws:
logger.warning(
f"{peer.short}:{peer.name} - No WebSocket connection. Skipping."
)
continue
# Add this caller to all peers
await peer.ws.send_json(
{
"type": "addPeer",
"data": {
"peer_id": session.id,
"should_create_offer": False,
"has_audio": has_audio,
"has_video": has_video,
},
}
)
# Add each other peer to the caller
await session.ws.send_json(
{
"type": "addPeer",
"data": {
"peer_id": peer,
"should_create_offer": True,
"has_audio": peer.has_audio,
"has_video": peer.has_video,
},
}
)
# Add this user as a peer connected to this WebSocket
lobby.sessions[session.id] = session
async def part(
lobby: Lobby,
session: Session,
):
if not session.ws:
logger.error(
f"{session.id}:{session.name} - No WebSocket connection. Audio not available."
)
return
if session.id not in lobby.sessions:
logger.info(f"{session.id}: <- {session.name} - Does not exist in lobby audio.")
return
logger.info(f"{session.id}: <- {session.name} - Audio part.")
logger.info(f"{lobby.short}: -> remove_peer - {session.short}:{session.name}")
del lobby.sessions[session.id]
# Remove this peer from all other peers, and remove each peer from this peer
for peer in lobby.sessions.values():
if not peer.ws:
logger.warning(
f"{peer.short}:{peer.name} - No WebSocket connection. Skipping."
)
continue
await peer.ws.send_json(
{"type": "remove_peer", "data": {"peer_id": session.id}}
)
if session.ws:
await session.ws.send_json(
{"type": "remove_peer", "data": {"peer_id": peer.id}}
)
else:
logger.warning(f"{session.short}:{session.name} - No WebSocket connection.")
# Register websocket endpoint directly on app with full public_url path
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}")
async def websocket_lobby(
websocket: WebSocket,
lobby_id: str | None = Path(...),
session_id: str = Cookie(default=None),
):
await websocket.accept()
if session_id is None:
await websocket.send_json(
{"type": "error", "error": "Invalid or missing user session"}
)
await websocket.close()
return
short = session_id[:8]
logger.info(f"Session ID from cookie: {session_id}")
session = getSession(session_id)
if not session:
logger.error(f"{short}: Invalid session ID {session_id}")
await websocket.send_json({"type": "error", "error": f"Invalid session ID {session_id}"})
await websocket.close()
return
session.ws = websocket
if lobby_id is None:
lobby_id = "default"
lobby = getLobby(lobby_id)
if not lobby:
lobby = Lobby(lobby_id)
lobbies[lobby_id] = lobby
logger.info(f"{short}: Lobby {lobby_id} - New Lobby")
try:
while True:
data = await websocket.receive_json()
match data.get("type"):
case "set_name":
name = data.get("name")
if not name:
await websocket.send_json(
{"type": "error", "error": "Name required"}
)
continue
# Check for duplicate name
if any(s.name.lower() == name.lower() for s in sessions.values()):
await websocket.send_json(
{"type": "error", "error": "Name already taken"}
)
continue
session.name = name
logger.info(f"{session.short}: Name set to {session.name}")
await websocket.send_json({"type": "update", "name": name})
case "list_users":
users = [{"name": s.name, "live": True} for s in sessions.values()]
await websocket.send_json({"type": "users", "users": users})
case 'media_status':
has_audio = data.get("audio", False)
has_video = data.get("video", False)
logger.info(f"{session.short}: <- media-status - audio: {has_audio}, video: {has_video}")
session.has_audio = has_audio
session.has_video = has_video
case "join":
has_audio = data.get("audio", False)
has_video = data.get("video", False)
await join(lobby, session, has_video, has_audio)
case "part":
await part(lobby, session)
case "relayICECandidate":
if id not in lobby.sessions:
logger.error(
f"{session.short}:{session.name} <- relayICECandidate - Does not have Audio"
)
return
peer_id = data.peer_id
candidate = data.candidate
message = {
type: "iceCandidate",
data: {"peer_id": session.id, "candidate": candidate},
}
if peer_id in lobby.sessions:
ws = lobby.sessions[peer_id].ws
if not ws:
logger.warning(
f"{lobby.sessions[peer_id].short}:{lobby.sessions[peer_id].name} - No WebSocket connection. Skipping."
)
continue
await ws.send_json(message)
case "relaySessionDescription":
# todo: if audio doesn't work, figure out if its because of peer_id/session_description missing
if session.id not in lobby.sessions:
logger.error(
f"{session.short}:{session.name} - relaySessionDescription - Does not have Audio"
)
peer_id = data.peer_id
session_description = data.session_description
message = {
type: "sessionDescription",
data: {
"peer_id": session.name,
"session_description": session_description,
},
}
if peer_id in lobby.sessions:
ws = lobby.sessions[peer_id].ws
if not ws:
logger.warning(
f"{lobby.sessions[peer_id].short}:{lobby.sessions[peer_id].name} - No WebSocket connection. Skipping."
)
continue
await ws.send_json(message)
case _:
await websocket.send_json(
{
"type": "error",
"error": f"Unknown request type: {data.get('type')}",
}
)
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for user {session_id}")
# Serve static files or proxy to frontend development server
PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true"
client_build_path = os.path.join(os.path.dirname(__file__), "/client/build")
if PRODUCTION:
logger.info(f"Serving static files from: {client_build_path} at {public_url}")
app.mount(
public_url, StaticFiles(directory=client_build_path, html=True), name="static"
)
else:
logger.info(f"Proxying static files to http://static-frontend:3000 at {public_url}")
import ssl
@app.api_route(
f"{public_url}{{path:path}}",
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
)
async def proxy_static(request: Request, path: str):
# Do not proxy API or websocket paths
if path.startswith("api/") or path.startswith("ws/"):
return Response(status_code=404)
url = f"{request.url.scheme}://static-frontend:3000/{public_url.strip('/')}/{path}"
if not path:
url = f"{request.url.scheme}://static-frontend:3000/{public_url.strip('/')}"
headers = dict(request.headers)
try:
# Accept self-signed certs in dev
async with httpx.AsyncClient(verify=False) as client:
proxy_req = client.build_request(
request.method, url, headers=headers, content=await request.body()
)
proxy_resp = await client.send(proxy_req, stream=True)
content = await proxy_resp.aread()
# Remove problematic headers for browser decoding
filtered_headers = {
k: v
for k, v in proxy_resp.headers.items()
if k.lower() not in ["content-encoding", "transfer-encoding", "content-length"]
}
return Response(
content=content,
status_code=proxy_resp.status_code,
headers=filtered_headers,
)
except Exception as e:
logger.error(f"Proxy error for {url}: {e}")
return Response("Proxy error", status_code=502)
# WebSocket proxy for /ws (for React DevTools, etc.)
import websockets
import asyncio
from starlette.websockets import WebSocket as StarletteWebSocket
@app.websocket("/ws")
async def websocket_proxy(websocket: StarletteWebSocket):
logger.info("WebSocket proxy connection established.")
# Get scheme from websocket.url (should be 'ws' or 'wss')
scheme = websocket.url.scheme if hasattr(websocket, 'url') else 'ws'
target_url = f"{scheme}://static-frontend:3000/ws"
await websocket.accept()
try:
# Accept self-signed certs in dev for WSS
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws:
async def client_to_server():
while True:
msg = await websocket.receive_text()
await target_ws.send(msg)
async def server_to_client():
while True:
msg = await target_ws.recv()
if isinstance(msg, str):
await websocket.send_text(msg)
else:
await websocket.send_bytes(msg)
try:
await asyncio.gather(client_to_server(), server_to_client())
except (WebSocketDisconnect, websockets.ConnectionClosed):
logger.info("WebSocket proxy connection closed.")
except Exception as e:
logger.error(f"WebSocket proxy error: {e}")
await websocket.close()