Rework in progress
This commit is contained in:
parent
33d9c1d28a
commit
dc55196311
@ -257,7 +257,7 @@ FROM llm-base AS backstory
|
||||
|
||||
COPY /src/requirements.txt /opt/backstory/src/requirements.txt
|
||||
RUN pip install -r /opt/backstory/src/requirements.txt
|
||||
RUN pip install 'markitdown[all]'
|
||||
RUN pip install 'markitdown[all]' pydantic
|
||||
|
||||
SHELL [ "/bin/bash", "-c" ]
|
||||
|
||||
|
@ -281,6 +281,7 @@ const App = () => {
|
||||
throw Error("Server is temporarily down.");
|
||||
}
|
||||
const data = await response.json();
|
||||
console.log(`Session created: ${data.id}`);
|
||||
setSessionId(data.id);
|
||||
|
||||
const newPath = `/${data.id}`;
|
||||
|
@ -16,17 +16,20 @@ import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
||||
|
||||
import { SetSnackType } from './Snack';
|
||||
|
||||
interface ServerTunables {
|
||||
system_prompt: string,
|
||||
message_history_length: number,
|
||||
tools: Tool[],
|
||||
rags: Tool[]
|
||||
};
|
||||
|
||||
type Tool = {
|
||||
type: string,
|
||||
function?: {
|
||||
name: string,
|
||||
description: string,
|
||||
parameters?: any,
|
||||
returns?: any
|
||||
},
|
||||
name?: string,
|
||||
description?: string,
|
||||
enabled: boolean
|
||||
name: string,
|
||||
description: string,
|
||||
parameters?: any,
|
||||
returns?: any
|
||||
};
|
||||
|
||||
interface ControlsParams {
|
||||
@ -41,7 +44,6 @@ type GPUInfo = {
|
||||
discrete: boolean
|
||||
}
|
||||
|
||||
|
||||
type SystemInfo = {
|
||||
"Installed RAM": string,
|
||||
"Graphics Card": GPUInfo[],
|
||||
@ -94,116 +96,111 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
||||
const [tools, setTools] = useState<Tool[]>([]);
|
||||
const [rags, setRags] = useState<Tool[]>([]);
|
||||
const [systemPrompt, setSystemPrompt] = useState<string>("");
|
||||
const [serverSystemPrompt, setServerSystemPrompt] = useState<string>("");
|
||||
const [messageHistoryLength, setMessageHistoryLength] = useState<number>(5);
|
||||
const [serverTunables, setServerTunables] = useState<ServerTunables | undefined>(undefined);
|
||||
|
||||
useEffect(() => {
|
||||
if (systemPrompt === serverSystemPrompt || !systemPrompt.trim() || sessionId === undefined) {
|
||||
return;
|
||||
}
|
||||
const sendSystemPrompt = async (prompt: string) => {
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ "system_prompt": prompt }),
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
const newPrompt = data["system_prompt"];
|
||||
if (newPrompt !== serverSystemPrompt) {
|
||||
setServerSystemPrompt(newPrompt);
|
||||
setSystemPrompt(newPrompt)
|
||||
setSnack("System prompt updated", "success");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack("System prompt update failed", "error");
|
||||
}
|
||||
};
|
||||
|
||||
sendSystemPrompt(systemPrompt);
|
||||
|
||||
}, [systemPrompt, setServerSystemPrompt, serverSystemPrompt, connectionBase, sessionId, setSnack]);
|
||||
|
||||
useEffect(() => {
|
||||
if (sessionId === undefined) {
|
||||
return;
|
||||
}
|
||||
const sendMessageHistoryLength = async (length: number) => {
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ "message_history_length": length }),
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
const newLength = data["message_history_length"];
|
||||
if (newLength !== messageHistoryLength) {
|
||||
setMessageHistoryLength(newLength);
|
||||
setSnack("Message history length updated", "success");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack("Message history length update failed", "error");
|
||||
}
|
||||
};
|
||||
|
||||
sendMessageHistoryLength(messageHistoryLength);
|
||||
|
||||
}, [messageHistoryLength, setMessageHistoryLength, connectionBase, sessionId, setSnack]);
|
||||
|
||||
const reset = async (types: ("rags" | "tools" | "history" | "system_prompt" | "message_history_length")[], message: string = "Update successful.") => {
|
||||
useEffect(() => {
|
||||
if (serverTunables === undefined || systemPrompt === serverTunables.system_prompt || !systemPrompt.trim() || sessionId === undefined) {
|
||||
return;
|
||||
}
|
||||
const sendSystemPrompt = async (prompt: string) => {
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/reset/${sessionId}`, {
|
||||
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ "reset": types }),
|
||||
body: JSON.stringify({ "system_prompt": prompt }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
if (data.error) {
|
||||
throw Error()
|
||||
}
|
||||
for (const [key, value] of Object.entries(data)) {
|
||||
switch (key) {
|
||||
case "rags":
|
||||
setRags(value as Tool[]);
|
||||
break;
|
||||
case "tools":
|
||||
setTools(value as Tool[]);
|
||||
break;
|
||||
case "system_prompt":
|
||||
setServerSystemPrompt((value as any)["system_prompt"].trim());
|
||||
setSystemPrompt((value as any)["system_prompt"].trim());
|
||||
break;
|
||||
case "history":
|
||||
console.log('TODO: handle history reset');
|
||||
break;
|
||||
}
|
||||
}
|
||||
setSnack(message, "success");
|
||||
} else {
|
||||
throw Error(`${{ status: response.status, message: response.statusText }}`);
|
||||
|
||||
const tunables = await response.json();
|
||||
serverTunables.system_prompt = tunables.system_prompt;
|
||||
setSystemPrompt(tunables.system_prompt)
|
||||
setSnack("System prompt updated", "success");
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack("System prompt update failed", "error");
|
||||
}
|
||||
};
|
||||
|
||||
sendSystemPrompt(systemPrompt);
|
||||
|
||||
}, [systemPrompt, connectionBase, sessionId, setSnack, serverTunables]);
|
||||
|
||||
useEffect(() => {
|
||||
if (serverTunables === undefined || messageHistoryLength === serverTunables.message_history_length || !messageHistoryLength || sessionId === undefined) {
|
||||
return;
|
||||
}
|
||||
const sendMessageHistoryLength = async (length: number) => {
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ "message_history_length": length }),
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
const newLength = data["message_history_length"];
|
||||
if (newLength !== messageHistoryLength) {
|
||||
setMessageHistoryLength(newLength);
|
||||
setSnack("Message history length updated", "success");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack("Unable to restore defaults", "error");
|
||||
setSnack("Message history length update failed", "error");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
sendMessageHistoryLength(messageHistoryLength);
|
||||
|
||||
}, [messageHistoryLength, setMessageHistoryLength, connectionBase, sessionId, setSnack, serverTunables]);
|
||||
|
||||
const reset = async (types: ("rags" | "tools" | "history" | "system_prompt" | "message_history_length")[], message: string = "Update successful.") => {
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/reset/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ "reset": types }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
if (data.error) {
|
||||
throw Error()
|
||||
}
|
||||
for (const [key, value] of Object.entries(data)) {
|
||||
switch (key) {
|
||||
case "rags":
|
||||
setRags(value as Tool[]);
|
||||
break;
|
||||
case "tools":
|
||||
setTools(value as Tool[]);
|
||||
break;
|
||||
case "system_prompt":
|
||||
setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
|
||||
break;
|
||||
case "history":
|
||||
console.log('TODO: handle history reset');
|
||||
break;
|
||||
}
|
||||
}
|
||||
setSnack(message, "success");
|
||||
} else {
|
||||
throw Error(`${{ status: response.status, message: response.statusText }}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack("Unable to restore defaults", "error");
|
||||
}
|
||||
};
|
||||
|
||||
// Get the system information
|
||||
useEffect(() => {
|
||||
if (systemInfo !== undefined || sessionId === undefined) {
|
||||
@ -229,21 +226,20 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
||||
setEditSystemPrompt(systemPrompt);
|
||||
}, [systemPrompt, setEditSystemPrompt]);
|
||||
|
||||
|
||||
const toggleRag = async (tool: Tool) => {
|
||||
tool.enabled = !tool.enabled
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/rags/${sessionId}`, {
|
||||
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ "tool": tool?.name, "enabled": tool.enabled }),
|
||||
body: JSON.stringify({ "rags": [{ "name": tool?.name, "enabled": tool.enabled }] }),
|
||||
});
|
||||
|
||||
const rags = await response.json();
|
||||
setRags([...rags])
|
||||
const tunables: ServerTunables = await response.json();
|
||||
setRags(tunables.rags)
|
||||
setSnack(`${tool?.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
@ -255,117 +251,63 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
||||
const toggleTool = async (tool: Tool) => {
|
||||
tool.enabled = !tool.enabled
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/tools/${sessionId}`, {
|
||||
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ "tool": tool?.function?.name, "enabled": tool.enabled }),
|
||||
body: JSON.stringify({ "tools": [{ "name": tool.name, "enabled": tool.enabled }] }),
|
||||
});
|
||||
|
||||
const tools = await response.json();
|
||||
setTools([...tools])
|
||||
setSnack(`${tool?.function?.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
||||
const tunables: ServerTunables = await response.json();
|
||||
setTools(tunables.tools)
|
||||
setSnack(`${tool.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack(`${tool?.function?.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
|
||||
setSnack(`${tool.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
|
||||
tool.enabled = !tool.enabled
|
||||
}
|
||||
};
|
||||
|
||||
// If the tools have not been set, fetch them from the server
|
||||
// If the systemPrompt has not been set, fetch it from the server
|
||||
useEffect(() => {
|
||||
if (tools.length || sessionId === undefined) {
|
||||
if (serverTunables !== undefined || sessionId === undefined) {
|
||||
return;
|
||||
}
|
||||
const fetchTools = async () => {
|
||||
try {
|
||||
// Make the fetch request with proper headers
|
||||
const response = await fetch(connectionBase + `/api/tools/${sessionId}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw Error();
|
||||
}
|
||||
const tools = await response.json();
|
||||
setTools(tools);
|
||||
} catch (error: any) {
|
||||
setSnack("Unable to fetch tools", "error");
|
||||
console.error(error);
|
||||
}
|
||||
const fetchTunables = async () => {
|
||||
// Make the fetch request with proper headers
|
||||
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
});
|
||||
const data = await response.json();
|
||||
console.log("Server tunables: ", data);
|
||||
setServerTunables(data);
|
||||
setSystemPrompt(data["system_prompt"]);
|
||||
setMessageHistoryLength(data["message_history_length"]);
|
||||
setTools(data["tools"]);
|
||||
setRags(data["rags"]);
|
||||
}
|
||||
|
||||
fetchTools();
|
||||
}, [sessionId, tools, setTools, setSnack, connectionBase]);
|
||||
|
||||
// If the RAGs have not been set, fetch them from the server
|
||||
useEffect(() => {
|
||||
if (rags.length || sessionId === undefined) {
|
||||
return;
|
||||
}
|
||||
const fetchRags = async () => {
|
||||
try {
|
||||
// Make the fetch request with proper headers
|
||||
const response = await fetch(connectionBase + `/api/rags/${sessionId}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw Error();
|
||||
}
|
||||
const rags = await response.json();
|
||||
setRags(rags);
|
||||
} catch (error: any) {
|
||||
setSnack("Unable to fetch RAGs", "error");
|
||||
console.error(error);
|
||||
}
|
||||
}
|
||||
|
||||
fetchRags();
|
||||
}, [sessionId, rags, setRags, setSnack, connectionBase]);
|
||||
|
||||
// If the systemPrompt has not been set, fetch it from the server
|
||||
useEffect(() => {
|
||||
if (serverSystemPrompt !== "" || sessionId === undefined) {
|
||||
return;
|
||||
}
|
||||
const fetchTunables = async () => {
|
||||
// Make the fetch request with proper headers
|
||||
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
});
|
||||
const data = await response.json();
|
||||
const serverSystemPrompt = data["system_prompt"].trim();
|
||||
setServerSystemPrompt(serverSystemPrompt);
|
||||
setSystemPrompt(serverSystemPrompt);
|
||||
setMessageHistoryLength(data["message_history_length"]);
|
||||
}
|
||||
|
||||
fetchTunables();
|
||||
}, [sessionId, serverSystemPrompt, setServerSystemPrompt, connectionBase]);
|
||||
|
||||
|
||||
|
||||
|
||||
fetchTunables();
|
||||
}, [sessionId, connectionBase, setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags]);
|
||||
|
||||
const toggle = async (type: string, index: number) => {
|
||||
switch (type) {
|
||||
case "rag":
|
||||
if (rags === undefined) {
|
||||
return;
|
||||
}
|
||||
toggleRag(rags[index])
|
||||
break;
|
||||
case "tool":
|
||||
if (tools === undefined) {
|
||||
return;
|
||||
}
|
||||
toggleTool(tools[index]);
|
||||
}
|
||||
};
|
||||
@ -442,11 +384,11 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
||||
<AccordionActions>
|
||||
<FormGroup sx={{ p: 1 }}>
|
||||
{
|
||||
tools.map((tool, index) =>
|
||||
(tools || []).map((tool, index) =>
|
||||
<Box key={index}>
|
||||
<Divider />
|
||||
<FormControlLabel control={<Switch checked={tool.enabled} />} onChange={() => toggle("tool", index)} label={tool?.function?.name} />
|
||||
<Typography sx={{ fontSize: "0.8rem", mb: 1 }}>{tool?.function?.description}</Typography>
|
||||
<FormControlLabel control={<Switch checked={tool.enabled} />} onChange={() => toggle("tool", index)} label={tool.name} />
|
||||
<Typography sx={{ fontSize: "0.8rem", mb: 1 }}>{tool.description}</Typography>
|
||||
</Box>
|
||||
)
|
||||
}</FormGroup>
|
||||
@ -463,14 +405,14 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
||||
<AccordionActions>
|
||||
<FormGroup sx={{ p: 1, flexGrow: 1, justifyContent: "flex-start" }}>
|
||||
{
|
||||
rags.map((rag, index) =>
|
||||
(rags || []).map((rag, index) =>
|
||||
<Box key={index} sx={{ display: "flex", flexGrow: 1, flexDirection: "column" }}>
|
||||
<Divider />
|
||||
<FormControlLabel
|
||||
control={<Switch checked={rag.enabled} />}
|
||||
onChange={() => toggle("rag", index)} label={rag?.name}
|
||||
onChange={() => toggle("rag", index)} label={rag.name}
|
||||
/>
|
||||
<Typography>{rag?.description}</Typography>
|
||||
<Typography>{rag.description}</Typography>
|
||||
</Box>
|
||||
)
|
||||
}</FormGroup>
|
||||
|
@ -160,14 +160,15 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const { messages } = await response.json();
|
||||
|
||||
console.log(`History returned for ${type} from server with ${data.length} entries`)
|
||||
if (data.length === 0) {
|
||||
if (messages === undefined || messages.length === 0) {
|
||||
console.log(`History returned for ${type} from server with 0 entries`)
|
||||
setConversation([])
|
||||
setNoInteractions(true);
|
||||
} else {
|
||||
setConversation(data);
|
||||
console.log(`History returned for ${type} from server with ${messages.length} entries:`, messages)
|
||||
setConversation(messages);
|
||||
setNoInteractions(false);
|
||||
}
|
||||
setProcessingMessage(undefined);
|
||||
|
@ -42,17 +42,19 @@ const DeleteConfirmation = (props : DeleteConfirmationProps) => {
|
||||
return (
|
||||
<>
|
||||
<Tooltip title={label ? `Reset ${label}` : "Reset"} >
|
||||
<IconButton
|
||||
aria-label="reset"
|
||||
onClick={handleClickOpen}
|
||||
color={ color || "inherit" }
|
||||
sx={{ display: "flex", margin: 'auto 0px' }}
|
||||
size="large"
|
||||
edge="start"
|
||||
disabled={disabled}
|
||||
>
|
||||
<ResetIcon />
|
||||
</IconButton>
|
||||
<span style={{ display: "flex" }}> { /* This span is used to wrap the IconButton to ensure Tooltip works even when disabled */}
|
||||
<IconButton
|
||||
aria-label="reset"
|
||||
onClick={handleClickOpen}
|
||||
color={color || "inherit"}
|
||||
sx={{ display: "flex", margin: 'auto 0px' }}
|
||||
size="large"
|
||||
edge="start"
|
||||
disabled={disabled}
|
||||
>
|
||||
<ResetIcon />
|
||||
</IconButton>
|
||||
</span>
|
||||
</Tooltip>
|
||||
|
||||
<Dialog
|
||||
|
@ -39,10 +39,10 @@ type MessageData = {
|
||||
display?: string, /* Messages generated on the server for filler should not be shown */
|
||||
id?: string,
|
||||
isProcessing?: boolean,
|
||||
metadata?: MessageMetaProps
|
||||
metadata?: MessageMetaData
|
||||
};
|
||||
|
||||
interface MessageMetaProps {
|
||||
interface MessageMetaData {
|
||||
query?: {
|
||||
query_embedding: number[];
|
||||
vector_embedding: number[];
|
||||
@ -64,7 +64,7 @@ type MessageList = MessageData[];
|
||||
|
||||
interface MessageProps {
|
||||
sx?: SxProps<Theme>,
|
||||
message?: MessageData,
|
||||
message: MessageData,
|
||||
isFullWidth?: boolean,
|
||||
submitQuery?: (text: string) => void,
|
||||
sessionId?: string,
|
||||
@ -78,8 +78,25 @@ interface ChatQueryInterface {
|
||||
submitQuery?: (text: string) => void
|
||||
}
|
||||
|
||||
interface MessageMetaProps {
|
||||
metadata: MessageMetaData,
|
||||
messageProps: MessageProps
|
||||
};
|
||||
|
||||
|
||||
const MessageMeta = (props: MessageMetaProps) => {
|
||||
const {
|
||||
/* MessageData */
|
||||
full_query,
|
||||
rag,
|
||||
tools,
|
||||
eval_count,
|
||||
eval_duration,
|
||||
prompt_eval_count,
|
||||
prompt_eval_duration,
|
||||
} = props.metadata || {};
|
||||
const messageProps = props.messageProps;
|
||||
|
||||
const MessageMeta = ({ ...props }: MessageMetaProps) => {
|
||||
return (<>
|
||||
<Box sx={{ fontSize: "0.8rem", mb: 1 }}>
|
||||
Below is the LLM performance of this query. Note that if tools are called, the
|
||||
@ -99,28 +116,28 @@ const MessageMeta = ({ ...props }: MessageMetaProps) => {
|
||||
<TableBody>
|
||||
<TableRow key="prompt" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
||||
<TableCell component="th" scope="row">Prompt</TableCell>
|
||||
<TableCell align="right">{props.prompt_eval_count}</TableCell>
|
||||
<TableCell align="right">{Math.round(props.prompt_eval_duration / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round(props.prompt_eval_count * 10 ** 9 / props.prompt_eval_duration)}</TableCell>
|
||||
<TableCell align="right">{prompt_eval_count}</TableCell>
|
||||
<TableCell align="right">{Math.round(prompt_eval_duration / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round(prompt_eval_count * 10 ** 9 / prompt_eval_duration)}</TableCell>
|
||||
</TableRow>
|
||||
<TableRow key="response" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
||||
<TableCell component="th" scope="row">Response</TableCell>
|
||||
<TableCell align="right">{props.eval_count}</TableCell>
|
||||
<TableCell align="right">{Math.round(props.eval_duration / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round(props.eval_count * 10 ** 9 / props.eval_duration)}</TableCell>
|
||||
<TableCell align="right">{eval_count}</TableCell>
|
||||
<TableCell align="right">{Math.round(eval_duration / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round(eval_count * 10 ** 9 / eval_duration)}</TableCell>
|
||||
</TableRow>
|
||||
<TableRow key="total" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
||||
<TableCell component="th" scope="row">Total</TableCell>
|
||||
<TableCell align="right">{props.prompt_eval_count + props.eval_count}</TableCell>
|
||||
<TableCell align="right">{Math.round((props.prompt_eval_duration + props.eval_duration) / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round((props.prompt_eval_count + props.eval_count) * 10 ** 9 / (props.prompt_eval_duration + props.eval_duration))}</TableCell>
|
||||
<TableCell align="right">{prompt_eval_count + eval_count}</TableCell>
|
||||
<TableCell align="right">{Math.round((prompt_eval_duration + eval_duration) / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round((prompt_eval_count + eval_count) * 10 ** 9 / (prompt_eval_duration + eval_duration))}</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
|
||||
{
|
||||
props?.full_query !== undefined &&
|
||||
full_query !== undefined &&
|
||||
<Accordion>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Box sx={{ fontSize: "0.8rem" }}>
|
||||
@ -128,12 +145,12 @@ const MessageMeta = ({ ...props }: MessageMetaProps) => {
|
||||
</Box>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
<pre style={{ "display": "block", "position": "relative" }}><CopyBubble content={props.full_query.trim()} />{props.full_query.trim()}</pre>
|
||||
<pre style={{ "display": "block", "position": "relative" }}><CopyBubble content={full_query?.trim()} />{full_query?.trim()}</pre>
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
}
|
||||
{
|
||||
props.tools !== undefined && props.tools.length !== 0 &&
|
||||
tools !== undefined && tools.length !== 0 &&
|
||||
<Accordion sx={{ boxSizing: "border-box" }}>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Box sx={{ fontSize: "0.8rem" }}>
|
||||
@ -141,7 +158,7 @@ const MessageMeta = ({ ...props }: MessageMetaProps) => {
|
||||
</Box>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
{props.tools.map((tool: any, index: number) => <Box key={index}>
|
||||
{tools.map((tool: any, index: number) => <Box key={index}>
|
||||
{index !== 0 && <Divider />}
|
||||
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "column", mt: 0.5 }}>
|
||||
<div style={{ display: "flex", paddingRight: "1rem", whiteSpace: "nowrap" }}>
|
||||
@ -165,24 +182,24 @@ const MessageMeta = ({ ...props }: MessageMetaProps) => {
|
||||
</Accordion>
|
||||
}
|
||||
{
|
||||
props?.rag?.name !== undefined && <>
|
||||
rag?.name !== undefined && <>
|
||||
<Accordion>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Box sx={{ fontSize: "0.8rem" }}>
|
||||
Top RAG {props.rag.ids.length} matches from '{props.rag.name}' collection against embedding vector of {props.rag.query_embedding.length} dimensions
|
||||
Top RAG {rag.ids.length} matches from '{rag.name}' collection against embedding vector of {rag.query_embedding.length} dimensions
|
||||
</Box>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
{props.rag.ids.map((id: number, index: number) => <Box key={index}>
|
||||
{rag.ids.map((id: number, index: number) => <Box key={index}>
|
||||
{index !== 0 && <Divider />}
|
||||
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "row", mb: 0.5, mt: 0.5 }}>
|
||||
<div style={{ display: "flex", flexDirection: "column", paddingRight: "1rem", minWidth: "10rem" }}>
|
||||
<div style={{ whiteSpace: "nowrap" }}>Doc ID: {props.rag.ids[index].slice(-10)}</div>
|
||||
<div style={{ whiteSpace: "nowrap" }}>Similarity: {Math.round(props.rag.distances[index] * 100) / 100}</div>
|
||||
<div style={{ whiteSpace: "nowrap" }}>Type: {props.rag.metadatas[index].doc_type}</div>
|
||||
<div style={{ whiteSpace: "nowrap" }}>Chunk Len: {props.rag.documents[index].length}</div>
|
||||
<div style={{ whiteSpace: "nowrap" }}>Doc ID: {rag.ids[index].slice(-10)}</div>
|
||||
<div style={{ whiteSpace: "nowrap" }}>Similarity: {Math.round(rag.distances[index] * 100) / 100}</div>
|
||||
<div style={{ whiteSpace: "nowrap" }}>Type: {rag.metadatas[index].doc_type}</div>
|
||||
<div style={{ whiteSpace: "nowrap" }}>Chunk Len: {rag.documents[index].length}</div>
|
||||
</div>
|
||||
<div style={{ display: "flex", padding: "3px", flexGrow: 1, border: "1px solid #E0E0E0", maxHeight: "5rem", overflow: "auto" }}>{props.rag.documents[index]}</div>
|
||||
<div style={{ display: "flex", padding: "3px", flexGrow: 1, border: "1px solid #E0E0E0", maxHeight: "5rem", overflow: "auto" }}>{rag.documents[index]}</div>
|
||||
</Box>
|
||||
</Box>
|
||||
)}
|
||||
@ -195,13 +212,53 @@ const MessageMeta = ({ ...props }: MessageMetaProps) => {
|
||||
</Box>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
<VectorVisualizer inline {...props} rag={props?.rag} />
|
||||
<VectorVisualizer inline {...messageProps} {...props.metadata} rag={rag} />
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
<Accordion>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Box sx={{ fontSize: "0.8rem" }}>
|
||||
All response fields
|
||||
</Box>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
{Object.entries(props.messageProps.message)
|
||||
.filter(([key, value]) => key !== undefined && value !== undefined)
|
||||
.map(([key, value]) => (<>
|
||||
{(typeof (value) !== "string" || value?.trim() !== "") &&
|
||||
<Accordion key={key}>
|
||||
<AccordionSummary sx={{ fontSize: "1rem", fontWeight: "bold" }} expandIcon={<ExpandMoreIcon />}>
|
||||
{key}
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
{key === "metadata" &&
|
||||
Object.entries(value)
|
||||
.filter(([key, value]) => key !== undefined && value !== undefined)
|
||||
.map(([key, value]) => (
|
||||
<Accordion key={`${key}-metadata`}>
|
||||
<AccordionSummary sx={{ fontSize: "1rem", fontWeight: "bold" }} expandIcon={<ExpandMoreIcon />}>
|
||||
{key}
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
<pre>{`${typeof (value) !== "object" ? value : JSON.stringify(value)}`}</pre>
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
))}
|
||||
{key !== "metadata" &&
|
||||
<pre>{typeof (value) !== "object" ? value : JSON.stringify(value)}</pre>
|
||||
}
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
}
|
||||
</>
|
||||
)
|
||||
)}
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
|
||||
</>
|
||||
}
|
||||
</>
|
||||
);
|
||||
</>);
|
||||
};
|
||||
|
||||
const ChatQuery = ({ text, submitQuery }: ChatQueryInterface) => {
|
||||
@ -221,7 +278,7 @@ const ChatQuery = ({ text, submitQuery }: ChatQueryInterface) => {
|
||||
}
|
||||
|
||||
const Message = (props: MessageProps) => {
|
||||
const { message, submitQuery, isFullWidth, sessionId, setSnack, connectionBase, sx, className } = props;
|
||||
const { message, submitQuery, isFullWidth, sx, className } = props;
|
||||
const [expanded, setExpanded] = useState<boolean>(false);
|
||||
const textFieldRef = useRef(null);
|
||||
|
||||
@ -293,7 +350,7 @@ const Message = (props: MessageProps) => {
|
||||
{message.metadata && <>
|
||||
<Collapse in={expanded} timeout="auto" unmountOnExit>
|
||||
<CardContent>
|
||||
<MessageMeta {...{ ...message.metadata, sessionId, connectionBase, setSnack }} />
|
||||
<MessageMeta messageProps={props} metadata={message.metadata} />
|
||||
</CardContent>
|
||||
</Collapse>
|
||||
</>}
|
||||
@ -305,7 +362,6 @@ export type {
|
||||
MessageProps,
|
||||
MessageList,
|
||||
ChatQueryInterface,
|
||||
MessageMetaProps,
|
||||
MessageData,
|
||||
MessageRoles
|
||||
};
|
||||
|
@ -112,6 +112,12 @@ const ResumeBuilder: React.FC<ResumeBuilderProps> = ({
|
||||
return keep;
|
||||
});
|
||||
|
||||
/* If Resume hasn't occurred yet and there is still more than one message,
|
||||
* resume has been generated. */
|
||||
if (!hasResume && reduced.length > 1) {
|
||||
setHasResume(true);
|
||||
}
|
||||
|
||||
if (reduced.length > 0) {
|
||||
// First message is always 'content'
|
||||
reduced[0].title = 'Job Description';
|
||||
@ -123,7 +129,7 @@ const ResumeBuilder: React.FC<ResumeBuilderProps> = ({
|
||||
reduced = reduced.filter(m => m.display !== "hide");
|
||||
|
||||
return reduced;
|
||||
}, [setHasJobDescription, setHasResume]);
|
||||
}, [setHasJobDescription, setHasResume, hasResume]);
|
||||
|
||||
const filterResumeMessages = useCallback((messages: MessageList): MessageList => {
|
||||
if (messages === undefined || messages.length === 0) {
|
||||
@ -135,11 +141,11 @@ const ResumeBuilder: React.FC<ResumeBuilderProps> = ({
|
||||
if ((m.metadata?.origin || m.origin || "no origin") === 'fact_check') {
|
||||
setHasFacts(true);
|
||||
}
|
||||
// if (!keep) {
|
||||
// console.log(`filterResumeMessages: ${i + 1} filtered:`, m);
|
||||
// } else {
|
||||
// console.log(`filterResumeMessages: ${i + 1}:`, m);
|
||||
// }
|
||||
if (!keep) {
|
||||
console.log(`filterResumeMessages: ${i + 1} filtered:`, m);
|
||||
} else {
|
||||
console.log(`filterResumeMessages: ${i + 1}:`, m);
|
||||
}
|
||||
return keep;
|
||||
});
|
||||
|
||||
|
1236
src/server.py
1236
src/server.py
File diff suppressed because it is too large
Load Diff
@ -2,7 +2,9 @@
|
||||
from . import defines
|
||||
|
||||
# Import rest as `utils.*` accessible
|
||||
from .rag import *
|
||||
from .rag import ChromaDBFileWatcher, start_file_watcher
|
||||
|
||||
# Expose only public names (avoid importing hidden/internal names)
|
||||
__all__ = [name for name in dir() if not name.startswith("_")]
|
||||
from .message import Message
|
||||
from .conversation import Conversation
|
||||
from .session import Session, Chat, Resume, JobDescription, FactCheck
|
||||
from .context import Context
|
98
src/utils/context.py
Normal file
98
src/utils/context.py
Normal file
@ -0,0 +1,98 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from uuid import uuid4
|
||||
from typing import List, Optional
|
||||
from typing_extensions import Annotated, Union
|
||||
from .session import AnySession, Session
|
||||
|
||||
class Context(BaseModel):
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid4()),
|
||||
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
||||
)
|
||||
|
||||
sessions: List[Annotated[Union[*Session.__subclasses__()], Field(discriminator="session_type")]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
user_resume: Optional[str] = None
|
||||
user_job_description: Optional[str] = None
|
||||
user_facts: Optional[str] = None
|
||||
tools: List[dict] = []
|
||||
rags: List[dict] = []
|
||||
message_history_length: int = 5
|
||||
context_tokens: int = 0
|
||||
|
||||
def __init__(self, id: Optional[str] = None, **kwargs):
|
||||
super().__init__(id=id if id is not None else str(uuid4()), **kwargs)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_unique_session_types(self):
|
||||
"""Ensure at most one session per session_type."""
|
||||
session_types = [session.session_type for session in self.sessions]
|
||||
if len(session_types) != len(set(session_types)):
|
||||
raise ValueError("Context cannot contain multiple sessions of the same session_type")
|
||||
return self
|
||||
|
||||
def get_or_create_session(self, session_type: str, **kwargs) -> Session:
|
||||
"""
|
||||
Get or create and append a new session of the specified type, ensuring only one session per type exists.
|
||||
|
||||
Args:
|
||||
session_type: The type of session to create (e.g., 'web', 'database').
|
||||
**kwargs: Additional fields required by the specific session subclass.
|
||||
|
||||
Returns:
|
||||
The created session instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching session type is found or if a session of this type already exists.
|
||||
"""
|
||||
# Check if a session with the given session_type already exists
|
||||
for session in self.sessions:
|
||||
if session.session_type == session_type:
|
||||
return session
|
||||
|
||||
# Find the matching subclass
|
||||
for session_cls in Session.__subclasses__():
|
||||
if session_cls.__fields__["session_type"].default == session_type:
|
||||
# Create the session instance with provided kwargs
|
||||
session = session_cls(session_type=session_type, **kwargs)
|
||||
self.sessions.append(session)
|
||||
return session
|
||||
|
||||
raise ValueError(f"No session class found for session_type: {session_type}")
|
||||
|
||||
def add_session(self, session: AnySession) -> None:
|
||||
"""Add a Session to the context, ensuring no duplicate session_type."""
|
||||
if any(s.session_type == session.session_type for s in self.sessions):
|
||||
raise ValueError(f"A session with session_type '{session.session_type}' already exists")
|
||||
self.sessions.append(session)
|
||||
|
||||
def get_session(self, session_type: str) -> Session | None:
|
||||
"""Return the Session with the given session_type, or None if not found."""
|
||||
for session in self.sessions:
|
||||
if session.session_type == session_type:
|
||||
return session
|
||||
return None
|
||||
|
||||
def is_valid_session_type(self, session_type: str) -> bool:
|
||||
"""Check if the given session_type is valid."""
|
||||
return session_type in Session.valid_session_types()
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""Return a summary of the context."""
|
||||
if not self.sessions:
|
||||
return f"Context {self.uuid}: No sessions."
|
||||
summary = f"Context {self.uuid}:\n"
|
||||
for i, session in enumerate(self.sessions, 1):
|
||||
summary += f"\nSession {i} ({session.session_type}):\n"
|
||||
summary += session.conversation.get_summary()
|
||||
if session.session_type == "resume":
|
||||
summary += f"\nResume: {session.get_resume()}\n"
|
||||
elif session.session_type == "job_description":
|
||||
summary += f"\nJob Description: {session.job_description}\n"
|
||||
elif session.session_type == "fact_check":
|
||||
summary += f"\nFacts: {session.facts}\n"
|
||||
elif session.session_type == "chat":
|
||||
summary += f"\nChat Name: {session.name}\n"
|
||||
return summary
|
23
src/utils/conversation.py
Normal file
23
src/utils/conversation.py
Normal file
@ -0,0 +1,23 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
from datetime import datetime, timezone
|
||||
from .message import Message
|
||||
|
||||
class Conversation(BaseModel):
|
||||
messages: List[Message] = []
|
||||
|
||||
def add_message(self, message: Message | List[Message]) -> None:
|
||||
"""Add a Message(s) to the conversation."""
|
||||
if isinstance(message, Message):
|
||||
self.messages.append(message)
|
||||
else:
|
||||
self.messages.extend(message)
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""Return a summary of the conversation."""
|
||||
if not self.messages:
|
||||
return "Conversation is empty."
|
||||
summary = f"Conversation:\n"
|
||||
for i, message in enumerate(self.messages, 1):
|
||||
summary += f"\nMessage {i}:\n{message.get_summary()}\n"
|
||||
return summary
|
31
src/utils/message.py
Normal file
31
src/utils/message.py
Normal file
@ -0,0 +1,31 @@
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timezone
|
||||
|
||||
class Message(BaseModel):
|
||||
prompt: str
|
||||
preamble: str = ""
|
||||
content: str = ""
|
||||
response: str = ""
|
||||
metadata: dict[str, Any] = {
|
||||
"rag": { "documents": [] },
|
||||
"tools": [],
|
||||
"eval_count": 0,
|
||||
"eval_duration": 0,
|
||||
"prompt_eval_count": 0,
|
||||
"prompt_eval_duration": 0,
|
||||
}
|
||||
actions: List[str] = []
|
||||
timestamp: datetime = datetime.now(timezone.utc)
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""Return a summary of the message."""
|
||||
response_summary = (
|
||||
f"Response: {self.response} (Actions: {', '.join(self.actions)})"
|
||||
if self.response else "No response yet"
|
||||
)
|
||||
return (
|
||||
f"Message at {self.timestamp}:\n"
|
||||
f"Query: {self.preamble}{self.content}\n"
|
||||
f"{response_summary}"
|
||||
)
|
78
src/utils/session.py
Normal file
78
src/utils/session.py
Normal file
@ -0,0 +1,78 @@
|
||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||
from typing import Literal, TypeAlias, get_args
|
||||
from .conversation import Conversation
|
||||
|
||||
class Session(BaseModel):
|
||||
session_type: Literal["resume", "job_description", "fact_check", "chat"]
|
||||
system_prompt: str = "You are a helpful assistant."
|
||||
conversation: Conversation = Conversation()
|
||||
context_tokens: int = 0
|
||||
|
||||
_content_seed: str = PrivateAttr(default="")
|
||||
|
||||
def get_and_reset_content_seed(self):
|
||||
tmp = self._content_seed
|
||||
self._content_seed = ""
|
||||
return tmp
|
||||
|
||||
def set_content_seed(self, content: str) -> None:
|
||||
"""Set the content seed for the session."""
|
||||
self._content_seed = content
|
||||
|
||||
def get_content_seed(self) -> str:
|
||||
"""Get the content seed for the session."""
|
||||
return self._content_seed
|
||||
|
||||
@classmethod
|
||||
def valid_session_types(cls) -> set[str]:
|
||||
"""Return the set of valid session_type values."""
|
||||
return set(get_args(cls.__annotations__["session_type"]))
|
||||
|
||||
|
||||
# Type alias for Session or any subclass
|
||||
AnySession: TypeAlias = Session # BaseModel covers Session and subclasses
|
||||
|
||||
class Resume(Session):
|
||||
session_type: Literal["resume"] = "resume"
|
||||
resume: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_resume(self):
|
||||
if not self.resume.strip():
|
||||
raise ValueError("Resume content cannot be empty")
|
||||
return self
|
||||
|
||||
def get_resume(self) -> str:
|
||||
"""Get the resume content."""
|
||||
return self.resume
|
||||
|
||||
def set_resume(self, resume: str) -> None:
|
||||
"""Set the resume content."""
|
||||
self.resume = resume
|
||||
|
||||
class JobDescription(Session):
|
||||
session_type: Literal["job_description"] = "job_description"
|
||||
job_description: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_job_description(self):
|
||||
if not self.job_description.strip():
|
||||
raise ValueError("Job description cannot be empty")
|
||||
return self
|
||||
|
||||
class FactCheck(Session):
|
||||
session_type: Literal["fact_check"] = "fact_check"
|
||||
facts: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_facts(self):
|
||||
if not self.facts.strip():
|
||||
raise ValueError("Facts cannot be empty")
|
||||
return self
|
||||
|
||||
class Chat(Session):
|
||||
session_type: Literal["chat"] = "chat"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name(self):
|
||||
return self
|
Loading…
x
Reference in New Issue
Block a user