backstory/frontend/src/Conversation.tsx

446 lines
15 KiB
TypeScript

import React, { useState, useImperativeHandle, forwardRef, useEffect, useRef, useCallback } from 'react';
import TextField from '@mui/material/TextField';
import Typography from '@mui/material/Typography';
import Tooltip from '@mui/material/Tooltip';
import Button from '@mui/material/Button';
import Box from '@mui/material/Box';
import SendIcon from '@mui/icons-material/Send';
import PropagateLoader from "react-spinners/PropagateLoader";
import { Message, MessageList, MessageData } from './Message';
import { SeverityType } from './Snack';
import { ContextStatus } from './ContextStatus';
const loadingMessage: MessageData = { "role": "assistant", "content": "Establishing connection with server..." };
type ConversationMode = 'chat' | 'fact-check' | 'system';
interface ConversationHandle {
submitQuery: (query: string) => void;
}
interface ConversationProps {
type: ConversationMode
prompt: string,
connectionBase: string,
sessionId?: string,
setSnack: (message: string, severity: SeverityType) => void,
defaultPrompts?: React.ReactElement[],
preamble?: MessageList,
hideDefaultPrompts?: boolean,
};
const Conversation = forwardRef<ConversationHandle, ConversationProps>(({ prompt, type, preamble, hideDefaultPrompts, defaultPrompts, sessionId, setSnack, connectionBase }: ConversationProps, ref) => {
const [query, setQuery] = useState<string>("");
const [contextUsedPercentage, setContextUsedPercentage] = useState<number>(0);
const [processing, setProcessing] = useState<boolean>(false);
const [countdown, setCountdown] = useState<number>(0);
const [conversation, setConversation] = useState<MessageList>([]);
const timerRef = useRef<any>(null);
const [lastEvalTPS, setLastEvalTPS] = useState<number>(35);
const [lastPromptTPS, setLastPromptTPS] = useState<number>(430);
const [contextStatus, setContextStatus] = useState<ContextStatus>({ context_used: 0, max_context: 0 });
const [contextWarningShown, setContextWarningShown] = useState<boolean>(false);
const [noInteractions, setNoInteractions] = useState<boolean>(true);
// Update the context status
const updateContextStatus = useCallback(() => {
const fetchContextStatus = async () => {
try {
const response = await fetch(connectionBase + `/api/context-status/${sessionId}`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
},
});
if (!response.ok) {
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
}
const data = await response.json();
setContextStatus(data);
}
catch (error) {
console.error('Error getting context status:', error);
setSnack("Unable to obtain context status.", "error");
}
};
fetchContextStatus();
}, [setContextStatus, connectionBase, setSnack, sessionId]);
// Set the initial chat history to "loading" or the welcome message if loaded.
useEffect(() => {
if (sessionId === undefined) {
setConversation([loadingMessage]);
return;
}
const fetchHistory = async () => {
try {
const response = await fetch(connectionBase + `/api/history/${sessionId}`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
},
});
if (!response.ok) {
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
}
const data = await response.json();
console.log(`Session id: ${sessionId} -- history returned from server with ${data.length} entries`)
if (data.length === 0) {
setConversation(preamble || []);
setNoInteractions(true);
} else {
setConversation(data);
setNoInteractions(false);
}
updateContextStatus();
} catch (error) {
console.error('Error generating session ID:', error);
setSnack("Unable to obtain chat history.", "error");
}
};
if (sessionId !== undefined) {
fetchHistory();
}
}, [sessionId, setConversation, updateContextStatus, connectionBase, setSnack, preamble]);
const isScrolledToBottom = useCallback(()=> {
// Current vertical scroll position
const scrollTop = window.scrollY || document.documentElement.scrollTop;
// Total height of the page content
const scrollHeight = document.documentElement.scrollHeight;
// Height of the visible window
const clientHeight = document.documentElement.clientHeight;
// If we're at the bottom (allowing a small buffer of 16px)
return scrollTop + clientHeight >= scrollHeight - 16;
}, []);
const scrollToBottom = useCallback(() => {
console.log("Scroll to bottom");
window.scrollTo({
top: document.body.scrollHeight,
});
}, []);
const startCountdown = (seconds: number) => {
if (timerRef.current) clearInterval(timerRef.current);
setCountdown(seconds);
timerRef.current = setInterval(() => {
setCountdown((prev) => {
if (prev <= 1) {
clearInterval(timerRef.current);
timerRef.current = null;
if (isScrolledToBottom()) {
setTimeout(() => {
scrollToBottom();
}, 50)
}
return 0;
}
return prev - 1;
});
}, 1000);
};
const stopCountdown = () => {
if (timerRef.current) {
clearInterval(timerRef.current);
timerRef.current = null;
setCountdown(0);
}
};
const handleKeyPress = (event: any) => {
if (event.key === 'Enter') {
switch (event.target.id) {
case 'QueryInput':
sendQuery(query);
break;
}
}
};
useImperativeHandle(ref, () => ({
submitQuery: (query: string) => {
sendQuery(query);
}
}));
const submitQuery = (query: string) => {
sendQuery(query);
}
// If context status changes, show a warning if necessary. If it drops
// back below the threshold, clear the warning trigger
useEffect(() => {
const context_used_percentage = Math.round(100 * contextStatus.context_used / contextStatus.max_context);
if (context_used_percentage >= 90 && !contextWarningShown) {
setSnack(`${context_used_percentage}% of context used. You may wish to start a new chat.`, "warning");
setContextWarningShown(true);
}
if (context_used_percentage < 90 && contextWarningShown) {
setContextWarningShown(false);
}
setContextUsedPercentage(context_used_percentage)
}, [contextStatus, setContextWarningShown, contextWarningShown, setContextUsedPercentage, setSnack]);
const sendQuery = async (query: string) => {
setNoInteractions(false);
if (!query.trim()) return;
//setTab(0);
const userMessage: MessageData[] = [{ role: 'user', content: query }];
let scrolledToBottom;
// Add user message to conversation
const newConversation: MessageList = [
...conversation,
...userMessage
];
setConversation(newConversation);
scrollToBottom();
// Clear input
setQuery('');
try {
scrolledToBottom = isScrolledToBottom();
setProcessing(true);
// Create a unique ID for the processing message
const processingId = Date.now().toString();
// Add initial processing message
setConversation(prev => [
...prev,
{ role: 'assistant', content: 'Processing request...', id: processingId, isProcessing: true }
]);
if (scrolledToBottom) {
setTimeout(() => { scrollToBottom() }, 50);
}
// Make the fetch request with proper headers
const response = await fetch(connectionBase + `/api/chat/${sessionId}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json',
},
body: JSON.stringify({ role: 'user', content: query.trim() }),
});
// We'll guess that the response will be around 500 tokens...
const token_guess = 500;
const estimate = Math.round(token_guess / lastEvalTPS + contextStatus.context_used / lastPromptTPS);
scrolledToBottom = isScrolledToBottom();
setSnack(`Query sent. Response estimated in ${estimate}s.`, "info");
startCountdown(Math.round(estimate));
if (scrolledToBottom) {
setTimeout(() => { scrollToBottom() }, 50);
}
if (!response.ok) {
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
}
if (!response.body) {
throw new Error('Response body is null');
}
// Set up stream processing with explicit chunking
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
const chunk = decoder.decode(value, { stream: true });
// Process each complete line immediately
buffer += chunk;
let lines = buffer.split('\n');
buffer = lines.pop() || ''; // Keep incomplete line in buffer
for (const line of lines) {
if (!line.trim()) continue;
try {
const update = JSON.parse(line);
// Force an immediate state update based on the message type
if (update.status === 'processing') {
scrolledToBottom = isScrolledToBottom();
// Update processing message with immediate re-render
setConversation(prev => prev.map(msg =>
msg.id === processingId
? { ...msg, content: update.message }
: msg
));
if (scrolledToBottom) {
setTimeout(() => { scrollToBottom() }, 50);
}
// Add a small delay to ensure React has time to update the UI
await new Promise(resolve => setTimeout(resolve, 0));
} else if (update.status === 'done') {
// Replace processing message with final result
scrolledToBottom = isScrolledToBottom();
setConversation(prev => [
...prev.filter(msg => msg.id !== processingId),
update.message
]);
const metadata = update.message.metadata;
const evalTPS = metadata.eval_count * 10 ** 9 / metadata.eval_duration;
const promptTPS = metadata.prompt_eval_count * 10 ** 9 / metadata.prompt_eval_duration;
setLastEvalTPS(evalTPS ? evalTPS : 35);
setLastPromptTPS(promptTPS ? promptTPS : 35);
updateContextStatus();
if (scrolledToBottom) {
setTimeout(() => { scrollToBottom() }, 50);
}
} else if (update.status === 'error') {
// Show error
scrolledToBottom = isScrolledToBottom();
setConversation(prev => [
...prev.filter(msg => msg.id !== processingId),
{ role: 'assistant', type: 'error', content: update.message }
]);
if (scrolledToBottom) {
setTimeout(() => { scrollToBottom() }, 50);
}
}
} catch (e) {
setSnack("Error processing query", "error")
console.error('Error parsing JSON:', e, line);
}
}
}
// Process any remaining buffer content
if (buffer.trim()) {
try {
const update = JSON.parse(buffer);
if (update.status === 'done') {
scrolledToBottom = isScrolledToBottom();
setConversation(prev => [
...prev.filter(msg => msg.id !== processingId),
update.message
]);
if (scrolledToBottom) {
setTimeout(() => { scrollToBottom() }, 500);
}
}
} catch (e) {
setSnack("Error processing query", "error")
}
}
scrolledToBottom = isScrolledToBottom();
stopCountdown();
setProcessing(false);
if (scrolledToBottom) {
setTimeout(() => { scrollToBottom() }, 50);
}
} catch (error) {
console.error('Fetch error:', error);
setSnack("Unable to process query", "error");
scrolledToBottom = isScrolledToBottom();
setConversation(prev => [
...prev.filter(msg => !msg.isProcessing),
{ role: 'assistant', type: 'error', content: `Error: ${error}` }
]);
setProcessing(false);
stopCountdown();
if (scrolledToBottom) {
setTimeout(() => { scrollToBottom() }, 50);
}
}
};
return (
<Box className="Conversation" sx={{ display: "flex", flexDirection: "column", overflowY: "auto" }}>
{conversation.map((message, index) => <Message key={index} {...{ submitQuery, message, connectionBase, sessionId, setSnack }} />)}
<Box sx={{
display: "flex",
flexDirection: "column",
alignItems: "center",
justifyContent: "center",
mb: 1
}}>
<PropagateLoader
size="10px"
loading={processing}
aria-label="Loading Spinner"
data-testid="loader"
/>
{processing === true && countdown > 0 && (
<Box
sx={{
pt: 1,
fontSize: "0.7rem",
color: "darkgrey"
}}
>Estimated response time: {countdown}s</Box>
)}
</Box>
<Box className="Query" sx={{ display: "flex", flexDirection: "row", p: 1 }}>
<TextField
variant="outlined"
disabled={processing}
fullWidth
type="text"
value={query}
onChange={(e) => setQuery(e.target.value)}
onKeyDown={handleKeyPress}
placeholder={prompt}
id="QueryInput"
/>
<Tooltip title="Send">
<Button sx={{ m: 1 }} variant="contained" onClick={() => { sendQuery(query); }}><SendIcon /></Button>
</Tooltip>
</Box>
{(noInteractions || !hideDefaultPrompts) && defaultPrompts !== undefined && defaultPrompts.length &&
<Box sx={{ display: "flex", flexDirection: "column" }}>
{
defaultPrompts.map((element, index) => {
return (<Box key={index}>{element}</Box>);
})
}
</Box>
}
<Box sx={{ ml: "0.25rem", fontSize: "0.6rem", color: "darkgrey", display: "flex", flexShrink: 1, flexDirection: "row", gap: 1, mb: "auto", mt: 1 }}>
Context used: {contextUsedPercentage}% {contextStatus.context_used}/{contextStatus.max_context}
{
contextUsedPercentage >= 90 ? <Typography sx={{ fontSize: "0.6rem", color: "red" }}>WARNING: Context almost exhausted. You should start a new chat.</Typography>
: (contextUsedPercentage >= 50 ? <Typography sx={{ fontSize: "0.6rem", color: "orange" }}>NOTE: Context is getting long. Queries will be slower, and the LLM may stop issuing tool calls.</Typography>
: <></>)
}
</Box>
<Box sx={{ display: "flex", flexGrow: 1 }}></Box>
</Box>
);
});
export type {
ConversationProps,
ConversationHandle
};
export {
Conversation
};