Reformatted all content to black

This commit is contained in:
James Ketr 2025-05-14 11:31:31 -07:00
parent a1798b58ac
commit e044f9c639
34 changed files with 3223 additions and 2753 deletions

View File

@ -88,6 +88,10 @@ button {
flex-grow: 1; flex-grow: 1;
} }
.MessageContent div > p:first-child {
margin-top: 0;
}
.MenuCard.MuiCard-root { .MenuCard.MuiCard-root {
display: flex; display: flex;
flex-direction: column; flex-direction: column;

View File

@ -30,7 +30,6 @@ const BackstoryTextField = React.forwardRef<BackstoryTextFieldRef, BackstoryText
const shadowRef = useRef<HTMLTextAreaElement>(null); const shadowRef = useRef<HTMLTextAreaElement>(null);
const [editValue, setEditValue] = useState<string>(value); const [editValue, setEditValue] = useState<string>(value);
console.log({ value, placeholder, editValue });
// Sync editValue with prop value if it changes externally // Sync editValue with prop value if it changes externally
useEffect(() => { useEffect(() => {
setEditValue(value || ""); setEditValue(value || "");

View File

@ -158,7 +158,7 @@ function ChatBubble(props: ChatBubbleProps) {
}; };
// Render Accordion for expandable content // Render Accordion for expandable content
if (expandable || (role === 'content' && title)) { if (expandable || title) {
// Determine if Accordion is controlled // Determine if Accordion is controlled
const isControlled = typeof expanded === 'boolean' && typeof onExpand === 'function'; const isControlled = typeof expanded === 'boolean' && typeof onExpand === 'function';

View File

@ -1,13 +1,15 @@
import React, { useState, useImperativeHandle, forwardRef, useEffect, useRef, useCallback } from 'react'; import React, { useState, useImperativeHandle, forwardRef, useEffect, useRef, useCallback } from 'react';
import Typography from '@mui/material/Typography'; import Typography from '@mui/material/Typography';
import Tooltip from '@mui/material/Tooltip'; import Tooltip from '@mui/material/Tooltip';
import IconButton from '@mui/material/IconButton';
import Button from '@mui/material/Button'; import Button from '@mui/material/Button';
import Box from '@mui/material/Box'; import Box from '@mui/material/Box';
import SendIcon from '@mui/icons-material/Send'; import SendIcon from '@mui/icons-material/Send';
import CancelIcon from '@mui/icons-material/Cancel';
import { SxProps, Theme } from '@mui/material'; import { SxProps, Theme } from '@mui/material';
import PropagateLoader from "react-spinners/PropagateLoader"; import PropagateLoader from "react-spinners/PropagateLoader";
import { Message, MessageList, MessageData } from './Message'; import { Message, MessageList, BackstoryMessage } from './Message';
import { ContextStatus } from './ContextStatus'; import { ContextStatus } from './ContextStatus';
import { Scrollable } from './Scrollable'; import { Scrollable } from './Scrollable';
import { DeleteConfirmation } from './DeleteConfirmation'; import { DeleteConfirmation } from './DeleteConfirmation';
@ -17,7 +19,7 @@ import { BackstoryTextField, BackstoryTextFieldRef } from './BackstoryTextField'
import { BackstoryElementProps } from './BackstoryTab'; import { BackstoryElementProps } from './BackstoryTab';
import { connectionBase } from './Global'; import { connectionBase } from './Global';
const loadingMessage: MessageData = { "role": "status", "content": "Establishing connection with server..." }; const loadingMessage: BackstoryMessage = { "role": "status", "content": "Establishing connection with server..." };
type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check'; type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check';
@ -25,24 +27,6 @@ interface ConversationHandle {
submitQuery: (prompt: string, options?: QueryOptions) => void; submitQuery: (prompt: string, options?: QueryOptions) => void;
fetchHistory: () => void; fetchHistory: () => void;
} }
interface BackstoryMessage {
prompt: string;
preamble: {};
status: string;
full_content: string;
response: string; // Set when status === 'done' or 'error'
chunk: string; // Used when status === 'streaming'
metadata: {
rag: { documents: [] };
tools: string[];
eval_count: number;
eval_duration: number;
prompt_eval_count: number;
prompt_eval_duration: number;
};
actions: string[];
timestamp: string;
};
interface ConversationProps extends BackstoryElementProps { interface ConversationProps extends BackstoryElementProps {
className?: string, // Override default className className?: string, // Override default className
@ -59,7 +43,7 @@ interface ConversationProps extends BackstoryElementProps {
messageFilter?: ((messages: MessageList) => MessageList) | undefined, // Filter callback to determine which Messages to display in Conversation messageFilter?: ((messages: MessageList) => MessageList) | undefined, // Filter callback to determine which Messages to display in Conversation
messages?: MessageList, // messages?: MessageList, //
sx?: SxProps<Theme>, sx?: SxProps<Theme>,
onResponse?: ((message: MessageData) => void) | undefined, // Event called when a query completes (provides messages) onResponse?: ((message: BackstoryMessage) => void) | undefined, // Event called when a query completes (provides messages)
}; };
const Conversation = forwardRef<ConversationHandle, ConversationProps>(({ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
@ -87,8 +71,8 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
const [countdown, setCountdown] = useState<number>(0); const [countdown, setCountdown] = useState<number>(0);
const [conversation, setConversation] = useState<MessageList>([]); const [conversation, setConversation] = useState<MessageList>([]);
const [filteredConversation, setFilteredConversation] = useState<MessageList>([]); const [filteredConversation, setFilteredConversation] = useState<MessageList>([]);
const [processingMessage, setProcessingMessage] = useState<MessageData | undefined>(undefined); const [processingMessage, setProcessingMessage] = useState<BackstoryMessage | undefined>(undefined);
const [streamingMessage, setStreamingMessage] = useState<MessageData | undefined>(undefined); const [streamingMessage, setStreamingMessage] = useState<BackstoryMessage | undefined>(undefined);
const timerRef = useRef<any>(null); const timerRef = useRef<any>(null);
const [contextStatus, setContextStatus] = useState<ContextStatus>({ context_used: 0, max_context: 0 }); const [contextStatus, setContextStatus] = useState<ContextStatus>({ context_used: 0, max_context: 0 });
const [contextWarningShown, setContextWarningShown] = useState<boolean>(false); const [contextWarningShown, setContextWarningShown] = useState<boolean>(false);
@ -96,6 +80,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
const conversationRef = useRef<MessageList>([]); const conversationRef = useRef<MessageList>([]);
const viewableElementRef = useRef<HTMLDivElement>(null); const viewableElementRef = useRef<HTMLDivElement>(null);
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null); const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
const stopRef = useRef(false);
// Keep the ref updated whenever items changes // Keep the ref updated whenever items changes
useEffect(() => { useEffect(() => {
@ -181,14 +166,25 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
const backstoryMessages: BackstoryMessage[] = messages; const backstoryMessages: BackstoryMessage[] = messages;
setConversation(backstoryMessages.flatMap((backstoryMessage: BackstoryMessage) => [{ setConversation(backstoryMessages.flatMap((backstoryMessage: BackstoryMessage) => {
role: 'user', if (backstoryMessage.status === "partial") {
content: backstoryMessage.prompt || "", return [{
}, { ...backstoryMessage,
...backstoryMessage, role: "assistant",
role: backstoryMessage.status === "done" ? "assistant" : backstoryMessage.status, content: backstoryMessage.response || "",
expanded: false,
expandable: true,
}]
}
return [{
role: 'user',
content: backstoryMessage.prompt || "",
}, {
...backstoryMessage,
role: ['done'].includes(backstoryMessage.status || "") ? "assistant" : backstoryMessage.status,
content: backstoryMessage.response || "", content: backstoryMessage.response || "",
}] as MessageList)); }] as MessageList;
}));
setNoInteractions(false); setNoInteractions(false);
} }
setProcessingMessage(undefined); setProcessingMessage(undefined);
@ -294,6 +290,11 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
} }
}; };
const cancelQuery = () => {
console.log("Stop query");
stopRef.current = true;
};
const sendQuery = async (request: string, options?: QueryOptions) => { const sendQuery = async (request: string, options?: QueryOptions) => {
request = request.trim(); request = request.trim();
@ -308,6 +309,8 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
return; return;
} }
stopRef.current = false;
setNoInteractions(false); setNoInteractions(false);
setConversation([ setConversation([
@ -325,12 +328,10 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
try { try {
setProcessing(true); setProcessing(true);
// Create a unique ID for the processing message
const processingId = Date.now().toString();
// Add initial processing message // Add initial processing message
setProcessingMessage( setProcessingMessage(
{ role: 'status', content: 'Submitting request...', id: processingId, isProcessing: true } { role: 'status', content: 'Submitting request...', disableCopy: true }
); );
// Add a small delay to ensure React has time to update the UI // Add a small delay to ensure React has time to update the UI
@ -379,17 +380,20 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
switch (update.status) { switch (update.status) {
case 'done': case 'done':
console.log('Done processing:', update); case 'partial':
stopCountdown(); if (update.status === 'done') stopCountdown();
setStreamingMessage(undefined); if (update.status === 'done') setStreamingMessage(undefined);
setProcessingMessage(undefined); if (update.status === 'done') setProcessingMessage(undefined);
const backstoryMessage: BackstoryMessage = update; const backstoryMessage: BackstoryMessage = update;
setConversation([ setConversation([
...conversationRef.current, { ...conversationRef.current, {
...backstoryMessage, ...backstoryMessage,
role: 'assistant', role: 'assistant',
origin: type, origin: type,
prompt: ['done', 'partial'].includes(update.status) ? update.prompt : '',
content: backstoryMessage.response || "", content: backstoryMessage.response || "",
expanded: update.status === "done" ? true : false,
expandable: true,
}] as MessageList); }] as MessageList);
// Add a small delay to ensure React has time to update the UI // Add a small delay to ensure React has time to update the UI
await new Promise(resolve => setTimeout(resolve, 0)); await new Promise(resolve => setTimeout(resolve, 0));
@ -424,9 +428,9 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
// Update processing message with immediate re-render // Update processing message with immediate re-render
if (update.status === "streaming") { if (update.status === "streaming") {
streaming_response += update.chunk streaming_response += update.chunk
setStreamingMessage({ role: update.status, content: streaming_response }); setStreamingMessage({ role: update.status, content: streaming_response, disableCopy: true });
} else { } else {
setProcessingMessage({ role: update.status, content: update.response }); setProcessingMessage({ role: update.status, content: update.response, disableCopy: true });
/* Reset stream on non streaming message */ /* Reset stream on non streaming message */
streaming_response = "" streaming_response = ""
} }
@ -437,12 +441,11 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
} }
} }
while (true) { while (!stopRef.current) {
const { done, value } = await reader.read(); const { done, value } = await reader.read();
if (done) { if (done) {
break; break;
} }
const chunk = decoder.decode(value, { stream: true }); const chunk = decoder.decode(value, { stream: true });
// Process each complete line immediately // Process each complete line immediately
@ -470,26 +473,32 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
} }
} }
if (stopRef.current) {
await reader.cancel();
setProcessingMessage(undefined);
setStreamingMessage(undefined);
setSnack("Processing cancelled", "warning");
}
stopCountdown(); stopCountdown();
setProcessing(false); setProcessing(false);
stopRef.current = false;
} catch (error) { } catch (error) {
console.error('Fetch error:', error); console.error('Fetch error:', error);
setSnack("Unable to process query", "error"); setSnack("Unable to process query", "error");
setProcessingMessage({ role: 'error', content: "Unable to process query" }); setProcessingMessage({ role: 'error', content: "Unable to process query", disableCopy: true });
setTimeout(() => { setTimeout(() => {
setProcessingMessage(undefined); setProcessingMessage(undefined);
}, 5000); }, 5000);
stopRef.current = false;
setProcessing(false); setProcessing(false);
stopCountdown(); stopCountdown();
// Add a small delay to ensure React has time to update the UI return;
await new Promise(resolve => setTimeout(resolve, 0));
} }
}; };
return ( return (
<Scrollable <Scrollable
className={className || "Conversation"} className={`${className || ""} Conversation`}
autoscroll autoscroll
textFieldRef={viewableElementRef} textFieldRef={viewableElementRef}
fallbackThreshold={0.5} fallbackThreshold={0.5}
@ -564,6 +573,20 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
</Button> </Button>
</span> </span>
</Tooltip> </Tooltip>
<Tooltip title="Cancel">
<span style={{ display: "flex" }}> { /* This span is used to wrap the IconButton to ensure Tooltip works even when disabled */}
<IconButton
aria-label="cancel"
onClick={() => { cancelQuery(); }}
sx={{ display: "flex", margin: 'auto 0px' }}
size="large"
edge="start"
disabled={stopRef.current || sessionId === undefined || processing === false}
>
<CancelIcon />
</IconButton>
</span>
</Tooltip>
</Box> </Box>
</Box> </Box>
{(noInteractions || !hideDefaultPrompts) && defaultPrompts !== undefined && defaultPrompts.length && {(noInteractions || !hideDefaultPrompts) && defaultPrompts !== undefined && defaultPrompts.length &&

View File

@ -47,11 +47,18 @@ type MessageRoles =
'thinking' | 'thinking' |
'user'; 'user';
type MessageData = { type BackstoryMessage = {
// Only two required fields
role: MessageRoles, role: MessageRoles,
content: string, content: string,
status?: string, // streaming, done, error... // Rest are optional
response?: string, prompt?: string;
preamble?: {};
status?: string;
full_content?: string;
response?: string; // Set when status === 'done', 'partial', or 'error'
chunk?: string; // Used when status === 'streaming'
timestamp?: string;
disableCopy?: boolean, disableCopy?: boolean,
user?: string, user?: string,
title?: string, title?: string,
@ -84,11 +91,11 @@ interface MessageMetaData {
setSnack: SetSnackType, setSnack: SetSnackType,
} }
type MessageList = MessageData[]; type MessageList = BackstoryMessage[];
interface MessageProps extends BackstoryElementProps { interface MessageProps extends BackstoryElementProps {
sx?: SxProps<Theme>, sx?: SxProps<Theme>,
message: MessageData, message: BackstoryMessage,
expanded?: boolean, expanded?: boolean,
onExpand?: (open: boolean) => void, onExpand?: (open: boolean) => void,
className?: string, className?: string,
@ -237,7 +244,7 @@ const MessageMeta = (props: MessageMetaProps) => {
}; };
const Message = (props: MessageProps) => { const Message = (props: MessageProps) => {
const { message, submitQuery, sx, className, onExpand, expanded, sessionId, setSnack } = props; const { message, submitQuery, sx, className, onExpand, sessionId, setSnack } = props;
const [metaExpanded, setMetaExpanded] = useState<boolean>(false); const [metaExpanded, setMetaExpanded] = useState<boolean>(false);
const textFieldRef = useRef(null); const textFieldRef = useRef(null);
@ -254,14 +261,16 @@ const Message = (props: MessageProps) => {
return (<></>); return (<></>);
} }
const formattedContent = message.content.trim() || "Waiting for LLM to spool up..."; const formattedContent = message.content.trim();
if (formattedContent === "") {
return (<></>);
}
return ( return (
<ChatBubble <ChatBubble
className={className || "Message"} className={`${className || ""} Message Message-${message.role}`}
{...message} {...message}
onExpand={onExpand} onExpand={onExpand}
expanded={expanded}
sx={{ sx={{
display: "flex", display: "flex",
flexDirection: "column", flexDirection: "column",
@ -273,34 +282,24 @@ const Message = (props: MessageProps) => {
...sx, ...sx,
}}> }}>
<CardContent ref={textFieldRef} sx={{ position: "relative", display: "flex", flexDirection: "column", overflowX: "auto", m: 0, p: 0, paddingBottom: '0px !important' }}> <CardContent ref={textFieldRef} sx={{ position: "relative", display: "flex", flexDirection: "column", overflowX: "auto", m: 0, p: 0, paddingBottom: '0px !important' }}>
{message.role !== 'user' ? <Scrollable
<Scrollable className="MessageContent"
className="MessageContent" autoscroll
autoscroll fallbackThreshold={0.5}
fallbackThreshold={0.5} sx={{
sx={{ p: 0,
p: 0, m: 0,
m: 0, maxHeight: (message.role === "streaming") ? "20rem" : "unset",
maxHeight: (message.role === "streaming") ? "20rem" : "unset", display: "flex",
display: "flex", flexGrow: 1,
flexGrow: 1, overflow: "auto", /* Handles scrolling for the div */
overflow: "auto", /* Handles scrolling for the div */ }}
}} >
> <StyledMarkdown streaming={message.role === "streaming"} {...{ content: formattedContent, submitQuery, sessionId, setSnack }} />
<StyledMarkdown streaming={message.role === "streaming"} {...{ content: formattedContent, submitQuery, sessionId, setSnack }} /> </Scrollable>
</Scrollable>
:
<Typography
className="MessageContent"
ref={textFieldRef}
variant="body2"
sx={{ display: "flex", color: 'text.secondary' }}>
{message.content}
</Typography>
}
</CardContent> </CardContent>
<CardActions disableSpacing sx={{ display: "flex", flexDirection: "row", justifyContent: "space-between", alignItems: "center", width: "100%", p: 0, m: 0 }}> <CardActions disableSpacing sx={{ display: "flex", flexDirection: "row", justifyContent: "space-between", alignItems: "center", width: "100%", p: 0, m: 0 }}>
{(message.disableCopy === undefined || message.disableCopy === false) && ["assistant", "content"].includes(message.role) && <CopyBubble content={message.content} />} {(message.disableCopy === undefined || message.disableCopy === false) && <CopyBubble content={message.content} />}
{message.metadata && ( {message.metadata && (
<Box sx={{ display: "flex", alignItems: "center", gap: 1 }}> <Box sx={{ display: "flex", alignItems: "center", gap: 1 }}>
<Button variant="text" onClick={handleMetaExpandClick} sx={{ color: "darkgrey", p: 0 }}> <Button variant="text" onClick={handleMetaExpandClick} sx={{ color: "darkgrey", p: 0 }}>
@ -309,7 +308,7 @@ const Message = (props: MessageProps) => {
<ExpandMore <ExpandMore
expand={metaExpanded} expand={metaExpanded}
onClick={handleMetaExpandClick} onClick={handleMetaExpandClick}
aria-expanded={expanded} aria-expanded={message.expanded}
aria-label="show more" aria-label="show more"
> >
<ExpandMoreIcon /> <ExpandMoreIcon />
@ -331,7 +330,8 @@ const Message = (props: MessageProps) => {
export type { export type {
MessageProps, MessageProps,
MessageList, MessageList,
MessageData, BackstoryMessage,
MessageMetaData,
MessageRoles, MessageRoles,
}; };

View File

@ -7,7 +7,7 @@ import {
import { SxProps } from '@mui/material'; import { SxProps } from '@mui/material';
import { ChatQuery } from './ChatQuery'; import { ChatQuery } from './ChatQuery';
import { MessageList, MessageData } from './Message'; import { MessageList, BackstoryMessage } from './Message';
import { Conversation } from './Conversation'; import { Conversation } from './Conversation';
import { BackstoryPageProps } from './BackstoryTab'; import { BackstoryPageProps } from './BackstoryTab';
@ -62,11 +62,6 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
return []; return [];
} }
if (messages.length > 2) {
setHasResume(true);
setHasFacts(true);
}
if (messages.length > 0) { if (messages.length > 0) {
messages[0].role = 'content'; messages[0].role = 'content';
messages[0].title = 'Job Description'; messages[0].title = 'Job Description';
@ -74,6 +69,19 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
messages[0].expandable = true; messages[0].expandable = true;
} }
if (-1 !== messages.findIndex(m => m.status === 'done')) {
setHasResume(true);
setHasFacts(true);
}
return messages;
if (messages.length > 1) {
setHasResume(true);
setHasFacts(true);
}
if (messages.length > 3) { if (messages.length > 3) {
// messages[2] is Show job requirements // messages[2] is Show job requirements
messages[3].role = 'job-requirements'; messages[3].role = 'job-requirements';
@ -95,6 +103,8 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
return []; return [];
} }
return messages;
if (messages.length > 1) { if (messages.length > 1) {
// messages[0] is Show Qualifications // messages[0] is Show Qualifications
messages[1].role = 'qualifications'; messages[1].role = 'qualifications';
@ -139,7 +149,7 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
return filtered; return filtered;
}, []); }, []);
const jobResponse = useCallback(async (message: MessageData) => { const jobResponse = useCallback(async (message: BackstoryMessage) => {
console.log('onJobResponse', message); console.log('onJobResponse', message);
if (message.actions && message.actions.includes("job_description")) { if (message.actions && message.actions.includes("job_description")) {
await jobConversationRef.current.fetchHistory(); await jobConversationRef.current.fetchHistory();
@ -155,12 +165,12 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
} }
}, [setHasFacts, setHasResume, setActiveTab]); }, [setHasFacts, setHasResume, setActiveTab]);
const resumeResponse = useCallback((message: MessageData): void => { const resumeResponse = useCallback((message: BackstoryMessage): void => {
console.log('onResumeResponse', message); console.log('onResumeResponse', message);
setHasFacts(true); setHasFacts(true);
}, [setHasFacts]); }, [setHasFacts]);
const factsResponse = useCallback((message: MessageData): void => { const factsResponse = useCallback((message: BackstoryMessage): void => {
console.log('onFactsResponse', message); console.log('onFactsResponse', message);
}, []); }, []);
@ -199,7 +209,8 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
3. **Mapping Analysis**: Identifies legitimate matches between requirements and qualifications 3. **Mapping Analysis**: Identifies legitimate matches between requirements and qualifications
3. **Resume Generation**: Uses mapping output to create a tailored resume with evidence-based content 3. **Resume Generation**: Uses mapping output to create a tailored resume with evidence-based content
4. **Verification**: Performs fact-checking to catch any remaining fabrications 4. **Verification**: Performs fact-checking to catch any remaining fabrications
1. **Re-generation**: If verification does not pass, a second attempt is made to correct any issues` 1. **Re-generation**: If verification does not pass, a second attempt is made to correct any issues`,
disableCopy: true
}]; }];

View File

@ -16,4 +16,13 @@ pre:not(.MessageContent) {
.MuiMarkdown > div { .MuiMarkdown > div {
width: 100%; width: 100%;
}
.Message-streaming .MuiMarkdown ul,
.Message-streaming .MuiMarkdown h1,
.Message-streaming .MuiMarkdown p,
.Message-assistant .MuiMarkdown ul,
.Message-assistant .MuiMarkdown h1,
.Message-assistant .MuiMarkdown p {
color: white;
} }

View File

@ -37,14 +37,14 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
} }
if (className === "lang-json" && !streaming) { if (className === "lang-json" && !streaming) {
try { try {
const fixed = jsonrepair(content); let fixed = JSON.parse(jsonrepair(content));
return <Scrollable className="JsonViewScrollable"> return <Scrollable className="JsonViewScrollable">
<JsonView <JsonView
className="JsonView" className="JsonView"
style={{ style={{
...vscodeTheme, ...vscodeTheme,
fontSize: "0.8rem", fontSize: "0.8rem",
maxHeight: "20rem", maxHeight: "10rem",
padding: "14px 0", padding: "14px 0",
overflow: "hidden", overflow: "hidden",
width: "100%", width: "100%",
@ -53,9 +53,9 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
}} }}
displayDataTypes={false} displayDataTypes={false}
objectSortKeys={false} objectSortKeys={false}
collapsed={false} collapsed={true}
shortenTextAfterLength={100} shortenTextAfterLength={100}
value={JSON.parse(fixed)}> value={fixed}>
<JsonView.String <JsonView.String
render={({ children, ...reset }) => { render={({ children, ...reset }) => {
if (typeof (children) === "string" && children.match("\n")) { if (typeof (children) === "string" && children.match("\n")) {
@ -66,7 +66,7 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
</JsonView> </JsonView>
</Scrollable> </Scrollable>
} catch (e) { } catch (e) {
console.log("jsonrepair error", e); return <pre><code className="JsonRaw">{content}</code></pre>
}; };
} }
return <pre><code className={className}>{element.children}</code></pre>; return <pre><code className={className}>{element.children}</code></pre>;

File diff suppressed because it is too large Load Diff

View File

@ -1,29 +1,25 @@
from .. utils import logger from ..utils import logger
import ollama import ollama
from .. utils import ( from ..utils import rag as Rag, Context, defines
rag as Rag,
Context,
defines
)
import json import json
llm = ollama.Client(host=defines.ollama_api_url) llm = ollama.Client(host=defines.ollama_api_url)
observer, file_watcher = Rag.start_file_watcher( observer, file_watcher = Rag.start_file_watcher(
llm=llm, llm=llm, watch_directory=defines.doc_dir, recreate=False # Don't recreate if exists
watch_directory=defines.doc_dir,
recreate=False # Don't recreate if exists
) )
context = Context(file_watcher=file_watcher) context = Context(file_watcher=file_watcher)
data = context.model_dump(mode='json') data = context.model_dump(mode="json")
context = Context.model_validate_json(json.dumps(data)) context = Context.model_validate_json(json.dumps(data))
context.file_watcher = file_watcher context.file_watcher = file_watcher
agent = context.get_or_create_agent("chat", system_prompt="You are a helpful assistant.") agent = context.get_or_create_agent(
"chat", system_prompt="You are a helpful assistant."
)
# logger.info(f"data: {data}") # logger.info(f"data: {data}")
# logger.info(f"agent: {agent}") # logger.info(f"agent: {agent}")
agent_type = agent.get_agent_type() agent_type = agent.get_agent_type()
@ -32,7 +28,7 @@ logger.info(f"system_prompt: {agent.system_prompt}")
agent.system_prompt = "Eat more tomatoes." agent.system_prompt = "Eat more tomatoes."
data = context.model_dump(mode='json') data = context.model_dump(mode="json")
context = Context.model_validate_json(json.dumps(data)) context = Context.model_validate_json(json.dumps(data))
context.file_watcher = file_watcher context.file_watcher = file_watcher

View File

@ -1,109 +1,139 @@
from fastapi import FastAPI, Request, Depends, Query # type: ignore from fastapi import FastAPI, Request, Depends, Query # type: ignore
from fastapi.responses import RedirectResponse, JSONResponse # type: ignore from fastapi.responses import RedirectResponse, JSONResponse # type: ignore
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import logging import logging
import traceback import traceback
from typing import Callable, Optional from typing import Callable, Optional
from anyio.to_thread import run_sync # type: ignore from anyio.to_thread import run_sync # type: ignore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RedirectToContext(Exception): class RedirectToContext(Exception):
def __init__(self, url: str): def __init__(self, url: str):
self.url = url self.url = url
logger.info(f"Redirect to Context: {url}") logger.info(f"Redirect to Context: {url}")
super().__init__(f"Redirect to Context: {url}") super().__init__(f"Redirect to Context: {url}")
class ContextRouteManager: class ContextRouteManager:
def __init__(self, app: FastAPI): def __init__(self, app: FastAPI):
self.app = app self.app = app
self.setup_handlers() self.setup_handlers()
def setup_handlers(self): def setup_handlers(self):
@self.app.exception_handler(RedirectToContext) @self.app.exception_handler(RedirectToContext)
async def handle_context_redirect(request: Request, exc: RedirectToContext): async def handle_context_redirect(request: Request, exc: RedirectToContext):
logger.info(f"Handling redirect to {exc.url}") logger.info(f"Handling redirect to {exc.url}")
return RedirectResponse(url=exc.url, status_code=307) return RedirectResponse(url=exc.url, status_code=307)
def ensure_context(self, route_name: str = "context_id") -> Callable[[Request], Optional[UUID]]: def ensure_context(
self, route_name: str = "context_id"
) -> Callable[[Request], Optional[UUID]]:
logger.info(f"Setting up context dependency for route parameter: {route_name}") logger.info(f"Setting up context dependency for route parameter: {route_name}")
async def _ensure_context_dependency(request: Request) -> Optional[UUID]: async def _ensure_context_dependency(request: Request) -> Optional[UUID]:
logger.info(f"Entering ensure_context_dependency, Request URL: {request.url}") logger.info(
f"Entering ensure_context_dependency, Request URL: {request.url}"
)
logger.info(f"Path params: {request.path_params}") logger.info(f"Path params: {request.path_params}")
path_params = request.path_params path_params = request.path_params
route_value = path_params.get(route_name) route_value = path_params.get(route_name)
logger.info(f"route_value: {route_value!r}, type: {type(route_value)}") logger.info(f"route_value: {route_value!r}, type: {type(route_value)}")
if route_value is None or not isinstance(route_value, str) or not route_value.strip(): if (
route_value is None
or not isinstance(route_value, str)
or not route_value.strip()
):
logger.info(f"route_value is invalid, generating new UUID") logger.info(f"route_value is invalid, generating new UUID")
path = request.url.path.rstrip('/') path = request.url.path.rstrip("/")
new_context = await run_sync(uuid4) new_context = await run_sync(uuid4)
redirect_url = f"{path}/{new_context}" redirect_url = f"{path}/{new_context}"
logger.info(f"Redirecting to {redirect_url}") logger.info(f"Redirecting to {redirect_url}")
raise RedirectToContext(redirect_url) raise RedirectToContext(redirect_url)
logger.info(f"Attempting to parse route_value as UUID: {route_value}") logger.info(f"Attempting to parse route_value as UUID: {route_value}")
try: try:
route_context = await run_sync(UUID, route_value) route_context = await run_sync(UUID, route_value)
logger.info(f"Successfully parsed UUID: {route_context}") logger.info(f"Successfully parsed UUID: {route_context}")
return route_context return route_context
except ValueError as e: except ValueError as e:
logger.error(f"Failed to parse UUID from route_value: {route_value!r}, error: {str(e)}") logger.error(
path = request.url.path.rstrip('/') f"Failed to parse UUID from route_value: {route_value!r}, error: {str(e)}"
)
path = request.url.path.rstrip("/")
new_context = await run_sync(uuid4) new_context = await run_sync(uuid4)
redirect_url = f"{path}/{new_context}" redirect_url = f"{path}/{new_context}"
logger.info(f"Invalid UUID, redirecting to {redirect_url}") logger.info(f"Invalid UUID, redirecting to {redirect_url}")
raise RedirectToContext(redirect_url) raise RedirectToContext(redirect_url)
return _ensure_context_dependency # type: ignore return _ensure_context_dependency # type: ignore
def route_pattern(self, path: str, *dependencies, **kwargs): def route_pattern(self, path: str, *dependencies, **kwargs):
logger.info(f"Registering route: {path}") logger.info(f"Registering route: {path}")
ensure_context = self.ensure_context() ensure_context = self.ensure_context()
def decorator(func): def decorator(func):
all_dependencies = list(dependencies) all_dependencies = list(dependencies)
all_dependencies.append(Depends(ensure_context)) all_dependencies.append(Depends(ensure_context))
logger.info(f"Route {path} registered with dependencies: {all_dependencies}") logger.info(
f"Route {path} registered with dependencies: {all_dependencies}"
)
return self.app.get(path, dependencies=all_dependencies, **kwargs)(func) return self.app.get(path, dependencies=all_dependencies, **kwargs)(func)
return decorator return decorator
app = FastAPI(redirect_slashes=True) app = FastAPI(redirect_slashes=True)
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception): async def global_exception_handler(request: Request, exc: Exception):
logger.error(f"Unhandled exception: {str(exc)}") logger.error(f"Unhandled exception: {str(exc)}")
logger.error(f"Request URL: {request.url}, Path params: {request.path_params}") logger.error(f"Request URL: {request.url}, Path params: {request.path_params}")
logger.error(f"Stack trace: {''.join(traceback.format_tb(exc.__traceback__))}") logger.error(f"Stack trace: {''.join(traceback.format_tb(exc.__traceback__))}")
return JSONResponse( return JSONResponse(
status_code=500, status_code=500, content={"error": "Internal server error", "detail": str(exc)}
content={"error": "Internal server error", "detail": str(exc)}
) )
@app.middleware("http") @app.middleware("http")
async def log_requests(request: Request, call_next): async def log_requests(request: Request, call_next):
logger.info(f"Incoming request: {request.method} {request.url}, Path params: {request.path_params}") logger.info(
f"Incoming request: {request.method} {request.url}, Path params: {request.path_params}"
)
response = await call_next(request) response = await call_next(request)
return response return response
context_router = ContextRouteManager(app) context_router = ContextRouteManager(app)
@context_router.route_pattern("/api/history/{context_id}") @context_router.route_pattern("/api/history/{context_id}")
async def get_history(request: Request, context_id: UUID = Depends(context_router.ensure_context()), agent_type: str = Query(..., description="Type of agent to retrieve history for")): async def get_history(
request: Request,
context_id: UUID = Depends(context_router.ensure_context()),
agent_type: str = Query(..., description="Type of agent to retrieve history for"),
):
logger.info(f"{request.method} {request.url.path} with context_id: {context_id}") logger.info(f"{request.method} {request.url.path} with context_id: {context_id}")
return {"context_id": str(context_id), "agent_type": agent_type} return {"context_id": str(context_id), "agent_type": agent_type}
@app.get("/api/history") @app.get("/api/history")
async def redirect_history(request: Request, agent_type: str = Query(..., description="Type of agent to retrieve history for")): async def redirect_history(
path = request.url.path.rstrip('/') request: Request,
agent_type: str = Query(..., description="Type of agent to retrieve history for"),
):
path = request.url.path.rstrip("/")
new_context = uuid4() new_context = uuid4()
redirect_url = f"{path}/{new_context}?agent_type={agent_type}" redirect_url = f"{path}/{new_context}?agent_type={agent_type}"
logger.info(f"Redirecting from /api/history to {redirect_url}") logger.info(f"Redirecting from /api/history to {redirect_url}")
return RedirectResponse(url=redirect_url, status_code=307) return RedirectResponse(url=redirect_url, status_code=307)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn # type: ignore import uvicorn # type: ignore
uvicorn.run(app, host="0.0.0.0", port=8900)
uvicorn.run(app, host="0.0.0.0", port=8900)

View File

@ -1,28 +1,24 @@
# From /opt/backstory run: # From /opt/backstory run:
# python -m src.tests.test-context # python -m src.tests.test-context
import os import os
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
import warnings import warnings
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv") warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
import ollama import ollama
from .. utils import ( from ..utils import rag as Rag, Context, defines
rag as Rag,
Context,
defines
)
import json import json
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
observer, file_watcher = Rag.start_file_watcher( observer, file_watcher = Rag.start_file_watcher(
llm=llm, llm=llm, watch_directory=defines.doc_dir, recreate=False # Don't recreate if exists
watch_directory=defines.doc_dir,
recreate=False # Don't recreate if exists
) )
context = Context(file_watcher=file_watcher) context = Context(file_watcher=file_watcher)
data = context.model_dump(mode='json') data = context.model_dump(mode="json")
context = Context.from_json(json.dumps(data), file_watcher=file_watcher) context = Context.from_json(json.dumps(data), file_watcher=file_watcher)

View File

@ -1,13 +1,11 @@
# From /opt/backstory run: # From /opt/backstory run:
# python -m src.tests.test-message # python -m src.tests.test-message
from .. utils import logger from ..utils import logger
from .. utils import ( from ..utils import Message
Message
)
import json import json
prompt = "This is a test" prompt = "This is a test"
message = Message(prompt=prompt) message = Message(prompt=prompt)
print(message.model_dump(mode='json')) print(message.model_dump(mode="json"))

View File

@ -1,8 +1,6 @@
# From /opt/backstory run: # From /opt/backstory run:
# python -m src.tests.test-metrics # python -m src.tests.test-metrics
from .. utils import ( from ..utils import Metrics
Metrics
)
import json import json
@ -13,7 +11,7 @@ metrics = Metrics()
metrics.prepare_count.labels(agent="chat").inc() metrics.prepare_count.labels(agent="chat").inc()
metrics.prepare_duration.labels(agent="prepare").observe(0.45) metrics.prepare_duration.labels(agent="prepare").observe(0.45)
json = metrics.model_dump(mode='json') json = metrics.model_dump(mode="json")
metrics = Metrics.model_validate(json) metrics = Metrics.model_validate(json)
print(metrics) print(metrics)

View File

@ -1,44 +1,51 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel # type: ignore from pydantic import BaseModel # type: ignore
import importlib import importlib
from . import defines from . import defines
from . context import Context from .context import Context
from . conversation import Conversation from .conversation import Conversation
from . message import Message, Tunables from .message import Message, Tunables
from . rag import ChromaDBFileWatcher, start_file_watcher from .rag import ChromaDBFileWatcher, start_file_watcher
from . setup_logging import setup_logging from .setup_logging import setup_logging
from . agents import class_registry, AnyAgent, Agent, __all__ as agents_all from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
from . metrics import Metrics from .metrics import Metrics
__all__ = [ __all__ = [
'Agent', "Agent",
'Tunables', "Tunables",
'Context', "Context",
'Conversation', "Conversation",
'Message', "Message",
'Metrics', "Metrics",
'ChromaDBFileWatcher', "ChromaDBFileWatcher",
'start_file_watcher', "start_file_watcher",
'logger', "logger",
] ]
__all__.extend(agents_all) # type: ignore __all__.extend(agents_all) # type: ignore
logger = setup_logging(level=defines.logging_level) logger = setup_logging(level=defines.logging_level)
def rebuild_models(): def rebuild_models():
for class_name, (module_name, _) in class_registry.items(): for class_name, (module_name, _) in class_registry.items():
try: try:
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
cls = getattr(module, class_name, None) cls = getattr(module, class_name, None)
logger.debug(f"Checking: {class_name} in module {module_name}") logger.debug(f"Checking: {class_name} in module {module_name}")
logger.debug(f" cls: {True if cls else False}") logger.debug(f" cls: {True if cls else False}")
logger.debug(f" isinstance(cls, type): {isinstance(cls, type)}") logger.debug(f" isinstance(cls, type): {isinstance(cls, type)}")
logger.debug(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}") logger.debug(
logger.debug(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}") f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}"
logger.debug(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}") )
logger.debug(
f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}"
)
logger.debug(
f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}"
)
if ( if (
cls cls
@ -48,13 +55,15 @@ def rebuild_models():
and cls is not AnyAgent and cls is not AnyAgent
): ):
logger.debug(f"Rebuilding {class_name} from {module_name}") logger.debug(f"Rebuilding {class_name} from {module_name}")
from . agents import Agent from .agents import Agent
from . context import Context from .context import Context
cls.model_rebuild() cls.model_rebuild()
except ImportError as e: except ImportError as e:
logger.error(f"Failed to import module {module_name}: {e}") logger.error(f"Failed to import module {module_name}: {e}")
except Exception as e: except Exception as e:
logger.error(f"Error processing {class_name} in {module_name}: {e}") logger.error(f"Error processing {class_name} in {module_name}: {e}")
# Call this after all modules are imported # Call this after all modules are imported
rebuild_models() rebuild_models()

View File

@ -4,19 +4,21 @@ import importlib
import pathlib import pathlib
import inspect import inspect
from . types import agent_registry from .types import agent_registry
from .. setup_logging import setup_logging from ..setup_logging import setup_logging
from .. import defines from .. import defines
from . base import Agent from .base import Agent
logger = setup_logging(defines.logging_level) logger = setup_logging(defines.logging_level)
__all__ = [ "AnyAgent", "Agent", "agent_registry", "class_registry" ] __all__ = ["AnyAgent", "Agent", "agent_registry", "class_registry"]
# Type alias for Agent or any subclass # Type alias for Agent or any subclass
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
class_registry: Dict[str, Tuple[str, str]] = {} # Maps class_name to (module_name, class_name) class_registry: Dict[str, Tuple[str, str]] = (
{}
) # Maps class_name to (module_name, class_name)
package_dir = pathlib.Path(__file__).parent package_dir = pathlib.Path(__file__).parent
package_name = __name__ package_name = __name__
@ -42,7 +44,7 @@ for path in package_dir.glob("*.py"):
class_registry[name] = (full_module_name, name) class_registry[name] = (full_module_name, name)
globals()[name] = obj globals()[name] = obj
logger.info(f"Adding agent: {name}") logger.info(f"Adding agent: {name}")
__all__.append(name) # type: ignore __all__.append(name) # type: ignore
except ImportError as e: except ImportError as e:
logger.error(f"Error importing {full_module_name}: {e}") logger.error(f"Error importing {full_module_name}: {e}")
raise e raise e

View File

@ -1,8 +1,17 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, PrivateAttr, Field # type: ignore from pydantic import BaseModel, PrivateAttr, Field # type: ignore
from typing import ( from typing import (
Literal, get_args, List, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any, Literal,
TypeAlias, Dict, Tuple get_args,
List,
AsyncGenerator,
TYPE_CHECKING,
Optional,
ClassVar,
Any,
TypeAlias,
Dict,
Tuple,
) )
import json import json
import time import time
@ -10,59 +19,66 @@ import inspect
from abc import ABC from abc import ABC
import asyncio import asyncio
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
from .. setup_logging import setup_logging from ..setup_logging import setup_logging
logger = setup_logging() logger = setup_logging()
# Only import Context for type checking # Only import Context for type checking
if TYPE_CHECKING: if TYPE_CHECKING:
from .. context import Context from ..context import Context
from . types import agent_registry from .types import agent_registry
from .. import defines from .. import defines
from .. message import Message, Tunables from ..message import Message, Tunables
from .. metrics import Metrics from ..metrics import Metrics
from .. tools import ( TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools ) # type: ignore -- dynamically added to __all__ from ..tools import TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools # type: ignore -- dynamically added to __all__
from .. conversation import Conversation from ..conversation import Conversation
class LLMMessage(BaseModel): class LLMMessage(BaseModel):
role : str = Field(default="") role: str = Field(default="")
content : str = Field(default="") content: str = Field(default="")
tool_calls : Optional[List[Dict]] = Field(default={}, exclude=True) tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
class Agent(BaseModel, ABC): class Agent(BaseModel, ABC):
""" """
Base class for all agent types. Base class for all agent types.
This class defines the common attributes and methods for all agent types. This class defines the common attributes and methods for all agent types.
""" """
# Agent management with pydantic # Agent management with pydantic
agent_type: Literal["base"] = "base" agent_type: Literal["base"] = "base"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
# Tunables (sets default for new Messages attached to this agent) # Tunables (sets default for new Messages attached to this agent)
tunables: Tunables = Field(default_factory=Tunables) tunables: Tunables = Field(default_factory=Tunables)
# Agent properties # Agent properties
system_prompt: str # Mandatory system_prompt: str # Mandatory
conversation: Conversation = Conversation() conversation: Conversation = Conversation()
context_tokens: int = 0 context_tokens: int = 0
context: Optional[Context] = Field(default=None, exclude=True) # Avoid circular reference, require as param, and prevent serialization context: Optional[Context] = Field(
default=None, exclude=True
) # Avoid circular reference, require as param, and prevent serialization
metrics: Metrics = Field(default_factory=Metrics, exclude=True) metrics: Metrics = Field(default_factory=Metrics, exclude=True)
# context_size is shared across all subclasses # context_size is shared across all subclasses
_context_size: ClassVar[int] = int(defines.max_context * 0.5) _context_size: ClassVar[int] = int(defines.max_context * 0.5)
@property @property
def context_size(self) -> int: def context_size(self) -> int:
return Agent._context_size return Agent._context_size
@context_size.setter @context_size.setter
def context_size(self, value: int): def context_size(self, value: int):
Agent._context_size = value Agent._context_size = value
def set_optimal_context_size(self, llm: Any, model: str, prompt: str, ctx_buffer=2048) -> int: def set_optimal_context_size(
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
) -> int:
# # Get more accurate token count estimate using tiktoken or similar # # Get more accurate token count estimate using tiktoken or similar
# response = llm.generate( # response = llm.generate(
# model=model, # model=model,
@ -73,31 +89,33 @@ class Agent(BaseModel, ABC):
# } # Don't generate any tokens, just tokenize # } # Don't generate any tokens, just tokenize
# ) # )
# # The prompt_eval_count gives you the token count of your input # # The prompt_eval_count gives you the token count of your input
# tokens = response.get("prompt_eval_count", 0) # tokens = response.get("prompt_eval_count", 0)
# Most models average 1.3-1.5 tokens per word # Most models average 1.3-1.5 tokens per word
word_count = len(prompt.split()) word_count = len(prompt.split())
tokens = int(word_count * 1.4) tokens = int(word_count * 1.4)
# Add buffer for safety # Add buffer for safety
total_ctx = tokens + ctx_buffer total_ctx = tokens + ctx_buffer
if total_ctx > self.context_size: if total_ctx > self.context_size:
logger.info(f"Increasing context size from {self.context_size} to {total_ctx}") logger.info(
f"Increasing context size from {self.context_size} to {total_ctx}"
)
# Grow the context size if necessary # Grow the context size if necessary
self.context_size = max(self.context_size, total_ctx) self.context_size = max(self.context_size, total_ctx)
# Use actual model maximum context size # Use actual model maximum context size
return self.context_size return self.context_size
# Class and pydantic model management # Class and pydantic model management
def __init_subclass__(cls, **kwargs) -> None: def __init_subclass__(cls, **kwargs) -> None:
"""Auto-register subclasses""" """Auto-register subclasses"""
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
# Register this class if it has an agent_type # Register this class if it has an agent_type
if hasattr(cls, 'agent_type') and cls.agent_type != Agent._agent_type: if hasattr(cls, "agent_type") and cls.agent_type != Agent._agent_type:
agent_registry.register(cls.agent_type, cls) agent_registry.register(cls.agent_type, cls)
# def __init__(self, *, context=context, **data): # def __init__(self, *, context=context, **data):
# super().__init__(**data) # super().__init__(**data)
# self.set_context(context) # self.set_context(context)
@ -118,12 +136,12 @@ class Agent(BaseModel, ABC):
def set_context(self, context: Context): def set_context(self, context: Context):
object.__setattr__(self, "context", context) object.__setattr__(self, "context", context)
# Agent methods # Agent methods
def get_agent_type(self): def get_agent_type(self):
return self._agent_type return self._agent_type
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
""" """
Prepare message with context information in message.preamble Prepare message with context information in message.preamble
""" """
@ -133,7 +151,7 @@ class Agent(BaseModel, ABC):
with self.metrics.prepare_duration.labels(agent=self.agent_type).time(): with self.metrics.prepare_duration.labels(agent=self.agent_type).time():
if not self.context: if not self.context:
raise ValueError("Context is not set for this agent.") raise ValueError("Context is not set for this agent.")
# Generate RAG content if enabled, based on the content # Generate RAG content if enabled, based on the content
rag_context = "" rag_context = ""
if message.tunables.enable_rag and message.prompt: if message.tunables.enable_rag and message.prompt:
@ -146,7 +164,7 @@ class Agent(BaseModel, ABC):
return return
if message.status != "done": if message.status != "done":
yield message yield message
if "rag" in message.metadata and message.metadata["rag"]: if "rag" in message.metadata and message.metadata["rag"]:
for rag in message.metadata["rag"]: for rag in message.metadata["rag"]:
for doc in rag["documents"]: for doc in rag["documents"]:
@ -159,16 +177,23 @@ class Agent(BaseModel, ABC):
if message.tunables.enable_context and self.context.user_resume: if message.tunables.enable_context and self.context.user_resume:
message.preamble["resume"] = self.context.user_resume message.preamble["resume"] = self.context.user_resume
message.system_prompt = self.system_prompt message.system_prompt = self.system_prompt
message.status = "done" message.status = "done"
yield message yield message
return return
async def process_tool_calls(self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[LLMMessage]) -> AsyncGenerator[Message, None]: async def process_tool_calls(
self,
llm: Any,
model: str,
message: Message,
tool_message: Any,
messages: List[LLMMessage],
) -> AsyncGenerator[Message, None]:
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
self.metrics.tool_count.labels(agent=self.agent_type).inc() self.metrics.tool_count.labels(agent=self.agent_type).inc()
with self.metrics.tool_duration.labels(agent=self.agent_type).time(): with self.metrics.tool_duration.labels(agent=self.agent_type).time():
@ -185,12 +210,14 @@ class Agent(BaseModel, ABC):
for i, tool_call in enumerate(tool_message.tool_calls): for i, tool_call in enumerate(tool_message.tool_calls):
arguments = tool_call.function.arguments arguments = tool_call.function.arguments
tool = tool_call.function.name tool = tool_call.function.name
# Yield status update before processing each tool # Yield status update before processing each tool
message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..." message.response = (
f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
)
yield message yield message
logger.info(f"LLM - {message.response}") logger.info(f"LLM - {message.response}")
# Process the tool based on its type # Process the tool based on its type
match tool: match tool:
case "TickerValue": case "TickerValue":
@ -199,40 +226,48 @@ class Agent(BaseModel, ABC):
ret = None ret = None
else: else:
ret = TickerValue(ticker) ret = TickerValue(ticker)
case "AnalyzeSite": case "AnalyzeSite":
url = arguments.get("url") url = arguments.get("url")
question = arguments.get("question", "what is the summary of this content?") question = arguments.get(
"question", "what is the summary of this content?"
)
# Additional status update for long-running operations # Additional status update for long-running operations
message.response = f"Retrieving and summarizing content from {url}..." message.response = (
f"Retrieving and summarizing content from {url}..."
)
yield message yield message
ret = await AnalyzeSite(llm=llm, model=model, url=url, question=question) ret = await AnalyzeSite(
llm=llm, model=model, url=url, question=question
)
case "DateTime": case "DateTime":
tz = arguments.get("timezone") tz = arguments.get("timezone")
ret = DateTime(tz) ret = DateTime(tz)
case "WeatherForecast": case "WeatherForecast":
city = arguments.get("city") city = arguments.get("city")
state = arguments.get("state") state = arguments.get("state")
message.response = f"Fetching weather data for {city}, {state}..." message.response = (
yield message f"Fetching weather data for {city}, {state}..."
)
yield message
ret = WeatherForecast(city, state) ret = WeatherForecast(city, state)
case _: case _:
ret = None ret = None
# Build response for this tool # Build response for this tool
tool_response = { tool_response = {
"role": "tool", "role": "tool",
"content": json.dumps(ret), "content": json.dumps(ret),
"name": tool_call.function.name "name": tool_call.function.name,
} }
tool_metadata["tool_calls"].append(tool_response) tool_metadata["tool_calls"].append(tool_response)
if len(tool_metadata["tool_calls"]) == 0: if len(tool_metadata["tool_calls"]) == 0:
message.status = "done" message.status = "done"
yield message yield message
@ -241,13 +276,15 @@ class Agent(BaseModel, ABC):
message_dict = LLMMessage( message_dict = LLMMessage(
role=tool_message.get("role", "assistant"), role=tool_message.get("role", "assistant"),
content=tool_message.get("content", ""), content=tool_message.get("content", ""),
tool_calls=[ { tool_calls=[
"function": { {
"name": tc["function"]["name"], "function": {
"arguments": tc["function"]["arguments"] "name": tc["function"]["name"],
"arguments": tc["function"]["arguments"],
}
} }
} for tc in tool_message.tool_calls for tc in tool_message.tool_calls
] ],
) )
messages.append(message_dict) messages.append(message_dict)
@ -261,8 +298,8 @@ class Agent(BaseModel, ABC):
message.response = "" message.response = ""
start_time = time.perf_counter() start_time = time.perf_counter()
for response in llm.chat( for response in llm.chat(
model=model, model=model,
messages=messages, messages=messages,
options={ options={
**message.metadata["options"], **message.metadata["options"],
# "temperature": 0.5, # "temperature": 0.5,
@ -282,45 +319,62 @@ class Agent(BaseModel, ABC):
message.metadata["eval_count"] += response.eval_count message.metadata["eval_count"] += response.eval_count
message.metadata["eval_duration"] += response.eval_duration message.metadata["eval_duration"] += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count message.metadata["prompt_eval_count"] += response.prompt_eval_count
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration message.metadata[
self.context_tokens = response.prompt_eval_count + response.eval_count "prompt_eval_duration"
] += response.prompt_eval_duration
self.context_tokens = (
response.prompt_eval_count + response.eval_count
)
message.status = "done" message.status = "done"
yield message yield message
end_time = time.perf_counter() end_time = time.perf_counter()
message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}" message.metadata["timers"][
"llm_with_tools"
] = f"{(end_time - start_time):.4f}"
return return
def collect_metrics(self, response): def collect_metrics(self, response):
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(response.prompt_eval_count) self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
response.prompt_eval_count
)
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count) self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
async def generate_llm_response(self, llm: Any, model: str, message: Message, temperature = 0.7) -> AsyncGenerator[Message, None]: async def generate_llm_response(
self, llm: Any, model: str, message: Message, temperature=0.7
) -> AsyncGenerator[Message, None]:
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
self.metrics.generate_count.labels(agent=self.agent_type).inc() self.metrics.generate_count.labels(agent=self.agent_type).inc()
with self.metrics.generate_duration.labels(agent=self.agent_type).time(): with self.metrics.generate_duration.labels(agent=self.agent_type).time():
if not self.context: if not self.context:
raise ValueError("Context is not set for this agent.") raise ValueError("Context is not set for this agent.")
# Create a pruned down message list based purely on the prompt and responses, # Create a pruned down message list based purely on the prompt and responses,
# discarding the full preamble generated by prepare_message # discarding the full preamble generated by prepare_message
messages: List[LLMMessage] = [ LLMMessage(role="system", content=message.system_prompt) ] messages: List[LLMMessage] = [
messages.extend([ LLMMessage(role="system", content=message.system_prompt)
item for m in self.conversation ]
for item in [ messages.extend(
LLMMessage(role="user", content=m.prompt.strip()), [
LLMMessage(role="assistant", content=m.response.strip()) item
for m in self.conversation
for item in [
LLMMessage(role="user", content=m.prompt.strip()),
LLMMessage(role="assistant", content=m.response.strip()),
]
] ]
]) )
# Only the actual user query is provided with the full context message # Only the actual user query is provided with the full context message
messages.append(LLMMessage(role="user", content=message.context_prompt.strip())) messages.append(
LLMMessage(role="user", content=message.context_prompt.strip())
)
#message.metadata["messages"] = messages # message.metadata["messages"] = messages
message.metadata["options"]={ message.metadata["options"] = {
"seed": 8911, "seed": 8911,
"num_ctx": self.context_size, "num_ctx": self.context_size,
"temperature": temperature # Higher temperature to encourage tool usage "temperature": temperature, # Higher temperature to encourage tool usage
} }
# Create a dict for storing various timing stats # Create a dict for storing various timing stats
@ -329,7 +383,7 @@ class Agent(BaseModel, ABC):
use_tools = message.tunables.enable_tools and len(self.context.tools) > 0 use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
message.metadata["tools"] = { message.metadata["tools"] = {
"available": llm_tools(self.context.tools), "available": llm_tools(self.context.tools),
"used": False "used": False,
} }
tool_metadata = message.metadata["tools"] tool_metadata = message.metadata["tools"]
@ -355,14 +409,16 @@ class Agent(BaseModel, ABC):
tools=tool_metadata["available"], tools=tool_metadata["available"],
options={ options={
**message.metadata["options"], **message.metadata["options"],
#"num_predict": 1024, # "Low" token limit to cut off after tool call # "num_predict": 1024, # "Low" token limit to cut off after tool call
}, },
stream=False # No need to stream the probe stream=False, # No need to stream the probe
) )
self.collect_metrics(response) self.collect_metrics(response)
end_time = time.perf_counter() end_time = time.perf_counter()
message.metadata["timers"]["tool_check"] = f"{(end_time - start_time):.4f}" message.metadata["timers"][
"tool_check"
] = f"{(end_time - start_time):.4f}"
if not response.message.tool_calls: if not response.message.tool_calls:
logger.info("LLM indicates tools will not be used") logger.info("LLM indicates tools will not be used")
# The LLM will not use tools, so disable use_tools so we can stream the full response # The LLM will not use tools, so disable use_tools so we can stream the full response
@ -374,7 +430,9 @@ class Agent(BaseModel, ABC):
logger.info("LLM indicates tools will be used") logger.info("LLM indicates tools will be used")
# Tools are enabled and available and the LLM indicated it will use them # Tools are enabled and available and the LLM indicated it will use them
message.response = f"Performing tool analysis step 2/2 (tool use suspected)..." message.response = (
f"Performing tool analysis step 2/2 (tool use suspected)..."
)
yield message yield message
logger.info(f"Performing LLM call with tools") logger.info(f"Performing LLM call with tools")
@ -384,14 +442,16 @@ class Agent(BaseModel, ABC):
messages=tool_metadata["messages"], # messages, messages=tool_metadata["messages"], # messages,
tools=tool_metadata["available"], tools=tool_metadata["available"],
options={ options={
**message.metadata["options"], **message.metadata["options"],
}, },
stream=False stream=False,
) )
self.collect_metrics(response) self.collect_metrics(response)
end_time = time.perf_counter() end_time = time.perf_counter()
message.metadata["timers"]["non_streaming"] = f"{(end_time - start_time):.4f}" message.metadata["timers"][
"non_streaming"
] = f"{(end_time - start_time):.4f}"
if not response: if not response:
message.status = "error" message.status = "error"
@ -403,16 +463,24 @@ class Agent(BaseModel, ABC):
tool_metadata["used"] = response.message.tool_calls tool_metadata["used"] = response.message.tool_calls
# Process all yielded items from the handler # Process all yielded items from the handler
start_time = time.perf_counter() start_time = time.perf_counter()
async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages): async for message in self.process_tool_calls(
llm=llm,
model=model,
message=message,
tool_message=response.message,
messages=messages,
):
if message.status == "error": if message.status == "error":
yield message yield message
return return
yield message yield message
end_time = time.perf_counter() end_time = time.perf_counter()
message.metadata["timers"]["process_tool_calls"] = f"{(end_time - start_time):.4f}" message.metadata["timers"][
"process_tool_calls"
] = f"{(end_time - start_time):.4f}"
message.status = "done" message.status = "done"
return return
logger.info("LLM indicated tools will be used, and then they weren't") logger.info("LLM indicated tools will be used, and then they weren't")
message.response = response.message.content message.response = response.message.content
message.status = "done" message.status = "done"
@ -427,8 +495,8 @@ class Agent(BaseModel, ABC):
message.response = "" message.response = ""
start_time = time.perf_counter() start_time = time.perf_counter()
for response in llm.chat( for response in llm.chat(
model=model, model=model,
messages=messages, messages=messages,
options={ options={
**message.metadata["options"], **message.metadata["options"],
}, },
@ -452,16 +520,22 @@ class Agent(BaseModel, ABC):
message.metadata["eval_count"] += response.eval_count message.metadata["eval_count"] += response.eval_count
message.metadata["eval_duration"] += response.eval_duration message.metadata["eval_duration"] += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count message.metadata["prompt_eval_count"] += response.prompt_eval_count
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration message.metadata[
self.context_tokens = response.prompt_eval_count + response.eval_count "prompt_eval_duration"
] += response.prompt_eval_duration
self.context_tokens = (
response.prompt_eval_count + response.eval_count
)
message.status = "done" message.status = "done"
yield message yield message
end_time = time.perf_counter() end_time = time.perf_counter()
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}" message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
return return
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]: async def process_message(
self, llm: Any, model: str, message: Message
) -> AsyncGenerator[Message, None]:
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
self.metrics.process_count.labels(agent=self.agent_type).inc() self.metrics.process_count.labels(agent=self.agent_type).inc()
@ -469,23 +543,31 @@ class Agent(BaseModel, ABC):
if not self.context: if not self.context:
raise ValueError("Context is not set for this agent.") raise ValueError("Context is not set for this agent.")
logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") logger.info(
spinner: List[str] = ['\\', '|', '/', '-'] "TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
tick : int = 0 )
spinner: List[str] = ["\\", "|", "/", "-"]
tick: int = 0
while self.context.processing: while self.context.processing:
message.status = "waiting" message.status = "waiting"
message.response = f"Busy processing another request. Please wait. {spinner[tick]}" message.response = (
f"Busy processing another request. Please wait. {spinner[tick]}"
)
tick = (tick + 1) % len(spinner) tick = (tick + 1) % len(spinner)
yield message yield message
await asyncio.sleep(1) # Allow the event loop to process the write await asyncio.sleep(1) # Allow the event loop to process the write
self.context.processing = True self.context.processing = True
message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n</|system|>" message.metadata["system_prompt"] = (
f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
)
message.context_prompt = "" message.context_prompt = ""
for p in message.preamble.keys(): for p in message.preamble.keys():
message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n</|{p}>\n\n" message.context_prompt += (
f"\n<|{p}|>\n{message.preamble[p].strip()}\n</|{p}>\n\n"
)
message.context_prompt += f"{message.prompt}" message.context_prompt += f"{message.prompt}"
# Estimate token length of new messages # Estimate token length of new messages
@ -493,20 +575,24 @@ class Agent(BaseModel, ABC):
message.status = "thinking" message.status = "thinking"
yield message yield message
message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt) message.metadata["context_size"] = self.set_optimal_context_size(
llm, model, prompt=message.context_prompt
)
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
message.status = "thinking" message.status = "thinking"
yield message yield message
async for message in self.generate_llm_response(llm=llm, model=model, message=message): async for message in self.generate_llm_response(
llm=llm, model=model, message=message
):
# logger.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}") # logger.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
if message.status == "error": if message.status == "error":
yield message yield message
self.context.processing = False self.context.processing = False
return return
yield message yield message
# Done processing, add message to conversation # Done processing, add message to conversation
message.status = "done" message.status = "done"
self.conversation.add(message) self.conversation.add(message)
@ -514,6 +600,6 @@ class Agent(BaseModel, ABC):
return return
# Register the base agent # Register the base agent
agent_registry.register(Agent._agent_type, Agent) agent_registry.register(Agent._agent_type, Agent)

View File

@ -3,9 +3,10 @@ from typing import Literal, AsyncGenerator, ClassVar, Optional, Any
from datetime import datetime from datetime import datetime
import inspect import inspect
from . base import Agent, agent_registry from .base import Agent, agent_registry
from .. message import Message from ..message import Message
from .. setup_logging import setup_logging from ..setup_logging import setup_logging
logger = setup_logging() logger = setup_logging()
system_message = f""" system_message = f"""
@ -26,35 +27,42 @@ When answering queries, follow these steps:
Always use tools, <|resume|>, and <|context|> when possible. Be concise, and never make up information. If you do not know the answer, say so. Always use tools, <|resume|>, and <|context|> when possible. Be concise, and never make up information. If you do not know the answer, say so.
""" """
class Chat(Agent): class Chat(Agent):
""" """
Chat Agent Chat Agent
""" """
agent_type: Literal["chat"] = "chat" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_message agent_type: Literal["chat"] = "chat" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: system_prompt: str = system_message
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
async for message in super().prepare_message(message):
if message.status != "done":
yield message
if message.preamble: async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
excluded = {} logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
preamble_types = [f"<|{p}|>" for p in message.preamble.keys() if p not in excluded] if not self.context:
preamble_types_AND = " and ".join(preamble_types) raise ValueError("Context is not set for this agent.")
preamble_types_OR = " or ".join(preamble_types)
message.preamble["rules"] = f"""\ async for message in super().prepare_message(message):
if message.status != "done":
yield message
if message.preamble:
excluded = {}
preamble_types = [
f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
]
preamble_types_AND = " and ".join(preamble_types)
preamble_types_OR = " or ".join(preamble_types)
message.preamble[
"rules"
] = f"""\
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly. - Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
- If there is no information in these sections, answer based on your knowledge, or use any available tools. - If there is no information in these sections, answer based on your knowledge, or use any available tools.
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}. - Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
""" """
message.preamble["question"] = "Respond to:" message.preamble["question"] = "Respond to:"
# Register the base agent # Register the base agent
agent_registry.register(Chat._agent_type, Chat) agent_registry.register(Chat._agent_type, Chat)

View File

@ -1,13 +1,21 @@
from __future__ import annotations from __future__ import annotations
from pydantic import model_validator # type: ignore from pydantic import model_validator # type: ignore
from typing import Literal, ClassVar, Optional, Any, AsyncGenerator, List # NOTE: You must import Optional for late binding to work from typing import (
Literal,
ClassVar,
Optional,
Any,
AsyncGenerator,
List,
) # NOTE: You must import Optional for late binding to work
from datetime import datetime from datetime import datetime
import inspect import inspect
from . base import Agent, agent_registry from .base import Agent, agent_registry
from .. conversation import Conversation from ..conversation import Conversation
from .. message import Message from ..message import Message
from .. setup_logging import setup_logging from ..setup_logging import setup_logging
logger = setup_logging() logger = setup_logging()
system_fact_check = f""" system_fact_check = f"""
@ -21,50 +29,56 @@ When answering queries, follow these steps:
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|generated-resume|>, or <|resume|> tags. - Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|generated-resume|>, or <|resume|> tags.
""".strip() """.strip()
class FactCheck(Agent): class FactCheck(Agent):
agent_type: Literal["fact_check"] = "fact_check" # type: ignore agent_type: Literal["fact_check"] = "fact_check" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt:str = system_fact_check system_prompt: str = system_fact_check
facts: str facts: str
@model_validator(mode="after") @model_validator(mode="after")
def validate_facts(self): def validate_facts(self):
if not self.facts.strip(): if not self.facts.strip():
raise ValueError("Facts cannot be empty") raise ValueError("Facts cannot be empty")
return self return self
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
if not self.context: if not self.context:
raise ValueError("Context is not set for this agent.") raise ValueError("Context is not set for this agent.")
resume_agent = self.context.get_agent("resume") resume_agent = self.context.get_agent("resume")
if not resume_agent: if not resume_agent:
raise ValueError("resume agent does not exist") raise ValueError("resume agent does not exist")
message.tunables.enable_tools = False message.tunables.enable_tools = False
async for message in super().prepare_message(message): async for message in super().prepare_message(message):
if message.status != "done": if message.status != "done":
yield message yield message
message.preamble["generated-resume"] = resume_agent.resume message.preamble["generated-resume"] = resume_agent.resume
message.preamble["discrepancies"] = self.facts message.preamble["discrepancies"] = self.facts
excluded = {"job_description"} excluded = {"job_description"}
preamble_types = [f"<|{p}|>" for p in message.preamble.keys() if p not in excluded] preamble_types = [
preamble_types_AND = " and ".join(preamble_types) f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
preamble_types_OR = " or ".join(preamble_types) ]
message.preamble["rules"] = f"""\ preamble_types_AND = " and ".join(preamble_types)
preamble_types_OR = " or ".join(preamble_types)
message.preamble[
"rules"
] = f"""\
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly. - Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
- If there is no information in these sections, answer based on your knowledge, or use any available tools. - If there is no information in these sections, answer based on your knowledge, or use any available tools.
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}. - Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
""" """
message.preamble["question"] = "Respond to:" message.preamble["question"] = "Respond to:"
yield message
return
yield message
return
# Register the base agent # Register the base agent
agent_registry.register(FactCheck._agent_type, FactCheck) agent_registry.register(FactCheck._agent_type, FactCheck)

File diff suppressed because it is too large Load Diff

View File

@ -1,12 +1,20 @@
from __future__ import annotations from __future__ import annotations
from pydantic import model_validator # type: ignore from pydantic import model_validator # type: ignore
from typing import Literal, ClassVar, Optional, Any, AsyncGenerator, List # NOTE: You must import Optional for late binding to work from typing import (
Literal,
ClassVar,
Optional,
Any,
AsyncGenerator,
List,
) # NOTE: You must import Optional for late binding to work
from datetime import datetime from datetime import datetime
import inspect import inspect
from . base import Agent, agent_registry from .base import Agent, agent_registry
from .. message import Message from ..message import Message
from .. setup_logging import setup_logging from ..setup_logging import setup_logging
logger = setup_logging() logger = setup_logging()
system_fact_check = f""" system_fact_check = f"""
@ -36,82 +44,94 @@ When answering queries, follow these steps:
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|job_description|>, <|resume|>, or <|context|> tags. - Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|job_description|>, <|resume|>, or <|context|> tags.
""".strip() """.strip()
class Resume(Agent): class Resume(Agent):
agent_type: Literal["resume"] = "resume" # type: ignore agent_type: Literal["resume"] = "resume" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt:str = system_fact_check system_prompt: str = system_fact_check
resume: str resume: str
@model_validator(mode="after") @model_validator(mode="after")
def validate_resume(self): def validate_resume(self):
if not self.resume.strip(): if not self.resume.strip():
raise ValueError("Resume content cannot be empty") raise ValueError("Resume content cannot be empty")
return self return self
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
if not self.context: if not self.context:
raise ValueError("Context is not set for this agent.") raise ValueError("Context is not set for this agent.")
# Generating fact check or resume should not use any tools
message.tunables.enable_tools = False
async for message in super().prepare_message(message): # Generating fact check or resume should not use any tools
if message.status != "done": message.tunables.enable_tools = False
yield message
message.preamble["generated-resume"] = self.resume async for message in super().prepare_message(message):
job_description_agent = self.context.get_agent("job_description") if message.status != "done":
if not job_description_agent: yield message
raise ValueError("job_description agent does not exist")
message.preamble["job_description"] = job_description_agent.job_description
excluded = {} message.preamble["generated-resume"] = self.resume
preamble_types = [f"<|{p}|>" for p in message.preamble.keys() if p not in excluded] job_description_agent = self.context.get_agent("job_description")
preamble_types_AND = " and ".join(preamble_types) if not job_description_agent:
preamble_types_OR = " or ".join(preamble_types) raise ValueError("job_description agent does not exist")
message.preamble["rules"] = f"""\
message.preamble["job_description"] = job_description_agent.job_description
excluded = {}
preamble_types = [
f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
]
preamble_types_AND = " and ".join(preamble_types)
preamble_types_OR = " or ".join(preamble_types)
message.preamble[
"rules"
] = f"""\
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly. - Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
- If there is no information in these sections, answer based on your knowledge, or use any available tools. - If there is no information in these sections, answer based on your knowledge, or use any available tools.
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}. - Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
""" """
fact_check_agent = self.context.get_agent(agent_type="fact_check") fact_check_agent = self.context.get_agent(agent_type="fact_check")
if fact_check_agent: if fact_check_agent:
message.preamble["question"] = "Respond to:" message.preamble["question"] = "Respond to:"
else: else:
message.preamble["question"] = f"Fact check the <|generated-resume|> based on the <|resume|>{' and <|context|>' if 'context' in message.preamble else ''}." message.preamble["question"] = (
f"Fact check the <|generated-resume|> based on the <|resume|>{' and <|context|>' if 'context' in message.preamble else ''}."
)
yield message
return
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
async for message in super().process_message(llm, model, message):
if message.status != "done":
yield message yield message
return
fact_check_agent = self.context.get_agent(agent_type="fact_check") async def process_message(
if not fact_check_agent: self, llm: Any, model: str, message: Message
# Switch agent from "Fact Check Generated Resume" mode ) -> AsyncGenerator[Message, None]:
# to "Answer Questions about Generated Resume" logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
self.system_prompt = system_resume if not self.context:
raise ValueError("Context is not set for this agent.")
# Instantiate the "resume" agent, and seed (or reset) its conversation async for message in super().process_message(llm, model, message):
# with this message. if message.status != "done":
fact_check_agent = self.context.get_or_create_agent(agent_type="fact_check", facts=message.response) yield message
first_fact_check_message = message.copy()
first_fact_check_message.prompt = "Fact check the generated resume." fact_check_agent = self.context.get_agent(agent_type="fact_check")
fact_check_agent.conversation.add(first_fact_check_message) if not fact_check_agent:
message.response = "Resume fact checked." # Switch agent from "Fact Check Generated Resume" mode
# to "Answer Questions about Generated Resume"
self.system_prompt = system_resume
# Instantiate the "resume" agent, and seed (or reset) its conversation
# with this message.
fact_check_agent = self.context.get_or_create_agent(
agent_type="fact_check", facts=message.response
)
first_fact_check_message = message.copy()
first_fact_check_message.prompt = "Fact check the generated resume."
fact_check_agent.conversation.add(first_fact_check_message)
message.response = "Resume fact checked."
# Return the final message
yield message
return
# Return the final message
yield message
return
# Register the base agent # Register the base agent
agent_registry.register(Resume._agent_type, Resume) agent_registry.register(Resume._agent_type, Resume)

View File

@ -1,31 +1,34 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Dict, Optional, Type from typing import List, Dict, Optional, Type
# We'll use a registry pattern rather than hardcoded strings # We'll use a registry pattern rather than hardcoded strings
class AgentRegistry: class AgentRegistry:
"""Registry for agent types and classes""" """Registry for agent types and classes"""
_registry: Dict[str, Type] = {}
_registry: Dict[str, Type] = {}
@classmethod
def register(cls, agent_type: str, agent_class: Type) -> Type: @classmethod
"""Register an agent class with its type""" def register(cls, agent_type: str, agent_class: Type) -> Type:
cls._registry[agent_type] = agent_class """Register an agent class with its type"""
return agent_class cls._registry[agent_type] = agent_class
return agent_class
@classmethod
def get_class(cls, agent_type: str) -> Optional[Type]: @classmethod
"""Get the class for a given agent type""" def get_class(cls, agent_type: str) -> Optional[Type]:
return cls._registry.get(agent_type) """Get the class for a given agent type"""
return cls._registry.get(agent_type)
@classmethod
def get_types(cls) -> List[str]: @classmethod
"""Get all registered agent types""" def get_types(cls) -> List[str]:
return list(cls._registry.keys()) """Get all registered agent types"""
return list(cls._registry.keys())
@classmethod
def get_classes(cls) -> Dict[str, Type]: @classmethod
"""Get all registered agent classes""" def get_classes(cls) -> Dict[str, Type]:
return cls._registry.copy() """Get all registered agent classes"""
return cls._registry.copy()
# Create a singleton instance # Create a singleton instance
agent_registry = AgentRegistry() agent_registry = AgentRegistry()

View File

@ -1,122 +0,0 @@
import chromadb
from typing import List, Dict, Any, Union
from . import defines
from .chunk import chunk_document
import ollama
def init_chroma_client(persist_directory: str = defines.persist_directory):
"""Initialize and return a ChromaDB client."""
# return chromadb.PersistentClient(path=persist_directory)
return chromadb.Client()
def create_or_get_collection(db: chromadb.Client, collection_name: str):
"""Create or get a ChromaDB collection."""
try:
return db.get_collection(
name=collection_name
)
except:
return db.create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"}
)
def process_documents_to_chroma(
client: ollama.Client,
documents: List[Dict[str, Any]],
collection_name: str = "document_collection",
text_key: str = "text",
max_tokens: int = 512,
overlap: int = 50,
model: str = defines.encoding_model,
persist_directory: str = defines.persist_directory
):
"""
Process documents, chunk them, compute embeddings, and store in ChromaDB.
Args:
documents: List of document dictionaries
collection_name: Name for the ChromaDB collection
text_key: The key containing text content
max_tokens: Maximum tokens per chunk
overlap: Token overlap between chunks
model: Ollama model for embeddings
persist_directory: Directory to store ChromaDB data
"""
# Initialize ChromaDB client and collection
db = init_chroma_client(persist_directory)
collection = create_or_get_collection(db, collection_name)
# Process each document
for doc in documents:
# Chunk the document
doc_chunks = chunk_document(doc, text_key, max_tokens, overlap)
# Prepare data for ChromaDB
ids = []
texts = []
metadatas = []
embeddings = []
for chunk in doc_chunks:
# Create a unique ID for the chunk
chunk_id = f"{chunk['id']}_{chunk['chunk_id']}"
# Extract text
text = chunk[text_key]
# Create metadata (excluding text and embedding to avoid duplication)
metadata = {k: v for k, v in chunk.items() if k != text_key and k != "embedding"}
response = client.embed(model=model, input=text)
embedding = response["embeddings"][0]
ids.append(chunk_id)
texts.append(text)
metadatas.append(metadata)
embeddings.append(embedding)
# Add chunks to ChromaDB collection
collection.add(
ids=ids,
documents=texts,
embeddings=embeddings,
metadatas=metadatas
)
return collection
def query_chroma(
client: ollama.Client,
query_text: str,
collection_name: str = "document_collection",
n_results: int = 5,
model: str = defines.encoding_model,
persist_directory: str = defines.persist_directory
):
"""
Query ChromaDB for similar documents.
Args:
query_text: The text to search for
collection_name: Name of the ChromaDB collection
n_results: Number of results to return
model: Ollama model for embedding the query
persist_directory: Directory where ChromaDB data is stored
Returns:
Query results from ChromaDB
"""
# Initialize ChromaDB client and collection
db = init_chroma_client(persist_directory)
collection = create_or_get_collection(db, collection_name)
query_response = client.embed(model=model, input=query_text)
query_embeddings = query_response["embeddings"]
# Query the collection
results = collection.query(
query_embeddings=query_embeddings,
n_results=n_results
)
return results

View File

@ -1,88 +0,0 @@
import tiktoken # type: ignore
from . import defines
from typing import List, Dict, Any, Union
def get_encoding(model=defines.model):
"""Get the tokenizer for counting tokens."""
try:
return tiktoken.get_encoding("cl100k_base") # Default encoding used by many embedding models
except:
return tiktoken.encoding_for_model(model)
def count_tokens(text: str) -> int:
"""Count the number of tokens in a text string."""
encoding = get_encoding()
return len(encoding.encode(text))
def chunk_text(text: str, max_tokens: int = 512, overlap: int = 50) -> List[str]:
"""
Split a text into chunks based on token count with overlap between chunks.
Args:
text: The text to split into chunks
max_tokens: Maximum number of tokens per chunk
overlap: Number of tokens to overlap between chunks
Returns:
List of text chunks
"""
if not text or max_tokens <= 0:
return []
encoding = get_encoding()
tokens = encoding.encode(text)
chunks = []
i = 0
while i < len(tokens):
# Get the current chunk of tokens
chunk_end = min(i + max_tokens, len(tokens))
chunk_tokens = tokens[i:chunk_end]
chunks.append(encoding.decode(chunk_tokens))
# Move to the next position with overlap
if chunk_end == len(tokens):
break
i += max_tokens - overlap
return chunks
def chunk_document(document: Dict[str, Any],
text_key: str = "text",
max_tokens: int = 512,
overlap: int = 50) -> List[Dict[str, Any]]:
"""
Chunk a document dictionary into multiple chunks.
Args:
document: Document dictionary with metadata and text
text_key: The key in the document that contains the text to chunk
max_tokens: Maximum number of tokens per chunk
overlap: Number of tokens to overlap between chunks
Returns:
List of document dictionaries, each with chunked text and preserved metadata
"""
if text_key not in document:
raise Exception(f"{text_key} not in document")
# Extract text and create chunks
if "title" in document:
text = f"{document["title"]}: {document[text_key]}"
else:
text = document[text_key]
chunks = chunk_text(text, max_tokens, overlap)
# Create document chunks with preserved metadata
chunked_docs = []
for i, chunk in enumerate(chunks):
# Create a new doc with all original fields
doc_chunk = document.copy()
# Replace text with the chunk
doc_chunk[text_key] = chunk
# Add chunk metadata
doc_chunk["chunk_id"] = i
doc_chunk["chunk_total"] = len(chunks)
chunked_docs.append(doc_chunk)
return chunked_docs

View File

@ -1,32 +1,35 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, Field, model_validator# type: ignore from pydantic import BaseModel, Field, model_validator # type: ignore
from uuid import uuid4 from uuid import uuid4
from typing import List, Optional, Generator, ClassVar, Any from typing import List, Optional, Generator, ClassVar, Any
from typing_extensions import Annotated, Union from typing_extensions import Annotated, Union
import numpy as np # type: ignore import numpy as np # type: ignore
import logging import logging
from uuid import uuid4 from uuid import uuid4
from prometheus_client import CollectorRegistry, Counter # type: ignore from prometheus_client import CollectorRegistry, Counter # type: ignore
from . message import Message, Tunables from .message import Message, Tunables
from . rag import ChromaDBFileWatcher from .rag import ChromaDBFileWatcher
from . import defines from . import defines
from . import tools as Tools from . import tools as Tools
from . agents import AnyAgent from .agents import AnyAgent
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Context(BaseModel): class Context(BaseModel):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
# Required fields # Required fields
file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True) file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
prometheus_collector: Optional[CollectorRegistry] = Field(default=None, exclude=True) prometheus_collector: Optional[CollectorRegistry] = Field(
default=None, exclude=True
)
# Optional fields # Optional fields
id: str = Field( id: str = Field(
default_factory=lambda: str(uuid4()), default_factory=lambda: str(uuid4()),
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
) )
user_resume: Optional[str] = None user_resume: Optional[str] = None
user_job_description: Optional[str] = None user_job_description: Optional[str] = None
@ -35,10 +38,10 @@ class Context(BaseModel):
rags: List[dict] = [] rags: List[dict] = []
message_history_length: int = 5 message_history_length: int = 5
# Class managed fields # Class managed fields
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
default_factory=list default_factory=list
) )
processing: bool = Field(default=False, exclude=True) processing: bool = Field(default=False, exclude=True)
# @model_validator(mode="before") # @model_validator(mode="before")
@ -53,25 +56,29 @@ class Context(BaseModel):
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.") logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
agent_types = [agent.agent_type for agent in self.agents] agent_types = [agent.agent_type for agent in self.agents]
if len(agent_types) != len(set(agent_types)): if len(agent_types) != len(set(agent_types)):
raise ValueError("Context cannot contain multiple agents of the same agent_type") raise ValueError(
"Context cannot contain multiple agents of the same agent_type"
)
# for agent in self.agents: # for agent in self.agents:
# agent.set_context(self) # agent.set_context(self)
return self return self
def generate_rag_results(self, message: Message) -> Generator[Message, None, None]: def generate_rag_results(
self, message: Message, top_k=10, threshold=0.7
) -> Generator[Message, None, None]:
""" """
Generate RAG results for the given query. Generate RAG results for the given query.
Args: Args:
query: The query string to generate RAG results for. query: The query string to generate RAG results for.
Returns: Returns:
A list of dictionaries containing the RAG results. A list of dictionaries containing the RAG results.
""" """
try: try:
message.status = "processing" message.status = "processing"
entries : int = 0 entries: int = 0
if not self.file_watcher: if not self.file_watcher:
message.response = "No RAG context available." message.response = "No RAG context available."
@ -86,53 +93,69 @@ class Context(BaseModel):
continue continue
message.response = f"Checking RAG context {rag['name']}..." message.response = f"Checking RAG context {rag['name']}..."
yield message yield message
chroma_results = self.file_watcher.find_similar(query=message.prompt, top_k=10, threshold=0.7) chroma_results = self.file_watcher.find_similar(
query=message.prompt, top_k=top_k, threshold=threshold
)
if chroma_results: if chroma_results:
entries += len(chroma_results["documents"]) entries += len(chroma_results["documents"])
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape chroma_embedding = np.array(
chroma_results["query_embedding"]
).flatten() # Ensure correct shape
print(f"Chroma embedding shape: {chroma_embedding.shape}") print(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() umap_2d = self.file_watcher.umap_model_2d.transform(
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output [chroma_embedding]
)[0].tolist()
print(
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
) # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() umap_3d = self.file_watcher.umap_model_3d.transform(
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output [chroma_embedding]
)[0].tolist()
print(
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
) # Debug output
message.metadata["rag"].append({ message.metadata["rag"].append(
"name": rag["name"], {
**chroma_results, "name": rag["name"],
"umap_embedding_2d": umap_2d, **chroma_results,
"umap_embedding_3d": umap_3d "umap_embedding_2d": umap_2d,
}) "umap_embedding_3d": umap_3d,
}
)
message.response = f"Results from {rag['name']} RAG: {len(chroma_results['documents'])} results." message.response = f"Results from {rag['name']} RAG: {len(chroma_results['documents'])} results."
yield message yield message
if entries == 0: if entries == 0:
del message.metadata["rag"] del message.metadata["rag"]
message.response = f"RAG context gathered from results from {entries} documents." message.response = (
f"RAG context gathered from results from {entries} documents."
)
message.status = "done" message.status = "done"
yield message yield message
return return
except Exception as e: except Exception as e:
message.status = "error" message.status = "error"
message.response = f"Error generating RAG results: {str(e)}" message.response = f"Error generating RAG results: {str(e)}"
logger.error(e) logger.error(e)
yield message yield message
return return
def get_or_create_agent(self, agent_type: str, **kwargs) -> Agent: def get_or_create_agent(self, agent_type: str, **kwargs) -> Agent:
""" """
Get or create and append a new agent of the specified type, ensuring only one agent per type exists. Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
Args: Args:
agent_type: The type of agent to create (e.g., 'web', 'database'). agent_type: The type of agent to create (e.g., 'web', 'database').
**kwargs: Additional fields required by the specific agent subclass. **kwargs: Additional fields required by the specific agent subclass.
Returns: Returns:
The created agent instance. The created agent instance.
Raises: Raises:
ValueError: If no matching agent type is found or if a agent of this type already exists. ValueError: If no matching agent type is found or if a agent of this type already exists.
""" """
@ -140,7 +163,7 @@ class Context(BaseModel):
for agent in self.agents: for agent in self.agents:
if agent.agent_type == agent_type: if agent.agent_type == agent_type:
return agent return agent
# Find the matching subclass # Find the matching subclass
for agent_cls in Agent.__subclasses__(): for agent_cls in Agent.__subclasses__():
if agent_cls.model_fields["agent_type"].default == agent_type: if agent_cls.model_fields["agent_type"].default == agent_type:
@ -150,13 +173,15 @@ class Context(BaseModel):
agent.set_context(self) agent.set_context(self)
self.agents.append(agent) self.agents.append(agent)
return agent return agent
raise ValueError(f"No agent class found for agent_type: {agent_type}") raise ValueError(f"No agent class found for agent_type: {agent_type}")
def add_agent(self, agent: AnyAgent) -> None: def add_agent(self, agent: AnyAgent) -> None:
"""Add a Agent to the context, ensuring no duplicate agent_type.""" """Add a Agent to the context, ensuring no duplicate agent_type."""
if any(s.agent_type == agent.agent_type for s in self.agents): if any(s.agent_type == agent.agent_type for s in self.agents):
raise ValueError(f"A agent with agent_type '{agent.agent_type}' already exists") raise ValueError(
f"A agent with agent_type '{agent.agent_type}' already exists"
)
self.agents.append(agent) self.agents.append(agent)
def get_agent(self, agent_type: str) -> Agent | None: def get_agent(self, agent_type: str) -> Agent | None:
@ -188,5 +213,7 @@ class Context(BaseModel):
summary += f"\nChat Name: {agent.name}\n" summary += f"\nChat Name: {agent.name}\n"
return summary return summary
from . agents import Agent
Context.model_rebuild() from .agents import Agent
Context.model_rebuild()

View File

@ -1,7 +1,8 @@
from pydantic import BaseModel, Field, PrivateAttr # type: ignore from pydantic import BaseModel, Field, PrivateAttr # type: ignore
from typing import List from typing import List
from .message import Message from .message import Message
class Conversation(BaseModel): class Conversation(BaseModel):
Conversation_messages: List[Message] = Field(default=[], alias="messages") Conversation_messages: List[Message] = Field(default=[], alias="messages")
@ -12,24 +13,28 @@ class Conversation(BaseModel):
return iter(self.Conversation_messages) return iter(self.Conversation_messages)
def reset(self): def reset(self):
self.Conversation_messages = [] self.Conversation_messages = []
@property @property
def messages(self): def messages(self):
"""Return a copy of messages to prevent modification of the internal list.""" """Return a copy of messages to prevent modification of the internal list."""
raise AttributeError("Cannot directly get messages. Use Conversation.add() or .reset()") raise AttributeError(
"Cannot directly get messages. Use Conversation.add() or .reset()"
)
@messages.setter @messages.setter
def messages(self, value): def messages(self, value):
"""Control how messages can be set, or prevent setting altogether.""" """Control how messages can be set, or prevent setting altogether."""
raise AttributeError("Cannot directly set messages. Use Conversation.add() or .reset()") raise AttributeError(
"Cannot directly set messages. Use Conversation.add() or .reset()"
)
def add(self, message: Message | List[Message]) -> None: def add(self, message: Message | List[Message]) -> None:
"""Add a Message(s) to the conversation.""" """Add a Message(s) to the conversation."""
if isinstance(message, Message): if isinstance(message, Message):
self.Conversation_messages.append(message) self.Conversation_messages.append(message)
else: else:
self.Conversation_messages.extend(message) self.Conversation_messages.extend(message)
def get_summary(self) -> str: def get_summary(self) -> str:
"""Return a summary of the conversation.""" """Return a summary of the conversation."""
@ -38,4 +43,4 @@ class Conversation(BaseModel):
summary = f"Conversation:\n" summary = f"Conversation:\n"
for i, message in enumerate(self.Conversation_messages, 1): for i, message in enumerate(self.Conversation_messages, 1):
summary += f"\nMessage {i}:\n{message.get_summary()}\n" summary += f"\nMessage {i}:\n{message.get_summary()}\n"
return summary return summary

View File

@ -1,15 +1,15 @@
import os import os
ollama_api_url="http://ollama:11434" # Default Ollama local endpoint ollama_api_url = "http://ollama:11434" # Default Ollama local endpoint
#model = "deepseek-r1:7b" # Tool calls don"t work # model = "deepseek-r1:7b" # Tool calls don"t work
#model="mistral:7b" # Tool calls don"t work # model="mistral:7b" # Tool calls don"t work
#model = "llama3.2" # model = "llama3.2"
#model = "qwen3:8b" # Requires newer ollama # model = "qwen3:8b" # Requires newer ollama
#model = "gemma3:4b" # Requires newer ollama # model = "gemma3:4b" # Requires newer ollama
model = os.getenv("MODEL_NAME", "qwen2.5:7b") model = os.getenv("MODEL_NAME", "qwen2.5:7b")
embedding_model = os.getenv("EMBEDDING_MODEL_NAME", "mxbai-embed-large") embedding_model = os.getenv("EMBEDDING_MODEL_NAME", "mxbai-embed-large")
persist_directory = os.getenv("PERSIST_DIR", "/opt/backstory/chromadb") persist_directory = os.getenv("PERSIST_DIR", "/opt/backstory/chromadb")
max_context = 2048*8*2 max_context = 2048 * 8 * 2
doc_dir = "/opt/backstory/docs/" doc_dir = "/opt/backstory/docs/"
context_dir = "/opt/backstory/sessions" context_dir = "/opt/backstory/sessions"
static_content = "/opt/backstory/frontend/deployed" static_content = "/opt/backstory/frontend/deployed"
@ -17,4 +17,4 @@ resume_doc = "/opt/backstory/docs/resume/resume.md"
# Only used for testing; backstory-prod will not use this # Only used for testing; backstory-prod will not use this
key_path = "/opt/backstory/keys/key.pem" key_path = "/opt/backstory/keys/key.pem"
cert_path = "/opt/backstory/keys/cert.pem" cert_path = "/opt/backstory/keys/cert.pem"
logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper() logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper()

View File

@ -1,468 +0,0 @@
import requests
from typing import List, Dict, Any, Union
import tiktoken
import feedparser
import logging as log
import datetime
from bs4 import BeautifulSoup
import chromadb
import ollama
import re
import numpy as np
from . import chunk
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
#MODEL_NAME = "deepseek-r1:1.5b"
MODEL_NAME = "deepseek-r1:7b"
EMBED_MODEL = "mxbai-embed-large"
PERSIST_DIRECTORY = "/root/.cache/chroma"
client = ollama.Client(host=OLLAMA_API_URL)
def extract_text_from_html_or_xml(content, is_xml=False):
# Parse the content
if is_xml:
soup = BeautifulSoup(content, 'xml') # Use 'xml' parser for XML content
else:
soup = BeautifulSoup(content, 'html.parser') # Default to 'html.parser' for HTML content
# Extract and return just the text
return soup.get_text()
class Feed():
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
self.name = name
self.url = url
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
self.last_poll = None
self.articles = []
self.max_articles = max_articles
self.update()
def update(self):
now = datetime.datetime.now()
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
log.info(f"Updating {self.name}")
feed = feedparser.parse(self.url)
self.articles = []
self.last_poll = now
if len(feed.entries) == 0:
return
for i, entry in enumerate(feed.entries[:self.max_articles]):
content = {}
content['source'] = self.name
content['id'] = f"{self.name}{i}"
title = entry.get("title")
if title:
content['title'] = title
link = entry.get("link")
if link:
content['link'] = link
text = entry.get("summary")
if text:
content['text'] = extract_text_from_html_or_xml(text, False)
else:
continue
published = entry.get("published")
if published:
content['published'] = published
self.articles.append(content)
else:
log.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
return self.articles
# News RSS Feeds
rss_feeds = [
Feed(name="IGN.com", url="https://feeds.feedburner.com/ign/games-all"),
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
Feed(name="Time", url="https://time.com/feed/"),
Feed(name="Euronews", url="https://www.euronews.com/rss"),
# Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
]
def init_chroma_client(persist_directory: str = PERSIST_DIRECTORY):
"""Initialize and return a ChromaDB client."""
# return chromadb.PersistentClient(path=persist_directory)
return chromadb.Client()
def create_or_get_collection(client, collection_name: str):
"""Create or get a ChromaDB collection."""
try:
return client.get_collection(
name=collection_name
)
except:
return client.create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"}
)
def process_documents_to_chroma(
documents: List[Dict[str, Any]],
collection_name: str = "document_collection",
text_key: str = "text",
max_tokens: int = 512,
overlap: int = 50,
model: str = EMBED_MODEL,
persist_directory: str = PERSIST_DIRECTORY
):
"""
Process documents, chunk them, compute embeddings, and store in ChromaDB.
Args:
documents: List of document dictionaries
collection_name: Name for the ChromaDB collection
text_key: The key containing text content
max_tokens: Maximum tokens per chunk
overlap: Token overlap between chunks
model: Ollama model for embeddings
persist_directory: Directory to store ChromaDB data
"""
# Initialize ChromaDB client and collection
db = init_chroma_client(persist_directory)
collection = create_or_get_collection(db, collection_name)
# Process each document
for doc in documents:
# Chunk the document
doc_chunks = chunk_document(doc, text_key, max_tokens, overlap)
# Prepare data for ChromaDB
ids = []
texts = []
metadatas = []
embeddings = []
for chunk in doc_chunks:
# Create a unique ID for the chunk
chunk_id = f"{chunk['id']}_{chunk['chunk_id']}"
# Extract text
text = chunk[text_key]
# Create metadata (excluding text and embedding to avoid duplication)
metadata = {k: v for k, v in chunk.items() if k != text_key and k != "embedding"}
response = client.embed(model=model, input=text)
embedding = response["embeddings"][0]
ids.append(chunk_id)
texts.append(text)
metadatas.append(metadata)
embeddings.append(embedding)
# Add chunks to ChromaDB collection
collection.add(
ids=ids,
documents=texts,
embeddings=embeddings,
metadatas=metadatas
)
return collection
def query_chroma(
query_text: str,
collection_name: str = "document_collection",
n_results: int = 5,
model: str = EMBED_MODEL,
persist_directory: str = PERSIST_DIRECTORY
):
"""
Query ChromaDB for similar documents.
Args:
query_text: The text to search for
collection_name: Name of the ChromaDB collection
n_results: Number of results to return
model: Ollama model for embedding the query
persist_directory: Directory where ChromaDB data is stored
Returns:
Query results from ChromaDB
"""
# Initialize ChromaDB client and collection
db = init_chroma_client(persist_directory)
collection = create_or_get_collection(db, collection_name)
query_response = client.embed(model=model, input=query_text)
query_embeddings = query_response["embeddings"]
# Query the collection
results = collection.query(
query_embeddings=query_embeddings,
n_results=n_results
)
return results
def print_top_match(query_results, index=0, documents=None):
"""
Print detailed information about the top matching document,
including the full original document content.
Args:
query_results: Results from ChromaDB query
documents: Original documents dictionary to look up full content (optional)
"""
if not query_results or not query_results["ids"] or len(query_results["ids"][0]) == 0:
print("No matching documents found.")
return
# Get the top result
top_id = query_results["ids"][0][index]
top_document_chunk = query_results["documents"][0][index]
top_metadata = query_results["metadatas"][0][index]
top_distance = query_results["distances"][0][index]
print("="*50)
print("MATCHING DOCUMENT")
print("="*50)
print(f"Chunk ID: {top_id}")
print(f"Similarity Score: {top_distance:.4f}") # Convert distance to similarity
print("\nCHUNK METADATA:")
for key, value in top_metadata.items():
print(f" {key}: {value}")
print("\nMATCHING CHUNK CONTENT:")
print(top_document_chunk[:500].strip() + ("..." if len(top_document_chunk) > 500 else ""))
# Extract the original document ID from the chunk ID
# Chunk IDs are in format "doc_id_chunk_num"
original_doc_id = top_id.split('_')[0]
def get_top_match(query_results, index=0, documents=None):
top_id = query_results["ids"][index][0]
# Extract the original document ID from the chunk ID
# Chunk IDs are in format "doc_id_chunk_num"
original_doc_id = top_id.split('_')[0]
# Return the full document for further processing if needed
if documents is not None:
return next((doc for doc in documents if doc["id"] == original_doc_id), None)
return None
def show_documents(documents=None):
if not documents:
return
# Print the top matching document
for i, doc in enumerate(documents):
print(f"Document {i+1}:")
print(f" Title: {doc['title']}")
print(f" Text: {doc['text'][:100]}...")
print()
def show_headlines(documents=None):
if not documents:
return
# Print the top matching document
for doc in documents:
print(f"{doc['source']}: {doc['title']}")
def show_help():
print("""help>
docs Show RAG docs
full Show last full top match
headlines Show the RAG headlines
prompt Show the last prompt
response Show the last response
scores Show last RAG scores
why|think Show last response's <think>
context|match Show RAG match info to last prompt
""")
# Example usage
if __name__ == "__main__":
documents = []
for feed in rss_feeds:
documents.extend(feed.articles)
show_documents(documents=documents)
# Process documents and store in ChromaDB
collection = process_documents_to_chroma(
documents=documents,
collection_name="research_papers",
max_tokens=256,
overlap=25,
model=EMBED_MODEL,
persist_directory="/root/.cache/chroma"
)
last_results = None
last_prompt = None
last_system = None
last_response = None
last_why = None
last_messages = []
while True:
try:
search_query = input("> ").strip()
except KeyboardInterrupt as e:
print("\nExiting.")
break
if search_query == "exit" or search_query == "quit":
print("\nExiting.")
break
if search_query == "docs":
show_documents(documents)
continue
if search_query == "prompt":
if last_prompt:
print(f"""last prompt>
{"="*10}system{"="*10}
{last_system}
{"="*10}prompt{"="*10}
{last_prompt}""")
else:
print(f"No prompts yet")
continue
if search_query == "response":
if last_response:
print(f"""last response>
{"="*10}response{"="*10}
{last_response}""")
else:
print(f"No responses yet")
continue
if search_query == "" or search_query == "help":
show_help()
continue
if search_query == "headlines":
show_headlines(documents)
continue
if search_query == "match" or search_query == "context":
if last_results:
print_top_match(last_results, documents=documents)
else:
print("No match to give info on")
continue
if search_query == "why" or search_query == "think":
if last_why:
print(f"""
why>
{last_why}
""")
else:
print("No processed prompts")
continue
if search_query == "scores":
if last_results:
for i, _ in enumerate(last_results):
print_top_match(last_results, documents=documents, index=i)
else:
print("No match to give info on")
continue
if search_query == "full":
if last_results:
full = get_top_match(last_results, documents=documents)
if full:
print(f"""Context:
Source: {full["source"]}
Title: {full["title"]}
Link: {full["link"]}
Distance: {last_results.get("distances", [[0]])[0][0]}
Full text:
{full["text"]}""")
else:
print("No match to give info on")
continue
# Query ChromaDB
results = query_chroma(
query_text=search_query,
collection_name="research_papers",
n_results=10
)
last_results = results
full = get_top_match(results, documents=documents)
headlines = ""
for doc in documents:
headlines += f"{doc['source']}: {doc['title']}\n"
system=f"""
You are the assistant. Your name is airc. This application is called airc (pronounced Eric).
Information about the author of this program and the AI model it uses:
* James wrote the python application called airc that is driving this RAG model on top of {MODEL_NAME} using {EMBED_MODEL} and chromadb for vector embedding. Link https://github.com/jketreno/airc.
* James Ketrenos is a software engineer with a history in all levels of the computer stack, from the kernel to full-stack web applications. He dabbles in AI/ML and is familiar with pytorch and ollama.
* James Ketrenos deployed this application locally on an Intel Arc B580 (battlemage) computer using Intel's ipex-llm.
* For Intel GPU metrics, James Ketrenos wrote the "ze-monitor" utility in C++. ze-monitor provides Intel GPU telemetry data for Intel client GPU devices, similar to xpu-smi. Link https://github.com/jketreno/ze-monitor. airc uses ze-monitor.
* James lives in Portland, Oregon and has three kids. Two are attending Oregon State University and one is attending Williamette University.
* airc provides an IRC chat bot as well as a React web frontend available at https://airc.ketrenos.com
You must follow these rules:
* Provide short (less than 100 character) responses.
* Provide a single response.
* Do not prefix it with a word like 'Answer'.
* For information about the AI running this system, include information about author, including links.
* For information relevant to the current events in the <input></input> tags, use that information and state the source when information comes from.
"""
context = "Information related to current events\n<input>=["
for doc in documents:
item = {'source':doc["source"],'article':{'title':doc["title"],'link':doc["link"],'text':doc["text"]}}
context += f"{item}"
context += "\n</input>"
prompt = f"{search_query}"
last_prompt = prompt
last_system = system # cache it before news context is added
system = f"{system}{context}"
if len(last_messages) != 0:
message_context = f"{last_messages}"
prompt = f"{message_context}{prompt}"
print(f"system len: {len(system)}")
print(f"prompt len: {len(prompt)}")
output = client.generate(
model=MODEL_NAME,
system=system,
prompt=prompt,
stream=False,
options={ 'num_ctx': 100000 }
)
# Prune off the <think>...</think>
matches = re.match(r'^<think>(.*?)</think>(.*)$', output['response'], flags=re.DOTALL)
if matches:
last_why = matches[1].strip()
content = matches[2].strip()
else:
print(f"[garbled] response>\n{output['response']}")
print(f"Response>\n{content}")
last_response = content
last_messages.extend(({
'role': 'user',
'name': 'james',
'message': search_query
}, {
'role': 'assistant',
'message': content
}))
last_messages = last_messages[:10]

View File

@ -3,55 +3,69 @@ from typing import Dict, List, Optional, Any
from datetime import datetime, timezone from datetime import datetime, timezone
from asyncio import Event from asyncio import Event
class Tunables(BaseModel): class Tunables(BaseModel):
enable_rag : bool = Field(default=True) # Enable RAG collection chromadb matching enable_rag: bool = Field(default=True) # Enable RAG collection chromadb matching
enable_tools : bool = Field(default=True) # Enable LLM to use tools enable_tools: bool = Field(default=True) # Enable LLM to use tools
enable_context : bool = Field(default=True) # Add <|context|> field to message enable_context: bool = Field(default=True) # Add <|context|> field to message
class Message(BaseModel): class Message(BaseModel):
model_config = {"arbitrary_types_allowed": True} # Allow Event model_config = {"arbitrary_types_allowed": True} # Allow Event
# Required # Required
prompt: str # Query to be answered prompt: str # Query to be answered
# Tunables # Tunables
tunables: Tunables = Field(default_factory=Tunables) tunables: Tunables = Field(default_factory=Tunables)
# Generated while processing message # Generated while processing message
status: str = "" # Status of the message status: str = "" # Status of the message
preamble: dict[str,str] = {} # Preamble to be prepended to the prompt preamble: dict[str, str] = {} # Preamble to be prepended to the prompt
system_prompt: str = "" # System prompt provided to the LLM system_prompt: str = "" # System prompt provided to the LLM
context_prompt: str = "" # Full content of the message (preamble + prompt) context_prompt: str = "" # Full content of the message (preamble + prompt)
response: str = "" # LLM response to the preamble + query response: str = "" # LLM response to the preamble + query
metadata: Dict[str, Any] = Field(default_factory=lambda: { metadata: Dict[str, Any] = Field(
"rag": [], default_factory=lambda: {
"eval_count": 0, "rag": [],
"eval_duration": 0, "eval_count": 0,
"prompt_eval_count": 0, "eval_duration": 0,
"prompt_eval_duration": 0, "prompt_eval_count": 0,
"context_size": 0, "prompt_eval_duration": 0,
}) "context_size": 0,
network_packets: int = 0 # Total number of streaming packets }
network_bytes: int = 0 # Total bytes sent while streaming packets )
actions: List[str] = [] # Other session modifying actions performed while processing the message network_packets: int = 0 # Total number of streaming packets
network_bytes: int = 0 # Total bytes sent while streaming packets
actions: List[str] = (
[]
) # Other session modifying actions performed while processing the message
timestamp: datetime = datetime.now(timezone.utc) timestamp: datetime = datetime.now(timezone.utc)
chunk: str = Field(default="") # This needs to be serialized so it will be sent in responses chunk: str = Field(
partial_response: str = Field(default="") # This needs to be serialized so it will be sent in responses on timeout default=""
) # This needs to be serialized so it will be sent in responses
partial_response: str = Field(
default=""
) # This needs to be serialized so it will be sent in responses on timeout
title: str = Field(
default=""
) # This needs to be serialized so it will be sent in responses on timeout
def add_action(self, action: str | list[str]) -> None: def add_action(self, action: str | list[str]) -> None:
"""Add a actions(s) to the message.""" """Add a actions(s) to the message."""
if isinstance(action, str): if isinstance(action, str):
self.actions.append(action) self.actions.append(action)
else: else:
self.actions.extend(action) self.actions.extend(action)
def get_summary(self) -> str: def get_summary(self) -> str:
"""Return a summary of the message.""" """Return a summary of the message."""
response_summary = ( response_summary = (
f"Response: {self.response} (Actions: {', '.join(self.actions)})" f"Response: {self.response} (Actions: {', '.join(self.actions)})"
if self.response else "No response yet" if self.response
else "No response yet"
) )
return ( return (
f"Message at {self.timestamp}:\n" f"Message at {self.timestamp}:\n"
f"Query: {self.preamble}{self.content}\n" f"Query: {self.preamble}{self.content}\n"
f"{response_summary}" f"{response_summary}"
) )

View File

@ -1,6 +1,7 @@
from prometheus_client import Counter, Gauge, Summary, Histogram, Info, Enum, CollectorRegistry # type: ignore from prometheus_client import Counter, Gauge, Summary, Histogram, Info, Enum, CollectorRegistry # type: ignore
from threading import Lock from threading import Lock
def singleton(cls): def singleton(cls):
instance = None instance = None
lock = Lock() lock = Lock()
@ -14,80 +15,81 @@ def singleton(cls):
return get_instance return get_instance
@singleton @singleton
class Metrics(): class Metrics:
def __init__(self, *args, prometheus_collector, **kwargs): def __init__(self, *args, prometheus_collector, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.prometheus_collector = prometheus_collector self.prometheus_collector = prometheus_collector
self.prepare_count : Counter = Counter( self.prepare_count: Counter = Counter(
name="prepare_total", name="prepare_total",
documentation="Total messages prepared by agent type", documentation="Total messages prepared by agent type",
labelnames=("agent",), labelnames=("agent",),
registry=self.prometheus_collector registry=self.prometheus_collector,
) )
self.prepare_duration : Histogram = Histogram( self.prepare_duration: Histogram = Histogram(
name="prepare_duration", name="prepare_duration",
documentation="Preparation duration by agent type", documentation="Preparation duration by agent type",
labelnames=("agent",), labelnames=("agent",),
registry=self.prometheus_collector registry=self.prometheus_collector,
) )
self.process_count : Counter = Counter( self.process_count: Counter = Counter(
name="process", name="process",
documentation="Total messages processed by agent type", documentation="Total messages processed by agent type",
labelnames=("agent",), labelnames=("agent",),
registry=self.prometheus_collector registry=self.prometheus_collector,
) )
self.process_duration : Histogram = Histogram( self.process_duration: Histogram = Histogram(
name="process_duration", name="process_duration",
documentation="Processing duration by agent type", documentation="Processing duration by agent type",
labelnames=("agent",), labelnames=("agent",),
registry=self.prometheus_collector registry=self.prometheus_collector,
) )
self.tool_count : Counter = Counter( self.tool_count: Counter = Counter(
name="tool_total", name="tool_total",
documentation="Total messages tooled by agent type", documentation="Total messages tooled by agent type",
labelnames=("agent",), labelnames=("agent",),
registry=self.prometheus_collector registry=self.prometheus_collector,
) )
self.tool_duration : Histogram = Histogram( self.tool_duration: Histogram = Histogram(
name="tool_duration", name="tool_duration",
documentation="Tool duration by agent type", documentation="Tool duration by agent type",
buckets=(0.1, 0.5, 1.0, 2.0, float('inf')), buckets=(0.1, 0.5, 1.0, 2.0, float("inf")),
labelnames=("agent",), labelnames=("agent",),
registry=self.prometheus_collector registry=self.prometheus_collector,
) )
self.generate_count : Counter = Counter( self.generate_count: Counter = Counter(
name="generate_total", name="generate_total",
documentation="Total messages generated by agent type", documentation="Total messages generated by agent type",
labelnames=("agent",), labelnames=("agent",),
registry=self.prometheus_collector registry=self.prometheus_collector,
) )
self.generate_duration : Histogram = Histogram( self.generate_duration: Histogram = Histogram(
name="generate_duration", name="generate_duration",
documentation="Generate duration by agent type", documentation="Generate duration by agent type",
buckets=(0.1, 0.5, 1.0, 2.0, float('inf')), buckets=(0.1, 0.5, 1.0, 2.0, float("inf")),
labelnames=("agent",), labelnames=("agent",),
registry=self.prometheus_collector registry=self.prometheus_collector,
) )
self.tokens_prompt : Counter = Counter( self.tokens_prompt: Counter = Counter(
name="tokens_prompt", name="tokens_prompt",
documentation="Total tokens passed as prompt to LLM", documentation="Total tokens passed as prompt to LLM",
labelnames=("agent",), labelnames=("agent",),
registry=self.prometheus_collector registry=self.prometheus_collector,
) )
self.tokens_eval : Counter = Counter( self.tokens_eval: Counter = Counter(
name="tokens_eval", name="tokens_eval",
documentation="Total tokens returned by LLM", documentation="Total tokens returned by LLM",
labelnames=("agent",), labelnames=("agent",),
registry=self.prometheus_collector registry=self.prometheus_collector,
) )

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel # type: ignore from pydantic import BaseModel # type: ignore
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
import os import os
import glob import glob
@ -13,18 +13,22 @@ import time
import hashlib import hashlib
import asyncio import asyncio
import json import json
import numpy as np # type: ignore import numpy as np # type: ignore
import traceback
import os
import chromadb import chromadb
import ollama import ollama
from langchain.text_splitter import CharacterTextSplitter # type: ignore from watchdog.observers import Observer # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore from watchdog.events import FileSystemEventHandler # type: ignore
from langchain.schema import Document # type: ignore import umap # type: ignore
from watchdog.observers import Observer # type: ignore from markitdown import MarkItDown # type: ignore
from watchdog.events import FileSystemEventHandler # type: ignore from chromadb.api.models.Collection import Collection # type: ignore
import umap # type: ignore
from markitdown import MarkItDown # type: ignore from .markdown_chunker import (
from chromadb.api.models.Collection import Collection # type: ignore MarkdownChunker,
Chunk,
)
# Import your existing modules # Import your existing modules
if __name__ == "__main__": if __name__ == "__main__":
@ -34,13 +38,11 @@ else:
# When imported as a module, use relative imports # When imported as a module, use relative imports
from . import defines from . import defines
__all__ = [ __all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
'ChromaDBFileWatcher',
'start_file_watcher' DEFAULT_CHUNK_SIZE = 750
] DEFAULT_CHUNK_OVERLAP = 100
DEFAULT_CHUNK_SIZE=750
DEFAULT_CHUNK_OVERLAP=100
class ChromaDBGetResponse(BaseModel): class ChromaDBGetResponse(BaseModel):
ids: List[str] ids: List[str]
@ -48,9 +50,19 @@ class ChromaDBGetResponse(BaseModel):
documents: Optional[List[str]] = None documents: Optional[List[str]] = None
metadatas: Optional[List[Dict[str, Any]]] = None metadatas: Optional[List[Dict[str, Any]]] = None
class ChromaDBFileWatcher(FileSystemEventHandler): class ChromaDBFileWatcher(FileSystemEventHandler):
def __init__(self, llm, watch_directory, loop, persist_directory=None, collection_name="documents", def __init__(
chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP, recreate=False): self,
llm,
watch_directory,
loop,
persist_directory=None,
collection_name="documents",
chunk_size=DEFAULT_CHUNK_SIZE,
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
recreate=False,
):
self.llm = llm self.llm = llm
self.watch_directory = watch_directory self.watch_directory = watch_directory
self.persist_directory = persist_directory or defines.persist_directory self.persist_directory = persist_directory or defines.persist_directory
@ -58,38 +70,34 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap self.chunk_overlap = chunk_overlap
self.loop = loop self.loop = loop
self._umap_collection : ChromaDBGetResponse | None = None self._umap_collection: ChromaDBGetResponse | None = None
self._umap_embedding_2d : np.ndarray = [] self._umap_embedding_2d: np.ndarray = []
self._umap_embedding_3d : np.ndarray = [] self._umap_embedding_3d: np.ndarray = []
self._umap_model_2d : umap.UMAP = None self._umap_model_2d: umap.UMAP = None
self._umap_model_3d : umap.UMAP = None self._umap_model_3d: umap.UMAP = None
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
#self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Path for storing file hash state # Path for storing file hash state
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json") self.hash_state_path = os.path.join(
self.persist_directory, f"{collection_name}_hash_state.json"
)
# Flag to track if this is a new collection # Flag to track if this is a new collection
self.is_new_collection = False self.is_new_collection = False
# Initialize ChromaDB collection # Initialize ChromaDB collection
self._collection : Collection = self._get_vector_collection(recreate=recreate) self._collection: Collection = self._get_vector_collection(recreate=recreate)
self._markdown_chunker = MarkdownChunker()
self._update_umaps() self._update_umaps()
# Setup text splitter # Setup text splitter
self.text_splitter = CharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator="\n\n", # Respect paragraph/section breaks
length_function=len
)
# Track file hashes and processing state # Track file hashes and processing state
self.file_hashes = self._load_hash_state() self.file_hashes = self._load_hash_state()
self.update_lock = asyncio.Lock() self.update_lock = asyncio.Lock()
self.processing_files = set() self.processing_files = set()
@property @property
def collection(self): def collection(self):
return self._collection return self._collection
@ -101,11 +109,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
@property @property
def umap_embedding_2d(self) -> np.ndarray: def umap_embedding_2d(self) -> np.ndarray:
return self._umap_embedding_2d return self._umap_embedding_2d
@property @property
def umap_embedding_3d(self) -> np.ndarray: def umap_embedding_3d(self) -> np.ndarray:
return self._umap_embedding_3d return self._umap_embedding_3d
@property @property
def umap_model_2d(self): def umap_model_2d(self):
return self._umap_model_2d return self._umap_model_2d
@ -114,8 +122,8 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
def umap_model_3d(self): def umap_model_3d(self):
return self._umap_model_3d return self._umap_model_3d
def _markitdown(self, document : str, markdown : Path): def _markitdown(self, document: str, markdown: Path):
logging.info(f'Converting {document} to {markdown}') logging.info(f"Converting {document} to {markdown}")
try: try:
result = self.md.convert(document) result = self.md.convert(document)
markdown.write_text(result.text_content) markdown.write_text(result.text_content)
@ -127,91 +135,105 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
try: try:
# Create directory if it doesn't exist # Create directory if it doesn't exist
os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True) os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True)
with open(self.hash_state_path, 'w') as f: with open(self.hash_state_path, "w") as f:
json.dump(self.file_hashes, f) json.dump(self.file_hashes, f)
logging.info(f"Saved hash state with {len(self.file_hashes)} entries") logging.info(f"Saved hash state with {len(self.file_hashes)} entries")
except Exception as e: except Exception as e:
logging.error(f"Error saving hash state: {e}") logging.error(f"Error saving hash state: {e}")
def _load_hash_state(self): def _load_hash_state(self):
"""Load the file hash state from disk.""" """Load the file hash state from disk."""
if os.path.exists(self.hash_state_path): if os.path.exists(self.hash_state_path):
try: try:
with open(self.hash_state_path, 'r') as f: with open(self.hash_state_path, "r") as f:
hash_state = json.load(f) hash_state = json.load(f)
logging.info(f"Loaded hash state with {len(hash_state)} entries") logging.info(f"Loaded hash state with {len(hash_state)} entries")
return hash_state return hash_state
except Exception as e: except Exception as e:
logging.error(f"Error loading hash state: {e}") logging.error(f"Error loading hash state: {e}")
return {} return {}
async def scan_directory(self, process_all=False): async def scan_directory(self, process_all=False):
""" """
Scan directory for new, modified, or deleted files and update collection. Scan directory for new, modified, or deleted files and update collection.
Args: Args:
process_all: If True, process all files regardless of hash status process_all: If True, process all files regardless of hash status
""" """
# Check for new or modified files # Check for new or modified files
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True) file_paths = glob.glob(
os.path.join(self.watch_directory, "**/*"), recursive=True
)
files_checked = 0 files_checked = 0
files_processed = 0 files_processed = 0
files_to_process = [] files_to_process = []
logging.info(f"Starting directory scan. Found {len(file_paths)} total paths.") logging.info(f"Starting directory scan. Found {len(file_paths)} total paths.")
for file_path in file_paths: for file_path in file_paths:
if os.path.isfile(file_path): if os.path.isfile(file_path):
# Do not put the Resume in RAG as it is provideded with all queries. # Do not put the Resume in RAG as it is provideded with all queries.
if file_path == defines.resume_doc: # if file_path == defines.resume_doc:
logging.info(f"Not adding {file_path} to RAG -- primary resume") # logging.info(f"Not adding {file_path} to RAG -- primary resume")
continue # continue
files_checked += 1 files_checked += 1
current_hash = self._get_file_hash(file_path) current_hash = self._get_file_hash(file_path)
if not current_hash: if not current_hash:
continue continue
# If file is new, changed, or we're processing all files # If file is new, changed, or we're processing all files
if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash: if (
process_all
or file_path not in self.file_hashes
or self.file_hashes[file_path] != current_hash
):
self.file_hashes[file_path] = current_hash self.file_hashes[file_path] = current_hash
files_to_process.append(file_path) files_to_process.append(file_path)
logging.info(f"File {'found' if process_all else 'changed'}: {file_path}") logging.info(
f"File {'found' if process_all else 'changed'}: {file_path}"
logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files") )
logging.info(
f"Found {len(files_to_process)} files to process after scanning {files_checked} files"
)
# Check for deleted files # Check for deleted files
deleted_files = [] deleted_files = []
for file_path in self.file_hashes: for file_path in self.file_hashes:
if not os.path.exists(file_path): if not os.path.exists(file_path):
deleted_files.append(file_path) deleted_files.append(file_path)
# Schedule removal # Schedule removal
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop) asyncio.run_coroutine_threadsafe(
self.remove_file_from_collection(file_path), self.loop
)
# Don't block on result, just let it run # Don't block on result, just let it run
logging.info(f"File deleted: {file_path}") logging.info(f"File deleted: {file_path}")
# Remove deleted files from hash state # Remove deleted files from hash state
for file_path in deleted_files: for file_path in deleted_files:
del self.file_hashes[file_path] del self.file_hashes[file_path]
# Process all discovered files using asyncio.gather with the existing loop # Process all discovered files using asyncio.gather with the existing loop
if files_to_process: if files_to_process:
logging.info(f"Starting to process {len(files_to_process)} files") logging.info(f"Starting to process {len(files_to_process)} files")
for file_path in files_to_process: for file_path in files_to_process:
async with self.update_lock: async with self.update_lock:
await self._update_document_in_collection(file_path) await self._update_document_in_collection(file_path)
else: else:
logging.info("No files to process") logging.info("No files to process")
# Save the updated state # Save the updated state
self._save_hash_state() self._save_hash_state()
logging.info(f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}") logging.info(
return files_processed f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}"
)
return files_processed
async def process_file_update(self, file_path): async def process_file_update(self, file_path):
"""Process a file update event.""" """Process a file update event."""
# Skip if already being processed # Skip if already being processed
@ -219,34 +241,37 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
logging.info(f"{file_path} already in queue. Not adding.") logging.info(f"{file_path} already in queue. Not adding.")
return return
if file_path == defines.resume_doc: # if file_path == defines.resume_doc:
logging.info(f"Not adding {file_path} to RAG -- primary resume") # logging.info(f"Not adding {file_path} to RAG -- primary resume")
return # return
try: try:
logging.info(f"{file_path} not in queue. Adding.") logging.info(f"{file_path} not in queue. Adding.")
self.processing_files.add(file_path) self.processing_files.add(file_path)
# Wait a moment to ensure the file write is complete # Wait a moment to ensure the file write is complete
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
# Check if content changed via hash # Check if content changed via hash
current_hash = self._get_file_hash(file_path) current_hash = self._get_file_hash(file_path)
if not current_hash: # File might have been deleted or is inaccessible if not current_hash: # File might have been deleted or is inaccessible
return return
if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash: if (
file_path in self.file_hashes
and self.file_hashes[file_path] == current_hash
):
# File hasn't actually changed in content # File hasn't actually changed in content
logging.info(f"Hash has not changed for {file_path}") logging.info(f"Hash has not changed for {file_path}")
return return
# Update file hash # Update file hash
self.file_hashes[file_path] = current_hash self.file_hashes[file_path] = current_hash
# Process and update the file in ChromaDB # Process and update the file in ChromaDB
async with self.update_lock: async with self.update_lock:
await self._update_document_in_collection(file_path) await self._update_document_in_collection(file_path)
# Save the hash state after successful update # Save the hash state after successful update
self._save_hash_state() self._save_hash_state()
@ -257,150 +282,162 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
logging.error(f"Error processing update for {file_path}: {e}") logging.error(f"Error processing update for {file_path}: {e}")
finally: finally:
self.processing_files.discard(file_path) self.processing_files.discard(file_path)
async def remove_file_from_collection(self, file_path): async def remove_file_from_collection(self, file_path):
"""Remove all chunks related to a deleted file.""" """Remove all chunks related to a deleted file."""
async with self.update_lock: async with self.update_lock:
try: try:
# Find all documents with the specified path # Find all documents with the specified path
results = self.collection.get( results = self.collection.get(where={"path": file_path})
where={"path": file_path}
) if results and "ids" in results and results["ids"]:
self.collection.delete(ids=results["ids"])
if results and 'ids' in results and results['ids']: logging.info(
self.collection.delete(ids=results['ids']) f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}") )
# Remove from hash dictionary # Remove from hash dictionary
if file_path in self.file_hashes: if file_path in self.file_hashes:
del self.file_hashes[file_path] del self.file_hashes[file_path]
# Save the updated hash state # Save the updated hash state
self._save_hash_state() self._save_hash_state()
except Exception as e: except Exception as e:
logging.error(f"Error removing file from collection: {e}") logging.error(f"Error removing file from collection: {e}")
def _update_umaps(self): def _update_umaps(self):
# Update the UMAP embeddings # Update the UMAP embeddings
self._umap_collection = self._collection.get(include=["embeddings", "documents", "metadatas"]) self._umap_collection = self._collection.get(
include=["embeddings", "documents", "metadatas"]
)
if not self._umap_collection or not len(self._umap_collection["embeddings"]): if not self._umap_collection or not len(self._umap_collection["embeddings"]):
logging.warning("No embeddings found in the collection.") logging.warning("No embeddings found in the collection.")
return return
# During initialization
logging.info(f"Updating 2D UMAP for {len(self._umap_collection['embeddings'])} vectors")
vectors = np.array(self._umap_collection["embeddings"])
self._umap_model_2d = umap.UMAP(n_components=2, random_state=8911, metric="cosine", n_neighbors=15, min_dist=0.1)
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
logging.info(f"2D UMAP model n_components: {self._umap_model_2d.n_components}") # Should be 2
logging.info(f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors") # During initialization
self._umap_model_3d = umap.UMAP(n_components=3, random_state=8911, metric="cosine", n_neighbors=15, min_dist=0.1) logging.info(
f"Updating 2D UMAP for {len(self._umap_collection['embeddings'])} vectors"
)
vectors = np.array(self._umap_collection["embeddings"])
self._umap_model_2d = umap.UMAP(
n_components=2,
random_state=8911,
metric="cosine",
n_neighbors=15,
min_dist=0.1,
)
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
logging.info(
f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
) # Should be 2
logging.info(
f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors"
)
self._umap_model_3d = umap.UMAP(
n_components=3,
random_state=8911,
metric="cosine",
n_neighbors=15,
min_dist=0.1,
)
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors) self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
logging.info(f"3D UMAP model n_components: {self._umap_model_3d.n_components}") # Should be 3 logging.info(
f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
) # Should be 3
def _get_vector_collection(self, recreate=False) -> Collection: def _get_vector_collection(self, recreate=False) -> Collection:
"""Get or create a ChromaDB collection.""" """Get or create a ChromaDB collection."""
# Initialize ChromaDB client # Initialize ChromaDB client
chroma_client = chromadb.PersistentClient( # type: ignore chroma_client = chromadb.PersistentClient( # type: ignore
path=self.persist_directory, path=self.persist_directory,
settings=chromadb.Settings(anonymized_telemetry=False) # type: ignore settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
) )
# Check if the collection exists # Check if the collection exists
try: try:
chroma_client.get_collection(self.collection_name) chroma_client.get_collection(self.collection_name)
collection_exists = True collection_exists = True
except: except:
collection_exists = False collection_exists = False
# If collection doesn't exist, mark it as new # If collection doesn't exist, mark it as new
if not collection_exists: if not collection_exists:
self.is_new_collection = True self.is_new_collection = True
logging.info(f"Creating new collection: {self.collection_name}") logging.info(f"Creating new collection: {self.collection_name}")
# Delete if recreate is True # Delete if recreate is True
if recreate and collection_exists: if recreate and collection_exists:
chroma_client.delete_collection(name=self.collection_name) chroma_client.delete_collection(name=self.collection_name)
self.is_new_collection = True self.is_new_collection = True
logging.info(f"Recreating collection: {self.collection_name}") logging.info(f"Recreating collection: {self.collection_name}")
return chroma_client.get_or_create_collection( return chroma_client.get_or_create_collection(
name=self.collection_name, name=self.collection_name, metadata={"hnsw:space": "cosine"}
metadata={ )
"hnsw:space": "cosine"
})
def load_text_files(self, directory=None, encoding="utf-8"):
"""Load all text files from a directory into Document objects."""
directory = directory or self.watch_directory
file_paths = glob.glob(os.path.join(directory, "**/*"), recursive=True)
documents = []
for file_path in file_paths:
if os.path.isfile(file_path): # Ensure it's a file, not a directory
try:
with open(file_path, "r", encoding=encoding) as f:
content = f.read()
# Extract top-level directory
rel_path = os.path.relpath(file_path, directory)
top_level_dir = rel_path.split(os.sep)[0]
documents.append(Document(
page_content=content,
metadata={"doc_type": top_level_dir, "path": file_path}
))
except Exception as e:
logging.error(f"Failed to load {file_path}: {e}")
return documents
def create_chunks_from_documents(self, docs): def create_chunks_from_documents(self, docs):
"""Split documents into chunks using the text splitter.""" """Split documents into chunks using the text splitter."""
return self.text_splitter.split_documents(docs) return self.text_splitter.split_documents(docs)
def get_embedding(self, text, normalize=True): def get_embedding(self, text, normalize=True):
"""Generate embeddings using Ollama.""" """Generate embeddings using Ollama."""
#response = self.embedding_model.encode(text) # Outputs 384-dim vectors # response = self.embedding_model.encode(text) # Outputs 384-dim vectors
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
embedding = response["embedding"]
response = self.llm.embeddings(
model=defines.embedding_model,
prompt=text)
embedding = response['embedding']
# response = self.llm.embeddings.create( # response = self.llm.embeddings.create(
# model=defines.embedding_model, # model=defines.embedding_model,
# input=text, # input=text,
# options={"num_ctx": self.chunk_size * 3} # No need waste ctx space # options={"num_ctx": self.chunk_size * 3} # No need waste ctx space
# ) # )
if normalize: if normalize:
normalized = self._normalize_embeddings(embedding) normalized = self._normalize_embeddings(embedding)
return normalized return normalized
return embedding return embedding
def add_embeddings_to_collection(self, chunks): def add_embeddings_to_collection(self, chunks: List[Chunk]):
"""Add embeddings for chunks to the collection.""" """Add embeddings for chunks to the collection."""
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
text = chunk.page_content text = chunk["text"]
metadata = chunk.metadata metadata = chunk["metadata"]
# Generate a more unique ID based on content and metadata # Generate a more unique ID based on content and metadata
content_hash = hashlib.md5(text.encode()).hexdigest() content_hash = hashlib.md5(text.encode()).hexdigest()
path_hash = "" path_hash = ""
if "path" in metadata: if "path" in metadata:
path_hash = hashlib.md5(metadata["path"].encode()).hexdigest()[:8] path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[
:8
]
chunk_id = f"{path_hash}_{content_hash}_{i}" chunk_id = f"{path_hash}_{content_hash}_{i}"
embedding = self.get_embedding(text) embedding = self.get_embedding(text)
self.collection.add( try:
ids=[chunk_id], self.collection.add(
documents=[text], ids=[chunk_id],
embeddings=[embedding], documents=[text],
metadatas=[metadata] embeddings=[embedding],
) metadatas=[metadata],
)
except Exception as e:
logging.error(f"Error adding chunk to collection: {e}")
logging.error(traceback.format_exc())
logging.error(chunk)
def read_line_range(self, file_path, start, end, buffer=5) -> list[str]:
try:
with open(file_path, "r") as file:
lines = file.readlines()
start = max(0, start - buffer)
end = min(len(lines), end + buffer)
return lines[start:end]
except:
logging.warning(f"Unable to open {file_path}")
return []
# Cosine Distance Equivalent Similarity Retrieval Characteristics # Cosine Distance Equivalent Similarity Retrieval Characteristics
# 0.2 - 0.3 0.85 - 0.90 Very strict, highly precise results only # 0.2 - 0.3 0.85 - 0.90 Very strict, highly precise results only
# 0.3 - 0.5 0.75 - 0.85 Strong relevance, good precision # 0.3 - 0.5 0.75 - 0.85 Strong relevance, good precision
@ -413,16 +450,16 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# collection is configured with hnsw:space cosine # collection is configured with hnsw:space cosine
query_embedding = self.get_embedding(query) query_embedding = self.get_embedding(query)
results = self.collection.query( results = self.collection.query(
query_embeddings=[query_embedding], query_embeddings=[query_embedding],
n_results=top_k, n_results=top_k,
include=["documents", "metadatas", "distances"], include=["documents", "metadatas", "distances"],
) )
# Extract results # Extract results
ids = results['ids'][0] ids = results["ids"][0]
documents = results['documents'][0] documents = results["documents"][0]
distances = results['distances'][0] distances = results["distances"][0]
metadatas = results['metadatas'][0] metadatas = results["metadatas"][0]
filtered_ids = [] filtered_ids = []
filtered_documents = [] filtered_documents = []
@ -436,6 +473,14 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
filtered_metadatas.append(metadatas[i]) filtered_metadatas.append(metadatas[i])
filtered_distances.append(distance) filtered_distances.append(distance)
for index, meta in enumerate(filtered_metadatas):
source_file = meta["source_file"]
del meta["source_file"]
lines = self.read_line_range(
source_file, meta["line_begin"], meta["line_end"]
)
if len(lines):
filtered_documents[index] = "\n".join(lines)
# Return the filtered results instead of all results # Return the filtered results instead of all results
return { return {
"query_embedding": query_embedding, "query_embedding": query_embedding,
@ -444,50 +489,52 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
"distances": filtered_distances, "distances": filtered_distances,
"metadatas": filtered_metadatas, "metadatas": filtered_metadatas,
} }
def _get_file_hash(self, file_path): def _get_file_hash(self, file_path):
"""Calculate MD5 hash of a file.""" """Calculate MD5 hash of a file."""
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
return hashlib.md5(f.read()).hexdigest() return hashlib.md5(f.read()).hexdigest()
except Exception as e: except Exception as e:
logging.error(f"Error hashing file {file_path}: {e}") logging.error(f"Error hashing file {file_path}: {e}")
return None return None
def on_modified(self, event): def on_modified(self, event):
"""Handle file modification events.""" """Handle file modification events."""
if event.is_directory: if event.is_directory:
return return
file_path = event.src_path file_path = event.src_path
# Schedule the update using asyncio # Schedule the update using asyncio
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop) asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
logging.info(f"File modified: {file_path}") logging.info(f"File modified: {file_path}")
def on_created(self, event): def on_created(self, event):
"""Handle file creation events.""" """Handle file creation events."""
if event.is_directory: if event.is_directory:
return return
file_path = event.src_path file_path = event.src_path
# Schedule the update using asyncio # Schedule the update using asyncio
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop) asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
logging.info(f"File created: {file_path}") logging.info(f"File created: {file_path}")
def on_deleted(self, event): def on_deleted(self, event):
"""Handle file deletion events.""" """Handle file deletion events."""
if event.is_directory: if event.is_directory:
return return
file_path = event.src_path file_path = event.src_path
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop) asyncio.run_coroutine_threadsafe(
self.remove_file_from_collection(file_path), self.loop
)
logging.info(f"File deleted: {file_path}") logging.info(f"File deleted: {file_path}")
def on_moved(self, event): def on_moved(self, event):
"""Handle move deletion events.""" """Handle move deletion events."""
if event.is_directory: if event.is_directory:
return return
file_path = event.src_path file_path = event.src_path
logging.info(f"TODO: on_moved: ${file_path}") logging.info(f"TODO: on_moved: ${file_path}")
@ -508,72 +555,88 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
try: try:
# Remove existing entries for this file # Remove existing entries for this file
existing_results = self.collection.get(where={"path": file_path}) existing_results = self.collection.get(where={"path": file_path})
if existing_results and 'ids' in existing_results and existing_results['ids']: if (
self.collection.delete(ids=existing_results['ids']) existing_results
and "ids" in existing_results
and existing_results["ids"]
):
self.collection.delete(ids=existing_results["ids"])
extensions = (".docx", ".xlsx", ".xls", ".pdf") extensions = (".docx", ".xlsx", ".xls", ".pdf")
if file_path.endswith(extensions): if file_path.endswith(extensions):
p = Path(file_path) p = Path(file_path)
p_as_md = p.with_suffix(".md") p_as_md = p.with_suffix(".md")
if p_as_md.exists(): if p_as_md.exists():
logging.info(f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}") logging.info(
f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}"
)
# If file_path.md doesn't exist or file_path is newer than file_path.md, # If file_path.md doesn't exist or file_path is newer than file_path.md,
# fire off markitdown # fire off markitdown
if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime): if (not p_as_md.exists()) or (
p.stat().st_mtime > p_as_md.stat().st_mtime
):
self._markitdown(file_path, p_as_md) self._markitdown(file_path, p_as_md)
return return
# Create document object in LangChain format chunks = self._markdown_chunker.process_file(file_path)
with open(file_path, "r", encoding="utf-8") as f: if not chunks:
content = f.read() return
# Extract top-level directory # Extract top-level directory
rel_path = os.path.relpath(file_path, self.watch_directory) rel_path = os.path.relpath(file_path, self.watch_directory)
top_level_dir = rel_path.split(os.sep)[0] path_parts = rel_path.split(os.sep)
top_level_dir = path_parts[0]
document = Document( # file_name = path_parts[-1]
page_content=content, for i, chunk in enumerate(chunks):
metadata={"doc_type": top_level_dir, "path": file_path} chunk["metadata"]["doc_type"] = top_level_dir
) # with open(f"src/tmp/{file_name}.{i}", "w") as f:
# f.write(json.dumps(chunk, indent=2))
# Create chunks
chunks = self.text_splitter.split_documents([document])
# Add chunks to collection # Add chunks to collection
self.add_embeddings_to_collection(chunks) self.add_embeddings_to_collection(chunks)
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}") logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
except Exception as e: except Exception as e:
logging.error(f"Error updating document in collection: {e}") logging.error(f"Error updating document in collection: {e}")
logging.error(traceback.format_exc())
async def initialize_collection(self): async def initialize_collection(self):
"""Initialize the collection with all documents from the watch directory.""" """Initialize the collection with all documents from the watch directory."""
# Process all files regardless of hash state # Process all files regardless of hash state
num_processed = await self.scan_directory(process_all=True) num_processed = await self.scan_directory(process_all=True)
logging.info(f"Vectorstore initialized with {self.collection.count()} documents") logging.info(
f"Vectorstore initialized with {self.collection.count()} documents"
)
self._update_umaps() self._update_umaps()
# Show stats # Show stats
try: try:
all_metadata = self.collection.get()['metadatas'] all_metadata = self.collection.get()["metadatas"]
if all_metadata: if all_metadata:
doc_types = set(m.get('doc_type', 'unknown') for m in all_metadata) doc_types = set(m.get("doc_type", "unknown") for m in all_metadata)
logging.info(f"Document types: {doc_types}") logging.info(f"Document types: {doc_types}")
except Exception as e: except Exception as e:
logging.error(f"Error getting document types: {e}") logging.error(f"Error getting document types: {e}")
return num_processed return num_processed
# Function to start the file watcher # Function to start the file watcher
def start_file_watcher(llm, watch_directory, persist_directory=None, def start_file_watcher(
collection_name="documents", initialize=False, recreate=False): llm,
watch_directory,
persist_directory=None,
collection_name="documents",
initialize=False,
recreate=False,
):
""" """
Start watching a directory for file changes. Start watching a directory for file changes.
Args: Args:
llm: The language model client llm: The language model client
watch_directory: Directory to watch for changes watch_directory: Directory to watch for changes
@ -585,14 +648,14 @@ def start_file_watcher(llm, watch_directory, persist_directory=None,
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
file_watcher = ChromaDBFileWatcher( file_watcher = ChromaDBFileWatcher(
llm, llm,
watch_directory, watch_directory,
loop=loop, loop=loop,
persist_directory=persist_directory, persist_directory=persist_directory,
collection_name=collection_name, collection_name=collection_name,
recreate=recreate recreate=recreate,
) )
# Process all files if: # Process all files if:
# 1. initialize=True was passed (explicit request to initialize) # 1. initialize=True was passed (explicit request to initialize)
# 2. This is a new collection (doesn't exist yet) # 2. This is a new collection (doesn't exist yet)
@ -604,38 +667,39 @@ def start_file_watcher(llm, watch_directory, persist_directory=None,
# Only process new/changed files # Only process new/changed files
logging.info("Scanning for new/changed documents") logging.info("Scanning for new/changed documents")
asyncio.run_coroutine_threadsafe(file_watcher.scan_directory(), loop) asyncio.run_coroutine_threadsafe(file_watcher.scan_directory(), loop)
# Start observer # Start observer
observer = Observer() observer = Observer()
observer.schedule(file_watcher, watch_directory, recursive=True) observer.schedule(file_watcher, watch_directory, recursive=True)
observer.start() observer.start()
logging.info(f"Started watching directory: {watch_directory}") logging.info(f"Started watching directory: {watch_directory}")
return observer, file_watcher return observer, file_watcher
if __name__ == "__main__": if __name__ == "__main__":
# When running directly, use absolute imports # When running directly, use absolute imports
import defines import defines
# Initialize Ollama client # Initialize Ollama client
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
# Start the file watcher (with initialization) # Start the file watcher (with initialization)
observer, file_watcher = start_file_watcher( observer, file_watcher = start_file_watcher(
llm, llm,
defines.doc_dir, defines.doc_dir,
recreate=True, # Start fresh recreate=True, # Start fresh
) )
# Example query # Example query
query = "Can you describe James Ketrenos' work history?" query = "Can you describe James Ketrenos' work history?"
top_docs = file_watcher.find_similar(query, top_k=3) top_docs = file_watcher.find_similar(query, top_k=3)
logging.info(top_docs) logging.info(top_docs)
try: try:
# Keep the main thread running # Keep the main thread running
while True: while True:
time.sleep(1) time.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
observer.stop() observer.stop()
observer.join() observer.join()

View File

@ -4,9 +4,12 @@ import logging
from . import defines from . import defines
def setup_logging(level=defines.logging_level) -> logging.Logger: def setup_logging(level=defines.logging_level) -> logging.Logger:
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel") warnings.filterwarnings(
"ignore", message="Overriding a previously registered kernel"
)
warnings.filterwarnings("ignore", message="Warning only once for all operators") warnings.filterwarnings("ignore", message="Warning only once for all operators")
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*") warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to") warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
@ -16,19 +19,25 @@ def setup_logging(level=defines.logging_level) -> logging.Logger:
numeric_level = getattr(logging, level.upper(), None) numeric_level = getattr(logging, level.upper(), None)
if not isinstance(numeric_level, int): if not isinstance(numeric_level, int):
raise ValueError(f"Invalid log level: {level}") raise ValueError(f"Invalid log level: {level}")
logging.basicConfig( logging.basicConfig(
level=numeric_level, level=numeric_level,
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", datefmt="%Y-%m-%d %H:%M:%S",
force=True force=True,
) )
# Now reduce verbosity for FastAPI, Uvicorn, Starlette # Now reduce verbosity for FastAPI, Uvicorn, Starlette
for noisy_logger in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi", "starlette"): for noisy_logger in (
#for noisy_logger in ("starlette"): "uvicorn",
"uvicorn.error",
"uvicorn.access",
"fastapi",
"starlette",
):
# for noisy_logger in ("starlette"):
logging.getLogger(noisy_logger).setLevel(logging.WARNING) logging.getLogger(noisy_logger).setLevel(logging.WARNING)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
return logger return logger

View File

@ -1,14 +1,14 @@
import importlib import importlib
from . basetools import tools, llm_tools, enabled_tools, tool_functions from .basetools import tools, llm_tools, enabled_tools, tool_functions
from .. setup_logging import setup_logging from ..setup_logging import setup_logging
from .. import defines from .. import defines
logger = setup_logging(level=defines.logging_level) logger = setup_logging(level=defines.logging_level)
# Dynamically import all names from basetools listed in tools_all # Dynamically import all names from basetools listed in tools_all
module = importlib.import_module('.basetools', package=__package__) module = importlib.import_module(".basetools", package=__package__)
for name in tool_functions: for name in tool_functions:
globals()[name] = getattr(module, name) globals()[name] = getattr(module, name)
__all__ = [ 'tools', 'llm_tools', 'enabled_tools', 'tool_functions' ] __all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]

View File

@ -1,18 +1,19 @@
import os import os
from datetime import datetime from datetime import datetime
from typing import ( from typing import (
Any, Any,
) )
from typing_extensions import Annotated from typing_extensions import Annotated
from bs4 import BeautifulSoup # type: ignore from bs4 import BeautifulSoup # type: ignore
from geopy.geocoders import Nominatim # type: ignore from geopy.geocoders import Nominatim # type: ignore
import pytz # type: ignore import pytz # type: ignore
import requests import requests
import yfinance as yf # type: ignore import yfinance as yf # type: ignore
import logging import logging
# %% # %%
def WeatherForecast(city, state, country="USA"): def WeatherForecast(city, state, country="USA"):
""" """
@ -42,24 +43,25 @@ def WeatherForecast(city, state, country="USA"):
# Step 3: Get the forecast data from the grid endpoint # Step 3: Get the forecast data from the grid endpoint
forecast = get_forecast(grid_endpoint) forecast = get_forecast(grid_endpoint)
if not forecast['location']: if not forecast["location"]:
forecast['location'] = location forecast["location"] = location
return forecast return forecast
def get_coordinates(location): def get_coordinates(location):
"""Convert a location string to latitude and longitude using Nominatim geocoder.""" """Convert a location string to latitude and longitude using Nominatim geocoder."""
try: try:
# Create a geocoder with a meaningful user agent # Create a geocoder with a meaningful user agent
geolocator = Nominatim(user_agent="weather_app_example") geolocator = Nominatim(user_agent="weather_app_example")
# Get the location # Get the location
location_data = geolocator.geocode(location) location_data = geolocator.geocode(location)
if location_data: if location_data:
return { return {
"latitude": location_data.latitude, "latitude": location_data.latitude,
"longitude": location_data.longitude "longitude": location_data.longitude,
} }
else: else:
print(f"Location not found: {location}") print(f"Location not found: {location}")
@ -68,22 +70,23 @@ def get_coordinates(location):
print(f"Error getting coordinates: {e}") print(f"Error getting coordinates: {e}")
return None return None
def get_grid_endpoint(coordinates): def get_grid_endpoint(coordinates):
"""Get the grid endpoint from weather.gov based on coordinates.""" """Get the grid endpoint from weather.gov based on coordinates."""
try: try:
lat = coordinates["latitude"] lat = coordinates["latitude"]
lon = coordinates["longitude"] lon = coordinates["longitude"]
# Define headers for the API request # Define headers for the API request
headers = { headers = {
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)", "User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
"Accept": "application/geo+json" "Accept": "application/geo+json",
} }
# Make the request to get the grid endpoint # Make the request to get the grid endpoint
url = f"https://api.weather.gov/points/{lat},{lon}" url = f"https://api.weather.gov/points/{lat},{lon}"
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers)
if response.status_code == 200: if response.status_code == 200:
data = response.json() data = response.json()
return data["properties"]["forecast"] return data["properties"]["forecast"]
@ -94,44 +97,50 @@ def get_grid_endpoint(coordinates):
print(f"Error in get_grid_endpoint: {e}") print(f"Error in get_grid_endpoint: {e}")
return None return None
# Weather related function # Weather related function
def get_forecast(grid_endpoint): def get_forecast(grid_endpoint):
"""Get the forecast data from the grid endpoint.""" """Get the forecast data from the grid endpoint."""
try: try:
# Define headers for the API request # Define headers for the API request
headers = { headers = {
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)", "User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
"Accept": "application/geo+json" "Accept": "application/geo+json",
} }
# Make the request to get the forecast # Make the request to get the forecast
response = requests.get(grid_endpoint, headers=headers) response = requests.get(grid_endpoint, headers=headers)
if response.status_code == 200: if response.status_code == 200:
data = response.json() data = response.json()
# Extract the relevant forecast information # Extract the relevant forecast information
periods = data["properties"]["periods"] periods = data["properties"]["periods"]
# Process the forecast data into a simpler format # Process the forecast data into a simpler format
forecast = { forecast = {
"location": data["properties"].get("relativeLocation", {}).get("properties", {}), "location": data["properties"]
.get("relativeLocation", {})
.get("properties", {}),
"updated": data["properties"].get("updated", ""), "updated": data["properties"].get("updated", ""),
"periods": [] "periods": [],
} }
for period in periods: for period in periods:
forecast["periods"].append({ forecast["periods"].append(
"name": period.get("name", ""), {
"temperature": period.get("temperature", ""), "name": period.get("name", ""),
"temperatureUnit": period.get("temperatureUnit", ""), "temperature": period.get("temperature", ""),
"windSpeed": period.get("windSpeed", ""), "temperatureUnit": period.get("temperatureUnit", ""),
"windDirection": period.get("windDirection", ""), "windSpeed": period.get("windSpeed", ""),
"shortForecast": period.get("shortForecast", ""), "windDirection": period.get("windDirection", ""),
"detailedForecast": period.get("detailedForecast", "") "shortForecast": period.get("shortForecast", ""),
}) "detailedForecast": period.get("detailedForecast", ""),
}
)
return forecast return forecast
else: else:
print(f"Error getting forecast: {response.status_code} - {response.text}") print(f"Error getting forecast: {response.status_code} - {response.text}")
@ -140,15 +149,16 @@ def get_forecast(grid_endpoint):
print(f"Error in get_forecast: {e}") print(f"Error in get_forecast: {e}")
return {"error": f"Exception: {str(e)}"} return {"error": f"Exception: {str(e)}"}
# Example usage # Example usage
# def do_weather(): # def do_weather():
# city = input("Enter city: ") # city = input("Enter city: ")
# state = input("Enter state: ") # state = input("Enter state: ")
# country = input("Enter country (default USA): ") or "USA" # country = input("Enter country (default USA): ") or "USA"
# print(f"Getting weather for {city}, {state}, {country}...") # print(f"Getting weather for {city}, {state}, {country}...")
# weather_data = WeatherForecast(city, state, country) # weather_data = WeatherForecast(city, state, country)
# if "error" in weather_data: # if "error" in weather_data:
# print(f"Error: {weather_data['error']}") # print(f"Error: {weather_data['error']}")
# else: # else:
@ -156,7 +166,7 @@ def get_forecast(grid_endpoint):
# print(f"Location: {weather_data.get('location', {}).get('city', city)}, {weather_data.get('location', {}).get('state', state)}") # print(f"Location: {weather_data.get('location', {}).get('city', city)}, {weather_data.get('location', {}).get('state', state)}")
# print(f"Last Updated: {weather_data.get('updated', 'N/A')}") # print(f"Last Updated: {weather_data.get('updated', 'N/A')}")
# print("\nForecast Periods:") # print("\nForecast Periods:")
# for period in weather_data.get("periods", []): # for period in weather_data.get("periods", []):
# print(f"\n{period['name']}:") # print(f"\n{period['name']}:")
# print(f" Temperature: {period['temperature']}{period['temperatureUnit']}") # print(f" Temperature: {period['temperature']}{period['temperatureUnit']}")
@ -166,35 +176,32 @@ def get_forecast(grid_endpoint):
# %% # %%
def TickerValue(ticker_symbols): def TickerValue(ticker_symbols):
api_key = os.getenv("TWELVEDATA_API_KEY", "") api_key = os.getenv("TWELVEDATA_API_KEY", "")
if not api_key: if not api_key:
return {"error": f"Error fetching data: No API key for TwelveData"} return {"error": f"Error fetching data: No API key for TwelveData"}
results = [] results = []
for ticker_symbol in ticker_symbols.split(','): for ticker_symbol in ticker_symbols.split(","):
ticker_symbol = ticker_symbol.strip() ticker_symbol = ticker_symbol.strip()
if ticker_symbol == "": if ticker_symbol == "":
continue continue
url = f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}" url = (
f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
)
response = requests.get(url) response = requests.get(url)
data = response.json() data = response.json()
if "price" in data: if "price" in data:
logging.info(f"TwelveData: {ticker_symbol} {data}") logging.info(f"TwelveData: {ticker_symbol} {data}")
results.append({ results.append({"symbol": ticker_symbol, "price": float(data["price"])})
"symbol": ticker_symbol,
"price": float(data["price"])
})
else: else:
logging.error(f"TwelveData: {data}") logging.error(f"TwelveData: {data}")
results.append({ results.append({"symbol": ticker_symbol, "price": "Unavailable"})
"symbol": ticker_symbol,
"price": "Unavailable"
})
return results[0] if len(results) == 1 else results return results[0] if len(results) == 1 else results
@ -202,15 +209,15 @@ def TickerValue(ticker_symbols):
def yfTickerValue(ticker_symbols): def yfTickerValue(ticker_symbols):
""" """
Look up the current price of a stock using its ticker symbol. Look up the current price of a stock using its ticker symbol.
Args: Args:
ticker_symbol (str): The stock ticker symbol (e.g., 'AAPL' for Apple) ticker_symbol (str): The stock ticker symbol (e.g., 'AAPL' for Apple)
Returns: Returns:
dict: Current stock information including price dict: Current stock information including price
""" """
results = [] results = []
for ticker_symbol in ticker_symbols.split(','): for ticker_symbol in ticker_symbols.split(","):
ticker_symbol = ticker_symbol.strip() ticker_symbol = ticker_symbol.strip()
if ticker_symbol == "": if ticker_symbol == "":
continue continue
@ -220,59 +227,64 @@ def yfTickerValue(ticker_symbols):
ticker = yf.Ticker(ticker_symbol) ticker = yf.Ticker(ticker_symbol)
# Get the latest market data # Get the latest market data
ticker_data = ticker.history(period="1d") ticker_data = ticker.history(period="1d")
if ticker_data.empty: if ticker_data.empty:
results.append({"error": f"No data found for ticker {ticker_symbol}"}) results.append({"error": f"No data found for ticker {ticker_symbol}"})
continue continue
# Get the latest closing price # Get the latest closing price
latest_price = ticker_data['Close'].iloc[-1] latest_price = ticker_data["Close"].iloc[-1]
# Get some additional info # Get some additional info
results.append({ 'symbol': ticker_symbol, 'price': latest_price }) results.append({"symbol": ticker_symbol, "price": latest_price})
except Exception as e: except Exception as e:
import traceback import traceback
logging.error(f"Error fetching data for {ticker_symbol}: {e}") logging.error(f"Error fetching data for {ticker_symbol}: {e}")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
results.append({"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}) results.append(
{"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}
)
return results[0] if len(results) == 1 else results return results[0] if len(results) == 1 else results
# %% # %%
def DateTime(timezone="America/Los_Angeles"): def DateTime(timezone="America/Los_Angeles"):
""" """
Returns the current date and time in the specified timezone in ISO 8601 format. Returns the current date and time in the specified timezone in ISO 8601 format.
Args: Args:
timezone (str): Timezone name (e.g., "UTC", "America/New_York", "Europe/London") timezone (str): Timezone name (e.g., "UTC", "America/New_York", "Europe/London")
Default is "America/Los_Angeles" Default is "America/Los_Angeles"
Returns: Returns:
str: Current date and time with timezone in the format YYYY-MM-DDTHH:MM:SS+HH:MM str: Current date and time with timezone in the format YYYY-MM-DDTHH:MM:SS+HH:MM
""" """
try: try:
if timezone == 'system' or timezone == '' or not timezone: if timezone == "system" or timezone == "" or not timezone:
timezone = 'America/Los_Angeles' timezone = "America/Los_Angeles"
# Get current UTC time (timezone-aware) # Get current UTC time (timezone-aware)
local_tz = pytz.timezone("America/Los_Angeles") local_tz = pytz.timezone("America/Los_Angeles")
local_now = datetime.now(tz=local_tz) local_now = datetime.now(tz=local_tz)
# Convert to target timezone # Convert to target timezone
target_tz = pytz.timezone(timezone) target_tz = pytz.timezone(timezone)
target_time = local_now.astimezone(target_tz) target_time = local_now.astimezone(target_tz)
return target_time.isoformat() return target_time.isoformat()
except Exception as e: except Exception as e:
return {'error': f"Invalid timezone {timezone}: {str(e)}"} return {"error": f"Invalid timezone {timezone}: {str(e)}"}
async def AnalyzeSite(llm, model: str, url : str, question : str):
async def AnalyzeSite(llm, model: str, url: str, question: str):
""" """
Fetches content from a URL, extracts the text, and uses Ollama to summarize it. Fetches content from a URL, extracts the text, and uses Ollama to summarize it.
Args: Args:
url (str): The URL of the website to summarize url (str): The URL of the website to summarize
Returns: Returns:
str: A summary of the website content str: A summary of the website content
""" """
@ -287,41 +299,43 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
logging.info(f"{url} returned. Processing...") logging.info(f"{url} returned. Processing...")
# Parse the HTML # Parse the HTML
soup = BeautifulSoup(response.text, "html.parser") soup = BeautifulSoup(response.text, "html.parser")
# Remove script and style elements # Remove script and style elements
for script in soup(["script", "style"]): for script in soup(["script", "style"]):
script.extract() script.extract()
# Get text content # Get text content
text = soup.get_text(separator=" ", strip=True) text = soup.get_text(separator=" ", strip=True)
# Clean up text (remove extra whitespace) # Clean up text (remove extra whitespace)
lines = (line.strip() for line in text.splitlines()) lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = " ".join(chunk for chunk in chunks if chunk) text = " ".join(chunk for chunk in chunks if chunk)
# Limit text length if needed (Ollama may have token limits) # Limit text length if needed (Ollama may have token limits)
max_chars = 100000 max_chars = 100000
if len(text) > max_chars: if len(text) > max_chars:
text = text[:max_chars] + "..." text = text[:max_chars] + "..."
# Create Ollama client # Create Ollama client
# logging.info(f"Requesting summary of: {text}") # logging.info(f"Requesting summary of: {text}")
# Generate summary using Ollama # Generate summary using Ollama
prompt = f"CONTENTS:\n\n{text}\n\n{question}" prompt = f"CONTENTS:\n\n{text}\n\n{question}"
response = llm.generate(model=model, response = llm.generate(
system="You are given the contents of {url}. Answer the question about the contents", model=model,
prompt=prompt) system="You are given the contents of {url}. Answer the question about the contents",
prompt=prompt,
#logging.info(response["response"]) )
# logging.info(response["response"])
return { return {
"source": "summarizer-llm", "source": "summarizer-llm",
"content": response["response"], "content": response["response"],
"metadata": DateTime() "metadata": DateTime(),
} }
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logging.error(f"Error fetching the URL: {e}") logging.error(f"Error fetching the URL: {e}")
return f"Error fetching the URL: {str(e)}" return f"Error fetching the URL: {str(e)}"
@ -331,109 +345,116 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
# %% # %%
tools = [ { tools = [
"type": "function", {
"function": { "type": "function",
"name": "TickerValue", "function": {
"description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'", "name": "TickerValue",
"parameters": { "description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'",
"type": "object", "parameters": {
"properties": { "type": "object",
"ticker": { "properties": {
"type": "string", "ticker": {
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').", "type": "string",
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').",
},
}, },
"required": ["ticker"],
"additionalProperties": False,
}, },
"required": ["ticker"],
"additionalProperties": False
}
}
}, {
"type": "function",
"function": {
"name": "AnalyzeSite",
"description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The website URL to download and process",
},
"question": {
"type": "string",
"description": "The question to ask the second LLM about the content",
},
},
"required": ["url", "question"],
"additionalProperties": False
}, },
"returns": { },
"type": "object", {
"properties": { "type": "function",
"source": { "function": {
"type": "string", "name": "AnalyzeSite",
"description": "Identifier for the source LLM" "description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The website URL to download and process",
},
"question": {
"type": "string",
"description": "The question to ask the second LLM about the content",
},
}, },
"content": { "required": ["url", "question"],
"type": "string", "additionalProperties": False,
"description": "The complete response from the second LLM"
},
"metadata": {
"type": "object",
"description": "Additional information about the response"
}
}
}
}
}, {
"type": "function",
"function": {
"name": "DateTime",
"description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'."
}
}, },
"required": [] "returns": {
} "type": "object",
} "properties": {
}, { "source": {
"type": "function", "type": "string",
"function": { "description": "Identifier for the source LLM",
"name": "WeatherForecast", },
"description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.", "content": {
"parameters": { "type": "string",
"type": "object", "description": "The complete response from the second LLM",
"properties": { },
"city": { "metadata": {
"type": "string", "type": "object",
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').", "description": "Additional information about the response",
"minLength": 2 },
}, },
"state": {
"type": "string",
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
"minLength": 2
}
}, },
"required": [ "city", "state" ], },
"additionalProperties": False },
} {
} "type": "function",
}] "function": {
"name": "DateTime",
"description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.",
}
},
"required": [],
},
},
},
{
"type": "function",
"function": {
"name": "WeatherForecast",
"description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
"minLength": 2,
},
"state": {
"type": "string",
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
"minLength": 2,
},
},
"required": ["city", "state"],
"additionalProperties": False,
},
},
},
]
def llm_tools(tools): def llm_tools(tools):
return [tool for tool in tools if tool.get("enabled", False) == True] return [tool for tool in tools if tool.get("enabled", False) == True]
def enabled_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: def enabled_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
return [{**tool, "enabled": True} for tool in tools] return [{**tool, "enabled": True} for tool in tools]
tool_functions = [ 'DateTime', 'WeatherForecast', 'TickerValue', 'AnalyzeSite' ]
__all__ = [ 'tools', 'llm_tools', 'enabled_tools', 'tool_functions' ]
#__all__.extend(__tool_functions__) # type: ignore
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite"]
__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]
# __all__.extend(__tool_functions__) # type: ignore