Refactoring

This commit is contained in:
James Ketr 2025-09-03 17:04:31 -07:00
parent 9ce3d1b670
commit 7042a76d19
10 changed files with 301 additions and 156 deletions

View File

@ -35,7 +35,7 @@ interface BotManagerProps {
} }
const BotManager: React.FC<BotManagerProps> = ({ lobbyId, onBotAdded, sx }) => { const BotManager: React.FC<BotManagerProps> = ({ lobbyId, onBotAdded, sx }) => {
const [bots, setBots] = useState<Record<string, BotInfoModel>>({}); const [bots, setBots] = useState<BotInfoModel[]>([]);
const [providers, setProviders] = useState<Record<string, string>>({}); const [providers, setProviders] = useState<Record<string, string>>({});
const [botProviders, setBotProviders] = useState<BotProviderModel[]>([]); const [botProviders, setBotProviders] = useState<BotProviderModel[]>([]);
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
@ -49,10 +49,7 @@ const BotManager: React.FC<BotManagerProps> = ({ lobbyId, onBotAdded, sx }) => {
setLoading(true); setLoading(true);
setError(null); setError(null);
try { try {
const [botsResponse, providersResponse] = await Promise.all([ const [botsResponse, providersResponse] = await Promise.all([botsApi.getAvailable(), botsApi.getProviders()]);
botsApi.getAvailable(),
botsApi.getProviders(),
]);
setBots(botsResponse.bots); setBots(botsResponse.bots);
setProviders(botsResponse.providers); setProviders(botsResponse.providers);
@ -112,11 +109,11 @@ const BotManager: React.FC<BotManagerProps> = ({ lobbyId, onBotAdded, sx }) => {
}; };
const getProviderName = (providerId: string): string => { const getProviderName = (providerId: string): string => {
const provider = botProviders.find(p => p.provider_id === providerId); const provider = botProviders.find((p) => p.provider_id === providerId);
return provider ? provider.name : "Unknown Provider"; return provider ? provider.name : "Unknown Provider";
}; };
const botCount = Object.keys(bots).length; const botCount = bots.length;
const providerCount = botProviders.length; const providerCount = botProviders.length;
return ( return (
@ -169,12 +166,12 @@ const BotManager: React.FC<BotManagerProps> = ({ lobbyId, onBotAdded, sx }) => {
</AccordionSummary> </AccordionSummary>
<AccordionDetails> <AccordionDetails>
<List dense> <List dense>
{Object.entries(bots).map(([botName, botInfo]) => { {bots.map((botInfo) => {
const providerId = providers[botName]; const providerId = providers[botInfo.name];
const providerName = getProviderName(providerId); const providerName = getProviderName(providerId);
return ( return (
<ListItem key={botName}> <ListItem key={botInfo.name}>
<ListItemText <ListItemText
primary={botInfo.name} primary={botInfo.name}
secondary={ secondary={
@ -242,21 +239,21 @@ const BotManager: React.FC<BotManagerProps> = ({ lobbyId, onBotAdded, sx }) => {
Select Bot Select Bot
</Typography> </Typography>
<List> <List>
{Object.entries(bots).map(([botName, botInfo]) => ( {bots.map((botInfo) => (
<ListItem <ListItem
key={botName} key={botInfo.name}
component="div" component="div"
sx={{ sx={{
cursor: "pointer", cursor: "pointer",
backgroundColor: selectedBot === botName ? "action.selected" : "transparent", backgroundColor: selectedBot === botInfo.name ? "action.selected" : "transparent",
"&:hover": { "&:hover": {
backgroundColor: "action.hover", backgroundColor: "action.hover",
}, },
}} }}
onClick={() => setSelectedBot(botName)} onClick={() => setSelectedBot(botInfo.name)}
> >
<ListItemText primary={botInfo.name} secondary={botInfo.description} /> <ListItemText primary={botInfo.name} secondary={botInfo.description} />
<Chip label={getProviderName(providers[botName])} size="small" variant="outlined" /> <Chip label={getProviderName(providers[botInfo.name])} size="small" variant="outlined" />
</ListItem> </ListItem>
))} ))}
</List> </List>

View File

@ -6,4 +6,5 @@ export type Lobby = LobbyModel;
// Extended Session type that allows name to be null initially (before user sets it) // Extended Session type that allows name to be null initially (before user sets it)
export type Session = Omit<SessionResponse, "name"> & { export type Session = Omit<SessionResponse, "name"> & {
name: string | null; name: string | null;
has_media?: boolean; // Whether this session provides audio/video streams
}; };

View File

@ -140,6 +140,7 @@ export type { Peer };
interface AddPeerConfig { interface AddPeerConfig {
peer_id: string; peer_id: string;
peer_name: string; peer_name: string;
has_media?: boolean; // Whether this peer provides audio/video streams
should_create_offer?: boolean; should_create_offer?: boolean;
} }
@ -406,8 +407,13 @@ const MediaAgent = (props: MediaAgentProps) => {
} }
} }
// Queue peer if media not ready // Queue peer if we need local media but don't have it yet
if (!media) { // Only queue if we're expected to provide media (local user has media)
const localUserHasMedia = session?.has_media !== false; // Default to true for backward compatibility
const peerHasMedia = config.has_media !== false; // Default to true for backward compatibility
// Only need to wait for media if we (local user) are supposed to provide it
if (!media && localUserHasMedia) {
console.log(`media-agent - addPeer:${config.peer_name} - No local media yet, queuing peer`); console.log(`media-agent - addPeer:${config.peer_name} - No local media yet, queuing peer`);
setPendingPeers((prev) => { setPendingPeers((prev) => {
// Avoid duplicate queuing // Avoid duplicate queuing
@ -676,12 +682,20 @@ const MediaAgent = (props: MediaAgentProps) => {
} }
}; };
// Add local tracks // Add local tracks to the connection only if we have media and it's valid
console.log(`media-agent - addPeer:${peer.peer_name} Adding local tracks to new peer connection`); console.log(
`media-agent - addPeer:${peer.peer_name} Adding local tracks to new peer connection (localHasMedia=${localUserHasMedia})`
);
if (media && localUserHasMedia) {
media.getTracks().forEach((t) => { media.getTracks().forEach((t) => {
console.log(`media-agent - addPeer:${peer.peer_name} Adding track:`, t.kind, t.enabled); console.log(`media-agent - addPeer:${peer.peer_name} Adding track:`, t.kind, t.enabled);
connection.addTrack(t, media); connection.addTrack(t, media);
}); });
} else if (!localUserHasMedia) {
console.log(`media-agent - addPeer:${peer.peer_name} - Local user has no media, skipping track addition`);
} else {
console.log(`media-agent - addPeer:${peer.peer_name} - No local media available yet`);
}
// Update peers state // Update peers state
setPeers(updatedPeers); setPeers(updatedPeers);
@ -1056,8 +1070,11 @@ const MediaAgent = (props: MediaAgentProps) => {
useEffect(() => { useEffect(() => {
mountedRef.current = true; mountedRef.current = true;
const localUserHasMedia = session?.has_media !== false; // Default to true for backward compatibility
if (mediaStreamRef.current || readyState !== ReadyState.OPEN) return; if (mediaStreamRef.current || readyState !== ReadyState.OPEN) return;
if (localUserHasMedia) {
console.log(`media-agent - Setting up local media`); console.log(`media-agent - Setting up local media`);
setup_local_media().then((mediaStream) => { setup_local_media().then((mediaStream) => {
if (!mountedRef.current) { if (!mountedRef.current) {
@ -1073,6 +1090,13 @@ const MediaAgent = (props: MediaAgentProps) => {
mediaStreamRef.current = mediaStream; mediaStreamRef.current = mediaStream;
setMedia(mediaStream); setMedia(mediaStream);
}); });
} else {
console.log(`media-agent - Local user has no media, creating empty stream`);
// Create an empty media stream for users without media
const emptyStream = new MediaStream();
mediaStreamRef.current = emptyStream;
setMedia(emptyStream);
}
return () => { return () => {
mountedRef.current = false; mountedRef.current = false;
@ -1091,7 +1115,7 @@ const MediaAgent = (props: MediaAgentProps) => {
connectionsRef.current.forEach((connection) => connection.close()); connectionsRef.current.forEach((connection) => connection.close());
connectionsRef.current.clear(); connectionsRef.current.clear();
}; };
}, [readyState, setup_local_media]); }, [readyState, setup_local_media, session]);
return null; return null;
}; };

View File

@ -16,6 +16,7 @@ type User = {
local: boolean /* Client side variable */; local: boolean /* Client side variable */;
protected?: boolean; protected?: boolean;
is_bot?: boolean; is_bot?: boolean;
has_media?: boolean; // Whether this user provides audio/video streams
bot_run_id?: string; bot_run_id?: string;
bot_provider_id?: string; bot_provider_id?: string;
}; };
@ -168,13 +169,28 @@ const UserList: React.FC<UserListProps> = (props: UserListProps) => {
</div> </div>
{user.name && !user.live && <div className="NoNetwork"></div>} {user.name && !user.live && <div className="NoNetwork"></div>}
</div> </div>
{user.name && user.live && peers[user.session_id] ? ( {user.name && user.live && peers[user.session_id] && (user.local || user.has_media !== false) ? (
<MediaControl <MediaControl
className={videoClass} className={videoClass}
key={user.session_id} key={user.session_id}
peer={peers[user.session_id]} peer={peers[user.session_id]}
isSelf={user.local} isSelf={user.local}
/> />
) : user.name && user.live && user.has_media === false ? (
<div
className="Video"
style={{
background: "#333",
color: "#fff",
display: "flex",
alignItems: "center",
justifyContent: "center",
minHeight: "120px",
fontSize: "14px",
}}
>
💬 Chat Only
</div>
) : ( ) : (
<video className="Video"></video> <video className="Video"></video>
)} )}

View File

@ -39,7 +39,7 @@ export interface BotProviderListResponse {
} }
export interface BotListResponse { export interface BotListResponse {
bots: Record<string, BotInfoModel>; bots: BotInfoModel[];
providers: Record<string, string>; providers: Record<string, string>;
} }

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Optional from typing import Any, Optional, List
from fastapi import ( from fastapi import (
Body, Body,
Cookie, Cookie,
@ -26,7 +26,6 @@ from fastapi.staticfiles import StaticFiles
import httpx import httpx
from pydantic import ValidationError from pydantic import ValidationError
from logger import logger from logger import logger
# Import shared models # Import shared models
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.models import ( from shared.models import (
@ -64,6 +63,8 @@ from shared.models import (
BotJoinPayload, BotJoinPayload,
BotLeaveLobbyRequest, BotLeaveLobbyRequest,
BotLeaveLobbyResponse, BotLeaveLobbyResponse,
BotProviderBotsResponse,
BotProviderJoinResponse,
) )
@ -371,6 +372,7 @@ class Lobby:
if s.name and s.name.lower() in name_passwords if s.name and s.name.lower() in name_passwords
else False, else False,
is_bot=s.is_bot, is_bot=s.is_bot,
has_media=s.has_media,
bot_run_id=s.bot_run_id, bot_run_id=s.bot_run_id,
bot_provider_id=s.bot_provider_id, bot_provider_id=s.bot_provider_id,
) )
@ -499,8 +501,10 @@ class Session:
_loaded = False _loaded = False
lock = threading.RLock() # Thread safety for class-level operations lock = threading.RLock() # Thread safety for class-level operations
def __init__(self, id: str, is_bot: bool = False): def __init__(self, id: str, is_bot: bool = False, has_media: bool = True):
logger.info(f"Instantiating new session {id} (bot: {is_bot})") logger.info(
f"Instantiating new session {id} (bot: {is_bot}, media: {has_media})"
)
with Session.lock: with Session.lock:
self._instances.append(self) self._instances.append(self)
self.id = id self.id = id
@ -515,6 +519,7 @@ class Session:
self.last_used = time.time() self.last_used = time.time()
self.displaced_at: float | None = None # When name was taken over self.displaced_at: float | None = None # When name was taken over
self.is_bot = is_bot # Whether this session represents a bot self.is_bot = is_bot # Whether this session represents a bot
self.has_media = has_media # Whether this session provides audio/video streams
self.bot_run_id: str | None = None # Bot run ID for tracking self.bot_run_id: str | None = None # Bot run ID for tracking
self.bot_provider_id: str | None = None # Bot provider ID self.bot_provider_id: str | None = None # Bot provider ID
self.session_lock = threading.RLock() # Instance-level lock self.session_lock = threading.RLock() # Instance-level lock
@ -542,6 +547,7 @@ class Session:
last_used=s.last_used, last_used=s.last_used,
displaced_at=s.displaced_at, displaced_at=s.displaced_at,
is_bot=s.is_bot, is_bot=s.is_bot,
has_media=s.has_media,
bot_run_id=s.bot_run_id, bot_run_id=s.bot_run_id,
bot_provider_id=s.bot_provider_id, bot_provider_id=s.bot_provider_id,
) )
@ -625,7 +631,11 @@ class Session:
logger.info(f"Expiring session {s_saved.id[:8]}:{name} during load") logger.info(f"Expiring session {s_saved.id[:8]}:{name} during load")
continue # Skip loading this expired session continue # Skip loading this expired session
session = Session(s_saved.id, is_bot=getattr(s_saved, "is_bot", False)) session = Session(
s_saved.id,
is_bot=getattr(s_saved, "is_bot", False),
has_media=getattr(s_saved, "has_media", True),
)
session.name = name session.name = name
# Load timestamps, with defaults for backward compatibility # Load timestamps, with defaults for backward compatibility
session.created_at = created_at session.created_at = created_at
@ -633,6 +643,7 @@ class Session:
session.displaced_at = displaced_at session.displaced_at = displaced_at
# Load bot information with defaults for backward compatibility # Load bot information with defaults for backward compatibility
session.is_bot = getattr(s_saved, "is_bot", False) session.is_bot = getattr(s_saved, "is_bot", False)
session.has_media = getattr(s_saved, "has_media", True)
session.bot_run_id = getattr(s_saved, "bot_run_id", None) session.bot_run_id = getattr(s_saved, "bot_run_id", None)
session.bot_provider_id = getattr(s_saved, "bot_provider_id", None) session.bot_provider_id = getattr(s_saved, "bot_provider_id", None)
for lobby_saved in s_saved.lobbies: for lobby_saved in s_saved.lobbies:
@ -1079,6 +1090,10 @@ class Session:
del lobby.sessions[peer_session.id] del lobby.sessions[peer_session.id]
continue continue
# Only create WebRTC peer connections if at least one participant has media
should_create_rtc_connection = self.has_media or peer_session.has_media
if should_create_rtc_connection:
# Add the peer to session's RTC peer list # Add the peer to session's RTC peer list
with self.session_lock: with self.session_lock:
self.lobby_peers[lobby.id].append(peer_session.id) self.lobby_peers[lobby.id].append(peer_session.id)
@ -1090,7 +1105,7 @@ class Session:
peer_session.lobby_peers[lobby.id].append(self.id) peer_session.lobby_peers[lobby.id].append(self.id)
logger.info( logger.info(
f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName()}, {lobby.getName()}, should_create_offer=False)" f"{self.getName()} -> {peer_session.getName()}:addPeer({self.getName()}, {lobby.getName()}, should_create_offer=False, has_media={self.has_media})"
) )
try: try:
await peer_session.ws.send_json( await peer_session.ws.send_json(
@ -1099,6 +1114,7 @@ class Session:
"data": { "data": {
"peer_id": self.id, "peer_id": self.id,
"peer_name": self.name, "peer_name": self.name,
"has_media": self.has_media,
"should_create_offer": False, "should_create_offer": False,
}, },
} }
@ -1110,7 +1126,7 @@ class Session:
# Add each other peer to the caller # Add each other peer to the caller
logger.info( logger.info(
f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName()}, {lobby.getName()}, should_create_offer=True)" f"{self.getName()} -> {self.getName()}:addPeer({peer_session.getName()}, {lobby.getName()}, should_create_offer=True, has_media={peer_session.has_media})"
) )
try: try:
await self.ws.send_json( await self.ws.send_json(
@ -1119,12 +1135,17 @@ class Session:
"data": { "data": {
"peer_id": peer_session.id, "peer_id": peer_session.id,
"peer_name": peer_session.name, "peer_name": peer_session.name,
"has_media": peer_session.has_media,
"should_create_offer": True, "should_create_offer": True,
}, },
} }
) )
except Exception as e: except Exception as e:
logger.warning(f"Failed to send addPeer to {self.getName()}: {e}") logger.warning(f"Failed to send addPeer to {self.getName()}: {e}")
else:
logger.info(
f"{self.getName()} - Skipping WebRTC connection with {peer_session.getName()} (neither has media: self={self.has_media}, peer={peer_session.has_media})"
)
# Add this user as an RTC peer # Add this user as an RTC peer
await lobby.addSession(self) await lobby.addSession(self)
@ -1452,7 +1473,7 @@ async def list_bot_providers() -> BotProviderListResponse:
@app.get(public_url + "api/bots", response_model=BotListResponse) @app.get(public_url + "api/bots", response_model=BotListResponse)
async def list_available_bots() -> BotListResponse: async def list_available_bots() -> BotListResponse:
"""List all available bots from all registered providers""" """List all available bots from all registered providers"""
bots: dict[str, BotInfoModel] = {} bots: List[BotInfoModel] = []
providers: dict[str, str] = {} providers: dict[str, str] = {}
# Update last_seen timestamps and fetch bots from each provider # Update last_seen timestamps and fetch bots from each provider
@ -1464,11 +1485,14 @@ async def list_available_bots() -> BotListResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get(f"{provider.base_url}/bots", timeout=5.0) response = await client.get(f"{provider.base_url}/bots", timeout=5.0)
if response.status_code == 200: if response.status_code == 200:
provider_bots = response.json() # Use Pydantic model to validate the response
# provider_bots should be a dict of bot_name -> bot_info bots_response = BotProviderBotsResponse.model_validate(
for bot_name, bot_info in provider_bots.items(): response.json()
bots[bot_name] = BotInfoModel(**bot_info) )
providers[bot_name] = provider_id # Add each bot to the consolidated list
for bot_info in bots_response.bots:
bots.append(bot_info)
providers[bot_info.name] = provider_id
else: else:
logger.warning( logger.warning(
f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}" f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}"
@ -1486,8 +1510,9 @@ async def request_bot_join_lobby(
) -> BotJoinLobbyResponse: ) -> BotJoinLobbyResponse:
"""Request a bot to join a specific lobby""" """Request a bot to join a specific lobby"""
# Find which provider has this bot # Find which provider has this bot and determine its media capability
target_provider_id = request.provider_id target_provider_id = request.provider_id
bot_has_media = False
if not target_provider_id: if not target_provider_id:
# Auto-discover provider for this bot # Auto-discover provider for this bot
for provider_id, provider in bot_providers.items(): for provider_id, provider in bot_providers.items():
@ -1497,12 +1522,42 @@ async def request_bot_join_lobby(
f"{provider.base_url}/bots", timeout=5.0 f"{provider.base_url}/bots", timeout=5.0
) )
if response.status_code == 200: if response.status_code == 200:
provider_bots = response.json() # Use Pydantic model to validate the response
if bot_name in provider_bots: bots_response = BotProviderBotsResponse.model_validate(
response.json()
)
# Look for the bot by name
for bot_info in bots_response.bots:
if bot_info.name == bot_name:
target_provider_id = provider_id target_provider_id = provider_id
bot_has_media = bot_info.has_media
break
if target_provider_id:
break break
except Exception: except Exception:
continue continue
else:
# Query the specified provider for bot media capability
if target_provider_id in bot_providers:
provider = bot_providers[target_provider_id]
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{provider.base_url}/bots", timeout=5.0
)
if response.status_code == 200:
# Use Pydantic model to validate the response
bots_response = BotProviderBotsResponse.model_validate(
response.json()
)
# Look for the bot by name
for bot_info in bots_response.bots:
if bot_info.name == bot_name:
bot_has_media = bot_info.has_media
break
except Exception:
# Default to no media if we can't query
pass
if not target_provider_id or target_provider_id not in bot_providers: if not target_provider_id or target_provider_id not in bot_providers:
raise HTTPException(status_code=404, detail="Bot or provider not found") raise HTTPException(status_code=404, detail="Bot or provider not found")
@ -1519,8 +1574,10 @@ async def request_bot_join_lobby(
bot_session_id = secrets.token_hex(16) bot_session_id = secrets.token_hex(16)
# Create the Session object for the bot # Create the Session object for the bot
bot_session = Session(bot_session_id, is_bot=True) bot_session = Session(bot_session_id, is_bot=True, has_media=bot_has_media)
logger.info(f"Created bot session for: {bot_session.getName()}") logger.info(
f"Created bot session for: {bot_session.getName()} (has_media={bot_has_media})"
)
# Determine server URL for the bot to connect back to # Determine server URL for the bot to connect back to
# Use the server's public URL or construct from request # Use the server's public URL or construct from request
@ -1549,8 +1606,12 @@ async def request_bot_join_lobby(
) )
if response.status_code == 200: if response.status_code == 200:
result = response.json() # Use Pydantic model to parse and validate response
run_id = result.get("run_id", "unknown") try:
join_response = BotProviderJoinResponse.model_validate(
response.json()
)
run_id = join_response.run_id
# Update bot session with run and provider information # Update bot session with run and provider information
with bot_session.session_lock: with bot_session.session_lock:
@ -1568,6 +1629,12 @@ async def request_bot_join_lobby(
run_id=run_id, run_id=run_id,
provider_id=target_provider_id, provider_id=target_provider_id,
) )
except ValidationError as e:
logger.error(f"Invalid response from bot provider: {e}")
raise HTTPException(
status_code=502,
detail=f"Bot provider returned invalid response: {str(e)}",
)
else: else:
logger.error( logger.error(
f"Bot provider returned error: HTTP {response.status_code}: {response.text}" f"Bot provider returned error: HTTP {response.status_code}: {response.text}"

View File

@ -44,6 +44,7 @@ class ParticipantModel(BaseModel):
live: bool live: bool
protected: bool protected: bool
is_bot: bool = False is_bot: bool = False
has_media: bool = True # Whether this participant provides audio/video streams
bot_run_id: Optional[str] = None bot_run_id: Optional[str] = None
bot_provider_id: Optional[str] = None bot_provider_id: Optional[str] = None
@ -226,6 +227,7 @@ class AddPeerModel(BaseModel):
"""WebRTC add peer message""" """WebRTC add peer message"""
peer_id: str peer_id: str
peer_name: str peer_name: str
has_media: bool = True
should_create_offer: bool = False should_create_offer: bool = False
@ -316,6 +318,7 @@ class SessionSaved(BaseModel):
last_used: float = 0.0 last_used: float = 0.0
displaced_at: Optional[float] = None # When name was taken over displaced_at: Optional[float] = None # When name was taken over
is_bot: bool = False # Whether this session represents a bot is_bot: bool = False # Whether this session represents a bot
has_media: bool = True # Whether this session provides audio/video streams
bot_run_id: Optional[str] = None # Bot run ID for tracking bot_run_id: Optional[str] = None # Bot run ID for tracking
bot_provider_id: Optional[str] = None # Bot provider ID bot_provider_id: Optional[str] = None # Bot provider ID
@ -336,6 +339,13 @@ class BotInfoModel(BaseModel):
name: str name: str
description: str description: str
has_media: bool = True # Whether this bot provides audio/video streams
class BotProviderBotsResponse(BaseModel):
"""Response from bot provider's /bots endpoint"""
bots: List[BotInfoModel]
class BotProviderModel(BaseModel): class BotProviderModel(BaseModel):
@ -375,7 +385,7 @@ class BotProviderListResponse(BaseModel):
class BotListResponse(BaseModel): class BotListResponse(BaseModel):
"""Response listing all available bots from all providers""" """Response listing all available bots from all providers"""
bots: Dict[str, BotInfoModel] # bot_name -> bot_info bots: List[BotInfoModel] # List of available bots
providers: Dict[str, str] # bot_name -> provider_id providers: Dict[str, str] # bot_name -> provider_id
@ -407,6 +417,14 @@ class BotJoinLobbyResponse(BaseModel):
provider_id: str provider_id: str
class BotProviderJoinResponse(BaseModel):
"""Response from bot provider's /bots/{bot_name}/join endpoint"""
status: str
bot: str
run_id: str
class BotLeaveLobbyRequest(BaseModel): class BotLeaveLobbyRequest(BaseModel):
"""Request to make a bot leave a lobby""" """Request to make a bot leave a lobby"""

View File

@ -13,7 +13,7 @@ import sys
import os import os
import time import time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Dict, Any from typing import Dict, Any, List
# Add the parent directory to sys.path to allow absolute imports # Add the parent directory to sys.path to allow absolute imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@ -27,7 +27,7 @@ from voicebot.webrtc_signaling import WebRTCSignalingClient
# Add shared models import for chat types # Add shared models import for chat types
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.models import ChatMessageModel from shared.models import ChatMessageModel, BotInfoModel, BotProviderBotsResponse
@asynccontextmanager @asynccontextmanager
@ -37,7 +37,8 @@ async def lifespan(app: FastAPI):
# Log the discovered bots # Log the discovered bots
bots = discover_bots() bots = discover_bots()
if bots: if bots:
logger.info(f"📋 Discovered {len(bots)} bots: {list(bots.keys())}") bot_names = [bot.name for bot in bots]
logger.info(f"📋 Discovered {len(bots)} bots: {bot_names}")
else: else:
logger.info("⚠️ No bots discovered") logger.info("⚠️ No bots discovered")
@ -72,13 +73,21 @@ registry: Dict[str, WebRTCSignalingClient] = {}
logger.info("📦 Bot orchestrator module imported/reloaded") logger.info("📦 Bot orchestrator module imported/reloaded")
def discover_bots() -> Dict[str, Dict[str, Any]]: # Global bot registry for internal use
_bot_registry: Dict[str, Dict[str, Any]] = {}
def discover_bots() -> "List[BotInfoModel]":
"""Discover bot modules under the voicebot.bots package that expose bot_info. """Discover bot modules under the voicebot.bots package that expose bot_info.
This intentionally imports modules under `voicebot.bots` so heavy bot This intentionally imports modules under `voicebot.bots` so heavy bot
implementations can remain in that package and be imported lazily. implementations can remain in that package and be imported lazily.
""" """
bots: Dict[str, Dict[str, Any]] = {} global _bot_registry
from shared.models import BotInfoModel
bots: List[BotInfoModel] = []
_bot_registry.clear() # Clear previous discoveries
try: try:
package = importlib.import_module("voicebot.bots") package = importlib.import_module("voicebot.bots")
package_path = package.__path__ package_path = package.__path__
@ -92,49 +101,62 @@ def discover_bots() -> Dict[str, Dict[str, Any]]:
except Exception: except Exception:
logger.exception("Failed to import voicebot.bots.%s", name) logger.exception("Failed to import voicebot.bots.%s", name)
continue continue
info = None
create_tracks = None
if hasattr(mod, "agent_info") and callable(getattr(mod, "agent_info")): if hasattr(mod, "agent_info") and callable(getattr(mod, "agent_info")):
try: try:
info = mod.agent_info() info = mod.agent_info()
# Note: Keep copy as is to maintain structure # Convert string has_media to boolean for compatibility
except Exception: processed_info = dict(info)
logger.exception("agent_info() failed for %s", name) has_media_value = processed_info.get("has_media", True)
if isinstance(has_media_value, str):
processed_info["has_media"] = has_media_value.lower() in ("true", "1", "yes")
# Create BotInfoModel using model_validate
bot_info = BotInfoModel.model_validate(processed_info)
bots.append(bot_info)
# Store additional metadata in registry
create_tracks = None
if hasattr(mod, "create_agent_tracks") and callable(getattr(mod, "create_agent_tracks")): if hasattr(mod, "create_agent_tracks") and callable(getattr(mod, "create_agent_tracks")):
create_tracks = getattr(mod, "create_agent_tracks") create_tracks = getattr(mod, "create_agent_tracks")
if info:
# Check for chat handler
chat_handler = None chat_handler = None
if hasattr(mod, "handle_chat_message") and callable(getattr(mod, "handle_chat_message")): if hasattr(mod, "handle_chat_message") and callable(getattr(mod, "handle_chat_message")):
chat_handler = getattr(mod, "handle_chat_message") chat_handler = getattr(mod, "handle_chat_message")
bots[info.get("name", name)] = { _bot_registry[bot_info.name] = {
"module": name, "module": name,
"info": info, "info": bot_info,
"create_tracks": create_tracks, "create_tracks": create_tracks,
"chat_handler": chat_handler "chat_handler": chat_handler
} }
except Exception:
logger.exception("agent_info() failed for %s", name)
return bots return bots
@app.get("/bots") @app.get("/bots")
def list_bots() -> Dict[str, Any]: def list_bots() -> "BotProviderBotsResponse":
"""List available bots.""" """List available bots."""
from shared.models import BotProviderBotsResponse
bots = discover_bots() bots = discover_bots()
return {k: v["info"] for k, v in bots.items()} return BotProviderBotsResponse(bots=bots)
@app.post("/bots/{bot_name}/join") @app.post("/bots/{bot_name}/join")
async def bot_join(bot_name: str, req: JoinRequest): async def bot_join(bot_name: str, req: JoinRequest):
"""Make a bot join a lobby.""" """Make a bot join a lobby."""
bots = discover_bots() # Ensure bots are discovered and registry is populated
bot = bots.get(bot_name) discover_bots()
if not bot:
if bot_name not in _bot_registry:
raise HTTPException(status_code=404, detail="Bot not found") raise HTTPException(status_code=404, detail="Bot not found")
create_tracks = bot.get("create_tracks") bot_data = _bot_registry[bot_name]
chat_handler = bot.get("chat_handler") create_tracks = bot_data.get("create_tracks")
chat_handler = bot_data.get("chat_handler")
logger.info(f"🤖 Bot {bot_name} joining lobby {req.lobby_id} with nick: '{req.nick}'") logger.info(f"🤖 Bot {bot_name} joining lobby {req.lobby_id} with nick: '{req.nick}'")
if chat_handler: if chat_handler:

View File

@ -37,7 +37,7 @@ RESPONSES = {
def agent_info() -> Dict[str, str]: def agent_info() -> Dict[str, str]:
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION} return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "false"}
def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]: def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:

View File

@ -22,7 +22,7 @@ AGENT_DESCRIPTION = "Speech recognition agent (Whisper) - processes incoming aud
def agent_info() -> Dict[str, str]: def agent_info() -> Dict[str, str]:
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION} return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "false"}
def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]: def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]: