Lots of improvements to tunables and feedback
This commit is contained in:
parent
078c6ae183
commit
e159239ade
@ -192,3 +192,12 @@ div {
|
||||
margin-bottom: 0;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.PromptStats .MuiTableCell-root {
|
||||
font-size: 0.8rem;
|
||||
}
|
||||
|
||||
#SystemPromptInput {
|
||||
font-size: 0.9rem;
|
||||
line-height: 1.25rem;
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
import React, { useState, useEffect, useRef, useCallback, ReactElement } from 'react';
|
||||
import FormGroup from '@mui/material/FormGroup';
|
||||
import FormControlLabel from '@mui/material/FormControlLabel';
|
||||
import { useTheme } from '@mui/material';
|
||||
import { styled } from '@mui/material/styles';
|
||||
import Switch from '@mui/material/Switch';
|
||||
import Divider from '@mui/material/Divider';
|
||||
@ -18,35 +17,43 @@ import Button from '@mui/material/Button';
|
||||
import AppBar from '@mui/material/AppBar';
|
||||
import Drawer from '@mui/material/Drawer';
|
||||
import Toolbar from '@mui/material/Toolbar';
|
||||
import MenuIcon from '@mui/icons-material/Menu';
|
||||
import SettingsIcon from '@mui/icons-material/Settings';
|
||||
import IconButton, { IconButtonProps } from '@mui/material/IconButton';
|
||||
import Box from '@mui/material/Box';
|
||||
import CssBaseline from '@mui/material/CssBaseline';
|
||||
import AddIcon from '@mui/icons-material/AddCircle';
|
||||
import ResetIcon from '@mui/icons-material/History';
|
||||
import SendIcon from '@mui/icons-material/Send';
|
||||
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
||||
import MoreVertIcon from '@mui/icons-material/MoreVert';
|
||||
import Card from '@mui/material/Card';
|
||||
import CardHeader from '@mui/material/CardHeader';
|
||||
import CardMedia from '@mui/material/CardMedia';
|
||||
import CardContent from '@mui/material/CardContent';
|
||||
import CardActions from '@mui/material/CardActions';
|
||||
import Collapse from '@mui/material/Collapse';
|
||||
import Table from '@mui/material/Table';
|
||||
import TableBody from '@mui/material/TableBody';
|
||||
import TableCell from '@mui/material/TableCell';
|
||||
import TableContainer from '@mui/material/TableContainer';
|
||||
import TableHead from '@mui/material/TableHead';
|
||||
import TableRow from '@mui/material/TableRow';
|
||||
|
||||
import PropagateLoader from "react-spinners/PropagateLoader";
|
||||
// import Markdown from 'react-markdown';
|
||||
import { MuiMarkdown as Markdown } from "mui-markdown";
|
||||
import './App.css';
|
||||
|
||||
import { MuiMarkdown } from "mui-markdown";
|
||||
import ReactMarkdown from 'react-markdown';
|
||||
import rehypeKatex from 'rehype-katex'
|
||||
import remarkMath from 'remark-math'
|
||||
import 'katex/dist/katex.min.css' // `rehype-katex` does not import the CSS for you
|
||||
|
||||
import './App.css';
|
||||
|
||||
import '@fontsource/roboto/300.css';
|
||||
import '@fontsource/roboto/400.css';
|
||||
import '@fontsource/roboto/500.css';
|
||||
import '@fontsource/roboto/700.css';
|
||||
|
||||
//const use_mui_markdown = true
|
||||
const use_mui_markdown = true
|
||||
|
||||
|
||||
const welcomeMarkdown = `
|
||||
# Welcome to Ketr-Chat.
|
||||
|
||||
@ -54,12 +61,12 @@ This LLM agent was built by James Ketrenos in order to provide answers to any qu
|
||||
|
||||
In addition to being a RAG enabled expert system, the LLM is configured with real-time access to weather, stocks, the current time, and can answer questions about the contents of a website.
|
||||
|
||||
Ask things like:
|
||||
* What are the headlines from CNBC?
|
||||
* What is the weather in Portland, OR?
|
||||
* What is James Ketrenos' work history?
|
||||
* What are the stock value of the most traded companies?
|
||||
* What programming languages has James used?
|
||||
You can ask things like: (or just click the text to submit the query)
|
||||
* <ChatQuery text="What are the headlines from CNBC?"/>
|
||||
* <ChatQuery text="What is the weather in Portland, OR?"/>
|
||||
* <ChatQuery text="What is James Ketrenos' work history?"/>
|
||||
* <ChatQuery text="What are the stock value of the most traded companies?"/>
|
||||
* <ChatQuery text="What programming languages has James used?"/>
|
||||
`;
|
||||
|
||||
const welcomeMessage = {
|
||||
@ -90,7 +97,9 @@ interface ControlsParams {
|
||||
toggleTool: (tool: Tool) => void,
|
||||
toggleRag: (tool: Tool) => void,
|
||||
setSystemPrompt: (prompt: string) => void,
|
||||
reset: (types: ("rags" | "tools" | "history" | "system-prompt")[], message: string) => Promise<void>
|
||||
reset: (types: ("rags" | "tools" | "history" | "system-prompt" | "message-history-length")[], message: string) => Promise<void>
|
||||
messageHistoryLength: number,
|
||||
setMessageHistoryLength: (messageHistoryLength: number) => void,
|
||||
};
|
||||
|
||||
type GPUInfo = {
|
||||
@ -106,7 +115,11 @@ type SystemInfo = {
|
||||
|
||||
type MessageMetadata = {
|
||||
rag: any,
|
||||
tools: any[]
|
||||
tools: any[],
|
||||
eval_count: number,
|
||||
eval_duration: number,
|
||||
prompt_eval_count: number,
|
||||
prompt_eval_duration: number
|
||||
};
|
||||
|
||||
type MessageData = {
|
||||
@ -167,7 +180,7 @@ const SystemInfoComponent: React.FC<{ systemInfo: SystemInfo }> = ({ systemInfo
|
||||
return <div className="SystemInfo">{systemElements}</div>;
|
||||
};
|
||||
|
||||
const Controls = ({ tools, rags, systemPrompt, toggleTool, toggleRag, setSystemPrompt, reset, systemInfo }: ControlsParams) => {
|
||||
const Controls = ({ tools, rags, systemPrompt, toggleTool, toggleRag, messageHistoryLength, setMessageHistoryLength, setSystemPrompt, reset, systemInfo }: ControlsParams) => {
|
||||
const [editSystemPrompt, setEditSystemPrompt] = useState<string>(systemPrompt);
|
||||
|
||||
useEffect(() => {
|
||||
@ -222,6 +235,29 @@ const Controls = ({ tools, rags, systemPrompt, toggleTool, toggleRag, setSystemP
|
||||
</div>
|
||||
</AccordionActions>
|
||||
</Accordion>
|
||||
<Accordion>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Typography component="span">Tunables</Typography>
|
||||
</AccordionSummary>
|
||||
<AccordionActions style={{ flexDirection: "column" }}>
|
||||
<TextField
|
||||
id="outlined-number"
|
||||
label="Message history"
|
||||
type="number"
|
||||
helperText="Only use this many messages as context. 0 = All. Keeping this low will reduce context growth and improve performance."
|
||||
value={messageHistoryLength}
|
||||
onChange={(e: any) => setMessageHistoryLength(e.target.value)}
|
||||
slotProps={{
|
||||
htmlInput: {
|
||||
min: 0
|
||||
},
|
||||
inputLabel: {
|
||||
shrink: true,
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</AccordionActions>
|
||||
</Accordion>
|
||||
<Accordion>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Typography component="span">Tools</Typography>
|
||||
@ -236,7 +272,7 @@ const Controls = ({ tools, rags, systemPrompt, toggleTool, toggleRag, setSystemP
|
||||
<Box key={index}>
|
||||
<Divider />
|
||||
<FormControlLabel control={<Switch checked={tool.enabled} />} onChange={() => toggle("tool", index)} label={tool?.function?.name} />
|
||||
<Typography>{tool?.function?.description}</Typography>
|
||||
<Typography sx={{ fontSize: "0.8rem", mb: 1 }}>{tool?.function?.description}</Typography>
|
||||
</Box>
|
||||
)
|
||||
}</FormGroup>
|
||||
@ -274,7 +310,7 @@ const Controls = ({ tools, rags, systemPrompt, toggleTool, toggleRag, setSystemP
|
||||
</AccordionActions>
|
||||
</Accordion>
|
||||
<Button onClick={() => { reset(["history"], "History cleared."); }}>Clear Chat History</Button>
|
||||
<Button onClick={() => { reset(["rags", "tools", "system-prompt"], "Default settings restored.") }}>Reset to defaults</Button>
|
||||
<Button onClick={() => { reset(["rags", "tools", "system-prompt", "message-history-length"], "Default settings restored.") }}>Reset to defaults</Button>
|
||||
</div>);
|
||||
}
|
||||
|
||||
@ -307,7 +343,8 @@ const ExpandMore = styled((props: ExpandMoreProps) => {
|
||||
}));
|
||||
|
||||
interface MessageInterface {
|
||||
message: MessageData
|
||||
message: MessageData,
|
||||
submitQuery: (text: string) => void
|
||||
};
|
||||
|
||||
interface MessageMetaInterface {
|
||||
@ -319,29 +356,74 @@ const MessageMeta = ({ metadata }: MessageMetaInterface) => {
|
||||
}
|
||||
|
||||
return (<>
|
||||
<Box sx={{ fontSize: "0.8rem", mb: 1 }}>
|
||||
Below is the LLM performance of this query. Note that if tools are called, the entire context is processed for each separate tool request by the LLM. This can dramatically increase the total time for a response.
|
||||
</Box>
|
||||
<TableContainer component={Card} className="PromptStats" sx={{ mb: 1 }}>
|
||||
<Table aria-label="prompt stats" size="small">
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableCell></TableCell>
|
||||
<TableCell align="right" >Tokens</TableCell>
|
||||
<TableCell align="right">Time (s)</TableCell>
|
||||
<TableCell align="right">TPS</TableCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
<TableRow key="prompt" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
||||
<TableCell component="th" scope="row">Prompt</TableCell>
|
||||
<TableCell align="right">{metadata.prompt_eval_count}</TableCell>
|
||||
<TableCell align="right">{Math.round(metadata.prompt_eval_duration / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round(metadata.prompt_eval_count * 10 ** 9 / metadata.prompt_eval_duration)}</TableCell>
|
||||
</TableRow>
|
||||
<TableRow key="response" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
||||
<TableCell component="th" scope="row">Response</TableCell>
|
||||
<TableCell align="right">{metadata.eval_count}</TableCell>
|
||||
<TableCell align="right">{Math.round(metadata.eval_duration / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round(metadata.eval_count * 10 ** 9 / metadata.eval_duration)}</TableCell>
|
||||
</TableRow>
|
||||
<TableRow key="total" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
||||
<TableCell component="th" scope="row">Total</TableCell>
|
||||
<TableCell align="right">{metadata.prompt_eval_count + metadata.eval_count}</TableCell>
|
||||
<TableCell align="right">{Math.round((metadata.prompt_eval_duration + metadata.eval_duration) / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round((metadata.prompt_eval_count + metadata.eval_count) * 10 ** 9 / (metadata.prompt_eval_duration + metadata.eval_duration))}</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
{
|
||||
metadata.tools !== undefined && metadata.tools.length !== 0 &&
|
||||
<Typography sx={{ marginBottom: 2 }}>
|
||||
<p>Tools queried:</p>
|
||||
{metadata.tools.map((tool: any, index: number) => <>
|
||||
<Divider />
|
||||
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "column", mb: 0.5, mt: 0.5 }} key={index}>
|
||||
<div style={{ display: "flex", flexDirection: "column", paddingRight: "1rem", minWidth: "10rem" }}>
|
||||
<div style={{ whiteSpace: "nowrap" }}>{tool.tool}</div>
|
||||
<div style={{ whiteSpace: "nowrap" }}>Result Len: {JSON.stringify(tool.result).length}</div>
|
||||
<Accordion>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Box sx={{ fontSize: "0.8rem" }}>
|
||||
Tools queried
|
||||
</Box>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
{metadata.tools.map((tool: any, index: number) => <Box key={index}>
|
||||
{index !== 0 && <Divider />}
|
||||
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "column", mt: 0.5 }}>
|
||||
<div style={{ display: "flex", paddingRight: "1rem", minWidth: "10rem", whiteSpace: "nowrap" }}>
|
||||
{tool.tool}
|
||||
</div>
|
||||
<div style={{ display: "flex", padding: "3px", whiteSpace: "pre-wrap", flexGrow: 1, border: "1px solid #E0E0E0", maxHeight: "5rem", overflow: "auto" }}>{JSON.stringify(tool.result, null, 2)}</div>
|
||||
</Box>
|
||||
</>)}
|
||||
</Typography>
|
||||
</Box>)}
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
}
|
||||
{
|
||||
metadata.rag.name !== undefined &&
|
||||
<Typography sx={{ marginBottom: 2 }}>
|
||||
<p>Top RAG {metadata.rag.ids.length} matches from '{metadata.rag.name}' collection against embedding vector of {metadata.rag.query_embedding.length} dimensions:</p>
|
||||
{metadata.rag.ids.map((id: number, index: number) => <>
|
||||
<Divider />
|
||||
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "row", mb: 0.5, mt: 0.5 }} key={index}>
|
||||
<Accordion>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Box sx={{ fontSize: "0.8rem" }}>
|
||||
Top RAG {metadata.rag.ids.length} matches from '{metadata.rag.name}' collection against embedding vector of {metadata.rag.query_embedding.length} dimensions
|
||||
</Box>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
{metadata.rag.ids.map((id: number, index: number) => <Box key={index}>
|
||||
{index !== 0 && <Divider />}
|
||||
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "row", mb: 0.5, mt: 0.5 }}>
|
||||
<div style={{ display: "flex", flexDirection: "column", paddingRight: "1rem", minWidth: "10rem" }}>
|
||||
<div style={{ whiteSpace: "nowrap" }}>Doc ID: {metadata.rag.ids[index]}</div>
|
||||
<div style={{ whiteSpace: "nowrap" }}>Similarity: {Math.round(metadata.rag.distances[index] * 100) / 100}</div>
|
||||
@ -350,15 +432,25 @@ const MessageMeta = ({ metadata }: MessageMetaInterface) => {
|
||||
</div>
|
||||
<div style={{ display: "flex", padding: "3px", flexGrow: 1, border: "1px solid #E0E0E0", maxHeight: "5rem", overflow: "auto" }}>{metadata.rag.documents[index]}</div>
|
||||
</Box>
|
||||
</>
|
||||
</Box>
|
||||
)}
|
||||
</Typography >
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const Message = ({ message }: MessageInterface) => {
|
||||
interface ChatQueryInterface {
|
||||
text: string,
|
||||
submitQuery: (text: string) => void
|
||||
}
|
||||
|
||||
const ChatQuery = ({ text, submitQuery }: ChatQueryInterface) => {
|
||||
return (<Button onClick={(e: any) => { console.log(text); submitQuery(text); }}>{text}</Button>);
|
||||
}
|
||||
|
||||
const Message = ({ message, submitQuery }: MessageInterface) => {
|
||||
const [expanded, setExpanded] = React.useState(false);
|
||||
|
||||
const handleExpandClick = () => {
|
||||
@ -371,7 +463,15 @@ const Message = ({ message }: MessageInterface) => {
|
||||
<Card sx={{ flexGrow: 1, pb: message.metadata ? 0 : "8px" }} className={(message.role === 'user' ? 'user-message' : 'assistant-message')}>
|
||||
<CardContent>
|
||||
{message.role === 'assistant' ?
|
||||
<Markdown children={formattedContent} />
|
||||
use_mui_markdown ? <MuiMarkdown children={formattedContent} overrides={{
|
||||
ChatQuery: {
|
||||
component: ChatQuery,
|
||||
props: {
|
||||
submitQuery
|
||||
}, // Optional: pass default props if needed
|
||||
},
|
||||
}} /> : <ReactMarkdown remarkPlugins={[remarkMath]}
|
||||
rehypePlugins={[rehypeKatex]} children={formattedContent} />
|
||||
:
|
||||
<Typography variant="body2" sx={{ color: 'text.secondary' }}>
|
||||
{message.content}
|
||||
@ -423,6 +523,37 @@ const App = () => {
|
||||
const [serverSystemPrompt, setServerSystemPrompt] = useState<string>("");
|
||||
const [systemInfo, setSystemInfo] = useState<SystemInfo | undefined>(undefined);
|
||||
const [contextStatus, setContextStatus] = useState<ContextStatus>({ context_used: 0, max_context: 0 });
|
||||
const [contextWarningShown, setContextWarningShown] = useState<boolean>(false);
|
||||
const [contextUsedPercentage, setContextUsedPercentage] = useState<number>(0);
|
||||
const [lastEvalTPS, setLastEvalTPS] = useState<number>(35);
|
||||
const [lastPromptTPS, setLastPromptTPS] = useState<number>(430);
|
||||
const [countdown, setCountdown] = useState<number>(0);
|
||||
const [messageHistoryLength, setMessageHistoryLength] = useState<number>(0);
|
||||
|
||||
const timerRef = useRef<any>(null);
|
||||
|
||||
const startCountdown = (seconds: number) => {
|
||||
if (timerRef.current) clearInterval(timerRef.current);
|
||||
setCountdown(seconds);
|
||||
timerRef.current = setInterval(() => {
|
||||
setCountdown((prev) => {
|
||||
if (prev <= 1) {
|
||||
clearInterval(timerRef.current);
|
||||
timerRef.current = null;
|
||||
return 0;
|
||||
}
|
||||
return prev - 1;
|
||||
});
|
||||
}, 1000);
|
||||
};
|
||||
|
||||
const stopCountdown = () => {
|
||||
if (timerRef.current) {
|
||||
clearInterval(timerRef.current);
|
||||
timerRef.current = null;
|
||||
setCountdown(0);
|
||||
}
|
||||
};
|
||||
|
||||
// Scroll to bottom of conversation when conversation updates
|
||||
useEffect(() => {
|
||||
@ -460,6 +591,7 @@ const App = () => {
|
||||
});
|
||||
}, [systemInfo, setSystemInfo, loc, setSnack, sessionId])
|
||||
|
||||
// Update the context status
|
||||
const updateContextStatus = useCallback(() => {
|
||||
fetch(getConnectionBase(loc) + `/api/context-status/${sessionId}`, {
|
||||
method: 'GET',
|
||||
@ -469,7 +601,6 @@ const App = () => {
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log(`Session id: ${sessionId} -- history returned from server with ${data.length} entries`)
|
||||
setContextStatus(data);
|
||||
})
|
||||
.catch(error => {
|
||||
@ -538,9 +669,9 @@ const App = () => {
|
||||
if (serverSystemPrompt !== "" || sessionId === undefined) {
|
||||
return;
|
||||
}
|
||||
const fetchSystemPrompt = async () => {
|
||||
const fetchTunables = async () => {
|
||||
// Make the fetch request with proper headers
|
||||
const response = await fetch(getConnectionBase(loc) + `/api/system-prompt/${sessionId}`, {
|
||||
const response = await fetch(getConnectionBase(loc) + `/api/tunables/${sessionId}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@ -549,12 +680,12 @@ const App = () => {
|
||||
});
|
||||
const data = await response.json();
|
||||
const serverSystemPrompt = data["system-prompt"].trim();
|
||||
console.log("System prompt initialized to:", serverSystemPrompt);
|
||||
setServerSystemPrompt(serverSystemPrompt);
|
||||
setSystemPrompt(serverSystemPrompt);
|
||||
setMessageHistoryLength(data["message-history-length"]);
|
||||
}
|
||||
|
||||
fetchSystemPrompt();
|
||||
fetchTunables();
|
||||
}, [sessionId, serverSystemPrompt, setServerSystemPrompt, loc]);
|
||||
|
||||
// If the tools have not been set, fetch them from the server
|
||||
@ -615,6 +746,21 @@ const App = () => {
|
||||
fetchRags();
|
||||
}, [sessionId, rags, setRags, setSnack, loc]);
|
||||
|
||||
// If context status changes, show a warning if necessary. If it drops
|
||||
// back below the threshold, clear the warning trigger
|
||||
useEffect(() => {
|
||||
const context_used_percentage = Math.round(100 * contextStatus.context_used / contextStatus.max_context);
|
||||
if (context_used_percentage >= 90 && !contextWarningShown) {
|
||||
setSnack(`${context_used_percentage}% of context used. You may wish to start a new chat.`, "warning");
|
||||
setContextWarningShown(true);
|
||||
}
|
||||
if (context_used_percentage < 90 && contextWarningShown) {
|
||||
setContextWarningShown(false);
|
||||
}
|
||||
setContextUsedPercentage(context_used_percentage)
|
||||
}, [contextStatus, setContextWarningShown, contextWarningShown, setContextUsedPercentage, setSnack]);
|
||||
|
||||
|
||||
const toggleRag = async (tool: Tool) => {
|
||||
tool.enabled = !tool.enabled
|
||||
try {
|
||||
@ -665,7 +811,7 @@ const App = () => {
|
||||
}
|
||||
const sendSystemPrompt = async (prompt: string) => {
|
||||
try {
|
||||
const response = await fetch(getConnectionBase(loc) + `/api/system-prompt/${sessionId}`, {
|
||||
const response = await fetch(getConnectionBase(loc) + `/api/tunables/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@ -691,7 +837,38 @@ const App = () => {
|
||||
|
||||
}, [systemPrompt, setServerSystemPrompt, serverSystemPrompt, loc, sessionId, setSnack]);
|
||||
|
||||
const reset = async (types: ("rags" | "tools" | "history" | "system-prompt")[], message: string = "Update successful.") => {
|
||||
useEffect(() => {
|
||||
if (sessionId === undefined) {
|
||||
return;
|
||||
}
|
||||
const sendMessageHistoryLength = async (length: number) => {
|
||||
try {
|
||||
const response = await fetch(getConnectionBase(loc) + `/api/tunables/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ "message-history-length": length }),
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
const newLength = data["message-history-length"];
|
||||
if (newLength !== messageHistoryLength) {
|
||||
setMessageHistoryLength(newLength);
|
||||
setSnack("Message history length updated", "success");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack("Message history length update failed", "error");
|
||||
}
|
||||
};
|
||||
|
||||
sendMessageHistoryLength(messageHistoryLength);
|
||||
|
||||
}, [messageHistoryLength, setMessageHistoryLength, loc, sessionId, setSnack]);
|
||||
|
||||
const reset = async (types: ("rags" | "tools" | "history" | "system-prompt" | "message-history-length")[], message: string = "Update successful.") => {
|
||||
try {
|
||||
const response = await fetch(getConnectionBase(loc) + `/api/reset/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
@ -751,15 +928,20 @@ const App = () => {
|
||||
|
||||
const drawer = (
|
||||
<>
|
||||
{sessionId !== undefined && systemInfo !== undefined && <Controls {...{ tools, rags, reset, systemPrompt, toggleTool, toggleRag, setSystemPrompt, systemInfo }} />}
|
||||
{sessionId !== undefined && systemInfo !== undefined &&
|
||||
<Controls {...{ messageHistoryLength, setMessageHistoryLength, tools, rags, reset, systemPrompt, toggleTool, toggleRag, setSystemPrompt, systemInfo }} />}
|
||||
</>
|
||||
);
|
||||
|
||||
const submitQuery = (text: string) => {
|
||||
sendQuery(text);
|
||||
}
|
||||
|
||||
const handleKeyPress = (event: any) => {
|
||||
if (event.key === 'Enter') {
|
||||
switch (event.target.id) {
|
||||
case 'QueryInput':
|
||||
sendQuery();
|
||||
sendQuery(query);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -769,11 +951,9 @@ const App = () => {
|
||||
reset(["history"], "New chat started.");
|
||||
}
|
||||
|
||||
const sendQuery = async () => {
|
||||
const sendQuery = async (query: string) => {
|
||||
if (!query.trim()) return;
|
||||
|
||||
setSnack("Query sent", "info");
|
||||
|
||||
const userMessage = [{ role: 'user', content: query }];
|
||||
|
||||
// Add user message to conversation
|
||||
@ -811,6 +991,12 @@ const App = () => {
|
||||
body: JSON.stringify({ role: 'user', content: query.trim() }),
|
||||
});
|
||||
|
||||
// We'll guess that the response will be around 500 tokens...
|
||||
const token_guess = 500;
|
||||
const estimate = Math.round(token_guess / lastEvalTPS + contextStatus.context_used / lastPromptTPS);
|
||||
setSnack(`Query sent. Response estimated in ${estimate}s.`, "info");
|
||||
startCountdown(Math.round(estimate));
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
@ -860,6 +1046,11 @@ const App = () => {
|
||||
...prev.filter(msg => msg.id !== processingId),
|
||||
update.message
|
||||
]);
|
||||
const metadata = update.message.metadata;
|
||||
const evalTPS = metadata.eval_count * 10 ** 9 / metadata.eval_duration;
|
||||
const promptTPS = metadata.prompt_eval_count * 10 ** 9 / metadata.prompt_eval_duration;
|
||||
setLastEvalTPS(evalTPS ? evalTPS : 35);
|
||||
setLastPromptTPS(promptTPS ? promptTPS : 35);
|
||||
updateContextStatus();
|
||||
} else if (update.status === 'error') {
|
||||
// Show error
|
||||
@ -891,6 +1082,7 @@ const App = () => {
|
||||
}
|
||||
}
|
||||
|
||||
stopCountdown();
|
||||
setProcessing(false);
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
@ -900,6 +1092,7 @@ const App = () => {
|
||||
{ role: 'assistant', type: 'error', content: `Error: ${error}` }
|
||||
]);
|
||||
setProcessing(false);
|
||||
stopCountdown();
|
||||
}
|
||||
};
|
||||
|
||||
@ -946,7 +1139,7 @@ const App = () => {
|
||||
onClick={onNew}
|
||||
sx={{ mr: 2 }}
|
||||
>
|
||||
<AddIcon />
|
||||
<ResetIcon />
|
||||
</IconButton>
|
||||
</Tooltip>
|
||||
<Typography variant="h6" noWrap component="div">
|
||||
@ -985,17 +1178,39 @@ const App = () => {
|
||||
</Box>
|
||||
<Box component="main" sx={{ flexGrow: 1, overflow: 'auto' }} className="ChatBox" ref={conversationRef}>
|
||||
<Box className="Conversation" sx={{ flexGrow: 2, p: 1 }}>
|
||||
{conversation.map((message, index) => <Message key={index} message={message} />)}
|
||||
<div style={{ justifyContent: "center", display: "flex", paddingBottom: "0.5rem" }}>
|
||||
{conversation.map((message, index) => <Message key={index} submitQuery={submitQuery} message={message} />)}
|
||||
<Box sx={{
|
||||
display: "flex",
|
||||
flexDirection: "column",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
mb: 1
|
||||
}}>
|
||||
<PropagateLoader
|
||||
size="10px"
|
||||
loading={processing}
|
||||
aria-label="Loading Spinner"
|
||||
data-testid="loader"
|
||||
/>
|
||||
</div>
|
||||
{processing === true && countdown > 0 && (
|
||||
<Box
|
||||
sx={{
|
||||
pt: 1,
|
||||
fontSize: "0.7rem",
|
||||
color: "darkgrey"
|
||||
}}
|
||||
>Estimated response time: {countdown}s</Box>
|
||||
)}
|
||||
</Box>
|
||||
<Box sx={{ ml: "0.25rem", fontSize: "0.6rem", color: "darkgrey", display: "flex", flexDirection: "row", gap: 1, mt: "auto" }}>
|
||||
Context used: {contextUsedPercentage}% {contextStatus.context_used}/{contextStatus.max_context}
|
||||
{
|
||||
contextUsedPercentage >= 90 ? <Typography sx={{ fontSize: "0.6rem", color: "red" }}>WARNING: Context almost exhausted. You should start a new chat.</Typography>
|
||||
: (contextUsedPercentage >= 50 ? <Typography sx={{ fontSize: "0.6rem", color: "orange" }}>NOTE: Context is getting long. Queries will be slower, and the LLM may stop issuing tool calls.</Typography>
|
||||
: <></>)
|
||||
}
|
||||
</Box>
|
||||
</Box>
|
||||
{/* <Box sx={{ mt: "-1rem", ml: "0.25rem", fontSize: "0.6rem", color: "darkgrey", position: "sticky" }}>Context used: {Math.round(100 * contextStatus.context_used / contextStatus.max_context)}% {contextStatus.context_used}/{contextStatus.max_context}</Box> */}
|
||||
<Box className="Query" sx={{ display: "flex", flexDirection: "row", p: 1 }}>
|
||||
<TextField
|
||||
variant="outlined"
|
||||
@ -1011,7 +1226,7 @@ const App = () => {
|
||||
/>
|
||||
<AccordionActions>
|
||||
<Tooltip title="Send">
|
||||
<Button sx={{ m: 0 }} variant="contained" onClick={sendQuery}><SendIcon /></Button>
|
||||
<Button sx={{ m: 0 }} variant="contained" onClick={() => { sendQuery(query); }}><SendIcon /></Button>
|
||||
</Tooltip>
|
||||
</AccordionActions>
|
||||
</Box>
|
||||
|
@ -161,6 +161,7 @@ When answering queries, follow these steps:
|
||||
6. If [{context_tag}] and tool outputs contain conflicting information, prefer the tool outputs as they likely represent more current data
|
||||
|
||||
Always use tools and [{context_tag}] when possible. Be concise, and never make up information. If you do not know the answer, say so.
|
||||
|
||||
""".strip()
|
||||
|
||||
tool_log = []
|
||||
@ -387,6 +388,12 @@ class WebServer:
|
||||
context["llm_history"] = []
|
||||
context["user_history"] = []
|
||||
response["history"] = []
|
||||
context["context_tokens"] = round(len(str(context["system"])) * 3 / 4) # Estimate context usage
|
||||
response["context_used"] = context["context_tokens"]
|
||||
case "message-history-length":
|
||||
context["message_history_length"] = 5
|
||||
response["message-history-length"] = 5
|
||||
|
||||
if not response:
|
||||
return JSONResponse({ "error": "Usage: { reset: rags|tools|history|system-prompt}"})
|
||||
else:
|
||||
@ -396,25 +403,40 @@ class WebServer:
|
||||
except:
|
||||
return JSONResponse({ "error": "Usage: { reset: rags|tools|history|system-prompt}"})
|
||||
|
||||
@self.app.put('/api/system-prompt/{context_id}')
|
||||
async def put_system_prompt(context_id: str, request: Request):
|
||||
@self.app.put('/api/tunables/{context_id}')
|
||||
async def put_tunables(context_id: str, request: Request):
|
||||
if not is_valid_uuid(context_id):
|
||||
logging.warning(f"Invalid context_id: {context_id}")
|
||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||
context = self.upsert_context(context_id)
|
||||
data = await request.json()
|
||||
system_prompt = data["system-prompt"].strip()
|
||||
for k in data.keys():
|
||||
match k:
|
||||
case "system-prompt":
|
||||
system_prompt = data[k].strip()
|
||||
if not system_prompt:
|
||||
return JSONResponse({ "status": "error", "message": "System prompt can not be empty." })
|
||||
context["system"] = [{"role": "system", "content": system_prompt}]
|
||||
self.save_context(context_id)
|
||||
return JSONResponse({ "system-prompt": system_prompt })
|
||||
case "message-history-length":
|
||||
value = max(0, int(data[k]))
|
||||
context["message_history_length"] = value
|
||||
self.save_context(context_id)
|
||||
return JSONResponse({ "message-history-length": value })
|
||||
case _:
|
||||
return JSONResponse({ "error": f"Unrecognized tunable {k}"}, 404)
|
||||
|
||||
@self.app.get('/api/system-prompt/{context_id}')
|
||||
async def get_system_prompt(context_id: str):
|
||||
@self.app.get('/api/tunables/{context_id}')
|
||||
async def get_tunables(context_id: str):
|
||||
if not is_valid_uuid(context_id):
|
||||
logging.warning(f"Invalid context_id: {context_id}")
|
||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||
context = self.upsert_context(context_id)
|
||||
system_prompt = context["system"][0]["content"];
|
||||
return JSONResponse({ "system-prompt": system_prompt })
|
||||
return JSONResponse({
|
||||
"system-prompt": context["system"][0]["content"],
|
||||
"message-history-length": context["message_history_length"]
|
||||
})
|
||||
|
||||
@self.app.get('/api/system-info/{context_id}')
|
||||
async def get_system_info(context_id: str):
|
||||
@ -510,12 +532,8 @@ class WebServer:
|
||||
if not is_valid_uuid(context_id):
|
||||
logging.warning(f"Invalid context_id: {context_id}")
|
||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||
context_used = 0
|
||||
context = self.upsert_context(context_id)
|
||||
# TODO: Switch this to use the tokenizer values instead of 75% of character length
|
||||
for message in context["llm_history"]:
|
||||
context_used += round((len(message["role"]) + len(message["content"])) * 3 / 4)
|
||||
return JSONResponse({"context_used": context_used, "max_context": defines.max_context})
|
||||
return JSONResponse({"context_used": context["context_tokens"], "max_context": defines.max_context})
|
||||
|
||||
@self.app.get('/api/health')
|
||||
async def health_check():
|
||||
@ -530,6 +548,8 @@ class WebServer:
|
||||
self.logging.info(f"Serve index.html for {path}")
|
||||
return FileResponse('/opt/airc/src/ketr-chat/build/index.html')
|
||||
|
||||
import requests
|
||||
|
||||
def save_context(self, session_id):
|
||||
"""
|
||||
Serialize a Python dictionary to a file in the sessions directory.
|
||||
@ -581,18 +601,25 @@ class WebServer:
|
||||
def create_context(self, context_id = None):
|
||||
if not context_id:
|
||||
context_id = str(uuid.uuid4())
|
||||
system_context = [{"role": "system", "content": system_message}];
|
||||
context = {
|
||||
"id": context_id,
|
||||
"system": [{"role": "system", "content": system_message}],
|
||||
"system": system_context,
|
||||
"llm_history": [],
|
||||
"user_history": [],
|
||||
"tools": default_tools(tools),
|
||||
"rags": rags.copy()
|
||||
"rags": rags.copy(),
|
||||
"context_tokens": round(len(str(system_context)) * 3 / 4), # Estimate context usage
|
||||
"message_history_length": 5 # Number of messages to supply in context
|
||||
}
|
||||
logging.info(f"{context_id} created and added to sessions.")
|
||||
self.contexts[context_id] = context
|
||||
return context
|
||||
|
||||
def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096):
|
||||
ctx = round(context + len(str(messages)) * 3 / 4)
|
||||
return max(defines.max_context, min(2048, ctx + ctx_buffer))
|
||||
|
||||
def upsert_context(self, context_id):
|
||||
if not context_id:
|
||||
logging.warning("No context ID provided. Creating a new context.")
|
||||
@ -619,7 +646,11 @@ class WebServer:
|
||||
user_history = context["user_history"]
|
||||
metadata = {
|
||||
"rag": {},
|
||||
"tools": []
|
||||
"tools": [],
|
||||
"eval_count": 0,
|
||||
"eval_duration": 0,
|
||||
"prompt_eval_count": 0,
|
||||
"prompt_eval_duration": 0,
|
||||
}
|
||||
rag_docs = []
|
||||
for rag in context["rags"]:
|
||||
@ -643,13 +674,24 @@ class WebServer:
|
||||
llm_history.append({"role": "user", "content": preamble + content})
|
||||
user_history.append({"role": "user", "content": content})
|
||||
|
||||
if context["message_history_length"]:
|
||||
messages = context["system"] + llm_history[-context["message_history_length"]:]
|
||||
else:
|
||||
messages = context["system"] + llm_history
|
||||
|
||||
try:
|
||||
yield {"status": "processing", "message": "Processing request..."}
|
||||
|
||||
# Estimate token length of new messages
|
||||
ctx_size = self.get_optimal_ctx_size(context["context_tokens"], messages=llm_history[-1]["content"])
|
||||
|
||||
# Use the async generator in an async for loop
|
||||
response = self.client.chat(model=self.model, messages=messages, tools=llm_tools(context["tools"]), options={ 'num_ctx': defines.max_context })
|
||||
response = self.client.chat(model=self.model, messages=messages, tools=llm_tools(context["tools"]), options={ 'num_ctx': ctx_size })
|
||||
metadata["eval_count"] += response['eval_count']
|
||||
metadata["eval_duration"] += response['eval_duration']
|
||||
metadata["prompt_eval_count"] += response['prompt_eval_count']
|
||||
metadata["prompt_eval_duration"] += response['prompt_eval_duration']
|
||||
context["context_tokens"] = response['prompt_eval_count'] + response['eval_count']
|
||||
|
||||
tools_used = []
|
||||
|
||||
@ -680,6 +722,8 @@ class WebServer:
|
||||
{'function': {'name': tc['function']['name'], 'arguments': tc['function']['arguments']}}
|
||||
for tc in message['tool_calls']
|
||||
]
|
||||
|
||||
pre_add_index = len(messages)
|
||||
messages.append(message_dict)
|
||||
|
||||
if isinstance(tool_result, list):
|
||||
@ -690,7 +734,15 @@ class WebServer:
|
||||
metadata["tools"] = tools_used
|
||||
|
||||
yield {"status": "processing", "message": "Generating final response..."}
|
||||
response = self.client.chat(model=self.model, messages=messages, stream=False, options={ 'num_ctx': defines.max_context })
|
||||
# Estimate token length of new messages
|
||||
ctx_size = self.get_optimal_ctx_size(context["context_tokens"], messages=messages[pre_add_index:])
|
||||
# Decrease creativity when processing tool call requests
|
||||
response = self.client.chat(model=self.model, messages=messages, stream=False, options={ 'num_ctx': ctx_size }) #, "temperature": 0.5 })
|
||||
metadata["eval_count"] += response['eval_count']
|
||||
metadata["eval_duration"] += response['eval_duration']
|
||||
metadata["prompt_eval_count"] += response['prompt_eval_count']
|
||||
metadata["prompt_eval_duration"] += response['prompt_eval_duration']
|
||||
context["context_tokens"] = response['prompt_eval_count'] + response['eval_count']
|
||||
|
||||
reply = response['message']['content']
|
||||
final_message = {"role": "assistant", "content": reply }
|
||||
|
Loading…
x
Reference in New Issue
Block a user