Rework in progress
This commit is contained in:
parent
33d9c1d28a
commit
dc55196311
@ -257,7 +257,7 @@ FROM llm-base AS backstory
|
|||||||
|
|
||||||
COPY /src/requirements.txt /opt/backstory/src/requirements.txt
|
COPY /src/requirements.txt /opt/backstory/src/requirements.txt
|
||||||
RUN pip install -r /opt/backstory/src/requirements.txt
|
RUN pip install -r /opt/backstory/src/requirements.txt
|
||||||
RUN pip install 'markitdown[all]'
|
RUN pip install 'markitdown[all]' pydantic
|
||||||
|
|
||||||
SHELL [ "/bin/bash", "-c" ]
|
SHELL [ "/bin/bash", "-c" ]
|
||||||
|
|
||||||
|
@ -281,6 +281,7 @@ const App = () => {
|
|||||||
throw Error("Server is temporarily down.");
|
throw Error("Server is temporarily down.");
|
||||||
}
|
}
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
console.log(`Session created: ${data.id}`);
|
||||||
setSessionId(data.id);
|
setSessionId(data.id);
|
||||||
|
|
||||||
const newPath = `/${data.id}`;
|
const newPath = `/${data.id}`;
|
||||||
|
@ -16,17 +16,20 @@ import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
|||||||
|
|
||||||
import { SetSnackType } from './Snack';
|
import { SetSnackType } from './Snack';
|
||||||
|
|
||||||
|
interface ServerTunables {
|
||||||
|
system_prompt: string,
|
||||||
|
message_history_length: number,
|
||||||
|
tools: Tool[],
|
||||||
|
rags: Tool[]
|
||||||
|
};
|
||||||
|
|
||||||
type Tool = {
|
type Tool = {
|
||||||
type: string,
|
type: string,
|
||||||
function?: {
|
enabled: boolean
|
||||||
name: string,
|
name: string,
|
||||||
description: string,
|
description: string,
|
||||||
parameters?: any,
|
parameters?: any,
|
||||||
returns?: any
|
returns?: any
|
||||||
},
|
|
||||||
name?: string,
|
|
||||||
description?: string,
|
|
||||||
enabled: boolean
|
|
||||||
};
|
};
|
||||||
|
|
||||||
interface ControlsParams {
|
interface ControlsParams {
|
||||||
@ -41,7 +44,6 @@ type GPUInfo = {
|
|||||||
discrete: boolean
|
discrete: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
type SystemInfo = {
|
type SystemInfo = {
|
||||||
"Installed RAM": string,
|
"Installed RAM": string,
|
||||||
"Graphics Card": GPUInfo[],
|
"Graphics Card": GPUInfo[],
|
||||||
@ -94,11 +96,11 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
|||||||
const [tools, setTools] = useState<Tool[]>([]);
|
const [tools, setTools] = useState<Tool[]>([]);
|
||||||
const [rags, setRags] = useState<Tool[]>([]);
|
const [rags, setRags] = useState<Tool[]>([]);
|
||||||
const [systemPrompt, setSystemPrompt] = useState<string>("");
|
const [systemPrompt, setSystemPrompt] = useState<string>("");
|
||||||
const [serverSystemPrompt, setServerSystemPrompt] = useState<string>("");
|
|
||||||
const [messageHistoryLength, setMessageHistoryLength] = useState<number>(5);
|
const [messageHistoryLength, setMessageHistoryLength] = useState<number>(5);
|
||||||
|
const [serverTunables, setServerTunables] = useState<ServerTunables | undefined>(undefined);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (systemPrompt === serverSystemPrompt || !systemPrompt.trim() || sessionId === undefined) {
|
if (serverTunables === undefined || systemPrompt === serverTunables.system_prompt || !systemPrompt.trim() || sessionId === undefined) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const sendSystemPrompt = async (prompt: string) => {
|
const sendSystemPrompt = async (prompt: string) => {
|
||||||
@ -112,13 +114,10 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
|||||||
body: JSON.stringify({ "system_prompt": prompt }),
|
body: JSON.stringify({ "system_prompt": prompt }),
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = await response.json();
|
const tunables = await response.json();
|
||||||
const newPrompt = data["system_prompt"];
|
serverTunables.system_prompt = tunables.system_prompt;
|
||||||
if (newPrompt !== serverSystemPrompt) {
|
setSystemPrompt(tunables.system_prompt)
|
||||||
setServerSystemPrompt(newPrompt);
|
|
||||||
setSystemPrompt(newPrompt)
|
|
||||||
setSnack("System prompt updated", "success");
|
setSnack("System prompt updated", "success");
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Fetch error:', error);
|
console.error('Fetch error:', error);
|
||||||
setSnack("System prompt update failed", "error");
|
setSnack("System prompt update failed", "error");
|
||||||
@ -127,10 +126,10 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
|||||||
|
|
||||||
sendSystemPrompt(systemPrompt);
|
sendSystemPrompt(systemPrompt);
|
||||||
|
|
||||||
}, [systemPrompt, setServerSystemPrompt, serverSystemPrompt, connectionBase, sessionId, setSnack]);
|
}, [systemPrompt, connectionBase, sessionId, setSnack, serverTunables]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (sessionId === undefined) {
|
if (serverTunables === undefined || messageHistoryLength === serverTunables.message_history_length || !messageHistoryLength || sessionId === undefined) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const sendMessageHistoryLength = async (length: number) => {
|
const sendMessageHistoryLength = async (length: number) => {
|
||||||
@ -158,7 +157,7 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
|||||||
|
|
||||||
sendMessageHistoryLength(messageHistoryLength);
|
sendMessageHistoryLength(messageHistoryLength);
|
||||||
|
|
||||||
}, [messageHistoryLength, setMessageHistoryLength, connectionBase, sessionId, setSnack]);
|
}, [messageHistoryLength, setMessageHistoryLength, connectionBase, sessionId, setSnack, serverTunables]);
|
||||||
|
|
||||||
const reset = async (types: ("rags" | "tools" | "history" | "system_prompt" | "message_history_length")[], message: string = "Update successful.") => {
|
const reset = async (types: ("rags" | "tools" | "history" | "system_prompt" | "message_history_length")[], message: string = "Update successful.") => {
|
||||||
try {
|
try {
|
||||||
@ -185,8 +184,7 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
|||||||
setTools(value as Tool[]);
|
setTools(value as Tool[]);
|
||||||
break;
|
break;
|
||||||
case "system_prompt":
|
case "system_prompt":
|
||||||
setServerSystemPrompt((value as any)["system_prompt"].trim());
|
setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
|
||||||
setSystemPrompt((value as any)["system_prompt"].trim());
|
|
||||||
break;
|
break;
|
||||||
case "history":
|
case "history":
|
||||||
console.log('TODO: handle history reset');
|
console.log('TODO: handle history reset');
|
||||||
@ -203,7 +201,6 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// Get the system information
|
// Get the system information
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (systemInfo !== undefined || sessionId === undefined) {
|
if (systemInfo !== undefined || sessionId === undefined) {
|
||||||
@ -229,21 +226,20 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
|||||||
setEditSystemPrompt(systemPrompt);
|
setEditSystemPrompt(systemPrompt);
|
||||||
}, [systemPrompt, setEditSystemPrompt]);
|
}, [systemPrompt, setEditSystemPrompt]);
|
||||||
|
|
||||||
|
|
||||||
const toggleRag = async (tool: Tool) => {
|
const toggleRag = async (tool: Tool) => {
|
||||||
tool.enabled = !tool.enabled
|
tool.enabled = !tool.enabled
|
||||||
try {
|
try {
|
||||||
const response = await fetch(connectionBase + `/api/rags/${sessionId}`, {
|
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Accept': 'application/json',
|
'Accept': 'application/json',
|
||||||
},
|
},
|
||||||
body: JSON.stringify({ "tool": tool?.name, "enabled": tool.enabled }),
|
body: JSON.stringify({ "rags": [{ "name": tool?.name, "enabled": tool.enabled }] }),
|
||||||
});
|
});
|
||||||
|
|
||||||
const rags = await response.json();
|
const tunables: ServerTunables = await response.json();
|
||||||
setRags([...rags])
|
setRags(tunables.rags)
|
||||||
setSnack(`${tool?.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
setSnack(`${tool?.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Fetch error:', error);
|
console.error('Fetch error:', error);
|
||||||
@ -255,86 +251,28 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
|||||||
const toggleTool = async (tool: Tool) => {
|
const toggleTool = async (tool: Tool) => {
|
||||||
tool.enabled = !tool.enabled
|
tool.enabled = !tool.enabled
|
||||||
try {
|
try {
|
||||||
const response = await fetch(connectionBase + `/api/tools/${sessionId}`, {
|
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Accept': 'application/json',
|
'Accept': 'application/json',
|
||||||
},
|
},
|
||||||
body: JSON.stringify({ "tool": tool?.function?.name, "enabled": tool.enabled }),
|
body: JSON.stringify({ "tools": [{ "name": tool.name, "enabled": tool.enabled }] }),
|
||||||
});
|
});
|
||||||
|
|
||||||
const tools = await response.json();
|
const tunables: ServerTunables = await response.json();
|
||||||
setTools([...tools])
|
setTools(tunables.tools)
|
||||||
setSnack(`${tool?.function?.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
setSnack(`${tool.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Fetch error:', error);
|
console.error('Fetch error:', error);
|
||||||
setSnack(`${tool?.function?.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
|
setSnack(`${tool.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
|
||||||
tool.enabled = !tool.enabled
|
tool.enabled = !tool.enabled
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// If the tools have not been set, fetch them from the server
|
|
||||||
useEffect(() => {
|
|
||||||
if (tools.length || sessionId === undefined) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const fetchTools = async () => {
|
|
||||||
try {
|
|
||||||
// Make the fetch request with proper headers
|
|
||||||
const response = await fetch(connectionBase + `/api/tools/${sessionId}`, {
|
|
||||||
method: 'GET',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Accept': 'application/json',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
if (!response.ok) {
|
|
||||||
throw Error();
|
|
||||||
}
|
|
||||||
const tools = await response.json();
|
|
||||||
setTools(tools);
|
|
||||||
} catch (error: any) {
|
|
||||||
setSnack("Unable to fetch tools", "error");
|
|
||||||
console.error(error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fetchTools();
|
|
||||||
}, [sessionId, tools, setTools, setSnack, connectionBase]);
|
|
||||||
|
|
||||||
// If the RAGs have not been set, fetch them from the server
|
|
||||||
useEffect(() => {
|
|
||||||
if (rags.length || sessionId === undefined) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const fetchRags = async () => {
|
|
||||||
try {
|
|
||||||
// Make the fetch request with proper headers
|
|
||||||
const response = await fetch(connectionBase + `/api/rags/${sessionId}`, {
|
|
||||||
method: 'GET',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Accept': 'application/json',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
if (!response.ok) {
|
|
||||||
throw Error();
|
|
||||||
}
|
|
||||||
const rags = await response.json();
|
|
||||||
setRags(rags);
|
|
||||||
} catch (error: any) {
|
|
||||||
setSnack("Unable to fetch RAGs", "error");
|
|
||||||
console.error(error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fetchRags();
|
|
||||||
}, [sessionId, rags, setRags, setSnack, connectionBase]);
|
|
||||||
|
|
||||||
// If the systemPrompt has not been set, fetch it from the server
|
// If the systemPrompt has not been set, fetch it from the server
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (serverSystemPrompt !== "" || sessionId === undefined) {
|
if (serverTunables !== undefined || sessionId === undefined) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const fetchTunables = async () => {
|
const fetchTunables = async () => {
|
||||||
@ -347,25 +285,29 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
const serverSystemPrompt = data["system_prompt"].trim();
|
console.log("Server tunables: ", data);
|
||||||
setServerSystemPrompt(serverSystemPrompt);
|
setServerTunables(data);
|
||||||
setSystemPrompt(serverSystemPrompt);
|
setSystemPrompt(data["system_prompt"]);
|
||||||
setMessageHistoryLength(data["message_history_length"]);
|
setMessageHistoryLength(data["message_history_length"]);
|
||||||
|
setTools(data["tools"]);
|
||||||
|
setRags(data["rags"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
fetchTunables();
|
fetchTunables();
|
||||||
}, [sessionId, serverSystemPrompt, setServerSystemPrompt, connectionBase]);
|
}, [sessionId, connectionBase, setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags]);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const toggle = async (type: string, index: number) => {
|
const toggle = async (type: string, index: number) => {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case "rag":
|
case "rag":
|
||||||
|
if (rags === undefined) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
toggleRag(rags[index])
|
toggleRag(rags[index])
|
||||||
break;
|
break;
|
||||||
case "tool":
|
case "tool":
|
||||||
|
if (tools === undefined) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
toggleTool(tools[index]);
|
toggleTool(tools[index]);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -442,11 +384,11 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
|||||||
<AccordionActions>
|
<AccordionActions>
|
||||||
<FormGroup sx={{ p: 1 }}>
|
<FormGroup sx={{ p: 1 }}>
|
||||||
{
|
{
|
||||||
tools.map((tool, index) =>
|
(tools || []).map((tool, index) =>
|
||||||
<Box key={index}>
|
<Box key={index}>
|
||||||
<Divider />
|
<Divider />
|
||||||
<FormControlLabel control={<Switch checked={tool.enabled} />} onChange={() => toggle("tool", index)} label={tool?.function?.name} />
|
<FormControlLabel control={<Switch checked={tool.enabled} />} onChange={() => toggle("tool", index)} label={tool.name} />
|
||||||
<Typography sx={{ fontSize: "0.8rem", mb: 1 }}>{tool?.function?.description}</Typography>
|
<Typography sx={{ fontSize: "0.8rem", mb: 1 }}>{tool.description}</Typography>
|
||||||
</Box>
|
</Box>
|
||||||
)
|
)
|
||||||
}</FormGroup>
|
}</FormGroup>
|
||||||
@ -463,14 +405,14 @@ const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
|
|||||||
<AccordionActions>
|
<AccordionActions>
|
||||||
<FormGroup sx={{ p: 1, flexGrow: 1, justifyContent: "flex-start" }}>
|
<FormGroup sx={{ p: 1, flexGrow: 1, justifyContent: "flex-start" }}>
|
||||||
{
|
{
|
||||||
rags.map((rag, index) =>
|
(rags || []).map((rag, index) =>
|
||||||
<Box key={index} sx={{ display: "flex", flexGrow: 1, flexDirection: "column" }}>
|
<Box key={index} sx={{ display: "flex", flexGrow: 1, flexDirection: "column" }}>
|
||||||
<Divider />
|
<Divider />
|
||||||
<FormControlLabel
|
<FormControlLabel
|
||||||
control={<Switch checked={rag.enabled} />}
|
control={<Switch checked={rag.enabled} />}
|
||||||
onChange={() => toggle("rag", index)} label={rag?.name}
|
onChange={() => toggle("rag", index)} label={rag.name}
|
||||||
/>
|
/>
|
||||||
<Typography>{rag?.description}</Typography>
|
<Typography>{rag.description}</Typography>
|
||||||
</Box>
|
</Box>
|
||||||
)
|
)
|
||||||
}</FormGroup>
|
}</FormGroup>
|
||||||
|
@ -160,14 +160,15 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json();
|
const { messages } = await response.json();
|
||||||
|
|
||||||
console.log(`History returned for ${type} from server with ${data.length} entries`)
|
if (messages === undefined || messages.length === 0) {
|
||||||
if (data.length === 0) {
|
console.log(`History returned for ${type} from server with 0 entries`)
|
||||||
setConversation([])
|
setConversation([])
|
||||||
setNoInteractions(true);
|
setNoInteractions(true);
|
||||||
} else {
|
} else {
|
||||||
setConversation(data);
|
console.log(`History returned for ${type} from server with ${messages.length} entries:`, messages)
|
||||||
|
setConversation(messages);
|
||||||
setNoInteractions(false);
|
setNoInteractions(false);
|
||||||
}
|
}
|
||||||
setProcessingMessage(undefined);
|
setProcessingMessage(undefined);
|
||||||
|
@ -42,6 +42,7 @@ const DeleteConfirmation = (props : DeleteConfirmationProps) => {
|
|||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Tooltip title={label ? `Reset ${label}` : "Reset"} >
|
<Tooltip title={label ? `Reset ${label}` : "Reset"} >
|
||||||
|
<span style={{ display: "flex" }}> { /* This span is used to wrap the IconButton to ensure Tooltip works even when disabled */}
|
||||||
<IconButton
|
<IconButton
|
||||||
aria-label="reset"
|
aria-label="reset"
|
||||||
onClick={handleClickOpen}
|
onClick={handleClickOpen}
|
||||||
@ -53,6 +54,7 @@ const DeleteConfirmation = (props : DeleteConfirmationProps) => {
|
|||||||
>
|
>
|
||||||
<ResetIcon />
|
<ResetIcon />
|
||||||
</IconButton>
|
</IconButton>
|
||||||
|
</span>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
|
|
||||||
<Dialog
|
<Dialog
|
||||||
|
@ -39,10 +39,10 @@ type MessageData = {
|
|||||||
display?: string, /* Messages generated on the server for filler should not be shown */
|
display?: string, /* Messages generated on the server for filler should not be shown */
|
||||||
id?: string,
|
id?: string,
|
||||||
isProcessing?: boolean,
|
isProcessing?: boolean,
|
||||||
metadata?: MessageMetaProps
|
metadata?: MessageMetaData
|
||||||
};
|
};
|
||||||
|
|
||||||
interface MessageMetaProps {
|
interface MessageMetaData {
|
||||||
query?: {
|
query?: {
|
||||||
query_embedding: number[];
|
query_embedding: number[];
|
||||||
vector_embedding: number[];
|
vector_embedding: number[];
|
||||||
@ -64,7 +64,7 @@ type MessageList = MessageData[];
|
|||||||
|
|
||||||
interface MessageProps {
|
interface MessageProps {
|
||||||
sx?: SxProps<Theme>,
|
sx?: SxProps<Theme>,
|
||||||
message?: MessageData,
|
message: MessageData,
|
||||||
isFullWidth?: boolean,
|
isFullWidth?: boolean,
|
||||||
submitQuery?: (text: string) => void,
|
submitQuery?: (text: string) => void,
|
||||||
sessionId?: string,
|
sessionId?: string,
|
||||||
@ -78,8 +78,25 @@ interface ChatQueryInterface {
|
|||||||
submitQuery?: (text: string) => void
|
submitQuery?: (text: string) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface MessageMetaProps {
|
||||||
|
metadata: MessageMetaData,
|
||||||
|
messageProps: MessageProps
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
const MessageMeta = (props: MessageMetaProps) => {
|
||||||
|
const {
|
||||||
|
/* MessageData */
|
||||||
|
full_query,
|
||||||
|
rag,
|
||||||
|
tools,
|
||||||
|
eval_count,
|
||||||
|
eval_duration,
|
||||||
|
prompt_eval_count,
|
||||||
|
prompt_eval_duration,
|
||||||
|
} = props.metadata || {};
|
||||||
|
const messageProps = props.messageProps;
|
||||||
|
|
||||||
const MessageMeta = ({ ...props }: MessageMetaProps) => {
|
|
||||||
return (<>
|
return (<>
|
||||||
<Box sx={{ fontSize: "0.8rem", mb: 1 }}>
|
<Box sx={{ fontSize: "0.8rem", mb: 1 }}>
|
||||||
Below is the LLM performance of this query. Note that if tools are called, the
|
Below is the LLM performance of this query. Note that if tools are called, the
|
||||||
@ -99,28 +116,28 @@ const MessageMeta = ({ ...props }: MessageMetaProps) => {
|
|||||||
<TableBody>
|
<TableBody>
|
||||||
<TableRow key="prompt" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
<TableRow key="prompt" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
||||||
<TableCell component="th" scope="row">Prompt</TableCell>
|
<TableCell component="th" scope="row">Prompt</TableCell>
|
||||||
<TableCell align="right">{props.prompt_eval_count}</TableCell>
|
<TableCell align="right">{prompt_eval_count}</TableCell>
|
||||||
<TableCell align="right">{Math.round(props.prompt_eval_duration / 10 ** 7) / 100}</TableCell>
|
<TableCell align="right">{Math.round(prompt_eval_duration / 10 ** 7) / 100}</TableCell>
|
||||||
<TableCell align="right">{Math.round(props.prompt_eval_count * 10 ** 9 / props.prompt_eval_duration)}</TableCell>
|
<TableCell align="right">{Math.round(prompt_eval_count * 10 ** 9 / prompt_eval_duration)}</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
<TableRow key="response" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
<TableRow key="response" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
||||||
<TableCell component="th" scope="row">Response</TableCell>
|
<TableCell component="th" scope="row">Response</TableCell>
|
||||||
<TableCell align="right">{props.eval_count}</TableCell>
|
<TableCell align="right">{eval_count}</TableCell>
|
||||||
<TableCell align="right">{Math.round(props.eval_duration / 10 ** 7) / 100}</TableCell>
|
<TableCell align="right">{Math.round(eval_duration / 10 ** 7) / 100}</TableCell>
|
||||||
<TableCell align="right">{Math.round(props.eval_count * 10 ** 9 / props.eval_duration)}</TableCell>
|
<TableCell align="right">{Math.round(eval_count * 10 ** 9 / eval_duration)}</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
<TableRow key="total" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
<TableRow key="total" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
||||||
<TableCell component="th" scope="row">Total</TableCell>
|
<TableCell component="th" scope="row">Total</TableCell>
|
||||||
<TableCell align="right">{props.prompt_eval_count + props.eval_count}</TableCell>
|
<TableCell align="right">{prompt_eval_count + eval_count}</TableCell>
|
||||||
<TableCell align="right">{Math.round((props.prompt_eval_duration + props.eval_duration) / 10 ** 7) / 100}</TableCell>
|
<TableCell align="right">{Math.round((prompt_eval_duration + eval_duration) / 10 ** 7) / 100}</TableCell>
|
||||||
<TableCell align="right">{Math.round((props.prompt_eval_count + props.eval_count) * 10 ** 9 / (props.prompt_eval_duration + props.eval_duration))}</TableCell>
|
<TableCell align="right">{Math.round((prompt_eval_count + eval_count) * 10 ** 9 / (prompt_eval_duration + eval_duration))}</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
</TableBody>
|
</TableBody>
|
||||||
</Table>
|
</Table>
|
||||||
</TableContainer>
|
</TableContainer>
|
||||||
|
|
||||||
{
|
{
|
||||||
props?.full_query !== undefined &&
|
full_query !== undefined &&
|
||||||
<Accordion>
|
<Accordion>
|
||||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||||
<Box sx={{ fontSize: "0.8rem" }}>
|
<Box sx={{ fontSize: "0.8rem" }}>
|
||||||
@ -128,12 +145,12 @@ const MessageMeta = ({ ...props }: MessageMetaProps) => {
|
|||||||
</Box>
|
</Box>
|
||||||
</AccordionSummary>
|
</AccordionSummary>
|
||||||
<AccordionDetails>
|
<AccordionDetails>
|
||||||
<pre style={{ "display": "block", "position": "relative" }}><CopyBubble content={props.full_query.trim()} />{props.full_query.trim()}</pre>
|
<pre style={{ "display": "block", "position": "relative" }}><CopyBubble content={full_query?.trim()} />{full_query?.trim()}</pre>
|
||||||
</AccordionDetails>
|
</AccordionDetails>
|
||||||
</Accordion>
|
</Accordion>
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
props.tools !== undefined && props.tools.length !== 0 &&
|
tools !== undefined && tools.length !== 0 &&
|
||||||
<Accordion sx={{ boxSizing: "border-box" }}>
|
<Accordion sx={{ boxSizing: "border-box" }}>
|
||||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||||
<Box sx={{ fontSize: "0.8rem" }}>
|
<Box sx={{ fontSize: "0.8rem" }}>
|
||||||
@ -141,7 +158,7 @@ const MessageMeta = ({ ...props }: MessageMetaProps) => {
|
|||||||
</Box>
|
</Box>
|
||||||
</AccordionSummary>
|
</AccordionSummary>
|
||||||
<AccordionDetails>
|
<AccordionDetails>
|
||||||
{props.tools.map((tool: any, index: number) => <Box key={index}>
|
{tools.map((tool: any, index: number) => <Box key={index}>
|
||||||
{index !== 0 && <Divider />}
|
{index !== 0 && <Divider />}
|
||||||
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "column", mt: 0.5 }}>
|
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "column", mt: 0.5 }}>
|
||||||
<div style={{ display: "flex", paddingRight: "1rem", whiteSpace: "nowrap" }}>
|
<div style={{ display: "flex", paddingRight: "1rem", whiteSpace: "nowrap" }}>
|
||||||
@ -165,24 +182,24 @@ const MessageMeta = ({ ...props }: MessageMetaProps) => {
|
|||||||
</Accordion>
|
</Accordion>
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
props?.rag?.name !== undefined && <>
|
rag?.name !== undefined && <>
|
||||||
<Accordion>
|
<Accordion>
|
||||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||||
<Box sx={{ fontSize: "0.8rem" }}>
|
<Box sx={{ fontSize: "0.8rem" }}>
|
||||||
Top RAG {props.rag.ids.length} matches from '{props.rag.name}' collection against embedding vector of {props.rag.query_embedding.length} dimensions
|
Top RAG {rag.ids.length} matches from '{rag.name}' collection against embedding vector of {rag.query_embedding.length} dimensions
|
||||||
</Box>
|
</Box>
|
||||||
</AccordionSummary>
|
</AccordionSummary>
|
||||||
<AccordionDetails>
|
<AccordionDetails>
|
||||||
{props.rag.ids.map((id: number, index: number) => <Box key={index}>
|
{rag.ids.map((id: number, index: number) => <Box key={index}>
|
||||||
{index !== 0 && <Divider />}
|
{index !== 0 && <Divider />}
|
||||||
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "row", mb: 0.5, mt: 0.5 }}>
|
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "row", mb: 0.5, mt: 0.5 }}>
|
||||||
<div style={{ display: "flex", flexDirection: "column", paddingRight: "1rem", minWidth: "10rem" }}>
|
<div style={{ display: "flex", flexDirection: "column", paddingRight: "1rem", minWidth: "10rem" }}>
|
||||||
<div style={{ whiteSpace: "nowrap" }}>Doc ID: {props.rag.ids[index].slice(-10)}</div>
|
<div style={{ whiteSpace: "nowrap" }}>Doc ID: {rag.ids[index].slice(-10)}</div>
|
||||||
<div style={{ whiteSpace: "nowrap" }}>Similarity: {Math.round(props.rag.distances[index] * 100) / 100}</div>
|
<div style={{ whiteSpace: "nowrap" }}>Similarity: {Math.round(rag.distances[index] * 100) / 100}</div>
|
||||||
<div style={{ whiteSpace: "nowrap" }}>Type: {props.rag.metadatas[index].doc_type}</div>
|
<div style={{ whiteSpace: "nowrap" }}>Type: {rag.metadatas[index].doc_type}</div>
|
||||||
<div style={{ whiteSpace: "nowrap" }}>Chunk Len: {props.rag.documents[index].length}</div>
|
<div style={{ whiteSpace: "nowrap" }}>Chunk Len: {rag.documents[index].length}</div>
|
||||||
</div>
|
</div>
|
||||||
<div style={{ display: "flex", padding: "3px", flexGrow: 1, border: "1px solid #E0E0E0", maxHeight: "5rem", overflow: "auto" }}>{props.rag.documents[index]}</div>
|
<div style={{ display: "flex", padding: "3px", flexGrow: 1, border: "1px solid #E0E0E0", maxHeight: "5rem", overflow: "auto" }}>{rag.documents[index]}</div>
|
||||||
</Box>
|
</Box>
|
||||||
</Box>
|
</Box>
|
||||||
)}
|
)}
|
||||||
@ -195,13 +212,53 @@ const MessageMeta = ({ ...props }: MessageMetaProps) => {
|
|||||||
</Box>
|
</Box>
|
||||||
</AccordionSummary>
|
</AccordionSummary>
|
||||||
<AccordionDetails>
|
<AccordionDetails>
|
||||||
<VectorVisualizer inline {...props} rag={props?.rag} />
|
<VectorVisualizer inline {...messageProps} {...props.metadata} rag={rag} />
|
||||||
|
</AccordionDetails>
|
||||||
|
</Accordion>
|
||||||
|
<Accordion>
|
||||||
|
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||||
|
<Box sx={{ fontSize: "0.8rem" }}>
|
||||||
|
All response fields
|
||||||
|
</Box>
|
||||||
|
</AccordionSummary>
|
||||||
|
<AccordionDetails>
|
||||||
|
{Object.entries(props.messageProps.message)
|
||||||
|
.filter(([key, value]) => key !== undefined && value !== undefined)
|
||||||
|
.map(([key, value]) => (<>
|
||||||
|
{(typeof (value) !== "string" || value?.trim() !== "") &&
|
||||||
|
<Accordion key={key}>
|
||||||
|
<AccordionSummary sx={{ fontSize: "1rem", fontWeight: "bold" }} expandIcon={<ExpandMoreIcon />}>
|
||||||
|
{key}
|
||||||
|
</AccordionSummary>
|
||||||
|
<AccordionDetails>
|
||||||
|
{key === "metadata" &&
|
||||||
|
Object.entries(value)
|
||||||
|
.filter(([key, value]) => key !== undefined && value !== undefined)
|
||||||
|
.map(([key, value]) => (
|
||||||
|
<Accordion key={`${key}-metadata`}>
|
||||||
|
<AccordionSummary sx={{ fontSize: "1rem", fontWeight: "bold" }} expandIcon={<ExpandMoreIcon />}>
|
||||||
|
{key}
|
||||||
|
</AccordionSummary>
|
||||||
|
<AccordionDetails>
|
||||||
|
<pre>{`${typeof (value) !== "object" ? value : JSON.stringify(value)}`}</pre>
|
||||||
|
</AccordionDetails>
|
||||||
|
</Accordion>
|
||||||
|
))}
|
||||||
|
{key !== "metadata" &&
|
||||||
|
<pre>{typeof (value) !== "object" ? value : JSON.stringify(value)}</pre>
|
||||||
|
}
|
||||||
</AccordionDetails>
|
</AccordionDetails>
|
||||||
</Accordion>
|
</Accordion>
|
||||||
</>
|
|
||||||
}
|
}
|
||||||
</>
|
</>
|
||||||
);
|
)
|
||||||
|
)}
|
||||||
|
</AccordionDetails>
|
||||||
|
</Accordion>
|
||||||
|
|
||||||
|
</>
|
||||||
|
}
|
||||||
|
</>);
|
||||||
};
|
};
|
||||||
|
|
||||||
const ChatQuery = ({ text, submitQuery }: ChatQueryInterface) => {
|
const ChatQuery = ({ text, submitQuery }: ChatQueryInterface) => {
|
||||||
@ -221,7 +278,7 @@ const ChatQuery = ({ text, submitQuery }: ChatQueryInterface) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const Message = (props: MessageProps) => {
|
const Message = (props: MessageProps) => {
|
||||||
const { message, submitQuery, isFullWidth, sessionId, setSnack, connectionBase, sx, className } = props;
|
const { message, submitQuery, isFullWidth, sx, className } = props;
|
||||||
const [expanded, setExpanded] = useState<boolean>(false);
|
const [expanded, setExpanded] = useState<boolean>(false);
|
||||||
const textFieldRef = useRef(null);
|
const textFieldRef = useRef(null);
|
||||||
|
|
||||||
@ -293,7 +350,7 @@ const Message = (props: MessageProps) => {
|
|||||||
{message.metadata && <>
|
{message.metadata && <>
|
||||||
<Collapse in={expanded} timeout="auto" unmountOnExit>
|
<Collapse in={expanded} timeout="auto" unmountOnExit>
|
||||||
<CardContent>
|
<CardContent>
|
||||||
<MessageMeta {...{ ...message.metadata, sessionId, connectionBase, setSnack }} />
|
<MessageMeta messageProps={props} metadata={message.metadata} />
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Collapse>
|
</Collapse>
|
||||||
</>}
|
</>}
|
||||||
@ -305,7 +362,6 @@ export type {
|
|||||||
MessageProps,
|
MessageProps,
|
||||||
MessageList,
|
MessageList,
|
||||||
ChatQueryInterface,
|
ChatQueryInterface,
|
||||||
MessageMetaProps,
|
|
||||||
MessageData,
|
MessageData,
|
||||||
MessageRoles
|
MessageRoles
|
||||||
};
|
};
|
||||||
|
@ -112,6 +112,12 @@ const ResumeBuilder: React.FC<ResumeBuilderProps> = ({
|
|||||||
return keep;
|
return keep;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/* If Resume hasn't occurred yet and there is still more than one message,
|
||||||
|
* resume has been generated. */
|
||||||
|
if (!hasResume && reduced.length > 1) {
|
||||||
|
setHasResume(true);
|
||||||
|
}
|
||||||
|
|
||||||
if (reduced.length > 0) {
|
if (reduced.length > 0) {
|
||||||
// First message is always 'content'
|
// First message is always 'content'
|
||||||
reduced[0].title = 'Job Description';
|
reduced[0].title = 'Job Description';
|
||||||
@ -123,7 +129,7 @@ const ResumeBuilder: React.FC<ResumeBuilderProps> = ({
|
|||||||
reduced = reduced.filter(m => m.display !== "hide");
|
reduced = reduced.filter(m => m.display !== "hide");
|
||||||
|
|
||||||
return reduced;
|
return reduced;
|
||||||
}, [setHasJobDescription, setHasResume]);
|
}, [setHasJobDescription, setHasResume, hasResume]);
|
||||||
|
|
||||||
const filterResumeMessages = useCallback((messages: MessageList): MessageList => {
|
const filterResumeMessages = useCallback((messages: MessageList): MessageList => {
|
||||||
if (messages === undefined || messages.length === 0) {
|
if (messages === undefined || messages.length === 0) {
|
||||||
@ -135,11 +141,11 @@ const ResumeBuilder: React.FC<ResumeBuilderProps> = ({
|
|||||||
if ((m.metadata?.origin || m.origin || "no origin") === 'fact_check') {
|
if ((m.metadata?.origin || m.origin || "no origin") === 'fact_check') {
|
||||||
setHasFacts(true);
|
setHasFacts(true);
|
||||||
}
|
}
|
||||||
// if (!keep) {
|
if (!keep) {
|
||||||
// console.log(`filterResumeMessages: ${i + 1} filtered:`, m);
|
console.log(`filterResumeMessages: ${i + 1} filtered:`, m);
|
||||||
// } else {
|
} else {
|
||||||
// console.log(`filterResumeMessages: ${i + 1}:`, m);
|
console.log(`filterResumeMessages: ${i + 1}:`, m);
|
||||||
// }
|
}
|
||||||
return keep;
|
return keep;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
890
src/server.py
890
src/server.py
File diff suppressed because it is too large
Load Diff
@ -2,7 +2,9 @@
|
|||||||
from . import defines
|
from . import defines
|
||||||
|
|
||||||
# Import rest as `utils.*` accessible
|
# Import rest as `utils.*` accessible
|
||||||
from .rag import *
|
from .rag import ChromaDBFileWatcher, start_file_watcher
|
||||||
|
|
||||||
# Expose only public names (avoid importing hidden/internal names)
|
from .message import Message
|
||||||
__all__ = [name for name in dir() if not name.startswith("_")]
|
from .conversation import Conversation
|
||||||
|
from .session import Session, Chat, Resume, JobDescription, FactCheck
|
||||||
|
from .context import Context
|
98
src/utils/context.py
Normal file
98
src/utils/context.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
from uuid import uuid4
|
||||||
|
from typing import List, Optional
|
||||||
|
from typing_extensions import Annotated, Union
|
||||||
|
from .session import AnySession, Session
|
||||||
|
|
||||||
|
class Context(BaseModel):
|
||||||
|
id: str = Field(
|
||||||
|
default_factory=lambda: str(uuid4()),
|
||||||
|
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
sessions: List[Annotated[Union[*Session.__subclasses__()], Field(discriminator="session_type")]] = Field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
|
user_resume: Optional[str] = None
|
||||||
|
user_job_description: Optional[str] = None
|
||||||
|
user_facts: Optional[str] = None
|
||||||
|
tools: List[dict] = []
|
||||||
|
rags: List[dict] = []
|
||||||
|
message_history_length: int = 5
|
||||||
|
context_tokens: int = 0
|
||||||
|
|
||||||
|
def __init__(self, id: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(id=id if id is not None else str(uuid4()), **kwargs)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_unique_session_types(self):
|
||||||
|
"""Ensure at most one session per session_type."""
|
||||||
|
session_types = [session.session_type for session in self.sessions]
|
||||||
|
if len(session_types) != len(set(session_types)):
|
||||||
|
raise ValueError("Context cannot contain multiple sessions of the same session_type")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_or_create_session(self, session_type: str, **kwargs) -> Session:
|
||||||
|
"""
|
||||||
|
Get or create and append a new session of the specified type, ensuring only one session per type exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_type: The type of session to create (e.g., 'web', 'database').
|
||||||
|
**kwargs: Additional fields required by the specific session subclass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created session instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no matching session type is found or if a session of this type already exists.
|
||||||
|
"""
|
||||||
|
# Check if a session with the given session_type already exists
|
||||||
|
for session in self.sessions:
|
||||||
|
if session.session_type == session_type:
|
||||||
|
return session
|
||||||
|
|
||||||
|
# Find the matching subclass
|
||||||
|
for session_cls in Session.__subclasses__():
|
||||||
|
if session_cls.__fields__["session_type"].default == session_type:
|
||||||
|
# Create the session instance with provided kwargs
|
||||||
|
session = session_cls(session_type=session_type, **kwargs)
|
||||||
|
self.sessions.append(session)
|
||||||
|
return session
|
||||||
|
|
||||||
|
raise ValueError(f"No session class found for session_type: {session_type}")
|
||||||
|
|
||||||
|
def add_session(self, session: AnySession) -> None:
|
||||||
|
"""Add a Session to the context, ensuring no duplicate session_type."""
|
||||||
|
if any(s.session_type == session.session_type for s in self.sessions):
|
||||||
|
raise ValueError(f"A session with session_type '{session.session_type}' already exists")
|
||||||
|
self.sessions.append(session)
|
||||||
|
|
||||||
|
def get_session(self, session_type: str) -> Session | None:
|
||||||
|
"""Return the Session with the given session_type, or None if not found."""
|
||||||
|
for session in self.sessions:
|
||||||
|
if session.session_type == session_type:
|
||||||
|
return session
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_valid_session_type(self, session_type: str) -> bool:
|
||||||
|
"""Check if the given session_type is valid."""
|
||||||
|
return session_type in Session.valid_session_types()
|
||||||
|
|
||||||
|
def get_summary(self) -> str:
|
||||||
|
"""Return a summary of the context."""
|
||||||
|
if not self.sessions:
|
||||||
|
return f"Context {self.uuid}: No sessions."
|
||||||
|
summary = f"Context {self.uuid}:\n"
|
||||||
|
for i, session in enumerate(self.sessions, 1):
|
||||||
|
summary += f"\nSession {i} ({session.session_type}):\n"
|
||||||
|
summary += session.conversation.get_summary()
|
||||||
|
if session.session_type == "resume":
|
||||||
|
summary += f"\nResume: {session.get_resume()}\n"
|
||||||
|
elif session.session_type == "job_description":
|
||||||
|
summary += f"\nJob Description: {session.job_description}\n"
|
||||||
|
elif session.session_type == "fact_check":
|
||||||
|
summary += f"\nFacts: {session.facts}\n"
|
||||||
|
elif session.session_type == "chat":
|
||||||
|
summary += f"\nChat Name: {session.name}\n"
|
||||||
|
return summary
|
23
src/utils/conversation.py
Normal file
23
src/utils/conversation.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from .message import Message
|
||||||
|
|
||||||
|
class Conversation(BaseModel):
|
||||||
|
messages: List[Message] = []
|
||||||
|
|
||||||
|
def add_message(self, message: Message | List[Message]) -> None:
|
||||||
|
"""Add a Message(s) to the conversation."""
|
||||||
|
if isinstance(message, Message):
|
||||||
|
self.messages.append(message)
|
||||||
|
else:
|
||||||
|
self.messages.extend(message)
|
||||||
|
|
||||||
|
def get_summary(self) -> str:
|
||||||
|
"""Return a summary of the conversation."""
|
||||||
|
if not self.messages:
|
||||||
|
return "Conversation is empty."
|
||||||
|
summary = f"Conversation:\n"
|
||||||
|
for i, message in enumerate(self.messages, 1):
|
||||||
|
summary += f"\nMessage {i}:\n{message.get_summary()}\n"
|
||||||
|
return summary
|
31
src/utils/message.py
Normal file
31
src/utils/message.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
prompt: str
|
||||||
|
preamble: str = ""
|
||||||
|
content: str = ""
|
||||||
|
response: str = ""
|
||||||
|
metadata: dict[str, Any] = {
|
||||||
|
"rag": { "documents": [] },
|
||||||
|
"tools": [],
|
||||||
|
"eval_count": 0,
|
||||||
|
"eval_duration": 0,
|
||||||
|
"prompt_eval_count": 0,
|
||||||
|
"prompt_eval_duration": 0,
|
||||||
|
}
|
||||||
|
actions: List[str] = []
|
||||||
|
timestamp: datetime = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
def get_summary(self) -> str:
|
||||||
|
"""Return a summary of the message."""
|
||||||
|
response_summary = (
|
||||||
|
f"Response: {self.response} (Actions: {', '.join(self.actions)})"
|
||||||
|
if self.response else "No response yet"
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
f"Message at {self.timestamp}:\n"
|
||||||
|
f"Query: {self.preamble}{self.content}\n"
|
||||||
|
f"{response_summary}"
|
||||||
|
)
|
78
src/utils/session.py
Normal file
78
src/utils/session.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||||
|
from typing import Literal, TypeAlias, get_args
|
||||||
|
from .conversation import Conversation
|
||||||
|
|
||||||
|
class Session(BaseModel):
|
||||||
|
session_type: Literal["resume", "job_description", "fact_check", "chat"]
|
||||||
|
system_prompt: str = "You are a helpful assistant."
|
||||||
|
conversation: Conversation = Conversation()
|
||||||
|
context_tokens: int = 0
|
||||||
|
|
||||||
|
_content_seed: str = PrivateAttr(default="")
|
||||||
|
|
||||||
|
def get_and_reset_content_seed(self):
|
||||||
|
tmp = self._content_seed
|
||||||
|
self._content_seed = ""
|
||||||
|
return tmp
|
||||||
|
|
||||||
|
def set_content_seed(self, content: str) -> None:
|
||||||
|
"""Set the content seed for the session."""
|
||||||
|
self._content_seed = content
|
||||||
|
|
||||||
|
def get_content_seed(self) -> str:
|
||||||
|
"""Get the content seed for the session."""
|
||||||
|
return self._content_seed
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def valid_session_types(cls) -> set[str]:
|
||||||
|
"""Return the set of valid session_type values."""
|
||||||
|
return set(get_args(cls.__annotations__["session_type"]))
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for Session or any subclass
|
||||||
|
AnySession: TypeAlias = Session # BaseModel covers Session and subclasses
|
||||||
|
|
||||||
|
class Resume(Session):
|
||||||
|
session_type: Literal["resume"] = "resume"
|
||||||
|
resume: str = ""
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_resume(self):
|
||||||
|
if not self.resume.strip():
|
||||||
|
raise ValueError("Resume content cannot be empty")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_resume(self) -> str:
|
||||||
|
"""Get the resume content."""
|
||||||
|
return self.resume
|
||||||
|
|
||||||
|
def set_resume(self, resume: str) -> None:
|
||||||
|
"""Set the resume content."""
|
||||||
|
self.resume = resume
|
||||||
|
|
||||||
|
class JobDescription(Session):
|
||||||
|
session_type: Literal["job_description"] = "job_description"
|
||||||
|
job_description: str = ""
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_job_description(self):
|
||||||
|
if not self.job_description.strip():
|
||||||
|
raise ValueError("Job description cannot be empty")
|
||||||
|
return self
|
||||||
|
|
||||||
|
class FactCheck(Session):
|
||||||
|
session_type: Literal["fact_check"] = "fact_check"
|
||||||
|
facts: str = ""
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_facts(self):
|
||||||
|
if not self.facts.strip():
|
||||||
|
raise ValueError("Facts cannot be empty")
|
||||||
|
return self
|
||||||
|
|
||||||
|
class Chat(Session):
|
||||||
|
session_type: Literal["chat"] = "chat"
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_name(self):
|
||||||
|
return self
|
Loading…
x
Reference in New Issue
Block a user