Compare commits
11 Commits
5806563777
...
8a4f94817a
Author | SHA1 | Date | |
---|---|---|---|
8a4f94817a | |||
10f28b0e9b | |||
2a3dc56897 | |||
7f24d8870c | |||
3094288e46 | |||
d1940e18e5 | |||
e607e3a2f2 | |||
4614dbb237 | |||
622c33545e | |||
c3cf9a9c76 | |||
90a83a7313 |
@ -293,8 +293,13 @@ RUN { \
|
|||||||
echo ' openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout src/key.pem -out src/cert.pem -subj "/C=US/ST=OR/L=Portland/O=Development/CN=localhost"'; \
|
echo ' openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout src/key.pem -out src/cert.pem -subj "/C=US/ST=OR/L=Portland/O=Development/CN=localhost"'; \
|
||||||
echo ' fi' ; \
|
echo ' fi' ; \
|
||||||
echo ' while true; do'; \
|
echo ' while true; do'; \
|
||||||
echo ' echo "Launching Backstory server..."'; \
|
echo ' if [[ ! -e /opt/backstory/block-server ]]; then'; \
|
||||||
echo ' python src/server.py "${@}" || echo "Backstory server died. Restarting in 3 seconds."'; \
|
echo ' echo "Launching Backstory server..."'; \
|
||||||
|
echo ' python src/server.py "${@}" || echo "Backstory server died."'; \
|
||||||
|
echo ' else'; \
|
||||||
|
echo ' echo "block-server file exists. Not launching."'; \
|
||||||
|
echo ' fi' ; \
|
||||||
|
echo ' echo "Sleeping for 3 seconds."'; \
|
||||||
echo ' sleep 3'; \
|
echo ' sleep 3'; \
|
||||||
echo ' done' ; \
|
echo ' done' ; \
|
||||||
echo 'fi'; \
|
echo 'fi'; \
|
||||||
|
738
frontend/package-lock.json
generated
738
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -18,6 +18,7 @@
|
|||||||
"@types/node": "^16.18.126",
|
"@types/node": "^16.18.126",
|
||||||
"@types/react": "^19.0.12",
|
"@types/react": "^19.0.12",
|
||||||
"@types/react-dom": "^19.0.4",
|
"@types/react-dom": "^19.0.4",
|
||||||
|
"@uiw/react-json-view": "^2.0.0-alpha.31",
|
||||||
"mui-markdown": "^1.2.6",
|
"mui-markdown": "^1.2.6",
|
||||||
"react": "^19.0.0",
|
"react": "^19.0.0",
|
||||||
"react-dom": "^19.0.0",
|
"react-dom": "^19.0.0",
|
||||||
@ -55,6 +56,7 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@types/plotly.js": "^2.35.5"
|
"@types/plotly.js": "^2.35.5",
|
||||||
|
"@craco/craco": "^0.0.0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,8 +26,8 @@ interface ConversationHandle {
|
|||||||
|
|
||||||
interface BackstoryMessage {
|
interface BackstoryMessage {
|
||||||
prompt: string;
|
prompt: string;
|
||||||
preamble: string;
|
preamble: {};
|
||||||
content: string;
|
full_content: string;
|
||||||
response: string;
|
response: string;
|
||||||
metadata: {
|
metadata: {
|
||||||
rag: { documents: [] };
|
rag: { documents: [] };
|
||||||
@ -138,6 +138,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
let filtered = [];
|
let filtered = [];
|
||||||
if (messageFilter === undefined) {
|
if (messageFilter === undefined) {
|
||||||
filtered = conversation;
|
filtered = conversation;
|
||||||
|
// console.log('No message filter provided. Using all messages.', filtered);
|
||||||
} else {
|
} else {
|
||||||
//console.log('Filtering conversation...')
|
//console.log('Filtering conversation...')
|
||||||
filtered = messageFilter(conversation); /* Do not copy conversation or useEffect will loop forever */
|
filtered = messageFilter(conversation); /* Do not copy conversation or useEffect will loop forever */
|
||||||
@ -206,8 +207,8 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
}, {
|
}, {
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
prompt: message.prompt || "",
|
prompt: message.prompt || "",
|
||||||
preamble: message.preamble || "",
|
preamble: message.preamble || {},
|
||||||
full_content: message.content || "",
|
full_content: message.full_content || "",
|
||||||
content: message.response || "",
|
content: message.response || "",
|
||||||
metadata: message.metadata,
|
metadata: message.metadata,
|
||||||
actions: message.actions,
|
actions: message.actions,
|
||||||
@ -403,52 +404,59 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
try {
|
try {
|
||||||
const update = JSON.parse(line);
|
const update = JSON.parse(line);
|
||||||
|
|
||||||
// Force an immediate state update based on the message type
|
switch (update.status) {
|
||||||
if (update.status === 'processing') {
|
case 'processing':
|
||||||
// Update processing message with immediate re-render
|
case 'thinking':
|
||||||
setProcessingMessage({ role: 'status', content: update.message });
|
// Force an immediate state update based on the message type
|
||||||
// Add a small delay to ensure React has time to update the UI
|
// Update processing message with immediate re-render
|
||||||
await new Promise(resolve => setTimeout(resolve, 0));
|
setProcessingMessage({ role: 'status', content: update.response });
|
||||||
} else if (update.status === 'done') {
|
// Add a small delay to ensure React has time to update the UI
|
||||||
// Replace processing message with final result
|
await new Promise(resolve => setTimeout(resolve, 0));
|
||||||
if (onResponse) {
|
break;
|
||||||
update.message = onResponse(update.message);
|
case 'done':
|
||||||
}
|
console.log('Done processing:', update);
|
||||||
setProcessingMessage(undefined);
|
// Replace processing message with final result
|
||||||
const backstoryMessage: BackstoryMessage = update.message;
|
if (onResponse) {
|
||||||
setConversation([
|
update.message = onResponse(update);
|
||||||
...conversationRef.current, {
|
}
|
||||||
role: 'user',
|
|
||||||
content: backstoryMessage.prompt || "",
|
|
||||||
}, {
|
|
||||||
role: 'assistant',
|
|
||||||
prompt: backstoryMessage.prompt || "",
|
|
||||||
preamble: backstoryMessage.preamble || "",
|
|
||||||
full_content: backstoryMessage.content || "",
|
|
||||||
content: backstoryMessage.response || "",
|
|
||||||
metadata: backstoryMessage.metadata,
|
|
||||||
actions: backstoryMessage.actions,
|
|
||||||
}] as MessageList);
|
|
||||||
// Add a small delay to ensure React has time to update the UI
|
|
||||||
await new Promise(resolve => setTimeout(resolve, 0));
|
|
||||||
|
|
||||||
const metadata = update.message.metadata;
|
|
||||||
if (metadata) {
|
|
||||||
const evalTPS = metadata.eval_count * 10 ** 9 / metadata.eval_duration;
|
|
||||||
const promptTPS = metadata.prompt_eval_count * 10 ** 9 / metadata.prompt_eval_duration;
|
|
||||||
setLastEvalTPS(evalTPS ? evalTPS : 35);
|
|
||||||
setLastPromptTPS(promptTPS ? promptTPS : 35);
|
|
||||||
updateContextStatus();
|
|
||||||
}
|
|
||||||
} else if (update.status === 'error') {
|
|
||||||
// Show error
|
|
||||||
setProcessingMessage({ role: 'error', content: update.message });
|
|
||||||
setTimeout(() => {
|
|
||||||
setProcessingMessage(undefined);
|
setProcessingMessage(undefined);
|
||||||
}, 5000);
|
const backstoryMessage: BackstoryMessage = update;
|
||||||
|
setConversation([
|
||||||
|
...conversationRef.current, {
|
||||||
|
// role: 'user',
|
||||||
|
// content: backstoryMessage.prompt || "",
|
||||||
|
// }, {
|
||||||
|
role: 'assistant',
|
||||||
|
origin: type,
|
||||||
|
content: backstoryMessage.response || "",
|
||||||
|
prompt: backstoryMessage.prompt || "",
|
||||||
|
preamble: backstoryMessage.preamble || {},
|
||||||
|
full_content: backstoryMessage.full_content || "",
|
||||||
|
metadata: backstoryMessage.metadata,
|
||||||
|
actions: backstoryMessage.actions,
|
||||||
|
}] as MessageList);
|
||||||
|
// Add a small delay to ensure React has time to update the UI
|
||||||
|
await new Promise(resolve => setTimeout(resolve, 0));
|
||||||
|
|
||||||
// Add a small delay to ensure React has time to update the UI
|
const metadata = update.metadata;
|
||||||
await new Promise(resolve => setTimeout(resolve, 0));
|
if (metadata) {
|
||||||
|
const evalTPS = metadata.eval_count * 10 ** 9 / metadata.eval_duration;
|
||||||
|
const promptTPS = metadata.prompt_eval_count * 10 ** 9 / metadata.prompt_eval_duration;
|
||||||
|
setLastEvalTPS(evalTPS ? evalTPS : 35);
|
||||||
|
setLastPromptTPS(promptTPS ? promptTPS : 35);
|
||||||
|
updateContextStatus();
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 'error':
|
||||||
|
// Show error
|
||||||
|
setProcessingMessage({ role: 'error', content: update.response });
|
||||||
|
setTimeout(() => {
|
||||||
|
setProcessingMessage(undefined);
|
||||||
|
}, 5000);
|
||||||
|
|
||||||
|
// Add a small delay to ensure React has time to update the UI
|
||||||
|
await new Promise(resolve => setTimeout(resolve, 0));
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
setSnack("Error processing query", "error")
|
setSnack("Error processing query", "error")
|
||||||
@ -462,25 +470,44 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
try {
|
try {
|
||||||
const update = JSON.parse(buffer);
|
const update = JSON.parse(buffer);
|
||||||
|
|
||||||
if (update.status === 'done') {
|
switch (update.status) {
|
||||||
if (onResponse) {
|
case 'processing':
|
||||||
update.message = onResponse(update.message);
|
case 'thinking':
|
||||||
}
|
// Force an immediate state update based on the message type
|
||||||
setProcessingMessage(undefined);
|
// Update processing message with immediate re-render
|
||||||
const backstoryMessage: BackstoryMessage = update.message;
|
setProcessingMessage({ role: 'status', content: update.response });
|
||||||
setConversation([
|
// Add a small delay to ensure React has time to update the UI
|
||||||
...conversationRef.current, {
|
await new Promise(resolve => setTimeout(resolve, 0));
|
||||||
role: 'user',
|
break;
|
||||||
content: backstoryMessage.prompt || "",
|
case 'error':
|
||||||
}, {
|
// Show error
|
||||||
role: 'assistant',
|
setProcessingMessage({ role: 'error', content: update.response });
|
||||||
prompt: backstoryMessage.prompt || "",
|
setTimeout(() => {
|
||||||
preamble: backstoryMessage.preamble || "",
|
setProcessingMessage(undefined);
|
||||||
full_content: backstoryMessage.content || "",
|
}, 5000);
|
||||||
content: backstoryMessage.response || "",
|
break;
|
||||||
metadata: backstoryMessage.metadata,
|
case 'done':
|
||||||
actions: backstoryMessage.actions,
|
console.log('Done processing:', update);
|
||||||
}] as MessageList);
|
if (onResponse) {
|
||||||
|
update.message = onResponse(update);
|
||||||
|
}
|
||||||
|
setProcessingMessage(undefined);
|
||||||
|
const backstoryMessage: BackstoryMessage = update;
|
||||||
|
setConversation([
|
||||||
|
...conversationRef.current, {
|
||||||
|
// role: 'user',
|
||||||
|
// content: backstoryMessage.prompt || "",
|
||||||
|
// }, {
|
||||||
|
role: 'assistant',
|
||||||
|
origin: type,
|
||||||
|
prompt: backstoryMessage.prompt || "",
|
||||||
|
content: backstoryMessage.response || "",
|
||||||
|
preamble: backstoryMessage.preamble || {},
|
||||||
|
full_content: backstoryMessage.full_content || "",
|
||||||
|
metadata: backstoryMessage.metadata,
|
||||||
|
actions: backstoryMessage.actions,
|
||||||
|
}] as MessageList);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
setSnack("Error processing query", "error")
|
setSnack("Error processing query", "error")
|
||||||
|
@ -19,6 +19,7 @@ import Typography from '@mui/material/Typography';
|
|||||||
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
||||||
import { ExpandMore } from './ExpandMore';
|
import { ExpandMore } from './ExpandMore';
|
||||||
import { SxProps, Theme } from '@mui/material';
|
import { SxProps, Theme } from '@mui/material';
|
||||||
|
import JsonView from '@uiw/react-json-view';
|
||||||
|
|
||||||
import { ChatBubble } from './ChatBubble';
|
import { ChatBubble } from './ChatBubble';
|
||||||
import { StyledMarkdown } from './StyledMarkdown';
|
import { StyledMarkdown } from './StyledMarkdown';
|
||||||
@ -32,6 +33,8 @@ type MessageRoles = 'info' | 'user' | 'assistant' | 'system' | 'status' | 'error
|
|||||||
type MessageData = {
|
type MessageData = {
|
||||||
role: MessageRoles,
|
role: MessageRoles,
|
||||||
content: string,
|
content: string,
|
||||||
|
full_content?: string,
|
||||||
|
|
||||||
disableCopy?: boolean,
|
disableCopy?: boolean,
|
||||||
user?: string,
|
user?: string,
|
||||||
title?: string,
|
title?: string,
|
||||||
@ -48,7 +51,6 @@ interface MessageMetaData {
|
|||||||
vector_embedding: number[];
|
vector_embedding: number[];
|
||||||
},
|
},
|
||||||
origin: string,
|
origin: string,
|
||||||
full_query?: string,
|
|
||||||
rag: any,
|
rag: any,
|
||||||
tools: any[],
|
tools: any[],
|
||||||
eval_count: number,
|
eval_count: number,
|
||||||
@ -87,7 +89,6 @@ interface MessageMetaProps {
|
|||||||
const MessageMeta = (props: MessageMetaProps) => {
|
const MessageMeta = (props: MessageMetaProps) => {
|
||||||
const {
|
const {
|
||||||
/* MessageData */
|
/* MessageData */
|
||||||
full_query,
|
|
||||||
rag,
|
rag,
|
||||||
tools,
|
tools,
|
||||||
eval_count,
|
eval_count,
|
||||||
@ -95,7 +96,7 @@ const MessageMeta = (props: MessageMetaProps) => {
|
|||||||
prompt_eval_count,
|
prompt_eval_count,
|
||||||
prompt_eval_duration,
|
prompt_eval_duration,
|
||||||
} = props.metadata || {};
|
} = props.metadata || {};
|
||||||
const messageProps = props.messageProps;
|
const message = props.messageProps.message;
|
||||||
|
|
||||||
return (<>
|
return (<>
|
||||||
<Box sx={{ fontSize: "0.8rem", mb: 1 }}>
|
<Box sx={{ fontSize: "0.8rem", mb: 1 }}>
|
||||||
@ -137,7 +138,7 @@ const MessageMeta = (props: MessageMetaProps) => {
|
|||||||
</TableContainer>
|
</TableContainer>
|
||||||
|
|
||||||
{
|
{
|
||||||
full_query !== undefined &&
|
message.full_content !== undefined &&
|
||||||
<Accordion>
|
<Accordion>
|
||||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||||
<Box sx={{ fontSize: "0.8rem" }}>
|
<Box sx={{ fontSize: "0.8rem" }}>
|
||||||
@ -145,7 +146,7 @@ const MessageMeta = (props: MessageMetaProps) => {
|
|||||||
</Box>
|
</Box>
|
||||||
</AccordionSummary>
|
</AccordionSummary>
|
||||||
<AccordionDetails>
|
<AccordionDetails>
|
||||||
<pre style={{ "display": "block", "position": "relative" }}><CopyBubble content={full_query?.trim()} />{full_query?.trim()}</pre>
|
<pre style={{ "display": "block", "position": "relative" }}><CopyBubble content={message.full_content?.trim()} />{message.full_content?.trim()}</pre>
|
||||||
</AccordionDetails>
|
</AccordionDetails>
|
||||||
</Accordion>
|
</Accordion>
|
||||||
}
|
}
|
||||||
@ -182,14 +183,18 @@ const MessageMeta = (props: MessageMetaProps) => {
|
|||||||
</Accordion>
|
</Accordion>
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
rag?.name !== undefined && <>
|
rag.map((rag: any) => (
|
||||||
<Accordion>
|
<Accordion key={rag.name}>
|
||||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||||
<Box sx={{ fontSize: "0.8rem" }}>
|
<Box sx={{ fontSize: "0.8rem" }}>
|
||||||
Top RAG {rag.ids.length} matches from '{rag.name}' collection against embedding vector of {rag.query_embedding.length} dimensions
|
Top RAG {rag.ids.length} matches from '{rag.name}' collection against embedding vector of {rag.query_embedding.length} dimensions
|
||||||
</Box>
|
</Box>
|
||||||
</AccordionSummary>
|
</AccordionSummary>
|
||||||
<AccordionDetails>
|
<AccordionDetails>
|
||||||
|
<Box sx={{ fontSize: "0.8rem" }}>
|
||||||
|
UMAP Vector Visualization of '{rag.name}' RAG
|
||||||
|
</Box>
|
||||||
|
<VectorVisualizer inline {...props.messageProps} {...props.metadata} rag={rag} />
|
||||||
{rag.ids.map((id: number, index: number) => <Box key={index}>
|
{rag.ids.map((id: number, index: number) => <Box key={index}>
|
||||||
{index !== 0 && <Divider />}
|
{index !== 0 && <Divider />}
|
||||||
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "row", mb: 0.5, mt: 0.5 }}>
|
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "row", mb: 0.5, mt: 0.5 }}>
|
||||||
@ -205,55 +210,33 @@ const MessageMeta = (props: MessageMetaProps) => {
|
|||||||
)}
|
)}
|
||||||
</AccordionDetails>
|
</AccordionDetails>
|
||||||
</Accordion>
|
</Accordion>
|
||||||
<Accordion>
|
))
|
||||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
|
||||||
<Box sx={{ fontSize: "0.8rem" }}>
|
|
||||||
UMAP Vector Visualization of RAG
|
|
||||||
</Box>
|
|
||||||
</AccordionSummary>
|
|
||||||
<AccordionDetails>
|
|
||||||
<VectorVisualizer inline {...messageProps} {...props.metadata} rag={rag} />
|
|
||||||
</AccordionDetails>
|
|
||||||
</Accordion>
|
|
||||||
<Accordion>
|
|
||||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
|
||||||
<Box sx={{ fontSize: "0.8rem" }}>
|
|
||||||
All response fields
|
|
||||||
</Box>
|
|
||||||
</AccordionSummary>
|
|
||||||
<AccordionDetails>
|
|
||||||
{Object.entries(props.messageProps.message)
|
|
||||||
.filter(([key, value]) => key !== undefined && value !== undefined)
|
|
||||||
.map(([key, value]) => (typeof (value) !== "string" || value?.trim() !== "") &&
|
|
||||||
<Accordion key={key}>
|
|
||||||
<AccordionSummary sx={{ fontSize: "1rem", fontWeight: "bold" }} expandIcon={<ExpandMoreIcon />}>
|
|
||||||
{key}
|
|
||||||
</AccordionSummary>
|
|
||||||
<AccordionDetails>
|
|
||||||
{key === "metadata" &&
|
|
||||||
Object.entries(value)
|
|
||||||
.filter(([key, value]) => key !== undefined && value !== undefined)
|
|
||||||
.map(([key, value]) => (
|
|
||||||
<Accordion key={`metadata.${key}`}>
|
|
||||||
<AccordionSummary sx={{ fontSize: "1rem", fontWeight: "bold" }} expandIcon={<ExpandMoreIcon />}>
|
|
||||||
{key}
|
|
||||||
</AccordionSummary>
|
|
||||||
<AccordionDetails>
|
|
||||||
<pre>{`${typeof (value) !== "object" ? value : JSON.stringify(value)}`}</pre>
|
|
||||||
</AccordionDetails>
|
|
||||||
</Accordion>
|
|
||||||
))}
|
|
||||||
{key !== "metadata" &&
|
|
||||||
<pre>{typeof (value) !== "object" ? value : JSON.stringify(value)}</pre>
|
|
||||||
}
|
|
||||||
</AccordionDetails>
|
|
||||||
</Accordion>
|
|
||||||
)}
|
|
||||||
</AccordionDetails>
|
|
||||||
</Accordion>
|
|
||||||
|
|
||||||
</>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
<Accordion>
|
||||||
|
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||||
|
<Box sx={{ fontSize: "0.8rem" }}>
|
||||||
|
All response fields
|
||||||
|
</Box>
|
||||||
|
</AccordionSummary>
|
||||||
|
<AccordionDetails>
|
||||||
|
{Object.entries(message)
|
||||||
|
.filter(([key, value]) => key !== undefined && value !== undefined)
|
||||||
|
.map(([key, value]) => (typeof (value) !== "string" || value?.trim() !== "") &&
|
||||||
|
<Accordion key={key}>
|
||||||
|
<AccordionSummary sx={{ fontSize: "1rem", fontWeight: "bold" }} expandIcon={<ExpandMoreIcon />}>
|
||||||
|
{key}
|
||||||
|
</AccordionSummary>
|
||||||
|
<AccordionDetails>
|
||||||
|
{typeof (value) === "string" ?
|
||||||
|
<pre>{value}</pre> :
|
||||||
|
<JsonView collapsed={1} value={value as any} style={{ fontSize: "0.8rem", maxHeight: "20rem", overflow: "auto" }} />
|
||||||
|
}
|
||||||
|
</AccordionDetails>
|
||||||
|
</Accordion>
|
||||||
|
)}
|
||||||
|
</AccordionDetails>
|
||||||
|
</Accordion>
|
||||||
</>);
|
</>);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -82,6 +82,7 @@ const emojiMap: Record<string, string> = {
|
|||||||
query: '🔍',
|
query: '🔍',
|
||||||
resume: '📄',
|
resume: '📄',
|
||||||
projects: '📁',
|
projects: '📁',
|
||||||
|
jobs: '📁',
|
||||||
'performance-reviews': '📄',
|
'performance-reviews': '📄',
|
||||||
news: '📰',
|
news: '📰',
|
||||||
};
|
};
|
||||||
@ -91,7 +92,8 @@ const colorMap: Record<string, string> = {
|
|||||||
resume: '#4A7A7D', // Dusty Teal — secondary theme color
|
resume: '#4A7A7D', // Dusty Teal — secondary theme color
|
||||||
projects: '#1A2536', // Midnight Blue — rich and deep
|
projects: '#1A2536', // Midnight Blue — rich and deep
|
||||||
news: '#D3CDBF', // Warm Gray — soft and neutral
|
news: '#D3CDBF', // Warm Gray — soft and neutral
|
||||||
'performance-reviews': '#FF0000', // Bright red
|
'performance-reviews': '#FFD0D0', // Light red
|
||||||
|
'jobs': '#F3aD8F', // Warm Gray — soft and neutral
|
||||||
};
|
};
|
||||||
|
|
||||||
const sizeMap: Record<string, number> = {
|
const sizeMap: Record<string, number> = {
|
||||||
@ -156,7 +158,7 @@ const VectorVisualizer: React.FC<VectorVisualizerProps> = (props: VectorVisualiz
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!result || !result.embeddings) return;
|
if (!result || !result.embeddings) return;
|
||||||
if (result.embeddings.length === 0) return;
|
if (result.embeddings.length === 0) return;
|
||||||
|
console.log('Result:', result);
|
||||||
const vectors: (number[])[] = [...result.embeddings];
|
const vectors: (number[])[] = [...result.embeddings];
|
||||||
const documents = [...result.documents || []];
|
const documents = [...result.documents || []];
|
||||||
const metadatas = [...result.metadatas || []];
|
const metadatas = [...result.metadatas || []];
|
||||||
|
@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
# Ensure input was provided
|
# Ensure input was provided
|
||||||
if [[ -z "$1" ]]; then
|
if [[ -z "$1" ]]; then
|
||||||
echo "Usage: $0 <path/to/python_script.py>"
|
TARGET=$(readlink -f "src/server.py")
|
||||||
exit 1
|
else
|
||||||
|
TARGET=$(readlink -f "$1")
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Resolve user-supplied path to absolute path
|
# Resolve user-supplied path to absolute path
|
||||||
TARGET=$(readlink -f "$1")
|
|
||||||
|
|
||||||
if [[ ! -f "$TARGET" ]]; then
|
if [[ ! -f "$TARGET" ]]; then
|
||||||
echo "Target file '$TARGET' not found."
|
echo "Target file '$TARGET' not found."
|
||||||
|
519
src/server.py
519
src/server.py
@ -1,3 +1,7 @@
|
|||||||
|
from utils import logger
|
||||||
|
|
||||||
|
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Imports [standard]
|
# Imports [standard]
|
||||||
# Standard library modules (no try-except needed)
|
# Standard library modules (no try-except needed)
|
||||||
@ -34,6 +38,7 @@ try_import("sklearn")
|
|||||||
import ollama
|
import ollama
|
||||||
import requests
|
import requests
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI, Request, BackgroundTasks
|
from fastapi import FastAPI, Request, BackgroundTasks
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
|
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@ -44,8 +49,10 @@ from sklearn.preprocessing import MinMaxScaler
|
|||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
rag as Rag,
|
rag as Rag,
|
||||||
Context, Conversation, Session, Message, Chat, Resume, JobDescription, FactCheck,
|
Context, Conversation, Message,
|
||||||
defines
|
Agent,
|
||||||
|
defines,
|
||||||
|
logger
|
||||||
)
|
)
|
||||||
|
|
||||||
from tools import (
|
from tools import (
|
||||||
@ -250,25 +257,6 @@ def parse_args():
|
|||||||
default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}")
|
default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
def setup_logging(level):
|
|
||||||
global logging
|
|
||||||
|
|
||||||
numeric_level = getattr(logging, level.upper(), None)
|
|
||||||
if not isinstance(numeric_level, int):
|
|
||||||
raise ValueError(f"Invalid log level: {level}")
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=numeric_level,
|
|
||||||
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
force=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now reduce verbosity for FastAPI, Uvicorn, Starlette
|
|
||||||
for noisy_logger in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi", "starlette"):
|
|
||||||
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
logging.info(f"Logging is set to {level} level.")
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
@ -288,10 +276,10 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
|
|||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||||
}
|
}
|
||||||
logging.info(f"Fetching {url}")
|
logger.info(f"Fetching {url}")
|
||||||
response = requests.get(url, headers=headers, timeout=10)
|
response = requests.get(url, headers=headers, timeout=10)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
logging.info(f"{url} returned. Processing...")
|
logger.info(f"{url} returned. Processing...")
|
||||||
# Parse the HTML
|
# Parse the HTML
|
||||||
soup = BeautifulSoup(response.text, "html.parser")
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
|
|
||||||
@ -313,7 +301,7 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
|
|||||||
text = text[:max_chars] + "..."
|
text = text[:max_chars] + "..."
|
||||||
|
|
||||||
# Create Ollama client
|
# Create Ollama client
|
||||||
# logging.info(f"Requesting summary of: {text}")
|
# logger.info(f"Requesting summary of: {text}")
|
||||||
|
|
||||||
# Generate summary using Ollama
|
# Generate summary using Ollama
|
||||||
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
||||||
@ -321,7 +309,7 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
|
|||||||
system="You are given the contents of {url}. Answer the question about the contents",
|
system="You are given the contents of {url}. Answer the question about the contents",
|
||||||
prompt=prompt)
|
prompt=prompt)
|
||||||
|
|
||||||
#logging.info(response["response"])
|
#logger.info(response["response"])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"source": "summarizer-llm",
|
"source": "summarizer-llm",
|
||||||
@ -359,8 +347,23 @@ def llm_tools(tools):
|
|||||||
|
|
||||||
# %%
|
# %%
|
||||||
class WebServer:
|
class WebServer:
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(self, app: FastAPI):
|
||||||
|
# Start the file watcher
|
||||||
|
self.observer, self.file_watcher = Rag.start_file_watcher(
|
||||||
|
llm=self.llm,
|
||||||
|
watch_directory=defines.doc_dir,
|
||||||
|
recreate=False # Don't recreate if exists
|
||||||
|
)
|
||||||
|
logger.info(f"API started with {self.file_watcher.collection.count()} documents in the collection")
|
||||||
|
yield
|
||||||
|
if self.observer:
|
||||||
|
self.observer.stop()
|
||||||
|
self.observer.join()
|
||||||
|
logger.info("File watcher stopped")
|
||||||
|
|
||||||
def __init__(self, llm, model=MODEL_NAME):
|
def __init__(self, llm, model=MODEL_NAME):
|
||||||
self.app = FastAPI()
|
self.app = FastAPI(lifespan=self.lifespan)
|
||||||
self.contexts = {}
|
self.contexts = {}
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -375,7 +378,7 @@ class WebServer:
|
|||||||
else:
|
else:
|
||||||
allow_origins=["http://battle-linux.ketrenos.com:3000"]
|
allow_origins=["http://battle-linux.ketrenos.com:3000"]
|
||||||
|
|
||||||
logging.info(f"Allowed origins: {allow_origins}")
|
logger.info(f"Allowed origins: {allow_origins}")
|
||||||
|
|
||||||
self.app.add_middleware(
|
self.app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
@ -385,38 +388,19 @@ class WebServer:
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@self.app.on_event("startup")
|
|
||||||
async def startup_event():
|
|
||||||
# Start the file watcher
|
|
||||||
self.observer, self.file_watcher = Rag.start_file_watcher(
|
|
||||||
llm=llm,
|
|
||||||
watch_directory=defines.doc_dir,
|
|
||||||
recreate=False # Don't recreate if exists
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"API started with {self.file_watcher.collection.count()} documents in the collection")
|
|
||||||
|
|
||||||
@self.app.on_event("shutdown")
|
|
||||||
async def shutdown_event():
|
|
||||||
if self.observer:
|
|
||||||
self.observer.stop()
|
|
||||||
self.observer.join()
|
|
||||||
print("File watcher stopped")
|
|
||||||
|
|
||||||
self.setup_routes()
|
self.setup_routes()
|
||||||
|
|
||||||
def setup_routes(self):
|
def setup_routes(self):
|
||||||
@self.app.get("/")
|
@self.app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
context = self.create_context()
|
context = self.create_context()
|
||||||
logging.info(f"Redirecting non-session to {context.id}")
|
logger.info(f"Redirecting non-context to {context.id}")
|
||||||
return RedirectResponse(url=f"/{context.id}", status_code=307)
|
return RedirectResponse(url=f"/{context.id}", status_code=307)
|
||||||
#return JSONResponse({"redirect": f"/{context.id}"})
|
#return JSONResponse({"redirect": f"/{context.id}"})
|
||||||
|
|
||||||
|
|
||||||
@self.app.put("/api/umap/{context_id}")
|
@self.app.put("/api/umap/{context_id}")
|
||||||
async def put_umap(context_id: str, request: Request):
|
async def put_umap(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
try:
|
try:
|
||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
raise Exception("File watcher not initialized")
|
raise Exception("File watcher not initialized")
|
||||||
@ -429,29 +413,36 @@ class WebServer:
|
|||||||
|
|
||||||
dimensions = data.get("dimensions", 2)
|
dimensions = data.get("dimensions", 2)
|
||||||
result = self.file_watcher.umap_collection
|
result = self.file_watcher.umap_collection
|
||||||
|
if not result:
|
||||||
|
return JSONResponse({"error": "No UMAP collection found"}, status_code=404)
|
||||||
if dimensions == 2:
|
if dimensions == 2:
|
||||||
logging.info("Returning 2D UMAP")
|
logger.info("Returning 2D UMAP")
|
||||||
umap_embedding = self.file_watcher.umap_embedding_2d
|
umap_embedding = self.file_watcher.umap_embedding_2d
|
||||||
else:
|
else:
|
||||||
logging.info("Returning 3D UMAP")
|
logger.info("Returning 3D UMAP")
|
||||||
umap_embedding = self.file_watcher.umap_embedding_3d
|
umap_embedding = self.file_watcher.umap_embedding_3d
|
||||||
|
|
||||||
|
if len(umap_embedding) == 0:
|
||||||
|
return JSONResponse({"error": "No UMAP embedding found"}, status_code=404)
|
||||||
|
|
||||||
result["embeddings"] = umap_embedding.tolist()
|
result["embeddings"] = umap_embedding.tolist()
|
||||||
|
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(e)
|
logger.error(f"put_umap error: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
return JSONResponse({"error": str(e)}, 500)
|
return JSONResponse({"error": str(e)}, 500)
|
||||||
|
|
||||||
@self.app.put("/api/similarity/{context_id}")
|
@self.app.put("/api/similarity/{context_id}")
|
||||||
async def put_similarity(context_id: str, request: Request):
|
async def put_similarity(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
return
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
if not is_valid_uuid(context_id):
|
if not is_valid_uuid(context_id):
|
||||||
logging.warning(f"Invalid context_id: {context_id}")
|
logger.warning(f"Invalid context_id: {context_id}")
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -468,13 +459,13 @@ class WebServer:
|
|||||||
return JSONResponse({"error": "No results found"}, status_code=404)
|
return JSONResponse({"error": "No results found"}, status_code=404)
|
||||||
|
|
||||||
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
||||||
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||||
|
|
||||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
||||||
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
logger.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
||||||
|
|
||||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
||||||
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
logger.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
||||||
|
|
||||||
return JSONResponse({
|
return JSONResponse({
|
||||||
**chroma_results,
|
**chroma_results,
|
||||||
@ -484,19 +475,19 @@ class WebServer:
|
|||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(e)
|
logger.error(e)
|
||||||
#return JSONResponse({"error": str(e)}, 500)
|
#return JSONResponse({"error": str(e)}, 500)
|
||||||
|
|
||||||
@self.app.put("/api/reset/{context_id}/{session_type}")
|
@self.app.put("/api/reset/{context_id}/{agent_type}")
|
||||||
async def put_reset(context_id: str, session_type: str, request: Request):
|
async def put_reset(context_id: str, agent_type: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
if not is_valid_uuid(context_id):
|
if not is_valid_uuid(context_id):
|
||||||
logging.warning(f"Invalid context_id: {context_id}")
|
logger.warning(f"Invalid context_id: {context_id}")
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
session = context.get_session(session_type)
|
agent = context.get_agent(agent_type)
|
||||||
if not session:
|
if not agent:
|
||||||
return JSONResponse({ "error": f"{session_type} is not recognized", "context": context.id }, status_code=404)
|
return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404)
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
try:
|
try:
|
||||||
@ -504,8 +495,8 @@ class WebServer:
|
|||||||
for reset_operation in data["reset"]:
|
for reset_operation in data["reset"]:
|
||||||
match reset_operation:
|
match reset_operation:
|
||||||
case "system_prompt":
|
case "system_prompt":
|
||||||
logging.info(f"Resetting {reset_operation}")
|
logger.info(f"Resetting {reset_operation}")
|
||||||
match session_type:
|
match agent_type:
|
||||||
case "chat":
|
case "chat":
|
||||||
prompt = system_message
|
prompt = system_message
|
||||||
case "job_description":
|
case "job_description":
|
||||||
@ -515,14 +506,14 @@ class WebServer:
|
|||||||
case "fact_check":
|
case "fact_check":
|
||||||
prompt = system_message
|
prompt = system_message
|
||||||
|
|
||||||
session.system_prompt = prompt
|
agent.system_prompt = prompt
|
||||||
response["system_prompt"] = { "system_prompt": prompt }
|
response["system_prompt"] = { "system_prompt": prompt }
|
||||||
case "rags":
|
case "rags":
|
||||||
logging.info(f"Resetting {reset_operation}")
|
logger.info(f"Resetting {reset_operation}")
|
||||||
context.rags = rags.copy()
|
context.rags = rags.copy()
|
||||||
response["rags"] = context.rags
|
response["rags"] = context.rags
|
||||||
case "tools":
|
case "tools":
|
||||||
logging.info(f"Resetting {reset_operation}")
|
logger.info(f"Resetting {reset_operation}")
|
||||||
context.tools = default_tools(tools)
|
context.tools = default_tools(tools)
|
||||||
response["tools"] = context.tools
|
response["tools"] = context.tools
|
||||||
case "history":
|
case "history":
|
||||||
@ -532,19 +523,19 @@ class WebServer:
|
|||||||
"fact_check": ("job_description", "resume", "fact_check"),
|
"fact_check": ("job_description", "resume", "fact_check"),
|
||||||
"chat": ("chat",),
|
"chat": ("chat",),
|
||||||
}
|
}
|
||||||
resets = reset_map.get(session_type, ())
|
resets = reset_map.get(agent_type, ())
|
||||||
|
|
||||||
for mode in resets:
|
for mode in resets:
|
||||||
tmp = context.get_session(mode)
|
tmp = context.get_agent(mode)
|
||||||
if not tmp:
|
if not tmp:
|
||||||
continue
|
continue
|
||||||
logging.info(f"Resetting {reset_operation} for {mode}")
|
logger.info(f"Resetting {reset_operation} for {mode}")
|
||||||
context.conversation = Conversation()
|
context.conversation = Conversation()
|
||||||
context.context_tokens = round(len(str(session.system_prompt)) * 3 / 4) # Estimate context usage
|
context.context_tokens = round(len(str(agent.system_prompt)) * 3 / 4) # Estimate context usage
|
||||||
response["history"] = []
|
response["history"] = []
|
||||||
response["context_used"] = session.context_tokens
|
response["context_used"] = agent.context_tokens
|
||||||
case "message_history_length":
|
case "message_history_length":
|
||||||
logging.info(f"Resetting {reset_operation}")
|
logger.info(f"Resetting {reset_operation}")
|
||||||
context.message_history_length = DEFAULT_HISTORY_LENGTH
|
context.message_history_length = DEFAULT_HISTORY_LENGTH
|
||||||
response["message_history_length"] = DEFAULT_HISTORY_LENGTH
|
response["message_history_length"] = DEFAULT_HISTORY_LENGTH
|
||||||
|
|
||||||
@ -559,13 +550,13 @@ class WebServer:
|
|||||||
|
|
||||||
@self.app.put("/api/tunables/{context_id}")
|
@self.app.put("/api/tunables/{context_id}")
|
||||||
async def put_tunables(context_id: str, request: Request):
|
async def put_tunables(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
try:
|
try:
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
session = context.get_session("chat")
|
agent = context.get_agent("chat")
|
||||||
if not session:
|
if not agent:
|
||||||
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
||||||
for k in data.keys():
|
for k in data.keys():
|
||||||
match k:
|
match k:
|
||||||
@ -600,7 +591,7 @@ class WebServer:
|
|||||||
system_prompt = data[k].strip()
|
system_prompt = data[k].strip()
|
||||||
if not system_prompt:
|
if not system_prompt:
|
||||||
return JSONResponse({ "status": "error", "message": "System prompt can not be empty." })
|
return JSONResponse({ "status": "error", "message": "System prompt can not be empty." })
|
||||||
session.system_prompt = system_prompt
|
agent.system_prompt = system_prompt
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse({ "system_prompt": system_prompt })
|
return JSONResponse({ "system_prompt": system_prompt })
|
||||||
case "message_history_length":
|
case "message_history_length":
|
||||||
@ -611,21 +602,21 @@ class WebServer:
|
|||||||
case _:
|
case _:
|
||||||
return JSONResponse({ "error": f"Unrecognized tunable {k}"}, status_code=404)
|
return JSONResponse({ "error": f"Unrecognized tunable {k}"}, status_code=404)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in put_tunables: {e}")
|
logger.error(f"Error in put_tunables: {e}")
|
||||||
return JSONResponse({"error": str(e)}, status_code=500)
|
return JSONResponse({"error": str(e)}, status_code=500)
|
||||||
|
|
||||||
@self.app.get("/api/tunables/{context_id}")
|
@self.app.get("/api/tunables/{context_id}")
|
||||||
async def get_tunables(context_id: str, request: Request):
|
async def get_tunables(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
if not is_valid_uuid(context_id):
|
if not is_valid_uuid(context_id):
|
||||||
logging.warning(f"Invalid context_id: {context_id}")
|
logger.warning(f"Invalid context_id: {context_id}")
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
session = context.get_session("chat")
|
agent = context.get_agent("chat")
|
||||||
if not session:
|
if not agent:
|
||||||
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
||||||
return JSONResponse({
|
return JSONResponse({
|
||||||
"system_prompt": session.system_prompt,
|
"system_prompt": agent.system_prompt,
|
||||||
"message_history_length": context.message_history_length,
|
"message_history_length": context.message_history_length,
|
||||||
"rags": context.rags,
|
"rags": context.rags,
|
||||||
"tools": [ {
|
"tools": [ {
|
||||||
@ -636,35 +627,34 @@ class WebServer:
|
|||||||
|
|
||||||
@self.app.get("/api/system-info/{context_id}")
|
@self.app.get("/api/system-info/{context_id}")
|
||||||
async def get_system_info(context_id: str, request: Request):
|
async def get_system_info(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
return JSONResponse(system_info(self.model))
|
return JSONResponse(system_info(self.model))
|
||||||
|
|
||||||
@self.app.post("/api/chat/{context_id}/{session_type}")
|
@self.app.post("/api/chat/{context_id}/{agent_type}")
|
||||||
async def post_chat_endpoint(context_id: str, session_type: str, request: Request):
|
async def post_chat_endpoint(context_id: str, agent_type: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
try:
|
try:
|
||||||
if not is_valid_uuid(context_id):
|
if not is_valid_uuid(context_id):
|
||||||
logging.warning(f"Invalid context_id: {context_id}")
|
logger.warning(f"Invalid context_id: {context_id}")
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
session = context.get_session(session_type)
|
agent = context.get_agent(agent_type)
|
||||||
if not session and session_type == "job_description":
|
if not agent and agent_type == "job_description":
|
||||||
logging.info(f"Session {session_type} not found. Returning empty history.")
|
logger.info(f"Agent {agent_type} not found. Returning empty history.")
|
||||||
# Create a new session if it doesn't exist
|
# Create a new agent if it doesn't exist
|
||||||
session = context.get_or_create_session("job_description", system_prompt=system_generate_resume, job_description=data["content"])
|
agent = context.get_or_create_agent("job_description", system_prompt=system_generate_resume, job_description=data["content"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(f"Attempt to create session type: {session_type} failed", e)
|
logger.info(f"Attempt to create agent type: {agent_type} failed", e)
|
||||||
return JSONResponse({ "error": f"{session_type} is not recognized", "context": context.id }, status_code=404)
|
return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404)
|
||||||
|
|
||||||
# Create a custom generator that ensures flushing
|
# Create a custom generator that ensures flushing
|
||||||
async def flush_generator():
|
async def flush_generator():
|
||||||
async for message in self.generate_response(context=context, session=session, content=data["content"]):
|
async for message in self.generate_response(context=context, agent=agent, content=data["content"]):
|
||||||
# Convert to JSON and add newline
|
# Convert to JSON and add newline
|
||||||
yield json.dumps(message) + "\n"
|
yield json.dumps(message.model_dump(mode='json')) + "\n"
|
||||||
# Save the history as its generated
|
# Save the history as its generated
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
# Explicitly flush after each yield
|
# Explicitly flush after each yield
|
||||||
@ -681,41 +671,43 @@ class WebServer:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in post_chat_endpoint: {e}")
|
logger.error(f"Error in post_chat_endpoint: {e}")
|
||||||
return JSONResponse({"error": str(e)}, status_code=500)
|
return JSONResponse({"error": str(e)}, status_code=500)
|
||||||
|
|
||||||
@self.app.post("/api/context")
|
@self.app.post("/api/context")
|
||||||
async def create_context():
|
async def create_context():
|
||||||
context = self.create_context()
|
context = self.create_context()
|
||||||
logging.info(f"Generated new session as {context.id}")
|
logger.info(f"Generated new agent as {context.id}")
|
||||||
return JSONResponse({ "id": context.id })
|
return JSONResponse({ "id": context.id })
|
||||||
|
|
||||||
@self.app.get("/api/history/{context_id}/{session_type}")
|
@self.app.get("/api/history/{context_id}/{agent_type}")
|
||||||
async def get_history(context_id: str, session_type: str, request: Request):
|
async def get_history(context_id: str, agent_type: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
try:
|
try:
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
session = context.get_session(session_type)
|
agent = context.get_agent(agent_type)
|
||||||
if not session:
|
if not agent:
|
||||||
logging.info(f"Session {session_type} not found. Returning empty history.")
|
logger.info(f"Agent {agent_type} not found. Returning empty history.")
|
||||||
return JSONResponse({ "messages": [] })
|
return JSONResponse({ "messages": [] })
|
||||||
logging.info(f"History for {session_type} contains {len(session.conversation.messages)} entries.")
|
logger.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.")
|
||||||
return session.conversation
|
return agent.conversation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in get_history: {e}")
|
logger.error(f"get_history error: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
return JSONResponse({"error": str(e)}, status_code=404)
|
return JSONResponse({"error": str(e)}, status_code=404)
|
||||||
|
|
||||||
@self.app.get("/api/tools/{context_id}")
|
@self.app.get("/api/tools/{context_id}")
|
||||||
async def get_tools(context_id: str, request: Request):
|
async def get_tools(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
return JSONResponse(context.tools)
|
return JSONResponse(context.tools)
|
||||||
|
|
||||||
@self.app.put("/api/tools/{context_id}")
|
@self.app.put("/api/tools/{context_id}")
|
||||||
async def put_tools(context_id: str, request: Request):
|
async def put_tools(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
if not is_valid_uuid(context_id):
|
if not is_valid_uuid(context_id):
|
||||||
logging.warning(f"Invalid context_id: {context_id}")
|
logger.warning(f"Invalid context_id: {context_id}")
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
try:
|
try:
|
||||||
@ -732,17 +724,17 @@ class WebServer:
|
|||||||
return JSONResponse({ "status": "error" }, 405)
|
return JSONResponse({ "status": "error" }, 405)
|
||||||
|
|
||||||
|
|
||||||
@self.app.get("/api/context-status/{context_id}/{session_type}")
|
@self.app.get("/api/context-status/{context_id}/{agent_type}")
|
||||||
async def get_context_status(context_id, session_type: str, request: Request):
|
async def get_context_status(context_id, agent_type: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
if not is_valid_uuid(context_id):
|
if not is_valid_uuid(context_id):
|
||||||
logging.warning(f"Invalid context_id: {context_id}")
|
logger.warning(f"Invalid context_id: {context_id}")
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
session = context.get_session(session_type)
|
agent = context.get_agent(agent_type)
|
||||||
if not session:
|
if not agent:
|
||||||
return JSONResponse({"context_used": 0, "max_context": defines.max_context})
|
return JSONResponse({"context_used": 0, "max_context": defines.max_context})
|
||||||
return JSONResponse({"context_used": session.context_tokens, "max_context": defines.max_context})
|
return JSONResponse({"context_used": agent.context_tokens, "max_context": defines.max_context})
|
||||||
|
|
||||||
@self.app.get("/api/health")
|
@self.app.get("/api/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
@ -752,57 +744,80 @@ class WebServer:
|
|||||||
async def serve_static(path: str):
|
async def serve_static(path: str):
|
||||||
full_path = os.path.join(defines.static_content, path)
|
full_path = os.path.join(defines.static_content, path)
|
||||||
if os.path.exists(full_path) and os.path.isfile(full_path):
|
if os.path.exists(full_path) and os.path.isfile(full_path):
|
||||||
logging.info(f"Serve static request for {full_path}")
|
logger.info(f"Serve static request for {full_path}")
|
||||||
return FileResponse(full_path)
|
return FileResponse(full_path)
|
||||||
logging.info(f"Serve index.html for {path}")
|
logger.info(f"Serve index.html for {path}")
|
||||||
return FileResponse(os.path.join(defines.static_content, "index.html"))
|
return FileResponse(os.path.join(defines.static_content, "index.html"))
|
||||||
|
|
||||||
def save_context(self, session_id):
|
def save_context(self, context_id):
|
||||||
"""
|
"""
|
||||||
Serialize a Python dictionary to a file in the sessions directory.
|
Serialize a Python dictionary to a file in the agents directory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Dictionary containing the session data
|
data: Dictionary containing the agent data
|
||||||
session_id: UUID string for the context. If it doesn't exist, it is created
|
context_id: UUID string for the context. If it doesn't exist, it is created
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The session_id used for the file
|
The context_id used for the file
|
||||||
"""
|
"""
|
||||||
context = self.upsert_context(session_id)
|
context = self.upsert_context(context_id)
|
||||||
|
|
||||||
# Create sessions directory if it doesn't exist
|
# Create agents directory if it doesn't exist
|
||||||
if not os.path.exists(defines.session_dir):
|
if not os.path.exists(defines.context_dir):
|
||||||
os.makedirs(defines.session_dir)
|
os.makedirs(defines.context_dir)
|
||||||
|
|
||||||
# Create the full file path
|
# Create the full file path
|
||||||
file_path = os.path.join(defines.session_dir, session_id)
|
file_path = os.path.join(defines.context_dir, context_id)
|
||||||
|
|
||||||
# Serialize the data to JSON and write to file
|
# Serialize the data to JSON and write to file
|
||||||
with open(file_path, "w") as f:
|
with open(file_path, "w") as f:
|
||||||
f.write(context.model_dump_json())
|
f.write(context.model_dump_json())
|
||||||
|
|
||||||
return session_id
|
return context_id
|
||||||
|
|
||||||
def load_context(self, session_id) -> Context:
|
def load_or_create_context(self, context_id) -> Context:
|
||||||
"""
|
"""
|
||||||
Load a context from a file in the sessions directory.
|
Load a context from a file in the context directory or create a new one if it doesn't exist.
|
||||||
Args:
|
Args:
|
||||||
session_id: UUID string for the context. If it doesn't exist, a new context is created.
|
context_id: UUID string for the context.
|
||||||
Returns:
|
Returns:
|
||||||
A Context object with the specified ID and default settings.
|
A Context object with the specified ID and default settings.
|
||||||
"""
|
"""
|
||||||
|
if not self.file_watcher:
|
||||||
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
file_path = os.path.join(defines.session_dir, session_id)
|
file_path = os.path.join(defines.context_dir, context_id)
|
||||||
|
|
||||||
# Check if the file exists
|
# Check if the file exists
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
self.contexts[session_id] = self.create_context(session_id)
|
logger.info(f"Context file {file_path} not found. Creating new context.")
|
||||||
|
self.contexts[context_id] = self.create_context(context_id)
|
||||||
else:
|
else:
|
||||||
# Read and deserialize the data
|
# Read and deserialize the data
|
||||||
with open(file_path, "r") as f:
|
with open(file_path, "r") as f:
|
||||||
self.contexts[session_id] = Context.model_validate_json(f.read())
|
content = f.read()
|
||||||
|
logger.info(f"Loading context from {file_path}, content length: {len(content)}")
|
||||||
|
try:
|
||||||
|
# Try parsing as JSON first to ensure valid JSON
|
||||||
|
import json
|
||||||
|
json_data = json.loads(content)
|
||||||
|
logger.info("JSON parsed successfully, attempting model validation")
|
||||||
|
|
||||||
|
# Now try Pydantic validation
|
||||||
|
self.contexts[context_id] = Context.model_validate_json(content)
|
||||||
|
self.contexts[context_id].file_watcher=self.file_watcher
|
||||||
|
|
||||||
|
logger.info(f"Successfully loaded context {context_id}")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Invalid JSON in file: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error validating context: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
# Fallback to creating a new context
|
||||||
|
self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher)
|
||||||
|
|
||||||
return self.contexts[session_id]
|
return self.contexts[context_id]
|
||||||
|
|
||||||
def create_context(self, context_id = None) -> Context:
|
def create_context(self, context_id = None) -> Context:
|
||||||
"""
|
"""
|
||||||
@ -812,18 +827,24 @@ class WebServer:
|
|||||||
Returns:
|
Returns:
|
||||||
A Context object with the specified ID and default settings.
|
A Context object with the specified ID and default settings.
|
||||||
"""
|
"""
|
||||||
context = Context(id=context_id)
|
if not self.file_watcher:
|
||||||
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
|
logger.info(f"Creating new context with ID: {context_id}")
|
||||||
|
context = Context(id=context_id, file_watcher=self.file_watcher)
|
||||||
|
|
||||||
if os.path.exists(defines.resume_doc):
|
if os.path.exists(defines.resume_doc):
|
||||||
context.user_resume = open(defines.resume_doc, "r").read()
|
context.user_resume = open(defines.resume_doc, "r").read()
|
||||||
context.add_session(Chat(system_prompt = system_message))
|
context.get_or_create_agent(
|
||||||
# context.add_session(Resume(system_prompt = system_generate_resume))
|
agent_type="chat",
|
||||||
# context.add_session(JobDescription(system_prompt = system_job_description))
|
system_prompt=system_message)
|
||||||
# context.add_session(FactCheck(system_prompt = system_fact_check))
|
# context.add_agent(Resume(system_prompt = system_generate_resume))
|
||||||
|
# context.add_agent(JobDescription(system_prompt = system_job_description))
|
||||||
|
# context.add_agent(FactCheck(system_prompt = system_fact_check))
|
||||||
context.tools = default_tools(tools)
|
context.tools = default_tools(tools)
|
||||||
context.rags = rags.copy()
|
context.rags = rags.copy()
|
||||||
|
|
||||||
logging.info(f"{context.id} created and added to sessions.")
|
logger.info(f"{context.id} created and added to contexts.")
|
||||||
self.contexts[context.id] = context
|
self.contexts[context.id] = context
|
||||||
self.save_context(context.id)
|
self.save_context(context.id)
|
||||||
return context
|
return context
|
||||||
@ -905,44 +926,42 @@ class WebServer:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if not context_id:
|
if not context_id:
|
||||||
logging.warning("No context ID provided. Creating a new context.")
|
logger.warning("No context ID provided. Creating a new context.")
|
||||||
return self.create_context()
|
return self.create_context()
|
||||||
|
|
||||||
if not is_valid_uuid(context_id):
|
|
||||||
logging.info(f"User requested invalid context_id: {context_id}")
|
|
||||||
raise ValueError("Invalid context_id: {context_id}")
|
|
||||||
|
|
||||||
if context_id in self.contexts:
|
if context_id in self.contexts:
|
||||||
return self.contexts[context_id]
|
return self.contexts[context_id]
|
||||||
|
|
||||||
logging.info(f"Context {context_id} not found. Creating new context.")
|
logger.info(f"Context {context_id} is not yet loaded.")
|
||||||
return self.load_context(context_id)
|
return self.load_or_create_context(context_id)
|
||||||
|
|
||||||
def generate_rag_results(self, context, content):
|
def generate_rag_results(self, context, content):
|
||||||
|
if not self.file_watcher:
|
||||||
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
results_found = False
|
results_found = False
|
||||||
|
|
||||||
if self.file_watcher:
|
for rag in context.rags:
|
||||||
for rag in context.rags:
|
if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now...
|
||||||
if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now...
|
yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."}
|
||||||
yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."}
|
chroma_results = self.file_watcher.find_similar(query=content, top_k=10)
|
||||||
chroma_results = self.file_watcher.find_similar(query=content, top_k=10)
|
if chroma_results:
|
||||||
if chroma_results:
|
results_found = True
|
||||||
results_found = True
|
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
||||||
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||||
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
|
||||||
|
|
||||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
||||||
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
logger.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
||||||
|
|
||||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
||||||
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
logger.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
**chroma_results,
|
**chroma_results,
|
||||||
"name": rag["name"],
|
"name": rag["name"],
|
||||||
"umap_embedding_2d": umap_2d,
|
"umap_embedding_2d": umap_2d,
|
||||||
"umap_embedding_3d": umap_3d
|
"umap_embedding_3d": umap_3d
|
||||||
}
|
}
|
||||||
|
|
||||||
if not results_found:
|
if not results_found:
|
||||||
yield {"status": "complete", "message": "No RAG context found"}
|
yield {"status": "complete", "message": "No RAG context found"}
|
||||||
@ -956,35 +975,51 @@ class WebServer:
|
|||||||
else:
|
else:
|
||||||
yield {"status": "complete", "message": "RAG processing complete"}
|
yield {"status": "complete", "message": "RAG processing complete"}
|
||||||
|
|
||||||
# session_type: chat
|
async def generate_response(self, context : Context, agent : Agent, content : str) -> AsyncGenerator[Message, None]:
|
||||||
# * Q&A
|
|
||||||
#
|
|
||||||
# session_type: job_description
|
|
||||||
# * First message sets Job Description and generates Resume
|
|
||||||
# * Has content (Job Description)
|
|
||||||
# * Then Q&A of Job Description
|
|
||||||
#
|
|
||||||
# session_type: resume
|
|
||||||
# * First message sets Resume and generates Fact Check
|
|
||||||
# * Has no content
|
|
||||||
# * Then Q&A of Resume
|
|
||||||
#
|
|
||||||
# Fact Check:
|
|
||||||
# * First message sets Fact Check and is Q&A
|
|
||||||
# * Has content
|
|
||||||
# * Then Q&A of Fact Check
|
|
||||||
async def generate_response(self, context : Context, session : Session, content : str):
|
|
||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
|
agent_type = agent.get_agent_type()
|
||||||
|
logger.info(f"generate_response: {agent_type}")
|
||||||
|
if agent_type == "chat":
|
||||||
|
message = Message(prompt=content)
|
||||||
|
async for message in agent.prepare_message(message):
|
||||||
|
# logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}")
|
||||||
|
if message.status == "error":
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
if message.status != "done":
|
||||||
|
yield message
|
||||||
|
async for message in agent.process_message(self.llm, self.model, message):
|
||||||
|
# logger.info(f"{agent_type}.process_message: {value.status} - {value.response}")
|
||||||
|
if message.status == "error":
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
if message.status != "done":
|
||||||
|
yield message
|
||||||
|
# async for value in agent.generate_llm_response(message):
|
||||||
|
# logger.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}")
|
||||||
|
# if value.status != "done":
|
||||||
|
# yield value
|
||||||
|
# if value.status == "error":
|
||||||
|
# message.status = "error"
|
||||||
|
# message.response = value.response
|
||||||
|
# yield message
|
||||||
|
# return
|
||||||
|
logger.info("TODO: There is more to do...")
|
||||||
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
if self.processing:
|
if self.processing:
|
||||||
logging.info("TODO: Implement delay queing; busy for same session, otherwise return queue size and estimated wait time")
|
logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||||
yield {"status": "error", "message": "Busy processing another request."}
|
yield {"status": "error", "message": "Busy processing another request."}
|
||||||
return
|
return
|
||||||
|
|
||||||
self.processing = True
|
self.processing = True
|
||||||
|
|
||||||
conversation : Conversation = session.conversation
|
conversation : Conversation = agent.conversation
|
||||||
|
|
||||||
message = Message(prompt=content)
|
message = Message(prompt=content)
|
||||||
del content # Prevent accidental use of content
|
del content # Prevent accidental use of content
|
||||||
@ -999,36 +1034,36 @@ class WebServer:
|
|||||||
enable_rag = False
|
enable_rag = False
|
||||||
|
|
||||||
# RAG is disabled when asking questions about the resume
|
# RAG is disabled when asking questions about the resume
|
||||||
if session.session_type == "resume":
|
if agent.get_agent_type() == "resume":
|
||||||
enable_rag = False
|
enable_rag = False
|
||||||
|
|
||||||
# The first time through each session session_type a content_seed may be set for
|
# The first time through each agent agent_type a content_seed may be set for
|
||||||
# future chat sessions; use it once, then clear it
|
# future chat agents; use it once, then clear it
|
||||||
message.preamble = session.get_and_reset_content_seed()
|
message.preamble = agent.get_and_reset_content_seed()
|
||||||
system_prompt = session.system_prompt
|
system_prompt = agent.system_prompt
|
||||||
|
|
||||||
# After the first time a particular session session_type is used, it is handled as a chat.
|
# After the first time a particular agent agent_type is used, it is handled as a chat.
|
||||||
# The number of messages indicating the session is ready for chat varies based on
|
# The number of messages indicating the agent is ready for chat varies based on
|
||||||
# the session_type of session
|
# the agent_type of agent
|
||||||
process_type = session.session_type
|
process_type = agent.get_agent_type()
|
||||||
match process_type:
|
match process_type:
|
||||||
case "job_description":
|
case "job_description":
|
||||||
logging.info(f"job_description user_history len: {len(conversation.messages)}")
|
logger.info(f"job_description user_history len: {len(conversation.messages)}")
|
||||||
if len(conversation.messages) >= 2: # USER, ASSISTANT
|
if len(conversation.messages) >= 2: # USER, ASSISTANT
|
||||||
process_type = "chat"
|
process_type = "chat"
|
||||||
case "resume":
|
case "resume":
|
||||||
logging.info(f"resume user_history len: {len(conversation.messages)}")
|
logger.info(f"resume user_history len: {len(conversation.messages)}")
|
||||||
if len(conversation.messages) >= 3: # USER, ASSISTANT, FACT_CHECK
|
if len(conversation.messages) >= 3: # USER, ASSISTANT, FACT_CHECK
|
||||||
process_type = "chat"
|
process_type = "chat"
|
||||||
case "fact_check":
|
case "fact_check":
|
||||||
process_type = "chat" # Fact Check is always a chat session
|
process_type = "chat" # Fact Check is always a chat agent
|
||||||
|
|
||||||
match process_type:
|
match process_type:
|
||||||
# Normal chat interactions with context history
|
# Normal chat interactions with context history
|
||||||
case "chat":
|
case "chat":
|
||||||
if not message.prompt:
|
if not message.prompt:
|
||||||
yield {"status": "error", "message": "No query provided for chat."}
|
yield {"status": "error", "message": "No query provided for chat."}
|
||||||
logging.info(f"user_history len: {len(conversation.messages)}")
|
logger.info(f"user_history len: {len(conversation.messages)}")
|
||||||
self.processing = False
|
self.processing = False
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -1071,7 +1106,7 @@ class WebServer:
|
|||||||
Use that information to respond to:"""
|
Use that information to respond to:"""
|
||||||
|
|
||||||
# Use the mode specific system_prompt instead of 'chat'
|
# Use the mode specific system_prompt instead of 'chat'
|
||||||
system_prompt = session.system_prompt
|
system_prompt = agent.system_prompt
|
||||||
|
|
||||||
# On first entry, a single job_description is provided ("user")
|
# On first entry, a single job_description is provided ("user")
|
||||||
# Generate a resume to append to RESUME history
|
# Generate a resume to append to RESUME history
|
||||||
@ -1110,10 +1145,10 @@ Use that information to respond to:"""
|
|||||||
<|job_description|>
|
<|job_description|>
|
||||||
{message.prompt}
|
{message.prompt}
|
||||||
"""
|
"""
|
||||||
tmp = context.get_session("job_description")
|
tmp = context.get_agent("job_description")
|
||||||
if not tmp:
|
if not tmp:
|
||||||
raise Exception(f"Job description session not found.")
|
raise Exception(f"Job description agent not found.")
|
||||||
# Set the content seed for the job_description session
|
# Set the content seed for the job_description agent
|
||||||
tmp.set_content_seed(message.preamble + "<|question|>\nUse the above information to respond to this prompt: ")
|
tmp.set_content_seed(message.preamble + "<|question|>\nUse the above information to respond to this prompt: ")
|
||||||
|
|
||||||
message.preamble += f"""
|
message.preamble += f"""
|
||||||
@ -1126,7 +1161,7 @@ Use to the above information to respond to this prompt:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# For all future calls to job_description, use the system_job_description
|
# For all future calls to job_description, use the system_job_description
|
||||||
session.system_prompt = system_job_description
|
agent.system_prompt = system_job_description
|
||||||
|
|
||||||
# Seed the history for job_description
|
# Seed the history for job_description
|
||||||
stuffingMessage = Message(prompt=message.prompt)
|
stuffingMessage = Message(prompt=message.prompt)
|
||||||
@ -1137,21 +1172,21 @@ Use to the above information to respond to this prompt:
|
|||||||
|
|
||||||
message.add_action("generate_resume")
|
message.add_action("generate_resume")
|
||||||
|
|
||||||
logging.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into session'resume'")
|
logger.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into agent'resume'")
|
||||||
logging.info("TODO: For subsequent runs, have the Session handler generate the follow up prompts so they can have correct context preamble")
|
logger.info("TODO: For subsequent runs, have the Agent handler generate the follow up prompts so they can have correct context preamble")
|
||||||
|
|
||||||
# Switch to resume session for LLM responses
|
# Switch to resume agent for LLM responses
|
||||||
# message.metadata["origin"] = "resume"
|
# message.metadata["origin"] = "resume"
|
||||||
# session = context.get_or_create_session("resume")
|
# agent = context.get_or_create_agent("resume")
|
||||||
# system_prompt = session.system_prompt
|
# system_prompt = agent.system_prompt
|
||||||
# llm_history = session.llm_history = []
|
# llm_history = agent.llm_history = []
|
||||||
# user_history = session.user_history = []
|
# user_history = agent.user_history = []
|
||||||
|
|
||||||
# Ignore the passed in content and invoke Fact Check
|
# Ignore the passed in content and invoke Fact Check
|
||||||
case "resume":
|
case "resume":
|
||||||
if len(context.get_or_create_session("resume").conversation.messages) < 2: # USER, **ASSISTANT**
|
if len(context.get_or_create_agent("resume").conversation.messages) < 2: # USER, **ASSISTANT**
|
||||||
raise Exception(f"No resume found in user history.")
|
raise Exception(f"No resume found in user history.")
|
||||||
resume = context.get_or_create_session("resume").conversation.messages[1]
|
resume = context.get_or_create_agent("resume").conversation.messages[1]
|
||||||
|
|
||||||
# Generate RAG content if enabled, based on the content
|
# Generate RAG content if enabled, based on the content
|
||||||
rag_context = ""
|
rag_context = ""
|
||||||
@ -1196,7 +1231,7 @@ Use to the above information to respond to this prompt:
|
|||||||
<|question|>
|
<|question|>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
context.get_or_create_session("resume").set_content_seed(f"""
|
context.get_or_create_agent("resume").set_content_seed(f"""
|
||||||
<|resume|>
|
<|resume|>
|
||||||
{resume["content"]}
|
{resume["content"]}
|
||||||
|
|
||||||
@ -1218,29 +1253,29 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
stuffingMessage.metadata["origin"] = "resume"
|
stuffingMessage.metadata["origin"] = "resume"
|
||||||
stuffingMessage.metadata["display"] = "hide"
|
stuffingMessage.metadata["display"] = "hide"
|
||||||
stuffingMessage.actions = [ "fact_check" ]
|
stuffingMessage.actions = [ "fact_check" ]
|
||||||
logging.info("TODO: Switch this to use actions to keep the UI from showingit")
|
logger.info("TODO: Switch this to use actions to keep the UI from showingit")
|
||||||
conversation.add_message(stuffingMessage)
|
conversation.add_message(stuffingMessage)
|
||||||
|
|
||||||
# For all future calls to job_description, use the system_job_description
|
# For all future calls to job_description, use the system_job_description
|
||||||
logging.info("TODO: Create a system_resume_QA prompt to use for the resume session")
|
logger.info("TODO: Create a system_resume_QA prompt to use for the resume agent")
|
||||||
session.system_prompt = system_prompt
|
agent.system_prompt = system_prompt
|
||||||
|
|
||||||
# Switch to fact_check session for LLM responses
|
# Switch to fact_check agent for LLM responses
|
||||||
message.metadata["origin"] = "fact_check"
|
message.metadata["origin"] = "fact_check"
|
||||||
session = context.get_or_create_session("fact_check", system_prompt=system_fact_check)
|
agent = context.get_or_create_agent("fact_check", system_prompt=system_fact_check)
|
||||||
|
|
||||||
llm_history = session.llm_history = []
|
llm_history = agent.llm_history = []
|
||||||
user_history = session.user_history = []
|
user_history = agent.user_history = []
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
raise Exception(f"Invalid chat session_type: {session_type}")
|
raise Exception(f"Invalid chat agent_type: {agent_type}")
|
||||||
|
|
||||||
conversation.add_message(message)
|
conversation.add_message(message)
|
||||||
# llm_history.append({"role": "user", "content": message.preamble + content})
|
# llm_history.append({"role": "user", "content": message.preamble + content})
|
||||||
# user_history.append({"role": "user", "content": content, "origin": message.metadata["origin"]})
|
# user_history.append({"role": "user", "content": content, "origin": message.metadata["origin"]})
|
||||||
# message.metadata["full_query"] = llm_history[-1]["content"]
|
# message.metadata["full_query"] = llm_history[-1]["content"]
|
||||||
|
|
||||||
# Uses cached system_prompt as session.system_prompt may have been updated for follow up questions
|
# Uses cached system_prompt as agent.system_prompt may have been updated for follow up questions
|
||||||
messages = create_system_message(system_prompt)
|
messages = create_system_message(system_prompt)
|
||||||
if context.message_history_length:
|
if context.message_history_length:
|
||||||
to_add = conversation.messages[-context.message_history_length:]
|
to_add = conversation.messages[-context.message_history_length:]
|
||||||
@ -1272,12 +1307,12 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
{message.prompt}"""
|
{message.prompt}"""
|
||||||
|
|
||||||
# Estimate token length of new messages
|
# Estimate token length of new messages
|
||||||
ctx_size = self.get_optimal_ctx_size(context.get_or_create_session(process_type).context_tokens, messages=message.prompt)
|
ctx_size = self.get_optimal_ctx_size(context.get_or_create_agent(process_type).context_tokens, messages=message.prompt)
|
||||||
|
|
||||||
if len(conversation.messages) > 2:
|
if len(conversation.messages) > 2:
|
||||||
processing_message = f"Processing {'RAG augmented ' if enable_rag else ''}query..."
|
processing_message = f"Processing {'RAG augmented ' if enable_rag else ''}query..."
|
||||||
else:
|
else:
|
||||||
match session.session_type:
|
match agent.get_agent_type():
|
||||||
case "job_description":
|
case "job_description":
|
||||||
processing_message = f"Generating {'RAG augmented ' if enable_rag else ''}resume..."
|
processing_message = f"Generating {'RAG augmented ' if enable_rag else ''}resume..."
|
||||||
case "resume":
|
case "resume":
|
||||||
@ -1294,7 +1329,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
else:
|
else:
|
||||||
response = self.llm.chat(model=self.model, messages=messages, options={ "num_ctx": ctx_size })
|
response = self.llm.chat(model=self.model, messages=messages, options={ "num_ctx": ctx_size })
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception({ "model": self.model, "error": str(e) })
|
logger.exception({ "model": self.model, "error": str(e) })
|
||||||
yield {"status": "error", "message": f"An error occurred communicating with LLM"}
|
yield {"status": "error", "message": f"An error occurred communicating with LLM"}
|
||||||
self.processing = False
|
self.processing = False
|
||||||
return
|
return
|
||||||
@ -1303,7 +1338,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
message.metadata["eval_duration"] += response["eval_duration"]
|
message.metadata["eval_duration"] += response["eval_duration"]
|
||||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||||
session.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||||
|
|
||||||
tools_used = []
|
tools_used = []
|
||||||
|
|
||||||
@ -1347,7 +1382,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
message.metadata["tools"] = tools_used
|
message.metadata["tools"] = tools_used
|
||||||
|
|
||||||
# Estimate token length of new messages
|
# Estimate token length of new messages
|
||||||
ctx_size = self.get_optimal_ctx_size(session.context_tokens, messages=messages[pre_add_index:])
|
ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
|
||||||
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
|
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
|
||||||
# Decrease creativity when processing tool call requests
|
# Decrease creativity when processing tool call requests
|
||||||
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
|
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
|
||||||
@ -1355,11 +1390,11 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
message.metadata["eval_duration"] += response["eval_duration"]
|
message.metadata["eval_duration"] += response["eval_duration"]
|
||||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||||
session.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||||
|
|
||||||
reply = response["message"]["content"]
|
reply = response["message"]["content"]
|
||||||
message.response = reply
|
message.response = reply
|
||||||
message.metadata["origin"] = session.session_type
|
message.metadata["origin"] = agent.get_agent_type()
|
||||||
# final_message = {"role": "assistant", "content": reply }
|
# final_message = {"role": "assistant", "content": reply }
|
||||||
|
|
||||||
# # history is provided to the LLM and should not have additional metadata
|
# # history is provided to the LLM and should not have additional metadata
|
||||||
@ -1379,7 +1414,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# logging.exception({ "model": self.model, "origin": session_type, "content": content, "error": str(e) })
|
# logger.exception({ "model": self.model, "origin": agent_type, "content": content, "error": str(e) })
|
||||||
# yield {"status": "error", "message": f"An error occurred: {str(e)}"}
|
# yield {"status": "error", "message": f"An error occurred: {str(e)}"}
|
||||||
|
|
||||||
# finally:
|
# finally:
|
||||||
@ -1390,7 +1425,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs):
|
def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs):
|
||||||
try:
|
try:
|
||||||
if self.ssl_enabled:
|
if self.ssl_enabled:
|
||||||
logging.info(f"Starting web server at https://{host}:{port}")
|
logger.info(f"Starting web server at https://{host}:{port}")
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
self.app,
|
self.app,
|
||||||
host=host,
|
host=host,
|
||||||
@ -1400,7 +1435,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
ssl_certfile=defines.cert_path
|
ssl_certfile=defines.cert_path
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info(f"Starting web server at http://{host}:{port}")
|
logger.info(f"Starting web server at http://{host}:{port}")
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
self.app,
|
self.app,
|
||||||
host=host,
|
host=host,
|
||||||
@ -1423,7 +1458,7 @@ def main():
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Setup logging based on the provided level
|
# Setup logging based on the provided level
|
||||||
setup_logging(args.level)
|
logger.setLevel(args.level.upper())
|
||||||
|
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
"ignore",
|
"ignore",
|
||||||
|
@ -1,10 +1,67 @@
|
|||||||
# Import defines to make `utils.defines` accessible
|
from typing import Optional, Type
|
||||||
|
|
||||||
from . import defines
|
from . import defines
|
||||||
|
from . rag import ChromaDBFileWatcher, start_file_watcher
|
||||||
|
from . message import Message
|
||||||
|
from . conversation import Conversation
|
||||||
|
from . context import Context
|
||||||
|
from . import agents
|
||||||
|
from . setup_logging import setup_logging
|
||||||
|
|
||||||
# Import rest as `utils.*` accessible
|
from .agents import Agent, __all__ as agents_all
|
||||||
from .rag import ChromaDBFileWatcher, start_file_watcher
|
|
||||||
|
|
||||||
from .message import Message
|
__all__ = [
|
||||||
from .conversation import Conversation
|
'Agent',
|
||||||
from .session import Session, Chat, Resume, JobDescription, FactCheck
|
'Context',
|
||||||
from .context import Context
|
'Conversation',
|
||||||
|
'Message',
|
||||||
|
'ChromaDBFileWatcher',
|
||||||
|
'start_file_watcher'
|
||||||
|
'logger',
|
||||||
|
] + agents_all
|
||||||
|
|
||||||
|
# Resolve circular dependencies by rebuilding models
|
||||||
|
# Call model_rebuild() on Agent and Context
|
||||||
|
Agent.model_rebuild()
|
||||||
|
Context.model_rebuild()
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
# Assuming class_registry is available from agents/__init__.py
|
||||||
|
from .agents import class_registry, AnyAgent
|
||||||
|
|
||||||
|
logger = setup_logging(level=defines.logging_level)
|
||||||
|
|
||||||
|
def rebuild_models():
|
||||||
|
for class_name, (module_name, _) in class_registry.items():
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
cls = getattr(module, class_name, None)
|
||||||
|
|
||||||
|
logger.debug(f"Checking: {class_name} in module {module_name}")
|
||||||
|
logger.debug(f" cls: {True if cls else False}")
|
||||||
|
logger.debug(f" isinstance(cls, type): {isinstance(cls, type)}")
|
||||||
|
logger.debug(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}")
|
||||||
|
logger.debug(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}")
|
||||||
|
logger.debug(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
cls
|
||||||
|
and isinstance(cls, type)
|
||||||
|
and issubclass(cls, BaseModel)
|
||||||
|
and issubclass(cls, AnyAgent)
|
||||||
|
and cls is not AnyAgent
|
||||||
|
):
|
||||||
|
logger.debug(f"Rebuilding {class_name} from {module_name}")
|
||||||
|
from . agents import Agent
|
||||||
|
from . context import Context
|
||||||
|
cls.model_rebuild()
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to import module {module_name}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing {class_name} in {module_name}: {e}")
|
||||||
|
|
||||||
|
# Call this after all modules are imported
|
||||||
|
rebuild_models()
|
43
src/utils/agents/__init__.py
Normal file
43
src/utils/agents/__init__.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import importlib
|
||||||
|
import pathlib
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from typing import TypeAlias, Dict, Tuple
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from . base import Agent
|
||||||
|
|
||||||
|
# Type alias for Agent or any subclass
|
||||||
|
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
|
||||||
|
|
||||||
|
package_dir = pathlib.Path(__file__).parent
|
||||||
|
package_name = __name__
|
||||||
|
__all__ = []
|
||||||
|
class_registry: Dict[str, Tuple[str, str]] = {} # Maps class_name to (module_name, class_name)
|
||||||
|
|
||||||
|
for path in package_dir.glob("*.py"):
|
||||||
|
if path.name in ("__init__.py", "base.py") or path.name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_name = path.stem
|
||||||
|
full_module_name = f"{package_name}.{module_name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(full_module_name)
|
||||||
|
|
||||||
|
# Find all Agent subclasses in the module
|
||||||
|
for name, obj in inspect.getmembers(module, inspect.isclass):
|
||||||
|
if (
|
||||||
|
issubclass(obj, AnyAgent)
|
||||||
|
and obj is not AnyAgent
|
||||||
|
and obj is not Agent
|
||||||
|
and name not in class_registry
|
||||||
|
):
|
||||||
|
class_registry[name] = (full_module_name, name)
|
||||||
|
globals()[name] = obj
|
||||||
|
logging.info(f"Adding agent: {name} from {full_module_name}")
|
||||||
|
__all__.append(name)
|
||||||
|
except ImportError as e:
|
||||||
|
logging.error(f"Failed to import module {full_module_name}: {e}")
|
||||||
|
|
||||||
|
__all__.append("AnyAgent")
|
258
src/utils/agents/base.py
Normal file
258
src/utils/agents/base.py
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from pydantic import BaseModel, model_validator, PrivateAttr, Field
|
||||||
|
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef, Any
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
from .. setup_logging import setup_logging
|
||||||
|
|
||||||
|
logger = setup_logging()
|
||||||
|
|
||||||
|
# Only import Context for type checking
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .. context import Context
|
||||||
|
|
||||||
|
from .types import registry
|
||||||
|
|
||||||
|
from .. conversation import Conversation
|
||||||
|
from .. message import Message
|
||||||
|
|
||||||
|
class Agent(BaseModel, ABC):
|
||||||
|
"""
|
||||||
|
Base class for all agent types.
|
||||||
|
This class defines the common attributes and methods for all agent types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Agent management with pydantic
|
||||||
|
agent_type: Literal["base"] = "base"
|
||||||
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
|
# Agent properties
|
||||||
|
system_prompt: str # Mandatory
|
||||||
|
conversation: Conversation = Conversation()
|
||||||
|
context_tokens: int = 0
|
||||||
|
context: Optional[Context] = Field(default=None, exclude=True) # Avoid circular reference, require as param, and prevent serialization
|
||||||
|
|
||||||
|
_content_seed: str = PrivateAttr(default="")
|
||||||
|
|
||||||
|
# Class and pydantic model management
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
"""Auto-register subclasses"""
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
# Register this class if it has an agent_type
|
||||||
|
if hasattr(cls, 'agent_type') and cls.agent_type != Agent._agent_type:
|
||||||
|
registry.register(cls.agent_type, cls)
|
||||||
|
|
||||||
|
def model_dump(self, *args, **kwargs):
|
||||||
|
# Ensure context is always excluded, even with exclude_unset=True
|
||||||
|
kwargs.setdefault("exclude", set())
|
||||||
|
if isinstance(kwargs["exclude"], set):
|
||||||
|
kwargs["exclude"].add("context")
|
||||||
|
elif isinstance(kwargs["exclude"], dict):
|
||||||
|
kwargs["exclude"]["context"] = True
|
||||||
|
return super().model_dump(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def valid_agent_types(cls) -> set[str]:
|
||||||
|
"""Return the set of valid agent_type values."""
|
||||||
|
return set(get_args(cls.__annotations__["agent_type"]))
|
||||||
|
|
||||||
|
def set_context(self, context):
|
||||||
|
object.__setattr__(self, "context", context)
|
||||||
|
|
||||||
|
# Agent methods
|
||||||
|
def get_agent_type(self):
|
||||||
|
return self._agent_type
|
||||||
|
|
||||||
|
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
||||||
|
"""
|
||||||
|
Prepare message with context information in message.preamble
|
||||||
|
"""
|
||||||
|
# Generate RAG content if enabled, based on the content
|
||||||
|
rag_context = ""
|
||||||
|
if not message.disable_rag:
|
||||||
|
# Gather RAG results, yielding each result
|
||||||
|
# as it becomes available
|
||||||
|
for value in self.context.generate_rag_results(message):
|
||||||
|
logger.info(f"RAG: {value.status} - {value.response}")
|
||||||
|
if value.status != "done":
|
||||||
|
yield value
|
||||||
|
if value.status == "error":
|
||||||
|
message.status = "error"
|
||||||
|
message.response = value.response
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
if message.metadata["rag"]:
|
||||||
|
for rag_collection in message.metadata["rag"]:
|
||||||
|
for doc in rag_collection["documents"]:
|
||||||
|
rag_context += f"{doc}\n"
|
||||||
|
|
||||||
|
if rag_context:
|
||||||
|
message["context"] = rag_context
|
||||||
|
|
||||||
|
if self.context.user_resume:
|
||||||
|
message["resume"] = self.content.user_resume
|
||||||
|
|
||||||
|
if message.preamble:
|
||||||
|
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
|
||||||
|
preamble_types_AND = " and ".join(preamble_types)
|
||||||
|
preamble_types_OR = " or ".join(preamble_types)
|
||||||
|
message.preamble["rules"] = f"""\
|
||||||
|
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly.
|
||||||
|
- If there is no information in these sections, answer based on your knowledge.
|
||||||
|
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||||
|
"""
|
||||||
|
message.preamble["question"] = "Use that information to respond to:"
|
||||||
|
else:
|
||||||
|
message.preamble["question"] = "Respond to:"
|
||||||
|
|
||||||
|
message.system_prompt = self.system_prompt
|
||||||
|
message.status = "done"
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||||
|
if self.context.processing:
|
||||||
|
logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||||
|
message.status = "error"
|
||||||
|
message.response = "Busy processing another request."
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
self.context.processing = True
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
for value in self.llm.chat(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
#tools=llm_tools(context.tools) if message.enable_tools else None,
|
||||||
|
options={ "num_ctx": message.ctx_size }
|
||||||
|
):
|
||||||
|
logger.info(f"LLM: {value.status} - {value.response}")
|
||||||
|
if value.status != "done":
|
||||||
|
message.status = value.status
|
||||||
|
message.response = value.response
|
||||||
|
yield message
|
||||||
|
if value.status == "error":
|
||||||
|
return
|
||||||
|
response = value
|
||||||
|
|
||||||
|
message.metadata["eval_count"] += response["eval_count"]
|
||||||
|
message.metadata["eval_duration"] += response["eval_duration"]
|
||||||
|
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||||
|
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||||
|
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||||
|
|
||||||
|
tools_used = []
|
||||||
|
|
||||||
|
yield {"status": "processing", "message": "Initial response received..."}
|
||||||
|
|
||||||
|
if "tool_calls" in response.get("message", {}):
|
||||||
|
yield {"status": "processing", "message": "Processing tool calls..."}
|
||||||
|
|
||||||
|
tool_message = response["message"]
|
||||||
|
tool_result = None
|
||||||
|
|
||||||
|
# Process all yielded items from the handler
|
||||||
|
async for item in self.handle_tool_calls(tool_message):
|
||||||
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
|
# This is the final result tuple (tool_result, tools_used)
|
||||||
|
tool_result, tools_used = item
|
||||||
|
else:
|
||||||
|
# This is a status update, forward it
|
||||||
|
yield item
|
||||||
|
|
||||||
|
message_dict = {
|
||||||
|
"role": tool_message.get("role", "assistant"),
|
||||||
|
"content": tool_message.get("content", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
if "tool_calls" in tool_message:
|
||||||
|
message_dict["tool_calls"] = [
|
||||||
|
{"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}}
|
||||||
|
for tc in tool_message["tool_calls"]
|
||||||
|
]
|
||||||
|
|
||||||
|
pre_add_index = len(messages)
|
||||||
|
messages.append(message_dict)
|
||||||
|
|
||||||
|
if isinstance(tool_result, list):
|
||||||
|
messages.extend(tool_result)
|
||||||
|
else:
|
||||||
|
if tool_result:
|
||||||
|
messages.append(tool_result)
|
||||||
|
|
||||||
|
message.metadata["tools"] = tools_used
|
||||||
|
|
||||||
|
# Estimate token length of new messages
|
||||||
|
ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
|
||||||
|
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
|
||||||
|
# Decrease creativity when processing tool call requests
|
||||||
|
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
|
||||||
|
message.metadata["eval_count"] += response["eval_count"]
|
||||||
|
message.metadata["eval_duration"] += response["eval_duration"]
|
||||||
|
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||||
|
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||||
|
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||||
|
|
||||||
|
reply = response["message"]["content"]
|
||||||
|
message.response = reply
|
||||||
|
message.metadata["origin"] = agent.agent_type
|
||||||
|
# final_message = {"role": "assistant", "content": reply }
|
||||||
|
|
||||||
|
# # history is provided to the LLM and should not have additional metadata
|
||||||
|
# llm_history.append(final_message)
|
||||||
|
|
||||||
|
# user_history is provided to the REST API and does not include CONTEXT
|
||||||
|
# It does include metadata
|
||||||
|
# final_message["metadata"] = message.metadata
|
||||||
|
# user_history.append({**final_message, "origin": message.metadata["origin"]})
|
||||||
|
|
||||||
|
# Return the REST API with metadata
|
||||||
|
yield {
|
||||||
|
"status": "done",
|
||||||
|
"message": {
|
||||||
|
**message.model_dump(mode='json'),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.context.processing = False
|
||||||
|
return
|
||||||
|
|
||||||
|
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
|
||||||
|
message.full_content = ""
|
||||||
|
for i, p in enumerate(message.preamble.keys()):
|
||||||
|
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
|
||||||
|
|
||||||
|
# Estimate token length of new messages
|
||||||
|
message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
|
||||||
|
|
||||||
|
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
||||||
|
message.status = "thinking"
|
||||||
|
yield message
|
||||||
|
|
||||||
|
for value in self.generate_llm_response(message):
|
||||||
|
logger.info(f"LLM: {value.status} - {value.response}")
|
||||||
|
if value.status != "done":
|
||||||
|
yield value
|
||||||
|
if value.status == "error":
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_and_reset_content_seed(self):
|
||||||
|
tmp = self._content_seed
|
||||||
|
self._content_seed = ""
|
||||||
|
return tmp
|
||||||
|
|
||||||
|
def set_content_seed(self, content: str) -> None:
|
||||||
|
"""Set the content seed for the agent."""
|
||||||
|
self._content_seed = content
|
||||||
|
|
||||||
|
def get_content_seed(self) -> str:
|
||||||
|
"""Get the content seed for the agent."""
|
||||||
|
return self._content_seed
|
||||||
|
|
||||||
|
# Register the base agent
|
||||||
|
registry.register(Agent._agent_type, Agent)
|
||||||
|
|
||||||
|
|
246
src/utils/agents/chat.py
Normal file
246
src/utils/agents/chat.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from pydantic import BaseModel, model_validator, PrivateAttr
|
||||||
|
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
import logging
|
||||||
|
from .base import Agent, registry
|
||||||
|
from .. conversation import Conversation
|
||||||
|
from .. message import Message
|
||||||
|
from .. import defines
|
||||||
|
|
||||||
|
class Chat(Agent, ABC):
|
||||||
|
"""
|
||||||
|
Base class for all agent types.
|
||||||
|
This class defines the common attributes and methods for all agent types.
|
||||||
|
"""
|
||||||
|
agent_type: Literal["chat"] = "chat"
|
||||||
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
|
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
||||||
|
"""
|
||||||
|
Prepare message with context information in message.preamble
|
||||||
|
"""
|
||||||
|
if not self.context:
|
||||||
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
|
# Generate RAG content if enabled, based on the content
|
||||||
|
rag_context = ""
|
||||||
|
if not message.disable_rag:
|
||||||
|
# Gather RAG results, yielding each result
|
||||||
|
# as it becomes available
|
||||||
|
for message in self.context.generate_rag_results(message):
|
||||||
|
logging.info(f"RAG: {message.status} - {message.response}")
|
||||||
|
if message.status == "error":
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
if message.status != "done":
|
||||||
|
yield message
|
||||||
|
|
||||||
|
if "rag" in message.metadata and message.metadata["rag"]:
|
||||||
|
for rag in message.metadata["rag"]:
|
||||||
|
for doc in rag["documents"]:
|
||||||
|
rag_context += f"{doc}\n"
|
||||||
|
|
||||||
|
message.preamble = {}
|
||||||
|
|
||||||
|
if rag_context:
|
||||||
|
message.preamble["context"] = rag_context
|
||||||
|
|
||||||
|
if self.context.user_resume:
|
||||||
|
message.preamble["resume"] = self.context.user_resume
|
||||||
|
|
||||||
|
if message.preamble:
|
||||||
|
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
|
||||||
|
preamble_types_AND = " and ".join(preamble_types)
|
||||||
|
preamble_types_OR = " or ".join(preamble_types)
|
||||||
|
message.preamble["rules"] = f"""\
|
||||||
|
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
||||||
|
- If there is no information in these sections, answer based on your knowledge.
|
||||||
|
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||||
|
"""
|
||||||
|
message.preamble["question"] = "Use that information to respond to:"
|
||||||
|
else:
|
||||||
|
message.preamble["question"] = "Respond to:"
|
||||||
|
|
||||||
|
message.system_prompt = self.system_prompt
|
||||||
|
message.status = "done"
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]:
|
||||||
|
if not self.context:
|
||||||
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
|
if self.context.processing:
|
||||||
|
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||||
|
message.status = "error"
|
||||||
|
message.response = "Busy processing another request."
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
self.context.processing = True
|
||||||
|
|
||||||
|
self.conversation.add_message(message)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
item for m in self.conversation.messages
|
||||||
|
for item in [
|
||||||
|
{"role": "user", "content": m.prompt},
|
||||||
|
{"role": "assistant", "content": m.response}
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
for value in llm.chat(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
#tools=llm_tools(context.tools) if message.enable_tools else None,
|
||||||
|
options={ "num_ctx": message.metadata["ctx_size"] if message.metadata["ctx_size"] else defines.max_context },
|
||||||
|
stream=True,
|
||||||
|
):
|
||||||
|
logging.info(f"LLM: {'done' if value.done else 'thinking'} - {value.message.content}")
|
||||||
|
message.response += value.message.content
|
||||||
|
yield message
|
||||||
|
if value.done:
|
||||||
|
response = value
|
||||||
|
message.status = "done"
|
||||||
|
|
||||||
|
if not response:
|
||||||
|
message.status = "error"
|
||||||
|
message.response = "No response from LLM."
|
||||||
|
yield message
|
||||||
|
self.context.processing = False
|
||||||
|
return
|
||||||
|
|
||||||
|
message.metadata["eval_count"] += response["eval_count"]
|
||||||
|
message.metadata["eval_duration"] += response["eval_duration"]
|
||||||
|
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||||
|
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||||
|
self.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||||
|
|
||||||
|
yield message
|
||||||
|
self.context.processing = False
|
||||||
|
return
|
||||||
|
|
||||||
|
tools_used = []
|
||||||
|
|
||||||
|
|
||||||
|
if "tool_calls" in response.get("message", {}):
|
||||||
|
message.status = "thinking"
|
||||||
|
message.response = "Processing tool calls..."
|
||||||
|
|
||||||
|
tool_message = response["message"]
|
||||||
|
tool_result = None
|
||||||
|
|
||||||
|
# Process all yielded items from the handler
|
||||||
|
async for value in self.handle_tool_calls(tool_message):
|
||||||
|
if isinstance(value, tuple) and len(value) == 2:
|
||||||
|
# This is the final result tuple (tool_result, tools_used)
|
||||||
|
tool_result, tools_used = value
|
||||||
|
else:
|
||||||
|
# This is a status update, forward it
|
||||||
|
yield value
|
||||||
|
|
||||||
|
message_dict = {
|
||||||
|
"role": tool_message.get("role", "assistant"),
|
||||||
|
"content": tool_message.get("content", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
if "tool_calls" in tool_message:
|
||||||
|
message_dict["tool_calls"] = [
|
||||||
|
{"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}}
|
||||||
|
for tc in tool_message["tool_calls"]
|
||||||
|
]
|
||||||
|
|
||||||
|
pre_add_index = len(messages)
|
||||||
|
messages.append(message_dict)
|
||||||
|
|
||||||
|
if isinstance(tool_result, list):
|
||||||
|
messages.extend(tool_result)
|
||||||
|
else:
|
||||||
|
if tool_result:
|
||||||
|
messages.append(tool_result)
|
||||||
|
|
||||||
|
message.metadata["tools"] = tools_used
|
||||||
|
|
||||||
|
# Estimate token length of new messages
|
||||||
|
ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
|
||||||
|
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
|
||||||
|
# Decrease creativity when processing tool call requests
|
||||||
|
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
|
||||||
|
message.metadata["eval_count"] += response["eval_count"]
|
||||||
|
message.metadata["eval_duration"] += response["eval_duration"]
|
||||||
|
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||||
|
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||||
|
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||||
|
|
||||||
|
reply = response["message"]["content"]
|
||||||
|
message.response = reply
|
||||||
|
message.metadata["origin"] = agent.agent_type
|
||||||
|
# final_message = {"role": "assistant", "content": reply }
|
||||||
|
|
||||||
|
# # history is provided to the LLM and should not have additional metadata
|
||||||
|
# llm_history.append(final_message)
|
||||||
|
|
||||||
|
# user_history is provided to the REST API and does not include CONTEXT
|
||||||
|
# It does include metadata
|
||||||
|
# final_message["metadata"] = message.metadata
|
||||||
|
# user_history.append({**final_message, "origin": message.metadata["origin"]})
|
||||||
|
|
||||||
|
# Return the REST API with metadata
|
||||||
|
yield {
|
||||||
|
"status": "done",
|
||||||
|
"message": {
|
||||||
|
**message.model_dump(mode='json'),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.context.processing = False
|
||||||
|
return
|
||||||
|
|
||||||
|
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
|
||||||
|
if not self.context:
|
||||||
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
|
message.full_content = f"<|system|>{self.system_prompt.strip()}\n"
|
||||||
|
for i, p in enumerate(message.preamble.keys()):
|
||||||
|
message.full_content += f"\n<|{p}|>\n{message.preamble[p].strip()}\n"
|
||||||
|
message.full_content += f"{message.prompt}"
|
||||||
|
|
||||||
|
# Estimate token length of new messages
|
||||||
|
message.metadata["ctx_size"] = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
|
||||||
|
|
||||||
|
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
||||||
|
message.status = "thinking"
|
||||||
|
yield message
|
||||||
|
|
||||||
|
async for message in self.generate_llm_response(llm, model, message):
|
||||||
|
logging.info(f"LLM: {message.status} - {message.response}")
|
||||||
|
if message.status == "error":
|
||||||
|
return
|
||||||
|
if message.status != "done":
|
||||||
|
yield message
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_and_reset_content_seed(self):
|
||||||
|
tmp = self._content_seed
|
||||||
|
self._content_seed = ""
|
||||||
|
return tmp
|
||||||
|
|
||||||
|
def set_content_seed(self, content: str) -> None:
|
||||||
|
"""Set the content seed for the agent."""
|
||||||
|
self._content_seed = content
|
||||||
|
|
||||||
|
def get_content_seed(self) -> str:
|
||||||
|
"""Get the content seed for the agent."""
|
||||||
|
return self._content_seed
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def valid_agent_types(cls) -> set[str]:
|
||||||
|
"""Return the set of valid agent_type values."""
|
||||||
|
return set(get_args(cls.__annotations__["agent_type"]))
|
||||||
|
|
||||||
|
# Register the base agent
|
||||||
|
registry.register(Chat._agent_type, Chat)
|
24
src/utils/agents/fact_check.py
Normal file
24
src/utils/agents/fact_check.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||||
|
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
import logging
|
||||||
|
from .base import Agent, registry
|
||||||
|
from .. conversation import Conversation
|
||||||
|
from .. message import Message
|
||||||
|
|
||||||
|
class FactCheck(Agent):
|
||||||
|
agent_type: Literal["fact_check"] = "fact_check"
|
||||||
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
|
facts: str = ""
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_facts(self):
|
||||||
|
if not self.facts.strip():
|
||||||
|
raise ValueError("Facts cannot be empty")
|
||||||
|
return self
|
||||||
|
|
||||||
|
# Register the base agent
|
||||||
|
registry.register(FactCheck._agent_type, FactCheck)
|
24
src/utils/agents/job_description.py
Normal file
24
src/utils/agents/job_description.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||||
|
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
import logging
|
||||||
|
from .base import Agent, registry
|
||||||
|
from .. conversation import Conversation
|
||||||
|
from .. message import Message
|
||||||
|
|
||||||
|
class JobDescription(Agent):
|
||||||
|
agent_type: Literal["job_description"] = "job_description"
|
||||||
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
|
job_description: str = ""
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_job_description(self):
|
||||||
|
if not self.job_description.strip():
|
||||||
|
raise ValueError("Job description cannot be empty")
|
||||||
|
return self
|
||||||
|
|
||||||
|
# Register the base agent
|
||||||
|
registry.register(JobDescription._agent_type, JobDescription)
|
32
src/utils/agents/resume.py
Normal file
32
src/utils/agents/resume.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||||
|
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
import logging
|
||||||
|
from .base import Agent, registry
|
||||||
|
from .. conversation import Conversation
|
||||||
|
from .. message import Message
|
||||||
|
|
||||||
|
class Resume(Agent):
|
||||||
|
agent_type: Literal["resume"] = "resume"
|
||||||
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
|
resume: str = ""
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_resume(self):
|
||||||
|
if not self.resume.strip():
|
||||||
|
raise ValueError("Resume content cannot be empty")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_resume(self) -> str:
|
||||||
|
"""Get the resume content."""
|
||||||
|
return self.resume
|
||||||
|
|
||||||
|
def set_resume(self, resume: str) -> None:
|
||||||
|
"""Set the resume content."""
|
||||||
|
self.resume = resume
|
||||||
|
|
||||||
|
# Register the base agent
|
||||||
|
registry.register(Resume._agent_type, Resume)
|
38
src/utils/agents/types.py
Normal file
38
src/utils/agents/types.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import List, Dict, Any, Union, ForwardRef, TypeVar, Optional, TYPE_CHECKING, Type, ClassVar, Literal
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
from pydantic import Field, BaseModel
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
# Forward references
|
||||||
|
AgentRef = ForwardRef('Agent')
|
||||||
|
ContextRef = ForwardRef('Context')
|
||||||
|
|
||||||
|
# We'll use a registry pattern rather than hardcoded strings
|
||||||
|
class AgentRegistry:
|
||||||
|
"""Registry for agent types and classes"""
|
||||||
|
_registry: Dict[str, Type] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, agent_type: str, agent_class: Type) -> Type:
|
||||||
|
"""Register an agent class with its type"""
|
||||||
|
cls._registry[agent_type] = agent_class
|
||||||
|
return agent_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_class(cls, agent_type: str) -> Optional[Type]:
|
||||||
|
"""Get the class for a given agent type"""
|
||||||
|
return cls._registry.get(agent_type)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_types(cls) -> List[str]:
|
||||||
|
"""Get all registered agent types"""
|
||||||
|
return list(cls._registry.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_classes(cls) -> Dict[str, Type]:
|
||||||
|
"""Get all registered agent classes"""
|
||||||
|
return cls._registry.copy()
|
||||||
|
|
||||||
|
# Create a singleton instance
|
||||||
|
registry = AgentRegistry()
|
@ -1,19 +1,32 @@
|
|||||||
from pydantic import BaseModel, Field, model_validator
|
from __future__ import annotations
|
||||||
|
from pydantic import BaseModel, Field, model_validator, ValidationError
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from typing import List, Optional
|
from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING
|
||||||
from typing_extensions import Annotated, Union
|
from typing_extensions import Annotated, Union
|
||||||
from .session import AnySession, Session
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
from uuid import uuid4
|
||||||
|
import re
|
||||||
|
|
||||||
|
from .message import Message
|
||||||
|
from .rag import ChromaDBFileWatcher
|
||||||
|
from . import defines
|
||||||
|
|
||||||
|
from .agents import AnyAgent
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Context(BaseModel):
|
class Context(BaseModel):
|
||||||
|
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
||||||
|
# Required fields
|
||||||
|
file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
||||||
|
|
||||||
|
# Optional fields
|
||||||
id: str = Field(
|
id: str = Field(
|
||||||
default_factory=lambda: str(uuid4()),
|
default_factory=lambda: str(uuid4()),
|
||||||
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
||||||
)
|
)
|
||||||
|
|
||||||
sessions: List[Annotated[Union[*Session.__subclasses__()], Field(discriminator="session_type")]] = Field(
|
|
||||||
default_factory=list
|
|
||||||
)
|
|
||||||
|
|
||||||
user_resume: Optional[str] = None
|
user_resume: Optional[str] = None
|
||||||
user_job_description: Optional[str] = None
|
user_job_description: Optional[str] = None
|
||||||
user_facts: Optional[str] = None
|
user_facts: Optional[str] = None
|
||||||
@ -21,78 +34,160 @@ class Context(BaseModel):
|
|||||||
rags: List[dict] = []
|
rags: List[dict] = []
|
||||||
message_history_length: int = 5
|
message_history_length: int = 5
|
||||||
context_tokens: int = 0
|
context_tokens: int = 0
|
||||||
|
# Class managed fields
|
||||||
|
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, id: Optional[str] = None, **kwargs):
|
processing: bool = Field(default=False, exclude=True)
|
||||||
super().__init__(id=id if id is not None else str(uuid4()), **kwargs)
|
|
||||||
|
# @model_validator(mode="before")
|
||||||
|
# @classmethod
|
||||||
|
# def before_model_validator(cls, values: Any):
|
||||||
|
# logger.info(f"Preparing model data: {cls} {values}")
|
||||||
|
# return values
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_unique_session_types(self):
|
def after_model_validator(self):
|
||||||
"""Ensure at most one session per session_type."""
|
"""Ensure at most one agent per agent_type."""
|
||||||
session_types = [session.session_type for session in self.sessions]
|
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
|
||||||
if len(session_types) != len(set(session_types)):
|
agent_types = [agent.agent_type for agent in self.agents]
|
||||||
raise ValueError("Context cannot contain multiple sessions of the same session_type")
|
if len(agent_types) != len(set(agent_types)):
|
||||||
|
raise ValueError("Context cannot contain multiple agents of the same agent_type")
|
||||||
|
for agent in self.agents:
|
||||||
|
agent.set_context(self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_or_create_session(self, session_type: str, **kwargs) -> Session:
|
def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096):
|
||||||
|
ctx = round(context + len(str(messages)) * 3 / 4)
|
||||||
|
return max(defines.max_context, min(2048, ctx + ctx_buffer))
|
||||||
|
|
||||||
|
def generate_rag_results(self, message: Message) -> Generator[Message, None, None]:
|
||||||
"""
|
"""
|
||||||
Get or create and append a new session of the specified type, ensuring only one session per type exists.
|
Generate RAG results for the given query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_type: The type of session to create (e.g., 'web', 'database').
|
query: The query string to generate RAG results for.
|
||||||
**kwargs: Additional fields required by the specific session subclass.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The created session instance.
|
A list of dictionaries containing the RAG results.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
message.status = "processing"
|
||||||
|
|
||||||
|
entries : int = 0
|
||||||
|
|
||||||
|
if not self.file_watcher:
|
||||||
|
message.response = "No RAG context available."
|
||||||
|
del message.metadata["rag"]
|
||||||
|
message.status = "done"
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
message.metadata["rag"] = []
|
||||||
|
for rag in self.rags:
|
||||||
|
if not rag["enabled"]:
|
||||||
|
continue
|
||||||
|
message.response = f"Checking RAG context {rag['name']}..."
|
||||||
|
yield message
|
||||||
|
chroma_results = self.file_watcher.find_similar(query=message.prompt, top_k=10)
|
||||||
|
if chroma_results:
|
||||||
|
entries += len(chroma_results["documents"])
|
||||||
|
|
||||||
|
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
||||||
|
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||||
|
|
||||||
|
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
||||||
|
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
||||||
|
|
||||||
|
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
||||||
|
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
||||||
|
|
||||||
|
message.metadata["rag"].append({
|
||||||
|
"name": rag["name"],
|
||||||
|
**chroma_results,
|
||||||
|
"umap_embedding_2d": umap_2d,
|
||||||
|
"umap_embedding_3d": umap_3d
|
||||||
|
})
|
||||||
|
yield message
|
||||||
|
|
||||||
|
if entries == 0:
|
||||||
|
del message.metadata["rag"]
|
||||||
|
|
||||||
|
message.response = f"RAG context gathered from results from {entries} documents."
|
||||||
|
message.status = "done"
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
message.status = "error"
|
||||||
|
message.response = f"Error generating RAG results: {str(e)}"
|
||||||
|
logger.error(e)
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_or_create_agent(self, agent_type: str, **kwargs) -> Agent:
|
||||||
|
"""
|
||||||
|
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_type: The type of agent to create (e.g., 'web', 'database').
|
||||||
|
**kwargs: Additional fields required by the specific agent subclass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created agent instance.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no matching session type is found or if a session of this type already exists.
|
ValueError: If no matching agent type is found or if a agent of this type already exists.
|
||||||
"""
|
"""
|
||||||
# Check if a session with the given session_type already exists
|
# Check if a agent with the given agent_type already exists
|
||||||
for session in self.sessions:
|
for agent in self.agents:
|
||||||
if session.session_type == session_type:
|
if agent.agent_type == agent_type:
|
||||||
return session
|
return agent
|
||||||
|
|
||||||
# Find the matching subclass
|
# Find the matching subclass
|
||||||
for session_cls in Session.__subclasses__():
|
for agent_cls in Agent.__subclasses__():
|
||||||
if session_cls.model_fields["session_type"].default == session_type:
|
if agent_cls.model_fields["agent_type"].default == agent_type:
|
||||||
# Create the session instance with provided kwargs
|
# Create the agent instance with provided kwargs
|
||||||
session = session_cls(session_type=session_type, **kwargs)
|
agent = agent_cls(agent_type=agent_type, context=self, **kwargs)
|
||||||
self.sessions.append(session)
|
self.agents.append(agent)
|
||||||
return session
|
return agent
|
||||||
|
|
||||||
raise ValueError(f"No session class found for session_type: {session_type}")
|
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
||||||
|
|
||||||
def add_session(self, session: AnySession) -> None:
|
def add_agent(self, agent: AnyAgent) -> None:
|
||||||
"""Add a Session to the context, ensuring no duplicate session_type."""
|
"""Add a Agent to the context, ensuring no duplicate agent_type."""
|
||||||
if any(s.session_type == session.session_type for s in self.sessions):
|
if any(s.agent_type == agent.agent_type for s in self.agents):
|
||||||
raise ValueError(f"A session with session_type '{session.session_type}' already exists")
|
raise ValueError(f"A agent with agent_type '{agent.agent_type}' already exists")
|
||||||
self.sessions.append(session)
|
self.agents.append(agent)
|
||||||
|
|
||||||
def get_session(self, session_type: str) -> Session | None:
|
def get_agent(self, agent_type: str) -> Agent | None:
|
||||||
"""Return the Session with the given session_type, or None if not found."""
|
"""Return the Agent with the given agent_type, or None if not found."""
|
||||||
for session in self.sessions:
|
for agent in self.agents:
|
||||||
if session.session_type == session_type:
|
if agent.agent_type == agent_type:
|
||||||
return session
|
return agent
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def is_valid_session_type(self, session_type: str) -> bool:
|
def is_valid_agent_type(self, agent_type: str) -> bool:
|
||||||
"""Check if the given session_type is valid."""
|
"""Check if the given agent_type is valid."""
|
||||||
return session_type in Session.valid_session_types()
|
return agent_type in Agent.valid_agent_types()
|
||||||
|
|
||||||
def get_summary(self) -> str:
|
def get_summary(self) -> str:
|
||||||
"""Return a summary of the context."""
|
"""Return a summary of the context."""
|
||||||
if not self.sessions:
|
if not self.agents:
|
||||||
return f"Context {self.uuid}: No sessions."
|
return f"Context {self.uuid}: No agents."
|
||||||
summary = f"Context {self.uuid}:\n"
|
summary = f"Context {self.uuid}:\n"
|
||||||
for i, session in enumerate(self.sessions, 1):
|
for i, agent in enumerate(self.agents, 1):
|
||||||
summary += f"\nSession {i} ({session.session_type}):\n"
|
summary += f"\nAgent {i} ({agent.agent_type}):\n"
|
||||||
summary += session.conversation.get_summary()
|
summary += agent.conversation.get_summary()
|
||||||
if session.session_type == "resume":
|
if agent.agent_type == "resume":
|
||||||
summary += f"\nResume: {session.get_resume()}\n"
|
summary += f"\nResume: {agent.get_resume()}\n"
|
||||||
elif session.session_type == "job_description":
|
elif agent.agent_type == "job_description":
|
||||||
summary += f"\nJob Description: {session.job_description}\n"
|
summary += f"\nJob Description: {agent.job_description}\n"
|
||||||
elif session.session_type == "fact_check":
|
elif agent.agent_type == "fact_check":
|
||||||
summary += f"\nFacts: {session.facts}\n"
|
summary += f"\nFacts: {agent.facts}\n"
|
||||||
elif session.session_type == "chat":
|
elif agent.agent_type == "chat":
|
||||||
summary += f"\nChat Name: {session.name}\n"
|
summary += f"\nChat Name: {agent.name}\n"
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
from . agents import Agent
|
||||||
|
Context.model_rebuild()
|
@ -9,9 +9,10 @@ embedding_model = os.getenv("EMBEDDING_MODEL_NAME", "mxbai-embed-large")
|
|||||||
persist_directory = os.getenv("PERSIST_DIR", "/opt/backstory/chromadb")
|
persist_directory = os.getenv("PERSIST_DIR", "/opt/backstory/chromadb")
|
||||||
max_context = 2048*8*2
|
max_context = 2048*8*2
|
||||||
doc_dir = "/opt/backstory/docs/"
|
doc_dir = "/opt/backstory/docs/"
|
||||||
session_dir = "/opt/backstory/sessions"
|
context_dir = "/opt/backstory/sessions"
|
||||||
static_content = "/opt/backstory/frontend/deployed"
|
static_content = "/opt/backstory/frontend/deployed"
|
||||||
resume_doc = "/opt/backstory/docs/resume/generic.md"
|
resume_doc = "/opt/backstory/docs/resume/generic.md"
|
||||||
# Only used for testing; backstory-prod will not use this
|
# Only used for testing; backstory-prod will not use this
|
||||||
key_path = "/opt/backstory/keys/key.pem"
|
key_path = "/opt/backstory/keys/key.pem"
|
||||||
cert_path = "/opt/backstory/keys/cert.pem"
|
cert_path = "/opt/backstory/keys/cert.pem"
|
||||||
|
logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper()
|
@ -3,19 +3,29 @@ from typing import Dict, List, Optional, Any
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
prompt: str
|
# Required
|
||||||
preamble: str = ""
|
prompt: str # Query to be answered
|
||||||
content: str = ""
|
|
||||||
response: str = ""
|
# Tunables
|
||||||
|
disable_rag: bool = False
|
||||||
|
disable_tools: bool = False
|
||||||
|
|
||||||
|
# Generated while processing message
|
||||||
|
status: str = "" # Status of the message
|
||||||
|
preamble: dict[str,str] = {} # Preamble to be prepended to the prompt
|
||||||
|
system_prompt: str = "" # System prompt provided to the LLM
|
||||||
|
full_content: str = "" # Full content of the message (preamble + prompt)
|
||||||
|
response: str = "" # LLM response to the preamble + query
|
||||||
metadata: dict[str, Any] = {
|
metadata: dict[str, Any] = {
|
||||||
"rag": { "documents": [] },
|
"rag": List[dict[str, Any]],
|
||||||
"tools": [],
|
"tools": [],
|
||||||
"eval_count": 0,
|
"eval_count": 0,
|
||||||
"eval_duration": 0,
|
"eval_duration": 0,
|
||||||
"prompt_eval_count": 0,
|
"prompt_eval_count": 0,
|
||||||
"prompt_eval_duration": 0,
|
"prompt_eval_duration": 0,
|
||||||
|
"ctx_size": 0,
|
||||||
}
|
}
|
||||||
actions: List[str] = []
|
actions: List[str] = [] # Other session modifying actions performed while processing the message
|
||||||
timestamp: datetime = datetime.now(timezone.utc)
|
timestamp: datetime = datetime.now(timezone.utc)
|
||||||
|
|
||||||
def add_action(self, action: str | list[str]) -> None:
|
def add_action(self, action: str | list[str]) -> None:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -51,8 +52,12 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
self.chunk_overlap = chunk_overlap
|
self.chunk_overlap = chunk_overlap
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
|
self._umap_collection = None
|
||||||
|
self._umap_embedding_2d = []
|
||||||
|
self._umap_embedding_3d = []
|
||||||
|
self._umap_model_2d = None
|
||||||
|
self._umap_model_3d = None
|
||||||
|
self._collection = None
|
||||||
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
|
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
|
||||||
|
|
||||||
#self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
#self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||||
|
@ -1,78 +0,0 @@
|
|||||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
|
||||||
from typing import Literal, TypeAlias, get_args
|
|
||||||
from .conversation import Conversation
|
|
||||||
|
|
||||||
class Session(BaseModel):
|
|
||||||
session_type: Literal["resume", "job_description", "fact_check", "chat"]
|
|
||||||
system_prompt: str # Mandatory
|
|
||||||
conversation: Conversation = Conversation()
|
|
||||||
context_tokens: int = 0
|
|
||||||
|
|
||||||
_content_seed: str = PrivateAttr(default="")
|
|
||||||
|
|
||||||
def get_and_reset_content_seed(self):
|
|
||||||
tmp = self._content_seed
|
|
||||||
self._content_seed = ""
|
|
||||||
return tmp
|
|
||||||
|
|
||||||
def set_content_seed(self, content: str) -> None:
|
|
||||||
"""Set the content seed for the session."""
|
|
||||||
self._content_seed = content
|
|
||||||
|
|
||||||
def get_content_seed(self) -> str:
|
|
||||||
"""Get the content seed for the session."""
|
|
||||||
return self._content_seed
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def valid_session_types(cls) -> set[str]:
|
|
||||||
"""Return the set of valid session_type values."""
|
|
||||||
return set(get_args(cls.__annotations__["session_type"]))
|
|
||||||
|
|
||||||
|
|
||||||
# Type alias for Session or any subclass
|
|
||||||
AnySession: TypeAlias = Session # BaseModel covers Session and subclasses
|
|
||||||
|
|
||||||
class Resume(Session):
|
|
||||||
session_type: Literal["resume"] = "resume"
|
|
||||||
resume: str = ""
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_resume(self):
|
|
||||||
if not self.resume.strip():
|
|
||||||
raise ValueError("Resume content cannot be empty")
|
|
||||||
return self
|
|
||||||
|
|
||||||
def get_resume(self) -> str:
|
|
||||||
"""Get the resume content."""
|
|
||||||
return self.resume
|
|
||||||
|
|
||||||
def set_resume(self, resume: str) -> None:
|
|
||||||
"""Set the resume content."""
|
|
||||||
self.resume = resume
|
|
||||||
|
|
||||||
class JobDescription(Session):
|
|
||||||
session_type: Literal["job_description"] = "job_description"
|
|
||||||
job_description: str = ""
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_job_description(self):
|
|
||||||
if not self.job_description.strip():
|
|
||||||
raise ValueError("Job description cannot be empty")
|
|
||||||
return self
|
|
||||||
|
|
||||||
class FactCheck(Session):
|
|
||||||
session_type: Literal["fact_check"] = "fact_check"
|
|
||||||
facts: str = ""
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_facts(self):
|
|
||||||
if not self.facts.strip():
|
|
||||||
raise ValueError("Facts cannot be empty")
|
|
||||||
return self
|
|
||||||
|
|
||||||
class Chat(Session):
|
|
||||||
session_type: Literal["chat"] = "chat"
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_name(self):
|
|
||||||
return self
|
|
32
src/utils/setup_logging.py
Normal file
32
src/utils/setup_logging.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from . import defines
|
||||||
|
|
||||||
|
def setup_logging(level=defines.logging_level) -> logging.Logger:
|
||||||
|
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
||||||
|
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
|
||||||
|
warnings.filterwarnings("ignore", message="Warning only once for all operators")
|
||||||
|
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
|
||||||
|
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
|
||||||
|
warnings.filterwarnings("ignore", message="n_jobs value 1 overridden")
|
||||||
|
|
||||||
|
numeric_level = getattr(logging, level.upper(), None)
|
||||||
|
if not isinstance(numeric_level, int):
|
||||||
|
raise ValueError(f"Invalid log level: {level}")
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=numeric_level,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
force=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now reduce verbosity for FastAPI, Uvicorn, Starlette
|
||||||
|
for noisy_logger in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi", "starlette"):
|
||||||
|
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
return logger
|
Loading…
x
Reference in New Issue
Block a user