546 lines
19 KiB
TypeScript
546 lines
19 KiB
TypeScript
import React, { useState, useImperativeHandle, forwardRef, useEffect, useRef, useCallback } from 'react';
|
|
import TextField from '@mui/material/TextField';
|
|
import Typography from '@mui/material/Typography';
|
|
import Tooltip from '@mui/material/Tooltip';
|
|
import Button from '@mui/material/Button';
|
|
import Box from '@mui/material/Box';
|
|
import SendIcon from '@mui/icons-material/Send';
|
|
import { SxProps, Theme } from '@mui/material';
|
|
import PropagateLoader from "react-spinners/PropagateLoader";
|
|
|
|
import { Message, MessageList, MessageData } from './Message';
|
|
import { SetSnackType } from './Snack';
|
|
import { ContextStatus } from './ContextStatus';
|
|
import { useAutoScrollToBottom } from './AutoScroll';
|
|
import { DeleteConfirmation } from './DeleteConfirmation';
|
|
|
|
import './Conversation.css';
|
|
|
|
const loadingMessage: MessageData = { "role": "status", "content": "Establishing connection with server..." };
|
|
|
|
type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check';
|
|
|
|
interface ConversationHandle {
|
|
submitQuery: (query: string) => void;
|
|
}
|
|
|
|
interface ConversationProps {
|
|
className?: string, // Override default className
|
|
type: ConversationMode, // Type of Conversation chat
|
|
prompt?: string, // Prompt to display in TextField input
|
|
actionLabel?: string, // Label to put on the primary button
|
|
resetAction?: () => void, // Callback when Reset is pressed
|
|
multiline?: boolean, // Render TextField as multiline or not
|
|
resetLabel?: string, // Label to put on Reset button
|
|
connectionBase: string, // Base URL for fetch() calls
|
|
sessionId?: string, // Session ID for fetch() calls
|
|
setSnack: SetSnackType, // Callback to display snack popups
|
|
defaultPrompts?: React.ReactElement[], // Set of Elements to display after the TextField
|
|
defaultQuery?: string, // Default text to populate the TextField input
|
|
preamble?: MessageList, // Messages to display at start of Conversation until Action has been invoked
|
|
hidePreamble?: boolean, // Whether to hide the preamble after an Action has been invoked
|
|
hideDefaultPrompts?: boolean, // Whether to hide the defaultPrompts after an Action has been invoked
|
|
messageFilter?: ((messages: MessageList) => MessageList) | undefined, // Filter callback to determine which Messages to display in Conversation
|
|
messages?: MessageList, //
|
|
sx?: SxProps<Theme>,
|
|
onResponse?: ((message: MessageData) => MessageData) | undefined, // Event called when a query completes (provides messages)
|
|
};
|
|
|
|
const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|
className,
|
|
type,
|
|
prompt,
|
|
actionLabel,
|
|
resetAction,
|
|
multiline,
|
|
resetLabel,
|
|
connectionBase,
|
|
sessionId,
|
|
setSnack,
|
|
defaultPrompts,
|
|
hideDefaultPrompts,
|
|
defaultQuery,
|
|
preamble,
|
|
hidePreamble,
|
|
messageFilter,
|
|
messages,
|
|
sx,
|
|
onResponse
|
|
}: ConversationProps, ref) => {
|
|
const [query, setQuery] = useState<string>("");
|
|
const [contextUsedPercentage, setContextUsedPercentage] = useState<number>(0);
|
|
const [processing, setProcessing] = useState<boolean>(false);
|
|
const [countdown, setCountdown] = useState<number>(0);
|
|
const [conversation, setConversation] = useState<MessageList>([]);
|
|
const [filteredConversation, setFilteredConversation] = useState<MessageList>([]);
|
|
const [processingMessage, setProcessingMessage] = useState<MessageData | undefined>(undefined);
|
|
const timerRef = useRef<any>(null);
|
|
const [lastEvalTPS, setLastEvalTPS] = useState<number>(35);
|
|
const [lastPromptTPS, setLastPromptTPS] = useState<number>(430);
|
|
const [contextStatus, setContextStatus] = useState<ContextStatus>({ context_used: 0, max_context: 0 });
|
|
const [contextWarningShown, setContextWarningShown] = useState<boolean>(false);
|
|
const [noInteractions, setNoInteractions] = useState<boolean>(true);
|
|
const conversationRef = useRef<MessageList>([]);
|
|
const scrollRef = useAutoScrollToBottom();
|
|
|
|
// Keep the ref updated whenever items changes
|
|
useEffect(() => {
|
|
conversationRef.current = conversation;
|
|
}, [conversation]);
|
|
|
|
// Update the context status
|
|
const updateContextStatus = useCallback(() => {
|
|
const fetchContextStatus = async () => {
|
|
try {
|
|
const response = await fetch(connectionBase + `/api/context-status/${sessionId}/${type}`, {
|
|
method: 'GET',
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
},
|
|
});
|
|
|
|
if (!response.ok) {
|
|
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
|
}
|
|
|
|
const data = await response.json();
|
|
setContextStatus(data);
|
|
}
|
|
catch (error) {
|
|
console.error('Error getting context status:', error);
|
|
setSnack("Unable to obtain context status.", "error");
|
|
}
|
|
};
|
|
fetchContextStatus();
|
|
}, [setContextStatus, connectionBase, setSnack, sessionId, type]);
|
|
|
|
/* Transform the 'Conversation' by filtering via callback, then adding
|
|
* preamble and messages based on whether the conversation
|
|
* has any elements yet */
|
|
useEffect(() => {
|
|
let filtered = [];
|
|
if (messageFilter === undefined) {
|
|
filtered = conversation;
|
|
} else {
|
|
//console.log('Filtering conversation...')
|
|
filtered = messageFilter(conversation); /* Do not copy conversation or useEffect will loop forever */
|
|
//console.log(`${conversation.length - filtered.length} messages filtered out.`);
|
|
}
|
|
if (filtered.length === 0) {
|
|
setFilteredConversation([
|
|
...(preamble || []),
|
|
...(messages || []),
|
|
]);
|
|
} else {
|
|
setFilteredConversation([
|
|
...(hidePreamble ? [] : (preamble || [])),
|
|
...(messages || []),
|
|
...filtered,
|
|
]);
|
|
};
|
|
}, [conversation, setFilteredConversation, messageFilter, preamble, messages, hidePreamble]);
|
|
|
|
// Set the initial chat history to "loading" or the welcome message if loaded.
|
|
useEffect(() => {
|
|
if (sessionId === undefined) {
|
|
setProcessingMessage(loadingMessage);
|
|
return;
|
|
}
|
|
|
|
const fetchHistory = async () => {
|
|
try {
|
|
const response = await fetch(connectionBase + `/api/history/${sessionId}/${type}`, {
|
|
method: 'GET',
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
},
|
|
});
|
|
|
|
if (!response.ok) {
|
|
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
|
}
|
|
|
|
const data = await response.json();
|
|
|
|
console.log(`History returned for ${type} from server with ${data.length} entries`)
|
|
if (data.length === 0) {
|
|
setConversation([])
|
|
setNoInteractions(true);
|
|
} else {
|
|
setConversation(data);
|
|
setNoInteractions(false);
|
|
}
|
|
setProcessingMessage(undefined);
|
|
updateContextStatus();
|
|
} catch (error) {
|
|
console.error('Error generating session ID:', error);
|
|
setProcessingMessage({ role: "error", content: "Unable to obtain history from server." });
|
|
setTimeout(() => {
|
|
setProcessingMessage(undefined);
|
|
}, 5000);
|
|
setSnack("Unable to obtain chat history.", "error");
|
|
}
|
|
};
|
|
|
|
fetchHistory();
|
|
}, [setConversation, setFilteredConversation, updateContextStatus, connectionBase, setSnack, type, sessionId]);
|
|
|
|
const startCountdown = (seconds: number) => {
|
|
if (timerRef.current) clearInterval(timerRef.current);
|
|
setCountdown(seconds);
|
|
timerRef.current = setInterval(() => {
|
|
setCountdown((prev) => {
|
|
if (prev <= 1) {
|
|
clearInterval(timerRef.current);
|
|
timerRef.current = null;
|
|
return 0;
|
|
}
|
|
return prev - 1;
|
|
});
|
|
}, 1000);
|
|
};
|
|
|
|
const stopCountdown = () => {
|
|
if (timerRef.current) {
|
|
clearInterval(timerRef.current);
|
|
timerRef.current = null;
|
|
setCountdown(0);
|
|
}
|
|
};
|
|
|
|
const handleKeyPress = (event: any) => {
|
|
if (event.key === 'Enter' && !event.shiftKey) {
|
|
sendQuery(query);
|
|
}
|
|
};
|
|
|
|
useImperativeHandle(ref, () => ({
|
|
submitQuery: (query: string) => {
|
|
sendQuery(query);
|
|
}
|
|
}));
|
|
|
|
// If context status changes, show a warning if necessary. If it drops
|
|
// back below the threshold, clear the warning trigger
|
|
useEffect(() => {
|
|
const context_used_percentage = Math.round(100 * contextStatus.context_used / contextStatus.max_context);
|
|
if (context_used_percentage >= 90 && !contextWarningShown) {
|
|
setSnack(`${context_used_percentage}% of context used. You may wish to start a new chat.`, "warning");
|
|
setContextWarningShown(true);
|
|
}
|
|
if (context_used_percentage < 90 && contextWarningShown) {
|
|
setContextWarningShown(false);
|
|
}
|
|
setContextUsedPercentage(context_used_percentage)
|
|
}, [contextStatus, setContextWarningShown, contextWarningShown, setContextUsedPercentage, setSnack]);
|
|
|
|
const reset = async () => {
|
|
try {
|
|
const response = await fetch(connectionBase + `/api/reset/${sessionId}/${type}`, {
|
|
method: 'PUT',
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
'Accept': 'application/json',
|
|
},
|
|
body: JSON.stringify({ reset: ['history'] })
|
|
});
|
|
|
|
if (!response.ok) {
|
|
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
|
}
|
|
|
|
if (!response.body) {
|
|
throw new Error('Response body is null');
|
|
}
|
|
|
|
setConversation([])
|
|
setNoInteractions(true);
|
|
|
|
} catch (e) {
|
|
setSnack("Error resetting history", "error")
|
|
console.error('Error resetting history:', e);
|
|
}
|
|
};
|
|
|
|
const sendQuery = async (query: string) => {
|
|
query = query.trim();
|
|
|
|
// If the query was empty, a default query was provided,
|
|
// and there is no prompt for the user, send the default query.
|
|
if (!query && defaultQuery && !prompt) {
|
|
query = defaultQuery.trim();
|
|
}
|
|
|
|
// If the query is empty, and a prompt was provided, do not
|
|
// send an empty query.
|
|
if (!query && prompt) {
|
|
return;
|
|
}
|
|
|
|
setNoInteractions(false);
|
|
|
|
if (query) {
|
|
setConversation([
|
|
...conversationRef.current,
|
|
{
|
|
role: 'user',
|
|
origin: type,
|
|
content: query,
|
|
disableCopy: true
|
|
}
|
|
]);
|
|
}
|
|
|
|
// Add a small delay to ensure React has time to update the UI
|
|
await new Promise(resolve => setTimeout(resolve, 0));
|
|
console.log(conversation);
|
|
|
|
// Clear input
|
|
setQuery('');
|
|
|
|
try {
|
|
setProcessing(true);
|
|
// Create a unique ID for the processing message
|
|
const processingId = Date.now().toString();
|
|
|
|
// Add initial processing message
|
|
setProcessingMessage(
|
|
{ role: 'status', content: 'Submitting request...', id: processingId, isProcessing: true }
|
|
);
|
|
|
|
// Add a small delay to ensure React has time to update the UI
|
|
await new Promise(resolve => setTimeout(resolve, 0));
|
|
|
|
// Make the fetch request with proper headers
|
|
const response = await fetch(connectionBase + `/api/chat/${sessionId}/${type}`, {
|
|
method: 'POST',
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
'Accept': 'application/json',
|
|
},
|
|
body: JSON.stringify({ role: 'user', content: query.trim() }),
|
|
});
|
|
|
|
// We'll guess that the response will be around 500 tokens...
|
|
const token_guess = 500;
|
|
const estimate = Math.round(token_guess / lastEvalTPS + contextStatus.context_used / lastPromptTPS);
|
|
|
|
setSnack(`Query sent. Response estimated in ${estimate}s.`, "info");
|
|
startCountdown(Math.round(estimate));
|
|
|
|
if (!response.ok) {
|
|
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
|
}
|
|
|
|
if (!response.body) {
|
|
throw new Error('Response body is null');
|
|
}
|
|
|
|
// Set up stream processing with explicit chunking
|
|
const reader = response.body.getReader();
|
|
const decoder = new TextDecoder();
|
|
let buffer = '';
|
|
|
|
while (true) {
|
|
const { done, value } = await reader.read();
|
|
if (done) {
|
|
break;
|
|
}
|
|
|
|
const chunk = decoder.decode(value, { stream: true });
|
|
|
|
// Process each complete line immediately
|
|
buffer += chunk;
|
|
let lines = buffer.split('\n');
|
|
buffer = lines.pop() || ''; // Keep incomplete line in buffer
|
|
for (const line of lines) {
|
|
if (!line.trim()) continue;
|
|
|
|
try {
|
|
const update = JSON.parse(line);
|
|
|
|
// Force an immediate state update based on the message type
|
|
if (update.status === 'processing') {
|
|
// Update processing message with immediate re-render
|
|
setProcessingMessage({ role: 'status', content: update.message });
|
|
// Add a small delay to ensure React has time to update the UI
|
|
await new Promise(resolve => setTimeout(resolve, 0));
|
|
} else if (update.status === 'done') {
|
|
// Replace processing message with final result
|
|
if (onResponse) {
|
|
update.message = onResponse(update.message);
|
|
}
|
|
setProcessingMessage(undefined);
|
|
setConversation([
|
|
...conversationRef.current,
|
|
update.message
|
|
])
|
|
// Add a small delay to ensure React has time to update the UI
|
|
await new Promise(resolve => setTimeout(resolve, 0));
|
|
|
|
const metadata = update.message.metadata;
|
|
if (metadata) {
|
|
const evalTPS = metadata.eval_count * 10 ** 9 / metadata.eval_duration;
|
|
const promptTPS = metadata.prompt_eval_count * 10 ** 9 / metadata.prompt_eval_duration;
|
|
setLastEvalTPS(evalTPS ? evalTPS : 35);
|
|
setLastPromptTPS(promptTPS ? promptTPS : 35);
|
|
updateContextStatus();
|
|
}
|
|
} else if (update.status === 'error') {
|
|
// Show error
|
|
setProcessingMessage({ role: 'error', content: update.message });
|
|
setTimeout(() => {
|
|
setProcessingMessage(undefined);
|
|
}, 5000);
|
|
|
|
// Add a small delay to ensure React has time to update the UI
|
|
await new Promise(resolve => setTimeout(resolve, 0));
|
|
}
|
|
} catch (e) {
|
|
setSnack("Error processing query", "error")
|
|
console.error('Error parsing JSON:', e, line);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Process any remaining buffer content
|
|
if (buffer.trim()) {
|
|
try {
|
|
const update = JSON.parse(buffer);
|
|
|
|
if (update.status === 'done') {
|
|
if (onResponse) {
|
|
update.message = onResponse(update.message);
|
|
}
|
|
setProcessingMessage(undefined);
|
|
setConversation([
|
|
...conversationRef.current,
|
|
update.message
|
|
]);
|
|
}
|
|
} catch (e) {
|
|
setSnack("Error processing query", "error")
|
|
}
|
|
}
|
|
|
|
stopCountdown();
|
|
setProcessing(false);
|
|
} catch (error) {
|
|
console.error('Fetch error:', error);
|
|
setSnack("Unable to process query", "error");
|
|
setProcessingMessage({ role: 'error', content: "Unable to process query" });
|
|
setTimeout(() => {
|
|
setProcessingMessage(undefined);
|
|
}, 5000);
|
|
|
|
setProcessing(false);
|
|
stopCountdown();
|
|
// Add a small delay to ensure React has time to update the UI
|
|
await new Promise(resolve => setTimeout(resolve, 0));
|
|
}
|
|
};
|
|
|
|
return (
|
|
<Box className={className || "Conversation"}
|
|
ref={scrollRef}
|
|
sx={{
|
|
p: 1,
|
|
mt: 0,
|
|
...sx
|
|
}}>
|
|
{
|
|
filteredConversation.map((message, index) =>
|
|
<Message key={index} {...{ sendQuery, message, connectionBase, sessionId, setSnack }} />
|
|
)
|
|
}
|
|
{
|
|
processingMessage !== undefined &&
|
|
<Message {...{ sendQuery, connectionBase, sessionId, setSnack, message: processingMessage }} />
|
|
}
|
|
<Box sx={{
|
|
display: "flex",
|
|
flexDirection: "column",
|
|
alignItems: "center",
|
|
justifyContent: "center",
|
|
mb: 1,
|
|
}}>
|
|
<PropagateLoader
|
|
size="10px"
|
|
loading={processing}
|
|
aria-label="Loading Spinner"
|
|
data-testid="loader"
|
|
/>
|
|
{processing === true && countdown > 0 && (
|
|
<Box
|
|
sx={{
|
|
pt: 1,
|
|
fontSize: "0.7rem",
|
|
color: "darkgrey"
|
|
}}
|
|
>Estimated response time: {countdown}s</Box>
|
|
)}
|
|
</Box>
|
|
<Box className="Query" sx={{ display: "flex", flexDirection: "column", p: 1 }}>
|
|
{prompt &&
|
|
<TextField
|
|
variant="outlined"
|
|
disabled={processing}
|
|
fullWidth={true}
|
|
multiline={multiline ? true : false}
|
|
type="text"
|
|
value={query}
|
|
onChange={(e) => setQuery(e.target.value)}
|
|
onKeyDown={handleKeyPress}
|
|
placeholder={prompt}
|
|
id="QueryInput"
|
|
/>
|
|
}
|
|
|
|
<Box key="jobActions" sx={{ display: "flex", justifyContent: "center", flexDirection: "row" }}>
|
|
<DeleteConfirmation
|
|
label={resetLabel || "all data"}
|
|
disabled={sessionId === undefined || processingMessage !== undefined || noInteractions}
|
|
onDelete={() => { reset(); resetAction && resetAction(); }} />
|
|
<Tooltip title={actionLabel || "Send"}>
|
|
<span style={{ display: "flex", flexGrow: 1 }}>
|
|
<Button
|
|
sx={{ m: 1, gap: 1, flexGrow: 1 }}
|
|
variant="contained"
|
|
disabled={sessionId === undefined || processingMessage !== undefined}
|
|
onClick={() => { sendQuery(query); }}>
|
|
{actionLabel}<SendIcon />
|
|
</Button>
|
|
</span>
|
|
</Tooltip>
|
|
</Box>
|
|
</Box>
|
|
{(noInteractions || !hideDefaultPrompts) && defaultPrompts !== undefined && defaultPrompts.length &&
|
|
<Box sx={{ display: "flex", flexDirection: "column" }}>
|
|
{
|
|
defaultPrompts.map((element, index) => {
|
|
return (<Box key={index}>{element}</Box>);
|
|
})
|
|
}
|
|
</Box>
|
|
}
|
|
<Box sx={{ ml: "0.25rem", fontSize: "0.6rem", color: "darkgrey", display: "flex", flexShrink: 1, flexDirection: "row", gap: 1, mb: "auto", mt: 1 }}>
|
|
Context used: {contextUsedPercentage}% {contextStatus.context_used}/{contextStatus.max_context}
|
|
{
|
|
contextUsedPercentage >= 90 ? <Typography sx={{ fontSize: "0.6rem", color: "red" }}>WARNING: Context almost exhausted. You should start a new chat.</Typography>
|
|
: (contextUsedPercentage >= 50 ? <Typography sx={{ fontSize: "0.6rem", color: "orange" }}>NOTE: Context is getting long. Queries will be slower, and the LLM may stop issuing tool calls.</Typography>
|
|
: <></>)
|
|
}
|
|
</Box>
|
|
<Box sx={{ display: "flex", flexGrow: 1 }}></Box>
|
|
</Box>
|
|
);
|
|
});
|
|
|
|
export type {
|
|
ConversationProps,
|
|
ConversationHandle
|
|
};
|
|
|
|
export {
|
|
Conversation
|
|
}; |