WebRTC now working
This commit is contained in:
parent
728293ea2e
commit
6b47704723
@ -5,13 +5,17 @@ This module contains admin-only endpoints for managing users, sessions, and syst
|
||||
Extracted from main.py to improve maintainability and separation of concerns.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from fastapi import APIRouter, Request, Response, Body
|
||||
|
||||
# Import shared models
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# Add the parent directory of server to the path to access shared
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
server_dir = os.path.dirname(current_dir)
|
||||
project_root = os.path.dirname(server_dir)
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from fastapi import APIRouter, Request, Response, Body
|
||||
from shared.models import (
|
||||
AdminNamesResponse,
|
||||
AdminActionResponse,
|
||||
@ -21,13 +25,15 @@ from shared.models import (
|
||||
AdminMetricsResponse,
|
||||
AdminMetricsConfig,
|
||||
)
|
||||
|
||||
from logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..core.session_manager import SessionManager
|
||||
from ..core.lobby_manager import LobbyManager
|
||||
from ..core.auth_manager import AuthManager
|
||||
from core.session_manager import SessionManager, SessionConfig
|
||||
from core.lobby_manager import LobbyManager
|
||||
from core.auth_manager import AuthManager
|
||||
else:
|
||||
# Import for runtime
|
||||
from core.session_manager import SessionConfig
|
||||
|
||||
|
||||
class AdminAPI:
|
||||
@ -163,8 +169,6 @@ class AdminAPI:
|
||||
old_displaced = 0
|
||||
|
||||
for session in all_sessions:
|
||||
from ..core.session_manager import SessionConfig
|
||||
|
||||
# Anonymous sessions
|
||||
if (not session.ws and not session.name and
|
||||
current_time - session.created_at > SessionConfig.ANONYMOUS_SESSION_TIMEOUT):
|
||||
|
@ -2,17 +2,6 @@
|
||||
Session management for the AI Voice Bot server.
|
||||
|
||||
This module handles session lifecycle, persistence, and cleanup operations.
|
||||
Extracted from m await lobby.add await lobby.removeSession(self)
|
||||
|
||||
# Publish event
|
||||
await event_bus.publish(SessionLeftLobby(
|
||||
session_id=self.id,
|
||||
lobby_id=lobby.id,(self)
|
||||
|
||||
# Publish event
|
||||
await event_bus.publish(SessionJoinedLobby(
|
||||
session_id=self.id,
|
||||
lobby_id=lobby.id,to improve maintainability and separation of concerns.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -148,13 +137,93 @@ class Session:
|
||||
self.displaced_at = time.time()
|
||||
|
||||
async def join_lobby(self, lobby):
|
||||
"""Join a lobby and update peers"""
|
||||
"""Join a lobby and establish WebRTC peer connections"""
|
||||
with self.session_lock:
|
||||
if lobby not in self.lobbies:
|
||||
self.lobbies.append(lobby)
|
||||
|
||||
# Initialize lobby_peers for this lobby if not exists
|
||||
if lobby.id not in self.lobby_peers:
|
||||
self.lobby_peers[lobby.id] = []
|
||||
|
||||
# Add to lobby first
|
||||
await lobby.addSession(self)
|
||||
|
||||
# Get existing peer sessions in this lobby for WebRTC setup
|
||||
peer_sessions = []
|
||||
for session in lobby.sessions.values():
|
||||
if (
|
||||
session.id != self.id and session.ws
|
||||
): # Don't include self and only connected sessions
|
||||
peer_sessions.append(session)
|
||||
|
||||
# Establish WebRTC peer connections with existing sessions
|
||||
for peer_session in peer_sessions:
|
||||
# Only establish connections if at least one session has media
|
||||
if self.has_media or peer_session.has_media:
|
||||
logger.info(
|
||||
f"{self.getName()} <-> {peer_session.getName()} - Establishing WebRTC peer connection"
|
||||
)
|
||||
|
||||
# Add peer to our lobby_peers list
|
||||
with self.session_lock:
|
||||
if peer_session.id not in self.lobby_peers[lobby.id]:
|
||||
self.lobby_peers[lobby.id].append(peer_session.id)
|
||||
|
||||
# Add this session to peer's lobby_peers list
|
||||
with peer_session.session_lock:
|
||||
if lobby.id not in peer_session.lobby_peers:
|
||||
peer_session.lobby_peers[lobby.id] = []
|
||||
if self.id not in peer_session.lobby_peers[lobby.id]:
|
||||
peer_session.lobby_peers[lobby.id].append(self.id)
|
||||
|
||||
# Send addPeer to existing peer (they should not create offer)
|
||||
logger.info(
|
||||
f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName()}, {lobby.getName()}, should_create_offer=False, has_media={self.has_media})"
|
||||
)
|
||||
try:
|
||||
await peer_session.ws.send_json(
|
||||
{
|
||||
"type": "addPeer",
|
||||
"data": {
|
||||
"peer_id": self.id,
|
||||
"peer_name": self.name,
|
||||
"has_media": self.has_media,
|
||||
"should_create_offer": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to send addPeer to {peer_session.getName()}: {e}"
|
||||
)
|
||||
|
||||
# Send addPeer to this session (they should create offer)
|
||||
if self.ws:
|
||||
logger.info(
|
||||
f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName()}, {lobby.getName()}, should_create_offer=True, has_media={peer_session.has_media})"
|
||||
)
|
||||
try:
|
||||
await self.ws.send_json(
|
||||
{
|
||||
"type": "addPeer",
|
||||
"data": {
|
||||
"peer_id": peer_session.id,
|
||||
"peer_name": peer_session.name,
|
||||
"has_media": peer_session.has_media,
|
||||
"should_create_offer": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to send addPeer to {self.getName()}: {e}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"{self.getName()} - Skipping WebRTC connection with {peer_session.getName()} (neither has media: self={self.has_media}, peer={peer_session.has_media})"
|
||||
)
|
||||
|
||||
# Publish join event
|
||||
await event_bus.publish(SessionJoinedLobby(
|
||||
session_id=self.id,
|
||||
@ -163,13 +232,72 @@ class Session:
|
||||
))
|
||||
|
||||
async def leave_lobby(self, lobby):
|
||||
"""Leave a lobby and clean up peers"""
|
||||
"""Leave a lobby and clean up WebRTC peer connections"""
|
||||
# Get peer sessions before removing from lobby
|
||||
peer_sessions = []
|
||||
if lobby.id in self.lobby_peers:
|
||||
for peer_id in self.lobby_peers[lobby.id]:
|
||||
peer_session = None
|
||||
# Find peer session in lobby
|
||||
for session in lobby.sessions.values():
|
||||
if session.id == peer_id:
|
||||
peer_session = session
|
||||
break
|
||||
|
||||
if peer_session and peer_session.ws:
|
||||
peer_sessions.append(peer_session)
|
||||
|
||||
# Send removePeer messages to all peers
|
||||
for peer_session in peer_sessions:
|
||||
logger.info(f"{peer_session.getName()} <- remove_peer({self.getName()})")
|
||||
try:
|
||||
await peer_session.ws.send_json(
|
||||
{
|
||||
"type": "removePeer",
|
||||
"data": {"peer_name": self.name, "peer_id": self.id},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to send removePeer to {peer_session.getName()}: {e}"
|
||||
)
|
||||
|
||||
# Remove from peer's lobby_peers
|
||||
with peer_session.session_lock:
|
||||
if (
|
||||
lobby.id in peer_session.lobby_peers
|
||||
and self.id in peer_session.lobby_peers[lobby.id]
|
||||
):
|
||||
peer_session.lobby_peers[lobby.id].remove(self.id)
|
||||
|
||||
# Send removePeer to this session
|
||||
if self.ws:
|
||||
logger.info(
|
||||
f"{self.getName()} <- remove_peer({peer_session.getName()})"
|
||||
)
|
||||
try:
|
||||
await self.ws.send_json(
|
||||
{
|
||||
"type": "removePeer",
|
||||
"data": {
|
||||
"peer_name": peer_session.name,
|
||||
"peer_id": peer_session.id,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to send removePeer to {self.getName()}: {e}"
|
||||
)
|
||||
|
||||
# Clean up our lobby_peers and lobbies
|
||||
with self.session_lock:
|
||||
if lobby in self.lobbies:
|
||||
self.lobbies.remove(lobby)
|
||||
if lobby.id in self.lobby_peers:
|
||||
del self.lobby_peers[lobby.id]
|
||||
|
||||
# Remove from lobby
|
||||
await lobby.removeSession(self)
|
||||
|
||||
# Publish leave event
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,293 +0,0 @@
|
||||
"""
|
||||
Refactored main.py - Step 1 of Server Architecture Improvement
|
||||
|
||||
This is a refactored version of the original main.py that demonstrates the new
|
||||
modular architecture with separated concerns:
|
||||
|
||||
- SessionManager: Handles session lifecycle and persistence
|
||||
- LobbyManager: Handles lobby management and chat
|
||||
- AuthManager: Handles authentication and name protection
|
||||
- WebSocket message routing: Clean message handling
|
||||
- Separated API modules: Admin, session, and lobby endpoints
|
||||
|
||||
This maintains backward compatibility while providing a foundation for
|
||||
further improvements.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, WebSocket, Path, Request, Response
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import httpx
|
||||
import ssl
|
||||
import websockets
|
||||
|
||||
# Import our new modular components
|
||||
try:
|
||||
from core.session_manager import SessionManager
|
||||
from core.lobby_manager import LobbyManager
|
||||
from core.auth_manager import AuthManager
|
||||
from websocket.connection import WebSocketConnectionManager
|
||||
from api.admin import AdminAPI
|
||||
from api.sessions import SessionAPI
|
||||
from api.lobbies import LobbyAPI
|
||||
except ImportError:
|
||||
# Handle relative imports when running as module
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from core.session_manager import SessionManager
|
||||
from core.lobby_manager import LobbyManager
|
||||
from core.auth_manager import AuthManager
|
||||
from websocket.connection import WebSocketConnectionManager
|
||||
from api.admin import AdminAPI
|
||||
from api.sessions import SessionAPI
|
||||
from api.lobbies import LobbyAPI
|
||||
|
||||
from logger import logger
|
||||
|
||||
|
||||
# Configuration
|
||||
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN")
|
||||
public_url = os.getenv("PUBLIC_URL", "/")
|
||||
if not public_url.endswith("/"):
|
||||
public_url += "/"
|
||||
|
||||
# Global managers - these replace the global variables from original main.py
|
||||
session_manager: SessionManager = None
|
||||
lobby_manager: LobbyManager = None
|
||||
auth_manager: AuthManager = None
|
||||
websocket_manager: WebSocketConnectionManager = None
|
||||
|
||||
# API instances
|
||||
admin_api: AdminAPI = None
|
||||
session_api: SessionAPI = None
|
||||
lobby_api: LobbyAPI = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup and shutdown"""
|
||||
global session_manager, lobby_manager, auth_manager, websocket_manager
|
||||
global admin_api, session_api, lobby_api
|
||||
|
||||
logger.info("Starting AI Voice Bot server with modular architecture...")
|
||||
|
||||
# Initialize core managers
|
||||
session_manager = SessionManager()
|
||||
lobby_manager = LobbyManager(session_manager=session_manager)
|
||||
auth_manager = AuthManager()
|
||||
|
||||
# Set up cross-manager dependencies
|
||||
session_manager.set_lobby_manager(lobby_manager)
|
||||
lobby_manager.set_name_protection_checker(auth_manager.is_name_protected)
|
||||
|
||||
# Initialize WebSocket manager
|
||||
websocket_manager = WebSocketConnectionManager(
|
||||
session_manager=session_manager,
|
||||
lobby_manager=lobby_manager
|
||||
)
|
||||
|
||||
# Initialize API routers
|
||||
admin_api = AdminAPI(
|
||||
session_manager=session_manager,
|
||||
lobby_manager=lobby_manager,
|
||||
auth_manager=auth_manager,
|
||||
admin_token=ADMIN_TOKEN,
|
||||
public_url=public_url
|
||||
)
|
||||
|
||||
session_api = SessionAPI(
|
||||
session_manager=session_manager,
|
||||
public_url=public_url
|
||||
)
|
||||
|
||||
lobby_api = LobbyAPI(
|
||||
session_manager=session_manager,
|
||||
lobby_manager=lobby_manager,
|
||||
public_url=public_url
|
||||
)
|
||||
|
||||
# Register API routes
|
||||
app.include_router(admin_api.router)
|
||||
app.include_router(session_api.router)
|
||||
app.include_router(lobby_api.router)
|
||||
|
||||
# Start background tasks
|
||||
await session_manager.start_background_tasks()
|
||||
|
||||
logger.info("AI Voice Bot server started successfully!")
|
||||
logger.info(f"Server URL: {public_url}")
|
||||
logger.info(f"Sessions loaded: {session_manager.get_session_count()}")
|
||||
logger.info(f"Lobbies available: {lobby_manager.get_lobby_count()}")
|
||||
logger.info(f"Protected names: {auth_manager.get_protection_count()}")
|
||||
|
||||
if ADMIN_TOKEN:
|
||||
logger.info("Admin endpoints protected with token")
|
||||
else:
|
||||
logger.warning("Admin endpoints are unprotected")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down AI Voice Bot server...")
|
||||
if session_manager:
|
||||
await session_manager.stop_background_tasks()
|
||||
await session_manager.cleanup_all_sessions()
|
||||
logger.info("Server shutdown complete")
|
||||
|
||||
|
||||
# Create FastAPI app with the new architecture
|
||||
app = FastAPI(
|
||||
title="AI Voice Bot Server",
|
||||
description="Modular AI Voice Bot Server with WebRTC support",
|
||||
version="2.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
logger.info(f"Starting server with public URL: {public_url}")
|
||||
|
||||
|
||||
@app.websocket(f"{public_url}" + "ws/lobby/{{lobby_id}}/{{session_id}}")
|
||||
async def lobby_websocket(
|
||||
websocket: WebSocket,
|
||||
lobby_id: str = Path(...),
|
||||
session_id: str = Path(...)
|
||||
):
|
||||
"""WebSocket endpoint for lobby connections - now uses WebSocketConnectionManager"""
|
||||
await websocket_manager.handle_connection(websocket, lobby_id, session_id)
|
||||
|
||||
|
||||
# WebSocket proxy for React dev server (development mode)
|
||||
PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true"
|
||||
|
||||
if not PRODUCTION:
|
||||
@app.websocket("/ws")
|
||||
async def websocket_proxy(websocket: WebSocket):
|
||||
"""Proxy WebSocket connections to React dev server"""
|
||||
logger.info("REACT: WebSocket proxy connection established.")
|
||||
target_url = "wss://client:3000/ws"
|
||||
await websocket.accept()
|
||||
try:
|
||||
# Accept self-signed certs in dev for WSS
|
||||
ssl_ctx = ssl.create_default_context()
|
||||
ssl_ctx.check_hostname = False
|
||||
ssl_ctx.verify_mode = ssl.CERT_NONE
|
||||
|
||||
async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws:
|
||||
async def client_to_server():
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
await target_ws.send(data)
|
||||
except Exception as e:
|
||||
logger.debug(f"Client to server error: {e}")
|
||||
|
||||
async def server_to_client():
|
||||
try:
|
||||
while True:
|
||||
data = await target_ws.recv()
|
||||
await websocket.send_text(data)
|
||||
except Exception as e:
|
||||
logger.debug(f"Server to client error: {e}")
|
||||
|
||||
# Run both directions concurrently
|
||||
import asyncio
|
||||
await asyncio.gather(
|
||||
client_to_server(),
|
||||
server_to_client(),
|
||||
return_exceptions=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket proxy error: {e}")
|
||||
finally:
|
||||
try:
|
||||
await websocket.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Serve static files or proxy to frontend development server
|
||||
client_build_path = "/client/build"
|
||||
|
||||
if PRODUCTION:
|
||||
# In production, serve static files from the client build directory
|
||||
if os.path.exists(client_build_path):
|
||||
logger.info(f"Serving static files from: {client_build_path} at {public_url}")
|
||||
app.mount(
|
||||
public_url, StaticFiles(directory=client_build_path, html=True), name="static"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Client build directory not found: {client_build_path}")
|
||||
else:
|
||||
# In development, proxy to the React dev server
|
||||
logger.info(f"Proxying static files to http://client:3000 at {public_url}")
|
||||
|
||||
@app.api_route(
|
||||
f"{public_url}{{path:path}}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
|
||||
)
|
||||
async def proxy_static(request: Request, path: str):
|
||||
# Do not proxy API or websocket paths
|
||||
if path.startswith("api/") or path.startswith("ws/"):
|
||||
return Response(status_code=404)
|
||||
|
||||
url = f"https://client:3000/{public_url.strip('/')}/{path}"
|
||||
if not path:
|
||||
url = f"https://client:3000/{public_url.strip('/')}"
|
||||
|
||||
# Prepare headers but remove problematic ones for proxying
|
||||
headers = dict(request.headers)
|
||||
# Remove host header to avoid conflicts
|
||||
headers.pop("host", None)
|
||||
# Remove accept-encoding to prevent compression issues
|
||||
headers.pop("accept-encoding", None)
|
||||
|
||||
try:
|
||||
# Use HTTP instead of HTTPS for internal container communication
|
||||
async with httpx.AsyncClient(verify=False) as client:
|
||||
proxy_req = client.build_request(
|
||||
request.method, url, headers=headers, content=await request.body()
|
||||
)
|
||||
proxy_resp = await client.send(proxy_req, stream=False)
|
||||
|
||||
# Get response headers but filter out problematic encoding headers
|
||||
response_headers = dict(proxy_resp.headers)
|
||||
# Remove content-encoding and transfer-encoding to prevent conflicts
|
||||
response_headers.pop("content-encoding", None)
|
||||
response_headers.pop("transfer-encoding", None)
|
||||
response_headers.pop("content-length", None) # Let FastAPI calculate this
|
||||
|
||||
return Response(
|
||||
content=proxy_resp.content,
|
||||
status_code=proxy_resp.status_code,
|
||||
headers=response_headers,
|
||||
media_type=proxy_resp.headers.get("content-type")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Proxy error for {path}: {e}")
|
||||
return Response(status_code=404)
|
||||
|
||||
|
||||
# Health check for the new architecture
|
||||
@app.get(f"{public_url}api/system/health")
|
||||
def system_health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"architecture": "modular",
|
||||
"version": "2.0.0",
|
||||
"managers": {
|
||||
"session_manager": "active" if session_manager else "inactive",
|
||||
"lobby_manager": "active" if lobby_manager else "inactive",
|
||||
"auth_manager": "active" if auth_manager else "inactive",
|
||||
"websocket_manager": "active" if websocket_manager else "inactive",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
File diff suppressed because it is too large
Load Diff
@ -1,213 +0,0 @@
|
||||
"""
|
||||
Refactored main.py - Step 1 of Server Architecture Improvement
|
||||
|
||||
This is a refactored version of the original main.py that demonstrates the new
|
||||
modular architecture with separated concerns:
|
||||
|
||||
- SessionManager: Handles session lifecycle and persistence
|
||||
- LobbyManager: Handles lobby management and chat
|
||||
- AuthManager: Handles authentication and name protection
|
||||
- WebSocket message routing: Clean message handling
|
||||
- Separated API modules: Admin, session, and lobby endpoints
|
||||
|
||||
This maintains backward compatibility while providing a foundation for
|
||||
further improvements.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, WebSocket, Path
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
# Import our new modular components
|
||||
try:
|
||||
from core.session_manager import SessionManager
|
||||
from core.lobby_manager import LobbyManager
|
||||
from core.auth_manager import AuthManager
|
||||
from websocket.connection import WebSocketConnectionManager
|
||||
from api.admin import AdminAPI
|
||||
from api.sessions import SessionAPI
|
||||
from api.lobbies import LobbyAPI
|
||||
except ImportError:
|
||||
# Handle relative imports when running as module
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from core.session_manager import SessionManager
|
||||
from core.lobby_manager import LobbyManager
|
||||
from core.auth_manager import AuthManager
|
||||
from websocket.connection import WebSocketConnectionManager
|
||||
from api.admin import AdminAPI
|
||||
from api.sessions import SessionAPI
|
||||
from api.lobbies import LobbyAPI
|
||||
|
||||
from logger import logger
|
||||
|
||||
|
||||
# Configuration
|
||||
public_url = os.getenv("PUBLIC_URL", "/")
|
||||
if not public_url.endswith("/"):
|
||||
public_url += "/"
|
||||
|
||||
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", None)
|
||||
|
||||
# Global managers - these replace the global variables from original main.py
|
||||
session_manager: SessionManager = None
|
||||
lobby_manager: LobbyManager = None
|
||||
auth_manager: AuthManager = None
|
||||
websocket_manager: WebSocketConnectionManager = None
|
||||
|
||||
# API routers
|
||||
admin_api: AdminAPI = None
|
||||
session_api: SessionAPI = None
|
||||
lobby_api: LobbyAPI = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup and shutdown events"""
|
||||
global session_manager, lobby_manager, auth_manager, websocket_manager
|
||||
global admin_api, session_api, lobby_api
|
||||
|
||||
# Startup
|
||||
logger.info("Starting AI Voice Bot server with modular architecture...")
|
||||
|
||||
# Initialize managers
|
||||
session_manager = SessionManager("sessions.json")
|
||||
lobby_manager = LobbyManager()
|
||||
auth_manager = AuthManager("sessions.json")
|
||||
|
||||
# Load existing data
|
||||
session_manager.load()
|
||||
|
||||
# Restore lobbies for existing sessions
|
||||
# Note: This is a simplified version - full lobby restoration would be more complex
|
||||
for session in session_manager.get_all_sessions():
|
||||
for lobby_info in session.lobbies:
|
||||
# Create lobby if it doesn't exist
|
||||
lobby = lobby_manager.create_or_get_lobby(
|
||||
name=lobby_info.name, private=lobby_info.private
|
||||
)
|
||||
# Add session to lobby (but don't trigger events during startup)
|
||||
with lobby.lock:
|
||||
lobby.sessions[session.id] = session
|
||||
|
||||
# Set up dependency injection for name protection
|
||||
lobby_manager.set_name_protection_checker(auth_manager.is_name_protected)
|
||||
|
||||
# Initialize WebSocket manager
|
||||
websocket_manager = WebSocketConnectionManager(
|
||||
session_manager=session_manager,
|
||||
lobby_manager=lobby_manager,
|
||||
auth_manager=auth_manager,
|
||||
)
|
||||
|
||||
# Initialize API routers
|
||||
admin_api = AdminAPI(
|
||||
session_manager=session_manager,
|
||||
lobby_manager=lobby_manager,
|
||||
auth_manager=auth_manager,
|
||||
admin_token=ADMIN_TOKEN,
|
||||
public_url=public_url,
|
||||
)
|
||||
|
||||
session_api = SessionAPI(session_manager=session_manager, public_url=public_url)
|
||||
|
||||
lobby_api = LobbyAPI(
|
||||
session_manager=session_manager,
|
||||
lobby_manager=lobby_manager,
|
||||
public_url=public_url,
|
||||
)
|
||||
|
||||
# Register API routes
|
||||
app.include_router(admin_api.router)
|
||||
app.include_router(session_api.router)
|
||||
app.include_router(lobby_api.router)
|
||||
|
||||
# Start background tasks
|
||||
await session_manager.start_background_tasks()
|
||||
|
||||
logger.info("AI Voice Bot server started successfully!")
|
||||
logger.info(f"Server URL: {public_url}")
|
||||
logger.info(f"Sessions loaded: {session_manager.get_session_count()}")
|
||||
logger.info(f"Lobbies available: {lobby_manager.get_lobby_count()}")
|
||||
logger.info(f"Protected names: {auth_manager.get_protection_count()}")
|
||||
|
||||
if ADMIN_TOKEN:
|
||||
logger.info("Admin endpoints protected with token")
|
||||
else:
|
||||
logger.warning("Admin endpoints are unprotected")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down AI Voice Bot server...")
|
||||
|
||||
# Stop background tasks
|
||||
if session_manager:
|
||||
await session_manager.stop_background_tasks()
|
||||
|
||||
logger.info("Server shutdown complete")
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="AI Voice Bot Server (Refactored)",
|
||||
description="WebRTC voice chat server with modular architecture",
|
||||
version="2.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
logger.info(f"Starting server with public URL: {public_url}")
|
||||
|
||||
|
||||
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}")
|
||||
async def lobby_websocket(
|
||||
websocket: WebSocket,
|
||||
lobby_id: str | None = Path(...),
|
||||
session_id: str | None = Path(...),
|
||||
):
|
||||
"""WebSocket endpoint for lobby connections - now uses WebSocketConnectionManager"""
|
||||
await websocket_manager.handle_connection(websocket, lobby_id, session_id)
|
||||
|
||||
|
||||
# Serve static files if available (for client)
|
||||
try:
|
||||
app.mount(public_url + "static", StaticFiles(directory="static"), name="static")
|
||||
logger.info("Static files mounted at /static")
|
||||
except Exception:
|
||||
logger.info("No static directory found, skipping static file serving")
|
||||
|
||||
|
||||
# Health check for the new architecture
|
||||
@app.get(f"{public_url}api/system/health")
|
||||
def system_health():
|
||||
"""System health check showing manager status"""
|
||||
return {
|
||||
"status": "ok",
|
||||
"architecture": "modular",
|
||||
"version": "2.0.0",
|
||||
"managers": {
|
||||
"session_manager": "active" if session_manager else "inactive",
|
||||
"lobby_manager": "active" if lobby_manager else "inactive",
|
||||
"auth_manager": "active" if auth_manager else "inactive",
|
||||
"websocket_manager": "active" if websocket_manager else "inactive",
|
||||
},
|
||||
"statistics": {
|
||||
"sessions": session_manager.get_session_count() if session_manager else 0,
|
||||
"lobbies": lobby_manager.get_lobby_count() if lobby_manager else 0,
|
||||
"protected_names": auth_manager.get_protection_count()
|
||||
if auth_manager
|
||||
else 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
File diff suppressed because it is too large
Load Diff
@ -10,6 +10,7 @@ from typing import Dict, Any, TYPE_CHECKING
|
||||
from fastapi import WebSocket
|
||||
|
||||
from logger import logger
|
||||
from .webrtc_signaling import WebRTCSignalingHandlers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..core.session_manager import Session
|
||||
@ -255,6 +256,38 @@ class SendChatMessageHandler(MessageHandler):
|
||||
await lobby.broadcast_chat_message(chat_message)
|
||||
|
||||
|
||||
class RelayICECandidateHandler(MessageHandler):
|
||||
"""Handler for relayICECandidate messages - WebRTC signaling"""
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
session: "Session",
|
||||
lobby: "Lobby",
|
||||
data: Dict[str, Any],
|
||||
websocket: WebSocket,
|
||||
managers: Dict[str, Any],
|
||||
) -> None:
|
||||
await WebRTCSignalingHandlers.handle_relay_ice_candidate(
|
||||
websocket, session, lobby, data
|
||||
)
|
||||
|
||||
|
||||
class RelaySessionDescriptionHandler(MessageHandler):
|
||||
"""Handler for relaySessionDescription messages - WebRTC signaling"""
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
session: "Session",
|
||||
lobby: "Lobby",
|
||||
data: Dict[str, Any],
|
||||
websocket: WebSocket,
|
||||
managers: Dict[str, Any],
|
||||
) -> None:
|
||||
await WebRTCSignalingHandlers.handle_relay_session_description(
|
||||
websocket, session, lobby, data
|
||||
)
|
||||
|
||||
|
||||
class MessageRouter:
|
||||
"""Routes WebSocket messages to appropriate handlers"""
|
||||
|
||||
@ -271,6 +304,10 @@ class MessageRouter:
|
||||
self.register("get_chat_messages", GetChatMessagesHandler())
|
||||
self.register("send_chat_message", SendChatMessageHandler())
|
||||
|
||||
# WebRTC signaling handlers
|
||||
self.register("relayICECandidate", RelayICECandidateHandler())
|
||||
self.register("relaySessionDescription", RelaySessionDescriptionHandler())
|
||||
|
||||
def register(self, message_type: str, handler: MessageHandler):
|
||||
"""Register a handler for a message type"""
|
||||
self._handlers[message_type] = handler
|
||||
|
199
server/websocket/webrtc_signaling.py
Normal file
199
server/websocket/webrtc_signaling.py
Normal file
@ -0,0 +1,199 @@
|
||||
"""
|
||||
WebRTC Signaling Handlers
|
||||
|
||||
This module contains WebRTC signaling message handlers for peer-to-peer communication.
|
||||
Handles ICE candidate relay and session description exchange between peers.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, TYPE_CHECKING
|
||||
from fastapi import WebSocket
|
||||
|
||||
from logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.session_manager import Session
|
||||
from core.lobby_manager import Lobby
|
||||
|
||||
|
||||
class WebRTCSignalingHandlers:
|
||||
"""WebRTC signaling message handlers for peer-to-peer communication."""
|
||||
|
||||
@staticmethod
|
||||
async def handle_relay_ice_candidate(
|
||||
websocket: WebSocket,
|
||||
session: "Session",
|
||||
lobby: "Lobby",
|
||||
data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Handle ICE candidate relay between peers.
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection
|
||||
session: The sender session
|
||||
lobby: The lobby context
|
||||
data: Message data containing peer_id and candidate
|
||||
"""
|
||||
logger.info(f"{session.getName()} <- relayICECandidate")
|
||||
|
||||
if not data:
|
||||
logger.error(f"{session.getName()} - relayICECandidate missing data")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": {"error": "relayICECandidate missing data"}
|
||||
})
|
||||
return
|
||||
|
||||
# Check if session is properly joined to lobby with RTC peers
|
||||
with session.session_lock:
|
||||
if (lobby.id not in session.lobby_peers or
|
||||
session.id not in lobby.sessions):
|
||||
logger.error(
|
||||
f"{session.short}:{session.name} <- relayICECandidate - "
|
||||
f"Not an RTC peer ({session.id})"
|
||||
)
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": {"error": "Not joined to lobby"}
|
||||
})
|
||||
return
|
||||
|
||||
session_peers = session.lobby_peers[lobby.id]
|
||||
|
||||
# Validate peer_id
|
||||
peer_id = data.get("peer_id")
|
||||
if peer_id not in session_peers:
|
||||
logger.error(
|
||||
f"{session.getName()} <- relayICECandidate - "
|
||||
f"Not an RTC peer({peer_id}) in {session_peers}"
|
||||
)
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": {"error": f"Target peer {peer_id} not found"}
|
||||
})
|
||||
return
|
||||
|
||||
# Get candidate data
|
||||
candidate = data.get("candidate")
|
||||
|
||||
# Prepare message for target peer
|
||||
message: Dict[str, Any] = {
|
||||
"type": "iceCandidate",
|
||||
"data": {
|
||||
"peer_id": session.id,
|
||||
"peer_name": session.name,
|
||||
"candidate": candidate,
|
||||
},
|
||||
}
|
||||
|
||||
# Find target peer session and relay the message
|
||||
peer_session = lobby.getSession(peer_id)
|
||||
if not peer_session or not peer_session.ws:
|
||||
logger.warning(
|
||||
f"{session.getName()} - Live peer session {peer_id} "
|
||||
f"not found in lobby {lobby.getName()}."
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"{session.getName()} -> iceCandidate({peer_session.getName()})"
|
||||
)
|
||||
|
||||
try:
|
||||
await peer_session.ws.send_json(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to relay ICE candidate: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def handle_relay_session_description(
|
||||
websocket: WebSocket,
|
||||
session: "Session",
|
||||
lobby: "Lobby",
|
||||
data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Handle session description relay between peers.
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection
|
||||
session: The sender session
|
||||
lobby: The lobby context
|
||||
data: Message data containing peer_id and session_description
|
||||
"""
|
||||
logger.info(f"{session.getName()} <- relaySessionDescription")
|
||||
|
||||
if not data:
|
||||
logger.error(f"{session.getName()} - relaySessionDescription missing data")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": {"error": "relaySessionDescription missing data"}
|
||||
})
|
||||
return
|
||||
|
||||
# Check if session is properly joined to lobby with RTC peers
|
||||
with session.session_lock:
|
||||
if (lobby.id not in session.lobby_peers or
|
||||
session.id not in lobby.sessions):
|
||||
logger.error(
|
||||
f"{session.short}:{session.name} <- relaySessionDescription - "
|
||||
f"Not an RTC peer ({session.id})"
|
||||
)
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": {"error": "Not joined to lobby"}
|
||||
})
|
||||
return
|
||||
|
||||
lobby_peers = session.lobby_peers[lobby.id]
|
||||
|
||||
# Validate peer_id
|
||||
peer_id = data.get("peer_id")
|
||||
if not peer_id:
|
||||
logger.error(f"{session.getName()} - relaySessionDescription missing peer_id")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": {"error": "relaySessionDescription missing peer_id"}
|
||||
})
|
||||
return
|
||||
|
||||
if peer_id not in lobby_peers:
|
||||
logger.error(
|
||||
f"{session.getName()} <- relaySessionDescription - "
|
||||
f"Not an RTC peer({peer_id}) in {lobby_peers}"
|
||||
)
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": {"error": f"Target peer {peer_id} not found"}
|
||||
})
|
||||
return
|
||||
|
||||
# Find target peer session
|
||||
peer_session = lobby.getSession(peer_id)
|
||||
if not peer_session or not peer_session.ws:
|
||||
logger.warning(
|
||||
f"{session.getName()} - Live peer session {peer_id} "
|
||||
f"not found in lobby {lobby.getName()}."
|
||||
)
|
||||
return
|
||||
|
||||
# Get session description data
|
||||
session_description = data.get("session_description")
|
||||
|
||||
# Prepare message for target peer
|
||||
message: Dict[str, Any] = {
|
||||
"type": "sessionDescription",
|
||||
"data": {
|
||||
"peer_id": session.id,
|
||||
"peer_name": session.name,
|
||||
"session_description": session_description,
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"{session.getName()} -> sessionDescription({peer_session.getName()})"
|
||||
)
|
||||
|
||||
try:
|
||||
await peer_session.ws.send_json(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to relay session description: {e}")
|
71
tests/test-webrtc-signaling.py
Normal file
71
tests/test-webrtc-signaling.py
Normal file
@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test WebRTC signaling handlers registration
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import websockets
|
||||
import sys
|
||||
|
||||
async def test_webrtc_handlers():
|
||||
"""Test that WebRTC signaling handlers are properly registered"""
|
||||
try:
|
||||
# Connect to the WebSocket endpoint
|
||||
uri = "ws://localhost:8000/ai-voicebot/ws"
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
print("Connected to WebSocket")
|
||||
|
||||
# Send a set_name message first
|
||||
await websocket.send(json.dumps({
|
||||
"type": "set_name",
|
||||
"data": {"name": "test_user"}
|
||||
}))
|
||||
|
||||
response = await websocket.recv()
|
||||
print(f"Set name response: {response}")
|
||||
|
||||
# Test relayICECandidate handler
|
||||
test_message = {
|
||||
"type": "relayICECandidate",
|
||||
"data": {
|
||||
"peer_id": "nonexistent_peer",
|
||||
"candidate": {"candidate": "test"}
|
||||
}
|
||||
}
|
||||
|
||||
await websocket.send(json.dumps(test_message))
|
||||
print("Sent relayICECandidate message")
|
||||
|
||||
# Expect an error response since we're not in a lobby
|
||||
response = await websocket.recv()
|
||||
print(f"ICE candidate response: {response}")
|
||||
|
||||
# Test relaySessionDescription handler
|
||||
test_message = {
|
||||
"type": "relaySessionDescription",
|
||||
"data": {
|
||||
"peer_id": "nonexistent_peer",
|
||||
"session_description": {"type": "offer"}
|
||||
}
|
||||
}
|
||||
|
||||
await websocket.send(json.dumps(test_message))
|
||||
print("Sent relaySessionDescription message")
|
||||
|
||||
# Expect an error response since we're not in a lobby
|
||||
response = await websocket.recv()
|
||||
print(f"Session description response: {response}")
|
||||
|
||||
print("WebRTC signaling handlers are working!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error testing WebRTC handlers: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = asyncio.run(test_webrtc_handlers())
|
||||
sys.exit(0 if success else 1)
|
41
tests/verify-webrtc-handlers.py
Normal file
41
tests/verify-webrtc-handlers.py
Normal file
@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify WebRTC signaling handlers are registered
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
# Add the server directory to Python path
|
||||
sys.path.insert(0, '/home/jketreno/docker/ai-voicebot/server')
|
||||
|
||||
from websocket.message_handlers import MessageRouter
|
||||
|
||||
def test_webrtc_handlers():
|
||||
"""Test that WebRTC signaling handlers are registered"""
|
||||
router = MessageRouter()
|
||||
supported_types = router.get_supported_types()
|
||||
|
||||
print("Supported message types:")
|
||||
for msg_type in sorted(supported_types):
|
||||
print(f" - {msg_type}")
|
||||
|
||||
# Check for WebRTC handlers
|
||||
webrtc_handlers = [
|
||||
"relayICECandidate",
|
||||
"relaySessionDescription"
|
||||
]
|
||||
|
||||
print("\nWebRTC signaling handlers:")
|
||||
for handler in webrtc_handlers:
|
||||
if handler in supported_types:
|
||||
print(f" ✓ {handler} - REGISTERED")
|
||||
else:
|
||||
print(f" ✗ {handler} - MISSING")
|
||||
return False
|
||||
|
||||
print("\n✅ All WebRTC signaling handlers are properly registered!")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_webrtc_handlers()
|
||||
sys.exit(0 if success else 1)
|
Loading…
x
Reference in New Issue
Block a user