472 lines
16 KiB
TypeScript

import React, { useState, useImperativeHandle, forwardRef, useEffect, useRef, useCallback } from 'react';
import Typography from '@mui/material/Typography';
import Tooltip from '@mui/material/Tooltip';
import IconButton from '@mui/material/IconButton';
import Button from '@mui/material/Button';
import Box from '@mui/material/Box';
import SendIcon from '@mui/icons-material/Send';
import CancelIcon from '@mui/icons-material/Cancel';
import { SxProps, Theme } from '@mui/material';
import PropagateLoader from "react-spinners/PropagateLoader";
import { Message, MessageList, BackstoryMessage, MessageRoles } from '../../Components/Message';
import { DeleteConfirmation } from '../../Components/DeleteConfirmation';
import { Query } from '../../Components/ChatQuery';
import { BackstoryTextField, BackstoryTextFieldRef } from '../../Components/BackstoryTextField';
import { BackstoryElementProps } from '../../Components/BackstoryTab';
import { connectionBase } from '../../Global';
import { useUser } from "../Components/UserContext";
import { streamQueryResponse, StreamQueryController } from './streamQueryResponse';
import './Conversation.css';
const loadingMessage: BackstoryMessage = { "role": "status", "content": "Establishing connection with server..." };
type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check' | 'persona';
interface ConversationHandle {
submitQuery: (query: Query) => void;
fetchHistory: () => void;
}
interface ConversationProps extends BackstoryElementProps {
className?: string, // Override default className
type: ConversationMode, // Type of Conversation chat
placeholder?: string, // Prompt to display in TextField input
actionLabel?: string, // Label to put on the primary button
resetAction?: () => void, // Callback when Reset is pressed
resetLabel?: string, // Label to put on Reset button
defaultPrompts?: React.ReactElement[], // Set of Elements to display after the TextField
defaultQuery?: string, // Default text to populate the TextField input
preamble?: MessageList, // Messages to display at start of Conversation until Action has been invoked
hidePreamble?: boolean, // Whether to hide the preamble after an Action has been invoked
hideDefaultPrompts?: boolean, // Whether to hide the defaultPrompts after an Action has been invoked
messageFilter?: ((messages: MessageList) => MessageList) | undefined, // Filter callback to determine which Messages to display in Conversation
messages?: MessageList, //
sx?: SxProps<Theme>,
onResponse?: ((message: BackstoryMessage) => void) | undefined, // Event called when a query completes (provides messages)
};
const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: ConversationProps, ref) => {
const {
sessionId,
actionLabel,
className,
defaultPrompts,
defaultQuery,
hideDefaultPrompts,
hidePreamble,
messageFilter,
messages,
onResponse,
placeholder,
preamble,
resetAction,
resetLabel,
setSnack,
submitQuery,
sx,
type,
} = props;
const { user } = useUser()
const [contextUsedPercentage, setContextUsedPercentage] = useState<number>(0);
const [processing, setProcessing] = useState<boolean>(false);
const [countdown, setCountdown] = useState<number>(0);
const [conversation, setConversation] = useState<MessageList>([]);
const [filteredConversation, setFilteredConversation] = useState<MessageList>([]);
const [processingMessage, setProcessingMessage] = useState<BackstoryMessage | undefined>(undefined);
const [streamingMessage, setStreamingMessage] = useState<BackstoryMessage | undefined>(undefined);
const timerRef = useRef<any>(null);
const [contextWarningShown, setContextWarningShown] = useState<boolean>(false);
const [noInteractions, setNoInteractions] = useState<boolean>(true);
const conversationRef = useRef<MessageList>([]);
const viewableElementRef = useRef<HTMLDivElement>(null);
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
const stopRef = useRef(false);
const controllerRef = useRef<StreamQueryController>(null);
// Keep the ref updated whenever items changes
useEffect(() => {
conversationRef.current = conversation;
}, [conversation]);
// Update the context status
/* Transform the 'Conversation' by filtering via callback, then adding
* preamble and messages based on whether the conversation
* has any elements yet */
useEffect(() => {
let filtered = [];
if (messageFilter === undefined) {
filtered = conversation;
// console.log('No message filter provided. Using all messages.', filtered);
} else {
//console.log('Filtering conversation...')
filtered = messageFilter(conversation); /* Do not copy conversation or useEffect will loop forever */
//console.log(`${conversation.length - filtered.length} messages filtered out.`);
}
if (filtered.length === 0) {
setFilteredConversation([
...(preamble || []),
...(messages || []),
]);
} else {
setFilteredConversation([
...(hidePreamble ? [] : (preamble || [])),
...(messages || []),
...filtered,
]);
};
}, [conversation, setFilteredConversation, messageFilter, preamble, messages, hidePreamble]);
const fetchHistory = useCallback(async () => {
let retries = 5;
while (--retries > 0) {
try {
const response = await fetch(connectionBase + `/api/history/${sessionId}/${type}`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
},
});
if (!response.ok) {
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
}
const { messages } = await response.json();
if (messages === undefined || messages.length === 0) {
console.log(`History returned for ${type} from server with 0 entries`)
setConversation([])
setNoInteractions(true);
} else {
console.log(`History returned for ${type} from server with ${messages.length} entries:`, messages)
const backstoryMessages: BackstoryMessage[] = messages;
setConversation(backstoryMessages.flatMap((backstoryMessage: BackstoryMessage) => {
if (backstoryMessage.status === "partial") {
return [{
...backstoryMessage,
role: "assistant",
content: backstoryMessage.response || "",
expanded: false,
expandable: true,
}]
}
return [{
role: 'user',
content: backstoryMessage.prompt || "",
}, {
...backstoryMessage,
role: ['done'].includes(backstoryMessage.status || "") ? "assistant" : backstoryMessage.status,
content: backstoryMessage.response || "",
}] as MessageList;
}));
setNoInteractions(false);
}
setProcessingMessage(undefined);
setStreamingMessage(undefined);
return;
} catch (error) {
console.error('Error generating session ID:', error);
setProcessingMessage({ role: "error", content: `Unable to obtain history from server. Retrying in 3 seconds (${retries} remain.)` });
setTimeout(() => {
setProcessingMessage(undefined);
}, 3000);
await new Promise(resolve => setTimeout(resolve, 3000));
setSnack("Unable to obtain chat history.", "error");
}
};
}, [setConversation,setSnack, type, sessionId]);
// Set the initial chat history to "loading" or the welcome message if loaded.
useEffect(() => {
if (sessionId === undefined) {
setProcessingMessage(loadingMessage);
return;
}
setProcessingMessage(undefined);
setStreamingMessage(undefined);
setConversation([]);
setNoInteractions(true);
if (user) {
fetchHistory();
}
}, [fetchHistory, sessionId, setProcessing, user]);
const startCountdown = (seconds: number) => {
if (timerRef.current) clearInterval(timerRef.current);
setCountdown(seconds);
timerRef.current = setInterval(() => {
setCountdown((prev) => {
if (prev <= 1) {
clearInterval(timerRef.current);
timerRef.current = null;
return 0;
}
return prev - 1;
});
}, 1000);
};
const stopCountdown = () => {
if (timerRef.current) {
clearInterval(timerRef.current);
timerRef.current = null;
setCountdown(0);
}
};
const handleEnter = (value: string) => {
const query: Query = {
prompt: value
}
processQuery(query);
};
useImperativeHandle(ref, () => ({
submitQuery: (query: Query) => {
processQuery(query);
},
fetchHistory: () => { return fetchHistory(); }
}));
const reset = async () => {
try {
const response = await fetch(connectionBase + `/api/reset/${sessionId}/${type}`, {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json',
},
body: JSON.stringify({ reset: ['history'] })
});
if (!response.ok) {
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
}
if (!response.body) {
throw new Error('Response body is null');
}
setProcessingMessage(undefined);
setStreamingMessage(undefined);
setConversation([]);
setNoInteractions(true);
} catch (e) {
setSnack("Error resetting history", "error")
console.error('Error resetting history:', e);
}
};
const cancelQuery = () => {
console.log("Stop query");
if (controllerRef.current) {
controllerRef.current.abort();
}
controllerRef.current = null;
};
const processQuery = (query: Query) => {
if (controllerRef.current) {
return;
}
setNoInteractions(false);
setConversation([
...conversationRef.current,
{
role: 'user',
origin: type,
content: query.prompt,
disableCopy: true
}
]);
setProcessing(true);
setProcessingMessage(
{ role: 'status', content: 'Submitting request...', disableCopy: true }
);
controllerRef.current = streamQueryResponse({
query,
type,
sessionId,
connectionBase,
onComplete: (msg) => {
console.log(msg);
switch (msg.status) {
case "done":
case "partial":
setConversation([
...conversationRef.current, {
...msg,
role: 'assistant',
origin: type,
prompt: ['done', 'partial'].includes(msg.status || "") ? msg.prompt : '',
content: msg.response || "",
expanded: msg.status === "done" ? true : false,
expandable: msg.status === "done" ? false : true,
}] as MessageList);
startCountdown(Math.ceil(msg.remaining_time || 0));
if (msg.status === "done") {
stopCountdown();
setStreamingMessage(undefined);
setProcessingMessage(undefined);
setProcessing(false);
controllerRef.current = null;
}
if (onResponse) {
onResponse(msg);
}
break;
case "error":
// Show error
setConversation([
...conversationRef.current, {
...msg,
role: 'error',
origin: type,
content: msg.response || "",
}] as MessageList);
setProcessingMessage(msg);
setProcessing(false);
stopCountdown();
controllerRef.current = null;
break;
default:
setProcessingMessage({ role: (msg.status || "error") as MessageRoles, content: msg.response || "", disableCopy: true });
break;
}
},
onStreaming: (chunk) => {
setStreamingMessage({ role: "streaming", content: chunk, disableCopy: true });
}
});
};
return (
// <Scrollable
// className={`${className || ""} Conversation`}
// autoscroll
// textFieldRef={viewableElementRef}
// fallbackThreshold={0.5}
// sx={{
// p: 1,
// mt: 0,
// ...sx
// }}
// >
<Box sx={{ p: 1, mt: 0, overflow: "hidden", ...sx }}>
{
filteredConversation.map((message, index) =>
<Message key={index} expanded={message.expanded === undefined ? true : message.expanded} {...{ sendQuery: processQuery, message, connectionBase, sessionId, setSnack, submitQuery }} />
)
}
{
processingMessage !== undefined &&
<Message {...{ sendQuery: processQuery, connectionBase, sessionId, setSnack, message: processingMessage, submitQuery }} />
}
{
streamingMessage !== undefined &&
<Message {...{ sendQuery: processQuery, connectionBase, sessionId, setSnack, message: streamingMessage, submitQuery }} />
}
<Box sx={{
display: "flex",
flexDirection: "column",
alignItems: "center",
justifyContent: "center",
m: 1,
}}>
<PropagateLoader
size="10px"
loading={processing}
aria-label="Loading Spinner"
data-testid="loader"
/>
{processing === true && countdown > 0 && (
<Box
sx={{
pt: 1,
fontSize: "0.7rem",
color: "darkgrey"
}}
>Response will be stopped in: {countdown}s</Box>
)}
</Box>
<Box className="Query" sx={{ display: "flex", flexDirection: "column", p: 1, flexGrow: 1 }}>
{placeholder &&
<Box sx={{ display: "flex", flexGrow: 1, p: 0, m: 0, flexDirection: "column" }}
ref={viewableElementRef}>
<BackstoryTextField
ref={backstoryTextRef}
disabled={processing}
onEnter={handleEnter}
placeholder={placeholder}
/>
</Box>
}
<Box key="jobActions" sx={{ display: "flex", justifyContent: "center", flexDirection: "row" }}>
<DeleteConfirmation
label={resetLabel || "all data"}
disabled={sessionId === undefined || processingMessage !== undefined || noInteractions}
onDelete={() => { reset(); resetAction && resetAction(); }} />
<Tooltip title={actionLabel || "Send"}>
<span style={{ display: "flex", flexGrow: 1 }}>
<Button
sx={{ m: 1, gap: 1, flexGrow: 1 }}
variant="contained"
disabled={sessionId === undefined || processingMessage !== undefined}
onClick={() => { processQuery({ prompt: (backstoryTextRef.current && backstoryTextRef.current.getAndResetValue()) || "" }); }}>
{actionLabel}<SendIcon />
</Button>
</span>
</Tooltip>
<Tooltip title="Cancel">
<span style={{ display: "flex" }}> { /* This span is used to wrap the IconButton to ensure Tooltip works even when disabled */}
<IconButton
aria-label="cancel"
onClick={() => { cancelQuery(); }}
sx={{ display: "flex", margin: 'auto 0px' }}
size="large"
edge="start"
disabled={stopRef.current || sessionId === undefined || processing === false}
>
<CancelIcon />
</IconButton>
</span>
</Tooltip>
</Box>
</Box>
{(noInteractions || !hideDefaultPrompts) && defaultPrompts !== undefined && defaultPrompts.length !== 0 &&
<Box sx={{ display: "flex", flexDirection: "column" }}>
{
defaultPrompts.map((element, index) => {
return (<Box key={index}>{element}</Box>);
})
}
</Box>
}
<Box sx={{ display: "flex", flexGrow: 1 }}></Box>
</Box >
);
});
export type {
ConversationProps,
ConversationHandle,
};
export {
Conversation
};