Lots of changes/refactorings
This commit is contained in:
parent
3674d57b0a
commit
b7e5963597
@ -35,7 +35,7 @@
|
|||||||
"test": "react-scripts test",
|
"test": "react-scripts test",
|
||||||
"eject": "react-scripts eject",
|
"eject": "react-scripts eject",
|
||||||
"type-check": "tsc --noEmit",
|
"type-check": "tsc --noEmit",
|
||||||
"generate-schema": "cd ../server && uv run python3 generate_schema_simple.py",
|
"generate-schema": "cd ../server && python3 generate_schema_simple.py",
|
||||||
"generate-types": "npx openapi-typescript openapi-schema.json -o src/api-types.ts",
|
"generate-types": "npx openapi-typescript openapi-schema.json -o src/api-types.ts",
|
||||||
"generate-api-types": "npm run generate-schema && npm run generate-types",
|
"generate-api-types": "npm run generate-schema && npm run generate-types",
|
||||||
"check-api-evolution": "node check-api-evolution.js",
|
"check-api-evolution": "node check-api-evolution.js",
|
||||||
|
@ -4,6 +4,7 @@ import { Input, Paper, Typography } from "@mui/material";
|
|||||||
import { Session, Lobby } from "./GlobalContext";
|
import { Session, Lobby } from "./GlobalContext";
|
||||||
import { UserList } from "./UserList";
|
import { UserList } from "./UserList";
|
||||||
import { LobbyChat } from "./LobbyChat";
|
import { LobbyChat } from "./LobbyChat";
|
||||||
|
import BotManager from "./BotManager";
|
||||||
import "./App.css";
|
import "./App.css";
|
||||||
import { ws_base, base } from "./Common";
|
import { ws_base, base } from "./Common";
|
||||||
import { Box, Button, Tooltip } from "@mui/material";
|
import { Box, Button, Tooltip } from "@mui/material";
|
||||||
@ -29,6 +30,15 @@ const LobbyView: React.FC<LobbyProps> = (props: LobbyProps) => {
|
|||||||
const [creatingLobby, setCreatingLobby] = useState<boolean>(false);
|
const [creatingLobby, setCreatingLobby] = useState<boolean>(false);
|
||||||
const [reconnectAttempt, setReconnectAttempt] = useState<number>(0);
|
const [reconnectAttempt, setReconnectAttempt] = useState<number>(0);
|
||||||
|
|
||||||
|
// Check if lobbyName looks like a lobby ID (32 hex characters) and redirect to default
|
||||||
|
useEffect(() => {
|
||||||
|
if (lobbyName && /^[a-f0-9]{32}$/i.test(lobbyName)) {
|
||||||
|
console.log(`Lobby - Detected lobby ID in URL (${lobbyName}), redirecting to default lobby`);
|
||||||
|
window.history.replaceState(null, "", `${base}/lobby/default`);
|
||||||
|
window.location.reload(); // Force reload to use the new URL
|
||||||
|
}
|
||||||
|
}, [lobbyName]);
|
||||||
|
|
||||||
const { sendJsonMessage, lastJsonMessage, readyState } = useWebSocket(socketUrl, {
|
const { sendJsonMessage, lastJsonMessage, readyState } = useWebSocket(socketUrl, {
|
||||||
onOpen: () => {
|
onOpen: () => {
|
||||||
console.log("app - WebSocket connection opened.");
|
console.log("app - WebSocket connection opened.");
|
||||||
@ -38,8 +48,24 @@ const LobbyView: React.FC<LobbyProps> = (props: LobbyProps) => {
|
|||||||
console.log("app - WebSocket connection closed.");
|
console.log("app - WebSocket connection closed.");
|
||||||
setReconnectAttempt((prev) => prev + 1);
|
setReconnectAttempt((prev) => prev + 1);
|
||||||
},
|
},
|
||||||
onError: (event: Event) => console.error("app - WebSocket error observed:", event),
|
onError: (event: Event) => {
|
||||||
shouldReconnect: (closeEvent) => true, // Will attempt to reconnect on all close events
|
console.error("app - WebSocket error observed:", event);
|
||||||
|
// If we get a WebSocket error, it might be due to invalid lobby ID
|
||||||
|
// Reset the lobby state to force recreation
|
||||||
|
if (lobby) {
|
||||||
|
console.log("app - WebSocket error, clearing lobby state to force refresh");
|
||||||
|
setLobby(null);
|
||||||
|
setSocketUrl(null);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
shouldReconnect: (closeEvent) => {
|
||||||
|
// Don't reconnect if the lobby doesn't exist (4xx errors)
|
||||||
|
if (closeEvent.code >= 4000 && closeEvent.code < 5000) {
|
||||||
|
console.log("app - WebSocket closed with client error, not reconnecting");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
},
|
||||||
reconnectInterval: 5000, // Retry every 5 seconds
|
reconnectInterval: 5000, // Retry every 5 seconds
|
||||||
onReconnectStop: (numAttempts) => {
|
onReconnectStop: (numAttempts) => {
|
||||||
console.log(`Stopped reconnecting after ${numAttempts} attempts`);
|
console.log(`Stopped reconnecting after ${numAttempts} attempts`);
|
||||||
@ -68,6 +94,13 @@ const LobbyView: React.FC<LobbyProps> = (props: LobbyProps) => {
|
|||||||
case "error":
|
case "error":
|
||||||
console.error(`Lobby - Server error: ${data.error}`);
|
console.error(`Lobby - Server error: ${data.error}`);
|
||||||
setError(data.error);
|
setError(data.error);
|
||||||
|
|
||||||
|
// If the error is about lobby not found, reset the lobby state
|
||||||
|
if (data.error && data.error.includes("Lobby not found")) {
|
||||||
|
console.log("Lobby - Lobby not found error, clearing lobby state");
|
||||||
|
setLobby(null);
|
||||||
|
setSocketUrl(null);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
@ -82,50 +115,65 @@ const LobbyView: React.FC<LobbyProps> = (props: LobbyProps) => {
|
|||||||
if (!session || !lobbyName || creatingLobby || (lobby && lobby.name === lobbyName)) {
|
if (!session || !lobbyName || creatingLobby || (lobby && lobby.name === lobbyName)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear any existing lobby state when switching to a new lobby name
|
||||||
|
if (lobby && lobby.name !== lobbyName) {
|
||||||
|
console.log(`Lobby - Clearing previous lobby state: ${lobby.name} -> ${lobbyName}`);
|
||||||
|
setLobby(null);
|
||||||
|
setSocketUrl(null);
|
||||||
|
}
|
||||||
|
|
||||||
const getLobby = async (lobbyName: string, session: Session) => {
|
const getLobby = async (lobbyName: string, session: Session) => {
|
||||||
const res = await fetch(`${base}/api/lobby/${session.id}`, {
|
try {
|
||||||
method: "POST",
|
const res = await fetch(`${base}/api/lobby/${session.id}`, {
|
||||||
cache: "no-cache",
|
method: "POST",
|
||||||
credentials: "same-origin",
|
cache: "no-cache",
|
||||||
headers: {
|
credentials: "same-origin",
|
||||||
"Content-Type": "application/json",
|
headers: {
|
||||||
},
|
"Content-Type": "application/json",
|
||||||
body: JSON.stringify({
|
|
||||||
type: "lobby_create",
|
|
||||||
data: {
|
|
||||||
name: lobbyName,
|
|
||||||
private: false,
|
|
||||||
},
|
},
|
||||||
}),
|
body: JSON.stringify({
|
||||||
});
|
type: "lobby_create",
|
||||||
|
data: {
|
||||||
|
name: lobbyName,
|
||||||
|
private: false,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
if (res.status >= 400) {
|
if (res.status >= 400) {
|
||||||
const error = `Unable to connect to AI Voice Chat server! Try refreshing your browser in a few seconds.`;
|
const error = `Unable to connect to AI Voice Chat server! Try refreshing your browser in a few seconds.`;
|
||||||
console.error(error);
|
console.error(error);
|
||||||
setError(error);
|
setError(error);
|
||||||
}
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const data = await res.json();
|
const data = await res.json();
|
||||||
if (data.error) {
|
if (data.error) {
|
||||||
console.error(`Lobby - Server error: ${data.error}`);
|
console.error(`Lobby - Server error: ${data.error}`);
|
||||||
setError(data.error);
|
setError(data.error);
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
if (data.type !== "lobby_created") {
|
||||||
|
console.error(`Lobby - Unexpected response type: ${data.type}`);
|
||||||
|
setError(`Unexpected response from server`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const lobby: Lobby = data.data;
|
||||||
|
console.log(`Lobby - Joined lobby`, lobby);
|
||||||
|
setLobby(lobby);
|
||||||
|
} catch (err) {
|
||||||
|
const errorMessage = err instanceof Error ? err.message : "Failed to create/join lobby";
|
||||||
|
console.error("Lobby creation error:", errorMessage);
|
||||||
|
setError(errorMessage);
|
||||||
}
|
}
|
||||||
if (data.type !== "lobby_created") {
|
|
||||||
console.error(`Lobby - Unexpected response type: ${data.type}`);
|
|
||||||
setError(`Unexpected response from server`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const lobby: Lobby = data.data;
|
|
||||||
console.log(`Lobby - Joined lobby`, lobby);
|
|
||||||
setLobby(lobby);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
setCreatingLobby(true);
|
setCreatingLobby(true);
|
||||||
getLobby(lobbyName, session).then(() => {
|
getLobby(lobbyName, session).finally(() => {
|
||||||
setCreatingLobby(false);
|
setCreatingLobby(false);
|
||||||
});
|
});
|
||||||
}, [session, lobbyName, setLobby, setError]);
|
}, [session, lobbyName, lobby, setLobby, setError]);
|
||||||
|
|
||||||
const setName = (name: string) => {
|
const setName = (name: string) => {
|
||||||
sendJsonMessage({
|
sendJsonMessage({
|
||||||
@ -216,6 +264,13 @@ const LobbyView: React.FC<LobbyProps> = (props: LobbyProps) => {
|
|||||||
{session && socketUrl && lobby && (
|
{session && socketUrl && lobby && (
|
||||||
<LobbyChat socketUrl={socketUrl} session={session} lobbyId={lobby.id} />
|
<LobbyChat socketUrl={socketUrl} session={session} lobbyId={lobby.id} />
|
||||||
)}
|
)}
|
||||||
|
{session && lobby && (
|
||||||
|
<BotManager
|
||||||
|
lobbyId={lobby.id}
|
||||||
|
onBotAdded={(botName) => console.log(`Bot ${botName} added to lobby`)}
|
||||||
|
sx={{ minWidth: "300px" }}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</Box>
|
</Box>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
303
client/src/BotManager.tsx
Normal file
303
client/src/BotManager.tsx
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
import React, { useState, useEffect } from "react";
|
||||||
|
import {
|
||||||
|
Paper,
|
||||||
|
Button,
|
||||||
|
List,
|
||||||
|
ListItem,
|
||||||
|
ListItemText,
|
||||||
|
Typography,
|
||||||
|
Box,
|
||||||
|
Chip,
|
||||||
|
IconButton,
|
||||||
|
Dialog,
|
||||||
|
DialogTitle,
|
||||||
|
DialogContent,
|
||||||
|
DialogActions,
|
||||||
|
TextField,
|
||||||
|
CircularProgress,
|
||||||
|
Alert,
|
||||||
|
Accordion,
|
||||||
|
AccordionSummary,
|
||||||
|
AccordionDetails,
|
||||||
|
} from "@mui/material";
|
||||||
|
import {
|
||||||
|
SmartToy as BotIcon,
|
||||||
|
Add as AddIcon,
|
||||||
|
ExpandMore as ExpandMoreIcon,
|
||||||
|
Refresh as RefreshIcon,
|
||||||
|
} from "@mui/icons-material";
|
||||||
|
import { botsApi, BotInfoModel, BotProviderModel, BotJoinLobbyRequest } from "./api-client";
|
||||||
|
|
||||||
|
interface BotManagerProps {
|
||||||
|
lobbyId: string;
|
||||||
|
onBotAdded?: (botName: string) => void;
|
||||||
|
sx?: any;
|
||||||
|
}
|
||||||
|
|
||||||
|
const BotManager: React.FC<BotManagerProps> = ({ lobbyId, onBotAdded, sx }) => {
|
||||||
|
const [bots, setBots] = useState<Record<string, BotInfoModel>>({});
|
||||||
|
const [providers, setProviders] = useState<Record<string, string>>({});
|
||||||
|
const [botProviders, setBotProviders] = useState<BotProviderModel[]>([]);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
const [addDialogOpen, setAddDialogOpen] = useState(false);
|
||||||
|
const [selectedBot, setSelectedBot] = useState<string>("");
|
||||||
|
const [botNick, setBotNick] = useState("");
|
||||||
|
const [addingBot, setAddingBot] = useState(false);
|
||||||
|
|
||||||
|
const loadBots = async () => {
|
||||||
|
setLoading(true);
|
||||||
|
setError(null);
|
||||||
|
try {
|
||||||
|
const [botsResponse, providersResponse] = await Promise.all([
|
||||||
|
botsApi.getAvailable(),
|
||||||
|
botsApi.getProviders(),
|
||||||
|
]);
|
||||||
|
|
||||||
|
setBots(botsResponse.bots);
|
||||||
|
setProviders(botsResponse.providers);
|
||||||
|
setBotProviders(providersResponse.providers);
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Failed to load bots:", err);
|
||||||
|
setError("Failed to load available bots");
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
loadBots();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleAddBot = async () => {
|
||||||
|
if (!selectedBot) return;
|
||||||
|
|
||||||
|
setAddingBot(true);
|
||||||
|
try {
|
||||||
|
const request: BotJoinLobbyRequest = {
|
||||||
|
bot_name: selectedBot,
|
||||||
|
lobby_id: lobbyId,
|
||||||
|
nick: botNick || `${selectedBot}-bot`,
|
||||||
|
provider_id: providers[selectedBot],
|
||||||
|
};
|
||||||
|
|
||||||
|
const response = await botsApi.requestJoinLobby(selectedBot, request);
|
||||||
|
|
||||||
|
if (response.status === "requested") {
|
||||||
|
setAddDialogOpen(false);
|
||||||
|
setSelectedBot("");
|
||||||
|
setBotNick("");
|
||||||
|
onBotAdded?.(selectedBot);
|
||||||
|
|
||||||
|
// Show success feedback could be added here
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Failed to add bot:", err);
|
||||||
|
setError("Failed to add bot to lobby");
|
||||||
|
} finally {
|
||||||
|
setAddingBot(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleOpenAddDialog = () => {
|
||||||
|
setAddDialogOpen(true);
|
||||||
|
setError(null);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCloseAddDialog = () => {
|
||||||
|
setAddDialogOpen(false);
|
||||||
|
setSelectedBot("");
|
||||||
|
setBotNick("");
|
||||||
|
setError(null);
|
||||||
|
};
|
||||||
|
|
||||||
|
const getProviderName = (providerId: string): string => {
|
||||||
|
const provider = botProviders.find(p => p.provider_id === providerId);
|
||||||
|
return provider ? provider.name : "Unknown Provider";
|
||||||
|
};
|
||||||
|
|
||||||
|
const botCount = Object.keys(bots).length;
|
||||||
|
const providerCount = botProviders.length;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Paper sx={{ p: 2, ...sx }}>
|
||||||
|
<Box display="flex" alignItems="center" justifyContent="space-between" mb={2}>
|
||||||
|
<Typography variant="h6" display="flex" alignItems="center" gap={1}>
|
||||||
|
<BotIcon />
|
||||||
|
AI Bots
|
||||||
|
</Typography>
|
||||||
|
<Box display="flex" gap={1}>
|
||||||
|
<IconButton onClick={loadBots} disabled={loading} size="small">
|
||||||
|
<RefreshIcon />
|
||||||
|
</IconButton>
|
||||||
|
<Button
|
||||||
|
variant="contained"
|
||||||
|
size="small"
|
||||||
|
startIcon={<AddIcon />}
|
||||||
|
onClick={handleOpenAddDialog}
|
||||||
|
disabled={loading || botCount === 0}
|
||||||
|
>
|
||||||
|
Add Bot
|
||||||
|
</Button>
|
||||||
|
</Box>
|
||||||
|
</Box>
|
||||||
|
|
||||||
|
{error && (
|
||||||
|
<Alert severity="error" sx={{ mb: 2 }}>
|
||||||
|
{error}
|
||||||
|
</Alert>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{loading ? (
|
||||||
|
<Box display="flex" justifyContent="center" p={2}>
|
||||||
|
<CircularProgress size={24} />
|
||||||
|
</Box>
|
||||||
|
) : (
|
||||||
|
<Box>
|
||||||
|
<Typography variant="body2" color="text.secondary" mb={1}>
|
||||||
|
{botCount} bots available from {providerCount} providers
|
||||||
|
</Typography>
|
||||||
|
|
||||||
|
{botCount === 0 ? (
|
||||||
|
<Typography variant="body2" color="text.secondary" style={{ fontStyle: "italic" }}>
|
||||||
|
No bots available. Make sure bot providers are registered and running.
|
||||||
|
</Typography>
|
||||||
|
) : (
|
||||||
|
<Accordion>
|
||||||
|
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||||
|
<Typography variant="subtitle2">Available Bots ({botCount})</Typography>
|
||||||
|
</AccordionSummary>
|
||||||
|
<AccordionDetails>
|
||||||
|
<List dense>
|
||||||
|
{Object.entries(bots).map(([botName, botInfo]) => {
|
||||||
|
const providerId = providers[botName];
|
||||||
|
const providerName = getProviderName(providerId);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ListItem key={botName}>
|
||||||
|
<ListItemText
|
||||||
|
primary={botInfo.name}
|
||||||
|
secondary={
|
||||||
|
<Box>
|
||||||
|
<Typography variant="body2" component="div">
|
||||||
|
{botInfo.description}
|
||||||
|
</Typography>
|
||||||
|
<Chip
|
||||||
|
label={providerName}
|
||||||
|
size="small"
|
||||||
|
variant="outlined"
|
||||||
|
sx={{ mt: 0.5 }}
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</ListItem>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</List>
|
||||||
|
</AccordionDetails>
|
||||||
|
</Accordion>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{providerCount > 0 && (
|
||||||
|
<Accordion>
|
||||||
|
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||||
|
<Typography variant="subtitle2">Bot Providers ({providerCount})</Typography>
|
||||||
|
</AccordionSummary>
|
||||||
|
<AccordionDetails>
|
||||||
|
<List dense>
|
||||||
|
{botProviders.map((provider) => (
|
||||||
|
<ListItem key={provider.provider_id}>
|
||||||
|
<ListItemText
|
||||||
|
primary={provider.name}
|
||||||
|
secondary={
|
||||||
|
<Box>
|
||||||
|
<Typography variant="body2" component="div">
|
||||||
|
{provider.description || "No description"}
|
||||||
|
</Typography>
|
||||||
|
<Typography variant="caption" color="text.secondary">
|
||||||
|
{provider.base_url}
|
||||||
|
</Typography>
|
||||||
|
</Box>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</ListItem>
|
||||||
|
))}
|
||||||
|
</List>
|
||||||
|
</AccordionDetails>
|
||||||
|
</Accordion>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Add Bot Dialog */}
|
||||||
|
<Dialog open={addDialogOpen} onClose={handleCloseAddDialog} maxWidth="sm" fullWidth>
|
||||||
|
<DialogTitle>Add Bot to Lobby</DialogTitle>
|
||||||
|
<DialogContent>
|
||||||
|
{error && (
|
||||||
|
<Alert severity="error" sx={{ mb: 2 }}>
|
||||||
|
{error}
|
||||||
|
</Alert>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<Box sx={{ mt: 1 }}>
|
||||||
|
<Typography variant="subtitle2" gutterBottom>
|
||||||
|
Select Bot
|
||||||
|
</Typography>
|
||||||
|
<List>
|
||||||
|
{Object.entries(bots).map(([botName, botInfo]) => (
|
||||||
|
<ListItem
|
||||||
|
key={botName}
|
||||||
|
component="div"
|
||||||
|
sx={{
|
||||||
|
cursor: "pointer",
|
||||||
|
backgroundColor: selectedBot === botName ? "action.selected" : "transparent",
|
||||||
|
"&:hover": {
|
||||||
|
backgroundColor: "action.hover",
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
onClick={() => setSelectedBot(botName)}
|
||||||
|
>
|
||||||
|
<ListItemText
|
||||||
|
primary={botInfo.name}
|
||||||
|
secondary={botInfo.description}
|
||||||
|
/>
|
||||||
|
<Chip
|
||||||
|
label={getProviderName(providers[botName])}
|
||||||
|
size="small"
|
||||||
|
variant="outlined"
|
||||||
|
/>
|
||||||
|
</ListItem>
|
||||||
|
))}
|
||||||
|
</List>
|
||||||
|
|
||||||
|
{selectedBot && (
|
||||||
|
<TextField
|
||||||
|
label="Bot Nickname (optional)"
|
||||||
|
value={botNick}
|
||||||
|
onChange={(e) => setBotNick(e.target.value)}
|
||||||
|
fullWidth
|
||||||
|
margin="normal"
|
||||||
|
placeholder={`${selectedBot}-bot`}
|
||||||
|
helperText="Custom name for this bot instance in the lobby"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
</DialogContent>
|
||||||
|
<DialogActions>
|
||||||
|
<Button onClick={handleCloseAddDialog}>Cancel</Button>
|
||||||
|
<Button
|
||||||
|
onClick={handleAddBot}
|
||||||
|
variant="contained"
|
||||||
|
disabled={!selectedBot || addingBot}
|
||||||
|
startIcon={addingBot ? <CircularProgress size={16} /> : <AddIcon />}
|
||||||
|
>
|
||||||
|
{addingBot ? "Adding..." : "Add Bot"}
|
||||||
|
</Button>
|
||||||
|
</DialogActions>
|
||||||
|
</Dialog>
|
||||||
|
</Paper>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default BotManager;
|
@ -18,6 +18,44 @@ export type SessionResponse = components['schemas']['SessionResponse'];
|
|||||||
export type LobbyCreateRequest = components['schemas']['LobbyCreateRequest'];
|
export type LobbyCreateRequest = components['schemas']['LobbyCreateRequest'];
|
||||||
export type LobbyCreateResponse = components['schemas']['LobbyCreateResponse'];
|
export type LobbyCreateResponse = components['schemas']['LobbyCreateResponse'];
|
||||||
|
|
||||||
|
// Bot Provider Types (manually defined until API types are regenerated)
|
||||||
|
export interface BotInfoModel {
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BotProviderModel {
|
||||||
|
provider_id: string;
|
||||||
|
base_url: string;
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
registered_at: number;
|
||||||
|
last_seen: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BotProviderListResponse {
|
||||||
|
providers: BotProviderModel[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BotListResponse {
|
||||||
|
bots: Record<string, BotInfoModel>;
|
||||||
|
providers: Record<string, string>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BotJoinLobbyRequest {
|
||||||
|
bot_name: string;
|
||||||
|
lobby_id: string;
|
||||||
|
nick?: string;
|
||||||
|
provider_id?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BotJoinLobbyResponse {
|
||||||
|
status: string;
|
||||||
|
bot_name: string;
|
||||||
|
run_id: string;
|
||||||
|
provider_id: string;
|
||||||
|
}
|
||||||
|
|
||||||
export class ApiError extends Error {
|
export class ApiError extends Error {
|
||||||
constructor(
|
constructor(
|
||||||
public status: number,
|
public status: number,
|
||||||
@ -34,15 +72,18 @@ export class ApiClient {
|
|||||||
private defaultHeaders: Record<string, string>;
|
private defaultHeaders: Record<string, string>;
|
||||||
|
|
||||||
constructor(baseURL?: string) {
|
constructor(baseURL?: string) {
|
||||||
this.baseURL = baseURL || process.env.REACT_APP_API_URL || 'http://localhost:8001';
|
this.baseURL = baseURL || process.env.REACT_APP_API_URL || "http://localhost:8001";
|
||||||
this.defaultHeaders = {};
|
this.defaultHeaders = {};
|
||||||
}
|
}
|
||||||
|
|
||||||
private async request<T>(path: string, options: {
|
private async request<T>(
|
||||||
method: string;
|
path: string,
|
||||||
body?: any;
|
options: {
|
||||||
params?: Record<string, string>;
|
method: string;
|
||||||
}): Promise<T> {
|
body?: any;
|
||||||
|
params?: Record<string, string>;
|
||||||
|
}
|
||||||
|
): Promise<T> {
|
||||||
const url = new URL(path, this.baseURL);
|
const url = new URL(path, this.baseURL);
|
||||||
|
|
||||||
if (options.params) {
|
if (options.params) {
|
||||||
@ -54,12 +95,12 @@ export class ApiClient {
|
|||||||
const requestInit: RequestInit = {
|
const requestInit: RequestInit = {
|
||||||
method: options.method,
|
method: options.method,
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
"Content-Type": "application/json",
|
||||||
...this.defaultHeaders,
|
...this.defaultHeaders,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
if (options.body && options.method !== 'GET') {
|
if (options.body && options.method !== "GET") {
|
||||||
requestInit.body = JSON.stringify(options.body);
|
requestInit.body = JSON.stringify(options.body);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,8 +116,8 @@ export class ApiClient {
|
|||||||
throw new ApiError(response.status, response.statusText, errorData);
|
throw new ApiError(response.status, response.statusText, errorData);
|
||||||
}
|
}
|
||||||
|
|
||||||
const contentType = response.headers.get('content-type');
|
const contentType = response.headers.get("content-type");
|
||||||
if (contentType && contentType.includes('application/json')) {
|
if (contentType && contentType.includes("application/json")) {
|
||||||
return response.json();
|
return response.json();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,34 +126,50 @@ export class ApiClient {
|
|||||||
|
|
||||||
// Admin API methods
|
// Admin API methods
|
||||||
async adminListNames(): Promise<AdminNamesResponse> {
|
async adminListNames(): Promise<AdminNamesResponse> {
|
||||||
return this.request<AdminNamesResponse>('/ai-voicebot/api/admin/names', { method: 'GET' });
|
return this.request<AdminNamesResponse>("/ai-voicebot/api/admin/names", { method: "GET" });
|
||||||
}
|
}
|
||||||
|
|
||||||
async adminSetPassword(data: AdminSetPassword): Promise<AdminActionResponse> {
|
async adminSetPassword(data: AdminSetPassword): Promise<AdminActionResponse> {
|
||||||
return this.request<AdminActionResponse>('/ai-voicebot/api/admin/set_password', { method: 'POST', body: data });
|
return this.request<AdminActionResponse>("/ai-voicebot/api/admin/set_password", { method: "POST", body: data });
|
||||||
}
|
}
|
||||||
|
|
||||||
async adminClearPassword(data: AdminClearPassword): Promise<AdminActionResponse> {
|
async adminClearPassword(data: AdminClearPassword): Promise<AdminActionResponse> {
|
||||||
return this.request<AdminActionResponse>('/ai-voicebot/api/admin/clear_password', { method: 'POST', body: data });
|
return this.request<AdminActionResponse>("/ai-voicebot/api/admin/clear_password", { method: "POST", body: data });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Health check
|
// Health check
|
||||||
async healthCheck(): Promise<HealthResponse> {
|
async healthCheck(): Promise<HealthResponse> {
|
||||||
return this.request<HealthResponse>('/ai-voicebot/api/health', { method: 'GET' });
|
return this.request<HealthResponse>("/ai-voicebot/api/health", { method: "GET" });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Session methods
|
// Session methods
|
||||||
async getSession(): Promise<SessionResponse> {
|
async getSession(): Promise<SessionResponse> {
|
||||||
return this.request<SessionResponse>('/ai-voicebot/api/session', { method: 'GET' });
|
return this.request<SessionResponse>("/ai-voicebot/api/session", { method: "GET" });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lobby methods
|
// Lobby methods
|
||||||
async getLobbies(): Promise<LobbiesResponse> {
|
async getLobbies(): Promise<LobbiesResponse> {
|
||||||
return this.request<LobbiesResponse>('/ai-voicebot/api/lobby', { method: 'GET' });
|
return this.request<LobbiesResponse>("/ai-voicebot/api/lobby", { method: "GET" });
|
||||||
}
|
}
|
||||||
|
|
||||||
async createLobby(sessionId: string, data: LobbyCreateRequest): Promise<LobbyCreateResponse> {
|
async createLobby(sessionId: string, data: LobbyCreateRequest): Promise<LobbyCreateResponse> {
|
||||||
return this.request<LobbyCreateResponse>(`/ai-voicebot/api/lobby/${sessionId}`, { method: 'POST', body: data });
|
return this.request<LobbyCreateResponse>(`/ai-voicebot/api/lobby/${sessionId}`, { method: "POST", body: data });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bot Provider methods
|
||||||
|
async getBotProviders(): Promise<BotProviderListResponse> {
|
||||||
|
return this.request<BotProviderListResponse>("/ai-voicebot/api/bots/providers", { method: "GET" });
|
||||||
|
}
|
||||||
|
|
||||||
|
async getAvailableBots(): Promise<BotListResponse> {
|
||||||
|
return this.request<BotListResponse>("/ai-voicebot/api/bots", { method: "GET" });
|
||||||
|
}
|
||||||
|
|
||||||
|
async requestBotJoinLobby(botName: string, request: BotJoinLobbyRequest): Promise<BotJoinLobbyResponse> {
|
||||||
|
return this.request<BotJoinLobbyResponse>(`/ai-voicebot/api/bots/${encodeURIComponent(botName)}/join`, {
|
||||||
|
method: "POST",
|
||||||
|
body: request,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -313,6 +370,12 @@ export const sessionsApi = {
|
|||||||
createLobby: (sessionId: string, data: LobbyCreateRequest) => apiClient.createLobby(sessionId, data),
|
createLobby: (sessionId: string, data: LobbyCreateRequest) => apiClient.createLobby(sessionId, data),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const botsApi = {
|
||||||
|
getProviders: () => apiClient.getBotProviders(),
|
||||||
|
getAvailable: () => apiClient.getAvailableBots(),
|
||||||
|
requestJoinLobby: (botName: string, request: BotJoinLobbyRequest) => apiClient.requestBotJoinLobby(botName, request),
|
||||||
|
};
|
||||||
|
|
||||||
// Automatically check for API evolution when this module is loaded
|
// Automatically check for API evolution when this module is loaded
|
||||||
// This will warn developers if new endpoints are available but not implemented
|
// This will warn developers if new endpoints are available but not implemented
|
||||||
if (process.env.NODE_ENV === 'development') {
|
if (process.env.NODE_ENV === 'development') {
|
||||||
|
@ -13,6 +13,7 @@ services:
|
|||||||
- "3456:3000"
|
- "3456:3000"
|
||||||
restart: no
|
restart: no
|
||||||
volumes:
|
volumes:
|
||||||
|
- ./server:/server:ro # So the frontend can read the OpenAPI spec
|
||||||
- ./client:/client:rw
|
- ./client:/client:rw
|
||||||
- ./dev-keys:/keys:ro # So the frontend entrypoint can check for SSL files
|
- ./dev-keys:/keys:ro # So the frontend entrypoint can check for SSL files
|
||||||
networks:
|
networks:
|
||||||
@ -61,7 +62,8 @@ services:
|
|||||||
- ./.env
|
- ./.env
|
||||||
environment:
|
environment:
|
||||||
- PRODUCTION=${PRODUCTION:-false}
|
- PRODUCTION=${PRODUCTION:-false}
|
||||||
restart: always
|
- VOICEBOT_MODE=provider
|
||||||
|
restart: unless-stopped
|
||||||
network_mode: host
|
network_mode: host
|
||||||
volumes:
|
volumes:
|
||||||
- ./cache:/root/.cache:rw
|
- ./cache:/root/.cache:rw
|
||||||
|
287
server/main.py
287
server/main.py
@ -1,9 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Any, Optional, TypedDict
|
from typing import Any, Optional
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
Body,
|
Body,
|
||||||
Cookie,
|
Cookie,
|
||||||
FastAPI,
|
FastAPI,
|
||||||
|
HTTPException,
|
||||||
Path,
|
Path,
|
||||||
WebSocket,
|
WebSocket,
|
||||||
Request,
|
Request,
|
||||||
@ -45,12 +46,26 @@ from shared.models import (
|
|||||||
JoinStatusModel,
|
JoinStatusModel,
|
||||||
ChatMessageModel,
|
ChatMessageModel,
|
||||||
ChatMessagesResponse,
|
ChatMessagesResponse,
|
||||||
|
ParticipantModel,
|
||||||
|
# Bot provider models
|
||||||
|
BotProviderModel,
|
||||||
|
BotProviderRegisterRequest,
|
||||||
|
BotProviderRegisterResponse,
|
||||||
|
BotProviderListResponse,
|
||||||
|
BotListResponse,
|
||||||
|
BotInfoModel,
|
||||||
|
BotJoinLobbyRequest,
|
||||||
|
BotJoinLobbyResponse,
|
||||||
|
BotJoinPayload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Mapping of reserved names to password records (lowercased name -> {salt:..., hash:...})
|
# Mapping of reserved names to password records (lowercased name -> {salt:..., hash:...})
|
||||||
name_passwords: dict[str, dict[str, str]] = {}
|
name_passwords: dict[str, dict[str, str]] = {}
|
||||||
|
|
||||||
|
# Bot provider registry: provider_id -> BotProviderModel
|
||||||
|
bot_providers: dict[str, BotProviderModel] = {}
|
||||||
|
|
||||||
all_label = "[ all ]"
|
all_label = "[ all ]"
|
||||||
info_label = "[ info ]"
|
info_label = "[ info ]"
|
||||||
todo_label = "[ todo ]"
|
todo_label = "[ todo ]"
|
||||||
@ -175,12 +190,6 @@ def admin_cleanup_sessions(request: Request):
|
|||||||
lobbies: dict[str, Lobby] = {}
|
lobbies: dict[str, Lobby] = {}
|
||||||
|
|
||||||
|
|
||||||
class LobbyResponse(TypedDict):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
private: bool
|
|
||||||
|
|
||||||
|
|
||||||
class Lobby:
|
class Lobby:
|
||||||
def __init__(self, name: str, id: str | None = None, private: bool = False):
|
def __init__(self, name: str, id: str | None = None, private: bool = False):
|
||||||
self.id = secrets.token_hex(16) if id is None else id
|
self.id = secrets.token_hex(16) if id is None else id
|
||||||
@ -194,15 +203,15 @@ class Lobby:
|
|||||||
return f"{self.short}:{self.name}"
|
return f"{self.short}:{self.name}"
|
||||||
|
|
||||||
async def update_state(self, requesting_session: Session | None = None):
|
async def update_state(self, requesting_session: Session | None = None):
|
||||||
users: list[dict[str, str | bool]] = [
|
users: list[ParticipantModel] = [
|
||||||
{
|
ParticipantModel(
|
||||||
"name": s.name,
|
name=s.name,
|
||||||
"live": True if s.ws else False,
|
live=True if s.ws else False,
|
||||||
"session_id": s.id,
|
session_id=s.id,
|
||||||
"protected": True
|
protected=True
|
||||||
if s.name and s.name.lower() in name_passwords
|
if s.name and s.name.lower() in name_passwords
|
||||||
else False,
|
else False,
|
||||||
}
|
)
|
||||||
for s in self.sessions.values()
|
for s in self.sessions.values()
|
||||||
if s.name
|
if s.name
|
||||||
]
|
]
|
||||||
@ -212,7 +221,10 @@ class Lobby:
|
|||||||
)
|
)
|
||||||
if requesting_session.ws:
|
if requesting_session.ws:
|
||||||
await requesting_session.ws.send_json(
|
await requesting_session.ws.send_json(
|
||||||
{"type": "lobby_state", "data": {"participants": users}}
|
{
|
||||||
|
"type": "lobby_state",
|
||||||
|
"data": {"participants": [user.model_dump() for user in users]},
|
||||||
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -223,7 +235,12 @@ class Lobby:
|
|||||||
logger.info(f"{s.getName()} -> lobby_state({self.getName()})")
|
logger.info(f"{s.getName()} -> lobby_state({self.getName()})")
|
||||||
if s.ws:
|
if s.ws:
|
||||||
await s.ws.send_json(
|
await s.ws.send_json(
|
||||||
{"type": "lobby_state", "data": {"participants": users}}
|
{
|
||||||
|
"type": "lobby_state",
|
||||||
|
"data": {
|
||||||
|
"participants": [user.model_dump() for user in users]
|
||||||
|
},
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def getSession(self, id: str) -> Session | None:
|
def getSession(self, id: str) -> Session | None:
|
||||||
@ -356,13 +373,46 @@ class Session:
|
|||||||
# rec is a NamePasswordRecord
|
# rec is a NamePasswordRecord
|
||||||
name_passwords[name] = {"salt": rec.salt, "hash": rec.hash}
|
name_passwords[name] = {"salt": rec.salt, "hash": rec.hash}
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
one_minute = 60.0
|
||||||
|
three_hours = 3 * 60 * 60.0
|
||||||
|
sessions_loaded = 0
|
||||||
|
sessions_expired = 0
|
||||||
|
|
||||||
for s_saved in payload.sessions:
|
for s_saved in payload.sessions:
|
||||||
|
# Check if this session should be expired during loading
|
||||||
|
created_at = getattr(s_saved, "created_at", time.time())
|
||||||
|
last_used = getattr(s_saved, "last_used", time.time())
|
||||||
|
displaced_at = getattr(s_saved, "displaced_at", None)
|
||||||
|
name = s_saved.name or ""
|
||||||
|
|
||||||
|
# Apply same removal criteria as cleanup_old_sessions
|
||||||
|
should_expire = False
|
||||||
|
|
||||||
|
# Rule 1: Sessions with no name that are older than 1 minute (no connection assumed for disk sessions)
|
||||||
|
if not name and current_time - created_at > one_minute:
|
||||||
|
should_expire = True
|
||||||
|
logger.info(
|
||||||
|
f"Expiring session {s_saved.id[:8]} during load - no name, older than 1 minute"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rule 2: Displaced sessions unused for 3+ hours (no connection assumed for disk sessions)
|
||||||
|
elif displaced_at is not None and current_time - last_used > three_hours:
|
||||||
|
should_expire = True
|
||||||
|
logger.info(
|
||||||
|
f"Expiring session {s_saved.id[:8]}:{name} during load - displaced and unused for 3+ hours"
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_expire:
|
||||||
|
sessions_expired += 1
|
||||||
|
continue # Skip loading this expired session
|
||||||
|
|
||||||
session = Session(s_saved.id)
|
session = Session(s_saved.id)
|
||||||
session.name = s_saved.name or ""
|
session.name = name
|
||||||
# Load timestamps, with defaults for backward compatibility
|
# Load timestamps, with defaults for backward compatibility
|
||||||
session.created_at = getattr(s_saved, "created_at", time.time())
|
session.created_at = created_at
|
||||||
session.last_used = getattr(s_saved, "last_used", time.time())
|
session.last_used = last_used
|
||||||
session.displaced_at = getattr(s_saved, "displaced_at", None)
|
session.displaced_at = displaced_at
|
||||||
for lobby_saved in s_saved.lobbies:
|
for lobby_saved in s_saved.lobbies:
|
||||||
session.lobbies.append(
|
session.lobbies.append(
|
||||||
Lobby(
|
Lobby(
|
||||||
@ -378,10 +428,15 @@ class Session:
|
|||||||
lobbies[lobby.id] = Lobby(
|
lobbies[lobby.id] = Lobby(
|
||||||
name=lobby.name, id=lobby.id
|
name=lobby.name, id=lobby.id
|
||||||
) # Ensure lobby exists
|
) # Ensure lobby exists
|
||||||
|
sessions_loaded += 1
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded {len(payload.sessions)} sessions and {len(name_passwords)} name passwords from {cls._save_file}"
|
f"Loaded {sessions_loaded} sessions and {len(name_passwords)} name passwords from {cls._save_file}"
|
||||||
)
|
)
|
||||||
|
if sessions_expired > 0:
|
||||||
|
logger.info(f"Expired {sessions_expired} old sessions during load")
|
||||||
|
# Save immediately to persist the cleanup
|
||||||
|
cls.save()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def getSession(cls, id: str) -> Session | None:
|
def getSession(cls, id: str) -> Session | None:
|
||||||
@ -440,7 +495,7 @@ class Session:
|
|||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
one_minute = 60.0
|
one_minute = 60.0
|
||||||
twenty_four_hours = 24 * 60 * 60.0
|
three_hours = 3 * 60 * 60.0
|
||||||
sessions_removed = 0
|
sessions_removed = 0
|
||||||
|
|
||||||
# Make a copy of the list to avoid modifying it while iterating
|
# Make a copy of the list to avoid modifying it while iterating
|
||||||
@ -459,14 +514,14 @@ class Session:
|
|||||||
sessions_to_remove.append(session)
|
sessions_to_remove.append(session)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Rule 2: Delete inactive sessions that had their nick taken over and haven't been used in 24 hours
|
# Rule 2: Delete inactive sessions that had their nick taken over and haven't been used in 3 hours
|
||||||
if (
|
if (
|
||||||
not session.ws
|
not session.ws
|
||||||
and session.displaced_at is not None
|
and session.displaced_at is not None
|
||||||
and current_time - session.last_used > twenty_four_hours
|
and current_time - session.last_used > three_hours
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Removing session {session.getName()} - displaced and unused for 24+ hours"
|
f"Removing session {session.getName()} - displaced and unused for 3+ hours"
|
||||||
)
|
)
|
||||||
sessions_to_remove.append(session)
|
sessions_to_remove.append(session)
|
||||||
continue
|
continue
|
||||||
@ -493,10 +548,23 @@ class Session:
|
|||||||
cls._instances.remove(session)
|
cls._instances.remove(session)
|
||||||
sessions_removed += 1
|
sessions_removed += 1
|
||||||
|
|
||||||
|
# Clean up empty lobbies from global lobbies dict
|
||||||
|
empty_lobbies: list[str] = []
|
||||||
|
for lobby_id, lobby in lobbies.items():
|
||||||
|
if len(lobby.sessions) == 0:
|
||||||
|
empty_lobbies.append(lobby_id)
|
||||||
|
|
||||||
|
for lobby_id in empty_lobbies:
|
||||||
|
del lobbies[lobby_id]
|
||||||
|
logger.info(f"Removed empty lobby {lobby_id}")
|
||||||
|
|
||||||
if sessions_removed > 0:
|
if sessions_removed > 0:
|
||||||
cls.save()
|
cls.save()
|
||||||
logger.info(f"Session cleanup: removed {sessions_removed} old sessions")
|
logger.info(f"Session cleanup: removed {sessions_removed} old sessions")
|
||||||
|
|
||||||
|
if empty_lobbies:
|
||||||
|
logger.info(f"Session cleanup: removed {len(empty_lobbies)} empty lobbies")
|
||||||
|
|
||||||
return sessions_removed
|
return sessions_removed
|
||||||
|
|
||||||
async def join(self, lobby: Lobby):
|
async def join(self, lobby: Lobby):
|
||||||
@ -784,6 +852,175 @@ async def get_chat_messages(
|
|||||||
messages = lobby.get_chat_messages(limit)
|
messages = lobby.get_chat_messages(limit)
|
||||||
|
|
||||||
return ChatMessagesResponse(messages=messages)
|
return ChatMessagesResponse(messages=messages)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Bot Provider API Endpoints
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
public_url + "api/bots/providers/register",
|
||||||
|
response_model=BotProviderRegisterResponse,
|
||||||
|
)
|
||||||
|
async def register_bot_provider(
|
||||||
|
request: BotProviderRegisterRequest,
|
||||||
|
) -> BotProviderRegisterResponse:
|
||||||
|
"""Register a new bot provider"""
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
provider_id = str(uuid.uuid4())
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
provider = BotProviderModel(
|
||||||
|
provider_id=provider_id,
|
||||||
|
base_url=request.base_url.rstrip("/"),
|
||||||
|
name=request.name,
|
||||||
|
description=request.description,
|
||||||
|
registered_at=now,
|
||||||
|
last_seen=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
bot_providers[provider_id] = provider
|
||||||
|
logger.info(f"Registered bot provider: {request.name} at {request.base_url}")
|
||||||
|
|
||||||
|
return BotProviderRegisterResponse(provider_id=provider_id)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get(public_url + "api/bots/providers", response_model=BotProviderListResponse)
|
||||||
|
async def list_bot_providers() -> BotProviderListResponse:
|
||||||
|
"""List all registered bot providers"""
|
||||||
|
return BotProviderListResponse(providers=list(bot_providers.values()))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get(public_url + "api/bots", response_model=BotListResponse)
|
||||||
|
async def list_available_bots() -> BotListResponse:
|
||||||
|
"""List all available bots from all registered providers"""
|
||||||
|
bots: dict[str, BotInfoModel] = {}
|
||||||
|
providers: dict[str, str] = {}
|
||||||
|
|
||||||
|
# Update last_seen timestamps and fetch bots from each provider
|
||||||
|
for provider_id, provider in bot_providers.items():
|
||||||
|
try:
|
||||||
|
import time
|
||||||
|
|
||||||
|
provider.last_seen = time.time()
|
||||||
|
|
||||||
|
# Make HTTP request to provider's /bots endpoint
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(f"{provider.base_url}/bots", timeout=5.0)
|
||||||
|
if response.status_code == 200:
|
||||||
|
provider_bots = response.json()
|
||||||
|
# provider_bots should be a dict of bot_name -> bot_info
|
||||||
|
for bot_name, bot_info in provider_bots.items():
|
||||||
|
bots[bot_name] = BotInfoModel(**bot_info)
|
||||||
|
providers[bot_name] = provider_id
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to fetch bots from provider {provider.name}: HTTP {response.status_code}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching bots from provider {provider.name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return BotListResponse(bots=bots, providers=providers)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(public_url + "api/bots/{bot_name}/join", response_model=BotJoinLobbyResponse)
|
||||||
|
async def request_bot_join_lobby(
|
||||||
|
bot_name: str, request: BotJoinLobbyRequest
|
||||||
|
) -> BotJoinLobbyResponse:
|
||||||
|
"""Request a bot to join a specific lobby"""
|
||||||
|
|
||||||
|
# Find which provider has this bot
|
||||||
|
target_provider_id = request.provider_id
|
||||||
|
if not target_provider_id:
|
||||||
|
# Auto-discover provider for this bot
|
||||||
|
for provider_id, provider in bot_providers.items():
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{provider.base_url}/bots", timeout=5.0
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
provider_bots = response.json()
|
||||||
|
if bot_name in provider_bots:
|
||||||
|
target_provider_id = provider_id
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not target_provider_id or target_provider_id not in bot_providers:
|
||||||
|
raise HTTPException(status_code=404, detail="Bot or provider not found")
|
||||||
|
|
||||||
|
provider = bot_providers[target_provider_id]
|
||||||
|
|
||||||
|
# Get the lobby to validate it exists
|
||||||
|
try:
|
||||||
|
getLobby(request.lobby_id) # Just validate it exists
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=404, detail="Lobby not found")
|
||||||
|
|
||||||
|
# Create a session for the bot
|
||||||
|
bot_session_id = secrets.token_hex(16)
|
||||||
|
|
||||||
|
# Determine server URL for the bot to connect back to
|
||||||
|
# Use the server's public URL or construct from request
|
||||||
|
server_base_url = os.getenv("PUBLIC_SERVER_URL", "http://localhost:8000")
|
||||||
|
if server_base_url.endswith("/"):
|
||||||
|
server_base_url = server_base_url[:-1]
|
||||||
|
|
||||||
|
bot_nick = request.nick or f"{bot_name}-bot"
|
||||||
|
|
||||||
|
# Prepare the join request for the bot provider
|
||||||
|
bot_join_payload = BotJoinPayload(
|
||||||
|
lobby_id=request.lobby_id,
|
||||||
|
session_id=bot_session_id,
|
||||||
|
nick=bot_nick,
|
||||||
|
server_url=f"{server_base_url}{public_url}".rstrip("/"),
|
||||||
|
insecure=False, # Assume secure by default
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Make request to bot provider
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{provider.base_url}/bots/{bot_name}/join",
|
||||||
|
json=bot_join_payload.model_dump(),
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
run_id = result.get("run_id", "unknown")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Bot {bot_name} requested to join lobby {request.lobby_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return BotJoinLobbyResponse(
|
||||||
|
status="requested",
|
||||||
|
bot_name=bot_name,
|
||||||
|
run_id=run_id,
|
||||||
|
provider_id=target_provider_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Bot provider returned error: HTTP {response.status_code}: {response.text}"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail=f"Bot provider error: {response.status_code}",
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise HTTPException(status_code=504, detail="Bot provider timeout")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error requesting bot join: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# Register websocket endpoint directly on app with full public_url path
|
# Register websocket endpoint directly on app with full public_url path
|
||||||
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}")
|
@app.websocket(f"{public_url}" + "ws/lobby/{lobby_id}/{session_id}")
|
||||||
async def lobby_join(
|
async def lobby_join(
|
||||||
|
@ -41,7 +41,8 @@ class ParticipantModel(BaseModel):
|
|||||||
"""Represents a participant in a lobby/session"""
|
"""Represents a participant in a lobby/session"""
|
||||||
name: str
|
name: str
|
||||||
session_id: str
|
session_id: str
|
||||||
# Add other participant fields as needed based on actual data structure
|
live: bool
|
||||||
|
protected: bool
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -75,6 +76,15 @@ class HealthResponse(BaseModel):
|
|||||||
status: str
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
class ClientStatusResponse(BaseModel):
|
||||||
|
"""Client status response"""
|
||||||
|
|
||||||
|
client_running: bool
|
||||||
|
session_name: str
|
||||||
|
lobby: str
|
||||||
|
server_url: str
|
||||||
|
|
||||||
|
|
||||||
class LobbyListItem(BaseModel):
|
class LobbyListItem(BaseModel):
|
||||||
"""Lobby item for list responses"""
|
"""Lobby item for list responses"""
|
||||||
id: str
|
id: str
|
||||||
@ -268,3 +278,82 @@ class SessionsPayload(BaseModel):
|
|||||||
"""Complete sessions data for persistence"""
|
"""Complete sessions data for persistence"""
|
||||||
sessions: List[SessionSaved] = []
|
sessions: List[SessionSaved] = []
|
||||||
name_passwords: Dict[str, NamePasswordRecord] = {}
|
name_passwords: Dict[str, NamePasswordRecord] = {}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Bot Provider Models
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class BotInfoModel(BaseModel):
|
||||||
|
"""Information about a specific bot"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class BotProviderModel(BaseModel):
|
||||||
|
"""Bot provider registration information"""
|
||||||
|
|
||||||
|
provider_id: str
|
||||||
|
base_url: str
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
registered_at: float
|
||||||
|
last_seen: float
|
||||||
|
|
||||||
|
|
||||||
|
class BotProviderRegisterRequest(BaseModel):
|
||||||
|
"""Request to register a bot provider"""
|
||||||
|
|
||||||
|
base_url: str
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class BotProviderRegisterResponse(BaseModel):
|
||||||
|
"""Response after registering a bot provider"""
|
||||||
|
|
||||||
|
provider_id: str
|
||||||
|
status: str = "registered"
|
||||||
|
|
||||||
|
|
||||||
|
class BotProviderListResponse(BaseModel):
|
||||||
|
"""Response listing all registered bot providers"""
|
||||||
|
|
||||||
|
providers: List[BotProviderModel]
|
||||||
|
|
||||||
|
|
||||||
|
class BotListResponse(BaseModel):
|
||||||
|
"""Response listing all available bots from all providers"""
|
||||||
|
|
||||||
|
bots: Dict[str, BotInfoModel] # bot_name -> bot_info
|
||||||
|
providers: Dict[str, str] # bot_name -> provider_id
|
||||||
|
|
||||||
|
|
||||||
|
class BotJoinLobbyRequest(BaseModel):
|
||||||
|
"""Request to make a bot join a lobby"""
|
||||||
|
|
||||||
|
bot_name: str
|
||||||
|
lobby_id: str
|
||||||
|
nick: str = ""
|
||||||
|
provider_id: Optional[str] = None # Optional: specify which provider to use
|
||||||
|
|
||||||
|
|
||||||
|
class BotJoinPayload(BaseModel):
|
||||||
|
"""Payload sent to bot provider to make a bot join a lobby"""
|
||||||
|
|
||||||
|
lobby_id: str
|
||||||
|
session_id: str
|
||||||
|
nick: str
|
||||||
|
server_url: str
|
||||||
|
insecure: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class BotJoinLobbyResponse(BaseModel):
|
||||||
|
"""Response after requesting a bot to join a lobby"""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
bot_name: str
|
||||||
|
run_id: str
|
||||||
|
provider_id: str
|
||||||
|
@ -0,0 +1,302 @@
|
|||||||
|
# AI Voicebot
|
||||||
|
|
||||||
|
A WebRTC-enabled AI voicebot system with speech recognition and synthetic media capabilities. The voicebot can run in two modes: as a client connecting to lobbies or as a provider serving bots to other applications.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Speech Recognition**: Uses Whisper models for real-time audio transcription
|
||||||
|
- **Synthetic Media**: Generates animated video and audio tracks
|
||||||
|
- **WebRTC Integration**: Real-time peer-to-peer communication
|
||||||
|
- **Bot Provider System**: Can register with a main server to provide bot services
|
||||||
|
- **Flexible Deployment**: Docker-based with development and production modes
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Docker and Docker Compose
|
||||||
|
- Python 3.12+ (if running locally)
|
||||||
|
- Access to a compatible signaling server
|
||||||
|
|
||||||
|
### Running with Docker
|
||||||
|
|
||||||
|
#### 1. Bot Provider Mode (Recommended)
|
||||||
|
|
||||||
|
Run the voicebot as a bot provider that registers with the main server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Development mode with auto-reload
|
||||||
|
VOICEBOT_MODE=provider PRODUCTION=false docker-compose up voicebot
|
||||||
|
|
||||||
|
# Production mode
|
||||||
|
VOICEBOT_MODE=provider PRODUCTION=true docker-compose up voicebot
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Direct Client Mode
|
||||||
|
|
||||||
|
Run the voicebot as a direct client connecting to a lobby:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Development mode
|
||||||
|
VOICEBOT_MODE=client PRODUCTION=false docker-compose up voicebot
|
||||||
|
|
||||||
|
# Production mode
|
||||||
|
VOICEBOT_MODE=client PRODUCTION=true docker-compose up voicebot
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running Locally
|
||||||
|
|
||||||
|
#### 1. Setup Environment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd voicebot/
|
||||||
|
|
||||||
|
# Create virtual environment
|
||||||
|
uv init --python /usr/bin/python3.12 --name "ai-voicebot-agent"
|
||||||
|
uv add -r requirements.txt
|
||||||
|
|
||||||
|
# Activate environment
|
||||||
|
source .venv/bin/activate
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Bot Provider Mode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Development with auto-reload
|
||||||
|
python main.py --mode provider --server-url https://your-server.com/ai-voicebot --reload --insecure
|
||||||
|
|
||||||
|
# Production
|
||||||
|
python main.py --mode provider --server-url https://your-server.com/ai-voicebot
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. Direct Client Mode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py --mode client \
|
||||||
|
--server-url https://your-server.com/ai-voicebot \
|
||||||
|
--lobby "my-lobby" \
|
||||||
|
--session-name "My Bot" \
|
||||||
|
--insecure
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
| Variable | Description | Default | Example |
|
||||||
|
|----------|-------------|---------|---------|
|
||||||
|
| `VOICEBOT_MODE` | Operating mode: `client` or `provider` | `client` | `provider` |
|
||||||
|
| `PRODUCTION` | Production mode flag | `false` | `true` |
|
||||||
|
|
||||||
|
### Command Line Arguments
|
||||||
|
|
||||||
|
#### Common Arguments
|
||||||
|
- `--mode`: Run as `client` or `provider`
|
||||||
|
- `--server-url`: Main server URL
|
||||||
|
- `--insecure`: Allow insecure SSL connections
|
||||||
|
- `--help`: Show all available options
|
||||||
|
|
||||||
|
#### Provider Mode Arguments
|
||||||
|
- `--host`: Host to bind the provider server (default: `0.0.0.0`)
|
||||||
|
- `--port`: Port for the provider server (default: `8788`)
|
||||||
|
- `--reload`: Enable auto-reload for development
|
||||||
|
|
||||||
|
#### Client Mode Arguments
|
||||||
|
- `--lobby`: Lobby name to join (default: `default`)
|
||||||
|
- `--session-name`: Display name for the bot (default: `Python Bot`)
|
||||||
|
- `--session-id`: Existing session ID to reuse
|
||||||
|
- `--password`: Password for protected names
|
||||||
|
- `--private`: Create/join private lobby
|
||||||
|
|
||||||
|
## Available Bots
|
||||||
|
|
||||||
|
The voicebot system includes the following bot types:
|
||||||
|
|
||||||
|
### 1. Whisper Bot
|
||||||
|
- **Name**: `whisper`
|
||||||
|
- **Description**: Speech recognition agent using OpenAI Whisper models
|
||||||
|
- **Capabilities**: Real-time audio transcription, multiple language support
|
||||||
|
- **Models**: Supports various Whisper and Distil-Whisper models
|
||||||
|
|
||||||
|
### 2. Synthetic Media Bot
|
||||||
|
- **Name**: `synthetic_media`
|
||||||
|
- **Description**: Generates animated video and audio tracks
|
||||||
|
- **Capabilities**: Animated video generation, synthetic audio, edge detection on incoming video
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Bot Provider System
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
|
||||||
|
│ Main Server │ │ Bot Provider │ │ Client App │
|
||||||
|
│ │◄───┤ (Voicebot) │ │ │
|
||||||
|
│ - Bot Registry │ │ - Whisper Bot │ │ - Bot Manager │
|
||||||
|
│ - Lobby Management │ - Synthetic Bot │ │ - UI Controls │
|
||||||
|
│ - API Endpoints │ │ - API Server │ │ - Lobby View │
|
||||||
|
└─────────────────┘ └──────────────────┘ └─────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Flow
|
||||||
|
1. Voicebot registers as bot provider with main server
|
||||||
|
2. Main server discovers available bots from providers
|
||||||
|
3. Client requests bot to join lobby via main server
|
||||||
|
4. Main server forwards request to appropriate provider
|
||||||
|
5. Provider creates bot instance that connects to the lobby
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
### Auto-Reload
|
||||||
|
|
||||||
|
In development mode, the bot provider supports auto-reload using uvicorn:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Watches /voicebot and /shared directories for changes
|
||||||
|
python main.py --mode provider --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
### Adding New Bots
|
||||||
|
|
||||||
|
1. Create a new module in `voicebot/bots/`
|
||||||
|
2. Implement required functions:
|
||||||
|
```python
|
||||||
|
def agent_info() -> dict:
|
||||||
|
return {"name": "my_bot", "description": "My custom bot"}
|
||||||
|
|
||||||
|
def create_agent_tracks(session_name: str) -> dict:
|
||||||
|
# Return MediaStreamTrack instances
|
||||||
|
return {"audio": my_audio_track, "video": my_video_track}
|
||||||
|
```
|
||||||
|
3. The bot will be automatically discovered and available
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Test bot discovery
|
||||||
|
python test_bot_api.py
|
||||||
|
|
||||||
|
# Test client connection
|
||||||
|
python main.py --mode client --lobby test --session-name "Test Bot"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Production Deployment
|
||||||
|
|
||||||
|
### Docker Compose
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
version: '3.8'
|
||||||
|
services:
|
||||||
|
voicebot-provider:
|
||||||
|
build: .
|
||||||
|
environment:
|
||||||
|
- VOICEBOT_MODE=provider
|
||||||
|
- PRODUCTION=true
|
||||||
|
ports:
|
||||||
|
- "8788:8788"
|
||||||
|
volumes:
|
||||||
|
- ./cache:/voicebot/cache
|
||||||
|
```
|
||||||
|
|
||||||
|
### Kubernetes
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: voicebot-provider
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app: voicebot-provider
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app: voicebot-provider
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: voicebot
|
||||||
|
image: ai-voicebot:latest
|
||||||
|
env:
|
||||||
|
- name: VOICEBOT_MODE
|
||||||
|
value: "provider"
|
||||||
|
- name: PRODUCTION
|
||||||
|
value: "true"
|
||||||
|
ports:
|
||||||
|
- containerPort: 8788
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### Bot Provider Endpoints
|
||||||
|
|
||||||
|
The voicebot provider exposes the following HTTP API:
|
||||||
|
|
||||||
|
- `GET /bots` - List available bots
|
||||||
|
- `POST /bots/{bot_name}/join` - Request bot to join lobby
|
||||||
|
- `GET /bots/runs` - List active bot instances
|
||||||
|
- `POST /bots/runs/{run_id}/stop` - Stop a bot instance
|
||||||
|
|
||||||
|
### Example API Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# List available bots
|
||||||
|
curl http://localhost:8788/bots
|
||||||
|
|
||||||
|
# Request whisper bot to join lobby
|
||||||
|
curl -X POST http://localhost:8788/bots/whisper/join \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"lobby_id": "lobby-123",
|
||||||
|
"session_id": "session-456",
|
||||||
|
"nick": "Speech Bot",
|
||||||
|
"server_url": "https://server.com/ai-voicebot"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
**Bot provider not registering:**
|
||||||
|
- Check server URL is correct and accessible
|
||||||
|
- Verify network connectivity between provider and server
|
||||||
|
- Check logs for registration errors
|
||||||
|
|
||||||
|
**Auto-reload not working:**
|
||||||
|
- Ensure `--reload` flag is used in development
|
||||||
|
- Check file permissions on watched directories
|
||||||
|
- Verify uvicorn version supports reload functionality
|
||||||
|
|
||||||
|
**WebRTC connection issues:**
|
||||||
|
- Check STUN/TURN server configuration
|
||||||
|
- Verify network ports are not blocked
|
||||||
|
- Check browser console for ICE connection errors
|
||||||
|
|
||||||
|
### Logs
|
||||||
|
|
||||||
|
Logs are written to stdout and include:
|
||||||
|
- Bot registration status
|
||||||
|
- WebRTC connection events
|
||||||
|
- Media track creation/destruction
|
||||||
|
- API request/response details
|
||||||
|
|
||||||
|
### Debug Mode
|
||||||
|
|
||||||
|
Enable verbose logging:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py --mode provider --server-url https://server.com --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
1. Fork the repository
|
||||||
|
2. Create a feature branch
|
||||||
|
3. Make your changes
|
||||||
|
4. Add tests for new functionality
|
||||||
|
5. Submit a pull request
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is licensed under the MIT License - see the LICENSE file for details.
|
@ -1,3 +1,6 @@
|
|||||||
"""Bots package for discoverable agent modules."""
|
"""Bots package for discoverable agent modules."""
|
||||||
|
|
||||||
|
from . import synthetic_media
|
||||||
|
from . import whisper
|
||||||
|
|
||||||
__all__ = ["synthetic_media", "whisper"]
|
__all__ = ["synthetic_media", "whisper"]
|
||||||
|
@ -18,21 +18,58 @@ fi
|
|||||||
export VIRTUAL_ENV=/voicebot/.venv
|
export VIRTUAL_ENV=/voicebot/.venv
|
||||||
export PATH="$VIRTUAL_ENV/bin:$PATH"
|
export PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||||
|
|
||||||
|
# Determine mode - provider or client
|
||||||
|
MODE="${VOICEBOT_MODE:-client}"
|
||||||
|
|
||||||
# Launch voicebot in production or development mode
|
# Launch voicebot in production or development mode
|
||||||
if [ "$PRODUCTION" != "true" ]; then
|
if [ "$PRODUCTION" != "true" ]; then
|
||||||
echo "Starting voicebot in development mode with auto-reload..."
|
echo "Starting voicebot in development mode..."
|
||||||
# Fix: Use single --watch argument with multiple paths instead of multiple --watch arguments
|
if [ "$MODE" = "provider" ]; then
|
||||||
python3 -u scripts/reload_runner.py --delay-restart 3 --watch . /shared --verbose --interval 0.5 -- uv run main.py \
|
echo "Running as bot provider with auto-reload..."
|
||||||
--insecure \
|
export VOICEBOT_MODE=provider
|
||||||
--server-url https://ketrenos.com/ai-voicebot \
|
exec uv run uvicorn main:uvicorn_app \
|
||||||
--lobby default \
|
--host 0.0.0.0 \
|
||||||
--session-name "Python Voicebot" \
|
--port 8788 \
|
||||||
--password "v01c3b0t"
|
--reload \
|
||||||
|
--reload-dir /voicebot \
|
||||||
|
--reload-dir /shared \
|
||||||
|
--log-level info
|
||||||
|
else
|
||||||
|
echo "Running as client (connecting to lobby)..."
|
||||||
|
export VOICEBOT_MODE=client
|
||||||
|
export VOICEBOT_SERVER_URL="https://ketrenos.com/ai-voicebot"
|
||||||
|
export VOICEBOT_LOBBY="default"
|
||||||
|
export VOICEBOT_SESSION_NAME="Python Voicebot"
|
||||||
|
export VOICEBOT_PASSWORD="v01c3b0t"
|
||||||
|
export VOICEBOT_INSECURE="true"
|
||||||
|
exec uv run uvicorn main:uvicorn_app \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port 8789 \
|
||||||
|
--reload \
|
||||||
|
--reload-dir /voicebot \
|
||||||
|
--reload-dir /shared \
|
||||||
|
--log-level info
|
||||||
|
fi
|
||||||
else
|
else
|
||||||
echo "Starting voicebot in production mode..."
|
echo "Starting voicebot in production mode..."
|
||||||
exec uv run main.py \
|
if [ "$MODE" = "provider" ]; then
|
||||||
--server-url https://ai-voicebot.ketrenos.com \
|
echo "Running as bot provider..."
|
||||||
--lobby default \
|
export VOICEBOT_MODE=provider
|
||||||
--session-name "Python Voicebot" \
|
exec uv run uvicorn main:uvicorn_app \
|
||||||
--password "v01c3b0t"
|
--host 0.0.0.0 \
|
||||||
|
--port 8788 \
|
||||||
|
--log-level info
|
||||||
|
else
|
||||||
|
echo "Running as client (connecting to lobby)..."
|
||||||
|
export VOICEBOT_MODE=client
|
||||||
|
export VOICEBOT_SERVER_URL="https://ai-voicebot.ketrenos.com"
|
||||||
|
export VOICEBOT_LOBBY="default"
|
||||||
|
export VOICEBOT_SESSION_NAME="Python Voicebot"
|
||||||
|
export VOICEBOT_PASSWORD="v01c3b0t"
|
||||||
|
export VOICEBOT_INSECURE="false"
|
||||||
|
exec uv run uvicorn main:uvicorn_app \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port 8789 \
|
||||||
|
--log-level info
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
347
voicebot/main.py
347
voicebot/main.py
@ -50,6 +50,7 @@ from shared.models import (
|
|||||||
IceCandidateModel,
|
IceCandidateModel,
|
||||||
ICECandidateDictModel,
|
ICECandidateDictModel,
|
||||||
SessionDescriptionTypedModel,
|
SessionDescriptionTypedModel,
|
||||||
|
ClientStatusResponse,
|
||||||
)
|
)
|
||||||
from aiortc import (
|
from aiortc import (
|
||||||
RTCPeerConnection,
|
RTCPeerConnection,
|
||||||
@ -60,6 +61,76 @@ from aiortc import (
|
|||||||
from logger import logger
|
from logger import logger
|
||||||
from voicebot.bots.synthetic_media import create_synthetic_tracks, AnimatedVideoTrack
|
from voicebot.bots.synthetic_media import create_synthetic_tracks, AnimatedVideoTrack
|
||||||
|
|
||||||
|
# Pydantic model for voicebot arguments
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class VoicebotMode(str, Enum):
|
||||||
|
"""Voicebot operation modes."""
|
||||||
|
CLIENT = "client"
|
||||||
|
PROVIDER = "provider"
|
||||||
|
|
||||||
|
class VoicebotArgs(BaseModel):
|
||||||
|
"""Pydantic model for voicebot CLI arguments and configuration."""
|
||||||
|
|
||||||
|
# Mode selection
|
||||||
|
mode: VoicebotMode = Field(default=VoicebotMode.CLIENT, description="Run as client (connect to lobby) or provider (serve bots)")
|
||||||
|
|
||||||
|
# Provider mode arguments
|
||||||
|
host: str = Field(default="0.0.0.0", description="Host for provider mode")
|
||||||
|
port: int = Field(default=8788, description="Port for provider mode", ge=1, le=65535)
|
||||||
|
reload: bool = Field(default=False, description="Enable auto-reload for development")
|
||||||
|
|
||||||
|
# Client mode arguments
|
||||||
|
server_url: str = Field(
|
||||||
|
default="http://localhost:8000/ai-voicebot",
|
||||||
|
description="AI-Voicebot lobby and signaling server base URL (http:// or https://)"
|
||||||
|
)
|
||||||
|
lobby: str = Field(default="default", description="Lobby name to create or join")
|
||||||
|
session_name: str = Field(default="Python Bot", description="Session (user) display name")
|
||||||
|
session_id: Optional[str] = Field(default=None, description="Optional existing session id to reuse")
|
||||||
|
password: Optional[str] = Field(default=None, description="Optional password to register or takeover a name")
|
||||||
|
private: bool = Field(default=False, description="Create the lobby as private")
|
||||||
|
insecure: bool = Field(default=False, description="Allow insecure server connections when using SSL")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_environment(cls) -> 'VoicebotArgs':
|
||||||
|
"""Create VoicebotArgs from environment variables."""
|
||||||
|
import os
|
||||||
|
|
||||||
|
mode_str = os.getenv('VOICEBOT_MODE', 'client')
|
||||||
|
return cls(
|
||||||
|
mode=VoicebotMode(mode_str),
|
||||||
|
host=os.getenv('VOICEBOT_HOST', '0.0.0.0'),
|
||||||
|
port=int(os.getenv('VOICEBOT_PORT', '8788')),
|
||||||
|
reload=os.getenv('VOICEBOT_RELOAD', 'false').lower() == 'true',
|
||||||
|
server_url=os.getenv('VOICEBOT_SERVER_URL', 'http://localhost:8000/ai-voicebot'),
|
||||||
|
lobby=os.getenv('VOICEBOT_LOBBY', 'default'),
|
||||||
|
session_name=os.getenv('VOICEBOT_SESSION_NAME', 'Python Bot'),
|
||||||
|
session_id=os.getenv('VOICEBOT_SESSION_ID', None),
|
||||||
|
password=os.getenv('VOICEBOT_PASSWORD', None),
|
||||||
|
private=os.getenv('VOICEBOT_PRIVATE', 'false').lower() == 'true',
|
||||||
|
insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true'
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_argparse(cls, args: 'argparse.Namespace') -> 'VoicebotArgs':
|
||||||
|
"""Create VoicebotArgs from argparse Namespace."""
|
||||||
|
mode_str = getattr(args, 'mode', 'client')
|
||||||
|
return cls(
|
||||||
|
mode=VoicebotMode(mode_str),
|
||||||
|
host=getattr(args, 'host', '0.0.0.0'),
|
||||||
|
port=getattr(args, 'port', 8788),
|
||||||
|
reload=getattr(args, 'reload', False),
|
||||||
|
server_url=getattr(args, 'server_url', 'http://localhost:8000/ai-voicebot'),
|
||||||
|
lobby=getattr(args, 'lobby', 'default'),
|
||||||
|
session_name=getattr(args, 'session_name', 'Python Bot'),
|
||||||
|
session_id=getattr(args, 'session_id', None),
|
||||||
|
password=getattr(args, 'password', None),
|
||||||
|
private=getattr(args, 'private', False),
|
||||||
|
insecure=getattr(args, 'insecure', False)
|
||||||
|
)
|
||||||
|
|
||||||
# Bot orchestration imports
|
# Bot orchestration imports
|
||||||
import importlib
|
import importlib
|
||||||
import pkgutil
|
import pkgutil
|
||||||
@ -1049,6 +1120,14 @@ async def main():
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Convert argparse Namespace to VoicebotArgs for type safety
|
||||||
|
voicebot_args = VoicebotArgs.from_argparse(args)
|
||||||
|
await main_with_args(voicebot_args)
|
||||||
|
|
||||||
|
|
||||||
|
async def main_with_args(args: VoicebotArgs):
|
||||||
|
"""Main voicebot client logic that accepts arguments object."""
|
||||||
|
|
||||||
# Resolve session id (create if needed)
|
# Resolve session id (create if needed)
|
||||||
try:
|
try:
|
||||||
session_id = create_or_get_session(
|
session_id = create_or_get_session(
|
||||||
@ -1131,6 +1210,112 @@ async def main():
|
|||||||
|
|
||||||
app = FastAPI(title="voicebot-bot-orchestrator")
|
app = FastAPI(title="voicebot-bot-orchestrator")
|
||||||
|
|
||||||
|
# Global client app instance for uvicorn import
|
||||||
|
client_app = None
|
||||||
|
|
||||||
|
# Global client arguments storage
|
||||||
|
_client_args: Optional[VoicebotArgs] = None
|
||||||
|
|
||||||
|
def create_client_app(args: VoicebotArgs):
|
||||||
|
"""Create a FastAPI app for client mode that uvicorn can import."""
|
||||||
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import os
|
||||||
|
import fcntl
|
||||||
|
|
||||||
|
global _client_args
|
||||||
|
_client_args = args
|
||||||
|
|
||||||
|
# Store the client task globally so we can manage it
|
||||||
|
client_task = None
|
||||||
|
lock_file = None
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
nonlocal client_task, lock_file
|
||||||
|
# Startup
|
||||||
|
# Use a file lock to prevent multiple instances from starting
|
||||||
|
lock_file_path = "/tmp/voicebot_client.lock"
|
||||||
|
|
||||||
|
try:
|
||||||
|
lock_file = open(lock_file_path, 'w')
|
||||||
|
# Try to acquire an exclusive lock (non-blocking)
|
||||||
|
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||||
|
|
||||||
|
if _client_args is None:
|
||||||
|
logger.error("Client args not initialized")
|
||||||
|
if lock_file:
|
||||||
|
lock_file.close()
|
||||||
|
lock_file = None
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Starting voicebot client...")
|
||||||
|
client_task = asyncio.create_task(main_with_args(_client_args))
|
||||||
|
|
||||||
|
except (IOError, OSError):
|
||||||
|
# Another process already has the lock
|
||||||
|
logger.info("Another instance is already running - skipping client startup")
|
||||||
|
if lock_file:
|
||||||
|
lock_file.close()
|
||||||
|
lock_file = None
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
if client_task and not client_task.done():
|
||||||
|
logger.info("Shutting down voicebot client...")
|
||||||
|
client_task.cancel()
|
||||||
|
try:
|
||||||
|
await client_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if lock_file:
|
||||||
|
try:
|
||||||
|
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||||
|
lock_file.close()
|
||||||
|
os.unlink(lock_file_path)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Create the client FastAPI app
|
||||||
|
app = FastAPI(title="voicebot-client", lifespan=lifespan)
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():# pyright: ignore
|
||||||
|
"""Simple health check endpoint"""
|
||||||
|
return {"status": "running", "mode": "client"}
|
||||||
|
|
||||||
|
@app.get("/status", response_model=ClientStatusResponse)
|
||||||
|
async def client_status() -> ClientStatusResponse:# pyright: ignore
|
||||||
|
"""Get client status"""
|
||||||
|
return ClientStatusResponse(
|
||||||
|
client_running=client_task is not None and not client_task.done(),
|
||||||
|
session_name=_client_args.session_name if _client_args else 'unknown',
|
||||||
|
lobby=_client_args.lobby if _client_args else 'unknown',
|
||||||
|
server_url=_client_args.server_url if _client_args else 'unknown'
|
||||||
|
)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
# Function to get the appropriate app based on environment variable
|
||||||
|
def get_app():
|
||||||
|
"""Get the appropriate FastAPI app based on VOICEBOT_MODE environment variable."""
|
||||||
|
import os
|
||||||
|
mode = os.getenv('VOICEBOT_MODE', 'provider')
|
||||||
|
|
||||||
|
if mode == 'client':
|
||||||
|
# For client mode, we need to create the client app with args from environment
|
||||||
|
args = VoicebotArgs.from_environment()
|
||||||
|
return create_client_app(args)
|
||||||
|
else:
|
||||||
|
# Provider mode - return the main bot orchestration app
|
||||||
|
return app
|
||||||
|
|
||||||
|
# Create app instance for uvicorn import
|
||||||
|
uvicorn_app = get_app()
|
||||||
|
|
||||||
|
|
||||||
class JoinRequest(BaseModel):
|
class JoinRequest(BaseModel):
|
||||||
lobby_id: str
|
lobby_id: str
|
||||||
@ -1251,11 +1436,165 @@ def start_bot_api(host: str = "0.0.0.0", port: int = 8788):
|
|||||||
uvicorn.run(app, host=host, port=port)
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
async def register_with_server(server_url: str, voicebot_url: str, insecure: bool = False) -> str:
|
||||||
|
"""Register this voicebot instance as a bot provider with the main server"""
|
||||||
|
try:
|
||||||
|
# Import httpx locally to avoid dependency issues
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"base_url": voicebot_url.rstrip('/'),
|
||||||
|
"name": "voicebot-provider",
|
||||||
|
"description": "AI voicebot provider with speech recognition and synthetic media capabilities"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prepare SSL context if needed
|
||||||
|
verify = not insecure
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(verify=verify) as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{server_url}/api/bots/providers/register",
|
||||||
|
json=payload,
|
||||||
|
timeout=10.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
provider_id = result.get("provider_id")
|
||||||
|
logger.info(f"Successfully registered with server as provider: {provider_id}")
|
||||||
|
return provider_id
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to register with server: HTTP {response.status_code}: {response.text}")
|
||||||
|
raise RuntimeError(f"Registration failed: {response.status_code}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error registering with server: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def start_bot_provider(
|
||||||
|
host: str = "0.0.0.0",
|
||||||
|
port: int = 8788,
|
||||||
|
server_url: str | None = None,
|
||||||
|
insecure: bool = False,
|
||||||
|
reload: bool = False
|
||||||
|
):
|
||||||
|
"""Start the bot provider API server and optionally register with main server"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Start the FastAPI server in a background thread
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# Add reload functionality for development
|
||||||
|
if reload:
|
||||||
|
server_thread = threading.Thread(
|
||||||
|
target=lambda: uvicorn.run(
|
||||||
|
app,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
log_level="info",
|
||||||
|
reload=True,
|
||||||
|
reload_dirs=["/voicebot", "/shared"]
|
||||||
|
),
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
server_thread = threading.Thread(
|
||||||
|
target=lambda: uvicorn.run(app, host=host, port=port, log_level="info"),
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
|
||||||
|
server_thread.start()
|
||||||
|
|
||||||
|
# If server_url is provided, register with the main server
|
||||||
|
if server_url:
|
||||||
|
# Give the server a moment to start
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# Construct the voicebot URL
|
||||||
|
voicebot_url = f"http://{host}:{port}"
|
||||||
|
if host == "0.0.0.0":
|
||||||
|
# Try to get a better hostname
|
||||||
|
import socket
|
||||||
|
try:
|
||||||
|
hostname = socket.gethostname()
|
||||||
|
voicebot_url = f"http://{hostname}:{port}"
|
||||||
|
except Exception:
|
||||||
|
voicebot_url = f"http://localhost:{port}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(register_with_server(server_url, voicebot_url, insecure))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to register with server: {e}")
|
||||||
|
|
||||||
|
# Keep the main thread alive
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Shutting down bot provider...")
|
||||||
|
|
||||||
|
|
||||||
|
def start_client_with_reload(args: VoicebotArgs):
|
||||||
|
"""Start the client with auto-reload functionality."""
|
||||||
|
global client_app
|
||||||
|
|
||||||
|
logger.info("Creating client app for uvicorn...")
|
||||||
|
client_app = create_client_app(args)
|
||||||
|
|
||||||
|
# Note: This function is called when --reload is specified
|
||||||
|
# The actual uvicorn execution should be handled by the entrypoint script
|
||||||
|
logger.info("Client app created. Uvicorn should be started by entrypoint script.")
|
||||||
|
|
||||||
|
# Fall back to running client directly if not using uvicorn
|
||||||
|
asyncio.run(main_with_args(args))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Install required packages:
|
# Install required packages:
|
||||||
# pip install aiortc websockets opencv-python numpy
|
# pip install aiortc websockets opencv-python numpy
|
||||||
|
|
||||||
asyncio.run(main())
|
import argparse
|
||||||
# test modification
|
|
||||||
# Test comment Mon Sep 1 03:48:19 PM PDT 2025
|
# Check if we're being run as a bot provider or as a client
|
||||||
# Test change at Mon Sep 1 03:52:13 PM PDT 2025
|
parser = argparse.ArgumentParser(description="AI Voicebot - WebRTC client or bot provider")
|
||||||
|
parser.add_argument("--mode", choices=["client", "provider"], default="client",
|
||||||
|
help="Run as client (connect to lobby) or provider (serve bots)")
|
||||||
|
|
||||||
|
# Provider mode arguments
|
||||||
|
parser.add_argument("--host", default="0.0.0.0", help="Host for provider mode")
|
||||||
|
parser.add_argument("--port", type=int, default=8788, help="Port for provider mode")
|
||||||
|
parser.add_argument("--reload", action="store_true",
|
||||||
|
help="Enable auto-reload for development")
|
||||||
|
|
||||||
|
# Client mode arguments
|
||||||
|
parser.add_argument("--server-url",
|
||||||
|
default="http://localhost:8000/ai-voicebot",
|
||||||
|
help="AI-Voicebot lobby and signaling server base URL (for client mode) or provider registration URL (for provider mode)")
|
||||||
|
parser.add_argument("--lobby", default="default", help="Lobby name to create or join (client mode)")
|
||||||
|
parser.add_argument("--session-name", default="Python Bot", help="Session (user) display name (client mode)")
|
||||||
|
parser.add_argument("--session-id", default=None, help="Optional existing session id to reuse (client mode)")
|
||||||
|
parser.add_argument("--password", default=None, help="Optional password to register or takeover a name (client mode)")
|
||||||
|
parser.add_argument("--private", action="store_true", help="Create the lobby as private (client mode)")
|
||||||
|
parser.add_argument("--insecure", action="store_true",
|
||||||
|
help="Allow insecure connections")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Convert argparse Namespace to VoicebotArgs for type safety
|
||||||
|
voicebot_args = VoicebotArgs.from_argparse(args)
|
||||||
|
|
||||||
|
if voicebot_args.mode == VoicebotMode.PROVIDER:
|
||||||
|
start_bot_provider(
|
||||||
|
host=voicebot_args.host,
|
||||||
|
port=voicebot_args.port,
|
||||||
|
server_url=voicebot_args.server_url,
|
||||||
|
insecure=voicebot_args.insecure,
|
||||||
|
reload=voicebot_args.reload
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if voicebot_args.reload:
|
||||||
|
start_client_with_reload(voicebot_args)
|
||||||
|
else:
|
||||||
|
asyncio.run(main_with_args(voicebot_args))
|
||||||
|
|
||||||
|
@ -1,287 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Simple file-watcher that restarts a command when Python source files change.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python scripts/reload_runner.py --watch voicebot -- python voicebot/main.py
|
|
||||||
|
|
||||||
This is intentionally dependency-free so it works in minimal dev environments
|
|
||||||
and inside containers without installing extra packages.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import hashlib
|
|
||||||
import os
|
|
||||||
import signal
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
from types import FrameType
|
|
||||||
|
|
||||||
|
|
||||||
def scan_py_mtimes(paths: List[str]) -> Dict[str, float]:
|
|
||||||
# Directories to skip during scanning
|
|
||||||
SKIP_DIRS = {
|
|
||||||
".venv",
|
|
||||||
"__pycache__",
|
|
||||||
".git",
|
|
||||||
"node_modules",
|
|
||||||
".mypy_cache",
|
|
||||||
".pytest_cache",
|
|
||||||
"build",
|
|
||||||
"dist",
|
|
||||||
}
|
|
||||||
|
|
||||||
mtimes: Dict[str, float] = {}
|
|
||||||
for p in paths:
|
|
||||||
if os.path.isfile(p) and p.endswith('.py'):
|
|
||||||
try:
|
|
||||||
# Use both mtime and ctime to catch more changes in Docker environments
|
|
||||||
stat = os.stat(p)
|
|
||||||
mtimes[p] = max(stat.st_mtime, stat.st_ctime)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
continue
|
|
||||||
|
|
||||||
for root, dirs, files in os.walk(p):
|
|
||||||
# Skip common directories that shouldn't trigger reloads
|
|
||||||
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
|
|
||||||
|
|
||||||
for f in files:
|
|
||||||
if not f.endswith('.py'):
|
|
||||||
continue
|
|
||||||
fp = os.path.join(root, f)
|
|
||||||
try:
|
|
||||||
# Use both mtime and ctime to catch more changes in Docker environments
|
|
||||||
stat = os.stat(fp)
|
|
||||||
mtimes[fp] = max(stat.st_mtime, stat.st_ctime)
|
|
||||||
except OSError:
|
|
||||||
# file might disappear between walk and stat
|
|
||||||
pass
|
|
||||||
return mtimes
|
|
||||||
|
|
||||||
|
|
||||||
def scan_py_hashes(paths: List[str]) -> Dict[str, str]:
|
|
||||||
"""Fallback method: scan file content hashes for change detection."""
|
|
||||||
# Directories to skip during scanning
|
|
||||||
SKIP_DIRS = {
|
|
||||||
".venv",
|
|
||||||
"__pycache__",
|
|
||||||
".git",
|
|
||||||
"node_modules",
|
|
||||||
".mypy_cache",
|
|
||||||
".pytest_cache",
|
|
||||||
"build",
|
|
||||||
"dist",
|
|
||||||
}
|
|
||||||
|
|
||||||
hashes: Dict[str, str] = {}
|
|
||||||
for p in paths:
|
|
||||||
if os.path.isfile(p) and p.endswith(".py"):
|
|
||||||
try:
|
|
||||||
with open(p, "rb") as f:
|
|
||||||
content = f.read()
|
|
||||||
hashes[p] = hashlib.md5(content).hexdigest()
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
continue
|
|
||||||
|
|
||||||
for root, dirs, files in os.walk(p):
|
|
||||||
# Skip common directories that shouldn't trigger reloads
|
|
||||||
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
|
|
||||||
|
|
||||||
for f in files:
|
|
||||||
if not f.endswith(".py"):
|
|
||||||
continue
|
|
||||||
fp = os.path.join(root, f)
|
|
||||||
try:
|
|
||||||
with open(fp, "rb") as file:
|
|
||||||
content = file.read()
|
|
||||||
hashes[fp] = hashlib.md5(content).hexdigest()
|
|
||||||
except OSError:
|
|
||||||
# file might disappear between walk and read
|
|
||||||
pass
|
|
||||||
return hashes
|
|
||||||
|
|
||||||
|
|
||||||
def start_process(cmd: List[str]) -> subprocess.Popen[bytes]:
|
|
||||||
print("Starting:", " ".join(cmd))
|
|
||||||
return subprocess.Popen(cmd)
|
|
||||||
|
|
||||||
|
|
||||||
def terminate_process(p: subprocess.Popen[bytes], timeout: float = 5.0) -> None:
|
|
||||||
if p.poll() is not None:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
p.terminate()
|
|
||||||
waited = 0.0
|
|
||||||
while p.poll() is None and waited < timeout:
|
|
||||||
time.sleep(0.1)
|
|
||||||
waited += 0.1
|
|
||||||
if p.poll() is None:
|
|
||||||
p.kill()
|
|
||||||
except Exception as e:
|
|
||||||
print("Error terminating process:", e)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
|
||||||
parser = argparse.ArgumentParser(description="Restart a command when .py files change")
|
|
||||||
parser.add_argument("--watch", "-w", nargs="+", default=["."], help="Directories or files to watch")
|
|
||||||
parser.add_argument(
|
|
||||||
"--interval", "-i", type=float, default=0.5, help="Polling interval in seconds"
|
|
||||||
)
|
|
||||||
parser.add_argument("--delay-restart", type=float, default=0.1, help="Delay after change before restarting")
|
|
||||||
parser.add_argument("--no-restart-on-exit", action="store_true", help="Don't restart if the process exits on its own")
|
|
||||||
parser.add_argument("--pass-sigterm", action="store_true", help="Forward SIGTERM to child and exit when received")
|
|
||||||
parser.add_argument(
|
|
||||||
"--verbose", "-v", action="store_true", help="Enable verbose logging"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-hash-fallback",
|
|
||||||
action="store_true",
|
|
||||||
help="Use content hashing as fallback for Docker environments",
|
|
||||||
)
|
|
||||||
# Accept the command to run as a positional "remainder" so callers can
|
|
||||||
# separate options with `--` and have everything after it treated as the
|
|
||||||
# command. Defining an option named "--" doesn't work reliably with
|
|
||||||
# argparse; use a positional argument instead.
|
|
||||||
parser.add_argument("cmd", nargs=argparse.REMAINDER, help="Command to run (required)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# args.cmd is the remainder of the command-line. Users typically call this
|
|
||||||
# script like: `reload_runner.py --watch . -- mycmd arg1 arg2`.
|
|
||||||
# argparse will include a literal leading '--' in the remainder list, so
|
|
||||||
# strip it if present.
|
|
||||||
raw_cmd = args.cmd
|
|
||||||
if raw_cmd and raw_cmd[0] == "--":
|
|
||||||
cmd = raw_cmd[1:]
|
|
||||||
else:
|
|
||||||
cmd = raw_cmd
|
|
||||||
|
|
||||||
if not cmd:
|
|
||||||
parser.error("Missing command to run. Put `--` before the command. See help.")
|
|
||||||
|
|
||||||
watch_paths = args.watch
|
|
||||||
|
|
||||||
last_mtimes = scan_py_mtimes(watch_paths)
|
|
||||||
last_hashes = scan_py_hashes(watch_paths) if args.use_hash_fallback else {}
|
|
||||||
|
|
||||||
if args.verbose:
|
|
||||||
print(f"Watching {len(last_mtimes)} Python files in paths: {watch_paths}")
|
|
||||||
print(f"Working directory: {os.getcwd()}")
|
|
||||||
print(f"Resolved watch paths: {[os.path.abspath(p) for p in watch_paths]}")
|
|
||||||
print(f"Polling interval: {args.interval}s")
|
|
||||||
if args.use_hash_fallback:
|
|
||||||
print("Using content hash fallback for change detection")
|
|
||||||
print("Sample files being watched:")
|
|
||||||
for fp in sorted(last_mtimes.keys())[:5]:
|
|
||||||
print(f" {fp}")
|
|
||||||
if len(last_mtimes) > 5:
|
|
||||||
print(f" ... and {len(last_mtimes) - 5} more")
|
|
||||||
|
|
||||||
child = start_process(cmd)
|
|
||||||
|
|
||||||
def handle_sigterm(signum: int, frame: Optional[FrameType]) -> None:
|
|
||||||
if args.pass_sigterm:
|
|
||||||
try:
|
|
||||||
child.send_signal(signum)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
print("Received signal, stopping watcher.")
|
|
||||||
try:
|
|
||||||
terminate_process(child)
|
|
||||||
finally:
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, handle_sigterm)
|
|
||||||
signal.signal(signal.SIGTERM, handle_sigterm)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
# Sleep in small increments so Ctrl-C is responsive
|
|
||||||
time.sleep(args.interval)
|
|
||||||
|
|
||||||
# If the child exited on its own
|
|
||||||
if child.poll() is not None:
|
|
||||||
rc = child.returncode
|
|
||||||
print(f"Process exited with code {rc}.")
|
|
||||||
if args.no_restart_on_exit:
|
|
||||||
return rc
|
|
||||||
# else restart immediately
|
|
||||||
child = start_process(cmd)
|
|
||||||
last_mtimes = scan_py_mtimes(watch_paths)
|
|
||||||
last_hashes = (
|
|
||||||
scan_py_hashes(watch_paths) if args.use_hash_fallback else {}
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check for source changes
|
|
||||||
current = scan_py_mtimes(watch_paths)
|
|
||||||
changed = False
|
|
||||||
change_reason = ""
|
|
||||||
|
|
||||||
# Check for new or changed files
|
|
||||||
for fp, m in current.items():
|
|
||||||
if fp not in last_mtimes or last_mtimes.get(fp) != m:
|
|
||||||
print("Detected change in:", fp)
|
|
||||||
if args.verbose:
|
|
||||||
old_mtime = last_mtimes.get(fp, 0)
|
|
||||||
print(f" Old mtime: {old_mtime}, New mtime: {m}")
|
|
||||||
changed = True
|
|
||||||
change_reason = f"mtime change in {fp}"
|
|
||||||
break
|
|
||||||
|
|
||||||
# Hash-based fallback check if mtime didn't detect changes
|
|
||||||
if not changed and args.use_hash_fallback:
|
|
||||||
current_hashes = scan_py_hashes(watch_paths)
|
|
||||||
for fp, h in current_hashes.items():
|
|
||||||
if fp not in last_hashes or last_hashes.get(fp) != h:
|
|
||||||
print("Detected content change in:", fp)
|
|
||||||
if args.verbose:
|
|
||||||
print(
|
|
||||||
f" Hash changed: {last_hashes.get(fp, 'None')} -> {h}"
|
|
||||||
)
|
|
||||||
changed = True
|
|
||||||
change_reason = f"content change in {fp}"
|
|
||||||
break
|
|
||||||
# Update hash cache
|
|
||||||
last_hashes = current_hashes
|
|
||||||
|
|
||||||
# Check for deleted files
|
|
||||||
if not changed:
|
|
||||||
for fp in list(last_mtimes.keys()):
|
|
||||||
if fp not in current:
|
|
||||||
print("Detected deleted file:", fp)
|
|
||||||
changed = True
|
|
||||||
change_reason = f"deleted file {fp}"
|
|
||||||
break
|
|
||||||
|
|
||||||
# Additional debug output
|
|
||||||
if args.verbose and not changed:
|
|
||||||
num_files = len(current)
|
|
||||||
if num_files != len(last_mtimes):
|
|
||||||
print(f"File count changed: {len(last_mtimes)} -> {num_files}")
|
|
||||||
changed = True
|
|
||||||
change_reason = "file count change"
|
|
||||||
|
|
||||||
if changed:
|
|
||||||
if args.verbose:
|
|
||||||
print(f"Restarting due to: {change_reason}")
|
|
||||||
# Small debounce
|
|
||||||
time.sleep(args.delay_restart)
|
|
||||||
terminate_process(child)
|
|
||||||
child = start_process(cmd)
|
|
||||||
last_mtimes = scan_py_mtimes(watch_paths)
|
|
||||||
if args.use_hash_fallback:
|
|
||||||
last_hashes = scan_py_hashes(watch_paths)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("Interrupted, shutting down.")
|
|
||||||
terminate_process(child)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
raise SystemExit(main())
|
|
Loading…
x
Reference in New Issue
Block a user