463 lines
19 KiB
Python
463 lines
19 KiB
Python
"""Bot Provider Management"""
|
|
import os
|
|
import time
|
|
import uuid
|
|
import secrets
|
|
import threading
|
|
import httpx
|
|
from typing import Dict, List, Optional
|
|
from pydantic import ValidationError
|
|
from logger import logger
|
|
|
|
# Import shared models
|
|
import sys
|
|
|
|
sys.path.append(
|
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
)
|
|
from shared.models import (
|
|
BotProviderModel,
|
|
BotProviderRegisterRequest,
|
|
BotProviderRegisterResponse,
|
|
BotProviderListResponse,
|
|
BotListResponse,
|
|
BotInfoModel,
|
|
BotJoinLobbyRequest,
|
|
BotJoinLobbyResponse,
|
|
BotLeaveLobbyRequest,
|
|
BotLeaveLobbyResponse,
|
|
BotProviderBotsResponse,
|
|
BotProviderJoinResponse,
|
|
BotJoinPayload,
|
|
BotInstanceModel,
|
|
)
|
|
from core.session_manager import SessionManager
|
|
from core.lobby_manager import LobbyManager
|
|
|
|
class BotProviderConfig:
|
|
"""Configuration class for bot provider management"""
|
|
|
|
# Comma-separated list of allowed provider keys
|
|
# Format: "key1:name1,key2:name2" or just "key1,key2" (names default to keys)
|
|
ALLOWED_PROVIDERS = os.getenv("BOT_PROVIDER_KEYS", "")
|
|
|
|
@classmethod
|
|
def get_allowed_providers(cls) -> Dict[str, str]:
|
|
"""Parse allowed providers from environment variable
|
|
|
|
Returns:
|
|
dict mapping provider_key -> provider_name
|
|
"""
|
|
if not cls.ALLOWED_PROVIDERS.strip():
|
|
return {}
|
|
|
|
providers: Dict[str, str] = {}
|
|
for entry in cls.ALLOWED_PROVIDERS.split(","):
|
|
entry = entry.strip()
|
|
if not entry:
|
|
continue
|
|
|
|
if ":" in entry:
|
|
key, name = entry.split(":", 1)
|
|
providers[key.strip()] = name.strip()
|
|
else:
|
|
providers[entry] = entry
|
|
|
|
return providers
|
|
|
|
|
|
class BotManager:
|
|
"""Manages bot providers and bot lifecycle"""
|
|
|
|
def __init__(self):
|
|
self.bot_providers: Dict[str, BotProviderModel] = {}
|
|
self.bot_instances: Dict[
|
|
str, BotInstanceModel
|
|
] = {} # bot_instance_id -> BotInstanceModel
|
|
self.lock = threading.RLock()
|
|
|
|
# Check if provider authentication is enabled
|
|
allowed_providers = BotProviderConfig.get_allowed_providers()
|
|
if not allowed_providers:
|
|
logger.warning("Bot provider authentication disabled. Any provider can register.")
|
|
|
|
async def register_provider(self, request: BotProviderRegisterRequest) -> BotProviderRegisterResponse:
|
|
"""Register a new bot provider with authentication"""
|
|
|
|
# Check if provider authentication is enabled
|
|
allowed_providers = BotProviderConfig.get_allowed_providers()
|
|
if allowed_providers:
|
|
# Authentication is enabled - validate provider key
|
|
if request.provider_key not in allowed_providers:
|
|
logger.warning(f"Rejected bot provider registration with invalid key: {request.provider_key}")
|
|
raise ValueError("Invalid provider key. Bot provider is not authorized to register.")
|
|
|
|
# Check if there's already an active provider with this key and remove it
|
|
providers_to_remove: List[str] = []
|
|
with self.lock:
|
|
for existing_provider_id, existing_provider in self.bot_providers.items():
|
|
if existing_provider.provider_key == request.provider_key:
|
|
providers_to_remove.append(existing_provider_id)
|
|
logger.info(f"Removing stale bot provider: {existing_provider.name} (ID: {existing_provider_id})")
|
|
|
|
# Remove stale providers
|
|
for provider_id_to_remove in providers_to_remove:
|
|
del self.bot_providers[provider_id_to_remove]
|
|
|
|
provider_id = str(uuid.uuid4())
|
|
now = time.time()
|
|
|
|
provider = BotProviderModel(
|
|
provider_id=provider_id,
|
|
base_url=request.base_url.rstrip("/"),
|
|
name=request.name,
|
|
description=request.description,
|
|
provider_key=request.provider_key,
|
|
registered_at=now,
|
|
last_seen=now,
|
|
)
|
|
|
|
with self.lock:
|
|
self.bot_providers[provider_id] = provider
|
|
|
|
logger.info(f"Registered bot provider: {request.name} at {request.base_url} with key: {request.provider_key}")
|
|
return BotProviderRegisterResponse(provider_id=provider_id)
|
|
|
|
def list_providers(self) -> BotProviderListResponse:
|
|
"""List all registered bot providers"""
|
|
with self.lock:
|
|
return BotProviderListResponse(providers=list(self.bot_providers.values()))
|
|
|
|
async def list_bots(self) -> BotListResponse:
|
|
"""List all available bots from all registered providers"""
|
|
bots: List[BotInfoModel] = []
|
|
providers: Dict[str, str] = {}
|
|
|
|
# Update last_seen timestamps and fetch bots from each provider
|
|
with self.lock:
|
|
providers_copy = dict(self.bot_providers.items())
|
|
|
|
for provider_id, provider in providers_copy.items():
|
|
try:
|
|
# Update last_seen timestamp
|
|
with self.lock:
|
|
if provider_id in self.bot_providers:
|
|
self.bot_providers[provider_id].last_seen = time.time()
|
|
|
|
# Make HTTP request to provider's /bots endpoint
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(f"{provider.base_url}/bots", timeout=5.0)
|
|
if response.status_code == 200:
|
|
# Use Pydantic model to validate the response
|
|
bots_response = BotProviderBotsResponse.model_validate(response.json())
|
|
# Add each bot to the consolidated list
|
|
for bot_info in bots_response.bots:
|
|
bots.append(bot_info)
|
|
providers[bot_info.name] = provider_id
|
|
else:
|
|
logger.warning(f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}")
|
|
except Exception as e:
|
|
logger.error(f"Error fetching bots from provider {provider.name}: {e}")
|
|
continue
|
|
|
|
return BotListResponse(bots=bots, providers=providers)
|
|
|
|
async def get_provider_bots(self, provider_id: str) -> BotProviderBotsResponse:
|
|
"""Get bots from a specific provider"""
|
|
provider = self.get_provider(provider_id)
|
|
if not provider:
|
|
raise ValueError(f"Provider {provider_id} not found")
|
|
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(f"{provider.base_url}/bots", timeout=5.0)
|
|
if response.status_code == 200:
|
|
return BotProviderBotsResponse.model_validate(response.json())
|
|
else:
|
|
logger.warning(
|
|
f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}"
|
|
)
|
|
return BotProviderBotsResponse(bots=[])
|
|
except Exception as e:
|
|
logger.error(f"Error fetching bots from provider {provider.name}: {e}")
|
|
return BotProviderBotsResponse(bots=[])
|
|
|
|
async def request_bot_join(
|
|
self,
|
|
bot_name: str,
|
|
request: BotJoinLobbyRequest,
|
|
session_manager: SessionManager,
|
|
lobby_manager: LobbyManager,
|
|
) -> BotJoinLobbyResponse:
|
|
"""Request a bot to join a specific lobby"""
|
|
|
|
# Find which provider has this bot and determine its media capability
|
|
target_provider_id = request.provider_id
|
|
bot_has_media = False
|
|
|
|
if not target_provider_id:
|
|
# Auto-discover provider for this bot
|
|
with self.lock:
|
|
providers_copy = dict(self.bot_providers.items())
|
|
|
|
for provider_id, provider in providers_copy.items():
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
f"{provider.base_url}/bots", timeout=5.0
|
|
)
|
|
if response.status_code == 200:
|
|
# Use Pydantic model to validate the response
|
|
bots_response = BotProviderBotsResponse.model_validate(
|
|
response.json()
|
|
)
|
|
# Look for the bot by name
|
|
for bot_info in bots_response.bots:
|
|
if bot_info.name == bot_name:
|
|
target_provider_id = provider_id
|
|
bot_has_media = bot_info.has_media
|
|
break
|
|
if target_provider_id:
|
|
break
|
|
except Exception:
|
|
continue
|
|
else:
|
|
# Query the specified provider for bot media capability
|
|
with self.lock:
|
|
if target_provider_id in self.bot_providers:
|
|
provider = self.bot_providers[target_provider_id]
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
f"{provider.base_url}/bots", timeout=5.0
|
|
)
|
|
if response.status_code == 200:
|
|
# Use Pydantic model to validate the response
|
|
bots_response = BotProviderBotsResponse.model_validate(
|
|
response.json()
|
|
)
|
|
# Look for the bot by name
|
|
for bot_info in bots_response.bots:
|
|
if bot_info.name == bot_name:
|
|
bot_has_media = bot_info.has_media
|
|
break
|
|
except Exception:
|
|
# Default to no media if we can't query
|
|
pass
|
|
|
|
if not target_provider_id:
|
|
raise ValueError("Bot or provider not found")
|
|
|
|
with self.lock:
|
|
if target_provider_id not in self.bot_providers:
|
|
raise ValueError("Provider not found")
|
|
provider = self.bot_providers[target_provider_id]
|
|
|
|
# Get the lobby to validate it exists
|
|
lobby = lobby_manager.get_lobby(request.lobby_id)
|
|
if not lobby:
|
|
raise ValueError("Lobby not found")
|
|
|
|
# Create a session for the bot
|
|
bot_session_id = secrets.token_hex(16)
|
|
bot_instance_id = str(uuid.uuid4())
|
|
|
|
# Create the Session object for the bot
|
|
bot_session = session_manager.get_or_create_session(
|
|
bot_session_id, bot_instance_id=bot_instance_id, has_media=bot_has_media
|
|
)
|
|
logger.info(
|
|
f"Created bot session for: {bot_session.getName()} (has_media={bot_has_media})"
|
|
)
|
|
|
|
# Determine server URL for the bot to connect back to
|
|
# Use the server's public URL or construct from request
|
|
server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000")
|
|
if server_base_url.endswith("/"):
|
|
server_base_url = server_base_url[:-1]
|
|
|
|
bot_nick = request.nick or f"{bot_name}-bot-{bot_session_id[:8]}"
|
|
|
|
# Get public URL prefix from environment
|
|
public_url = os.getenv("PUBLIC_URL_PREFIX", "/ai-voicebot")
|
|
|
|
# Prepare the join request for the bot provider
|
|
bot_join_payload = BotJoinPayload(
|
|
lobby_id=request.lobby_id,
|
|
session_id=bot_session_id,
|
|
nick=bot_nick,
|
|
server_url=f"{server_base_url}{public_url}".rstrip("/"),
|
|
insecure=True, # Accept self-signed certificates in development
|
|
)
|
|
|
|
try:
|
|
# Make request to bot provider
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
f"{provider.base_url}/bots/{bot_name}/join",
|
|
json=bot_join_payload.model_dump(),
|
|
timeout=10.0,
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
# Use Pydantic model to parse and validate response
|
|
try:
|
|
join_response = BotProviderJoinResponse.model_validate(
|
|
response.json()
|
|
)
|
|
run_id = join_response.run_id
|
|
|
|
# Update bot session with run and provider information
|
|
bot_session.setName(bot_nick)
|
|
bot_session.bot_run_id = run_id
|
|
bot_session.bot_provider_id = target_provider_id
|
|
|
|
# Create a unique bot instance ID and track the bot instance
|
|
bot_instance = BotInstanceModel(
|
|
bot_instance_id=bot_instance_id,
|
|
bot_name=bot_name,
|
|
nick=bot_nick,
|
|
lobby_id=request.lobby_id,
|
|
session_id=bot_session_id,
|
|
provider_id=target_provider_id,
|
|
run_id=run_id,
|
|
has_media=bot_has_media,
|
|
created_at=time.time(),
|
|
)
|
|
|
|
# Set the bot_instance_id on the session as well
|
|
bot_session.bot_instance_id = bot_instance_id
|
|
|
|
with self.lock:
|
|
self.bot_instances[bot_instance_id] = bot_instance
|
|
|
|
logger.info(
|
|
f"Bot {bot_name} requested to join lobby {request.lobby_id} with instance ID {bot_instance_id}"
|
|
)
|
|
|
|
return BotJoinLobbyResponse(
|
|
status="requested",
|
|
bot_instance_id=bot_instance_id,
|
|
bot_name=bot_name,
|
|
run_id=run_id,
|
|
provider_id=target_provider_id,
|
|
session_id=bot_session_id,
|
|
)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid response from bot provider: {e}")
|
|
raise ValueError(
|
|
f"Bot provider returned invalid response: {str(e)}"
|
|
)
|
|
else:
|
|
logger.error(
|
|
f"Bot provider returned error: HTTP {response.status_code}: {response.text}"
|
|
)
|
|
raise ValueError(f"Bot provider error: {response.status_code}")
|
|
|
|
except httpx.TimeoutException:
|
|
raise ValueError("Bot provider timeout")
|
|
except Exception as e:
|
|
logger.error(f"Error requesting bot join: {e}")
|
|
raise ValueError(f"Internal server error: {str(e)}")
|
|
|
|
async def request_bot_leave(
|
|
self, request: BotLeaveLobbyRequest, session_manager: SessionManager
|
|
) -> BotLeaveLobbyResponse:
|
|
"""Request a bot to leave from all lobbies and disconnect"""
|
|
|
|
# Find the bot instance
|
|
with self.lock:
|
|
if request.bot_instance_id not in self.bot_instances:
|
|
raise ValueError("Bot instance not found")
|
|
bot_instance = self.bot_instances[request.bot_instance_id]
|
|
|
|
# Find the bot session
|
|
bot_session = session_manager.get_session(bot_instance.session_id)
|
|
if not bot_session:
|
|
raise ValueError("Bot session not found")
|
|
|
|
if not bot_session.bot_instance_id:
|
|
raise ValueError("Session is not a bot")
|
|
|
|
logger.info(
|
|
f"Requesting bot instance {bot_instance.bot_instance_id} to leave all lobbies"
|
|
)
|
|
|
|
# Try to stop the bot at the provider level
|
|
try:
|
|
with self.lock:
|
|
if bot_instance.provider_id in self.bot_providers:
|
|
provider = self.bot_providers[bot_instance.provider_id]
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
f"{provider.base_url}/bots/runs/{bot_instance.run_id}/stop",
|
|
timeout=5.0,
|
|
)
|
|
if response.status_code == 200:
|
|
logger.info(
|
|
f"Successfully requested bot provider to stop run {bot_instance.run_id}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Bot provider returned error when stopping: HTTP {response.status_code}"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to request bot stop from provider: {e}")
|
|
except Exception as e:
|
|
logger.warning(f"Error communicating with bot provider: {e}")
|
|
|
|
# Force disconnect the bot session from all lobbies
|
|
try:
|
|
lobbies_to_leave = bot_session.lobbies[:]
|
|
|
|
for lobby in lobbies_to_leave:
|
|
try:
|
|
await bot_session.leave_lobby(lobby)
|
|
except Exception as e:
|
|
logger.warning(f"Error removing bot from lobby {lobby.name}: {e}")
|
|
|
|
# Close WebSocket connection if it exists
|
|
if bot_session.ws:
|
|
try:
|
|
await bot_session.ws.close()
|
|
except Exception as e:
|
|
logger.warning(f"Error closing bot WebSocket: {e}")
|
|
bot_session.ws = None
|
|
except Exception as e:
|
|
logger.warning(f"Error disconnecting bot session: {e}")
|
|
|
|
# Remove bot instance from tracking
|
|
with self.lock:
|
|
if request.bot_instance_id in self.bot_instances:
|
|
del self.bot_instances[request.bot_instance_id]
|
|
|
|
return BotLeaveLobbyResponse(
|
|
status="disconnected",
|
|
bot_instance_id=request.bot_instance_id,
|
|
session_id=bot_instance.session_id,
|
|
run_id=bot_instance.run_id,
|
|
)
|
|
|
|
async def get_bot_instance(self, bot_instance_id: str) -> BotInstanceModel:
|
|
"""Get information about a specific bot instance"""
|
|
with self.lock:
|
|
if bot_instance_id not in self.bot_instances:
|
|
raise ValueError("Bot instance not found")
|
|
bot_instance = self.bot_instances[bot_instance_id]
|
|
|
|
return bot_instance
|
|
|
|
def get_bot_instance_id_by_session_id(self, session_id: str) -> Optional[str]:
|
|
"""Get bot_instance_id by session_id"""
|
|
with self.lock:
|
|
for bot_instance_id, bot_instance in self.bot_instances.items():
|
|
if bot_instance.session_id == session_id:
|
|
return bot_instance_id
|
|
return None
|
|
|
|
def get_provider(self, provider_id: str) -> Optional[BotProviderModel]:
|
|
"""Get a specific bot provider by ID"""
|
|
with self.lock:
|
|
return self.bot_providers.get(provider_id)
|