239 lines
7.6 KiB
Python
239 lines
7.6 KiB
Python
"""
|
|
Bot orchestrator FastAPI service.
|
|
|
|
This module provides the FastAPI service for bot discovery and orchestration.
|
|
"""
|
|
|
|
import asyncio
|
|
import threading
|
|
import uuid
|
|
import importlib
|
|
import pkgutil
|
|
import sys
|
|
import os
|
|
from typing import Dict, Any
|
|
|
|
# Add the parent directory to sys.path to allow absolute imports
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
import uvicorn
|
|
from fastapi import FastAPI, HTTPException
|
|
|
|
from logger import logger
|
|
from voicebot.models import JoinRequest
|
|
from voicebot.webrtc_signaling import WebRTCSignalingClient
|
|
|
|
|
|
app = FastAPI(title="voicebot-bot-orchestrator")
|
|
|
|
# Lightweight in-memory registry of running bot clients
|
|
registry: Dict[str, WebRTCSignalingClient] = {}
|
|
|
|
|
|
def discover_bots() -> Dict[str, Dict[str, Any]]:
|
|
"""Discover bot modules under the voicebot.bots package that expose bot_info.
|
|
|
|
This intentionally imports modules under `voicebot.bots` so heavy bot
|
|
implementations can remain in that package and be imported lazily.
|
|
"""
|
|
bots: Dict[str, Dict[str, Any]] = {}
|
|
try:
|
|
package = importlib.import_module("voicebot.bots")
|
|
package_path = package.__path__
|
|
except Exception:
|
|
logger.exception("Failed to import voicebot.bots package")
|
|
return bots
|
|
|
|
for _finder, name, _ispkg in pkgutil.iter_modules(package_path):
|
|
try:
|
|
mod = importlib.import_module(f"voicebot.bots.{name}")
|
|
except Exception:
|
|
logger.exception("Failed to import voicebot.bots.%s", name)
|
|
continue
|
|
info = None
|
|
create_tracks = None
|
|
if hasattr(mod, "agent_info") and callable(getattr(mod, "agent_info")):
|
|
try:
|
|
info = mod.agent_info()
|
|
# Note: Keep copy as is to maintain structure
|
|
except Exception:
|
|
logger.exception("agent_info() failed for %s", name)
|
|
if hasattr(mod, "create_agent_tracks") and callable(getattr(mod, "create_agent_tracks")):
|
|
create_tracks = getattr(mod, "create_agent_tracks")
|
|
|
|
if info:
|
|
bots[info.get("name", name)] = {"module": name, "info": info, "create_tracks": create_tracks}
|
|
return bots
|
|
|
|
|
|
@app.get("/bots")
|
|
def list_bots() -> Dict[str, Any]:
|
|
"""List available bots."""
|
|
bots = discover_bots()
|
|
return {k: v["info"] for k, v in bots.items()}
|
|
|
|
|
|
@app.post("/bots/{bot_name}/join")
|
|
async def bot_join(bot_name: str, req: JoinRequest):
|
|
"""Make a bot join a lobby."""
|
|
bots = discover_bots()
|
|
bot = bots.get(bot_name)
|
|
if not bot:
|
|
raise HTTPException(status_code=404, detail="Bot not found")
|
|
|
|
create_tracks = bot.get("create_tracks")
|
|
|
|
# Start the WebRTCSignalingClient in a background asyncio task and register it
|
|
client = WebRTCSignalingClient(
|
|
server_url=req.server_url,
|
|
lobby_id=req.lobby_id,
|
|
session_id=req.session_id,
|
|
session_name=req.nick,
|
|
insecure=req.insecure,
|
|
create_tracks=create_tracks,
|
|
)
|
|
|
|
run_id = str(uuid.uuid4())
|
|
|
|
async def run_client():
|
|
try:
|
|
registry[run_id] = client
|
|
await client.connect()
|
|
except Exception:
|
|
logger.exception("Bot client failed for run %s", run_id)
|
|
finally:
|
|
registry.pop(run_id, None)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
threading.Thread(target=loop.run_until_complete, args=(run_client(),), daemon=True).start()
|
|
|
|
return {"status": "started", "bot": bot_name, "run_id": run_id}
|
|
|
|
|
|
@app.post("/bots/runs/{run_id}/stop")
|
|
async def stop_run(run_id: str):
|
|
"""Stop a running bot."""
|
|
client = registry.get(run_id)
|
|
if not client:
|
|
raise HTTPException(status_code=404, detail="Run not found")
|
|
try:
|
|
await client.disconnect()
|
|
except Exception:
|
|
logger.exception("Failed to stop run %s", run_id)
|
|
raise HTTPException(status_code=500, detail="Failed to stop run")
|
|
registry.pop(run_id, None)
|
|
return {"status": "stopped", "run_id": run_id}
|
|
|
|
|
|
@app.get("/bots/runs")
|
|
def list_runs() -> Dict[str, Any]:
|
|
"""List running bot instances."""
|
|
return {
|
|
"runs": [
|
|
{"run_id": run_id, "session_id": client.session_id, "session_name": client.session_name}
|
|
for run_id, client in registry.items()
|
|
]
|
|
}
|
|
|
|
|
|
def start_bot_api(host: str = "0.0.0.0", port: int = 8788):
|
|
"""Start the bot orchestration API server"""
|
|
uvicorn.run(app, host=host, port=port)
|
|
|
|
|
|
async def register_with_server(server_url: str, voicebot_url: str, insecure: bool = False) -> str:
|
|
"""Register this voicebot instance as a bot provider with the main server"""
|
|
try:
|
|
# Import httpx locally to avoid dependency issues
|
|
import httpx
|
|
|
|
payload = {
|
|
"base_url": voicebot_url.rstrip('/'),
|
|
"name": "voicebot-provider",
|
|
"description": "AI voicebot provider with speech recognition and synthetic media capabilities"
|
|
}
|
|
|
|
# Prepare SSL context if needed
|
|
verify = not insecure
|
|
|
|
async with httpx.AsyncClient(verify=verify) as client:
|
|
response = await client.post(
|
|
f"{server_url}/api/bots/providers/register",
|
|
json=payload,
|
|
timeout=10.0
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
provider_id = result.get("provider_id")
|
|
logger.info(f"Successfully registered with server as provider: {provider_id}")
|
|
return provider_id
|
|
else:
|
|
logger.error(f"Failed to register with server: HTTP {response.status_code}: {response.text}")
|
|
raise RuntimeError(f"Registration failed: {response.status_code}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error registering with server: {e}")
|
|
raise
|
|
|
|
|
|
def start_bot_provider(
|
|
host: str = "0.0.0.0",
|
|
port: int = 8788,
|
|
server_url: str | None = None,
|
|
insecure: bool = False,
|
|
reload: bool = False
|
|
):
|
|
"""Start the bot provider API server and optionally register with main server"""
|
|
import time
|
|
import socket
|
|
|
|
# Start the FastAPI server in a background thread
|
|
# Add reload functionality for development
|
|
if reload:
|
|
server_thread = threading.Thread(
|
|
target=lambda: uvicorn.run(
|
|
app,
|
|
host=host,
|
|
port=port,
|
|
log_level="info",
|
|
reload=True,
|
|
reload_dirs=["/voicebot", "/shared"]
|
|
),
|
|
daemon=True
|
|
)
|
|
else:
|
|
server_thread = threading.Thread(
|
|
target=lambda: uvicorn.run(app, host=host, port=port, log_level="info"),
|
|
daemon=True
|
|
)
|
|
logger.info(f"Starting bot provider API server on {host}:{port}...")
|
|
server_thread.start()
|
|
|
|
# If server_url is provided, register with the main server
|
|
if server_url:
|
|
# Give the server a moment to start
|
|
time.sleep(2)
|
|
|
|
# Construct the voicebot URL
|
|
voicebot_url = f"http://{host}:{port}"
|
|
if host == "0.0.0.0":
|
|
# Try to get a better hostname
|
|
try:
|
|
hostname = socket.gethostname()
|
|
voicebot_url = f"http://{hostname}:{port}"
|
|
except Exception:
|
|
voicebot_url = f"http://localhost:{port}"
|
|
|
|
try:
|
|
asyncio.run(register_with_server(server_url, voicebot_url, insecure))
|
|
except Exception as e:
|
|
logger.error(f"Failed to register with server: {e}")
|
|
|
|
# Keep the main thread alive
|
|
try:
|
|
while True:
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
logger.info("Shutting down bot provider...")
|