438 lines
15 KiB
Python
438 lines
15 KiB
Python
from fastapi import (
|
|
Cookie,
|
|
FastAPI,
|
|
Path,
|
|
WebSocket,
|
|
WebSocketDisconnect,
|
|
Request,
|
|
Response,
|
|
)
|
|
from fastapi.staticfiles import StaticFiles
|
|
import secrets
|
|
import os
|
|
import httpx
|
|
|
|
from logger import logger
|
|
|
|
public_url = os.getenv("PUBLIC_URL", "/")
|
|
if not public_url.endswith("/"):
|
|
public_url += "/"
|
|
|
|
app = FastAPI()
|
|
|
|
logger.info(f"Starting server with public URL: {public_url}")
|
|
|
|
|
|
class Session:
|
|
def __init__(self, id):
|
|
self.id = id
|
|
self.short = id[:8]
|
|
self.name = ""
|
|
self.lobbies: dict[str, Lobby] = {}
|
|
self.ws: WebSocket | None = None
|
|
self.has_audio = False
|
|
self.has_video = False
|
|
|
|
|
|
def getName(session: Session | None) -> str:
|
|
if not session:
|
|
return "Admin"
|
|
if session.name:
|
|
return session.name
|
|
return session.id
|
|
|
|
|
|
class Lobby:
|
|
def __init__(self, id):
|
|
self.id = id
|
|
self.short = id[:8]
|
|
self.sessions: dict[str, Session] = {}
|
|
|
|
def addSession(self, session: Session):
|
|
if session.id not in self.sessions:
|
|
self.sessions[session.id] = session
|
|
|
|
def removeSession(self, session: Session):
|
|
if session.id in self.sessions:
|
|
del self.sessions[session.id]
|
|
|
|
def getSession(self, id) -> Session | None:
|
|
return self.sessions.get(id, None)
|
|
|
|
|
|
lobbies: dict[str, Lobby] = {}
|
|
sessions: dict[str, Session] = {}
|
|
|
|
|
|
def getSession(session_id) -> Session | None:
|
|
return sessions.get(session_id, None)
|
|
|
|
|
|
def getLobby(lobby_id) -> Lobby | None:
|
|
return lobbies.get(lobby_id, None)
|
|
|
|
|
|
# API endpoints
|
|
@app.get(f"{public_url}api/health")
|
|
def health():
|
|
logger.info("Health check endpoint called.")
|
|
return {"status": "ok", "sessions": len(sessions), "lobbies": len(lobbies)}
|
|
|
|
|
|
# A session (cookie) is bound to a single user (name).
|
|
# A user can be in multiple lobbies, but a session is unique to a single user.
|
|
# A user can change their name, but the session ID remains the same and the name
|
|
# updates for all lobbies.
|
|
@app.get(f"{public_url}api/lobby")
|
|
async def lobby(
|
|
request: Request, response: Response, session_id: str = Cookie(default=None)
|
|
):
|
|
if session_id is None:
|
|
session_id = secrets.token_hex(16)
|
|
response.set_cookie(key="session_id", value=session_id)
|
|
|
|
print(f"[{session_id[:8]}]: Browser hand-shake achieved.")
|
|
|
|
if session_id not in sessions:
|
|
sessions[session_id] = Session(session_id)
|
|
logger.info(f"{session_id[:8]}: New session created.")
|
|
else:
|
|
name = sessions[session_id].name if sessions[session_id].name else "UNSET"
|
|
logger.info(f"{session_id[:8]}: Existing session resumed for {name}.")
|
|
|
|
return {"session": session_id}
|
|
|
|
|
|
all = "[ all ]"
|
|
info = "[ info ]"
|
|
todo = "[ todo ]"
|
|
|
|
|
|
async def join(
|
|
lobby: Lobby,
|
|
session: Session,
|
|
has_video: bool,
|
|
has_audio: bool,
|
|
):
|
|
if not session.name:
|
|
logger.error(
|
|
f"{session.short}:[UNSET] <- join - No name set yet. Audio not available."
|
|
)
|
|
return
|
|
|
|
if not session.ws:
|
|
logger.error(
|
|
f"{session.short}:{session.name} - No WebSocket connection. Audio not available."
|
|
)
|
|
return
|
|
|
|
logger.info(f"{lobby.short}: <- join - {session.short}:{session.name}")
|
|
|
|
if session.id in lobby.sessions:
|
|
logger.info(f"{session.short}:{session.name} - Already joined to Audio.")
|
|
return
|
|
|
|
logger.info(f"{lobby.short}: -> addPeer - {session.short}:{session.name}")
|
|
|
|
for peer in lobby.sessions.values():
|
|
if not peer.ws:
|
|
logger.warning(
|
|
f"{peer.short}:{peer.name} - No WebSocket connection. Skipping."
|
|
)
|
|
continue
|
|
# Add this caller to all peers
|
|
await peer.ws.send_json(
|
|
{
|
|
"type": "addPeer",
|
|
"data": {
|
|
"peer_id": session.id,
|
|
"should_create_offer": False,
|
|
"has_audio": has_audio,
|
|
"has_video": has_video,
|
|
},
|
|
}
|
|
)
|
|
|
|
# Add each other peer to the caller
|
|
await session.ws.send_json(
|
|
{
|
|
"type": "addPeer",
|
|
"data": {
|
|
"peer_id": peer,
|
|
"should_create_offer": True,
|
|
"has_audio": peer.has_audio,
|
|
"has_video": peer.has_video,
|
|
},
|
|
}
|
|
)
|
|
|
|
# Add this user as a peer connected to this WebSocket
|
|
lobby.sessions[session.id] = session
|
|
|
|
|
|
async def part(
|
|
lobby: Lobby,
|
|
session: Session,
|
|
):
|
|
if not session.ws:
|
|
logger.error(
|
|
f"{session.id}:{session.name} - No WebSocket connection. Audio not available."
|
|
)
|
|
return
|
|
|
|
if session.id not in lobby.sessions:
|
|
logger.info(f"{session.id}: <- {session.name} - Does not exist in lobby audio.")
|
|
return
|
|
|
|
logger.info(f"{session.id}: <- {session.name} - Audio part.")
|
|
logger.info(f"{lobby.short}: -> remove_peer - {session.short}:{session.name}")
|
|
|
|
del lobby.sessions[session.id]
|
|
|
|
# Remove this peer from all other peers, and remove each peer from this peer
|
|
for peer in lobby.sessions.values():
|
|
if not peer.ws:
|
|
logger.warning(
|
|
f"{peer.short}:{peer.name} - No WebSocket connection. Skipping."
|
|
)
|
|
continue
|
|
await peer.ws.send_json(
|
|
{"type": "remove_peer", "data": {"peer_id": session.id}}
|
|
)
|
|
|
|
if session.ws:
|
|
await session.ws.send_json(
|
|
{"type": "remove_peer", "data": {"peer_id": peer.id}}
|
|
)
|
|
else:
|
|
logger.warning(f"{session.short}:{session.name} - No WebSocket connection.")
|
|
|
|
|
|
# Register websocket endpoint directly on app with full public_url path
|
|
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}")
|
|
async def websocket_lobby(
|
|
websocket: WebSocket,
|
|
lobby_id: str | None = Path(...),
|
|
session_id: str = Cookie(default=None),
|
|
):
|
|
await websocket.accept()
|
|
if session_id is None:
|
|
await websocket.send_json(
|
|
{"type": "error", "error": "Invalid or missing user session"}
|
|
)
|
|
await websocket.close()
|
|
return
|
|
short = session_id[:8]
|
|
logger.info(f"Session ID from cookie: {session_id}")
|
|
|
|
session = getSession(session_id)
|
|
if not session:
|
|
logger.error(f"{short}: Invalid session ID {session_id}")
|
|
await websocket.send_json({"type": "error", "error": f"Invalid session ID {session_id}"})
|
|
await websocket.close()
|
|
return
|
|
session.ws = websocket
|
|
|
|
if lobby_id is None:
|
|
lobby_id = "default"
|
|
|
|
lobby = getLobby(lobby_id)
|
|
if not lobby:
|
|
lobby = Lobby(lobby_id)
|
|
lobbies[lobby_id] = lobby
|
|
logger.info(f"{short}: Lobby {lobby_id} - New Lobby")
|
|
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
match data.get("type"):
|
|
case "set_name":
|
|
name = data.get("name")
|
|
if not name:
|
|
await websocket.send_json(
|
|
{"type": "error", "error": "Name required"}
|
|
)
|
|
continue
|
|
# Check for duplicate name
|
|
if any(s.name.lower() == name.lower() for s in sessions.values()):
|
|
await websocket.send_json(
|
|
{"type": "error", "error": "Name already taken"}
|
|
)
|
|
continue
|
|
|
|
session.name = name
|
|
logger.info(f"{session.short}: Name set to {session.name}")
|
|
await websocket.send_json({"type": "update", "name": name})
|
|
case "list_users":
|
|
users = [{"name": s.name, "live": True} for s in sessions.values()]
|
|
await websocket.send_json({"type": "users", "users": users})
|
|
|
|
case 'media_status':
|
|
has_audio = data.get("audio", False)
|
|
has_video = data.get("video", False)
|
|
logger.info(f"{session.short}: <- media-status - audio: {has_audio}, video: {has_video}")
|
|
session.has_audio = has_audio
|
|
session.has_video = has_video
|
|
|
|
case "join":
|
|
has_audio = data.get("audio", False)
|
|
has_video = data.get("video", False)
|
|
await join(lobby, session, has_video, has_audio)
|
|
|
|
case "part":
|
|
await part(lobby, session)
|
|
|
|
case "relayICECandidate":
|
|
if id not in lobby.sessions:
|
|
logger.error(
|
|
f"{session.short}:{session.name} <- relayICECandidate - Does not have Audio"
|
|
)
|
|
return
|
|
|
|
peer_id = data.peer_id
|
|
candidate = data.candidate
|
|
|
|
message = {
|
|
type: "iceCandidate",
|
|
data: {"peer_id": session.id, "candidate": candidate},
|
|
}
|
|
|
|
if peer_id in lobby.sessions:
|
|
ws = lobby.sessions[peer_id].ws
|
|
if not ws:
|
|
logger.warning(
|
|
f"{lobby.sessions[peer_id].short}:{lobby.sessions[peer_id].name} - No WebSocket connection. Skipping."
|
|
)
|
|
continue
|
|
await ws.send_json(message)
|
|
|
|
case "relaySessionDescription":
|
|
# todo: if audio doesn't work, figure out if its because of peer_id/session_description missing
|
|
if session.id not in lobby.sessions:
|
|
logger.error(
|
|
f"{session.short}:{session.name} - relaySessionDescription - Does not have Audio"
|
|
)
|
|
|
|
peer_id = data.peer_id
|
|
session_description = data.session_description
|
|
message = {
|
|
type: "sessionDescription",
|
|
data: {
|
|
"peer_id": session.name,
|
|
"session_description": session_description,
|
|
},
|
|
}
|
|
if peer_id in lobby.sessions:
|
|
ws = lobby.sessions[peer_id].ws
|
|
if not ws:
|
|
logger.warning(
|
|
f"{lobby.sessions[peer_id].short}:{lobby.sessions[peer_id].name} - No WebSocket connection. Skipping."
|
|
)
|
|
continue
|
|
await ws.send_json(message)
|
|
|
|
case _:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"error": f"Unknown request type: {data.get('type')}",
|
|
}
|
|
)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"WebSocket disconnected for user {session_id}")
|
|
|
|
|
|
# Serve static files or proxy to frontend development server
|
|
PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true"
|
|
client_build_path = os.path.join(os.path.dirname(__file__), "/client/build")
|
|
|
|
if PRODUCTION:
|
|
logger.info(f"Serving static files from: {client_build_path} at {public_url}")
|
|
app.mount(
|
|
public_url, StaticFiles(directory=client_build_path, html=True), name="static"
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
logger.info(f"Proxying static files to http://static-frontend:3000 at {public_url}")
|
|
|
|
import ssl
|
|
|
|
@app.api_route(
|
|
f"{public_url}{{path:path}}",
|
|
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
|
|
)
|
|
async def proxy_static(request: Request, path: str):
|
|
# Do not proxy API or websocket paths
|
|
if path.startswith("api/") or path.startswith("ws/"):
|
|
return Response(status_code=404)
|
|
url = f"{request.url.scheme}://static-frontend:3000/{public_url.strip('/')}/{path}"
|
|
if not path:
|
|
url = f"{request.url.scheme}://static-frontend:3000/{public_url.strip('/')}"
|
|
headers = dict(request.headers)
|
|
try:
|
|
# Accept self-signed certs in dev
|
|
async with httpx.AsyncClient(verify=False) as client:
|
|
proxy_req = client.build_request(
|
|
request.method, url, headers=headers, content=await request.body()
|
|
)
|
|
proxy_resp = await client.send(proxy_req, stream=True)
|
|
content = await proxy_resp.aread()
|
|
|
|
# Remove problematic headers for browser decoding
|
|
filtered_headers = {
|
|
k: v
|
|
for k, v in proxy_resp.headers.items()
|
|
if k.lower() not in ["content-encoding", "transfer-encoding", "content-length"]
|
|
}
|
|
return Response(
|
|
content=content,
|
|
status_code=proxy_resp.status_code,
|
|
headers=filtered_headers,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Proxy error for {url}: {e}")
|
|
return Response("Proxy error", status_code=502)
|
|
|
|
# WebSocket proxy for /ws (for React DevTools, etc.)
|
|
import websockets
|
|
import asyncio
|
|
from starlette.websockets import WebSocket as StarletteWebSocket
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_proxy(websocket: StarletteWebSocket):
|
|
logger.info("WebSocket proxy connection established.")
|
|
# Get scheme from websocket.url (should be 'ws' or 'wss')
|
|
scheme = websocket.url.scheme if hasattr(websocket, 'url') else 'ws'
|
|
target_url = f"{scheme}://static-frontend:3000/ws"
|
|
await websocket.accept()
|
|
try:
|
|
# Accept self-signed certs in dev for WSS
|
|
ssl_ctx = ssl.create_default_context()
|
|
ssl_ctx.check_hostname = False
|
|
ssl_ctx.verify_mode = ssl.CERT_NONE
|
|
async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws:
|
|
|
|
async def client_to_server():
|
|
while True:
|
|
msg = await websocket.receive_text()
|
|
await target_ws.send(msg)
|
|
|
|
async def server_to_client():
|
|
while True:
|
|
msg = await target_ws.recv()
|
|
if isinstance(msg, str):
|
|
await websocket.send_text(msg)
|
|
else:
|
|
await websocket.send_bytes(msg)
|
|
|
|
try:
|
|
await asyncio.gather(client_to_server(), server_to_client())
|
|
except (WebSocketDisconnect, websockets.ConnectionClosed):
|
|
logger.info("WebSocket proxy connection closed.")
|
|
except Exception as e:
|
|
logger.error(f"WebSocket proxy error: {e}")
|
|
await websocket.close()
|