Lots of changes/refactorings
This commit is contained in:
parent
3674d57b0a
commit
b7e5963597
@ -35,7 +35,7 @@
|
||||
"test": "react-scripts test",
|
||||
"eject": "react-scripts eject",
|
||||
"type-check": "tsc --noEmit",
|
||||
"generate-schema": "cd ../server && uv run python3 generate_schema_simple.py",
|
||||
"generate-schema": "cd ../server && python3 generate_schema_simple.py",
|
||||
"generate-types": "npx openapi-typescript openapi-schema.json -o src/api-types.ts",
|
||||
"generate-api-types": "npm run generate-schema && npm run generate-types",
|
||||
"check-api-evolution": "node check-api-evolution.js",
|
||||
|
@ -4,6 +4,7 @@ import { Input, Paper, Typography } from "@mui/material";
|
||||
import { Session, Lobby } from "./GlobalContext";
|
||||
import { UserList } from "./UserList";
|
||||
import { LobbyChat } from "./LobbyChat";
|
||||
import BotManager from "./BotManager";
|
||||
import "./App.css";
|
||||
import { ws_base, base } from "./Common";
|
||||
import { Box, Button, Tooltip } from "@mui/material";
|
||||
@ -29,6 +30,15 @@ const LobbyView: React.FC<LobbyProps> = (props: LobbyProps) => {
|
||||
const [creatingLobby, setCreatingLobby] = useState<boolean>(false);
|
||||
const [reconnectAttempt, setReconnectAttempt] = useState<number>(0);
|
||||
|
||||
// Check if lobbyName looks like a lobby ID (32 hex characters) and redirect to default
|
||||
useEffect(() => {
|
||||
if (lobbyName && /^[a-f0-9]{32}$/i.test(lobbyName)) {
|
||||
console.log(`Lobby - Detected lobby ID in URL (${lobbyName}), redirecting to default lobby`);
|
||||
window.history.replaceState(null, "", `${base}/lobby/default`);
|
||||
window.location.reload(); // Force reload to use the new URL
|
||||
}
|
||||
}, [lobbyName]);
|
||||
|
||||
const { sendJsonMessage, lastJsonMessage, readyState } = useWebSocket(socketUrl, {
|
||||
onOpen: () => {
|
||||
console.log("app - WebSocket connection opened.");
|
||||
@ -38,8 +48,24 @@ const LobbyView: React.FC<LobbyProps> = (props: LobbyProps) => {
|
||||
console.log("app - WebSocket connection closed.");
|
||||
setReconnectAttempt((prev) => prev + 1);
|
||||
},
|
||||
onError: (event: Event) => console.error("app - WebSocket error observed:", event),
|
||||
shouldReconnect: (closeEvent) => true, // Will attempt to reconnect on all close events
|
||||
onError: (event: Event) => {
|
||||
console.error("app - WebSocket error observed:", event);
|
||||
// If we get a WebSocket error, it might be due to invalid lobby ID
|
||||
// Reset the lobby state to force recreation
|
||||
if (lobby) {
|
||||
console.log("app - WebSocket error, clearing lobby state to force refresh");
|
||||
setLobby(null);
|
||||
setSocketUrl(null);
|
||||
}
|
||||
},
|
||||
shouldReconnect: (closeEvent) => {
|
||||
// Don't reconnect if the lobby doesn't exist (4xx errors)
|
||||
if (closeEvent.code >= 4000 && closeEvent.code < 5000) {
|
||||
console.log("app - WebSocket closed with client error, not reconnecting");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
reconnectInterval: 5000, // Retry every 5 seconds
|
||||
onReconnectStop: (numAttempts) => {
|
||||
console.log(`Stopped reconnecting after ${numAttempts} attempts`);
|
||||
@ -68,6 +94,13 @@ const LobbyView: React.FC<LobbyProps> = (props: LobbyProps) => {
|
||||
case "error":
|
||||
console.error(`Lobby - Server error: ${data.error}`);
|
||||
setError(data.error);
|
||||
|
||||
// If the error is about lobby not found, reset the lobby state
|
||||
if (data.error && data.error.includes("Lobby not found")) {
|
||||
console.log("Lobby - Lobby not found error, clearing lobby state");
|
||||
setLobby(null);
|
||||
setSocketUrl(null);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@ -82,50 +115,65 @@ const LobbyView: React.FC<LobbyProps> = (props: LobbyProps) => {
|
||||
if (!session || !lobbyName || creatingLobby || (lobby && lobby.name === lobbyName)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Clear any existing lobby state when switching to a new lobby name
|
||||
if (lobby && lobby.name !== lobbyName) {
|
||||
console.log(`Lobby - Clearing previous lobby state: ${lobby.name} -> ${lobbyName}`);
|
||||
setLobby(null);
|
||||
setSocketUrl(null);
|
||||
}
|
||||
|
||||
const getLobby = async (lobbyName: string, session: Session) => {
|
||||
const res = await fetch(`${base}/api/lobby/${session.id}`, {
|
||||
method: "POST",
|
||||
cache: "no-cache",
|
||||
credentials: "same-origin",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
type: "lobby_create",
|
||||
data: {
|
||||
name: lobbyName,
|
||||
private: false,
|
||||
try {
|
||||
const res = await fetch(`${base}/api/lobby/${session.id}`, {
|
||||
method: "POST",
|
||||
cache: "no-cache",
|
||||
credentials: "same-origin",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}),
|
||||
});
|
||||
body: JSON.stringify({
|
||||
type: "lobby_create",
|
||||
data: {
|
||||
name: lobbyName,
|
||||
private: false,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
if (res.status >= 400) {
|
||||
const error = `Unable to connect to AI Voice Chat server! Try refreshing your browser in a few seconds.`;
|
||||
console.error(error);
|
||||
setError(error);
|
||||
}
|
||||
if (res.status >= 400) {
|
||||
const error = `Unable to connect to AI Voice Chat server! Try refreshing your browser in a few seconds.`;
|
||||
console.error(error);
|
||||
setError(error);
|
||||
return;
|
||||
}
|
||||
|
||||
const data = await res.json();
|
||||
if (data.error) {
|
||||
console.error(`Lobby - Server error: ${data.error}`);
|
||||
setError(data.error);
|
||||
return;
|
||||
const data = await res.json();
|
||||
if (data.error) {
|
||||
console.error(`Lobby - Server error: ${data.error}`);
|
||||
setError(data.error);
|
||||
return;
|
||||
}
|
||||
if (data.type !== "lobby_created") {
|
||||
console.error(`Lobby - Unexpected response type: ${data.type}`);
|
||||
setError(`Unexpected response from server`);
|
||||
return;
|
||||
}
|
||||
const lobby: Lobby = data.data;
|
||||
console.log(`Lobby - Joined lobby`, lobby);
|
||||
setLobby(lobby);
|
||||
} catch (err) {
|
||||
const errorMessage = err instanceof Error ? err.message : "Failed to create/join lobby";
|
||||
console.error("Lobby creation error:", errorMessage);
|
||||
setError(errorMessage);
|
||||
}
|
||||
if (data.type !== "lobby_created") {
|
||||
console.error(`Lobby - Unexpected response type: ${data.type}`);
|
||||
setError(`Unexpected response from server`);
|
||||
return;
|
||||
}
|
||||
const lobby: Lobby = data.data;
|
||||
console.log(`Lobby - Joined lobby`, lobby);
|
||||
setLobby(lobby);
|
||||
};
|
||||
|
||||
setCreatingLobby(true);
|
||||
getLobby(lobbyName, session).then(() => {
|
||||
getLobby(lobbyName, session).finally(() => {
|
||||
setCreatingLobby(false);
|
||||
});
|
||||
}, [session, lobbyName, setLobby, setError]);
|
||||
}, [session, lobbyName, lobby, setLobby, setError]);
|
||||
|
||||
const setName = (name: string) => {
|
||||
sendJsonMessage({
|
||||
@ -216,6 +264,13 @@ const LobbyView: React.FC<LobbyProps> = (props: LobbyProps) => {
|
||||
{session && socketUrl && lobby && (
|
||||
<LobbyChat socketUrl={socketUrl} session={session} lobbyId={lobby.id} />
|
||||
)}
|
||||
{session && lobby && (
|
||||
<BotManager
|
||||
lobbyId={lobby.id}
|
||||
onBotAdded={(botName) => console.log(`Bot ${botName} added to lobby`)}
|
||||
sx={{ minWidth: "300px" }}
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
303
client/src/BotManager.tsx
Normal file
303
client/src/BotManager.tsx
Normal file
@ -0,0 +1,303 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import {
|
||||
Paper,
|
||||
Button,
|
||||
List,
|
||||
ListItem,
|
||||
ListItemText,
|
||||
Typography,
|
||||
Box,
|
||||
Chip,
|
||||
IconButton,
|
||||
Dialog,
|
||||
DialogTitle,
|
||||
DialogContent,
|
||||
DialogActions,
|
||||
TextField,
|
||||
CircularProgress,
|
||||
Alert,
|
||||
Accordion,
|
||||
AccordionSummary,
|
||||
AccordionDetails,
|
||||
} from "@mui/material";
|
||||
import {
|
||||
SmartToy as BotIcon,
|
||||
Add as AddIcon,
|
||||
ExpandMore as ExpandMoreIcon,
|
||||
Refresh as RefreshIcon,
|
||||
} from "@mui/icons-material";
|
||||
import { botsApi, BotInfoModel, BotProviderModel, BotJoinLobbyRequest } from "./api-client";
|
||||
|
||||
interface BotManagerProps {
|
||||
lobbyId: string;
|
||||
onBotAdded?: (botName: string) => void;
|
||||
sx?: any;
|
||||
}
|
||||
|
||||
const BotManager: React.FC<BotManagerProps> = ({ lobbyId, onBotAdded, sx }) => {
|
||||
const [bots, setBots] = useState<Record<string, BotInfoModel>>({});
|
||||
const [providers, setProviders] = useState<Record<string, string>>({});
|
||||
const [botProviders, setBotProviders] = useState<BotProviderModel[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [addDialogOpen, setAddDialogOpen] = useState(false);
|
||||
const [selectedBot, setSelectedBot] = useState<string>("");
|
||||
const [botNick, setBotNick] = useState("");
|
||||
const [addingBot, setAddingBot] = useState(false);
|
||||
|
||||
const loadBots = async () => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const [botsResponse, providersResponse] = await Promise.all([
|
||||
botsApi.getAvailable(),
|
||||
botsApi.getProviders(),
|
||||
]);
|
||||
|
||||
setBots(botsResponse.bots);
|
||||
setProviders(botsResponse.providers);
|
||||
setBotProviders(providersResponse.providers);
|
||||
} catch (err) {
|
||||
console.error("Failed to load bots:", err);
|
||||
setError("Failed to load available bots");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
loadBots();
|
||||
}, []);
|
||||
|
||||
const handleAddBot = async () => {
|
||||
if (!selectedBot) return;
|
||||
|
||||
setAddingBot(true);
|
||||
try {
|
||||
const request: BotJoinLobbyRequest = {
|
||||
bot_name: selectedBot,
|
||||
lobby_id: lobbyId,
|
||||
nick: botNick || `${selectedBot}-bot`,
|
||||
provider_id: providers[selectedBot],
|
||||
};
|
||||
|
||||
const response = await botsApi.requestJoinLobby(selectedBot, request);
|
||||
|
||||
if (response.status === "requested") {
|
||||
setAddDialogOpen(false);
|
||||
setSelectedBot("");
|
||||
setBotNick("");
|
||||
onBotAdded?.(selectedBot);
|
||||
|
||||
// Show success feedback could be added here
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Failed to add bot:", err);
|
||||
setError("Failed to add bot to lobby");
|
||||
} finally {
|
||||
setAddingBot(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOpenAddDialog = () => {
|
||||
setAddDialogOpen(true);
|
||||
setError(null);
|
||||
};
|
||||
|
||||
const handleCloseAddDialog = () => {
|
||||
setAddDialogOpen(false);
|
||||
setSelectedBot("");
|
||||
setBotNick("");
|
||||
setError(null);
|
||||
};
|
||||
|
||||
const getProviderName = (providerId: string): string => {
|
||||
const provider = botProviders.find(p => p.provider_id === providerId);
|
||||
return provider ? provider.name : "Unknown Provider";
|
||||
};
|
||||
|
||||
const botCount = Object.keys(bots).length;
|
||||
const providerCount = botProviders.length;
|
||||
|
||||
return (
|
||||
<Paper sx={{ p: 2, ...sx }}>
|
||||
<Box display="flex" alignItems="center" justifyContent="space-between" mb={2}>
|
||||
<Typography variant="h6" display="flex" alignItems="center" gap={1}>
|
||||
<BotIcon />
|
||||
AI Bots
|
||||
</Typography>
|
||||
<Box display="flex" gap={1}>
|
||||
<IconButton onClick={loadBots} disabled={loading} size="small">
|
||||
<RefreshIcon />
|
||||
</IconButton>
|
||||
<Button
|
||||
variant="contained"
|
||||
size="small"
|
||||
startIcon={<AddIcon />}
|
||||
onClick={handleOpenAddDialog}
|
||||
disabled={loading || botCount === 0}
|
||||
>
|
||||
Add Bot
|
||||
</Button>
|
||||
</Box>
|
||||
</Box>
|
||||
|
||||
{error && (
|
||||
<Alert severity="error" sx={{ mb: 2 }}>
|
||||
{error}
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{loading ? (
|
||||
<Box display="flex" justifyContent="center" p={2}>
|
||||
<CircularProgress size={24} />
|
||||
</Box>
|
||||
) : (
|
||||
<Box>
|
||||
<Typography variant="body2" color="text.secondary" mb={1}>
|
||||
{botCount} bots available from {providerCount} providers
|
||||
</Typography>
|
||||
|
||||
{botCount === 0 ? (
|
||||
<Typography variant="body2" color="text.secondary" style={{ fontStyle: "italic" }}>
|
||||
No bots available. Make sure bot providers are registered and running.
|
||||
</Typography>
|
||||
) : (
|
||||
<Accordion>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Typography variant="subtitle2">Available Bots ({botCount})</Typography>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
<List dense>
|
||||
{Object.entries(bots).map(([botName, botInfo]) => {
|
||||
const providerId = providers[botName];
|
||||
const providerName = getProviderName(providerId);
|
||||
|
||||
return (
|
||||
<ListItem key={botName}>
|
||||
<ListItemText
|
||||
primary={botInfo.name}
|
||||
secondary={
|
||||
<Box>
|
||||
<Typography variant="body2" component="div">
|
||||
{botInfo.description}
|
||||
</Typography>
|
||||
<Chip
|
||||
label={providerName}
|
||||
size="small"
|
||||
variant="outlined"
|
||||
sx={{ mt: 0.5 }}
|
||||
/>
|
||||
</Box>
|
||||
}
|
||||
/>
|
||||
</ListItem>
|
||||
);
|
||||
})}
|
||||
</List>
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
)}
|
||||
|
||||
{providerCount > 0 && (
|
||||
<Accordion>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Typography variant="subtitle2">Bot Providers ({providerCount})</Typography>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
<List dense>
|
||||
{botProviders.map((provider) => (
|
||||
<ListItem key={provider.provider_id}>
|
||||
<ListItemText
|
||||
primary={provider.name}
|
||||
secondary={
|
||||
<Box>
|
||||
<Typography variant="body2" component="div">
|
||||
{provider.description || "No description"}
|
||||
</Typography>
|
||||
<Typography variant="caption" color="text.secondary">
|
||||
{provider.base_url}
|
||||
</Typography>
|
||||
</Box>
|
||||
}
|
||||
/>
|
||||
</ListItem>
|
||||
))}
|
||||
</List>
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
)}
|
||||
</Box>
|
||||
)}
|
||||
|
||||
{/* Add Bot Dialog */}
|
||||
<Dialog open={addDialogOpen} onClose={handleCloseAddDialog} maxWidth="sm" fullWidth>
|
||||
<DialogTitle>Add Bot to Lobby</DialogTitle>
|
||||
<DialogContent>
|
||||
{error && (
|
||||
<Alert severity="error" sx={{ mb: 2 }}>
|
||||
{error}
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
<Box sx={{ mt: 1 }}>
|
||||
<Typography variant="subtitle2" gutterBottom>
|
||||
Select Bot
|
||||
</Typography>
|
||||
<List>
|
||||
{Object.entries(bots).map(([botName, botInfo]) => (
|
||||
<ListItem
|
||||
key={botName}
|
||||
component="div"
|
||||
sx={{
|
||||
cursor: "pointer",
|
||||
backgroundColor: selectedBot === botName ? "action.selected" : "transparent",
|
||||
"&:hover": {
|
||||
backgroundColor: "action.hover",
|
||||
},
|
||||
}}
|
||||
onClick={() => setSelectedBot(botName)}
|
||||
>
|
||||
<ListItemText
|
||||
primary={botInfo.name}
|
||||
secondary={botInfo.description}
|
||||
/>
|
||||
<Chip
|
||||
label={getProviderName(providers[botName])}
|
||||
size="small"
|
||||
variant="outlined"
|
||||
/>
|
||||
</ListItem>
|
||||
))}
|
||||
</List>
|
||||
|
||||
{selectedBot && (
|
||||
<TextField
|
||||
label="Bot Nickname (optional)"
|
||||
value={botNick}
|
||||
onChange={(e) => setBotNick(e.target.value)}
|
||||
fullWidth
|
||||
margin="normal"
|
||||
placeholder={`${selectedBot}-bot`}
|
||||
helperText="Custom name for this bot instance in the lobby"
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
</DialogContent>
|
||||
<DialogActions>
|
||||
<Button onClick={handleCloseAddDialog}>Cancel</Button>
|
||||
<Button
|
||||
onClick={handleAddBot}
|
||||
variant="contained"
|
||||
disabled={!selectedBot || addingBot}
|
||||
startIcon={addingBot ? <CircularProgress size={16} /> : <AddIcon />}
|
||||
>
|
||||
{addingBot ? "Adding..." : "Add Bot"}
|
||||
</Button>
|
||||
</DialogActions>
|
||||
</Dialog>
|
||||
</Paper>
|
||||
);
|
||||
};
|
||||
|
||||
export default BotManager;
|
@ -18,6 +18,44 @@ export type SessionResponse = components['schemas']['SessionResponse'];
|
||||
export type LobbyCreateRequest = components['schemas']['LobbyCreateRequest'];
|
||||
export type LobbyCreateResponse = components['schemas']['LobbyCreateResponse'];
|
||||
|
||||
// Bot Provider Types (manually defined until API types are regenerated)
|
||||
export interface BotInfoModel {
|
||||
name: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export interface BotProviderModel {
|
||||
provider_id: string;
|
||||
base_url: string;
|
||||
name: string;
|
||||
description: string;
|
||||
registered_at: number;
|
||||
last_seen: number;
|
||||
}
|
||||
|
||||
export interface BotProviderListResponse {
|
||||
providers: BotProviderModel[];
|
||||
}
|
||||
|
||||
export interface BotListResponse {
|
||||
bots: Record<string, BotInfoModel>;
|
||||
providers: Record<string, string>;
|
||||
}
|
||||
|
||||
export interface BotJoinLobbyRequest {
|
||||
bot_name: string;
|
||||
lobby_id: string;
|
||||
nick?: string;
|
||||
provider_id?: string;
|
||||
}
|
||||
|
||||
export interface BotJoinLobbyResponse {
|
||||
status: string;
|
||||
bot_name: string;
|
||||
run_id: string;
|
||||
provider_id: string;
|
||||
}
|
||||
|
||||
export class ApiError extends Error {
|
||||
constructor(
|
||||
public status: number,
|
||||
@ -34,17 +72,20 @@ export class ApiClient {
|
||||
private defaultHeaders: Record<string, string>;
|
||||
|
||||
constructor(baseURL?: string) {
|
||||
this.baseURL = baseURL || process.env.REACT_APP_API_URL || 'http://localhost:8001';
|
||||
this.baseURL = baseURL || process.env.REACT_APP_API_URL || "http://localhost:8001";
|
||||
this.defaultHeaders = {};
|
||||
}
|
||||
|
||||
private async request<T>(path: string, options: {
|
||||
method: string;
|
||||
body?: any;
|
||||
params?: Record<string, string>;
|
||||
}): Promise<T> {
|
||||
private async request<T>(
|
||||
path: string,
|
||||
options: {
|
||||
method: string;
|
||||
body?: any;
|
||||
params?: Record<string, string>;
|
||||
}
|
||||
): Promise<T> {
|
||||
const url = new URL(path, this.baseURL);
|
||||
|
||||
|
||||
if (options.params) {
|
||||
Object.entries(options.params).forEach(([key, value]) => {
|
||||
url.searchParams.append(key, value);
|
||||
@ -54,17 +95,17 @@ export class ApiClient {
|
||||
const requestInit: RequestInit = {
|
||||
method: options.method,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
"Content-Type": "application/json",
|
||||
...this.defaultHeaders,
|
||||
},
|
||||
};
|
||||
|
||||
if (options.body && options.method !== 'GET') {
|
||||
if (options.body && options.method !== "GET") {
|
||||
requestInit.body = JSON.stringify(options.body);
|
||||
}
|
||||
|
||||
const response = await fetch(url.toString(), requestInit);
|
||||
|
||||
|
||||
if (!response.ok) {
|
||||
let errorData;
|
||||
try {
|
||||
@ -75,44 +116,60 @@ export class ApiClient {
|
||||
throw new ApiError(response.status, response.statusText, errorData);
|
||||
}
|
||||
|
||||
const contentType = response.headers.get('content-type');
|
||||
if (contentType && contentType.includes('application/json')) {
|
||||
const contentType = response.headers.get("content-type");
|
||||
if (contentType && contentType.includes("application/json")) {
|
||||
return response.json();
|
||||
}
|
||||
|
||||
|
||||
return response.text() as unknown as T;
|
||||
}
|
||||
|
||||
// Admin API methods
|
||||
async adminListNames(): Promise<AdminNamesResponse> {
|
||||
return this.request<AdminNamesResponse>('/ai-voicebot/api/admin/names', { method: 'GET' });
|
||||
return this.request<AdminNamesResponse>("/ai-voicebot/api/admin/names", { method: "GET" });
|
||||
}
|
||||
|
||||
async adminSetPassword(data: AdminSetPassword): Promise<AdminActionResponse> {
|
||||
return this.request<AdminActionResponse>('/ai-voicebot/api/admin/set_password', { method: 'POST', body: data });
|
||||
return this.request<AdminActionResponse>("/ai-voicebot/api/admin/set_password", { method: "POST", body: data });
|
||||
}
|
||||
|
||||
async adminClearPassword(data: AdminClearPassword): Promise<AdminActionResponse> {
|
||||
return this.request<AdminActionResponse>('/ai-voicebot/api/admin/clear_password', { method: 'POST', body: data });
|
||||
return this.request<AdminActionResponse>("/ai-voicebot/api/admin/clear_password", { method: "POST", body: data });
|
||||
}
|
||||
|
||||
// Health check
|
||||
async healthCheck(): Promise<HealthResponse> {
|
||||
return this.request<HealthResponse>('/ai-voicebot/api/health', { method: 'GET' });
|
||||
return this.request<HealthResponse>("/ai-voicebot/api/health", { method: "GET" });
|
||||
}
|
||||
|
||||
// Session methods
|
||||
async getSession(): Promise<SessionResponse> {
|
||||
return this.request<SessionResponse>('/ai-voicebot/api/session', { method: 'GET' });
|
||||
return this.request<SessionResponse>("/ai-voicebot/api/session", { method: "GET" });
|
||||
}
|
||||
|
||||
// Lobby methods
|
||||
async getLobbies(): Promise<LobbiesResponse> {
|
||||
return this.request<LobbiesResponse>('/ai-voicebot/api/lobby', { method: 'GET' });
|
||||
return this.request<LobbiesResponse>("/ai-voicebot/api/lobby", { method: "GET" });
|
||||
}
|
||||
|
||||
async createLobby(sessionId: string, data: LobbyCreateRequest): Promise<LobbyCreateResponse> {
|
||||
return this.request<LobbyCreateResponse>(`/ai-voicebot/api/lobby/${sessionId}`, { method: 'POST', body: data });
|
||||
return this.request<LobbyCreateResponse>(`/ai-voicebot/api/lobby/${sessionId}`, { method: "POST", body: data });
|
||||
}
|
||||
|
||||
// Bot Provider methods
|
||||
async getBotProviders(): Promise<BotProviderListResponse> {
|
||||
return this.request<BotProviderListResponse>("/ai-voicebot/api/bots/providers", { method: "GET" });
|
||||
}
|
||||
|
||||
async getAvailableBots(): Promise<BotListResponse> {
|
||||
return this.request<BotListResponse>("/ai-voicebot/api/bots", { method: "GET" });
|
||||
}
|
||||
|
||||
async requestBotJoinLobby(botName: string, request: BotJoinLobbyRequest): Promise<BotJoinLobbyResponse> {
|
||||
return this.request<BotJoinLobbyResponse>(`/ai-voicebot/api/bots/${encodeURIComponent(botName)}/join`, {
|
||||
method: "POST",
|
||||
body: request,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -313,6 +370,12 @@ export const sessionsApi = {
|
||||
createLobby: (sessionId: string, data: LobbyCreateRequest) => apiClient.createLobby(sessionId, data),
|
||||
};
|
||||
|
||||
export const botsApi = {
|
||||
getProviders: () => apiClient.getBotProviders(),
|
||||
getAvailable: () => apiClient.getAvailableBots(),
|
||||
requestJoinLobby: (botName: string, request: BotJoinLobbyRequest) => apiClient.requestBotJoinLobby(botName, request),
|
||||
};
|
||||
|
||||
// Automatically check for API evolution when this module is loaded
|
||||
// This will warn developers if new endpoints are available but not implemented
|
||||
if (process.env.NODE_ENV === 'development') {
|
||||
|
@ -13,6 +13,7 @@ services:
|
||||
- "3456:3000"
|
||||
restart: no
|
||||
volumes:
|
||||
- ./server:/server:ro # So the frontend can read the OpenAPI spec
|
||||
- ./client:/client:rw
|
||||
- ./dev-keys:/keys:ro # So the frontend entrypoint can check for SSL files
|
||||
networks:
|
||||
@ -61,7 +62,8 @@ services:
|
||||
- ./.env
|
||||
environment:
|
||||
- PRODUCTION=${PRODUCTION:-false}
|
||||
restart: always
|
||||
- VOICEBOT_MODE=provider
|
||||
restart: unless-stopped
|
||||
network_mode: host
|
||||
volumes:
|
||||
- ./cache:/root/.cache:rw
|
||||
|
287
server/main.py
287
server/main.py
@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Optional, TypedDict
|
||||
from typing import Any, Optional
|
||||
from fastapi import (
|
||||
Body,
|
||||
Cookie,
|
||||
FastAPI,
|
||||
HTTPException,
|
||||
Path,
|
||||
WebSocket,
|
||||
Request,
|
||||
@ -45,12 +46,26 @@ from shared.models import (
|
||||
JoinStatusModel,
|
||||
ChatMessageModel,
|
||||
ChatMessagesResponse,
|
||||
ParticipantModel,
|
||||
# Bot provider models
|
||||
BotProviderModel,
|
||||
BotProviderRegisterRequest,
|
||||
BotProviderRegisterResponse,
|
||||
BotProviderListResponse,
|
||||
BotListResponse,
|
||||
BotInfoModel,
|
||||
BotJoinLobbyRequest,
|
||||
BotJoinLobbyResponse,
|
||||
BotJoinPayload,
|
||||
)
|
||||
|
||||
|
||||
# Mapping of reserved names to password records (lowercased name -> {salt:..., hash:...})
|
||||
name_passwords: dict[str, dict[str, str]] = {}
|
||||
|
||||
# Bot provider registry: provider_id -> BotProviderModel
|
||||
bot_providers: dict[str, BotProviderModel] = {}
|
||||
|
||||
all_label = "[ all ]"
|
||||
info_label = "[ info ]"
|
||||
todo_label = "[ todo ]"
|
||||
@ -175,12 +190,6 @@ def admin_cleanup_sessions(request: Request):
|
||||
lobbies: dict[str, Lobby] = {}
|
||||
|
||||
|
||||
class LobbyResponse(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
private: bool
|
||||
|
||||
|
||||
class Lobby:
|
||||
def __init__(self, name: str, id: str | None = None, private: bool = False):
|
||||
self.id = secrets.token_hex(16) if id is None else id
|
||||
@ -194,15 +203,15 @@ class Lobby:
|
||||
return f"{self.short}:{self.name}"
|
||||
|
||||
async def update_state(self, requesting_session: Session | None = None):
|
||||
users: list[dict[str, str | bool]] = [
|
||||
{
|
||||
"name": s.name,
|
||||
"live": True if s.ws else False,
|
||||
"session_id": s.id,
|
||||
"protected": True
|
||||
users: list[ParticipantModel] = [
|
||||
ParticipantModel(
|
||||
name=s.name,
|
||||
live=True if s.ws else False,
|
||||
session_id=s.id,
|
||||
protected=True
|
||||
if s.name and s.name.lower() in name_passwords
|
||||
else False,
|
||||
}
|
||||
)
|
||||
for s in self.sessions.values()
|
||||
if s.name
|
||||
]
|
||||
@ -212,7 +221,10 @@ class Lobby:
|
||||
)
|
||||
if requesting_session.ws:
|
||||
await requesting_session.ws.send_json(
|
||||
{"type": "lobby_state", "data": {"participants": users}}
|
||||
{
|
||||
"type": "lobby_state",
|
||||
"data": {"participants": [user.model_dump() for user in users]},
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
@ -223,7 +235,12 @@ class Lobby:
|
||||
logger.info(f"{s.getName()} -> lobby_state({self.getName()})")
|
||||
if s.ws:
|
||||
await s.ws.send_json(
|
||||
{"type": "lobby_state", "data": {"participants": users}}
|
||||
{
|
||||
"type": "lobby_state",
|
||||
"data": {
|
||||
"participants": [user.model_dump() for user in users]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def getSession(self, id: str) -> Session | None:
|
||||
@ -356,13 +373,46 @@ class Session:
|
||||
# rec is a NamePasswordRecord
|
||||
name_passwords[name] = {"salt": rec.salt, "hash": rec.hash}
|
||||
|
||||
current_time = time.time()
|
||||
one_minute = 60.0
|
||||
three_hours = 3 * 60 * 60.0
|
||||
sessions_loaded = 0
|
||||
sessions_expired = 0
|
||||
|
||||
for s_saved in payload.sessions:
|
||||
# Check if this session should be expired during loading
|
||||
created_at = getattr(s_saved, "created_at", time.time())
|
||||
last_used = getattr(s_saved, "last_used", time.time())
|
||||
displaced_at = getattr(s_saved, "displaced_at", None)
|
||||
name = s_saved.name or ""
|
||||
|
||||
# Apply same removal criteria as cleanup_old_sessions
|
||||
should_expire = False
|
||||
|
||||
# Rule 1: Sessions with no name that are older than 1 minute (no connection assumed for disk sessions)
|
||||
if not name and current_time - created_at > one_minute:
|
||||
should_expire = True
|
||||
logger.info(
|
||||
f"Expiring session {s_saved.id[:8]} during load - no name, older than 1 minute"
|
||||
)
|
||||
|
||||
# Rule 2: Displaced sessions unused for 3+ hours (no connection assumed for disk sessions)
|
||||
elif displaced_at is not None and current_time - last_used > three_hours:
|
||||
should_expire = True
|
||||
logger.info(
|
||||
f"Expiring session {s_saved.id[:8]}:{name} during load - displaced and unused for 3+ hours"
|
||||
)
|
||||
|
||||
if should_expire:
|
||||
sessions_expired += 1
|
||||
continue # Skip loading this expired session
|
||||
|
||||
session = Session(s_saved.id)
|
||||
session.name = s_saved.name or ""
|
||||
session.name = name
|
||||
# Load timestamps, with defaults for backward compatibility
|
||||
session.created_at = getattr(s_saved, "created_at", time.time())
|
||||
session.last_used = getattr(s_saved, "last_used", time.time())
|
||||
session.displaced_at = getattr(s_saved, "displaced_at", None)
|
||||
session.created_at = created_at
|
||||
session.last_used = last_used
|
||||
session.displaced_at = displaced_at
|
||||
for lobby_saved in s_saved.lobbies:
|
||||
session.lobbies.append(
|
||||
Lobby(
|
||||
@ -378,10 +428,15 @@ class Session:
|
||||
lobbies[lobby.id] = Lobby(
|
||||
name=lobby.name, id=lobby.id
|
||||
) # Ensure lobby exists
|
||||
sessions_loaded += 1
|
||||
|
||||
logger.info(
|
||||
f"Loaded {len(payload.sessions)} sessions and {len(name_passwords)} name passwords from {cls._save_file}"
|
||||
f"Loaded {sessions_loaded} sessions and {len(name_passwords)} name passwords from {cls._save_file}"
|
||||
)
|
||||
if sessions_expired > 0:
|
||||
logger.info(f"Expired {sessions_expired} old sessions during load")
|
||||
# Save immediately to persist the cleanup
|
||||
cls.save()
|
||||
|
||||
@classmethod
|
||||
def getSession(cls, id: str) -> Session | None:
|
||||
@ -440,7 +495,7 @@ class Session:
|
||||
|
||||
current_time = time.time()
|
||||
one_minute = 60.0
|
||||
twenty_four_hours = 24 * 60 * 60.0
|
||||
three_hours = 3 * 60 * 60.0
|
||||
sessions_removed = 0
|
||||
|
||||
# Make a copy of the list to avoid modifying it while iterating
|
||||
@ -459,14 +514,14 @@ class Session:
|
||||
sessions_to_remove.append(session)
|
||||
continue
|
||||
|
||||
# Rule 2: Delete inactive sessions that had their nick taken over and haven't been used in 24 hours
|
||||
# Rule 2: Delete inactive sessions that had their nick taken over and haven't been used in 3 hours
|
||||
if (
|
||||
not session.ws
|
||||
and session.displaced_at is not None
|
||||
and current_time - session.last_used > twenty_four_hours
|
||||
and current_time - session.last_used > three_hours
|
||||
):
|
||||
logger.info(
|
||||
f"Removing session {session.getName()} - displaced and unused for 24+ hours"
|
||||
f"Removing session {session.getName()} - displaced and unused for 3+ hours"
|
||||
)
|
||||
sessions_to_remove.append(session)
|
||||
continue
|
||||
@ -493,10 +548,23 @@ class Session:
|
||||
cls._instances.remove(session)
|
||||
sessions_removed += 1
|
||||
|
||||
# Clean up empty lobbies from global lobbies dict
|
||||
empty_lobbies: list[str] = []
|
||||
for lobby_id, lobby in lobbies.items():
|
||||
if len(lobby.sessions) == 0:
|
||||
empty_lobbies.append(lobby_id)
|
||||
|
||||
for lobby_id in empty_lobbies:
|
||||
del lobbies[lobby_id]
|
||||
logger.info(f"Removed empty lobby {lobby_id}")
|
||||
|
||||
if sessions_removed > 0:
|
||||
cls.save()
|
||||
logger.info(f"Session cleanup: removed {sessions_removed} old sessions")
|
||||
|
||||
if empty_lobbies:
|
||||
logger.info(f"Session cleanup: removed {len(empty_lobbies)} empty lobbies")
|
||||
|
||||
return sessions_removed
|
||||
|
||||
async def join(self, lobby: Lobby):
|
||||
@ -784,6 +852,175 @@ async def get_chat_messages(
|
||||
messages = lobby.get_chat_messages(limit)
|
||||
|
||||
return ChatMessagesResponse(messages=messages)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Bot Provider API Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.post(
|
||||
public_url + "api/bots/providers/register",
|
||||
response_model=BotProviderRegisterResponse,
|
||||
)
|
||||
async def register_bot_provider(
|
||||
request: BotProviderRegisterRequest,
|
||||
) -> BotProviderRegisterResponse:
|
||||
"""Register a new bot provider"""
|
||||
import time
|
||||
import uuid
|
||||
|
||||
provider_id = str(uuid.uuid4())
|
||||
now = time.time()
|
||||
|
||||
provider = BotProviderModel(
|
||||
provider_id=provider_id,
|
||||
base_url=request.base_url.rstrip("/"),
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
registered_at=now,
|
||||
last_seen=now,
|
||||
)
|
||||
|
||||
bot_providers[provider_id] = provider
|
||||
logger.info(f"Registered bot provider: {request.name} at {request.base_url}")
|
||||
|
||||
return BotProviderRegisterResponse(provider_id=provider_id)
|
||||
|
||||
|
||||
@app.get(public_url + "api/bots/providers", response_model=BotProviderListResponse)
|
||||
async def list_bot_providers() -> BotProviderListResponse:
|
||||
"""List all registered bot providers"""
|
||||
return BotProviderListResponse(providers=list(bot_providers.values()))
|
||||
|
||||
|
||||
@app.get(public_url + "api/bots", response_model=BotListResponse)
|
||||
async def list_available_bots() -> BotListResponse:
|
||||
"""List all available bots from all registered providers"""
|
||||
bots: dict[str, BotInfoModel] = {}
|
||||
providers: dict[str, str] = {}
|
||||
|
||||
# Update last_seen timestamps and fetch bots from each provider
|
||||
for provider_id, provider in bot_providers.items():
|
||||
try:
|
||||
import time
|
||||
|
||||
provider.last_seen = time.time()
|
||||
|
||||
# Make HTTP request to provider's /bots endpoint
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{provider.base_url}/bots", timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
provider_bots = response.json()
|
||||
# provider_bots should be a dict of bot_name -> bot_info
|
||||
for bot_name, bot_info in provider_bots.items():
|
||||
bots[bot_name] = BotInfoModel(**bot_info)
|
||||
providers[bot_name] = provider_id
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching bots from provider {provider.name}: {e}")
|
||||
continue
|
||||
|
||||
return BotListResponse(bots=bots, providers=providers)
|
||||
|
||||
|
||||
@app.post(public_url + "api/bots/{bot_name}/join", response_model=BotJoinLobbyResponse)
|
||||
async def request_bot_join_lobby(
|
||||
bot_name: str, request: BotJoinLobbyRequest
|
||||
) -> BotJoinLobbyResponse:
|
||||
"""Request a bot to join a specific lobby"""
|
||||
|
||||
# Find which provider has this bot
|
||||
target_provider_id = request.provider_id
|
||||
if not target_provider_id:
|
||||
# Auto-discover provider for this bot
|
||||
for provider_id, provider in bot_providers.items():
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{provider.base_url}/bots", timeout=5.0
|
||||
)
|
||||
if response.status_code == 200:
|
||||
provider_bots = response.json()
|
||||
if bot_name in provider_bots:
|
||||
target_provider_id = provider_id
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not target_provider_id or target_provider_id not in bot_providers:
|
||||
raise HTTPException(status_code=404, detail="Bot or provider not found")
|
||||
|
||||
provider = bot_providers[target_provider_id]
|
||||
|
||||
# Get the lobby to validate it exists
|
||||
try:
|
||||
getLobby(request.lobby_id) # Just validate it exists
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Lobby not found")
|
||||
|
||||
# Create a session for the bot
|
||||
bot_session_id = secrets.token_hex(16)
|
||||
|
||||
# Determine server URL for the bot to connect back to
|
||||
# Use the server's public URL or construct from request
|
||||
server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000")
|
||||
if server_base_url.endswith("/"):
|
||||
server_base_url = server_base_url[:-1]
|
||||
|
||||
bot_nick = request.nick or f"{bot_name}-bot"
|
||||
|
||||
# Prepare the join request for the bot provider
|
||||
bot_join_payload = BotJoinPayload(
|
||||
lobby_id=request.lobby_id,
|
||||
session_id=bot_session_id,
|
||||
nick=bot_nick,
|
||||
server_url=f"{server_base_url}{public_url}".rstrip("/"),
|
||||
insecure=False, # Assume secure by default
|
||||
)
|
||||
|
||||
try:
|
||||
# Make request to bot provider
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{provider.base_url}/bots/{bot_name}/join",
|
||||
json=bot_join_payload.model_dump(),
|
||||
timeout=10.0,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
run_id = result.get("run_id", "unknown")
|
||||
|
||||
logger.info(
|
||||
f"Bot {bot_name} requested to join lobby {request.lobby_id}"
|
||||
)
|
||||
|
||||
return BotJoinLobbyResponse(
|
||||
status="requested",
|
||||
bot_name=bot_name,
|
||||
run_id=run_id,
|
||||
provider_id=target_provider_id,
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Bot provider returned error: HTTP {response.status_code}: {response.text}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Bot provider error: {response.status_code}",
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
raise HTTPException(status_code=504, detail="Bot provider timeout")
|
||||
except Exception as e:
|
||||
logger.error(f"Error requesting bot join: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
|
||||
# Register websocket endpoint directly on app with full public_url path
|
||||
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}")
|
||||
async def lobby_join(
|
||||
|
@ -41,7 +41,8 @@ class ParticipantModel(BaseModel):
|
||||
"""Represents a participant in a lobby/session"""
|
||||
name: str
|
||||
session_id: str
|
||||
# Add other participant fields as needed based on actual data structure
|
||||
live: bool
|
||||
protected: bool
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@ -75,6 +76,15 @@ class HealthResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class ClientStatusResponse(BaseModel):
|
||||
"""Client status response"""
|
||||
|
||||
client_running: bool
|
||||
session_name: str
|
||||
lobby: str
|
||||
server_url: str
|
||||
|
||||
|
||||
class LobbyListItem(BaseModel):
|
||||
"""Lobby item for list responses"""
|
||||
id: str
|
||||
@ -268,3 +278,82 @@ class SessionsPayload(BaseModel):
|
||||
"""Complete sessions data for persistence"""
|
||||
sessions: List[SessionSaved] = []
|
||||
name_passwords: Dict[str, NamePasswordRecord] = {}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Bot Provider Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BotInfoModel(BaseModel):
|
||||
"""Information about a specific bot"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class BotProviderModel(BaseModel):
|
||||
"""Bot provider registration information"""
|
||||
|
||||
provider_id: str
|
||||
base_url: str
|
||||
name: str
|
||||
description: str = ""
|
||||
registered_at: float
|
||||
last_seen: float
|
||||
|
||||
|
||||
class BotProviderRegisterRequest(BaseModel):
|
||||
"""Request to register a bot provider"""
|
||||
|
||||
base_url: str
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
class BotProviderRegisterResponse(BaseModel):
|
||||
"""Response after registering a bot provider"""
|
||||
|
||||
provider_id: str
|
||||
status: str = "registered"
|
||||
|
||||
|
||||
class BotProviderListResponse(BaseModel):
|
||||
"""Response listing all registered bot providers"""
|
||||
|
||||
providers: List[BotProviderModel]
|
||||
|
||||
|
||||
class BotListResponse(BaseModel):
|
||||
"""Response listing all available bots from all providers"""
|
||||
|
||||
bots: Dict[str, BotInfoModel] # bot_name -> bot_info
|
||||
providers: Dict[str, str] # bot_name -> provider_id
|
||||
|
||||
|
||||
class BotJoinLobbyRequest(BaseModel):
|
||||
"""Request to make a bot join a lobby"""
|
||||
|
||||
bot_name: str
|
||||
lobby_id: str
|
||||
nick: str = ""
|
||||
provider_id: Optional[str] = None # Optional: specify which provider to use
|
||||
|
||||
|
||||
class BotJoinPayload(BaseModel):
|
||||
"""Payload sent to bot provider to make a bot join a lobby"""
|
||||
|
||||
lobby_id: str
|
||||
session_id: str
|
||||
nick: str
|
||||
server_url: str
|
||||
insecure: bool = False
|
||||
|
||||
|
||||
class BotJoinLobbyResponse(BaseModel):
|
||||
"""Response after requesting a bot to join a lobby"""
|
||||
|
||||
status: str
|
||||
bot_name: str
|
||||
run_id: str
|
||||
provider_id: str
|
||||
|
@ -0,0 +1,302 @@
|
||||
# AI Voicebot
|
||||
|
||||
A WebRTC-enabled AI voicebot system with speech recognition and synthetic media capabilities. The voicebot can run in two modes: as a client connecting to lobbies or as a provider serving bots to other applications.
|
||||
|
||||
## Features
|
||||
|
||||
- **Speech Recognition**: Uses Whisper models for real-time audio transcription
|
||||
- **Synthetic Media**: Generates animated video and audio tracks
|
||||
- **WebRTC Integration**: Real-time peer-to-peer communication
|
||||
- **Bot Provider System**: Can register with a main server to provide bot services
|
||||
- **Flexible Deployment**: Docker-based with development and production modes
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Docker and Docker Compose
|
||||
- Python 3.12+ (if running locally)
|
||||
- Access to a compatible signaling server
|
||||
|
||||
### Running with Docker
|
||||
|
||||
#### 1. Bot Provider Mode (Recommended)
|
||||
|
||||
Run the voicebot as a bot provider that registers with the main server:
|
||||
|
||||
```bash
|
||||
# Development mode with auto-reload
|
||||
VOICEBOT_MODE=provider PRODUCTION=false docker-compose up voicebot
|
||||
|
||||
# Production mode
|
||||
VOICEBOT_MODE=provider PRODUCTION=true docker-compose up voicebot
|
||||
```
|
||||
|
||||
#### 2. Direct Client Mode
|
||||
|
||||
Run the voicebot as a direct client connecting to a lobby:
|
||||
|
||||
```bash
|
||||
# Development mode
|
||||
VOICEBOT_MODE=client PRODUCTION=false docker-compose up voicebot
|
||||
|
||||
# Production mode
|
||||
VOICEBOT_MODE=client PRODUCTION=true docker-compose up voicebot
|
||||
```
|
||||
|
||||
### Running Locally
|
||||
|
||||
#### 1. Setup Environment
|
||||
|
||||
```bash
|
||||
cd voicebot/
|
||||
|
||||
# Create virtual environment
|
||||
uv init --python /usr/bin/python3.12 --name "ai-voicebot-agent"
|
||||
uv add -r requirements.txt
|
||||
|
||||
# Activate environment
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
#### 2. Bot Provider Mode
|
||||
|
||||
```bash
|
||||
# Development with auto-reload
|
||||
python main.py --mode provider --server-url https://your-server.com/ai-voicebot --reload --insecure
|
||||
|
||||
# Production
|
||||
python main.py --mode provider --server-url https://your-server.com/ai-voicebot
|
||||
```
|
||||
|
||||
#### 3. Direct Client Mode
|
||||
|
||||
```bash
|
||||
python main.py --mode client \
|
||||
--server-url https://your-server.com/ai-voicebot \
|
||||
--lobby "my-lobby" \
|
||||
--session-name "My Bot" \
|
||||
--insecure
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description | Default | Example |
|
||||
|----------|-------------|---------|---------|
|
||||
| `VOICEBOT_MODE` | Operating mode: `client` or `provider` | `client` | `provider` |
|
||||
| `PRODUCTION` | Production mode flag | `false` | `true` |
|
||||
|
||||
### Command Line Arguments
|
||||
|
||||
#### Common Arguments
|
||||
- `--mode`: Run as `client` or `provider`
|
||||
- `--server-url`: Main server URL
|
||||
- `--insecure`: Allow insecure SSL connections
|
||||
- `--help`: Show all available options
|
||||
|
||||
#### Provider Mode Arguments
|
||||
- `--host`: Host to bind the provider server (default: `0.0.0.0`)
|
||||
- `--port`: Port for the provider server (default: `8788`)
|
||||
- `--reload`: Enable auto-reload for development
|
||||
|
||||
#### Client Mode Arguments
|
||||
- `--lobby`: Lobby name to join (default: `default`)
|
||||
- `--session-name`: Display name for the bot (default: `Python Bot`)
|
||||
- `--session-id`: Existing session ID to reuse
|
||||
- `--password`: Password for protected names
|
||||
- `--private`: Create/join private lobby
|
||||
|
||||
## Available Bots
|
||||
|
||||
The voicebot system includes the following bot types:
|
||||
|
||||
### 1. Whisper Bot
|
||||
- **Name**: `whisper`
|
||||
- **Description**: Speech recognition agent using OpenAI Whisper models
|
||||
- **Capabilities**: Real-time audio transcription, multiple language support
|
||||
- **Models**: Supports various Whisper and Distil-Whisper models
|
||||
|
||||
### 2. Synthetic Media Bot
|
||||
- **Name**: `synthetic_media`
|
||||
- **Description**: Generates animated video and audio tracks
|
||||
- **Capabilities**: Animated video generation, synthetic audio, edge detection on incoming video
|
||||
|
||||
## Architecture
|
||||
|
||||
### Bot Provider System
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
|
||||
│ Main Server │ │ Bot Provider │ │ Client App │
|
||||
│ │◄───┤ (Voicebot) │ │ │
|
||||
│ - Bot Registry │ │ - Whisper Bot │ │ - Bot Manager │
|
||||
│ - Lobby Management │ - Synthetic Bot │ │ - UI Controls │
|
||||
│ - API Endpoints │ │ - API Server │ │ - Lobby View │
|
||||
└─────────────────┘ └──────────────────┘ └─────────────────┘
|
||||
```
|
||||
|
||||
### Flow
|
||||
1. Voicebot registers as bot provider with main server
|
||||
2. Main server discovers available bots from providers
|
||||
3. Client requests bot to join lobby via main server
|
||||
4. Main server forwards request to appropriate provider
|
||||
5. Provider creates bot instance that connects to the lobby
|
||||
|
||||
## Development
|
||||
|
||||
### Auto-Reload
|
||||
|
||||
In development mode, the bot provider supports auto-reload using uvicorn:
|
||||
|
||||
```bash
|
||||
# Watches /voicebot and /shared directories for changes
|
||||
python main.py --mode provider --reload
|
||||
```
|
||||
|
||||
### Adding New Bots
|
||||
|
||||
1. Create a new module in `voicebot/bots/`
|
||||
2. Implement required functions:
|
||||
```python
|
||||
def agent_info() -> dict:
|
||||
return {"name": "my_bot", "description": "My custom bot"}
|
||||
|
||||
def create_agent_tracks(session_name: str) -> dict:
|
||||
# Return MediaStreamTrack instances
|
||||
return {"audio": my_audio_track, "video": my_video_track}
|
||||
```
|
||||
3. The bot will be automatically discovered and available
|
||||
|
||||
### Testing
|
||||
|
||||
```bash
|
||||
# Test bot discovery
|
||||
python test_bot_api.py
|
||||
|
||||
# Test client connection
|
||||
python main.py --mode client --lobby test --session-name "Test Bot"
|
||||
```
|
||||
|
||||
## Production Deployment
|
||||
|
||||
### Docker Compose
|
||||
|
||||
```yaml
|
||||
version: '3.8'
|
||||
services:
|
||||
voicebot-provider:
|
||||
build: .
|
||||
environment:
|
||||
- VOICEBOT_MODE=provider
|
||||
- PRODUCTION=true
|
||||
ports:
|
||||
- "8788:8788"
|
||||
volumes:
|
||||
- ./cache:/voicebot/cache
|
||||
```
|
||||
|
||||
### Kubernetes
|
||||
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: voicebot-provider
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: voicebot-provider
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: voicebot-provider
|
||||
spec:
|
||||
containers:
|
||||
- name: voicebot
|
||||
image: ai-voicebot:latest
|
||||
env:
|
||||
- name: VOICEBOT_MODE
|
||||
value: "provider"
|
||||
- name: PRODUCTION
|
||||
value: "true"
|
||||
ports:
|
||||
- containerPort: 8788
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Bot Provider Endpoints
|
||||
|
||||
The voicebot provider exposes the following HTTP API:
|
||||
|
||||
- `GET /bots` - List available bots
|
||||
- `POST /bots/{bot_name}/join` - Request bot to join lobby
|
||||
- `GET /bots/runs` - List active bot instances
|
||||
- `POST /bots/runs/{run_id}/stop` - Stop a bot instance
|
||||
|
||||
### Example API Usage
|
||||
|
||||
```bash
|
||||
# List available bots
|
||||
curl http://localhost:8788/bots
|
||||
|
||||
# Request whisper bot to join lobby
|
||||
curl -X POST http://localhost:8788/bots/whisper/join \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"lobby_id": "lobby-123",
|
||||
"session_id": "session-456",
|
||||
"nick": "Speech Bot",
|
||||
"server_url": "https://server.com/ai-voicebot"
|
||||
}'
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Bot provider not registering:**
|
||||
- Check server URL is correct and accessible
|
||||
- Verify network connectivity between provider and server
|
||||
- Check logs for registration errors
|
||||
|
||||
**Auto-reload not working:**
|
||||
- Ensure `--reload` flag is used in development
|
||||
- Check file permissions on watched directories
|
||||
- Verify uvicorn version supports reload functionality
|
||||
|
||||
**WebRTC connection issues:**
|
||||
- Check STUN/TURN server configuration
|
||||
- Verify network ports are not blocked
|
||||
- Check browser console for ICE connection errors
|
||||
|
||||
### Logs
|
||||
|
||||
Logs are written to stdout and include:
|
||||
- Bot registration status
|
||||
- WebRTC connection events
|
||||
- Media track creation/destruction
|
||||
- API request/response details
|
||||
|
||||
### Debug Mode
|
||||
|
||||
Enable verbose logging:
|
||||
|
||||
```bash
|
||||
python main.py --mode provider --server-url https://server.com --debug
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a feature branch
|
||||
3. Make your changes
|
||||
4. Add tests for new functionality
|
||||
5. Submit a pull request
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License - see the LICENSE file for details.
|
@ -1,3 +1,6 @@
|
||||
"""Bots package for discoverable agent modules."""
|
||||
|
||||
from . import synthetic_media
|
||||
from . import whisper
|
||||
|
||||
__all__ = ["synthetic_media", "whisper"]
|
||||
|
@ -18,21 +18,58 @@ fi
|
||||
export VIRTUAL_ENV=/voicebot/.venv
|
||||
export PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
# Determine mode - provider or client
|
||||
MODE="${VOICEBOT_MODE:-client}"
|
||||
|
||||
# Launch voicebot in production or development mode
|
||||
if [ "$PRODUCTION" != "true" ]; then
|
||||
echo "Starting voicebot in development mode with auto-reload..."
|
||||
# Fix: Use single --watch argument with multiple paths instead of multiple --watch arguments
|
||||
python3 -u scripts/reload_runner.py --delay-restart 3 --watch . /shared --verbose --interval 0.5 -- uv run main.py \
|
||||
--insecure \
|
||||
--server-url https://ketrenos.com/ai-voicebot \
|
||||
--lobby default \
|
||||
--session-name "Python Voicebot" \
|
||||
--password "v01c3b0t"
|
||||
echo "Starting voicebot in development mode..."
|
||||
if [ "$MODE" = "provider" ]; then
|
||||
echo "Running as bot provider with auto-reload..."
|
||||
export VOICEBOT_MODE=provider
|
||||
exec uv run uvicorn main:uvicorn_app \
|
||||
--host 0.0.0.0 \
|
||||
--port 8788 \
|
||||
--reload \
|
||||
--reload-dir /voicebot \
|
||||
--reload-dir /shared \
|
||||
--log-level info
|
||||
else
|
||||
echo "Running as client (connecting to lobby)..."
|
||||
export VOICEBOT_MODE=client
|
||||
export VOICEBOT_SERVER_URL="https://ketrenos.com/ai-voicebot"
|
||||
export VOICEBOT_LOBBY="default"
|
||||
export VOICEBOT_SESSION_NAME="Python Voicebot"
|
||||
export VOICEBOT_PASSWORD="v01c3b0t"
|
||||
export VOICEBOT_INSECURE="true"
|
||||
exec uv run uvicorn main:uvicorn_app \
|
||||
--host 0.0.0.0 \
|
||||
--port 8789 \
|
||||
--reload \
|
||||
--reload-dir /voicebot \
|
||||
--reload-dir /shared \
|
||||
--log-level info
|
||||
fi
|
||||
else
|
||||
echo "Starting voicebot in production mode..."
|
||||
exec uv run main.py \
|
||||
--server-url https://ai-voicebot.ketrenos.com \
|
||||
--lobby default \
|
||||
--session-name "Python Voicebot" \
|
||||
--password "v01c3b0t"
|
||||
if [ "$MODE" = "provider" ]; then
|
||||
echo "Running as bot provider..."
|
||||
export VOICEBOT_MODE=provider
|
||||
exec uv run uvicorn main:uvicorn_app \
|
||||
--host 0.0.0.0 \
|
||||
--port 8788 \
|
||||
--log-level info
|
||||
else
|
||||
echo "Running as client (connecting to lobby)..."
|
||||
export VOICEBOT_MODE=client
|
||||
export VOICEBOT_SERVER_URL="https://ai-voicebot.ketrenos.com"
|
||||
export VOICEBOT_LOBBY="default"
|
||||
export VOICEBOT_SESSION_NAME="Python Voicebot"
|
||||
export VOICEBOT_PASSWORD="v01c3b0t"
|
||||
export VOICEBOT_INSECURE="false"
|
||||
exec uv run uvicorn main:uvicorn_app \
|
||||
--host 0.0.0.0 \
|
||||
--port 8789 \
|
||||
--log-level info
|
||||
fi
|
||||
fi
|
||||
|
347
voicebot/main.py
347
voicebot/main.py
@ -50,6 +50,7 @@ from shared.models import (
|
||||
IceCandidateModel,
|
||||
ICECandidateDictModel,
|
||||
SessionDescriptionTypedModel,
|
||||
ClientStatusResponse,
|
||||
)
|
||||
from aiortc import (
|
||||
RTCPeerConnection,
|
||||
@ -60,6 +61,76 @@ from aiortc import (
|
||||
from logger import logger
|
||||
from voicebot.bots.synthetic_media import create_synthetic_tracks, AnimatedVideoTrack
|
||||
|
||||
# Pydantic model for voicebot arguments
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
class VoicebotMode(str, Enum):
|
||||
"""Voicebot operation modes."""
|
||||
CLIENT = "client"
|
||||
PROVIDER = "provider"
|
||||
|
||||
class VoicebotArgs(BaseModel):
|
||||
"""Pydantic model for voicebot CLI arguments and configuration."""
|
||||
|
||||
# Mode selection
|
||||
mode: VoicebotMode = Field(default=VoicebotMode.CLIENT, description="Run as client (connect to lobby) or provider (serve bots)")
|
||||
|
||||
# Provider mode arguments
|
||||
host: str = Field(default="0.0.0.0", description="Host for provider mode")
|
||||
port: int = Field(default=8788, description="Port for provider mode", ge=1, le=65535)
|
||||
reload: bool = Field(default=False, description="Enable auto-reload for development")
|
||||
|
||||
# Client mode arguments
|
||||
server_url: str = Field(
|
||||
default="http://localhost:8000/ai-voicebot",
|
||||
description="AI-Voicebot lobby and signaling server base URL (http:// or https://)"
|
||||
)
|
||||
lobby: str = Field(default="default", description="Lobby name to create or join")
|
||||
session_name: str = Field(default="Python Bot", description="Session (user) display name")
|
||||
session_id: Optional[str] = Field(default=None, description="Optional existing session id to reuse")
|
||||
password: Optional[str] = Field(default=None, description="Optional password to register or takeover a name")
|
||||
private: bool = Field(default=False, description="Create the lobby as private")
|
||||
insecure: bool = Field(default=False, description="Allow insecure server connections when using SSL")
|
||||
|
||||
@classmethod
|
||||
def from_environment(cls) -> 'VoicebotArgs':
|
||||
"""Create VoicebotArgs from environment variables."""
|
||||
import os
|
||||
|
||||
mode_str = os.getenv('VOICEBOT_MODE', 'client')
|
||||
return cls(
|
||||
mode=VoicebotMode(mode_str),
|
||||
host=os.getenv('VOICEBOT_HOST', '0.0.0.0'),
|
||||
port=int(os.getenv('VOICEBOT_PORT', '8788')),
|
||||
reload=os.getenv('VOICEBOT_RELOAD', 'false').lower() == 'true',
|
||||
server_url=os.getenv('VOICEBOT_SERVER_URL', 'http://localhost:8000/ai-voicebot'),
|
||||
lobby=os.getenv('VOICEBOT_LOBBY', 'default'),
|
||||
session_name=os.getenv('VOICEBOT_SESSION_NAME', 'Python Bot'),
|
||||
session_id=os.getenv('VOICEBOT_SESSION_ID', None),
|
||||
password=os.getenv('VOICEBOT_PASSWORD', None),
|
||||
private=os.getenv('VOICEBOT_PRIVATE', 'false').lower() == 'true',
|
||||
insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_argparse(cls, args: 'argparse.Namespace') -> 'VoicebotArgs':
|
||||
"""Create VoicebotArgs from argparse Namespace."""
|
||||
mode_str = getattr(args, 'mode', 'client')
|
||||
return cls(
|
||||
mode=VoicebotMode(mode_str),
|
||||
host=getattr(args, 'host', '0.0.0.0'),
|
||||
port=getattr(args, 'port', 8788),
|
||||
reload=getattr(args, 'reload', False),
|
||||
server_url=getattr(args, 'server_url', 'http://localhost:8000/ai-voicebot'),
|
||||
lobby=getattr(args, 'lobby', 'default'),
|
||||
session_name=getattr(args, 'session_name', 'Python Bot'),
|
||||
session_id=getattr(args, 'session_id', None),
|
||||
password=getattr(args, 'password', None),
|
||||
private=getattr(args, 'private', False),
|
||||
insecure=getattr(args, 'insecure', False)
|
||||
)
|
||||
|
||||
# Bot orchestration imports
|
||||
import importlib
|
||||
import pkgutil
|
||||
@ -1048,6 +1119,14 @@ async def main():
|
||||
help="Allow insecure server connections when using SSL (accept self-signed certs)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert argparse Namespace to VoicebotArgs for type safety
|
||||
voicebot_args = VoicebotArgs.from_argparse(args)
|
||||
await main_with_args(voicebot_args)
|
||||
|
||||
|
||||
async def main_with_args(args: VoicebotArgs):
|
||||
"""Main voicebot client logic that accepts arguments object."""
|
||||
|
||||
# Resolve session id (create if needed)
|
||||
try:
|
||||
@ -1131,6 +1210,112 @@ async def main():
|
||||
|
||||
app = FastAPI(title="voicebot-bot-orchestrator")
|
||||
|
||||
# Global client app instance for uvicorn import
|
||||
client_app = None
|
||||
|
||||
# Global client arguments storage
|
||||
_client_args: Optional[VoicebotArgs] = None
|
||||
|
||||
def create_client_app(args: VoicebotArgs):
|
||||
"""Create a FastAPI app for client mode that uvicorn can import."""
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
import fcntl
|
||||
|
||||
global _client_args
|
||||
_client_args = args
|
||||
|
||||
# Store the client task globally so we can manage it
|
||||
client_task = None
|
||||
lock_file = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
nonlocal client_task, lock_file
|
||||
# Startup
|
||||
# Use a file lock to prevent multiple instances from starting
|
||||
lock_file_path = "/tmp/voicebot_client.lock"
|
||||
|
||||
try:
|
||||
lock_file = open(lock_file_path, 'w')
|
||||
# Try to acquire an exclusive lock (non-blocking)
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
|
||||
if _client_args is None:
|
||||
logger.error("Client args not initialized")
|
||||
if lock_file:
|
||||
lock_file.close()
|
||||
lock_file = None
|
||||
yield
|
||||
return
|
||||
|
||||
logger.info("Starting voicebot client...")
|
||||
client_task = asyncio.create_task(main_with_args(_client_args))
|
||||
|
||||
except (IOError, OSError):
|
||||
# Another process already has the lock
|
||||
logger.info("Another instance is already running - skipping client startup")
|
||||
if lock_file:
|
||||
lock_file.close()
|
||||
lock_file = None
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if client_task and not client_task.done():
|
||||
logger.info("Shutting down voicebot client...")
|
||||
client_task.cancel()
|
||||
try:
|
||||
await client_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if lock_file:
|
||||
try:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||
lock_file.close()
|
||||
os.unlink(lock_file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Create the client FastAPI app
|
||||
app = FastAPI(title="voicebot-client", lifespan=lifespan)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():# pyright: ignore
|
||||
"""Simple health check endpoint"""
|
||||
return {"status": "running", "mode": "client"}
|
||||
|
||||
@app.get("/status", response_model=ClientStatusResponse)
|
||||
async def client_status() -> ClientStatusResponse:# pyright: ignore
|
||||
"""Get client status"""
|
||||
return ClientStatusResponse(
|
||||
client_running=client_task is not None and not client_task.done(),
|
||||
session_name=_client_args.session_name if _client_args else 'unknown',
|
||||
lobby=_client_args.lobby if _client_args else 'unknown',
|
||||
server_url=_client_args.server_url if _client_args else 'unknown'
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
# Function to get the appropriate app based on environment variable
|
||||
def get_app():
|
||||
"""Get the appropriate FastAPI app based on VOICEBOT_MODE environment variable."""
|
||||
import os
|
||||
mode = os.getenv('VOICEBOT_MODE', 'provider')
|
||||
|
||||
if mode == 'client':
|
||||
# For client mode, we need to create the client app with args from environment
|
||||
args = VoicebotArgs.from_environment()
|
||||
return create_client_app(args)
|
||||
else:
|
||||
# Provider mode - return the main bot orchestration app
|
||||
return app
|
||||
|
||||
# Create app instance for uvicorn import
|
||||
uvicorn_app = get_app()
|
||||
|
||||
|
||||
class JoinRequest(BaseModel):
|
||||
lobby_id: str
|
||||
@ -1251,11 +1436,165 @@ def start_bot_api(host: str = "0.0.0.0", port: int = 8788):
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
async def register_with_server(server_url: str, voicebot_url: str, insecure: bool = False) -> str:
|
||||
"""Register this voicebot instance as a bot provider with the main server"""
|
||||
try:
|
||||
# Import httpx locally to avoid dependency issues
|
||||
import httpx
|
||||
|
||||
payload = {
|
||||
"base_url": voicebot_url.rstrip('/'),
|
||||
"name": "voicebot-provider",
|
||||
"description": "AI voicebot provider with speech recognition and synthetic media capabilities"
|
||||
}
|
||||
|
||||
# Prepare SSL context if needed
|
||||
verify = not insecure
|
||||
|
||||
async with httpx.AsyncClient(verify=verify) as client:
|
||||
response = await client.post(
|
||||
f"{server_url}/api/bots/providers/register",
|
||||
json=payload,
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
provider_id = result.get("provider_id")
|
||||
logger.info(f"Successfully registered with server as provider: {provider_id}")
|
||||
return provider_id
|
||||
else:
|
||||
logger.error(f"Failed to register with server: HTTP {response.status_code}: {response.text}")
|
||||
raise RuntimeError(f"Registration failed: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering with server: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def start_bot_provider(
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8788,
|
||||
server_url: str | None = None,
|
||||
insecure: bool = False,
|
||||
reload: bool = False
|
||||
):
|
||||
"""Start the bot provider API server and optionally register with main server"""
|
||||
import time
|
||||
|
||||
# Start the FastAPI server in a background thread
|
||||
import threading
|
||||
|
||||
# Add reload functionality for development
|
||||
if reload:
|
||||
server_thread = threading.Thread(
|
||||
target=lambda: uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info",
|
||||
reload=True,
|
||||
reload_dirs=["/voicebot", "/shared"]
|
||||
),
|
||||
daemon=True
|
||||
)
|
||||
else:
|
||||
server_thread = threading.Thread(
|
||||
target=lambda: uvicorn.run(app, host=host, port=port, log_level="info"),
|
||||
daemon=True
|
||||
)
|
||||
|
||||
server_thread.start()
|
||||
|
||||
# If server_url is provided, register with the main server
|
||||
if server_url:
|
||||
# Give the server a moment to start
|
||||
time.sleep(2)
|
||||
|
||||
# Construct the voicebot URL
|
||||
voicebot_url = f"http://{host}:{port}"
|
||||
if host == "0.0.0.0":
|
||||
# Try to get a better hostname
|
||||
import socket
|
||||
try:
|
||||
hostname = socket.gethostname()
|
||||
voicebot_url = f"http://{hostname}:{port}"
|
||||
except Exception:
|
||||
voicebot_url = f"http://localhost:{port}"
|
||||
|
||||
try:
|
||||
asyncio.run(register_with_server(server_url, voicebot_url, insecure))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register with server: {e}")
|
||||
|
||||
# Keep the main thread alive
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutting down bot provider...")
|
||||
|
||||
|
||||
def start_client_with_reload(args: VoicebotArgs):
|
||||
"""Start the client with auto-reload functionality."""
|
||||
global client_app
|
||||
|
||||
logger.info("Creating client app for uvicorn...")
|
||||
client_app = create_client_app(args)
|
||||
|
||||
# Note: This function is called when --reload is specified
|
||||
# The actual uvicorn execution should be handled by the entrypoint script
|
||||
logger.info("Client app created. Uvicorn should be started by entrypoint script.")
|
||||
|
||||
# Fall back to running client directly if not using uvicorn
|
||||
asyncio.run(main_with_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Install required packages:
|
||||
# pip install aiortc websockets opencv-python numpy
|
||||
|
||||
asyncio.run(main())
|
||||
# test modification
|
||||
# Test comment Mon Sep 1 03:48:19 PM PDT 2025
|
||||
# Test change at Mon Sep 1 03:52:13 PM PDT 2025
|
||||
import argparse
|
||||
|
||||
# Check if we're being run as a bot provider or as a client
|
||||
parser = argparse.ArgumentParser(description="AI Voicebot - WebRTC client or bot provider")
|
||||
parser.add_argument("--mode", choices=["client", "provider"], default="client",
|
||||
help="Run as client (connect to lobby) or provider (serve bots)")
|
||||
|
||||
# Provider mode arguments
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Host for provider mode")
|
||||
parser.add_argument("--port", type=int, default=8788, help="Port for provider mode")
|
||||
parser.add_argument("--reload", action="store_true",
|
||||
help="Enable auto-reload for development")
|
||||
|
||||
# Client mode arguments
|
||||
parser.add_argument("--server-url",
|
||||
default="http://localhost:8000/ai-voicebot",
|
||||
help="AI-Voicebot lobby and signaling server base URL (for client mode) or provider registration URL (for provider mode)")
|
||||
parser.add_argument("--lobby", default="default", help="Lobby name to create or join (client mode)")
|
||||
parser.add_argument("--session-name", default="Python Bot", help="Session (user) display name (client mode)")
|
||||
parser.add_argument("--session-id", default=None, help="Optional existing session id to reuse (client mode)")
|
||||
parser.add_argument("--password", default=None, help="Optional password to register or takeover a name (client mode)")
|
||||
parser.add_argument("--private", action="store_true", help="Create the lobby as private (client mode)")
|
||||
parser.add_argument("--insecure", action="store_true",
|
||||
help="Allow insecure connections")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert argparse Namespace to VoicebotArgs for type safety
|
||||
voicebot_args = VoicebotArgs.from_argparse(args)
|
||||
|
||||
if voicebot_args.mode == VoicebotMode.PROVIDER:
|
||||
start_bot_provider(
|
||||
host=voicebot_args.host,
|
||||
port=voicebot_args.port,
|
||||
server_url=voicebot_args.server_url,
|
||||
insecure=voicebot_args.insecure,
|
||||
reload=voicebot_args.reload
|
||||
)
|
||||
else:
|
||||
if voicebot_args.reload:
|
||||
start_client_with_reload(voicebot_args)
|
||||
else:
|
||||
asyncio.run(main_with_args(voicebot_args))
|
||||
|
||||
|
@ -1,287 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple file-watcher that restarts a command when Python source files change.
|
||||
|
||||
Usage:
|
||||
python scripts/reload_runner.py --watch voicebot -- python voicebot/main.py
|
||||
|
||||
This is intentionally dependency-free so it works in minimal dev environments
|
||||
and inside containers without installing extra packages.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from types import FrameType
|
||||
|
||||
|
||||
def scan_py_mtimes(paths: List[str]) -> Dict[str, float]:
|
||||
# Directories to skip during scanning
|
||||
SKIP_DIRS = {
|
||||
".venv",
|
||||
"__pycache__",
|
||||
".git",
|
||||
"node_modules",
|
||||
".mypy_cache",
|
||||
".pytest_cache",
|
||||
"build",
|
||||
"dist",
|
||||
}
|
||||
|
||||
mtimes: Dict[str, float] = {}
|
||||
for p in paths:
|
||||
if os.path.isfile(p) and p.endswith('.py'):
|
||||
try:
|
||||
# Use both mtime and ctime to catch more changes in Docker environments
|
||||
stat = os.stat(p)
|
||||
mtimes[p] = max(stat.st_mtime, stat.st_ctime)
|
||||
except OSError:
|
||||
pass
|
||||
continue
|
||||
|
||||
for root, dirs, files in os.walk(p):
|
||||
# Skip common directories that shouldn't trigger reloads
|
||||
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
|
||||
|
||||
for f in files:
|
||||
if not f.endswith('.py'):
|
||||
continue
|
||||
fp = os.path.join(root, f)
|
||||
try:
|
||||
# Use both mtime and ctime to catch more changes in Docker environments
|
||||
stat = os.stat(fp)
|
||||
mtimes[fp] = max(stat.st_mtime, stat.st_ctime)
|
||||
except OSError:
|
||||
# file might disappear between walk and stat
|
||||
pass
|
||||
return mtimes
|
||||
|
||||
|
||||
def scan_py_hashes(paths: List[str]) -> Dict[str, str]:
|
||||
"""Fallback method: scan file content hashes for change detection."""
|
||||
# Directories to skip during scanning
|
||||
SKIP_DIRS = {
|
||||
".venv",
|
||||
"__pycache__",
|
||||
".git",
|
||||
"node_modules",
|
||||
".mypy_cache",
|
||||
".pytest_cache",
|
||||
"build",
|
||||
"dist",
|
||||
}
|
||||
|
||||
hashes: Dict[str, str] = {}
|
||||
for p in paths:
|
||||
if os.path.isfile(p) and p.endswith(".py"):
|
||||
try:
|
||||
with open(p, "rb") as f:
|
||||
content = f.read()
|
||||
hashes[p] = hashlib.md5(content).hexdigest()
|
||||
except OSError:
|
||||
pass
|
||||
continue
|
||||
|
||||
for root, dirs, files in os.walk(p):
|
||||
# Skip common directories that shouldn't trigger reloads
|
||||
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
|
||||
|
||||
for f in files:
|
||||
if not f.endswith(".py"):
|
||||
continue
|
||||
fp = os.path.join(root, f)
|
||||
try:
|
||||
with open(fp, "rb") as file:
|
||||
content = file.read()
|
||||
hashes[fp] = hashlib.md5(content).hexdigest()
|
||||
except OSError:
|
||||
# file might disappear between walk and read
|
||||
pass
|
||||
return hashes
|
||||
|
||||
|
||||
def start_process(cmd: List[str]) -> subprocess.Popen[bytes]:
|
||||
print("Starting:", " ".join(cmd))
|
||||
return subprocess.Popen(cmd)
|
||||
|
||||
|
||||
def terminate_process(p: subprocess.Popen[bytes], timeout: float = 5.0) -> None:
|
||||
if p.poll() is not None:
|
||||
return
|
||||
try:
|
||||
p.terminate()
|
||||
waited = 0.0
|
||||
while p.poll() is None and waited < timeout:
|
||||
time.sleep(0.1)
|
||||
waited += 0.1
|
||||
if p.poll() is None:
|
||||
p.kill()
|
||||
except Exception as e:
|
||||
print("Error terminating process:", e)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Restart a command when .py files change")
|
||||
parser.add_argument("--watch", "-w", nargs="+", default=["."], help="Directories or files to watch")
|
||||
parser.add_argument(
|
||||
"--interval", "-i", type=float, default=0.5, help="Polling interval in seconds"
|
||||
)
|
||||
parser.add_argument("--delay-restart", type=float, default=0.1, help="Delay after change before restarting")
|
||||
parser.add_argument("--no-restart-on-exit", action="store_true", help="Don't restart if the process exits on its own")
|
||||
parser.add_argument("--pass-sigterm", action="store_true", help="Forward SIGTERM to child and exit when received")
|
||||
parser.add_argument(
|
||||
"--verbose", "-v", action="store_true", help="Enable verbose logging"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-hash-fallback",
|
||||
action="store_true",
|
||||
help="Use content hashing as fallback for Docker environments",
|
||||
)
|
||||
# Accept the command to run as a positional "remainder" so callers can
|
||||
# separate options with `--` and have everything after it treated as the
|
||||
# command. Defining an option named "--" doesn't work reliably with
|
||||
# argparse; use a positional argument instead.
|
||||
parser.add_argument("cmd", nargs=argparse.REMAINDER, help="Command to run (required)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# args.cmd is the remainder of the command-line. Users typically call this
|
||||
# script like: `reload_runner.py --watch . -- mycmd arg1 arg2`.
|
||||
# argparse will include a literal leading '--' in the remainder list, so
|
||||
# strip it if present.
|
||||
raw_cmd = args.cmd
|
||||
if raw_cmd and raw_cmd[0] == "--":
|
||||
cmd = raw_cmd[1:]
|
||||
else:
|
||||
cmd = raw_cmd
|
||||
|
||||
if not cmd:
|
||||
parser.error("Missing command to run. Put `--` before the command. See help.")
|
||||
|
||||
watch_paths = args.watch
|
||||
|
||||
last_mtimes = scan_py_mtimes(watch_paths)
|
||||
last_hashes = scan_py_hashes(watch_paths) if args.use_hash_fallback else {}
|
||||
|
||||
if args.verbose:
|
||||
print(f"Watching {len(last_mtimes)} Python files in paths: {watch_paths}")
|
||||
print(f"Working directory: {os.getcwd()}")
|
||||
print(f"Resolved watch paths: {[os.path.abspath(p) for p in watch_paths]}")
|
||||
print(f"Polling interval: {args.interval}s")
|
||||
if args.use_hash_fallback:
|
||||
print("Using content hash fallback for change detection")
|
||||
print("Sample files being watched:")
|
||||
for fp in sorted(last_mtimes.keys())[:5]:
|
||||
print(f" {fp}")
|
||||
if len(last_mtimes) > 5:
|
||||
print(f" ... and {len(last_mtimes) - 5} more")
|
||||
|
||||
child = start_process(cmd)
|
||||
|
||||
def handle_sigterm(signum: int, frame: Optional[FrameType]) -> None:
|
||||
if args.pass_sigterm:
|
||||
try:
|
||||
child.send_signal(signum)
|
||||
except Exception:
|
||||
pass
|
||||
print("Received signal, stopping watcher.")
|
||||
try:
|
||||
terminate_process(child)
|
||||
finally:
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, handle_sigterm)
|
||||
signal.signal(signal.SIGTERM, handle_sigterm)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Sleep in small increments so Ctrl-C is responsive
|
||||
time.sleep(args.interval)
|
||||
|
||||
# If the child exited on its own
|
||||
if child.poll() is not None:
|
||||
rc = child.returncode
|
||||
print(f"Process exited with code {rc}.")
|
||||
if args.no_restart_on_exit:
|
||||
return rc
|
||||
# else restart immediately
|
||||
child = start_process(cmd)
|
||||
last_mtimes = scan_py_mtimes(watch_paths)
|
||||
last_hashes = (
|
||||
scan_py_hashes(watch_paths) if args.use_hash_fallback else {}
|
||||
)
|
||||
continue
|
||||
|
||||
# Check for source changes
|
||||
current = scan_py_mtimes(watch_paths)
|
||||
changed = False
|
||||
change_reason = ""
|
||||
|
||||
# Check for new or changed files
|
||||
for fp, m in current.items():
|
||||
if fp not in last_mtimes or last_mtimes.get(fp) != m:
|
||||
print("Detected change in:", fp)
|
||||
if args.verbose:
|
||||
old_mtime = last_mtimes.get(fp, 0)
|
||||
print(f" Old mtime: {old_mtime}, New mtime: {m}")
|
||||
changed = True
|
||||
change_reason = f"mtime change in {fp}"
|
||||
break
|
||||
|
||||
# Hash-based fallback check if mtime didn't detect changes
|
||||
if not changed and args.use_hash_fallback:
|
||||
current_hashes = scan_py_hashes(watch_paths)
|
||||
for fp, h in current_hashes.items():
|
||||
if fp not in last_hashes or last_hashes.get(fp) != h:
|
||||
print("Detected content change in:", fp)
|
||||
if args.verbose:
|
||||
print(
|
||||
f" Hash changed: {last_hashes.get(fp, 'None')} -> {h}"
|
||||
)
|
||||
changed = True
|
||||
change_reason = f"content change in {fp}"
|
||||
break
|
||||
# Update hash cache
|
||||
last_hashes = current_hashes
|
||||
|
||||
# Check for deleted files
|
||||
if not changed:
|
||||
for fp in list(last_mtimes.keys()):
|
||||
if fp not in current:
|
||||
print("Detected deleted file:", fp)
|
||||
changed = True
|
||||
change_reason = f"deleted file {fp}"
|
||||
break
|
||||
|
||||
# Additional debug output
|
||||
if args.verbose and not changed:
|
||||
num_files = len(current)
|
||||
if num_files != len(last_mtimes):
|
||||
print(f"File count changed: {len(last_mtimes)} -> {num_files}")
|
||||
changed = True
|
||||
change_reason = "file count change"
|
||||
|
||||
if changed:
|
||||
if args.verbose:
|
||||
print(f"Restarting due to: {change_reason}")
|
||||
# Small debounce
|
||||
time.sleep(args.delay_restart)
|
||||
terminate_process(child)
|
||||
child = start_process(cmd)
|
||||
last_mtimes = scan_py_mtimes(watch_paths)
|
||||
if args.use_hash_fallback:
|
||||
last_hashes = scan_py_hashes(watch_paths)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Interrupted, shutting down.")
|
||||
terminate_process(child)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
Loading…
x
Reference in New Issue
Block a user