From 7042a76d19076b46af715a0bd4940500fe73d28a Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Wed, 3 Sep 2025 17:04:31 -0700 Subject: [PATCH] Refactoring --- client/src/BotManager.tsx | 33 +++--- client/src/GlobalContext.tsx | 1 + client/src/MediaControl.tsx | 78 ++++++++----- client/src/UserList.tsx | 18 ++- client/src/api-client.ts | 2 +- server/main.py | 217 +++++++++++++++++++++++------------ shared/models.py | 20 +++- voicebot/bot_orchestrator.py | 84 +++++++++----- voicebot/bots/chatbot.py | 2 +- voicebot/bots/whisper.py | 2 +- 10 files changed, 301 insertions(+), 156 deletions(-) diff --git a/client/src/BotManager.tsx b/client/src/BotManager.tsx index 5c6735c..f92e27d 100644 --- a/client/src/BotManager.tsx +++ b/client/src/BotManager.tsx @@ -35,7 +35,7 @@ interface BotManagerProps { } const BotManager: React.FC = ({ lobbyId, onBotAdded, sx }) => { - const [bots, setBots] = useState>({}); + const [bots, setBots] = useState([]); const [providers, setProviders] = useState>({}); const [botProviders, setBotProviders] = useState([]); const [loading, setLoading] = useState(false); @@ -49,11 +49,8 @@ const BotManager: React.FC = ({ lobbyId, onBotAdded, sx }) => { setLoading(true); setError(null); try { - const [botsResponse, providersResponse] = await Promise.all([ - botsApi.getAvailable(), - botsApi.getProviders(), - ]); - + const [botsResponse, providersResponse] = await Promise.all([botsApi.getAvailable(), botsApi.getProviders()]); + setBots(botsResponse.bots); setProviders(botsResponse.providers); setBotProviders(providersResponse.providers); @@ -82,13 +79,13 @@ const BotManager: React.FC = ({ lobbyId, onBotAdded, sx }) => { }; const response = await botsApi.requestJoinLobby(selectedBot, request); - + if (response.status === "requested") { setAddDialogOpen(false); setSelectedBot(""); setBotNick(""); onBotAdded?.(selectedBot); - + // Show success feedback could be added here } } catch (err) { @@ -112,11 +109,11 @@ const BotManager: React.FC = ({ lobbyId, onBotAdded, sx }) => { }; const getProviderName = (providerId: string): string => { - const provider = botProviders.find(p => p.provider_id === providerId); + const provider = botProviders.find((p) => p.provider_id === providerId); return provider ? provider.name : "Unknown Provider"; }; - const botCount = Object.keys(bots).length; + const botCount = bots.length; const providerCount = botProviders.length; return ( @@ -169,12 +166,12 @@ const BotManager: React.FC = ({ lobbyId, onBotAdded, sx }) => { - {Object.entries(bots).map(([botName, botInfo]) => { - const providerId = providers[botName]; + {bots.map((botInfo) => { + const providerId = providers[botInfo.name]; const providerName = getProviderName(providerId); return ( - + = ({ lobbyId, onBotAdded, sx }) => { Select Bot - {Object.entries(bots).map(([botName, botInfo]) => ( + {bots.map((botInfo) => ( setSelectedBot(botName)} + onClick={() => setSelectedBot(botInfo.name)} > - + ))} diff --git a/client/src/GlobalContext.tsx b/client/src/GlobalContext.tsx index 7cee9ee..c20aae0 100644 --- a/client/src/GlobalContext.tsx +++ b/client/src/GlobalContext.tsx @@ -6,4 +6,5 @@ export type Lobby = LobbyModel; // Extended Session type that allows name to be null initially (before user sets it) export type Session = Omit & { name: string | null; + has_media?: boolean; // Whether this session provides audio/video streams }; diff --git a/client/src/MediaControl.tsx b/client/src/MediaControl.tsx index 6486580..f7370bf 100644 --- a/client/src/MediaControl.tsx +++ b/client/src/MediaControl.tsx @@ -140,6 +140,7 @@ export type { Peer }; interface AddPeerConfig { peer_id: string; peer_name: string; + has_media?: boolean; // Whether this peer provides audio/video streams should_create_offer?: boolean; } @@ -406,9 +407,14 @@ const MediaAgent = (props: MediaAgentProps) => { } } - // Queue peer if media not ready - if (!media) { - console.log(`media-agent - addPeer:${config.peer_name} - No local media yet, queuing peer`); + // Queue peer if we need local media but don't have it yet + // Only queue if we're expected to provide media (local user has media) + const localUserHasMedia = session?.has_media !== false; // Default to true for backward compatibility + const peerHasMedia = config.has_media !== false; // Default to true for backward compatibility + + // Only need to wait for media if we (local user) are supposed to provide it + if (!media && localUserHasMedia) { + console.log(`media-agent - addPeer:${config.peer_name} - No local media yet, queuing peer`); setPendingPeers((prev) => { // Avoid duplicate queuing if (!prev.some((p) => p.peer_id === config.peer_id)) { @@ -495,7 +501,7 @@ const MediaAgent = (props: MediaAgentProps) => { } console.log(`media-agent - addPeer:${peer.peer_name} Handling negotiationneeded for ${peer.peer_name}`); - + // Mark as negotiating isNegotiatingRef.current.set(peer_id, true); updatePeerConnectionState(peer_id, connection.connectionState, true); @@ -541,10 +547,10 @@ const MediaAgent = (props: MediaAgentProps) => { connection.connectionState, event ); - + // Update peer connection state updatePeerConnectionState(peer_id, connection.connectionState); - + if (connection.connectionState === "failed") { setTimeout(() => { if (connection.connectionState === "failed") { @@ -676,12 +682,20 @@ const MediaAgent = (props: MediaAgentProps) => { } }; - // Add local tracks - console.log(`media-agent - addPeer:${peer.peer_name} Adding local tracks to new peer connection`); - media.getTracks().forEach((t) => { - console.log(`media-agent - addPeer:${peer.peer_name} Adding track:`, t.kind, t.enabled); - connection.addTrack(t, media); - }); + // Add local tracks to the connection only if we have media and it's valid + console.log( + `media-agent - addPeer:${peer.peer_name} Adding local tracks to new peer connection (localHasMedia=${localUserHasMedia})` + ); + if (media && localUserHasMedia) { + media.getTracks().forEach((t) => { + console.log(`media-agent - addPeer:${peer.peer_name} Adding track:`, t.kind, t.enabled); + connection.addTrack(t, media); + }); + } else if (!localUserHasMedia) { + console.log(`media-agent - addPeer:${peer.peer_name} - Local user has no media, skipping track addition`); + } else { + console.log(`media-agent - addPeer:${peer.peer_name} - No local media available yet`); + } // Update peers state setPeers(updatedPeers); @@ -1056,23 +1070,33 @@ const MediaAgent = (props: MediaAgentProps) => { useEffect(() => { mountedRef.current = true; + const localUserHasMedia = session?.has_media !== false; // Default to true for backward compatibility + if (mediaStreamRef.current || readyState !== ReadyState.OPEN) return; - console.log(`media-agent - Setting up local media`); - setup_local_media().then((mediaStream) => { - if (!mountedRef.current) { - // Component unmounted, clean up - mediaStream.getTracks().forEach((track) => { - track.stop(); - if ((track as any).stopAnimation) (track as any).stopAnimation(); - if ((track as any).stopOscillator) (track as any).stopOscillator(); - }); - return; - } + if (localUserHasMedia) { + console.log(`media-agent - Setting up local media`); + setup_local_media().then((mediaStream) => { + if (!mountedRef.current) { + // Component unmounted, clean up + mediaStream.getTracks().forEach((track) => { + track.stop(); + if ((track as any).stopAnimation) (track as any).stopAnimation(); + if ((track as any).stopOscillator) (track as any).stopOscillator(); + }); + return; + } - mediaStreamRef.current = mediaStream; - setMedia(mediaStream); - }); + mediaStreamRef.current = mediaStream; + setMedia(mediaStream); + }); + } else { + console.log(`media-agent - Local user has no media, creating empty stream`); + // Create an empty media stream for users without media + const emptyStream = new MediaStream(); + mediaStreamRef.current = emptyStream; + setMedia(emptyStream); + } return () => { mountedRef.current = false; @@ -1091,7 +1115,7 @@ const MediaAgent = (props: MediaAgentProps) => { connectionsRef.current.forEach((connection) => connection.close()); connectionsRef.current.clear(); }; - }, [readyState, setup_local_media]); + }, [readyState, setup_local_media, session]); return null; }; diff --git a/client/src/UserList.tsx b/client/src/UserList.tsx index f431fca..ff32595 100644 --- a/client/src/UserList.tsx +++ b/client/src/UserList.tsx @@ -16,6 +16,7 @@ type User = { local: boolean /* Client side variable */; protected?: boolean; is_bot?: boolean; + has_media?: boolean; // Whether this user provides audio/video streams bot_run_id?: string; bot_provider_id?: string; }; @@ -168,13 +169,28 @@ const UserList: React.FC = (props: UserListProps) => { {user.name && !user.live &&
} - {user.name && user.live && peers[user.session_id] ? ( + {user.name && user.live && peers[user.session_id] && (user.local || user.has_media !== false) ? ( + ) : user.name && user.live && user.has_media === false ? ( +
+ 💬 Chat Only +
) : ( )} diff --git a/client/src/api-client.ts b/client/src/api-client.ts index fff0a67..c4859b7 100644 --- a/client/src/api-client.ts +++ b/client/src/api-client.ts @@ -39,7 +39,7 @@ export interface BotProviderListResponse { } export interface BotListResponse { - bots: Record; + bots: BotInfoModel[]; providers: Record; } diff --git a/server/main.py b/server/main.py index d504f5a..d4d7579 100644 --- a/server/main.py +++ b/server/main.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any, Optional, List from fastapi import ( Body, Cookie, @@ -26,7 +26,6 @@ from fastapi.staticfiles import StaticFiles import httpx from pydantic import ValidationError from logger import logger - # Import shared models sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from shared.models import ( @@ -64,6 +63,8 @@ from shared.models import ( BotJoinPayload, BotLeaveLobbyRequest, BotLeaveLobbyResponse, + BotProviderBotsResponse, + BotProviderJoinResponse, ) @@ -371,6 +372,7 @@ class Lobby: if s.name and s.name.lower() in name_passwords else False, is_bot=s.is_bot, + has_media=s.has_media, bot_run_id=s.bot_run_id, bot_provider_id=s.bot_provider_id, ) @@ -499,8 +501,10 @@ class Session: _loaded = False lock = threading.RLock() # Thread safety for class-level operations - def __init__(self, id: str, is_bot: bool = False): - logger.info(f"Instantiating new session {id} (bot: {is_bot})") + def __init__(self, id: str, is_bot: bool = False, has_media: bool = True): + logger.info( + f"Instantiating new session {id} (bot: {is_bot}, media: {has_media})" + ) with Session.lock: self._instances.append(self) self.id = id @@ -515,6 +519,7 @@ class Session: self.last_used = time.time() self.displaced_at: float | None = None # When name was taken over self.is_bot = is_bot # Whether this session represents a bot + self.has_media = has_media # Whether this session provides audio/video streams self.bot_run_id: str | None = None # Bot run ID for tracking self.bot_provider_id: str | None = None # Bot provider ID self.session_lock = threading.RLock() # Instance-level lock @@ -542,6 +547,7 @@ class Session: last_used=s.last_used, displaced_at=s.displaced_at, is_bot=s.is_bot, + has_media=s.has_media, bot_run_id=s.bot_run_id, bot_provider_id=s.bot_provider_id, ) @@ -625,7 +631,11 @@ class Session: logger.info(f"Expiring session {s_saved.id[:8]}:{name} during load") continue # Skip loading this expired session - session = Session(s_saved.id, is_bot=getattr(s_saved, "is_bot", False)) + session = Session( + s_saved.id, + is_bot=getattr(s_saved, "is_bot", False), + has_media=getattr(s_saved, "has_media", True), + ) session.name = name # Load timestamps, with defaults for backward compatibility session.created_at = created_at @@ -633,6 +643,7 @@ class Session: session.displaced_at = displaced_at # Load bot information with defaults for backward compatibility session.is_bot = getattr(s_saved, "is_bot", False) + session.has_media = getattr(s_saved, "has_media", True) session.bot_run_id = getattr(s_saved, "bot_run_id", None) session.bot_provider_id = getattr(s_saved, "bot_provider_id", None) for lobby_saved in s_saved.lobbies: @@ -1079,52 +1090,62 @@ class Session: del lobby.sessions[peer_session.id] continue - # Add the peer to session's RTC peer list - with self.session_lock: - self.lobby_peers[lobby.id].append(peer_session.id) + # Only create WebRTC peer connections if at least one participant has media + should_create_rtc_connection = self.has_media or peer_session.has_media - # Add this user as an RTC peer to each existing peer - with peer_session.session_lock: - if lobby.id not in peer_session.lobby_peers: - peer_session.lobby_peers[lobby.id] = [] - peer_session.lobby_peers[lobby.id].append(self.id) + if should_create_rtc_connection: + # Add the peer to session's RTC peer list + with self.session_lock: + self.lobby_peers[lobby.id].append(peer_session.id) - logger.info( - f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName()}, {lobby.getName()}, should_create_offer=False)" - ) - try: - await peer_session.ws.send_json( - { - "type": "addPeer", - "data": { - "peer_id": self.id, - "peer_name": self.name, - "should_create_offer": False, - }, - } - ) - except Exception as e: - logger.warning( - f"Failed to send addPeer to {peer_session.getName()}: {e}" - ) + # Add this user as an RTC peer to each existing peer + with peer_session.session_lock: + if lobby.id not in peer_session.lobby_peers: + peer_session.lobby_peers[lobby.id] = [] + peer_session.lobby_peers[lobby.id].append(self.id) - # Add each other peer to the caller - logger.info( - f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName()}, {lobby.getName()}, should_create_offer=True)" - ) - try: - await self.ws.send_json( - { - "type": "addPeer", - "data": { - "peer_id": peer_session.id, - "peer_name": peer_session.name, - "should_create_offer": True, - }, - } + logger.info( + f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName()}, {lobby.getName()}, should_create_offer=False, has_media={self.has_media})" + ) + try: + await peer_session.ws.send_json( + { + "type": "addPeer", + "data": { + "peer_id": self.id, + "peer_name": self.name, + "has_media": self.has_media, + "should_create_offer": False, + }, + } + ) + except Exception as e: + logger.warning( + f"Failed to send addPeer to {peer_session.getName()}: {e}" + ) + + # Add each other peer to the caller + logger.info( + f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName()}, {lobby.getName()}, should_create_offer=True, has_media={peer_session.has_media})" + ) + try: + await self.ws.send_json( + { + "type": "addPeer", + "data": { + "peer_id": peer_session.id, + "peer_name": peer_session.name, + "has_media": peer_session.has_media, + "should_create_offer": True, + }, + } + ) + except Exception as e: + logger.warning(f"Failed to send addPeer to {self.getName()}: {e}") + else: + logger.info( + f"{self.getName()} - Skipping WebRTC connection with {peer_session.getName()} (neither has media: self={self.has_media}, peer={peer_session.has_media})" ) - except Exception as e: - logger.warning(f"Failed to send addPeer to {self.getName()}: {e}") # Add this user as an RTC peer await lobby.addSession(self) @@ -1452,7 +1473,7 @@ async def list_bot_providers() -> BotProviderListResponse: @app.get(public_url + "api/bots", response_model=BotListResponse) async def list_available_bots() -> BotListResponse: """List all available bots from all registered providers""" - bots: dict[str, BotInfoModel] = {} + bots: List[BotInfoModel] = [] providers: dict[str, str] = {} # Update last_seen timestamps and fetch bots from each provider @@ -1464,11 +1485,14 @@ async def list_available_bots() -> BotListResponse: async with httpx.AsyncClient() as client: response = await client.get(f"{provider.base_url}/bots", timeout=5.0) if response.status_code == 200: - provider_bots = response.json() - # provider_bots should be a dict of bot_name -> bot_info - for bot_name, bot_info in provider_bots.items(): - bots[bot_name] = BotInfoModel(**bot_info) - providers[bot_name] = provider_id + # Use Pydantic model to validate the response + bots_response = BotProviderBotsResponse.model_validate( + response.json() + ) + # Add each bot to the consolidated list + for bot_info in bots_response.bots: + bots.append(bot_info) + providers[bot_info.name] = provider_id else: logger.warning( f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}" @@ -1486,8 +1510,9 @@ async def request_bot_join_lobby( ) -> BotJoinLobbyResponse: """Request a bot to join a specific lobby""" - # Find which provider has this bot + # Find which provider has this bot and determine its media capability target_provider_id = request.provider_id + bot_has_media = False if not target_provider_id: # Auto-discover provider for this bot for provider_id, provider in bot_providers.items(): @@ -1497,12 +1522,42 @@ async def request_bot_join_lobby( f"{provider.base_url}/bots", timeout=5.0 ) if response.status_code == 200: - provider_bots = response.json() - if bot_name in provider_bots: - target_provider_id = provider_id + # Use Pydantic model to validate the response + bots_response = BotProviderBotsResponse.model_validate( + response.json() + ) + # Look for the bot by name + for bot_info in bots_response.bots: + if bot_info.name == bot_name: + target_provider_id = provider_id + bot_has_media = bot_info.has_media + break + if target_provider_id: break except Exception: continue + else: + # Query the specified provider for bot media capability + if target_provider_id in bot_providers: + provider = bot_providers[target_provider_id] + try: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{provider.base_url}/bots", timeout=5.0 + ) + if response.status_code == 200: + # Use Pydantic model to validate the response + bots_response = BotProviderBotsResponse.model_validate( + response.json() + ) + # Look for the bot by name + for bot_info in bots_response.bots: + if bot_info.name == bot_name: + bot_has_media = bot_info.has_media + break + except Exception: + # Default to no media if we can't query + pass if not target_provider_id or target_provider_id not in bot_providers: raise HTTPException(status_code=404, detail="Bot or provider not found") @@ -1519,8 +1574,10 @@ async def request_bot_join_lobby( bot_session_id = secrets.token_hex(16) # Create the Session object for the bot - bot_session = Session(bot_session_id, is_bot=True) - logger.info(f"Created bot session for: {bot_session.getName()}") + bot_session = Session(bot_session_id, is_bot=True, has_media=bot_has_media) + logger.info( + f"Created bot session for: {bot_session.getName()} (has_media={bot_has_media})" + ) # Determine server URL for the bot to connect back to # Use the server's public URL or construct from request @@ -1549,25 +1606,35 @@ async def request_bot_join_lobby( ) if response.status_code == 200: - result = response.json() - run_id = result.get("run_id", "unknown") + # Use Pydantic model to parse and validate response + try: + join_response = BotProviderJoinResponse.model_validate( + response.json() + ) + run_id = join_response.run_id - # Update bot session with run and provider information - with bot_session.session_lock: - bot_session.bot_run_id = run_id - bot_session.bot_provider_id = target_provider_id - bot_session.setName(bot_nick) + # Update bot session with run and provider information + with bot_session.session_lock: + bot_session.bot_run_id = run_id + bot_session.bot_provider_id = target_provider_id + bot_session.setName(bot_nick) - logger.info( - f"Bot {bot_name} requested to join lobby {request.lobby_id}" - ) + logger.info( + f"Bot {bot_name} requested to join lobby {request.lobby_id}" + ) - return BotJoinLobbyResponse( - status="requested", - bot_name=bot_name, - run_id=run_id, - provider_id=target_provider_id, - ) + return BotJoinLobbyResponse( + status="requested", + bot_name=bot_name, + run_id=run_id, + provider_id=target_provider_id, + ) + except ValidationError as e: + logger.error(f"Invalid response from bot provider: {e}") + raise HTTPException( + status_code=502, + detail=f"Bot provider returned invalid response: {str(e)}", + ) else: logger.error( f"Bot provider returned error: HTTP {response.status_code}: {response.text}" diff --git a/shared/models.py b/shared/models.py index 1debee6..a9f0e6a 100644 --- a/shared/models.py +++ b/shared/models.py @@ -44,6 +44,7 @@ class ParticipantModel(BaseModel): live: bool protected: bool is_bot: bool = False + has_media: bool = True # Whether this participant provides audio/video streams bot_run_id: Optional[str] = None bot_provider_id: Optional[str] = None @@ -226,6 +227,7 @@ class AddPeerModel(BaseModel): """WebRTC add peer message""" peer_id: str peer_name: str + has_media: bool = True should_create_offer: bool = False @@ -316,6 +318,7 @@ class SessionSaved(BaseModel): last_used: float = 0.0 displaced_at: Optional[float] = None # When name was taken over is_bot: bool = False # Whether this session represents a bot + has_media: bool = True # Whether this session provides audio/video streams bot_run_id: Optional[str] = None # Bot run ID for tracking bot_provider_id: Optional[str] = None # Bot provider ID @@ -336,6 +339,13 @@ class BotInfoModel(BaseModel): name: str description: str + has_media: bool = True # Whether this bot provides audio/video streams + + +class BotProviderBotsResponse(BaseModel): + """Response from bot provider's /bots endpoint""" + + bots: List[BotInfoModel] class BotProviderModel(BaseModel): @@ -375,7 +385,7 @@ class BotProviderListResponse(BaseModel): class BotListResponse(BaseModel): """Response listing all available bots from all providers""" - bots: Dict[str, BotInfoModel] # bot_name -> bot_info + bots: List[BotInfoModel] # List of available bots providers: Dict[str, str] # bot_name -> provider_id @@ -407,6 +417,14 @@ class BotJoinLobbyResponse(BaseModel): provider_id: str +class BotProviderJoinResponse(BaseModel): + """Response from bot provider's /bots/{bot_name}/join endpoint""" + + status: str + bot: str + run_id: str + + class BotLeaveLobbyRequest(BaseModel): """Request to make a bot leave a lobby""" diff --git a/voicebot/bot_orchestrator.py b/voicebot/bot_orchestrator.py index f323728..e28e008 100644 --- a/voicebot/bot_orchestrator.py +++ b/voicebot/bot_orchestrator.py @@ -13,7 +13,7 @@ import sys import os import time from contextlib import asynccontextmanager -from typing import Dict, Any +from typing import Dict, Any, List # Add the parent directory to sys.path to allow absolute imports sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -27,7 +27,7 @@ from voicebot.webrtc_signaling import WebRTCSignalingClient # Add shared models import for chat types sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from shared.models import ChatMessageModel +from shared.models import ChatMessageModel, BotInfoModel, BotProviderBotsResponse @asynccontextmanager @@ -37,7 +37,8 @@ async def lifespan(app: FastAPI): # Log the discovered bots bots = discover_bots() if bots: - logger.info(f"📋 Discovered {len(bots)} bots: {list(bots.keys())}") + bot_names = [bot.name for bot in bots] + logger.info(f"📋 Discovered {len(bots)} bots: {bot_names}") else: logger.info("⚠️ No bots discovered") @@ -72,13 +73,21 @@ registry: Dict[str, WebRTCSignalingClient] = {} logger.info("📦 Bot orchestrator module imported/reloaded") -def discover_bots() -> Dict[str, Dict[str, Any]]: +# Global bot registry for internal use +_bot_registry: Dict[str, Dict[str, Any]] = {} + +def discover_bots() -> "List[BotInfoModel]": """Discover bot modules under the voicebot.bots package that expose bot_info. This intentionally imports modules under `voicebot.bots` so heavy bot implementations can remain in that package and be imported lazily. """ - bots: Dict[str, Dict[str, Any]] = {} + global _bot_registry + from shared.models import BotInfoModel + + bots: List[BotInfoModel] = [] + _bot_registry.clear() # Clear previous discoveries + try: package = importlib.import_module("voicebot.bots") package_path = package.__path__ @@ -92,49 +101,62 @@ def discover_bots() -> Dict[str, Dict[str, Any]]: except Exception: logger.exception("Failed to import voicebot.bots.%s", name) continue - info = None - create_tracks = None + if hasattr(mod, "agent_info") and callable(getattr(mod, "agent_info")): try: info = mod.agent_info() - # Note: Keep copy as is to maintain structure + # Convert string has_media to boolean for compatibility + processed_info = dict(info) + has_media_value = processed_info.get("has_media", True) + if isinstance(has_media_value, str): + processed_info["has_media"] = has_media_value.lower() in ("true", "1", "yes") + + # Create BotInfoModel using model_validate + bot_info = BotInfoModel.model_validate(processed_info) + bots.append(bot_info) + + # Store additional metadata in registry + create_tracks = None + if hasattr(mod, "create_agent_tracks") and callable(getattr(mod, "create_agent_tracks")): + create_tracks = getattr(mod, "create_agent_tracks") + + chat_handler = None + if hasattr(mod, "handle_chat_message") and callable(getattr(mod, "handle_chat_message")): + chat_handler = getattr(mod, "handle_chat_message") + + _bot_registry[bot_info.name] = { + "module": name, + "info": bot_info, + "create_tracks": create_tracks, + "chat_handler": chat_handler + } + except Exception: logger.exception("agent_info() failed for %s", name) - if hasattr(mod, "create_agent_tracks") and callable(getattr(mod, "create_agent_tracks")): - create_tracks = getattr(mod, "create_agent_tracks") - - if info: - # Check for chat handler - chat_handler = None - if hasattr(mod, "handle_chat_message") and callable(getattr(mod, "handle_chat_message")): - chat_handler = getattr(mod, "handle_chat_message") - - bots[info.get("name", name)] = { - "module": name, - "info": info, - "create_tracks": create_tracks, - "chat_handler": chat_handler - } + return bots @app.get("/bots") -def list_bots() -> Dict[str, Any]: +def list_bots() -> "BotProviderBotsResponse": """List available bots.""" + from shared.models import BotProviderBotsResponse bots = discover_bots() - return {k: v["info"] for k, v in bots.items()} + return BotProviderBotsResponse(bots=bots) @app.post("/bots/{bot_name}/join") async def bot_join(bot_name: str, req: JoinRequest): """Make a bot join a lobby.""" - bots = discover_bots() - bot = bots.get(bot_name) - if not bot: + # Ensure bots are discovered and registry is populated + discover_bots() + + if bot_name not in _bot_registry: raise HTTPException(status_code=404, detail="Bot not found") - - create_tracks = bot.get("create_tracks") - chat_handler = bot.get("chat_handler") + + bot_data = _bot_registry[bot_name] + create_tracks = bot_data.get("create_tracks") + chat_handler = bot_data.get("chat_handler") logger.info(f"🤖 Bot {bot_name} joining lobby {req.lobby_id} with nick: '{req.nick}'") if chat_handler: diff --git a/voicebot/bots/chatbot.py b/voicebot/bots/chatbot.py index 16ab458..8ce1104 100644 --- a/voicebot/bots/chatbot.py +++ b/voicebot/bots/chatbot.py @@ -37,7 +37,7 @@ RESPONSES = { def agent_info() -> Dict[str, str]: - return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION} + return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "false"} def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]: diff --git a/voicebot/bots/whisper.py b/voicebot/bots/whisper.py index 6edef33..6674e43 100644 --- a/voicebot/bots/whisper.py +++ b/voicebot/bots/whisper.py @@ -22,7 +22,7 @@ AGENT_DESCRIPTION = "Speech recognition agent (Whisper) - processes incoming aud def agent_info() -> Dict[str, str]: - return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION} + return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "false"} def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]: