ai-voicebot/server/core/bot_manager.py
2025-09-08 13:02:57 -07:00

581 lines
24 KiB
Python

"""Bot Provider Management - Test fix for stale lobby IDs after restart"""
import os
import time
import uuid
import secrets
import threading
import httpx
import json
import asyncio
from typing import Dict, List, Optional
from pydantic import ValidationError
from shared.logger import logger
# Import shared models
import sys
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
from shared.models import (
BotProviderModel,
BotProviderPublicModel,
BotProviderRegisterRequest,
BotProviderRegisterResponse,
BotProviderListResponse,
BotListResponse,
BotInfoModel,
BotJoinLobbyRequest,
BotJoinLobbyResponse,
BotLeaveLobbyRequest,
BotLeaveLobbyResponse,
BotProviderBotsResponse,
BotProviderJoinResponse,
BotJoinPayload,
BotInstanceModel,
)
from core.session_manager import SessionManager
from core.lobby_manager import LobbyManager
class BotProviderConfig:
"""Configuration class for bot provider management"""
# Comma-separated list of allowed provider keys
# Format: "key1:name1,key2:name2" or just "key1,key2" (names default to keys)
ALLOWED_PROVIDERS = os.getenv("BOT_PROVIDER_KEYS", "")
@classmethod
def get_allowed_providers(cls) -> Dict[str, str]:
"""Parse allowed providers from environment variable
Returns:
dict mapping provider_key -> provider_name
"""
if not cls.ALLOWED_PROVIDERS.strip():
return {}
providers: Dict[str, str] = {}
for entry in cls.ALLOWED_PROVIDERS.split(","):
entry = entry.strip()
if not entry:
continue
if ":" in entry:
key, name = entry.split(":", 1)
providers[key.strip()] = name.strip()
else:
providers[entry] = entry
return providers
class BotManager:
"""Manages bot providers and bot lifecycle"""
def __init__(self):
self.bot_providers: Dict[str, BotProviderModel] = {}
self.bot_instances: Dict[
str, BotInstanceModel
] = {} # bot_instance_id -> BotInstanceModel
self.lock = threading.RLock()
self.bot_providers_file = "bot_providers.json"
self.cleanup_task: Optional["asyncio.Task[None]"] = None
self._shutdown_event = asyncio.Event()
# Check if provider authentication is enabled
allowed_providers = BotProviderConfig.get_allowed_providers()
if not allowed_providers:
logger.warning("Bot provider authentication disabled. Any provider can register.")
# Load persisted bot providers
self._load_bot_providers()
# Note: Don't start cleanup task here - will be started when needed
def start_cleanup(self):
"""Start the cleanup task"""
try:
if self.cleanup_task is None:
self.cleanup_task = asyncio.create_task(self._periodic_cleanup())
logger.debug("Bot provider cleanup task started")
except RuntimeError:
# No event loop running yet, cleanup will be started later
logger.debug("No event loop available for bot provider cleanup task")
async def stop_cleanup(self):
"""Stop the cleanup task"""
self._shutdown_event.set()
if self.cleanup_task:
self.cleanup_task.cancel()
try:
await self.cleanup_task
except asyncio.CancelledError:
pass
async def _periodic_cleanup(self):
"""Periodically clean up stale bot providers"""
cleanup_interval = 300 # 5 minutes
stale_threshold = 900 # 15 minutes
while not self._shutdown_event.is_set():
try:
await asyncio.sleep(cleanup_interval)
now = time.time()
providers_to_remove = []
with self.lock:
for provider_id, provider in self.bot_providers.items():
if now - provider.last_seen > stale_threshold:
providers_to_remove.append(provider_id)
logger.info(f"Marking stale bot provider for removal: {provider.name} (ID: {provider_id}, last_seen: {now - provider.last_seen:.1f}s ago)")
if providers_to_remove:
with self.lock:
for provider_id in providers_to_remove:
if provider_id in self.bot_providers:
del self.bot_providers[provider_id]
self._save_bot_providers()
logger.info(f"Cleaned up {len(providers_to_remove)} stale bot providers")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in bot provider cleanup: {e}")
def _save_bot_providers(self):
"""Save bot providers to disk"""
try:
with self.lock:
providers_data = {}
for provider_id, provider in self.bot_providers.items():
providers_data[provider_id] = provider.model_dump()
with open(self.bot_providers_file, 'w') as f:
json.dump(providers_data, f, indent=2)
logger.debug(f"Saved {len(providers_data)} bot providers to {self.bot_providers_file}")
except Exception as e:
logger.error(f"Failed to save bot providers: {e}")
def _load_bot_providers(self):
"""Load bot providers from disk"""
try:
if not os.path.exists(self.bot_providers_file):
logger.debug(f"No bot providers file found at {self.bot_providers_file}")
return
with open(self.bot_providers_file, 'r') as f:
providers_data = json.load(f)
with self.lock:
for provider_id, provider_dict in providers_data.items():
try:
provider = BotProviderModel.model_validate(provider_dict)
self.bot_providers[provider_id] = provider
except Exception as e:
logger.warning(f"Failed to load bot provider {provider_id}: {e}")
logger.info(f"Loaded {len(self.bot_providers)} bot providers from {self.bot_providers_file}")
except Exception as e:
logger.error(f"Failed to load bot providers: {e}")
async def register_provider(self, request: BotProviderRegisterRequest) -> BotProviderRegisterResponse:
"""Register a new bot provider with authentication"""
# Check if provider authentication is enabled
allowed_providers = BotProviderConfig.get_allowed_providers()
if allowed_providers:
# Authentication is enabled - validate provider key
if request.provider_key not in allowed_providers:
logger.warning(f"Rejected bot provider registration with invalid key: {request.provider_key}")
raise ValueError("Invalid provider key. Bot provider is not authorized to register.")
# Require provider_id for static registration
if not request.provider_id:
logger.warning("Rejected bot provider registration without provider_id")
raise ValueError("provider_id is required for bot provider registration.")
provider_id = request.provider_id
# Remove any existing provider with the same provider_id (this handles restarts/re-registration)
with self.lock:
if provider_id in self.bot_providers:
logger.info(
f"Removing existing provider with ID: {provider_id} for re-registration"
)
del self.bot_providers[provider_id]
now = time.time()
provider = BotProviderModel(
provider_id=provider_id,
base_url=request.base_url.rstrip("/"),
name=request.name,
description=request.description,
provider_key=request.provider_key,
registered_at=now,
last_seen=now,
)
with self.lock:
self.bot_providers[provider_id] = provider
# Save to disk
self._save_bot_providers()
# Start cleanup task if not already running
self.start_cleanup()
logger.info(f"Registered bot provider: {request.name} at {request.base_url} with key: {request.provider_key}")
return BotProviderRegisterResponse(provider_id=provider_id)
def list_providers(self) -> BotProviderListResponse:
"""List all registered bot providers (public information only)"""
with self.lock:
# Create safe public versions of provider info
public_providers: List[BotProviderPublicModel] = []
for provider in self.bot_providers.values():
public_provider = BotProviderPublicModel(
provider_id=provider.provider_id,
name=provider.name,
description=provider.description,
registered_at=provider.registered_at,
last_seen=provider.last_seen
)
public_providers.append(public_provider)
return BotProviderListResponse(providers=public_providers)
async def list_bots(self) -> BotListResponse:
"""List all available bots from all registered providers"""
bots: List[BotInfoModel] = []
providers: Dict[str, str] = {}
# Update last_seen timestamps and fetch bots from each provider
with self.lock:
providers_copy = dict(self.bot_providers.items())
for provider_id, provider in providers_copy.items():
try:
# Update last_seen timestamp
with self.lock:
if provider_id in self.bot_providers:
self.bot_providers[provider_id].last_seen = time.time()
# Make HTTP request to provider's /bots endpoint
async with httpx.AsyncClient() as client:
response = await client.get(f"{provider.base_url}/bots", timeout=5.0)
if response.status_code == 200:
# Use Pydantic model to validate the response
bots_response = BotProviderBotsResponse.model_validate(response.json())
# Add each bot to the consolidated list
for bot_info in bots_response.bots:
bots.append(bot_info)
providers[bot_info.name] = provider_id
else:
logger.warning(f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}")
except Exception as e:
logger.error(f"Error fetching bots from provider {provider.name}: {e}")
continue
return BotListResponse(bots=bots, providers=providers)
async def get_provider_bots(self, provider_id: str) -> BotProviderBotsResponse:
"""Get bots from a specific provider"""
provider = self.get_provider(provider_id)
if not provider:
raise ValueError(f"Provider {provider_id} not found")
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{provider.base_url}/bots", timeout=5.0)
if response.status_code == 200:
return BotProviderBotsResponse.model_validate(response.json())
else:
logger.warning(
f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}"
)
return BotProviderBotsResponse(bots=[])
except Exception as e:
logger.error(f"Error fetching bots from provider {provider.name}: {e}")
return BotProviderBotsResponse(bots=[])
async def request_bot_join(
self,
bot_name: str,
request: BotJoinLobbyRequest,
session_manager: SessionManager,
lobby_manager: LobbyManager,
) -> BotJoinLobbyResponse:
"""Request a bot to join a specific lobby"""
# Find which provider has this bot and determine its media capability
target_provider_id = request.provider_id
bot_has_media = False
if not target_provider_id:
# Auto-discover provider for this bot
with self.lock:
providers_copy = dict(self.bot_providers.items())
for provider_id, provider in providers_copy.items():
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{provider.base_url}/bots", timeout=5.0
)
if response.status_code == 200:
# Use Pydantic model to validate the response
bots_response = BotProviderBotsResponse.model_validate(
response.json()
)
# Look for the bot by name
for bot_info in bots_response.bots:
if bot_info.name == bot_name:
target_provider_id = provider_id
bot_has_media = bot_info.has_media
break
if target_provider_id:
break
except Exception:
continue
else:
# Query the specified provider for bot media capability
with self.lock:
if target_provider_id in self.bot_providers:
provider = self.bot_providers[target_provider_id]
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{provider.base_url}/bots", timeout=5.0
)
if response.status_code == 200:
# Use Pydantic model to validate the response
bots_response = BotProviderBotsResponse.model_validate(
response.json()
)
# Look for the bot by name
for bot_info in bots_response.bots:
if bot_info.name == bot_name:
bot_has_media = bot_info.has_media
break
except Exception:
# Default to no media if we can't query
pass
if not target_provider_id:
raise ValueError("Bot or provider not found")
with self.lock:
if target_provider_id not in self.bot_providers:
raise ValueError("Provider not found")
provider = self.bot_providers[target_provider_id]
# Get the lobby to validate it exists
lobby = lobby_manager.get_lobby(request.lobby_id)
if not lobby:
raise ValueError("Lobby not found")
# Create a session for the bot
bot_session_id = secrets.token_hex(16)
bot_instance_id = str(uuid.uuid4())
# Create the Session object for the bot
bot_session = session_manager.get_or_create_session(
bot_session_id, bot_instance_id=bot_instance_id, has_media=bot_has_media
)
logger.info(
f"Created bot session for: {bot_session.getName()} (has_media={bot_has_media})"
)
# Determine server URL for the bot to connect back to
# Use the server's public URL or construct from request
server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000")
if server_base_url.endswith("/"):
server_base_url = server_base_url[:-1]
bot_nick = request.nick or f"{bot_name}-bot-{bot_session_id[:8]}"
# Get public URL prefix from environment
public_url = os.getenv("PUBLIC_URL_PREFIX", "/ai-voicebot")
# Prepare the join request for the bot provider
bot_join_payload = BotJoinPayload(
lobby_id=request.lobby_id,
session_id=bot_session_id,
nick=bot_nick,
server_url=f"{server_base_url}{public_url}".rstrip("/"),
insecure=True, # Accept self-signed certificates in development
)
try:
# Make request to bot provider
async with httpx.AsyncClient() as client:
response = await client.post(
f"{provider.base_url}/bots/{bot_name}/join",
json=bot_join_payload.model_dump(),
timeout=10.0,
)
if response.status_code == 200:
# Use Pydantic model to parse and validate response
try:
join_response = BotProviderJoinResponse.model_validate(
response.json()
)
run_id = join_response.run_id
# Update bot session with run and provider information
bot_session.setName(bot_nick)
bot_session.bot_run_id = run_id
bot_session.bot_provider_id = target_provider_id
# Create a unique bot instance ID and track the bot instance
bot_instance = BotInstanceModel(
bot_instance_id=bot_instance_id,
bot_name=bot_name,
nick=bot_nick,
lobby_id=request.lobby_id,
session_id=bot_session_id,
provider_id=target_provider_id,
run_id=run_id,
has_media=bot_has_media,
created_at=time.time(),
)
# Set the bot_instance_id on the session as well
bot_session.bot_instance_id = bot_instance_id
with self.lock:
self.bot_instances[bot_instance_id] = bot_instance
logger.info(
f"Bot {bot_name} requested to join lobby {request.lobby_id} with instance ID {bot_instance_id}"
)
return BotJoinLobbyResponse(
status="requested",
bot_instance_id=bot_instance_id,
bot_name=bot_name,
run_id=run_id,
provider_id=target_provider_id,
session_id=bot_session_id,
)
except ValidationError as e:
logger.error(f"Invalid response from bot provider: {e}")
raise ValueError(
f"Bot provider returned invalid response: {str(e)}"
)
else:
logger.error(
f"Bot provider returned error: HTTP {response.status_code}: {response.text}"
)
raise ValueError(f"Bot provider error: {response.status_code}")
except httpx.TimeoutException:
raise ValueError("Bot provider timeout")
except Exception as e:
logger.error(f"Error requesting bot join: {e}")
raise ValueError(f"Internal server error: {str(e)}")
async def request_bot_leave(
self, request: BotLeaveLobbyRequest, session_manager: SessionManager
) -> BotLeaveLobbyResponse:
"""Request a bot to leave from all lobbies and disconnect"""
# Find the bot instance
with self.lock:
if request.bot_instance_id not in self.bot_instances:
raise ValueError("Bot instance not found")
bot_instance = self.bot_instances[request.bot_instance_id]
# Find the bot session
bot_session = session_manager.get_session(bot_instance.session_id)
if not bot_session:
raise ValueError("Bot session not found")
if not bot_session.bot_instance_id:
raise ValueError("Session is not a bot")
logger.info(
f"Requesting bot instance {bot_instance.bot_instance_id} to leave all lobbies"
)
# Try to stop the bot at the provider level
try:
with self.lock:
if bot_instance.provider_id in self.bot_providers:
provider = self.bot_providers[bot_instance.provider_id]
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{provider.base_url}/bots/runs/{bot_instance.run_id}/stop",
timeout=5.0,
)
if response.status_code == 200:
logger.info(
f"Successfully requested bot provider to stop run {bot_instance.run_id}"
)
else:
logger.warning(
f"Bot provider returned error when stopping: HTTP {response.status_code}"
)
except Exception as e:
logger.warning(f"Failed to request bot stop from provider: {e}")
except Exception as e:
logger.warning(f"Error communicating with bot provider: {e}")
# Force disconnect the bot session from all lobbies
try:
lobbies_to_leave = bot_session.lobbies[:]
for lobby in lobbies_to_leave:
try:
await bot_session.leave_lobby(lobby)
except Exception as e:
logger.warning(f"Error removing bot from lobby {lobby.name}: {e}")
# Close WebSocket connection if it exists
if bot_session.ws:
try:
await bot_session.ws.close()
except Exception as e:
logger.warning(f"Error closing bot WebSocket: {e}")
bot_session.ws = None
except Exception as e:
logger.warning(f"Error disconnecting bot session: {e}")
# Remove bot instance from tracking
with self.lock:
if request.bot_instance_id in self.bot_instances:
del self.bot_instances[request.bot_instance_id]
return BotLeaveLobbyResponse(
status="disconnected",
bot_instance_id=request.bot_instance_id,
session_id=bot_instance.session_id,
run_id=bot_instance.run_id,
)
async def get_bot_instance(self, bot_instance_id: str) -> BotInstanceModel:
"""Get information about a specific bot instance"""
with self.lock:
if bot_instance_id not in self.bot_instances:
raise ValueError("Bot instance not found")
bot_instance = self.bot_instances[bot_instance_id]
return bot_instance
def get_bot_instance_id_by_session_id(self, session_id: str) -> Optional[str]:
"""Get bot_instance_id by session_id"""
with self.lock:
for bot_instance_id, bot_instance in self.bot_instances.items():
if bot_instance.session_id == session_id:
return bot_instance_id
return None
def get_provider(self, provider_id: str) -> Optional[BotProviderModel]:
"""Get a specific bot provider by ID"""
with self.lock:
return self.bot_providers.get(provider_id)