diff --git a/server/api/bots.py b/server/api/bots.py new file mode 100644 index 0000000..dea85bb --- /dev/null +++ b/server/api/bots.py @@ -0,0 +1,124 @@ +"""Bot API endpoints""" +from fastapi import APIRouter, HTTPException +from logger import logger + +# Import shared models with fallback handling +try: + from ...shared.models import ( + BotProviderRegisterRequest, + BotProviderRegisterResponse, + BotProviderListResponse, + BotListResponse, + BotJoinLobbyRequest, + BotJoinLobbyResponse, + BotLeaveLobbyRequest, + BotLeaveLobbyResponse, + ) +except ImportError: + try: + from shared.models import ( + BotProviderRegisterRequest, + BotProviderRegisterResponse, + BotProviderListResponse, + BotListResponse, + BotJoinLobbyRequest, + BotJoinLobbyResponse, + BotLeaveLobbyRequest, + BotLeaveLobbyResponse, + ) + except ImportError: + # Create dummy models for standalone testing + from pydantic import BaseModel + from typing import List, Dict, Optional + + class BotProviderRegisterRequest(BaseModel): + base_url: str + name: str + description: str + provider_key: str + + class BotProviderRegisterResponse(BaseModel): + provider_id: str + + class BotProviderListResponse(BaseModel): + providers: List[dict] + + class BotListResponse(BaseModel): + bots: List[dict] + providers: Dict[str, str] + + class BotJoinLobbyRequest(BaseModel): + lobby_id: str + provider_id: Optional[str] = None + nick: Optional[str] = None + + class BotJoinLobbyResponse(BaseModel): + status: str + bot_name: str + run_id: str + provider_id: str + + class BotLeaveLobbyRequest(BaseModel): + session_id: str + + class BotLeaveLobbyResponse(BaseModel): + status: str + session_id: str + run_id: Optional[str] = None + + +def create_bot_router(bot_manager, session_manager, lobby_manager): + """Create bot API router with dependencies""" + + router = APIRouter(prefix="/bots", tags=["bots"]) + + @router.post("/providers/register", response_model=BotProviderRegisterResponse) + async def register_bot_provider(request: BotProviderRegisterRequest) -> BotProviderRegisterResponse: + """Register a new bot provider with authentication""" + try: + return await bot_manager.register_provider(request) + except ValueError as e: + if "Invalid provider key" in str(e): + raise HTTPException(status_code=403, detail=str(e)) + else: + raise HTTPException(status_code=400, detail=str(e)) + + @router.get("/providers", response_model=BotProviderListResponse) + async def list_bot_providers() -> BotProviderListResponse: + """List all registered bot providers""" + return bot_manager.list_providers() + + @router.get("", response_model=BotListResponse) + async def list_available_bots() -> BotListResponse: + """List all available bots from all registered providers""" + return await bot_manager.list_bots() + + @router.post("/{bot_name}/join", response_model=BotJoinLobbyResponse) + async def request_bot_join_lobby(bot_name: str, request: BotJoinLobbyRequest) -> BotJoinLobbyResponse: + """Request a bot to join a specific lobby""" + try: + return await bot_manager.request_bot_join(bot_name, request, session_manager, lobby_manager) + except ValueError as e: + if "not found" in str(e).lower(): + raise HTTPException(status_code=404, detail=str(e)) + elif "timeout" in str(e).lower(): + raise HTTPException(status_code=504, detail=str(e)) + elif "provider error" in str(e).lower(): + raise HTTPException(status_code=502, detail=str(e)) + else: + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/leave", response_model=BotLeaveLobbyResponse) + async def request_bot_leave_lobby(request: BotLeaveLobbyRequest) -> BotLeaveLobbyResponse: + """Request a bot to leave from all lobbies and disconnect""" + try: + return await bot_manager.request_bot_leave(request, session_manager) + except ValueError as e: + if "not found" in str(e).lower(): + raise HTTPException(status_code=404, detail=str(e)) + elif "not a bot" in str(e).lower(): + raise HTTPException(status_code=400, detail=str(e)) + else: + raise HTTPException(status_code=500, detail=str(e)) + + return router diff --git a/server/core/bot_manager.py b/server/core/bot_manager.py new file mode 100644 index 0000000..ca7f8cd --- /dev/null +++ b/server/core/bot_manager.py @@ -0,0 +1,428 @@ +"""Bot Provider Management""" +import os +import time +import uuid +import secrets +import threading +import httpx +from typing import Dict, List, Optional +from pydantic import ValidationError +from logger import logger + +# Import shared models with fallback handling +try: + from ...shared.models import ( + BotProviderModel, + BotProviderRegisterRequest, + BotProviderRegisterResponse, + BotProviderListResponse, + BotListResponse, + BotInfoModel, + BotJoinLobbyRequest, + BotJoinLobbyResponse, + BotLeaveLobbyRequest, + BotLeaveLobbyResponse, + BotProviderBotsResponse, + BotProviderJoinResponse, + BotJoinPayload, + ) +except ImportError: + try: + from shared.models import ( + BotProviderModel, + BotProviderRegisterRequest, + BotProviderRegisterResponse, + BotProviderListResponse, + BotListResponse, + BotInfoModel, + BotJoinLobbyRequest, + BotJoinLobbyResponse, + BotLeaveLobbyRequest, + BotLeaveLobbyResponse, + BotProviderBotsResponse, + BotProviderJoinResponse, + BotJoinPayload, + ) + except ImportError: + # Create dummy models for standalone testing + from pydantic import BaseModel + + class BotProviderModel(BaseModel): + provider_id: str + base_url: str + name: str + description: str + provider_key: str + registered_at: float + last_seen: float + + class BotProviderRegisterRequest(BaseModel): + base_url: str + name: str + description: str + provider_key: str + + class BotProviderRegisterResponse(BaseModel): + provider_id: str + + class BotProviderListResponse(BaseModel): + providers: List[BotProviderModel] + + class BotInfoModel(BaseModel): + name: str + description: str + has_media: bool = False + + class BotListResponse(BaseModel): + bots: List[BotInfoModel] + providers: Dict[str, str] + + class BotJoinLobbyRequest(BaseModel): + lobby_id: str + provider_id: Optional[str] = None + nick: Optional[str] = None + + class BotJoinLobbyResponse(BaseModel): + status: str + bot_name: str + run_id: str + provider_id: str + + class BotLeaveLobbyRequest(BaseModel): + session_id: str + + class BotLeaveLobbyResponse(BaseModel): + status: str + session_id: str + run_id: Optional[str] = None + + class BotProviderBotsResponse(BaseModel): + bots: List[BotInfoModel] + + class BotProviderJoinResponse(BaseModel): + run_id: str + + class BotJoinPayload(BaseModel): + lobby_id: str + session_id: str + nick: str + server_url: str + insecure: bool = True + + +class BotProviderConfig: + """Configuration class for bot provider management""" + + # Comma-separated list of allowed provider keys + # Format: "key1:name1,key2:name2" or just "key1,key2" (names default to keys) + ALLOWED_PROVIDERS = os.getenv("BOT_PROVIDER_KEYS", "") + + @classmethod + def get_allowed_providers(cls) -> Dict[str, str]: + """Parse allowed providers from environment variable + + Returns: + dict mapping provider_key -> provider_name + """ + if not cls.ALLOWED_PROVIDERS.strip(): + return {} + + providers: Dict[str, str] = {} + for entry in cls.ALLOWED_PROVIDERS.split(","): + entry = entry.strip() + if not entry: + continue + + if ":" in entry: + key, name = entry.split(":", 1) + providers[key.strip()] = name.strip() + else: + providers[entry] = entry + + return providers + + +class BotManager: + """Manages bot providers and bot lifecycle""" + + def __init__(self): + self.bot_providers: Dict[str, BotProviderModel] = {} + self.lock = threading.RLock() + + # Check if provider authentication is enabled + allowed_providers = BotProviderConfig.get_allowed_providers() + if not allowed_providers: + logger.warning("Bot provider authentication disabled. Any provider can register.") + + async def register_provider(self, request: BotProviderRegisterRequest) -> BotProviderRegisterResponse: + """Register a new bot provider with authentication""" + + # Check if provider authentication is enabled + allowed_providers = BotProviderConfig.get_allowed_providers() + if allowed_providers: + # Authentication is enabled - validate provider key + if request.provider_key not in allowed_providers: + logger.warning(f"Rejected bot provider registration with invalid key: {request.provider_key}") + raise ValueError("Invalid provider key. Bot provider is not authorized to register.") + + # Check if there's already an active provider with this key and remove it + providers_to_remove: List[str] = [] + with self.lock: + for existing_provider_id, existing_provider in self.bot_providers.items(): + if existing_provider.provider_key == request.provider_key: + providers_to_remove.append(existing_provider_id) + logger.info(f"Removing stale bot provider: {existing_provider.name} (ID: {existing_provider_id})") + + # Remove stale providers + for provider_id_to_remove in providers_to_remove: + del self.bot_providers[provider_id_to_remove] + + provider_id = str(uuid.uuid4()) + now = time.time() + + provider = BotProviderModel( + provider_id=provider_id, + base_url=request.base_url.rstrip("/"), + name=request.name, + description=request.description, + provider_key=request.provider_key, + registered_at=now, + last_seen=now, + ) + + with self.lock: + self.bot_providers[provider_id] = provider + + logger.info(f"Registered bot provider: {request.name} at {request.base_url} with key: {request.provider_key}") + return BotProviderRegisterResponse(provider_id=provider_id) + + def list_providers(self) -> BotProviderListResponse: + """List all registered bot providers""" + with self.lock: + return BotProviderListResponse(providers=list(self.bot_providers.values())) + + async def list_bots(self) -> BotListResponse: + """List all available bots from all registered providers""" + bots: List[BotInfoModel] = [] + providers: Dict[str, str] = {} + + # Update last_seen timestamps and fetch bots from each provider + with self.lock: + providers_copy = dict(self.bot_providers.items()) + + for provider_id, provider in providers_copy.items(): + try: + # Update last_seen timestamp + with self.lock: + if provider_id in self.bot_providers: + self.bot_providers[provider_id].last_seen = time.time() + + # Make HTTP request to provider's /bots endpoint + async with httpx.AsyncClient() as client: + response = await client.get(f"{provider.base_url}/bots", timeout=5.0) + if response.status_code == 200: + # Use Pydantic model to validate the response + bots_response = BotProviderBotsResponse.model_validate(response.json()) + # Add each bot to the consolidated list + for bot_info in bots_response.bots: + bots.append(bot_info) + providers[bot_info.name] = provider_id + else: + logger.warning(f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}") + except Exception as e: + logger.error(f"Error fetching bots from provider {provider.name}: {e}") + continue + + return BotListResponse(bots=bots, providers=providers) + + async def request_bot_join(self, bot_name: str, request: BotJoinLobbyRequest, session_manager, lobby_manager) -> BotJoinLobbyResponse: + """Request a bot to join a specific lobby""" + + # Find which provider has this bot and determine its media capability + target_provider_id = request.provider_id + bot_has_media = False + + if not target_provider_id: + # Auto-discover provider for this bot + with self.lock: + providers_copy = dict(self.bot_providers.items()) + + for provider_id, provider in providers_copy.items(): + try: + async with httpx.AsyncClient() as client: + response = await client.get(f"{provider.base_url}/bots", timeout=5.0) + if response.status_code == 200: + # Use Pydantic model to validate the response + bots_response = BotProviderBotsResponse.model_validate(response.json()) + # Look for the bot by name + for bot_info in bots_response.bots: + if bot_info.name == bot_name: + target_provider_id = provider_id + bot_has_media = bot_info.has_media + break + if target_provider_id: + break + except Exception: + continue + else: + # Query the specified provider for bot media capability + with self.lock: + if target_provider_id in self.bot_providers: + provider = self.bot_providers[target_provider_id] + try: + async with httpx.AsyncClient() as client: + response = await client.get(f"{provider.base_url}/bots", timeout=5.0) + if response.status_code == 200: + # Use Pydantic model to validate the response + bots_response = BotProviderBotsResponse.model_validate(response.json()) + # Look for the bot by name + for bot_info in bots_response.bots: + if bot_info.name == bot_name: + bot_has_media = bot_info.has_media + break + except Exception: + # Default to no media if we can't query + pass + + if not target_provider_id: + raise ValueError("Bot or provider not found") + + with self.lock: + if target_provider_id not in self.bot_providers: + raise ValueError("Provider not found") + provider = self.bot_providers[target_provider_id] + + # Get the lobby to validate it exists + lobby = lobby_manager.get_lobby(request.lobby_id) + if not lobby: + raise ValueError("Lobby not found") + + # Create a session for the bot + bot_session_id = secrets.token_hex(16) + + # Create the Session object for the bot + bot_session = session_manager.create_session(bot_session_id, is_bot=True, has_media=bot_has_media) + logger.info(f"Created bot session for: {bot_session.getName()} (has_media={bot_has_media})") + + # Determine server URL for the bot to connect back to + # Use the server's public URL or construct from request + server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000") + if server_base_url.endswith("/"): + server_base_url = server_base_url[:-1] + + bot_nick = request.nick or f"{bot_name}-bot-{bot_session_id[:8]}" + + # Get public URL prefix from environment + public_url = os.getenv("PUBLIC_URL_PREFIX", "/ai-voicebot") + + # Prepare the join request for the bot provider + bot_join_payload = BotJoinPayload( + lobby_id=request.lobby_id, + session_id=bot_session_id, + nick=bot_nick, + server_url=f"{server_base_url}{public_url}".rstrip("/"), + insecure=True, # Accept self-signed certificates in development + ) + + try: + # Make request to bot provider + async with httpx.AsyncClient() as client: + response = await client.post( + f"{provider.base_url}/bots/{bot_name}/join", + json=bot_join_payload.model_dump(), + timeout=10.0, + ) + + if response.status_code == 200: + # Use Pydantic model to parse and validate response + try: + join_response = BotProviderJoinResponse.model_validate(response.json()) + run_id = join_response.run_id + + # Update bot session with run and provider information + bot_session.setName(bot_nick) + bot_session.bot_run_id = run_id + bot_session.bot_provider_id = target_provider_id + + logger.info(f"Bot {bot_name} requested to join lobby {request.lobby_id}") + + return BotJoinLobbyResponse( + status="requested", + bot_name=bot_name, + run_id=run_id, + provider_id=target_provider_id, + ) + except ValidationError as e: + logger.error(f"Invalid response from bot provider: {e}") + raise ValueError(f"Bot provider returned invalid response: {str(e)}") + else: + logger.error(f"Bot provider returned error: HTTP {response.status_code}: {response.text}") + raise ValueError(f"Bot provider error: {response.status_code}") + + except httpx.TimeoutException: + raise ValueError("Bot provider timeout") + except Exception as e: + logger.error(f"Error requesting bot join: {e}") + raise ValueError(f"Internal server error: {str(e)}") + + async def request_bot_leave(self, request: BotLeaveLobbyRequest, session_manager) -> BotLeaveLobbyResponse: + """Request a bot to leave from all lobbies and disconnect""" + + # Find the bot session + bot_session = session_manager.get_session(request.session_id) + if not bot_session: + raise ValueError("Bot session not found") + + if not bot_session.is_bot: + raise ValueError("Session is not a bot") + + run_id = bot_session.bot_run_id + provider_id = bot_session.bot_provider_id + + logger.info(f"Requesting bot {bot_session.getName()} to leave all lobbies") + + # Try to stop the bot at the provider level if we have the information + if provider_id and run_id: + with self.lock: + if provider_id in self.bot_providers: + provider = self.bot_providers[provider_id] + try: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{provider.base_url}/bots/runs/{run_id}/stop", + timeout=5.0, + ) + if response.status_code == 200: + logger.info(f"Successfully requested bot provider to stop run {run_id}") + else: + logger.warning(f"Bot provider returned error when stopping: HTTP {response.status_code}") + except Exception as e: + logger.warning(f"Failed to request bot stop from provider: {e}") + + # Force disconnect the bot session from all lobbies + lobbies_to_leave = bot_session.lobbies[:] + + for lobby in lobbies_to_leave: + try: + await bot_session.leave_lobby(lobby) + except Exception as e: + logger.warning(f"Error removing bot from lobby {lobby.name}: {e}") + + # Close WebSocket connection if it exists + if bot_session.ws: + try: + await bot_session.ws.close() + except Exception as e: + logger.warning(f"Error closing bot WebSocket: {e}") + bot_session.ws = None + + return BotLeaveLobbyResponse( + status="disconnected", + session_id=request.session_id, + run_id=run_id, + ) + + def get_provider(self, provider_id: str) -> Optional[BotProviderModel]: + """Get a specific bot provider by ID""" + with self.lock: + return self.bot_providers.get(provider_id) diff --git a/server/main.py b/server/main.py index 1ee3532..7d7d97f 100644 --- a/server/main.py +++ b/server/main.py @@ -29,10 +29,12 @@ try: from core.session_manager import SessionManager from core.lobby_manager import LobbyManager from core.auth_manager import AuthManager + from core.bot_manager import BotManager from websocket.connection import WebSocketConnectionManager from api.admin import AdminAPI from api.sessions import SessionAPI from api.lobbies import LobbyAPI + from api.bots import create_bot_router except ImportError: # Handle relative imports when running as module import sys @@ -43,10 +45,12 @@ except ImportError: from core.session_manager import SessionManager from core.lobby_manager import LobbyManager from core.auth_manager import AuthManager + from core.bot_manager import BotManager from websocket.connection import WebSocketConnectionManager from api.admin import AdminAPI from api.sessions import SessionAPI from api.lobbies import LobbyAPI + from api.bots import create_bot_router from logger import logger @@ -72,13 +76,14 @@ logger.info(f"Starting server with public URL: {public_url}") session_manager: SessionManager = None lobby_manager: LobbyManager = None auth_manager: AuthManager = None +bot_manager: BotManager = None websocket_manager: WebSocketConnectionManager = None @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for startup and shutdown events""" - global session_manager, lobby_manager, auth_manager, websocket_manager + global session_manager, lobby_manager, auth_manager, bot_manager, websocket_manager # Startup logger.info("Starting AI Voice Bot server with modular architecture...") @@ -87,6 +92,7 @@ async def lifespan(app: FastAPI): session_manager = SessionManager("sessions.json") lobby_manager = LobbyManager() auth_manager = AuthManager("sessions.json") + bot_manager = BotManager() # Load existing data session_manager.load() @@ -127,10 +133,14 @@ async def lifespan(app: FastAPI): public_url=public_url, ) + # Create bot API router + bot_router = create_bot_router(bot_manager, session_manager, lobby_manager) + # Register API routes during startup app.include_router(admin_api.router) app.include_router(session_api.router) app.include_router(lobby_api.router) + app.include_router(bot_router, prefix=public_url.rstrip("/") + "/api") # Register static file serving AFTER API routes to avoid conflicts PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true" @@ -283,6 +293,7 @@ def system_health(): "session_manager": "active" if session_manager else "inactive", "lobby_manager": "active" if lobby_manager else "inactive", "auth_manager": "active" if auth_manager else "inactive", + "bot_manager": "active" if bot_manager else "inactive", "websocket_manager": "active" if websocket_manager else "inactive", }, "statistics": {