402 lines
15 KiB
Python
402 lines
15 KiB
Python
"""
|
|
Refactored main.py - Step 1 of Server Architecture Improvement
|
|
|
|
This is a refactored version of the original main.py that demonstrates the new
|
|
modular architecture with separated concerns:
|
|
|
|
- SessionManager: Handles session lifecycle and persistence
|
|
- LobbyManager: Handles lobby management and chat
|
|
- AuthManager: Handles authentication and name protection
|
|
- WebSocket message routing: Clean message handling
|
|
- Separated API modules: Admin, session, and lobby endpoints
|
|
|
|
This maintains backward compatibility while providing a foundation for
|
|
further improvements.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import os
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime
|
|
|
|
from fastapi import FastAPI, WebSocket, Path, Request
|
|
from fastapi.responses import Response
|
|
from fastapi.staticfiles import StaticFiles
|
|
from starlette.websockets import WebSocketDisconnect
|
|
|
|
# Import our new modular components
|
|
try:
|
|
from core.session_manager import SessionManager
|
|
from core.lobby_manager import LobbyManager
|
|
from core.auth_manager import AuthManager
|
|
from core.bot_manager import BotManager
|
|
from websocket.connection import WebSocketConnectionManager
|
|
from api.admin import AdminAPI
|
|
from api.sessions import SessionAPI
|
|
from api.lobbies import LobbyAPI
|
|
from api.bots import create_bot_router
|
|
except ImportError:
|
|
# Handle relative imports when running as module
|
|
import sys
|
|
import os
|
|
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from core.session_manager import SessionManager
|
|
from core.lobby_manager import LobbyManager
|
|
from core.auth_manager import AuthManager
|
|
from core.bot_manager import BotManager
|
|
from websocket.connection import WebSocketConnectionManager
|
|
from api.admin import AdminAPI
|
|
from api.sessions import SessionAPI
|
|
from api.lobbies import LobbyAPI
|
|
from api.bots import create_bot_router
|
|
|
|
from logger import logger
|
|
|
|
# Import performance monitoring components
|
|
try:
|
|
from api.monitoring import router as monitoring_router
|
|
from core.performance import metrics_collector
|
|
from core.health import (
|
|
health_monitor, DatabaseHealthCheck, WebSocketHealthCheck,
|
|
LobbyHealthCheck, SystemResourceHealthCheck
|
|
)
|
|
from core.cache import cache_manager
|
|
monitoring_available = True
|
|
logger.info("Performance monitoring modules loaded successfully")
|
|
except ImportError as e:
|
|
logger.warning(f"Performance monitoring not available: {e}")
|
|
monitoring_router = None
|
|
metrics_collector = None
|
|
health_monitor = None
|
|
cache_manager = None
|
|
monitoring_available = False
|
|
|
|
|
|
# Configuration
|
|
public_url = os.getenv("PUBLIC_URL", "/")
|
|
if not public_url.endswith("/"):
|
|
public_url += "/"
|
|
|
|
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", None)
|
|
|
|
# Create FastAPI app first
|
|
app = FastAPI(
|
|
title="AI Voice Bot Server (Refactored)",
|
|
description="WebRTC voice chat server with modular architecture",
|
|
version="2.0.0",
|
|
)
|
|
|
|
logger.info(f"Starting server with public URL: {public_url}")
|
|
|
|
|
|
# Global managers - these replace the global variables from original main.py
|
|
session_manager: SessionManager = None
|
|
lobby_manager: LobbyManager = None
|
|
auth_manager: AuthManager = None
|
|
bot_manager: BotManager = None
|
|
websocket_manager: WebSocketConnectionManager = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Lifespan context manager for startup and shutdown events"""
|
|
global session_manager, lobby_manager, auth_manager, bot_manager, websocket_manager
|
|
|
|
# Startup
|
|
logger.info("Starting AI Voice Bot server with modular architecture...")
|
|
|
|
# Initialize managers
|
|
session_manager = SessionManager("sessions.json")
|
|
lobby_manager = LobbyManager()
|
|
auth_manager = AuthManager("sessions.json")
|
|
bot_manager = BotManager()
|
|
|
|
# Load existing data
|
|
session_manager.load()
|
|
|
|
# Restore lobbies for existing sessions
|
|
for session in session_manager.get_all_sessions():
|
|
for lobby_info in session.lobbies:
|
|
lobby = lobby_manager.create_or_get_lobby(
|
|
name=lobby_info.name, private=lobby_info.private
|
|
)
|
|
with lobby.lock:
|
|
lobby.sessions[session.id] = session
|
|
|
|
# Set up dependency injection for name protection
|
|
lobby_manager.set_name_protection_checker(auth_manager.is_name_protected)
|
|
|
|
# Initialize WebSocket manager
|
|
websocket_manager = WebSocketConnectionManager(
|
|
session_manager=session_manager,
|
|
lobby_manager=lobby_manager,
|
|
auth_manager=auth_manager,
|
|
)
|
|
|
|
# Create and register API routes
|
|
admin_api = AdminAPI(
|
|
session_manager=session_manager,
|
|
lobby_manager=lobby_manager,
|
|
auth_manager=auth_manager,
|
|
admin_token=ADMIN_TOKEN,
|
|
public_url=public_url,
|
|
)
|
|
|
|
session_api = SessionAPI(session_manager=session_manager, public_url=public_url)
|
|
|
|
lobby_api = LobbyAPI(
|
|
session_manager=session_manager,
|
|
lobby_manager=lobby_manager,
|
|
public_url=public_url,
|
|
)
|
|
|
|
# Create bot API router
|
|
bot_router = create_bot_router(bot_manager, session_manager, lobby_manager)
|
|
|
|
# Register API routes during startup
|
|
app.include_router(admin_api.router)
|
|
app.include_router(session_api.router)
|
|
app.include_router(lobby_api.router)
|
|
app.include_router(bot_router, prefix=public_url.rstrip("/") + "/api")
|
|
|
|
# Add monitoring router if available
|
|
if monitoring_available and monitoring_router:
|
|
app.include_router(monitoring_router, prefix=public_url.rstrip("/"))
|
|
logger.info("Monitoring API endpoints registered")
|
|
|
|
# Initialize and start performance monitoring if available
|
|
if monitoring_available:
|
|
logger.info("Starting performance monitoring...")
|
|
|
|
# Register health check components
|
|
if health_monitor:
|
|
health_monitor.register_component(DatabaseHealthCheck(session_manager))
|
|
health_monitor.register_component(WebSocketHealthCheck(session_manager))
|
|
health_monitor.register_component(LobbyHealthCheck(lobby_manager))
|
|
health_monitor.register_component(SystemResourceHealthCheck(metrics_collector))
|
|
|
|
# Start monitoring tasks
|
|
if metrics_collector:
|
|
await metrics_collector.start_collection()
|
|
if health_monitor:
|
|
await health_monitor.start_monitoring()
|
|
if cache_manager:
|
|
await cache_manager.start_all()
|
|
# Warm up caches with current data
|
|
await cache_manager.warm_cache(session_manager, lobby_manager)
|
|
|
|
logger.info("Performance monitoring started successfully!")
|
|
else:
|
|
logger.info("Performance monitoring disabled - running in basic mode")
|
|
|
|
# Register static file serving AFTER API routes to avoid conflicts
|
|
PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true"
|
|
client_build_path = os.path.join(os.path.dirname(__file__), "/client/build")
|
|
|
|
if PRODUCTION:
|
|
logger.info(f"Serving static files from: {client_build_path} at {public_url}")
|
|
app.mount(
|
|
public_url,
|
|
StaticFiles(directory=client_build_path, html=True),
|
|
name="static",
|
|
)
|
|
else:
|
|
logger.info(f"Proxying static files to http://client:3000 at {public_url}")
|
|
|
|
import ssl
|
|
import httpx
|
|
|
|
@app.api_route(
|
|
f"{public_url}{{path:path}}",
|
|
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
|
|
)
|
|
async def proxy_static(request: Request, path: str):
|
|
# Do not proxy API or websocket paths
|
|
if path.startswith("api/") or path.startswith("ws/"):
|
|
return Response(status_code=404)
|
|
url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}/{path}"
|
|
if not path:
|
|
url = f"{request.url.scheme}://client:3000/{public_url.strip('/')}"
|
|
headers = dict(request.headers)
|
|
try:
|
|
# Accept self-signed certs in dev
|
|
async with httpx.AsyncClient(verify=False) as client:
|
|
proxy_req = client.build_request(
|
|
request.method,
|
|
url,
|
|
headers=headers,
|
|
content=await request.body(),
|
|
)
|
|
proxy_resp = await client.send(proxy_req, stream=True)
|
|
content = await proxy_resp.aread()
|
|
|
|
# Remove problematic headers for browser decoding
|
|
filtered_headers = {
|
|
k: v
|
|
for k, v in proxy_resp.headers.items()
|
|
if k.lower()
|
|
not in [
|
|
"content-encoding",
|
|
"transfer-encoding",
|
|
"content-length",
|
|
]
|
|
}
|
|
return Response(
|
|
content=content,
|
|
status_code=proxy_resp.status_code,
|
|
headers=filtered_headers,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Proxy error for {url}: {e}")
|
|
return Response("Proxy error", status_code=502)
|
|
|
|
# WebSocket proxy for /ws (for React DevTools, etc.)
|
|
import websockets
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_proxy(websocket: WebSocket):
|
|
logger.info("REACT: WebSocket proxy connection established.")
|
|
# Get scheme from websocket.url (should be 'ws' or 'wss')
|
|
scheme = websocket.url.scheme if hasattr(websocket, "url") else "ws"
|
|
target_url = f"{scheme}://client:3000/ws"
|
|
await websocket.accept()
|
|
try:
|
|
# Accept self-signed certs in dev for WSS
|
|
ssl_ctx = ssl.create_default_context()
|
|
ssl_ctx.check_hostname = False
|
|
ssl_ctx.verify_mode = ssl.CERT_NONE
|
|
async with websockets.connect(target_url, ssl=ssl_ctx) as target_ws:
|
|
|
|
async def client_to_server():
|
|
while True:
|
|
msg = await websocket.receive_text()
|
|
await target_ws.send(msg)
|
|
|
|
async def server_to_client():
|
|
while True:
|
|
msg = await target_ws.recv()
|
|
if isinstance(msg, str):
|
|
await websocket.send_text(msg)
|
|
else:
|
|
await websocket.send_bytes(msg)
|
|
|
|
try:
|
|
await asyncio.gather(client_to_server(), server_to_client())
|
|
except (WebSocketDisconnect, websockets.ConnectionClosed):
|
|
logger.info("REACT: WebSocket proxy connection closed.")
|
|
except Exception as e:
|
|
logger.error(f"REACT: WebSocket proxy error: {e}")
|
|
await websocket.close()
|
|
|
|
# Start background tasks
|
|
await session_manager.start_background_tasks()
|
|
|
|
logger.info("AI Voice Bot server started successfully!")
|
|
logger.info(f"Server URL: {public_url}")
|
|
logger.info(f"Sessions loaded: {session_manager.get_session_count()}")
|
|
logger.info(f"Lobbies available: {lobby_manager.get_lobby_count()}")
|
|
logger.info(f"Protected names: {auth_manager.get_protection_count()}")
|
|
|
|
if ADMIN_TOKEN:
|
|
logger.info("Admin endpoints protected with token")
|
|
else:
|
|
logger.warning("Admin endpoints are unprotected")
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
logger.info("Shutting down AI Voice Bot server...")
|
|
|
|
# Stop performance monitoring if available
|
|
if monitoring_available:
|
|
logger.info("Stopping performance monitoring...")
|
|
if metrics_collector:
|
|
await metrics_collector.stop_collection()
|
|
if health_monitor:
|
|
await health_monitor.stop_monitoring()
|
|
if cache_manager:
|
|
await cache_manager.stop_all()
|
|
logger.info("Performance monitoring stopped")
|
|
|
|
# Stop background tasks
|
|
if session_manager:
|
|
await session_manager.stop_background_tasks()
|
|
|
|
logger.info("Server shutdown complete")
|
|
|
|
|
|
# Set the lifespan
|
|
app.router.lifespan_context = lifespan
|
|
|
|
|
|
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}")
|
|
async def lobby_websocket(
|
|
websocket: WebSocket,
|
|
lobby_id: str | None = Path(...),
|
|
session_id: str | None = Path(...),
|
|
):
|
|
"""WebSocket endpoint for lobby connections - now uses WebSocketConnectionManager"""
|
|
await websocket_manager.handle_connection(websocket, lobby_id, session_id)
|
|
|
|
|
|
# Enhanced health check showing monitoring capabilities
|
|
@app.get(f"{public_url}api/system/health")
|
|
async def system_health():
|
|
"""System health check showing manager status and enhanced monitoring"""
|
|
try:
|
|
# Get basic manager status
|
|
manager_status = {
|
|
"session_manager": "active" if session_manager else "inactive",
|
|
"lobby_manager": "active" if lobby_manager else "inactive",
|
|
"auth_manager": "active" if auth_manager else "inactive",
|
|
"bot_manager": "active" if bot_manager else "inactive",
|
|
"websocket_manager": "active" if websocket_manager else "inactive",
|
|
}
|
|
|
|
# Get enhanced monitoring status
|
|
monitoring_status = {
|
|
"performance_monitoring": "active" if metrics_collector else "inactive",
|
|
"health_monitoring": "active" if health_monitor else "inactive",
|
|
"cache_management": "active" if cache_manager else "inactive",
|
|
}
|
|
|
|
# Get basic statistics
|
|
statistics = {
|
|
"sessions": session_manager.get_session_count() if session_manager else 0,
|
|
"lobbies": lobby_manager.get_lobby_count() if lobby_manager else 0,
|
|
"protected_names": auth_manager.get_protection_count() if auth_manager else 0,
|
|
}
|
|
|
|
# Get performance metrics if available
|
|
performance_summary = {}
|
|
if metrics_collector:
|
|
performance_summary = metrics_collector.get_performance_summary()
|
|
|
|
return {
|
|
"status": "ok",
|
|
"architecture": "modular_with_monitoring",
|
|
"version": "2.1.0", # Updated version for Step 5
|
|
"managers": manager_status,
|
|
"monitoring": monitoring_status,
|
|
"statistics": statistics,
|
|
"performance": performance_summary.get("health_status", "unknown") if performance_summary else "unknown",
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in system health check: {e}")
|
|
return {
|
|
"status": "error",
|
|
"message": str(e),
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|