581 lines
24 KiB
Python
581 lines
24 KiB
Python
"""Bot Provider Management - Test fix for stale lobby IDs after restart"""
|
|
import os
|
|
import time
|
|
import uuid
|
|
import secrets
|
|
import threading
|
|
import httpx
|
|
import json
|
|
import asyncio
|
|
from typing import Dict, List, Optional
|
|
from pydantic import ValidationError
|
|
from logger import logger
|
|
|
|
# Import shared models
|
|
import sys
|
|
|
|
sys.path.append(
|
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
)
|
|
from shared.models import (
|
|
BotProviderModel,
|
|
BotProviderPublicModel,
|
|
BotProviderRegisterRequest,
|
|
BotProviderRegisterResponse,
|
|
BotProviderListResponse,
|
|
BotListResponse,
|
|
BotInfoModel,
|
|
BotJoinLobbyRequest,
|
|
BotJoinLobbyResponse,
|
|
BotLeaveLobbyRequest,
|
|
BotLeaveLobbyResponse,
|
|
BotProviderBotsResponse,
|
|
BotProviderJoinResponse,
|
|
BotJoinPayload,
|
|
BotInstanceModel,
|
|
)
|
|
from core.session_manager import SessionManager
|
|
from core.lobby_manager import LobbyManager
|
|
|
|
class BotProviderConfig:
|
|
"""Configuration class for bot provider management"""
|
|
|
|
# Comma-separated list of allowed provider keys
|
|
# Format: "key1:name1,key2:name2" or just "key1,key2" (names default to keys)
|
|
ALLOWED_PROVIDERS = os.getenv("BOT_PROVIDER_KEYS", "")
|
|
|
|
@classmethod
|
|
def get_allowed_providers(cls) -> Dict[str, str]:
|
|
"""Parse allowed providers from environment variable
|
|
|
|
Returns:
|
|
dict mapping provider_key -> provider_name
|
|
"""
|
|
if not cls.ALLOWED_PROVIDERS.strip():
|
|
return {}
|
|
|
|
providers: Dict[str, str] = {}
|
|
for entry in cls.ALLOWED_PROVIDERS.split(","):
|
|
entry = entry.strip()
|
|
if not entry:
|
|
continue
|
|
|
|
if ":" in entry:
|
|
key, name = entry.split(":", 1)
|
|
providers[key.strip()] = name.strip()
|
|
else:
|
|
providers[entry] = entry
|
|
|
|
return providers
|
|
|
|
|
|
class BotManager:
|
|
"""Manages bot providers and bot lifecycle"""
|
|
|
|
def __init__(self):
|
|
self.bot_providers: Dict[str, BotProviderModel] = {}
|
|
self.bot_instances: Dict[
|
|
str, BotInstanceModel
|
|
] = {} # bot_instance_id -> BotInstanceModel
|
|
self.lock = threading.RLock()
|
|
self.bot_providers_file = "bot_providers.json"
|
|
self.cleanup_task: Optional["asyncio.Task[None]"] = None
|
|
self._shutdown_event = asyncio.Event()
|
|
|
|
# Check if provider authentication is enabled
|
|
allowed_providers = BotProviderConfig.get_allowed_providers()
|
|
if not allowed_providers:
|
|
logger.warning("Bot provider authentication disabled. Any provider can register.")
|
|
|
|
# Load persisted bot providers
|
|
self._load_bot_providers()
|
|
|
|
# Note: Don't start cleanup task here - will be started when needed
|
|
|
|
def start_cleanup(self):
|
|
"""Start the cleanup task"""
|
|
try:
|
|
if self.cleanup_task is None:
|
|
self.cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
|
logger.debug("Bot provider cleanup task started")
|
|
except RuntimeError:
|
|
# No event loop running yet, cleanup will be started later
|
|
logger.debug("No event loop available for bot provider cleanup task")
|
|
|
|
async def stop_cleanup(self):
|
|
"""Stop the cleanup task"""
|
|
self._shutdown_event.set()
|
|
if self.cleanup_task:
|
|
self.cleanup_task.cancel()
|
|
try:
|
|
await self.cleanup_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def _periodic_cleanup(self):
|
|
"""Periodically clean up stale bot providers"""
|
|
cleanup_interval = 300 # 5 minutes
|
|
stale_threshold = 900 # 15 minutes
|
|
|
|
while not self._shutdown_event.is_set():
|
|
try:
|
|
await asyncio.sleep(cleanup_interval)
|
|
|
|
now = time.time()
|
|
providers_to_remove = []
|
|
|
|
with self.lock:
|
|
for provider_id, provider in self.bot_providers.items():
|
|
if now - provider.last_seen > stale_threshold:
|
|
providers_to_remove.append(provider_id)
|
|
logger.info(f"Marking stale bot provider for removal: {provider.name} (ID: {provider_id}, last_seen: {now - provider.last_seen:.1f}s ago)")
|
|
|
|
if providers_to_remove:
|
|
with self.lock:
|
|
for provider_id in providers_to_remove:
|
|
if provider_id in self.bot_providers:
|
|
del self.bot_providers[provider_id]
|
|
|
|
self._save_bot_providers()
|
|
logger.info(f"Cleaned up {len(providers_to_remove)} stale bot providers")
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in bot provider cleanup: {e}")
|
|
|
|
def _save_bot_providers(self):
|
|
"""Save bot providers to disk"""
|
|
try:
|
|
with self.lock:
|
|
providers_data = {}
|
|
for provider_id, provider in self.bot_providers.items():
|
|
providers_data[provider_id] = provider.model_dump()
|
|
|
|
with open(self.bot_providers_file, 'w') as f:
|
|
json.dump(providers_data, f, indent=2)
|
|
logger.debug(f"Saved {len(providers_data)} bot providers to {self.bot_providers_file}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save bot providers: {e}")
|
|
|
|
def _load_bot_providers(self):
|
|
"""Load bot providers from disk"""
|
|
try:
|
|
if not os.path.exists(self.bot_providers_file):
|
|
logger.debug(f"No bot providers file found at {self.bot_providers_file}")
|
|
return
|
|
|
|
with open(self.bot_providers_file, 'r') as f:
|
|
providers_data = json.load(f)
|
|
|
|
with self.lock:
|
|
for provider_id, provider_dict in providers_data.items():
|
|
try:
|
|
provider = BotProviderModel.model_validate(provider_dict)
|
|
self.bot_providers[provider_id] = provider
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load bot provider {provider_id}: {e}")
|
|
|
|
logger.info(f"Loaded {len(self.bot_providers)} bot providers from {self.bot_providers_file}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load bot providers: {e}")
|
|
|
|
async def register_provider(self, request: BotProviderRegisterRequest) -> BotProviderRegisterResponse:
|
|
"""Register a new bot provider with authentication"""
|
|
|
|
# Check if provider authentication is enabled
|
|
allowed_providers = BotProviderConfig.get_allowed_providers()
|
|
if allowed_providers:
|
|
# Authentication is enabled - validate provider key
|
|
if request.provider_key not in allowed_providers:
|
|
logger.warning(f"Rejected bot provider registration with invalid key: {request.provider_key}")
|
|
raise ValueError("Invalid provider key. Bot provider is not authorized to register.")
|
|
|
|
# Require provider_id for static registration
|
|
if not request.provider_id:
|
|
logger.warning("Rejected bot provider registration without provider_id")
|
|
raise ValueError("provider_id is required for bot provider registration.")
|
|
|
|
provider_id = request.provider_id
|
|
|
|
# Remove any existing provider with the same provider_id (this handles restarts/re-registration)
|
|
with self.lock:
|
|
if provider_id in self.bot_providers:
|
|
logger.info(
|
|
f"Removing existing provider with ID: {provider_id} for re-registration"
|
|
)
|
|
del self.bot_providers[provider_id]
|
|
now = time.time()
|
|
|
|
provider = BotProviderModel(
|
|
provider_id=provider_id,
|
|
base_url=request.base_url.rstrip("/"),
|
|
name=request.name,
|
|
description=request.description,
|
|
provider_key=request.provider_key,
|
|
registered_at=now,
|
|
last_seen=now,
|
|
)
|
|
|
|
with self.lock:
|
|
self.bot_providers[provider_id] = provider
|
|
|
|
# Save to disk
|
|
self._save_bot_providers()
|
|
|
|
# Start cleanup task if not already running
|
|
self.start_cleanup()
|
|
|
|
logger.info(f"Registered bot provider: {request.name} at {request.base_url} with key: {request.provider_key}")
|
|
return BotProviderRegisterResponse(provider_id=provider_id)
|
|
|
|
def list_providers(self) -> BotProviderListResponse:
|
|
"""List all registered bot providers (public information only)"""
|
|
with self.lock:
|
|
# Create safe public versions of provider info
|
|
public_providers: List[BotProviderPublicModel] = []
|
|
for provider in self.bot_providers.values():
|
|
public_provider = BotProviderPublicModel(
|
|
provider_id=provider.provider_id,
|
|
name=provider.name,
|
|
description=provider.description,
|
|
registered_at=provider.registered_at,
|
|
last_seen=provider.last_seen
|
|
)
|
|
public_providers.append(public_provider)
|
|
|
|
return BotProviderListResponse(providers=public_providers)
|
|
|
|
async def list_bots(self) -> BotListResponse:
|
|
"""List all available bots from all registered providers"""
|
|
bots: List[BotInfoModel] = []
|
|
providers: Dict[str, str] = {}
|
|
|
|
# Update last_seen timestamps and fetch bots from each provider
|
|
with self.lock:
|
|
providers_copy = dict(self.bot_providers.items())
|
|
|
|
for provider_id, provider in providers_copy.items():
|
|
try:
|
|
# Update last_seen timestamp
|
|
with self.lock:
|
|
if provider_id in self.bot_providers:
|
|
self.bot_providers[provider_id].last_seen = time.time()
|
|
|
|
# Make HTTP request to provider's /bots endpoint
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(f"{provider.base_url}/bots", timeout=5.0)
|
|
if response.status_code == 200:
|
|
# Use Pydantic model to validate the response
|
|
bots_response = BotProviderBotsResponse.model_validate(response.json())
|
|
# Add each bot to the consolidated list
|
|
for bot_info in bots_response.bots:
|
|
bots.append(bot_info)
|
|
providers[bot_info.name] = provider_id
|
|
else:
|
|
logger.warning(f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}")
|
|
except Exception as e:
|
|
logger.error(f"Error fetching bots from provider {provider.name}: {e}")
|
|
continue
|
|
|
|
return BotListResponse(bots=bots, providers=providers)
|
|
|
|
async def get_provider_bots(self, provider_id: str) -> BotProviderBotsResponse:
|
|
"""Get bots from a specific provider"""
|
|
provider = self.get_provider(provider_id)
|
|
if not provider:
|
|
raise ValueError(f"Provider {provider_id} not found")
|
|
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(f"{provider.base_url}/bots", timeout=5.0)
|
|
if response.status_code == 200:
|
|
return BotProviderBotsResponse.model_validate(response.json())
|
|
else:
|
|
logger.warning(
|
|
f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}"
|
|
)
|
|
return BotProviderBotsResponse(bots=[])
|
|
except Exception as e:
|
|
logger.error(f"Error fetching bots from provider {provider.name}: {e}")
|
|
return BotProviderBotsResponse(bots=[])
|
|
|
|
async def request_bot_join(
|
|
self,
|
|
bot_name: str,
|
|
request: BotJoinLobbyRequest,
|
|
session_manager: SessionManager,
|
|
lobby_manager: LobbyManager,
|
|
) -> BotJoinLobbyResponse:
|
|
"""Request a bot to join a specific lobby"""
|
|
|
|
# Find which provider has this bot and determine its media capability
|
|
target_provider_id = request.provider_id
|
|
bot_has_media = False
|
|
|
|
if not target_provider_id:
|
|
# Auto-discover provider for this bot
|
|
with self.lock:
|
|
providers_copy = dict(self.bot_providers.items())
|
|
|
|
for provider_id, provider in providers_copy.items():
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
f"{provider.base_url}/bots", timeout=5.0
|
|
)
|
|
if response.status_code == 200:
|
|
# Use Pydantic model to validate the response
|
|
bots_response = BotProviderBotsResponse.model_validate(
|
|
response.json()
|
|
)
|
|
# Look for the bot by name
|
|
for bot_info in bots_response.bots:
|
|
if bot_info.name == bot_name:
|
|
target_provider_id = provider_id
|
|
bot_has_media = bot_info.has_media
|
|
break
|
|
if target_provider_id:
|
|
break
|
|
except Exception:
|
|
continue
|
|
else:
|
|
# Query the specified provider for bot media capability
|
|
with self.lock:
|
|
if target_provider_id in self.bot_providers:
|
|
provider = self.bot_providers[target_provider_id]
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
f"{provider.base_url}/bots", timeout=5.0
|
|
)
|
|
if response.status_code == 200:
|
|
# Use Pydantic model to validate the response
|
|
bots_response = BotProviderBotsResponse.model_validate(
|
|
response.json()
|
|
)
|
|
# Look for the bot by name
|
|
for bot_info in bots_response.bots:
|
|
if bot_info.name == bot_name:
|
|
bot_has_media = bot_info.has_media
|
|
break
|
|
except Exception:
|
|
# Default to no media if we can't query
|
|
pass
|
|
|
|
if not target_provider_id:
|
|
raise ValueError("Bot or provider not found")
|
|
|
|
with self.lock:
|
|
if target_provider_id not in self.bot_providers:
|
|
raise ValueError("Provider not found")
|
|
provider = self.bot_providers[target_provider_id]
|
|
|
|
# Get the lobby to validate it exists
|
|
lobby = lobby_manager.get_lobby(request.lobby_id)
|
|
if not lobby:
|
|
raise ValueError("Lobby not found")
|
|
|
|
# Create a session for the bot
|
|
bot_session_id = secrets.token_hex(16)
|
|
bot_instance_id = str(uuid.uuid4())
|
|
|
|
# Create the Session object for the bot
|
|
bot_session = session_manager.get_or_create_session(
|
|
bot_session_id, bot_instance_id=bot_instance_id, has_media=bot_has_media
|
|
)
|
|
logger.info(
|
|
f"Created bot session for: {bot_session.getName()} (has_media={bot_has_media})"
|
|
)
|
|
|
|
# Determine server URL for the bot to connect back to
|
|
# Use the server's public URL or construct from request
|
|
server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000")
|
|
if server_base_url.endswith("/"):
|
|
server_base_url = server_base_url[:-1]
|
|
|
|
bot_nick = request.nick or f"{bot_name}-bot-{bot_session_id[:8]}"
|
|
|
|
# Get public URL prefix from environment
|
|
public_url = os.getenv("PUBLIC_URL_PREFIX", "/ai-voicebot")
|
|
|
|
# Prepare the join request for the bot provider
|
|
bot_join_payload = BotJoinPayload(
|
|
lobby_id=request.lobby_id,
|
|
session_id=bot_session_id,
|
|
nick=bot_nick,
|
|
server_url=f"{server_base_url}{public_url}".rstrip("/"),
|
|
insecure=True, # Accept self-signed certificates in development
|
|
)
|
|
|
|
try:
|
|
# Make request to bot provider
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
f"{provider.base_url}/bots/{bot_name}/join",
|
|
json=bot_join_payload.model_dump(),
|
|
timeout=10.0,
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
# Use Pydantic model to parse and validate response
|
|
try:
|
|
join_response = BotProviderJoinResponse.model_validate(
|
|
response.json()
|
|
)
|
|
run_id = join_response.run_id
|
|
|
|
# Update bot session with run and provider information
|
|
bot_session.setName(bot_nick)
|
|
bot_session.bot_run_id = run_id
|
|
bot_session.bot_provider_id = target_provider_id
|
|
|
|
# Create a unique bot instance ID and track the bot instance
|
|
bot_instance = BotInstanceModel(
|
|
bot_instance_id=bot_instance_id,
|
|
bot_name=bot_name,
|
|
nick=bot_nick,
|
|
lobby_id=request.lobby_id,
|
|
session_id=bot_session_id,
|
|
provider_id=target_provider_id,
|
|
run_id=run_id,
|
|
has_media=bot_has_media,
|
|
created_at=time.time(),
|
|
)
|
|
|
|
# Set the bot_instance_id on the session as well
|
|
bot_session.bot_instance_id = bot_instance_id
|
|
|
|
with self.lock:
|
|
self.bot_instances[bot_instance_id] = bot_instance
|
|
|
|
logger.info(
|
|
f"Bot {bot_name} requested to join lobby {request.lobby_id} with instance ID {bot_instance_id}"
|
|
)
|
|
|
|
return BotJoinLobbyResponse(
|
|
status="requested",
|
|
bot_instance_id=bot_instance_id,
|
|
bot_name=bot_name,
|
|
run_id=run_id,
|
|
provider_id=target_provider_id,
|
|
session_id=bot_session_id,
|
|
)
|
|
except ValidationError as e:
|
|
logger.error(f"Invalid response from bot provider: {e}")
|
|
raise ValueError(
|
|
f"Bot provider returned invalid response: {str(e)}"
|
|
)
|
|
else:
|
|
logger.error(
|
|
f"Bot provider returned error: HTTP {response.status_code}: {response.text}"
|
|
)
|
|
raise ValueError(f"Bot provider error: {response.status_code}")
|
|
|
|
except httpx.TimeoutException:
|
|
raise ValueError("Bot provider timeout")
|
|
except Exception as e:
|
|
logger.error(f"Error requesting bot join: {e}")
|
|
raise ValueError(f"Internal server error: {str(e)}")
|
|
|
|
async def request_bot_leave(
|
|
self, request: BotLeaveLobbyRequest, session_manager: SessionManager
|
|
) -> BotLeaveLobbyResponse:
|
|
"""Request a bot to leave from all lobbies and disconnect"""
|
|
|
|
# Find the bot instance
|
|
with self.lock:
|
|
if request.bot_instance_id not in self.bot_instances:
|
|
raise ValueError("Bot instance not found")
|
|
bot_instance = self.bot_instances[request.bot_instance_id]
|
|
|
|
# Find the bot session
|
|
bot_session = session_manager.get_session(bot_instance.session_id)
|
|
if not bot_session:
|
|
raise ValueError("Bot session not found")
|
|
|
|
if not bot_session.bot_instance_id:
|
|
raise ValueError("Session is not a bot")
|
|
|
|
logger.info(
|
|
f"Requesting bot instance {bot_instance.bot_instance_id} to leave all lobbies"
|
|
)
|
|
|
|
# Try to stop the bot at the provider level
|
|
try:
|
|
with self.lock:
|
|
if bot_instance.provider_id in self.bot_providers:
|
|
provider = self.bot_providers[bot_instance.provider_id]
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
f"{provider.base_url}/bots/runs/{bot_instance.run_id}/stop",
|
|
timeout=5.0,
|
|
)
|
|
if response.status_code == 200:
|
|
logger.info(
|
|
f"Successfully requested bot provider to stop run {bot_instance.run_id}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Bot provider returned error when stopping: HTTP {response.status_code}"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to request bot stop from provider: {e}")
|
|
except Exception as e:
|
|
logger.warning(f"Error communicating with bot provider: {e}")
|
|
|
|
# Force disconnect the bot session from all lobbies
|
|
try:
|
|
lobbies_to_leave = bot_session.lobbies[:]
|
|
|
|
for lobby in lobbies_to_leave:
|
|
try:
|
|
await bot_session.leave_lobby(lobby)
|
|
except Exception as e:
|
|
logger.warning(f"Error removing bot from lobby {lobby.name}: {e}")
|
|
|
|
# Close WebSocket connection if it exists
|
|
if bot_session.ws:
|
|
try:
|
|
await bot_session.ws.close()
|
|
except Exception as e:
|
|
logger.warning(f"Error closing bot WebSocket: {e}")
|
|
bot_session.ws = None
|
|
except Exception as e:
|
|
logger.warning(f"Error disconnecting bot session: {e}")
|
|
|
|
# Remove bot instance from tracking
|
|
with self.lock:
|
|
if request.bot_instance_id in self.bot_instances:
|
|
del self.bot_instances[request.bot_instance_id]
|
|
|
|
return BotLeaveLobbyResponse(
|
|
status="disconnected",
|
|
bot_instance_id=request.bot_instance_id,
|
|
session_id=bot_instance.session_id,
|
|
run_id=bot_instance.run_id,
|
|
)
|
|
|
|
async def get_bot_instance(self, bot_instance_id: str) -> BotInstanceModel:
|
|
"""Get information about a specific bot instance"""
|
|
with self.lock:
|
|
if bot_instance_id not in self.bot_instances:
|
|
raise ValueError("Bot instance not found")
|
|
bot_instance = self.bot_instances[bot_instance_id]
|
|
|
|
return bot_instance
|
|
|
|
def get_bot_instance_id_by_session_id(self, session_id: str) -> Optional[str]:
|
|
"""Get bot_instance_id by session_id"""
|
|
with self.lock:
|
|
for bot_instance_id, bot_instance in self.bot_instances.items():
|
|
if bot_instance.session_id == session_id:
|
|
return bot_instance_id
|
|
return None
|
|
|
|
def get_provider(self, provider_id: str) -> Optional[BotProviderModel]:
|
|
"""Get a specific bot provider by ID"""
|
|
with self.lock:
|
|
return self.bot_providers.get(provider_id)
|