"""Bot Provider Management""" import os import time import uuid import secrets import threading import httpx from typing import Dict, List, Optional from pydantic import ValidationError from logger import logger # Import shared models with fallback handling try: from ...shared.models import ( BotProviderModel, BotProviderRegisterRequest, BotProviderRegisterResponse, BotProviderListResponse, BotListResponse, BotInfoModel, BotJoinLobbyRequest, BotJoinLobbyResponse, BotLeaveLobbyRequest, BotLeaveLobbyResponse, BotProviderBotsResponse, BotProviderJoinResponse, BotJoinPayload, ) except ImportError: try: from shared.models import ( BotProviderModel, BotProviderRegisterRequest, BotProviderRegisterResponse, BotProviderListResponse, BotListResponse, BotInfoModel, BotJoinLobbyRequest, BotJoinLobbyResponse, BotLeaveLobbyRequest, BotLeaveLobbyResponse, BotProviderBotsResponse, BotProviderJoinResponse, BotJoinPayload, ) except ImportError: # Create dummy models for standalone testing from pydantic import BaseModel class BotProviderModel(BaseModel): provider_id: str base_url: str name: str description: str provider_key: str registered_at: float last_seen: float class BotProviderRegisterRequest(BaseModel): base_url: str name: str description: str provider_key: str class BotProviderRegisterResponse(BaseModel): provider_id: str class BotProviderListResponse(BaseModel): providers: List[BotProviderModel] class BotInfoModel(BaseModel): name: str description: str has_media: bool = False class BotListResponse(BaseModel): bots: List[BotInfoModel] providers: Dict[str, str] class BotJoinLobbyRequest(BaseModel): lobby_id: str provider_id: Optional[str] = None nick: Optional[str] = None class BotJoinLobbyResponse(BaseModel): status: str bot_name: str run_id: str provider_id: str class BotLeaveLobbyRequest(BaseModel): session_id: str class BotLeaveLobbyResponse(BaseModel): status: str session_id: str run_id: Optional[str] = None class BotProviderBotsResponse(BaseModel): bots: List[BotInfoModel] class BotProviderJoinResponse(BaseModel): run_id: str class BotJoinPayload(BaseModel): lobby_id: str session_id: str nick: str server_url: str insecure: bool = True class BotProviderConfig: """Configuration class for bot provider management""" # Comma-separated list of allowed provider keys # Format: "key1:name1,key2:name2" or just "key1,key2" (names default to keys) ALLOWED_PROVIDERS = os.getenv("BOT_PROVIDER_KEYS", "") @classmethod def get_allowed_providers(cls) -> Dict[str, str]: """Parse allowed providers from environment variable Returns: dict mapping provider_key -> provider_name """ if not cls.ALLOWED_PROVIDERS.strip(): return {} providers: Dict[str, str] = {} for entry in cls.ALLOWED_PROVIDERS.split(","): entry = entry.strip() if not entry: continue if ":" in entry: key, name = entry.split(":", 1) providers[key.strip()] = name.strip() else: providers[entry] = entry return providers class BotManager: """Manages bot providers and bot lifecycle""" def __init__(self): self.bot_providers: Dict[str, BotProviderModel] = {} self.lock = threading.RLock() # Check if provider authentication is enabled allowed_providers = BotProviderConfig.get_allowed_providers() if not allowed_providers: logger.warning("Bot provider authentication disabled. Any provider can register.") async def register_provider(self, request: BotProviderRegisterRequest) -> BotProviderRegisterResponse: """Register a new bot provider with authentication""" # Check if provider authentication is enabled allowed_providers = BotProviderConfig.get_allowed_providers() if allowed_providers: # Authentication is enabled - validate provider key if request.provider_key not in allowed_providers: logger.warning(f"Rejected bot provider registration with invalid key: {request.provider_key}") raise ValueError("Invalid provider key. Bot provider is not authorized to register.") # Check if there's already an active provider with this key and remove it providers_to_remove: List[str] = [] with self.lock: for existing_provider_id, existing_provider in self.bot_providers.items(): if existing_provider.provider_key == request.provider_key: providers_to_remove.append(existing_provider_id) logger.info(f"Removing stale bot provider: {existing_provider.name} (ID: {existing_provider_id})") # Remove stale providers for provider_id_to_remove in providers_to_remove: del self.bot_providers[provider_id_to_remove] provider_id = str(uuid.uuid4()) now = time.time() provider = BotProviderModel( provider_id=provider_id, base_url=request.base_url.rstrip("/"), name=request.name, description=request.description, provider_key=request.provider_key, registered_at=now, last_seen=now, ) with self.lock: self.bot_providers[provider_id] = provider logger.info(f"Registered bot provider: {request.name} at {request.base_url} with key: {request.provider_key}") return BotProviderRegisterResponse(provider_id=provider_id) def list_providers(self) -> BotProviderListResponse: """List all registered bot providers""" with self.lock: return BotProviderListResponse(providers=list(self.bot_providers.values())) async def list_bots(self) -> BotListResponse: """List all available bots from all registered providers""" bots: List[BotInfoModel] = [] providers: Dict[str, str] = {} # Update last_seen timestamps and fetch bots from each provider with self.lock: providers_copy = dict(self.bot_providers.items()) for provider_id, provider in providers_copy.items(): try: # Update last_seen timestamp with self.lock: if provider_id in self.bot_providers: self.bot_providers[provider_id].last_seen = time.time() # Make HTTP request to provider's /bots endpoint async with httpx.AsyncClient() as client: response = await client.get(f"{provider.base_url}/bots", timeout=5.0) if response.status_code == 200: # Use Pydantic model to validate the response bots_response = BotProviderBotsResponse.model_validate(response.json()) # Add each bot to the consolidated list for bot_info in bots_response.bots: bots.append(bot_info) providers[bot_info.name] = provider_id else: logger.warning(f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}") except Exception as e: logger.error(f"Error fetching bots from provider {provider.name}: {e}") continue return BotListResponse(bots=bots, providers=providers) async def request_bot_join(self, bot_name: str, request: BotJoinLobbyRequest, session_manager, lobby_manager) -> BotJoinLobbyResponse: """Request a bot to join a specific lobby""" # Find which provider has this bot and determine its media capability target_provider_id = request.provider_id bot_has_media = False if not target_provider_id: # Auto-discover provider for this bot with self.lock: providers_copy = dict(self.bot_providers.items()) for provider_id, provider in providers_copy.items(): try: async with httpx.AsyncClient() as client: response = await client.get(f"{provider.base_url}/bots", timeout=5.0) if response.status_code == 200: # Use Pydantic model to validate the response bots_response = BotProviderBotsResponse.model_validate(response.json()) # Look for the bot by name for bot_info in bots_response.bots: if bot_info.name == bot_name: target_provider_id = provider_id bot_has_media = bot_info.has_media break if target_provider_id: break except Exception: continue else: # Query the specified provider for bot media capability with self.lock: if target_provider_id in self.bot_providers: provider = self.bot_providers[target_provider_id] try: async with httpx.AsyncClient() as client: response = await client.get(f"{provider.base_url}/bots", timeout=5.0) if response.status_code == 200: # Use Pydantic model to validate the response bots_response = BotProviderBotsResponse.model_validate(response.json()) # Look for the bot by name for bot_info in bots_response.bots: if bot_info.name == bot_name: bot_has_media = bot_info.has_media break except Exception: # Default to no media if we can't query pass if not target_provider_id: raise ValueError("Bot or provider not found") with self.lock: if target_provider_id not in self.bot_providers: raise ValueError("Provider not found") provider = self.bot_providers[target_provider_id] # Get the lobby to validate it exists lobby = lobby_manager.get_lobby(request.lobby_id) if not lobby: raise ValueError("Lobby not found") # Create a session for the bot bot_session_id = secrets.token_hex(16) # Create the Session object for the bot bot_session = session_manager.create_session(bot_session_id, is_bot=True, has_media=bot_has_media) logger.info(f"Created bot session for: {bot_session.getName()} (has_media={bot_has_media})") # Determine server URL for the bot to connect back to # Use the server's public URL or construct from request server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000") if server_base_url.endswith("/"): server_base_url = server_base_url[:-1] bot_nick = request.nick or f"{bot_name}-bot-{bot_session_id[:8]}" # Get public URL prefix from environment public_url = os.getenv("PUBLIC_URL_PREFIX", "/ai-voicebot") # Prepare the join request for the bot provider bot_join_payload = BotJoinPayload( lobby_id=request.lobby_id, session_id=bot_session_id, nick=bot_nick, server_url=f"{server_base_url}{public_url}".rstrip("/"), insecure=True, # Accept self-signed certificates in development ) try: # Make request to bot provider async with httpx.AsyncClient() as client: response = await client.post( f"{provider.base_url}/bots/{bot_name}/join", json=bot_join_payload.model_dump(), timeout=10.0, ) if response.status_code == 200: # Use Pydantic model to parse and validate response try: join_response = BotProviderJoinResponse.model_validate(response.json()) run_id = join_response.run_id # Update bot session with run and provider information bot_session.setName(bot_nick) bot_session.bot_run_id = run_id bot_session.bot_provider_id = target_provider_id logger.info(f"Bot {bot_name} requested to join lobby {request.lobby_id}") return BotJoinLobbyResponse( status="requested", bot_name=bot_name, run_id=run_id, provider_id=target_provider_id, ) except ValidationError as e: logger.error(f"Invalid response from bot provider: {e}") raise ValueError(f"Bot provider returned invalid response: {str(e)}") else: logger.error(f"Bot provider returned error: HTTP {response.status_code}: {response.text}") raise ValueError(f"Bot provider error: {response.status_code}") except httpx.TimeoutException: raise ValueError("Bot provider timeout") except Exception as e: logger.error(f"Error requesting bot join: {e}") raise ValueError(f"Internal server error: {str(e)}") async def request_bot_leave(self, request: BotLeaveLobbyRequest, session_manager) -> BotLeaveLobbyResponse: """Request a bot to leave from all lobbies and disconnect""" # Find the bot session bot_session = session_manager.get_session(request.session_id) if not bot_session: raise ValueError("Bot session not found") if not bot_session.is_bot: raise ValueError("Session is not a bot") run_id = bot_session.bot_run_id provider_id = bot_session.bot_provider_id logger.info(f"Requesting bot {bot_session.getName()} to leave all lobbies") # Try to stop the bot at the provider level if we have the information if provider_id and run_id: with self.lock: if provider_id in self.bot_providers: provider = self.bot_providers[provider_id] try: async with httpx.AsyncClient() as client: response = await client.post( f"{provider.base_url}/bots/runs/{run_id}/stop", timeout=5.0, ) if response.status_code == 200: logger.info(f"Successfully requested bot provider to stop run {run_id}") else: logger.warning(f"Bot provider returned error when stopping: HTTP {response.status_code}") except Exception as e: logger.warning(f"Failed to request bot stop from provider: {e}") # Force disconnect the bot session from all lobbies lobbies_to_leave = bot_session.lobbies[:] for lobby in lobbies_to_leave: try: await bot_session.leave_lobby(lobby) except Exception as e: logger.warning(f"Error removing bot from lobby {lobby.name}: {e}") # Close WebSocket connection if it exists if bot_session.ws: try: await bot_session.ws.close() except Exception as e: logger.warning(f"Error closing bot WebSocket: {e}") bot_session.ws = None return BotLeaveLobbyResponse( status="disconnected", session_id=request.session_id, run_id=run_id, ) def get_provider(self, provider_id: str) -> Optional[BotProviderModel]: """Get a specific bot provider by ID""" with self.lock: return self.bot_providers.get(provider_id)