backstory/frontend/src/components/Conversation.tsx

435 lines
15 KiB
TypeScript

import React, { useState, useImperativeHandle, forwardRef, useEffect, useRef, useCallback } from 'react';
import Tooltip from '@mui/material/Tooltip';
import IconButton from '@mui/material/IconButton';
import Button from '@mui/material/Button';
import Box from '@mui/material/Box';
import SendIcon from '@mui/icons-material/Send';
import CancelIcon from '@mui/icons-material/Cancel';
import { SxProps, Theme } from '@mui/material';
import PropagateLoader from "react-spinners/PropagateLoader";
import { Message } from './Message';
import { DeleteConfirmation } from 'components/DeleteConfirmation';
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
import { BackstoryElementProps } from './BackstoryTab';
import { connectionBase } from 'utils/Global';
import { useAuth } from "hooks/AuthContext";
import { StreamingResponse } from 'services/api-client';
import { ChatMessage, ChatMessageBase, ChatContext, ChatSession, ChatQuery, ChatMessageUser } from 'types/types';
import { PaginatedResponse } from 'types/conversion';
import './Conversation.css';
import { useSelectedCandidate } from 'hooks/GlobalContext';
const defaultMessage: ChatMessage = {
type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: ""
};
const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing connection with server..." };
type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check' | 'persona';
interface ConversationHandle {
submitQuery: (query: ChatQuery) => void;
fetchHistory: () => void;
}
interface ConversationProps extends BackstoryElementProps {
className?: string, // Override default className
type: ConversationMode, // Type of Conversation chat
placeholder?: string, // Prompt to display in TextField input
actionLabel?: string, // Label to put on the primary button
resetAction?: () => void, // Callback when Reset is pressed
resetLabel?: string, // Label to put on Reset button
defaultPrompts?: React.ReactElement[], // Set of Elements to display after the TextField
defaultQuery?: string, // Default text to populate the TextField input
preamble?: ChatMessage[], // Messages to display at start of Conversation until Action has been invoked
hidePreamble?: boolean, // Whether to hide the preamble after an Action has been invoked
hideDefaultPrompts?: boolean, // Whether to hide the defaultPrompts after an Action has been invoked
messageFilter?: ((messages: ChatMessage[]) => ChatMessage[]) | undefined, // Filter callback to determine which Messages to display in Conversation
messages?: ChatMessage[], //
sx?: SxProps<Theme>,
onResponse?: ((message: ChatMessage) => void) | undefined, // Event called when a query completes (provides messages)
};
const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: ConversationProps, ref) => {
const {
actionLabel,
defaultPrompts,
hideDefaultPrompts,
hidePreamble,
messageFilter,
messages,
onResponse,
placeholder,
preamble,
resetAction,
resetLabel,
setSnack,
submitQuery,
sx,
type,
} = props;
const { apiClient } = useAuth()
const [processing, setProcessing] = useState<boolean>(false);
const [countdown, setCountdown] = useState<number>(0);
const [conversation, setConversation] = useState<ChatMessage[]>([]);
const conversationRef = useRef<ChatMessage[]>([]);
const [filteredConversation, setFilteredConversation] = useState<ChatMessage[]>([]);
const [processingMessage, setProcessingMessage] = useState<ChatMessage | undefined>(undefined);
const [streamingMessage, setStreamingMessage] = useState<ChatMessage | undefined>(undefined);
const [noInteractions, setNoInteractions] = useState<boolean>(true);
const viewableElementRef = useRef<HTMLDivElement>(null);
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
const stopRef = useRef(false);
const controllerRef = useRef<StreamingResponse>(null);
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
// Keep the ref updated whenever items changes
useEffect(() => {
conversationRef.current = conversation;
}, [conversation]);
// Update the context status
/* Transform the 'Conversation' by filtering via callback, then adding
* preamble and messages based on whether the conversation
* has any elements yet */
useEffect(() => {
let filtered = [];
if (messageFilter === undefined) {
filtered = conversation;
// console.log('No message filter provided. Using all messages.', filtered);
} else {
//console.log('Filtering conversation...')
filtered = messageFilter(conversation); /* Do not copy conversation or useEffect will loop forever */
//console.log(`${conversation.length - filtered.length} messages filtered out.`);
}
if (filtered.length === 0) {
setFilteredConversation([
...(preamble || []),
...(messages || []),
]);
} else {
setFilteredConversation([
...(hidePreamble ? [] : (preamble || [])),
...(messages || []),
...filtered,
]);
};
}, [conversation, setFilteredConversation, messageFilter, preamble, messages, hidePreamble]);
useEffect(() => {
if (chatSession) {
return;
}
const createChatSession = async () => {
try {
const chatContext: ChatContext = { type: "general" };
const response: ChatSession = await apiClient.createChatSession(chatContext);
setChatSession(response);
} catch (e) {
console.error(e);
setSnack("Unable to create chat session.", "error");
}
};
createChatSession();
}, [chatSession, setChatSession]);
const getChatMessages = useCallback(async () => {
if (!chatSession || !chatSession.id) {
return;
}
try {
const response: PaginatedResponse<ChatMessage> = await apiClient.getChatMessages(chatSession.id);
const messages: ChatMessage[] = response.data;
setProcessingMessage(undefined);
setStreamingMessage(undefined);
if (messages.length === 0) {
console.log(`History returned with 0 entries`)
setConversation([])
setNoInteractions(true);
} else {
console.log(`History returned with ${messages.length} entries:`, messages)
setConversation(messages);
setNoInteractions(false);
}
} catch (error) {
console.error('Unable to obtain chat history', error);
setProcessingMessage({ ...defaultMessage, status: "error", content: `Unable to obtain history from server.` });
setTimeout(() => {
setProcessingMessage(undefined);
setNoInteractions(true);
}, 3000);
setSnack("Unable to obtain chat history.", "error");
}
}, [chatSession]);
// Set the initial chat history to "loading" or the welcome message if loaded.
useEffect(() => {
if (!chatSession) {
setProcessingMessage(loadingMessage);
return;
}
setProcessingMessage(undefined);
setStreamingMessage(undefined);
setConversation([]);
setNoInteractions(true);
getChatMessages();
}, [chatSession]);
const handleEnter = (value: string) => {
const query: ChatQuery = {
prompt: value
}
processQuery(query);
};
useImperativeHandle(ref, () => ({
submitQuery: (query: ChatQuery) => {
processQuery(query);
},
fetchHistory: () => { getChatMessages(); }
}));
// const reset = async () => {
// try {
// const response = await fetch(connectionBase + `/api/reset/${sessionId}/${type}`, {
// method: 'PUT',
// headers: {
// 'Content-Type': 'application/json',
// 'Accept': 'application/json',
// },
// body: JSON.stringify({ reset: ['history'] })
// });
// if (!response.ok) {
// throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
// }
// if (!response.body) {
// throw new Error('Response body is null');
// }
// setProcessingMessage(undefined);
// setStreamingMessage(undefined);
// setConversation([]);
// setNoInteractions(true);
// } catch (e) {
// setSnack("Error resetting history", "error")
// console.error('Error resetting history:', e);
// }
// };
const cancelQuery = () => {
console.log("Stop query");
if (controllerRef.current) {
controllerRef.current.cancel();
}
controllerRef.current = null;
};
const processQuery = (query: ChatQuery) => {
if (controllerRef.current || !chatSession || !chatSession.id) {
return;
}
const sessionId: string = chatSession.id;
setNoInteractions(false);
setConversation([
...conversationRef.current,
{
...defaultMessage,
type: 'user',
sender: 'user',
content: query.prompt,
}
]);
setProcessing(true);
setProcessingMessage(
{ ...defaultMessage, content: 'Submitting request...' }
);
const chatMessage: ChatMessageUser = {
sessionId: chatSession.id,
content: query.prompt,
tunables: query.tunables,
status: "done",
type: "user",
sender: "user",
timestamp: new Date()
};
controllerRef.current = apiClient.sendMessageStream(chatMessage, {
onMessage: (msg: ChatMessageBase) => {
console.log("onMessage:", msg);
if (msg.type === "response") {
setConversation([
...conversationRef.current,
msg
]);
setStreamingMessage(undefined);
setProcessingMessage(undefined);
setProcessing(false);
} else {
setProcessingMessage(msg);
}
if (onResponse) {
onResponse(msg);
}
},
onError: (error: string | ChatMessageBase) => {
console.log("onError:", error);
// Type-guard to determine if this is a ChatMessageBase or a string
if (typeof error === "object" && error !== null && "content" in error) {
setProcessingMessage(error as ChatMessage);
setProcessing(false);
controllerRef.current = null;
} else {
setProcessingMessage({ ...defaultMessage, content: error as string });
}
},
onStreaming: (chunk: ChatMessageBase) => {
console.log("onStreaming:", chunk);
setStreamingMessage({ ...defaultMessage, ...chunk });
},
onStatusChange: (status: string) => {
console.log("onStatusChange:", status);
},
onComplete: () => {
console.log("onComplete");
controllerRef.current = null;
}
});
};
if (!chatSession) {
return (<></>);
}
return (
// <Scrollable
// className={`${className || ""} Conversation`}
// autoscroll
// textFieldRef={viewableElementRef}
// fallbackThreshold={0.5}
// sx={{
// p: 1,
// mt: 0,
// ...sx
// }}
// >
<Box className="Conversation" sx={{ flexGrow: 1, minHeight: "max-content", height: "max-content", maxHeight: "max-content", overflow: "hidden" }}>
<Box sx={{ p: 1, mt: 0, ...sx }}>
{
filteredConversation.map((message, index) =>
<Message key={index} {...{ chatSession, sendQuery: processQuery, message, connectionBase, setSnack, submitQuery }} />
)
}
{
processingMessage !== undefined &&
<Message {...{ chatSession, sendQuery: processQuery, connectionBase, setSnack, message: processingMessage, submitQuery }} />
}
{
streamingMessage !== undefined &&
<Message {...{ chatSession, sendQuery: processQuery, connectionBase, setSnack, message: streamingMessage, submitQuery }} />
}
<Box sx={{
display: "flex",
flexDirection: "column",
alignItems: "center",
justifyContent: "center",
m: 1,
}}>
<PropagateLoader
size="10px"
loading={processing}
aria-label="Loading Spinner"
data-testid="loader"
/>
{processing === true && countdown > 0 && (
<Box
sx={{
pt: 1,
fontSize: "0.7rem",
color: "darkgrey"
}}
>Response will be stopped in: {countdown}s</Box>
)}
</Box>
<Box className="Query" sx={{ display: "flex", flexDirection: "column", p: 1, flexGrow: 1 }}>
{placeholder &&
<Box sx={{ display: "flex", flexGrow: 1, p: 0, m: 0, flexDirection: "column" }}
ref={viewableElementRef}>
<BackstoryTextField
ref={backstoryTextRef}
disabled={processing}
onEnter={handleEnter}
placeholder={placeholder}
/>
</Box>
}
<Box key="jobActions" sx={{ display: "flex", justifyContent: "center", flexDirection: "row" }}>
<DeleteConfirmation
label={resetLabel || "all data"}
disabled={!chatSession || processingMessage !== undefined || noInteractions}
onDelete={() => { /*reset(); resetAction && resetAction(); */ }} />
<Tooltip title={actionLabel || "Send"}>
<span style={{ display: "flex", flexGrow: 1 }}>
<Button
sx={{ m: 1, gap: 1, flexGrow: 1 }}
variant="contained"
disabled={!chatSession || processingMessage !== undefined}
onClick={() => { processQuery({ prompt: (backstoryTextRef.current && backstoryTextRef.current.getAndResetValue()) || "" }); }}>
{actionLabel}<SendIcon />
</Button>
</span>
</Tooltip>
<Tooltip title="Cancel">
<span style={{ display: "flex" }}> { /* This span is used to wrap the IconButton to ensure Tooltip works even when disabled */}
<IconButton
aria-label="cancel"
onClick={() => { cancelQuery(); }}
sx={{ display: "flex", margin: 'auto 0px' }}
size="large"
edge="start"
disabled={stopRef.current || !chatSession || processing === false}
>
<CancelIcon />
</IconButton>
</span>
</Tooltip>
</Box>
</Box>
{(noInteractions || !hideDefaultPrompts) && defaultPrompts !== undefined && defaultPrompts.length !== 0 &&
<Box sx={{ display: "flex", flexDirection: "column" }}>
{
defaultPrompts.map((element, index) => {
return (<Box key={index}>{element}</Box>);
})
}
</Box>
}
<Box sx={{ display: "flex", flexGrow: 1 }}></Box>
</Box >
</Box>
);
});
export type {
ConversationProps,
ConversationHandle,
};
export {
Conversation
};