Reformatted all content to black
This commit is contained in:
parent
a1798b58ac
commit
e044f9c639
@ -88,6 +88,10 @@ button {
|
|||||||
flex-grow: 1;
|
flex-grow: 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.MessageContent div > p:first-child {
|
||||||
|
margin-top: 0;
|
||||||
|
}
|
||||||
|
|
||||||
.MenuCard.MuiCard-root {
|
.MenuCard.MuiCard-root {
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
|
@ -30,7 +30,6 @@ const BackstoryTextField = React.forwardRef<BackstoryTextFieldRef, BackstoryText
|
|||||||
const shadowRef = useRef<HTMLTextAreaElement>(null);
|
const shadowRef = useRef<HTMLTextAreaElement>(null);
|
||||||
const [editValue, setEditValue] = useState<string>(value);
|
const [editValue, setEditValue] = useState<string>(value);
|
||||||
|
|
||||||
console.log({ value, placeholder, editValue });
|
|
||||||
// Sync editValue with prop value if it changes externally
|
// Sync editValue with prop value if it changes externally
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setEditValue(value || "");
|
setEditValue(value || "");
|
||||||
|
@ -158,7 +158,7 @@ function ChatBubble(props: ChatBubbleProps) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Render Accordion for expandable content
|
// Render Accordion for expandable content
|
||||||
if (expandable || (role === 'content' && title)) {
|
if (expandable || title) {
|
||||||
// Determine if Accordion is controlled
|
// Determine if Accordion is controlled
|
||||||
const isControlled = typeof expanded === 'boolean' && typeof onExpand === 'function';
|
const isControlled = typeof expanded === 'boolean' && typeof onExpand === 'function';
|
||||||
|
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
import React, { useState, useImperativeHandle, forwardRef, useEffect, useRef, useCallback } from 'react';
|
import React, { useState, useImperativeHandle, forwardRef, useEffect, useRef, useCallback } from 'react';
|
||||||
import Typography from '@mui/material/Typography';
|
import Typography from '@mui/material/Typography';
|
||||||
import Tooltip from '@mui/material/Tooltip';
|
import Tooltip from '@mui/material/Tooltip';
|
||||||
|
import IconButton from '@mui/material/IconButton';
|
||||||
import Button from '@mui/material/Button';
|
import Button from '@mui/material/Button';
|
||||||
import Box from '@mui/material/Box';
|
import Box from '@mui/material/Box';
|
||||||
import SendIcon from '@mui/icons-material/Send';
|
import SendIcon from '@mui/icons-material/Send';
|
||||||
|
import CancelIcon from '@mui/icons-material/Cancel';
|
||||||
import { SxProps, Theme } from '@mui/material';
|
import { SxProps, Theme } from '@mui/material';
|
||||||
import PropagateLoader from "react-spinners/PropagateLoader";
|
import PropagateLoader from "react-spinners/PropagateLoader";
|
||||||
|
|
||||||
import { Message, MessageList, MessageData } from './Message';
|
import { Message, MessageList, BackstoryMessage } from './Message';
|
||||||
import { ContextStatus } from './ContextStatus';
|
import { ContextStatus } from './ContextStatus';
|
||||||
import { Scrollable } from './Scrollable';
|
import { Scrollable } from './Scrollable';
|
||||||
import { DeleteConfirmation } from './DeleteConfirmation';
|
import { DeleteConfirmation } from './DeleteConfirmation';
|
||||||
@ -17,7 +19,7 @@ import { BackstoryTextField, BackstoryTextFieldRef } from './BackstoryTextField'
|
|||||||
import { BackstoryElementProps } from './BackstoryTab';
|
import { BackstoryElementProps } from './BackstoryTab';
|
||||||
import { connectionBase } from './Global';
|
import { connectionBase } from './Global';
|
||||||
|
|
||||||
const loadingMessage: MessageData = { "role": "status", "content": "Establishing connection with server..." };
|
const loadingMessage: BackstoryMessage = { "role": "status", "content": "Establishing connection with server..." };
|
||||||
|
|
||||||
type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check';
|
type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check';
|
||||||
|
|
||||||
@ -25,24 +27,6 @@ interface ConversationHandle {
|
|||||||
submitQuery: (prompt: string, options?: QueryOptions) => void;
|
submitQuery: (prompt: string, options?: QueryOptions) => void;
|
||||||
fetchHistory: () => void;
|
fetchHistory: () => void;
|
||||||
}
|
}
|
||||||
interface BackstoryMessage {
|
|
||||||
prompt: string;
|
|
||||||
preamble: {};
|
|
||||||
status: string;
|
|
||||||
full_content: string;
|
|
||||||
response: string; // Set when status === 'done' or 'error'
|
|
||||||
chunk: string; // Used when status === 'streaming'
|
|
||||||
metadata: {
|
|
||||||
rag: { documents: [] };
|
|
||||||
tools: string[];
|
|
||||||
eval_count: number;
|
|
||||||
eval_duration: number;
|
|
||||||
prompt_eval_count: number;
|
|
||||||
prompt_eval_duration: number;
|
|
||||||
};
|
|
||||||
actions: string[];
|
|
||||||
timestamp: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
interface ConversationProps extends BackstoryElementProps {
|
interface ConversationProps extends BackstoryElementProps {
|
||||||
className?: string, // Override default className
|
className?: string, // Override default className
|
||||||
@ -59,7 +43,7 @@ interface ConversationProps extends BackstoryElementProps {
|
|||||||
messageFilter?: ((messages: MessageList) => MessageList) | undefined, // Filter callback to determine which Messages to display in Conversation
|
messageFilter?: ((messages: MessageList) => MessageList) | undefined, // Filter callback to determine which Messages to display in Conversation
|
||||||
messages?: MessageList, //
|
messages?: MessageList, //
|
||||||
sx?: SxProps<Theme>,
|
sx?: SxProps<Theme>,
|
||||||
onResponse?: ((message: MessageData) => void) | undefined, // Event called when a query completes (provides messages)
|
onResponse?: ((message: BackstoryMessage) => void) | undefined, // Event called when a query completes (provides messages)
|
||||||
};
|
};
|
||||||
|
|
||||||
const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||||
@ -87,8 +71,8 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
const [countdown, setCountdown] = useState<number>(0);
|
const [countdown, setCountdown] = useState<number>(0);
|
||||||
const [conversation, setConversation] = useState<MessageList>([]);
|
const [conversation, setConversation] = useState<MessageList>([]);
|
||||||
const [filteredConversation, setFilteredConversation] = useState<MessageList>([]);
|
const [filteredConversation, setFilteredConversation] = useState<MessageList>([]);
|
||||||
const [processingMessage, setProcessingMessage] = useState<MessageData | undefined>(undefined);
|
const [processingMessage, setProcessingMessage] = useState<BackstoryMessage | undefined>(undefined);
|
||||||
const [streamingMessage, setStreamingMessage] = useState<MessageData | undefined>(undefined);
|
const [streamingMessage, setStreamingMessage] = useState<BackstoryMessage | undefined>(undefined);
|
||||||
const timerRef = useRef<any>(null);
|
const timerRef = useRef<any>(null);
|
||||||
const [contextStatus, setContextStatus] = useState<ContextStatus>({ context_used: 0, max_context: 0 });
|
const [contextStatus, setContextStatus] = useState<ContextStatus>({ context_used: 0, max_context: 0 });
|
||||||
const [contextWarningShown, setContextWarningShown] = useState<boolean>(false);
|
const [contextWarningShown, setContextWarningShown] = useState<boolean>(false);
|
||||||
@ -96,6 +80,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
const conversationRef = useRef<MessageList>([]);
|
const conversationRef = useRef<MessageList>([]);
|
||||||
const viewableElementRef = useRef<HTMLDivElement>(null);
|
const viewableElementRef = useRef<HTMLDivElement>(null);
|
||||||
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
|
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
|
||||||
|
const stopRef = useRef(false);
|
||||||
|
|
||||||
// Keep the ref updated whenever items changes
|
// Keep the ref updated whenever items changes
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -181,14 +166,25 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
|
|
||||||
const backstoryMessages: BackstoryMessage[] = messages;
|
const backstoryMessages: BackstoryMessage[] = messages;
|
||||||
|
|
||||||
setConversation(backstoryMessages.flatMap((backstoryMessage: BackstoryMessage) => [{
|
setConversation(backstoryMessages.flatMap((backstoryMessage: BackstoryMessage) => {
|
||||||
role: 'user',
|
if (backstoryMessage.status === "partial") {
|
||||||
content: backstoryMessage.prompt || "",
|
return [{
|
||||||
}, {
|
...backstoryMessage,
|
||||||
...backstoryMessage,
|
role: "assistant",
|
||||||
role: backstoryMessage.status === "done" ? "assistant" : backstoryMessage.status,
|
content: backstoryMessage.response || "",
|
||||||
|
expanded: false,
|
||||||
|
expandable: true,
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
return [{
|
||||||
|
role: 'user',
|
||||||
|
content: backstoryMessage.prompt || "",
|
||||||
|
}, {
|
||||||
|
...backstoryMessage,
|
||||||
|
role: ['done'].includes(backstoryMessage.status || "") ? "assistant" : backstoryMessage.status,
|
||||||
content: backstoryMessage.response || "",
|
content: backstoryMessage.response || "",
|
||||||
}] as MessageList));
|
}] as MessageList;
|
||||||
|
}));
|
||||||
setNoInteractions(false);
|
setNoInteractions(false);
|
||||||
}
|
}
|
||||||
setProcessingMessage(undefined);
|
setProcessingMessage(undefined);
|
||||||
@ -294,6 +290,11 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const cancelQuery = () => {
|
||||||
|
console.log("Stop query");
|
||||||
|
stopRef.current = true;
|
||||||
|
};
|
||||||
|
|
||||||
const sendQuery = async (request: string, options?: QueryOptions) => {
|
const sendQuery = async (request: string, options?: QueryOptions) => {
|
||||||
request = request.trim();
|
request = request.trim();
|
||||||
|
|
||||||
@ -308,6 +309,8 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stopRef.current = false;
|
||||||
|
|
||||||
setNoInteractions(false);
|
setNoInteractions(false);
|
||||||
|
|
||||||
setConversation([
|
setConversation([
|
||||||
@ -325,12 +328,10 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
setProcessing(true);
|
setProcessing(true);
|
||||||
// Create a unique ID for the processing message
|
|
||||||
const processingId = Date.now().toString();
|
|
||||||
|
|
||||||
// Add initial processing message
|
// Add initial processing message
|
||||||
setProcessingMessage(
|
setProcessingMessage(
|
||||||
{ role: 'status', content: 'Submitting request...', id: processingId, isProcessing: true }
|
{ role: 'status', content: 'Submitting request...', disableCopy: true }
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add a small delay to ensure React has time to update the UI
|
// Add a small delay to ensure React has time to update the UI
|
||||||
@ -379,17 +380,20 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
|
|
||||||
switch (update.status) {
|
switch (update.status) {
|
||||||
case 'done':
|
case 'done':
|
||||||
console.log('Done processing:', update);
|
case 'partial':
|
||||||
stopCountdown();
|
if (update.status === 'done') stopCountdown();
|
||||||
setStreamingMessage(undefined);
|
if (update.status === 'done') setStreamingMessage(undefined);
|
||||||
setProcessingMessage(undefined);
|
if (update.status === 'done') setProcessingMessage(undefined);
|
||||||
const backstoryMessage: BackstoryMessage = update;
|
const backstoryMessage: BackstoryMessage = update;
|
||||||
setConversation([
|
setConversation([
|
||||||
...conversationRef.current, {
|
...conversationRef.current, {
|
||||||
...backstoryMessage,
|
...backstoryMessage,
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
origin: type,
|
origin: type,
|
||||||
|
prompt: ['done', 'partial'].includes(update.status) ? update.prompt : '',
|
||||||
content: backstoryMessage.response || "",
|
content: backstoryMessage.response || "",
|
||||||
|
expanded: update.status === "done" ? true : false,
|
||||||
|
expandable: true,
|
||||||
}] as MessageList);
|
}] as MessageList);
|
||||||
// Add a small delay to ensure React has time to update the UI
|
// Add a small delay to ensure React has time to update the UI
|
||||||
await new Promise(resolve => setTimeout(resolve, 0));
|
await new Promise(resolve => setTimeout(resolve, 0));
|
||||||
@ -424,9 +428,9 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
// Update processing message with immediate re-render
|
// Update processing message with immediate re-render
|
||||||
if (update.status === "streaming") {
|
if (update.status === "streaming") {
|
||||||
streaming_response += update.chunk
|
streaming_response += update.chunk
|
||||||
setStreamingMessage({ role: update.status, content: streaming_response });
|
setStreamingMessage({ role: update.status, content: streaming_response, disableCopy: true });
|
||||||
} else {
|
} else {
|
||||||
setProcessingMessage({ role: update.status, content: update.response });
|
setProcessingMessage({ role: update.status, content: update.response, disableCopy: true });
|
||||||
/* Reset stream on non streaming message */
|
/* Reset stream on non streaming message */
|
||||||
streaming_response = ""
|
streaming_response = ""
|
||||||
}
|
}
|
||||||
@ -437,12 +441,11 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
while (true) {
|
while (!stopRef.current) {
|
||||||
const { done, value } = await reader.read();
|
const { done, value } = await reader.read();
|
||||||
if (done) {
|
if (done) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
|
|
||||||
// Process each complete line immediately
|
// Process each complete line immediately
|
||||||
@ -470,26 +473,32 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (stopRef.current) {
|
||||||
|
await reader.cancel();
|
||||||
|
setProcessingMessage(undefined);
|
||||||
|
setStreamingMessage(undefined);
|
||||||
|
setSnack("Processing cancelled", "warning");
|
||||||
|
}
|
||||||
stopCountdown();
|
stopCountdown();
|
||||||
setProcessing(false);
|
setProcessing(false);
|
||||||
|
stopRef.current = false;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Fetch error:', error);
|
console.error('Fetch error:', error);
|
||||||
setSnack("Unable to process query", "error");
|
setSnack("Unable to process query", "error");
|
||||||
setProcessingMessage({ role: 'error', content: "Unable to process query" });
|
setProcessingMessage({ role: 'error', content: "Unable to process query", disableCopy: true });
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
setProcessingMessage(undefined);
|
setProcessingMessage(undefined);
|
||||||
}, 5000);
|
}, 5000);
|
||||||
|
stopRef.current = false;
|
||||||
setProcessing(false);
|
setProcessing(false);
|
||||||
stopCountdown();
|
stopCountdown();
|
||||||
// Add a small delay to ensure React has time to update the UI
|
return;
|
||||||
await new Promise(resolve => setTimeout(resolve, 0));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Scrollable
|
<Scrollable
|
||||||
className={className || "Conversation"}
|
className={`${className || ""} Conversation`}
|
||||||
autoscroll
|
autoscroll
|
||||||
textFieldRef={viewableElementRef}
|
textFieldRef={viewableElementRef}
|
||||||
fallbackThreshold={0.5}
|
fallbackThreshold={0.5}
|
||||||
@ -564,6 +573,20 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
</Button>
|
</Button>
|
||||||
</span>
|
</span>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
|
<Tooltip title="Cancel">
|
||||||
|
<span style={{ display: "flex" }}> { /* This span is used to wrap the IconButton to ensure Tooltip works even when disabled */}
|
||||||
|
<IconButton
|
||||||
|
aria-label="cancel"
|
||||||
|
onClick={() => { cancelQuery(); }}
|
||||||
|
sx={{ display: "flex", margin: 'auto 0px' }}
|
||||||
|
size="large"
|
||||||
|
edge="start"
|
||||||
|
disabled={stopRef.current || sessionId === undefined || processing === false}
|
||||||
|
>
|
||||||
|
<CancelIcon />
|
||||||
|
</IconButton>
|
||||||
|
</span>
|
||||||
|
</Tooltip>
|
||||||
</Box>
|
</Box>
|
||||||
</Box>
|
</Box>
|
||||||
{(noInteractions || !hideDefaultPrompts) && defaultPrompts !== undefined && defaultPrompts.length &&
|
{(noInteractions || !hideDefaultPrompts) && defaultPrompts !== undefined && defaultPrompts.length &&
|
||||||
|
@ -47,11 +47,18 @@ type MessageRoles =
|
|||||||
'thinking' |
|
'thinking' |
|
||||||
'user';
|
'user';
|
||||||
|
|
||||||
type MessageData = {
|
type BackstoryMessage = {
|
||||||
|
// Only two required fields
|
||||||
role: MessageRoles,
|
role: MessageRoles,
|
||||||
content: string,
|
content: string,
|
||||||
status?: string, // streaming, done, error...
|
// Rest are optional
|
||||||
response?: string,
|
prompt?: string;
|
||||||
|
preamble?: {};
|
||||||
|
status?: string;
|
||||||
|
full_content?: string;
|
||||||
|
response?: string; // Set when status === 'done', 'partial', or 'error'
|
||||||
|
chunk?: string; // Used when status === 'streaming'
|
||||||
|
timestamp?: string;
|
||||||
disableCopy?: boolean,
|
disableCopy?: boolean,
|
||||||
user?: string,
|
user?: string,
|
||||||
title?: string,
|
title?: string,
|
||||||
@ -84,11 +91,11 @@ interface MessageMetaData {
|
|||||||
setSnack: SetSnackType,
|
setSnack: SetSnackType,
|
||||||
}
|
}
|
||||||
|
|
||||||
type MessageList = MessageData[];
|
type MessageList = BackstoryMessage[];
|
||||||
|
|
||||||
interface MessageProps extends BackstoryElementProps {
|
interface MessageProps extends BackstoryElementProps {
|
||||||
sx?: SxProps<Theme>,
|
sx?: SxProps<Theme>,
|
||||||
message: MessageData,
|
message: BackstoryMessage,
|
||||||
expanded?: boolean,
|
expanded?: boolean,
|
||||||
onExpand?: (open: boolean) => void,
|
onExpand?: (open: boolean) => void,
|
||||||
className?: string,
|
className?: string,
|
||||||
@ -237,7 +244,7 @@ const MessageMeta = (props: MessageMetaProps) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const Message = (props: MessageProps) => {
|
const Message = (props: MessageProps) => {
|
||||||
const { message, submitQuery, sx, className, onExpand, expanded, sessionId, setSnack } = props;
|
const { message, submitQuery, sx, className, onExpand, sessionId, setSnack } = props;
|
||||||
const [metaExpanded, setMetaExpanded] = useState<boolean>(false);
|
const [metaExpanded, setMetaExpanded] = useState<boolean>(false);
|
||||||
const textFieldRef = useRef(null);
|
const textFieldRef = useRef(null);
|
||||||
|
|
||||||
@ -254,14 +261,16 @@ const Message = (props: MessageProps) => {
|
|||||||
return (<></>);
|
return (<></>);
|
||||||
}
|
}
|
||||||
|
|
||||||
const formattedContent = message.content.trim() || "Waiting for LLM to spool up...";
|
const formattedContent = message.content.trim();
|
||||||
|
if (formattedContent === "") {
|
||||||
|
return (<></>);
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ChatBubble
|
<ChatBubble
|
||||||
className={className || "Message"}
|
className={`${className || ""} Message Message-${message.role}`}
|
||||||
{...message}
|
{...message}
|
||||||
onExpand={onExpand}
|
onExpand={onExpand}
|
||||||
expanded={expanded}
|
|
||||||
sx={{
|
sx={{
|
||||||
display: "flex",
|
display: "flex",
|
||||||
flexDirection: "column",
|
flexDirection: "column",
|
||||||
@ -273,34 +282,24 @@ const Message = (props: MessageProps) => {
|
|||||||
...sx,
|
...sx,
|
||||||
}}>
|
}}>
|
||||||
<CardContent ref={textFieldRef} sx={{ position: "relative", display: "flex", flexDirection: "column", overflowX: "auto", m: 0, p: 0, paddingBottom: '0px !important' }}>
|
<CardContent ref={textFieldRef} sx={{ position: "relative", display: "flex", flexDirection: "column", overflowX: "auto", m: 0, p: 0, paddingBottom: '0px !important' }}>
|
||||||
{message.role !== 'user' ?
|
<Scrollable
|
||||||
<Scrollable
|
className="MessageContent"
|
||||||
className="MessageContent"
|
autoscroll
|
||||||
autoscroll
|
fallbackThreshold={0.5}
|
||||||
fallbackThreshold={0.5}
|
sx={{
|
||||||
sx={{
|
p: 0,
|
||||||
p: 0,
|
m: 0,
|
||||||
m: 0,
|
maxHeight: (message.role === "streaming") ? "20rem" : "unset",
|
||||||
maxHeight: (message.role === "streaming") ? "20rem" : "unset",
|
display: "flex",
|
||||||
display: "flex",
|
flexGrow: 1,
|
||||||
flexGrow: 1,
|
overflow: "auto", /* Handles scrolling for the div */
|
||||||
overflow: "auto", /* Handles scrolling for the div */
|
}}
|
||||||
}}
|
>
|
||||||
>
|
<StyledMarkdown streaming={message.role === "streaming"} {...{ content: formattedContent, submitQuery, sessionId, setSnack }} />
|
||||||
<StyledMarkdown streaming={message.role === "streaming"} {...{ content: formattedContent, submitQuery, sessionId, setSnack }} />
|
</Scrollable>
|
||||||
</Scrollable>
|
|
||||||
:
|
|
||||||
<Typography
|
|
||||||
className="MessageContent"
|
|
||||||
ref={textFieldRef}
|
|
||||||
variant="body2"
|
|
||||||
sx={{ display: "flex", color: 'text.secondary' }}>
|
|
||||||
{message.content}
|
|
||||||
</Typography>
|
|
||||||
}
|
|
||||||
</CardContent>
|
</CardContent>
|
||||||
<CardActions disableSpacing sx={{ display: "flex", flexDirection: "row", justifyContent: "space-between", alignItems: "center", width: "100%", p: 0, m: 0 }}>
|
<CardActions disableSpacing sx={{ display: "flex", flexDirection: "row", justifyContent: "space-between", alignItems: "center", width: "100%", p: 0, m: 0 }}>
|
||||||
{(message.disableCopy === undefined || message.disableCopy === false) && ["assistant", "content"].includes(message.role) && <CopyBubble content={message.content} />}
|
{(message.disableCopy === undefined || message.disableCopy === false) && <CopyBubble content={message.content} />}
|
||||||
{message.metadata && (
|
{message.metadata && (
|
||||||
<Box sx={{ display: "flex", alignItems: "center", gap: 1 }}>
|
<Box sx={{ display: "flex", alignItems: "center", gap: 1 }}>
|
||||||
<Button variant="text" onClick={handleMetaExpandClick} sx={{ color: "darkgrey", p: 0 }}>
|
<Button variant="text" onClick={handleMetaExpandClick} sx={{ color: "darkgrey", p: 0 }}>
|
||||||
@ -309,7 +308,7 @@ const Message = (props: MessageProps) => {
|
|||||||
<ExpandMore
|
<ExpandMore
|
||||||
expand={metaExpanded}
|
expand={metaExpanded}
|
||||||
onClick={handleMetaExpandClick}
|
onClick={handleMetaExpandClick}
|
||||||
aria-expanded={expanded}
|
aria-expanded={message.expanded}
|
||||||
aria-label="show more"
|
aria-label="show more"
|
||||||
>
|
>
|
||||||
<ExpandMoreIcon />
|
<ExpandMoreIcon />
|
||||||
@ -331,7 +330,8 @@ const Message = (props: MessageProps) => {
|
|||||||
export type {
|
export type {
|
||||||
MessageProps,
|
MessageProps,
|
||||||
MessageList,
|
MessageList,
|
||||||
MessageData,
|
BackstoryMessage,
|
||||||
|
MessageMetaData,
|
||||||
MessageRoles,
|
MessageRoles,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import {
|
|||||||
import { SxProps } from '@mui/material';
|
import { SxProps } from '@mui/material';
|
||||||
|
|
||||||
import { ChatQuery } from './ChatQuery';
|
import { ChatQuery } from './ChatQuery';
|
||||||
import { MessageList, MessageData } from './Message';
|
import { MessageList, BackstoryMessage } from './Message';
|
||||||
import { Conversation } from './Conversation';
|
import { Conversation } from './Conversation';
|
||||||
import { BackstoryPageProps } from './BackstoryTab';
|
import { BackstoryPageProps } from './BackstoryTab';
|
||||||
|
|
||||||
@ -62,11 +62,6 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
|||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (messages.length > 2) {
|
|
||||||
setHasResume(true);
|
|
||||||
setHasFacts(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (messages.length > 0) {
|
if (messages.length > 0) {
|
||||||
messages[0].role = 'content';
|
messages[0].role = 'content';
|
||||||
messages[0].title = 'Job Description';
|
messages[0].title = 'Job Description';
|
||||||
@ -74,6 +69,19 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
|||||||
messages[0].expandable = true;
|
messages[0].expandable = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (-1 !== messages.findIndex(m => m.status === 'done')) {
|
||||||
|
setHasResume(true);
|
||||||
|
setHasFacts(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages;
|
||||||
|
|
||||||
|
if (messages.length > 1) {
|
||||||
|
setHasResume(true);
|
||||||
|
setHasFacts(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if (messages.length > 3) {
|
if (messages.length > 3) {
|
||||||
// messages[2] is Show job requirements
|
// messages[2] is Show job requirements
|
||||||
messages[3].role = 'job-requirements';
|
messages[3].role = 'job-requirements';
|
||||||
@ -95,6 +103,8 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
|||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return messages;
|
||||||
|
|
||||||
if (messages.length > 1) {
|
if (messages.length > 1) {
|
||||||
// messages[0] is Show Qualifications
|
// messages[0] is Show Qualifications
|
||||||
messages[1].role = 'qualifications';
|
messages[1].role = 'qualifications';
|
||||||
@ -139,7 +149,7 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
|||||||
return filtered;
|
return filtered;
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const jobResponse = useCallback(async (message: MessageData) => {
|
const jobResponse = useCallback(async (message: BackstoryMessage) => {
|
||||||
console.log('onJobResponse', message);
|
console.log('onJobResponse', message);
|
||||||
if (message.actions && message.actions.includes("job_description")) {
|
if (message.actions && message.actions.includes("job_description")) {
|
||||||
await jobConversationRef.current.fetchHistory();
|
await jobConversationRef.current.fetchHistory();
|
||||||
@ -155,12 +165,12 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
|||||||
}
|
}
|
||||||
}, [setHasFacts, setHasResume, setActiveTab]);
|
}, [setHasFacts, setHasResume, setActiveTab]);
|
||||||
|
|
||||||
const resumeResponse = useCallback((message: MessageData): void => {
|
const resumeResponse = useCallback((message: BackstoryMessage): void => {
|
||||||
console.log('onResumeResponse', message);
|
console.log('onResumeResponse', message);
|
||||||
setHasFacts(true);
|
setHasFacts(true);
|
||||||
}, [setHasFacts]);
|
}, [setHasFacts]);
|
||||||
|
|
||||||
const factsResponse = useCallback((message: MessageData): void => {
|
const factsResponse = useCallback((message: BackstoryMessage): void => {
|
||||||
console.log('onFactsResponse', message);
|
console.log('onFactsResponse', message);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
@ -199,7 +209,8 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
|||||||
3. **Mapping Analysis**: Identifies legitimate matches between requirements and qualifications
|
3. **Mapping Analysis**: Identifies legitimate matches between requirements and qualifications
|
||||||
3. **Resume Generation**: Uses mapping output to create a tailored resume with evidence-based content
|
3. **Resume Generation**: Uses mapping output to create a tailored resume with evidence-based content
|
||||||
4. **Verification**: Performs fact-checking to catch any remaining fabrications
|
4. **Verification**: Performs fact-checking to catch any remaining fabrications
|
||||||
1. **Re-generation**: If verification does not pass, a second attempt is made to correct any issues`
|
1. **Re-generation**: If verification does not pass, a second attempt is made to correct any issues`,
|
||||||
|
disableCopy: true
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,3 +17,12 @@ pre:not(.MessageContent) {
|
|||||||
.MuiMarkdown > div {
|
.MuiMarkdown > div {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.Message-streaming .MuiMarkdown ul,
|
||||||
|
.Message-streaming .MuiMarkdown h1,
|
||||||
|
.Message-streaming .MuiMarkdown p,
|
||||||
|
.Message-assistant .MuiMarkdown ul,
|
||||||
|
.Message-assistant .MuiMarkdown h1,
|
||||||
|
.Message-assistant .MuiMarkdown p {
|
||||||
|
color: white;
|
||||||
|
}
|
@ -37,14 +37,14 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
|
|||||||
}
|
}
|
||||||
if (className === "lang-json" && !streaming) {
|
if (className === "lang-json" && !streaming) {
|
||||||
try {
|
try {
|
||||||
const fixed = jsonrepair(content);
|
let fixed = JSON.parse(jsonrepair(content));
|
||||||
return <Scrollable className="JsonViewScrollable">
|
return <Scrollable className="JsonViewScrollable">
|
||||||
<JsonView
|
<JsonView
|
||||||
className="JsonView"
|
className="JsonView"
|
||||||
style={{
|
style={{
|
||||||
...vscodeTheme,
|
...vscodeTheme,
|
||||||
fontSize: "0.8rem",
|
fontSize: "0.8rem",
|
||||||
maxHeight: "20rem",
|
maxHeight: "10rem",
|
||||||
padding: "14px 0",
|
padding: "14px 0",
|
||||||
overflow: "hidden",
|
overflow: "hidden",
|
||||||
width: "100%",
|
width: "100%",
|
||||||
@ -53,9 +53,9 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
|
|||||||
}}
|
}}
|
||||||
displayDataTypes={false}
|
displayDataTypes={false}
|
||||||
objectSortKeys={false}
|
objectSortKeys={false}
|
||||||
collapsed={false}
|
collapsed={true}
|
||||||
shortenTextAfterLength={100}
|
shortenTextAfterLength={100}
|
||||||
value={JSON.parse(fixed)}>
|
value={fixed}>
|
||||||
<JsonView.String
|
<JsonView.String
|
||||||
render={({ children, ...reset }) => {
|
render={({ children, ...reset }) => {
|
||||||
if (typeof (children) === "string" && children.match("\n")) {
|
if (typeof (children) === "string" && children.match("\n")) {
|
||||||
@ -66,7 +66,7 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
|
|||||||
</JsonView>
|
</JsonView>
|
||||||
</Scrollable>
|
</Scrollable>
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log("jsonrepair error", e);
|
return <pre><code className="JsonRaw">{content}</code></pre>
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
return <pre><code className={className}>{element.children}</code></pre>;
|
return <pre><code className={className}>{element.children}</code></pre>;
|
||||||
|
499
src/server.py
499
src/server.py
@ -1,8 +1,9 @@
|
|||||||
LLM_TIMEOUT=600
|
LLM_TIMEOUT = 600
|
||||||
|
|
||||||
from utils import logger
|
from utils import logger
|
||||||
|
from pydantic import BaseModel, Field # type: ignore
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, Dict, Optional
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Imports [standard]
|
# Imports [standard]
|
||||||
@ -26,6 +27,7 @@ from uuid import uuid4
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
def try_import(module_name, pip_name=None):
|
def try_import(module_name, pip_name=None):
|
||||||
try:
|
try:
|
||||||
__import__(module_name)
|
__import__(module_name)
|
||||||
@ -33,6 +35,7 @@ def try_import(module_name, pip_name=None):
|
|||||||
print(f"Module '{module_name}' not found. Install it using:")
|
print(f"Module '{module_name}' not found. Install it using:")
|
||||||
print(f" pip install {pip_name or module_name}")
|
print(f" pip install {pip_name or module_name}")
|
||||||
|
|
||||||
|
|
||||||
# Third-party modules with import checks
|
# Third-party modules with import checks
|
||||||
try_import("ollama")
|
try_import("ollama")
|
||||||
try_import("requests")
|
try_import("requests")
|
||||||
@ -47,23 +50,25 @@ try_import("prometheus_fastapi_instrumentator")
|
|||||||
import ollama
|
import ollama
|
||||||
import requests
|
import requests
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI, Request, BackgroundTasks # type: ignore
|
from fastapi import FastAPI, Request, BackgroundTasks # type: ignore
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse # type: ignore
|
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse # type: ignore
|
||||||
from fastapi.middleware.cors import CORSMiddleware # type: ignore
|
from fastapi.middleware.cors import CORSMiddleware # type: ignore
|
||||||
import uvicorn # type: ignore
|
import uvicorn # type: ignore
|
||||||
import numpy as np # type: ignore
|
import numpy as np # type: ignore
|
||||||
import umap # type: ignore
|
import umap # type: ignore
|
||||||
from sklearn.preprocessing import MinMaxScaler # type: ignore
|
from sklearn.preprocessing import MinMaxScaler # type: ignore
|
||||||
|
|
||||||
# Prometheus
|
# Prometheus
|
||||||
from prometheus_client import Summary # type: ignore
|
from prometheus_client import Summary # type: ignore
|
||||||
from prometheus_fastapi_instrumentator import Instrumentator # type: ignore
|
from prometheus_fastapi_instrumentator import Instrumentator # type: ignore
|
||||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
rag as Rag,
|
rag as Rag,
|
||||||
tools as Tools,
|
tools as Tools,
|
||||||
Context, Conversation, Message,
|
Context,
|
||||||
|
Conversation,
|
||||||
|
Message,
|
||||||
Agent,
|
Agent,
|
||||||
Metrics,
|
Metrics,
|
||||||
Tunables,
|
Tunables,
|
||||||
@ -71,14 +76,25 @@ from utils import (
|
|||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
CONTEXT_VERSION=2
|
CONTEXT_VERSION = 2
|
||||||
|
|
||||||
rags = [
|
rags = [
|
||||||
{ "name": "JPK", "enabled": True, "description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects." },
|
{
|
||||||
|
"name": "JPK",
|
||||||
|
"enabled": True,
|
||||||
|
"description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects.",
|
||||||
|
},
|
||||||
# { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
|
# { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
|
||||||
]
|
]
|
||||||
|
|
||||||
REQUEST_TIME = Summary('request_processing_seconds', 'Time spent processing request')
|
|
||||||
|
class QueryOptions(BaseModel):
|
||||||
|
prompt: str
|
||||||
|
tunables: Tunables = Field(default_factory=Tunables)
|
||||||
|
agent_options: Dict[str, Any] = Field(default={})
|
||||||
|
|
||||||
|
|
||||||
|
REQUEST_TIME = Summary("request_processing_seconds", "Time spent processing request")
|
||||||
|
|
||||||
system_message_old = f"""
|
system_message_old = f"""
|
||||||
Launched on {datetime.now().isoformat()}.
|
Launched on {datetime.now().isoformat()}.
|
||||||
@ -107,6 +123,7 @@ You are provided with a <|resume|> which was generated by you, the <|context|> y
|
|||||||
Your task is to answer questions about the <|fact_check|> you generated based on the <|resume|> and <|context>.
|
Your task is to answer questions about the <|fact_check|> you generated based on the <|resume|> and <|context>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_installed_ram():
|
def get_installed_ram():
|
||||||
try:
|
try:
|
||||||
with open("/proc/meminfo", "r") as f:
|
with open("/proc/meminfo", "r") as f:
|
||||||
@ -117,21 +134,29 @@ def get_installed_ram():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error retrieving RAM: {e}"
|
return f"Error retrieving RAM: {e}"
|
||||||
|
|
||||||
|
|
||||||
def get_graphics_cards():
|
def get_graphics_cards():
|
||||||
gpus = []
|
gpus = []
|
||||||
try:
|
try:
|
||||||
# Run the ze-monitor utility
|
# Run the ze-monitor utility
|
||||||
result = subprocess.run(["ze-monitor"], capture_output=True, text=True, check=True)
|
result = subprocess.run(
|
||||||
|
["ze-monitor"], capture_output=True, text=True, check=True
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up the output (remove leading/trailing whitespace and newlines)
|
# Clean up the output (remove leading/trailing whitespace and newlines)
|
||||||
output = result.stdout.strip()
|
output = result.stdout.strip()
|
||||||
for index in range(len(output.splitlines())):
|
for index in range(len(output.splitlines())):
|
||||||
result = subprocess.run(["ze-monitor", "--device", f"{index+1}", "--info"], capture_output=True, text=True, check=True)
|
result = subprocess.run(
|
||||||
|
["ze-monitor", "--device", f"{index+1}", "--info"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
gpu_info = result.stdout.strip().splitlines()
|
gpu_info = result.stdout.strip().splitlines()
|
||||||
gpu = {
|
gpu = {
|
||||||
"discrete": True, # Assume it's discrete initially
|
"discrete": True, # Assume it's discrete initially
|
||||||
"name": None,
|
"name": None,
|
||||||
"memory": None
|
"memory": None,
|
||||||
}
|
}
|
||||||
gpus.append(gpu)
|
gpus.append(gpu)
|
||||||
for line in gpu_info:
|
for line in gpu_info:
|
||||||
@ -154,6 +179,7 @@ def get_graphics_cards():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error retrieving GPU info: {e}"
|
return f"Error retrieving GPU info: {e}"
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_info():
|
def get_cpu_info():
|
||||||
try:
|
try:
|
||||||
with open("/proc/cpuinfo", "r") as f:
|
with open("/proc/cpuinfo", "r") as f:
|
||||||
@ -165,6 +191,7 @@ def get_cpu_info():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error retrieving CPU info: {e}"
|
return f"Error retrieving CPU info: {e}"
|
||||||
|
|
||||||
|
|
||||||
def system_info(model):
|
def system_info(model):
|
||||||
return {
|
return {
|
||||||
"System RAM": get_installed_ram(),
|
"System RAM": get_installed_ram(),
|
||||||
@ -172,18 +199,19 @@ def system_info(model):
|
|||||||
"CPU": get_cpu_info(),
|
"CPU": get_cpu_info(),
|
||||||
"LLM Model": model,
|
"LLM Model": model,
|
||||||
"Embedding Model": defines.embedding_model,
|
"Embedding Model": defines.embedding_model,
|
||||||
"Context length": defines.max_context
|
"Context length": defines.max_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Defaults
|
# Defaults
|
||||||
OLLAMA_API_URL = defines.ollama_api_url
|
OLLAMA_API_URL = defines.ollama_api_url
|
||||||
MODEL_NAME = defines.model
|
MODEL_NAME = defines.model
|
||||||
LOG_LEVEL="info"
|
LOG_LEVEL = "info"
|
||||||
USE_TLS=False
|
USE_TLS = False
|
||||||
WEB_HOST="0.0.0.0"
|
WEB_HOST = "0.0.0.0"
|
||||||
WEB_PORT=8911
|
WEB_PORT = 8911
|
||||||
DEFAULT_HISTORY_LENGTH=5
|
DEFAULT_HISTORY_LENGTH = 5
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Globals
|
# Globals
|
||||||
@ -192,31 +220,58 @@ DEFAULT_HISTORY_LENGTH=5
|
|||||||
def create_system_message(prompt):
|
def create_system_message(prompt):
|
||||||
return [{"role": "system", "content": prompt}]
|
return [{"role": "system", "content": prompt}]
|
||||||
|
|
||||||
|
|
||||||
tool_log = []
|
tool_log = []
|
||||||
command_log = []
|
command_log = []
|
||||||
model = None
|
model = None
|
||||||
client = None
|
client = None
|
||||||
web_server = None
|
web_server = None
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Cmd line overrides
|
# Cmd line overrides
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
||||||
parser.add_argument("--ollama-server", type=str, default=OLLAMA_API_URL, help=f"Ollama API endpoint. default={OLLAMA_API_URL}")
|
parser.add_argument(
|
||||||
parser.add_argument("--ollama-model", type=str, default=MODEL_NAME, help=f"LLM model to use. default={MODEL_NAME}")
|
"--ollama-server",
|
||||||
parser.add_argument("--web-host", type=str, default=WEB_HOST, help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.")
|
type=str,
|
||||||
parser.add_argument("--web-port", type=str, default=WEB_PORT, help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.")
|
default=OLLAMA_API_URL,
|
||||||
parser.add_argument("--level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
help=f"Ollama API endpoint. default={OLLAMA_API_URL}",
|
||||||
default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ollama-model",
|
||||||
|
type=str,
|
||||||
|
default=MODEL_NAME,
|
||||||
|
help=f"LLM model to use. default={MODEL_NAME}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--web-host",
|
||||||
|
type=str,
|
||||||
|
default=WEB_HOST,
|
||||||
|
help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--web-port",
|
||||||
|
type=str,
|
||||||
|
default=WEB_PORT,
|
||||||
|
help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--level",
|
||||||
|
type=str,
|
||||||
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||||
|
default=LOG_LEVEL,
|
||||||
|
help=f"Set the logging level. default={LOG_LEVEL}",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
def is_valid_uuid(value: str) -> bool:
|
def is_valid_uuid(value: str) -> bool:
|
||||||
try:
|
try:
|
||||||
@ -226,11 +281,6 @@ def is_valid_uuid(value: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
class WebServer:
|
class WebServer:
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@ -239,9 +289,11 @@ class WebServer:
|
|||||||
self.observer, self.file_watcher = Rag.start_file_watcher(
|
self.observer, self.file_watcher = Rag.start_file_watcher(
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
watch_directory=defines.doc_dir,
|
watch_directory=defines.doc_dir,
|
||||||
recreate=False # Don't recreate if exists
|
recreate=False, # Don't recreate if exists
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"API started with {self.file_watcher.collection.count()} documents in the collection"
|
||||||
)
|
)
|
||||||
logger.info(f"API started with {self.file_watcher.collection.count()} documents in the collection")
|
|
||||||
yield
|
yield
|
||||||
if self.observer:
|
if self.observer:
|
||||||
self.observer.stop()
|
self.observer.stop()
|
||||||
@ -271,12 +323,14 @@ class WebServer:
|
|||||||
self.file_watcher = None
|
self.file_watcher = None
|
||||||
self.observer = None
|
self.observer = None
|
||||||
|
|
||||||
self.ssl_enabled = os.path.exists(defines.key_path) and os.path.exists(defines.cert_path)
|
self.ssl_enabled = os.path.exists(defines.key_path) and os.path.exists(
|
||||||
|
defines.cert_path
|
||||||
|
)
|
||||||
|
|
||||||
if self.ssl_enabled:
|
if self.ssl_enabled:
|
||||||
allow_origins=["https://battle-linux.ketrenos.com:3000"]
|
allow_origins = ["https://battle-linux.ketrenos.com:3000"]
|
||||||
else:
|
else:
|
||||||
allow_origins=["http://battle-linux.ketrenos.com:3000"]
|
allow_origins = ["http://battle-linux.ketrenos.com:3000"]
|
||||||
|
|
||||||
logger.info(f"Allowed origins: {allow_origins}")
|
logger.info(f"Allowed origins: {allow_origins}")
|
||||||
|
|
||||||
@ -296,7 +350,7 @@ class WebServer:
|
|||||||
context = self.create_context()
|
context = self.create_context()
|
||||||
logger.info(f"Redirecting non-context to {context.id}")
|
logger.info(f"Redirecting non-context to {context.id}")
|
||||||
return RedirectResponse(url=f"/{context.id}", status_code=307)
|
return RedirectResponse(url=f"/{context.id}", status_code=307)
|
||||||
#return JSONResponse({"redirect": f"/{context.id}"})
|
# return JSONResponse({"redirect": f"/{context.id}"})
|
||||||
|
|
||||||
@self.app.put("/api/umap/{context_id}")
|
@self.app.put("/api/umap/{context_id}")
|
||||||
async def put_umap(context_id: str, request: Request):
|
async def put_umap(context_id: str, request: Request):
|
||||||
@ -307,14 +361,18 @@ class WebServer:
|
|||||||
|
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
if not context:
|
if not context:
|
||||||
return JSONResponse({"error": f"Invalid context: {context_id}"}, status_code=400)
|
return JSONResponse(
|
||||||
|
{"error": f"Invalid context: {context_id}"}, status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
|
||||||
dimensions = data.get("dimensions", 2)
|
dimensions = data.get("dimensions", 2)
|
||||||
result = self.file_watcher.umap_collection
|
result = self.file_watcher.umap_collection
|
||||||
if not result:
|
if not result:
|
||||||
return JSONResponse({"error": "No UMAP collection found"}, status_code=404)
|
return JSONResponse(
|
||||||
|
{"error": "No UMAP collection found"}, status_code=404
|
||||||
|
)
|
||||||
if dimensions == 2:
|
if dimensions == 2:
|
||||||
logger.info("Returning 2D UMAP")
|
logger.info("Returning 2D UMAP")
|
||||||
umap_embedding = self.file_watcher.umap_embedding_2d
|
umap_embedding = self.file_watcher.umap_embedding_2d
|
||||||
@ -323,7 +381,9 @@ class WebServer:
|
|||||||
umap_embedding = self.file_watcher.umap_embedding_3d
|
umap_embedding = self.file_watcher.umap_embedding_3d
|
||||||
|
|
||||||
if len(umap_embedding) == 0:
|
if len(umap_embedding) == 0:
|
||||||
return JSONResponse({"error": "No UMAP embedding found"}, status_code=404)
|
return JSONResponse(
|
||||||
|
{"error": "No UMAP embedding found"}, status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
result["embeddings"] = umap_embedding.tolist()
|
result["embeddings"] = umap_embedding.tolist()
|
||||||
|
|
||||||
@ -347,35 +407,56 @@ class WebServer:
|
|||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
query = data.get("query", "")
|
query = data.get("query", "")
|
||||||
|
threshold = data.get("threshold", 0.5)
|
||||||
|
results = data.get("results", 10)
|
||||||
except:
|
except:
|
||||||
query = ""
|
query = ""
|
||||||
|
threshold = 0.5
|
||||||
|
results = 10
|
||||||
if not query:
|
if not query:
|
||||||
return JSONResponse({"error": "No query provided for similarity search"}, status_code=400)
|
return JSONResponse(
|
||||||
|
{"error": "No query provided for similarity search"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
chroma_results = self.file_watcher.find_similar(query=query, top_k=10)
|
chroma_results = self.file_watcher.find_similar(
|
||||||
|
query=query, top_k=results, threshold=threshold
|
||||||
|
)
|
||||||
if not chroma_results:
|
if not chroma_results:
|
||||||
return JSONResponse({"error": "No results found"}, status_code=404)
|
return JSONResponse({"error": "No results found"}, status_code=404)
|
||||||
|
|
||||||
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
chroma_embedding = np.array(
|
||||||
|
chroma_results["query_embedding"]
|
||||||
|
).flatten() # Ensure correct shape
|
||||||
logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||||
|
|
||||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[
|
||||||
logger.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
0
|
||||||
|
].tolist()
|
||||||
|
logger.info(
|
||||||
|
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
|
||||||
|
) # Debug output
|
||||||
|
|
||||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[
|
||||||
logger.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
0
|
||||||
|
].tolist()
|
||||||
|
logger.info(
|
||||||
|
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
|
||||||
|
) # Debug output
|
||||||
|
|
||||||
return JSONResponse({
|
return JSONResponse(
|
||||||
**chroma_results,
|
{
|
||||||
"query": query,
|
**chroma_results,
|
||||||
"umap_embedding_2d": umap_2d,
|
"query": query,
|
||||||
"umap_embedding_3d": umap_3d
|
"umap_embedding_2d": umap_2d,
|
||||||
})
|
"umap_embedding_3d": umap_3d,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
#return JSONResponse({"error": str(e)}, 500)
|
logging.error(traceback.format_exc())
|
||||||
|
return JSONResponse({"error": str(e)}, 500)
|
||||||
|
|
||||||
@self.app.put("/api/reset/{context_id}/{agent_type}")
|
@self.app.put("/api/reset/{context_id}/{agent_type}")
|
||||||
async def put_reset(context_id: str, agent_type: str, request: Request):
|
async def put_reset(context_id: str, agent_type: str, request: Request):
|
||||||
@ -386,7 +467,10 @@ class WebServer:
|
|||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent(agent_type)
|
agent = context.get_agent(agent_type)
|
||||||
if not agent:
|
if not agent:
|
||||||
return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404)
|
return JSONResponse(
|
||||||
|
{"error": f"{agent_type} is not recognized", "context": context.id},
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
try:
|
try:
|
||||||
@ -405,20 +489,29 @@ class WebServer:
|
|||||||
response["tools"] = context.tools
|
response["tools"] = context.tools
|
||||||
case "history":
|
case "history":
|
||||||
reset_map = {
|
reset_map = {
|
||||||
"job_description": ("job_description", "resume", "fact_check"),
|
"job_description": (
|
||||||
|
"job_description",
|
||||||
|
"resume",
|
||||||
|
"fact_check",
|
||||||
|
),
|
||||||
"resume": ("job_description", "resume", "fact_check"),
|
"resume": ("job_description", "resume", "fact_check"),
|
||||||
"fact_check": ("job_description", "resume", "fact_check"),
|
"fact_check": (
|
||||||
|
"job_description",
|
||||||
|
"resume",
|
||||||
|
"fact_check",
|
||||||
|
),
|
||||||
"chat": ("chat",),
|
"chat": ("chat",),
|
||||||
}
|
}
|
||||||
resets = reset_map.get(agent_type, ())
|
resets = reset_map.get(agent_type, ())
|
||||||
|
|
||||||
for mode in resets:
|
for mode in resets:
|
||||||
tmp = context.get_agent(mode)
|
tmp = context.get_agent(mode)
|
||||||
if not tmp:
|
if not tmp:
|
||||||
|
logger.info(
|
||||||
|
f"Agent {mode} not found for {context_id}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
logger.info(f"Resetting {reset_operation} for {mode}")
|
logger.info(f"Resetting {reset_operation} for {mode}")
|
||||||
context.conversation = Conversation()
|
tmp.conversation.reset()
|
||||||
context.context_tokens = round(len(str(agent.system_prompt)) * 3 / 4) # Estimate context usage
|
|
||||||
response["history"] = []
|
response["history"] = []
|
||||||
response["context_used"] = agent.context_tokens
|
response["context_used"] = agent.context_tokens
|
||||||
case "message_history_length":
|
case "message_history_length":
|
||||||
@ -427,13 +520,19 @@ class WebServer:
|
|||||||
response["message_history_length"] = DEFAULT_HISTORY_LENGTH
|
response["message_history_length"] = DEFAULT_HISTORY_LENGTH
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
return JSONResponse({ "error": "Usage: { reset: rags|tools|history|system_prompt}"})
|
return JSONResponse(
|
||||||
|
{"error": "Usage: { reset: rags|tools|history|system_prompt}"}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
except:
|
except Exception as e:
|
||||||
return JSONResponse({ "error": "Usage: { reset: rags|tools|history|system_prompt}"})
|
logger.error(f"Error in reset: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "Usage: { reset: rags|tools|history|system_prompt}"}
|
||||||
|
)
|
||||||
|
|
||||||
@self.app.put("/api/tunables/{context_id}")
|
@self.app.put("/api/tunables/{context_id}")
|
||||||
async def put_tunables(context_id: str, request: Request):
|
async def put_tunables(context_id: str, request: Request):
|
||||||
@ -444,50 +543,77 @@ class WebServer:
|
|||||||
data = await request.json()
|
data = await request.json()
|
||||||
agent = context.get_agent("chat")
|
agent = context.get_agent("chat")
|
||||||
if not agent:
|
if not agent:
|
||||||
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
return JSONResponse(
|
||||||
|
{"error": f"chat is not recognized", "context": context.id},
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
for k in data.keys():
|
for k in data.keys():
|
||||||
match k:
|
match k:
|
||||||
case "tools":
|
case "tools":
|
||||||
# { "tools": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
# { "tools": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
||||||
tools : list[dict[str, Any]] = data[k]
|
tools: list[dict[str, Any]] = data[k]
|
||||||
if not tools:
|
if not tools:
|
||||||
return JSONResponse({ "status": "error", "message": "Tools can not be empty." })
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"status": "error",
|
||||||
|
"message": "Tools can not be empty.",
|
||||||
|
}
|
||||||
|
)
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
for context_tool in context.tools:
|
for context_tool in context.tools:
|
||||||
if context_tool["function"]["name"] == tool["name"]:
|
if context_tool["function"]["name"] == tool["name"]:
|
||||||
context_tool["enabled"] = tool["enabled"]
|
context_tool["enabled"] = tool["enabled"]
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse({ "tools": [ {
|
return JSONResponse(
|
||||||
**t["function"],
|
{
|
||||||
"enabled": t["enabled"],
|
"tools": [
|
||||||
} for t in context.tools] })
|
{
|
||||||
|
**t["function"],
|
||||||
|
"enabled": t["enabled"],
|
||||||
|
}
|
||||||
|
for t in context.tools
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
case "rags":
|
case "rags":
|
||||||
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
||||||
rags : list[dict[str, Any]] = data[k]
|
rags: list[dict[str, Any]] = data[k]
|
||||||
if not rags:
|
if not rags:
|
||||||
return JSONResponse({ "status": "error", "message": "RAGs can not be empty." })
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"status": "error",
|
||||||
|
"message": "RAGs can not be empty.",
|
||||||
|
}
|
||||||
|
)
|
||||||
for rag in rags:
|
for rag in rags:
|
||||||
for context_rag in context.rags:
|
for context_rag in context.rags:
|
||||||
if context_rag["name"] == rag["name"]:
|
if context_rag["name"] == rag["name"]:
|
||||||
context_rag["enabled"] = rag["enabled"]
|
context_rag["enabled"] = rag["enabled"]
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse({ "rags": context.rags })
|
return JSONResponse({"rags": context.rags})
|
||||||
|
|
||||||
case "system_prompt":
|
case "system_prompt":
|
||||||
system_prompt = data[k].strip()
|
system_prompt = data[k].strip()
|
||||||
if not system_prompt:
|
if not system_prompt:
|
||||||
return JSONResponse({ "status": "error", "message": "System prompt can not be empty." })
|
return JSONResponse(
|
||||||
agent.system_prompt = system_prompt
|
{
|
||||||
|
"status": "error",
|
||||||
|
"message": "System prompt can not be empty.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
agent.system_prompt = system_prompt
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse({ "system_prompt": system_prompt })
|
return JSONResponse({"system_prompt": system_prompt})
|
||||||
case "message_history_length":
|
case "message_history_length":
|
||||||
value = max(0, int(data[k]))
|
value = max(0, int(data[k]))
|
||||||
context.message_history_length = value
|
context.message_history_length = value
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse({ "message_history_length": value })
|
return JSONResponse({"message_history_length": value})
|
||||||
case _:
|
case _:
|
||||||
return JSONResponse({ "error": f"Unrecognized tunable {k}"}, status_code=404)
|
return JSONResponse(
|
||||||
|
{"error": f"Unrecognized tunable {k}"}, status_code=404
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in put_tunables: {e}")
|
logger.error(f"Error in put_tunables: {e}")
|
||||||
return JSONResponse({"error": str(e)}, status_code=500)
|
return JSONResponse({"error": str(e)}, status_code=500)
|
||||||
@ -501,85 +627,94 @@ class WebServer:
|
|||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent("chat")
|
agent = context.get_agent("chat")
|
||||||
if not agent:
|
if not agent:
|
||||||
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
return JSONResponse(
|
||||||
return JSONResponse({
|
{"error": f"chat is not recognized", "context": context.id},
|
||||||
"system_prompt": agent.system_prompt,
|
status_code=404,
|
||||||
"message_history_length": context.message_history_length,
|
)
|
||||||
"rags": context.rags,
|
return JSONResponse(
|
||||||
"tools": [ {
|
{
|
||||||
**t["function"],
|
"system_prompt": agent.system_prompt,
|
||||||
"enabled": t["enabled"],
|
"message_history_length": context.message_history_length,
|
||||||
} for t in context.tools ]
|
"rags": context.rags,
|
||||||
})
|
"tools": [
|
||||||
|
{
|
||||||
|
**t["function"],
|
||||||
|
"enabled": t["enabled"],
|
||||||
|
}
|
||||||
|
for t in context.tools
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@self.app.get("/api/system-info/{context_id}")
|
@self.app.get("/api/system-info/{context_id}")
|
||||||
async def get_system_info(context_id: str, request: Request):
|
async def get_system_info(context_id: str, request: Request):
|
||||||
logger.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
return JSONResponse(system_info(self.model))
|
return JSONResponse(system_info(self.model))
|
||||||
|
|
||||||
@self.app.post("/api/chat/{context_id}/{agent_type}")
|
@self.app.post("/api/{agent_type}/{context_id}")
|
||||||
async def post_chat_endpoint(context_id: str, agent_type: str, request: Request):
|
async def post_agent_endpoint(
|
||||||
|
agent_type: str, context_id: str, request: Request
|
||||||
|
):
|
||||||
logger.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
if not is_valid_uuid(context_id):
|
|
||||||
logger.warning(f"Invalid context_id: {context_id}")
|
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent(agent_type)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Attempt to create agent type: {agent_type} failed", e)
|
error = {
|
||||||
return JSONResponse({"error": f"{agent_type} is not recognized or context {context_id} is invalid "}, status_code=404)
|
"error": f"Unable to create or access context {context_id}: {e}"
|
||||||
|
}
|
||||||
|
logger.info(error)
|
||||||
|
return JSONResponse(error, status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = await request.json()
|
data = await request.json()
|
||||||
prompt = query["prompt"]
|
query: QueryOptions = QueryOptions(**data)
|
||||||
if not isinstance(prompt, str) or len(prompt) == 0:
|
|
||||||
logger.info(f"Prompt is empty")
|
|
||||||
return JSONResponse({"error": "Prompt cannot be empty"}, status_code=400)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Attempt to parse request: {str(e)}.")
|
error = {"error": f"Attempt to parse request: {e}"}
|
||||||
return JSONResponse({"error": f"Attempt to parse request: {str(e)}."}, status_code=400)
|
logger.info(error)
|
||||||
|
return JSONResponse(error, status_code=400)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
options = Tunables(**query["options"]) if "options" in query else None
|
agent = context.get_or_create_agent(agent_type, **query.agent_options)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Attempt to set tunables failed: {query['options']}.", e)
|
error = {
|
||||||
return JSONResponse({"error": f"Invalid options: {query['options']}"}, status_code=400)
|
"error": f"Attempt to create agent type: {agent_type} failed: {e}"
|
||||||
|
}
|
||||||
|
return JSONResponse(error, status_code=404)
|
||||||
|
|
||||||
if not agent:
|
|
||||||
match agent_type:
|
|
||||||
case "job_description":
|
|
||||||
logger.info(f"Agent {agent_type} not found. Returning empty history.")
|
|
||||||
agent = context.get_or_create_agent("job_description", job_description=prompt)
|
|
||||||
case _:
|
|
||||||
logger.info(f"Invalid agent creation sequence for {agent_type}. Returning error.")
|
|
||||||
return JSONResponse({"error": f"{agent_type} is not recognized", "context": context.id}, status_code=404)
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
async def flush_generator():
|
async def flush_generator():
|
||||||
logger.info(f"{agent.agent_type} - {inspect.stack()[0].function}")
|
logger.info(f"{agent.agent_type} - {inspect.stack()[0].function}")
|
||||||
try:
|
try:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
async for message in self.generate_response(context=context, agent=agent, prompt=prompt, options=options):
|
async for message in self.generate_response(
|
||||||
if message.status != "done":
|
context=context,
|
||||||
|
agent=agent,
|
||||||
|
prompt=query.prompt,
|
||||||
|
options=query.options,
|
||||||
|
):
|
||||||
|
if message.status != "done" and message.status != "partial":
|
||||||
if message.status == "streaming":
|
if message.status == "streaming":
|
||||||
result = {
|
result = {
|
||||||
"status": "streaming",
|
"status": "streaming",
|
||||||
"chunk": message.chunk,
|
"chunk": message.chunk,
|
||||||
"remaining_time": LLM_TIMEOUT - (time.perf_counter() - start_time)
|
"remaining_time": LLM_TIMEOUT
|
||||||
|
- (time.perf_counter() - start_time),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
result = {
|
result = {
|
||||||
"status": message.status,
|
"status": message.status,
|
||||||
"response": message.response,
|
"response": message.response,
|
||||||
"remaining_time": LLM_TIMEOUT
|
"remaining_time": LLM_TIMEOUT,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
logger.info(f"Message complete. Providing full response.")
|
logger.info(f"Providing {message.status} response.")
|
||||||
try:
|
try:
|
||||||
message.response = message.response
|
result = message.model_dump(
|
||||||
result = message.model_dump(by_alias=True, mode='json')
|
by_alias=True, mode="json"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result = {"status": "error", "response": str(e)}
|
result = {"status": "error", "response": str(e)}
|
||||||
yield json.dumps(result) + "\n"
|
yield json.dumps(result) + "\n"
|
||||||
@ -589,26 +724,27 @@ class WebServer:
|
|||||||
result = json.dumps(result) + "\n"
|
result = json.dumps(result) + "\n"
|
||||||
message.network_packets += 1
|
message.network_packets += 1
|
||||||
message.network_bytes += len(result)
|
message.network_bytes += len(result)
|
||||||
|
yield result
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
logger.info("Disconnect detected. Aborting generation.")
|
logger.info("Disconnect detected. Aborting generation.")
|
||||||
context.processing = False
|
context.processing = False
|
||||||
# Save context on completion or error
|
# Save context on completion or error
|
||||||
message.prompt = prompt
|
message.prompt = query.prompt
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
message.response = "Client disconnected during generation."
|
message.response = (
|
||||||
|
"Client disconnected during generation."
|
||||||
|
)
|
||||||
agent.conversation.add(message)
|
agent.conversation.add(message)
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
yield result
|
|
||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
if current_time - start_time > LLM_TIMEOUT:
|
if current_time - start_time > LLM_TIMEOUT:
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
message.response = f"Processing time ({LLM_TIMEOUT}s) exceeded for single LLM inference (likely due to LLM getting stuck.) You will need to retry your query."
|
message.response = f"Processing time ({LLM_TIMEOUT}s) exceeded for single LLM inference (likely due to LLM getting stuck.) You will need to retry your query."
|
||||||
message.partial_response = message.response
|
message.partial_response = message.response
|
||||||
logger.info(message.response + " Ending session")
|
logger.info(message.response + " Ending session")
|
||||||
result = message.model_dump(by_alias=True, mode='json')
|
result = message.model_dump(by_alias=True, mode="json")
|
||||||
result = json.dumps(result) + "\n"
|
result = json.dumps(result) + "\n"
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
@ -620,7 +756,7 @@ class WebServer:
|
|||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
context.processing = False
|
context.processing = False
|
||||||
logger.error(f"Error in process_generator: {e}")
|
logger.error(f"Error in generate_response: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
yield json.dumps({"status": "error", "response": str(e)}) + "\n"
|
yield json.dumps({"status": "error", "response": str(e)}) + "\n"
|
||||||
finally:
|
finally:
|
||||||
@ -634,8 +770,8 @@ class WebServer:
|
|||||||
headers={
|
headers={
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"X-Accel-Buffering": "no" # Prevents Nginx buffering if you're using it
|
"X-Accel-Buffering": "no", # Prevents Nginx buffering if you're using it
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
context.processing = False
|
context.processing = False
|
||||||
@ -647,7 +783,7 @@ class WebServer:
|
|||||||
try:
|
try:
|
||||||
context = self.create_context()
|
context = self.create_context()
|
||||||
logger.info(f"Generated new agent as {context.id}")
|
logger.info(f"Generated new agent as {context.id}")
|
||||||
return JSONResponse({ "id": context.id })
|
return JSONResponse({"id": context.id})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"get_history error: {str(e)}")
|
logger.error(f"get_history error: {str(e)}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
@ -660,13 +796,18 @@ class WebServer:
|
|||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent(agent_type)
|
agent = context.get_agent(agent_type)
|
||||||
if not agent:
|
if not agent:
|
||||||
logger.info(f"Agent {agent_type} not found. Returning empty history.")
|
logger.info(
|
||||||
return JSONResponse({ "messages": [] })
|
f"Agent {agent_type} not found. Returning empty history."
|
||||||
logger.info(f"History for {agent_type} contains {len(agent.conversation)} entries.")
|
)
|
||||||
|
return JSONResponse({"messages": []})
|
||||||
|
logger.info(
|
||||||
|
f"History for {agent_type} contains {len(agent.conversation)} entries."
|
||||||
|
)
|
||||||
return agent.conversation
|
return agent.conversation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"get_history error: {str(e)}")
|
logger.error(f"get_history error: {str(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return JSONResponse({"error": str(e)}, status_code=404)
|
return JSONResponse({"error": str(e)}, status_code=404)
|
||||||
|
|
||||||
@ -692,10 +833,11 @@ class WebServer:
|
|||||||
tool["enabled"] = enabled
|
tool["enabled"] = enabled
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse(context.tools)
|
return JSONResponse(context.tools)
|
||||||
return JSONResponse({ "status": f"{modify} not found in tools." }, status_code=404)
|
return JSONResponse(
|
||||||
|
{"status": f"{modify} not found in tools."}, status_code=404
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
return JSONResponse({ "status": "error" }, 405)
|
return JSONResponse({"status": "error"}, 405)
|
||||||
|
|
||||||
|
|
||||||
@self.app.get("/api/context-status/{context_id}/{agent_type}")
|
@self.app.get("/api/context-status/{context_id}/{agent_type}")
|
||||||
async def get_context_status(context_id, agent_type: str, request: Request):
|
async def get_context_status(context_id, agent_type: str, request: Request):
|
||||||
@ -706,8 +848,15 @@ class WebServer:
|
|||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent(agent_type)
|
agent = context.get_agent(agent_type)
|
||||||
if not agent:
|
if not agent:
|
||||||
return JSONResponse({"context_used": 0, "max_context": defines.max_context})
|
return JSONResponse(
|
||||||
return JSONResponse({"context_used": agent.context_tokens, "max_context": defines.max_context})
|
{"context_used": 0, "max_context": defines.max_context}
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"context_used": agent.context_tokens,
|
||||||
|
"max_context": defines.max_context,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@self.app.get("/api/health")
|
@self.app.get("/api/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
@ -770,8 +919,11 @@ class WebServer:
|
|||||||
# Read and deserialize the data
|
# Read and deserialize the data
|
||||||
with open(file_path, "r") as f:
|
with open(file_path, "r") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
logger.info(f"Loading context from {file_path}, content length: {len(content)}")
|
logger.info(
|
||||||
|
f"Loading context from {file_path}, content length: {len(content)}"
|
||||||
|
)
|
||||||
import json
|
import json
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Try parsing as JSON first to ensure valid JSON
|
# Try parsing as JSON first to ensure valid JSON
|
||||||
json_data = json.loads(content)
|
json_data = json.loads(content)
|
||||||
@ -787,7 +939,9 @@ class WebServer:
|
|||||||
# Now set context on agents manually
|
# Now set context on agents manually
|
||||||
agent_types = [agent.agent_type for agent in context.agents]
|
agent_types = [agent.agent_type for agent in context.agents]
|
||||||
if len(agent_types) != len(set(agent_types)):
|
if len(agent_types) != len(set(agent_types)):
|
||||||
raise ValueError("Context cannot contain multiple agents of the same agent_type")
|
raise ValueError(
|
||||||
|
"Context cannot contain multiple agents of the same agent_type"
|
||||||
|
)
|
||||||
for agent in context.agents:
|
for agent in context.agents:
|
||||||
agent.set_context(context)
|
agent.set_context(context)
|
||||||
|
|
||||||
@ -799,13 +953,18 @@ class WebServer:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error validating context: {str(e)}")
|
logger.error(f"Error validating context: {str(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
# Fallback to creating a new context
|
# Fallback to creating a new context
|
||||||
self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher, prometheus_collector=self.prometheus_collector)
|
self.contexts[context_id] = Context(
|
||||||
|
id=context_id,
|
||||||
|
file_watcher=self.file_watcher,
|
||||||
|
prometheus_collector=self.prometheus_collector,
|
||||||
|
)
|
||||||
|
|
||||||
return self.contexts[context_id]
|
return self.contexts[context_id]
|
||||||
|
|
||||||
def create_context(self, context_id = None) -> Context:
|
def create_context(self, context_id=None) -> Context:
|
||||||
"""
|
"""
|
||||||
Create a new context with a unique ID and default settings.
|
Create a new context with a unique ID and default settings.
|
||||||
Args:
|
Args:
|
||||||
@ -818,7 +977,11 @@ class WebServer:
|
|||||||
if not context_id:
|
if not context_id:
|
||||||
context_id = str(uuid4())
|
context_id = str(uuid4())
|
||||||
logger.info(f"Creating new context with ID: {context_id}")
|
logger.info(f"Creating new context with ID: {context_id}")
|
||||||
context = Context(id=context_id, file_watcher=self.file_watcher, prometheus_collector=self.prometheus_collector)
|
context = Context(
|
||||||
|
id=context_id,
|
||||||
|
file_watcher=self.file_watcher,
|
||||||
|
prometheus_collector=self.prometheus_collector,
|
||||||
|
)
|
||||||
|
|
||||||
if os.path.exists(defines.resume_doc):
|
if os.path.exists(defines.resume_doc):
|
||||||
context.user_resume = open(defines.resume_doc, "r").read()
|
context.user_resume = open(defines.resume_doc, "r").read()
|
||||||
@ -835,7 +998,7 @@ class WebServer:
|
|||||||
self.save_context(context.id)
|
self.save_context(context.id)
|
||||||
return context
|
return context
|
||||||
|
|
||||||
def upsert_context(self, context_id = None) -> Context:
|
def upsert_context(self, context_id=None) -> Context:
|
||||||
"""
|
"""
|
||||||
Upsert a context based on the provided context_id.
|
Upsert a context based on the provided context_id.
|
||||||
Args:
|
Args:
|
||||||
@ -855,7 +1018,9 @@ class WebServer:
|
|||||||
return self.load_or_create_context(context_id)
|
return self.load_or_create_context(context_id)
|
||||||
|
|
||||||
@REQUEST_TIME.time()
|
@REQUEST_TIME.time()
|
||||||
async def generate_response(self, context : Context, agent : Agent, prompt : str, options: Tunables | None) -> AsyncGenerator[Message, None]:
|
async def generate_response(
|
||||||
|
self, context: Context, agent: Agent, prompt: str, options: Tunables | None
|
||||||
|
) -> AsyncGenerator[Message, None]:
|
||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
raise Exception("File watcher not initialized")
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
@ -880,7 +1045,9 @@ class WebServer:
|
|||||||
if message.status == "error":
|
if message.status == "error":
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"{agent_type}.process_message: {message.status} {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
|
logger.info(
|
||||||
|
f"{agent_type}.process_message: {message.status} {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}"
|
||||||
|
)
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
@ -895,24 +1062,21 @@ class WebServer:
|
|||||||
port=port,
|
port=port,
|
||||||
log_config=None,
|
log_config=None,
|
||||||
ssl_keyfile=defines.key_path,
|
ssl_keyfile=defines.key_path,
|
||||||
ssl_certfile=defines.cert_path
|
ssl_certfile=defines.cert_path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Starting web server at http://{host}:{port}")
|
logger.info(f"Starting web server at http://{host}:{port}")
|
||||||
uvicorn.run(
|
uvicorn.run(self.app, host=host, port=port, log_config=None)
|
||||||
self.app,
|
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
log_config=None
|
|
||||||
)
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
if self.observer:
|
if self.observer:
|
||||||
self.observer.stop()
|
self.observer.stop()
|
||||||
if self.observer:
|
if self.observer:
|
||||||
self.observer.join()
|
self.observer.join()
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
|
||||||
# Main function to run everything
|
# Main function to run everything
|
||||||
def main():
|
def main():
|
||||||
global model
|
global model
|
||||||
@ -923,23 +1087,16 @@ def main():
|
|||||||
# Setup logging based on the provided level
|
# Setup logging based on the provided level
|
||||||
logger.setLevel(args.level.upper())
|
logger.setLevel(args.level.upper())
|
||||||
|
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.*")
|
||||||
"ignore",
|
|
||||||
category=FutureWarning,
|
|
||||||
module="sklearn.*"
|
|
||||||
)
|
|
||||||
|
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings("ignore", category=UserWarning, module="umap.*")
|
||||||
"ignore",
|
|
||||||
category=UserWarning,
|
|
||||||
module="umap.*"
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = ollama.Client(host=args.ollama_server) # type: ignore
|
llm = ollama.Client(host=args.ollama_server) # type: ignore
|
||||||
model = args.ollama_model
|
model = args.ollama_model
|
||||||
|
|
||||||
web_server = WebServer(llm, model)
|
web_server = WebServer(llm, model)
|
||||||
|
|
||||||
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
|
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
|
||||||
|
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
@ -1,29 +1,25 @@
|
|||||||
from .. utils import logger
|
from ..utils import logger
|
||||||
|
|
||||||
import ollama
|
import ollama
|
||||||
|
|
||||||
from .. utils import (
|
from ..utils import rag as Rag, Context, defines
|
||||||
rag as Rag,
|
|
||||||
Context,
|
|
||||||
defines
|
|
||||||
)
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
llm = ollama.Client(host=defines.ollama_api_url)
|
llm = ollama.Client(host=defines.ollama_api_url)
|
||||||
|
|
||||||
observer, file_watcher = Rag.start_file_watcher(
|
observer, file_watcher = Rag.start_file_watcher(
|
||||||
llm=llm,
|
llm=llm, watch_directory=defines.doc_dir, recreate=False # Don't recreate if exists
|
||||||
watch_directory=defines.doc_dir,
|
|
||||||
recreate=False # Don't recreate if exists
|
|
||||||
)
|
)
|
||||||
|
|
||||||
context = Context(file_watcher=file_watcher)
|
context = Context(file_watcher=file_watcher)
|
||||||
data = context.model_dump(mode='json')
|
data = context.model_dump(mode="json")
|
||||||
context = Context.model_validate_json(json.dumps(data))
|
context = Context.model_validate_json(json.dumps(data))
|
||||||
context.file_watcher = file_watcher
|
context.file_watcher = file_watcher
|
||||||
|
|
||||||
agent = context.get_or_create_agent("chat", system_prompt="You are a helpful assistant.")
|
agent = context.get_or_create_agent(
|
||||||
|
"chat", system_prompt="You are a helpful assistant."
|
||||||
|
)
|
||||||
# logger.info(f"data: {data}")
|
# logger.info(f"data: {data}")
|
||||||
# logger.info(f"agent: {agent}")
|
# logger.info(f"agent: {agent}")
|
||||||
agent_type = agent.get_agent_type()
|
agent_type = agent.get_agent_type()
|
||||||
@ -32,7 +28,7 @@ logger.info(f"system_prompt: {agent.system_prompt}")
|
|||||||
|
|
||||||
agent.system_prompt = "Eat more tomatoes."
|
agent.system_prompt = "Eat more tomatoes."
|
||||||
|
|
||||||
data = context.model_dump(mode='json')
|
data = context.model_dump(mode="json")
|
||||||
context = Context.model_validate_json(json.dumps(data))
|
context = Context.model_validate_json(json.dumps(data))
|
||||||
context.file_watcher = file_watcher
|
context.file_watcher = file_watcher
|
||||||
|
|
||||||
|
@ -1,19 +1,21 @@
|
|||||||
from fastapi import FastAPI, Request, Depends, Query # type: ignore
|
from fastapi import FastAPI, Request, Depends, Query # type: ignore
|
||||||
from fastapi.responses import RedirectResponse, JSONResponse # type: ignore
|
from fastapi.responses import RedirectResponse, JSONResponse # type: ignore
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from anyio.to_thread import run_sync # type: ignore
|
from anyio.to_thread import run_sync # type: ignore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RedirectToContext(Exception):
|
class RedirectToContext(Exception):
|
||||||
def __init__(self, url: str):
|
def __init__(self, url: str):
|
||||||
self.url = url
|
self.url = url
|
||||||
logger.info(f"Redirect to Context: {url}")
|
logger.info(f"Redirect to Context: {url}")
|
||||||
super().__init__(f"Redirect to Context: {url}")
|
super().__init__(f"Redirect to Context: {url}")
|
||||||
|
|
||||||
|
|
||||||
class ContextRouteManager:
|
class ContextRouteManager:
|
||||||
def __init__(self, app: FastAPI):
|
def __init__(self, app: FastAPI):
|
||||||
self.app = app
|
self.app = app
|
||||||
@ -25,20 +27,28 @@ class ContextRouteManager:
|
|||||||
logger.info(f"Handling redirect to {exc.url}")
|
logger.info(f"Handling redirect to {exc.url}")
|
||||||
return RedirectResponse(url=exc.url, status_code=307)
|
return RedirectResponse(url=exc.url, status_code=307)
|
||||||
|
|
||||||
def ensure_context(self, route_name: str = "context_id") -> Callable[[Request], Optional[UUID]]:
|
def ensure_context(
|
||||||
|
self, route_name: str = "context_id"
|
||||||
|
) -> Callable[[Request], Optional[UUID]]:
|
||||||
logger.info(f"Setting up context dependency for route parameter: {route_name}")
|
logger.info(f"Setting up context dependency for route parameter: {route_name}")
|
||||||
|
|
||||||
async def _ensure_context_dependency(request: Request) -> Optional[UUID]:
|
async def _ensure_context_dependency(request: Request) -> Optional[UUID]:
|
||||||
logger.info(f"Entering ensure_context_dependency, Request URL: {request.url}")
|
logger.info(
|
||||||
|
f"Entering ensure_context_dependency, Request URL: {request.url}"
|
||||||
|
)
|
||||||
logger.info(f"Path params: {request.path_params}")
|
logger.info(f"Path params: {request.path_params}")
|
||||||
|
|
||||||
path_params = request.path_params
|
path_params = request.path_params
|
||||||
route_value = path_params.get(route_name)
|
route_value = path_params.get(route_name)
|
||||||
logger.info(f"route_value: {route_value!r}, type: {type(route_value)}")
|
logger.info(f"route_value: {route_value!r}, type: {type(route_value)}")
|
||||||
|
|
||||||
if route_value is None or not isinstance(route_value, str) or not route_value.strip():
|
if (
|
||||||
|
route_value is None
|
||||||
|
or not isinstance(route_value, str)
|
||||||
|
or not route_value.strip()
|
||||||
|
):
|
||||||
logger.info(f"route_value is invalid, generating new UUID")
|
logger.info(f"route_value is invalid, generating new UUID")
|
||||||
path = request.url.path.rstrip('/')
|
path = request.url.path.rstrip("/")
|
||||||
new_context = await run_sync(uuid4)
|
new_context = await run_sync(uuid4)
|
||||||
redirect_url = f"{path}/{new_context}"
|
redirect_url = f"{path}/{new_context}"
|
||||||
logger.info(f"Redirecting to {redirect_url}")
|
logger.info(f"Redirecting to {redirect_url}")
|
||||||
@ -50,14 +60,16 @@ class ContextRouteManager:
|
|||||||
logger.info(f"Successfully parsed UUID: {route_context}")
|
logger.info(f"Successfully parsed UUID: {route_context}")
|
||||||
return route_context
|
return route_context
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Failed to parse UUID from route_value: {route_value!r}, error: {str(e)}")
|
logger.error(
|
||||||
path = request.url.path.rstrip('/')
|
f"Failed to parse UUID from route_value: {route_value!r}, error: {str(e)}"
|
||||||
|
)
|
||||||
|
path = request.url.path.rstrip("/")
|
||||||
new_context = await run_sync(uuid4)
|
new_context = await run_sync(uuid4)
|
||||||
redirect_url = f"{path}/{new_context}"
|
redirect_url = f"{path}/{new_context}"
|
||||||
logger.info(f"Invalid UUID, redirecting to {redirect_url}")
|
logger.info(f"Invalid UUID, redirecting to {redirect_url}")
|
||||||
raise RedirectToContext(redirect_url)
|
raise RedirectToContext(redirect_url)
|
||||||
|
|
||||||
return _ensure_context_dependency # type: ignore
|
return _ensure_context_dependency # type: ignore
|
||||||
|
|
||||||
def route_pattern(self, path: str, *dependencies, **kwargs):
|
def route_pattern(self, path: str, *dependencies, **kwargs):
|
||||||
logger.info(f"Registering route: {path}")
|
logger.info(f"Registering route: {path}")
|
||||||
@ -66,44 +78,62 @@ class ContextRouteManager:
|
|||||||
def decorator(func):
|
def decorator(func):
|
||||||
all_dependencies = list(dependencies)
|
all_dependencies = list(dependencies)
|
||||||
all_dependencies.append(Depends(ensure_context))
|
all_dependencies.append(Depends(ensure_context))
|
||||||
logger.info(f"Route {path} registered with dependencies: {all_dependencies}")
|
logger.info(
|
||||||
|
f"Route {path} registered with dependencies: {all_dependencies}"
|
||||||
|
)
|
||||||
return self.app.get(path, dependencies=all_dependencies, **kwargs)(func)
|
return self.app.get(path, dependencies=all_dependencies, **kwargs)(func)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(redirect_slashes=True)
|
app = FastAPI(redirect_slashes=True)
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(Exception)
|
@app.exception_handler(Exception)
|
||||||
async def global_exception_handler(request: Request, exc: Exception):
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
logger.error(f"Unhandled exception: {str(exc)}")
|
logger.error(f"Unhandled exception: {str(exc)}")
|
||||||
logger.error(f"Request URL: {request.url}, Path params: {request.path_params}")
|
logger.error(f"Request URL: {request.url}, Path params: {request.path_params}")
|
||||||
logger.error(f"Stack trace: {''.join(traceback.format_tb(exc.__traceback__))}")
|
logger.error(f"Stack trace: {''.join(traceback.format_tb(exc.__traceback__))}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500, content={"error": "Internal server error", "detail": str(exc)}
|
||||||
content={"error": "Internal server error", "detail": str(exc)}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def log_requests(request: Request, call_next):
|
async def log_requests(request: Request, call_next):
|
||||||
logger.info(f"Incoming request: {request.method} {request.url}, Path params: {request.path_params}")
|
logger.info(
|
||||||
|
f"Incoming request: {request.method} {request.url}, Path params: {request.path_params}"
|
||||||
|
)
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
context_router = ContextRouteManager(app)
|
context_router = ContextRouteManager(app)
|
||||||
|
|
||||||
|
|
||||||
@context_router.route_pattern("/api/history/{context_id}")
|
@context_router.route_pattern("/api/history/{context_id}")
|
||||||
async def get_history(request: Request, context_id: UUID = Depends(context_router.ensure_context()), agent_type: str = Query(..., description="Type of agent to retrieve history for")):
|
async def get_history(
|
||||||
|
request: Request,
|
||||||
|
context_id: UUID = Depends(context_router.ensure_context()),
|
||||||
|
agent_type: str = Query(..., description="Type of agent to retrieve history for"),
|
||||||
|
):
|
||||||
logger.info(f"{request.method} {request.url.path} with context_id: {context_id}")
|
logger.info(f"{request.method} {request.url.path} with context_id: {context_id}")
|
||||||
return {"context_id": str(context_id), "agent_type": agent_type}
|
return {"context_id": str(context_id), "agent_type": agent_type}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/history")
|
@app.get("/api/history")
|
||||||
async def redirect_history(request: Request, agent_type: str = Query(..., description="Type of agent to retrieve history for")):
|
async def redirect_history(
|
||||||
path = request.url.path.rstrip('/')
|
request: Request,
|
||||||
|
agent_type: str = Query(..., description="Type of agent to retrieve history for"),
|
||||||
|
):
|
||||||
|
path = request.url.path.rstrip("/")
|
||||||
new_context = uuid4()
|
new_context = uuid4()
|
||||||
redirect_url = f"{path}/{new_context}?agent_type={agent_type}"
|
redirect_url = f"{path}/{new_context}?agent_type={agent_type}"
|
||||||
logger.info(f"Redirecting from /api/history to {redirect_url}")
|
logger.info(f"Redirecting from /api/history to {redirect_url}")
|
||||||
return RedirectResponse(url=redirect_url, status_code=307)
|
return RedirectResponse(url=redirect_url, status_code=307)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn # type: ignore
|
import uvicorn # type: ignore
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8900)
|
uvicorn.run(app, host="0.0.0.0", port=8900)
|
@ -1,28 +1,24 @@
|
|||||||
# From /opt/backstory run:
|
# From /opt/backstory run:
|
||||||
# python -m src.tests.test-context
|
# python -m src.tests.test-context
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
|
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
|
||||||
|
|
||||||
import ollama
|
import ollama
|
||||||
|
|
||||||
from .. utils import (
|
from ..utils import rag as Rag, Context, defines
|
||||||
rag as Rag,
|
|
||||||
Context,
|
|
||||||
defines
|
|
||||||
)
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
||||||
|
|
||||||
observer, file_watcher = Rag.start_file_watcher(
|
observer, file_watcher = Rag.start_file_watcher(
|
||||||
llm=llm,
|
llm=llm, watch_directory=defines.doc_dir, recreate=False # Don't recreate if exists
|
||||||
watch_directory=defines.doc_dir,
|
|
||||||
recreate=False # Don't recreate if exists
|
|
||||||
)
|
)
|
||||||
|
|
||||||
context = Context(file_watcher=file_watcher)
|
context = Context(file_watcher=file_watcher)
|
||||||
data = context.model_dump(mode='json')
|
data = context.model_dump(mode="json")
|
||||||
context = Context.from_json(json.dumps(data), file_watcher=file_watcher)
|
context = Context.from_json(json.dumps(data), file_watcher=file_watcher)
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
# From /opt/backstory run:
|
# From /opt/backstory run:
|
||||||
# python -m src.tests.test-message
|
# python -m src.tests.test-message
|
||||||
from .. utils import logger
|
from ..utils import logger
|
||||||
|
|
||||||
from .. utils import (
|
from ..utils import Message
|
||||||
Message
|
|
||||||
)
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
prompt = "This is a test"
|
prompt = "This is a test"
|
||||||
message = Message(prompt=prompt)
|
message = Message(prompt=prompt)
|
||||||
print(message.model_dump(mode='json'))
|
print(message.model_dump(mode="json"))
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
# From /opt/backstory run:
|
# From /opt/backstory run:
|
||||||
# python -m src.tests.test-metrics
|
# python -m src.tests.test-metrics
|
||||||
from .. utils import (
|
from ..utils import Metrics
|
||||||
Metrics
|
|
||||||
)
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@ -13,7 +11,7 @@ metrics = Metrics()
|
|||||||
metrics.prepare_count.labels(agent="chat").inc()
|
metrics.prepare_count.labels(agent="chat").inc()
|
||||||
metrics.prepare_duration.labels(agent="prepare").observe(0.45)
|
metrics.prepare_duration.labels(agent="prepare").observe(0.45)
|
||||||
|
|
||||||
json = metrics.model_dump(mode='json')
|
json = metrics.model_dump(mode="json")
|
||||||
metrics = Metrics.model_validate(json)
|
metrics = Metrics.model_validate(json)
|
||||||
|
|
||||||
print(metrics)
|
print(metrics)
|
@ -1,32 +1,33 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel # type: ignore
|
from pydantic import BaseModel # type: ignore
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
from . import defines
|
from . import defines
|
||||||
from . context import Context
|
from .context import Context
|
||||||
from . conversation import Conversation
|
from .conversation import Conversation
|
||||||
from . message import Message, Tunables
|
from .message import Message, Tunables
|
||||||
from . rag import ChromaDBFileWatcher, start_file_watcher
|
from .rag import ChromaDBFileWatcher, start_file_watcher
|
||||||
from . setup_logging import setup_logging
|
from .setup_logging import setup_logging
|
||||||
from . agents import class_registry, AnyAgent, Agent, __all__ as agents_all
|
from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
|
||||||
from . metrics import Metrics
|
from .metrics import Metrics
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Agent',
|
"Agent",
|
||||||
'Tunables',
|
"Tunables",
|
||||||
'Context',
|
"Context",
|
||||||
'Conversation',
|
"Conversation",
|
||||||
'Message',
|
"Message",
|
||||||
'Metrics',
|
"Metrics",
|
||||||
'ChromaDBFileWatcher',
|
"ChromaDBFileWatcher",
|
||||||
'start_file_watcher',
|
"start_file_watcher",
|
||||||
'logger',
|
"logger",
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__.extend(agents_all) # type: ignore
|
__all__.extend(agents_all) # type: ignore
|
||||||
|
|
||||||
logger = setup_logging(level=defines.logging_level)
|
logger = setup_logging(level=defines.logging_level)
|
||||||
|
|
||||||
|
|
||||||
def rebuild_models():
|
def rebuild_models():
|
||||||
for class_name, (module_name, _) in class_registry.items():
|
for class_name, (module_name, _) in class_registry.items():
|
||||||
try:
|
try:
|
||||||
@ -36,9 +37,15 @@ def rebuild_models():
|
|||||||
logger.debug(f"Checking: {class_name} in module {module_name}")
|
logger.debug(f"Checking: {class_name} in module {module_name}")
|
||||||
logger.debug(f" cls: {True if cls else False}")
|
logger.debug(f" cls: {True if cls else False}")
|
||||||
logger.debug(f" isinstance(cls, type): {isinstance(cls, type)}")
|
logger.debug(f" isinstance(cls, type): {isinstance(cls, type)}")
|
||||||
logger.debug(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}")
|
logger.debug(
|
||||||
logger.debug(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}")
|
f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}"
|
||||||
logger.debug(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}")
|
)
|
||||||
|
logger.debug(
|
||||||
|
f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}"
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cls
|
cls
|
||||||
@ -48,13 +55,15 @@ def rebuild_models():
|
|||||||
and cls is not AnyAgent
|
and cls is not AnyAgent
|
||||||
):
|
):
|
||||||
logger.debug(f"Rebuilding {class_name} from {module_name}")
|
logger.debug(f"Rebuilding {class_name} from {module_name}")
|
||||||
from . agents import Agent
|
from .agents import Agent
|
||||||
from . context import Context
|
from .context import Context
|
||||||
|
|
||||||
cls.model_rebuild()
|
cls.model_rebuild()
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(f"Failed to import module {module_name}: {e}")
|
logger.error(f"Failed to import module {module_name}: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing {class_name} in {module_name}: {e}")
|
logger.error(f"Error processing {class_name} in {module_name}: {e}")
|
||||||
|
|
||||||
|
|
||||||
# Call this after all modules are imported
|
# Call this after all modules are imported
|
||||||
rebuild_models()
|
rebuild_models()
|
||||||
|
@ -4,19 +4,21 @@ import importlib
|
|||||||
import pathlib
|
import pathlib
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from . types import agent_registry
|
from .types import agent_registry
|
||||||
from .. setup_logging import setup_logging
|
from ..setup_logging import setup_logging
|
||||||
from .. import defines
|
from .. import defines
|
||||||
from . base import Agent
|
from .base import Agent
|
||||||
|
|
||||||
logger = setup_logging(defines.logging_level)
|
logger = setup_logging(defines.logging_level)
|
||||||
|
|
||||||
__all__ = [ "AnyAgent", "Agent", "agent_registry", "class_registry" ]
|
__all__ = ["AnyAgent", "Agent", "agent_registry", "class_registry"]
|
||||||
|
|
||||||
# Type alias for Agent or any subclass
|
# Type alias for Agent or any subclass
|
||||||
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
|
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
|
||||||
|
|
||||||
class_registry: Dict[str, Tuple[str, str]] = {} # Maps class_name to (module_name, class_name)
|
class_registry: Dict[str, Tuple[str, str]] = (
|
||||||
|
{}
|
||||||
|
) # Maps class_name to (module_name, class_name)
|
||||||
|
|
||||||
package_dir = pathlib.Path(__file__).parent
|
package_dir = pathlib.Path(__file__).parent
|
||||||
package_name = __name__
|
package_name = __name__
|
||||||
@ -42,7 +44,7 @@ for path in package_dir.glob("*.py"):
|
|||||||
class_registry[name] = (full_module_name, name)
|
class_registry[name] = (full_module_name, name)
|
||||||
globals()[name] = obj
|
globals()[name] = obj
|
||||||
logger.info(f"Adding agent: {name}")
|
logger.info(f"Adding agent: {name}")
|
||||||
__all__.append(name) # type: ignore
|
__all__.append(name) # type: ignore
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(f"Error importing {full_module_name}: {e}")
|
logger.error(f"Error importing {full_module_name}: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
@ -1,8 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, PrivateAttr, Field # type: ignore
|
from pydantic import BaseModel, PrivateAttr, Field # type: ignore
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal, get_args, List, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any,
|
Literal,
|
||||||
TypeAlias, Dict, Tuple
|
get_args,
|
||||||
|
List,
|
||||||
|
AsyncGenerator,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Optional,
|
||||||
|
ClassVar,
|
||||||
|
Any,
|
||||||
|
TypeAlias,
|
||||||
|
Dict,
|
||||||
|
Tuple,
|
||||||
)
|
)
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
@ -10,50 +19,55 @@ import inspect
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
||||||
|
|
||||||
from .. setup_logging import setup_logging
|
from ..setup_logging import setup_logging
|
||||||
|
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
|
|
||||||
# Only import Context for type checking
|
# Only import Context for type checking
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .. context import Context
|
from ..context import Context
|
||||||
|
|
||||||
from . types import agent_registry
|
from .types import agent_registry
|
||||||
from .. import defines
|
from .. import defines
|
||||||
from .. message import Message, Tunables
|
from ..message import Message, Tunables
|
||||||
from .. metrics import Metrics
|
from ..metrics import Metrics
|
||||||
from .. tools import ( TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools ) # type: ignore -- dynamically added to __all__
|
from ..tools import TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools # type: ignore -- dynamically added to __all__
|
||||||
from .. conversation import Conversation
|
from ..conversation import Conversation
|
||||||
|
|
||||||
|
|
||||||
class LLMMessage(BaseModel):
|
class LLMMessage(BaseModel):
|
||||||
role : str = Field(default="")
|
role: str = Field(default="")
|
||||||
content : str = Field(default="")
|
content: str = Field(default="")
|
||||||
tool_calls : Optional[List[Dict]] = Field(default={}, exclude=True)
|
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
|
||||||
|
|
||||||
|
|
||||||
class Agent(BaseModel, ABC):
|
class Agent(BaseModel, ABC):
|
||||||
"""
|
"""
|
||||||
Base class for all agent types.
|
Base class for all agent types.
|
||||||
This class defines the common attributes and methods for all agent types.
|
This class defines the common attributes and methods for all agent types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Agent management with pydantic
|
# Agent management with pydantic
|
||||||
agent_type: Literal["base"] = "base"
|
agent_type: Literal["base"] = "base"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
|
|
||||||
# Tunables (sets default for new Messages attached to this agent)
|
# Tunables (sets default for new Messages attached to this agent)
|
||||||
tunables: Tunables = Field(default_factory=Tunables)
|
tunables: Tunables = Field(default_factory=Tunables)
|
||||||
|
|
||||||
# Agent properties
|
# Agent properties
|
||||||
system_prompt: str # Mandatory
|
system_prompt: str # Mandatory
|
||||||
conversation: Conversation = Conversation()
|
conversation: Conversation = Conversation()
|
||||||
context_tokens: int = 0
|
context_tokens: int = 0
|
||||||
context: Optional[Context] = Field(default=None, exclude=True) # Avoid circular reference, require as param, and prevent serialization
|
context: Optional[Context] = Field(
|
||||||
|
default=None, exclude=True
|
||||||
|
) # Avoid circular reference, require as param, and prevent serialization
|
||||||
metrics: Metrics = Field(default_factory=Metrics, exclude=True)
|
metrics: Metrics = Field(default_factory=Metrics, exclude=True)
|
||||||
|
|
||||||
# context_size is shared across all subclasses
|
# context_size is shared across all subclasses
|
||||||
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
|
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def context_size(self) -> int:
|
def context_size(self) -> int:
|
||||||
return Agent._context_size
|
return Agent._context_size
|
||||||
@ -62,7 +76,9 @@ class Agent(BaseModel, ABC):
|
|||||||
def context_size(self, value: int):
|
def context_size(self, value: int):
|
||||||
Agent._context_size = value
|
Agent._context_size = value
|
||||||
|
|
||||||
def set_optimal_context_size(self, llm: Any, model: str, prompt: str, ctx_buffer=2048) -> int:
|
def set_optimal_context_size(
|
||||||
|
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
|
||||||
|
) -> int:
|
||||||
# # Get more accurate token count estimate using tiktoken or similar
|
# # Get more accurate token count estimate using tiktoken or similar
|
||||||
# response = llm.generate(
|
# response = llm.generate(
|
||||||
# model=model,
|
# model=model,
|
||||||
@ -83,7 +99,9 @@ class Agent(BaseModel, ABC):
|
|||||||
total_ctx = tokens + ctx_buffer
|
total_ctx = tokens + ctx_buffer
|
||||||
|
|
||||||
if total_ctx > self.context_size:
|
if total_ctx > self.context_size:
|
||||||
logger.info(f"Increasing context size from {self.context_size} to {total_ctx}")
|
logger.info(
|
||||||
|
f"Increasing context size from {self.context_size} to {total_ctx}"
|
||||||
|
)
|
||||||
|
|
||||||
# Grow the context size if necessary
|
# Grow the context size if necessary
|
||||||
self.context_size = max(self.context_size, total_ctx)
|
self.context_size = max(self.context_size, total_ctx)
|
||||||
@ -95,7 +113,7 @@ class Agent(BaseModel, ABC):
|
|||||||
"""Auto-register subclasses"""
|
"""Auto-register subclasses"""
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
# Register this class if it has an agent_type
|
# Register this class if it has an agent_type
|
||||||
if hasattr(cls, 'agent_type') and cls.agent_type != Agent._agent_type:
|
if hasattr(cls, "agent_type") and cls.agent_type != Agent._agent_type:
|
||||||
agent_registry.register(cls.agent_type, cls)
|
agent_registry.register(cls.agent_type, cls)
|
||||||
|
|
||||||
# def __init__(self, *, context=context, **data):
|
# def __init__(self, *, context=context, **data):
|
||||||
@ -123,7 +141,7 @@ class Agent(BaseModel, ABC):
|
|||||||
def get_agent_type(self):
|
def get_agent_type(self):
|
||||||
return self._agent_type
|
return self._agent_type
|
||||||
|
|
||||||
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||||
"""
|
"""
|
||||||
Prepare message with context information in message.preamble
|
Prepare message with context information in message.preamble
|
||||||
"""
|
"""
|
||||||
@ -166,7 +184,14 @@ class Agent(BaseModel, ABC):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
async def process_tool_calls(self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[LLMMessage]) -> AsyncGenerator[Message, None]:
|
async def process_tool_calls(
|
||||||
|
self,
|
||||||
|
llm: Any,
|
||||||
|
model: str,
|
||||||
|
message: Message,
|
||||||
|
tool_message: Any,
|
||||||
|
messages: List[LLMMessage],
|
||||||
|
) -> AsyncGenerator[Message, None]:
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
|
||||||
self.metrics.tool_count.labels(agent=self.agent_type).inc()
|
self.metrics.tool_count.labels(agent=self.agent_type).inc()
|
||||||
@ -187,7 +212,9 @@ class Agent(BaseModel, ABC):
|
|||||||
tool = tool_call.function.name
|
tool = tool_call.function.name
|
||||||
|
|
||||||
# Yield status update before processing each tool
|
# Yield status update before processing each tool
|
||||||
message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
|
message.response = (
|
||||||
|
f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
|
||||||
|
)
|
||||||
yield message
|
yield message
|
||||||
logger.info(f"LLM - {message.response}")
|
logger.info(f"LLM - {message.response}")
|
||||||
|
|
||||||
@ -202,12 +229,18 @@ class Agent(BaseModel, ABC):
|
|||||||
|
|
||||||
case "AnalyzeSite":
|
case "AnalyzeSite":
|
||||||
url = arguments.get("url")
|
url = arguments.get("url")
|
||||||
question = arguments.get("question", "what is the summary of this content?")
|
question = arguments.get(
|
||||||
|
"question", "what is the summary of this content?"
|
||||||
|
)
|
||||||
|
|
||||||
# Additional status update for long-running operations
|
# Additional status update for long-running operations
|
||||||
message.response = f"Retrieving and summarizing content from {url}..."
|
message.response = (
|
||||||
|
f"Retrieving and summarizing content from {url}..."
|
||||||
|
)
|
||||||
yield message
|
yield message
|
||||||
ret = await AnalyzeSite(llm=llm, model=model, url=url, question=question)
|
ret = await AnalyzeSite(
|
||||||
|
llm=llm, model=model, url=url, question=question
|
||||||
|
)
|
||||||
|
|
||||||
case "DateTime":
|
case "DateTime":
|
||||||
tz = arguments.get("timezone")
|
tz = arguments.get("timezone")
|
||||||
@ -217,7 +250,9 @@ class Agent(BaseModel, ABC):
|
|||||||
city = arguments.get("city")
|
city = arguments.get("city")
|
||||||
state = arguments.get("state")
|
state = arguments.get("state")
|
||||||
|
|
||||||
message.response = f"Fetching weather data for {city}, {state}..."
|
message.response = (
|
||||||
|
f"Fetching weather data for {city}, {state}..."
|
||||||
|
)
|
||||||
yield message
|
yield message
|
||||||
ret = WeatherForecast(city, state)
|
ret = WeatherForecast(city, state)
|
||||||
|
|
||||||
@ -228,7 +263,7 @@ class Agent(BaseModel, ABC):
|
|||||||
tool_response = {
|
tool_response = {
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"content": json.dumps(ret),
|
"content": json.dumps(ret),
|
||||||
"name": tool_call.function.name
|
"name": tool_call.function.name,
|
||||||
}
|
}
|
||||||
|
|
||||||
tool_metadata["tool_calls"].append(tool_response)
|
tool_metadata["tool_calls"].append(tool_response)
|
||||||
@ -241,13 +276,15 @@ class Agent(BaseModel, ABC):
|
|||||||
message_dict = LLMMessage(
|
message_dict = LLMMessage(
|
||||||
role=tool_message.get("role", "assistant"),
|
role=tool_message.get("role", "assistant"),
|
||||||
content=tool_message.get("content", ""),
|
content=tool_message.get("content", ""),
|
||||||
tool_calls=[ {
|
tool_calls=[
|
||||||
"function": {
|
{
|
||||||
"name": tc["function"]["name"],
|
"function": {
|
||||||
"arguments": tc["function"]["arguments"]
|
"name": tc["function"]["name"],
|
||||||
|
"arguments": tc["function"]["arguments"],
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} for tc in tool_message.tool_calls
|
for tc in tool_message.tool_calls
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
messages.append(message_dict)
|
messages.append(message_dict)
|
||||||
@ -282,20 +319,30 @@ class Agent(BaseModel, ABC):
|
|||||||
message.metadata["eval_count"] += response.eval_count
|
message.metadata["eval_count"] += response.eval_count
|
||||||
message.metadata["eval_duration"] += response.eval_duration
|
message.metadata["eval_duration"] += response.eval_duration
|
||||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
||||||
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
|
message.metadata[
|
||||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
"prompt_eval_duration"
|
||||||
|
] += response.prompt_eval_duration
|
||||||
|
self.context_tokens = (
|
||||||
|
response.prompt_eval_count + response.eval_count
|
||||||
|
)
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}"
|
message.metadata["timers"][
|
||||||
|
"llm_with_tools"
|
||||||
|
] = f"{(end_time - start_time):.4f}"
|
||||||
return
|
return
|
||||||
|
|
||||||
def collect_metrics(self, response):
|
def collect_metrics(self, response):
|
||||||
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(response.prompt_eval_count)
|
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
|
||||||
|
response.prompt_eval_count
|
||||||
|
)
|
||||||
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
|
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
|
||||||
|
|
||||||
async def generate_llm_response(self, llm: Any, model: str, message: Message, temperature = 0.7) -> AsyncGenerator[Message, None]:
|
async def generate_llm_response(
|
||||||
|
self, llm: Any, model: str, message: Message, temperature=0.7
|
||||||
|
) -> AsyncGenerator[Message, None]:
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
|
||||||
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||||
@ -305,22 +352,29 @@ class Agent(BaseModel, ABC):
|
|||||||
|
|
||||||
# Create a pruned down message list based purely on the prompt and responses,
|
# Create a pruned down message list based purely on the prompt and responses,
|
||||||
# discarding the full preamble generated by prepare_message
|
# discarding the full preamble generated by prepare_message
|
||||||
messages: List[LLMMessage] = [ LLMMessage(role="system", content=message.system_prompt) ]
|
messages: List[LLMMessage] = [
|
||||||
messages.extend([
|
LLMMessage(role="system", content=message.system_prompt)
|
||||||
item for m in self.conversation
|
]
|
||||||
for item in [
|
messages.extend(
|
||||||
LLMMessage(role="user", content=m.prompt.strip()),
|
[
|
||||||
LLMMessage(role="assistant", content=m.response.strip())
|
item
|
||||||
|
for m in self.conversation
|
||||||
|
for item in [
|
||||||
|
LLMMessage(role="user", content=m.prompt.strip()),
|
||||||
|
LLMMessage(role="assistant", content=m.response.strip()),
|
||||||
|
]
|
||||||
]
|
]
|
||||||
])
|
)
|
||||||
# Only the actual user query is provided with the full context message
|
# Only the actual user query is provided with the full context message
|
||||||
messages.append(LLMMessage(role="user", content=message.context_prompt.strip()))
|
messages.append(
|
||||||
|
LLMMessage(role="user", content=message.context_prompt.strip())
|
||||||
|
)
|
||||||
|
|
||||||
#message.metadata["messages"] = messages
|
# message.metadata["messages"] = messages
|
||||||
message.metadata["options"]={
|
message.metadata["options"] = {
|
||||||
"seed": 8911,
|
"seed": 8911,
|
||||||
"num_ctx": self.context_size,
|
"num_ctx": self.context_size,
|
||||||
"temperature": temperature # Higher temperature to encourage tool usage
|
"temperature": temperature, # Higher temperature to encourage tool usage
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create a dict for storing various timing stats
|
# Create a dict for storing various timing stats
|
||||||
@ -329,7 +383,7 @@ class Agent(BaseModel, ABC):
|
|||||||
use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
||||||
message.metadata["tools"] = {
|
message.metadata["tools"] = {
|
||||||
"available": llm_tools(self.context.tools),
|
"available": llm_tools(self.context.tools),
|
||||||
"used": False
|
"used": False,
|
||||||
}
|
}
|
||||||
tool_metadata = message.metadata["tools"]
|
tool_metadata = message.metadata["tools"]
|
||||||
|
|
||||||
@ -355,14 +409,16 @@ class Agent(BaseModel, ABC):
|
|||||||
tools=tool_metadata["available"],
|
tools=tool_metadata["available"],
|
||||||
options={
|
options={
|
||||||
**message.metadata["options"],
|
**message.metadata["options"],
|
||||||
#"num_predict": 1024, # "Low" token limit to cut off after tool call
|
# "num_predict": 1024, # "Low" token limit to cut off after tool call
|
||||||
},
|
},
|
||||||
stream=False # No need to stream the probe
|
stream=False, # No need to stream the probe
|
||||||
)
|
)
|
||||||
self.collect_metrics(response)
|
self.collect_metrics(response)
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
message.metadata["timers"]["tool_check"] = f"{(end_time - start_time):.4f}"
|
message.metadata["timers"][
|
||||||
|
"tool_check"
|
||||||
|
] = f"{(end_time - start_time):.4f}"
|
||||||
if not response.message.tool_calls:
|
if not response.message.tool_calls:
|
||||||
logger.info("LLM indicates tools will not be used")
|
logger.info("LLM indicates tools will not be used")
|
||||||
# The LLM will not use tools, so disable use_tools so we can stream the full response
|
# The LLM will not use tools, so disable use_tools so we can stream the full response
|
||||||
@ -374,7 +430,9 @@ class Agent(BaseModel, ABC):
|
|||||||
logger.info("LLM indicates tools will be used")
|
logger.info("LLM indicates tools will be used")
|
||||||
|
|
||||||
# Tools are enabled and available and the LLM indicated it will use them
|
# Tools are enabled and available and the LLM indicated it will use them
|
||||||
message.response = f"Performing tool analysis step 2/2 (tool use suspected)..."
|
message.response = (
|
||||||
|
f"Performing tool analysis step 2/2 (tool use suspected)..."
|
||||||
|
)
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
logger.info(f"Performing LLM call with tools")
|
logger.info(f"Performing LLM call with tools")
|
||||||
@ -384,14 +442,16 @@ class Agent(BaseModel, ABC):
|
|||||||
messages=tool_metadata["messages"], # messages,
|
messages=tool_metadata["messages"], # messages,
|
||||||
tools=tool_metadata["available"],
|
tools=tool_metadata["available"],
|
||||||
options={
|
options={
|
||||||
**message.metadata["options"],
|
**message.metadata["options"],
|
||||||
},
|
},
|
||||||
stream=False
|
stream=False,
|
||||||
)
|
)
|
||||||
self.collect_metrics(response)
|
self.collect_metrics(response)
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
message.metadata["timers"]["non_streaming"] = f"{(end_time - start_time):.4f}"
|
message.metadata["timers"][
|
||||||
|
"non_streaming"
|
||||||
|
] = f"{(end_time - start_time):.4f}"
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
@ -403,13 +463,21 @@ class Agent(BaseModel, ABC):
|
|||||||
tool_metadata["used"] = response.message.tool_calls
|
tool_metadata["used"] = response.message.tool_calls
|
||||||
# Process all yielded items from the handler
|
# Process all yielded items from the handler
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages):
|
async for message in self.process_tool_calls(
|
||||||
|
llm=llm,
|
||||||
|
model=model,
|
||||||
|
message=message,
|
||||||
|
tool_message=response.message,
|
||||||
|
messages=messages,
|
||||||
|
):
|
||||||
if message.status == "error":
|
if message.status == "error":
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
yield message
|
yield message
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
message.metadata["timers"]["process_tool_calls"] = f"{(end_time - start_time):.4f}"
|
message.metadata["timers"][
|
||||||
|
"process_tool_calls"
|
||||||
|
] = f"{(end_time - start_time):.4f}"
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -452,8 +520,12 @@ class Agent(BaseModel, ABC):
|
|||||||
message.metadata["eval_count"] += response.eval_count
|
message.metadata["eval_count"] += response.eval_count
|
||||||
message.metadata["eval_duration"] += response.eval_duration
|
message.metadata["eval_duration"] += response.eval_duration
|
||||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
||||||
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
|
message.metadata[
|
||||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
"prompt_eval_duration"
|
||||||
|
] += response.prompt_eval_duration
|
||||||
|
self.context_tokens = (
|
||||||
|
response.prompt_eval_count + response.eval_count
|
||||||
|
)
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
@ -461,7 +533,9 @@ class Agent(BaseModel, ABC):
|
|||||||
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
|
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
|
||||||
return
|
return
|
||||||
|
|
||||||
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
|
async def process_message(
|
||||||
|
self, llm: Any, model: str, message: Message
|
||||||
|
) -> AsyncGenerator[Message, None]:
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
|
||||||
self.metrics.process_count.labels(agent=self.agent_type).inc()
|
self.metrics.process_count.labels(agent=self.agent_type).inc()
|
||||||
@ -470,22 +544,30 @@ class Agent(BaseModel, ABC):
|
|||||||
if not self.context:
|
if not self.context:
|
||||||
raise ValueError("Context is not set for this agent.")
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
logger.info(
|
||||||
spinner: List[str] = ['\\', '|', '/', '-']
|
"TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
|
||||||
tick : int = 0
|
)
|
||||||
|
spinner: List[str] = ["\\", "|", "/", "-"]
|
||||||
|
tick: int = 0
|
||||||
while self.context.processing:
|
while self.context.processing:
|
||||||
message.status = "waiting"
|
message.status = "waiting"
|
||||||
message.response = f"Busy processing another request. Please wait. {spinner[tick]}"
|
message.response = (
|
||||||
|
f"Busy processing another request. Please wait. {spinner[tick]}"
|
||||||
|
)
|
||||||
tick = (tick + 1) % len(spinner)
|
tick = (tick + 1) % len(spinner)
|
||||||
yield message
|
yield message
|
||||||
await asyncio.sleep(1) # Allow the event loop to process the write
|
await asyncio.sleep(1) # Allow the event loop to process the write
|
||||||
|
|
||||||
self.context.processing = True
|
self.context.processing = True
|
||||||
|
|
||||||
message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
|
message.metadata["system_prompt"] = (
|
||||||
|
f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
|
||||||
|
)
|
||||||
message.context_prompt = ""
|
message.context_prompt = ""
|
||||||
for p in message.preamble.keys():
|
for p in message.preamble.keys():
|
||||||
message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n</|{p}>\n\n"
|
message.context_prompt += (
|
||||||
|
f"\n<|{p}|>\n{message.preamble[p].strip()}\n</|{p}>\n\n"
|
||||||
|
)
|
||||||
message.context_prompt += f"{message.prompt}"
|
message.context_prompt += f"{message.prompt}"
|
||||||
|
|
||||||
# Estimate token length of new messages
|
# Estimate token length of new messages
|
||||||
@ -493,13 +575,17 @@ class Agent(BaseModel, ABC):
|
|||||||
message.status = "thinking"
|
message.status = "thinking"
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt)
|
message.metadata["context_size"] = self.set_optimal_context_size(
|
||||||
|
llm, model, prompt=message.context_prompt
|
||||||
|
)
|
||||||
|
|
||||||
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
||||||
message.status = "thinking"
|
message.status = "thinking"
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
async for message in self.generate_llm_response(llm=llm, model=model, message=message):
|
async for message in self.generate_llm_response(
|
||||||
|
llm=llm, model=model, message=message
|
||||||
|
):
|
||||||
# logger.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
|
# logger.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
|
||||||
if message.status == "error":
|
if message.status == "error":
|
||||||
yield message
|
yield message
|
||||||
@ -514,6 +600,6 @@ class Agent(BaseModel, ABC):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
agent_registry.register(Agent._agent_type, Agent)
|
agent_registry.register(Agent._agent_type, Agent)
|
||||||
|
|
||||||
|
@ -3,9 +3,10 @@ from typing import Literal, AsyncGenerator, ClassVar, Optional, Any
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from . base import Agent, agent_registry
|
from .base import Agent, agent_registry
|
||||||
from .. message import Message
|
from ..message import Message
|
||||||
from .. setup_logging import setup_logging
|
from ..setup_logging import setup_logging
|
||||||
|
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
|
|
||||||
system_message = f"""
|
system_message = f"""
|
||||||
@ -26,35 +27,42 @@ When answering queries, follow these steps:
|
|||||||
Always use tools, <|resume|>, and <|context|> when possible. Be concise, and never make up information. If you do not know the answer, say so.
|
Always use tools, <|resume|>, and <|context|> when possible. Be concise, and never make up information. If you do not know the answer, say so.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Chat(Agent):
|
class Chat(Agent):
|
||||||
"""
|
"""
|
||||||
Chat Agent
|
Chat Agent
|
||||||
"""
|
"""
|
||||||
agent_type: Literal["chat"] = "chat" # type: ignore
|
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
|
||||||
|
|
||||||
system_prompt: str = system_message
|
agent_type: Literal["chat"] = "chat" # type: ignore
|
||||||
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
system_prompt: str = system_message
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
|
||||||
if not self.context:
|
|
||||||
raise ValueError("Context is not set for this agent.")
|
|
||||||
|
|
||||||
async for message in super().prepare_message(message):
|
async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||||
if message.status != "done":
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
yield message
|
if not self.context:
|
||||||
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
if message.preamble:
|
async for message in super().prepare_message(message):
|
||||||
excluded = {}
|
if message.status != "done":
|
||||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys() if p not in excluded]
|
yield message
|
||||||
preamble_types_AND = " and ".join(preamble_types)
|
|
||||||
preamble_types_OR = " or ".join(preamble_types)
|
if message.preamble:
|
||||||
message.preamble["rules"] = f"""\
|
excluded = {}
|
||||||
|
preamble_types = [
|
||||||
|
f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
|
||||||
|
]
|
||||||
|
preamble_types_AND = " and ".join(preamble_types)
|
||||||
|
preamble_types_OR = " or ".join(preamble_types)
|
||||||
|
message.preamble[
|
||||||
|
"rules"
|
||||||
|
] = f"""\
|
||||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
||||||
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
||||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||||
"""
|
"""
|
||||||
message.preamble["question"] = "Respond to:"
|
message.preamble["question"] = "Respond to:"
|
||||||
|
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
agent_registry.register(Chat._agent_type, Chat)
|
agent_registry.register(Chat._agent_type, Chat)
|
||||||
|
@ -1,13 +1,21 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import model_validator # type: ignore
|
from pydantic import model_validator # type: ignore
|
||||||
from typing import Literal, ClassVar, Optional, Any, AsyncGenerator, List # NOTE: You must import Optional for late binding to work
|
from typing import (
|
||||||
|
Literal,
|
||||||
|
ClassVar,
|
||||||
|
Optional,
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
List,
|
||||||
|
) # NOTE: You must import Optional for late binding to work
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from . base import Agent, agent_registry
|
from .base import Agent, agent_registry
|
||||||
from .. conversation import Conversation
|
from ..conversation import Conversation
|
||||||
from .. message import Message
|
from ..message import Message
|
||||||
from .. setup_logging import setup_logging
|
from ..setup_logging import setup_logging
|
||||||
|
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
|
|
||||||
system_fact_check = f"""
|
system_fact_check = f"""
|
||||||
@ -21,50 +29,56 @@ When answering queries, follow these steps:
|
|||||||
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|generated-resume|>, or <|resume|> tags.
|
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|generated-resume|>, or <|resume|> tags.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
class FactCheck(Agent):
|
class FactCheck(Agent):
|
||||||
agent_type: Literal["fact_check"] = "fact_check" # type: ignore
|
agent_type: Literal["fact_check"] = "fact_check" # type: ignore
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
system_prompt:str = system_fact_check
|
system_prompt: str = system_fact_check
|
||||||
facts: str
|
facts: str
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_facts(self):
|
def validate_facts(self):
|
||||||
if not self.facts.strip():
|
if not self.facts.strip():
|
||||||
raise ValueError("Facts cannot be empty")
|
raise ValueError("Facts cannot be empty")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
if not self.context:
|
if not self.context:
|
||||||
raise ValueError("Context is not set for this agent.")
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
resume_agent = self.context.get_agent("resume")
|
resume_agent = self.context.get_agent("resume")
|
||||||
if not resume_agent:
|
if not resume_agent:
|
||||||
raise ValueError("resume agent does not exist")
|
raise ValueError("resume agent does not exist")
|
||||||
|
|
||||||
message.tunables.enable_tools = False
|
message.tunables.enable_tools = False
|
||||||
|
|
||||||
async for message in super().prepare_message(message):
|
async for message in super().prepare_message(message):
|
||||||
if message.status != "done":
|
if message.status != "done":
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
message.preamble["generated-resume"] = resume_agent.resume
|
message.preamble["generated-resume"] = resume_agent.resume
|
||||||
message.preamble["discrepancies"] = self.facts
|
message.preamble["discrepancies"] = self.facts
|
||||||
|
|
||||||
excluded = {"job_description"}
|
excluded = {"job_description"}
|
||||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys() if p not in excluded]
|
preamble_types = [
|
||||||
preamble_types_AND = " and ".join(preamble_types)
|
f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
|
||||||
preamble_types_OR = " or ".join(preamble_types)
|
]
|
||||||
message.preamble["rules"] = f"""\
|
preamble_types_AND = " and ".join(preamble_types)
|
||||||
|
preamble_types_OR = " or ".join(preamble_types)
|
||||||
|
message.preamble[
|
||||||
|
"rules"
|
||||||
|
] = f"""\
|
||||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
||||||
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
||||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||||
"""
|
"""
|
||||||
message.preamble["question"] = "Respond to:"
|
message.preamble["question"] = "Respond to:"
|
||||||
|
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
agent_registry.register(FactCheck._agent_type, FactCheck)
|
agent_registry.register(FactCheck._agent_type, FactCheck)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,12 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import model_validator # type: ignore
|
from pydantic import model_validator # type: ignore
|
||||||
from typing import Literal, ClassVar, Optional, Any, AsyncGenerator, List # NOTE: You must import Optional for late binding to work
|
from typing import (
|
||||||
|
Literal,
|
||||||
|
ClassVar,
|
||||||
|
Optional,
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
List,
|
||||||
|
) # NOTE: You must import Optional for late binding to work
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from . base import Agent, agent_registry
|
from .base import Agent, agent_registry
|
||||||
from .. message import Message
|
from ..message import Message
|
||||||
from .. setup_logging import setup_logging
|
from ..setup_logging import setup_logging
|
||||||
|
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
|
|
||||||
system_fact_check = f"""
|
system_fact_check = f"""
|
||||||
@ -36,82 +44,94 @@ When answering queries, follow these steps:
|
|||||||
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|job_description|>, <|resume|>, or <|context|> tags.
|
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|job_description|>, <|resume|>, or <|context|> tags.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
class Resume(Agent):
|
class Resume(Agent):
|
||||||
agent_type: Literal["resume"] = "resume" # type: ignore
|
agent_type: Literal["resume"] = "resume" # type: ignore
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
system_prompt:str = system_fact_check
|
system_prompt: str = system_fact_check
|
||||||
resume: str
|
resume: str
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_resume(self):
|
def validate_resume(self):
|
||||||
if not self.resume.strip():
|
if not self.resume.strip():
|
||||||
raise ValueError("Resume content cannot be empty")
|
raise ValueError("Resume content cannot be empty")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
if not self.context:
|
if not self.context:
|
||||||
raise ValueError("Context is not set for this agent.")
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
# Generating fact check or resume should not use any tools
|
# Generating fact check or resume should not use any tools
|
||||||
message.tunables.enable_tools = False
|
message.tunables.enable_tools = False
|
||||||
|
|
||||||
async for message in super().prepare_message(message):
|
async for message in super().prepare_message(message):
|
||||||
if message.status != "done":
|
if message.status != "done":
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
message.preamble["generated-resume"] = self.resume
|
message.preamble["generated-resume"] = self.resume
|
||||||
job_description_agent = self.context.get_agent("job_description")
|
job_description_agent = self.context.get_agent("job_description")
|
||||||
if not job_description_agent:
|
if not job_description_agent:
|
||||||
raise ValueError("job_description agent does not exist")
|
raise ValueError("job_description agent does not exist")
|
||||||
|
|
||||||
message.preamble["job_description"] = job_description_agent.job_description
|
message.preamble["job_description"] = job_description_agent.job_description
|
||||||
|
|
||||||
excluded = {}
|
excluded = {}
|
||||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys() if p not in excluded]
|
preamble_types = [
|
||||||
preamble_types_AND = " and ".join(preamble_types)
|
f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
|
||||||
preamble_types_OR = " or ".join(preamble_types)
|
]
|
||||||
message.preamble["rules"] = f"""\
|
preamble_types_AND = " and ".join(preamble_types)
|
||||||
|
preamble_types_OR = " or ".join(preamble_types)
|
||||||
|
message.preamble[
|
||||||
|
"rules"
|
||||||
|
] = f"""\
|
||||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
||||||
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
||||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||||
"""
|
"""
|
||||||
fact_check_agent = self.context.get_agent(agent_type="fact_check")
|
fact_check_agent = self.context.get_agent(agent_type="fact_check")
|
||||||
if fact_check_agent:
|
if fact_check_agent:
|
||||||
message.preamble["question"] = "Respond to:"
|
message.preamble["question"] = "Respond to:"
|
||||||
else:
|
else:
|
||||||
message.preamble["question"] = f"Fact check the <|generated-resume|> based on the <|resume|>{' and <|context|>' if 'context' in message.preamble else ''}."
|
message.preamble["question"] = (
|
||||||
|
f"Fact check the <|generated-resume|> based on the <|resume|>{' and <|context|>' if 'context' in message.preamble else ''}."
|
||||||
|
)
|
||||||
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
|
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
|
||||||
if not self.context:
|
|
||||||
raise ValueError("Context is not set for this agent.")
|
|
||||||
|
|
||||||
async for message in super().process_message(llm, model, message):
|
|
||||||
if message.status != "done":
|
|
||||||
yield message
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
fact_check_agent = self.context.get_agent(agent_type="fact_check")
|
async def process_message(
|
||||||
if not fact_check_agent:
|
self, llm: Any, model: str, message: Message
|
||||||
# Switch agent from "Fact Check Generated Resume" mode
|
) -> AsyncGenerator[Message, None]:
|
||||||
# to "Answer Questions about Generated Resume"
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
self.system_prompt = system_resume
|
if not self.context:
|
||||||
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
# Instantiate the "resume" agent, and seed (or reset) its conversation
|
async for message in super().process_message(llm, model, message):
|
||||||
# with this message.
|
if message.status != "done":
|
||||||
fact_check_agent = self.context.get_or_create_agent(agent_type="fact_check", facts=message.response)
|
yield message
|
||||||
first_fact_check_message = message.copy()
|
|
||||||
first_fact_check_message.prompt = "Fact check the generated resume."
|
fact_check_agent = self.context.get_agent(agent_type="fact_check")
|
||||||
fact_check_agent.conversation.add(first_fact_check_message)
|
if not fact_check_agent:
|
||||||
message.response = "Resume fact checked."
|
# Switch agent from "Fact Check Generated Resume" mode
|
||||||
|
# to "Answer Questions about Generated Resume"
|
||||||
|
self.system_prompt = system_resume
|
||||||
|
|
||||||
|
# Instantiate the "resume" agent, and seed (or reset) its conversation
|
||||||
|
# with this message.
|
||||||
|
fact_check_agent = self.context.get_or_create_agent(
|
||||||
|
agent_type="fact_check", facts=message.response
|
||||||
|
)
|
||||||
|
first_fact_check_message = message.copy()
|
||||||
|
first_fact_check_message.prompt = "Fact check the generated resume."
|
||||||
|
fact_check_agent.conversation.add(first_fact_check_message)
|
||||||
|
message.response = "Resume fact checked."
|
||||||
|
|
||||||
|
# Return the final message
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
# Return the final message
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
agent_registry.register(Resume._agent_type, Resume)
|
agent_registry.register(Resume._agent_type, Resume)
|
||||||
|
@ -1,31 +1,34 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import List, Dict, Optional, Type
|
from typing import List, Dict, Optional, Type
|
||||||
|
|
||||||
|
|
||||||
# We'll use a registry pattern rather than hardcoded strings
|
# We'll use a registry pattern rather than hardcoded strings
|
||||||
class AgentRegistry:
|
class AgentRegistry:
|
||||||
"""Registry for agent types and classes"""
|
"""Registry for agent types and classes"""
|
||||||
_registry: Dict[str, Type] = {}
|
|
||||||
|
|
||||||
@classmethod
|
_registry: Dict[str, Type] = {}
|
||||||
def register(cls, agent_type: str, agent_class: Type) -> Type:
|
|
||||||
"""Register an agent class with its type"""
|
|
||||||
cls._registry[agent_type] = agent_class
|
|
||||||
return agent_class
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_class(cls, agent_type: str) -> Optional[Type]:
|
def register(cls, agent_type: str, agent_class: Type) -> Type:
|
||||||
"""Get the class for a given agent type"""
|
"""Register an agent class with its type"""
|
||||||
return cls._registry.get(agent_type)
|
cls._registry[agent_type] = agent_class
|
||||||
|
return agent_class
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_types(cls) -> List[str]:
|
def get_class(cls, agent_type: str) -> Optional[Type]:
|
||||||
"""Get all registered agent types"""
|
"""Get the class for a given agent type"""
|
||||||
return list(cls._registry.keys())
|
return cls._registry.get(agent_type)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_types(cls) -> List[str]:
|
||||||
|
"""Get all registered agent types"""
|
||||||
|
return list(cls._registry.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_classes(cls) -> Dict[str, Type]:
|
||||||
|
"""Get all registered agent classes"""
|
||||||
|
return cls._registry.copy()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_classes(cls) -> Dict[str, Type]:
|
|
||||||
"""Get all registered agent classes"""
|
|
||||||
return cls._registry.copy()
|
|
||||||
|
|
||||||
# Create a singleton instance
|
# Create a singleton instance
|
||||||
agent_registry = AgentRegistry()
|
agent_registry = AgentRegistry()
|
@ -1,122 +0,0 @@
|
|||||||
import chromadb
|
|
||||||
from typing import List, Dict, Any, Union
|
|
||||||
from . import defines
|
|
||||||
from .chunk import chunk_document
|
|
||||||
import ollama
|
|
||||||
|
|
||||||
def init_chroma_client(persist_directory: str = defines.persist_directory):
|
|
||||||
"""Initialize and return a ChromaDB client."""
|
|
||||||
# return chromadb.PersistentClient(path=persist_directory)
|
|
||||||
return chromadb.Client()
|
|
||||||
|
|
||||||
def create_or_get_collection(db: chromadb.Client, collection_name: str):
|
|
||||||
"""Create or get a ChromaDB collection."""
|
|
||||||
try:
|
|
||||||
return db.get_collection(
|
|
||||||
name=collection_name
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
return db.create_collection(
|
|
||||||
name=collection_name,
|
|
||||||
metadata={"hnsw:space": "cosine"}
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_documents_to_chroma(
|
|
||||||
client: ollama.Client,
|
|
||||||
documents: List[Dict[str, Any]],
|
|
||||||
collection_name: str = "document_collection",
|
|
||||||
text_key: str = "text",
|
|
||||||
max_tokens: int = 512,
|
|
||||||
overlap: int = 50,
|
|
||||||
model: str = defines.encoding_model,
|
|
||||||
persist_directory: str = defines.persist_directory
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Process documents, chunk them, compute embeddings, and store in ChromaDB.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of document dictionaries
|
|
||||||
collection_name: Name for the ChromaDB collection
|
|
||||||
text_key: The key containing text content
|
|
||||||
max_tokens: Maximum tokens per chunk
|
|
||||||
overlap: Token overlap between chunks
|
|
||||||
model: Ollama model for embeddings
|
|
||||||
persist_directory: Directory to store ChromaDB data
|
|
||||||
"""
|
|
||||||
# Initialize ChromaDB client and collection
|
|
||||||
db = init_chroma_client(persist_directory)
|
|
||||||
collection = create_or_get_collection(db, collection_name)
|
|
||||||
|
|
||||||
# Process each document
|
|
||||||
for doc in documents:
|
|
||||||
# Chunk the document
|
|
||||||
doc_chunks = chunk_document(doc, text_key, max_tokens, overlap)
|
|
||||||
|
|
||||||
# Prepare data for ChromaDB
|
|
||||||
ids = []
|
|
||||||
texts = []
|
|
||||||
metadatas = []
|
|
||||||
embeddings = []
|
|
||||||
|
|
||||||
for chunk in doc_chunks:
|
|
||||||
# Create a unique ID for the chunk
|
|
||||||
chunk_id = f"{chunk['id']}_{chunk['chunk_id']}"
|
|
||||||
|
|
||||||
# Extract text
|
|
||||||
text = chunk[text_key]
|
|
||||||
|
|
||||||
# Create metadata (excluding text and embedding to avoid duplication)
|
|
||||||
metadata = {k: v for k, v in chunk.items() if k != text_key and k != "embedding"}
|
|
||||||
|
|
||||||
response = client.embed(model=model, input=text)
|
|
||||||
embedding = response["embeddings"][0]
|
|
||||||
ids.append(chunk_id)
|
|
||||||
texts.append(text)
|
|
||||||
metadatas.append(metadata)
|
|
||||||
embeddings.append(embedding)
|
|
||||||
|
|
||||||
# Add chunks to ChromaDB collection
|
|
||||||
collection.add(
|
|
||||||
ids=ids,
|
|
||||||
documents=texts,
|
|
||||||
embeddings=embeddings,
|
|
||||||
metadatas=metadatas
|
|
||||||
)
|
|
||||||
|
|
||||||
return collection
|
|
||||||
|
|
||||||
def query_chroma(
|
|
||||||
client: ollama.Client,
|
|
||||||
query_text: str,
|
|
||||||
collection_name: str = "document_collection",
|
|
||||||
n_results: int = 5,
|
|
||||||
model: str = defines.encoding_model,
|
|
||||||
persist_directory: str = defines.persist_directory
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Query ChromaDB for similar documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_text: The text to search for
|
|
||||||
collection_name: Name of the ChromaDB collection
|
|
||||||
n_results: Number of results to return
|
|
||||||
model: Ollama model for embedding the query
|
|
||||||
persist_directory: Directory where ChromaDB data is stored
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Query results from ChromaDB
|
|
||||||
"""
|
|
||||||
# Initialize ChromaDB client and collection
|
|
||||||
db = init_chroma_client(persist_directory)
|
|
||||||
collection = create_or_get_collection(db, collection_name)
|
|
||||||
|
|
||||||
query_response = client.embed(model=model, input=query_text)
|
|
||||||
query_embeddings = query_response["embeddings"]
|
|
||||||
|
|
||||||
# Query the collection
|
|
||||||
results = collection.query(
|
|
||||||
query_embeddings=query_embeddings,
|
|
||||||
n_results=n_results
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
@ -1,88 +0,0 @@
|
|||||||
import tiktoken # type: ignore
|
|
||||||
from . import defines
|
|
||||||
from typing import List, Dict, Any, Union
|
|
||||||
|
|
||||||
def get_encoding(model=defines.model):
|
|
||||||
"""Get the tokenizer for counting tokens."""
|
|
||||||
try:
|
|
||||||
return tiktoken.get_encoding("cl100k_base") # Default encoding used by many embedding models
|
|
||||||
except:
|
|
||||||
return tiktoken.encoding_for_model(model)
|
|
||||||
|
|
||||||
def count_tokens(text: str) -> int:
|
|
||||||
"""Count the number of tokens in a text string."""
|
|
||||||
encoding = get_encoding()
|
|
||||||
return len(encoding.encode(text))
|
|
||||||
|
|
||||||
def chunk_text(text: str, max_tokens: int = 512, overlap: int = 50) -> List[str]:
|
|
||||||
"""
|
|
||||||
Split a text into chunks based on token count with overlap between chunks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to split into chunks
|
|
||||||
max_tokens: Maximum number of tokens per chunk
|
|
||||||
overlap: Number of tokens to overlap between chunks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of text chunks
|
|
||||||
"""
|
|
||||||
if not text or max_tokens <= 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
encoding = get_encoding()
|
|
||||||
tokens = encoding.encode(text)
|
|
||||||
chunks = []
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
while i < len(tokens):
|
|
||||||
# Get the current chunk of tokens
|
|
||||||
chunk_end = min(i + max_tokens, len(tokens))
|
|
||||||
chunk_tokens = tokens[i:chunk_end]
|
|
||||||
chunks.append(encoding.decode(chunk_tokens))
|
|
||||||
|
|
||||||
# Move to the next position with overlap
|
|
||||||
if chunk_end == len(tokens):
|
|
||||||
break
|
|
||||||
i += max_tokens - overlap
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
def chunk_document(document: Dict[str, Any],
|
|
||||||
text_key: str = "text",
|
|
||||||
max_tokens: int = 512,
|
|
||||||
overlap: int = 50) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Chunk a document dictionary into multiple chunks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
document: Document dictionary with metadata and text
|
|
||||||
text_key: The key in the document that contains the text to chunk
|
|
||||||
max_tokens: Maximum number of tokens per chunk
|
|
||||||
overlap: Number of tokens to overlap between chunks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of document dictionaries, each with chunked text and preserved metadata
|
|
||||||
"""
|
|
||||||
if text_key not in document:
|
|
||||||
raise Exception(f"{text_key} not in document")
|
|
||||||
|
|
||||||
# Extract text and create chunks
|
|
||||||
if "title" in document:
|
|
||||||
text = f"{document["title"]}: {document[text_key]}"
|
|
||||||
else:
|
|
||||||
text = document[text_key]
|
|
||||||
chunks = chunk_text(text, max_tokens, overlap)
|
|
||||||
|
|
||||||
# Create document chunks with preserved metadata
|
|
||||||
chunked_docs = []
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
# Create a new doc with all original fields
|
|
||||||
doc_chunk = document.copy()
|
|
||||||
# Replace text with the chunk
|
|
||||||
doc_chunk[text_key] = chunk
|
|
||||||
# Add chunk metadata
|
|
||||||
doc_chunk["chunk_id"] = i
|
|
||||||
doc_chunk["chunk_total"] = len(chunks)
|
|
||||||
chunked_docs.append(doc_chunk)
|
|
||||||
|
|
||||||
return chunked_docs
|
|
@ -1,32 +1,35 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, Field, model_validator# type: ignore
|
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from typing import List, Optional, Generator, ClassVar, Any
|
from typing import List, Optional, Generator, ClassVar, Any
|
||||||
from typing_extensions import Annotated, Union
|
from typing_extensions import Annotated, Union
|
||||||
import numpy as np # type: ignore
|
import numpy as np # type: ignore
|
||||||
import logging
|
import logging
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||||
|
|
||||||
from . message import Message, Tunables
|
from .message import Message, Tunables
|
||||||
from . rag import ChromaDBFileWatcher
|
from .rag import ChromaDBFileWatcher
|
||||||
from . import defines
|
from . import defines
|
||||||
from . import tools as Tools
|
from . import tools as Tools
|
||||||
from . agents import AnyAgent
|
from .agents import AnyAgent
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Context(BaseModel):
|
class Context(BaseModel):
|
||||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
||||||
# Required fields
|
# Required fields
|
||||||
file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
||||||
prometheus_collector: Optional[CollectorRegistry] = Field(default=None, exclude=True)
|
prometheus_collector: Optional[CollectorRegistry] = Field(
|
||||||
|
default=None, exclude=True
|
||||||
|
)
|
||||||
|
|
||||||
# Optional fields
|
# Optional fields
|
||||||
id: str = Field(
|
id: str = Field(
|
||||||
default_factory=lambda: str(uuid4()),
|
default_factory=lambda: str(uuid4()),
|
||||||
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||||
)
|
)
|
||||||
user_resume: Optional[str] = None
|
user_resume: Optional[str] = None
|
||||||
user_job_description: Optional[str] = None
|
user_job_description: Optional[str] = None
|
||||||
@ -35,7 +38,7 @@ class Context(BaseModel):
|
|||||||
rags: List[dict] = []
|
rags: List[dict] = []
|
||||||
message_history_length: int = 5
|
message_history_length: int = 5
|
||||||
# Class managed fields
|
# Class managed fields
|
||||||
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
|
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
|
||||||
default_factory=list
|
default_factory=list
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -53,12 +56,16 @@ class Context(BaseModel):
|
|||||||
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
|
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
|
||||||
agent_types = [agent.agent_type for agent in self.agents]
|
agent_types = [agent.agent_type for agent in self.agents]
|
||||||
if len(agent_types) != len(set(agent_types)):
|
if len(agent_types) != len(set(agent_types)):
|
||||||
raise ValueError("Context cannot contain multiple agents of the same agent_type")
|
raise ValueError(
|
||||||
|
"Context cannot contain multiple agents of the same agent_type"
|
||||||
|
)
|
||||||
# for agent in self.agents:
|
# for agent in self.agents:
|
||||||
# agent.set_context(self)
|
# agent.set_context(self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def generate_rag_results(self, message: Message) -> Generator[Message, None, None]:
|
def generate_rag_results(
|
||||||
|
self, message: Message, top_k=10, threshold=0.7
|
||||||
|
) -> Generator[Message, None, None]:
|
||||||
"""
|
"""
|
||||||
Generate RAG results for the given query.
|
Generate RAG results for the given query.
|
||||||
|
|
||||||
@ -71,7 +78,7 @@ class Context(BaseModel):
|
|||||||
try:
|
try:
|
||||||
message.status = "processing"
|
message.status = "processing"
|
||||||
|
|
||||||
entries : int = 0
|
entries: int = 0
|
||||||
|
|
||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
message.response = "No RAG context available."
|
message.response = "No RAG context available."
|
||||||
@ -86,41 +93,57 @@ class Context(BaseModel):
|
|||||||
continue
|
continue
|
||||||
message.response = f"Checking RAG context {rag['name']}..."
|
message.response = f"Checking RAG context {rag['name']}..."
|
||||||
yield message
|
yield message
|
||||||
chroma_results = self.file_watcher.find_similar(query=message.prompt, top_k=10, threshold=0.7)
|
chroma_results = self.file_watcher.find_similar(
|
||||||
|
query=message.prompt, top_k=top_k, threshold=threshold
|
||||||
|
)
|
||||||
if chroma_results:
|
if chroma_results:
|
||||||
entries += len(chroma_results["documents"])
|
entries += len(chroma_results["documents"])
|
||||||
|
|
||||||
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
chroma_embedding = np.array(
|
||||||
|
chroma_results["query_embedding"]
|
||||||
|
).flatten() # Ensure correct shape
|
||||||
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||||
|
|
||||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
umap_2d = self.file_watcher.umap_model_2d.transform(
|
||||||
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
[chroma_embedding]
|
||||||
|
)[0].tolist()
|
||||||
|
print(
|
||||||
|
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
|
||||||
|
) # Debug output
|
||||||
|
|
||||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
umap_3d = self.file_watcher.umap_model_3d.transform(
|
||||||
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
[chroma_embedding]
|
||||||
|
)[0].tolist()
|
||||||
|
print(
|
||||||
|
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
|
||||||
|
) # Debug output
|
||||||
|
|
||||||
message.metadata["rag"].append({
|
message.metadata["rag"].append(
|
||||||
"name": rag["name"],
|
{
|
||||||
**chroma_results,
|
"name": rag["name"],
|
||||||
"umap_embedding_2d": umap_2d,
|
**chroma_results,
|
||||||
"umap_embedding_3d": umap_3d
|
"umap_embedding_2d": umap_2d,
|
||||||
})
|
"umap_embedding_3d": umap_3d,
|
||||||
|
}
|
||||||
|
)
|
||||||
message.response = f"Results from {rag['name']} RAG: {len(chroma_results['documents'])} results."
|
message.response = f"Results from {rag['name']} RAG: {len(chroma_results['documents'])} results."
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
if entries == 0:
|
if entries == 0:
|
||||||
del message.metadata["rag"]
|
del message.metadata["rag"]
|
||||||
|
|
||||||
message.response = f"RAG context gathered from results from {entries} documents."
|
message.response = (
|
||||||
|
f"RAG context gathered from results from {entries} documents."
|
||||||
|
)
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
message.response = f"Error generating RAG results: {str(e)}"
|
message.response = f"Error generating RAG results: {str(e)}"
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
def get_or_create_agent(self, agent_type: str, **kwargs) -> Agent:
|
def get_or_create_agent(self, agent_type: str, **kwargs) -> Agent:
|
||||||
"""
|
"""
|
||||||
@ -156,7 +179,9 @@ class Context(BaseModel):
|
|||||||
def add_agent(self, agent: AnyAgent) -> None:
|
def add_agent(self, agent: AnyAgent) -> None:
|
||||||
"""Add a Agent to the context, ensuring no duplicate agent_type."""
|
"""Add a Agent to the context, ensuring no duplicate agent_type."""
|
||||||
if any(s.agent_type == agent.agent_type for s in self.agents):
|
if any(s.agent_type == agent.agent_type for s in self.agents):
|
||||||
raise ValueError(f"A agent with agent_type '{agent.agent_type}' already exists")
|
raise ValueError(
|
||||||
|
f"A agent with agent_type '{agent.agent_type}' already exists"
|
||||||
|
)
|
||||||
self.agents.append(agent)
|
self.agents.append(agent)
|
||||||
|
|
||||||
def get_agent(self, agent_type: str) -> Agent | None:
|
def get_agent(self, agent_type: str) -> Agent | None:
|
||||||
@ -188,5 +213,7 @@ class Context(BaseModel):
|
|||||||
summary += f"\nChat Name: {agent.name}\n"
|
summary += f"\nChat Name: {agent.name}\n"
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
from . agents import Agent
|
|
||||||
|
from .agents import Agent
|
||||||
|
|
||||||
Context.model_rebuild()
|
Context.model_rebuild()
|
@ -1,7 +1,8 @@
|
|||||||
from pydantic import BaseModel, Field, PrivateAttr # type: ignore
|
from pydantic import BaseModel, Field, PrivateAttr # type: ignore
|
||||||
from typing import List
|
from typing import List
|
||||||
from .message import Message
|
from .message import Message
|
||||||
|
|
||||||
|
|
||||||
class Conversation(BaseModel):
|
class Conversation(BaseModel):
|
||||||
Conversation_messages: List[Message] = Field(default=[], alias="messages")
|
Conversation_messages: List[Message] = Field(default=[], alias="messages")
|
||||||
|
|
||||||
@ -12,24 +13,28 @@ class Conversation(BaseModel):
|
|||||||
return iter(self.Conversation_messages)
|
return iter(self.Conversation_messages)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.Conversation_messages = []
|
self.Conversation_messages = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def messages(self):
|
def messages(self):
|
||||||
"""Return a copy of messages to prevent modification of the internal list."""
|
"""Return a copy of messages to prevent modification of the internal list."""
|
||||||
raise AttributeError("Cannot directly get messages. Use Conversation.add() or .reset()")
|
raise AttributeError(
|
||||||
|
"Cannot directly get messages. Use Conversation.add() or .reset()"
|
||||||
|
)
|
||||||
|
|
||||||
@messages.setter
|
@messages.setter
|
||||||
def messages(self, value):
|
def messages(self, value):
|
||||||
"""Control how messages can be set, or prevent setting altogether."""
|
"""Control how messages can be set, or prevent setting altogether."""
|
||||||
raise AttributeError("Cannot directly set messages. Use Conversation.add() or .reset()")
|
raise AttributeError(
|
||||||
|
"Cannot directly set messages. Use Conversation.add() or .reset()"
|
||||||
|
)
|
||||||
|
|
||||||
def add(self, message: Message | List[Message]) -> None:
|
def add(self, message: Message | List[Message]) -> None:
|
||||||
"""Add a Message(s) to the conversation."""
|
"""Add a Message(s) to the conversation."""
|
||||||
if isinstance(message, Message):
|
if isinstance(message, Message):
|
||||||
self.Conversation_messages.append(message)
|
self.Conversation_messages.append(message)
|
||||||
else:
|
else:
|
||||||
self.Conversation_messages.extend(message)
|
self.Conversation_messages.extend(message)
|
||||||
|
|
||||||
def get_summary(self) -> str:
|
def get_summary(self) -> str:
|
||||||
"""Return a summary of the conversation."""
|
"""Return a summary of the conversation."""
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
ollama_api_url="http://ollama:11434" # Default Ollama local endpoint
|
ollama_api_url = "http://ollama:11434" # Default Ollama local endpoint
|
||||||
#model = "deepseek-r1:7b" # Tool calls don"t work
|
# model = "deepseek-r1:7b" # Tool calls don"t work
|
||||||
#model="mistral:7b" # Tool calls don"t work
|
# model="mistral:7b" # Tool calls don"t work
|
||||||
#model = "llama3.2"
|
# model = "llama3.2"
|
||||||
#model = "qwen3:8b" # Requires newer ollama
|
# model = "qwen3:8b" # Requires newer ollama
|
||||||
#model = "gemma3:4b" # Requires newer ollama
|
# model = "gemma3:4b" # Requires newer ollama
|
||||||
model = os.getenv("MODEL_NAME", "qwen2.5:7b")
|
model = os.getenv("MODEL_NAME", "qwen2.5:7b")
|
||||||
embedding_model = os.getenv("EMBEDDING_MODEL_NAME", "mxbai-embed-large")
|
embedding_model = os.getenv("EMBEDDING_MODEL_NAME", "mxbai-embed-large")
|
||||||
persist_directory = os.getenv("PERSIST_DIR", "/opt/backstory/chromadb")
|
persist_directory = os.getenv("PERSIST_DIR", "/opt/backstory/chromadb")
|
||||||
max_context = 2048*8*2
|
max_context = 2048 * 8 * 2
|
||||||
doc_dir = "/opt/backstory/docs/"
|
doc_dir = "/opt/backstory/docs/"
|
||||||
context_dir = "/opt/backstory/sessions"
|
context_dir = "/opt/backstory/sessions"
|
||||||
static_content = "/opt/backstory/frontend/deployed"
|
static_content = "/opt/backstory/frontend/deployed"
|
||||||
|
@ -1,468 +0,0 @@
|
|||||||
import requests
|
|
||||||
from typing import List, Dict, Any, Union
|
|
||||||
import tiktoken
|
|
||||||
import feedparser
|
|
||||||
import logging as log
|
|
||||||
import datetime
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
import chromadb
|
|
||||||
import ollama
|
|
||||||
import re
|
|
||||||
import numpy as np
|
|
||||||
from . import chunk
|
|
||||||
|
|
||||||
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
|
|
||||||
#MODEL_NAME = "deepseek-r1:1.5b"
|
|
||||||
MODEL_NAME = "deepseek-r1:7b"
|
|
||||||
EMBED_MODEL = "mxbai-embed-large"
|
|
||||||
PERSIST_DIRECTORY = "/root/.cache/chroma"
|
|
||||||
|
|
||||||
client = ollama.Client(host=OLLAMA_API_URL)
|
|
||||||
|
|
||||||
def extract_text_from_html_or_xml(content, is_xml=False):
|
|
||||||
# Parse the content
|
|
||||||
if is_xml:
|
|
||||||
soup = BeautifulSoup(content, 'xml') # Use 'xml' parser for XML content
|
|
||||||
else:
|
|
||||||
soup = BeautifulSoup(content, 'html.parser') # Default to 'html.parser' for HTML content
|
|
||||||
|
|
||||||
# Extract and return just the text
|
|
||||||
return soup.get_text()
|
|
||||||
|
|
||||||
class Feed():
|
|
||||||
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
|
|
||||||
self.name = name
|
|
||||||
self.url = url
|
|
||||||
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
|
|
||||||
self.last_poll = None
|
|
||||||
self.articles = []
|
|
||||||
self.max_articles = max_articles
|
|
||||||
self.update()
|
|
||||||
|
|
||||||
def update(self):
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
|
|
||||||
log.info(f"Updating {self.name}")
|
|
||||||
feed = feedparser.parse(self.url)
|
|
||||||
self.articles = []
|
|
||||||
self.last_poll = now
|
|
||||||
|
|
||||||
if len(feed.entries) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
for i, entry in enumerate(feed.entries[:self.max_articles]):
|
|
||||||
content = {}
|
|
||||||
content['source'] = self.name
|
|
||||||
content['id'] = f"{self.name}{i}"
|
|
||||||
title = entry.get("title")
|
|
||||||
if title:
|
|
||||||
content['title'] = title
|
|
||||||
link = entry.get("link")
|
|
||||||
if link:
|
|
||||||
content['link'] = link
|
|
||||||
text = entry.get("summary")
|
|
||||||
if text:
|
|
||||||
content['text'] = extract_text_from_html_or_xml(text, False)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
published = entry.get("published")
|
|
||||||
if published:
|
|
||||||
content['published'] = published
|
|
||||||
|
|
||||||
self.articles.append(content)
|
|
||||||
else:
|
|
||||||
log.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
|
|
||||||
return self.articles
|
|
||||||
|
|
||||||
# News RSS Feeds
|
|
||||||
rss_feeds = [
|
|
||||||
Feed(name="IGN.com", url="https://feeds.feedburner.com/ign/games-all"),
|
|
||||||
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
|
|
||||||
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
|
|
||||||
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
|
|
||||||
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
|
|
||||||
Feed(name="Time", url="https://time.com/feed/"),
|
|
||||||
Feed(name="Euronews", url="https://www.euronews.com/rss"),
|
|
||||||
# Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def init_chroma_client(persist_directory: str = PERSIST_DIRECTORY):
|
|
||||||
"""Initialize and return a ChromaDB client."""
|
|
||||||
# return chromadb.PersistentClient(path=persist_directory)
|
|
||||||
return chromadb.Client()
|
|
||||||
|
|
||||||
def create_or_get_collection(client, collection_name: str):
|
|
||||||
"""Create or get a ChromaDB collection."""
|
|
||||||
try:
|
|
||||||
return client.get_collection(
|
|
||||||
name=collection_name
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
return client.create_collection(
|
|
||||||
name=collection_name,
|
|
||||||
metadata={"hnsw:space": "cosine"}
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_documents_to_chroma(
|
|
||||||
documents: List[Dict[str, Any]],
|
|
||||||
collection_name: str = "document_collection",
|
|
||||||
text_key: str = "text",
|
|
||||||
max_tokens: int = 512,
|
|
||||||
overlap: int = 50,
|
|
||||||
model: str = EMBED_MODEL,
|
|
||||||
persist_directory: str = PERSIST_DIRECTORY
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Process documents, chunk them, compute embeddings, and store in ChromaDB.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of document dictionaries
|
|
||||||
collection_name: Name for the ChromaDB collection
|
|
||||||
text_key: The key containing text content
|
|
||||||
max_tokens: Maximum tokens per chunk
|
|
||||||
overlap: Token overlap between chunks
|
|
||||||
model: Ollama model for embeddings
|
|
||||||
persist_directory: Directory to store ChromaDB data
|
|
||||||
"""
|
|
||||||
# Initialize ChromaDB client and collection
|
|
||||||
db = init_chroma_client(persist_directory)
|
|
||||||
collection = create_or_get_collection(db, collection_name)
|
|
||||||
|
|
||||||
# Process each document
|
|
||||||
for doc in documents:
|
|
||||||
# Chunk the document
|
|
||||||
doc_chunks = chunk_document(doc, text_key, max_tokens, overlap)
|
|
||||||
|
|
||||||
# Prepare data for ChromaDB
|
|
||||||
ids = []
|
|
||||||
texts = []
|
|
||||||
metadatas = []
|
|
||||||
embeddings = []
|
|
||||||
|
|
||||||
for chunk in doc_chunks:
|
|
||||||
# Create a unique ID for the chunk
|
|
||||||
chunk_id = f"{chunk['id']}_{chunk['chunk_id']}"
|
|
||||||
|
|
||||||
# Extract text
|
|
||||||
text = chunk[text_key]
|
|
||||||
|
|
||||||
# Create metadata (excluding text and embedding to avoid duplication)
|
|
||||||
metadata = {k: v for k, v in chunk.items() if k != text_key and k != "embedding"}
|
|
||||||
|
|
||||||
response = client.embed(model=model, input=text)
|
|
||||||
embedding = response["embeddings"][0]
|
|
||||||
ids.append(chunk_id)
|
|
||||||
texts.append(text)
|
|
||||||
metadatas.append(metadata)
|
|
||||||
embeddings.append(embedding)
|
|
||||||
|
|
||||||
# Add chunks to ChromaDB collection
|
|
||||||
collection.add(
|
|
||||||
ids=ids,
|
|
||||||
documents=texts,
|
|
||||||
embeddings=embeddings,
|
|
||||||
metadatas=metadatas
|
|
||||||
)
|
|
||||||
|
|
||||||
return collection
|
|
||||||
|
|
||||||
def query_chroma(
|
|
||||||
query_text: str,
|
|
||||||
collection_name: str = "document_collection",
|
|
||||||
n_results: int = 5,
|
|
||||||
model: str = EMBED_MODEL,
|
|
||||||
persist_directory: str = PERSIST_DIRECTORY
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Query ChromaDB for similar documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_text: The text to search for
|
|
||||||
collection_name: Name of the ChromaDB collection
|
|
||||||
n_results: Number of results to return
|
|
||||||
model: Ollama model for embedding the query
|
|
||||||
persist_directory: Directory where ChromaDB data is stored
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Query results from ChromaDB
|
|
||||||
"""
|
|
||||||
# Initialize ChromaDB client and collection
|
|
||||||
db = init_chroma_client(persist_directory)
|
|
||||||
collection = create_or_get_collection(db, collection_name)
|
|
||||||
|
|
||||||
query_response = client.embed(model=model, input=query_text)
|
|
||||||
query_embeddings = query_response["embeddings"]
|
|
||||||
|
|
||||||
# Query the collection
|
|
||||||
results = collection.query(
|
|
||||||
query_embeddings=query_embeddings,
|
|
||||||
n_results=n_results
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def print_top_match(query_results, index=0, documents=None):
|
|
||||||
"""
|
|
||||||
Print detailed information about the top matching document,
|
|
||||||
including the full original document content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_results: Results from ChromaDB query
|
|
||||||
documents: Original documents dictionary to look up full content (optional)
|
|
||||||
"""
|
|
||||||
if not query_results or not query_results["ids"] or len(query_results["ids"][0]) == 0:
|
|
||||||
print("No matching documents found.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get the top result
|
|
||||||
top_id = query_results["ids"][0][index]
|
|
||||||
top_document_chunk = query_results["documents"][0][index]
|
|
||||||
top_metadata = query_results["metadatas"][0][index]
|
|
||||||
top_distance = query_results["distances"][0][index]
|
|
||||||
|
|
||||||
print("="*50)
|
|
||||||
print("MATCHING DOCUMENT")
|
|
||||||
print("="*50)
|
|
||||||
print(f"Chunk ID: {top_id}")
|
|
||||||
print(f"Similarity Score: {top_distance:.4f}") # Convert distance to similarity
|
|
||||||
|
|
||||||
print("\nCHUNK METADATA:")
|
|
||||||
for key, value in top_metadata.items():
|
|
||||||
print(f" {key}: {value}")
|
|
||||||
|
|
||||||
print("\nMATCHING CHUNK CONTENT:")
|
|
||||||
print(top_document_chunk[:500].strip() + ("..." if len(top_document_chunk) > 500 else ""))
|
|
||||||
|
|
||||||
# Extract the original document ID from the chunk ID
|
|
||||||
# Chunk IDs are in format "doc_id_chunk_num"
|
|
||||||
original_doc_id = top_id.split('_')[0]
|
|
||||||
|
|
||||||
def get_top_match(query_results, index=0, documents=None):
|
|
||||||
top_id = query_results["ids"][index][0]
|
|
||||||
# Extract the original document ID from the chunk ID
|
|
||||||
# Chunk IDs are in format "doc_id_chunk_num"
|
|
||||||
original_doc_id = top_id.split('_')[0]
|
|
||||||
|
|
||||||
# Return the full document for further processing if needed
|
|
||||||
if documents is not None:
|
|
||||||
return next((doc for doc in documents if doc["id"] == original_doc_id), None)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def show_documents(documents=None):
|
|
||||||
if not documents:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Print the top matching document
|
|
||||||
for i, doc in enumerate(documents):
|
|
||||||
print(f"Document {i+1}:")
|
|
||||||
print(f" Title: {doc['title']}")
|
|
||||||
print(f" Text: {doc['text'][:100]}...")
|
|
||||||
print()
|
|
||||||
|
|
||||||
def show_headlines(documents=None):
|
|
||||||
if not documents:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Print the top matching document
|
|
||||||
for doc in documents:
|
|
||||||
print(f"{doc['source']}: {doc['title']}")
|
|
||||||
|
|
||||||
def show_help():
|
|
||||||
print("""help>
|
|
||||||
docs Show RAG docs
|
|
||||||
full Show last full top match
|
|
||||||
headlines Show the RAG headlines
|
|
||||||
prompt Show the last prompt
|
|
||||||
response Show the last response
|
|
||||||
scores Show last RAG scores
|
|
||||||
why|think Show last response's <think>
|
|
||||||
context|match Show RAG match info to last prompt
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
# Example usage
|
|
||||||
if __name__ == "__main__":
|
|
||||||
documents = []
|
|
||||||
for feed in rss_feeds:
|
|
||||||
documents.extend(feed.articles)
|
|
||||||
|
|
||||||
show_documents(documents=documents)
|
|
||||||
|
|
||||||
# Process documents and store in ChromaDB
|
|
||||||
collection = process_documents_to_chroma(
|
|
||||||
documents=documents,
|
|
||||||
collection_name="research_papers",
|
|
||||||
max_tokens=256,
|
|
||||||
overlap=25,
|
|
||||||
model=EMBED_MODEL,
|
|
||||||
persist_directory="/root/.cache/chroma"
|
|
||||||
)
|
|
||||||
|
|
||||||
last_results = None
|
|
||||||
last_prompt = None
|
|
||||||
last_system = None
|
|
||||||
last_response = None
|
|
||||||
last_why = None
|
|
||||||
last_messages = []
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
search_query = input("> ").strip()
|
|
||||||
except KeyboardInterrupt as e:
|
|
||||||
print("\nExiting.")
|
|
||||||
break
|
|
||||||
|
|
||||||
if search_query == "exit" or search_query == "quit":
|
|
||||||
print("\nExiting.")
|
|
||||||
break
|
|
||||||
|
|
||||||
if search_query == "docs":
|
|
||||||
show_documents(documents)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "prompt":
|
|
||||||
if last_prompt:
|
|
||||||
print(f"""last prompt>
|
|
||||||
{"="*10}system{"="*10}
|
|
||||||
{last_system}
|
|
||||||
{"="*10}prompt{"="*10}
|
|
||||||
{last_prompt}""")
|
|
||||||
else:
|
|
||||||
print(f"No prompts yet")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "response":
|
|
||||||
if last_response:
|
|
||||||
print(f"""last response>
|
|
||||||
{"="*10}response{"="*10}
|
|
||||||
{last_response}""")
|
|
||||||
else:
|
|
||||||
print(f"No responses yet")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "" or search_query == "help":
|
|
||||||
show_help()
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "headlines":
|
|
||||||
show_headlines(documents)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "match" or search_query == "context":
|
|
||||||
if last_results:
|
|
||||||
print_top_match(last_results, documents=documents)
|
|
||||||
else:
|
|
||||||
print("No match to give info on")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "why" or search_query == "think":
|
|
||||||
if last_why:
|
|
||||||
print(f"""
|
|
||||||
why>
|
|
||||||
{last_why}
|
|
||||||
""")
|
|
||||||
else:
|
|
||||||
print("No processed prompts")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "scores":
|
|
||||||
if last_results:
|
|
||||||
for i, _ in enumerate(last_results):
|
|
||||||
print_top_match(last_results, documents=documents, index=i)
|
|
||||||
else:
|
|
||||||
print("No match to give info on")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "full":
|
|
||||||
if last_results:
|
|
||||||
full = get_top_match(last_results, documents=documents)
|
|
||||||
if full:
|
|
||||||
print(f"""Context:
|
|
||||||
Source: {full["source"]}
|
|
||||||
Title: {full["title"]}
|
|
||||||
Link: {full["link"]}
|
|
||||||
Distance: {last_results.get("distances", [[0]])[0][0]}
|
|
||||||
Full text:
|
|
||||||
{full["text"]}""")
|
|
||||||
else:
|
|
||||||
print("No match to give info on")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Query ChromaDB
|
|
||||||
results = query_chroma(
|
|
||||||
query_text=search_query,
|
|
||||||
collection_name="research_papers",
|
|
||||||
n_results=10
|
|
||||||
)
|
|
||||||
last_results = results
|
|
||||||
|
|
||||||
full = get_top_match(results, documents=documents)
|
|
||||||
|
|
||||||
headlines = ""
|
|
||||||
for doc in documents:
|
|
||||||
headlines += f"{doc['source']}: {doc['title']}\n"
|
|
||||||
|
|
||||||
system=f"""
|
|
||||||
You are the assistant. Your name is airc. This application is called airc (pronounced Eric).
|
|
||||||
|
|
||||||
Information about the author of this program and the AI model it uses:
|
|
||||||
|
|
||||||
* James wrote the python application called airc that is driving this RAG model on top of {MODEL_NAME} using {EMBED_MODEL} and chromadb for vector embedding. Link https://github.com/jketreno/airc.
|
|
||||||
* James Ketrenos is a software engineer with a history in all levels of the computer stack, from the kernel to full-stack web applications. He dabbles in AI/ML and is familiar with pytorch and ollama.
|
|
||||||
* James Ketrenos deployed this application locally on an Intel Arc B580 (battlemage) computer using Intel's ipex-llm.
|
|
||||||
* For Intel GPU metrics, James Ketrenos wrote the "ze-monitor" utility in C++. ze-monitor provides Intel GPU telemetry data for Intel client GPU devices, similar to xpu-smi. Link https://github.com/jketreno/ze-monitor. airc uses ze-monitor.
|
|
||||||
* James lives in Portland, Oregon and has three kids. Two are attending Oregon State University and one is attending Williamette University.
|
|
||||||
* airc provides an IRC chat bot as well as a React web frontend available at https://airc.ketrenos.com
|
|
||||||
|
|
||||||
You must follow these rules:
|
|
||||||
|
|
||||||
* Provide short (less than 100 character) responses.
|
|
||||||
* Provide a single response.
|
|
||||||
* Do not prefix it with a word like 'Answer'.
|
|
||||||
* For information about the AI running this system, include information about author, including links.
|
|
||||||
* For information relevant to the current events in the <input></input> tags, use that information and state the source when information comes from.
|
|
||||||
|
|
||||||
"""
|
|
||||||
context = "Information related to current events\n<input>=["
|
|
||||||
for doc in documents:
|
|
||||||
item = {'source':doc["source"],'article':{'title':doc["title"],'link':doc["link"],'text':doc["text"]}}
|
|
||||||
context += f"{item}"
|
|
||||||
context += "\n</input>"
|
|
||||||
|
|
||||||
prompt = f"{search_query}"
|
|
||||||
last_prompt = prompt
|
|
||||||
last_system = system # cache it before news context is added
|
|
||||||
system = f"{system}{context}"
|
|
||||||
if len(last_messages) != 0:
|
|
||||||
message_context = f"{last_messages}"
|
|
||||||
prompt = f"{message_context}{prompt}"
|
|
||||||
|
|
||||||
print(f"system len: {len(system)}")
|
|
||||||
print(f"prompt len: {len(prompt)}")
|
|
||||||
output = client.generate(
|
|
||||||
model=MODEL_NAME,
|
|
||||||
system=system,
|
|
||||||
prompt=prompt,
|
|
||||||
stream=False,
|
|
||||||
options={ 'num_ctx': 100000 }
|
|
||||||
)
|
|
||||||
# Prune off the <think>...</think>
|
|
||||||
matches = re.match(r'^<think>(.*?)</think>(.*)$', output['response'], flags=re.DOTALL)
|
|
||||||
if matches:
|
|
||||||
last_why = matches[1].strip()
|
|
||||||
content = matches[2].strip()
|
|
||||||
else:
|
|
||||||
print(f"[garbled] response>\n{output['response']}")
|
|
||||||
print(f"Response>\n{content}")
|
|
||||||
|
|
||||||
last_response = content
|
|
||||||
last_messages.extend(({
|
|
||||||
'role': 'user',
|
|
||||||
'name': 'james',
|
|
||||||
'message': search_query
|
|
||||||
}, {
|
|
||||||
'role': 'assistant',
|
|
||||||
'message': content
|
|
||||||
}))
|
|
||||||
last_messages = last_messages[:10]
|
|
@ -3,52 +3,66 @@ from typing import Dict, List, Optional, Any
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from asyncio import Event
|
from asyncio import Event
|
||||||
|
|
||||||
|
|
||||||
class Tunables(BaseModel):
|
class Tunables(BaseModel):
|
||||||
enable_rag : bool = Field(default=True) # Enable RAG collection chromadb matching
|
enable_rag: bool = Field(default=True) # Enable RAG collection chromadb matching
|
||||||
enable_tools : bool = Field(default=True) # Enable LLM to use tools
|
enable_tools: bool = Field(default=True) # Enable LLM to use tools
|
||||||
enable_context : bool = Field(default=True) # Add <|context|> field to message
|
enable_context: bool = Field(default=True) # Add <|context|> field to message
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
model_config = {"arbitrary_types_allowed": True} # Allow Event
|
model_config = {"arbitrary_types_allowed": True} # Allow Event
|
||||||
# Required
|
# Required
|
||||||
prompt: str # Query to be answered
|
prompt: str # Query to be answered
|
||||||
|
|
||||||
# Tunables
|
# Tunables
|
||||||
tunables: Tunables = Field(default_factory=Tunables)
|
tunables: Tunables = Field(default_factory=Tunables)
|
||||||
|
|
||||||
# Generated while processing message
|
# Generated while processing message
|
||||||
status: str = "" # Status of the message
|
status: str = "" # Status of the message
|
||||||
preamble: dict[str,str] = {} # Preamble to be prepended to the prompt
|
preamble: dict[str, str] = {} # Preamble to be prepended to the prompt
|
||||||
system_prompt: str = "" # System prompt provided to the LLM
|
system_prompt: str = "" # System prompt provided to the LLM
|
||||||
context_prompt: str = "" # Full content of the message (preamble + prompt)
|
context_prompt: str = "" # Full content of the message (preamble + prompt)
|
||||||
response: str = "" # LLM response to the preamble + query
|
response: str = "" # LLM response to the preamble + query
|
||||||
metadata: Dict[str, Any] = Field(default_factory=lambda: {
|
metadata: Dict[str, Any] = Field(
|
||||||
"rag": [],
|
default_factory=lambda: {
|
||||||
"eval_count": 0,
|
"rag": [],
|
||||||
"eval_duration": 0,
|
"eval_count": 0,
|
||||||
"prompt_eval_count": 0,
|
"eval_duration": 0,
|
||||||
"prompt_eval_duration": 0,
|
"prompt_eval_count": 0,
|
||||||
"context_size": 0,
|
"prompt_eval_duration": 0,
|
||||||
})
|
"context_size": 0,
|
||||||
network_packets: int = 0 # Total number of streaming packets
|
}
|
||||||
network_bytes: int = 0 # Total bytes sent while streaming packets
|
)
|
||||||
actions: List[str] = [] # Other session modifying actions performed while processing the message
|
network_packets: int = 0 # Total number of streaming packets
|
||||||
|
network_bytes: int = 0 # Total bytes sent while streaming packets
|
||||||
|
actions: List[str] = (
|
||||||
|
[]
|
||||||
|
) # Other session modifying actions performed while processing the message
|
||||||
timestamp: datetime = datetime.now(timezone.utc)
|
timestamp: datetime = datetime.now(timezone.utc)
|
||||||
chunk: str = Field(default="") # This needs to be serialized so it will be sent in responses
|
chunk: str = Field(
|
||||||
partial_response: str = Field(default="") # This needs to be serialized so it will be sent in responses on timeout
|
default=""
|
||||||
|
) # This needs to be serialized so it will be sent in responses
|
||||||
|
partial_response: str = Field(
|
||||||
|
default=""
|
||||||
|
) # This needs to be serialized so it will be sent in responses on timeout
|
||||||
|
title: str = Field(
|
||||||
|
default=""
|
||||||
|
) # This needs to be serialized so it will be sent in responses on timeout
|
||||||
|
|
||||||
def add_action(self, action: str | list[str]) -> None:
|
def add_action(self, action: str | list[str]) -> None:
|
||||||
"""Add a actions(s) to the message."""
|
"""Add a actions(s) to the message."""
|
||||||
if isinstance(action, str):
|
if isinstance(action, str):
|
||||||
self.actions.append(action)
|
self.actions.append(action)
|
||||||
else:
|
else:
|
||||||
self.actions.extend(action)
|
self.actions.extend(action)
|
||||||
|
|
||||||
def get_summary(self) -> str:
|
def get_summary(self) -> str:
|
||||||
"""Return a summary of the message."""
|
"""Return a summary of the message."""
|
||||||
response_summary = (
|
response_summary = (
|
||||||
f"Response: {self.response} (Actions: {', '.join(self.actions)})"
|
f"Response: {self.response} (Actions: {', '.join(self.actions)})"
|
||||||
if self.response else "No response yet"
|
if self.response
|
||||||
|
else "No response yet"
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
f"Message at {self.timestamp}:\n"
|
f"Message at {self.timestamp}:\n"
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from prometheus_client import Counter, Gauge, Summary, Histogram, Info, Enum, CollectorRegistry # type: ignore
|
from prometheus_client import Counter, Gauge, Summary, Histogram, Info, Enum, CollectorRegistry # type: ignore
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
|
|
||||||
def singleton(cls):
|
def singleton(cls):
|
||||||
instance = None
|
instance = None
|
||||||
lock = Lock()
|
lock = Lock()
|
||||||
@ -14,80 +15,81 @@ def singleton(cls):
|
|||||||
|
|
||||||
return get_instance
|
return get_instance
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class Metrics():
|
class Metrics:
|
||||||
def __init__(self, *args, prometheus_collector, **kwargs):
|
def __init__(self, *args, prometheus_collector, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.prometheus_collector = prometheus_collector
|
self.prometheus_collector = prometheus_collector
|
||||||
|
|
||||||
self.prepare_count : Counter = Counter(
|
self.prepare_count: Counter = Counter(
|
||||||
name="prepare_total",
|
name="prepare_total",
|
||||||
documentation="Total messages prepared by agent type",
|
documentation="Total messages prepared by agent type",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prepare_duration : Histogram = Histogram(
|
self.prepare_duration: Histogram = Histogram(
|
||||||
name="prepare_duration",
|
name="prepare_duration",
|
||||||
documentation="Preparation duration by agent type",
|
documentation="Preparation duration by agent type",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.process_count : Counter = Counter(
|
self.process_count: Counter = Counter(
|
||||||
name="process",
|
name="process",
|
||||||
documentation="Total messages processed by agent type",
|
documentation="Total messages processed by agent type",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.process_duration : Histogram = Histogram(
|
self.process_duration: Histogram = Histogram(
|
||||||
name="process_duration",
|
name="process_duration",
|
||||||
documentation="Processing duration by agent type",
|
documentation="Processing duration by agent type",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.tool_count : Counter = Counter(
|
self.tool_count: Counter = Counter(
|
||||||
name="tool_total",
|
name="tool_total",
|
||||||
documentation="Total messages tooled by agent type",
|
documentation="Total messages tooled by agent type",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.tool_duration : Histogram = Histogram(
|
self.tool_duration: Histogram = Histogram(
|
||||||
name="tool_duration",
|
name="tool_duration",
|
||||||
documentation="Tool duration by agent type",
|
documentation="Tool duration by agent type",
|
||||||
buckets=(0.1, 0.5, 1.0, 2.0, float('inf')),
|
buckets=(0.1, 0.5, 1.0, 2.0, float("inf")),
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.generate_count : Counter = Counter(
|
self.generate_count: Counter = Counter(
|
||||||
name="generate_total",
|
name="generate_total",
|
||||||
documentation="Total messages generated by agent type",
|
documentation="Total messages generated by agent type",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.generate_duration : Histogram = Histogram(
|
self.generate_duration: Histogram = Histogram(
|
||||||
name="generate_duration",
|
name="generate_duration",
|
||||||
documentation="Generate duration by agent type",
|
documentation="Generate duration by agent type",
|
||||||
buckets=(0.1, 0.5, 1.0, 2.0, float('inf')),
|
buckets=(0.1, 0.5, 1.0, 2.0, float("inf")),
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.tokens_prompt : Counter = Counter(
|
self.tokens_prompt: Counter = Counter(
|
||||||
name="tokens_prompt",
|
name="tokens_prompt",
|
||||||
documentation="Total tokens passed as prompt to LLM",
|
documentation="Total tokens passed as prompt to LLM",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.tokens_eval : Counter = Counter(
|
self.tokens_eval: Counter = Counter(
|
||||||
name="tokens_eval",
|
name="tokens_eval",
|
||||||
documentation="Total tokens returned by LLM",
|
documentation="Total tokens returned by LLM",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
342
src/utils/rag.py
342
src/utils/rag.py
@ -1,4 +1,4 @@
|
|||||||
from pydantic import BaseModel # type: ignore
|
from pydantic import BaseModel # type: ignore
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
@ -13,18 +13,22 @@ import time
|
|||||||
import hashlib
|
import hashlib
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import numpy as np # type: ignore
|
import numpy as np # type: ignore
|
||||||
|
import traceback
|
||||||
|
import os
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import ollama
|
import ollama
|
||||||
from langchain.text_splitter import CharacterTextSplitter # type: ignore
|
from watchdog.observers import Observer # type: ignore
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from watchdog.events import FileSystemEventHandler # type: ignore
|
||||||
from langchain.schema import Document # type: ignore
|
import umap # type: ignore
|
||||||
from watchdog.observers import Observer # type: ignore
|
from markitdown import MarkItDown # type: ignore
|
||||||
from watchdog.events import FileSystemEventHandler # type: ignore
|
from chromadb.api.models.Collection import Collection # type: ignore
|
||||||
import umap # type: ignore
|
|
||||||
from markitdown import MarkItDown # type: ignore
|
from .markdown_chunker import (
|
||||||
from chromadb.api.models.Collection import Collection # type: ignore
|
MarkdownChunker,
|
||||||
|
Chunk,
|
||||||
|
)
|
||||||
|
|
||||||
# Import your existing modules
|
# Import your existing modules
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -34,13 +38,11 @@ else:
|
|||||||
# When imported as a module, use relative imports
|
# When imported as a module, use relative imports
|
||||||
from . import defines
|
from . import defines
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
|
||||||
'ChromaDBFileWatcher',
|
|
||||||
'start_file_watcher'
|
DEFAULT_CHUNK_SIZE = 750
|
||||||
]
|
DEFAULT_CHUNK_OVERLAP = 100
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE=750
|
|
||||||
DEFAULT_CHUNK_OVERLAP=100
|
|
||||||
|
|
||||||
class ChromaDBGetResponse(BaseModel):
|
class ChromaDBGetResponse(BaseModel):
|
||||||
ids: List[str]
|
ids: List[str]
|
||||||
@ -48,9 +50,19 @@ class ChromaDBGetResponse(BaseModel):
|
|||||||
documents: Optional[List[str]] = None
|
documents: Optional[List[str]] = None
|
||||||
metadatas: Optional[List[Dict[str, Any]]] = None
|
metadatas: Optional[List[Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChromaDBFileWatcher(FileSystemEventHandler):
|
class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||||
def __init__(self, llm, watch_directory, loop, persist_directory=None, collection_name="documents",
|
def __init__(
|
||||||
chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP, recreate=False):
|
self,
|
||||||
|
llm,
|
||||||
|
watch_directory,
|
||||||
|
loop,
|
||||||
|
persist_directory=None,
|
||||||
|
collection_name="documents",
|
||||||
|
chunk_size=DEFAULT_CHUNK_SIZE,
|
||||||
|
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
|
||||||
|
recreate=False,
|
||||||
|
):
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.watch_directory = watch_directory
|
self.watch_directory = watch_directory
|
||||||
self.persist_directory = persist_directory or defines.persist_directory
|
self.persist_directory = persist_directory or defines.persist_directory
|
||||||
@ -58,33 +70,29 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
self.chunk_overlap = chunk_overlap
|
self.chunk_overlap = chunk_overlap
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self._umap_collection : ChromaDBGetResponse | None = None
|
self._umap_collection: ChromaDBGetResponse | None = None
|
||||||
self._umap_embedding_2d : np.ndarray = []
|
self._umap_embedding_2d: np.ndarray = []
|
||||||
self._umap_embedding_3d : np.ndarray = []
|
self._umap_embedding_3d: np.ndarray = []
|
||||||
self._umap_model_2d : umap.UMAP = None
|
self._umap_model_2d: umap.UMAP = None
|
||||||
self._umap_model_3d : umap.UMAP = None
|
self._umap_model_3d: umap.UMAP = None
|
||||||
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
|
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
|
||||||
|
|
||||||
#self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||||
|
|
||||||
# Path for storing file hash state
|
# Path for storing file hash state
|
||||||
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
|
self.hash_state_path = os.path.join(
|
||||||
|
self.persist_directory, f"{collection_name}_hash_state.json"
|
||||||
|
)
|
||||||
|
|
||||||
# Flag to track if this is a new collection
|
# Flag to track if this is a new collection
|
||||||
self.is_new_collection = False
|
self.is_new_collection = False
|
||||||
|
|
||||||
# Initialize ChromaDB collection
|
# Initialize ChromaDB collection
|
||||||
self._collection : Collection = self._get_vector_collection(recreate=recreate)
|
self._collection: Collection = self._get_vector_collection(recreate=recreate)
|
||||||
|
self._markdown_chunker = MarkdownChunker()
|
||||||
self._update_umaps()
|
self._update_umaps()
|
||||||
|
|
||||||
# Setup text splitter
|
# Setup text splitter
|
||||||
self.text_splitter = CharacterTextSplitter(
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
separator="\n\n", # Respect paragraph/section breaks
|
|
||||||
length_function=len
|
|
||||||
)
|
|
||||||
|
|
||||||
# Track file hashes and processing state
|
# Track file hashes and processing state
|
||||||
self.file_hashes = self._load_hash_state()
|
self.file_hashes = self._load_hash_state()
|
||||||
self.update_lock = asyncio.Lock()
|
self.update_lock = asyncio.Lock()
|
||||||
@ -114,8 +122,8 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
def umap_model_3d(self):
|
def umap_model_3d(self):
|
||||||
return self._umap_model_3d
|
return self._umap_model_3d
|
||||||
|
|
||||||
def _markitdown(self, document : str, markdown : Path):
|
def _markitdown(self, document: str, markdown: Path):
|
||||||
logging.info(f'Converting {document} to {markdown}')
|
logging.info(f"Converting {document} to {markdown}")
|
||||||
try:
|
try:
|
||||||
result = self.md.convert(document)
|
result = self.md.convert(document)
|
||||||
markdown.write_text(result.text_content)
|
markdown.write_text(result.text_content)
|
||||||
@ -128,7 +136,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
# Create directory if it doesn't exist
|
# Create directory if it doesn't exist
|
||||||
os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True)
|
os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True)
|
||||||
|
|
||||||
with open(self.hash_state_path, 'w') as f:
|
with open(self.hash_state_path, "w") as f:
|
||||||
json.dump(self.file_hashes, f)
|
json.dump(self.file_hashes, f)
|
||||||
|
|
||||||
logging.info(f"Saved hash state with {len(self.file_hashes)} entries")
|
logging.info(f"Saved hash state with {len(self.file_hashes)} entries")
|
||||||
@ -139,7 +147,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
"""Load the file hash state from disk."""
|
"""Load the file hash state from disk."""
|
||||||
if os.path.exists(self.hash_state_path):
|
if os.path.exists(self.hash_state_path):
|
||||||
try:
|
try:
|
||||||
with open(self.hash_state_path, 'r') as f:
|
with open(self.hash_state_path, "r") as f:
|
||||||
hash_state = json.load(f)
|
hash_state = json.load(f)
|
||||||
logging.info(f"Loaded hash state with {len(hash_state)} entries")
|
logging.info(f"Loaded hash state with {len(hash_state)} entries")
|
||||||
return hash_state
|
return hash_state
|
||||||
@ -156,7 +164,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
process_all: If True, process all files regardless of hash status
|
process_all: If True, process all files regardless of hash status
|
||||||
"""
|
"""
|
||||||
# Check for new or modified files
|
# Check for new or modified files
|
||||||
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
|
file_paths = glob.glob(
|
||||||
|
os.path.join(self.watch_directory, "**/*"), recursive=True
|
||||||
|
)
|
||||||
files_checked = 0
|
files_checked = 0
|
||||||
files_processed = 0
|
files_processed = 0
|
||||||
files_to_process = []
|
files_to_process = []
|
||||||
@ -166,21 +176,29 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
for file_path in file_paths:
|
for file_path in file_paths:
|
||||||
if os.path.isfile(file_path):
|
if os.path.isfile(file_path):
|
||||||
# Do not put the Resume in RAG as it is provideded with all queries.
|
# Do not put the Resume in RAG as it is provideded with all queries.
|
||||||
if file_path == defines.resume_doc:
|
# if file_path == defines.resume_doc:
|
||||||
logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
||||||
continue
|
# continue
|
||||||
files_checked += 1
|
files_checked += 1
|
||||||
current_hash = self._get_file_hash(file_path)
|
current_hash = self._get_file_hash(file_path)
|
||||||
if not current_hash:
|
if not current_hash:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If file is new, changed, or we're processing all files
|
# If file is new, changed, or we're processing all files
|
||||||
if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
|
if (
|
||||||
|
process_all
|
||||||
|
or file_path not in self.file_hashes
|
||||||
|
or self.file_hashes[file_path] != current_hash
|
||||||
|
):
|
||||||
self.file_hashes[file_path] = current_hash
|
self.file_hashes[file_path] = current_hash
|
||||||
files_to_process.append(file_path)
|
files_to_process.append(file_path)
|
||||||
logging.info(f"File {'found' if process_all else 'changed'}: {file_path}")
|
logging.info(
|
||||||
|
f"File {'found' if process_all else 'changed'}: {file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files")
|
logging.info(
|
||||||
|
f"Found {len(files_to_process)} files to process after scanning {files_checked} files"
|
||||||
|
)
|
||||||
|
|
||||||
# Check for deleted files
|
# Check for deleted files
|
||||||
deleted_files = []
|
deleted_files = []
|
||||||
@ -188,7 +206,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
deleted_files.append(file_path)
|
deleted_files.append(file_path)
|
||||||
# Schedule removal
|
# Schedule removal
|
||||||
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self.remove_file_from_collection(file_path), self.loop
|
||||||
|
)
|
||||||
# Don't block on result, just let it run
|
# Don't block on result, just let it run
|
||||||
logging.info(f"File deleted: {file_path}")
|
logging.info(f"File deleted: {file_path}")
|
||||||
|
|
||||||
@ -209,7 +229,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
# Save the updated state
|
# Save the updated state
|
||||||
self._save_hash_state()
|
self._save_hash_state()
|
||||||
|
|
||||||
logging.info(f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}")
|
logging.info(
|
||||||
|
f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}"
|
||||||
|
)
|
||||||
return files_processed
|
return files_processed
|
||||||
|
|
||||||
async def process_file_update(self, file_path):
|
async def process_file_update(self, file_path):
|
||||||
@ -219,9 +241,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
logging.info(f"{file_path} already in queue. Not adding.")
|
logging.info(f"{file_path} already in queue. Not adding.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if file_path == defines.resume_doc:
|
# if file_path == defines.resume_doc:
|
||||||
logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
||||||
return
|
# return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info(f"{file_path} not in queue. Adding.")
|
logging.info(f"{file_path} not in queue. Adding.")
|
||||||
@ -235,7 +257,10 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
if not current_hash: # File might have been deleted or is inaccessible
|
if not current_hash: # File might have been deleted or is inaccessible
|
||||||
return
|
return
|
||||||
|
|
||||||
if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash:
|
if (
|
||||||
|
file_path in self.file_hashes
|
||||||
|
and self.file_hashes[file_path] == current_hash
|
||||||
|
):
|
||||||
# File hasn't actually changed in content
|
# File hasn't actually changed in content
|
||||||
logging.info(f"Hash has not changed for {file_path}")
|
logging.info(f"Hash has not changed for {file_path}")
|
||||||
return
|
return
|
||||||
@ -263,13 +288,13 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
async with self.update_lock:
|
async with self.update_lock:
|
||||||
try:
|
try:
|
||||||
# Find all documents with the specified path
|
# Find all documents with the specified path
|
||||||
results = self.collection.get(
|
results = self.collection.get(where={"path": file_path})
|
||||||
where={"path": file_path}
|
|
||||||
)
|
|
||||||
|
|
||||||
if results and 'ids' in results and results['ids']:
|
if results and "ids" in results and results["ids"]:
|
||||||
self.collection.delete(ids=results['ids'])
|
self.collection.delete(ids=results["ids"])
|
||||||
logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}")
|
logging.info(
|
||||||
|
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
# Remove from hash dictionary
|
# Remove from hash dictionary
|
||||||
if file_path in self.file_hashes:
|
if file_path in self.file_hashes:
|
||||||
@ -282,29 +307,51 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
|
|
||||||
def _update_umaps(self):
|
def _update_umaps(self):
|
||||||
# Update the UMAP embeddings
|
# Update the UMAP embeddings
|
||||||
self._umap_collection = self._collection.get(include=["embeddings", "documents", "metadatas"])
|
self._umap_collection = self._collection.get(
|
||||||
|
include=["embeddings", "documents", "metadatas"]
|
||||||
|
)
|
||||||
if not self._umap_collection or not len(self._umap_collection["embeddings"]):
|
if not self._umap_collection or not len(self._umap_collection["embeddings"]):
|
||||||
logging.warning("No embeddings found in the collection.")
|
logging.warning("No embeddings found in the collection.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# During initialization
|
# During initialization
|
||||||
logging.info(f"Updating 2D UMAP for {len(self._umap_collection['embeddings'])} vectors")
|
logging.info(
|
||||||
|
f"Updating 2D UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
||||||
|
)
|
||||||
vectors = np.array(self._umap_collection["embeddings"])
|
vectors = np.array(self._umap_collection["embeddings"])
|
||||||
self._umap_model_2d = umap.UMAP(n_components=2, random_state=8911, metric="cosine", n_neighbors=15, min_dist=0.1)
|
self._umap_model_2d = umap.UMAP(
|
||||||
|
n_components=2,
|
||||||
|
random_state=8911,
|
||||||
|
metric="cosine",
|
||||||
|
n_neighbors=15,
|
||||||
|
min_dist=0.1,
|
||||||
|
)
|
||||||
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
|
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
|
||||||
logging.info(f"2D UMAP model n_components: {self._umap_model_2d.n_components}") # Should be 2
|
logging.info(
|
||||||
|
f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
|
||||||
|
) # Should be 2
|
||||||
|
|
||||||
logging.info(f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors")
|
logging.info(
|
||||||
self._umap_model_3d = umap.UMAP(n_components=3, random_state=8911, metric="cosine", n_neighbors=15, min_dist=0.1)
|
f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
||||||
|
)
|
||||||
|
self._umap_model_3d = umap.UMAP(
|
||||||
|
n_components=3,
|
||||||
|
random_state=8911,
|
||||||
|
metric="cosine",
|
||||||
|
n_neighbors=15,
|
||||||
|
min_dist=0.1,
|
||||||
|
)
|
||||||
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
|
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
|
||||||
logging.info(f"3D UMAP model n_components: {self._umap_model_3d.n_components}") # Should be 3
|
logging.info(
|
||||||
|
f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
|
||||||
|
) # Should be 3
|
||||||
|
|
||||||
def _get_vector_collection(self, recreate=False) -> Collection:
|
def _get_vector_collection(self, recreate=False) -> Collection:
|
||||||
"""Get or create a ChromaDB collection."""
|
"""Get or create a ChromaDB collection."""
|
||||||
# Initialize ChromaDB client
|
# Initialize ChromaDB client
|
||||||
chroma_client = chromadb.PersistentClient( # type: ignore
|
chroma_client = chromadb.PersistentClient( # type: ignore
|
||||||
path=self.persist_directory,
|
path=self.persist_directory,
|
||||||
settings=chromadb.Settings(anonymized_telemetry=False) # type: ignore
|
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the collection exists
|
# Check if the collection exists
|
||||||
@ -326,35 +373,8 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
logging.info(f"Recreating collection: {self.collection_name}")
|
logging.info(f"Recreating collection: {self.collection_name}")
|
||||||
|
|
||||||
return chroma_client.get_or_create_collection(
|
return chroma_client.get_or_create_collection(
|
||||||
name=self.collection_name,
|
name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
||||||
metadata={
|
)
|
||||||
"hnsw:space": "cosine"
|
|
||||||
})
|
|
||||||
|
|
||||||
def load_text_files(self, directory=None, encoding="utf-8"):
|
|
||||||
"""Load all text files from a directory into Document objects."""
|
|
||||||
directory = directory or self.watch_directory
|
|
||||||
file_paths = glob.glob(os.path.join(directory, "**/*"), recursive=True)
|
|
||||||
documents = []
|
|
||||||
|
|
||||||
for file_path in file_paths:
|
|
||||||
if os.path.isfile(file_path): # Ensure it's a file, not a directory
|
|
||||||
try:
|
|
||||||
with open(file_path, "r", encoding=encoding) as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# Extract top-level directory
|
|
||||||
rel_path = os.path.relpath(file_path, directory)
|
|
||||||
top_level_dir = rel_path.split(os.sep)[0]
|
|
||||||
|
|
||||||
documents.append(Document(
|
|
||||||
page_content=content,
|
|
||||||
metadata={"doc_type": top_level_dir, "path": file_path}
|
|
||||||
))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to load {file_path}: {e}")
|
|
||||||
|
|
||||||
return documents
|
|
||||||
|
|
||||||
def create_chunks_from_documents(self, docs):
|
def create_chunks_from_documents(self, docs):
|
||||||
"""Split documents into chunks using the text splitter."""
|
"""Split documents into chunks using the text splitter."""
|
||||||
@ -362,12 +382,10 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
|
|
||||||
def get_embedding(self, text, normalize=True):
|
def get_embedding(self, text, normalize=True):
|
||||||
"""Generate embeddings using Ollama."""
|
"""Generate embeddings using Ollama."""
|
||||||
#response = self.embedding_model.encode(text) # Outputs 384-dim vectors
|
# response = self.embedding_model.encode(text) # Outputs 384-dim vectors
|
||||||
|
|
||||||
response = self.llm.embeddings(
|
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
|
||||||
model=defines.embedding_model,
|
embedding = response["embedding"]
|
||||||
prompt=text)
|
|
||||||
embedding = response['embedding']
|
|
||||||
|
|
||||||
# response = self.llm.embeddings.create(
|
# response = self.llm.embeddings.create(
|
||||||
# model=defines.embedding_model,
|
# model=defines.embedding_model,
|
||||||
@ -379,27 +397,46 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
return normalized
|
return normalized
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def add_embeddings_to_collection(self, chunks):
|
def add_embeddings_to_collection(self, chunks: List[Chunk]):
|
||||||
"""Add embeddings for chunks to the collection."""
|
"""Add embeddings for chunks to the collection."""
|
||||||
|
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
text = chunk.page_content
|
text = chunk["text"]
|
||||||
metadata = chunk.metadata
|
metadata = chunk["metadata"]
|
||||||
|
|
||||||
# Generate a more unique ID based on content and metadata
|
# Generate a more unique ID based on content and metadata
|
||||||
content_hash = hashlib.md5(text.encode()).hexdigest()
|
content_hash = hashlib.md5(text.encode()).hexdigest()
|
||||||
path_hash = ""
|
path_hash = ""
|
||||||
if "path" in metadata:
|
if "path" in metadata:
|
||||||
path_hash = hashlib.md5(metadata["path"].encode()).hexdigest()[:8]
|
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[
|
||||||
|
:8
|
||||||
|
]
|
||||||
|
|
||||||
chunk_id = f"{path_hash}_{content_hash}_{i}"
|
chunk_id = f"{path_hash}_{content_hash}_{i}"
|
||||||
|
|
||||||
embedding = self.get_embedding(text)
|
embedding = self.get_embedding(text)
|
||||||
self.collection.add(
|
try:
|
||||||
ids=[chunk_id],
|
self.collection.add(
|
||||||
documents=[text],
|
ids=[chunk_id],
|
||||||
embeddings=[embedding],
|
documents=[text],
|
||||||
metadatas=[metadata]
|
embeddings=[embedding],
|
||||||
)
|
metadatas=[metadata],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error adding chunk to collection: {e}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
logging.error(chunk)
|
||||||
|
|
||||||
|
def read_line_range(self, file_path, start, end, buffer=5) -> list[str]:
|
||||||
|
try:
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
start = max(0, start - buffer)
|
||||||
|
end = min(len(lines), end + buffer)
|
||||||
|
return lines[start:end]
|
||||||
|
except:
|
||||||
|
logging.warning(f"Unable to open {file_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
# Cosine Distance Equivalent Similarity Retrieval Characteristics
|
# Cosine Distance Equivalent Similarity Retrieval Characteristics
|
||||||
# 0.2 - 0.3 0.85 - 0.90 Very strict, highly precise results only
|
# 0.2 - 0.3 0.85 - 0.90 Very strict, highly precise results only
|
||||||
@ -419,10 +456,10 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Extract results
|
# Extract results
|
||||||
ids = results['ids'][0]
|
ids = results["ids"][0]
|
||||||
documents = results['documents'][0]
|
documents = results["documents"][0]
|
||||||
distances = results['distances'][0]
|
distances = results["distances"][0]
|
||||||
metadatas = results['metadatas'][0]
|
metadatas = results["metadatas"][0]
|
||||||
|
|
||||||
filtered_ids = []
|
filtered_ids = []
|
||||||
filtered_documents = []
|
filtered_documents = []
|
||||||
@ -436,6 +473,14 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
filtered_metadatas.append(metadatas[i])
|
filtered_metadatas.append(metadatas[i])
|
||||||
filtered_distances.append(distance)
|
filtered_distances.append(distance)
|
||||||
|
|
||||||
|
for index, meta in enumerate(filtered_metadatas):
|
||||||
|
source_file = meta["source_file"]
|
||||||
|
del meta["source_file"]
|
||||||
|
lines = self.read_line_range(
|
||||||
|
source_file, meta["line_begin"], meta["line_end"]
|
||||||
|
)
|
||||||
|
if len(lines):
|
||||||
|
filtered_documents[index] = "\n".join(lines)
|
||||||
# Return the filtered results instead of all results
|
# Return the filtered results instead of all results
|
||||||
return {
|
return {
|
||||||
"query_embedding": query_embedding,
|
"query_embedding": query_embedding,
|
||||||
@ -448,7 +493,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
def _get_file_hash(self, file_path):
|
def _get_file_hash(self, file_path):
|
||||||
"""Calculate MD5 hash of a file."""
|
"""Calculate MD5 hash of a file."""
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'rb') as f:
|
with open(file_path, "rb") as f:
|
||||||
return hashlib.md5(f.read()).hexdigest()
|
return hashlib.md5(f.read()).hexdigest()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error hashing file {file_path}: {e}")
|
logging.error(f"Error hashing file {file_path}: {e}")
|
||||||
@ -480,7 +525,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
return
|
return
|
||||||
|
|
||||||
file_path = event.src_path
|
file_path = event.src_path
|
||||||
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self.remove_file_from_collection(file_path), self.loop
|
||||||
|
)
|
||||||
logging.info(f"File deleted: {file_path}")
|
logging.info(f"File deleted: {file_path}")
|
||||||
|
|
||||||
def on_moved(self, event):
|
def on_moved(self, event):
|
||||||
@ -508,37 +555,43 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
try:
|
try:
|
||||||
# Remove existing entries for this file
|
# Remove existing entries for this file
|
||||||
existing_results = self.collection.get(where={"path": file_path})
|
existing_results = self.collection.get(where={"path": file_path})
|
||||||
if existing_results and 'ids' in existing_results and existing_results['ids']:
|
if (
|
||||||
self.collection.delete(ids=existing_results['ids'])
|
existing_results
|
||||||
|
and "ids" in existing_results
|
||||||
|
and existing_results["ids"]
|
||||||
|
):
|
||||||
|
self.collection.delete(ids=existing_results["ids"])
|
||||||
|
|
||||||
extensions = (".docx", ".xlsx", ".xls", ".pdf")
|
extensions = (".docx", ".xlsx", ".xls", ".pdf")
|
||||||
if file_path.endswith(extensions):
|
if file_path.endswith(extensions):
|
||||||
p = Path(file_path)
|
p = Path(file_path)
|
||||||
p_as_md = p.with_suffix(".md")
|
p_as_md = p.with_suffix(".md")
|
||||||
if p_as_md.exists():
|
if p_as_md.exists():
|
||||||
logging.info(f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}")
|
logging.info(
|
||||||
|
f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}"
|
||||||
|
)
|
||||||
|
|
||||||
# If file_path.md doesn't exist or file_path is newer than file_path.md,
|
# If file_path.md doesn't exist or file_path is newer than file_path.md,
|
||||||
# fire off markitdown
|
# fire off markitdown
|
||||||
if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime):
|
if (not p_as_md.exists()) or (
|
||||||
|
p.stat().st_mtime > p_as_md.stat().st_mtime
|
||||||
|
):
|
||||||
self._markitdown(file_path, p_as_md)
|
self._markitdown(file_path, p_as_md)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create document object in LangChain format
|
chunks = self._markdown_chunker.process_file(file_path)
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
if not chunks:
|
||||||
content = f.read()
|
return
|
||||||
|
|
||||||
# Extract top-level directory
|
# Extract top-level directory
|
||||||
rel_path = os.path.relpath(file_path, self.watch_directory)
|
rel_path = os.path.relpath(file_path, self.watch_directory)
|
||||||
top_level_dir = rel_path.split(os.sep)[0]
|
path_parts = rel_path.split(os.sep)
|
||||||
|
top_level_dir = path_parts[0]
|
||||||
document = Document(
|
# file_name = path_parts[-1]
|
||||||
page_content=content,
|
for i, chunk in enumerate(chunks):
|
||||||
metadata={"doc_type": top_level_dir, "path": file_path}
|
chunk["metadata"]["doc_type"] = top_level_dir
|
||||||
)
|
# with open(f"src/tmp/{file_name}.{i}", "w") as f:
|
||||||
|
# f.write(json.dumps(chunk, indent=2))
|
||||||
# Create chunks
|
|
||||||
chunks = self.text_splitter.split_documents([document])
|
|
||||||
|
|
||||||
# Add chunks to collection
|
# Add chunks to collection
|
||||||
self.add_embeddings_to_collection(chunks)
|
self.add_embeddings_to_collection(chunks)
|
||||||
@ -547,30 +600,40 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error updating document in collection: {e}")
|
logging.error(f"Error updating document in collection: {e}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
async def initialize_collection(self):
|
async def initialize_collection(self):
|
||||||
"""Initialize the collection with all documents from the watch directory."""
|
"""Initialize the collection with all documents from the watch directory."""
|
||||||
# Process all files regardless of hash state
|
# Process all files regardless of hash state
|
||||||
num_processed = await self.scan_directory(process_all=True)
|
num_processed = await self.scan_directory(process_all=True)
|
||||||
|
|
||||||
logging.info(f"Vectorstore initialized with {self.collection.count()} documents")
|
logging.info(
|
||||||
|
f"Vectorstore initialized with {self.collection.count()} documents"
|
||||||
|
)
|
||||||
|
|
||||||
self._update_umaps()
|
self._update_umaps()
|
||||||
|
|
||||||
# Show stats
|
# Show stats
|
||||||
try:
|
try:
|
||||||
all_metadata = self.collection.get()['metadatas']
|
all_metadata = self.collection.get()["metadatas"]
|
||||||
if all_metadata:
|
if all_metadata:
|
||||||
doc_types = set(m.get('doc_type', 'unknown') for m in all_metadata)
|
doc_types = set(m.get("doc_type", "unknown") for m in all_metadata)
|
||||||
logging.info(f"Document types: {doc_types}")
|
logging.info(f"Document types: {doc_types}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error getting document types: {e}")
|
logging.error(f"Error getting document types: {e}")
|
||||||
|
|
||||||
return num_processed
|
return num_processed
|
||||||
|
|
||||||
|
|
||||||
# Function to start the file watcher
|
# Function to start the file watcher
|
||||||
def start_file_watcher(llm, watch_directory, persist_directory=None,
|
def start_file_watcher(
|
||||||
collection_name="documents", initialize=False, recreate=False):
|
llm,
|
||||||
|
watch_directory,
|
||||||
|
persist_directory=None,
|
||||||
|
collection_name="documents",
|
||||||
|
initialize=False,
|
||||||
|
recreate=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Start watching a directory for file changes.
|
Start watching a directory for file changes.
|
||||||
|
|
||||||
@ -590,7 +653,7 @@ def start_file_watcher(llm, watch_directory, persist_directory=None,
|
|||||||
loop=loop,
|
loop=loop,
|
||||||
persist_directory=persist_directory,
|
persist_directory=persist_directory,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
recreate=recreate
|
recreate=recreate,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process all files if:
|
# Process all files if:
|
||||||
@ -613,12 +676,13 @@ def start_file_watcher(llm, watch_directory, persist_directory=None,
|
|||||||
logging.info(f"Started watching directory: {watch_directory}")
|
logging.info(f"Started watching directory: {watch_directory}")
|
||||||
return observer, file_watcher
|
return observer, file_watcher
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# When running directly, use absolute imports
|
# When running directly, use absolute imports
|
||||||
import defines
|
import defines
|
||||||
|
|
||||||
# Initialize Ollama client
|
# Initialize Ollama client
|
||||||
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
||||||
|
|
||||||
# Start the file watcher (with initialization)
|
# Start the file watcher (with initialization)
|
||||||
observer, file_watcher = start_file_watcher(
|
observer, file_watcher = start_file_watcher(
|
||||||
|
@ -4,9 +4,12 @@ import logging
|
|||||||
|
|
||||||
from . import defines
|
from . import defines
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(level=defines.logging_level) -> logging.Logger:
|
def setup_logging(level=defines.logging_level) -> logging.Logger:
|
||||||
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
||||||
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
|
warnings.filterwarnings(
|
||||||
|
"ignore", message="Overriding a previously registered kernel"
|
||||||
|
)
|
||||||
warnings.filterwarnings("ignore", message="Warning only once for all operators")
|
warnings.filterwarnings("ignore", message="Warning only once for all operators")
|
||||||
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
|
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
|
||||||
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
|
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
|
||||||
@ -21,12 +24,18 @@ def setup_logging(level=defines.logging_level) -> logging.Logger:
|
|||||||
level=numeric_level,
|
level=numeric_level,
|
||||||
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
force=True
|
force=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now reduce verbosity for FastAPI, Uvicorn, Starlette
|
# Now reduce verbosity for FastAPI, Uvicorn, Starlette
|
||||||
for noisy_logger in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi", "starlette"):
|
for noisy_logger in (
|
||||||
#for noisy_logger in ("starlette"):
|
"uvicorn",
|
||||||
|
"uvicorn.error",
|
||||||
|
"uvicorn.access",
|
||||||
|
"fastapi",
|
||||||
|
"starlette",
|
||||||
|
):
|
||||||
|
# for noisy_logger in ("starlette"):
|
||||||
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
|
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
from . basetools import tools, llm_tools, enabled_tools, tool_functions
|
from .basetools import tools, llm_tools, enabled_tools, tool_functions
|
||||||
from .. setup_logging import setup_logging
|
from ..setup_logging import setup_logging
|
||||||
from .. import defines
|
from .. import defines
|
||||||
|
|
||||||
logger = setup_logging(level=defines.logging_level)
|
logger = setup_logging(level=defines.logging_level)
|
||||||
|
|
||||||
# Dynamically import all names from basetools listed in tools_all
|
# Dynamically import all names from basetools listed in tools_all
|
||||||
module = importlib.import_module('.basetools', package=__package__)
|
module = importlib.import_module(".basetools", package=__package__)
|
||||||
for name in tool_functions:
|
for name in tool_functions:
|
||||||
globals()[name] = getattr(module, name)
|
globals()[name] = getattr(module, name)
|
||||||
|
|
||||||
__all__ = [ 'tools', 'llm_tools', 'enabled_tools', 'tool_functions' ]
|
__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||||
|
@ -1,18 +1,19 @@
|
|||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
)
|
)
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from bs4 import BeautifulSoup # type: ignore
|
from bs4 import BeautifulSoup # type: ignore
|
||||||
|
|
||||||
from geopy.geocoders import Nominatim # type: ignore
|
from geopy.geocoders import Nominatim # type: ignore
|
||||||
import pytz # type: ignore
|
import pytz # type: ignore
|
||||||
import requests
|
import requests
|
||||||
import yfinance as yf # type: ignore
|
import yfinance as yf # type: ignore
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
def WeatherForecast(city, state, country="USA"):
|
def WeatherForecast(city, state, country="USA"):
|
||||||
"""
|
"""
|
||||||
@ -42,11 +43,12 @@ def WeatherForecast(city, state, country="USA"):
|
|||||||
# Step 3: Get the forecast data from the grid endpoint
|
# Step 3: Get the forecast data from the grid endpoint
|
||||||
forecast = get_forecast(grid_endpoint)
|
forecast = get_forecast(grid_endpoint)
|
||||||
|
|
||||||
if not forecast['location']:
|
if not forecast["location"]:
|
||||||
forecast['location'] = location
|
forecast["location"] = location
|
||||||
|
|
||||||
return forecast
|
return forecast
|
||||||
|
|
||||||
|
|
||||||
def get_coordinates(location):
|
def get_coordinates(location):
|
||||||
"""Convert a location string to latitude and longitude using Nominatim geocoder."""
|
"""Convert a location string to latitude and longitude using Nominatim geocoder."""
|
||||||
try:
|
try:
|
||||||
@ -59,7 +61,7 @@ def get_coordinates(location):
|
|||||||
if location_data:
|
if location_data:
|
||||||
return {
|
return {
|
||||||
"latitude": location_data.latitude,
|
"latitude": location_data.latitude,
|
||||||
"longitude": location_data.longitude
|
"longitude": location_data.longitude,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
print(f"Location not found: {location}")
|
print(f"Location not found: {location}")
|
||||||
@ -68,6 +70,7 @@ def get_coordinates(location):
|
|||||||
print(f"Error getting coordinates: {e}")
|
print(f"Error getting coordinates: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_grid_endpoint(coordinates):
|
def get_grid_endpoint(coordinates):
|
||||||
"""Get the grid endpoint from weather.gov based on coordinates."""
|
"""Get the grid endpoint from weather.gov based on coordinates."""
|
||||||
try:
|
try:
|
||||||
@ -77,7 +80,7 @@ def get_grid_endpoint(coordinates):
|
|||||||
# Define headers for the API request
|
# Define headers for the API request
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
||||||
"Accept": "application/geo+json"
|
"Accept": "application/geo+json",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Make the request to get the grid endpoint
|
# Make the request to get the grid endpoint
|
||||||
@ -94,15 +97,17 @@ def get_grid_endpoint(coordinates):
|
|||||||
print(f"Error in get_grid_endpoint: {e}")
|
print(f"Error in get_grid_endpoint: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Weather related function
|
# Weather related function
|
||||||
|
|
||||||
|
|
||||||
def get_forecast(grid_endpoint):
|
def get_forecast(grid_endpoint):
|
||||||
"""Get the forecast data from the grid endpoint."""
|
"""Get the forecast data from the grid endpoint."""
|
||||||
try:
|
try:
|
||||||
# Define headers for the API request
|
# Define headers for the API request
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
||||||
"Accept": "application/geo+json"
|
"Accept": "application/geo+json",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Make the request to get the forecast
|
# Make the request to get the forecast
|
||||||
@ -116,21 +121,25 @@ def get_forecast(grid_endpoint):
|
|||||||
|
|
||||||
# Process the forecast data into a simpler format
|
# Process the forecast data into a simpler format
|
||||||
forecast = {
|
forecast = {
|
||||||
"location": data["properties"].get("relativeLocation", {}).get("properties", {}),
|
"location": data["properties"]
|
||||||
|
.get("relativeLocation", {})
|
||||||
|
.get("properties", {}),
|
||||||
"updated": data["properties"].get("updated", ""),
|
"updated": data["properties"].get("updated", ""),
|
||||||
"periods": []
|
"periods": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
for period in periods:
|
for period in periods:
|
||||||
forecast["periods"].append({
|
forecast["periods"].append(
|
||||||
"name": period.get("name", ""),
|
{
|
||||||
"temperature": period.get("temperature", ""),
|
"name": period.get("name", ""),
|
||||||
"temperatureUnit": period.get("temperatureUnit", ""),
|
"temperature": period.get("temperature", ""),
|
||||||
"windSpeed": period.get("windSpeed", ""),
|
"temperatureUnit": period.get("temperatureUnit", ""),
|
||||||
"windDirection": period.get("windDirection", ""),
|
"windSpeed": period.get("windSpeed", ""),
|
||||||
"shortForecast": period.get("shortForecast", ""),
|
"windDirection": period.get("windDirection", ""),
|
||||||
"detailedForecast": period.get("detailedForecast", "")
|
"shortForecast": period.get("shortForecast", ""),
|
||||||
})
|
"detailedForecast": period.get("detailedForecast", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return forecast
|
return forecast
|
||||||
else:
|
else:
|
||||||
@ -140,6 +149,7 @@ def get_forecast(grid_endpoint):
|
|||||||
print(f"Error in get_forecast: {e}")
|
print(f"Error in get_forecast: {e}")
|
||||||
return {"error": f"Exception: {str(e)}"}
|
return {"error": f"Exception: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
# Example usage
|
# Example usage
|
||||||
# def do_weather():
|
# def do_weather():
|
||||||
# city = input("Enter city: ")
|
# city = input("Enter city: ")
|
||||||
@ -166,34 +176,31 @@ def get_forecast(grid_endpoint):
|
|||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
|
||||||
def TickerValue(ticker_symbols):
|
def TickerValue(ticker_symbols):
|
||||||
api_key = os.getenv("TWELVEDATA_API_KEY", "")
|
api_key = os.getenv("TWELVEDATA_API_KEY", "")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
return {"error": f"Error fetching data: No API key for TwelveData"}
|
return {"error": f"Error fetching data: No API key for TwelveData"}
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for ticker_symbol in ticker_symbols.split(','):
|
for ticker_symbol in ticker_symbols.split(","):
|
||||||
ticker_symbol = ticker_symbol.strip()
|
ticker_symbol = ticker_symbol.strip()
|
||||||
if ticker_symbol == "":
|
if ticker_symbol == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
url = f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
|
url = (
|
||||||
|
f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
|
||||||
|
)
|
||||||
|
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
if "price" in data:
|
if "price" in data:
|
||||||
logging.info(f"TwelveData: {ticker_symbol} {data}")
|
logging.info(f"TwelveData: {ticker_symbol} {data}")
|
||||||
results.append({
|
results.append({"symbol": ticker_symbol, "price": float(data["price"])})
|
||||||
"symbol": ticker_symbol,
|
|
||||||
"price": float(data["price"])
|
|
||||||
})
|
|
||||||
else:
|
else:
|
||||||
logging.error(f"TwelveData: {data}")
|
logging.error(f"TwelveData: {data}")
|
||||||
results.append({
|
results.append({"symbol": ticker_symbol, "price": "Unavailable"})
|
||||||
"symbol": ticker_symbol,
|
|
||||||
"price": "Unavailable"
|
|
||||||
})
|
|
||||||
|
|
||||||
return results[0] if len(results) == 1 else results
|
return results[0] if len(results) == 1 else results
|
||||||
|
|
||||||
@ -210,7 +217,7 @@ def yfTickerValue(ticker_symbols):
|
|||||||
dict: Current stock information including price
|
dict: Current stock information including price
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
for ticker_symbol in ticker_symbols.split(','):
|
for ticker_symbol in ticker_symbols.split(","):
|
||||||
ticker_symbol = ticker_symbol.strip()
|
ticker_symbol = ticker_symbol.strip()
|
||||||
if ticker_symbol == "":
|
if ticker_symbol == "":
|
||||||
continue
|
continue
|
||||||
@ -226,19 +233,23 @@ def yfTickerValue(ticker_symbols):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Get the latest closing price
|
# Get the latest closing price
|
||||||
latest_price = ticker_data['Close'].iloc[-1]
|
latest_price = ticker_data["Close"].iloc[-1]
|
||||||
|
|
||||||
# Get some additional info
|
# Get some additional info
|
||||||
results.append({ 'symbol': ticker_symbol, 'price': latest_price })
|
results.append({"symbol": ticker_symbol, "price": latest_price})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logging.error(f"Error fetching data for {ticker_symbol}: {e}")
|
logging.error(f"Error fetching data for {ticker_symbol}: {e}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
results.append({"error": f"Error fetching data for {ticker_symbol}: {str(e)}"})
|
results.append(
|
||||||
|
{"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}
|
||||||
|
)
|
||||||
|
|
||||||
return results[0] if len(results) == 1 else results
|
return results[0] if len(results) == 1 else results
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
def DateTime(timezone="America/Los_Angeles"):
|
def DateTime(timezone="America/Los_Angeles"):
|
||||||
"""
|
"""
|
||||||
@ -252,8 +263,8 @@ def DateTime(timezone="America/Los_Angeles"):
|
|||||||
str: Current date and time with timezone in the format YYYY-MM-DDTHH:MM:SS+HH:MM
|
str: Current date and time with timezone in the format YYYY-MM-DDTHH:MM:SS+HH:MM
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if timezone == 'system' or timezone == '' or not timezone:
|
if timezone == "system" or timezone == "" or not timezone:
|
||||||
timezone = 'America/Los_Angeles'
|
timezone = "America/Los_Angeles"
|
||||||
# Get current UTC time (timezone-aware)
|
# Get current UTC time (timezone-aware)
|
||||||
local_tz = pytz.timezone("America/Los_Angeles")
|
local_tz = pytz.timezone("America/Los_Angeles")
|
||||||
local_now = datetime.now(tz=local_tz)
|
local_now = datetime.now(tz=local_tz)
|
||||||
@ -264,9 +275,10 @@ def DateTime(timezone="America/Los_Angeles"):
|
|||||||
|
|
||||||
return target_time.isoformat()
|
return target_time.isoformat()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {'error': f"Invalid timezone {timezone}: {str(e)}"}
|
return {"error": f"Invalid timezone {timezone}: {str(e)}"}
|
||||||
|
|
||||||
async def AnalyzeSite(llm, model: str, url : str, question : str):
|
|
||||||
|
async def AnalyzeSite(llm, model: str, url: str, question: str):
|
||||||
"""
|
"""
|
||||||
Fetches content from a URL, extracts the text, and uses Ollama to summarize it.
|
Fetches content from a URL, extracts the text, and uses Ollama to summarize it.
|
||||||
|
|
||||||
@ -310,16 +322,18 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
|
|||||||
|
|
||||||
# Generate summary using Ollama
|
# Generate summary using Ollama
|
||||||
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
||||||
response = llm.generate(model=model,
|
response = llm.generate(
|
||||||
system="You are given the contents of {url}. Answer the question about the contents",
|
model=model,
|
||||||
prompt=prompt)
|
system="You are given the contents of {url}. Answer the question about the contents",
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
#logging.info(response["response"])
|
# logging.info(response["response"])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"source": "summarizer-llm",
|
"source": "summarizer-llm",
|
||||||
"content": response["response"],
|
"content": response["response"],
|
||||||
"metadata": DateTime()
|
"metadata": DateTime(),
|
||||||
}
|
}
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
@ -331,109 +345,116 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
tools = [ {
|
tools = [
|
||||||
"type": "function",
|
{
|
||||||
"function": {
|
"type": "function",
|
||||||
"name": "TickerValue",
|
"function": {
|
||||||
"description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'",
|
"name": "TickerValue",
|
||||||
"parameters": {
|
"description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'",
|
||||||
"type": "object",
|
"parameters": {
|
||||||
"properties": {
|
"type": "object",
|
||||||
"ticker": {
|
"properties": {
|
||||||
"type": "string",
|
"ticker": {
|
||||||
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').",
|
"type": "string",
|
||||||
|
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
"required": ["ticker"],
|
||||||
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
"required": ["ticker"],
|
|
||||||
"additionalProperties": False
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "AnalyzeSite",
|
|
||||||
"description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"url": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The website URL to download and process",
|
|
||||||
},
|
|
||||||
"question": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The question to ask the second LLM about the content",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["url", "question"],
|
|
||||||
"additionalProperties": False
|
|
||||||
},
|
},
|
||||||
"returns": {
|
},
|
||||||
"type": "object",
|
{
|
||||||
"properties": {
|
"type": "function",
|
||||||
"source": {
|
"function": {
|
||||||
"type": "string",
|
"name": "AnalyzeSite",
|
||||||
"description": "Identifier for the source LLM"
|
"description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The website URL to download and process",
|
||||||
|
},
|
||||||
|
"question": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The question to ask the second LLM about the content",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"content": {
|
"required": ["url", "question"],
|
||||||
"type": "string",
|
"additionalProperties": False,
|
||||||
"description": "The complete response from the second LLM"
|
|
||||||
},
|
|
||||||
"metadata": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "Additional information about the response"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "DateTime",
|
|
||||||
"description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"timezone": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'."
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": []
|
"returns": {
|
||||||
}
|
"type": "object",
|
||||||
}
|
"properties": {
|
||||||
}, {
|
"source": {
|
||||||
"type": "function",
|
"type": "string",
|
||||||
"function": {
|
"description": "Identifier for the source LLM",
|
||||||
"name": "WeatherForecast",
|
},
|
||||||
"description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.",
|
"content": {
|
||||||
"parameters": {
|
"type": "string",
|
||||||
"type": "object",
|
"description": "The complete response from the second LLM",
|
||||||
"properties": {
|
},
|
||||||
"city": {
|
"metadata": {
|
||||||
"type": "string",
|
"type": "object",
|
||||||
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
|
"description": "Additional information about the response",
|
||||||
"minLength": 2
|
},
|
||||||
},
|
},
|
||||||
"state": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
|
|
||||||
"minLength": 2
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": [ "city", "state" ],
|
},
|
||||||
"additionalProperties": False
|
},
|
||||||
}
|
{
|
||||||
}
|
"type": "function",
|
||||||
}]
|
"function": {
|
||||||
|
"name": "DateTime",
|
||||||
|
"description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"timezone": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "WeatherForecast",
|
||||||
|
"description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
|
||||||
|
"minLength": 2,
|
||||||
|
},
|
||||||
|
"state": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
|
||||||
|
"minLength": 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["city", "state"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def llm_tools(tools):
|
def llm_tools(tools):
|
||||||
return [tool for tool in tools if tool.get("enabled", False) == True]
|
return [tool for tool in tools if tool.get("enabled", False) == True]
|
||||||
|
|
||||||
|
|
||||||
def enabled_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def enabled_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
return [{**tool, "enabled": True} for tool in tools]
|
return [{**tool, "enabled": True} for tool in tools]
|
||||||
|
|
||||||
tool_functions = [ 'DateTime', 'WeatherForecast', 'TickerValue', 'AnalyzeSite' ]
|
|
||||||
__all__ = [ 'tools', 'llm_tools', 'enabled_tools', 'tool_functions' ]
|
|
||||||
#__all__.extend(__tool_functions__) # type: ignore
|
|
||||||
|
|
||||||
|
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite"]
|
||||||
|
__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||||
|
# __all__.extend(__tool_functions__) # type: ignore
|
||||||
|
Loading…
x
Reference in New Issue
Block a user