435 lines
15 KiB
TypeScript
435 lines
15 KiB
TypeScript
import React, { useState, useImperativeHandle, forwardRef, useEffect, useRef, useCallback } from 'react';
|
|
import Tooltip from '@mui/material/Tooltip';
|
|
import IconButton from '@mui/material/IconButton';
|
|
import Button from '@mui/material/Button';
|
|
import Box from '@mui/material/Box';
|
|
import SendIcon from '@mui/icons-material/Send';
|
|
import CancelIcon from '@mui/icons-material/Cancel';
|
|
import { SxProps, Theme } from '@mui/material';
|
|
import PropagateLoader from "react-spinners/PropagateLoader";
|
|
|
|
import { Message } from './Message';
|
|
import { DeleteConfirmation } from 'components/DeleteConfirmation';
|
|
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
|
|
import { BackstoryElementProps } from './BackstoryTab';
|
|
import { connectionBase } from 'utils/Global';
|
|
import { useAuth } from "hooks/AuthContext";
|
|
import { StreamingResponse } from 'services/api-client';
|
|
import { ChatMessage, ChatMessageBase, ChatContext, ChatSession, ChatQuery, ChatMessageUser } from 'types/types';
|
|
import { PaginatedResponse } from 'types/conversion';
|
|
|
|
import './Conversation.css';
|
|
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
|
|
|
const defaultMessage: ChatMessage = {
|
|
type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: ""
|
|
};
|
|
|
|
const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing connection with server..." };
|
|
|
|
type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check' | 'persona';
|
|
|
|
interface ConversationHandle {
|
|
submitQuery: (query: ChatQuery) => void;
|
|
fetchHistory: () => void;
|
|
}
|
|
|
|
interface ConversationProps extends BackstoryElementProps {
|
|
className?: string, // Override default className
|
|
type: ConversationMode, // Type of Conversation chat
|
|
placeholder?: string, // Prompt to display in TextField input
|
|
actionLabel?: string, // Label to put on the primary button
|
|
resetAction?: () => void, // Callback when Reset is pressed
|
|
resetLabel?: string, // Label to put on Reset button
|
|
defaultPrompts?: React.ReactElement[], // Set of Elements to display after the TextField
|
|
defaultQuery?: string, // Default text to populate the TextField input
|
|
preamble?: ChatMessage[], // Messages to display at start of Conversation until Action has been invoked
|
|
hidePreamble?: boolean, // Whether to hide the preamble after an Action has been invoked
|
|
hideDefaultPrompts?: boolean, // Whether to hide the defaultPrompts after an Action has been invoked
|
|
messageFilter?: ((messages: ChatMessage[]) => ChatMessage[]) | undefined, // Filter callback to determine which Messages to display in Conversation
|
|
messages?: ChatMessage[], //
|
|
sx?: SxProps<Theme>,
|
|
onResponse?: ((message: ChatMessage) => void) | undefined, // Event called when a query completes (provides messages)
|
|
};
|
|
|
|
const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: ConversationProps, ref) => {
|
|
const {
|
|
actionLabel,
|
|
defaultPrompts,
|
|
hideDefaultPrompts,
|
|
hidePreamble,
|
|
messageFilter,
|
|
messages,
|
|
onResponse,
|
|
placeholder,
|
|
preamble,
|
|
resetAction,
|
|
resetLabel,
|
|
setSnack,
|
|
submitQuery,
|
|
sx,
|
|
type,
|
|
} = props;
|
|
const { apiClient } = useAuth()
|
|
const [processing, setProcessing] = useState<boolean>(false);
|
|
const [countdown, setCountdown] = useState<number>(0);
|
|
const [conversation, setConversation] = useState<ChatMessage[]>([]);
|
|
const conversationRef = useRef<ChatMessage[]>([]);
|
|
const [filteredConversation, setFilteredConversation] = useState<ChatMessage[]>([]);
|
|
const [processingMessage, setProcessingMessage] = useState<ChatMessage | undefined>(undefined);
|
|
const [streamingMessage, setStreamingMessage] = useState<ChatMessage | undefined>(undefined);
|
|
const [noInteractions, setNoInteractions] = useState<boolean>(true);
|
|
const viewableElementRef = useRef<HTMLDivElement>(null);
|
|
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
|
|
const stopRef = useRef(false);
|
|
const controllerRef = useRef<StreamingResponse>(null);
|
|
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
|
|
|
|
// Keep the ref updated whenever items changes
|
|
useEffect(() => {
|
|
conversationRef.current = conversation;
|
|
}, [conversation]);
|
|
|
|
// Update the context status
|
|
/* Transform the 'Conversation' by filtering via callback, then adding
|
|
* preamble and messages based on whether the conversation
|
|
* has any elements yet */
|
|
useEffect(() => {
|
|
let filtered = [];
|
|
if (messageFilter === undefined) {
|
|
filtered = conversation;
|
|
// console.log('No message filter provided. Using all messages.', filtered);
|
|
} else {
|
|
//console.log('Filtering conversation...')
|
|
filtered = messageFilter(conversation); /* Do not copy conversation or useEffect will loop forever */
|
|
//console.log(`${conversation.length - filtered.length} messages filtered out.`);
|
|
}
|
|
if (filtered.length === 0) {
|
|
setFilteredConversation([
|
|
...(preamble || []),
|
|
...(messages || []),
|
|
]);
|
|
} else {
|
|
setFilteredConversation([
|
|
...(hidePreamble ? [] : (preamble || [])),
|
|
...(messages || []),
|
|
...filtered,
|
|
]);
|
|
};
|
|
}, [conversation, setFilteredConversation, messageFilter, preamble, messages, hidePreamble]);
|
|
|
|
useEffect(() => {
|
|
if (chatSession) {
|
|
return;
|
|
}
|
|
const createChatSession = async () => {
|
|
try {
|
|
const chatContext: ChatContext = { type: "general" };
|
|
const response: ChatSession = await apiClient.createChatSession(chatContext);
|
|
setChatSession(response);
|
|
} catch (e) {
|
|
console.error(e);
|
|
setSnack("Unable to create chat session.", "error");
|
|
}
|
|
};
|
|
|
|
createChatSession();
|
|
|
|
}, [chatSession, setChatSession]);
|
|
|
|
const getChatMessages = useCallback(async () => {
|
|
if (!chatSession || !chatSession.id) {
|
|
return;
|
|
}
|
|
try {
|
|
const response: PaginatedResponse<ChatMessage> = await apiClient.getChatMessages(chatSession.id);
|
|
const messages: ChatMessage[] = response.data;
|
|
|
|
setProcessingMessage(undefined);
|
|
setStreamingMessage(undefined);
|
|
|
|
if (messages.length === 0) {
|
|
console.log(`History returned with 0 entries`)
|
|
setConversation([])
|
|
setNoInteractions(true);
|
|
} else {
|
|
console.log(`History returned with ${messages.length} entries:`, messages)
|
|
setConversation(messages);
|
|
setNoInteractions(false);
|
|
}
|
|
} catch (error) {
|
|
console.error('Unable to obtain chat history', error);
|
|
setProcessingMessage({ ...defaultMessage, status: "error", content: `Unable to obtain history from server.` });
|
|
setTimeout(() => {
|
|
setProcessingMessage(undefined);
|
|
setNoInteractions(true);
|
|
}, 3000);
|
|
setSnack("Unable to obtain chat history.", "error");
|
|
}
|
|
}, [chatSession]);
|
|
|
|
|
|
// Set the initial chat history to "loading" or the welcome message if loaded.
|
|
useEffect(() => {
|
|
if (!chatSession) {
|
|
setProcessingMessage(loadingMessage);
|
|
return;
|
|
}
|
|
|
|
setProcessingMessage(undefined);
|
|
setStreamingMessage(undefined);
|
|
setConversation([]);
|
|
setNoInteractions(true);
|
|
|
|
getChatMessages();
|
|
|
|
}, [chatSession]);
|
|
|
|
const handleEnter = (value: string) => {
|
|
const query: ChatQuery = {
|
|
prompt: value
|
|
}
|
|
processQuery(query);
|
|
};
|
|
|
|
useImperativeHandle(ref, () => ({
|
|
submitQuery: (query: ChatQuery) => {
|
|
processQuery(query);
|
|
},
|
|
fetchHistory: () => { getChatMessages(); }
|
|
}));
|
|
|
|
|
|
// const reset = async () => {
|
|
// try {
|
|
// const response = await fetch(connectionBase + `/api/reset/${sessionId}/${type}`, {
|
|
// method: 'PUT',
|
|
// headers: {
|
|
// 'Content-Type': 'application/json',
|
|
// 'Accept': 'application/json',
|
|
// },
|
|
// body: JSON.stringify({ reset: ['history'] })
|
|
// });
|
|
|
|
// if (!response.ok) {
|
|
// throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
|
// }
|
|
|
|
// if (!response.body) {
|
|
// throw new Error('Response body is null');
|
|
// }
|
|
|
|
// setProcessingMessage(undefined);
|
|
// setStreamingMessage(undefined);
|
|
// setConversation([]);
|
|
// setNoInteractions(true);
|
|
|
|
// } catch (e) {
|
|
// setSnack("Error resetting history", "error")
|
|
// console.error('Error resetting history:', e);
|
|
// }
|
|
// };
|
|
|
|
const cancelQuery = () => {
|
|
console.log("Stop query");
|
|
if (controllerRef.current) {
|
|
controllerRef.current.cancel();
|
|
}
|
|
controllerRef.current = null;
|
|
};
|
|
|
|
const processQuery = (query: ChatQuery) => {
|
|
if (controllerRef.current || !chatSession || !chatSession.id) {
|
|
return;
|
|
}
|
|
const sessionId: string = chatSession.id;
|
|
|
|
setNoInteractions(false);
|
|
setConversation([
|
|
...conversationRef.current,
|
|
{
|
|
...defaultMessage,
|
|
type: 'user',
|
|
sender: 'user',
|
|
content: query.prompt,
|
|
}
|
|
]);
|
|
setProcessing(true);
|
|
setProcessingMessage(
|
|
{ ...defaultMessage, content: 'Submitting request...' }
|
|
);
|
|
|
|
const chatMessage: ChatMessageUser = {
|
|
sessionId: chatSession.id,
|
|
content: query.prompt,
|
|
tunables: query.tunables,
|
|
status: "done",
|
|
type: "user",
|
|
sender: "user",
|
|
timestamp: new Date()
|
|
};
|
|
|
|
controllerRef.current = apiClient.sendMessageStream(chatMessage, {
|
|
onMessage: (msg: ChatMessageBase) => {
|
|
console.log("onMessage:", msg);
|
|
if (msg.type === "response") {
|
|
setConversation([
|
|
...conversationRef.current,
|
|
msg
|
|
]);
|
|
setStreamingMessage(undefined);
|
|
setProcessingMessage(undefined);
|
|
setProcessing(false);
|
|
} else {
|
|
setProcessingMessage(msg);
|
|
}
|
|
if (onResponse) {
|
|
onResponse(msg);
|
|
}
|
|
},
|
|
onError: (error: string | ChatMessageBase) => {
|
|
console.log("onError:", error);
|
|
// Type-guard to determine if this is a ChatMessageBase or a string
|
|
if (typeof error === "object" && error !== null && "content" in error) {
|
|
setProcessingMessage(error as ChatMessage);
|
|
setProcessing(false);
|
|
controllerRef.current = null;
|
|
} else {
|
|
setProcessingMessage({ ...defaultMessage, content: error as string });
|
|
}
|
|
},
|
|
onStreaming: (chunk: ChatMessageBase) => {
|
|
console.log("onStreaming:", chunk);
|
|
setStreamingMessage({ ...defaultMessage, ...chunk });
|
|
},
|
|
onStatusChange: (status: string) => {
|
|
console.log("onStatusChange:", status);
|
|
},
|
|
onComplete: () => {
|
|
console.log("onComplete");
|
|
controllerRef.current = null;
|
|
}
|
|
});
|
|
};
|
|
|
|
if (!chatSession) {
|
|
return (<></>);
|
|
}
|
|
return (
|
|
// <Scrollable
|
|
// className={`${className || ""} Conversation`}
|
|
// autoscroll
|
|
// textFieldRef={viewableElementRef}
|
|
// fallbackThreshold={0.5}
|
|
// sx={{
|
|
// p: 1,
|
|
// mt: 0,
|
|
// ...sx
|
|
// }}
|
|
// >
|
|
<Box className="Conversation" sx={{ flexGrow: 1, minHeight: "max-content", height: "max-content", maxHeight: "max-content", overflow: "hidden" }}>
|
|
<Box sx={{ p: 1, mt: 0, ...sx }}>
|
|
{
|
|
filteredConversation.map((message, index) =>
|
|
<Message key={index} {...{ chatSession, sendQuery: processQuery, message, connectionBase, setSnack, submitQuery }} />
|
|
)
|
|
}
|
|
{
|
|
processingMessage !== undefined &&
|
|
<Message {...{ chatSession, sendQuery: processQuery, connectionBase, setSnack, message: processingMessage, submitQuery }} />
|
|
}
|
|
{
|
|
streamingMessage !== undefined &&
|
|
<Message {...{ chatSession, sendQuery: processQuery, connectionBase, setSnack, message: streamingMessage, submitQuery }} />
|
|
}
|
|
<Box sx={{
|
|
display: "flex",
|
|
flexDirection: "column",
|
|
alignItems: "center",
|
|
justifyContent: "center",
|
|
m: 1,
|
|
}}>
|
|
<PropagateLoader
|
|
size="10px"
|
|
loading={processing}
|
|
aria-label="Loading Spinner"
|
|
data-testid="loader"
|
|
/>
|
|
{processing === true && countdown > 0 && (
|
|
<Box
|
|
sx={{
|
|
pt: 1,
|
|
fontSize: "0.7rem",
|
|
color: "darkgrey"
|
|
}}
|
|
>Response will be stopped in: {countdown}s</Box>
|
|
)}
|
|
</Box>
|
|
<Box className="Query" sx={{ display: "flex", flexDirection: "column", p: 1, flexGrow: 1 }}>
|
|
{placeholder &&
|
|
<Box sx={{ display: "flex", flexGrow: 1, p: 0, m: 0, flexDirection: "column" }}
|
|
ref={viewableElementRef}>
|
|
<BackstoryTextField
|
|
ref={backstoryTextRef}
|
|
disabled={processing}
|
|
onEnter={handleEnter}
|
|
placeholder={placeholder}
|
|
/>
|
|
</Box>
|
|
}
|
|
|
|
<Box key="jobActions" sx={{ display: "flex", justifyContent: "center", flexDirection: "row" }}>
|
|
<DeleteConfirmation
|
|
label={resetLabel || "all data"}
|
|
disabled={!chatSession || processingMessage !== undefined || noInteractions}
|
|
onDelete={() => { /*reset(); resetAction && resetAction(); */ }} />
|
|
<Tooltip title={actionLabel || "Send"}>
|
|
<span style={{ display: "flex", flexGrow: 1 }}>
|
|
<Button
|
|
sx={{ m: 1, gap: 1, flexGrow: 1 }}
|
|
variant="contained"
|
|
disabled={!chatSession || processingMessage !== undefined}
|
|
onClick={() => { processQuery({ prompt: (backstoryTextRef.current && backstoryTextRef.current.getAndResetValue()) || "" }); }}>
|
|
{actionLabel}<SendIcon />
|
|
</Button>
|
|
</span>
|
|
</Tooltip>
|
|
<Tooltip title="Cancel">
|
|
<span style={{ display: "flex" }}> { /* This span is used to wrap the IconButton to ensure Tooltip works even when disabled */}
|
|
<IconButton
|
|
aria-label="cancel"
|
|
onClick={() => { cancelQuery(); }}
|
|
sx={{ display: "flex", margin: 'auto 0px' }}
|
|
size="large"
|
|
edge="start"
|
|
disabled={stopRef.current || !chatSession || processing === false}
|
|
>
|
|
<CancelIcon />
|
|
</IconButton>
|
|
</span>
|
|
</Tooltip>
|
|
</Box>
|
|
</Box>
|
|
{(noInteractions || !hideDefaultPrompts) && defaultPrompts !== undefined && defaultPrompts.length !== 0 &&
|
|
<Box sx={{ display: "flex", flexDirection: "column" }}>
|
|
{
|
|
defaultPrompts.map((element, index) => {
|
|
return (<Box key={index}>{element}</Box>);
|
|
})
|
|
}
|
|
</Box>
|
|
}
|
|
<Box sx={{ display: "flex", flexGrow: 1 }}></Box>
|
|
</Box >
|
|
</Box>
|
|
);
|
|
});
|
|
|
|
export type {
|
|
ConversationProps,
|
|
ConversationHandle,
|
|
};
|
|
|
|
export {
|
|
Conversation
|
|
}; |