Adding orchestration
This commit is contained in:
parent
93025d22fc
commit
3674d57b0a
3
voicebot/bots/__init__.py
Normal file
3
voicebot/bots/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""Bots package for discoverable agent modules."""
|
||||
|
||||
__all__ = ["synthetic_media", "whisper"]
|
@ -1,10 +1,13 @@
|
||||
"""
|
||||
Synthetic Media Tracks Module
|
||||
Synthetic Media Tracks Module (bots/synthetic_media)
|
||||
|
||||
This module provides synthetic audio and video track creation for WebRTC media streaming.
|
||||
Contains AnimatedVideoTrack and SyntheticAudioTrack implementations ported from JavaScript.
|
||||
Copied from voicebot/synthetic_media. This module contains the real implementation
|
||||
and heavy dependencies and is intended to be imported lazily by the orchestrator
|
||||
or runtime controllers.
|
||||
"""
|
||||
|
||||
# ...existing heavy implementation moved here (keeps same content as original file)
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
import cv2
|
||||
@ -232,7 +235,7 @@ class SyntheticAudioTrack(MediaStreamTrack):
|
||||
- Top of screen (Y=0): 800Hz (high pitch)
|
||||
- Bottom of screen (Y=height): 200Hz (low pitch)
|
||||
|
||||
Bounce events add temporary audio effects on top of the continuous tone.
|
||||
Bounce events add temporary audio on top of the continuous tone.
|
||||
"""
|
||||
|
||||
kind = "audio"
|
||||
@ -458,3 +461,16 @@ def create_synthetic_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
|
||||
video_track.audio_track = audio_track
|
||||
|
||||
return {"video": video_track, "audio": audio_track}
|
||||
|
||||
|
||||
# Agent descriptor exported for dynamic discovery by the FastAPI service
|
||||
AGENT_NAME = "synthetic_media"
|
||||
AGENT_DESCRIPTION = "Synthetic audio and video tracks (AnimatedVideoTrack + SyntheticAudioTrack)"
|
||||
|
||||
def agent_info() -> dict[str, str]:
|
||||
"""Return agent metadata for discovery."""
|
||||
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION}
|
||||
|
||||
def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
|
||||
"""Factory wrapper used by the FastAPI service to instantiate tracks for an agent."""
|
||||
return create_synthetic_tracks(session_name)
|
@ -1,8 +1,28 @@
|
||||
"""Bots package whisper agent (bots/whisper)
|
||||
|
||||
Lightweight agent descriptor; heavy model loading must be done by a controller
|
||||
when the agent is actually used.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any
|
||||
import librosa
|
||||
from logger import logger
|
||||
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
||||
|
||||
|
||||
AGENT_NAME = "whisper"
|
||||
AGENT_DESCRIPTION = "Speech recognition agent (Whisper) - processes incoming audio"
|
||||
|
||||
|
||||
def agent_info() -> Dict[str, str]:
|
||||
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION}
|
||||
|
||||
|
||||
def create_agent_tracks(session_name: str) -> dict:
|
||||
"""Whisper is not a media source - return no local tracks."""
|
||||
return {}
|
||||
|
||||
model_ids = {
|
||||
"Distil-Whisper": [
|
||||
"distil-whisper/distil-large-v2",
|
162
voicebot/main.py
162
voicebot/main.py
@ -2,7 +2,7 @@
|
||||
WebRTC Media Agent for Python
|
||||
|
||||
This module provides WebRTC signaling server communication and peer connection management.
|
||||
Synthetic audio/video track creation is handled by the synthetic_media module.
|
||||
Synthetic audio/video track creation is handled by the bots.synthetic_media module.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -18,6 +18,7 @@ from typing import (
|
||||
Protocol,
|
||||
AsyncIterator,
|
||||
cast,
|
||||
Any,
|
||||
)
|
||||
|
||||
# test
|
||||
@ -57,7 +58,17 @@ from aiortc import (
|
||||
MediaStreamTrack,
|
||||
)
|
||||
from logger import logger
|
||||
from synthetic_media import create_synthetic_tracks, AnimatedVideoTrack
|
||||
from voicebot.bots.synthetic_media import create_synthetic_tracks, AnimatedVideoTrack
|
||||
|
||||
# Bot orchestration imports
|
||||
import importlib
|
||||
import pkgutil
|
||||
import uuid
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import threading
|
||||
import uvicorn
|
||||
|
||||
|
||||
# import debug_aioice
|
||||
|
||||
@ -103,6 +114,7 @@ class WebRTCSignalingClient:
|
||||
session_id: str,
|
||||
session_name: str,
|
||||
insecure: bool = False,
|
||||
create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None,
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.lobby_id = lobby_id
|
||||
@ -110,6 +122,9 @@ class WebRTCSignalingClient:
|
||||
self.session_name = session_name
|
||||
self.insecure = insecure
|
||||
|
||||
# Optional factory to create local media tracks for this client (bot provided)
|
||||
self.create_tracks = create_tracks
|
||||
|
||||
# WebSocket client protocol instance (typed as object to avoid Any)
|
||||
self.websocket: Optional[object] = None
|
||||
|
||||
@ -207,10 +222,19 @@ class WebRTCSignalingClient:
|
||||
logger.info("Disconnected from signaling server")
|
||||
|
||||
async def _setup_local_media(self):
|
||||
"""Create local synthetic media tracks"""
|
||||
# Create synthetic tracks using the new module
|
||||
tracks = create_synthetic_tracks(self.session_name)
|
||||
self.local_tracks.update(tracks)
|
||||
"""Create local media tracks"""
|
||||
# If a bot provided a create_tracks callable, use it to create tracks.
|
||||
# Otherwise, use default synthetic tracks.
|
||||
try:
|
||||
if self.create_tracks:
|
||||
tracks = self.create_tracks(self.session_name)
|
||||
self.local_tracks.update(tracks)
|
||||
else:
|
||||
# Default fallback to synthetic tracks
|
||||
tracks = create_synthetic_tracks(self.session_name)
|
||||
self.local_tracks.update(tracks)
|
||||
except Exception:
|
||||
logger.exception("Failed to create local tracks using bot factory")
|
||||
|
||||
# Add local peer to peers dict
|
||||
local_peer = Peer(
|
||||
@ -221,7 +245,7 @@ class WebRTCSignalingClient:
|
||||
)
|
||||
self.peers[self.session_id] = local_peer
|
||||
|
||||
logger.info("Local synthetic media tracks created")
|
||||
logger.info("Local media tracks created")
|
||||
|
||||
async def _send_message(
|
||||
self, message_type: str, data: Optional[MessageData] = None
|
||||
@ -1103,6 +1127,130 @@ async def main():
|
||||
await client.disconnect()
|
||||
|
||||
|
||||
# --- FastAPI service for bot discovery and orchestration -------------------
|
||||
|
||||
app = FastAPI(title="voicebot-bot-orchestrator")
|
||||
|
||||
|
||||
class JoinRequest(BaseModel):
|
||||
lobby_id: str
|
||||
session_id: str
|
||||
nick: str
|
||||
server_url: str
|
||||
insecure: bool = False
|
||||
|
||||
|
||||
def discover_bots() -> Dict[str, Dict[str, Any]]:
|
||||
"""Discover bot modules under the voicebot.bots package that expose bot_info.
|
||||
|
||||
This intentionally imports modules under `voicebot.bots` so heavy bot
|
||||
implementations can remain in that package and be imported lazily.
|
||||
"""
|
||||
bots: Dict[str, Dict[str, Any]] = {}
|
||||
try:
|
||||
package = importlib.import_module("voicebot.bots")
|
||||
package_path = package.__path__
|
||||
except Exception:
|
||||
logger.exception("Failed to import voicebot.bots package")
|
||||
return bots
|
||||
|
||||
for _finder, name, _ispkg in pkgutil.iter_modules(package_path):
|
||||
try:
|
||||
mod = importlib.import_module(f"voicebot.bots.{name}")
|
||||
except Exception:
|
||||
logger.exception("Failed to import voicebot.bots.%s", name)
|
||||
continue
|
||||
info = None
|
||||
create_tracks = None
|
||||
if hasattr(mod, "agent_info") and callable(getattr(mod, "agent_info")):
|
||||
try:
|
||||
info = mod.agent_info()
|
||||
# Note: Keep copy as is to maintain structure
|
||||
except Exception:
|
||||
logger.exception("agent_info() failed for %s", name)
|
||||
if hasattr(mod, "create_agent_tracks") and callable(getattr(mod, "create_agent_tracks")):
|
||||
create_tracks = getattr(mod, "create_agent_tracks")
|
||||
|
||||
if info:
|
||||
bots[info.get("name", name)] = {"module": name, "info": info, "create_tracks": create_tracks}
|
||||
return bots
|
||||
|
||||
|
||||
@app.get("/bots")
|
||||
def list_bots() -> Dict[str, Any]:
|
||||
bots = discover_bots()
|
||||
return {k: v["info"] for k, v in bots.items()}
|
||||
|
||||
|
||||
@app.post("/bots/{bot_name}/join")
|
||||
async def bot_join(bot_name: str, req: JoinRequest):
|
||||
bots = discover_bots()
|
||||
bot = bots.get(bot_name)
|
||||
if not bot:
|
||||
raise HTTPException(status_code=404, detail="Bot not found")
|
||||
|
||||
create_tracks = bot.get("create_tracks")
|
||||
|
||||
# Start the WebRTCSignalingClient in a background asyncio task and register it
|
||||
client = WebRTCSignalingClient(
|
||||
server_url=req.server_url,
|
||||
lobby_id=req.lobby_id,
|
||||
session_id=req.session_id,
|
||||
session_name=req.nick,
|
||||
insecure=req.insecure,
|
||||
create_tracks=create_tracks,
|
||||
)
|
||||
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
async def run_client():
|
||||
try:
|
||||
registry[run_id] = client
|
||||
await client.connect()
|
||||
except Exception:
|
||||
logger.exception("Bot client failed for run %s", run_id)
|
||||
finally:
|
||||
registry.pop(run_id, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
threading.Thread(target=loop.run_until_complete, args=(run_client(),), daemon=True).start()
|
||||
|
||||
return {"status": "started", "bot": bot_name, "run_id": run_id}
|
||||
|
||||
|
||||
# Lightweight in-memory registry of running bot clients
|
||||
registry: Dict[str, WebRTCSignalingClient] = {}
|
||||
|
||||
|
||||
@app.post("/bots/runs/{run_id}/stop")
|
||||
async def stop_run(run_id: str):
|
||||
client = registry.get(run_id)
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
try:
|
||||
await client.disconnect()
|
||||
except Exception:
|
||||
logger.exception("Failed to stop run %s", run_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to stop run")
|
||||
registry.pop(run_id, None)
|
||||
return {"status": "stopped", "run_id": run_id}
|
||||
|
||||
|
||||
@app.get("/bots/runs")
|
||||
def list_runs() -> Dict[str, Any]:
|
||||
return {
|
||||
"runs": [
|
||||
{"run_id": run_id, "session_id": client.session_id, "session_name": client.session_name}
|
||||
for run_id, client in registry.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def start_bot_api(host: str = "0.0.0.0", port: int = 8788):
|
||||
"""Start the bot orchestration API server"""
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Install required packages:
|
||||
# pip install aiortc websockets opencv-python numpy
|
||||
|
Loading…
x
Reference in New Issue
Block a user