Reformatted all content to black
This commit is contained in:
parent
a1798b58ac
commit
e044f9c639
@ -88,6 +88,10 @@ button {
|
|||||||
flex-grow: 1;
|
flex-grow: 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.MessageContent div > p:first-child {
|
||||||
|
margin-top: 0;
|
||||||
|
}
|
||||||
|
|
||||||
.MenuCard.MuiCard-root {
|
.MenuCard.MuiCard-root {
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
|
@ -30,7 +30,6 @@ const BackstoryTextField = React.forwardRef<BackstoryTextFieldRef, BackstoryText
|
|||||||
const shadowRef = useRef<HTMLTextAreaElement>(null);
|
const shadowRef = useRef<HTMLTextAreaElement>(null);
|
||||||
const [editValue, setEditValue] = useState<string>(value);
|
const [editValue, setEditValue] = useState<string>(value);
|
||||||
|
|
||||||
console.log({ value, placeholder, editValue });
|
|
||||||
// Sync editValue with prop value if it changes externally
|
// Sync editValue with prop value if it changes externally
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setEditValue(value || "");
|
setEditValue(value || "");
|
||||||
|
@ -158,7 +158,7 @@ function ChatBubble(props: ChatBubbleProps) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Render Accordion for expandable content
|
// Render Accordion for expandable content
|
||||||
if (expandable || (role === 'content' && title)) {
|
if (expandable || title) {
|
||||||
// Determine if Accordion is controlled
|
// Determine if Accordion is controlled
|
||||||
const isControlled = typeof expanded === 'boolean' && typeof onExpand === 'function';
|
const isControlled = typeof expanded === 'boolean' && typeof onExpand === 'function';
|
||||||
|
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
import React, { useState, useImperativeHandle, forwardRef, useEffect, useRef, useCallback } from 'react';
|
import React, { useState, useImperativeHandle, forwardRef, useEffect, useRef, useCallback } from 'react';
|
||||||
import Typography from '@mui/material/Typography';
|
import Typography from '@mui/material/Typography';
|
||||||
import Tooltip from '@mui/material/Tooltip';
|
import Tooltip from '@mui/material/Tooltip';
|
||||||
|
import IconButton from '@mui/material/IconButton';
|
||||||
import Button from '@mui/material/Button';
|
import Button from '@mui/material/Button';
|
||||||
import Box from '@mui/material/Box';
|
import Box from '@mui/material/Box';
|
||||||
import SendIcon from '@mui/icons-material/Send';
|
import SendIcon from '@mui/icons-material/Send';
|
||||||
|
import CancelIcon from '@mui/icons-material/Cancel';
|
||||||
import { SxProps, Theme } from '@mui/material';
|
import { SxProps, Theme } from '@mui/material';
|
||||||
import PropagateLoader from "react-spinners/PropagateLoader";
|
import PropagateLoader from "react-spinners/PropagateLoader";
|
||||||
|
|
||||||
import { Message, MessageList, MessageData } from './Message';
|
import { Message, MessageList, BackstoryMessage } from './Message';
|
||||||
import { ContextStatus } from './ContextStatus';
|
import { ContextStatus } from './ContextStatus';
|
||||||
import { Scrollable } from './Scrollable';
|
import { Scrollable } from './Scrollable';
|
||||||
import { DeleteConfirmation } from './DeleteConfirmation';
|
import { DeleteConfirmation } from './DeleteConfirmation';
|
||||||
@ -17,7 +19,7 @@ import { BackstoryTextField, BackstoryTextFieldRef } from './BackstoryTextField'
|
|||||||
import { BackstoryElementProps } from './BackstoryTab';
|
import { BackstoryElementProps } from './BackstoryTab';
|
||||||
import { connectionBase } from './Global';
|
import { connectionBase } from './Global';
|
||||||
|
|
||||||
const loadingMessage: MessageData = { "role": "status", "content": "Establishing connection with server..." };
|
const loadingMessage: BackstoryMessage = { "role": "status", "content": "Establishing connection with server..." };
|
||||||
|
|
||||||
type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check';
|
type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check';
|
||||||
|
|
||||||
@ -25,24 +27,6 @@ interface ConversationHandle {
|
|||||||
submitQuery: (prompt: string, options?: QueryOptions) => void;
|
submitQuery: (prompt: string, options?: QueryOptions) => void;
|
||||||
fetchHistory: () => void;
|
fetchHistory: () => void;
|
||||||
}
|
}
|
||||||
interface BackstoryMessage {
|
|
||||||
prompt: string;
|
|
||||||
preamble: {};
|
|
||||||
status: string;
|
|
||||||
full_content: string;
|
|
||||||
response: string; // Set when status === 'done' or 'error'
|
|
||||||
chunk: string; // Used when status === 'streaming'
|
|
||||||
metadata: {
|
|
||||||
rag: { documents: [] };
|
|
||||||
tools: string[];
|
|
||||||
eval_count: number;
|
|
||||||
eval_duration: number;
|
|
||||||
prompt_eval_count: number;
|
|
||||||
prompt_eval_duration: number;
|
|
||||||
};
|
|
||||||
actions: string[];
|
|
||||||
timestamp: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
interface ConversationProps extends BackstoryElementProps {
|
interface ConversationProps extends BackstoryElementProps {
|
||||||
className?: string, // Override default className
|
className?: string, // Override default className
|
||||||
@ -59,7 +43,7 @@ interface ConversationProps extends BackstoryElementProps {
|
|||||||
messageFilter?: ((messages: MessageList) => MessageList) | undefined, // Filter callback to determine which Messages to display in Conversation
|
messageFilter?: ((messages: MessageList) => MessageList) | undefined, // Filter callback to determine which Messages to display in Conversation
|
||||||
messages?: MessageList, //
|
messages?: MessageList, //
|
||||||
sx?: SxProps<Theme>,
|
sx?: SxProps<Theme>,
|
||||||
onResponse?: ((message: MessageData) => void) | undefined, // Event called when a query completes (provides messages)
|
onResponse?: ((message: BackstoryMessage) => void) | undefined, // Event called when a query completes (provides messages)
|
||||||
};
|
};
|
||||||
|
|
||||||
const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||||
@ -87,8 +71,8 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
const [countdown, setCountdown] = useState<number>(0);
|
const [countdown, setCountdown] = useState<number>(0);
|
||||||
const [conversation, setConversation] = useState<MessageList>([]);
|
const [conversation, setConversation] = useState<MessageList>([]);
|
||||||
const [filteredConversation, setFilteredConversation] = useState<MessageList>([]);
|
const [filteredConversation, setFilteredConversation] = useState<MessageList>([]);
|
||||||
const [processingMessage, setProcessingMessage] = useState<MessageData | undefined>(undefined);
|
const [processingMessage, setProcessingMessage] = useState<BackstoryMessage | undefined>(undefined);
|
||||||
const [streamingMessage, setStreamingMessage] = useState<MessageData | undefined>(undefined);
|
const [streamingMessage, setStreamingMessage] = useState<BackstoryMessage | undefined>(undefined);
|
||||||
const timerRef = useRef<any>(null);
|
const timerRef = useRef<any>(null);
|
||||||
const [contextStatus, setContextStatus] = useState<ContextStatus>({ context_used: 0, max_context: 0 });
|
const [contextStatus, setContextStatus] = useState<ContextStatus>({ context_used: 0, max_context: 0 });
|
||||||
const [contextWarningShown, setContextWarningShown] = useState<boolean>(false);
|
const [contextWarningShown, setContextWarningShown] = useState<boolean>(false);
|
||||||
@ -96,6 +80,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
const conversationRef = useRef<MessageList>([]);
|
const conversationRef = useRef<MessageList>([]);
|
||||||
const viewableElementRef = useRef<HTMLDivElement>(null);
|
const viewableElementRef = useRef<HTMLDivElement>(null);
|
||||||
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
|
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
|
||||||
|
const stopRef = useRef(false);
|
||||||
|
|
||||||
// Keep the ref updated whenever items changes
|
// Keep the ref updated whenever items changes
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -181,14 +166,25 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
|
|
||||||
const backstoryMessages: BackstoryMessage[] = messages;
|
const backstoryMessages: BackstoryMessage[] = messages;
|
||||||
|
|
||||||
setConversation(backstoryMessages.flatMap((backstoryMessage: BackstoryMessage) => [{
|
setConversation(backstoryMessages.flatMap((backstoryMessage: BackstoryMessage) => {
|
||||||
|
if (backstoryMessage.status === "partial") {
|
||||||
|
return [{
|
||||||
|
...backstoryMessage,
|
||||||
|
role: "assistant",
|
||||||
|
content: backstoryMessage.response || "",
|
||||||
|
expanded: false,
|
||||||
|
expandable: true,
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
return [{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: backstoryMessage.prompt || "",
|
content: backstoryMessage.prompt || "",
|
||||||
}, {
|
}, {
|
||||||
...backstoryMessage,
|
...backstoryMessage,
|
||||||
role: backstoryMessage.status === "done" ? "assistant" : backstoryMessage.status,
|
role: ['done'].includes(backstoryMessage.status || "") ? "assistant" : backstoryMessage.status,
|
||||||
content: backstoryMessage.response || "",
|
content: backstoryMessage.response || "",
|
||||||
}] as MessageList));
|
}] as MessageList;
|
||||||
|
}));
|
||||||
setNoInteractions(false);
|
setNoInteractions(false);
|
||||||
}
|
}
|
||||||
setProcessingMessage(undefined);
|
setProcessingMessage(undefined);
|
||||||
@ -294,6 +290,11 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const cancelQuery = () => {
|
||||||
|
console.log("Stop query");
|
||||||
|
stopRef.current = true;
|
||||||
|
};
|
||||||
|
|
||||||
const sendQuery = async (request: string, options?: QueryOptions) => {
|
const sendQuery = async (request: string, options?: QueryOptions) => {
|
||||||
request = request.trim();
|
request = request.trim();
|
||||||
|
|
||||||
@ -308,6 +309,8 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stopRef.current = false;
|
||||||
|
|
||||||
setNoInteractions(false);
|
setNoInteractions(false);
|
||||||
|
|
||||||
setConversation([
|
setConversation([
|
||||||
@ -325,12 +328,10 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
setProcessing(true);
|
setProcessing(true);
|
||||||
// Create a unique ID for the processing message
|
|
||||||
const processingId = Date.now().toString();
|
|
||||||
|
|
||||||
// Add initial processing message
|
// Add initial processing message
|
||||||
setProcessingMessage(
|
setProcessingMessage(
|
||||||
{ role: 'status', content: 'Submitting request...', id: processingId, isProcessing: true }
|
{ role: 'status', content: 'Submitting request...', disableCopy: true }
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add a small delay to ensure React has time to update the UI
|
// Add a small delay to ensure React has time to update the UI
|
||||||
@ -379,17 +380,20 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
|
|
||||||
switch (update.status) {
|
switch (update.status) {
|
||||||
case 'done':
|
case 'done':
|
||||||
console.log('Done processing:', update);
|
case 'partial':
|
||||||
stopCountdown();
|
if (update.status === 'done') stopCountdown();
|
||||||
setStreamingMessage(undefined);
|
if (update.status === 'done') setStreamingMessage(undefined);
|
||||||
setProcessingMessage(undefined);
|
if (update.status === 'done') setProcessingMessage(undefined);
|
||||||
const backstoryMessage: BackstoryMessage = update;
|
const backstoryMessage: BackstoryMessage = update;
|
||||||
setConversation([
|
setConversation([
|
||||||
...conversationRef.current, {
|
...conversationRef.current, {
|
||||||
...backstoryMessage,
|
...backstoryMessage,
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
origin: type,
|
origin: type,
|
||||||
|
prompt: ['done', 'partial'].includes(update.status) ? update.prompt : '',
|
||||||
content: backstoryMessage.response || "",
|
content: backstoryMessage.response || "",
|
||||||
|
expanded: update.status === "done" ? true : false,
|
||||||
|
expandable: true,
|
||||||
}] as MessageList);
|
}] as MessageList);
|
||||||
// Add a small delay to ensure React has time to update the UI
|
// Add a small delay to ensure React has time to update the UI
|
||||||
await new Promise(resolve => setTimeout(resolve, 0));
|
await new Promise(resolve => setTimeout(resolve, 0));
|
||||||
@ -424,9 +428,9 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
// Update processing message with immediate re-render
|
// Update processing message with immediate re-render
|
||||||
if (update.status === "streaming") {
|
if (update.status === "streaming") {
|
||||||
streaming_response += update.chunk
|
streaming_response += update.chunk
|
||||||
setStreamingMessage({ role: update.status, content: streaming_response });
|
setStreamingMessage({ role: update.status, content: streaming_response, disableCopy: true });
|
||||||
} else {
|
} else {
|
||||||
setProcessingMessage({ role: update.status, content: update.response });
|
setProcessingMessage({ role: update.status, content: update.response, disableCopy: true });
|
||||||
/* Reset stream on non streaming message */
|
/* Reset stream on non streaming message */
|
||||||
streaming_response = ""
|
streaming_response = ""
|
||||||
}
|
}
|
||||||
@ -437,12 +441,11 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
while (true) {
|
while (!stopRef.current) {
|
||||||
const { done, value } = await reader.read();
|
const { done, value } = await reader.read();
|
||||||
if (done) {
|
if (done) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
|
|
||||||
// Process each complete line immediately
|
// Process each complete line immediately
|
||||||
@ -470,26 +473,32 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (stopRef.current) {
|
||||||
|
await reader.cancel();
|
||||||
|
setProcessingMessage(undefined);
|
||||||
|
setStreamingMessage(undefined);
|
||||||
|
setSnack("Processing cancelled", "warning");
|
||||||
|
}
|
||||||
stopCountdown();
|
stopCountdown();
|
||||||
setProcessing(false);
|
setProcessing(false);
|
||||||
|
stopRef.current = false;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Fetch error:', error);
|
console.error('Fetch error:', error);
|
||||||
setSnack("Unable to process query", "error");
|
setSnack("Unable to process query", "error");
|
||||||
setProcessingMessage({ role: 'error', content: "Unable to process query" });
|
setProcessingMessage({ role: 'error', content: "Unable to process query", disableCopy: true });
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
setProcessingMessage(undefined);
|
setProcessingMessage(undefined);
|
||||||
}, 5000);
|
}, 5000);
|
||||||
|
stopRef.current = false;
|
||||||
setProcessing(false);
|
setProcessing(false);
|
||||||
stopCountdown();
|
stopCountdown();
|
||||||
// Add a small delay to ensure React has time to update the UI
|
return;
|
||||||
await new Promise(resolve => setTimeout(resolve, 0));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Scrollable
|
<Scrollable
|
||||||
className={className || "Conversation"}
|
className={`${className || ""} Conversation`}
|
||||||
autoscroll
|
autoscroll
|
||||||
textFieldRef={viewableElementRef}
|
textFieldRef={viewableElementRef}
|
||||||
fallbackThreshold={0.5}
|
fallbackThreshold={0.5}
|
||||||
@ -564,6 +573,20 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
|||||||
</Button>
|
</Button>
|
||||||
</span>
|
</span>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
|
<Tooltip title="Cancel">
|
||||||
|
<span style={{ display: "flex" }}> { /* This span is used to wrap the IconButton to ensure Tooltip works even when disabled */}
|
||||||
|
<IconButton
|
||||||
|
aria-label="cancel"
|
||||||
|
onClick={() => { cancelQuery(); }}
|
||||||
|
sx={{ display: "flex", margin: 'auto 0px' }}
|
||||||
|
size="large"
|
||||||
|
edge="start"
|
||||||
|
disabled={stopRef.current || sessionId === undefined || processing === false}
|
||||||
|
>
|
||||||
|
<CancelIcon />
|
||||||
|
</IconButton>
|
||||||
|
</span>
|
||||||
|
</Tooltip>
|
||||||
</Box>
|
</Box>
|
||||||
</Box>
|
</Box>
|
||||||
{(noInteractions || !hideDefaultPrompts) && defaultPrompts !== undefined && defaultPrompts.length &&
|
{(noInteractions || !hideDefaultPrompts) && defaultPrompts !== undefined && defaultPrompts.length &&
|
||||||
|
@ -47,11 +47,18 @@ type MessageRoles =
|
|||||||
'thinking' |
|
'thinking' |
|
||||||
'user';
|
'user';
|
||||||
|
|
||||||
type MessageData = {
|
type BackstoryMessage = {
|
||||||
|
// Only two required fields
|
||||||
role: MessageRoles,
|
role: MessageRoles,
|
||||||
content: string,
|
content: string,
|
||||||
status?: string, // streaming, done, error...
|
// Rest are optional
|
||||||
response?: string,
|
prompt?: string;
|
||||||
|
preamble?: {};
|
||||||
|
status?: string;
|
||||||
|
full_content?: string;
|
||||||
|
response?: string; // Set when status === 'done', 'partial', or 'error'
|
||||||
|
chunk?: string; // Used when status === 'streaming'
|
||||||
|
timestamp?: string;
|
||||||
disableCopy?: boolean,
|
disableCopy?: boolean,
|
||||||
user?: string,
|
user?: string,
|
||||||
title?: string,
|
title?: string,
|
||||||
@ -84,11 +91,11 @@ interface MessageMetaData {
|
|||||||
setSnack: SetSnackType,
|
setSnack: SetSnackType,
|
||||||
}
|
}
|
||||||
|
|
||||||
type MessageList = MessageData[];
|
type MessageList = BackstoryMessage[];
|
||||||
|
|
||||||
interface MessageProps extends BackstoryElementProps {
|
interface MessageProps extends BackstoryElementProps {
|
||||||
sx?: SxProps<Theme>,
|
sx?: SxProps<Theme>,
|
||||||
message: MessageData,
|
message: BackstoryMessage,
|
||||||
expanded?: boolean,
|
expanded?: boolean,
|
||||||
onExpand?: (open: boolean) => void,
|
onExpand?: (open: boolean) => void,
|
||||||
className?: string,
|
className?: string,
|
||||||
@ -237,7 +244,7 @@ const MessageMeta = (props: MessageMetaProps) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const Message = (props: MessageProps) => {
|
const Message = (props: MessageProps) => {
|
||||||
const { message, submitQuery, sx, className, onExpand, expanded, sessionId, setSnack } = props;
|
const { message, submitQuery, sx, className, onExpand, sessionId, setSnack } = props;
|
||||||
const [metaExpanded, setMetaExpanded] = useState<boolean>(false);
|
const [metaExpanded, setMetaExpanded] = useState<boolean>(false);
|
||||||
const textFieldRef = useRef(null);
|
const textFieldRef = useRef(null);
|
||||||
|
|
||||||
@ -254,14 +261,16 @@ const Message = (props: MessageProps) => {
|
|||||||
return (<></>);
|
return (<></>);
|
||||||
}
|
}
|
||||||
|
|
||||||
const formattedContent = message.content.trim() || "Waiting for LLM to spool up...";
|
const formattedContent = message.content.trim();
|
||||||
|
if (formattedContent === "") {
|
||||||
|
return (<></>);
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ChatBubble
|
<ChatBubble
|
||||||
className={className || "Message"}
|
className={`${className || ""} Message Message-${message.role}`}
|
||||||
{...message}
|
{...message}
|
||||||
onExpand={onExpand}
|
onExpand={onExpand}
|
||||||
expanded={expanded}
|
|
||||||
sx={{
|
sx={{
|
||||||
display: "flex",
|
display: "flex",
|
||||||
flexDirection: "column",
|
flexDirection: "column",
|
||||||
@ -273,7 +282,6 @@ const Message = (props: MessageProps) => {
|
|||||||
...sx,
|
...sx,
|
||||||
}}>
|
}}>
|
||||||
<CardContent ref={textFieldRef} sx={{ position: "relative", display: "flex", flexDirection: "column", overflowX: "auto", m: 0, p: 0, paddingBottom: '0px !important' }}>
|
<CardContent ref={textFieldRef} sx={{ position: "relative", display: "flex", flexDirection: "column", overflowX: "auto", m: 0, p: 0, paddingBottom: '0px !important' }}>
|
||||||
{message.role !== 'user' ?
|
|
||||||
<Scrollable
|
<Scrollable
|
||||||
className="MessageContent"
|
className="MessageContent"
|
||||||
autoscroll
|
autoscroll
|
||||||
@ -289,18 +297,9 @@ const Message = (props: MessageProps) => {
|
|||||||
>
|
>
|
||||||
<StyledMarkdown streaming={message.role === "streaming"} {...{ content: formattedContent, submitQuery, sessionId, setSnack }} />
|
<StyledMarkdown streaming={message.role === "streaming"} {...{ content: formattedContent, submitQuery, sessionId, setSnack }} />
|
||||||
</Scrollable>
|
</Scrollable>
|
||||||
:
|
|
||||||
<Typography
|
|
||||||
className="MessageContent"
|
|
||||||
ref={textFieldRef}
|
|
||||||
variant="body2"
|
|
||||||
sx={{ display: "flex", color: 'text.secondary' }}>
|
|
||||||
{message.content}
|
|
||||||
</Typography>
|
|
||||||
}
|
|
||||||
</CardContent>
|
</CardContent>
|
||||||
<CardActions disableSpacing sx={{ display: "flex", flexDirection: "row", justifyContent: "space-between", alignItems: "center", width: "100%", p: 0, m: 0 }}>
|
<CardActions disableSpacing sx={{ display: "flex", flexDirection: "row", justifyContent: "space-between", alignItems: "center", width: "100%", p: 0, m: 0 }}>
|
||||||
{(message.disableCopy === undefined || message.disableCopy === false) && ["assistant", "content"].includes(message.role) && <CopyBubble content={message.content} />}
|
{(message.disableCopy === undefined || message.disableCopy === false) && <CopyBubble content={message.content} />}
|
||||||
{message.metadata && (
|
{message.metadata && (
|
||||||
<Box sx={{ display: "flex", alignItems: "center", gap: 1 }}>
|
<Box sx={{ display: "flex", alignItems: "center", gap: 1 }}>
|
||||||
<Button variant="text" onClick={handleMetaExpandClick} sx={{ color: "darkgrey", p: 0 }}>
|
<Button variant="text" onClick={handleMetaExpandClick} sx={{ color: "darkgrey", p: 0 }}>
|
||||||
@ -309,7 +308,7 @@ const Message = (props: MessageProps) => {
|
|||||||
<ExpandMore
|
<ExpandMore
|
||||||
expand={metaExpanded}
|
expand={metaExpanded}
|
||||||
onClick={handleMetaExpandClick}
|
onClick={handleMetaExpandClick}
|
||||||
aria-expanded={expanded}
|
aria-expanded={message.expanded}
|
||||||
aria-label="show more"
|
aria-label="show more"
|
||||||
>
|
>
|
||||||
<ExpandMoreIcon />
|
<ExpandMoreIcon />
|
||||||
@ -331,7 +330,8 @@ const Message = (props: MessageProps) => {
|
|||||||
export type {
|
export type {
|
||||||
MessageProps,
|
MessageProps,
|
||||||
MessageList,
|
MessageList,
|
||||||
MessageData,
|
BackstoryMessage,
|
||||||
|
MessageMetaData,
|
||||||
MessageRoles,
|
MessageRoles,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import {
|
|||||||
import { SxProps } from '@mui/material';
|
import { SxProps } from '@mui/material';
|
||||||
|
|
||||||
import { ChatQuery } from './ChatQuery';
|
import { ChatQuery } from './ChatQuery';
|
||||||
import { MessageList, MessageData } from './Message';
|
import { MessageList, BackstoryMessage } from './Message';
|
||||||
import { Conversation } from './Conversation';
|
import { Conversation } from './Conversation';
|
||||||
import { BackstoryPageProps } from './BackstoryTab';
|
import { BackstoryPageProps } from './BackstoryTab';
|
||||||
|
|
||||||
@ -62,11 +62,6 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
|||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (messages.length > 2) {
|
|
||||||
setHasResume(true);
|
|
||||||
setHasFacts(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (messages.length > 0) {
|
if (messages.length > 0) {
|
||||||
messages[0].role = 'content';
|
messages[0].role = 'content';
|
||||||
messages[0].title = 'Job Description';
|
messages[0].title = 'Job Description';
|
||||||
@ -74,6 +69,19 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
|||||||
messages[0].expandable = true;
|
messages[0].expandable = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (-1 !== messages.findIndex(m => m.status === 'done')) {
|
||||||
|
setHasResume(true);
|
||||||
|
setHasFacts(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages;
|
||||||
|
|
||||||
|
if (messages.length > 1) {
|
||||||
|
setHasResume(true);
|
||||||
|
setHasFacts(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if (messages.length > 3) {
|
if (messages.length > 3) {
|
||||||
// messages[2] is Show job requirements
|
// messages[2] is Show job requirements
|
||||||
messages[3].role = 'job-requirements';
|
messages[3].role = 'job-requirements';
|
||||||
@ -95,6 +103,8 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
|||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return messages;
|
||||||
|
|
||||||
if (messages.length > 1) {
|
if (messages.length > 1) {
|
||||||
// messages[0] is Show Qualifications
|
// messages[0] is Show Qualifications
|
||||||
messages[1].role = 'qualifications';
|
messages[1].role = 'qualifications';
|
||||||
@ -139,7 +149,7 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
|||||||
return filtered;
|
return filtered;
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const jobResponse = useCallback(async (message: MessageData) => {
|
const jobResponse = useCallback(async (message: BackstoryMessage) => {
|
||||||
console.log('onJobResponse', message);
|
console.log('onJobResponse', message);
|
||||||
if (message.actions && message.actions.includes("job_description")) {
|
if (message.actions && message.actions.includes("job_description")) {
|
||||||
await jobConversationRef.current.fetchHistory();
|
await jobConversationRef.current.fetchHistory();
|
||||||
@ -155,12 +165,12 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
|||||||
}
|
}
|
||||||
}, [setHasFacts, setHasResume, setActiveTab]);
|
}, [setHasFacts, setHasResume, setActiveTab]);
|
||||||
|
|
||||||
const resumeResponse = useCallback((message: MessageData): void => {
|
const resumeResponse = useCallback((message: BackstoryMessage): void => {
|
||||||
console.log('onResumeResponse', message);
|
console.log('onResumeResponse', message);
|
||||||
setHasFacts(true);
|
setHasFacts(true);
|
||||||
}, [setHasFacts]);
|
}, [setHasFacts]);
|
||||||
|
|
||||||
const factsResponse = useCallback((message: MessageData): void => {
|
const factsResponse = useCallback((message: BackstoryMessage): void => {
|
||||||
console.log('onFactsResponse', message);
|
console.log('onFactsResponse', message);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
@ -199,7 +209,8 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
|||||||
3. **Mapping Analysis**: Identifies legitimate matches between requirements and qualifications
|
3. **Mapping Analysis**: Identifies legitimate matches between requirements and qualifications
|
||||||
3. **Resume Generation**: Uses mapping output to create a tailored resume with evidence-based content
|
3. **Resume Generation**: Uses mapping output to create a tailored resume with evidence-based content
|
||||||
4. **Verification**: Performs fact-checking to catch any remaining fabrications
|
4. **Verification**: Performs fact-checking to catch any remaining fabrications
|
||||||
1. **Re-generation**: If verification does not pass, a second attempt is made to correct any issues`
|
1. **Re-generation**: If verification does not pass, a second attempt is made to correct any issues`,
|
||||||
|
disableCopy: true
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,3 +17,12 @@ pre:not(.MessageContent) {
|
|||||||
.MuiMarkdown > div {
|
.MuiMarkdown > div {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.Message-streaming .MuiMarkdown ul,
|
||||||
|
.Message-streaming .MuiMarkdown h1,
|
||||||
|
.Message-streaming .MuiMarkdown p,
|
||||||
|
.Message-assistant .MuiMarkdown ul,
|
||||||
|
.Message-assistant .MuiMarkdown h1,
|
||||||
|
.Message-assistant .MuiMarkdown p {
|
||||||
|
color: white;
|
||||||
|
}
|
@ -37,14 +37,14 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
|
|||||||
}
|
}
|
||||||
if (className === "lang-json" && !streaming) {
|
if (className === "lang-json" && !streaming) {
|
||||||
try {
|
try {
|
||||||
const fixed = jsonrepair(content);
|
let fixed = JSON.parse(jsonrepair(content));
|
||||||
return <Scrollable className="JsonViewScrollable">
|
return <Scrollable className="JsonViewScrollable">
|
||||||
<JsonView
|
<JsonView
|
||||||
className="JsonView"
|
className="JsonView"
|
||||||
style={{
|
style={{
|
||||||
...vscodeTheme,
|
...vscodeTheme,
|
||||||
fontSize: "0.8rem",
|
fontSize: "0.8rem",
|
||||||
maxHeight: "20rem",
|
maxHeight: "10rem",
|
||||||
padding: "14px 0",
|
padding: "14px 0",
|
||||||
overflow: "hidden",
|
overflow: "hidden",
|
||||||
width: "100%",
|
width: "100%",
|
||||||
@ -53,9 +53,9 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
|
|||||||
}}
|
}}
|
||||||
displayDataTypes={false}
|
displayDataTypes={false}
|
||||||
objectSortKeys={false}
|
objectSortKeys={false}
|
||||||
collapsed={false}
|
collapsed={true}
|
||||||
shortenTextAfterLength={100}
|
shortenTextAfterLength={100}
|
||||||
value={JSON.parse(fixed)}>
|
value={fixed}>
|
||||||
<JsonView.String
|
<JsonView.String
|
||||||
render={({ children, ...reset }) => {
|
render={({ children, ...reset }) => {
|
||||||
if (typeof (children) === "string" && children.match("\n")) {
|
if (typeof (children) === "string" && children.match("\n")) {
|
||||||
@ -66,7 +66,7 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
|
|||||||
</JsonView>
|
</JsonView>
|
||||||
</Scrollable>
|
</Scrollable>
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log("jsonrepair error", e);
|
return <pre><code className="JsonRaw">{content}</code></pre>
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
return <pre><code className={className}>{element.children}</code></pre>;
|
return <pre><code className={className}>{element.children}</code></pre>;
|
||||||
|
415
src/server.py
415
src/server.py
@ -1,8 +1,9 @@
|
|||||||
LLM_TIMEOUT = 600
|
LLM_TIMEOUT = 600
|
||||||
|
|
||||||
from utils import logger
|
from utils import logger
|
||||||
|
from pydantic import BaseModel, Field # type: ignore
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, Dict, Optional
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Imports [standard]
|
# Imports [standard]
|
||||||
@ -26,6 +27,7 @@ from uuid import uuid4
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
def try_import(module_name, pip_name=None):
|
def try_import(module_name, pip_name=None):
|
||||||
try:
|
try:
|
||||||
__import__(module_name)
|
__import__(module_name)
|
||||||
@ -33,6 +35,7 @@ def try_import(module_name, pip_name=None):
|
|||||||
print(f"Module '{module_name}' not found. Install it using:")
|
print(f"Module '{module_name}' not found. Install it using:")
|
||||||
print(f" pip install {pip_name or module_name}")
|
print(f" pip install {pip_name or module_name}")
|
||||||
|
|
||||||
|
|
||||||
# Third-party modules with import checks
|
# Third-party modules with import checks
|
||||||
try_import("ollama")
|
try_import("ollama")
|
||||||
try_import("requests")
|
try_import("requests")
|
||||||
@ -63,7 +66,9 @@ from prometheus_client import CollectorRegistry, Counter # type: ignore
|
|||||||
from utils import (
|
from utils import (
|
||||||
rag as Rag,
|
rag as Rag,
|
||||||
tools as Tools,
|
tools as Tools,
|
||||||
Context, Conversation, Message,
|
Context,
|
||||||
|
Conversation,
|
||||||
|
Message,
|
||||||
Agent,
|
Agent,
|
||||||
Metrics,
|
Metrics,
|
||||||
Tunables,
|
Tunables,
|
||||||
@ -74,11 +79,22 @@ from utils import (
|
|||||||
CONTEXT_VERSION = 2
|
CONTEXT_VERSION = 2
|
||||||
|
|
||||||
rags = [
|
rags = [
|
||||||
{ "name": "JPK", "enabled": True, "description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects." },
|
{
|
||||||
|
"name": "JPK",
|
||||||
|
"enabled": True,
|
||||||
|
"description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects.",
|
||||||
|
},
|
||||||
# { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
|
# { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
|
||||||
]
|
]
|
||||||
|
|
||||||
REQUEST_TIME = Summary('request_processing_seconds', 'Time spent processing request')
|
|
||||||
|
class QueryOptions(BaseModel):
|
||||||
|
prompt: str
|
||||||
|
tunables: Tunables = Field(default_factory=Tunables)
|
||||||
|
agent_options: Dict[str, Any] = Field(default={})
|
||||||
|
|
||||||
|
|
||||||
|
REQUEST_TIME = Summary("request_processing_seconds", "Time spent processing request")
|
||||||
|
|
||||||
system_message_old = f"""
|
system_message_old = f"""
|
||||||
Launched on {datetime.now().isoformat()}.
|
Launched on {datetime.now().isoformat()}.
|
||||||
@ -107,6 +123,7 @@ You are provided with a <|resume|> which was generated by you, the <|context|> y
|
|||||||
Your task is to answer questions about the <|fact_check|> you generated based on the <|resume|> and <|context>.
|
Your task is to answer questions about the <|fact_check|> you generated based on the <|resume|> and <|context>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_installed_ram():
|
def get_installed_ram():
|
||||||
try:
|
try:
|
||||||
with open("/proc/meminfo", "r") as f:
|
with open("/proc/meminfo", "r") as f:
|
||||||
@ -117,21 +134,29 @@ def get_installed_ram():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error retrieving RAM: {e}"
|
return f"Error retrieving RAM: {e}"
|
||||||
|
|
||||||
|
|
||||||
def get_graphics_cards():
|
def get_graphics_cards():
|
||||||
gpus = []
|
gpus = []
|
||||||
try:
|
try:
|
||||||
# Run the ze-monitor utility
|
# Run the ze-monitor utility
|
||||||
result = subprocess.run(["ze-monitor"], capture_output=True, text=True, check=True)
|
result = subprocess.run(
|
||||||
|
["ze-monitor"], capture_output=True, text=True, check=True
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up the output (remove leading/trailing whitespace and newlines)
|
# Clean up the output (remove leading/trailing whitespace and newlines)
|
||||||
output = result.stdout.strip()
|
output = result.stdout.strip()
|
||||||
for index in range(len(output.splitlines())):
|
for index in range(len(output.splitlines())):
|
||||||
result = subprocess.run(["ze-monitor", "--device", f"{index+1}", "--info"], capture_output=True, text=True, check=True)
|
result = subprocess.run(
|
||||||
|
["ze-monitor", "--device", f"{index+1}", "--info"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
gpu_info = result.stdout.strip().splitlines()
|
gpu_info = result.stdout.strip().splitlines()
|
||||||
gpu = {
|
gpu = {
|
||||||
"discrete": True, # Assume it's discrete initially
|
"discrete": True, # Assume it's discrete initially
|
||||||
"name": None,
|
"name": None,
|
||||||
"memory": None
|
"memory": None,
|
||||||
}
|
}
|
||||||
gpus.append(gpu)
|
gpus.append(gpu)
|
||||||
for line in gpu_info:
|
for line in gpu_info:
|
||||||
@ -154,6 +179,7 @@ def get_graphics_cards():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error retrieving GPU info: {e}"
|
return f"Error retrieving GPU info: {e}"
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_info():
|
def get_cpu_info():
|
||||||
try:
|
try:
|
||||||
with open("/proc/cpuinfo", "r") as f:
|
with open("/proc/cpuinfo", "r") as f:
|
||||||
@ -165,6 +191,7 @@ def get_cpu_info():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error retrieving CPU info: {e}"
|
return f"Error retrieving CPU info: {e}"
|
||||||
|
|
||||||
|
|
||||||
def system_info(model):
|
def system_info(model):
|
||||||
return {
|
return {
|
||||||
"System RAM": get_installed_ram(),
|
"System RAM": get_installed_ram(),
|
||||||
@ -172,9 +199,10 @@ def system_info(model):
|
|||||||
"CPU": get_cpu_info(),
|
"CPU": get_cpu_info(),
|
||||||
"LLM Model": model,
|
"LLM Model": model,
|
||||||
"Embedding Model": defines.embedding_model,
|
"Embedding Model": defines.embedding_model,
|
||||||
"Context length": defines.max_context
|
"Context length": defines.max_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Defaults
|
# Defaults
|
||||||
OLLAMA_API_URL = defines.ollama_api_url
|
OLLAMA_API_URL = defines.ollama_api_url
|
||||||
@ -192,31 +220,58 @@ DEFAULT_HISTORY_LENGTH=5
|
|||||||
def create_system_message(prompt):
|
def create_system_message(prompt):
|
||||||
return [{"role": "system", "content": prompt}]
|
return [{"role": "system", "content": prompt}]
|
||||||
|
|
||||||
|
|
||||||
tool_log = []
|
tool_log = []
|
||||||
command_log = []
|
command_log = []
|
||||||
model = None
|
model = None
|
||||||
client = None
|
client = None
|
||||||
web_server = None
|
web_server = None
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Cmd line overrides
|
# Cmd line overrides
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
||||||
parser.add_argument("--ollama-server", type=str, default=OLLAMA_API_URL, help=f"Ollama API endpoint. default={OLLAMA_API_URL}")
|
parser.add_argument(
|
||||||
parser.add_argument("--ollama-model", type=str, default=MODEL_NAME, help=f"LLM model to use. default={MODEL_NAME}")
|
"--ollama-server",
|
||||||
parser.add_argument("--web-host", type=str, default=WEB_HOST, help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.")
|
type=str,
|
||||||
parser.add_argument("--web-port", type=str, default=WEB_PORT, help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.")
|
default=OLLAMA_API_URL,
|
||||||
parser.add_argument("--level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
help=f"Ollama API endpoint. default={OLLAMA_API_URL}",
|
||||||
default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ollama-model",
|
||||||
|
type=str,
|
||||||
|
default=MODEL_NAME,
|
||||||
|
help=f"LLM model to use. default={MODEL_NAME}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--web-host",
|
||||||
|
type=str,
|
||||||
|
default=WEB_HOST,
|
||||||
|
help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--web-port",
|
||||||
|
type=str,
|
||||||
|
default=WEB_PORT,
|
||||||
|
help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--level",
|
||||||
|
type=str,
|
||||||
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||||
|
default=LOG_LEVEL,
|
||||||
|
help=f"Set the logging level. default={LOG_LEVEL}",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
def is_valid_uuid(value: str) -> bool:
|
def is_valid_uuid(value: str) -> bool:
|
||||||
try:
|
try:
|
||||||
@ -226,11 +281,6 @@ def is_valid_uuid(value: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
class WebServer:
|
class WebServer:
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@ -239,9 +289,11 @@ class WebServer:
|
|||||||
self.observer, self.file_watcher = Rag.start_file_watcher(
|
self.observer, self.file_watcher = Rag.start_file_watcher(
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
watch_directory=defines.doc_dir,
|
watch_directory=defines.doc_dir,
|
||||||
recreate=False # Don't recreate if exists
|
recreate=False, # Don't recreate if exists
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"API started with {self.file_watcher.collection.count()} documents in the collection"
|
||||||
)
|
)
|
||||||
logger.info(f"API started with {self.file_watcher.collection.count()} documents in the collection")
|
|
||||||
yield
|
yield
|
||||||
if self.observer:
|
if self.observer:
|
||||||
self.observer.stop()
|
self.observer.stop()
|
||||||
@ -271,7 +323,9 @@ class WebServer:
|
|||||||
self.file_watcher = None
|
self.file_watcher = None
|
||||||
self.observer = None
|
self.observer = None
|
||||||
|
|
||||||
self.ssl_enabled = os.path.exists(defines.key_path) and os.path.exists(defines.cert_path)
|
self.ssl_enabled = os.path.exists(defines.key_path) and os.path.exists(
|
||||||
|
defines.cert_path
|
||||||
|
)
|
||||||
|
|
||||||
if self.ssl_enabled:
|
if self.ssl_enabled:
|
||||||
allow_origins = ["https://battle-linux.ketrenos.com:3000"]
|
allow_origins = ["https://battle-linux.ketrenos.com:3000"]
|
||||||
@ -307,14 +361,18 @@ class WebServer:
|
|||||||
|
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
if not context:
|
if not context:
|
||||||
return JSONResponse({"error": f"Invalid context: {context_id}"}, status_code=400)
|
return JSONResponse(
|
||||||
|
{"error": f"Invalid context: {context_id}"}, status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
|
||||||
dimensions = data.get("dimensions", 2)
|
dimensions = data.get("dimensions", 2)
|
||||||
result = self.file_watcher.umap_collection
|
result = self.file_watcher.umap_collection
|
||||||
if not result:
|
if not result:
|
||||||
return JSONResponse({"error": "No UMAP collection found"}, status_code=404)
|
return JSONResponse(
|
||||||
|
{"error": "No UMAP collection found"}, status_code=404
|
||||||
|
)
|
||||||
if dimensions == 2:
|
if dimensions == 2:
|
||||||
logger.info("Returning 2D UMAP")
|
logger.info("Returning 2D UMAP")
|
||||||
umap_embedding = self.file_watcher.umap_embedding_2d
|
umap_embedding = self.file_watcher.umap_embedding_2d
|
||||||
@ -323,7 +381,9 @@ class WebServer:
|
|||||||
umap_embedding = self.file_watcher.umap_embedding_3d
|
umap_embedding = self.file_watcher.umap_embedding_3d
|
||||||
|
|
||||||
if len(umap_embedding) == 0:
|
if len(umap_embedding) == 0:
|
||||||
return JSONResponse({"error": "No UMAP embedding found"}, status_code=404)
|
return JSONResponse(
|
||||||
|
{"error": "No UMAP embedding found"}, status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
result["embeddings"] = umap_embedding.tolist()
|
result["embeddings"] = umap_embedding.tolist()
|
||||||
|
|
||||||
@ -347,35 +407,56 @@ class WebServer:
|
|||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
query = data.get("query", "")
|
query = data.get("query", "")
|
||||||
|
threshold = data.get("threshold", 0.5)
|
||||||
|
results = data.get("results", 10)
|
||||||
except:
|
except:
|
||||||
query = ""
|
query = ""
|
||||||
|
threshold = 0.5
|
||||||
|
results = 10
|
||||||
if not query:
|
if not query:
|
||||||
return JSONResponse({"error": "No query provided for similarity search"}, status_code=400)
|
return JSONResponse(
|
||||||
|
{"error": "No query provided for similarity search"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
chroma_results = self.file_watcher.find_similar(query=query, top_k=10)
|
chroma_results = self.file_watcher.find_similar(
|
||||||
|
query=query, top_k=results, threshold=threshold
|
||||||
|
)
|
||||||
if not chroma_results:
|
if not chroma_results:
|
||||||
return JSONResponse({"error": "No results found"}, status_code=404)
|
return JSONResponse({"error": "No results found"}, status_code=404)
|
||||||
|
|
||||||
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
chroma_embedding = np.array(
|
||||||
|
chroma_results["query_embedding"]
|
||||||
|
).flatten() # Ensure correct shape
|
||||||
logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||||
|
|
||||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[
|
||||||
logger.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
0
|
||||||
|
].tolist()
|
||||||
|
logger.info(
|
||||||
|
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
|
||||||
|
) # Debug output
|
||||||
|
|
||||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[
|
||||||
logger.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
0
|
||||||
|
].tolist()
|
||||||
|
logger.info(
|
||||||
|
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
|
||||||
|
) # Debug output
|
||||||
|
|
||||||
return JSONResponse({
|
return JSONResponse(
|
||||||
|
{
|
||||||
**chroma_results,
|
**chroma_results,
|
||||||
"query": query,
|
"query": query,
|
||||||
"umap_embedding_2d": umap_2d,
|
"umap_embedding_2d": umap_2d,
|
||||||
"umap_embedding_3d": umap_3d
|
"umap_embedding_3d": umap_3d,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
#return JSONResponse({"error": str(e)}, 500)
|
logging.error(traceback.format_exc())
|
||||||
|
return JSONResponse({"error": str(e)}, 500)
|
||||||
|
|
||||||
@self.app.put("/api/reset/{context_id}/{agent_type}")
|
@self.app.put("/api/reset/{context_id}/{agent_type}")
|
||||||
async def put_reset(context_id: str, agent_type: str, request: Request):
|
async def put_reset(context_id: str, agent_type: str, request: Request):
|
||||||
@ -386,7 +467,10 @@ class WebServer:
|
|||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent(agent_type)
|
agent = context.get_agent(agent_type)
|
||||||
if not agent:
|
if not agent:
|
||||||
return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404)
|
return JSONResponse(
|
||||||
|
{"error": f"{agent_type} is not recognized", "context": context.id},
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
try:
|
try:
|
||||||
@ -405,20 +489,29 @@ class WebServer:
|
|||||||
response["tools"] = context.tools
|
response["tools"] = context.tools
|
||||||
case "history":
|
case "history":
|
||||||
reset_map = {
|
reset_map = {
|
||||||
"job_description": ("job_description", "resume", "fact_check"),
|
"job_description": (
|
||||||
|
"job_description",
|
||||||
|
"resume",
|
||||||
|
"fact_check",
|
||||||
|
),
|
||||||
"resume": ("job_description", "resume", "fact_check"),
|
"resume": ("job_description", "resume", "fact_check"),
|
||||||
"fact_check": ("job_description", "resume", "fact_check"),
|
"fact_check": (
|
||||||
|
"job_description",
|
||||||
|
"resume",
|
||||||
|
"fact_check",
|
||||||
|
),
|
||||||
"chat": ("chat",),
|
"chat": ("chat",),
|
||||||
}
|
}
|
||||||
resets = reset_map.get(agent_type, ())
|
resets = reset_map.get(agent_type, ())
|
||||||
|
|
||||||
for mode in resets:
|
for mode in resets:
|
||||||
tmp = context.get_agent(mode)
|
tmp = context.get_agent(mode)
|
||||||
if not tmp:
|
if not tmp:
|
||||||
|
logger.info(
|
||||||
|
f"Agent {mode} not found for {context_id}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
logger.info(f"Resetting {reset_operation} for {mode}")
|
logger.info(f"Resetting {reset_operation} for {mode}")
|
||||||
context.conversation = Conversation()
|
tmp.conversation.reset()
|
||||||
context.context_tokens = round(len(str(agent.system_prompt)) * 3 / 4) # Estimate context usage
|
|
||||||
response["history"] = []
|
response["history"] = []
|
||||||
response["context_used"] = agent.context_tokens
|
response["context_used"] = agent.context_tokens
|
||||||
case "message_history_length":
|
case "message_history_length":
|
||||||
@ -427,13 +520,19 @@ class WebServer:
|
|||||||
response["message_history_length"] = DEFAULT_HISTORY_LENGTH
|
response["message_history_length"] = DEFAULT_HISTORY_LENGTH
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
return JSONResponse({ "error": "Usage: { reset: rags|tools|history|system_prompt}"})
|
return JSONResponse(
|
||||||
|
{"error": "Usage: { reset: rags|tools|history|system_prompt}"}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
except:
|
except Exception as e:
|
||||||
return JSONResponse({ "error": "Usage: { reset: rags|tools|history|system_prompt}"})
|
logger.error(f"Error in reset: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "Usage: { reset: rags|tools|history|system_prompt}"}
|
||||||
|
)
|
||||||
|
|
||||||
@self.app.put("/api/tunables/{context_id}")
|
@self.app.put("/api/tunables/{context_id}")
|
||||||
async def put_tunables(context_id: str, request: Request):
|
async def put_tunables(context_id: str, request: Request):
|
||||||
@ -444,29 +543,49 @@ class WebServer:
|
|||||||
data = await request.json()
|
data = await request.json()
|
||||||
agent = context.get_agent("chat")
|
agent = context.get_agent("chat")
|
||||||
if not agent:
|
if not agent:
|
||||||
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
return JSONResponse(
|
||||||
|
{"error": f"chat is not recognized", "context": context.id},
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
for k in data.keys():
|
for k in data.keys():
|
||||||
match k:
|
match k:
|
||||||
case "tools":
|
case "tools":
|
||||||
# { "tools": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
# { "tools": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
||||||
tools: list[dict[str, Any]] = data[k]
|
tools: list[dict[str, Any]] = data[k]
|
||||||
if not tools:
|
if not tools:
|
||||||
return JSONResponse({ "status": "error", "message": "Tools can not be empty." })
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"status": "error",
|
||||||
|
"message": "Tools can not be empty.",
|
||||||
|
}
|
||||||
|
)
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
for context_tool in context.tools:
|
for context_tool in context.tools:
|
||||||
if context_tool["function"]["name"] == tool["name"]:
|
if context_tool["function"]["name"] == tool["name"]:
|
||||||
context_tool["enabled"] = tool["enabled"]
|
context_tool["enabled"] = tool["enabled"]
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse({ "tools": [ {
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
**t["function"],
|
**t["function"],
|
||||||
"enabled": t["enabled"],
|
"enabled": t["enabled"],
|
||||||
} for t in context.tools] })
|
}
|
||||||
|
for t in context.tools
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
case "rags":
|
case "rags":
|
||||||
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
||||||
rags: list[dict[str, Any]] = data[k]
|
rags: list[dict[str, Any]] = data[k]
|
||||||
if not rags:
|
if not rags:
|
||||||
return JSONResponse({ "status": "error", "message": "RAGs can not be empty." })
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"status": "error",
|
||||||
|
"message": "RAGs can not be empty.",
|
||||||
|
}
|
||||||
|
)
|
||||||
for rag in rags:
|
for rag in rags:
|
||||||
for context_rag in context.rags:
|
for context_rag in context.rags:
|
||||||
if context_rag["name"] == rag["name"]:
|
if context_rag["name"] == rag["name"]:
|
||||||
@ -477,7 +596,12 @@ class WebServer:
|
|||||||
case "system_prompt":
|
case "system_prompt":
|
||||||
system_prompt = data[k].strip()
|
system_prompt = data[k].strip()
|
||||||
if not system_prompt:
|
if not system_prompt:
|
||||||
return JSONResponse({ "status": "error", "message": "System prompt can not be empty." })
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"status": "error",
|
||||||
|
"message": "System prompt can not be empty.",
|
||||||
|
}
|
||||||
|
)
|
||||||
agent.system_prompt = system_prompt
|
agent.system_prompt = system_prompt
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse({"system_prompt": system_prompt})
|
return JSONResponse({"system_prompt": system_prompt})
|
||||||
@ -487,7 +611,9 @@ class WebServer:
|
|||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse({"message_history_length": value})
|
return JSONResponse({"message_history_length": value})
|
||||||
case _:
|
case _:
|
||||||
return JSONResponse({ "error": f"Unrecognized tunable {k}"}, status_code=404)
|
return JSONResponse(
|
||||||
|
{"error": f"Unrecognized tunable {k}"}, status_code=404
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in put_tunables: {e}")
|
logger.error(f"Error in put_tunables: {e}")
|
||||||
return JSONResponse({"error": str(e)}, status_code=500)
|
return JSONResponse({"error": str(e)}, status_code=500)
|
||||||
@ -501,85 +627,94 @@ class WebServer:
|
|||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent("chat")
|
agent = context.get_agent("chat")
|
||||||
if not agent:
|
if not agent:
|
||||||
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
return JSONResponse(
|
||||||
return JSONResponse({
|
{"error": f"chat is not recognized", "context": context.id},
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
"system_prompt": agent.system_prompt,
|
"system_prompt": agent.system_prompt,
|
||||||
"message_history_length": context.message_history_length,
|
"message_history_length": context.message_history_length,
|
||||||
"rags": context.rags,
|
"rags": context.rags,
|
||||||
"tools": [ {
|
"tools": [
|
||||||
|
{
|
||||||
**t["function"],
|
**t["function"],
|
||||||
"enabled": t["enabled"],
|
"enabled": t["enabled"],
|
||||||
} for t in context.tools ]
|
}
|
||||||
})
|
for t in context.tools
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@self.app.get("/api/system-info/{context_id}")
|
@self.app.get("/api/system-info/{context_id}")
|
||||||
async def get_system_info(context_id: str, request: Request):
|
async def get_system_info(context_id: str, request: Request):
|
||||||
logger.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
return JSONResponse(system_info(self.model))
|
return JSONResponse(system_info(self.model))
|
||||||
|
|
||||||
@self.app.post("/api/chat/{context_id}/{agent_type}")
|
@self.app.post("/api/{agent_type}/{context_id}")
|
||||||
async def post_chat_endpoint(context_id: str, agent_type: str, request: Request):
|
async def post_agent_endpoint(
|
||||||
|
agent_type: str, context_id: str, request: Request
|
||||||
|
):
|
||||||
logger.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
if not is_valid_uuid(context_id):
|
|
||||||
logger.warning(f"Invalid context_id: {context_id}")
|
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent(agent_type)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Attempt to create agent type: {agent_type} failed", e)
|
error = {
|
||||||
return JSONResponse({"error": f"{agent_type} is not recognized or context {context_id} is invalid "}, status_code=404)
|
"error": f"Unable to create or access context {context_id}: {e}"
|
||||||
|
}
|
||||||
|
logger.info(error)
|
||||||
|
return JSONResponse(error, status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = await request.json()
|
data = await request.json()
|
||||||
prompt = query["prompt"]
|
query: QueryOptions = QueryOptions(**data)
|
||||||
if not isinstance(prompt, str) or len(prompt) == 0:
|
|
||||||
logger.info(f"Prompt is empty")
|
|
||||||
return JSONResponse({"error": "Prompt cannot be empty"}, status_code=400)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Attempt to parse request: {str(e)}.")
|
error = {"error": f"Attempt to parse request: {e}"}
|
||||||
return JSONResponse({"error": f"Attempt to parse request: {str(e)}."}, status_code=400)
|
logger.info(error)
|
||||||
|
return JSONResponse(error, status_code=400)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
options = Tunables(**query["options"]) if "options" in query else None
|
agent = context.get_or_create_agent(agent_type, **query.agent_options)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Attempt to set tunables failed: {query['options']}.", e)
|
error = {
|
||||||
return JSONResponse({"error": f"Invalid options: {query['options']}"}, status_code=400)
|
"error": f"Attempt to create agent type: {agent_type} failed: {e}"
|
||||||
|
}
|
||||||
|
return JSONResponse(error, status_code=404)
|
||||||
|
|
||||||
if not agent:
|
|
||||||
match agent_type:
|
|
||||||
case "job_description":
|
|
||||||
logger.info(f"Agent {agent_type} not found. Returning empty history.")
|
|
||||||
agent = context.get_or_create_agent("job_description", job_description=prompt)
|
|
||||||
case _:
|
|
||||||
logger.info(f"Invalid agent creation sequence for {agent_type}. Returning error.")
|
|
||||||
return JSONResponse({"error": f"{agent_type} is not recognized", "context": context.id}, status_code=404)
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
async def flush_generator():
|
async def flush_generator():
|
||||||
logger.info(f"{agent.agent_type} - {inspect.stack()[0].function}")
|
logger.info(f"{agent.agent_type} - {inspect.stack()[0].function}")
|
||||||
try:
|
try:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
async for message in self.generate_response(context=context, agent=agent, prompt=prompt, options=options):
|
async for message in self.generate_response(
|
||||||
if message.status != "done":
|
context=context,
|
||||||
|
agent=agent,
|
||||||
|
prompt=query.prompt,
|
||||||
|
options=query.options,
|
||||||
|
):
|
||||||
|
if message.status != "done" and message.status != "partial":
|
||||||
if message.status == "streaming":
|
if message.status == "streaming":
|
||||||
result = {
|
result = {
|
||||||
"status": "streaming",
|
"status": "streaming",
|
||||||
"chunk": message.chunk,
|
"chunk": message.chunk,
|
||||||
"remaining_time": LLM_TIMEOUT - (time.perf_counter() - start_time)
|
"remaining_time": LLM_TIMEOUT
|
||||||
|
- (time.perf_counter() - start_time),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
result = {
|
result = {
|
||||||
"status": message.status,
|
"status": message.status,
|
||||||
"response": message.response,
|
"response": message.response,
|
||||||
"remaining_time": LLM_TIMEOUT
|
"remaining_time": LLM_TIMEOUT,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
logger.info(f"Message complete. Providing full response.")
|
logger.info(f"Providing {message.status} response.")
|
||||||
try:
|
try:
|
||||||
message.response = message.response
|
result = message.model_dump(
|
||||||
result = message.model_dump(by_alias=True, mode='json')
|
by_alias=True, mode="json"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result = {"status": "error", "response": str(e)}
|
result = {"status": "error", "response": str(e)}
|
||||||
yield json.dumps(result) + "\n"
|
yield json.dumps(result) + "\n"
|
||||||
@ -589,26 +724,27 @@ class WebServer:
|
|||||||
result = json.dumps(result) + "\n"
|
result = json.dumps(result) + "\n"
|
||||||
message.network_packets += 1
|
message.network_packets += 1
|
||||||
message.network_bytes += len(result)
|
message.network_bytes += len(result)
|
||||||
|
yield result
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
logger.info("Disconnect detected. Aborting generation.")
|
logger.info("Disconnect detected. Aborting generation.")
|
||||||
context.processing = False
|
context.processing = False
|
||||||
# Save context on completion or error
|
# Save context on completion or error
|
||||||
message.prompt = prompt
|
message.prompt = query.prompt
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
message.response = "Client disconnected during generation."
|
message.response = (
|
||||||
|
"Client disconnected during generation."
|
||||||
|
)
|
||||||
agent.conversation.add(message)
|
agent.conversation.add(message)
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
yield result
|
|
||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
if current_time - start_time > LLM_TIMEOUT:
|
if current_time - start_time > LLM_TIMEOUT:
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
message.response = f"Processing time ({LLM_TIMEOUT}s) exceeded for single LLM inference (likely due to LLM getting stuck.) You will need to retry your query."
|
message.response = f"Processing time ({LLM_TIMEOUT}s) exceeded for single LLM inference (likely due to LLM getting stuck.) You will need to retry your query."
|
||||||
message.partial_response = message.response
|
message.partial_response = message.response
|
||||||
logger.info(message.response + " Ending session")
|
logger.info(message.response + " Ending session")
|
||||||
result = message.model_dump(by_alias=True, mode='json')
|
result = message.model_dump(by_alias=True, mode="json")
|
||||||
result = json.dumps(result) + "\n"
|
result = json.dumps(result) + "\n"
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
@ -620,7 +756,7 @@ class WebServer:
|
|||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
context.processing = False
|
context.processing = False
|
||||||
logger.error(f"Error in process_generator: {e}")
|
logger.error(f"Error in generate_response: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
yield json.dumps({"status": "error", "response": str(e)}) + "\n"
|
yield json.dumps({"status": "error", "response": str(e)}) + "\n"
|
||||||
finally:
|
finally:
|
||||||
@ -634,8 +770,8 @@ class WebServer:
|
|||||||
headers={
|
headers={
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"X-Accel-Buffering": "no" # Prevents Nginx buffering if you're using it
|
"X-Accel-Buffering": "no", # Prevents Nginx buffering if you're using it
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
context.processing = False
|
context.processing = False
|
||||||
@ -660,13 +796,18 @@ class WebServer:
|
|||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent(agent_type)
|
agent = context.get_agent(agent_type)
|
||||||
if not agent:
|
if not agent:
|
||||||
logger.info(f"Agent {agent_type} not found. Returning empty history.")
|
logger.info(
|
||||||
|
f"Agent {agent_type} not found. Returning empty history."
|
||||||
|
)
|
||||||
return JSONResponse({"messages": []})
|
return JSONResponse({"messages": []})
|
||||||
logger.info(f"History for {agent_type} contains {len(agent.conversation)} entries.")
|
logger.info(
|
||||||
|
f"History for {agent_type} contains {len(agent.conversation)} entries."
|
||||||
|
)
|
||||||
return agent.conversation
|
return agent.conversation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"get_history error: {str(e)}")
|
logger.error(f"get_history error: {str(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return JSONResponse({"error": str(e)}, status_code=404)
|
return JSONResponse({"error": str(e)}, status_code=404)
|
||||||
|
|
||||||
@ -692,11 +833,12 @@ class WebServer:
|
|||||||
tool["enabled"] = enabled
|
tool["enabled"] = enabled
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse(context.tools)
|
return JSONResponse(context.tools)
|
||||||
return JSONResponse({ "status": f"{modify} not found in tools." }, status_code=404)
|
return JSONResponse(
|
||||||
|
{"status": f"{modify} not found in tools."}, status_code=404
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
return JSONResponse({"status": "error"}, 405)
|
return JSONResponse({"status": "error"}, 405)
|
||||||
|
|
||||||
|
|
||||||
@self.app.get("/api/context-status/{context_id}/{agent_type}")
|
@self.app.get("/api/context-status/{context_id}/{agent_type}")
|
||||||
async def get_context_status(context_id, agent_type: str, request: Request):
|
async def get_context_status(context_id, agent_type: str, request: Request):
|
||||||
logger.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
@ -706,8 +848,15 @@ class WebServer:
|
|||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent(agent_type)
|
agent = context.get_agent(agent_type)
|
||||||
if not agent:
|
if not agent:
|
||||||
return JSONResponse({"context_used": 0, "max_context": defines.max_context})
|
return JSONResponse(
|
||||||
return JSONResponse({"context_used": agent.context_tokens, "max_context": defines.max_context})
|
{"context_used": 0, "max_context": defines.max_context}
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"context_used": agent.context_tokens,
|
||||||
|
"max_context": defines.max_context,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@self.app.get("/api/health")
|
@self.app.get("/api/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
@ -770,8 +919,11 @@ class WebServer:
|
|||||||
# Read and deserialize the data
|
# Read and deserialize the data
|
||||||
with open(file_path, "r") as f:
|
with open(file_path, "r") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
logger.info(f"Loading context from {file_path}, content length: {len(content)}")
|
logger.info(
|
||||||
|
f"Loading context from {file_path}, content length: {len(content)}"
|
||||||
|
)
|
||||||
import json
|
import json
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Try parsing as JSON first to ensure valid JSON
|
# Try parsing as JSON first to ensure valid JSON
|
||||||
json_data = json.loads(content)
|
json_data = json.loads(content)
|
||||||
@ -787,7 +939,9 @@ class WebServer:
|
|||||||
# Now set context on agents manually
|
# Now set context on agents manually
|
||||||
agent_types = [agent.agent_type for agent in context.agents]
|
agent_types = [agent.agent_type for agent in context.agents]
|
||||||
if len(agent_types) != len(set(agent_types)):
|
if len(agent_types) != len(set(agent_types)):
|
||||||
raise ValueError("Context cannot contain multiple agents of the same agent_type")
|
raise ValueError(
|
||||||
|
"Context cannot contain multiple agents of the same agent_type"
|
||||||
|
)
|
||||||
for agent in context.agents:
|
for agent in context.agents:
|
||||||
agent.set_context(context)
|
agent.set_context(context)
|
||||||
|
|
||||||
@ -799,9 +953,14 @@ class WebServer:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error validating context: {str(e)}")
|
logger.error(f"Error validating context: {str(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
# Fallback to creating a new context
|
# Fallback to creating a new context
|
||||||
self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher, prometheus_collector=self.prometheus_collector)
|
self.contexts[context_id] = Context(
|
||||||
|
id=context_id,
|
||||||
|
file_watcher=self.file_watcher,
|
||||||
|
prometheus_collector=self.prometheus_collector,
|
||||||
|
)
|
||||||
|
|
||||||
return self.contexts[context_id]
|
return self.contexts[context_id]
|
||||||
|
|
||||||
@ -818,7 +977,11 @@ class WebServer:
|
|||||||
if not context_id:
|
if not context_id:
|
||||||
context_id = str(uuid4())
|
context_id = str(uuid4())
|
||||||
logger.info(f"Creating new context with ID: {context_id}")
|
logger.info(f"Creating new context with ID: {context_id}")
|
||||||
context = Context(id=context_id, file_watcher=self.file_watcher, prometheus_collector=self.prometheus_collector)
|
context = Context(
|
||||||
|
id=context_id,
|
||||||
|
file_watcher=self.file_watcher,
|
||||||
|
prometheus_collector=self.prometheus_collector,
|
||||||
|
)
|
||||||
|
|
||||||
if os.path.exists(defines.resume_doc):
|
if os.path.exists(defines.resume_doc):
|
||||||
context.user_resume = open(defines.resume_doc, "r").read()
|
context.user_resume = open(defines.resume_doc, "r").read()
|
||||||
@ -855,7 +1018,9 @@ class WebServer:
|
|||||||
return self.load_or_create_context(context_id)
|
return self.load_or_create_context(context_id)
|
||||||
|
|
||||||
@REQUEST_TIME.time()
|
@REQUEST_TIME.time()
|
||||||
async def generate_response(self, context : Context, agent : Agent, prompt : str, options: Tunables | None) -> AsyncGenerator[Message, None]:
|
async def generate_response(
|
||||||
|
self, context: Context, agent: Agent, prompt: str, options: Tunables | None
|
||||||
|
) -> AsyncGenerator[Message, None]:
|
||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
raise Exception("File watcher not initialized")
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
@ -880,7 +1045,9 @@ class WebServer:
|
|||||||
if message.status == "error":
|
if message.status == "error":
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"{agent_type}.process_message: {message.status} {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
|
logger.info(
|
||||||
|
f"{agent_type}.process_message: {message.status} {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}"
|
||||||
|
)
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
@ -895,24 +1062,21 @@ class WebServer:
|
|||||||
port=port,
|
port=port,
|
||||||
log_config=None,
|
log_config=None,
|
||||||
ssl_keyfile=defines.key_path,
|
ssl_keyfile=defines.key_path,
|
||||||
ssl_certfile=defines.cert_path
|
ssl_certfile=defines.cert_path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Starting web server at http://{host}:{port}")
|
logger.info(f"Starting web server at http://{host}:{port}")
|
||||||
uvicorn.run(
|
uvicorn.run(self.app, host=host, port=port, log_config=None)
|
||||||
self.app,
|
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
log_config=None
|
|
||||||
)
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
if self.observer:
|
if self.observer:
|
||||||
self.observer.stop()
|
self.observer.stop()
|
||||||
if self.observer:
|
if self.observer:
|
||||||
self.observer.join()
|
self.observer.join()
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
|
||||||
# Main function to run everything
|
# Main function to run everything
|
||||||
def main():
|
def main():
|
||||||
global model
|
global model
|
||||||
@ -923,17 +1087,9 @@ def main():
|
|||||||
# Setup logging based on the provided level
|
# Setup logging based on the provided level
|
||||||
logger.setLevel(args.level.upper())
|
logger.setLevel(args.level.upper())
|
||||||
|
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.*")
|
||||||
"ignore",
|
|
||||||
category=FutureWarning,
|
|
||||||
module="sklearn.*"
|
|
||||||
)
|
|
||||||
|
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings("ignore", category=UserWarning, module="umap.*")
|
||||||
"ignore",
|
|
||||||
category=UserWarning,
|
|
||||||
module="umap.*"
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = ollama.Client(host=args.ollama_server) # type: ignore
|
llm = ollama.Client(host=args.ollama_server) # type: ignore
|
||||||
model = args.ollama_model
|
model = args.ollama_model
|
||||||
@ -942,4 +1098,5 @@ def main():
|
|||||||
|
|
||||||
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
|
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
|
||||||
|
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
@ -2,28 +2,24 @@ from .. utils import logger
|
|||||||
|
|
||||||
import ollama
|
import ollama
|
||||||
|
|
||||||
from .. utils import (
|
from ..utils import rag as Rag, Context, defines
|
||||||
rag as Rag,
|
|
||||||
Context,
|
|
||||||
defines
|
|
||||||
)
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
llm = ollama.Client(host=defines.ollama_api_url)
|
llm = ollama.Client(host=defines.ollama_api_url)
|
||||||
|
|
||||||
observer, file_watcher = Rag.start_file_watcher(
|
observer, file_watcher = Rag.start_file_watcher(
|
||||||
llm=llm,
|
llm=llm, watch_directory=defines.doc_dir, recreate=False # Don't recreate if exists
|
||||||
watch_directory=defines.doc_dir,
|
|
||||||
recreate=False # Don't recreate if exists
|
|
||||||
)
|
)
|
||||||
|
|
||||||
context = Context(file_watcher=file_watcher)
|
context = Context(file_watcher=file_watcher)
|
||||||
data = context.model_dump(mode='json')
|
data = context.model_dump(mode="json")
|
||||||
context = Context.model_validate_json(json.dumps(data))
|
context = Context.model_validate_json(json.dumps(data))
|
||||||
context.file_watcher = file_watcher
|
context.file_watcher = file_watcher
|
||||||
|
|
||||||
agent = context.get_or_create_agent("chat", system_prompt="You are a helpful assistant.")
|
agent = context.get_or_create_agent(
|
||||||
|
"chat", system_prompt="You are a helpful assistant."
|
||||||
|
)
|
||||||
# logger.info(f"data: {data}")
|
# logger.info(f"data: {data}")
|
||||||
# logger.info(f"agent: {agent}")
|
# logger.info(f"agent: {agent}")
|
||||||
agent_type = agent.get_agent_type()
|
agent_type = agent.get_agent_type()
|
||||||
@ -32,7 +28,7 @@ logger.info(f"system_prompt: {agent.system_prompt}")
|
|||||||
|
|
||||||
agent.system_prompt = "Eat more tomatoes."
|
agent.system_prompt = "Eat more tomatoes."
|
||||||
|
|
||||||
data = context.model_dump(mode='json')
|
data = context.model_dump(mode="json")
|
||||||
context = Context.model_validate_json(json.dumps(data))
|
context = Context.model_validate_json(json.dumps(data))
|
||||||
context.file_watcher = file_watcher
|
context.file_watcher = file_watcher
|
||||||
|
|
||||||
|
@ -8,12 +8,14 @@ from anyio.to_thread import run_sync # type: ignore
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RedirectToContext(Exception):
|
class RedirectToContext(Exception):
|
||||||
def __init__(self, url: str):
|
def __init__(self, url: str):
|
||||||
self.url = url
|
self.url = url
|
||||||
logger.info(f"Redirect to Context: {url}")
|
logger.info(f"Redirect to Context: {url}")
|
||||||
super().__init__(f"Redirect to Context: {url}")
|
super().__init__(f"Redirect to Context: {url}")
|
||||||
|
|
||||||
|
|
||||||
class ContextRouteManager:
|
class ContextRouteManager:
|
||||||
def __init__(self, app: FastAPI):
|
def __init__(self, app: FastAPI):
|
||||||
self.app = app
|
self.app = app
|
||||||
@ -25,20 +27,28 @@ class ContextRouteManager:
|
|||||||
logger.info(f"Handling redirect to {exc.url}")
|
logger.info(f"Handling redirect to {exc.url}")
|
||||||
return RedirectResponse(url=exc.url, status_code=307)
|
return RedirectResponse(url=exc.url, status_code=307)
|
||||||
|
|
||||||
def ensure_context(self, route_name: str = "context_id") -> Callable[[Request], Optional[UUID]]:
|
def ensure_context(
|
||||||
|
self, route_name: str = "context_id"
|
||||||
|
) -> Callable[[Request], Optional[UUID]]:
|
||||||
logger.info(f"Setting up context dependency for route parameter: {route_name}")
|
logger.info(f"Setting up context dependency for route parameter: {route_name}")
|
||||||
|
|
||||||
async def _ensure_context_dependency(request: Request) -> Optional[UUID]:
|
async def _ensure_context_dependency(request: Request) -> Optional[UUID]:
|
||||||
logger.info(f"Entering ensure_context_dependency, Request URL: {request.url}")
|
logger.info(
|
||||||
|
f"Entering ensure_context_dependency, Request URL: {request.url}"
|
||||||
|
)
|
||||||
logger.info(f"Path params: {request.path_params}")
|
logger.info(f"Path params: {request.path_params}")
|
||||||
|
|
||||||
path_params = request.path_params
|
path_params = request.path_params
|
||||||
route_value = path_params.get(route_name)
|
route_value = path_params.get(route_name)
|
||||||
logger.info(f"route_value: {route_value!r}, type: {type(route_value)}")
|
logger.info(f"route_value: {route_value!r}, type: {type(route_value)}")
|
||||||
|
|
||||||
if route_value is None or not isinstance(route_value, str) or not route_value.strip():
|
if (
|
||||||
|
route_value is None
|
||||||
|
or not isinstance(route_value, str)
|
||||||
|
or not route_value.strip()
|
||||||
|
):
|
||||||
logger.info(f"route_value is invalid, generating new UUID")
|
logger.info(f"route_value is invalid, generating new UUID")
|
||||||
path = request.url.path.rstrip('/')
|
path = request.url.path.rstrip("/")
|
||||||
new_context = await run_sync(uuid4)
|
new_context = await run_sync(uuid4)
|
||||||
redirect_url = f"{path}/{new_context}"
|
redirect_url = f"{path}/{new_context}"
|
||||||
logger.info(f"Redirecting to {redirect_url}")
|
logger.info(f"Redirecting to {redirect_url}")
|
||||||
@ -50,8 +60,10 @@ class ContextRouteManager:
|
|||||||
logger.info(f"Successfully parsed UUID: {route_context}")
|
logger.info(f"Successfully parsed UUID: {route_context}")
|
||||||
return route_context
|
return route_context
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Failed to parse UUID from route_value: {route_value!r}, error: {str(e)}")
|
logger.error(
|
||||||
path = request.url.path.rstrip('/')
|
f"Failed to parse UUID from route_value: {route_value!r}, error: {str(e)}"
|
||||||
|
)
|
||||||
|
path = request.url.path.rstrip("/")
|
||||||
new_context = await run_sync(uuid4)
|
new_context = await run_sync(uuid4)
|
||||||
redirect_url = f"{path}/{new_context}"
|
redirect_url = f"{path}/{new_context}"
|
||||||
logger.info(f"Invalid UUID, redirecting to {redirect_url}")
|
logger.info(f"Invalid UUID, redirecting to {redirect_url}")
|
||||||
@ -66,44 +78,62 @@ class ContextRouteManager:
|
|||||||
def decorator(func):
|
def decorator(func):
|
||||||
all_dependencies = list(dependencies)
|
all_dependencies = list(dependencies)
|
||||||
all_dependencies.append(Depends(ensure_context))
|
all_dependencies.append(Depends(ensure_context))
|
||||||
logger.info(f"Route {path} registered with dependencies: {all_dependencies}")
|
logger.info(
|
||||||
|
f"Route {path} registered with dependencies: {all_dependencies}"
|
||||||
|
)
|
||||||
return self.app.get(path, dependencies=all_dependencies, **kwargs)(func)
|
return self.app.get(path, dependencies=all_dependencies, **kwargs)(func)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(redirect_slashes=True)
|
app = FastAPI(redirect_slashes=True)
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(Exception)
|
@app.exception_handler(Exception)
|
||||||
async def global_exception_handler(request: Request, exc: Exception):
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
logger.error(f"Unhandled exception: {str(exc)}")
|
logger.error(f"Unhandled exception: {str(exc)}")
|
||||||
logger.error(f"Request URL: {request.url}, Path params: {request.path_params}")
|
logger.error(f"Request URL: {request.url}, Path params: {request.path_params}")
|
||||||
logger.error(f"Stack trace: {''.join(traceback.format_tb(exc.__traceback__))}")
|
logger.error(f"Stack trace: {''.join(traceback.format_tb(exc.__traceback__))}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500, content={"error": "Internal server error", "detail": str(exc)}
|
||||||
content={"error": "Internal server error", "detail": str(exc)}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def log_requests(request: Request, call_next):
|
async def log_requests(request: Request, call_next):
|
||||||
logger.info(f"Incoming request: {request.method} {request.url}, Path params: {request.path_params}")
|
logger.info(
|
||||||
|
f"Incoming request: {request.method} {request.url}, Path params: {request.path_params}"
|
||||||
|
)
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
context_router = ContextRouteManager(app)
|
context_router = ContextRouteManager(app)
|
||||||
|
|
||||||
|
|
||||||
@context_router.route_pattern("/api/history/{context_id}")
|
@context_router.route_pattern("/api/history/{context_id}")
|
||||||
async def get_history(request: Request, context_id: UUID = Depends(context_router.ensure_context()), agent_type: str = Query(..., description="Type of agent to retrieve history for")):
|
async def get_history(
|
||||||
|
request: Request,
|
||||||
|
context_id: UUID = Depends(context_router.ensure_context()),
|
||||||
|
agent_type: str = Query(..., description="Type of agent to retrieve history for"),
|
||||||
|
):
|
||||||
logger.info(f"{request.method} {request.url.path} with context_id: {context_id}")
|
logger.info(f"{request.method} {request.url.path} with context_id: {context_id}")
|
||||||
return {"context_id": str(context_id), "agent_type": agent_type}
|
return {"context_id": str(context_id), "agent_type": agent_type}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/history")
|
@app.get("/api/history")
|
||||||
async def redirect_history(request: Request, agent_type: str = Query(..., description="Type of agent to retrieve history for")):
|
async def redirect_history(
|
||||||
path = request.url.path.rstrip('/')
|
request: Request,
|
||||||
|
agent_type: str = Query(..., description="Type of agent to retrieve history for"),
|
||||||
|
):
|
||||||
|
path = request.url.path.rstrip("/")
|
||||||
new_context = uuid4()
|
new_context = uuid4()
|
||||||
redirect_url = f"{path}/{new_context}?agent_type={agent_type}"
|
redirect_url = f"{path}/{new_context}?agent_type={agent_type}"
|
||||||
logger.info(f"Redirecting from /api/history to {redirect_url}")
|
logger.info(f"Redirecting from /api/history to {redirect_url}")
|
||||||
return RedirectResponse(url=redirect_url, status_code=307)
|
return RedirectResponse(url=redirect_url, status_code=307)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn # type: ignore
|
import uvicorn # type: ignore
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8900)
|
uvicorn.run(app, host="0.0.0.0", port=8900)
|
@ -1,28 +1,24 @@
|
|||||||
# From /opt/backstory run:
|
# From /opt/backstory run:
|
||||||
# python -m src.tests.test-context
|
# python -m src.tests.test-context
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
|
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
|
||||||
|
|
||||||
import ollama
|
import ollama
|
||||||
|
|
||||||
from .. utils import (
|
from ..utils import rag as Rag, Context, defines
|
||||||
rag as Rag,
|
|
||||||
Context,
|
|
||||||
defines
|
|
||||||
)
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
||||||
|
|
||||||
observer, file_watcher = Rag.start_file_watcher(
|
observer, file_watcher = Rag.start_file_watcher(
|
||||||
llm=llm,
|
llm=llm, watch_directory=defines.doc_dir, recreate=False # Don't recreate if exists
|
||||||
watch_directory=defines.doc_dir,
|
|
||||||
recreate=False # Don't recreate if exists
|
|
||||||
)
|
)
|
||||||
|
|
||||||
context = Context(file_watcher=file_watcher)
|
context = Context(file_watcher=file_watcher)
|
||||||
data = context.model_dump(mode='json')
|
data = context.model_dump(mode="json")
|
||||||
context = Context.from_json(json.dumps(data), file_watcher=file_watcher)
|
context = Context.from_json(json.dumps(data), file_watcher=file_watcher)
|
||||||
|
@ -2,12 +2,10 @@
|
|||||||
# python -m src.tests.test-message
|
# python -m src.tests.test-message
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
|
|
||||||
from .. utils import (
|
from ..utils import Message
|
||||||
Message
|
|
||||||
)
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
prompt = "This is a test"
|
prompt = "This is a test"
|
||||||
message = Message(prompt=prompt)
|
message = Message(prompt=prompt)
|
||||||
print(message.model_dump(mode='json'))
|
print(message.model_dump(mode="json"))
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
# From /opt/backstory run:
|
# From /opt/backstory run:
|
||||||
# python -m src.tests.test-metrics
|
# python -m src.tests.test-metrics
|
||||||
from .. utils import (
|
from ..utils import Metrics
|
||||||
Metrics
|
|
||||||
)
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@ -13,7 +11,7 @@ metrics = Metrics()
|
|||||||
metrics.prepare_count.labels(agent="chat").inc()
|
metrics.prepare_count.labels(agent="chat").inc()
|
||||||
metrics.prepare_duration.labels(agent="prepare").observe(0.45)
|
metrics.prepare_duration.labels(agent="prepare").observe(0.45)
|
||||||
|
|
||||||
json = metrics.model_dump(mode='json')
|
json = metrics.model_dump(mode="json")
|
||||||
metrics = Metrics.model_validate(json)
|
metrics = Metrics.model_validate(json)
|
||||||
|
|
||||||
print(metrics)
|
print(metrics)
|
@ -12,21 +12,22 @@ from . agents import class_registry, AnyAgent, Agent, __all__ as agents_all
|
|||||||
from .metrics import Metrics
|
from .metrics import Metrics
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Agent',
|
"Agent",
|
||||||
'Tunables',
|
"Tunables",
|
||||||
'Context',
|
"Context",
|
||||||
'Conversation',
|
"Conversation",
|
||||||
'Message',
|
"Message",
|
||||||
'Metrics',
|
"Metrics",
|
||||||
'ChromaDBFileWatcher',
|
"ChromaDBFileWatcher",
|
||||||
'start_file_watcher',
|
"start_file_watcher",
|
||||||
'logger',
|
"logger",
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__.extend(agents_all) # type: ignore
|
__all__.extend(agents_all) # type: ignore
|
||||||
|
|
||||||
logger = setup_logging(level=defines.logging_level)
|
logger = setup_logging(level=defines.logging_level)
|
||||||
|
|
||||||
|
|
||||||
def rebuild_models():
|
def rebuild_models():
|
||||||
for class_name, (module_name, _) in class_registry.items():
|
for class_name, (module_name, _) in class_registry.items():
|
||||||
try:
|
try:
|
||||||
@ -36,9 +37,15 @@ def rebuild_models():
|
|||||||
logger.debug(f"Checking: {class_name} in module {module_name}")
|
logger.debug(f"Checking: {class_name} in module {module_name}")
|
||||||
logger.debug(f" cls: {True if cls else False}")
|
logger.debug(f" cls: {True if cls else False}")
|
||||||
logger.debug(f" isinstance(cls, type): {isinstance(cls, type)}")
|
logger.debug(f" isinstance(cls, type): {isinstance(cls, type)}")
|
||||||
logger.debug(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}")
|
logger.debug(
|
||||||
logger.debug(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}")
|
f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}"
|
||||||
logger.debug(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}")
|
)
|
||||||
|
logger.debug(
|
||||||
|
f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}"
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cls
|
cls
|
||||||
@ -50,11 +57,13 @@ def rebuild_models():
|
|||||||
logger.debug(f"Rebuilding {class_name} from {module_name}")
|
logger.debug(f"Rebuilding {class_name} from {module_name}")
|
||||||
from .agents import Agent
|
from .agents import Agent
|
||||||
from .context import Context
|
from .context import Context
|
||||||
|
|
||||||
cls.model_rebuild()
|
cls.model_rebuild()
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(f"Failed to import module {module_name}: {e}")
|
logger.error(f"Failed to import module {module_name}: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing {class_name} in {module_name}: {e}")
|
logger.error(f"Error processing {class_name} in {module_name}: {e}")
|
||||||
|
|
||||||
|
|
||||||
# Call this after all modules are imported
|
# Call this after all modules are imported
|
||||||
rebuild_models()
|
rebuild_models()
|
||||||
|
@ -16,7 +16,9 @@ __all__ = [ "AnyAgent", "Agent", "agent_registry", "class_registry" ]
|
|||||||
# Type alias for Agent or any subclass
|
# Type alias for Agent or any subclass
|
||||||
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
|
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
|
||||||
|
|
||||||
class_registry: Dict[str, Tuple[str, str]] = {} # Maps class_name to (module_name, class_name)
|
class_registry: Dict[str, Tuple[str, str]] = (
|
||||||
|
{}
|
||||||
|
) # Maps class_name to (module_name, class_name)
|
||||||
|
|
||||||
package_dir = pathlib.Path(__file__).parent
|
package_dir = pathlib.Path(__file__).parent
|
||||||
package_name = __name__
|
package_name = __name__
|
||||||
|
@ -1,8 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, PrivateAttr, Field # type: ignore
|
from pydantic import BaseModel, PrivateAttr, Field # type: ignore
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal, get_args, List, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any,
|
Literal,
|
||||||
TypeAlias, Dict, Tuple
|
get_args,
|
||||||
|
List,
|
||||||
|
AsyncGenerator,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Optional,
|
||||||
|
ClassVar,
|
||||||
|
Any,
|
||||||
|
TypeAlias,
|
||||||
|
Dict,
|
||||||
|
Tuple,
|
||||||
)
|
)
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
@ -24,24 +33,26 @@ from . types import agent_registry
|
|||||||
from .. import defines
|
from .. import defines
|
||||||
from ..message import Message, Tunables
|
from ..message import Message, Tunables
|
||||||
from ..metrics import Metrics
|
from ..metrics import Metrics
|
||||||
from .. tools import ( TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools ) # type: ignore -- dynamically added to __all__
|
from ..tools import TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools # type: ignore -- dynamically added to __all__
|
||||||
from ..conversation import Conversation
|
from ..conversation import Conversation
|
||||||
|
|
||||||
|
|
||||||
class LLMMessage(BaseModel):
|
class LLMMessage(BaseModel):
|
||||||
role: str = Field(default="")
|
role: str = Field(default="")
|
||||||
content: str = Field(default="")
|
content: str = Field(default="")
|
||||||
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
|
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
|
||||||
|
|
||||||
|
|
||||||
class Agent(BaseModel, ABC):
|
class Agent(BaseModel, ABC):
|
||||||
"""
|
"""
|
||||||
Base class for all agent types.
|
Base class for all agent types.
|
||||||
This class defines the common attributes and methods for all agent types.
|
This class defines the common attributes and methods for all agent types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Agent management with pydantic
|
# Agent management with pydantic
|
||||||
agent_type: Literal["base"] = "base"
|
agent_type: Literal["base"] = "base"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
|
|
||||||
# Tunables (sets default for new Messages attached to this agent)
|
# Tunables (sets default for new Messages attached to this agent)
|
||||||
tunables: Tunables = Field(default_factory=Tunables)
|
tunables: Tunables = Field(default_factory=Tunables)
|
||||||
|
|
||||||
@ -49,11 +60,14 @@ class Agent(BaseModel, ABC):
|
|||||||
system_prompt: str # Mandatory
|
system_prompt: str # Mandatory
|
||||||
conversation: Conversation = Conversation()
|
conversation: Conversation = Conversation()
|
||||||
context_tokens: int = 0
|
context_tokens: int = 0
|
||||||
context: Optional[Context] = Field(default=None, exclude=True) # Avoid circular reference, require as param, and prevent serialization
|
context: Optional[Context] = Field(
|
||||||
|
default=None, exclude=True
|
||||||
|
) # Avoid circular reference, require as param, and prevent serialization
|
||||||
metrics: Metrics = Field(default_factory=Metrics, exclude=True)
|
metrics: Metrics = Field(default_factory=Metrics, exclude=True)
|
||||||
|
|
||||||
# context_size is shared across all subclasses
|
# context_size is shared across all subclasses
|
||||||
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
|
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def context_size(self) -> int:
|
def context_size(self) -> int:
|
||||||
return Agent._context_size
|
return Agent._context_size
|
||||||
@ -62,7 +76,9 @@ class Agent(BaseModel, ABC):
|
|||||||
def context_size(self, value: int):
|
def context_size(self, value: int):
|
||||||
Agent._context_size = value
|
Agent._context_size = value
|
||||||
|
|
||||||
def set_optimal_context_size(self, llm: Any, model: str, prompt: str, ctx_buffer=2048) -> int:
|
def set_optimal_context_size(
|
||||||
|
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
|
||||||
|
) -> int:
|
||||||
# # Get more accurate token count estimate using tiktoken or similar
|
# # Get more accurate token count estimate using tiktoken or similar
|
||||||
# response = llm.generate(
|
# response = llm.generate(
|
||||||
# model=model,
|
# model=model,
|
||||||
@ -83,7 +99,9 @@ class Agent(BaseModel, ABC):
|
|||||||
total_ctx = tokens + ctx_buffer
|
total_ctx = tokens + ctx_buffer
|
||||||
|
|
||||||
if total_ctx > self.context_size:
|
if total_ctx > self.context_size:
|
||||||
logger.info(f"Increasing context size from {self.context_size} to {total_ctx}")
|
logger.info(
|
||||||
|
f"Increasing context size from {self.context_size} to {total_ctx}"
|
||||||
|
)
|
||||||
|
|
||||||
# Grow the context size if necessary
|
# Grow the context size if necessary
|
||||||
self.context_size = max(self.context_size, total_ctx)
|
self.context_size = max(self.context_size, total_ctx)
|
||||||
@ -95,7 +113,7 @@ class Agent(BaseModel, ABC):
|
|||||||
"""Auto-register subclasses"""
|
"""Auto-register subclasses"""
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
# Register this class if it has an agent_type
|
# Register this class if it has an agent_type
|
||||||
if hasattr(cls, 'agent_type') and cls.agent_type != Agent._agent_type:
|
if hasattr(cls, "agent_type") and cls.agent_type != Agent._agent_type:
|
||||||
agent_registry.register(cls.agent_type, cls)
|
agent_registry.register(cls.agent_type, cls)
|
||||||
|
|
||||||
# def __init__(self, *, context=context, **data):
|
# def __init__(self, *, context=context, **data):
|
||||||
@ -166,7 +184,14 @@ class Agent(BaseModel, ABC):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
async def process_tool_calls(self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[LLMMessage]) -> AsyncGenerator[Message, None]:
|
async def process_tool_calls(
|
||||||
|
self,
|
||||||
|
llm: Any,
|
||||||
|
model: str,
|
||||||
|
message: Message,
|
||||||
|
tool_message: Any,
|
||||||
|
messages: List[LLMMessage],
|
||||||
|
) -> AsyncGenerator[Message, None]:
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
|
||||||
self.metrics.tool_count.labels(agent=self.agent_type).inc()
|
self.metrics.tool_count.labels(agent=self.agent_type).inc()
|
||||||
@ -187,7 +212,9 @@ class Agent(BaseModel, ABC):
|
|||||||
tool = tool_call.function.name
|
tool = tool_call.function.name
|
||||||
|
|
||||||
# Yield status update before processing each tool
|
# Yield status update before processing each tool
|
||||||
message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
|
message.response = (
|
||||||
|
f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
|
||||||
|
)
|
||||||
yield message
|
yield message
|
||||||
logger.info(f"LLM - {message.response}")
|
logger.info(f"LLM - {message.response}")
|
||||||
|
|
||||||
@ -202,12 +229,18 @@ class Agent(BaseModel, ABC):
|
|||||||
|
|
||||||
case "AnalyzeSite":
|
case "AnalyzeSite":
|
||||||
url = arguments.get("url")
|
url = arguments.get("url")
|
||||||
question = arguments.get("question", "what is the summary of this content?")
|
question = arguments.get(
|
||||||
|
"question", "what is the summary of this content?"
|
||||||
|
)
|
||||||
|
|
||||||
# Additional status update for long-running operations
|
# Additional status update for long-running operations
|
||||||
message.response = f"Retrieving and summarizing content from {url}..."
|
message.response = (
|
||||||
|
f"Retrieving and summarizing content from {url}..."
|
||||||
|
)
|
||||||
yield message
|
yield message
|
||||||
ret = await AnalyzeSite(llm=llm, model=model, url=url, question=question)
|
ret = await AnalyzeSite(
|
||||||
|
llm=llm, model=model, url=url, question=question
|
||||||
|
)
|
||||||
|
|
||||||
case "DateTime":
|
case "DateTime":
|
||||||
tz = arguments.get("timezone")
|
tz = arguments.get("timezone")
|
||||||
@ -217,7 +250,9 @@ class Agent(BaseModel, ABC):
|
|||||||
city = arguments.get("city")
|
city = arguments.get("city")
|
||||||
state = arguments.get("state")
|
state = arguments.get("state")
|
||||||
|
|
||||||
message.response = f"Fetching weather data for {city}, {state}..."
|
message.response = (
|
||||||
|
f"Fetching weather data for {city}, {state}..."
|
||||||
|
)
|
||||||
yield message
|
yield message
|
||||||
ret = WeatherForecast(city, state)
|
ret = WeatherForecast(city, state)
|
||||||
|
|
||||||
@ -228,7 +263,7 @@ class Agent(BaseModel, ABC):
|
|||||||
tool_response = {
|
tool_response = {
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"content": json.dumps(ret),
|
"content": json.dumps(ret),
|
||||||
"name": tool_call.function.name
|
"name": tool_call.function.name,
|
||||||
}
|
}
|
||||||
|
|
||||||
tool_metadata["tool_calls"].append(tool_response)
|
tool_metadata["tool_calls"].append(tool_response)
|
||||||
@ -241,13 +276,15 @@ class Agent(BaseModel, ABC):
|
|||||||
message_dict = LLMMessage(
|
message_dict = LLMMessage(
|
||||||
role=tool_message.get("role", "assistant"),
|
role=tool_message.get("role", "assistant"),
|
||||||
content=tool_message.get("content", ""),
|
content=tool_message.get("content", ""),
|
||||||
tool_calls=[ {
|
tool_calls=[
|
||||||
|
{
|
||||||
"function": {
|
"function": {
|
||||||
"name": tc["function"]["name"],
|
"name": tc["function"]["name"],
|
||||||
"arguments": tc["function"]["arguments"]
|
"arguments": tc["function"]["arguments"],
|
||||||
}
|
}
|
||||||
} for tc in tool_message.tool_calls
|
}
|
||||||
]
|
for tc in tool_message.tool_calls
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
messages.append(message_dict)
|
messages.append(message_dict)
|
||||||
@ -282,20 +319,30 @@ class Agent(BaseModel, ABC):
|
|||||||
message.metadata["eval_count"] += response.eval_count
|
message.metadata["eval_count"] += response.eval_count
|
||||||
message.metadata["eval_duration"] += response.eval_duration
|
message.metadata["eval_duration"] += response.eval_duration
|
||||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
||||||
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
|
message.metadata[
|
||||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
"prompt_eval_duration"
|
||||||
|
] += response.prompt_eval_duration
|
||||||
|
self.context_tokens = (
|
||||||
|
response.prompt_eval_count + response.eval_count
|
||||||
|
)
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}"
|
message.metadata["timers"][
|
||||||
|
"llm_with_tools"
|
||||||
|
] = f"{(end_time - start_time):.4f}"
|
||||||
return
|
return
|
||||||
|
|
||||||
def collect_metrics(self, response):
|
def collect_metrics(self, response):
|
||||||
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(response.prompt_eval_count)
|
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
|
||||||
|
response.prompt_eval_count
|
||||||
|
)
|
||||||
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
|
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
|
||||||
|
|
||||||
async def generate_llm_response(self, llm: Any, model: str, message: Message, temperature = 0.7) -> AsyncGenerator[Message, None]:
|
async def generate_llm_response(
|
||||||
|
self, llm: Any, model: str, message: Message, temperature=0.7
|
||||||
|
) -> AsyncGenerator[Message, None]:
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
|
||||||
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||||
@ -305,22 +352,29 @@ class Agent(BaseModel, ABC):
|
|||||||
|
|
||||||
# Create a pruned down message list based purely on the prompt and responses,
|
# Create a pruned down message list based purely on the prompt and responses,
|
||||||
# discarding the full preamble generated by prepare_message
|
# discarding the full preamble generated by prepare_message
|
||||||
messages: List[LLMMessage] = [ LLMMessage(role="system", content=message.system_prompt) ]
|
messages: List[LLMMessage] = [
|
||||||
messages.extend([
|
LLMMessage(role="system", content=message.system_prompt)
|
||||||
item for m in self.conversation
|
]
|
||||||
|
messages.extend(
|
||||||
|
[
|
||||||
|
item
|
||||||
|
for m in self.conversation
|
||||||
for item in [
|
for item in [
|
||||||
LLMMessage(role="user", content=m.prompt.strip()),
|
LLMMessage(role="user", content=m.prompt.strip()),
|
||||||
LLMMessage(role="assistant", content=m.response.strip())
|
LLMMessage(role="assistant", content=m.response.strip()),
|
||||||
]
|
]
|
||||||
])
|
]
|
||||||
|
)
|
||||||
# Only the actual user query is provided with the full context message
|
# Only the actual user query is provided with the full context message
|
||||||
messages.append(LLMMessage(role="user", content=message.context_prompt.strip()))
|
messages.append(
|
||||||
|
LLMMessage(role="user", content=message.context_prompt.strip())
|
||||||
|
)
|
||||||
|
|
||||||
# message.metadata["messages"] = messages
|
# message.metadata["messages"] = messages
|
||||||
message.metadata["options"] = {
|
message.metadata["options"] = {
|
||||||
"seed": 8911,
|
"seed": 8911,
|
||||||
"num_ctx": self.context_size,
|
"num_ctx": self.context_size,
|
||||||
"temperature": temperature # Higher temperature to encourage tool usage
|
"temperature": temperature, # Higher temperature to encourage tool usage
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create a dict for storing various timing stats
|
# Create a dict for storing various timing stats
|
||||||
@ -329,7 +383,7 @@ class Agent(BaseModel, ABC):
|
|||||||
use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
||||||
message.metadata["tools"] = {
|
message.metadata["tools"] = {
|
||||||
"available": llm_tools(self.context.tools),
|
"available": llm_tools(self.context.tools),
|
||||||
"used": False
|
"used": False,
|
||||||
}
|
}
|
||||||
tool_metadata = message.metadata["tools"]
|
tool_metadata = message.metadata["tools"]
|
||||||
|
|
||||||
@ -357,12 +411,14 @@ class Agent(BaseModel, ABC):
|
|||||||
**message.metadata["options"],
|
**message.metadata["options"],
|
||||||
# "num_predict": 1024, # "Low" token limit to cut off after tool call
|
# "num_predict": 1024, # "Low" token limit to cut off after tool call
|
||||||
},
|
},
|
||||||
stream=False # No need to stream the probe
|
stream=False, # No need to stream the probe
|
||||||
)
|
)
|
||||||
self.collect_metrics(response)
|
self.collect_metrics(response)
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
message.metadata["timers"]["tool_check"] = f"{(end_time - start_time):.4f}"
|
message.metadata["timers"][
|
||||||
|
"tool_check"
|
||||||
|
] = f"{(end_time - start_time):.4f}"
|
||||||
if not response.message.tool_calls:
|
if not response.message.tool_calls:
|
||||||
logger.info("LLM indicates tools will not be used")
|
logger.info("LLM indicates tools will not be used")
|
||||||
# The LLM will not use tools, so disable use_tools so we can stream the full response
|
# The LLM will not use tools, so disable use_tools so we can stream the full response
|
||||||
@ -374,7 +430,9 @@ class Agent(BaseModel, ABC):
|
|||||||
logger.info("LLM indicates tools will be used")
|
logger.info("LLM indicates tools will be used")
|
||||||
|
|
||||||
# Tools are enabled and available and the LLM indicated it will use them
|
# Tools are enabled and available and the LLM indicated it will use them
|
||||||
message.response = f"Performing tool analysis step 2/2 (tool use suspected)..."
|
message.response = (
|
||||||
|
f"Performing tool analysis step 2/2 (tool use suspected)..."
|
||||||
|
)
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
logger.info(f"Performing LLM call with tools")
|
logger.info(f"Performing LLM call with tools")
|
||||||
@ -386,12 +444,14 @@ class Agent(BaseModel, ABC):
|
|||||||
options={
|
options={
|
||||||
**message.metadata["options"],
|
**message.metadata["options"],
|
||||||
},
|
},
|
||||||
stream=False
|
stream=False,
|
||||||
)
|
)
|
||||||
self.collect_metrics(response)
|
self.collect_metrics(response)
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
message.metadata["timers"]["non_streaming"] = f"{(end_time - start_time):.4f}"
|
message.metadata["timers"][
|
||||||
|
"non_streaming"
|
||||||
|
] = f"{(end_time - start_time):.4f}"
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
@ -403,13 +463,21 @@ class Agent(BaseModel, ABC):
|
|||||||
tool_metadata["used"] = response.message.tool_calls
|
tool_metadata["used"] = response.message.tool_calls
|
||||||
# Process all yielded items from the handler
|
# Process all yielded items from the handler
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages):
|
async for message in self.process_tool_calls(
|
||||||
|
llm=llm,
|
||||||
|
model=model,
|
||||||
|
message=message,
|
||||||
|
tool_message=response.message,
|
||||||
|
messages=messages,
|
||||||
|
):
|
||||||
if message.status == "error":
|
if message.status == "error":
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
yield message
|
yield message
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
message.metadata["timers"]["process_tool_calls"] = f"{(end_time - start_time):.4f}"
|
message.metadata["timers"][
|
||||||
|
"process_tool_calls"
|
||||||
|
] = f"{(end_time - start_time):.4f}"
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -452,8 +520,12 @@ class Agent(BaseModel, ABC):
|
|||||||
message.metadata["eval_count"] += response.eval_count
|
message.metadata["eval_count"] += response.eval_count
|
||||||
message.metadata["eval_duration"] += response.eval_duration
|
message.metadata["eval_duration"] += response.eval_duration
|
||||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
||||||
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
|
message.metadata[
|
||||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
"prompt_eval_duration"
|
||||||
|
] += response.prompt_eval_duration
|
||||||
|
self.context_tokens = (
|
||||||
|
response.prompt_eval_count + response.eval_count
|
||||||
|
)
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
@ -461,7 +533,9 @@ class Agent(BaseModel, ABC):
|
|||||||
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
|
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
|
||||||
return
|
return
|
||||||
|
|
||||||
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
|
async def process_message(
|
||||||
|
self, llm: Any, model: str, message: Message
|
||||||
|
) -> AsyncGenerator[Message, None]:
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
|
||||||
self.metrics.process_count.labels(agent=self.agent_type).inc()
|
self.metrics.process_count.labels(agent=self.agent_type).inc()
|
||||||
@ -470,22 +544,30 @@ class Agent(BaseModel, ABC):
|
|||||||
if not self.context:
|
if not self.context:
|
||||||
raise ValueError("Context is not set for this agent.")
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
logger.info(
|
||||||
spinner: List[str] = ['\\', '|', '/', '-']
|
"TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
|
||||||
|
)
|
||||||
|
spinner: List[str] = ["\\", "|", "/", "-"]
|
||||||
tick: int = 0
|
tick: int = 0
|
||||||
while self.context.processing:
|
while self.context.processing:
|
||||||
message.status = "waiting"
|
message.status = "waiting"
|
||||||
message.response = f"Busy processing another request. Please wait. {spinner[tick]}"
|
message.response = (
|
||||||
|
f"Busy processing another request. Please wait. {spinner[tick]}"
|
||||||
|
)
|
||||||
tick = (tick + 1) % len(spinner)
|
tick = (tick + 1) % len(spinner)
|
||||||
yield message
|
yield message
|
||||||
await asyncio.sleep(1) # Allow the event loop to process the write
|
await asyncio.sleep(1) # Allow the event loop to process the write
|
||||||
|
|
||||||
self.context.processing = True
|
self.context.processing = True
|
||||||
|
|
||||||
message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
|
message.metadata["system_prompt"] = (
|
||||||
|
f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
|
||||||
|
)
|
||||||
message.context_prompt = ""
|
message.context_prompt = ""
|
||||||
for p in message.preamble.keys():
|
for p in message.preamble.keys():
|
||||||
message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n</|{p}>\n\n"
|
message.context_prompt += (
|
||||||
|
f"\n<|{p}|>\n{message.preamble[p].strip()}\n</|{p}>\n\n"
|
||||||
|
)
|
||||||
message.context_prompt += f"{message.prompt}"
|
message.context_prompt += f"{message.prompt}"
|
||||||
|
|
||||||
# Estimate token length of new messages
|
# Estimate token length of new messages
|
||||||
@ -493,13 +575,17 @@ class Agent(BaseModel, ABC):
|
|||||||
message.status = "thinking"
|
message.status = "thinking"
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt)
|
message.metadata["context_size"] = self.set_optimal_context_size(
|
||||||
|
llm, model, prompt=message.context_prompt
|
||||||
|
)
|
||||||
|
|
||||||
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
||||||
message.status = "thinking"
|
message.status = "thinking"
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
async for message in self.generate_llm_response(llm=llm, model=model, message=message):
|
async for message in self.generate_llm_response(
|
||||||
|
llm=llm, model=model, message=message
|
||||||
|
):
|
||||||
# logger.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
|
# logger.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
|
||||||
if message.status == "error":
|
if message.status == "error":
|
||||||
yield message
|
yield message
|
||||||
@ -514,6 +600,6 @@ class Agent(BaseModel, ABC):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
agent_registry.register(Agent._agent_type, Agent)
|
agent_registry.register(Agent._agent_type, Agent)
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import inspect
|
|||||||
from .base import Agent, agent_registry
|
from .base import Agent, agent_registry
|
||||||
from ..message import Message
|
from ..message import Message
|
||||||
from ..setup_logging import setup_logging
|
from ..setup_logging import setup_logging
|
||||||
|
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
|
|
||||||
system_message = f"""
|
system_message = f"""
|
||||||
@ -26,10 +27,12 @@ When answering queries, follow these steps:
|
|||||||
Always use tools, <|resume|>, and <|context|> when possible. Be concise, and never make up information. If you do not know the answer, say so.
|
Always use tools, <|resume|>, and <|context|> when possible. Be concise, and never make up information. If you do not know the answer, say so.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Chat(Agent):
|
class Chat(Agent):
|
||||||
"""
|
"""
|
||||||
Chat Agent
|
Chat Agent
|
||||||
"""
|
"""
|
||||||
|
|
||||||
agent_type: Literal["chat"] = "chat" # type: ignore
|
agent_type: Literal["chat"] = "chat" # type: ignore
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
@ -46,15 +49,20 @@ class Chat(Agent):
|
|||||||
|
|
||||||
if message.preamble:
|
if message.preamble:
|
||||||
excluded = {}
|
excluded = {}
|
||||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys() if p not in excluded]
|
preamble_types = [
|
||||||
|
f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
|
||||||
|
]
|
||||||
preamble_types_AND = " and ".join(preamble_types)
|
preamble_types_AND = " and ".join(preamble_types)
|
||||||
preamble_types_OR = " or ".join(preamble_types)
|
preamble_types_OR = " or ".join(preamble_types)
|
||||||
message.preamble["rules"] = f"""\
|
message.preamble[
|
||||||
|
"rules"
|
||||||
|
] = f"""\
|
||||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
||||||
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
||||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||||
"""
|
"""
|
||||||
message.preamble["question"] = "Respond to:"
|
message.preamble["question"] = "Respond to:"
|
||||||
|
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
agent_registry.register(Chat._agent_type, Chat)
|
agent_registry.register(Chat._agent_type, Chat)
|
||||||
|
@ -1,6 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import model_validator # type: ignore
|
from pydantic import model_validator # type: ignore
|
||||||
from typing import Literal, ClassVar, Optional, Any, AsyncGenerator, List # NOTE: You must import Optional for late binding to work
|
from typing import (
|
||||||
|
Literal,
|
||||||
|
ClassVar,
|
||||||
|
Optional,
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
List,
|
||||||
|
) # NOTE: You must import Optional for late binding to work
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
@ -8,6 +15,7 @@ from . base import Agent, agent_registry
|
|||||||
from ..conversation import Conversation
|
from ..conversation import Conversation
|
||||||
from ..message import Message
|
from ..message import Message
|
||||||
from ..setup_logging import setup_logging
|
from ..setup_logging import setup_logging
|
||||||
|
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
|
|
||||||
system_fact_check = f"""
|
system_fact_check = f"""
|
||||||
@ -21,6 +29,7 @@ When answering queries, follow these steps:
|
|||||||
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|generated-resume|>, or <|resume|> tags.
|
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|generated-resume|>, or <|resume|> tags.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
class FactCheck(Agent):
|
class FactCheck(Agent):
|
||||||
agent_type: Literal["fact_check"] = "fact_check" # type: ignore
|
agent_type: Literal["fact_check"] = "fact_check" # type: ignore
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
@ -53,10 +62,14 @@ class FactCheck(Agent):
|
|||||||
message.preamble["discrepancies"] = self.facts
|
message.preamble["discrepancies"] = self.facts
|
||||||
|
|
||||||
excluded = {"job_description"}
|
excluded = {"job_description"}
|
||||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys() if p not in excluded]
|
preamble_types = [
|
||||||
|
f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
|
||||||
|
]
|
||||||
preamble_types_AND = " and ".join(preamble_types)
|
preamble_types_AND = " and ".join(preamble_types)
|
||||||
preamble_types_OR = " or ".join(preamble_types)
|
preamble_types_OR = " or ".join(preamble_types)
|
||||||
message.preamble["rules"] = f"""\
|
message.preamble[
|
||||||
|
"rules"
|
||||||
|
] = f"""\
|
||||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
||||||
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
||||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||||
@ -66,5 +79,6 @@ class FactCheck(Agent):
|
|||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
agent_registry.register(FactCheck._agent_type, FactCheck)
|
agent_registry.register(FactCheck._agent_type, FactCheck)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,12 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import model_validator # type: ignore
|
from pydantic import model_validator # type: ignore
|
||||||
from typing import Literal, ClassVar, Optional, Any, AsyncGenerator, List # NOTE: You must import Optional for late binding to work
|
from typing import (
|
||||||
|
Literal,
|
||||||
|
ClassVar,
|
||||||
|
Optional,
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
List,
|
||||||
|
) # NOTE: You must import Optional for late binding to work
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from .base import Agent, agent_registry
|
from .base import Agent, agent_registry
|
||||||
from ..message import Message
|
from ..message import Message
|
||||||
from ..setup_logging import setup_logging
|
from ..setup_logging import setup_logging
|
||||||
|
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
|
|
||||||
system_fact_check = f"""
|
system_fact_check = f"""
|
||||||
@ -36,6 +44,7 @@ When answering queries, follow these steps:
|
|||||||
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|job_description|>, <|resume|>, or <|context|> tags.
|
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|job_description|>, <|resume|>, or <|context|> tags.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
class Resume(Agent):
|
class Resume(Agent):
|
||||||
agent_type: Literal["resume"] = "resume" # type: ignore
|
agent_type: Literal["resume"] = "resume" # type: ignore
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
@ -69,10 +78,14 @@ class Resume(Agent):
|
|||||||
message.preamble["job_description"] = job_description_agent.job_description
|
message.preamble["job_description"] = job_description_agent.job_description
|
||||||
|
|
||||||
excluded = {}
|
excluded = {}
|
||||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys() if p not in excluded]
|
preamble_types = [
|
||||||
|
f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
|
||||||
|
]
|
||||||
preamble_types_AND = " and ".join(preamble_types)
|
preamble_types_AND = " and ".join(preamble_types)
|
||||||
preamble_types_OR = " or ".join(preamble_types)
|
preamble_types_OR = " or ".join(preamble_types)
|
||||||
message.preamble["rules"] = f"""\
|
message.preamble[
|
||||||
|
"rules"
|
||||||
|
] = f"""\
|
||||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
||||||
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
||||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||||
@ -81,12 +94,16 @@ class Resume(Agent):
|
|||||||
if fact_check_agent:
|
if fact_check_agent:
|
||||||
message.preamble["question"] = "Respond to:"
|
message.preamble["question"] = "Respond to:"
|
||||||
else:
|
else:
|
||||||
message.preamble["question"] = f"Fact check the <|generated-resume|> based on the <|resume|>{' and <|context|>' if 'context' in message.preamble else ''}."
|
message.preamble["question"] = (
|
||||||
|
f"Fact check the <|generated-resume|> based on the <|resume|>{' and <|context|>' if 'context' in message.preamble else ''}."
|
||||||
|
)
|
||||||
|
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
|
async def process_message(
|
||||||
|
self, llm: Any, model: str, message: Message
|
||||||
|
) -> AsyncGenerator[Message, None]:
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
if not self.context:
|
if not self.context:
|
||||||
raise ValueError("Context is not set for this agent.")
|
raise ValueError("Context is not set for this agent.")
|
||||||
@ -103,7 +120,9 @@ class Resume(Agent):
|
|||||||
|
|
||||||
# Instantiate the "resume" agent, and seed (or reset) its conversation
|
# Instantiate the "resume" agent, and seed (or reset) its conversation
|
||||||
# with this message.
|
# with this message.
|
||||||
fact_check_agent = self.context.get_or_create_agent(agent_type="fact_check", facts=message.response)
|
fact_check_agent = self.context.get_or_create_agent(
|
||||||
|
agent_type="fact_check", facts=message.response
|
||||||
|
)
|
||||||
first_fact_check_message = message.copy()
|
first_fact_check_message = message.copy()
|
||||||
first_fact_check_message.prompt = "Fact check the generated resume."
|
first_fact_check_message.prompt = "Fact check the generated resume."
|
||||||
fact_check_agent.conversation.add(first_fact_check_message)
|
fact_check_agent.conversation.add(first_fact_check_message)
|
||||||
@ -113,5 +132,6 @@ class Resume(Agent):
|
|||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
agent_registry.register(Resume._agent_type, Resume)
|
agent_registry.register(Resume._agent_type, Resume)
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import List, Dict, Optional, Type
|
from typing import List, Dict, Optional, Type
|
||||||
|
|
||||||
|
|
||||||
# We'll use a registry pattern rather than hardcoded strings
|
# We'll use a registry pattern rather than hardcoded strings
|
||||||
class AgentRegistry:
|
class AgentRegistry:
|
||||||
"""Registry for agent types and classes"""
|
"""Registry for agent types and classes"""
|
||||||
|
|
||||||
_registry: Dict[str, Type] = {}
|
_registry: Dict[str, Type] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -27,5 +29,6 @@ class AgentRegistry:
|
|||||||
"""Get all registered agent classes"""
|
"""Get all registered agent classes"""
|
||||||
return cls._registry.copy()
|
return cls._registry.copy()
|
||||||
|
|
||||||
|
|
||||||
# Create a singleton instance
|
# Create a singleton instance
|
||||||
agent_registry = AgentRegistry()
|
agent_registry = AgentRegistry()
|
@ -1,122 +0,0 @@
|
|||||||
import chromadb
|
|
||||||
from typing import List, Dict, Any, Union
|
|
||||||
from . import defines
|
|
||||||
from .chunk import chunk_document
|
|
||||||
import ollama
|
|
||||||
|
|
||||||
def init_chroma_client(persist_directory: str = defines.persist_directory):
|
|
||||||
"""Initialize and return a ChromaDB client."""
|
|
||||||
# return chromadb.PersistentClient(path=persist_directory)
|
|
||||||
return chromadb.Client()
|
|
||||||
|
|
||||||
def create_or_get_collection(db: chromadb.Client, collection_name: str):
|
|
||||||
"""Create or get a ChromaDB collection."""
|
|
||||||
try:
|
|
||||||
return db.get_collection(
|
|
||||||
name=collection_name
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
return db.create_collection(
|
|
||||||
name=collection_name,
|
|
||||||
metadata={"hnsw:space": "cosine"}
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_documents_to_chroma(
|
|
||||||
client: ollama.Client,
|
|
||||||
documents: List[Dict[str, Any]],
|
|
||||||
collection_name: str = "document_collection",
|
|
||||||
text_key: str = "text",
|
|
||||||
max_tokens: int = 512,
|
|
||||||
overlap: int = 50,
|
|
||||||
model: str = defines.encoding_model,
|
|
||||||
persist_directory: str = defines.persist_directory
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Process documents, chunk them, compute embeddings, and store in ChromaDB.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of document dictionaries
|
|
||||||
collection_name: Name for the ChromaDB collection
|
|
||||||
text_key: The key containing text content
|
|
||||||
max_tokens: Maximum tokens per chunk
|
|
||||||
overlap: Token overlap between chunks
|
|
||||||
model: Ollama model for embeddings
|
|
||||||
persist_directory: Directory to store ChromaDB data
|
|
||||||
"""
|
|
||||||
# Initialize ChromaDB client and collection
|
|
||||||
db = init_chroma_client(persist_directory)
|
|
||||||
collection = create_or_get_collection(db, collection_name)
|
|
||||||
|
|
||||||
# Process each document
|
|
||||||
for doc in documents:
|
|
||||||
# Chunk the document
|
|
||||||
doc_chunks = chunk_document(doc, text_key, max_tokens, overlap)
|
|
||||||
|
|
||||||
# Prepare data for ChromaDB
|
|
||||||
ids = []
|
|
||||||
texts = []
|
|
||||||
metadatas = []
|
|
||||||
embeddings = []
|
|
||||||
|
|
||||||
for chunk in doc_chunks:
|
|
||||||
# Create a unique ID for the chunk
|
|
||||||
chunk_id = f"{chunk['id']}_{chunk['chunk_id']}"
|
|
||||||
|
|
||||||
# Extract text
|
|
||||||
text = chunk[text_key]
|
|
||||||
|
|
||||||
# Create metadata (excluding text and embedding to avoid duplication)
|
|
||||||
metadata = {k: v for k, v in chunk.items() if k != text_key and k != "embedding"}
|
|
||||||
|
|
||||||
response = client.embed(model=model, input=text)
|
|
||||||
embedding = response["embeddings"][0]
|
|
||||||
ids.append(chunk_id)
|
|
||||||
texts.append(text)
|
|
||||||
metadatas.append(metadata)
|
|
||||||
embeddings.append(embedding)
|
|
||||||
|
|
||||||
# Add chunks to ChromaDB collection
|
|
||||||
collection.add(
|
|
||||||
ids=ids,
|
|
||||||
documents=texts,
|
|
||||||
embeddings=embeddings,
|
|
||||||
metadatas=metadatas
|
|
||||||
)
|
|
||||||
|
|
||||||
return collection
|
|
||||||
|
|
||||||
def query_chroma(
|
|
||||||
client: ollama.Client,
|
|
||||||
query_text: str,
|
|
||||||
collection_name: str = "document_collection",
|
|
||||||
n_results: int = 5,
|
|
||||||
model: str = defines.encoding_model,
|
|
||||||
persist_directory: str = defines.persist_directory
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Query ChromaDB for similar documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_text: The text to search for
|
|
||||||
collection_name: Name of the ChromaDB collection
|
|
||||||
n_results: Number of results to return
|
|
||||||
model: Ollama model for embedding the query
|
|
||||||
persist_directory: Directory where ChromaDB data is stored
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Query results from ChromaDB
|
|
||||||
"""
|
|
||||||
# Initialize ChromaDB client and collection
|
|
||||||
db = init_chroma_client(persist_directory)
|
|
||||||
collection = create_or_get_collection(db, collection_name)
|
|
||||||
|
|
||||||
query_response = client.embed(model=model, input=query_text)
|
|
||||||
query_embeddings = query_response["embeddings"]
|
|
||||||
|
|
||||||
# Query the collection
|
|
||||||
results = collection.query(
|
|
||||||
query_embeddings=query_embeddings,
|
|
||||||
n_results=n_results
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
@ -1,88 +0,0 @@
|
|||||||
import tiktoken # type: ignore
|
|
||||||
from . import defines
|
|
||||||
from typing import List, Dict, Any, Union
|
|
||||||
|
|
||||||
def get_encoding(model=defines.model):
|
|
||||||
"""Get the tokenizer for counting tokens."""
|
|
||||||
try:
|
|
||||||
return tiktoken.get_encoding("cl100k_base") # Default encoding used by many embedding models
|
|
||||||
except:
|
|
||||||
return tiktoken.encoding_for_model(model)
|
|
||||||
|
|
||||||
def count_tokens(text: str) -> int:
|
|
||||||
"""Count the number of tokens in a text string."""
|
|
||||||
encoding = get_encoding()
|
|
||||||
return len(encoding.encode(text))
|
|
||||||
|
|
||||||
def chunk_text(text: str, max_tokens: int = 512, overlap: int = 50) -> List[str]:
|
|
||||||
"""
|
|
||||||
Split a text into chunks based on token count with overlap between chunks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to split into chunks
|
|
||||||
max_tokens: Maximum number of tokens per chunk
|
|
||||||
overlap: Number of tokens to overlap between chunks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of text chunks
|
|
||||||
"""
|
|
||||||
if not text or max_tokens <= 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
encoding = get_encoding()
|
|
||||||
tokens = encoding.encode(text)
|
|
||||||
chunks = []
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
while i < len(tokens):
|
|
||||||
# Get the current chunk of tokens
|
|
||||||
chunk_end = min(i + max_tokens, len(tokens))
|
|
||||||
chunk_tokens = tokens[i:chunk_end]
|
|
||||||
chunks.append(encoding.decode(chunk_tokens))
|
|
||||||
|
|
||||||
# Move to the next position with overlap
|
|
||||||
if chunk_end == len(tokens):
|
|
||||||
break
|
|
||||||
i += max_tokens - overlap
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
def chunk_document(document: Dict[str, Any],
|
|
||||||
text_key: str = "text",
|
|
||||||
max_tokens: int = 512,
|
|
||||||
overlap: int = 50) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Chunk a document dictionary into multiple chunks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
document: Document dictionary with metadata and text
|
|
||||||
text_key: The key in the document that contains the text to chunk
|
|
||||||
max_tokens: Maximum number of tokens per chunk
|
|
||||||
overlap: Number of tokens to overlap between chunks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of document dictionaries, each with chunked text and preserved metadata
|
|
||||||
"""
|
|
||||||
if text_key not in document:
|
|
||||||
raise Exception(f"{text_key} not in document")
|
|
||||||
|
|
||||||
# Extract text and create chunks
|
|
||||||
if "title" in document:
|
|
||||||
text = f"{document["title"]}: {document[text_key]}"
|
|
||||||
else:
|
|
||||||
text = document[text_key]
|
|
||||||
chunks = chunk_text(text, max_tokens, overlap)
|
|
||||||
|
|
||||||
# Create document chunks with preserved metadata
|
|
||||||
chunked_docs = []
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
# Create a new doc with all original fields
|
|
||||||
doc_chunk = document.copy()
|
|
||||||
# Replace text with the chunk
|
|
||||||
doc_chunk[text_key] = chunk
|
|
||||||
# Add chunk metadata
|
|
||||||
doc_chunk["chunk_id"] = i
|
|
||||||
doc_chunk["chunk_total"] = len(chunks)
|
|
||||||
chunked_docs.append(doc_chunk)
|
|
||||||
|
|
||||||
return chunked_docs
|
|
@ -17,16 +17,19 @@ from . agents import AnyAgent
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Context(BaseModel):
|
class Context(BaseModel):
|
||||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
||||||
# Required fields
|
# Required fields
|
||||||
file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
||||||
prometheus_collector: Optional[CollectorRegistry] = Field(default=None, exclude=True)
|
prometheus_collector: Optional[CollectorRegistry] = Field(
|
||||||
|
default=None, exclude=True
|
||||||
|
)
|
||||||
|
|
||||||
# Optional fields
|
# Optional fields
|
||||||
id: str = Field(
|
id: str = Field(
|
||||||
default_factory=lambda: str(uuid4()),
|
default_factory=lambda: str(uuid4()),
|
||||||
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||||
)
|
)
|
||||||
user_resume: Optional[str] = None
|
user_resume: Optional[str] = None
|
||||||
user_job_description: Optional[str] = None
|
user_job_description: Optional[str] = None
|
||||||
@ -53,12 +56,16 @@ class Context(BaseModel):
|
|||||||
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
|
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
|
||||||
agent_types = [agent.agent_type for agent in self.agents]
|
agent_types = [agent.agent_type for agent in self.agents]
|
||||||
if len(agent_types) != len(set(agent_types)):
|
if len(agent_types) != len(set(agent_types)):
|
||||||
raise ValueError("Context cannot contain multiple agents of the same agent_type")
|
raise ValueError(
|
||||||
|
"Context cannot contain multiple agents of the same agent_type"
|
||||||
|
)
|
||||||
# for agent in self.agents:
|
# for agent in self.agents:
|
||||||
# agent.set_context(self)
|
# agent.set_context(self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def generate_rag_results(self, message: Message) -> Generator[Message, None, None]:
|
def generate_rag_results(
|
||||||
|
self, message: Message, top_k=10, threshold=0.7
|
||||||
|
) -> Generator[Message, None, None]:
|
||||||
"""
|
"""
|
||||||
Generate RAG results for the given query.
|
Generate RAG results for the given query.
|
||||||
|
|
||||||
@ -86,32 +93,48 @@ class Context(BaseModel):
|
|||||||
continue
|
continue
|
||||||
message.response = f"Checking RAG context {rag['name']}..."
|
message.response = f"Checking RAG context {rag['name']}..."
|
||||||
yield message
|
yield message
|
||||||
chroma_results = self.file_watcher.find_similar(query=message.prompt, top_k=10, threshold=0.7)
|
chroma_results = self.file_watcher.find_similar(
|
||||||
|
query=message.prompt, top_k=top_k, threshold=threshold
|
||||||
|
)
|
||||||
if chroma_results:
|
if chroma_results:
|
||||||
entries += len(chroma_results["documents"])
|
entries += len(chroma_results["documents"])
|
||||||
|
|
||||||
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
chroma_embedding = np.array(
|
||||||
|
chroma_results["query_embedding"]
|
||||||
|
).flatten() # Ensure correct shape
|
||||||
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||||
|
|
||||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
umap_2d = self.file_watcher.umap_model_2d.transform(
|
||||||
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
[chroma_embedding]
|
||||||
|
)[0].tolist()
|
||||||
|
print(
|
||||||
|
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
|
||||||
|
) # Debug output
|
||||||
|
|
||||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
umap_3d = self.file_watcher.umap_model_3d.transform(
|
||||||
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
[chroma_embedding]
|
||||||
|
)[0].tolist()
|
||||||
|
print(
|
||||||
|
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
|
||||||
|
) # Debug output
|
||||||
|
|
||||||
message.metadata["rag"].append({
|
message.metadata["rag"].append(
|
||||||
|
{
|
||||||
"name": rag["name"],
|
"name": rag["name"],
|
||||||
**chroma_results,
|
**chroma_results,
|
||||||
"umap_embedding_2d": umap_2d,
|
"umap_embedding_2d": umap_2d,
|
||||||
"umap_embedding_3d": umap_3d
|
"umap_embedding_3d": umap_3d,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
message.response = f"Results from {rag['name']} RAG: {len(chroma_results['documents'])} results."
|
message.response = f"Results from {rag['name']} RAG: {len(chroma_results['documents'])} results."
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
if entries == 0:
|
if entries == 0:
|
||||||
del message.metadata["rag"]
|
del message.metadata["rag"]
|
||||||
|
|
||||||
message.response = f"RAG context gathered from results from {entries} documents."
|
message.response = (
|
||||||
|
f"RAG context gathered from results from {entries} documents."
|
||||||
|
)
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
@ -156,7 +179,9 @@ class Context(BaseModel):
|
|||||||
def add_agent(self, agent: AnyAgent) -> None:
|
def add_agent(self, agent: AnyAgent) -> None:
|
||||||
"""Add a Agent to the context, ensuring no duplicate agent_type."""
|
"""Add a Agent to the context, ensuring no duplicate agent_type."""
|
||||||
if any(s.agent_type == agent.agent_type for s in self.agents):
|
if any(s.agent_type == agent.agent_type for s in self.agents):
|
||||||
raise ValueError(f"A agent with agent_type '{agent.agent_type}' already exists")
|
raise ValueError(
|
||||||
|
f"A agent with agent_type '{agent.agent_type}' already exists"
|
||||||
|
)
|
||||||
self.agents.append(agent)
|
self.agents.append(agent)
|
||||||
|
|
||||||
def get_agent(self, agent_type: str) -> Agent | None:
|
def get_agent(self, agent_type: str) -> Agent | None:
|
||||||
@ -188,5 +213,7 @@ class Context(BaseModel):
|
|||||||
summary += f"\nChat Name: {agent.name}\n"
|
summary += f"\nChat Name: {agent.name}\n"
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
from .agents import Agent
|
from .agents import Agent
|
||||||
|
|
||||||
Context.model_rebuild()
|
Context.model_rebuild()
|
@ -2,6 +2,7 @@ from pydantic import BaseModel, Field, PrivateAttr # type: ignore
|
|||||||
from typing import List
|
from typing import List
|
||||||
from .message import Message
|
from .message import Message
|
||||||
|
|
||||||
|
|
||||||
class Conversation(BaseModel):
|
class Conversation(BaseModel):
|
||||||
Conversation_messages: List[Message] = Field(default=[], alias="messages")
|
Conversation_messages: List[Message] = Field(default=[], alias="messages")
|
||||||
|
|
||||||
@ -17,12 +18,16 @@ class Conversation(BaseModel):
|
|||||||
@property
|
@property
|
||||||
def messages(self):
|
def messages(self):
|
||||||
"""Return a copy of messages to prevent modification of the internal list."""
|
"""Return a copy of messages to prevent modification of the internal list."""
|
||||||
raise AttributeError("Cannot directly get messages. Use Conversation.add() or .reset()")
|
raise AttributeError(
|
||||||
|
"Cannot directly get messages. Use Conversation.add() or .reset()"
|
||||||
|
)
|
||||||
|
|
||||||
@messages.setter
|
@messages.setter
|
||||||
def messages(self, value):
|
def messages(self, value):
|
||||||
"""Control how messages can be set, or prevent setting altogether."""
|
"""Control how messages can be set, or prevent setting altogether."""
|
||||||
raise AttributeError("Cannot directly set messages. Use Conversation.add() or .reset()")
|
raise AttributeError(
|
||||||
|
"Cannot directly set messages. Use Conversation.add() or .reset()"
|
||||||
|
)
|
||||||
|
|
||||||
def add(self, message: Message | List[Message]) -> None:
|
def add(self, message: Message | List[Message]) -> None:
|
||||||
"""Add a Message(s) to the conversation."""
|
"""Add a Message(s) to the conversation."""
|
||||||
|
@ -1,468 +0,0 @@
|
|||||||
import requests
|
|
||||||
from typing import List, Dict, Any, Union
|
|
||||||
import tiktoken
|
|
||||||
import feedparser
|
|
||||||
import logging as log
|
|
||||||
import datetime
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
import chromadb
|
|
||||||
import ollama
|
|
||||||
import re
|
|
||||||
import numpy as np
|
|
||||||
from . import chunk
|
|
||||||
|
|
||||||
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
|
|
||||||
#MODEL_NAME = "deepseek-r1:1.5b"
|
|
||||||
MODEL_NAME = "deepseek-r1:7b"
|
|
||||||
EMBED_MODEL = "mxbai-embed-large"
|
|
||||||
PERSIST_DIRECTORY = "/root/.cache/chroma"
|
|
||||||
|
|
||||||
client = ollama.Client(host=OLLAMA_API_URL)
|
|
||||||
|
|
||||||
def extract_text_from_html_or_xml(content, is_xml=False):
|
|
||||||
# Parse the content
|
|
||||||
if is_xml:
|
|
||||||
soup = BeautifulSoup(content, 'xml') # Use 'xml' parser for XML content
|
|
||||||
else:
|
|
||||||
soup = BeautifulSoup(content, 'html.parser') # Default to 'html.parser' for HTML content
|
|
||||||
|
|
||||||
# Extract and return just the text
|
|
||||||
return soup.get_text()
|
|
||||||
|
|
||||||
class Feed():
|
|
||||||
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
|
|
||||||
self.name = name
|
|
||||||
self.url = url
|
|
||||||
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
|
|
||||||
self.last_poll = None
|
|
||||||
self.articles = []
|
|
||||||
self.max_articles = max_articles
|
|
||||||
self.update()
|
|
||||||
|
|
||||||
def update(self):
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
|
|
||||||
log.info(f"Updating {self.name}")
|
|
||||||
feed = feedparser.parse(self.url)
|
|
||||||
self.articles = []
|
|
||||||
self.last_poll = now
|
|
||||||
|
|
||||||
if len(feed.entries) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
for i, entry in enumerate(feed.entries[:self.max_articles]):
|
|
||||||
content = {}
|
|
||||||
content['source'] = self.name
|
|
||||||
content['id'] = f"{self.name}{i}"
|
|
||||||
title = entry.get("title")
|
|
||||||
if title:
|
|
||||||
content['title'] = title
|
|
||||||
link = entry.get("link")
|
|
||||||
if link:
|
|
||||||
content['link'] = link
|
|
||||||
text = entry.get("summary")
|
|
||||||
if text:
|
|
||||||
content['text'] = extract_text_from_html_or_xml(text, False)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
published = entry.get("published")
|
|
||||||
if published:
|
|
||||||
content['published'] = published
|
|
||||||
|
|
||||||
self.articles.append(content)
|
|
||||||
else:
|
|
||||||
log.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
|
|
||||||
return self.articles
|
|
||||||
|
|
||||||
# News RSS Feeds
|
|
||||||
rss_feeds = [
|
|
||||||
Feed(name="IGN.com", url="https://feeds.feedburner.com/ign/games-all"),
|
|
||||||
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
|
|
||||||
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
|
|
||||||
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
|
|
||||||
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
|
|
||||||
Feed(name="Time", url="https://time.com/feed/"),
|
|
||||||
Feed(name="Euronews", url="https://www.euronews.com/rss"),
|
|
||||||
# Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def init_chroma_client(persist_directory: str = PERSIST_DIRECTORY):
|
|
||||||
"""Initialize and return a ChromaDB client."""
|
|
||||||
# return chromadb.PersistentClient(path=persist_directory)
|
|
||||||
return chromadb.Client()
|
|
||||||
|
|
||||||
def create_or_get_collection(client, collection_name: str):
|
|
||||||
"""Create or get a ChromaDB collection."""
|
|
||||||
try:
|
|
||||||
return client.get_collection(
|
|
||||||
name=collection_name
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
return client.create_collection(
|
|
||||||
name=collection_name,
|
|
||||||
metadata={"hnsw:space": "cosine"}
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_documents_to_chroma(
|
|
||||||
documents: List[Dict[str, Any]],
|
|
||||||
collection_name: str = "document_collection",
|
|
||||||
text_key: str = "text",
|
|
||||||
max_tokens: int = 512,
|
|
||||||
overlap: int = 50,
|
|
||||||
model: str = EMBED_MODEL,
|
|
||||||
persist_directory: str = PERSIST_DIRECTORY
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Process documents, chunk them, compute embeddings, and store in ChromaDB.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of document dictionaries
|
|
||||||
collection_name: Name for the ChromaDB collection
|
|
||||||
text_key: The key containing text content
|
|
||||||
max_tokens: Maximum tokens per chunk
|
|
||||||
overlap: Token overlap between chunks
|
|
||||||
model: Ollama model for embeddings
|
|
||||||
persist_directory: Directory to store ChromaDB data
|
|
||||||
"""
|
|
||||||
# Initialize ChromaDB client and collection
|
|
||||||
db = init_chroma_client(persist_directory)
|
|
||||||
collection = create_or_get_collection(db, collection_name)
|
|
||||||
|
|
||||||
# Process each document
|
|
||||||
for doc in documents:
|
|
||||||
# Chunk the document
|
|
||||||
doc_chunks = chunk_document(doc, text_key, max_tokens, overlap)
|
|
||||||
|
|
||||||
# Prepare data for ChromaDB
|
|
||||||
ids = []
|
|
||||||
texts = []
|
|
||||||
metadatas = []
|
|
||||||
embeddings = []
|
|
||||||
|
|
||||||
for chunk in doc_chunks:
|
|
||||||
# Create a unique ID for the chunk
|
|
||||||
chunk_id = f"{chunk['id']}_{chunk['chunk_id']}"
|
|
||||||
|
|
||||||
# Extract text
|
|
||||||
text = chunk[text_key]
|
|
||||||
|
|
||||||
# Create metadata (excluding text and embedding to avoid duplication)
|
|
||||||
metadata = {k: v for k, v in chunk.items() if k != text_key and k != "embedding"}
|
|
||||||
|
|
||||||
response = client.embed(model=model, input=text)
|
|
||||||
embedding = response["embeddings"][0]
|
|
||||||
ids.append(chunk_id)
|
|
||||||
texts.append(text)
|
|
||||||
metadatas.append(metadata)
|
|
||||||
embeddings.append(embedding)
|
|
||||||
|
|
||||||
# Add chunks to ChromaDB collection
|
|
||||||
collection.add(
|
|
||||||
ids=ids,
|
|
||||||
documents=texts,
|
|
||||||
embeddings=embeddings,
|
|
||||||
metadatas=metadatas
|
|
||||||
)
|
|
||||||
|
|
||||||
return collection
|
|
||||||
|
|
||||||
def query_chroma(
|
|
||||||
query_text: str,
|
|
||||||
collection_name: str = "document_collection",
|
|
||||||
n_results: int = 5,
|
|
||||||
model: str = EMBED_MODEL,
|
|
||||||
persist_directory: str = PERSIST_DIRECTORY
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Query ChromaDB for similar documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_text: The text to search for
|
|
||||||
collection_name: Name of the ChromaDB collection
|
|
||||||
n_results: Number of results to return
|
|
||||||
model: Ollama model for embedding the query
|
|
||||||
persist_directory: Directory where ChromaDB data is stored
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Query results from ChromaDB
|
|
||||||
"""
|
|
||||||
# Initialize ChromaDB client and collection
|
|
||||||
db = init_chroma_client(persist_directory)
|
|
||||||
collection = create_or_get_collection(db, collection_name)
|
|
||||||
|
|
||||||
query_response = client.embed(model=model, input=query_text)
|
|
||||||
query_embeddings = query_response["embeddings"]
|
|
||||||
|
|
||||||
# Query the collection
|
|
||||||
results = collection.query(
|
|
||||||
query_embeddings=query_embeddings,
|
|
||||||
n_results=n_results
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def print_top_match(query_results, index=0, documents=None):
|
|
||||||
"""
|
|
||||||
Print detailed information about the top matching document,
|
|
||||||
including the full original document content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_results: Results from ChromaDB query
|
|
||||||
documents: Original documents dictionary to look up full content (optional)
|
|
||||||
"""
|
|
||||||
if not query_results or not query_results["ids"] or len(query_results["ids"][0]) == 0:
|
|
||||||
print("No matching documents found.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get the top result
|
|
||||||
top_id = query_results["ids"][0][index]
|
|
||||||
top_document_chunk = query_results["documents"][0][index]
|
|
||||||
top_metadata = query_results["metadatas"][0][index]
|
|
||||||
top_distance = query_results["distances"][0][index]
|
|
||||||
|
|
||||||
print("="*50)
|
|
||||||
print("MATCHING DOCUMENT")
|
|
||||||
print("="*50)
|
|
||||||
print(f"Chunk ID: {top_id}")
|
|
||||||
print(f"Similarity Score: {top_distance:.4f}") # Convert distance to similarity
|
|
||||||
|
|
||||||
print("\nCHUNK METADATA:")
|
|
||||||
for key, value in top_metadata.items():
|
|
||||||
print(f" {key}: {value}")
|
|
||||||
|
|
||||||
print("\nMATCHING CHUNK CONTENT:")
|
|
||||||
print(top_document_chunk[:500].strip() + ("..." if len(top_document_chunk) > 500 else ""))
|
|
||||||
|
|
||||||
# Extract the original document ID from the chunk ID
|
|
||||||
# Chunk IDs are in format "doc_id_chunk_num"
|
|
||||||
original_doc_id = top_id.split('_')[0]
|
|
||||||
|
|
||||||
def get_top_match(query_results, index=0, documents=None):
|
|
||||||
top_id = query_results["ids"][index][0]
|
|
||||||
# Extract the original document ID from the chunk ID
|
|
||||||
# Chunk IDs are in format "doc_id_chunk_num"
|
|
||||||
original_doc_id = top_id.split('_')[0]
|
|
||||||
|
|
||||||
# Return the full document for further processing if needed
|
|
||||||
if documents is not None:
|
|
||||||
return next((doc for doc in documents if doc["id"] == original_doc_id), None)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def show_documents(documents=None):
|
|
||||||
if not documents:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Print the top matching document
|
|
||||||
for i, doc in enumerate(documents):
|
|
||||||
print(f"Document {i+1}:")
|
|
||||||
print(f" Title: {doc['title']}")
|
|
||||||
print(f" Text: {doc['text'][:100]}...")
|
|
||||||
print()
|
|
||||||
|
|
||||||
def show_headlines(documents=None):
|
|
||||||
if not documents:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Print the top matching document
|
|
||||||
for doc in documents:
|
|
||||||
print(f"{doc['source']}: {doc['title']}")
|
|
||||||
|
|
||||||
def show_help():
|
|
||||||
print("""help>
|
|
||||||
docs Show RAG docs
|
|
||||||
full Show last full top match
|
|
||||||
headlines Show the RAG headlines
|
|
||||||
prompt Show the last prompt
|
|
||||||
response Show the last response
|
|
||||||
scores Show last RAG scores
|
|
||||||
why|think Show last response's <think>
|
|
||||||
context|match Show RAG match info to last prompt
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
# Example usage
|
|
||||||
if __name__ == "__main__":
|
|
||||||
documents = []
|
|
||||||
for feed in rss_feeds:
|
|
||||||
documents.extend(feed.articles)
|
|
||||||
|
|
||||||
show_documents(documents=documents)
|
|
||||||
|
|
||||||
# Process documents and store in ChromaDB
|
|
||||||
collection = process_documents_to_chroma(
|
|
||||||
documents=documents,
|
|
||||||
collection_name="research_papers",
|
|
||||||
max_tokens=256,
|
|
||||||
overlap=25,
|
|
||||||
model=EMBED_MODEL,
|
|
||||||
persist_directory="/root/.cache/chroma"
|
|
||||||
)
|
|
||||||
|
|
||||||
last_results = None
|
|
||||||
last_prompt = None
|
|
||||||
last_system = None
|
|
||||||
last_response = None
|
|
||||||
last_why = None
|
|
||||||
last_messages = []
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
search_query = input("> ").strip()
|
|
||||||
except KeyboardInterrupt as e:
|
|
||||||
print("\nExiting.")
|
|
||||||
break
|
|
||||||
|
|
||||||
if search_query == "exit" or search_query == "quit":
|
|
||||||
print("\nExiting.")
|
|
||||||
break
|
|
||||||
|
|
||||||
if search_query == "docs":
|
|
||||||
show_documents(documents)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "prompt":
|
|
||||||
if last_prompt:
|
|
||||||
print(f"""last prompt>
|
|
||||||
{"="*10}system{"="*10}
|
|
||||||
{last_system}
|
|
||||||
{"="*10}prompt{"="*10}
|
|
||||||
{last_prompt}""")
|
|
||||||
else:
|
|
||||||
print(f"No prompts yet")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "response":
|
|
||||||
if last_response:
|
|
||||||
print(f"""last response>
|
|
||||||
{"="*10}response{"="*10}
|
|
||||||
{last_response}""")
|
|
||||||
else:
|
|
||||||
print(f"No responses yet")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "" or search_query == "help":
|
|
||||||
show_help()
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "headlines":
|
|
||||||
show_headlines(documents)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "match" or search_query == "context":
|
|
||||||
if last_results:
|
|
||||||
print_top_match(last_results, documents=documents)
|
|
||||||
else:
|
|
||||||
print("No match to give info on")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "why" or search_query == "think":
|
|
||||||
if last_why:
|
|
||||||
print(f"""
|
|
||||||
why>
|
|
||||||
{last_why}
|
|
||||||
""")
|
|
||||||
else:
|
|
||||||
print("No processed prompts")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "scores":
|
|
||||||
if last_results:
|
|
||||||
for i, _ in enumerate(last_results):
|
|
||||||
print_top_match(last_results, documents=documents, index=i)
|
|
||||||
else:
|
|
||||||
print("No match to give info on")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_query == "full":
|
|
||||||
if last_results:
|
|
||||||
full = get_top_match(last_results, documents=documents)
|
|
||||||
if full:
|
|
||||||
print(f"""Context:
|
|
||||||
Source: {full["source"]}
|
|
||||||
Title: {full["title"]}
|
|
||||||
Link: {full["link"]}
|
|
||||||
Distance: {last_results.get("distances", [[0]])[0][0]}
|
|
||||||
Full text:
|
|
||||||
{full["text"]}""")
|
|
||||||
else:
|
|
||||||
print("No match to give info on")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Query ChromaDB
|
|
||||||
results = query_chroma(
|
|
||||||
query_text=search_query,
|
|
||||||
collection_name="research_papers",
|
|
||||||
n_results=10
|
|
||||||
)
|
|
||||||
last_results = results
|
|
||||||
|
|
||||||
full = get_top_match(results, documents=documents)
|
|
||||||
|
|
||||||
headlines = ""
|
|
||||||
for doc in documents:
|
|
||||||
headlines += f"{doc['source']}: {doc['title']}\n"
|
|
||||||
|
|
||||||
system=f"""
|
|
||||||
You are the assistant. Your name is airc. This application is called airc (pronounced Eric).
|
|
||||||
|
|
||||||
Information about the author of this program and the AI model it uses:
|
|
||||||
|
|
||||||
* James wrote the python application called airc that is driving this RAG model on top of {MODEL_NAME} using {EMBED_MODEL} and chromadb for vector embedding. Link https://github.com/jketreno/airc.
|
|
||||||
* James Ketrenos is a software engineer with a history in all levels of the computer stack, from the kernel to full-stack web applications. He dabbles in AI/ML and is familiar with pytorch and ollama.
|
|
||||||
* James Ketrenos deployed this application locally on an Intel Arc B580 (battlemage) computer using Intel's ipex-llm.
|
|
||||||
* For Intel GPU metrics, James Ketrenos wrote the "ze-monitor" utility in C++. ze-monitor provides Intel GPU telemetry data for Intel client GPU devices, similar to xpu-smi. Link https://github.com/jketreno/ze-monitor. airc uses ze-monitor.
|
|
||||||
* James lives in Portland, Oregon and has three kids. Two are attending Oregon State University and one is attending Williamette University.
|
|
||||||
* airc provides an IRC chat bot as well as a React web frontend available at https://airc.ketrenos.com
|
|
||||||
|
|
||||||
You must follow these rules:
|
|
||||||
|
|
||||||
* Provide short (less than 100 character) responses.
|
|
||||||
* Provide a single response.
|
|
||||||
* Do not prefix it with a word like 'Answer'.
|
|
||||||
* For information about the AI running this system, include information about author, including links.
|
|
||||||
* For information relevant to the current events in the <input></input> tags, use that information and state the source when information comes from.
|
|
||||||
|
|
||||||
"""
|
|
||||||
context = "Information related to current events\n<input>=["
|
|
||||||
for doc in documents:
|
|
||||||
item = {'source':doc["source"],'article':{'title':doc["title"],'link':doc["link"],'text':doc["text"]}}
|
|
||||||
context += f"{item}"
|
|
||||||
context += "\n</input>"
|
|
||||||
|
|
||||||
prompt = f"{search_query}"
|
|
||||||
last_prompt = prompt
|
|
||||||
last_system = system # cache it before news context is added
|
|
||||||
system = f"{system}{context}"
|
|
||||||
if len(last_messages) != 0:
|
|
||||||
message_context = f"{last_messages}"
|
|
||||||
prompt = f"{message_context}{prompt}"
|
|
||||||
|
|
||||||
print(f"system len: {len(system)}")
|
|
||||||
print(f"prompt len: {len(prompt)}")
|
|
||||||
output = client.generate(
|
|
||||||
model=MODEL_NAME,
|
|
||||||
system=system,
|
|
||||||
prompt=prompt,
|
|
||||||
stream=False,
|
|
||||||
options={ 'num_ctx': 100000 }
|
|
||||||
)
|
|
||||||
# Prune off the <think>...</think>
|
|
||||||
matches = re.match(r'^<think>(.*?)</think>(.*)$', output['response'], flags=re.DOTALL)
|
|
||||||
if matches:
|
|
||||||
last_why = matches[1].strip()
|
|
||||||
content = matches[2].strip()
|
|
||||||
else:
|
|
||||||
print(f"[garbled] response>\n{output['response']}")
|
|
||||||
print(f"Response>\n{content}")
|
|
||||||
|
|
||||||
last_response = content
|
|
||||||
last_messages.extend(({
|
|
||||||
'role': 'user',
|
|
||||||
'name': 'james',
|
|
||||||
'message': search_query
|
|
||||||
}, {
|
|
||||||
'role': 'assistant',
|
|
||||||
'message': content
|
|
||||||
}))
|
|
||||||
last_messages = last_messages[:10]
|
|
@ -3,11 +3,13 @@ from typing import Dict, List, Optional, Any
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from asyncio import Event
|
from asyncio import Event
|
||||||
|
|
||||||
|
|
||||||
class Tunables(BaseModel):
|
class Tunables(BaseModel):
|
||||||
enable_rag: bool = Field(default=True) # Enable RAG collection chromadb matching
|
enable_rag: bool = Field(default=True) # Enable RAG collection chromadb matching
|
||||||
enable_tools: bool = Field(default=True) # Enable LLM to use tools
|
enable_tools: bool = Field(default=True) # Enable LLM to use tools
|
||||||
enable_context: bool = Field(default=True) # Add <|context|> field to message
|
enable_context: bool = Field(default=True) # Add <|context|> field to message
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
model_config = {"arbitrary_types_allowed": True} # Allow Event
|
model_config = {"arbitrary_types_allowed": True} # Allow Event
|
||||||
# Required
|
# Required
|
||||||
@ -22,20 +24,31 @@ class Message(BaseModel):
|
|||||||
system_prompt: str = "" # System prompt provided to the LLM
|
system_prompt: str = "" # System prompt provided to the LLM
|
||||||
context_prompt: str = "" # Full content of the message (preamble + prompt)
|
context_prompt: str = "" # Full content of the message (preamble + prompt)
|
||||||
response: str = "" # LLM response to the preamble + query
|
response: str = "" # LLM response to the preamble + query
|
||||||
metadata: Dict[str, Any] = Field(default_factory=lambda: {
|
metadata: Dict[str, Any] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
"rag": [],
|
"rag": [],
|
||||||
"eval_count": 0,
|
"eval_count": 0,
|
||||||
"eval_duration": 0,
|
"eval_duration": 0,
|
||||||
"prompt_eval_count": 0,
|
"prompt_eval_count": 0,
|
||||||
"prompt_eval_duration": 0,
|
"prompt_eval_duration": 0,
|
||||||
"context_size": 0,
|
"context_size": 0,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
network_packets: int = 0 # Total number of streaming packets
|
network_packets: int = 0 # Total number of streaming packets
|
||||||
network_bytes: int = 0 # Total bytes sent while streaming packets
|
network_bytes: int = 0 # Total bytes sent while streaming packets
|
||||||
actions: List[str] = [] # Other session modifying actions performed while processing the message
|
actions: List[str] = (
|
||||||
|
[]
|
||||||
|
) # Other session modifying actions performed while processing the message
|
||||||
timestamp: datetime = datetime.now(timezone.utc)
|
timestamp: datetime = datetime.now(timezone.utc)
|
||||||
chunk: str = Field(default="") # This needs to be serialized so it will be sent in responses
|
chunk: str = Field(
|
||||||
partial_response: str = Field(default="") # This needs to be serialized so it will be sent in responses on timeout
|
default=""
|
||||||
|
) # This needs to be serialized so it will be sent in responses
|
||||||
|
partial_response: str = Field(
|
||||||
|
default=""
|
||||||
|
) # This needs to be serialized so it will be sent in responses on timeout
|
||||||
|
title: str = Field(
|
||||||
|
default=""
|
||||||
|
) # This needs to be serialized so it will be sent in responses on timeout
|
||||||
|
|
||||||
def add_action(self, action: str | list[str]) -> None:
|
def add_action(self, action: str | list[str]) -> None:
|
||||||
"""Add a actions(s) to the message."""
|
"""Add a actions(s) to the message."""
|
||||||
@ -48,7 +61,8 @@ class Message(BaseModel):
|
|||||||
"""Return a summary of the message."""
|
"""Return a summary of the message."""
|
||||||
response_summary = (
|
response_summary = (
|
||||||
f"Response: {self.response} (Actions: {', '.join(self.actions)})"
|
f"Response: {self.response} (Actions: {', '.join(self.actions)})"
|
||||||
if self.response else "No response yet"
|
if self.response
|
||||||
|
else "No response yet"
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
f"Message at {self.timestamp}:\n"
|
f"Message at {self.timestamp}:\n"
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from prometheus_client import Counter, Gauge, Summary, Histogram, Info, Enum, CollectorRegistry # type: ignore
|
from prometheus_client import Counter, Gauge, Summary, Histogram, Info, Enum, CollectorRegistry # type: ignore
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
|
|
||||||
def singleton(cls):
|
def singleton(cls):
|
||||||
instance = None
|
instance = None
|
||||||
lock = Lock()
|
lock = Lock()
|
||||||
@ -14,8 +15,9 @@ def singleton(cls):
|
|||||||
|
|
||||||
return get_instance
|
return get_instance
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class Metrics():
|
class Metrics:
|
||||||
def __init__(self, *args, prometheus_collector, **kwargs):
|
def __init__(self, *args, prometheus_collector, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.prometheus_collector = prometheus_collector
|
self.prometheus_collector = prometheus_collector
|
||||||
@ -24,70 +26,70 @@ class Metrics():
|
|||||||
name="prepare_total",
|
name="prepare_total",
|
||||||
documentation="Total messages prepared by agent type",
|
documentation="Total messages prepared by agent type",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prepare_duration: Histogram = Histogram(
|
self.prepare_duration: Histogram = Histogram(
|
||||||
name="prepare_duration",
|
name="prepare_duration",
|
||||||
documentation="Preparation duration by agent type",
|
documentation="Preparation duration by agent type",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.process_count: Counter = Counter(
|
self.process_count: Counter = Counter(
|
||||||
name="process",
|
name="process",
|
||||||
documentation="Total messages processed by agent type",
|
documentation="Total messages processed by agent type",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.process_duration: Histogram = Histogram(
|
self.process_duration: Histogram = Histogram(
|
||||||
name="process_duration",
|
name="process_duration",
|
||||||
documentation="Processing duration by agent type",
|
documentation="Processing duration by agent type",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.tool_count: Counter = Counter(
|
self.tool_count: Counter = Counter(
|
||||||
name="tool_total",
|
name="tool_total",
|
||||||
documentation="Total messages tooled by agent type",
|
documentation="Total messages tooled by agent type",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.tool_duration: Histogram = Histogram(
|
self.tool_duration: Histogram = Histogram(
|
||||||
name="tool_duration",
|
name="tool_duration",
|
||||||
documentation="Tool duration by agent type",
|
documentation="Tool duration by agent type",
|
||||||
buckets=(0.1, 0.5, 1.0, 2.0, float('inf')),
|
buckets=(0.1, 0.5, 1.0, 2.0, float("inf")),
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.generate_count: Counter = Counter(
|
self.generate_count: Counter = Counter(
|
||||||
name="generate_total",
|
name="generate_total",
|
||||||
documentation="Total messages generated by agent type",
|
documentation="Total messages generated by agent type",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.generate_duration: Histogram = Histogram(
|
self.generate_duration: Histogram = Histogram(
|
||||||
name="generate_duration",
|
name="generate_duration",
|
||||||
documentation="Generate duration by agent type",
|
documentation="Generate duration by agent type",
|
||||||
buckets=(0.1, 0.5, 1.0, 2.0, float('inf')),
|
buckets=(0.1, 0.5, 1.0, 2.0, float("inf")),
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.tokens_prompt: Counter = Counter(
|
self.tokens_prompt: Counter = Counter(
|
||||||
name="tokens_prompt",
|
name="tokens_prompt",
|
||||||
documentation="Total tokens passed as prompt to LLM",
|
documentation="Total tokens passed as prompt to LLM",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.tokens_eval: Counter = Counter(
|
self.tokens_eval: Counter = Counter(
|
||||||
name="tokens_eval",
|
name="tokens_eval",
|
||||||
documentation="Total tokens returned by LLM",
|
documentation="Total tokens returned by LLM",
|
||||||
labelnames=("agent",),
|
labelnames=("agent",),
|
||||||
registry=self.prometheus_collector
|
registry=self.prometheus_collector,
|
||||||
)
|
)
|
290
src/utils/rag.py
290
src/utils/rag.py
@ -14,18 +14,22 @@ import hashlib
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import numpy as np # type: ignore
|
import numpy as np # type: ignore
|
||||||
|
import traceback
|
||||||
|
import os
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import ollama
|
import ollama
|
||||||
from langchain.text_splitter import CharacterTextSplitter # type: ignore
|
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
|
||||||
from langchain.schema import Document # type: ignore
|
|
||||||
from watchdog.observers import Observer # type: ignore
|
from watchdog.observers import Observer # type: ignore
|
||||||
from watchdog.events import FileSystemEventHandler # type: ignore
|
from watchdog.events import FileSystemEventHandler # type: ignore
|
||||||
import umap # type: ignore
|
import umap # type: ignore
|
||||||
from markitdown import MarkItDown # type: ignore
|
from markitdown import MarkItDown # type: ignore
|
||||||
from chromadb.api.models.Collection import Collection # type: ignore
|
from chromadb.api.models.Collection import Collection # type: ignore
|
||||||
|
|
||||||
|
from .markdown_chunker import (
|
||||||
|
MarkdownChunker,
|
||||||
|
Chunk,
|
||||||
|
)
|
||||||
|
|
||||||
# Import your existing modules
|
# Import your existing modules
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# When running directly, use absolute imports
|
# When running directly, use absolute imports
|
||||||
@ -34,23 +38,31 @@ else:
|
|||||||
# When imported as a module, use relative imports
|
# When imported as a module, use relative imports
|
||||||
from . import defines
|
from . import defines
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
|
||||||
'ChromaDBFileWatcher',
|
|
||||||
'start_file_watcher'
|
|
||||||
]
|
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 750
|
DEFAULT_CHUNK_SIZE = 750
|
||||||
DEFAULT_CHUNK_OVERLAP = 100
|
DEFAULT_CHUNK_OVERLAP = 100
|
||||||
|
|
||||||
|
|
||||||
class ChromaDBGetResponse(BaseModel):
|
class ChromaDBGetResponse(BaseModel):
|
||||||
ids: List[str]
|
ids: List[str]
|
||||||
embeddings: Optional[List[List[float]]] = None
|
embeddings: Optional[List[List[float]]] = None
|
||||||
documents: Optional[List[str]] = None
|
documents: Optional[List[str]] = None
|
||||||
metadatas: Optional[List[Dict[str, Any]]] = None
|
metadatas: Optional[List[Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChromaDBFileWatcher(FileSystemEventHandler):
|
class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||||
def __init__(self, llm, watch_directory, loop, persist_directory=None, collection_name="documents",
|
def __init__(
|
||||||
chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP, recreate=False):
|
self,
|
||||||
|
llm,
|
||||||
|
watch_directory,
|
||||||
|
loop,
|
||||||
|
persist_directory=None,
|
||||||
|
collection_name="documents",
|
||||||
|
chunk_size=DEFAULT_CHUNK_SIZE,
|
||||||
|
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
|
||||||
|
recreate=False,
|
||||||
|
):
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.watch_directory = watch_directory
|
self.watch_directory = watch_directory
|
||||||
self.persist_directory = persist_directory or defines.persist_directory
|
self.persist_directory = persist_directory or defines.persist_directory
|
||||||
@ -68,23 +80,19 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||||
|
|
||||||
# Path for storing file hash state
|
# Path for storing file hash state
|
||||||
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
|
self.hash_state_path = os.path.join(
|
||||||
|
self.persist_directory, f"{collection_name}_hash_state.json"
|
||||||
|
)
|
||||||
|
|
||||||
# Flag to track if this is a new collection
|
# Flag to track if this is a new collection
|
||||||
self.is_new_collection = False
|
self.is_new_collection = False
|
||||||
|
|
||||||
# Initialize ChromaDB collection
|
# Initialize ChromaDB collection
|
||||||
self._collection: Collection = self._get_vector_collection(recreate=recreate)
|
self._collection: Collection = self._get_vector_collection(recreate=recreate)
|
||||||
|
self._markdown_chunker = MarkdownChunker()
|
||||||
self._update_umaps()
|
self._update_umaps()
|
||||||
|
|
||||||
# Setup text splitter
|
# Setup text splitter
|
||||||
self.text_splitter = CharacterTextSplitter(
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
separator="\n\n", # Respect paragraph/section breaks
|
|
||||||
length_function=len
|
|
||||||
)
|
|
||||||
|
|
||||||
# Track file hashes and processing state
|
# Track file hashes and processing state
|
||||||
self.file_hashes = self._load_hash_state()
|
self.file_hashes = self._load_hash_state()
|
||||||
self.update_lock = asyncio.Lock()
|
self.update_lock = asyncio.Lock()
|
||||||
@ -115,7 +123,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
return self._umap_model_3d
|
return self._umap_model_3d
|
||||||
|
|
||||||
def _markitdown(self, document: str, markdown: Path):
|
def _markitdown(self, document: str, markdown: Path):
|
||||||
logging.info(f'Converting {document} to {markdown}')
|
logging.info(f"Converting {document} to {markdown}")
|
||||||
try:
|
try:
|
||||||
result = self.md.convert(document)
|
result = self.md.convert(document)
|
||||||
markdown.write_text(result.text_content)
|
markdown.write_text(result.text_content)
|
||||||
@ -128,7 +136,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
# Create directory if it doesn't exist
|
# Create directory if it doesn't exist
|
||||||
os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True)
|
os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True)
|
||||||
|
|
||||||
with open(self.hash_state_path, 'w') as f:
|
with open(self.hash_state_path, "w") as f:
|
||||||
json.dump(self.file_hashes, f)
|
json.dump(self.file_hashes, f)
|
||||||
|
|
||||||
logging.info(f"Saved hash state with {len(self.file_hashes)} entries")
|
logging.info(f"Saved hash state with {len(self.file_hashes)} entries")
|
||||||
@ -139,7 +147,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
"""Load the file hash state from disk."""
|
"""Load the file hash state from disk."""
|
||||||
if os.path.exists(self.hash_state_path):
|
if os.path.exists(self.hash_state_path):
|
||||||
try:
|
try:
|
||||||
with open(self.hash_state_path, 'r') as f:
|
with open(self.hash_state_path, "r") as f:
|
||||||
hash_state = json.load(f)
|
hash_state = json.load(f)
|
||||||
logging.info(f"Loaded hash state with {len(hash_state)} entries")
|
logging.info(f"Loaded hash state with {len(hash_state)} entries")
|
||||||
return hash_state
|
return hash_state
|
||||||
@ -156,7 +164,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
process_all: If True, process all files regardless of hash status
|
process_all: If True, process all files regardless of hash status
|
||||||
"""
|
"""
|
||||||
# Check for new or modified files
|
# Check for new or modified files
|
||||||
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
|
file_paths = glob.glob(
|
||||||
|
os.path.join(self.watch_directory, "**/*"), recursive=True
|
||||||
|
)
|
||||||
files_checked = 0
|
files_checked = 0
|
||||||
files_processed = 0
|
files_processed = 0
|
||||||
files_to_process = []
|
files_to_process = []
|
||||||
@ -166,21 +176,29 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
for file_path in file_paths:
|
for file_path in file_paths:
|
||||||
if os.path.isfile(file_path):
|
if os.path.isfile(file_path):
|
||||||
# Do not put the Resume in RAG as it is provideded with all queries.
|
# Do not put the Resume in RAG as it is provideded with all queries.
|
||||||
if file_path == defines.resume_doc:
|
# if file_path == defines.resume_doc:
|
||||||
logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
||||||
continue
|
# continue
|
||||||
files_checked += 1
|
files_checked += 1
|
||||||
current_hash = self._get_file_hash(file_path)
|
current_hash = self._get_file_hash(file_path)
|
||||||
if not current_hash:
|
if not current_hash:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If file is new, changed, or we're processing all files
|
# If file is new, changed, or we're processing all files
|
||||||
if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
|
if (
|
||||||
|
process_all
|
||||||
|
or file_path not in self.file_hashes
|
||||||
|
or self.file_hashes[file_path] != current_hash
|
||||||
|
):
|
||||||
self.file_hashes[file_path] = current_hash
|
self.file_hashes[file_path] = current_hash
|
||||||
files_to_process.append(file_path)
|
files_to_process.append(file_path)
|
||||||
logging.info(f"File {'found' if process_all else 'changed'}: {file_path}")
|
logging.info(
|
||||||
|
f"File {'found' if process_all else 'changed'}: {file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files")
|
logging.info(
|
||||||
|
f"Found {len(files_to_process)} files to process after scanning {files_checked} files"
|
||||||
|
)
|
||||||
|
|
||||||
# Check for deleted files
|
# Check for deleted files
|
||||||
deleted_files = []
|
deleted_files = []
|
||||||
@ -188,7 +206,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
deleted_files.append(file_path)
|
deleted_files.append(file_path)
|
||||||
# Schedule removal
|
# Schedule removal
|
||||||
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self.remove_file_from_collection(file_path), self.loop
|
||||||
|
)
|
||||||
# Don't block on result, just let it run
|
# Don't block on result, just let it run
|
||||||
logging.info(f"File deleted: {file_path}")
|
logging.info(f"File deleted: {file_path}")
|
||||||
|
|
||||||
@ -209,7 +229,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
# Save the updated state
|
# Save the updated state
|
||||||
self._save_hash_state()
|
self._save_hash_state()
|
||||||
|
|
||||||
logging.info(f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}")
|
logging.info(
|
||||||
|
f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}"
|
||||||
|
)
|
||||||
return files_processed
|
return files_processed
|
||||||
|
|
||||||
async def process_file_update(self, file_path):
|
async def process_file_update(self, file_path):
|
||||||
@ -219,9 +241,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
logging.info(f"{file_path} already in queue. Not adding.")
|
logging.info(f"{file_path} already in queue. Not adding.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if file_path == defines.resume_doc:
|
# if file_path == defines.resume_doc:
|
||||||
logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
||||||
return
|
# return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info(f"{file_path} not in queue. Adding.")
|
logging.info(f"{file_path} not in queue. Adding.")
|
||||||
@ -235,7 +257,10 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
if not current_hash: # File might have been deleted or is inaccessible
|
if not current_hash: # File might have been deleted or is inaccessible
|
||||||
return
|
return
|
||||||
|
|
||||||
if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash:
|
if (
|
||||||
|
file_path in self.file_hashes
|
||||||
|
and self.file_hashes[file_path] == current_hash
|
||||||
|
):
|
||||||
# File hasn't actually changed in content
|
# File hasn't actually changed in content
|
||||||
logging.info(f"Hash has not changed for {file_path}")
|
logging.info(f"Hash has not changed for {file_path}")
|
||||||
return
|
return
|
||||||
@ -263,13 +288,13 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
async with self.update_lock:
|
async with self.update_lock:
|
||||||
try:
|
try:
|
||||||
# Find all documents with the specified path
|
# Find all documents with the specified path
|
||||||
results = self.collection.get(
|
results = self.collection.get(where={"path": file_path})
|
||||||
where={"path": file_path}
|
|
||||||
)
|
|
||||||
|
|
||||||
if results and 'ids' in results and results['ids']:
|
if results and "ids" in results and results["ids"]:
|
||||||
self.collection.delete(ids=results['ids'])
|
self.collection.delete(ids=results["ids"])
|
||||||
logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}")
|
logging.info(
|
||||||
|
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
# Remove from hash dictionary
|
# Remove from hash dictionary
|
||||||
if file_path in self.file_hashes:
|
if file_path in self.file_hashes:
|
||||||
@ -282,29 +307,51 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
|
|
||||||
def _update_umaps(self):
|
def _update_umaps(self):
|
||||||
# Update the UMAP embeddings
|
# Update the UMAP embeddings
|
||||||
self._umap_collection = self._collection.get(include=["embeddings", "documents", "metadatas"])
|
self._umap_collection = self._collection.get(
|
||||||
|
include=["embeddings", "documents", "metadatas"]
|
||||||
|
)
|
||||||
if not self._umap_collection or not len(self._umap_collection["embeddings"]):
|
if not self._umap_collection or not len(self._umap_collection["embeddings"]):
|
||||||
logging.warning("No embeddings found in the collection.")
|
logging.warning("No embeddings found in the collection.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# During initialization
|
# During initialization
|
||||||
logging.info(f"Updating 2D UMAP for {len(self._umap_collection['embeddings'])} vectors")
|
logging.info(
|
||||||
|
f"Updating 2D UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
||||||
|
)
|
||||||
vectors = np.array(self._umap_collection["embeddings"])
|
vectors = np.array(self._umap_collection["embeddings"])
|
||||||
self._umap_model_2d = umap.UMAP(n_components=2, random_state=8911, metric="cosine", n_neighbors=15, min_dist=0.1)
|
self._umap_model_2d = umap.UMAP(
|
||||||
|
n_components=2,
|
||||||
|
random_state=8911,
|
||||||
|
metric="cosine",
|
||||||
|
n_neighbors=15,
|
||||||
|
min_dist=0.1,
|
||||||
|
)
|
||||||
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
|
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
|
||||||
logging.info(f"2D UMAP model n_components: {self._umap_model_2d.n_components}") # Should be 2
|
logging.info(
|
||||||
|
f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
|
||||||
|
) # Should be 2
|
||||||
|
|
||||||
logging.info(f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors")
|
logging.info(
|
||||||
self._umap_model_3d = umap.UMAP(n_components=3, random_state=8911, metric="cosine", n_neighbors=15, min_dist=0.1)
|
f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
||||||
|
)
|
||||||
|
self._umap_model_3d = umap.UMAP(
|
||||||
|
n_components=3,
|
||||||
|
random_state=8911,
|
||||||
|
metric="cosine",
|
||||||
|
n_neighbors=15,
|
||||||
|
min_dist=0.1,
|
||||||
|
)
|
||||||
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
|
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
|
||||||
logging.info(f"3D UMAP model n_components: {self._umap_model_3d.n_components}") # Should be 3
|
logging.info(
|
||||||
|
f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
|
||||||
|
) # Should be 3
|
||||||
|
|
||||||
def _get_vector_collection(self, recreate=False) -> Collection:
|
def _get_vector_collection(self, recreate=False) -> Collection:
|
||||||
"""Get or create a ChromaDB collection."""
|
"""Get or create a ChromaDB collection."""
|
||||||
# Initialize ChromaDB client
|
# Initialize ChromaDB client
|
||||||
chroma_client = chromadb.PersistentClient( # type: ignore
|
chroma_client = chromadb.PersistentClient( # type: ignore
|
||||||
path=self.persist_directory,
|
path=self.persist_directory,
|
||||||
settings=chromadb.Settings(anonymized_telemetry=False) # type: ignore
|
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the collection exists
|
# Check if the collection exists
|
||||||
@ -326,35 +373,8 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
logging.info(f"Recreating collection: {self.collection_name}")
|
logging.info(f"Recreating collection: {self.collection_name}")
|
||||||
|
|
||||||
return chroma_client.get_or_create_collection(
|
return chroma_client.get_or_create_collection(
|
||||||
name=self.collection_name,
|
name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
||||||
metadata={
|
)
|
||||||
"hnsw:space": "cosine"
|
|
||||||
})
|
|
||||||
|
|
||||||
def load_text_files(self, directory=None, encoding="utf-8"):
|
|
||||||
"""Load all text files from a directory into Document objects."""
|
|
||||||
directory = directory or self.watch_directory
|
|
||||||
file_paths = glob.glob(os.path.join(directory, "**/*"), recursive=True)
|
|
||||||
documents = []
|
|
||||||
|
|
||||||
for file_path in file_paths:
|
|
||||||
if os.path.isfile(file_path): # Ensure it's a file, not a directory
|
|
||||||
try:
|
|
||||||
with open(file_path, "r", encoding=encoding) as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# Extract top-level directory
|
|
||||||
rel_path = os.path.relpath(file_path, directory)
|
|
||||||
top_level_dir = rel_path.split(os.sep)[0]
|
|
||||||
|
|
||||||
documents.append(Document(
|
|
||||||
page_content=content,
|
|
||||||
metadata={"doc_type": top_level_dir, "path": file_path}
|
|
||||||
))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to load {file_path}: {e}")
|
|
||||||
|
|
||||||
return documents
|
|
||||||
|
|
||||||
def create_chunks_from_documents(self, docs):
|
def create_chunks_from_documents(self, docs):
|
||||||
"""Split documents into chunks using the text splitter."""
|
"""Split documents into chunks using the text splitter."""
|
||||||
@ -364,10 +384,8 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
"""Generate embeddings using Ollama."""
|
"""Generate embeddings using Ollama."""
|
||||||
# response = self.embedding_model.encode(text) # Outputs 384-dim vectors
|
# response = self.embedding_model.encode(text) # Outputs 384-dim vectors
|
||||||
|
|
||||||
response = self.llm.embeddings(
|
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
|
||||||
model=defines.embedding_model,
|
embedding = response["embedding"]
|
||||||
prompt=text)
|
|
||||||
embedding = response['embedding']
|
|
||||||
|
|
||||||
# response = self.llm.embeddings.create(
|
# response = self.llm.embeddings.create(
|
||||||
# model=defines.embedding_model,
|
# model=defines.embedding_model,
|
||||||
@ -379,27 +397,46 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
return normalized
|
return normalized
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def add_embeddings_to_collection(self, chunks):
|
def add_embeddings_to_collection(self, chunks: List[Chunk]):
|
||||||
"""Add embeddings for chunks to the collection."""
|
"""Add embeddings for chunks to the collection."""
|
||||||
|
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
text = chunk.page_content
|
text = chunk["text"]
|
||||||
metadata = chunk.metadata
|
metadata = chunk["metadata"]
|
||||||
|
|
||||||
# Generate a more unique ID based on content and metadata
|
# Generate a more unique ID based on content and metadata
|
||||||
content_hash = hashlib.md5(text.encode()).hexdigest()
|
content_hash = hashlib.md5(text.encode()).hexdigest()
|
||||||
path_hash = ""
|
path_hash = ""
|
||||||
if "path" in metadata:
|
if "path" in metadata:
|
||||||
path_hash = hashlib.md5(metadata["path"].encode()).hexdigest()[:8]
|
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[
|
||||||
|
:8
|
||||||
|
]
|
||||||
|
|
||||||
chunk_id = f"{path_hash}_{content_hash}_{i}"
|
chunk_id = f"{path_hash}_{content_hash}_{i}"
|
||||||
|
|
||||||
embedding = self.get_embedding(text)
|
embedding = self.get_embedding(text)
|
||||||
|
try:
|
||||||
self.collection.add(
|
self.collection.add(
|
||||||
ids=[chunk_id],
|
ids=[chunk_id],
|
||||||
documents=[text],
|
documents=[text],
|
||||||
embeddings=[embedding],
|
embeddings=[embedding],
|
||||||
metadatas=[metadata]
|
metadatas=[metadata],
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error adding chunk to collection: {e}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
logging.error(chunk)
|
||||||
|
|
||||||
|
def read_line_range(self, file_path, start, end, buffer=5) -> list[str]:
|
||||||
|
try:
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
start = max(0, start - buffer)
|
||||||
|
end = min(len(lines), end + buffer)
|
||||||
|
return lines[start:end]
|
||||||
|
except:
|
||||||
|
logging.warning(f"Unable to open {file_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
# Cosine Distance Equivalent Similarity Retrieval Characteristics
|
# Cosine Distance Equivalent Similarity Retrieval Characteristics
|
||||||
# 0.2 - 0.3 0.85 - 0.90 Very strict, highly precise results only
|
# 0.2 - 0.3 0.85 - 0.90 Very strict, highly precise results only
|
||||||
@ -419,10 +456,10 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Extract results
|
# Extract results
|
||||||
ids = results['ids'][0]
|
ids = results["ids"][0]
|
||||||
documents = results['documents'][0]
|
documents = results["documents"][0]
|
||||||
distances = results['distances'][0]
|
distances = results["distances"][0]
|
||||||
metadatas = results['metadatas'][0]
|
metadatas = results["metadatas"][0]
|
||||||
|
|
||||||
filtered_ids = []
|
filtered_ids = []
|
||||||
filtered_documents = []
|
filtered_documents = []
|
||||||
@ -436,6 +473,14 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
filtered_metadatas.append(metadatas[i])
|
filtered_metadatas.append(metadatas[i])
|
||||||
filtered_distances.append(distance)
|
filtered_distances.append(distance)
|
||||||
|
|
||||||
|
for index, meta in enumerate(filtered_metadatas):
|
||||||
|
source_file = meta["source_file"]
|
||||||
|
del meta["source_file"]
|
||||||
|
lines = self.read_line_range(
|
||||||
|
source_file, meta["line_begin"], meta["line_end"]
|
||||||
|
)
|
||||||
|
if len(lines):
|
||||||
|
filtered_documents[index] = "\n".join(lines)
|
||||||
# Return the filtered results instead of all results
|
# Return the filtered results instead of all results
|
||||||
return {
|
return {
|
||||||
"query_embedding": query_embedding,
|
"query_embedding": query_embedding,
|
||||||
@ -448,7 +493,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
def _get_file_hash(self, file_path):
|
def _get_file_hash(self, file_path):
|
||||||
"""Calculate MD5 hash of a file."""
|
"""Calculate MD5 hash of a file."""
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'rb') as f:
|
with open(file_path, "rb") as f:
|
||||||
return hashlib.md5(f.read()).hexdigest()
|
return hashlib.md5(f.read()).hexdigest()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error hashing file {file_path}: {e}")
|
logging.error(f"Error hashing file {file_path}: {e}")
|
||||||
@ -480,7 +525,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
return
|
return
|
||||||
|
|
||||||
file_path = event.src_path
|
file_path = event.src_path
|
||||||
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self.remove_file_from_collection(file_path), self.loop
|
||||||
|
)
|
||||||
logging.info(f"File deleted: {file_path}")
|
logging.info(f"File deleted: {file_path}")
|
||||||
|
|
||||||
def on_moved(self, event):
|
def on_moved(self, event):
|
||||||
@ -508,37 +555,43 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
try:
|
try:
|
||||||
# Remove existing entries for this file
|
# Remove existing entries for this file
|
||||||
existing_results = self.collection.get(where={"path": file_path})
|
existing_results = self.collection.get(where={"path": file_path})
|
||||||
if existing_results and 'ids' in existing_results and existing_results['ids']:
|
if (
|
||||||
self.collection.delete(ids=existing_results['ids'])
|
existing_results
|
||||||
|
and "ids" in existing_results
|
||||||
|
and existing_results["ids"]
|
||||||
|
):
|
||||||
|
self.collection.delete(ids=existing_results["ids"])
|
||||||
|
|
||||||
extensions = (".docx", ".xlsx", ".xls", ".pdf")
|
extensions = (".docx", ".xlsx", ".xls", ".pdf")
|
||||||
if file_path.endswith(extensions):
|
if file_path.endswith(extensions):
|
||||||
p = Path(file_path)
|
p = Path(file_path)
|
||||||
p_as_md = p.with_suffix(".md")
|
p_as_md = p.with_suffix(".md")
|
||||||
if p_as_md.exists():
|
if p_as_md.exists():
|
||||||
logging.info(f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}")
|
logging.info(
|
||||||
|
f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}"
|
||||||
|
)
|
||||||
|
|
||||||
# If file_path.md doesn't exist or file_path is newer than file_path.md,
|
# If file_path.md doesn't exist or file_path is newer than file_path.md,
|
||||||
# fire off markitdown
|
# fire off markitdown
|
||||||
if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime):
|
if (not p_as_md.exists()) or (
|
||||||
|
p.stat().st_mtime > p_as_md.stat().st_mtime
|
||||||
|
):
|
||||||
self._markitdown(file_path, p_as_md)
|
self._markitdown(file_path, p_as_md)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create document object in LangChain format
|
chunks = self._markdown_chunker.process_file(file_path)
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
if not chunks:
|
||||||
content = f.read()
|
return
|
||||||
|
|
||||||
# Extract top-level directory
|
# Extract top-level directory
|
||||||
rel_path = os.path.relpath(file_path, self.watch_directory)
|
rel_path = os.path.relpath(file_path, self.watch_directory)
|
||||||
top_level_dir = rel_path.split(os.sep)[0]
|
path_parts = rel_path.split(os.sep)
|
||||||
|
top_level_dir = path_parts[0]
|
||||||
document = Document(
|
# file_name = path_parts[-1]
|
||||||
page_content=content,
|
for i, chunk in enumerate(chunks):
|
||||||
metadata={"doc_type": top_level_dir, "path": file_path}
|
chunk["metadata"]["doc_type"] = top_level_dir
|
||||||
)
|
# with open(f"src/tmp/{file_name}.{i}", "w") as f:
|
||||||
|
# f.write(json.dumps(chunk, indent=2))
|
||||||
# Create chunks
|
|
||||||
chunks = self.text_splitter.split_documents([document])
|
|
||||||
|
|
||||||
# Add chunks to collection
|
# Add chunks to collection
|
||||||
self.add_embeddings_to_collection(chunks)
|
self.add_embeddings_to_collection(chunks)
|
||||||
@ -547,30 +600,40 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error updating document in collection: {e}")
|
logging.error(f"Error updating document in collection: {e}")
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
async def initialize_collection(self):
|
async def initialize_collection(self):
|
||||||
"""Initialize the collection with all documents from the watch directory."""
|
"""Initialize the collection with all documents from the watch directory."""
|
||||||
# Process all files regardless of hash state
|
# Process all files regardless of hash state
|
||||||
num_processed = await self.scan_directory(process_all=True)
|
num_processed = await self.scan_directory(process_all=True)
|
||||||
|
|
||||||
logging.info(f"Vectorstore initialized with {self.collection.count()} documents")
|
logging.info(
|
||||||
|
f"Vectorstore initialized with {self.collection.count()} documents"
|
||||||
|
)
|
||||||
|
|
||||||
self._update_umaps()
|
self._update_umaps()
|
||||||
|
|
||||||
# Show stats
|
# Show stats
|
||||||
try:
|
try:
|
||||||
all_metadata = self.collection.get()['metadatas']
|
all_metadata = self.collection.get()["metadatas"]
|
||||||
if all_metadata:
|
if all_metadata:
|
||||||
doc_types = set(m.get('doc_type', 'unknown') for m in all_metadata)
|
doc_types = set(m.get("doc_type", "unknown") for m in all_metadata)
|
||||||
logging.info(f"Document types: {doc_types}")
|
logging.info(f"Document types: {doc_types}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error getting document types: {e}")
|
logging.error(f"Error getting document types: {e}")
|
||||||
|
|
||||||
return num_processed
|
return num_processed
|
||||||
|
|
||||||
|
|
||||||
# Function to start the file watcher
|
# Function to start the file watcher
|
||||||
def start_file_watcher(llm, watch_directory, persist_directory=None,
|
def start_file_watcher(
|
||||||
collection_name="documents", initialize=False, recreate=False):
|
llm,
|
||||||
|
watch_directory,
|
||||||
|
persist_directory=None,
|
||||||
|
collection_name="documents",
|
||||||
|
initialize=False,
|
||||||
|
recreate=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Start watching a directory for file changes.
|
Start watching a directory for file changes.
|
||||||
|
|
||||||
@ -590,7 +653,7 @@ def start_file_watcher(llm, watch_directory, persist_directory=None,
|
|||||||
loop=loop,
|
loop=loop,
|
||||||
persist_directory=persist_directory,
|
persist_directory=persist_directory,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
recreate=recreate
|
recreate=recreate,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process all files if:
|
# Process all files if:
|
||||||
@ -613,6 +676,7 @@ def start_file_watcher(llm, watch_directory, persist_directory=None,
|
|||||||
logging.info(f"Started watching directory: {watch_directory}")
|
logging.info(f"Started watching directory: {watch_directory}")
|
||||||
return observer, file_watcher
|
return observer, file_watcher
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# When running directly, use absolute imports
|
# When running directly, use absolute imports
|
||||||
import defines
|
import defines
|
||||||
|
@ -4,9 +4,12 @@ import logging
|
|||||||
|
|
||||||
from . import defines
|
from . import defines
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(level=defines.logging_level) -> logging.Logger:
|
def setup_logging(level=defines.logging_level) -> logging.Logger:
|
||||||
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
||||||
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
|
warnings.filterwarnings(
|
||||||
|
"ignore", message="Overriding a previously registered kernel"
|
||||||
|
)
|
||||||
warnings.filterwarnings("ignore", message="Warning only once for all operators")
|
warnings.filterwarnings("ignore", message="Warning only once for all operators")
|
||||||
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
|
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
|
||||||
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
|
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
|
||||||
@ -21,11 +24,17 @@ def setup_logging(level=defines.logging_level) -> logging.Logger:
|
|||||||
level=numeric_level,
|
level=numeric_level,
|
||||||
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
force=True
|
force=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now reduce verbosity for FastAPI, Uvicorn, Starlette
|
# Now reduce verbosity for FastAPI, Uvicorn, Starlette
|
||||||
for noisy_logger in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi", "starlette"):
|
for noisy_logger in (
|
||||||
|
"uvicorn",
|
||||||
|
"uvicorn.error",
|
||||||
|
"uvicorn.access",
|
||||||
|
"fastapi",
|
||||||
|
"starlette",
|
||||||
|
):
|
||||||
# for noisy_logger in ("starlette"):
|
# for noisy_logger in ("starlette"):
|
||||||
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
|
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
@ -7,8 +7,8 @@ from .. import defines
|
|||||||
logger = setup_logging(level=defines.logging_level)
|
logger = setup_logging(level=defines.logging_level)
|
||||||
|
|
||||||
# Dynamically import all names from basetools listed in tools_all
|
# Dynamically import all names from basetools listed in tools_all
|
||||||
module = importlib.import_module('.basetools', package=__package__)
|
module = importlib.import_module(".basetools", package=__package__)
|
||||||
for name in tool_functions:
|
for name in tool_functions:
|
||||||
globals()[name] = getattr(module, name)
|
globals()[name] = getattr(module, name)
|
||||||
|
|
||||||
__all__ = [ 'tools', 'llm_tools', 'enabled_tools', 'tool_functions' ]
|
__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||||
|
@ -13,6 +13,7 @@ import requests
|
|||||||
import yfinance as yf # type: ignore
|
import yfinance as yf # type: ignore
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
def WeatherForecast(city, state, country="USA"):
|
def WeatherForecast(city, state, country="USA"):
|
||||||
"""
|
"""
|
||||||
@ -42,11 +43,12 @@ def WeatherForecast(city, state, country="USA"):
|
|||||||
# Step 3: Get the forecast data from the grid endpoint
|
# Step 3: Get the forecast data from the grid endpoint
|
||||||
forecast = get_forecast(grid_endpoint)
|
forecast = get_forecast(grid_endpoint)
|
||||||
|
|
||||||
if not forecast['location']:
|
if not forecast["location"]:
|
||||||
forecast['location'] = location
|
forecast["location"] = location
|
||||||
|
|
||||||
return forecast
|
return forecast
|
||||||
|
|
||||||
|
|
||||||
def get_coordinates(location):
|
def get_coordinates(location):
|
||||||
"""Convert a location string to latitude and longitude using Nominatim geocoder."""
|
"""Convert a location string to latitude and longitude using Nominatim geocoder."""
|
||||||
try:
|
try:
|
||||||
@ -59,7 +61,7 @@ def get_coordinates(location):
|
|||||||
if location_data:
|
if location_data:
|
||||||
return {
|
return {
|
||||||
"latitude": location_data.latitude,
|
"latitude": location_data.latitude,
|
||||||
"longitude": location_data.longitude
|
"longitude": location_data.longitude,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
print(f"Location not found: {location}")
|
print(f"Location not found: {location}")
|
||||||
@ -68,6 +70,7 @@ def get_coordinates(location):
|
|||||||
print(f"Error getting coordinates: {e}")
|
print(f"Error getting coordinates: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_grid_endpoint(coordinates):
|
def get_grid_endpoint(coordinates):
|
||||||
"""Get the grid endpoint from weather.gov based on coordinates."""
|
"""Get the grid endpoint from weather.gov based on coordinates."""
|
||||||
try:
|
try:
|
||||||
@ -77,7 +80,7 @@ def get_grid_endpoint(coordinates):
|
|||||||
# Define headers for the API request
|
# Define headers for the API request
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
||||||
"Accept": "application/geo+json"
|
"Accept": "application/geo+json",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Make the request to get the grid endpoint
|
# Make the request to get the grid endpoint
|
||||||
@ -94,15 +97,17 @@ def get_grid_endpoint(coordinates):
|
|||||||
print(f"Error in get_grid_endpoint: {e}")
|
print(f"Error in get_grid_endpoint: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Weather related function
|
# Weather related function
|
||||||
|
|
||||||
|
|
||||||
def get_forecast(grid_endpoint):
|
def get_forecast(grid_endpoint):
|
||||||
"""Get the forecast data from the grid endpoint."""
|
"""Get the forecast data from the grid endpoint."""
|
||||||
try:
|
try:
|
||||||
# Define headers for the API request
|
# Define headers for the API request
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
||||||
"Accept": "application/geo+json"
|
"Accept": "application/geo+json",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Make the request to get the forecast
|
# Make the request to get the forecast
|
||||||
@ -116,21 +121,25 @@ def get_forecast(grid_endpoint):
|
|||||||
|
|
||||||
# Process the forecast data into a simpler format
|
# Process the forecast data into a simpler format
|
||||||
forecast = {
|
forecast = {
|
||||||
"location": data["properties"].get("relativeLocation", {}).get("properties", {}),
|
"location": data["properties"]
|
||||||
|
.get("relativeLocation", {})
|
||||||
|
.get("properties", {}),
|
||||||
"updated": data["properties"].get("updated", ""),
|
"updated": data["properties"].get("updated", ""),
|
||||||
"periods": []
|
"periods": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
for period in periods:
|
for period in periods:
|
||||||
forecast["periods"].append({
|
forecast["periods"].append(
|
||||||
|
{
|
||||||
"name": period.get("name", ""),
|
"name": period.get("name", ""),
|
||||||
"temperature": period.get("temperature", ""),
|
"temperature": period.get("temperature", ""),
|
||||||
"temperatureUnit": period.get("temperatureUnit", ""),
|
"temperatureUnit": period.get("temperatureUnit", ""),
|
||||||
"windSpeed": period.get("windSpeed", ""),
|
"windSpeed": period.get("windSpeed", ""),
|
||||||
"windDirection": period.get("windDirection", ""),
|
"windDirection": period.get("windDirection", ""),
|
||||||
"shortForecast": period.get("shortForecast", ""),
|
"shortForecast": period.get("shortForecast", ""),
|
||||||
"detailedForecast": period.get("detailedForecast", "")
|
"detailedForecast": period.get("detailedForecast", ""),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return forecast
|
return forecast
|
||||||
else:
|
else:
|
||||||
@ -140,6 +149,7 @@ def get_forecast(grid_endpoint):
|
|||||||
print(f"Error in get_forecast: {e}")
|
print(f"Error in get_forecast: {e}")
|
||||||
return {"error": f"Exception: {str(e)}"}
|
return {"error": f"Exception: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
# Example usage
|
# Example usage
|
||||||
# def do_weather():
|
# def do_weather():
|
||||||
# city = input("Enter city: ")
|
# city = input("Enter city: ")
|
||||||
@ -166,34 +176,31 @@ def get_forecast(grid_endpoint):
|
|||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
|
||||||
def TickerValue(ticker_symbols):
|
def TickerValue(ticker_symbols):
|
||||||
api_key = os.getenv("TWELVEDATA_API_KEY", "")
|
api_key = os.getenv("TWELVEDATA_API_KEY", "")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
return {"error": f"Error fetching data: No API key for TwelveData"}
|
return {"error": f"Error fetching data: No API key for TwelveData"}
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for ticker_symbol in ticker_symbols.split(','):
|
for ticker_symbol in ticker_symbols.split(","):
|
||||||
ticker_symbol = ticker_symbol.strip()
|
ticker_symbol = ticker_symbol.strip()
|
||||||
if ticker_symbol == "":
|
if ticker_symbol == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
url = f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
|
url = (
|
||||||
|
f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
|
||||||
|
)
|
||||||
|
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
if "price" in data:
|
if "price" in data:
|
||||||
logging.info(f"TwelveData: {ticker_symbol} {data}")
|
logging.info(f"TwelveData: {ticker_symbol} {data}")
|
||||||
results.append({
|
results.append({"symbol": ticker_symbol, "price": float(data["price"])})
|
||||||
"symbol": ticker_symbol,
|
|
||||||
"price": float(data["price"])
|
|
||||||
})
|
|
||||||
else:
|
else:
|
||||||
logging.error(f"TwelveData: {data}")
|
logging.error(f"TwelveData: {data}")
|
||||||
results.append({
|
results.append({"symbol": ticker_symbol, "price": "Unavailable"})
|
||||||
"symbol": ticker_symbol,
|
|
||||||
"price": "Unavailable"
|
|
||||||
})
|
|
||||||
|
|
||||||
return results[0] if len(results) == 1 else results
|
return results[0] if len(results) == 1 else results
|
||||||
|
|
||||||
@ -210,7 +217,7 @@ def yfTickerValue(ticker_symbols):
|
|||||||
dict: Current stock information including price
|
dict: Current stock information including price
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
for ticker_symbol in ticker_symbols.split(','):
|
for ticker_symbol in ticker_symbols.split(","):
|
||||||
ticker_symbol = ticker_symbol.strip()
|
ticker_symbol = ticker_symbol.strip()
|
||||||
if ticker_symbol == "":
|
if ticker_symbol == "":
|
||||||
continue
|
continue
|
||||||
@ -226,19 +233,23 @@ def yfTickerValue(ticker_symbols):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Get the latest closing price
|
# Get the latest closing price
|
||||||
latest_price = ticker_data['Close'].iloc[-1]
|
latest_price = ticker_data["Close"].iloc[-1]
|
||||||
|
|
||||||
# Get some additional info
|
# Get some additional info
|
||||||
results.append({ 'symbol': ticker_symbol, 'price': latest_price })
|
results.append({"symbol": ticker_symbol, "price": latest_price})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logging.error(f"Error fetching data for {ticker_symbol}: {e}")
|
logging.error(f"Error fetching data for {ticker_symbol}: {e}")
|
||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
results.append({"error": f"Error fetching data for {ticker_symbol}: {str(e)}"})
|
results.append(
|
||||||
|
{"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}
|
||||||
|
)
|
||||||
|
|
||||||
return results[0] if len(results) == 1 else results
|
return results[0] if len(results) == 1 else results
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
def DateTime(timezone="America/Los_Angeles"):
|
def DateTime(timezone="America/Los_Angeles"):
|
||||||
"""
|
"""
|
||||||
@ -252,8 +263,8 @@ def DateTime(timezone="America/Los_Angeles"):
|
|||||||
str: Current date and time with timezone in the format YYYY-MM-DDTHH:MM:SS+HH:MM
|
str: Current date and time with timezone in the format YYYY-MM-DDTHH:MM:SS+HH:MM
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if timezone == 'system' or timezone == '' or not timezone:
|
if timezone == "system" or timezone == "" or not timezone:
|
||||||
timezone = 'America/Los_Angeles'
|
timezone = "America/Los_Angeles"
|
||||||
# Get current UTC time (timezone-aware)
|
# Get current UTC time (timezone-aware)
|
||||||
local_tz = pytz.timezone("America/Los_Angeles")
|
local_tz = pytz.timezone("America/Los_Angeles")
|
||||||
local_now = datetime.now(tz=local_tz)
|
local_now = datetime.now(tz=local_tz)
|
||||||
@ -264,7 +275,8 @@ def DateTime(timezone="America/Los_Angeles"):
|
|||||||
|
|
||||||
return target_time.isoformat()
|
return target_time.isoformat()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {'error': f"Invalid timezone {timezone}: {str(e)}"}
|
return {"error": f"Invalid timezone {timezone}: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
async def AnalyzeSite(llm, model: str, url: str, question: str):
|
async def AnalyzeSite(llm, model: str, url: str, question: str):
|
||||||
"""
|
"""
|
||||||
@ -310,16 +322,18 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
|
|||||||
|
|
||||||
# Generate summary using Ollama
|
# Generate summary using Ollama
|
||||||
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
||||||
response = llm.generate(model=model,
|
response = llm.generate(
|
||||||
|
model=model,
|
||||||
system="You are given the contents of {url}. Answer the question about the contents",
|
system="You are given the contents of {url}. Answer the question about the contents",
|
||||||
prompt=prompt)
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
# logging.info(response["response"])
|
# logging.info(response["response"])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"source": "summarizer-llm",
|
"source": "summarizer-llm",
|
||||||
"content": response["response"],
|
"content": response["response"],
|
||||||
"metadata": DateTime()
|
"metadata": DateTime(),
|
||||||
}
|
}
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
@ -331,7 +345,8 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
tools = [ {
|
tools = [
|
||||||
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "TickerValue",
|
"name": "TickerValue",
|
||||||
@ -345,10 +360,11 @@ tools = [ {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["ticker"],
|
"required": ["ticker"],
|
||||||
"additionalProperties": False
|
"additionalProperties": False,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}, {
|
},
|
||||||
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "AnalyzeSite",
|
"name": "AnalyzeSite",
|
||||||
@ -366,27 +382,28 @@ tools = [ {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["url", "question"],
|
"required": ["url", "question"],
|
||||||
"additionalProperties": False
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
"returns": {
|
"returns": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"source": {
|
"source": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Identifier for the source LLM"
|
"description": "Identifier for the source LLM",
|
||||||
},
|
},
|
||||||
"content": {
|
"content": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The complete response from the second LLM"
|
"description": "The complete response from the second LLM",
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"description": "Additional information about the response"
|
"description": "Additional information about the response",
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}, {
|
},
|
||||||
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "DateTime",
|
"name": "DateTime",
|
||||||
@ -396,13 +413,14 @@ tools = [ {
|
|||||||
"properties": {
|
"properties": {
|
||||||
"timezone": {
|
"timezone": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'."
|
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": []
|
"required": [],
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}, {
|
},
|
||||||
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "WeatherForecast",
|
"name": "WeatherForecast",
|
||||||
@ -413,27 +431,30 @@ tools = [ {
|
|||||||
"city": {
|
"city": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
|
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
|
||||||
"minLength": 2
|
"minLength": 2,
|
||||||
},
|
},
|
||||||
"state": {
|
"state": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
|
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
|
||||||
"minLength": 2
|
"minLength": 2,
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"required": ["city", "state"],
|
"required": ["city", "state"],
|
||||||
"additionalProperties": False
|
"additionalProperties": False,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}]
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def llm_tools(tools):
|
def llm_tools(tools):
|
||||||
return [tool for tool in tools if tool.get("enabled", False) == True]
|
return [tool for tool in tools if tool.get("enabled", False) == True]
|
||||||
|
|
||||||
|
|
||||||
def enabled_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def enabled_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
return [{**tool, "enabled": True} for tool in tools]
|
return [{**tool, "enabled": True} for tool in tools]
|
||||||
|
|
||||||
tool_functions = [ 'DateTime', 'WeatherForecast', 'TickerValue', 'AnalyzeSite' ]
|
|
||||||
__all__ = [ 'tools', 'llm_tools', 'enabled_tools', 'tool_functions' ]
|
|
||||||
#__all__.extend(__tool_functions__) # type: ignore
|
|
||||||
|
|
||||||
|
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite"]
|
||||||
|
__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||||
|
# __all__.extend(__tool_functions__) # type: ignore
|
||||||
|
Loading…
x
Reference in New Issue
Block a user