Reformatted all content to black
This commit is contained in:
parent
a1798b58ac
commit
e044f9c639
@ -88,6 +88,10 @@ button {
|
||||
flex-grow: 1;
|
||||
}
|
||||
|
||||
.MessageContent div > p:first-child {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.MenuCard.MuiCard-root {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
@ -30,7 +30,6 @@ const BackstoryTextField = React.forwardRef<BackstoryTextFieldRef, BackstoryText
|
||||
const shadowRef = useRef<HTMLTextAreaElement>(null);
|
||||
const [editValue, setEditValue] = useState<string>(value);
|
||||
|
||||
console.log({ value, placeholder, editValue });
|
||||
// Sync editValue with prop value if it changes externally
|
||||
useEffect(() => {
|
||||
setEditValue(value || "");
|
||||
|
@ -158,7 +158,7 @@ function ChatBubble(props: ChatBubbleProps) {
|
||||
};
|
||||
|
||||
// Render Accordion for expandable content
|
||||
if (expandable || (role === 'content' && title)) {
|
||||
if (expandable || title) {
|
||||
// Determine if Accordion is controlled
|
||||
const isControlled = typeof expanded === 'boolean' && typeof onExpand === 'function';
|
||||
|
||||
|
@ -1,13 +1,15 @@
|
||||
import React, { useState, useImperativeHandle, forwardRef, useEffect, useRef, useCallback } from 'react';
|
||||
import Typography from '@mui/material/Typography';
|
||||
import Tooltip from '@mui/material/Tooltip';
|
||||
import IconButton from '@mui/material/IconButton';
|
||||
import Button from '@mui/material/Button';
|
||||
import Box from '@mui/material/Box';
|
||||
import SendIcon from '@mui/icons-material/Send';
|
||||
import CancelIcon from '@mui/icons-material/Cancel';
|
||||
import { SxProps, Theme } from '@mui/material';
|
||||
import PropagateLoader from "react-spinners/PropagateLoader";
|
||||
|
||||
import { Message, MessageList, MessageData } from './Message';
|
||||
import { Message, MessageList, BackstoryMessage } from './Message';
|
||||
import { ContextStatus } from './ContextStatus';
|
||||
import { Scrollable } from './Scrollable';
|
||||
import { DeleteConfirmation } from './DeleteConfirmation';
|
||||
@ -17,7 +19,7 @@ import { BackstoryTextField, BackstoryTextFieldRef } from './BackstoryTextField'
|
||||
import { BackstoryElementProps } from './BackstoryTab';
|
||||
import { connectionBase } from './Global';
|
||||
|
||||
const loadingMessage: MessageData = { "role": "status", "content": "Establishing connection with server..." };
|
||||
const loadingMessage: BackstoryMessage = { "role": "status", "content": "Establishing connection with server..." };
|
||||
|
||||
type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check';
|
||||
|
||||
@ -25,24 +27,6 @@ interface ConversationHandle {
|
||||
submitQuery: (prompt: string, options?: QueryOptions) => void;
|
||||
fetchHistory: () => void;
|
||||
}
|
||||
interface BackstoryMessage {
|
||||
prompt: string;
|
||||
preamble: {};
|
||||
status: string;
|
||||
full_content: string;
|
||||
response: string; // Set when status === 'done' or 'error'
|
||||
chunk: string; // Used when status === 'streaming'
|
||||
metadata: {
|
||||
rag: { documents: [] };
|
||||
tools: string[];
|
||||
eval_count: number;
|
||||
eval_duration: number;
|
||||
prompt_eval_count: number;
|
||||
prompt_eval_duration: number;
|
||||
};
|
||||
actions: string[];
|
||||
timestamp: string;
|
||||
};
|
||||
|
||||
interface ConversationProps extends BackstoryElementProps {
|
||||
className?: string, // Override default className
|
||||
@ -59,7 +43,7 @@ interface ConversationProps extends BackstoryElementProps {
|
||||
messageFilter?: ((messages: MessageList) => MessageList) | undefined, // Filter callback to determine which Messages to display in Conversation
|
||||
messages?: MessageList, //
|
||||
sx?: SxProps<Theme>,
|
||||
onResponse?: ((message: MessageData) => void) | undefined, // Event called when a query completes (provides messages)
|
||||
onResponse?: ((message: BackstoryMessage) => void) | undefined, // Event called when a query completes (provides messages)
|
||||
};
|
||||
|
||||
const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||
@ -87,8 +71,8 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||
const [countdown, setCountdown] = useState<number>(0);
|
||||
const [conversation, setConversation] = useState<MessageList>([]);
|
||||
const [filteredConversation, setFilteredConversation] = useState<MessageList>([]);
|
||||
const [processingMessage, setProcessingMessage] = useState<MessageData | undefined>(undefined);
|
||||
const [streamingMessage, setStreamingMessage] = useState<MessageData | undefined>(undefined);
|
||||
const [processingMessage, setProcessingMessage] = useState<BackstoryMessage | undefined>(undefined);
|
||||
const [streamingMessage, setStreamingMessage] = useState<BackstoryMessage | undefined>(undefined);
|
||||
const timerRef = useRef<any>(null);
|
||||
const [contextStatus, setContextStatus] = useState<ContextStatus>({ context_used: 0, max_context: 0 });
|
||||
const [contextWarningShown, setContextWarningShown] = useState<boolean>(false);
|
||||
@ -96,6 +80,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||
const conversationRef = useRef<MessageList>([]);
|
||||
const viewableElementRef = useRef<HTMLDivElement>(null);
|
||||
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
|
||||
const stopRef = useRef(false);
|
||||
|
||||
// Keep the ref updated whenever items changes
|
||||
useEffect(() => {
|
||||
@ -181,14 +166,25 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||
|
||||
const backstoryMessages: BackstoryMessage[] = messages;
|
||||
|
||||
setConversation(backstoryMessages.flatMap((backstoryMessage: BackstoryMessage) => [{
|
||||
setConversation(backstoryMessages.flatMap((backstoryMessage: BackstoryMessage) => {
|
||||
if (backstoryMessage.status === "partial") {
|
||||
return [{
|
||||
...backstoryMessage,
|
||||
role: "assistant",
|
||||
content: backstoryMessage.response || "",
|
||||
expanded: false,
|
||||
expandable: true,
|
||||
}]
|
||||
}
|
||||
return [{
|
||||
role: 'user',
|
||||
content: backstoryMessage.prompt || "",
|
||||
}, {
|
||||
...backstoryMessage,
|
||||
role: backstoryMessage.status === "done" ? "assistant" : backstoryMessage.status,
|
||||
role: ['done'].includes(backstoryMessage.status || "") ? "assistant" : backstoryMessage.status,
|
||||
content: backstoryMessage.response || "",
|
||||
}] as MessageList));
|
||||
}] as MessageList;
|
||||
}));
|
||||
setNoInteractions(false);
|
||||
}
|
||||
setProcessingMessage(undefined);
|
||||
@ -294,6 +290,11 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||
}
|
||||
};
|
||||
|
||||
const cancelQuery = () => {
|
||||
console.log("Stop query");
|
||||
stopRef.current = true;
|
||||
};
|
||||
|
||||
const sendQuery = async (request: string, options?: QueryOptions) => {
|
||||
request = request.trim();
|
||||
|
||||
@ -308,6 +309,8 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||
return;
|
||||
}
|
||||
|
||||
stopRef.current = false;
|
||||
|
||||
setNoInteractions(false);
|
||||
|
||||
setConversation([
|
||||
@ -325,12 +328,10 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||
|
||||
try {
|
||||
setProcessing(true);
|
||||
// Create a unique ID for the processing message
|
||||
const processingId = Date.now().toString();
|
||||
|
||||
// Add initial processing message
|
||||
setProcessingMessage(
|
||||
{ role: 'status', content: 'Submitting request...', id: processingId, isProcessing: true }
|
||||
{ role: 'status', content: 'Submitting request...', disableCopy: true }
|
||||
);
|
||||
|
||||
// Add a small delay to ensure React has time to update the UI
|
||||
@ -379,17 +380,20 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||
|
||||
switch (update.status) {
|
||||
case 'done':
|
||||
console.log('Done processing:', update);
|
||||
stopCountdown();
|
||||
setStreamingMessage(undefined);
|
||||
setProcessingMessage(undefined);
|
||||
case 'partial':
|
||||
if (update.status === 'done') stopCountdown();
|
||||
if (update.status === 'done') setStreamingMessage(undefined);
|
||||
if (update.status === 'done') setProcessingMessage(undefined);
|
||||
const backstoryMessage: BackstoryMessage = update;
|
||||
setConversation([
|
||||
...conversationRef.current, {
|
||||
...backstoryMessage,
|
||||
role: 'assistant',
|
||||
origin: type,
|
||||
prompt: ['done', 'partial'].includes(update.status) ? update.prompt : '',
|
||||
content: backstoryMessage.response || "",
|
||||
expanded: update.status === "done" ? true : false,
|
||||
expandable: true,
|
||||
}] as MessageList);
|
||||
// Add a small delay to ensure React has time to update the UI
|
||||
await new Promise(resolve => setTimeout(resolve, 0));
|
||||
@ -424,9 +428,9 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||
// Update processing message with immediate re-render
|
||||
if (update.status === "streaming") {
|
||||
streaming_response += update.chunk
|
||||
setStreamingMessage({ role: update.status, content: streaming_response });
|
||||
setStreamingMessage({ role: update.status, content: streaming_response, disableCopy: true });
|
||||
} else {
|
||||
setProcessingMessage({ role: update.status, content: update.response });
|
||||
setProcessingMessage({ role: update.status, content: update.response, disableCopy: true });
|
||||
/* Reset stream on non streaming message */
|
||||
streaming_response = ""
|
||||
}
|
||||
@ -437,12 +441,11 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||
}
|
||||
}
|
||||
|
||||
while (true) {
|
||||
while (!stopRef.current) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
|
||||
// Process each complete line immediately
|
||||
@ -470,26 +473,32 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||
}
|
||||
}
|
||||
|
||||
if (stopRef.current) {
|
||||
await reader.cancel();
|
||||
setProcessingMessage(undefined);
|
||||
setStreamingMessage(undefined);
|
||||
setSnack("Processing cancelled", "warning");
|
||||
}
|
||||
stopCountdown();
|
||||
setProcessing(false);
|
||||
stopRef.current = false;
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack("Unable to process query", "error");
|
||||
setProcessingMessage({ role: 'error', content: "Unable to process query" });
|
||||
setProcessingMessage({ role: 'error', content: "Unable to process query", disableCopy: true });
|
||||
setTimeout(() => {
|
||||
setProcessingMessage(undefined);
|
||||
}, 5000);
|
||||
|
||||
stopRef.current = false;
|
||||
setProcessing(false);
|
||||
stopCountdown();
|
||||
// Add a small delay to ensure React has time to update the UI
|
||||
await new Promise(resolve => setTimeout(resolve, 0));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Scrollable
|
||||
className={className || "Conversation"}
|
||||
className={`${className || ""} Conversation`}
|
||||
autoscroll
|
||||
textFieldRef={viewableElementRef}
|
||||
fallbackThreshold={0.5}
|
||||
@ -564,6 +573,20 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
|
||||
</Button>
|
||||
</span>
|
||||
</Tooltip>
|
||||
<Tooltip title="Cancel">
|
||||
<span style={{ display: "flex" }}> { /* This span is used to wrap the IconButton to ensure Tooltip works even when disabled */}
|
||||
<IconButton
|
||||
aria-label="cancel"
|
||||
onClick={() => { cancelQuery(); }}
|
||||
sx={{ display: "flex", margin: 'auto 0px' }}
|
||||
size="large"
|
||||
edge="start"
|
||||
disabled={stopRef.current || sessionId === undefined || processing === false}
|
||||
>
|
||||
<CancelIcon />
|
||||
</IconButton>
|
||||
</span>
|
||||
</Tooltip>
|
||||
</Box>
|
||||
</Box>
|
||||
{(noInteractions || !hideDefaultPrompts) && defaultPrompts !== undefined && defaultPrompts.length &&
|
||||
|
@ -47,11 +47,18 @@ type MessageRoles =
|
||||
'thinking' |
|
||||
'user';
|
||||
|
||||
type MessageData = {
|
||||
type BackstoryMessage = {
|
||||
// Only two required fields
|
||||
role: MessageRoles,
|
||||
content: string,
|
||||
status?: string, // streaming, done, error...
|
||||
response?: string,
|
||||
// Rest are optional
|
||||
prompt?: string;
|
||||
preamble?: {};
|
||||
status?: string;
|
||||
full_content?: string;
|
||||
response?: string; // Set when status === 'done', 'partial', or 'error'
|
||||
chunk?: string; // Used when status === 'streaming'
|
||||
timestamp?: string;
|
||||
disableCopy?: boolean,
|
||||
user?: string,
|
||||
title?: string,
|
||||
@ -84,11 +91,11 @@ interface MessageMetaData {
|
||||
setSnack: SetSnackType,
|
||||
}
|
||||
|
||||
type MessageList = MessageData[];
|
||||
type MessageList = BackstoryMessage[];
|
||||
|
||||
interface MessageProps extends BackstoryElementProps {
|
||||
sx?: SxProps<Theme>,
|
||||
message: MessageData,
|
||||
message: BackstoryMessage,
|
||||
expanded?: boolean,
|
||||
onExpand?: (open: boolean) => void,
|
||||
className?: string,
|
||||
@ -237,7 +244,7 @@ const MessageMeta = (props: MessageMetaProps) => {
|
||||
};
|
||||
|
||||
const Message = (props: MessageProps) => {
|
||||
const { message, submitQuery, sx, className, onExpand, expanded, sessionId, setSnack } = props;
|
||||
const { message, submitQuery, sx, className, onExpand, sessionId, setSnack } = props;
|
||||
const [metaExpanded, setMetaExpanded] = useState<boolean>(false);
|
||||
const textFieldRef = useRef(null);
|
||||
|
||||
@ -254,14 +261,16 @@ const Message = (props: MessageProps) => {
|
||||
return (<></>);
|
||||
}
|
||||
|
||||
const formattedContent = message.content.trim() || "Waiting for LLM to spool up...";
|
||||
const formattedContent = message.content.trim();
|
||||
if (formattedContent === "") {
|
||||
return (<></>);
|
||||
}
|
||||
|
||||
return (
|
||||
<ChatBubble
|
||||
className={className || "Message"}
|
||||
className={`${className || ""} Message Message-${message.role}`}
|
||||
{...message}
|
||||
onExpand={onExpand}
|
||||
expanded={expanded}
|
||||
sx={{
|
||||
display: "flex",
|
||||
flexDirection: "column",
|
||||
@ -273,7 +282,6 @@ const Message = (props: MessageProps) => {
|
||||
...sx,
|
||||
}}>
|
||||
<CardContent ref={textFieldRef} sx={{ position: "relative", display: "flex", flexDirection: "column", overflowX: "auto", m: 0, p: 0, paddingBottom: '0px !important' }}>
|
||||
{message.role !== 'user' ?
|
||||
<Scrollable
|
||||
className="MessageContent"
|
||||
autoscroll
|
||||
@ -289,18 +297,9 @@ const Message = (props: MessageProps) => {
|
||||
>
|
||||
<StyledMarkdown streaming={message.role === "streaming"} {...{ content: formattedContent, submitQuery, sessionId, setSnack }} />
|
||||
</Scrollable>
|
||||
:
|
||||
<Typography
|
||||
className="MessageContent"
|
||||
ref={textFieldRef}
|
||||
variant="body2"
|
||||
sx={{ display: "flex", color: 'text.secondary' }}>
|
||||
{message.content}
|
||||
</Typography>
|
||||
}
|
||||
</CardContent>
|
||||
<CardActions disableSpacing sx={{ display: "flex", flexDirection: "row", justifyContent: "space-between", alignItems: "center", width: "100%", p: 0, m: 0 }}>
|
||||
{(message.disableCopy === undefined || message.disableCopy === false) && ["assistant", "content"].includes(message.role) && <CopyBubble content={message.content} />}
|
||||
{(message.disableCopy === undefined || message.disableCopy === false) && <CopyBubble content={message.content} />}
|
||||
{message.metadata && (
|
||||
<Box sx={{ display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<Button variant="text" onClick={handleMetaExpandClick} sx={{ color: "darkgrey", p: 0 }}>
|
||||
@ -309,7 +308,7 @@ const Message = (props: MessageProps) => {
|
||||
<ExpandMore
|
||||
expand={metaExpanded}
|
||||
onClick={handleMetaExpandClick}
|
||||
aria-expanded={expanded}
|
||||
aria-expanded={message.expanded}
|
||||
aria-label="show more"
|
||||
>
|
||||
<ExpandMoreIcon />
|
||||
@ -331,7 +330,8 @@ const Message = (props: MessageProps) => {
|
||||
export type {
|
||||
MessageProps,
|
||||
MessageList,
|
||||
MessageData,
|
||||
BackstoryMessage,
|
||||
MessageMetaData,
|
||||
MessageRoles,
|
||||
};
|
||||
|
||||
|
@ -7,7 +7,7 @@ import {
|
||||
import { SxProps } from '@mui/material';
|
||||
|
||||
import { ChatQuery } from './ChatQuery';
|
||||
import { MessageList, MessageData } from './Message';
|
||||
import { MessageList, BackstoryMessage } from './Message';
|
||||
import { Conversation } from './Conversation';
|
||||
import { BackstoryPageProps } from './BackstoryTab';
|
||||
|
||||
@ -62,11 +62,6 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
||||
return [];
|
||||
}
|
||||
|
||||
if (messages.length > 2) {
|
||||
setHasResume(true);
|
||||
setHasFacts(true);
|
||||
}
|
||||
|
||||
if (messages.length > 0) {
|
||||
messages[0].role = 'content';
|
||||
messages[0].title = 'Job Description';
|
||||
@ -74,6 +69,19 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
||||
messages[0].expandable = true;
|
||||
}
|
||||
|
||||
if (-1 !== messages.findIndex(m => m.status === 'done')) {
|
||||
setHasResume(true);
|
||||
setHasFacts(true);
|
||||
}
|
||||
|
||||
return messages;
|
||||
|
||||
if (messages.length > 1) {
|
||||
setHasResume(true);
|
||||
setHasFacts(true);
|
||||
}
|
||||
|
||||
|
||||
if (messages.length > 3) {
|
||||
// messages[2] is Show job requirements
|
||||
messages[3].role = 'job-requirements';
|
||||
@ -95,6 +103,8 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
||||
return [];
|
||||
}
|
||||
|
||||
return messages;
|
||||
|
||||
if (messages.length > 1) {
|
||||
// messages[0] is Show Qualifications
|
||||
messages[1].role = 'qualifications';
|
||||
@ -139,7 +149,7 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
||||
return filtered;
|
||||
}, []);
|
||||
|
||||
const jobResponse = useCallback(async (message: MessageData) => {
|
||||
const jobResponse = useCallback(async (message: BackstoryMessage) => {
|
||||
console.log('onJobResponse', message);
|
||||
if (message.actions && message.actions.includes("job_description")) {
|
||||
await jobConversationRef.current.fetchHistory();
|
||||
@ -155,12 +165,12 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
||||
}
|
||||
}, [setHasFacts, setHasResume, setActiveTab]);
|
||||
|
||||
const resumeResponse = useCallback((message: MessageData): void => {
|
||||
const resumeResponse = useCallback((message: BackstoryMessage): void => {
|
||||
console.log('onResumeResponse', message);
|
||||
setHasFacts(true);
|
||||
}, [setHasFacts]);
|
||||
|
||||
const factsResponse = useCallback((message: MessageData): void => {
|
||||
const factsResponse = useCallback((message: BackstoryMessage): void => {
|
||||
console.log('onFactsResponse', message);
|
||||
}, []);
|
||||
|
||||
@ -199,7 +209,8 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = ({
|
||||
3. **Mapping Analysis**: Identifies legitimate matches between requirements and qualifications
|
||||
3. **Resume Generation**: Uses mapping output to create a tailored resume with evidence-based content
|
||||
4. **Verification**: Performs fact-checking to catch any remaining fabrications
|
||||
1. **Re-generation**: If verification does not pass, a second attempt is made to correct any issues`
|
||||
1. **Re-generation**: If verification does not pass, a second attempt is made to correct any issues`,
|
||||
disableCopy: true
|
||||
}];
|
||||
|
||||
|
||||
|
@ -17,3 +17,12 @@ pre:not(.MessageContent) {
|
||||
.MuiMarkdown > div {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.Message-streaming .MuiMarkdown ul,
|
||||
.Message-streaming .MuiMarkdown h1,
|
||||
.Message-streaming .MuiMarkdown p,
|
||||
.Message-assistant .MuiMarkdown ul,
|
||||
.Message-assistant .MuiMarkdown h1,
|
||||
.Message-assistant .MuiMarkdown p {
|
||||
color: white;
|
||||
}
|
@ -37,14 +37,14 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
|
||||
}
|
||||
if (className === "lang-json" && !streaming) {
|
||||
try {
|
||||
const fixed = jsonrepair(content);
|
||||
let fixed = JSON.parse(jsonrepair(content));
|
||||
return <Scrollable className="JsonViewScrollable">
|
||||
<JsonView
|
||||
className="JsonView"
|
||||
style={{
|
||||
...vscodeTheme,
|
||||
fontSize: "0.8rem",
|
||||
maxHeight: "20rem",
|
||||
maxHeight: "10rem",
|
||||
padding: "14px 0",
|
||||
overflow: "hidden",
|
||||
width: "100%",
|
||||
@ -53,9 +53,9 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
|
||||
}}
|
||||
displayDataTypes={false}
|
||||
objectSortKeys={false}
|
||||
collapsed={false}
|
||||
collapsed={true}
|
||||
shortenTextAfterLength={100}
|
||||
value={JSON.parse(fixed)}>
|
||||
value={fixed}>
|
||||
<JsonView.String
|
||||
render={({ children, ...reset }) => {
|
||||
if (typeof (children) === "string" && children.match("\n")) {
|
||||
@ -66,7 +66,7 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
|
||||
</JsonView>
|
||||
</Scrollable>
|
||||
} catch (e) {
|
||||
console.log("jsonrepair error", e);
|
||||
return <pre><code className="JsonRaw">{content}</code></pre>
|
||||
};
|
||||
}
|
||||
return <pre><code className={className}>{element.children}</code></pre>;
|
||||
|
415
src/server.py
415
src/server.py
@ -1,8 +1,9 @@
|
||||
LLM_TIMEOUT = 600
|
||||
|
||||
from utils import logger
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, Dict, Optional
|
||||
|
||||
# %%
|
||||
# Imports [standard]
|
||||
@ -26,6 +27,7 @@ from uuid import uuid4
|
||||
import time
|
||||
import traceback
|
||||
|
||||
|
||||
def try_import(module_name, pip_name=None):
|
||||
try:
|
||||
__import__(module_name)
|
||||
@ -33,6 +35,7 @@ def try_import(module_name, pip_name=None):
|
||||
print(f"Module '{module_name}' not found. Install it using:")
|
||||
print(f" pip install {pip_name or module_name}")
|
||||
|
||||
|
||||
# Third-party modules with import checks
|
||||
try_import("ollama")
|
||||
try_import("requests")
|
||||
@ -63,7 +66,9 @@ from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||
from utils import (
|
||||
rag as Rag,
|
||||
tools as Tools,
|
||||
Context, Conversation, Message,
|
||||
Context,
|
||||
Conversation,
|
||||
Message,
|
||||
Agent,
|
||||
Metrics,
|
||||
Tunables,
|
||||
@ -74,11 +79,22 @@ from utils import (
|
||||
CONTEXT_VERSION = 2
|
||||
|
||||
rags = [
|
||||
{ "name": "JPK", "enabled": True, "description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects." },
|
||||
{
|
||||
"name": "JPK",
|
||||
"enabled": True,
|
||||
"description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects.",
|
||||
},
|
||||
# { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
|
||||
]
|
||||
|
||||
REQUEST_TIME = Summary('request_processing_seconds', 'Time spent processing request')
|
||||
|
||||
class QueryOptions(BaseModel):
|
||||
prompt: str
|
||||
tunables: Tunables = Field(default_factory=Tunables)
|
||||
agent_options: Dict[str, Any] = Field(default={})
|
||||
|
||||
|
||||
REQUEST_TIME = Summary("request_processing_seconds", "Time spent processing request")
|
||||
|
||||
system_message_old = f"""
|
||||
Launched on {datetime.now().isoformat()}.
|
||||
@ -107,6 +123,7 @@ You are provided with a <|resume|> which was generated by you, the <|context|> y
|
||||
Your task is to answer questions about the <|fact_check|> you generated based on the <|resume|> and <|context>.
|
||||
"""
|
||||
|
||||
|
||||
def get_installed_ram():
|
||||
try:
|
||||
with open("/proc/meminfo", "r") as f:
|
||||
@ -117,21 +134,29 @@ def get_installed_ram():
|
||||
except Exception as e:
|
||||
return f"Error retrieving RAM: {e}"
|
||||
|
||||
|
||||
def get_graphics_cards():
|
||||
gpus = []
|
||||
try:
|
||||
# Run the ze-monitor utility
|
||||
result = subprocess.run(["ze-monitor"], capture_output=True, text=True, check=True)
|
||||
result = subprocess.run(
|
||||
["ze-monitor"], capture_output=True, text=True, check=True
|
||||
)
|
||||
|
||||
# Clean up the output (remove leading/trailing whitespace and newlines)
|
||||
output = result.stdout.strip()
|
||||
for index in range(len(output.splitlines())):
|
||||
result = subprocess.run(["ze-monitor", "--device", f"{index+1}", "--info"], capture_output=True, text=True, check=True)
|
||||
result = subprocess.run(
|
||||
["ze-monitor", "--device", f"{index+1}", "--info"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
gpu_info = result.stdout.strip().splitlines()
|
||||
gpu = {
|
||||
"discrete": True, # Assume it's discrete initially
|
||||
"name": None,
|
||||
"memory": None
|
||||
"memory": None,
|
||||
}
|
||||
gpus.append(gpu)
|
||||
for line in gpu_info:
|
||||
@ -154,6 +179,7 @@ def get_graphics_cards():
|
||||
except Exception as e:
|
||||
return f"Error retrieving GPU info: {e}"
|
||||
|
||||
|
||||
def get_cpu_info():
|
||||
try:
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
@ -165,6 +191,7 @@ def get_cpu_info():
|
||||
except Exception as e:
|
||||
return f"Error retrieving CPU info: {e}"
|
||||
|
||||
|
||||
def system_info(model):
|
||||
return {
|
||||
"System RAM": get_installed_ram(),
|
||||
@ -172,9 +199,10 @@ def system_info(model):
|
||||
"CPU": get_cpu_info(),
|
||||
"LLM Model": model,
|
||||
"Embedding Model": defines.embedding_model,
|
||||
"Context length": defines.max_context
|
||||
"Context length": defines.max_context,
|
||||
}
|
||||
|
||||
|
||||
# %%
|
||||
# Defaults
|
||||
OLLAMA_API_URL = defines.ollama_api_url
|
||||
@ -192,31 +220,58 @@ DEFAULT_HISTORY_LENGTH=5
|
||||
def create_system_message(prompt):
|
||||
return [{"role": "system", "content": prompt}]
|
||||
|
||||
|
||||
tool_log = []
|
||||
command_log = []
|
||||
model = None
|
||||
client = None
|
||||
web_server = None
|
||||
|
||||
|
||||
# %%
|
||||
# Cmd line overrides
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
||||
parser.add_argument("--ollama-server", type=str, default=OLLAMA_API_URL, help=f"Ollama API endpoint. default={OLLAMA_API_URL}")
|
||||
parser.add_argument("--ollama-model", type=str, default=MODEL_NAME, help=f"LLM model to use. default={MODEL_NAME}")
|
||||
parser.add_argument("--web-host", type=str, default=WEB_HOST, help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.")
|
||||
parser.add_argument("--web-port", type=str, default=WEB_PORT, help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.")
|
||||
parser.add_argument("--level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}")
|
||||
parser.add_argument(
|
||||
"--ollama-server",
|
||||
type=str,
|
||||
default=OLLAMA_API_URL,
|
||||
help=f"Ollama API endpoint. default={OLLAMA_API_URL}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ollama-model",
|
||||
type=str,
|
||||
default=MODEL_NAME,
|
||||
help=f"LLM model to use. default={MODEL_NAME}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--web-host",
|
||||
type=str,
|
||||
default=WEB_HOST,
|
||||
help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--web-port",
|
||||
type=str,
|
||||
default=WEB_PORT,
|
||||
help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--level",
|
||||
type=str,
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
default=LOG_LEVEL,
|
||||
help=f"Set the logging level. default={LOG_LEVEL}",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
|
||||
# %%
|
||||
def is_valid_uuid(value: str) -> bool:
|
||||
try:
|
||||
@ -226,11 +281,6 @@ def is_valid_uuid(value: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
class WebServer:
|
||||
@asynccontextmanager
|
||||
@ -239,9 +289,11 @@ class WebServer:
|
||||
self.observer, self.file_watcher = Rag.start_file_watcher(
|
||||
llm=self.llm,
|
||||
watch_directory=defines.doc_dir,
|
||||
recreate=False # Don't recreate if exists
|
||||
recreate=False, # Don't recreate if exists
|
||||
)
|
||||
logger.info(
|
||||
f"API started with {self.file_watcher.collection.count()} documents in the collection"
|
||||
)
|
||||
logger.info(f"API started with {self.file_watcher.collection.count()} documents in the collection")
|
||||
yield
|
||||
if self.observer:
|
||||
self.observer.stop()
|
||||
@ -271,7 +323,9 @@ class WebServer:
|
||||
self.file_watcher = None
|
||||
self.observer = None
|
||||
|
||||
self.ssl_enabled = os.path.exists(defines.key_path) and os.path.exists(defines.cert_path)
|
||||
self.ssl_enabled = os.path.exists(defines.key_path) and os.path.exists(
|
||||
defines.cert_path
|
||||
)
|
||||
|
||||
if self.ssl_enabled:
|
||||
allow_origins = ["https://battle-linux.ketrenos.com:3000"]
|
||||
@ -307,14 +361,18 @@ class WebServer:
|
||||
|
||||
context = self.upsert_context(context_id)
|
||||
if not context:
|
||||
return JSONResponse({"error": f"Invalid context: {context_id}"}, status_code=400)
|
||||
return JSONResponse(
|
||||
{"error": f"Invalid context: {context_id}"}, status_code=400
|
||||
)
|
||||
|
||||
data = await request.json()
|
||||
|
||||
dimensions = data.get("dimensions", 2)
|
||||
result = self.file_watcher.umap_collection
|
||||
if not result:
|
||||
return JSONResponse({"error": "No UMAP collection found"}, status_code=404)
|
||||
return JSONResponse(
|
||||
{"error": "No UMAP collection found"}, status_code=404
|
||||
)
|
||||
if dimensions == 2:
|
||||
logger.info("Returning 2D UMAP")
|
||||
umap_embedding = self.file_watcher.umap_embedding_2d
|
||||
@ -323,7 +381,9 @@ class WebServer:
|
||||
umap_embedding = self.file_watcher.umap_embedding_3d
|
||||
|
||||
if len(umap_embedding) == 0:
|
||||
return JSONResponse({"error": "No UMAP embedding found"}, status_code=404)
|
||||
return JSONResponse(
|
||||
{"error": "No UMAP embedding found"}, status_code=404
|
||||
)
|
||||
|
||||
result["embeddings"] = umap_embedding.tolist()
|
||||
|
||||
@ -347,35 +407,56 @@ class WebServer:
|
||||
try:
|
||||
data = await request.json()
|
||||
query = data.get("query", "")
|
||||
threshold = data.get("threshold", 0.5)
|
||||
results = data.get("results", 10)
|
||||
except:
|
||||
query = ""
|
||||
threshold = 0.5
|
||||
results = 10
|
||||
if not query:
|
||||
return JSONResponse({"error": "No query provided for similarity search"}, status_code=400)
|
||||
|
||||
return JSONResponse(
|
||||
{"error": "No query provided for similarity search"},
|
||||
status_code=400,
|
||||
)
|
||||
try:
|
||||
chroma_results = self.file_watcher.find_similar(query=query, top_k=10)
|
||||
chroma_results = self.file_watcher.find_similar(
|
||||
query=query, top_k=results, threshold=threshold
|
||||
)
|
||||
if not chroma_results:
|
||||
return JSONResponse({"error": "No results found"}, status_code=404)
|
||||
|
||||
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
||||
chroma_embedding = np.array(
|
||||
chroma_results["query_embedding"]
|
||||
).flatten() # Ensure correct shape
|
||||
logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||
|
||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
||||
logger.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[
|
||||
0
|
||||
].tolist()
|
||||
logger.info(
|
||||
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
|
||||
) # Debug output
|
||||
|
||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
||||
logger.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[
|
||||
0
|
||||
].tolist()
|
||||
logger.info(
|
||||
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
|
||||
) # Debug output
|
||||
|
||||
return JSONResponse({
|
||||
return JSONResponse(
|
||||
{
|
||||
**chroma_results,
|
||||
"query": query,
|
||||
"umap_embedding_2d": umap_2d,
|
||||
"umap_embedding_3d": umap_3d
|
||||
})
|
||||
"umap_embedding_3d": umap_3d,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
#return JSONResponse({"error": str(e)}, 500)
|
||||
logging.error(traceback.format_exc())
|
||||
return JSONResponse({"error": str(e)}, 500)
|
||||
|
||||
@self.app.put("/api/reset/{context_id}/{agent_type}")
|
||||
async def put_reset(context_id: str, agent_type: str, request: Request):
|
||||
@ -386,7 +467,10 @@ class WebServer:
|
||||
context = self.upsert_context(context_id)
|
||||
agent = context.get_agent(agent_type)
|
||||
if not agent:
|
||||
return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404)
|
||||
return JSONResponse(
|
||||
{"error": f"{agent_type} is not recognized", "context": context.id},
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
data = await request.json()
|
||||
try:
|
||||
@ -405,20 +489,29 @@ class WebServer:
|
||||
response["tools"] = context.tools
|
||||
case "history":
|
||||
reset_map = {
|
||||
"job_description": ("job_description", "resume", "fact_check"),
|
||||
"job_description": (
|
||||
"job_description",
|
||||
"resume",
|
||||
"fact_check",
|
||||
),
|
||||
"resume": ("job_description", "resume", "fact_check"),
|
||||
"fact_check": ("job_description", "resume", "fact_check"),
|
||||
"fact_check": (
|
||||
"job_description",
|
||||
"resume",
|
||||
"fact_check",
|
||||
),
|
||||
"chat": ("chat",),
|
||||
}
|
||||
resets = reset_map.get(agent_type, ())
|
||||
|
||||
for mode in resets:
|
||||
tmp = context.get_agent(mode)
|
||||
if not tmp:
|
||||
logger.info(
|
||||
f"Agent {mode} not found for {context_id}"
|
||||
)
|
||||
continue
|
||||
logger.info(f"Resetting {reset_operation} for {mode}")
|
||||
context.conversation = Conversation()
|
||||
context.context_tokens = round(len(str(agent.system_prompt)) * 3 / 4) # Estimate context usage
|
||||
tmp.conversation.reset()
|
||||
response["history"] = []
|
||||
response["context_used"] = agent.context_tokens
|
||||
case "message_history_length":
|
||||
@ -427,13 +520,19 @@ class WebServer:
|
||||
response["message_history_length"] = DEFAULT_HISTORY_LENGTH
|
||||
|
||||
if not response:
|
||||
return JSONResponse({ "error": "Usage: { reset: rags|tools|history|system_prompt}"})
|
||||
return JSONResponse(
|
||||
{"error": "Usage: { reset: rags|tools|history|system_prompt}"}
|
||||
)
|
||||
else:
|
||||
self.save_context(context_id)
|
||||
return JSONResponse(response)
|
||||
|
||||
except:
|
||||
return JSONResponse({ "error": "Usage: { reset: rags|tools|history|system_prompt}"})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in reset: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return JSONResponse(
|
||||
{"error": "Usage: { reset: rags|tools|history|system_prompt}"}
|
||||
)
|
||||
|
||||
@self.app.put("/api/tunables/{context_id}")
|
||||
async def put_tunables(context_id: str, request: Request):
|
||||
@ -444,29 +543,49 @@ class WebServer:
|
||||
data = await request.json()
|
||||
agent = context.get_agent("chat")
|
||||
if not agent:
|
||||
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
||||
return JSONResponse(
|
||||
{"error": f"chat is not recognized", "context": context.id},
|
||||
status_code=404,
|
||||
)
|
||||
for k in data.keys():
|
||||
match k:
|
||||
case "tools":
|
||||
# { "tools": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
||||
tools: list[dict[str, Any]] = data[k]
|
||||
if not tools:
|
||||
return JSONResponse({ "status": "error", "message": "Tools can not be empty." })
|
||||
return JSONResponse(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Tools can not be empty.",
|
||||
}
|
||||
)
|
||||
for tool in tools:
|
||||
for context_tool in context.tools:
|
||||
if context_tool["function"]["name"] == tool["name"]:
|
||||
context_tool["enabled"] = tool["enabled"]
|
||||
self.save_context(context_id)
|
||||
return JSONResponse({ "tools": [ {
|
||||
return JSONResponse(
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
**t["function"],
|
||||
"enabled": t["enabled"],
|
||||
} for t in context.tools] })
|
||||
}
|
||||
for t in context.tools
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
case "rags":
|
||||
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
||||
rags: list[dict[str, Any]] = data[k]
|
||||
if not rags:
|
||||
return JSONResponse({ "status": "error", "message": "RAGs can not be empty." })
|
||||
return JSONResponse(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "RAGs can not be empty.",
|
||||
}
|
||||
)
|
||||
for rag in rags:
|
||||
for context_rag in context.rags:
|
||||
if context_rag["name"] == rag["name"]:
|
||||
@ -477,7 +596,12 @@ class WebServer:
|
||||
case "system_prompt":
|
||||
system_prompt = data[k].strip()
|
||||
if not system_prompt:
|
||||
return JSONResponse({ "status": "error", "message": "System prompt can not be empty." })
|
||||
return JSONResponse(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "System prompt can not be empty.",
|
||||
}
|
||||
)
|
||||
agent.system_prompt = system_prompt
|
||||
self.save_context(context_id)
|
||||
return JSONResponse({"system_prompt": system_prompt})
|
||||
@ -487,7 +611,9 @@ class WebServer:
|
||||
self.save_context(context_id)
|
||||
return JSONResponse({"message_history_length": value})
|
||||
case _:
|
||||
return JSONResponse({ "error": f"Unrecognized tunable {k}"}, status_code=404)
|
||||
return JSONResponse(
|
||||
{"error": f"Unrecognized tunable {k}"}, status_code=404
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in put_tunables: {e}")
|
||||
return JSONResponse({"error": str(e)}, status_code=500)
|
||||
@ -501,85 +627,94 @@ class WebServer:
|
||||
context = self.upsert_context(context_id)
|
||||
agent = context.get_agent("chat")
|
||||
if not agent:
|
||||
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
||||
return JSONResponse({
|
||||
return JSONResponse(
|
||||
{"error": f"chat is not recognized", "context": context.id},
|
||||
status_code=404,
|
||||
)
|
||||
return JSONResponse(
|
||||
{
|
||||
"system_prompt": agent.system_prompt,
|
||||
"message_history_length": context.message_history_length,
|
||||
"rags": context.rags,
|
||||
"tools": [ {
|
||||
"tools": [
|
||||
{
|
||||
**t["function"],
|
||||
"enabled": t["enabled"],
|
||||
} for t in context.tools ]
|
||||
})
|
||||
}
|
||||
for t in context.tools
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@self.app.get("/api/system-info/{context_id}")
|
||||
async def get_system_info(context_id: str, request: Request):
|
||||
logger.info(f"{request.method} {request.url.path}")
|
||||
return JSONResponse(system_info(self.model))
|
||||
|
||||
@self.app.post("/api/chat/{context_id}/{agent_type}")
|
||||
async def post_chat_endpoint(context_id: str, agent_type: str, request: Request):
|
||||
@self.app.post("/api/{agent_type}/{context_id}")
|
||||
async def post_agent_endpoint(
|
||||
agent_type: str, context_id: str, request: Request
|
||||
):
|
||||
logger.info(f"{request.method} {request.url.path}")
|
||||
if not is_valid_uuid(context_id):
|
||||
logger.warning(f"Invalid context_id: {context_id}")
|
||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||
|
||||
try:
|
||||
context = self.upsert_context(context_id)
|
||||
agent = context.get_agent(agent_type)
|
||||
except Exception as e:
|
||||
logger.info(f"Attempt to create agent type: {agent_type} failed", e)
|
||||
return JSONResponse({"error": f"{agent_type} is not recognized or context {context_id} is invalid "}, status_code=404)
|
||||
error = {
|
||||
"error": f"Unable to create or access context {context_id}: {e}"
|
||||
}
|
||||
logger.info(error)
|
||||
return JSONResponse(error, status_code=404)
|
||||
|
||||
try:
|
||||
query = await request.json()
|
||||
prompt = query["prompt"]
|
||||
if not isinstance(prompt, str) or len(prompt) == 0:
|
||||
logger.info(f"Prompt is empty")
|
||||
return JSONResponse({"error": "Prompt cannot be empty"}, status_code=400)
|
||||
data = await request.json()
|
||||
query: QueryOptions = QueryOptions(**data)
|
||||
except Exception as e:
|
||||
logger.info(f"Attempt to parse request: {str(e)}.")
|
||||
return JSONResponse({"error": f"Attempt to parse request: {str(e)}."}, status_code=400)
|
||||
error = {"error": f"Attempt to parse request: {e}"}
|
||||
logger.info(error)
|
||||
return JSONResponse(error, status_code=400)
|
||||
|
||||
try:
|
||||
options = Tunables(**query["options"]) if "options" in query else None
|
||||
agent = context.get_or_create_agent(agent_type, **query.agent_options)
|
||||
except Exception as e:
|
||||
logger.info(f"Attempt to set tunables failed: {query['options']}.", e)
|
||||
return JSONResponse({"error": f"Invalid options: {query['options']}"}, status_code=400)
|
||||
error = {
|
||||
"error": f"Attempt to create agent type: {agent_type} failed: {e}"
|
||||
}
|
||||
return JSONResponse(error, status_code=404)
|
||||
|
||||
if not agent:
|
||||
match agent_type:
|
||||
case "job_description":
|
||||
logger.info(f"Agent {agent_type} not found. Returning empty history.")
|
||||
agent = context.get_or_create_agent("job_description", job_description=prompt)
|
||||
case _:
|
||||
logger.info(f"Invalid agent creation sequence for {agent_type}. Returning error.")
|
||||
return JSONResponse({"error": f"{agent_type} is not recognized", "context": context.id}, status_code=404)
|
||||
try:
|
||||
|
||||
async def flush_generator():
|
||||
logger.info(f"{agent.agent_type} - {inspect.stack()[0].function}")
|
||||
try:
|
||||
start_time = time.perf_counter()
|
||||
async for message in self.generate_response(context=context, agent=agent, prompt=prompt, options=options):
|
||||
if message.status != "done":
|
||||
async for message in self.generate_response(
|
||||
context=context,
|
||||
agent=agent,
|
||||
prompt=query.prompt,
|
||||
options=query.options,
|
||||
):
|
||||
if message.status != "done" and message.status != "partial":
|
||||
if message.status == "streaming":
|
||||
result = {
|
||||
"status": "streaming",
|
||||
"chunk": message.chunk,
|
||||
"remaining_time": LLM_TIMEOUT - (time.perf_counter() - start_time)
|
||||
"remaining_time": LLM_TIMEOUT
|
||||
- (time.perf_counter() - start_time),
|
||||
}
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
result = {
|
||||
"status": message.status,
|
||||
"response": message.response,
|
||||
"remaining_time": LLM_TIMEOUT
|
||||
"remaining_time": LLM_TIMEOUT,
|
||||
}
|
||||
else:
|
||||
logger.info(f"Message complete. Providing full response.")
|
||||
logger.info(f"Providing {message.status} response.")
|
||||
try:
|
||||
message.response = message.response
|
||||
result = message.model_dump(by_alias=True, mode='json')
|
||||
result = message.model_dump(
|
||||
by_alias=True, mode="json"
|
||||
)
|
||||
except Exception as e:
|
||||
result = {"status": "error", "response": str(e)}
|
||||
yield json.dumps(result) + "\n"
|
||||
@ -589,26 +724,27 @@ class WebServer:
|
||||
result = json.dumps(result) + "\n"
|
||||
message.network_packets += 1
|
||||
message.network_bytes += len(result)
|
||||
yield result
|
||||
if await request.is_disconnected():
|
||||
logger.info("Disconnect detected. Aborting generation.")
|
||||
context.processing = False
|
||||
# Save context on completion or error
|
||||
message.prompt = prompt
|
||||
message.prompt = query.prompt
|
||||
message.status = "error"
|
||||
message.response = "Client disconnected during generation."
|
||||
message.response = (
|
||||
"Client disconnected during generation."
|
||||
)
|
||||
agent.conversation.add(message)
|
||||
self.save_context(context_id)
|
||||
return
|
||||
|
||||
yield result
|
||||
|
||||
current_time = time.perf_counter()
|
||||
if current_time - start_time > LLM_TIMEOUT:
|
||||
message.status = "error"
|
||||
message.response = f"Processing time ({LLM_TIMEOUT}s) exceeded for single LLM inference (likely due to LLM getting stuck.) You will need to retry your query."
|
||||
message.partial_response = message.response
|
||||
logger.info(message.response + " Ending session")
|
||||
result = message.model_dump(by_alias=True, mode='json')
|
||||
result = message.model_dump(by_alias=True, mode="json")
|
||||
result = json.dumps(result) + "\n"
|
||||
yield result
|
||||
|
||||
@ -620,7 +756,7 @@ class WebServer:
|
||||
await asyncio.sleep(0)
|
||||
except Exception as e:
|
||||
context.processing = False
|
||||
logger.error(f"Error in process_generator: {e}")
|
||||
logger.error(f"Error in generate_response: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
yield json.dumps({"status": "error", "response": str(e)}) + "\n"
|
||||
finally:
|
||||
@ -634,8 +770,8 @@ class WebServer:
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # Prevents Nginx buffering if you're using it
|
||||
}
|
||||
"X-Accel-Buffering": "no", # Prevents Nginx buffering if you're using it
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
context.processing = False
|
||||
@ -660,13 +796,18 @@ class WebServer:
|
||||
context = self.upsert_context(context_id)
|
||||
agent = context.get_agent(agent_type)
|
||||
if not agent:
|
||||
logger.info(f"Agent {agent_type} not found. Returning empty history.")
|
||||
logger.info(
|
||||
f"Agent {agent_type} not found. Returning empty history."
|
||||
)
|
||||
return JSONResponse({"messages": []})
|
||||
logger.info(f"History for {agent_type} contains {len(agent.conversation)} entries.")
|
||||
logger.info(
|
||||
f"History for {agent_type} contains {len(agent.conversation)} entries."
|
||||
)
|
||||
return agent.conversation
|
||||
except Exception as e:
|
||||
logger.error(f"get_history error: {str(e)}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return JSONResponse({"error": str(e)}, status_code=404)
|
||||
|
||||
@ -692,11 +833,12 @@ class WebServer:
|
||||
tool["enabled"] = enabled
|
||||
self.save_context(context_id)
|
||||
return JSONResponse(context.tools)
|
||||
return JSONResponse({ "status": f"{modify} not found in tools." }, status_code=404)
|
||||
return JSONResponse(
|
||||
{"status": f"{modify} not found in tools."}, status_code=404
|
||||
)
|
||||
except:
|
||||
return JSONResponse({"status": "error"}, 405)
|
||||
|
||||
|
||||
@self.app.get("/api/context-status/{context_id}/{agent_type}")
|
||||
async def get_context_status(context_id, agent_type: str, request: Request):
|
||||
logger.info(f"{request.method} {request.url.path}")
|
||||
@ -706,8 +848,15 @@ class WebServer:
|
||||
context = self.upsert_context(context_id)
|
||||
agent = context.get_agent(agent_type)
|
||||
if not agent:
|
||||
return JSONResponse({"context_used": 0, "max_context": defines.max_context})
|
||||
return JSONResponse({"context_used": agent.context_tokens, "max_context": defines.max_context})
|
||||
return JSONResponse(
|
||||
{"context_used": 0, "max_context": defines.max_context}
|
||||
)
|
||||
return JSONResponse(
|
||||
{
|
||||
"context_used": agent.context_tokens,
|
||||
"max_context": defines.max_context,
|
||||
}
|
||||
)
|
||||
|
||||
@self.app.get("/api/health")
|
||||
async def health_check():
|
||||
@ -770,8 +919,11 @@ class WebServer:
|
||||
# Read and deserialize the data
|
||||
with open(file_path, "r") as f:
|
||||
content = f.read()
|
||||
logger.info(f"Loading context from {file_path}, content length: {len(content)}")
|
||||
logger.info(
|
||||
f"Loading context from {file_path}, content length: {len(content)}"
|
||||
)
|
||||
import json
|
||||
|
||||
try:
|
||||
# Try parsing as JSON first to ensure valid JSON
|
||||
json_data = json.loads(content)
|
||||
@ -787,7 +939,9 @@ class WebServer:
|
||||
# Now set context on agents manually
|
||||
agent_types = [agent.agent_type for agent in context.agents]
|
||||
if len(agent_types) != len(set(agent_types)):
|
||||
raise ValueError("Context cannot contain multiple agents of the same agent_type")
|
||||
raise ValueError(
|
||||
"Context cannot contain multiple agents of the same agent_type"
|
||||
)
|
||||
for agent in context.agents:
|
||||
agent.set_context(context)
|
||||
|
||||
@ -799,9 +953,14 @@ class WebServer:
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating context: {str(e)}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
# Fallback to creating a new context
|
||||
self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher, prometheus_collector=self.prometheus_collector)
|
||||
self.contexts[context_id] = Context(
|
||||
id=context_id,
|
||||
file_watcher=self.file_watcher,
|
||||
prometheus_collector=self.prometheus_collector,
|
||||
)
|
||||
|
||||
return self.contexts[context_id]
|
||||
|
||||
@ -818,7 +977,11 @@ class WebServer:
|
||||
if not context_id:
|
||||
context_id = str(uuid4())
|
||||
logger.info(f"Creating new context with ID: {context_id}")
|
||||
context = Context(id=context_id, file_watcher=self.file_watcher, prometheus_collector=self.prometheus_collector)
|
||||
context = Context(
|
||||
id=context_id,
|
||||
file_watcher=self.file_watcher,
|
||||
prometheus_collector=self.prometheus_collector,
|
||||
)
|
||||
|
||||
if os.path.exists(defines.resume_doc):
|
||||
context.user_resume = open(defines.resume_doc, "r").read()
|
||||
@ -855,7 +1018,9 @@ class WebServer:
|
||||
return self.load_or_create_context(context_id)
|
||||
|
||||
@REQUEST_TIME.time()
|
||||
async def generate_response(self, context : Context, agent : Agent, prompt : str, options: Tunables | None) -> AsyncGenerator[Message, None]:
|
||||
async def generate_response(
|
||||
self, context: Context, agent: Agent, prompt: str, options: Tunables | None
|
||||
) -> AsyncGenerator[Message, None]:
|
||||
if not self.file_watcher:
|
||||
raise Exception("File watcher not initialized")
|
||||
|
||||
@ -880,7 +1045,9 @@ class WebServer:
|
||||
if message.status == "error":
|
||||
return
|
||||
|
||||
logger.info(f"{agent_type}.process_message: {message.status} {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
|
||||
logger.info(
|
||||
f"{agent_type}.process_message: {message.status} {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}"
|
||||
)
|
||||
message.status = "done"
|
||||
yield message
|
||||
return
|
||||
@ -895,24 +1062,21 @@ class WebServer:
|
||||
port=port,
|
||||
log_config=None,
|
||||
ssl_keyfile=defines.key_path,
|
||||
ssl_certfile=defines.cert_path
|
||||
ssl_certfile=defines.cert_path,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Starting web server at http://{host}:{port}")
|
||||
uvicorn.run(
|
||||
self.app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_config=None
|
||||
)
|
||||
uvicorn.run(self.app, host=host, port=port, log_config=None)
|
||||
except KeyboardInterrupt:
|
||||
if self.observer:
|
||||
self.observer.stop()
|
||||
if self.observer:
|
||||
self.observer.join()
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
|
||||
# Main function to run everything
|
||||
def main():
|
||||
global model
|
||||
@ -923,17 +1087,9 @@ def main():
|
||||
# Setup logging based on the provided level
|
||||
logger.setLevel(args.level.upper())
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=FutureWarning,
|
||||
module="sklearn.*"
|
||||
)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.*")
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=UserWarning,
|
||||
module="umap.*"
|
||||
)
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="umap.*")
|
||||
|
||||
llm = ollama.Client(host=args.ollama_server) # type: ignore
|
||||
model = args.ollama_model
|
||||
@ -942,4 +1098,5 @@ def main():
|
||||
|
||||
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
|
||||
|
||||
|
||||
main()
|
||||
|
@ -2,28 +2,24 @@ from .. utils import logger
|
||||
|
||||
import ollama
|
||||
|
||||
from .. utils import (
|
||||
rag as Rag,
|
||||
Context,
|
||||
defines
|
||||
)
|
||||
from ..utils import rag as Rag, Context, defines
|
||||
|
||||
import json
|
||||
|
||||
llm = ollama.Client(host=defines.ollama_api_url)
|
||||
|
||||
observer, file_watcher = Rag.start_file_watcher(
|
||||
llm=llm,
|
||||
watch_directory=defines.doc_dir,
|
||||
recreate=False # Don't recreate if exists
|
||||
llm=llm, watch_directory=defines.doc_dir, recreate=False # Don't recreate if exists
|
||||
)
|
||||
|
||||
context = Context(file_watcher=file_watcher)
|
||||
data = context.model_dump(mode='json')
|
||||
data = context.model_dump(mode="json")
|
||||
context = Context.model_validate_json(json.dumps(data))
|
||||
context.file_watcher = file_watcher
|
||||
|
||||
agent = context.get_or_create_agent("chat", system_prompt="You are a helpful assistant.")
|
||||
agent = context.get_or_create_agent(
|
||||
"chat", system_prompt="You are a helpful assistant."
|
||||
)
|
||||
# logger.info(f"data: {data}")
|
||||
# logger.info(f"agent: {agent}")
|
||||
agent_type = agent.get_agent_type()
|
||||
@ -32,7 +28,7 @@ logger.info(f"system_prompt: {agent.system_prompt}")
|
||||
|
||||
agent.system_prompt = "Eat more tomatoes."
|
||||
|
||||
data = context.model_dump(mode='json')
|
||||
data = context.model_dump(mode="json")
|
||||
context = Context.model_validate_json(json.dumps(data))
|
||||
context.file_watcher = file_watcher
|
||||
|
||||
|
@ -8,12 +8,14 @@ from anyio.to_thread import run_sync # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedirectToContext(Exception):
|
||||
def __init__(self, url: str):
|
||||
self.url = url
|
||||
logger.info(f"Redirect to Context: {url}")
|
||||
super().__init__(f"Redirect to Context: {url}")
|
||||
|
||||
|
||||
class ContextRouteManager:
|
||||
def __init__(self, app: FastAPI):
|
||||
self.app = app
|
||||
@ -25,20 +27,28 @@ class ContextRouteManager:
|
||||
logger.info(f"Handling redirect to {exc.url}")
|
||||
return RedirectResponse(url=exc.url, status_code=307)
|
||||
|
||||
def ensure_context(self, route_name: str = "context_id") -> Callable[[Request], Optional[UUID]]:
|
||||
def ensure_context(
|
||||
self, route_name: str = "context_id"
|
||||
) -> Callable[[Request], Optional[UUID]]:
|
||||
logger.info(f"Setting up context dependency for route parameter: {route_name}")
|
||||
|
||||
async def _ensure_context_dependency(request: Request) -> Optional[UUID]:
|
||||
logger.info(f"Entering ensure_context_dependency, Request URL: {request.url}")
|
||||
logger.info(
|
||||
f"Entering ensure_context_dependency, Request URL: {request.url}"
|
||||
)
|
||||
logger.info(f"Path params: {request.path_params}")
|
||||
|
||||
path_params = request.path_params
|
||||
route_value = path_params.get(route_name)
|
||||
logger.info(f"route_value: {route_value!r}, type: {type(route_value)}")
|
||||
|
||||
if route_value is None or not isinstance(route_value, str) or not route_value.strip():
|
||||
if (
|
||||
route_value is None
|
||||
or not isinstance(route_value, str)
|
||||
or not route_value.strip()
|
||||
):
|
||||
logger.info(f"route_value is invalid, generating new UUID")
|
||||
path = request.url.path.rstrip('/')
|
||||
path = request.url.path.rstrip("/")
|
||||
new_context = await run_sync(uuid4)
|
||||
redirect_url = f"{path}/{new_context}"
|
||||
logger.info(f"Redirecting to {redirect_url}")
|
||||
@ -50,8 +60,10 @@ class ContextRouteManager:
|
||||
logger.info(f"Successfully parsed UUID: {route_context}")
|
||||
return route_context
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to parse UUID from route_value: {route_value!r}, error: {str(e)}")
|
||||
path = request.url.path.rstrip('/')
|
||||
logger.error(
|
||||
f"Failed to parse UUID from route_value: {route_value!r}, error: {str(e)}"
|
||||
)
|
||||
path = request.url.path.rstrip("/")
|
||||
new_context = await run_sync(uuid4)
|
||||
redirect_url = f"{path}/{new_context}"
|
||||
logger.info(f"Invalid UUID, redirecting to {redirect_url}")
|
||||
@ -66,44 +78,62 @@ class ContextRouteManager:
|
||||
def decorator(func):
|
||||
all_dependencies = list(dependencies)
|
||||
all_dependencies.append(Depends(ensure_context))
|
||||
logger.info(f"Route {path} registered with dependencies: {all_dependencies}")
|
||||
logger.info(
|
||||
f"Route {path} registered with dependencies: {all_dependencies}"
|
||||
)
|
||||
return self.app.get(path, dependencies=all_dependencies, **kwargs)(func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
app = FastAPI(redirect_slashes=True)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
logger.error(f"Unhandled exception: {str(exc)}")
|
||||
logger.error(f"Request URL: {request.url}, Path params: {request.path_params}")
|
||||
logger.error(f"Stack trace: {''.join(traceback.format_tb(exc.__traceback__))}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "Internal server error", "detail": str(exc)}
|
||||
status_code=500, content={"error": "Internal server error", "detail": str(exc)}
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
logger.info(f"Incoming request: {request.method} {request.url}, Path params: {request.path_params}")
|
||||
logger.info(
|
||||
f"Incoming request: {request.method} {request.url}, Path params: {request.path_params}"
|
||||
)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
context_router = ContextRouteManager(app)
|
||||
|
||||
|
||||
@context_router.route_pattern("/api/history/{context_id}")
|
||||
async def get_history(request: Request, context_id: UUID = Depends(context_router.ensure_context()), agent_type: str = Query(..., description="Type of agent to retrieve history for")):
|
||||
async def get_history(
|
||||
request: Request,
|
||||
context_id: UUID = Depends(context_router.ensure_context()),
|
||||
agent_type: str = Query(..., description="Type of agent to retrieve history for"),
|
||||
):
|
||||
logger.info(f"{request.method} {request.url.path} with context_id: {context_id}")
|
||||
return {"context_id": str(context_id), "agent_type": agent_type}
|
||||
|
||||
|
||||
@app.get("/api/history")
|
||||
async def redirect_history(request: Request, agent_type: str = Query(..., description="Type of agent to retrieve history for")):
|
||||
path = request.url.path.rstrip('/')
|
||||
async def redirect_history(
|
||||
request: Request,
|
||||
agent_type: str = Query(..., description="Type of agent to retrieve history for"),
|
||||
):
|
||||
path = request.url.path.rstrip("/")
|
||||
new_context = uuid4()
|
||||
redirect_url = f"{path}/{new_context}?agent_type={agent_type}"
|
||||
logger.info(f"Redirecting from /api/history to {redirect_url}")
|
||||
return RedirectResponse(url=redirect_url, status_code=307)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn # type: ignore
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8900)
|
@ -1,28 +1,24 @@
|
||||
# From /opt/backstory run:
|
||||
# python -m src.tests.test-context
|
||||
import os
|
||||
|
||||
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
|
||||
|
||||
import ollama
|
||||
|
||||
from .. utils import (
|
||||
rag as Rag,
|
||||
Context,
|
||||
defines
|
||||
)
|
||||
from ..utils import rag as Rag, Context, defines
|
||||
|
||||
import json
|
||||
|
||||
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
||||
|
||||
observer, file_watcher = Rag.start_file_watcher(
|
||||
llm=llm,
|
||||
watch_directory=defines.doc_dir,
|
||||
recreate=False # Don't recreate if exists
|
||||
llm=llm, watch_directory=defines.doc_dir, recreate=False # Don't recreate if exists
|
||||
)
|
||||
|
||||
context = Context(file_watcher=file_watcher)
|
||||
data = context.model_dump(mode='json')
|
||||
data = context.model_dump(mode="json")
|
||||
context = Context.from_json(json.dumps(data), file_watcher=file_watcher)
|
||||
|
@ -2,12 +2,10 @@
|
||||
# python -m src.tests.test-message
|
||||
from ..utils import logger
|
||||
|
||||
from .. utils import (
|
||||
Message
|
||||
)
|
||||
from ..utils import Message
|
||||
|
||||
import json
|
||||
|
||||
prompt = "This is a test"
|
||||
message = Message(prompt=prompt)
|
||||
print(message.model_dump(mode='json'))
|
||||
print(message.model_dump(mode="json"))
|
||||
|
@ -1,8 +1,6 @@
|
||||
# From /opt/backstory run:
|
||||
# python -m src.tests.test-metrics
|
||||
from .. utils import (
|
||||
Metrics
|
||||
)
|
||||
from ..utils import Metrics
|
||||
|
||||
import json
|
||||
|
||||
@ -13,7 +11,7 @@ metrics = Metrics()
|
||||
metrics.prepare_count.labels(agent="chat").inc()
|
||||
metrics.prepare_duration.labels(agent="prepare").observe(0.45)
|
||||
|
||||
json = metrics.model_dump(mode='json')
|
||||
json = metrics.model_dump(mode="json")
|
||||
metrics = Metrics.model_validate(json)
|
||||
|
||||
print(metrics)
|
@ -12,21 +12,22 @@ from . agents import class_registry, AnyAgent, Agent, __all__ as agents_all
|
||||
from .metrics import Metrics
|
||||
|
||||
__all__ = [
|
||||
'Agent',
|
||||
'Tunables',
|
||||
'Context',
|
||||
'Conversation',
|
||||
'Message',
|
||||
'Metrics',
|
||||
'ChromaDBFileWatcher',
|
||||
'start_file_watcher',
|
||||
'logger',
|
||||
"Agent",
|
||||
"Tunables",
|
||||
"Context",
|
||||
"Conversation",
|
||||
"Message",
|
||||
"Metrics",
|
||||
"ChromaDBFileWatcher",
|
||||
"start_file_watcher",
|
||||
"logger",
|
||||
]
|
||||
|
||||
__all__.extend(agents_all) # type: ignore
|
||||
|
||||
logger = setup_logging(level=defines.logging_level)
|
||||
|
||||
|
||||
def rebuild_models():
|
||||
for class_name, (module_name, _) in class_registry.items():
|
||||
try:
|
||||
@ -36,9 +37,15 @@ def rebuild_models():
|
||||
logger.debug(f"Checking: {class_name} in module {module_name}")
|
||||
logger.debug(f" cls: {True if cls else False}")
|
||||
logger.debug(f" isinstance(cls, type): {isinstance(cls, type)}")
|
||||
logger.debug(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}")
|
||||
logger.debug(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}")
|
||||
logger.debug(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}")
|
||||
logger.debug(
|
||||
f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}"
|
||||
)
|
||||
logger.debug(
|
||||
f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}"
|
||||
)
|
||||
logger.debug(
|
||||
f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}"
|
||||
)
|
||||
|
||||
if (
|
||||
cls
|
||||
@ -50,11 +57,13 @@ def rebuild_models():
|
||||
logger.debug(f"Rebuilding {class_name} from {module_name}")
|
||||
from .agents import Agent
|
||||
from .context import Context
|
||||
|
||||
cls.model_rebuild()
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import module {module_name}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {class_name} in {module_name}: {e}")
|
||||
|
||||
|
||||
# Call this after all modules are imported
|
||||
rebuild_models()
|
||||
|
@ -16,7 +16,9 @@ __all__ = [ "AnyAgent", "Agent", "agent_registry", "class_registry" ]
|
||||
# Type alias for Agent or any subclass
|
||||
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
|
||||
|
||||
class_registry: Dict[str, Tuple[str, str]] = {} # Maps class_name to (module_name, class_name)
|
||||
class_registry: Dict[str, Tuple[str, str]] = (
|
||||
{}
|
||||
) # Maps class_name to (module_name, class_name)
|
||||
|
||||
package_dir = pathlib.Path(__file__).parent
|
||||
package_name = __name__
|
||||
|
@ -1,8 +1,17 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, PrivateAttr, Field # type: ignore
|
||||
from typing import (
|
||||
Literal, get_args, List, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any,
|
||||
TypeAlias, Dict, Tuple
|
||||
Literal,
|
||||
get_args,
|
||||
List,
|
||||
AsyncGenerator,
|
||||
TYPE_CHECKING,
|
||||
Optional,
|
||||
ClassVar,
|
||||
Any,
|
||||
TypeAlias,
|
||||
Dict,
|
||||
Tuple,
|
||||
)
|
||||
import json
|
||||
import time
|
||||
@ -24,24 +33,26 @@ from . types import agent_registry
|
||||
from .. import defines
|
||||
from ..message import Message, Tunables
|
||||
from ..metrics import Metrics
|
||||
from .. tools import ( TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools ) # type: ignore -- dynamically added to __all__
|
||||
from ..tools import TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools # type: ignore -- dynamically added to __all__
|
||||
from ..conversation import Conversation
|
||||
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
role: str = Field(default="")
|
||||
content: str = Field(default="")
|
||||
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
|
||||
|
||||
|
||||
class Agent(BaseModel, ABC):
|
||||
"""
|
||||
Base class for all agent types.
|
||||
This class defines the common attributes and methods for all agent types.
|
||||
"""
|
||||
|
||||
# Agent management with pydantic
|
||||
agent_type: Literal["base"] = "base"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
|
||||
# Tunables (sets default for new Messages attached to this agent)
|
||||
tunables: Tunables = Field(default_factory=Tunables)
|
||||
|
||||
@ -49,11 +60,14 @@ class Agent(BaseModel, ABC):
|
||||
system_prompt: str # Mandatory
|
||||
conversation: Conversation = Conversation()
|
||||
context_tokens: int = 0
|
||||
context: Optional[Context] = Field(default=None, exclude=True) # Avoid circular reference, require as param, and prevent serialization
|
||||
context: Optional[Context] = Field(
|
||||
default=None, exclude=True
|
||||
) # Avoid circular reference, require as param, and prevent serialization
|
||||
metrics: Metrics = Field(default_factory=Metrics, exclude=True)
|
||||
|
||||
# context_size is shared across all subclasses
|
||||
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
|
||||
|
||||
@property
|
||||
def context_size(self) -> int:
|
||||
return Agent._context_size
|
||||
@ -62,7 +76,9 @@ class Agent(BaseModel, ABC):
|
||||
def context_size(self, value: int):
|
||||
Agent._context_size = value
|
||||
|
||||
def set_optimal_context_size(self, llm: Any, model: str, prompt: str, ctx_buffer=2048) -> int:
|
||||
def set_optimal_context_size(
|
||||
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
|
||||
) -> int:
|
||||
# # Get more accurate token count estimate using tiktoken or similar
|
||||
# response = llm.generate(
|
||||
# model=model,
|
||||
@ -83,7 +99,9 @@ class Agent(BaseModel, ABC):
|
||||
total_ctx = tokens + ctx_buffer
|
||||
|
||||
if total_ctx > self.context_size:
|
||||
logger.info(f"Increasing context size from {self.context_size} to {total_ctx}")
|
||||
logger.info(
|
||||
f"Increasing context size from {self.context_size} to {total_ctx}"
|
||||
)
|
||||
|
||||
# Grow the context size if necessary
|
||||
self.context_size = max(self.context_size, total_ctx)
|
||||
@ -95,7 +113,7 @@ class Agent(BaseModel, ABC):
|
||||
"""Auto-register subclasses"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Register this class if it has an agent_type
|
||||
if hasattr(cls, 'agent_type') and cls.agent_type != Agent._agent_type:
|
||||
if hasattr(cls, "agent_type") and cls.agent_type != Agent._agent_type:
|
||||
agent_registry.register(cls.agent_type, cls)
|
||||
|
||||
# def __init__(self, *, context=context, **data):
|
||||
@ -166,7 +184,14 @@ class Agent(BaseModel, ABC):
|
||||
|
||||
return
|
||||
|
||||
async def process_tool_calls(self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[LLMMessage]) -> AsyncGenerator[Message, None]:
|
||||
async def process_tool_calls(
|
||||
self,
|
||||
llm: Any,
|
||||
model: str,
|
||||
message: Message,
|
||||
tool_message: Any,
|
||||
messages: List[LLMMessage],
|
||||
) -> AsyncGenerator[Message, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
|
||||
self.metrics.tool_count.labels(agent=self.agent_type).inc()
|
||||
@ -187,7 +212,9 @@ class Agent(BaseModel, ABC):
|
||||
tool = tool_call.function.name
|
||||
|
||||
# Yield status update before processing each tool
|
||||
message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
|
||||
message.response = (
|
||||
f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
|
||||
)
|
||||
yield message
|
||||
logger.info(f"LLM - {message.response}")
|
||||
|
||||
@ -202,12 +229,18 @@ class Agent(BaseModel, ABC):
|
||||
|
||||
case "AnalyzeSite":
|
||||
url = arguments.get("url")
|
||||
question = arguments.get("question", "what is the summary of this content?")
|
||||
question = arguments.get(
|
||||
"question", "what is the summary of this content?"
|
||||
)
|
||||
|
||||
# Additional status update for long-running operations
|
||||
message.response = f"Retrieving and summarizing content from {url}..."
|
||||
message.response = (
|
||||
f"Retrieving and summarizing content from {url}..."
|
||||
)
|
||||
yield message
|
||||
ret = await AnalyzeSite(llm=llm, model=model, url=url, question=question)
|
||||
ret = await AnalyzeSite(
|
||||
llm=llm, model=model, url=url, question=question
|
||||
)
|
||||
|
||||
case "DateTime":
|
||||
tz = arguments.get("timezone")
|
||||
@ -217,7 +250,9 @@ class Agent(BaseModel, ABC):
|
||||
city = arguments.get("city")
|
||||
state = arguments.get("state")
|
||||
|
||||
message.response = f"Fetching weather data for {city}, {state}..."
|
||||
message.response = (
|
||||
f"Fetching weather data for {city}, {state}..."
|
||||
)
|
||||
yield message
|
||||
ret = WeatherForecast(city, state)
|
||||
|
||||
@ -228,7 +263,7 @@ class Agent(BaseModel, ABC):
|
||||
tool_response = {
|
||||
"role": "tool",
|
||||
"content": json.dumps(ret),
|
||||
"name": tool_call.function.name
|
||||
"name": tool_call.function.name,
|
||||
}
|
||||
|
||||
tool_metadata["tool_calls"].append(tool_response)
|
||||
@ -241,13 +276,15 @@ class Agent(BaseModel, ABC):
|
||||
message_dict = LLMMessage(
|
||||
role=tool_message.get("role", "assistant"),
|
||||
content=tool_message.get("content", ""),
|
||||
tool_calls=[ {
|
||||
tool_calls=[
|
||||
{
|
||||
"function": {
|
||||
"name": tc["function"]["name"],
|
||||
"arguments": tc["function"]["arguments"]
|
||||
"arguments": tc["function"]["arguments"],
|
||||
}
|
||||
} for tc in tool_message.tool_calls
|
||||
]
|
||||
}
|
||||
for tc in tool_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
messages.append(message_dict)
|
||||
@ -282,20 +319,30 @@ class Agent(BaseModel, ABC):
|
||||
message.metadata["eval_count"] += response.eval_count
|
||||
message.metadata["eval_duration"] += response.eval_duration
|
||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
||||
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
|
||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
||||
message.metadata[
|
||||
"prompt_eval_duration"
|
||||
] += response.prompt_eval_duration
|
||||
self.context_tokens = (
|
||||
response.prompt_eval_count + response.eval_count
|
||||
)
|
||||
message.status = "done"
|
||||
yield message
|
||||
|
||||
end_time = time.perf_counter()
|
||||
message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}"
|
||||
message.metadata["timers"][
|
||||
"llm_with_tools"
|
||||
] = f"{(end_time - start_time):.4f}"
|
||||
return
|
||||
|
||||
def collect_metrics(self, response):
|
||||
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(response.prompt_eval_count)
|
||||
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
|
||||
response.prompt_eval_count
|
||||
)
|
||||
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
|
||||
|
||||
async def generate_llm_response(self, llm: Any, model: str, message: Message, temperature = 0.7) -> AsyncGenerator[Message, None]:
|
||||
async def generate_llm_response(
|
||||
self, llm: Any, model: str, message: Message, temperature=0.7
|
||||
) -> AsyncGenerator[Message, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
|
||||
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||
@ -305,22 +352,29 @@ class Agent(BaseModel, ABC):
|
||||
|
||||
# Create a pruned down message list based purely on the prompt and responses,
|
||||
# discarding the full preamble generated by prepare_message
|
||||
messages: List[LLMMessage] = [ LLMMessage(role="system", content=message.system_prompt) ]
|
||||
messages.extend([
|
||||
item for m in self.conversation
|
||||
messages: List[LLMMessage] = [
|
||||
LLMMessage(role="system", content=message.system_prompt)
|
||||
]
|
||||
messages.extend(
|
||||
[
|
||||
item
|
||||
for m in self.conversation
|
||||
for item in [
|
||||
LLMMessage(role="user", content=m.prompt.strip()),
|
||||
LLMMessage(role="assistant", content=m.response.strip())
|
||||
LLMMessage(role="assistant", content=m.response.strip()),
|
||||
]
|
||||
])
|
||||
]
|
||||
)
|
||||
# Only the actual user query is provided with the full context message
|
||||
messages.append(LLMMessage(role="user", content=message.context_prompt.strip()))
|
||||
messages.append(
|
||||
LLMMessage(role="user", content=message.context_prompt.strip())
|
||||
)
|
||||
|
||||
# message.metadata["messages"] = messages
|
||||
message.metadata["options"] = {
|
||||
"seed": 8911,
|
||||
"num_ctx": self.context_size,
|
||||
"temperature": temperature # Higher temperature to encourage tool usage
|
||||
"temperature": temperature, # Higher temperature to encourage tool usage
|
||||
}
|
||||
|
||||
# Create a dict for storing various timing stats
|
||||
@ -329,7 +383,7 @@ class Agent(BaseModel, ABC):
|
||||
use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
||||
message.metadata["tools"] = {
|
||||
"available": llm_tools(self.context.tools),
|
||||
"used": False
|
||||
"used": False,
|
||||
}
|
||||
tool_metadata = message.metadata["tools"]
|
||||
|
||||
@ -357,12 +411,14 @@ class Agent(BaseModel, ABC):
|
||||
**message.metadata["options"],
|
||||
# "num_predict": 1024, # "Low" token limit to cut off after tool call
|
||||
},
|
||||
stream=False # No need to stream the probe
|
||||
stream=False, # No need to stream the probe
|
||||
)
|
||||
self.collect_metrics(response)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
message.metadata["timers"]["tool_check"] = f"{(end_time - start_time):.4f}"
|
||||
message.metadata["timers"][
|
||||
"tool_check"
|
||||
] = f"{(end_time - start_time):.4f}"
|
||||
if not response.message.tool_calls:
|
||||
logger.info("LLM indicates tools will not be used")
|
||||
# The LLM will not use tools, so disable use_tools so we can stream the full response
|
||||
@ -374,7 +430,9 @@ class Agent(BaseModel, ABC):
|
||||
logger.info("LLM indicates tools will be used")
|
||||
|
||||
# Tools are enabled and available and the LLM indicated it will use them
|
||||
message.response = f"Performing tool analysis step 2/2 (tool use suspected)..."
|
||||
message.response = (
|
||||
f"Performing tool analysis step 2/2 (tool use suspected)..."
|
||||
)
|
||||
yield message
|
||||
|
||||
logger.info(f"Performing LLM call with tools")
|
||||
@ -386,12 +444,14 @@ class Agent(BaseModel, ABC):
|
||||
options={
|
||||
**message.metadata["options"],
|
||||
},
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
self.collect_metrics(response)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
message.metadata["timers"]["non_streaming"] = f"{(end_time - start_time):.4f}"
|
||||
message.metadata["timers"][
|
||||
"non_streaming"
|
||||
] = f"{(end_time - start_time):.4f}"
|
||||
|
||||
if not response:
|
||||
message.status = "error"
|
||||
@ -403,13 +463,21 @@ class Agent(BaseModel, ABC):
|
||||
tool_metadata["used"] = response.message.tool_calls
|
||||
# Process all yielded items from the handler
|
||||
start_time = time.perf_counter()
|
||||
async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages):
|
||||
async for message in self.process_tool_calls(
|
||||
llm=llm,
|
||||
model=model,
|
||||
message=message,
|
||||
tool_message=response.message,
|
||||
messages=messages,
|
||||
):
|
||||
if message.status == "error":
|
||||
yield message
|
||||
return
|
||||
yield message
|
||||
end_time = time.perf_counter()
|
||||
message.metadata["timers"]["process_tool_calls"] = f"{(end_time - start_time):.4f}"
|
||||
message.metadata["timers"][
|
||||
"process_tool_calls"
|
||||
] = f"{(end_time - start_time):.4f}"
|
||||
message.status = "done"
|
||||
return
|
||||
|
||||
@ -452,8 +520,12 @@ class Agent(BaseModel, ABC):
|
||||
message.metadata["eval_count"] += response.eval_count
|
||||
message.metadata["eval_duration"] += response.eval_duration
|
||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
||||
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
|
||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
||||
message.metadata[
|
||||
"prompt_eval_duration"
|
||||
] += response.prompt_eval_duration
|
||||
self.context_tokens = (
|
||||
response.prompt_eval_count + response.eval_count
|
||||
)
|
||||
message.status = "done"
|
||||
yield message
|
||||
|
||||
@ -461,7 +533,9 @@ class Agent(BaseModel, ABC):
|
||||
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
|
||||
return
|
||||
|
||||
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
|
||||
async def process_message(
|
||||
self, llm: Any, model: str, message: Message
|
||||
) -> AsyncGenerator[Message, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
|
||||
self.metrics.process_count.labels(agent=self.agent_type).inc()
|
||||
@ -470,22 +544,30 @@ class Agent(BaseModel, ABC):
|
||||
if not self.context:
|
||||
raise ValueError("Context is not set for this agent.")
|
||||
|
||||
logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||
spinner: List[str] = ['\\', '|', '/', '-']
|
||||
logger.info(
|
||||
"TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
|
||||
)
|
||||
spinner: List[str] = ["\\", "|", "/", "-"]
|
||||
tick: int = 0
|
||||
while self.context.processing:
|
||||
message.status = "waiting"
|
||||
message.response = f"Busy processing another request. Please wait. {spinner[tick]}"
|
||||
message.response = (
|
||||
f"Busy processing another request. Please wait. {spinner[tick]}"
|
||||
)
|
||||
tick = (tick + 1) % len(spinner)
|
||||
yield message
|
||||
await asyncio.sleep(1) # Allow the event loop to process the write
|
||||
|
||||
self.context.processing = True
|
||||
|
||||
message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
|
||||
message.metadata["system_prompt"] = (
|
||||
f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
|
||||
)
|
||||
message.context_prompt = ""
|
||||
for p in message.preamble.keys():
|
||||
message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n</|{p}>\n\n"
|
||||
message.context_prompt += (
|
||||
f"\n<|{p}|>\n{message.preamble[p].strip()}\n</|{p}>\n\n"
|
||||
)
|
||||
message.context_prompt += f"{message.prompt}"
|
||||
|
||||
# Estimate token length of new messages
|
||||
@ -493,13 +575,17 @@ class Agent(BaseModel, ABC):
|
||||
message.status = "thinking"
|
||||
yield message
|
||||
|
||||
message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt)
|
||||
message.metadata["context_size"] = self.set_optimal_context_size(
|
||||
llm, model, prompt=message.context_prompt
|
||||
)
|
||||
|
||||
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
||||
message.status = "thinking"
|
||||
yield message
|
||||
|
||||
async for message in self.generate_llm_response(llm=llm, model=model, message=message):
|
||||
async for message in self.generate_llm_response(
|
||||
llm=llm, model=model, message=message
|
||||
):
|
||||
# logger.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
|
||||
if message.status == "error":
|
||||
yield message
|
||||
@ -514,6 +600,6 @@ class Agent(BaseModel, ABC):
|
||||
|
||||
return
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(Agent._agent_type, Agent)
|
||||
|
||||
|
@ -6,6 +6,7 @@ import inspect
|
||||
from .base import Agent, agent_registry
|
||||
from ..message import Message
|
||||
from ..setup_logging import setup_logging
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
system_message = f"""
|
||||
@ -26,10 +27,12 @@ When answering queries, follow these steps:
|
||||
Always use tools, <|resume|>, and <|context|> when possible. Be concise, and never make up information. If you do not know the answer, say so.
|
||||
"""
|
||||
|
||||
|
||||
class Chat(Agent):
|
||||
"""
|
||||
Chat Agent
|
||||
"""
|
||||
|
||||
agent_type: Literal["chat"] = "chat" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
@ -46,15 +49,20 @@ class Chat(Agent):
|
||||
|
||||
if message.preamble:
|
||||
excluded = {}
|
||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys() if p not in excluded]
|
||||
preamble_types = [
|
||||
f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
|
||||
]
|
||||
preamble_types_AND = " and ".join(preamble_types)
|
||||
preamble_types_OR = " or ".join(preamble_types)
|
||||
message.preamble["rules"] = f"""\
|
||||
message.preamble[
|
||||
"rules"
|
||||
] = f"""\
|
||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
||||
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||
"""
|
||||
message.preamble["question"] = "Respond to:"
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(Chat._agent_type, Chat)
|
||||
|
@ -1,6 +1,13 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator # type: ignore
|
||||
from typing import Literal, ClassVar, Optional, Any, AsyncGenerator, List # NOTE: You must import Optional for late binding to work
|
||||
from typing import (
|
||||
Literal,
|
||||
ClassVar,
|
||||
Optional,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
from datetime import datetime
|
||||
import inspect
|
||||
|
||||
@ -8,6 +15,7 @@ from . base import Agent, agent_registry
|
||||
from ..conversation import Conversation
|
||||
from ..message import Message
|
||||
from ..setup_logging import setup_logging
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
system_fact_check = f"""
|
||||
@ -21,6 +29,7 @@ When answering queries, follow these steps:
|
||||
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|generated-resume|>, or <|resume|> tags.
|
||||
""".strip()
|
||||
|
||||
|
||||
class FactCheck(Agent):
|
||||
agent_type: Literal["fact_check"] = "fact_check" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
@ -53,10 +62,14 @@ class FactCheck(Agent):
|
||||
message.preamble["discrepancies"] = self.facts
|
||||
|
||||
excluded = {"job_description"}
|
||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys() if p not in excluded]
|
||||
preamble_types = [
|
||||
f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
|
||||
]
|
||||
preamble_types_AND = " and ".join(preamble_types)
|
||||
preamble_types_OR = " or ".join(preamble_types)
|
||||
message.preamble["rules"] = f"""\
|
||||
message.preamble[
|
||||
"rules"
|
||||
] = f"""\
|
||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
||||
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||
@ -66,5 +79,6 @@ class FactCheck(Agent):
|
||||
yield message
|
||||
return
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(FactCheck._agent_type, FactCheck)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,12 +1,20 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator # type: ignore
|
||||
from typing import Literal, ClassVar, Optional, Any, AsyncGenerator, List # NOTE: You must import Optional for late binding to work
|
||||
from typing import (
|
||||
Literal,
|
||||
ClassVar,
|
||||
Optional,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
from datetime import datetime
|
||||
import inspect
|
||||
|
||||
from .base import Agent, agent_registry
|
||||
from ..message import Message
|
||||
from ..setup_logging import setup_logging
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
system_fact_check = f"""
|
||||
@ -36,6 +44,7 @@ When answering queries, follow these steps:
|
||||
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|job_description|>, <|resume|>, or <|context|> tags.
|
||||
""".strip()
|
||||
|
||||
|
||||
class Resume(Agent):
|
||||
agent_type: Literal["resume"] = "resume" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
@ -69,10 +78,14 @@ class Resume(Agent):
|
||||
message.preamble["job_description"] = job_description_agent.job_description
|
||||
|
||||
excluded = {}
|
||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys() if p not in excluded]
|
||||
preamble_types = [
|
||||
f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
|
||||
]
|
||||
preamble_types_AND = " and ".join(preamble_types)
|
||||
preamble_types_OR = " or ".join(preamble_types)
|
||||
message.preamble["rules"] = f"""\
|
||||
message.preamble[
|
||||
"rules"
|
||||
] = f"""\
|
||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
||||
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||
@ -81,12 +94,16 @@ class Resume(Agent):
|
||||
if fact_check_agent:
|
||||
message.preamble["question"] = "Respond to:"
|
||||
else:
|
||||
message.preamble["question"] = f"Fact check the <|generated-resume|> based on the <|resume|>{' and <|context|>' if 'context' in message.preamble else ''}."
|
||||
message.preamble["question"] = (
|
||||
f"Fact check the <|generated-resume|> based on the <|resume|>{' and <|context|>' if 'context' in message.preamble else ''}."
|
||||
)
|
||||
|
||||
yield message
|
||||
return
|
||||
|
||||
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
|
||||
async def process_message(
|
||||
self, llm: Any, model: str, message: Message
|
||||
) -> AsyncGenerator[Message, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
if not self.context:
|
||||
raise ValueError("Context is not set for this agent.")
|
||||
@ -103,7 +120,9 @@ class Resume(Agent):
|
||||
|
||||
# Instantiate the "resume" agent, and seed (or reset) its conversation
|
||||
# with this message.
|
||||
fact_check_agent = self.context.get_or_create_agent(agent_type="fact_check", facts=message.response)
|
||||
fact_check_agent = self.context.get_or_create_agent(
|
||||
agent_type="fact_check", facts=message.response
|
||||
)
|
||||
first_fact_check_message = message.copy()
|
||||
first_fact_check_message.prompt = "Fact check the generated resume."
|
||||
fact_check_agent.conversation.add(first_fact_check_message)
|
||||
@ -113,5 +132,6 @@ class Resume(Agent):
|
||||
yield message
|
||||
return
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(Resume._agent_type, Resume)
|
||||
|
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Optional, Type
|
||||
|
||||
|
||||
# We'll use a registry pattern rather than hardcoded strings
|
||||
class AgentRegistry:
|
||||
"""Registry for agent types and classes"""
|
||||
|
||||
_registry: Dict[str, Type] = {}
|
||||
|
||||
@classmethod
|
||||
@ -27,5 +29,6 @@ class AgentRegistry:
|
||||
"""Get all registered agent classes"""
|
||||
return cls._registry.copy()
|
||||
|
||||
|
||||
# Create a singleton instance
|
||||
agent_registry = AgentRegistry()
|
@ -1,122 +0,0 @@
|
||||
import chromadb
|
||||
from typing import List, Dict, Any, Union
|
||||
from . import defines
|
||||
from .chunk import chunk_document
|
||||
import ollama
|
||||
|
||||
def init_chroma_client(persist_directory: str = defines.persist_directory):
|
||||
"""Initialize and return a ChromaDB client."""
|
||||
# return chromadb.PersistentClient(path=persist_directory)
|
||||
return chromadb.Client()
|
||||
|
||||
def create_or_get_collection(db: chromadb.Client, collection_name: str):
|
||||
"""Create or get a ChromaDB collection."""
|
||||
try:
|
||||
return db.get_collection(
|
||||
name=collection_name
|
||||
)
|
||||
except:
|
||||
return db.create_collection(
|
||||
name=collection_name,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
def process_documents_to_chroma(
|
||||
client: ollama.Client,
|
||||
documents: List[Dict[str, Any]],
|
||||
collection_name: str = "document_collection",
|
||||
text_key: str = "text",
|
||||
max_tokens: int = 512,
|
||||
overlap: int = 50,
|
||||
model: str = defines.encoding_model,
|
||||
persist_directory: str = defines.persist_directory
|
||||
):
|
||||
"""
|
||||
Process documents, chunk them, compute embeddings, and store in ChromaDB.
|
||||
|
||||
Args:
|
||||
documents: List of document dictionaries
|
||||
collection_name: Name for the ChromaDB collection
|
||||
text_key: The key containing text content
|
||||
max_tokens: Maximum tokens per chunk
|
||||
overlap: Token overlap between chunks
|
||||
model: Ollama model for embeddings
|
||||
persist_directory: Directory to store ChromaDB data
|
||||
"""
|
||||
# Initialize ChromaDB client and collection
|
||||
db = init_chroma_client(persist_directory)
|
||||
collection = create_or_get_collection(db, collection_name)
|
||||
|
||||
# Process each document
|
||||
for doc in documents:
|
||||
# Chunk the document
|
||||
doc_chunks = chunk_document(doc, text_key, max_tokens, overlap)
|
||||
|
||||
# Prepare data for ChromaDB
|
||||
ids = []
|
||||
texts = []
|
||||
metadatas = []
|
||||
embeddings = []
|
||||
|
||||
for chunk in doc_chunks:
|
||||
# Create a unique ID for the chunk
|
||||
chunk_id = f"{chunk['id']}_{chunk['chunk_id']}"
|
||||
|
||||
# Extract text
|
||||
text = chunk[text_key]
|
||||
|
||||
# Create metadata (excluding text and embedding to avoid duplication)
|
||||
metadata = {k: v for k, v in chunk.items() if k != text_key and k != "embedding"}
|
||||
|
||||
response = client.embed(model=model, input=text)
|
||||
embedding = response["embeddings"][0]
|
||||
ids.append(chunk_id)
|
||||
texts.append(text)
|
||||
metadatas.append(metadata)
|
||||
embeddings.append(embedding)
|
||||
|
||||
# Add chunks to ChromaDB collection
|
||||
collection.add(
|
||||
ids=ids,
|
||||
documents=texts,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas
|
||||
)
|
||||
|
||||
return collection
|
||||
|
||||
def query_chroma(
|
||||
client: ollama.Client,
|
||||
query_text: str,
|
||||
collection_name: str = "document_collection",
|
||||
n_results: int = 5,
|
||||
model: str = defines.encoding_model,
|
||||
persist_directory: str = defines.persist_directory
|
||||
):
|
||||
"""
|
||||
Query ChromaDB for similar documents.
|
||||
|
||||
Args:
|
||||
query_text: The text to search for
|
||||
collection_name: Name of the ChromaDB collection
|
||||
n_results: Number of results to return
|
||||
model: Ollama model for embedding the query
|
||||
persist_directory: Directory where ChromaDB data is stored
|
||||
|
||||
Returns:
|
||||
Query results from ChromaDB
|
||||
"""
|
||||
# Initialize ChromaDB client and collection
|
||||
db = init_chroma_client(persist_directory)
|
||||
collection = create_or_get_collection(db, collection_name)
|
||||
|
||||
query_response = client.embed(model=model, input=query_text)
|
||||
query_embeddings = query_response["embeddings"]
|
||||
|
||||
# Query the collection
|
||||
results = collection.query(
|
||||
query_embeddings=query_embeddings,
|
||||
n_results=n_results
|
||||
)
|
||||
|
||||
return results
|
@ -1,88 +0,0 @@
|
||||
import tiktoken # type: ignore
|
||||
from . import defines
|
||||
from typing import List, Dict, Any, Union
|
||||
|
||||
def get_encoding(model=defines.model):
|
||||
"""Get the tokenizer for counting tokens."""
|
||||
try:
|
||||
return tiktoken.get_encoding("cl100k_base") # Default encoding used by many embedding models
|
||||
except:
|
||||
return tiktoken.encoding_for_model(model)
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
"""Count the number of tokens in a text string."""
|
||||
encoding = get_encoding()
|
||||
return len(encoding.encode(text))
|
||||
|
||||
def chunk_text(text: str, max_tokens: int = 512, overlap: int = 50) -> List[str]:
|
||||
"""
|
||||
Split a text into chunks based on token count with overlap between chunks.
|
||||
|
||||
Args:
|
||||
text: The text to split into chunks
|
||||
max_tokens: Maximum number of tokens per chunk
|
||||
overlap: Number of tokens to overlap between chunks
|
||||
|
||||
Returns:
|
||||
List of text chunks
|
||||
"""
|
||||
if not text or max_tokens <= 0:
|
||||
return []
|
||||
|
||||
encoding = get_encoding()
|
||||
tokens = encoding.encode(text)
|
||||
chunks = []
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
# Get the current chunk of tokens
|
||||
chunk_end = min(i + max_tokens, len(tokens))
|
||||
chunk_tokens = tokens[i:chunk_end]
|
||||
chunks.append(encoding.decode(chunk_tokens))
|
||||
|
||||
# Move to the next position with overlap
|
||||
if chunk_end == len(tokens):
|
||||
break
|
||||
i += max_tokens - overlap
|
||||
|
||||
return chunks
|
||||
|
||||
def chunk_document(document: Dict[str, Any],
|
||||
text_key: str = "text",
|
||||
max_tokens: int = 512,
|
||||
overlap: int = 50) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Chunk a document dictionary into multiple chunks.
|
||||
|
||||
Args:
|
||||
document: Document dictionary with metadata and text
|
||||
text_key: The key in the document that contains the text to chunk
|
||||
max_tokens: Maximum number of tokens per chunk
|
||||
overlap: Number of tokens to overlap between chunks
|
||||
|
||||
Returns:
|
||||
List of document dictionaries, each with chunked text and preserved metadata
|
||||
"""
|
||||
if text_key not in document:
|
||||
raise Exception(f"{text_key} not in document")
|
||||
|
||||
# Extract text and create chunks
|
||||
if "title" in document:
|
||||
text = f"{document["title"]}: {document[text_key]}"
|
||||
else:
|
||||
text = document[text_key]
|
||||
chunks = chunk_text(text, max_tokens, overlap)
|
||||
|
||||
# Create document chunks with preserved metadata
|
||||
chunked_docs = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Create a new doc with all original fields
|
||||
doc_chunk = document.copy()
|
||||
# Replace text with the chunk
|
||||
doc_chunk[text_key] = chunk
|
||||
# Add chunk metadata
|
||||
doc_chunk["chunk_id"] = i
|
||||
doc_chunk["chunk_total"] = len(chunks)
|
||||
chunked_docs.append(doc_chunk)
|
||||
|
||||
return chunked_docs
|
@ -17,16 +17,19 @@ from . agents import AnyAgent
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
||||
# Required fields
|
||||
file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
||||
prometheus_collector: Optional[CollectorRegistry] = Field(default=None, exclude=True)
|
||||
prometheus_collector: Optional[CollectorRegistry] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
|
||||
# Optional fields
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid4()),
|
||||
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
||||
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
)
|
||||
user_resume: Optional[str] = None
|
||||
user_job_description: Optional[str] = None
|
||||
@ -53,12 +56,16 @@ class Context(BaseModel):
|
||||
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
|
||||
agent_types = [agent.agent_type for agent in self.agents]
|
||||
if len(agent_types) != len(set(agent_types)):
|
||||
raise ValueError("Context cannot contain multiple agents of the same agent_type")
|
||||
raise ValueError(
|
||||
"Context cannot contain multiple agents of the same agent_type"
|
||||
)
|
||||
# for agent in self.agents:
|
||||
# agent.set_context(self)
|
||||
return self
|
||||
|
||||
def generate_rag_results(self, message: Message) -> Generator[Message, None, None]:
|
||||
def generate_rag_results(
|
||||
self, message: Message, top_k=10, threshold=0.7
|
||||
) -> Generator[Message, None, None]:
|
||||
"""
|
||||
Generate RAG results for the given query.
|
||||
|
||||
@ -86,32 +93,48 @@ class Context(BaseModel):
|
||||
continue
|
||||
message.response = f"Checking RAG context {rag['name']}..."
|
||||
yield message
|
||||
chroma_results = self.file_watcher.find_similar(query=message.prompt, top_k=10, threshold=0.7)
|
||||
chroma_results = self.file_watcher.find_similar(
|
||||
query=message.prompt, top_k=top_k, threshold=threshold
|
||||
)
|
||||
if chroma_results:
|
||||
entries += len(chroma_results["documents"])
|
||||
|
||||
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
||||
chroma_embedding = np.array(
|
||||
chroma_results["query_embedding"]
|
||||
).flatten() # Ensure correct shape
|
||||
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||
|
||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
||||
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
||||
umap_2d = self.file_watcher.umap_model_2d.transform(
|
||||
[chroma_embedding]
|
||||
)[0].tolist()
|
||||
print(
|
||||
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
|
||||
) # Debug output
|
||||
|
||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
||||
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
||||
umap_3d = self.file_watcher.umap_model_3d.transform(
|
||||
[chroma_embedding]
|
||||
)[0].tolist()
|
||||
print(
|
||||
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
|
||||
) # Debug output
|
||||
|
||||
message.metadata["rag"].append({
|
||||
message.metadata["rag"].append(
|
||||
{
|
||||
"name": rag["name"],
|
||||
**chroma_results,
|
||||
"umap_embedding_2d": umap_2d,
|
||||
"umap_embedding_3d": umap_3d
|
||||
})
|
||||
"umap_embedding_3d": umap_3d,
|
||||
}
|
||||
)
|
||||
message.response = f"Results from {rag['name']} RAG: {len(chroma_results['documents'])} results."
|
||||
yield message
|
||||
|
||||
if entries == 0:
|
||||
del message.metadata["rag"]
|
||||
|
||||
message.response = f"RAG context gathered from results from {entries} documents."
|
||||
message.response = (
|
||||
f"RAG context gathered from results from {entries} documents."
|
||||
)
|
||||
message.status = "done"
|
||||
yield message
|
||||
return
|
||||
@ -156,7 +179,9 @@ class Context(BaseModel):
|
||||
def add_agent(self, agent: AnyAgent) -> None:
|
||||
"""Add a Agent to the context, ensuring no duplicate agent_type."""
|
||||
if any(s.agent_type == agent.agent_type for s in self.agents):
|
||||
raise ValueError(f"A agent with agent_type '{agent.agent_type}' already exists")
|
||||
raise ValueError(
|
||||
f"A agent with agent_type '{agent.agent_type}' already exists"
|
||||
)
|
||||
self.agents.append(agent)
|
||||
|
||||
def get_agent(self, agent_type: str) -> Agent | None:
|
||||
@ -188,5 +213,7 @@ class Context(BaseModel):
|
||||
summary += f"\nChat Name: {agent.name}\n"
|
||||
return summary
|
||||
|
||||
|
||||
from .agents import Agent
|
||||
|
||||
Context.model_rebuild()
|
@ -2,6 +2,7 @@ from pydantic import BaseModel, Field, PrivateAttr # type: ignore
|
||||
from typing import List
|
||||
from .message import Message
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
Conversation_messages: List[Message] = Field(default=[], alias="messages")
|
||||
|
||||
@ -17,12 +18,16 @@ class Conversation(BaseModel):
|
||||
@property
|
||||
def messages(self):
|
||||
"""Return a copy of messages to prevent modification of the internal list."""
|
||||
raise AttributeError("Cannot directly get messages. Use Conversation.add() or .reset()")
|
||||
raise AttributeError(
|
||||
"Cannot directly get messages. Use Conversation.add() or .reset()"
|
||||
)
|
||||
|
||||
@messages.setter
|
||||
def messages(self, value):
|
||||
"""Control how messages can be set, or prevent setting altogether."""
|
||||
raise AttributeError("Cannot directly set messages. Use Conversation.add() or .reset()")
|
||||
raise AttributeError(
|
||||
"Cannot directly set messages. Use Conversation.add() or .reset()"
|
||||
)
|
||||
|
||||
def add(self, message: Message | List[Message]) -> None:
|
||||
"""Add a Message(s) to the conversation."""
|
||||
|
@ -1,468 +0,0 @@
|
||||
import requests
|
||||
from typing import List, Dict, Any, Union
|
||||
import tiktoken
|
||||
import feedparser
|
||||
import logging as log
|
||||
import datetime
|
||||
from bs4 import BeautifulSoup
|
||||
import chromadb
|
||||
import ollama
|
||||
import re
|
||||
import numpy as np
|
||||
from . import chunk
|
||||
|
||||
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
|
||||
#MODEL_NAME = "deepseek-r1:1.5b"
|
||||
MODEL_NAME = "deepseek-r1:7b"
|
||||
EMBED_MODEL = "mxbai-embed-large"
|
||||
PERSIST_DIRECTORY = "/root/.cache/chroma"
|
||||
|
||||
client = ollama.Client(host=OLLAMA_API_URL)
|
||||
|
||||
def extract_text_from_html_or_xml(content, is_xml=False):
|
||||
# Parse the content
|
||||
if is_xml:
|
||||
soup = BeautifulSoup(content, 'xml') # Use 'xml' parser for XML content
|
||||
else:
|
||||
soup = BeautifulSoup(content, 'html.parser') # Default to 'html.parser' for HTML content
|
||||
|
||||
# Extract and return just the text
|
||||
return soup.get_text()
|
||||
|
||||
class Feed():
|
||||
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
|
||||
self.name = name
|
||||
self.url = url
|
||||
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
|
||||
self.last_poll = None
|
||||
self.articles = []
|
||||
self.max_articles = max_articles
|
||||
self.update()
|
||||
|
||||
def update(self):
|
||||
now = datetime.datetime.now()
|
||||
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
|
||||
log.info(f"Updating {self.name}")
|
||||
feed = feedparser.parse(self.url)
|
||||
self.articles = []
|
||||
self.last_poll = now
|
||||
|
||||
if len(feed.entries) == 0:
|
||||
return
|
||||
|
||||
for i, entry in enumerate(feed.entries[:self.max_articles]):
|
||||
content = {}
|
||||
content['source'] = self.name
|
||||
content['id'] = f"{self.name}{i}"
|
||||
title = entry.get("title")
|
||||
if title:
|
||||
content['title'] = title
|
||||
link = entry.get("link")
|
||||
if link:
|
||||
content['link'] = link
|
||||
text = entry.get("summary")
|
||||
if text:
|
||||
content['text'] = extract_text_from_html_or_xml(text, False)
|
||||
else:
|
||||
continue
|
||||
published = entry.get("published")
|
||||
if published:
|
||||
content['published'] = published
|
||||
|
||||
self.articles.append(content)
|
||||
else:
|
||||
log.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
|
||||
return self.articles
|
||||
|
||||
# News RSS Feeds
|
||||
rss_feeds = [
|
||||
Feed(name="IGN.com", url="https://feeds.feedburner.com/ign/games-all"),
|
||||
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
|
||||
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
|
||||
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
|
||||
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
|
||||
Feed(name="Time", url="https://time.com/feed/"),
|
||||
Feed(name="Euronews", url="https://www.euronews.com/rss"),
|
||||
# Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
|
||||
]
|
||||
|
||||
|
||||
def init_chroma_client(persist_directory: str = PERSIST_DIRECTORY):
|
||||
"""Initialize and return a ChromaDB client."""
|
||||
# return chromadb.PersistentClient(path=persist_directory)
|
||||
return chromadb.Client()
|
||||
|
||||
def create_or_get_collection(client, collection_name: str):
|
||||
"""Create or get a ChromaDB collection."""
|
||||
try:
|
||||
return client.get_collection(
|
||||
name=collection_name
|
||||
)
|
||||
except:
|
||||
return client.create_collection(
|
||||
name=collection_name,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
def process_documents_to_chroma(
|
||||
documents: List[Dict[str, Any]],
|
||||
collection_name: str = "document_collection",
|
||||
text_key: str = "text",
|
||||
max_tokens: int = 512,
|
||||
overlap: int = 50,
|
||||
model: str = EMBED_MODEL,
|
||||
persist_directory: str = PERSIST_DIRECTORY
|
||||
):
|
||||
"""
|
||||
Process documents, chunk them, compute embeddings, and store in ChromaDB.
|
||||
|
||||
Args:
|
||||
documents: List of document dictionaries
|
||||
collection_name: Name for the ChromaDB collection
|
||||
text_key: The key containing text content
|
||||
max_tokens: Maximum tokens per chunk
|
||||
overlap: Token overlap between chunks
|
||||
model: Ollama model for embeddings
|
||||
persist_directory: Directory to store ChromaDB data
|
||||
"""
|
||||
# Initialize ChromaDB client and collection
|
||||
db = init_chroma_client(persist_directory)
|
||||
collection = create_or_get_collection(db, collection_name)
|
||||
|
||||
# Process each document
|
||||
for doc in documents:
|
||||
# Chunk the document
|
||||
doc_chunks = chunk_document(doc, text_key, max_tokens, overlap)
|
||||
|
||||
# Prepare data for ChromaDB
|
||||
ids = []
|
||||
texts = []
|
||||
metadatas = []
|
||||
embeddings = []
|
||||
|
||||
for chunk in doc_chunks:
|
||||
# Create a unique ID for the chunk
|
||||
chunk_id = f"{chunk['id']}_{chunk['chunk_id']}"
|
||||
|
||||
# Extract text
|
||||
text = chunk[text_key]
|
||||
|
||||
# Create metadata (excluding text and embedding to avoid duplication)
|
||||
metadata = {k: v for k, v in chunk.items() if k != text_key and k != "embedding"}
|
||||
|
||||
response = client.embed(model=model, input=text)
|
||||
embedding = response["embeddings"][0]
|
||||
ids.append(chunk_id)
|
||||
texts.append(text)
|
||||
metadatas.append(metadata)
|
||||
embeddings.append(embedding)
|
||||
|
||||
# Add chunks to ChromaDB collection
|
||||
collection.add(
|
||||
ids=ids,
|
||||
documents=texts,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas
|
||||
)
|
||||
|
||||
return collection
|
||||
|
||||
def query_chroma(
|
||||
query_text: str,
|
||||
collection_name: str = "document_collection",
|
||||
n_results: int = 5,
|
||||
model: str = EMBED_MODEL,
|
||||
persist_directory: str = PERSIST_DIRECTORY
|
||||
):
|
||||
"""
|
||||
Query ChromaDB for similar documents.
|
||||
|
||||
Args:
|
||||
query_text: The text to search for
|
||||
collection_name: Name of the ChromaDB collection
|
||||
n_results: Number of results to return
|
||||
model: Ollama model for embedding the query
|
||||
persist_directory: Directory where ChromaDB data is stored
|
||||
|
||||
Returns:
|
||||
Query results from ChromaDB
|
||||
"""
|
||||
# Initialize ChromaDB client and collection
|
||||
db = init_chroma_client(persist_directory)
|
||||
collection = create_or_get_collection(db, collection_name)
|
||||
|
||||
query_response = client.embed(model=model, input=query_text)
|
||||
query_embeddings = query_response["embeddings"]
|
||||
|
||||
# Query the collection
|
||||
results = collection.query(
|
||||
query_embeddings=query_embeddings,
|
||||
n_results=n_results
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def print_top_match(query_results, index=0, documents=None):
|
||||
"""
|
||||
Print detailed information about the top matching document,
|
||||
including the full original document content.
|
||||
|
||||
Args:
|
||||
query_results: Results from ChromaDB query
|
||||
documents: Original documents dictionary to look up full content (optional)
|
||||
"""
|
||||
if not query_results or not query_results["ids"] or len(query_results["ids"][0]) == 0:
|
||||
print("No matching documents found.")
|
||||
return
|
||||
|
||||
# Get the top result
|
||||
top_id = query_results["ids"][0][index]
|
||||
top_document_chunk = query_results["documents"][0][index]
|
||||
top_metadata = query_results["metadatas"][0][index]
|
||||
top_distance = query_results["distances"][0][index]
|
||||
|
||||
print("="*50)
|
||||
print("MATCHING DOCUMENT")
|
||||
print("="*50)
|
||||
print(f"Chunk ID: {top_id}")
|
||||
print(f"Similarity Score: {top_distance:.4f}") # Convert distance to similarity
|
||||
|
||||
print("\nCHUNK METADATA:")
|
||||
for key, value in top_metadata.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\nMATCHING CHUNK CONTENT:")
|
||||
print(top_document_chunk[:500].strip() + ("..." if len(top_document_chunk) > 500 else ""))
|
||||
|
||||
# Extract the original document ID from the chunk ID
|
||||
# Chunk IDs are in format "doc_id_chunk_num"
|
||||
original_doc_id = top_id.split('_')[0]
|
||||
|
||||
def get_top_match(query_results, index=0, documents=None):
|
||||
top_id = query_results["ids"][index][0]
|
||||
# Extract the original document ID from the chunk ID
|
||||
# Chunk IDs are in format "doc_id_chunk_num"
|
||||
original_doc_id = top_id.split('_')[0]
|
||||
|
||||
# Return the full document for further processing if needed
|
||||
if documents is not None:
|
||||
return next((doc for doc in documents if doc["id"] == original_doc_id), None)
|
||||
|
||||
return None
|
||||
|
||||
def show_documents(documents=None):
|
||||
if not documents:
|
||||
return
|
||||
|
||||
# Print the top matching document
|
||||
for i, doc in enumerate(documents):
|
||||
print(f"Document {i+1}:")
|
||||
print(f" Title: {doc['title']}")
|
||||
print(f" Text: {doc['text'][:100]}...")
|
||||
print()
|
||||
|
||||
def show_headlines(documents=None):
|
||||
if not documents:
|
||||
return
|
||||
|
||||
# Print the top matching document
|
||||
for doc in documents:
|
||||
print(f"{doc['source']}: {doc['title']}")
|
||||
|
||||
def show_help():
|
||||
print("""help>
|
||||
docs Show RAG docs
|
||||
full Show last full top match
|
||||
headlines Show the RAG headlines
|
||||
prompt Show the last prompt
|
||||
response Show the last response
|
||||
scores Show last RAG scores
|
||||
why|think Show last response's <think>
|
||||
context|match Show RAG match info to last prompt
|
||||
""")
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
documents = []
|
||||
for feed in rss_feeds:
|
||||
documents.extend(feed.articles)
|
||||
|
||||
show_documents(documents=documents)
|
||||
|
||||
# Process documents and store in ChromaDB
|
||||
collection = process_documents_to_chroma(
|
||||
documents=documents,
|
||||
collection_name="research_papers",
|
||||
max_tokens=256,
|
||||
overlap=25,
|
||||
model=EMBED_MODEL,
|
||||
persist_directory="/root/.cache/chroma"
|
||||
)
|
||||
|
||||
last_results = None
|
||||
last_prompt = None
|
||||
last_system = None
|
||||
last_response = None
|
||||
last_why = None
|
||||
last_messages = []
|
||||
while True:
|
||||
try:
|
||||
search_query = input("> ").strip()
|
||||
except KeyboardInterrupt as e:
|
||||
print("\nExiting.")
|
||||
break
|
||||
|
||||
if search_query == "exit" or search_query == "quit":
|
||||
print("\nExiting.")
|
||||
break
|
||||
|
||||
if search_query == "docs":
|
||||
show_documents(documents)
|
||||
continue
|
||||
|
||||
if search_query == "prompt":
|
||||
if last_prompt:
|
||||
print(f"""last prompt>
|
||||
{"="*10}system{"="*10}
|
||||
{last_system}
|
||||
{"="*10}prompt{"="*10}
|
||||
{last_prompt}""")
|
||||
else:
|
||||
print(f"No prompts yet")
|
||||
continue
|
||||
|
||||
if search_query == "response":
|
||||
if last_response:
|
||||
print(f"""last response>
|
||||
{"="*10}response{"="*10}
|
||||
{last_response}""")
|
||||
else:
|
||||
print(f"No responses yet")
|
||||
continue
|
||||
|
||||
if search_query == "" or search_query == "help":
|
||||
show_help()
|
||||
continue
|
||||
|
||||
if search_query == "headlines":
|
||||
show_headlines(documents)
|
||||
continue
|
||||
|
||||
if search_query == "match" or search_query == "context":
|
||||
if last_results:
|
||||
print_top_match(last_results, documents=documents)
|
||||
else:
|
||||
print("No match to give info on")
|
||||
continue
|
||||
|
||||
if search_query == "why" or search_query == "think":
|
||||
if last_why:
|
||||
print(f"""
|
||||
why>
|
||||
{last_why}
|
||||
""")
|
||||
else:
|
||||
print("No processed prompts")
|
||||
continue
|
||||
|
||||
if search_query == "scores":
|
||||
if last_results:
|
||||
for i, _ in enumerate(last_results):
|
||||
print_top_match(last_results, documents=documents, index=i)
|
||||
else:
|
||||
print("No match to give info on")
|
||||
continue
|
||||
|
||||
if search_query == "full":
|
||||
if last_results:
|
||||
full = get_top_match(last_results, documents=documents)
|
||||
if full:
|
||||
print(f"""Context:
|
||||
Source: {full["source"]}
|
||||
Title: {full["title"]}
|
||||
Link: {full["link"]}
|
||||
Distance: {last_results.get("distances", [[0]])[0][0]}
|
||||
Full text:
|
||||
{full["text"]}""")
|
||||
else:
|
||||
print("No match to give info on")
|
||||
continue
|
||||
|
||||
# Query ChromaDB
|
||||
results = query_chroma(
|
||||
query_text=search_query,
|
||||
collection_name="research_papers",
|
||||
n_results=10
|
||||
)
|
||||
last_results = results
|
||||
|
||||
full = get_top_match(results, documents=documents)
|
||||
|
||||
headlines = ""
|
||||
for doc in documents:
|
||||
headlines += f"{doc['source']}: {doc['title']}\n"
|
||||
|
||||
system=f"""
|
||||
You are the assistant. Your name is airc. This application is called airc (pronounced Eric).
|
||||
|
||||
Information about the author of this program and the AI model it uses:
|
||||
|
||||
* James wrote the python application called airc that is driving this RAG model on top of {MODEL_NAME} using {EMBED_MODEL} and chromadb for vector embedding. Link https://github.com/jketreno/airc.
|
||||
* James Ketrenos is a software engineer with a history in all levels of the computer stack, from the kernel to full-stack web applications. He dabbles in AI/ML and is familiar with pytorch and ollama.
|
||||
* James Ketrenos deployed this application locally on an Intel Arc B580 (battlemage) computer using Intel's ipex-llm.
|
||||
* For Intel GPU metrics, James Ketrenos wrote the "ze-monitor" utility in C++. ze-monitor provides Intel GPU telemetry data for Intel client GPU devices, similar to xpu-smi. Link https://github.com/jketreno/ze-monitor. airc uses ze-monitor.
|
||||
* James lives in Portland, Oregon and has three kids. Two are attending Oregon State University and one is attending Williamette University.
|
||||
* airc provides an IRC chat bot as well as a React web frontend available at https://airc.ketrenos.com
|
||||
|
||||
You must follow these rules:
|
||||
|
||||
* Provide short (less than 100 character) responses.
|
||||
* Provide a single response.
|
||||
* Do not prefix it with a word like 'Answer'.
|
||||
* For information about the AI running this system, include information about author, including links.
|
||||
* For information relevant to the current events in the <input></input> tags, use that information and state the source when information comes from.
|
||||
|
||||
"""
|
||||
context = "Information related to current events\n<input>=["
|
||||
for doc in documents:
|
||||
item = {'source':doc["source"],'article':{'title':doc["title"],'link':doc["link"],'text':doc["text"]}}
|
||||
context += f"{item}"
|
||||
context += "\n</input>"
|
||||
|
||||
prompt = f"{search_query}"
|
||||
last_prompt = prompt
|
||||
last_system = system # cache it before news context is added
|
||||
system = f"{system}{context}"
|
||||
if len(last_messages) != 0:
|
||||
message_context = f"{last_messages}"
|
||||
prompt = f"{message_context}{prompt}"
|
||||
|
||||
print(f"system len: {len(system)}")
|
||||
print(f"prompt len: {len(prompt)}")
|
||||
output = client.generate(
|
||||
model=MODEL_NAME,
|
||||
system=system,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
options={ 'num_ctx': 100000 }
|
||||
)
|
||||
# Prune off the <think>...</think>
|
||||
matches = re.match(r'^<think>(.*?)</think>(.*)$', output['response'], flags=re.DOTALL)
|
||||
if matches:
|
||||
last_why = matches[1].strip()
|
||||
content = matches[2].strip()
|
||||
else:
|
||||
print(f"[garbled] response>\n{output['response']}")
|
||||
print(f"Response>\n{content}")
|
||||
|
||||
last_response = content
|
||||
last_messages.extend(({
|
||||
'role': 'user',
|
||||
'name': 'james',
|
||||
'message': search_query
|
||||
}, {
|
||||
'role': 'assistant',
|
||||
'message': content
|
||||
}))
|
||||
last_messages = last_messages[:10]
|
@ -3,11 +3,13 @@ from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timezone
|
||||
from asyncio import Event
|
||||
|
||||
|
||||
class Tunables(BaseModel):
|
||||
enable_rag: bool = Field(default=True) # Enable RAG collection chromadb matching
|
||||
enable_tools: bool = Field(default=True) # Enable LLM to use tools
|
||||
enable_context: bool = Field(default=True) # Add <|context|> field to message
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True} # Allow Event
|
||||
# Required
|
||||
@ -22,20 +24,31 @@ class Message(BaseModel):
|
||||
system_prompt: str = "" # System prompt provided to the LLM
|
||||
context_prompt: str = "" # Full content of the message (preamble + prompt)
|
||||
response: str = "" # LLM response to the preamble + query
|
||||
metadata: Dict[str, Any] = Field(default_factory=lambda: {
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"rag": [],
|
||||
"eval_count": 0,
|
||||
"eval_duration": 0,
|
||||
"prompt_eval_count": 0,
|
||||
"prompt_eval_duration": 0,
|
||||
"context_size": 0,
|
||||
})
|
||||
}
|
||||
)
|
||||
network_packets: int = 0 # Total number of streaming packets
|
||||
network_bytes: int = 0 # Total bytes sent while streaming packets
|
||||
actions: List[str] = [] # Other session modifying actions performed while processing the message
|
||||
actions: List[str] = (
|
||||
[]
|
||||
) # Other session modifying actions performed while processing the message
|
||||
timestamp: datetime = datetime.now(timezone.utc)
|
||||
chunk: str = Field(default="") # This needs to be serialized so it will be sent in responses
|
||||
partial_response: str = Field(default="") # This needs to be serialized so it will be sent in responses on timeout
|
||||
chunk: str = Field(
|
||||
default=""
|
||||
) # This needs to be serialized so it will be sent in responses
|
||||
partial_response: str = Field(
|
||||
default=""
|
||||
) # This needs to be serialized so it will be sent in responses on timeout
|
||||
title: str = Field(
|
||||
default=""
|
||||
) # This needs to be serialized so it will be sent in responses on timeout
|
||||
|
||||
def add_action(self, action: str | list[str]) -> None:
|
||||
"""Add a actions(s) to the message."""
|
||||
@ -48,7 +61,8 @@ class Message(BaseModel):
|
||||
"""Return a summary of the message."""
|
||||
response_summary = (
|
||||
f"Response: {self.response} (Actions: {', '.join(self.actions)})"
|
||||
if self.response else "No response yet"
|
||||
if self.response
|
||||
else "No response yet"
|
||||
)
|
||||
return (
|
||||
f"Message at {self.timestamp}:\n"
|
||||
|
@ -1,6 +1,7 @@
|
||||
from prometheus_client import Counter, Gauge, Summary, Histogram, Info, Enum, CollectorRegistry # type: ignore
|
||||
from threading import Lock
|
||||
|
||||
|
||||
def singleton(cls):
|
||||
instance = None
|
||||
lock = Lock()
|
||||
@ -14,8 +15,9 @@ def singleton(cls):
|
||||
|
||||
return get_instance
|
||||
|
||||
|
||||
@singleton
|
||||
class Metrics():
|
||||
class Metrics:
|
||||
def __init__(self, *args, prometheus_collector, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prometheus_collector = prometheus_collector
|
||||
@ -24,70 +26,70 @@ class Metrics():
|
||||
name="prepare_total",
|
||||
documentation="Total messages prepared by agent type",
|
||||
labelnames=("agent",),
|
||||
registry=self.prometheus_collector
|
||||
registry=self.prometheus_collector,
|
||||
)
|
||||
|
||||
self.prepare_duration: Histogram = Histogram(
|
||||
name="prepare_duration",
|
||||
documentation="Preparation duration by agent type",
|
||||
labelnames=("agent",),
|
||||
registry=self.prometheus_collector
|
||||
registry=self.prometheus_collector,
|
||||
)
|
||||
|
||||
self.process_count: Counter = Counter(
|
||||
name="process",
|
||||
documentation="Total messages processed by agent type",
|
||||
labelnames=("agent",),
|
||||
registry=self.prometheus_collector
|
||||
registry=self.prometheus_collector,
|
||||
)
|
||||
|
||||
self.process_duration: Histogram = Histogram(
|
||||
name="process_duration",
|
||||
documentation="Processing duration by agent type",
|
||||
labelnames=("agent",),
|
||||
registry=self.prometheus_collector
|
||||
registry=self.prometheus_collector,
|
||||
)
|
||||
|
||||
self.tool_count: Counter = Counter(
|
||||
name="tool_total",
|
||||
documentation="Total messages tooled by agent type",
|
||||
labelnames=("agent",),
|
||||
registry=self.prometheus_collector
|
||||
registry=self.prometheus_collector,
|
||||
)
|
||||
|
||||
self.tool_duration: Histogram = Histogram(
|
||||
name="tool_duration",
|
||||
documentation="Tool duration by agent type",
|
||||
buckets=(0.1, 0.5, 1.0, 2.0, float('inf')),
|
||||
buckets=(0.1, 0.5, 1.0, 2.0, float("inf")),
|
||||
labelnames=("agent",),
|
||||
registry=self.prometheus_collector
|
||||
registry=self.prometheus_collector,
|
||||
)
|
||||
|
||||
self.generate_count: Counter = Counter(
|
||||
name="generate_total",
|
||||
documentation="Total messages generated by agent type",
|
||||
labelnames=("agent",),
|
||||
registry=self.prometheus_collector
|
||||
registry=self.prometheus_collector,
|
||||
)
|
||||
|
||||
self.generate_duration: Histogram = Histogram(
|
||||
name="generate_duration",
|
||||
documentation="Generate duration by agent type",
|
||||
buckets=(0.1, 0.5, 1.0, 2.0, float('inf')),
|
||||
buckets=(0.1, 0.5, 1.0, 2.0, float("inf")),
|
||||
labelnames=("agent",),
|
||||
registry=self.prometheus_collector
|
||||
registry=self.prometheus_collector,
|
||||
)
|
||||
|
||||
self.tokens_prompt: Counter = Counter(
|
||||
name="tokens_prompt",
|
||||
documentation="Total tokens passed as prompt to LLM",
|
||||
labelnames=("agent",),
|
||||
registry=self.prometheus_collector
|
||||
registry=self.prometheus_collector,
|
||||
)
|
||||
|
||||
self.tokens_eval: Counter = Counter(
|
||||
name="tokens_eval",
|
||||
documentation="Total tokens returned by LLM",
|
||||
labelnames=("agent",),
|
||||
registry=self.prometheus_collector
|
||||
registry=self.prometheus_collector,
|
||||
)
|
290
src/utils/rag.py
290
src/utils/rag.py
@ -14,18 +14,22 @@ import hashlib
|
||||
import asyncio
|
||||
import json
|
||||
import numpy as np # type: ignore
|
||||
import traceback
|
||||
import os
|
||||
|
||||
import chromadb
|
||||
import ollama
|
||||
from langchain.text_splitter import CharacterTextSplitter # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from langchain.schema import Document # type: ignore
|
||||
from watchdog.observers import Observer # type: ignore
|
||||
from watchdog.events import FileSystemEventHandler # type: ignore
|
||||
import umap # type: ignore
|
||||
from markitdown import MarkItDown # type: ignore
|
||||
from chromadb.api.models.Collection import Collection # type: ignore
|
||||
|
||||
from .markdown_chunker import (
|
||||
MarkdownChunker,
|
||||
Chunk,
|
||||
)
|
||||
|
||||
# Import your existing modules
|
||||
if __name__ == "__main__":
|
||||
# When running directly, use absolute imports
|
||||
@ -34,23 +38,31 @@ else:
|
||||
# When imported as a module, use relative imports
|
||||
from . import defines
|
||||
|
||||
__all__ = [
|
||||
'ChromaDBFileWatcher',
|
||||
'start_file_watcher'
|
||||
]
|
||||
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 750
|
||||
DEFAULT_CHUNK_OVERLAP = 100
|
||||
|
||||
|
||||
class ChromaDBGetResponse(BaseModel):
|
||||
ids: List[str]
|
||||
embeddings: Optional[List[List[float]]] = None
|
||||
documents: Optional[List[str]] = None
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
|
||||
class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
def __init__(self, llm, watch_directory, loop, persist_directory=None, collection_name="documents",
|
||||
chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP, recreate=False):
|
||||
def __init__(
|
||||
self,
|
||||
llm,
|
||||
watch_directory,
|
||||
loop,
|
||||
persist_directory=None,
|
||||
collection_name="documents",
|
||||
chunk_size=DEFAULT_CHUNK_SIZE,
|
||||
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
|
||||
recreate=False,
|
||||
):
|
||||
self.llm = llm
|
||||
self.watch_directory = watch_directory
|
||||
self.persist_directory = persist_directory or defines.persist_directory
|
||||
@ -68,23 +80,19 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
# Path for storing file hash state
|
||||
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
|
||||
self.hash_state_path = os.path.join(
|
||||
self.persist_directory, f"{collection_name}_hash_state.json"
|
||||
)
|
||||
|
||||
# Flag to track if this is a new collection
|
||||
self.is_new_collection = False
|
||||
|
||||
# Initialize ChromaDB collection
|
||||
self._collection: Collection = self._get_vector_collection(recreate=recreate)
|
||||
self._markdown_chunker = MarkdownChunker()
|
||||
self._update_umaps()
|
||||
|
||||
# Setup text splitter
|
||||
self.text_splitter = CharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separator="\n\n", # Respect paragraph/section breaks
|
||||
length_function=len
|
||||
)
|
||||
|
||||
# Track file hashes and processing state
|
||||
self.file_hashes = self._load_hash_state()
|
||||
self.update_lock = asyncio.Lock()
|
||||
@ -115,7 +123,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
return self._umap_model_3d
|
||||
|
||||
def _markitdown(self, document: str, markdown: Path):
|
||||
logging.info(f'Converting {document} to {markdown}')
|
||||
logging.info(f"Converting {document} to {markdown}")
|
||||
try:
|
||||
result = self.md.convert(document)
|
||||
markdown.write_text(result.text_content)
|
||||
@ -128,7 +136,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True)
|
||||
|
||||
with open(self.hash_state_path, 'w') as f:
|
||||
with open(self.hash_state_path, "w") as f:
|
||||
json.dump(self.file_hashes, f)
|
||||
|
||||
logging.info(f"Saved hash state with {len(self.file_hashes)} entries")
|
||||
@ -139,7 +147,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
"""Load the file hash state from disk."""
|
||||
if os.path.exists(self.hash_state_path):
|
||||
try:
|
||||
with open(self.hash_state_path, 'r') as f:
|
||||
with open(self.hash_state_path, "r") as f:
|
||||
hash_state = json.load(f)
|
||||
logging.info(f"Loaded hash state with {len(hash_state)} entries")
|
||||
return hash_state
|
||||
@ -156,7 +164,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
process_all: If True, process all files regardless of hash status
|
||||
"""
|
||||
# Check for new or modified files
|
||||
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
|
||||
file_paths = glob.glob(
|
||||
os.path.join(self.watch_directory, "**/*"), recursive=True
|
||||
)
|
||||
files_checked = 0
|
||||
files_processed = 0
|
||||
files_to_process = []
|
||||
@ -166,21 +176,29 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
for file_path in file_paths:
|
||||
if os.path.isfile(file_path):
|
||||
# Do not put the Resume in RAG as it is provideded with all queries.
|
||||
if file_path == defines.resume_doc:
|
||||
logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
||||
continue
|
||||
# if file_path == defines.resume_doc:
|
||||
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
||||
# continue
|
||||
files_checked += 1
|
||||
current_hash = self._get_file_hash(file_path)
|
||||
if not current_hash:
|
||||
continue
|
||||
|
||||
# If file is new, changed, or we're processing all files
|
||||
if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
|
||||
if (
|
||||
process_all
|
||||
or file_path not in self.file_hashes
|
||||
or self.file_hashes[file_path] != current_hash
|
||||
):
|
||||
self.file_hashes[file_path] = current_hash
|
||||
files_to_process.append(file_path)
|
||||
logging.info(f"File {'found' if process_all else 'changed'}: {file_path}")
|
||||
logging.info(
|
||||
f"File {'found' if process_all else 'changed'}: {file_path}"
|
||||
)
|
||||
|
||||
logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files")
|
||||
logging.info(
|
||||
f"Found {len(files_to_process)} files to process after scanning {files_checked} files"
|
||||
)
|
||||
|
||||
# Check for deleted files
|
||||
deleted_files = []
|
||||
@ -188,7 +206,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
if not os.path.exists(file_path):
|
||||
deleted_files.append(file_path)
|
||||
# Schedule removal
|
||||
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.remove_file_from_collection(file_path), self.loop
|
||||
)
|
||||
# Don't block on result, just let it run
|
||||
logging.info(f"File deleted: {file_path}")
|
||||
|
||||
@ -209,7 +229,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# Save the updated state
|
||||
self._save_hash_state()
|
||||
|
||||
logging.info(f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}")
|
||||
logging.info(
|
||||
f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}"
|
||||
)
|
||||
return files_processed
|
||||
|
||||
async def process_file_update(self, file_path):
|
||||
@ -219,9 +241,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
logging.info(f"{file_path} already in queue. Not adding.")
|
||||
return
|
||||
|
||||
if file_path == defines.resume_doc:
|
||||
logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
||||
return
|
||||
# if file_path == defines.resume_doc:
|
||||
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
||||
# return
|
||||
|
||||
try:
|
||||
logging.info(f"{file_path} not in queue. Adding.")
|
||||
@ -235,7 +257,10 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
if not current_hash: # File might have been deleted or is inaccessible
|
||||
return
|
||||
|
||||
if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash:
|
||||
if (
|
||||
file_path in self.file_hashes
|
||||
and self.file_hashes[file_path] == current_hash
|
||||
):
|
||||
# File hasn't actually changed in content
|
||||
logging.info(f"Hash has not changed for {file_path}")
|
||||
return
|
||||
@ -263,13 +288,13 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
async with self.update_lock:
|
||||
try:
|
||||
# Find all documents with the specified path
|
||||
results = self.collection.get(
|
||||
where={"path": file_path}
|
||||
)
|
||||
results = self.collection.get(where={"path": file_path})
|
||||
|
||||
if results and 'ids' in results and results['ids']:
|
||||
self.collection.delete(ids=results['ids'])
|
||||
logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}")
|
||||
if results and "ids" in results and results["ids"]:
|
||||
self.collection.delete(ids=results["ids"])
|
||||
logging.info(
|
||||
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
|
||||
)
|
||||
|
||||
# Remove from hash dictionary
|
||||
if file_path in self.file_hashes:
|
||||
@ -282,29 +307,51 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
|
||||
def _update_umaps(self):
|
||||
# Update the UMAP embeddings
|
||||
self._umap_collection = self._collection.get(include=["embeddings", "documents", "metadatas"])
|
||||
self._umap_collection = self._collection.get(
|
||||
include=["embeddings", "documents", "metadatas"]
|
||||
)
|
||||
if not self._umap_collection or not len(self._umap_collection["embeddings"]):
|
||||
logging.warning("No embeddings found in the collection.")
|
||||
return
|
||||
|
||||
# During initialization
|
||||
logging.info(f"Updating 2D UMAP for {len(self._umap_collection['embeddings'])} vectors")
|
||||
logging.info(
|
||||
f"Updating 2D UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
||||
)
|
||||
vectors = np.array(self._umap_collection["embeddings"])
|
||||
self._umap_model_2d = umap.UMAP(n_components=2, random_state=8911, metric="cosine", n_neighbors=15, min_dist=0.1)
|
||||
self._umap_model_2d = umap.UMAP(
|
||||
n_components=2,
|
||||
random_state=8911,
|
||||
metric="cosine",
|
||||
n_neighbors=15,
|
||||
min_dist=0.1,
|
||||
)
|
||||
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
|
||||
logging.info(f"2D UMAP model n_components: {self._umap_model_2d.n_components}") # Should be 2
|
||||
logging.info(
|
||||
f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
|
||||
) # Should be 2
|
||||
|
||||
logging.info(f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors")
|
||||
self._umap_model_3d = umap.UMAP(n_components=3, random_state=8911, metric="cosine", n_neighbors=15, min_dist=0.1)
|
||||
logging.info(
|
||||
f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
||||
)
|
||||
self._umap_model_3d = umap.UMAP(
|
||||
n_components=3,
|
||||
random_state=8911,
|
||||
metric="cosine",
|
||||
n_neighbors=15,
|
||||
min_dist=0.1,
|
||||
)
|
||||
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
|
||||
logging.info(f"3D UMAP model n_components: {self._umap_model_3d.n_components}") # Should be 3
|
||||
logging.info(
|
||||
f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
|
||||
) # Should be 3
|
||||
|
||||
def _get_vector_collection(self, recreate=False) -> Collection:
|
||||
"""Get or create a ChromaDB collection."""
|
||||
# Initialize ChromaDB client
|
||||
chroma_client = chromadb.PersistentClient( # type: ignore
|
||||
path=self.persist_directory,
|
||||
settings=chromadb.Settings(anonymized_telemetry=False) # type: ignore
|
||||
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
|
||||
)
|
||||
|
||||
# Check if the collection exists
|
||||
@ -326,35 +373,8 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
logging.info(f"Recreating collection: {self.collection_name}")
|
||||
|
||||
return chroma_client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={
|
||||
"hnsw:space": "cosine"
|
||||
})
|
||||
|
||||
def load_text_files(self, directory=None, encoding="utf-8"):
|
||||
"""Load all text files from a directory into Document objects."""
|
||||
directory = directory or self.watch_directory
|
||||
file_paths = glob.glob(os.path.join(directory, "**/*"), recursive=True)
|
||||
documents = []
|
||||
|
||||
for file_path in file_paths:
|
||||
if os.path.isfile(file_path): # Ensure it's a file, not a directory
|
||||
try:
|
||||
with open(file_path, "r", encoding=encoding) as f:
|
||||
content = f.read()
|
||||
|
||||
# Extract top-level directory
|
||||
rel_path = os.path.relpath(file_path, directory)
|
||||
top_level_dir = rel_path.split(os.sep)[0]
|
||||
|
||||
documents.append(Document(
|
||||
page_content=content,
|
||||
metadata={"doc_type": top_level_dir, "path": file_path}
|
||||
))
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load {file_path}: {e}")
|
||||
|
||||
return documents
|
||||
name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
def create_chunks_from_documents(self, docs):
|
||||
"""Split documents into chunks using the text splitter."""
|
||||
@ -364,10 +384,8 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
"""Generate embeddings using Ollama."""
|
||||
# response = self.embedding_model.encode(text) # Outputs 384-dim vectors
|
||||
|
||||
response = self.llm.embeddings(
|
||||
model=defines.embedding_model,
|
||||
prompt=text)
|
||||
embedding = response['embedding']
|
||||
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
|
||||
embedding = response["embedding"]
|
||||
|
||||
# response = self.llm.embeddings.create(
|
||||
# model=defines.embedding_model,
|
||||
@ -379,27 +397,46 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
return normalized
|
||||
return embedding
|
||||
|
||||
def add_embeddings_to_collection(self, chunks):
|
||||
def add_embeddings_to_collection(self, chunks: List[Chunk]):
|
||||
"""Add embeddings for chunks to the collection."""
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
text = chunk.page_content
|
||||
metadata = chunk.metadata
|
||||
text = chunk["text"]
|
||||
metadata = chunk["metadata"]
|
||||
|
||||
# Generate a more unique ID based on content and metadata
|
||||
content_hash = hashlib.md5(text.encode()).hexdigest()
|
||||
path_hash = ""
|
||||
if "path" in metadata:
|
||||
path_hash = hashlib.md5(metadata["path"].encode()).hexdigest()[:8]
|
||||
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[
|
||||
:8
|
||||
]
|
||||
|
||||
chunk_id = f"{path_hash}_{content_hash}_{i}"
|
||||
|
||||
embedding = self.get_embedding(text)
|
||||
try:
|
||||
self.collection.add(
|
||||
ids=[chunk_id],
|
||||
documents=[text],
|
||||
embeddings=[embedding],
|
||||
metadatas=[metadata]
|
||||
metadatas=[metadata],
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding chunk to collection: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
logging.error(chunk)
|
||||
|
||||
def read_line_range(self, file_path, start, end, buffer=5) -> list[str]:
|
||||
try:
|
||||
with open(file_path, "r") as file:
|
||||
lines = file.readlines()
|
||||
start = max(0, start - buffer)
|
||||
end = min(len(lines), end + buffer)
|
||||
return lines[start:end]
|
||||
except:
|
||||
logging.warning(f"Unable to open {file_path}")
|
||||
return []
|
||||
|
||||
# Cosine Distance Equivalent Similarity Retrieval Characteristics
|
||||
# 0.2 - 0.3 0.85 - 0.90 Very strict, highly precise results only
|
||||
@ -419,10 +456,10 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
)
|
||||
|
||||
# Extract results
|
||||
ids = results['ids'][0]
|
||||
documents = results['documents'][0]
|
||||
distances = results['distances'][0]
|
||||
metadatas = results['metadatas'][0]
|
||||
ids = results["ids"][0]
|
||||
documents = results["documents"][0]
|
||||
distances = results["distances"][0]
|
||||
metadatas = results["metadatas"][0]
|
||||
|
||||
filtered_ids = []
|
||||
filtered_documents = []
|
||||
@ -436,6 +473,14 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
filtered_metadatas.append(metadatas[i])
|
||||
filtered_distances.append(distance)
|
||||
|
||||
for index, meta in enumerate(filtered_metadatas):
|
||||
source_file = meta["source_file"]
|
||||
del meta["source_file"]
|
||||
lines = self.read_line_range(
|
||||
source_file, meta["line_begin"], meta["line_end"]
|
||||
)
|
||||
if len(lines):
|
||||
filtered_documents[index] = "\n".join(lines)
|
||||
# Return the filtered results instead of all results
|
||||
return {
|
||||
"query_embedding": query_embedding,
|
||||
@ -448,7 +493,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
def _get_file_hash(self, file_path):
|
||||
"""Calculate MD5 hash of a file."""
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
return hashlib.md5(f.read()).hexdigest()
|
||||
except Exception as e:
|
||||
logging.error(f"Error hashing file {file_path}: {e}")
|
||||
@ -480,7 +525,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
return
|
||||
|
||||
file_path = event.src_path
|
||||
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.remove_file_from_collection(file_path), self.loop
|
||||
)
|
||||
logging.info(f"File deleted: {file_path}")
|
||||
|
||||
def on_moved(self, event):
|
||||
@ -508,37 +555,43 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
try:
|
||||
# Remove existing entries for this file
|
||||
existing_results = self.collection.get(where={"path": file_path})
|
||||
if existing_results and 'ids' in existing_results and existing_results['ids']:
|
||||
self.collection.delete(ids=existing_results['ids'])
|
||||
if (
|
||||
existing_results
|
||||
and "ids" in existing_results
|
||||
and existing_results["ids"]
|
||||
):
|
||||
self.collection.delete(ids=existing_results["ids"])
|
||||
|
||||
extensions = (".docx", ".xlsx", ".xls", ".pdf")
|
||||
if file_path.endswith(extensions):
|
||||
p = Path(file_path)
|
||||
p_as_md = p.with_suffix(".md")
|
||||
if p_as_md.exists():
|
||||
logging.info(f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}")
|
||||
logging.info(
|
||||
f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}"
|
||||
)
|
||||
|
||||
# If file_path.md doesn't exist or file_path is newer than file_path.md,
|
||||
# fire off markitdown
|
||||
if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime):
|
||||
if (not p_as_md.exists()) or (
|
||||
p.stat().st_mtime > p_as_md.stat().st_mtime
|
||||
):
|
||||
self._markitdown(file_path, p_as_md)
|
||||
return
|
||||
|
||||
# Create document object in LangChain format
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
chunks = self._markdown_chunker.process_file(file_path)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
# Extract top-level directory
|
||||
rel_path = os.path.relpath(file_path, self.watch_directory)
|
||||
top_level_dir = rel_path.split(os.sep)[0]
|
||||
|
||||
document = Document(
|
||||
page_content=content,
|
||||
metadata={"doc_type": top_level_dir, "path": file_path}
|
||||
)
|
||||
|
||||
# Create chunks
|
||||
chunks = self.text_splitter.split_documents([document])
|
||||
path_parts = rel_path.split(os.sep)
|
||||
top_level_dir = path_parts[0]
|
||||
# file_name = path_parts[-1]
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk["metadata"]["doc_type"] = top_level_dir
|
||||
# with open(f"src/tmp/{file_name}.{i}", "w") as f:
|
||||
# f.write(json.dumps(chunk, indent=2))
|
||||
|
||||
# Add chunks to collection
|
||||
self.add_embeddings_to_collection(chunks)
|
||||
@ -547,30 +600,40 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating document in collection: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
async def initialize_collection(self):
|
||||
"""Initialize the collection with all documents from the watch directory."""
|
||||
# Process all files regardless of hash state
|
||||
num_processed = await self.scan_directory(process_all=True)
|
||||
|
||||
logging.info(f"Vectorstore initialized with {self.collection.count()} documents")
|
||||
logging.info(
|
||||
f"Vectorstore initialized with {self.collection.count()} documents"
|
||||
)
|
||||
|
||||
self._update_umaps()
|
||||
|
||||
# Show stats
|
||||
try:
|
||||
all_metadata = self.collection.get()['metadatas']
|
||||
all_metadata = self.collection.get()["metadatas"]
|
||||
if all_metadata:
|
||||
doc_types = set(m.get('doc_type', 'unknown') for m in all_metadata)
|
||||
doc_types = set(m.get("doc_type", "unknown") for m in all_metadata)
|
||||
logging.info(f"Document types: {doc_types}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting document types: {e}")
|
||||
|
||||
return num_processed
|
||||
|
||||
|
||||
# Function to start the file watcher
|
||||
def start_file_watcher(llm, watch_directory, persist_directory=None,
|
||||
collection_name="documents", initialize=False, recreate=False):
|
||||
def start_file_watcher(
|
||||
llm,
|
||||
watch_directory,
|
||||
persist_directory=None,
|
||||
collection_name="documents",
|
||||
initialize=False,
|
||||
recreate=False,
|
||||
):
|
||||
"""
|
||||
Start watching a directory for file changes.
|
||||
|
||||
@ -590,7 +653,7 @@ def start_file_watcher(llm, watch_directory, persist_directory=None,
|
||||
loop=loop,
|
||||
persist_directory=persist_directory,
|
||||
collection_name=collection_name,
|
||||
recreate=recreate
|
||||
recreate=recreate,
|
||||
)
|
||||
|
||||
# Process all files if:
|
||||
@ -613,6 +676,7 @@ def start_file_watcher(llm, watch_directory, persist_directory=None,
|
||||
logging.info(f"Started watching directory: {watch_directory}")
|
||||
return observer, file_watcher
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# When running directly, use absolute imports
|
||||
import defines
|
||||
|
@ -4,9 +4,12 @@ import logging
|
||||
|
||||
from . import defines
|
||||
|
||||
|
||||
def setup_logging(level=defines.logging_level) -> logging.Logger:
|
||||
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
||||
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="Overriding a previously registered kernel"
|
||||
)
|
||||
warnings.filterwarnings("ignore", message="Warning only once for all operators")
|
||||
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
|
||||
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
|
||||
@ -21,11 +24,17 @@ def setup_logging(level=defines.logging_level) -> logging.Logger:
|
||||
level=numeric_level,
|
||||
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
force=True
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Now reduce verbosity for FastAPI, Uvicorn, Starlette
|
||||
for noisy_logger in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi", "starlette"):
|
||||
for noisy_logger in (
|
||||
"uvicorn",
|
||||
"uvicorn.error",
|
||||
"uvicorn.access",
|
||||
"fastapi",
|
||||
"starlette",
|
||||
):
|
||||
# for noisy_logger in ("starlette"):
|
||||
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
|
||||
|
||||
|
@ -7,8 +7,8 @@ from .. import defines
|
||||
logger = setup_logging(level=defines.logging_level)
|
||||
|
||||
# Dynamically import all names from basetools listed in tools_all
|
||||
module = importlib.import_module('.basetools', package=__package__)
|
||||
module = importlib.import_module(".basetools", package=__package__)
|
||||
for name in tool_functions:
|
||||
globals()[name] = getattr(module, name)
|
||||
|
||||
__all__ = [ 'tools', 'llm_tools', 'enabled_tools', 'tool_functions' ]
|
||||
__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||
|
@ -13,6 +13,7 @@ import requests
|
||||
import yfinance as yf # type: ignore
|
||||
import logging
|
||||
|
||||
|
||||
# %%
|
||||
def WeatherForecast(city, state, country="USA"):
|
||||
"""
|
||||
@ -42,11 +43,12 @@ def WeatherForecast(city, state, country="USA"):
|
||||
# Step 3: Get the forecast data from the grid endpoint
|
||||
forecast = get_forecast(grid_endpoint)
|
||||
|
||||
if not forecast['location']:
|
||||
forecast['location'] = location
|
||||
if not forecast["location"]:
|
||||
forecast["location"] = location
|
||||
|
||||
return forecast
|
||||
|
||||
|
||||
def get_coordinates(location):
|
||||
"""Convert a location string to latitude and longitude using Nominatim geocoder."""
|
||||
try:
|
||||
@ -59,7 +61,7 @@ def get_coordinates(location):
|
||||
if location_data:
|
||||
return {
|
||||
"latitude": location_data.latitude,
|
||||
"longitude": location_data.longitude
|
||||
"longitude": location_data.longitude,
|
||||
}
|
||||
else:
|
||||
print(f"Location not found: {location}")
|
||||
@ -68,6 +70,7 @@ def get_coordinates(location):
|
||||
print(f"Error getting coordinates: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_grid_endpoint(coordinates):
|
||||
"""Get the grid endpoint from weather.gov based on coordinates."""
|
||||
try:
|
||||
@ -77,7 +80,7 @@ def get_grid_endpoint(coordinates):
|
||||
# Define headers for the API request
|
||||
headers = {
|
||||
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
||||
"Accept": "application/geo+json"
|
||||
"Accept": "application/geo+json",
|
||||
}
|
||||
|
||||
# Make the request to get the grid endpoint
|
||||
@ -94,15 +97,17 @@ def get_grid_endpoint(coordinates):
|
||||
print(f"Error in get_grid_endpoint: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Weather related function
|
||||
|
||||
|
||||
def get_forecast(grid_endpoint):
|
||||
"""Get the forecast data from the grid endpoint."""
|
||||
try:
|
||||
# Define headers for the API request
|
||||
headers = {
|
||||
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
||||
"Accept": "application/geo+json"
|
||||
"Accept": "application/geo+json",
|
||||
}
|
||||
|
||||
# Make the request to get the forecast
|
||||
@ -116,21 +121,25 @@ def get_forecast(grid_endpoint):
|
||||
|
||||
# Process the forecast data into a simpler format
|
||||
forecast = {
|
||||
"location": data["properties"].get("relativeLocation", {}).get("properties", {}),
|
||||
"location": data["properties"]
|
||||
.get("relativeLocation", {})
|
||||
.get("properties", {}),
|
||||
"updated": data["properties"].get("updated", ""),
|
||||
"periods": []
|
||||
"periods": [],
|
||||
}
|
||||
|
||||
for period in periods:
|
||||
forecast["periods"].append({
|
||||
forecast["periods"].append(
|
||||
{
|
||||
"name": period.get("name", ""),
|
||||
"temperature": period.get("temperature", ""),
|
||||
"temperatureUnit": period.get("temperatureUnit", ""),
|
||||
"windSpeed": period.get("windSpeed", ""),
|
||||
"windDirection": period.get("windDirection", ""),
|
||||
"shortForecast": period.get("shortForecast", ""),
|
||||
"detailedForecast": period.get("detailedForecast", "")
|
||||
})
|
||||
"detailedForecast": period.get("detailedForecast", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return forecast
|
||||
else:
|
||||
@ -140,6 +149,7 @@ def get_forecast(grid_endpoint):
|
||||
print(f"Error in get_forecast: {e}")
|
||||
return {"error": f"Exception: {str(e)}"}
|
||||
|
||||
|
||||
# Example usage
|
||||
# def do_weather():
|
||||
# city = input("Enter city: ")
|
||||
@ -166,34 +176,31 @@ def get_forecast(grid_endpoint):
|
||||
|
||||
# %%
|
||||
|
||||
|
||||
def TickerValue(ticker_symbols):
|
||||
api_key = os.getenv("TWELVEDATA_API_KEY", "")
|
||||
if not api_key:
|
||||
return {"error": f"Error fetching data: No API key for TwelveData"}
|
||||
|
||||
results = []
|
||||
for ticker_symbol in ticker_symbols.split(','):
|
||||
for ticker_symbol in ticker_symbols.split(","):
|
||||
ticker_symbol = ticker_symbol.strip()
|
||||
if ticker_symbol == "":
|
||||
continue
|
||||
|
||||
url = f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
|
||||
url = (
|
||||
f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
|
||||
)
|
||||
|
||||
response = requests.get(url)
|
||||
data = response.json()
|
||||
|
||||
if "price" in data:
|
||||
logging.info(f"TwelveData: {ticker_symbol} {data}")
|
||||
results.append({
|
||||
"symbol": ticker_symbol,
|
||||
"price": float(data["price"])
|
||||
})
|
||||
results.append({"symbol": ticker_symbol, "price": float(data["price"])})
|
||||
else:
|
||||
logging.error(f"TwelveData: {data}")
|
||||
results.append({
|
||||
"symbol": ticker_symbol,
|
||||
"price": "Unavailable"
|
||||
})
|
||||
results.append({"symbol": ticker_symbol, "price": "Unavailable"})
|
||||
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
@ -210,7 +217,7 @@ def yfTickerValue(ticker_symbols):
|
||||
dict: Current stock information including price
|
||||
"""
|
||||
results = []
|
||||
for ticker_symbol in ticker_symbols.split(','):
|
||||
for ticker_symbol in ticker_symbols.split(","):
|
||||
ticker_symbol = ticker_symbol.strip()
|
||||
if ticker_symbol == "":
|
||||
continue
|
||||
@ -226,19 +233,23 @@ def yfTickerValue(ticker_symbols):
|
||||
continue
|
||||
|
||||
# Get the latest closing price
|
||||
latest_price = ticker_data['Close'].iloc[-1]
|
||||
latest_price = ticker_data["Close"].iloc[-1]
|
||||
|
||||
# Get some additional info
|
||||
results.append({ 'symbol': ticker_symbol, 'price': latest_price })
|
||||
results.append({"symbol": ticker_symbol, "price": latest_price})
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logging.error(f"Error fetching data for {ticker_symbol}: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
results.append({"error": f"Error fetching data for {ticker_symbol}: {str(e)}"})
|
||||
results.append(
|
||||
{"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}
|
||||
)
|
||||
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
|
||||
# %%
|
||||
def DateTime(timezone="America/Los_Angeles"):
|
||||
"""
|
||||
@ -252,8 +263,8 @@ def DateTime(timezone="America/Los_Angeles"):
|
||||
str: Current date and time with timezone in the format YYYY-MM-DDTHH:MM:SS+HH:MM
|
||||
"""
|
||||
try:
|
||||
if timezone == 'system' or timezone == '' or not timezone:
|
||||
timezone = 'America/Los_Angeles'
|
||||
if timezone == "system" or timezone == "" or not timezone:
|
||||
timezone = "America/Los_Angeles"
|
||||
# Get current UTC time (timezone-aware)
|
||||
local_tz = pytz.timezone("America/Los_Angeles")
|
||||
local_now = datetime.now(tz=local_tz)
|
||||
@ -264,7 +275,8 @@ def DateTime(timezone="America/Los_Angeles"):
|
||||
|
||||
return target_time.isoformat()
|
||||
except Exception as e:
|
||||
return {'error': f"Invalid timezone {timezone}: {str(e)}"}
|
||||
return {"error": f"Invalid timezone {timezone}: {str(e)}"}
|
||||
|
||||
|
||||
async def AnalyzeSite(llm, model: str, url: str, question: str):
|
||||
"""
|
||||
@ -310,16 +322,18 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
|
||||
|
||||
# Generate summary using Ollama
|
||||
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
||||
response = llm.generate(model=model,
|
||||
response = llm.generate(
|
||||
model=model,
|
||||
system="You are given the contents of {url}. Answer the question about the contents",
|
||||
prompt=prompt)
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
# logging.info(response["response"])
|
||||
|
||||
return {
|
||||
"source": "summarizer-llm",
|
||||
"content": response["response"],
|
||||
"metadata": DateTime()
|
||||
"metadata": DateTime(),
|
||||
}
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
@ -331,7 +345,8 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
|
||||
|
||||
|
||||
# %%
|
||||
tools = [ {
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "TickerValue",
|
||||
@ -345,10 +360,11 @@ tools = [ {
|
||||
},
|
||||
},
|
||||
"required": ["ticker"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
}, {
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "AnalyzeSite",
|
||||
@ -366,27 +382,28 @@ tools = [ {
|
||||
},
|
||||
},
|
||||
"required": ["url", "question"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"returns": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "Identifier for the source LLM"
|
||||
"description": "Identifier for the source LLM",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The complete response from the second LLM"
|
||||
"description": "The complete response from the second LLM",
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"description": "Additional information about the response"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}, {
|
||||
"description": "Additional information about the response",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "DateTime",
|
||||
@ -396,13 +413,14 @@ tools = [ {
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'."
|
||||
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.",
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}, {
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WeatherForecast",
|
||||
@ -413,27 +431,30 @@ tools = [ {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
|
||||
"minLength": 2
|
||||
"minLength": 2,
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
|
||||
"minLength": 2
|
||||
}
|
||||
"minLength": 2,
|
||||
},
|
||||
},
|
||||
"required": ["city", "state"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
}]
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def llm_tools(tools):
|
||||
return [tool for tool in tools if tool.get("enabled", False) == True]
|
||||
|
||||
|
||||
def enabled_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
return [{**tool, "enabled": True} for tool in tools]
|
||||
|
||||
tool_functions = [ 'DateTime', 'WeatherForecast', 'TickerValue', 'AnalyzeSite' ]
|
||||
__all__ = [ 'tools', 'llm_tools', 'enabled_tools', 'tool_functions' ]
|
||||
#__all__.extend(__tool_functions__) # type: ignore
|
||||
|
||||
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite"]
|
||||
__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||
# __all__.extend(__tool_functions__) # type: ignore
|
||||
|
Loading…
x
Reference in New Issue
Block a user