Rag is being generated (again) however the LLM is not using it.
This commit is contained in:
parent
4a80004363
commit
77440a9d6b
@ -15,7 +15,7 @@ import { BackstoryElementProps } from './BackstoryTab';
|
||||
import { connectionBase } from 'utils/Global';
|
||||
import { useAuth } from "hooks/AuthContext";
|
||||
import { StreamingResponse } from 'services/api-client';
|
||||
import { ChatMessage, ChatMessageBase, ChatContext, ChatSession, ChatQuery } from 'types/types';
|
||||
import { ChatMessage, ChatMessageBase, ChatContext, ChatSession, ChatQuery, ChatMessageUser } from 'types/types';
|
||||
import { PaginatedResponse } from 'types/conversion';
|
||||
|
||||
import './Conversation.css';
|
||||
@ -259,7 +259,17 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
||||
{ ...defaultMessage, content: 'Submitting request...' }
|
||||
);
|
||||
|
||||
controllerRef.current = apiClient.sendMessageStream(sessionId, query, {
|
||||
const chatMessage: ChatMessageUser = {
|
||||
sessionId: chatSession.id,
|
||||
content: query.prompt,
|
||||
tunables: query.tunables,
|
||||
status: "done",
|
||||
type: "user",
|
||||
sender: "user",
|
||||
timestamp: new Date()
|
||||
};
|
||||
|
||||
controllerRef.current = apiClient.sendMessageStream(chatMessage, {
|
||||
onMessage: (msg: ChatMessageBase) => {
|
||||
console.log("onMessage:", msg);
|
||||
if (msg.type === "response") {
|
||||
|
@ -32,7 +32,7 @@ import { SetSnackType } from './Snack';
|
||||
import { CopyBubble } from './CopyBubble';
|
||||
import { Scrollable } from './Scrollable';
|
||||
import { BackstoryElementProps } from './BackstoryTab';
|
||||
import { ChatMessage, ChatSession, ChatMessageType } from 'types/types';
|
||||
import { ChatMessage, ChatSession, ChatMessageType, ChatMessageMetaData, ChromaDBGetResponse } from 'types/types';
|
||||
|
||||
const getStyle = (theme: Theme, type: ChatMessageType): any => {
|
||||
const defaultRadius = '16px';
|
||||
@ -184,19 +184,19 @@ interface MessageProps extends BackstoryElementProps {
|
||||
};
|
||||
|
||||
interface MessageMetaProps {
|
||||
metadata: Record<string, any>,
|
||||
metadata: ChatMessageMetaData,
|
||||
messageProps: MessageProps
|
||||
};
|
||||
|
||||
const MessageMeta = (props: MessageMetaProps) => {
|
||||
const {
|
||||
/* MessageData */
|
||||
rag,
|
||||
tools,
|
||||
eval_count,
|
||||
eval_duration,
|
||||
prompt_eval_count,
|
||||
prompt_eval_duration,
|
||||
ragResults = [],
|
||||
tools = null,
|
||||
evalCount = 0,
|
||||
evalDuration = 0,
|
||||
promptEvalCount = 0,
|
||||
promptEvalDuration = 0,
|
||||
} = props.metadata || {};
|
||||
const message: any = props.messageProps.message;
|
||||
|
||||
@ -206,7 +206,7 @@ const MessageMeta = (props: MessageMetaProps) => {
|
||||
|
||||
return (<>
|
||||
{
|
||||
prompt_eval_duration !== 0 && eval_duration !== 0 && <>
|
||||
promptEvalDuration !== 0 && evalDuration !== 0 && <>
|
||||
<TableContainer component={Card} className="PromptStats" sx={{ mb: 1 }}>
|
||||
<Table aria-label="prompt stats" size="small">
|
||||
<TableHead>
|
||||
@ -220,21 +220,21 @@ const MessageMeta = (props: MessageMetaProps) => {
|
||||
<TableBody>
|
||||
<TableRow key="prompt" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
||||
<TableCell component="th" scope="row">Prompt</TableCell>
|
||||
<TableCell align="right">{prompt_eval_count}</TableCell>
|
||||
<TableCell align="right">{Math.round(prompt_eval_duration / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round(prompt_eval_count * 10 ** 9 / prompt_eval_duration)}</TableCell>
|
||||
<TableCell align="right">{promptEvalCount}</TableCell>
|
||||
<TableCell align="right">{Math.round(promptEvalDuration / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round(promptEvalCount * 10 ** 9 / promptEvalDuration)}</TableCell>
|
||||
</TableRow>
|
||||
<TableRow key="response" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
||||
<TableCell component="th" scope="row">Response</TableCell>
|
||||
<TableCell align="right">{eval_count}</TableCell>
|
||||
<TableCell align="right">{Math.round(eval_duration / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round(eval_count * 10 ** 9 / eval_duration)}</TableCell>
|
||||
<TableCell align="right">{evalCount}</TableCell>
|
||||
<TableCell align="right">{Math.round(evalDuration / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round(evalCount * 10 ** 9 / evalDuration)}</TableCell>
|
||||
</TableRow>
|
||||
<TableRow key="total" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
|
||||
<TableCell component="th" scope="row">Total</TableCell>
|
||||
<TableCell align="right">{prompt_eval_count + eval_count}</TableCell>
|
||||
<TableCell align="right">{Math.round((prompt_eval_duration + eval_duration) / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round((prompt_eval_count + eval_count) * 10 ** 9 / (prompt_eval_duration + eval_duration))}</TableCell>
|
||||
<TableCell align="right">{promptEvalCount + evalCount}</TableCell>
|
||||
<TableCell align="right">{Math.round((promptEvalDuration + evalDuration) / 10 ** 7) / 100}</TableCell>
|
||||
<TableCell align="right">{Math.round((promptEvalCount + evalCount) * 10 ** 9 / (promptEvalDuration + evalDuration))}</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
@ -278,11 +278,11 @@ const MessageMeta = (props: MessageMetaProps) => {
|
||||
</Accordion>
|
||||
}
|
||||
{
|
||||
rag.map((collection: any) => (
|
||||
ragResults.map((collection: ChromaDBGetResponse) => (
|
||||
<Accordion key={collection.name}>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Box sx={{ fontSize: "0.8rem" }}>
|
||||
Top {collection.ids.length} RAG matches from {collection.size} entries using an embedding vector of {collection.query_embedding.length} dimensions
|
||||
Top {collection.ids?.length} RAG matches from {collection.size} entries using an embedding vector of {collection.queryEmbedding?.length} dimensions
|
||||
</Box>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
|
@ -26,7 +26,7 @@ import {
|
||||
Chat as ChatIcon
|
||||
} from '@mui/icons-material';
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import { ChatMessageBase, ChatMessage, ChatSession, ChatStatusType } from 'types/types';
|
||||
import { ChatMessageBase, ChatMessage, ChatSession, ChatStatusType, ChatMessageType, ChatMessageUser } from 'types/types';
|
||||
import { ConversationHandle } from 'components/Conversation';
|
||||
import { BackstoryPageProps } from 'components/BackstoryTab';
|
||||
import { Message } from 'components/Message';
|
||||
@ -35,16 +35,23 @@ import { CandidateSessionsResponse } from 'services/api-client';
|
||||
import { CandidateInfo } from 'components/CandidateInfo';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
||||
import PropagateLoader from 'react-spinners/PropagateLoader';
|
||||
|
||||
const DRAWER_WIDTH = 300;
|
||||
const HANDLE_WIDTH = 48;
|
||||
|
||||
const defaultMessage: ChatMessage = {
|
||||
type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: ""
|
||||
};
|
||||
|
||||
const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
||||
const { apiClient } = useAuth();
|
||||
const { selectedCandidate } = useSelectedCandidate()
|
||||
const navigate = useNavigate();
|
||||
const theme = useTheme();
|
||||
const isMdUp = useMediaQuery(theme.breakpoints.up('md'));
|
||||
const [processingMessage, setProcessingMessage] = useState<ChatMessage | null>(null);
|
||||
const [streamingMessage, setStreamingMessage] = useState<ChatMessage | null>(null);
|
||||
|
||||
const {
|
||||
setSnack,
|
||||
@ -53,10 +60,10 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
|
||||
const [sessions, setSessions] = useState<CandidateSessionsResponse | null>(null);
|
||||
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
|
||||
const [messages, setMessages] = useState([]);
|
||||
const [newMessage, setNewMessage] = useState('');
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [streaming, setStreaming] = useState(false);
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||
const [newMessage, setNewMessage] = useState<string>('');
|
||||
const [loading, setLoading] = useState<boolean>(false);
|
||||
const [streaming, setStreaming] = useState<boolean>(false);
|
||||
const messagesEndRef = useRef(null);
|
||||
|
||||
// Drawer state - defaults to open on md+ screens or when no session is selected
|
||||
@ -88,7 +95,11 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
|
||||
try {
|
||||
const result = await apiClient.getChatMessages(chatSession.id);
|
||||
setMessages(result.data as any);
|
||||
const chatMessages: ChatMessage[] = result.data;
|
||||
setMessages(chatMessages);
|
||||
setProcessingMessage(null);
|
||||
setStreamingMessage(null);
|
||||
console.log(`getChatMessages returned ${chatMessages.length} messages.`, chatMessages);
|
||||
} catch (error) {
|
||||
console.error('Failed to load messages:', error);
|
||||
}
|
||||
@ -106,6 +117,8 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
);
|
||||
setChatSession(newSession);
|
||||
setMessages([]);
|
||||
setProcessingMessage(null);
|
||||
setStreamingMessage(null);
|
||||
await loadSessions(); // Refresh sessions list
|
||||
} catch (error) {
|
||||
console.error('Failed to create session:', error);
|
||||
@ -122,35 +135,56 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
setNewMessage('');
|
||||
setStreaming(true);
|
||||
|
||||
const chatMessage: ChatMessageUser = {
|
||||
sessionId: chatSession.id,
|
||||
content: messageContent,
|
||||
status: "done",
|
||||
type: "user",
|
||||
sender: "user",
|
||||
timestamp: new Date()
|
||||
};
|
||||
|
||||
setMessages(prev => {
|
||||
const filtered = prev.filter((m: any) => m.id !== chatMessage.id);
|
||||
return [...filtered, chatMessage] as any;
|
||||
});
|
||||
|
||||
try {
|
||||
await apiClient.sendMessageStream(
|
||||
chatSession.id,
|
||||
{ prompt: messageContent }, {
|
||||
await apiClient.sendMessageStream(chatMessage, {
|
||||
onMessage: (msg: ChatMessage) => {
|
||||
console.log("onMessage:", msg);
|
||||
console.log(`onMessage: ${msg.type} ${msg.content}`, msg);
|
||||
if (msg.type === "response") {
|
||||
setMessages(prev => {
|
||||
const filtered = prev.filter((m: any) => m.id !== msg.id);
|
||||
return [...filtered, msg].sort((a, b) =>
|
||||
new Date(a.timestamp).getTime() - new Date(b.timestamp).getTime()
|
||||
) as any;
|
||||
});
|
||||
setMessages(prev => {
|
||||
const filtered = prev.filter((m: any) => m.id !== msg.id);
|
||||
return [...filtered, msg] as any;
|
||||
});
|
||||
setStreamingMessage(null);
|
||||
setProcessingMessage(null);
|
||||
} else {
|
||||
console.log(msg);
|
||||
setProcessingMessage(msg);
|
||||
}
|
||||
},
|
||||
onError: (error: string | ChatMessageBase) => {
|
||||
console.log("onError:", error);
|
||||
// Type-guard to determine if this is a ChatMessageBase or a string
|
||||
if (typeof error === "object" && error !== null && "content" in error) {
|
||||
setProcessingMessage(error as ChatMessage);
|
||||
} else {
|
||||
setProcessingMessage({ ...defaultMessage, content: error as string });
|
||||
}
|
||||
setStreaming(false);
|
||||
},
|
||||
onStreaming: (chunk: ChatMessageBase) => {
|
||||
console.log("onStreaming:", chunk);
|
||||
// console.log("onStreaming:", chunk);
|
||||
setStreamingMessage({ ...defaultMessage, ...chunk });
|
||||
},
|
||||
onStatusChange: (status: ChatStatusType) => {
|
||||
console.log("onStatusChange:", status);
|
||||
onStatusChange: (status: string) => {
|
||||
console.log(`onStatusChange: ${status}`);
|
||||
},
|
||||
onComplete: () => {
|
||||
console.log("onComplete");
|
||||
setStreamingMessage(null);
|
||||
setProcessingMessage(null);
|
||||
setStreaming(false);
|
||||
}
|
||||
});
|
||||
@ -321,27 +355,37 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
>
|
||||
{chatSession?.id ? (
|
||||
<>
|
||||
{/* Messages Area */}
|
||||
<Box sx={{ flexGrow: 1, overflow: 'auto', p: 2 }}>
|
||||
{messages.map((message: ChatMessageBase) => (
|
||||
<Message key={message.id} {...{ chatSession, message, setSnack, submitQuery }} />
|
||||
))}
|
||||
|
||||
{streaming && (
|
||||
<Box sx={{ display: 'flex', alignItems: 'center', mb: 2 }}>
|
||||
<Avatar sx={{ mr: 1, bgcolor: 'primary.main' }}>
|
||||
🤖
|
||||
</Avatar>
|
||||
<Card>
|
||||
<CardContent sx={{ display: 'flex', alignItems: 'center', p: 2 }}>
|
||||
<CircularProgress size={16} sx={{ mr: 1 }} />
|
||||
<Typography variant="body2">AI is typing...</Typography>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</Box>
|
||||
)}
|
||||
|
||||
{
|
||||
messages.map((message: ChatMessageBase) => (
|
||||
<Message key={message.id} {...{ chatSession, message, setSnack, submitQuery }} />
|
||||
))
|
||||
}
|
||||
{
|
||||
processingMessage !== null &&
|
||||
<Message {...{ chatSession, message: processingMessage, setSnack, submitQuery }} />
|
||||
}
|
||||
{
|
||||
streamingMessage !== null &&
|
||||
<Message {...{ chatSession, message: streamingMessage, setSnack, submitQuery }} />
|
||||
}
|
||||
{streaming && <Box sx={{
|
||||
display: "flex",
|
||||
flexDirection: "column",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
m: 1,
|
||||
}}>
|
||||
<PropagateLoader
|
||||
size="10px"
|
||||
loading={streaming}
|
||||
aria-label="Loading Spinner"
|
||||
data-testid="loader"
|
||||
/>
|
||||
</Box>
|
||||
}
|
||||
<div ref={messagesEndRef} />
|
||||
|
||||
</Box>
|
||||
|
||||
<Divider />
|
||||
@ -354,7 +398,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
placeholder="Type your message about the candidate..."
|
||||
value={newMessage}
|
||||
onChange={(e) => setNewMessage(e.target.value)}
|
||||
onKeyPress={(e) => {
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
sendMessage();
|
||||
@ -384,10 +428,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
flexDirection: 'column',
|
||||
gap: 2
|
||||
}}
|
||||
>
|
||||
<Typography variant="h1" sx={{ fontSize: 64, color: 'text.secondary' }}>
|
||||
🤖
|
||||
</Typography>
|
||||
>
|
||||
<Typography variant="h6" color="text.secondary">
|
||||
Select a session to start chatting
|
||||
</Typography>
|
||||
|
@ -19,10 +19,11 @@ import { StyledMarkdown } from 'components/StyledMarkdown';
|
||||
import { Scrollable } from '../components/Scrollable';
|
||||
import { Pulse } from 'components/Pulse';
|
||||
import { StreamingResponse } from 'services/api-client';
|
||||
import { ChatContext, ChatMessage, ChatMessageBase, ChatSession, ChatQuery } from 'types/types';
|
||||
import { ChatContext, ChatMessage, ChatMessageUser, ChatMessageBase, ChatSession, ChatQuery } from 'types/types';
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
|
||||
const emptyUser: Candidate = {
|
||||
userType: "candidate",
|
||||
description: "[blank]",
|
||||
username: "[blank]",
|
||||
firstName: "[blank]",
|
||||
@ -101,7 +102,17 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
||||
setCanGenImage(false);
|
||||
setShouldGenerateProfile(false); // Reset the flag
|
||||
|
||||
const streamResponse = apiClient.sendMessageStream(sessionId, query, {
|
||||
const chatMessage: ChatMessageUser = {
|
||||
sessionId: chatSession.id,
|
||||
content: query.prompt,
|
||||
tunables: query.tunables,
|
||||
status: "done",
|
||||
type: "user",
|
||||
sender: "user",
|
||||
timestamp: new Date()
|
||||
};
|
||||
|
||||
const streamResponse = apiClient.sendMessageStream(chatMessage, {
|
||||
onMessage: (chatMessage: ChatMessage) => {
|
||||
console.log('Message:', chatMessage);
|
||||
// Update UI with partial content
|
||||
|
@ -505,25 +505,11 @@ class ApiClient {
|
||||
return this.handleApiResponseWithConversion<Types.ChatSession>(response, 'ChatSession');
|
||||
}
|
||||
|
||||
/**
|
||||
* Send message with standard response (non-streaming)
|
||||
*/
|
||||
async sendMessage(sessionId: string, query: Types.ChatQuery): Promise<Types.ChatMessage> {
|
||||
const response = await fetch(`${this.baseUrl}/chat/sessions/${sessionId}/messages`, {
|
||||
method: 'POST',
|
||||
headers: this.defaultHeaders,
|
||||
body: JSON.stringify(formatApiRequest({query}))
|
||||
});
|
||||
|
||||
return this.handleApiResponseWithConversion<Types.ChatMessage>(response, 'ChatMessage');
|
||||
}
|
||||
|
||||
/**
|
||||
* Send message with streaming response support and date conversion
|
||||
*/
|
||||
sendMessageStream(
|
||||
sessionId: string,
|
||||
query: Types.ChatQuery,
|
||||
chatMessage: Types.ChatMessageUser,
|
||||
options: StreamingOptions = {}
|
||||
): StreamingResponse {
|
||||
const abortController = new AbortController();
|
||||
@ -533,14 +519,14 @@ class ApiClient {
|
||||
|
||||
const promise = new Promise<Types.ChatMessage[]>(async (resolve, reject) => {
|
||||
try {
|
||||
const response = await fetch(`${this.baseUrl}/chat/sessions/${sessionId}/messages/stream`, {
|
||||
const response = await fetch(`${this.baseUrl}/chat/sessions/${chatMessage.sessionId}/messages/stream`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
...this.defaultHeaders,
|
||||
'Accept': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache'
|
||||
},
|
||||
body: JSON.stringify(formatApiRequest({ query })),
|
||||
body: JSON.stringify(formatApiRequest({ chatMessage })),
|
||||
signal
|
||||
});
|
||||
|
||||
@ -555,8 +541,8 @@ class ApiClient {
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = '';
|
||||
let chatMessage: Types.ChatMessage | null = null;
|
||||
const chatMessageList: Types.ChatMessage[] = [];
|
||||
let incomingMessage: Types.ChatMessage | null = null;
|
||||
const incomingMessageList: Types.ChatMessage[] = [];
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
@ -585,20 +571,20 @@ class ApiClient {
|
||||
const convertedIncoming = convertChatMessageFromApi(incoming);
|
||||
|
||||
// Trigger callbacks based on status
|
||||
if (convertedIncoming.status !== chatMessage?.status) {
|
||||
if (convertedIncoming.status !== incomingMessage?.status) {
|
||||
options.onStatusChange?.(convertedIncoming.status);
|
||||
}
|
||||
|
||||
// Handle different status types
|
||||
switch (convertedIncoming.status) {
|
||||
case 'streaming':
|
||||
if (chatMessage === null) {
|
||||
chatMessage = {...convertedIncoming};
|
||||
if (incomingMessage === null) {
|
||||
incomingMessage = {...convertedIncoming};
|
||||
} else {
|
||||
// Can't do a simple += as typescript thinks .content might not be there
|
||||
chatMessage.content = (chatMessage?.content || '') + convertedIncoming.content;
|
||||
incomingMessage.content = (incomingMessage?.content || '') + convertedIncoming.content;
|
||||
// Update timestamp to latest
|
||||
chatMessage.timestamp = convertedIncoming.timestamp;
|
||||
incomingMessage.timestamp = convertedIncoming.timestamp;
|
||||
}
|
||||
options.onStreaming?.(convertedIncoming);
|
||||
break;
|
||||
@ -608,7 +594,7 @@ class ApiClient {
|
||||
break;
|
||||
|
||||
default:
|
||||
chatMessageList.push(convertedIncoming);
|
||||
incomingMessageList.push(convertedIncoming);
|
||||
options.onMessage?.(convertedIncoming);
|
||||
break;
|
||||
}
|
||||
@ -627,7 +613,7 @@ class ApiClient {
|
||||
}
|
||||
|
||||
options.onComplete?.();
|
||||
resolve(chatMessageList);
|
||||
resolve(incomingMessageList);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
options.onComplete?.();
|
||||
@ -647,24 +633,6 @@ class ApiClient {
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Send message with automatic streaming detection
|
||||
*/
|
||||
async sendMessageAuto(
|
||||
sessionId: string,
|
||||
query: Types.ChatQuery,
|
||||
options?: StreamingOptions
|
||||
): Promise<Types.ChatMessage[]> {
|
||||
// If streaming options are provided, use streaming
|
||||
if (options && (options.onMessage || options.onStreaming || options.onStatusChange)) {
|
||||
const streamResponse = this.sendMessageStream(sessionId, query, options);
|
||||
return streamResponse.promise;
|
||||
}
|
||||
|
||||
// Otherwise, use standard response
|
||||
return [await this.sendMessage(sessionId, query)];
|
||||
}
|
||||
|
||||
/**
|
||||
* Get persisted chat messages for a session with date conversion
|
||||
*/
|
||||
|
@ -1,6 +1,6 @@
|
||||
// Generated TypeScript types from Pydantic models
|
||||
// Source: src/backend/models.py
|
||||
// Generated on: 2025-05-30T18:07:14.923475
|
||||
// Generated on: 2025-05-31T18:20:52.253576
|
||||
// DO NOT EDIT MANUALLY - This file is auto-generated
|
||||
|
||||
// ============================
|
||||
@ -176,7 +176,7 @@ export interface Candidate {
|
||||
lastLogin?: Date;
|
||||
profileImage?: string;
|
||||
status: "active" | "inactive" | "pending" | "banned";
|
||||
userType?: "candidate";
|
||||
userType: "candidate";
|
||||
username: string;
|
||||
description?: string;
|
||||
resume?: string;
|
||||
@ -192,6 +192,8 @@ export interface Candidate {
|
||||
certifications?: Array<Certification>;
|
||||
jobApplications?: Array<JobApplication>;
|
||||
hasProfile?: boolean;
|
||||
rags?: Array<RagEntry>;
|
||||
ragContentSize?: number;
|
||||
age?: number;
|
||||
gender?: "female" | "male";
|
||||
ethnicity?: string;
|
||||
@ -246,6 +248,7 @@ export interface ChatMessage {
|
||||
type: "error" | "generating" | "info" | "preparing" | "processing" | "response" | "searching" | "system" | "thinking" | "tooling" | "user";
|
||||
sender: "user" | "assistant" | "system";
|
||||
timestamp: Date;
|
||||
tunables?: Tunables;
|
||||
content?: string;
|
||||
metadata?: ChatMessageMetaData;
|
||||
}
|
||||
@ -258,19 +261,20 @@ export interface ChatMessageBase {
|
||||
type: "error" | "generating" | "info" | "preparing" | "processing" | "response" | "searching" | "system" | "thinking" | "tooling" | "user";
|
||||
sender: "user" | "assistant" | "system";
|
||||
timestamp: Date;
|
||||
tunables?: Tunables;
|
||||
content?: string;
|
||||
}
|
||||
|
||||
export interface ChatMessageMetaData {
|
||||
model?: "qwen2.5" | "flux-schnell";
|
||||
model: "qwen2.5";
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
topP?: number;
|
||||
frequencyPenalty?: number;
|
||||
presencePenalty?: number;
|
||||
stopSequences?: Array<string>;
|
||||
tunables?: Tunables;
|
||||
rag?: Array<ChromaDBGetResponse>;
|
||||
ragResults?: Array<ChromaDBGetResponse>;
|
||||
llmHistory?: Array<LLMMessage>;
|
||||
evalCount?: number;
|
||||
evalDuration?: number;
|
||||
promptEvalCount?: number;
|
||||
@ -284,10 +288,11 @@ export interface ChatMessageUser {
|
||||
id?: string;
|
||||
sessionId: string;
|
||||
senderId?: string;
|
||||
status: "initializing" | "streaming" | "done" | "error";
|
||||
type?: "error" | "generating" | "info" | "preparing" | "processing" | "response" | "searching" | "system" | "thinking" | "tooling" | "user";
|
||||
sender: "user" | "assistant" | "system";
|
||||
status: "done";
|
||||
type: "user";
|
||||
sender: "user";
|
||||
timestamp: Date;
|
||||
tunables?: Tunables;
|
||||
content?: string;
|
||||
}
|
||||
|
||||
@ -386,7 +391,7 @@ export interface Employer {
|
||||
lastLogin?: Date;
|
||||
profileImage?: string;
|
||||
status: "active" | "inactive" | "pending" | "banned";
|
||||
userType?: "employer";
|
||||
userType: "employer";
|
||||
companyName: string;
|
||||
industry: string;
|
||||
description?: string;
|
||||
@ -507,6 +512,12 @@ export interface JobResponse {
|
||||
meta?: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface LLMMessage {
|
||||
role?: string;
|
||||
content?: string;
|
||||
toolCalls?: Array<Record<string, any>>;
|
||||
}
|
||||
|
||||
export interface Language {
|
||||
language: string;
|
||||
proficiency: "basic" | "conversational" | "fluent" | "native";
|
||||
|
@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import traceback
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from typing import (
|
||||
Literal,
|
||||
@ -18,18 +19,18 @@ import pathlib
|
||||
import inspect
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
|
||||
from database import RedisDatabase
|
||||
from . base import Agent
|
||||
from logger import logger
|
||||
from models import Candidate
|
||||
|
||||
_agents: List[Agent] = []
|
||||
|
||||
def get_or_create_agent(agent_type: str, prometheus_collector: CollectorRegistry, database: RedisDatabase, **kwargs) -> Agent:
|
||||
def get_or_create_agent(agent_type: str, prometheus_collector: CollectorRegistry, user: Optional[Candidate]=None, **kwargs) -> Agent:
|
||||
"""
|
||||
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
|
||||
|
||||
Args:
|
||||
agent_type: The type of agent to create (e.g., 'web', 'database').
|
||||
agent_type: The type of agent to create (e.g., 'general', 'candidate_chat', 'image_generation').
|
||||
**kwargs: Additional fields required by the specific agent subclass.
|
||||
|
||||
Returns:
|
||||
@ -47,7 +48,7 @@ def get_or_create_agent(agent_type: str, prometheus_collector: CollectorRegistry
|
||||
for agent_cls in Agent.__subclasses__():
|
||||
if agent_cls.model_fields["agent_type"].default == agent_type:
|
||||
# Create the agent instance with provided kwargs
|
||||
agent = agent_cls(agent_type=agent_type, prometheus_collector=prometheus_collector, database=database, **kwargs)
|
||||
agent = agent_cls(agent_type=agent_type, user=user, prometheus_collector=prometheus_collector, **kwargs)
|
||||
# if agent.agent_persist: # If an agent is not set to persist, do not add it to the list
|
||||
_agents.append(agent)
|
||||
return agent
|
||||
@ -90,8 +91,10 @@ for path in package_dir.glob("*.py"):
|
||||
logger.info(f"Adding agent: {name}")
|
||||
__all__.append(name) # type: ignore
|
||||
except ImportError as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Error importing {full_module_name}: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Error processing {full_module_name}: {e}")
|
||||
raise e
|
||||
|
@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import traceback
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from typing import (
|
||||
Literal,
|
||||
@ -20,19 +21,16 @@ from abc import ABC
|
||||
import asyncio
|
||||
from datetime import datetime, UTC
|
||||
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from models import ( ChatQuery, ChatMessage, ChatOptions, ChatMessageBase, ChatMessageUser, Tunables, ChatMessageType, ChatSenderType, ChatStatusType, ChatMessageMetaData)
|
||||
from models import ( LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageBase, ChatMessageUser, Tunables, ChatMessageType, ChatSenderType, ChatStatusType, ChatMessageMetaData, Candidate)
|
||||
from logger import logger
|
||||
import defines
|
||||
from .registry import agent_registry
|
||||
from metrics import Metrics
|
||||
from database import RedisDatabase # type: ignore
|
||||
import model_cast
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
role: str = Field(default="")
|
||||
content: str = Field(default="")
|
||||
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
|
||||
from rag import ( ChromaDBGetResponse )
|
||||
|
||||
class Agent(BaseModel, ABC):
|
||||
"""
|
||||
@ -45,12 +43,8 @@ class Agent(BaseModel, ABC):
|
||||
# Agent management with pydantic
|
||||
agent_type: Literal["base"] = "base"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
agent_persist: bool = True # Whether this agent will persist in the database
|
||||
|
||||
database: RedisDatabase = Field(
|
||||
...,
|
||||
description="Database connection for this agent, used to store and retrieve data."
|
||||
)
|
||||
user: Optional[Candidate] = None
|
||||
prometheus_collector: CollectorRegistry = Field(..., description="Prometheus collector for this agent, used to track metrics.", exclude=True)
|
||||
|
||||
# Tunables (sets default for new Messages attached to this agent)
|
||||
@ -131,48 +125,6 @@ class Agent(BaseModel, ABC):
|
||||
def get_agent_type(self):
|
||||
return self._agent_type
|
||||
|
||||
# async def prepare_message(self, message: ChatMessage) -> AsyncGenerator[ChatMessage, None]:
|
||||
# """
|
||||
# Prepare message with context information in message.preamble
|
||||
# """
|
||||
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
|
||||
# self.metrics.prepare_count.labels(agent=self.agent_type).inc()
|
||||
# with self.metrics.prepare_duration.labels(agent=self.agent_type).time():
|
||||
# if not self.context:
|
||||
# raise ValueError("Context is not set for this agent.")
|
||||
|
||||
# # Generate RAG content if enabled, based on the content
|
||||
# rag_context = ""
|
||||
# if message.tunables.enable_rag and message.prompt:
|
||||
# # Gather RAG results, yielding each result
|
||||
# # as it becomes available
|
||||
# for message in self.context.user.generate_rag_results(message):
|
||||
# logger.info(f"RAG: {message.status} - {message.content}")
|
||||
# if message.status == "error":
|
||||
# yield message
|
||||
# return
|
||||
# if message.status != "done":
|
||||
# yield message
|
||||
|
||||
# # for rag in message.metadata.rag:
|
||||
# # for doc in rag.documents:
|
||||
# # rag_context += f"{doc}\n"
|
||||
|
||||
# message.preamble = {}
|
||||
|
||||
# if rag_context:
|
||||
# message.preamble["context"] = f"The following is context information about {self.context.user.full_name}:\n{rag_context}"
|
||||
|
||||
# if message.tunables.enable_context and self.context.user_resume:
|
||||
# message.preamble["resume"] = self.context.user_resume
|
||||
|
||||
# message.system_prompt = self.system_prompt
|
||||
# message.status = ChatStatusType.DONE
|
||||
# yield message
|
||||
|
||||
# return
|
||||
|
||||
# async def process_tool_calls(
|
||||
# self,
|
||||
# llm: Any,
|
||||
@ -341,46 +293,168 @@ class Agent(BaseModel, ABC):
|
||||
)
|
||||
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
|
||||
|
||||
async def generate_rag_results(
|
||||
self,
|
||||
chat_message: ChatMessage,
|
||||
top_k: int=defines.default_rag_top_k,
|
||||
threshold: float=defines.default_rag_threshold
|
||||
) -> AsyncGenerator[ChatMessage, None]:
|
||||
"""
|
||||
Generate RAG results for the given query.
|
||||
|
||||
Args:
|
||||
query: The query string to generate RAG results for.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries containing the RAG results.
|
||||
"""
|
||||
rag_message = ChatMessage(
|
||||
session_id=chat_message.session_id,
|
||||
tunables=chat_message.tunables,
|
||||
status=ChatStatusType.INITIALIZING,
|
||||
type=ChatMessageType.PREPARING,
|
||||
sender=ChatSenderType.ASSISTANT,
|
||||
content="",
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata=ChatMessageMetaData()
|
||||
)
|
||||
|
||||
if not self.user:
|
||||
rag_message.status = ChatStatusType.DONE
|
||||
rag_message.content = "No user connected to this chat, so no RAG content."
|
||||
yield rag_message
|
||||
return
|
||||
|
||||
try:
|
||||
entries: int = 0
|
||||
user: Candidate = self.user
|
||||
for rag in user.rags:
|
||||
if not rag.enabled:
|
||||
continue
|
||||
rag_message.type = ChatMessageType.SEARCHING
|
||||
rag_message.status = ChatStatusType.INITIALIZING
|
||||
rag_message.content = f"Checking RAG context {rag.name}..."
|
||||
yield rag_message
|
||||
|
||||
chroma_results = user.file_watcher.find_similar(
|
||||
query=rag_message.content, top_k=top_k, threshold=threshold
|
||||
)
|
||||
if chroma_results:
|
||||
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
||||
|
||||
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
||||
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
||||
|
||||
rag_metadata = ChromaDBGetResponse(
|
||||
query=chat_message.content,
|
||||
query_embedding=query_embedding.tolist(),
|
||||
name=rag.name,
|
||||
ids=chroma_results.get("ids", []),
|
||||
embeddings=chroma_results.get("embeddings", []),
|
||||
documents=chroma_results.get("documents", []),
|
||||
metadatas=chroma_results.get("metadatas", []),
|
||||
umap_embedding_2d=umap_2d.tolist(),
|
||||
umap_embedding_3d=umap_3d.tolist(),
|
||||
size=user.file_watcher.collection.count()
|
||||
)
|
||||
|
||||
rag_message.metadata.rag_results.append(rag_metadata)
|
||||
rag_message.content = f"Results from {rag.name} RAG: {len(chroma_results['documents'])} results."
|
||||
yield rag_message
|
||||
|
||||
rag_message.content = (
|
||||
f"RAG context gathered from results from {entries} documents."
|
||||
)
|
||||
rag_message.status = ChatStatusType.DONE
|
||||
yield rag_message
|
||||
return
|
||||
except Exception as e:
|
||||
rag_message.status = ChatStatusType.ERROR
|
||||
rag_message.content = f"Error generating RAG results: {str(e)}"
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(rag_message.content)
|
||||
yield rag_message
|
||||
return
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str, query: ChatQuery, user_message: ChatMessageUser, user_id: str, temperature=0.7
|
||||
self, llm: Any, model: str, user_message: ChatMessageUser, user_id: str, temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageBase, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
|
||||
chat_message = ChatMessage(
|
||||
session_id=user_message.session_id,
|
||||
tunables=query.tunables,
|
||||
tunables=user_message.tunables,
|
||||
status=ChatStatusType.INITIALIZING,
|
||||
type=ChatMessageType.PREPARING,
|
||||
sender=ChatSenderType.ASSISTANT,
|
||||
content="",
|
||||
timestamp=datetime.now(UTC)
|
||||
)
|
||||
|
||||
chat_message.metadata = ChatMessageMetaData()
|
||||
chat_message.metadata.options = ChatOptions(
|
||||
seed=8911,
|
||||
num_ctx=self.context_size,
|
||||
temperature=temperature, # Higher temperature to encourage tool usage
|
||||
)
|
||||
|
||||
# Create a dict for storing various timing stats
|
||||
chat_message.metadata.timers = {}
|
||||
|
||||
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
||||
|
||||
rag_message : Optional[ChatMessage] = None
|
||||
async for rag_message in self.generate_rag_results(chat_message=user_message):
|
||||
if rag_message.status == ChatStatusType.ERROR:
|
||||
chat_message.status = rag_message.status
|
||||
chat_message.content = rag_message.content
|
||||
yield chat_message
|
||||
return
|
||||
yield rag_message
|
||||
|
||||
rag_context = ""
|
||||
if rag_message:
|
||||
rag_results: List[ChromaDBGetResponse] = rag_message.metadata.rag_results
|
||||
chat_message.metadata.rag_results = rag_results
|
||||
for chroma_results in rag_results:
|
||||
for index, metadata in enumerate(chroma_results.metadatas):
|
||||
content = "\n".join([
|
||||
line.strip()
|
||||
for line in chroma_results.documents[index].split("\n")
|
||||
if line
|
||||
]).strip()
|
||||
rag_context += f"""
|
||||
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
||||
Document reference: {chroma_results.ids[index]}
|
||||
Content: { content }
|
||||
|
||||
"""
|
||||
|
||||
# Create a pruned down message list based purely on the prompt and responses,
|
||||
# discarding the full preamble generated by prepare_message
|
||||
messages: List[LLMMessage] = [
|
||||
LLMMessage(role="system", content=self.system_prompt)
|
||||
]
|
||||
# Add the conversation history to the messages
|
||||
messages.extend([
|
||||
LLMMessage(role=m.sender, content=m.content.strip())
|
||||
for m in self.conversation
|
||||
])
|
||||
# Add the RAG context to the messages if available
|
||||
if rag_context:
|
||||
messages.append(
|
||||
LLMMessage(
|
||||
role="system",
|
||||
content=f"<|context|>\n{rag_context.strip()}\n</|context|>"
|
||||
)
|
||||
)
|
||||
# Only the actual user query is provided with the full context message
|
||||
messages.append(
|
||||
LLMMessage(role=user_message.sender, content=user_message.content.strip())
|
||||
)
|
||||
|
||||
# message.messages = messages
|
||||
chat_message.metadata = ChatMessageMetaData()
|
||||
chat_message.metadata.options = ChatOptions(
|
||||
seed=8911,
|
||||
num_ctx=self.context_size,
|
||||
temperature=temperature, # Higher temperature to encourage tool usage
|
||||
)
|
||||
|
||||
# Create a dict for storing various timing stats
|
||||
chat_message.metadata.timers = {}
|
||||
chat_message.metadata.llm_history = messages
|
||||
|
||||
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
||||
# message.metadata.tools = {
|
||||
@ -497,7 +571,7 @@ class Agent(BaseModel, ABC):
|
||||
model=model,
|
||||
messages=messages,
|
||||
options={
|
||||
**chat_message.metadata.model_dump(exclude_unset=True),
|
||||
**chat_message.metadata.options.model_dump(exclude_unset=True),
|
||||
},
|
||||
stream=True,
|
||||
):
|
||||
|
3
src/backend/entities/__init__.py
Normal file
3
src/backend/entities/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .candidate_entity import CandidateEntity
|
||||
from .entity_manager import entity_manager, get_candidate_entity
|
||||
|
226
src/backend/entities/candidate_entity.py
Normal file
226
src/backend/entities/candidate_entity.py
Normal file
@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from uuid import uuid4
|
||||
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
|
||||
|
||||
from typing_extensions import Annotated, Union
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from uuid import uuid4
|
||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||
import traceback
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from rag import start_file_watcher, ChromaDBFileWatcher, ChromaDBGetResponse
|
||||
import defines
|
||||
from logger import logger
|
||||
import agents as agents
|
||||
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageType, ChatMessageMetaData, ChatStatusType, Candidate, ChatContextType)
|
||||
from llm_manager import llm_manager
|
||||
|
||||
class CandidateEntity(Candidate):
|
||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
||||
|
||||
# Internal instance members
|
||||
CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True)
|
||||
CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
||||
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
|
||||
def __init__(self, candidate=None):
|
||||
if candidate is not None:
|
||||
# Copy attributes from the candidate instance
|
||||
super().__init__(**vars(candidate))
|
||||
else:
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def exists(cls, username: str):
|
||||
# Validate username format (only allow safe characters)
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
|
||||
return False # Invalid username characters
|
||||
|
||||
# Check for minimum and maximum length
|
||||
if not (3 <= len(username) <= 32):
|
||||
return False # Invalid username length
|
||||
|
||||
# Use Path for safe path handling and normalization
|
||||
user_dir = Path(defines.user_dir) / username
|
||||
user_info_path = user_dir / defines.user_info_file
|
||||
|
||||
# Ensure the final path is actually within the intended parent directory
|
||||
# to help prevent directory traversal attacks
|
||||
try:
|
||||
if not user_dir.resolve().is_relative_to(Path(defines.user_dir).resolve()):
|
||||
return False # Path traversal attempt detected
|
||||
except (ValueError, RuntimeError): # Potential exceptions from resolve()
|
||||
return False
|
||||
|
||||
# Check if file exists
|
||||
return user_info_path.is_file()
|
||||
|
||||
def get_or_create_agent(self, agent_type: ChatContextType, **kwargs) -> agents.Agent:
|
||||
"""
|
||||
Get or create an agent of the specified type for this candidate.
|
||||
|
||||
Args:
|
||||
agent_type: The type of agent to create (default is 'candidate_chat').
|
||||
**kwargs: Additional fields required by the specific agent subclass.
|
||||
|
||||
Returns:
|
||||
The created agent instance.
|
||||
"""
|
||||
return agents.get_or_create_agent(
|
||||
agent_type=agent_type,
|
||||
user=self,
|
||||
prometheus_collector=self.prometheus_collector,
|
||||
**kwargs)
|
||||
|
||||
# Wrapper properties that map into file_watcher
|
||||
@property
|
||||
def umap_collection(self) -> ChromaDBGetResponse:
|
||||
if not self.CandidateEntity__file_watcher:
|
||||
raise ValueError("initialize() has not been called.")
|
||||
return self.CandidateEntity__file_watcher.umap_collection
|
||||
|
||||
# Fields managed by initialize()
|
||||
CandidateEntity__initialized: bool = Field(default=False, exclude=True)
|
||||
@property
|
||||
def file_watcher(self) -> ChromaDBFileWatcher:
|
||||
if not self.CandidateEntity__file_watcher:
|
||||
raise ValueError("initialize() has not been called.")
|
||||
return self.CandidateEntity__file_watcher
|
||||
|
||||
@property
|
||||
def prometheus_collector(self) -> CollectorRegistry:
|
||||
if not self.CandidateEntity__prometheus_collector:
|
||||
raise ValueError("initialize() has not been called with a prometheus_collector.")
|
||||
return self.CandidateEntity__prometheus_collector
|
||||
|
||||
@property
|
||||
def observer(self) -> Any:
|
||||
if not self.CandidateEntity__observer:
|
||||
raise ValueError("initialize() has not been called.")
|
||||
return self.CandidateEntity__observer
|
||||
|
||||
@classmethod
|
||||
def sanitize(cls, user: Dict[str, Any]):
|
||||
sanitized : Dict[str, Any] = {}
|
||||
sanitized["username"] = user.get("username", "default")
|
||||
sanitized["first_name"] = user.get("first_name", sanitized["username"])
|
||||
sanitized["last_name"] = user.get("last_name", "")
|
||||
sanitized["title"] = user.get("title", "")
|
||||
sanitized["phone"] = user.get("phone", "")
|
||||
sanitized["location"] = user.get("location", "")
|
||||
sanitized["email"] = user.get("email", "")
|
||||
sanitized["full_name"] = user.get("full_name", f"{sanitized["first_name"]} {sanitized["last_name"]}")
|
||||
sanitized["description"] = user.get("description", "")
|
||||
profile_image = os.path.join(defines.user_dir, sanitized["username"], "profile.png")
|
||||
sanitized["has_profile"] = os.path.exists(profile_image)
|
||||
contact_info = user.get("contact_info", {})
|
||||
sanitized["contact_info"] = {}
|
||||
for key in contact_info:
|
||||
if not isinstance(contact_info[key], (str, int, float, complex)):
|
||||
continue
|
||||
sanitized["contact_info"][key] = contact_info[key]
|
||||
questions = user.get("questions", [ f"Tell me about {sanitized['first_name']}.", f"What are {sanitized['first_name']}'s professional strengths?"])
|
||||
sanitized["user_questions"] = []
|
||||
for question in questions:
|
||||
if type(question) == str:
|
||||
sanitized["user_questions"].append({"question": question})
|
||||
else:
|
||||
try:
|
||||
tmp = CandidateQuestion.model_validate(question)
|
||||
sanitized["user_questions"].append({"question": tmp.question})
|
||||
except Exception as e:
|
||||
continue
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def get_users(cls):
|
||||
# Initialize an empty list to store parsed JSON data
|
||||
user_data = []
|
||||
|
||||
# Define the users directory path
|
||||
users_dir = os.path.join(defines.user_dir)
|
||||
|
||||
# Check if the users directory exists
|
||||
if not os.path.exists(users_dir):
|
||||
return user_data
|
||||
|
||||
# Iterate through all items in the users directory
|
||||
for item in os.listdir(users_dir):
|
||||
# Construct the full path to the item
|
||||
item_path = os.path.join(users_dir, item)
|
||||
|
||||
# Check if the item is a directory
|
||||
if os.path.isdir(item_path):
|
||||
# Construct the path to info.json
|
||||
info_path = os.path.join(item_path, "info.json")
|
||||
|
||||
# Check if info.json exists
|
||||
if os.path.exists(info_path):
|
||||
try:
|
||||
# Read and parse the JSON file
|
||||
with open(info_path, 'r') as file:
|
||||
data = json.load(file)
|
||||
data["username"] = item
|
||||
profile_image = os.path.join(defines.user_dir, item, "profile.png")
|
||||
data["has_profile"] = os.path.exists(profile_image)
|
||||
user_data.append(data)
|
||||
except json.JSONDecodeError as e:
|
||||
# Skip files that aren't valid JSON
|
||||
logger.info(f"Invalid JSON for {info_path}: {str(e)}")
|
||||
continue
|
||||
except Exception as e:
|
||||
# Skip files that can't be read
|
||||
logger.info(f"Exception processing {info_path}: {str(e)}")
|
||||
continue
|
||||
|
||||
return user_data
|
||||
|
||||
async def initialize(self, prometheus_collector: CollectorRegistry):
|
||||
if self.CandidateEntity__initialized:
|
||||
# Initialization can only be attempted once; if there are multiple attempts, it means
|
||||
# a subsystem is failing or there is a logic bug in the code.
|
||||
#
|
||||
# NOTE: It is intentional that self.CandidateEntity__initialize = True regardless of whether it
|
||||
# succeeded. This prevents server loops on failure
|
||||
raise ValueError("initialize can only be attempted once")
|
||||
self.CandidateEntity__initialized = True
|
||||
|
||||
if not self.username:
|
||||
raise ValueError("username can not be empty")
|
||||
|
||||
user_dir = os.path.join(defines.user_dir, self.username)
|
||||
vector_db_dir=os.path.join(user_dir, defines.persist_directory)
|
||||
rag_content_dir=os.path.join(user_dir, defines.rag_content_dir)
|
||||
|
||||
logger.info(f"CandidateEntity(username={self.username}, user_dir={user_dir} persist_directory={vector_db_dir}, rag_content_dir={rag_content_dir}")
|
||||
|
||||
os.makedirs(vector_db_dir, exist_ok=True)
|
||||
os.makedirs(rag_content_dir, exist_ok=True)
|
||||
|
||||
if prometheus_collector:
|
||||
self.CandidateEntity__prometheus_collector = prometheus_collector
|
||||
|
||||
self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher(
|
||||
llm=llm_manager.get_llm(),
|
||||
collection_name=self.username,
|
||||
persist_directory=vector_db_dir,
|
||||
watch_directory=rag_content_dir,
|
||||
recreate=False, # Don't recreate if exists
|
||||
)
|
||||
has_username_rag = any(item["name"] == self.username for item in self.rags)
|
||||
if not has_username_rag:
|
||||
self.rags.append(RagEntry(
|
||||
name=self.username,
|
||||
description=f"Expert data about {self.full_name}.",
|
||||
))
|
||||
self.rag_content_size = self.file_watcher.collection.count()
|
||||
|
||||
CandidateEntity.model_rebuild()
|
122
src/backend/entities/entity_manager.py
Normal file
122
src/backend/entities/entity_manager.py
Normal file
@ -0,0 +1,122 @@
|
||||
import asyncio
|
||||
import weakref
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Any
|
||||
from contextlib import asynccontextmanager
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
|
||||
from models import ( Candidate )
|
||||
from .candidate_entity import CandidateEntity
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
|
||||
class EntityManager:
|
||||
"""Manages lifecycle of CandidateEntity instances"""
|
||||
|
||||
def __init__(self, default_ttl_minutes: int = 30):
|
||||
self._entities: Dict[str, CandidateEntity] = {}
|
||||
self._weak_refs: Dict[str, weakref.ReferenceType] = {}
|
||||
self._ttl_minutes = default_ttl_minutes
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._prometheus_collector: Optional[CollectorRegistry] = None
|
||||
|
||||
async def start_cleanup_task(self):
|
||||
"""Start background cleanup task"""
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
|
||||
async def stop_cleanup_task(self):
|
||||
"""Stop background cleanup task"""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
def initialize(self, prometheus_collector: CollectorRegistry):
|
||||
"""Initialize the EntityManager with Prometheus collector"""
|
||||
self._prometheus_collector = prometheus_collector
|
||||
|
||||
async def get_entity(self, candidate: Candidate) -> CandidateEntity:
|
||||
"""Get or create CandidateEntity with proper reference tracking"""
|
||||
|
||||
# Check if entity exists and is still valid
|
||||
if candidate.id in self._entities:
|
||||
entity = self._entities[candidate.id]
|
||||
entity._last_accessed = datetime.now()
|
||||
entity._reference_count += 1
|
||||
return entity
|
||||
|
||||
entity = CandidateEntity(candidate=candidate)
|
||||
await entity.initialize(prometheus_collector=self._prometheus_collector)
|
||||
|
||||
# Store with reference tracking
|
||||
self._entities[candidate.id] = entity
|
||||
self._weak_refs[candidate.id] = weakref.ref(entity, self._on_entity_deleted(candidate.id))
|
||||
|
||||
entity._reference_count = 1
|
||||
entity._last_accessed = datetime.now()
|
||||
|
||||
return entity
|
||||
|
||||
def _on_entity_deleted(self, user_id: str):
|
||||
"""Callback when entity is garbage collected"""
|
||||
def cleanup_callback(weak_ref):
|
||||
self._entities.pop(user_id, None)
|
||||
self._weak_refs.pop(user_id, None)
|
||||
print(f"Entity {user_id} garbage collected")
|
||||
return cleanup_callback
|
||||
|
||||
async def release_entity(self, user_id: str):
|
||||
"""Explicitly release reference to entity"""
|
||||
if user_id in self._entities:
|
||||
entity = self._entities[user_id]
|
||||
entity._reference_count = max(0, entity._reference_count - 1)
|
||||
entity._last_accessed = datetime.now()
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Background task to clean up expired entities"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(60) # Check every minute
|
||||
await self._cleanup_expired_entities()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error in cleanup task: {e}")
|
||||
|
||||
async def _cleanup_expired_entities(self):
|
||||
"""Remove entities that have expired based on TTL and reference count"""
|
||||
current_time = datetime.now()
|
||||
expired_entities = []
|
||||
|
||||
for user_id, entity in list(self._entities.items()):
|
||||
time_since_access = current_time - entity._last_accessed
|
||||
|
||||
# Remove if TTL exceeded and no active references
|
||||
if (time_since_access > timedelta(minutes=self._ttl_minutes)
|
||||
and entity._reference_count == 0):
|
||||
expired_entities.append(user_id)
|
||||
|
||||
for user_id in expired_entities:
|
||||
entity = self._entities.pop(user_id, None)
|
||||
self._weak_refs.pop(user_id, None)
|
||||
if entity:
|
||||
await entity.cleanup()
|
||||
print(f"Cleaned up expired entity {user_id}")
|
||||
|
||||
|
||||
# Global entity manager instance
|
||||
entity_manager = EntityManager(default_ttl_minutes=30)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_candidate_entity(candidate: Candidate):
|
||||
"""Context manager for safe entity access with automatic reference management"""
|
||||
if not entity_manager._prometheus_collector:
|
||||
raise ValueError("EntityManager has not been initialized with a Prometheus collector.")
|
||||
entity = await entity_manager.get_entity(candidate=candidate)
|
||||
try:
|
||||
yield entity
|
||||
finally:
|
||||
await entity_manager.release_entity(candidate.id)
|
@ -2,7 +2,7 @@
|
||||
"""
|
||||
Enhanced Type Generator - Generate TypeScript types from Pydantic models
|
||||
Now with command line parameters, pre-test validation, TypeScript compilation,
|
||||
and automatic date field conversion functions
|
||||
automatic date field conversion functions, and proper enum default handling
|
||||
"""
|
||||
|
||||
import sys
|
||||
@ -138,8 +138,55 @@ def is_date_type(python_type: Any) -> bool:
|
||||
|
||||
return False
|
||||
|
||||
def python_type_to_typescript(python_type: Any, debug: bool = False) -> str:
|
||||
"""Convert a Python type to TypeScript type string"""
|
||||
def get_default_enum_value(field_info: Any, debug: bool = False) -> Optional[Any]:
|
||||
"""Extract the specific enum value from a field's default, if it exists"""
|
||||
if not hasattr(field_info, 'default'):
|
||||
return None
|
||||
|
||||
default_val = field_info.default
|
||||
|
||||
if debug:
|
||||
print(f" 🔍 Checking default value: {repr(default_val)} (type: {type(default_val)})")
|
||||
|
||||
# Check for different types of "no default" markers
|
||||
if default_val is ... or default_val is None:
|
||||
if debug:
|
||||
print(f" └─ Default is undefined marker")
|
||||
return None
|
||||
|
||||
# Check for Pydantic's internal "PydanticUndefined" or similar markers
|
||||
default_str = str(default_val)
|
||||
default_type_str = str(type(default_val))
|
||||
|
||||
undefined_patterns = [
|
||||
'PydanticUndefined',
|
||||
'Undefined',
|
||||
'_Unset',
|
||||
'UNSET',
|
||||
'NotSet',
|
||||
'_MISSING'
|
||||
]
|
||||
|
||||
is_undefined_marker = any(pattern in default_str or pattern in default_type_str
|
||||
for pattern in undefined_patterns)
|
||||
|
||||
if is_undefined_marker:
|
||||
if debug:
|
||||
print(f" └─ Default is undefined marker pattern")
|
||||
return None
|
||||
|
||||
# Check if it's an enum instance
|
||||
if isinstance(default_val, Enum):
|
||||
if debug:
|
||||
print(f" └─ Default is enum instance: {default_val.value}")
|
||||
return default_val
|
||||
|
||||
if debug:
|
||||
print(f" └─ Default is not an enum instance")
|
||||
return None
|
||||
|
||||
def python_type_to_typescript(python_type: Any, field_info: Any = None, debug: bool = False) -> str:
|
||||
"""Convert a Python type to TypeScript type string, considering field defaults"""
|
||||
|
||||
if debug:
|
||||
print(f" 🔍 Converting type: {python_type} (type: {type(python_type)})")
|
||||
@ -151,6 +198,14 @@ def python_type_to_typescript(python_type: Any, debug: bool = False) -> str:
|
||||
if debug and original_type != python_type:
|
||||
print(f" 🔄 Unwrapped: {original_type} -> {python_type}")
|
||||
|
||||
# Check if this field has a specific enum default value
|
||||
if field_info:
|
||||
default_enum = get_default_enum_value(field_info, debug)
|
||||
if default_enum is not None:
|
||||
if debug:
|
||||
print(f" 🎯 Field has specific enum default: {default_enum.value}")
|
||||
return f'"{default_enum.value}"'
|
||||
|
||||
# Handle None/null
|
||||
if python_type is type(None):
|
||||
return "null"
|
||||
@ -180,22 +235,22 @@ def python_type_to_typescript(python_type: Any, debug: bool = False) -> str:
|
||||
# Handle Optional (Union[T, None])
|
||||
if len(args) == 2 and type(None) in args:
|
||||
non_none_type = next(arg for arg in args if arg is not type(None))
|
||||
return python_type_to_typescript(non_none_type, debug)
|
||||
return python_type_to_typescript(non_none_type, field_info, debug)
|
||||
|
||||
# Handle other unions
|
||||
union_types = [python_type_to_typescript(arg, debug) for arg in args if arg is not type(None)]
|
||||
union_types = [python_type_to_typescript(arg, None, debug) for arg in args if arg is not type(None)]
|
||||
return " | ".join(union_types)
|
||||
|
||||
elif origin is list or origin is List:
|
||||
if args:
|
||||
item_type = python_type_to_typescript(args[0], debug)
|
||||
item_type = python_type_to_typescript(args[0], None, debug)
|
||||
return f"Array<{item_type}>"
|
||||
return "Array<any>"
|
||||
|
||||
elif origin is dict or origin is Dict:
|
||||
if len(args) == 2:
|
||||
key_type = python_type_to_typescript(args[0], debug)
|
||||
value_type = python_type_to_typescript(args[1], debug)
|
||||
key_type = python_type_to_typescript(args[0], None, debug)
|
||||
value_type = python_type_to_typescript(args[1], None, debug)
|
||||
return f"Record<{key_type}, {value_type}>"
|
||||
return "Record<string, any>"
|
||||
|
||||
@ -319,6 +374,13 @@ def is_field_optional(field_info: Any, field_type: Any, debug: bool = False) ->
|
||||
if debug:
|
||||
print(f" └─ RESULT: Required (default is undefined marker)")
|
||||
return False
|
||||
|
||||
# Special case: if field has a specific default value (like enum), it's required
|
||||
# because it will always have a value, just not optional for the consumer
|
||||
if isinstance(default_val, Enum):
|
||||
if debug:
|
||||
print(f" └─ RESULT: Required (has specific enum default: {default_val.value})")
|
||||
return False
|
||||
|
||||
# Any other actual default value makes it optional
|
||||
if debug:
|
||||
@ -398,7 +460,8 @@ def process_pydantic_model(model_class, debug: bool = False) -> Dict[str, Any]:
|
||||
elif debug and ('date' in str(field_type).lower() or 'time' in str(field_type).lower()):
|
||||
print(f" ⚠️ Field {ts_name} contains 'date'/'time' but not detected as date type: {field_type}")
|
||||
|
||||
ts_type = python_type_to_typescript(field_type, debug)
|
||||
# Pass field_info to the type converter for default enum handling
|
||||
ts_type = python_type_to_typescript(field_type, field_info, debug)
|
||||
|
||||
# Check if optional
|
||||
is_optional = is_field_optional(field_info, field_type, debug)
|
||||
@ -449,30 +512,16 @@ def process_pydantic_model(model_class, debug: bool = False) -> Dict[str, Any]:
|
||||
elif debug and ('date' in str(field_type).lower() or 'time' in str(field_type).lower()):
|
||||
print(f" ⚠️ Field {ts_name} contains 'date'/'time' but not detected as date type: {field_type}")
|
||||
|
||||
ts_type = python_type_to_typescript(field_type, debug)
|
||||
# Pass field_info to the type converter for default enum handling
|
||||
ts_type = python_type_to_typescript(field_type, field_info, debug)
|
||||
|
||||
# For Pydantic v1, check required and default
|
||||
is_optional = is_field_optional(field_info, field_type)
|
||||
is_optional = is_field_optional(field_info, field_type, debug)
|
||||
|
||||
if debug:
|
||||
print(f" TS name: {ts_name}")
|
||||
print(f" TS type: {ts_type}")
|
||||
print(f" Optional: {is_optional}")
|
||||
|
||||
# Debug the optional logic
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
is_union_with_none = origin is Union and type(None) in args
|
||||
has_default = hasattr(field_info, 'default')
|
||||
has_default_factory = hasattr(field_info, 'default_factory') and field_info.default_factory is not None
|
||||
|
||||
print(f" └─ Type is Optional: {is_union_with_none}")
|
||||
if has_default:
|
||||
default_val = field_info.default
|
||||
print(f" └─ Has default: {default_val} (is ...? {default_val is ...})")
|
||||
else:
|
||||
print(f" └─ No default attribute")
|
||||
print(f" └─ Has default factory: {has_default_factory}")
|
||||
print()
|
||||
|
||||
properties.append({
|
||||
@ -730,7 +779,7 @@ def compile_typescript(ts_file: str) -> bool:
|
||||
def main():
|
||||
"""Main function with command line argument parsing"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Generate TypeScript types from Pydantic models with date conversion functions',
|
||||
description='Generate TypeScript types from Pydantic models with date conversion functions and proper enum default handling',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
@ -744,6 +793,9 @@ Examples:
|
||||
Generated conversion functions can be used like:
|
||||
const candidate = convertCandidateFromApi(apiResponse);
|
||||
const jobs = convertArrayFromApi<Job>(apiResponse, 'Job');
|
||||
|
||||
Enum defaults are now properly handled:
|
||||
status: ChatStatusType = ChatStatusType.DONE -> status: "done"
|
||||
"""
|
||||
)
|
||||
|
||||
@ -780,12 +832,12 @@ Generated conversion functions can be used like:
|
||||
parser.add_argument(
|
||||
'--version', '-v',
|
||||
action='version',
|
||||
version='TypeScript Generator 3.0 (with Date Conversion)'
|
||||
version='TypeScript Generator 3.1 (with Enum Default Handling)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("🚀 Enhanced TypeScript Type Generator with Date Conversion")
|
||||
print("🚀 Enhanced TypeScript Type Generator with Enum Default Handling")
|
||||
print("=" * 60)
|
||||
print(f"📁 Source file: {args.source}")
|
||||
print(f"📁 Output file: {args.output}")
|
||||
@ -830,25 +882,30 @@ Generated conversion functions can be used like:
|
||||
|
||||
# Count conversion functions and provide detailed feedback
|
||||
conversion_count = ts_content.count('export function convert') - ts_content.count('convertFromApi') - ts_content.count('convertArrayFromApi')
|
||||
enum_specific_count = ts_content.count(': "') - ts_content.count('export type')
|
||||
|
||||
if conversion_count > 0:
|
||||
print(f"🗓️ Generated {conversion_count} date conversion functions")
|
||||
if args.debug:
|
||||
# Show which models have date conversion
|
||||
models_with_dates = []
|
||||
for line in ts_content.split('\n'):
|
||||
if line.startswith('export function convert') and 'FromApi' in line and 'convertFromApi' not in line:
|
||||
model_name = line.split('convert')[1].split('FromApi')[0]
|
||||
models_with_dates.append(model_name)
|
||||
if models_with_dates:
|
||||
print(f" Models with date conversion: {', '.join(models_with_dates)}")
|
||||
if enum_specific_count > 0:
|
||||
print(f"🎯 Generated {enum_specific_count} specific enum default types")
|
||||
|
||||
if args.debug:
|
||||
# Show which models have date conversion
|
||||
models_with_dates = []
|
||||
for line in ts_content.split('\n'):
|
||||
if line.startswith('export function convert') and 'FromApi' in line and 'convertFromApi' not in line:
|
||||
model_name = line.split('convert')[1].split('FromApi')[0]
|
||||
models_with_dates.append(model_name)
|
||||
if models_with_dates:
|
||||
print(f" Models with date conversion: {', '.join(models_with_dates)}")
|
||||
|
||||
# Provide troubleshooting info if debug mode
|
||||
if args.debug:
|
||||
print(f"\n🐛 Debug mode was enabled. If you see incorrect date conversions:")
|
||||
print(f" 1. Check the debug output above for '📅 Date type check' lines")
|
||||
print(f" 2. Look for '⚠️' warnings about false positives")
|
||||
print(f" 3. Verify your Pydantic model field types are correct")
|
||||
print(f" 4. Re-run with --debug to see detailed type analysis")
|
||||
print(f"\n🐛 Debug mode was enabled. If you see incorrect type conversions:")
|
||||
print(f" 1. Check the debug output above for '🎯 Field has specific enum default' lines")
|
||||
print(f" 2. Look for '📅 Date type check' lines for date handling")
|
||||
print(f" 3. Look for '⚠️' warnings about fallback types")
|
||||
print(f" 4. Verify your Pydantic model field types and defaults are correct")
|
||||
|
||||
# Step 5: Compile TypeScript (unless skipped)
|
||||
if not args.skip_compile:
|
||||
@ -866,6 +923,8 @@ Generated conversion functions can be used like:
|
||||
print(f"✅ File size: {file_size} characters")
|
||||
if conversion_count > 0:
|
||||
print(f"✅ Date conversion functions: {conversion_count}")
|
||||
if enum_specific_count > 0:
|
||||
print(f"✅ Specific enum default types: {enum_specific_count}")
|
||||
if not args.skip_test:
|
||||
print("✅ Model validation passed")
|
||||
if not args.skip_compile:
|
||||
|
0
src/backend/helpers/__init__.py
Normal file
0
src/backend/helpers/__init__.py
Normal file
60
src/backend/helpers/check_serializable.py
Normal file
60
src/backend/helpers/check_serializable.py
Normal file
@ -0,0 +1,60 @@
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
import json
|
||||
from typing import Any, List, Set
|
||||
|
||||
def check_serializable(obj: Any, path: str = "", errors: List[str] = [], visited: Set[int] = set()) -> List[str]:
|
||||
"""
|
||||
Recursively check all fields in an object for non-JSON-serializable types, avoiding infinite recursion.
|
||||
Skips fields in Pydantic models marked with Field(..., exclude=True).
|
||||
Args:
|
||||
obj: The object to inspect (Pydantic model, dict, list, or other).
|
||||
path: The current field path (e.g., 'field1.nested_field').
|
||||
errors: List to collect error messages.
|
||||
visited: Set of object IDs to track visited objects and prevent infinite recursion.
|
||||
Returns:
|
||||
List of error messages for non-serializable fields.
|
||||
"""
|
||||
# Check for circular reference by object ID
|
||||
obj_id = id(obj)
|
||||
if obj_id in visited:
|
||||
errors.append(f"Field '{path}' contains a circular reference, skipping further inspection")
|
||||
return errors
|
||||
|
||||
# Add current object to visited set
|
||||
visited.add(obj_id)
|
||||
|
||||
try:
|
||||
# Handle Pydantic models
|
||||
if isinstance(obj, BaseModel):
|
||||
for field_name, field_info in obj.model_fields.items():
|
||||
# Skip fields marked with exclude=True
|
||||
if field_info.exclude:
|
||||
continue
|
||||
value = getattr(obj, field_name)
|
||||
new_path = f"{path}.{field_name}" if path else field_name
|
||||
check_serializable(value, new_path, errors, visited)
|
||||
|
||||
# Handle dictionaries
|
||||
elif isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
new_path = f"{path}[{key}]" if path else str(key)
|
||||
check_serializable(value, new_path, errors, visited)
|
||||
|
||||
# Handle lists, tuples, or other iterables
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for i, value in enumerate(obj):
|
||||
new_path = f"{path}[{i}]" if path else str(i)
|
||||
check_serializable(value, new_path, errors, visited)
|
||||
|
||||
# Handle other types (check for JSON serializability)
|
||||
else:
|
||||
try:
|
||||
json.dumps(obj)
|
||||
except (TypeError, OverflowError, ValueError) as e:
|
||||
errors.append(f"Field '{path}' contains non-serializable type: {type(obj)} ({str(e)})")
|
||||
|
||||
finally:
|
||||
# Remove the current object from visited to allow processing in other branches
|
||||
visited.discard(obj_id)
|
||||
|
||||
return errors
|
@ -21,7 +21,7 @@ import uuid
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, EmailStr, validator # type: ignore
|
||||
from pydantic import BaseModel, EmailStr, field_validator # type: ignore
|
||||
# Prometheus
|
||||
from prometheus_client import Summary # type: ignore
|
||||
from prometheus_fastapi_instrumentator import Instrumentator # type: ignore
|
||||
@ -38,11 +38,11 @@ from auth_utils import (
|
||||
)
|
||||
import model_cast
|
||||
import defines
|
||||
import agents
|
||||
from logger import logger
|
||||
from database import RedisDatabase, redis_manager, DatabaseManager
|
||||
from metrics import Metrics
|
||||
from llm_manager import llm_manager
|
||||
import entities
|
||||
|
||||
# =============================
|
||||
# Import Pydantic models
|
||||
@ -149,11 +149,11 @@ class LoginRequest(BaseModel):
|
||||
login: str # Can be email or username
|
||||
password: str
|
||||
|
||||
@validator('login')
|
||||
@field_validator('login')
|
||||
def sanitize_login(cls, v):
|
||||
return sanitize_login_input(v)
|
||||
|
||||
@validator('password')
|
||||
@field_validator('password')
|
||||
def validate_password_not_empty(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError('Password cannot be empty')
|
||||
@ -168,13 +168,13 @@ class CreateCandidateRequest(BaseModel):
|
||||
# Add other required candidate fields as needed
|
||||
phone: Optional[str] = None
|
||||
|
||||
@validator('username')
|
||||
@field_validator('username')
|
||||
def validate_username(cls, v):
|
||||
if not v or len(v.strip()) < 3:
|
||||
raise ValueError('Username must be at least 3 characters long')
|
||||
return v.strip().lower()
|
||||
|
||||
@validator('password')
|
||||
@field_validator('password')
|
||||
def validate_password_strength(cls, v):
|
||||
is_valid, issues = validate_password_strength(v)
|
||||
if not is_valid:
|
||||
@ -194,13 +194,13 @@ class CreateEmployerRequest(BaseModel):
|
||||
websiteUrl: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
|
||||
@validator('username')
|
||||
@field_validator('username')
|
||||
def validate_username(cls, v):
|
||||
if not v or len(v.strip()) < 3:
|
||||
raise ValueError('Username must be at least 3 characters long')
|
||||
return v.strip().lower()
|
||||
|
||||
@validator('password')
|
||||
@field_validator('password')
|
||||
def validate_password_strength(cls, v):
|
||||
is_valid, issues = validate_password_strength(v)
|
||||
if not is_valid:
|
||||
@ -1004,7 +1004,7 @@ class PasswordResetConfirm(BaseModel):
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
@validator('new_password')
|
||||
@field_validator('new_password')
|
||||
def validate_password_strength(cls, v):
|
||||
is_valid, issues = validate_password_strength(v)
|
||||
if not is_valid:
|
||||
@ -1401,7 +1401,6 @@ async def create_chat_session(
|
||||
|
||||
@api_router.post("/chat/sessions/{session_id}/messages/stream")
|
||||
async def post_chat_session_message_stream(
|
||||
session_id: str = Path(...),
|
||||
data: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
@ -1409,113 +1408,112 @@ async def post_chat_session_message_stream(
|
||||
):
|
||||
"""Post a message to a chat session and stream the response with persistence"""
|
||||
try:
|
||||
chat_session_data = await database.get_chat_session(session_id)
|
||||
user_message_data = data.get("chatMessage")
|
||||
if not user_message_data:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("INVALID_CHAT_MESSAGE", "chatMessage cannot be empty")
|
||||
)
|
||||
user_message = ChatMessageUser.model_validate(user_message_data)
|
||||
|
||||
chat_session_data = await database.get_chat_session(user_message.session_id)
|
||||
if not chat_session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
chat_session = ChatSession.model_validate(chat_session_data)
|
||||
chat_type = chat_session.context.type
|
||||
candidate_info = chat_session.context.additional_context.get("candidateInfo", {})
|
||||
|
||||
chat_type = chat_session_data.get("context", {}).get("type", "general")
|
||||
|
||||
# Get candidate info if this chat is about a specific candidate
|
||||
candidate_info = chat_session_data.get("context", {}).get("additionalContext", {}).get("candidateInfo")
|
||||
if candidate_info:
|
||||
logger.info(f"🔗 Chat session {session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}")
|
||||
logger.info(f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}")
|
||||
else:
|
||||
logger.info(f"🔗 Chat session {session_id} type {chat_type} accessed by user {current_user.id}")
|
||||
|
||||
query = data.get("query")
|
||||
if not query:
|
||||
logger.info(f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("INVALID_QUERY", "Query cannot be empty")
|
||||
content=create_error_response("CANDIDATE_REQUIRED", "This chat session requires a candidate association")
|
||||
)
|
||||
|
||||
chat_query = ChatQuery.model_validate(query)
|
||||
chat_agent = agents.get_or_create_agent(
|
||||
agent_type=chat_type,
|
||||
prometheus_collector=prometheus_collector,
|
||||
database=database
|
||||
)
|
||||
|
||||
if not chat_agent:
|
||||
candidate_data = await database.get_candidate(candidate_info["id"]) if candidate_info else None
|
||||
candidate : Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None
|
||||
if not candidate:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type")
|
||||
status_code=404,
|
||||
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session")
|
||||
)
|
||||
logger.info(f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query: {user_message.content}")
|
||||
|
||||
# Store the user's message first
|
||||
user_message = ChatMessageUser(
|
||||
session_id=session_id,
|
||||
type=ChatMessageType.USER,
|
||||
status=ChatStatusType.DONE,
|
||||
sender=ChatSenderType.USER,
|
||||
content=chat_query.prompt,
|
||||
timestamp=datetime.now(UTC)
|
||||
)
|
||||
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
|
||||
# Entity automatically released when done
|
||||
chat_agent = candidate_entity.get_or_create_agent(agent_type=chat_type)
|
||||
if not chat_agent:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type")
|
||||
)
|
||||
|
||||
# Persist user message to database
|
||||
await database.add_chat_message(session_id, user_message.model_dump())
|
||||
logger.info(f"💬 User message saved to database for session {session_id}")
|
||||
|
||||
# Update session last activity
|
||||
chat_session_data["lastActivity"] = datetime.now(UTC).isoformat()
|
||||
await database.set_chat_session(session_id, chat_session_data)
|
||||
|
||||
async def message_stream_generator():
|
||||
"""Generator to stream messages with persistence"""
|
||||
last_log = None
|
||||
ai_message = None
|
||||
# Persist user message to database
|
||||
await database.add_chat_message(user_message.session_id, user_message.model_dump())
|
||||
logger.info(f"💬 User message saved to database for session {user_message.session_id}")
|
||||
|
||||
async for chat_message in chat_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
query=chat_query,
|
||||
user_message=user_message,
|
||||
user_id=current_user.id,
|
||||
):
|
||||
# Store reference to the complete AI message
|
||||
if chat_message.status == ChatStatusType.DONE:
|
||||
ai_message = chat_message
|
||||
|
||||
# If the message is not done, convert it to a ChatMessageBase to remove
|
||||
# metadata and other unnecessary fields for streaming
|
||||
if chat_message.status != ChatStatusType.DONE:
|
||||
chat_message = model_cast.cast_to_model(ChatMessageBase, chat_message)
|
||||
# Update session last activity
|
||||
chat_session_data["lastActivity"] = datetime.now(UTC).isoformat()
|
||||
await database.set_chat_session(user_message.session_id, chat_session_data)
|
||||
|
||||
json_data = chat_message.model_dump(mode='json', by_alias=True, exclude_unset=True)
|
||||
json_str = json.dumps(json_data)
|
||||
async def message_stream_generator():
|
||||
"""Generator to stream messages with persistence"""
|
||||
last_log = None
|
||||
final_message = None
|
||||
|
||||
log = f"🔗 Message status={chat_message.status}, sender={getattr(chat_message, 'sender', 'unknown')}"
|
||||
if last_log != log:
|
||||
last_log = log
|
||||
logger.info(log)
|
||||
|
||||
yield f"data: {json_str}\n\n"
|
||||
|
||||
# After streaming is complete, persist the final AI message to database
|
||||
if ai_message and ai_message.status == ChatStatusType.DONE:
|
||||
try:
|
||||
await database.add_chat_message(session_id, ai_message.model_dump())
|
||||
logger.info(f"🤖 AI message saved to database for session {session_id}")
|
||||
async for generated_message in chat_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
user_message=user_message,
|
||||
user_id=current_user.id,
|
||||
):
|
||||
# Store reference to the complete AI message
|
||||
if generated_message.status == ChatStatusType.DONE:
|
||||
final_message = generated_message
|
||||
|
||||
# Update session last activity again
|
||||
chat_session_data["lastActivity"] = datetime.now(UTC).isoformat()
|
||||
await database.set_chat_session(session_id, chat_session_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save AI message to database: {e}")
|
||||
# If the message is not done, convert it to a ChatMessageBase to remove
|
||||
# metadata and other unnecessary fields for streaming
|
||||
if generated_message.status != ChatStatusType.DONE:
|
||||
generated_message = model_cast.cast_to_model(ChatMessageBase, generated_message)
|
||||
|
||||
return StreamingResponse(
|
||||
message_stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
json_data = generated_message.model_dump(mode='json', by_alias=True, exclude_unset=True)
|
||||
json_str = json.dumps(json_data)
|
||||
|
||||
log = f"🔗 Message status={generated_message.status}, sender={getattr(generated_message, 'sender', 'unknown')}"
|
||||
if last_log != log:
|
||||
last_log = log
|
||||
logger.info(log)
|
||||
|
||||
yield f"data: {json_str}\n\n"
|
||||
|
||||
# After streaming is complete, persist the final AI message to database
|
||||
if final_message and final_message.status == ChatStatusType.DONE:
|
||||
try:
|
||||
await database.add_chat_message(final_message.session_id, final_message.model_dump())
|
||||
logger.info(f"🤖 AI message saved to database for session {final_message.session_id}")
|
||||
|
||||
# Update session last activity again
|
||||
chat_session_data["lastActivity"] = datetime.now(UTC).isoformat()
|
||||
await database.set_chat_session(final_message.session_id, chat_session_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save AI message to database: {e}")
|
||||
|
||||
return StreamingResponse(
|
||||
message_stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
@ -1779,6 +1777,7 @@ async def track_requests(request, call_next):
|
||||
# FastAPI Metrics
|
||||
# ============================
|
||||
prometheus_collector = CollectorRegistry()
|
||||
entities.entity_manager.initialize(prometheus_collector)
|
||||
|
||||
# Keep the Instrumentator instance alive
|
||||
instrumentator = Instrumentator(
|
||||
@ -1835,4 +1834,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
else:
|
||||
logger.info(f"Starting web server at http://{host}:{port}")
|
||||
uvicorn.run(app=app, host=host, port=port, log_config=None)
|
||||
uvicorn.run(app="main:app", host=host, port=port, log_config=None)
|
@ -388,6 +388,11 @@ class BaseUser(BaseModel):
|
||||
class BaseUserWithType(BaseUser):
|
||||
user_type: UserType = Field(..., alias="userType")
|
||||
|
||||
class RagEntry(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
|
||||
class Candidate(BaseUser):
|
||||
user_type: Literal[UserType.CANDIDATE] = Field(UserType.CANDIDATE, alias="userType")
|
||||
username: str
|
||||
@ -405,6 +410,8 @@ class Candidate(BaseUser):
|
||||
certifications: Optional[List[Certification]] = None
|
||||
job_applications: Optional[List["JobApplication"]] = Field(None, alias="jobApplications")
|
||||
has_profile: bool = Field(default=False, alias="hasProfile")
|
||||
rags: List[RagEntry] = Field(default_factory=list)
|
||||
rag_content_size : int = 0
|
||||
# Used for AI generated personas
|
||||
age: Optional[int] = None
|
||||
gender: Optional[UserGender] = None
|
||||
@ -540,11 +547,6 @@ class JobApplication(BaseModel):
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
}
|
||||
|
||||
class RagEntry(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
|
||||
class ChromaDBGetResponse(BaseModel):
|
||||
# Chroma fields
|
||||
ids: List[str] = []
|
||||
@ -573,6 +575,26 @@ class ChatOptions(BaseModel):
|
||||
num_ctx: Optional[int] = Field(default=None, alias="numCtx") # Number of context tokens
|
||||
temperature: Optional[float] = Field(default=0.7) # Higher temperature to encourage tool usage
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
role: str = Field(default="")
|
||||
content: str = Field(default="")
|
||||
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
|
||||
|
||||
|
||||
class ChatMessageBase(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
session_id: str = Field(..., alias="sessionId")
|
||||
sender_id: Optional[str] = Field(None, alias="senderId")
|
||||
status: ChatStatusType
|
||||
type: ChatMessageType
|
||||
sender: ChatSenderType
|
||||
timestamp: datetime
|
||||
tunables: Optional[Tunables] = None
|
||||
content: str = ""
|
||||
model_config = {
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
}
|
||||
|
||||
class ChatMessageMetaData(BaseModel):
|
||||
model: AIModelType = AIModelType.QWEN2_5
|
||||
temperature: float = 0.7
|
||||
@ -581,8 +603,8 @@ class ChatMessageMetaData(BaseModel):
|
||||
frequency_penalty: Optional[float] = Field(None, alias="frequencyPenalty")
|
||||
presence_penalty: Optional[float] = Field(None, alias="presencePenalty")
|
||||
stop_sequences: Optional[List[str]] = Field(None, alias="stopSequences")
|
||||
tunables: Optional[Tunables] = None
|
||||
rag: List[ChromaDBGetResponse] = Field(default_factory=list)
|
||||
rag_results: List[ChromaDBGetResponse] = Field(default_factory=list, alias="ragResults")
|
||||
llm_history: List[LLMMessage] = Field(default_factory=list, alias="llmHistory")
|
||||
eval_count: int = 0
|
||||
eval_duration: int = 0
|
||||
prompt_eval_count: int = 0
|
||||
@ -594,28 +616,17 @@ class ChatMessageMetaData(BaseModel):
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
}
|
||||
|
||||
class ChatMessageBase(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
session_id: str = Field(..., alias="sessionId")
|
||||
sender_id: Optional[str] = Field(None, alias="senderId")
|
||||
status: ChatStatusType
|
||||
type: ChatMessageType
|
||||
sender: ChatSenderType
|
||||
timestamp: datetime
|
||||
content: str = ""
|
||||
model_config = {
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
}
|
||||
|
||||
class ChatMessageUser(ChatMessageBase):
|
||||
status: ChatStatusType = ChatStatusType.DONE
|
||||
type: ChatMessageType = ChatMessageType.USER
|
||||
sender: ChatSenderType = ChatSenderType.USER
|
||||
|
||||
class ChatMessage(ChatMessageBase):
|
||||
#attachments: Optional[List[Attachment]] = None
|
||||
#reactions: Optional[List[MessageReaction]] = None
|
||||
#is_edited: bool = Field(False, alias="isEdited")
|
||||
#edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory")
|
||||
metadata: ChatMessageMetaData = Field(None)
|
||||
metadata: ChatMessageMetaData = Field(default_factory=ChatMessageMetaData)
|
||||
|
||||
class ChatSession(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
8
src/backend/rag/__init__.py
Normal file
8
src/backend/rag/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from .rag import ChromaDBFileWatcher, ChromaDBGetResponse, start_file_watcher, RagEntry
|
||||
__all__ = [
|
||||
"ChromaDBFileWatcher",
|
||||
"ChromaDBGetResponse",
|
||||
"start_file_watcher",
|
||||
"RagEntry"
|
||||
]
|
||||
|
223
src/backend/rag/markdown_chunker.py
Normal file
223
src/backend/rag/markdown_chunker.py
Normal file
@ -0,0 +1,223 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Any, Optional, TypedDict, Tuple
|
||||
from markdown_it import MarkdownIt
|
||||
from markdown_it.tree import SyntaxTreeNode
|
||||
import traceback
|
||||
import logging
|
||||
import json
|
||||
|
||||
import defines
|
||||
|
||||
class Chunk(TypedDict):
|
||||
text: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
def clear_chunk(chunk: Chunk):
|
||||
chunk["text"] = ""
|
||||
chunk["metadata"] = {
|
||||
"doc_type": "unknown",
|
||||
"source_file": chunk["metadata"]["source_file"],
|
||||
"lines": chunk["metadata"]["lines"],
|
||||
"path": "", # This will be updated during processing
|
||||
"level": 0,
|
||||
}
|
||||
return chunk
|
||||
|
||||
class MarkdownChunker:
|
||||
def __init__(self):
|
||||
# Initialize the Markdown parser
|
||||
self.md_parser = MarkdownIt("commonmark")
|
||||
|
||||
def process_file(self, file_path: str) -> Optional[List[Chunk]]:
|
||||
"""
|
||||
Process a single markdown file and return chunks.
|
||||
|
||||
Args:
|
||||
file_path: Path to the markdown file
|
||||
|
||||
Returns:
|
||||
List of chunks with metadata or None if file can't be processed
|
||||
"""
|
||||
try:
|
||||
# Read the markdown file
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Parse the markdown
|
||||
tokens = self.md_parser.parse(content)
|
||||
logging.info(f"Found {len(tokens)} in {file_path}")
|
||||
|
||||
ast = SyntaxTreeNode(tokens)
|
||||
|
||||
# Extract chunks with metadata
|
||||
chunks = self.extract_chunks(ast, file_path, len(content.splitlines()))
|
||||
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing {file_path}: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
return None
|
||||
|
||||
def extract_chunks(self, ast: SyntaxTreeNode, file_path: str, total_lines: int) -> List[Chunk]:
|
||||
"""
|
||||
Extract logical chunks from the AST with appropriate metadata.
|
||||
|
||||
Args:
|
||||
ast: The Abstract Syntax Tree from markdown-it-py
|
||||
file_path: Path to the source file
|
||||
|
||||
Returns:
|
||||
List of chunks with metadata
|
||||
"""
|
||||
chunks: List[Chunk] = []
|
||||
current_headings: List[str] = [""] * 6 # Track h1-h6 headings
|
||||
|
||||
# Initialize a chunk structure
|
||||
chunk: Chunk = {
|
||||
"text": "",
|
||||
"metadata": {
|
||||
"source_file": file_path,
|
||||
"lines": total_lines
|
||||
},
|
||||
}
|
||||
clear_chunk(chunk)
|
||||
|
||||
# Process the AST recursively
|
||||
self._process_node(ast, current_headings, chunks, chunk, level=0)
|
||||
|
||||
return chunks
|
||||
|
||||
def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
k: ("" if v is None else v) for k, v in metadata.items() if v is not None
|
||||
}
|
||||
|
||||
def _extract_text_from_children(self, node: SyntaxTreeNode) -> str:
|
||||
lines = []
|
||||
for child in node.children:
|
||||
if child.type == "list_item":
|
||||
lines.append(f"- {self._extract_text_from_children(child).strip()}")
|
||||
elif child.type == "fence":
|
||||
info = f"{child.info}" if hasattr(child, "info") and child.info else ""
|
||||
lines.append(f"\n```{info}\n{child.content.strip()}\n```\n")
|
||||
elif child.children:
|
||||
lines.append(self._extract_text_from_children(child).strip())
|
||||
elif hasattr(child, "content"):
|
||||
lines.append(child.content.strip())
|
||||
return "\n".join(lines)
|
||||
|
||||
def _process_node(
|
||||
self,
|
||||
node: SyntaxTreeNode,
|
||||
current_headings: list[str],
|
||||
chunks: List[Chunk],
|
||||
chunk: Chunk,
|
||||
level: int,
|
||||
buffer: int = defines.chunk_buffer
|
||||
) -> int:
|
||||
is_list = False
|
||||
# Handle heading nodes
|
||||
if node.type == "heading":
|
||||
# print(f'{" " * level}{chunk["metadata"]["path"]} {node.type}')
|
||||
if node.is_nested and node.nester_tokens:
|
||||
opening, closing = node.nester_tokens
|
||||
level = int(opening.tag[1:]) - 1
|
||||
|
||||
heading_text = self._extract_text_from_children(node)
|
||||
|
||||
# Update lower heading states
|
||||
current_headings[level] = heading_text
|
||||
for i in range(level, 6):
|
||||
if i != level:
|
||||
current_headings[i] = ""
|
||||
|
||||
# Store previous chunk if it has text
|
||||
if chunk["text"].strip():
|
||||
path = " > ".join([h for h in current_headings if h])
|
||||
chunk["text"] = chunk["text"].strip()
|
||||
chunk["metadata"]["path"] = path
|
||||
chunk["metadata"]["level"] = level
|
||||
if node.nester_tokens:
|
||||
opening, closing = node.nester_tokens
|
||||
if opening and opening.map:
|
||||
( begin, end ) = opening.map
|
||||
metadata = chunk["metadata"]
|
||||
metadata["chunk_begin"] = max(0, begin - buffer)
|
||||
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
|
||||
metadata["line_begin"] = begin
|
||||
metadata["line_end"] = end
|
||||
|
||||
chunks.append(chunk.copy())
|
||||
clear_chunk(chunk)
|
||||
|
||||
# Add code block directly to current chunk
|
||||
elif node.type == "fence":
|
||||
# print(f'{" " * level}{chunk["metadata"]["path"]} {node.type}')
|
||||
language = node.info.strip() if hasattr(node, "info") and node.info else ""
|
||||
code_block = f"\n```{language}\n{node.content.strip()}\n```\n"
|
||||
chunk["text"] += code_block
|
||||
|
||||
# Handle list structures
|
||||
# elif node.type in ["list", "list_item"]:
|
||||
# # print(node.nester_tokens)
|
||||
# # print(node.pretty(show_text=True))
|
||||
# is_list = True
|
||||
# list_chunk = []
|
||||
# for child in node.children:
|
||||
# level = self._process_node(child, current_headings, chunks, chunk, level=level)
|
||||
|
||||
# Handle paragraph
|
||||
elif node.type in ["paragraph"]: # , "list", "list_item"]:
|
||||
text = self._extract_text_from_children(node)
|
||||
# indented_text = "\n".join([f'{"-" * (level + 2)}{line}' for line in text.split('\n')])
|
||||
# print(f'{"-" * level}{chunk["metadata"]["path"]} {node.type}:\n{indented_text} {len(node.children)}\n')
|
||||
text = text.strip()
|
||||
if text:
|
||||
# indented_text = "\n".join([f'{"-" * (level + 2)}{line}' for line in text.split('\n')])
|
||||
# print(f'{" " * level}{chunk["metadata"]["path"]} {node.type}:\n{indented_text}\n')
|
||||
chunk["text"] += f"\n{text}\n"
|
||||
chunk["text"] = chunk["text"].strip()
|
||||
if chunk["text"]:
|
||||
path = " > ".join([h for h in current_headings if h])
|
||||
chunk["text"] = chunk["text"]
|
||||
chunk["metadata"]["path"] = path
|
||||
chunk["metadata"]["level"] = level
|
||||
if node.nester_tokens:
|
||||
opening, closing = node.nester_tokens
|
||||
if opening and opening.map:
|
||||
( begin, end ) = opening.map
|
||||
metadata = chunk["metadata"]
|
||||
metadata["chunk_begin"] = max(0, begin - buffer)
|
||||
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
|
||||
metadata["line_begin"] = begin
|
||||
metadata["line_end"] = end
|
||||
chunks.append(chunk.copy())
|
||||
clear_chunk(chunk)
|
||||
|
||||
# Recursively process children
|
||||
if not is_list:
|
||||
for child in node.children:
|
||||
level = self._process_node(
|
||||
child, current_headings, chunks, chunk, level=level
|
||||
)
|
||||
|
||||
# After root-level recursion, finalize any remaining chunk
|
||||
if node.type == "document":
|
||||
# print(node.type)
|
||||
if chunk["text"].strip():
|
||||
path = " > ".join([h for h in current_headings if h])
|
||||
chunk["metadata"]["path"] = path
|
||||
if node.nester_tokens:
|
||||
opening, closing = node.nester_tokens
|
||||
if opening and opening.map:
|
||||
( begin, end ) = opening.map
|
||||
metadata = chunk["metadata"]
|
||||
metadata["chunk_begin"] = max(0, begin - buffer)
|
||||
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
|
||||
metadata["line_begin"] = begin
|
||||
metadata["line_end"] = end
|
||||
chunks.append(chunk.copy())
|
||||
|
||||
return level
|
727
src/backend/rag/rag.py
Normal file
727
src/backend/rag/rag.py
Normal file
@ -0,0 +1,727 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
import os
|
||||
import glob
|
||||
from pathlib import Path
|
||||
import time
|
||||
import hashlib
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import numpy as np # type: ignore
|
||||
import traceback
|
||||
|
||||
import chromadb # type: ignore
|
||||
import ollama
|
||||
from watchdog.observers import Observer # type: ignore
|
||||
from watchdog.events import FileSystemEventHandler # type: ignore
|
||||
import umap # type: ignore
|
||||
from markitdown import MarkItDown # type: ignore
|
||||
from chromadb.api.models.Collection import Collection # type: ignore
|
||||
|
||||
from .markdown_chunker import (
|
||||
MarkdownChunker,
|
||||
Chunk,
|
||||
)
|
||||
|
||||
# When imported as a module, use relative imports
|
||||
import defines
|
||||
|
||||
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "ChromaDBGetResponse"]
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 750
|
||||
DEFAULT_CHUNK_OVERLAP = 100
|
||||
|
||||
class RagEntry(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
|
||||
class ChromaDBGetResponse(BaseModel):
|
||||
name: str = ""
|
||||
size: int = 0
|
||||
ids: List[str] = []
|
||||
embeddings: List[List[float]] = Field(default=[])
|
||||
documents: List[str] = []
|
||||
metadatas: List[Dict[str, Any]] = []
|
||||
query: str = ""
|
||||
query_embedding: Optional[List[float]] = Field(default=None)
|
||||
umap_embedding_2d: Optional[List[float]] = Field(default=None)
|
||||
umap_embedding_3d: Optional[List[float]] = Field(default=None)
|
||||
enabled: bool = True
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
|
||||
@field_validator("embeddings", "query_embedding", "umap_embedding_2d", "umap_embedding_3d")
|
||||
@classmethod
|
||||
def validate_embeddings(cls, value, field):
|
||||
# logging.info(f"Validating {field.field_name} with value: {type(value)} - {value}")
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, np.ndarray):
|
||||
if field.field_name == "embeddings":
|
||||
if value.ndim != 2:
|
||||
raise ValueError(f"{field.name} must be a 2-dimensional NumPy array")
|
||||
return [[float(x) for x in row] for row in value.tolist()]
|
||||
else:
|
||||
if value.ndim != 1:
|
||||
raise ValueError(f"{field.field_name} must be a 1-dimensional NumPy array")
|
||||
return [float(x) for x in value.tolist()]
|
||||
if field.field_name == "embeddings":
|
||||
if not all(isinstance(sublist, list) and all(isinstance(x, (int, float)) for x in sublist) for sublist in value):
|
||||
raise ValueError(f"{field.field_name} must be a list of lists of floats")
|
||||
return [[float(x) for x in sublist] for sublist in value]
|
||||
else:
|
||||
if not isinstance(value, list) or not all(isinstance(x, (int, float)) for x in value):
|
||||
raise ValueError(f"{field.field_name} must be a list of floats")
|
||||
return [float(x) for x in value]
|
||||
|
||||
class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
def __init__(
|
||||
self,
|
||||
llm,
|
||||
watch_directory,
|
||||
loop,
|
||||
persist_directory,
|
||||
collection_name,
|
||||
chunk_size=DEFAULT_CHUNK_SIZE,
|
||||
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
|
||||
recreate=False,
|
||||
):
|
||||
self.llm = llm
|
||||
self.watch_directory = watch_directory
|
||||
self.persist_directory = persist_directory or defines.persist_directory
|
||||
self.collection_name = collection_name
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.loop = loop
|
||||
self._umap_collection: ChromaDBGetResponse | None = None
|
||||
self._umap_embedding_2d: np.ndarray = []
|
||||
self._umap_embedding_3d: np.ndarray = []
|
||||
self._umap_model_2d: umap.UMAP = None
|
||||
self._umap_model_3d: umap.UMAP = None
|
||||
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
|
||||
|
||||
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
# Path for storing file hash state
|
||||
self.hash_state_path = os.path.join(
|
||||
self.persist_directory, f"{collection_name}_hash_state.json"
|
||||
)
|
||||
|
||||
# Flag to track if this is a new collection
|
||||
self.is_new_collection = False
|
||||
|
||||
# Initialize ChromaDB collection
|
||||
self._collection: Collection = self._get_vector_collection(recreate=recreate)
|
||||
self._markdown_chunker = MarkdownChunker()
|
||||
self._update_umaps()
|
||||
|
||||
# Setup text splitter
|
||||
# Track file hashes and processing state
|
||||
self.file_hashes = self._load_hash_state()
|
||||
self.update_lock = asyncio.Lock()
|
||||
self.processing_files = set()
|
||||
|
||||
@property
|
||||
def collection(self) -> Collection:
|
||||
return self._collection
|
||||
|
||||
@property
|
||||
def umap_collection(self) -> ChromaDBGetResponse:
|
||||
if not self._umap_collection:
|
||||
raise ValueError("initialize_collection has not been called")
|
||||
return self._umap_collection
|
||||
|
||||
@property
|
||||
def umap_embedding_2d(self) -> np.ndarray:
|
||||
return self._umap_embedding_2d
|
||||
|
||||
@property
|
||||
def umap_embedding_3d(self) -> np.ndarray:
|
||||
return self._umap_embedding_3d
|
||||
|
||||
@property
|
||||
def umap_model_2d(self):
|
||||
return self._umap_model_2d
|
||||
|
||||
@property
|
||||
def umap_model_3d(self):
|
||||
return self._umap_model_3d
|
||||
|
||||
def _markitdown(self, document: str, markdown: Path):
|
||||
logging.info(f"Converting {document} to {markdown}")
|
||||
try:
|
||||
result = self.md.convert(document)
|
||||
markdown.write_text(result.text_content)
|
||||
except Exception as e:
|
||||
logging.error(f"Error convering via markdownit: {e}")
|
||||
|
||||
def _save_hash_state(self):
|
||||
"""Save the current file hash state to disk."""
|
||||
try:
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True)
|
||||
|
||||
with open(self.hash_state_path, "w") as f:
|
||||
json.dump(self.file_hashes, f)
|
||||
|
||||
logging.info(f"Saved hash state with {len(self.file_hashes)} entries")
|
||||
except Exception as e:
|
||||
logging.error(f"Error saving hash state: {e}")
|
||||
|
||||
def _load_hash_state(self):
|
||||
"""Load the file hash state from disk."""
|
||||
if os.path.exists(self.hash_state_path):
|
||||
try:
|
||||
with open(self.hash_state_path, "r") as f:
|
||||
hash_state = json.load(f)
|
||||
logging.info(f"Loaded hash state with {len(hash_state)} entries")
|
||||
return hash_state
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading hash state: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
async def scan_directory(self, process_all=False):
|
||||
"""
|
||||
Scan directory for new, modified, or deleted files and update collection.
|
||||
|
||||
Args:
|
||||
process_all: If True, process all files regardless of hash status
|
||||
"""
|
||||
# Check for new or modified files
|
||||
file_paths = glob.glob(
|
||||
os.path.join(self.watch_directory, "**/*"), recursive=True
|
||||
)
|
||||
files_checked = 0
|
||||
files_processed = 0
|
||||
files_to_process = []
|
||||
|
||||
logging.info(f"Starting directory scan. Found {len(file_paths)} total paths.")
|
||||
|
||||
for file_path in file_paths:
|
||||
if os.path.isfile(file_path):
|
||||
# Do not put the Resume in RAG as it is provideded with all queries.
|
||||
# if file_path == defines.resume_doc:
|
||||
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
||||
# continue
|
||||
files_checked += 1
|
||||
current_hash = self._get_file_hash(file_path)
|
||||
if not current_hash:
|
||||
logging.info(f"Unable to obtain hash of {file_path}")
|
||||
continue
|
||||
|
||||
# If file is new, changed, or we're processing all files
|
||||
if (
|
||||
process_all
|
||||
or file_path not in self.file_hashes
|
||||
or self.file_hashes[file_path] != current_hash
|
||||
):
|
||||
self.file_hashes[file_path] = current_hash
|
||||
files_to_process.append(file_path)
|
||||
logging.info(
|
||||
f"File {'found' if process_all else 'changed'}: {file_path}"
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Found {len(files_to_process)} files to process after scanning {files_checked} files"
|
||||
)
|
||||
|
||||
# Check for deleted files
|
||||
deleted_files = []
|
||||
for file_path in self.file_hashes:
|
||||
if not os.path.exists(file_path):
|
||||
deleted_files.append(file_path)
|
||||
# Schedule removal
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.remove_file_from_collection(file_path), self.loop
|
||||
)
|
||||
# Don't block on result, just let it run
|
||||
logging.info(f"File deleted: {file_path}")
|
||||
|
||||
# Remove deleted files from hash state
|
||||
for file_path in deleted_files:
|
||||
del self.file_hashes[file_path]
|
||||
|
||||
# Process all discovered files using asyncio.gather with the existing loop
|
||||
if files_to_process:
|
||||
logging.info(f"Starting to process {len(files_to_process)} files")
|
||||
|
||||
for file_path in files_to_process:
|
||||
async with self.update_lock:
|
||||
files_processed += 1
|
||||
await self._update_document_in_collection(file_path)
|
||||
else:
|
||||
logging.info("No files to process")
|
||||
|
||||
# Save the updated state
|
||||
self._save_hash_state()
|
||||
|
||||
logging.info(
|
||||
f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}"
|
||||
)
|
||||
return files_processed
|
||||
|
||||
async def process_file_update(self, file_path):
|
||||
"""Process a file update event."""
|
||||
# Skip if already being processed
|
||||
if file_path in self.processing_files:
|
||||
logging.info(f"{file_path} already in queue. Not adding.")
|
||||
return
|
||||
|
||||
# if file_path == defines.resume_doc:
|
||||
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
||||
# return
|
||||
|
||||
try:
|
||||
logging.info(f"{file_path} not in queue. Adding.")
|
||||
self.processing_files.add(file_path)
|
||||
|
||||
# Wait a moment to ensure the file write is complete
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Check if content changed via hash
|
||||
current_hash = self._get_file_hash(file_path)
|
||||
if not current_hash: # File might have been deleted or is inaccessible
|
||||
return
|
||||
|
||||
if (
|
||||
file_path in self.file_hashes
|
||||
and self.file_hashes[file_path] == current_hash
|
||||
):
|
||||
# File hasn't actually changed in content
|
||||
logging.info(f"Hash has not changed for {file_path}")
|
||||
return
|
||||
|
||||
# Update file hash
|
||||
self.file_hashes[file_path] = current_hash
|
||||
|
||||
# Process and update the file in ChromaDB
|
||||
async with self.update_lock:
|
||||
await self._update_document_in_collection(file_path)
|
||||
|
||||
# Save the hash state after successful update
|
||||
self._save_hash_state()
|
||||
|
||||
# Re-fit the UMAP for the new content
|
||||
self._update_umaps()
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing update for {file_path}: {e}")
|
||||
finally:
|
||||
self.processing_files.discard(file_path)
|
||||
|
||||
async def remove_file_from_collection(self, file_path):
|
||||
"""Remove all chunks related to a deleted file."""
|
||||
async with self.update_lock:
|
||||
try:
|
||||
# Find all documents with the specified path
|
||||
results = self.collection.get(where={"path": file_path})
|
||||
|
||||
if results and "ids" in results and results["ids"]:
|
||||
self.collection.delete(ids=results["ids"])
|
||||
logging.info(
|
||||
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
|
||||
)
|
||||
|
||||
# Remove from hash dictionary
|
||||
if file_path in self.file_hashes:
|
||||
del self.file_hashes[file_path]
|
||||
# Save the updated hash state
|
||||
self._save_hash_state()
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error removing file from collection: {e}")
|
||||
|
||||
def _update_umaps(self):
|
||||
# Update the UMAP embeddings
|
||||
self._umap_collection = self._collection.get(
|
||||
include=["embeddings", "documents", "metadatas"]
|
||||
)
|
||||
if not self._umap_collection or not len(self._umap_collection["embeddings"]):
|
||||
logging.warning("No embeddings found in the collection.")
|
||||
return
|
||||
|
||||
# During initialization
|
||||
logging.info(
|
||||
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
||||
)
|
||||
vectors = np.array(self._umap_collection["embeddings"])
|
||||
self._umap_model_2d = umap.UMAP(
|
||||
n_components=2,
|
||||
random_state=8911,
|
||||
metric="cosine",
|
||||
n_neighbors=30,
|
||||
min_dist=0.1,
|
||||
)
|
||||
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
|
||||
# logging.info(
|
||||
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
|
||||
# ) # Should be 2
|
||||
|
||||
logging.info(
|
||||
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
||||
)
|
||||
self._umap_model_3d = umap.UMAP(
|
||||
n_components=3,
|
||||
random_state=8911,
|
||||
metric="cosine",
|
||||
n_neighbors=30,
|
||||
min_dist=0.01,
|
||||
)
|
||||
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
|
||||
# logging.info(
|
||||
# f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
|
||||
# ) # Should be 3
|
||||
|
||||
def _get_vector_collection(self, recreate=False) -> Collection:
|
||||
"""Get or create a ChromaDB collection."""
|
||||
# Create the directory if it doesn't exist
|
||||
if not os.path.exists(self.persist_directory):
|
||||
os.makedirs(self.persist_directory)
|
||||
|
||||
# Initialize ChromaDB client
|
||||
chroma_client = chromadb.PersistentClient( # type: ignore
|
||||
path=self.persist_directory,
|
||||
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
|
||||
)
|
||||
|
||||
# Check if the collection exists
|
||||
try:
|
||||
chroma_client.get_collection(self.collection_name)
|
||||
collection_exists = True
|
||||
except:
|
||||
collection_exists = False
|
||||
|
||||
# If collection doesn't exist, mark it as new
|
||||
if not collection_exists:
|
||||
self.is_new_collection = True
|
||||
logging.info(f"Creating new collection: {self.collection_name}")
|
||||
|
||||
# Delete if recreate is True
|
||||
if recreate and collection_exists:
|
||||
chroma_client.delete_collection(name=self.collection_name)
|
||||
self.is_new_collection = True
|
||||
logging.info(f"Recreating collection: {self.collection_name}")
|
||||
|
||||
return chroma_client.get_or_create_collection(
|
||||
name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
def get_embedding(self, text: str) -> np.ndarray:
|
||||
"""Generate and normalize an embedding for the given text."""
|
||||
|
||||
# Get embedding
|
||||
try:
|
||||
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
|
||||
embedding = np.array(response["embedding"])
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get embedding: {e}")
|
||||
raise
|
||||
|
||||
# Log diagnostics
|
||||
logging.info(f"Input text: {text}")
|
||||
logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
|
||||
|
||||
# Check for invalid embeddings
|
||||
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
|
||||
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
|
||||
raise ValueError("Invalid embedding returned from Ollama.")
|
||||
|
||||
# Check normalization
|
||||
norm = np.linalg.norm(embedding)
|
||||
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
|
||||
logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
|
||||
|
||||
# Normalize if needed
|
||||
if not is_normalized:
|
||||
embedding = embedding / norm
|
||||
logging.info("Embedding normalized manually.")
|
||||
|
||||
return embedding
|
||||
|
||||
def add_embeddings_to_collection(self, chunks: List[Chunk]):
|
||||
"""Add embeddings for chunks to the collection."""
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
text = chunk["text"]
|
||||
metadata = chunk["metadata"]
|
||||
|
||||
# Generate a more unique ID based on content and metadata
|
||||
path_hash = ""
|
||||
if "path" in metadata:
|
||||
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[
|
||||
:8
|
||||
]
|
||||
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
|
||||
chunk_id = f"{path_hash}_{i}_{content_hash}"
|
||||
|
||||
embedding = self.get_embedding(text)
|
||||
try:
|
||||
self.collection.add(
|
||||
ids=[chunk_id],
|
||||
documents=[text],
|
||||
embeddings=[embedding],
|
||||
metadatas=[metadata],
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding chunk to collection: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
logging.error(chunk)
|
||||
|
||||
def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer)-> str | None:
|
||||
try:
|
||||
source_file = meta["source_file"]
|
||||
path_parts = source_file.split(os.sep)
|
||||
file_name = path_parts[-1]
|
||||
meta["source_file"] = file_name
|
||||
with open(source_file, "r") as file:
|
||||
lines = file.readlines()
|
||||
meta["file_lines"] = len(lines)
|
||||
start = max(0, meta["line_begin"] - buffer)
|
||||
meta["chunk_begin"] = start
|
||||
end = min(meta["lines"], meta["line_end"] + buffer)
|
||||
meta["chunk_end"] = end
|
||||
return "".join(lines[start:end])
|
||||
except:
|
||||
logging.warning(f"Unable to open {meta["source_file"]}")
|
||||
return None
|
||||
|
||||
# Cosine Distance Equivalent Similarity Retrieval Characteristics
|
||||
# 0.2 - 0.3 0.85 - 0.90 Very strict, highly precise results only
|
||||
# 0.3 - 0.5 0.75 - 0.85 Strong relevance, good precision
|
||||
# 0.5 - 0.7 0.65 - 0.75 Balanced precision/recall
|
||||
# 0.7 - 0.9 0.55 - 0.65 Higher recall, more inclusive
|
||||
# 0.9 - 1.2 0.40 - 0.55 Very inclusive, may include tangential content
|
||||
def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
|
||||
"""Find similar documents to the query."""
|
||||
|
||||
# collection is configured with hnsw:space cosine
|
||||
query_embedding = self.get_embedding(query)
|
||||
results = self.collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=top_k,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
|
||||
# Extract results
|
||||
ids = results["ids"][0]
|
||||
documents = results["documents"][0]
|
||||
distances = results["distances"][0]
|
||||
metadatas = results["metadatas"][0]
|
||||
|
||||
filtered_ids = []
|
||||
filtered_documents = []
|
||||
filtered_distances = []
|
||||
filtered_metadatas = []
|
||||
|
||||
for i, distance in enumerate(distances):
|
||||
if distance <= threshold: # For cosine distance, smaller is better
|
||||
filtered_ids.append(ids[i])
|
||||
filtered_documents.append(documents[i])
|
||||
filtered_metadatas.append(metadatas[i])
|
||||
filtered_distances.append(distance)
|
||||
|
||||
for index, meta in enumerate(filtered_metadatas):
|
||||
content = self.prepare_metadata(meta)
|
||||
if content is not None:
|
||||
filtered_documents[index] = content
|
||||
|
||||
# Return the filtered results instead of all results
|
||||
return {
|
||||
"query_embedding": query_embedding,
|
||||
"ids": filtered_ids,
|
||||
"documents": filtered_documents,
|
||||
"distances": filtered_distances,
|
||||
"metadatas": filtered_metadatas,
|
||||
}
|
||||
|
||||
def _get_file_hash(self, file_path):
|
||||
"""Calculate MD5 hash of a file."""
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
return hashlib.md5(f.read()).hexdigest()
|
||||
except Exception as e:
|
||||
logging.error(f"Error hashing file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def on_modified(self, event):
|
||||
"""Handle file modification events."""
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
file_path = event.src_path
|
||||
# Schedule the update using asyncio
|
||||
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
|
||||
logging.info(f"File modified: {file_path}")
|
||||
|
||||
def on_created(self, event):
|
||||
"""Handle file creation events."""
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
file_path = event.src_path
|
||||
# Schedule the update using asyncio
|
||||
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
|
||||
logging.info(f"File created: {file_path}")
|
||||
|
||||
def on_deleted(self, event):
|
||||
"""Handle file deletion events."""
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
file_path = event.src_path
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.remove_file_from_collection(file_path), self.loop
|
||||
)
|
||||
logging.info(f"File deleted: {file_path}")
|
||||
|
||||
def on_moved(self, event):
|
||||
"""Handle move deletion events."""
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
file_path = event.src_path
|
||||
logging.info(f"TODO: on_moved: ${file_path}")
|
||||
|
||||
def _normalize_embeddings(self, embeddings):
|
||||
"""Normalize the embeddings to unit length."""
|
||||
# Handle both single vector and array of vectors
|
||||
if isinstance(embeddings[0], (int, float)):
|
||||
# Single vector
|
||||
norm = np.linalg.norm(embeddings)
|
||||
return [e / norm for e in embeddings] if norm > 0 else embeddings
|
||||
else:
|
||||
# Array of vectors
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
return embeddings / norms
|
||||
|
||||
async def _update_document_in_collection(self, file_path):
|
||||
"""Update a document in the ChromaDB collection."""
|
||||
try:
|
||||
# Remove existing entries for this file
|
||||
existing_results = self.collection.get(where={"path": file_path})
|
||||
if (
|
||||
existing_results
|
||||
and "ids" in existing_results
|
||||
and existing_results["ids"]
|
||||
):
|
||||
self.collection.delete(ids=existing_results["ids"])
|
||||
|
||||
extensions = (".docx", ".xlsx", ".xls", ".pdf")
|
||||
if file_path.endswith(extensions):
|
||||
p = Path(file_path)
|
||||
p_as_md = p.with_suffix(".md")
|
||||
if p_as_md.exists():
|
||||
logging.info(
|
||||
f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}"
|
||||
)
|
||||
|
||||
# If file_path.md doesn't exist or file_path is newer than file_path.md,
|
||||
# fire off markitdown
|
||||
if (not p_as_md.exists()) or (
|
||||
p.stat().st_mtime > p_as_md.stat().st_mtime
|
||||
):
|
||||
self._markitdown(file_path, p_as_md)
|
||||
return
|
||||
|
||||
chunks = self._markdown_chunker.process_file(file_path)
|
||||
if not chunks:
|
||||
logging.info(f"No chunks found in markdown: {file_path}")
|
||||
return
|
||||
|
||||
# Extract top-level directory
|
||||
rel_path = os.path.relpath(file_path, self.watch_directory)
|
||||
path_parts = rel_path.split(os.sep)
|
||||
top_level_dir = path_parts[0]
|
||||
# file_name = path_parts[-1]
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk["metadata"]["doc_type"] = top_level_dir
|
||||
# with open(f"src/tmp/{file_name}.{i}", "w") as f:
|
||||
# f.write(json.dumps(chunk, indent=2))
|
||||
|
||||
# Add chunks to collection
|
||||
self.add_embeddings_to_collection(chunks)
|
||||
|
||||
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating document in collection: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
async def initialize_collection(self):
|
||||
"""Initialize the collection with all documents from the watch directory."""
|
||||
# Process all files regardless of hash state
|
||||
num_processed = await self.scan_directory(process_all=True)
|
||||
|
||||
logging.info(
|
||||
f"Vectorstore initialized with {self.collection.count()} documents"
|
||||
)
|
||||
|
||||
self._update_umaps()
|
||||
|
||||
# Show stats
|
||||
try:
|
||||
all_metadata = self.collection.get()["metadatas"]
|
||||
if all_metadata:
|
||||
doc_types = set(m.get("doc_type", "unknown") for m in all_metadata)
|
||||
logging.info(f"Document types: {doc_types}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting document types: {e}")
|
||||
|
||||
return num_processed
|
||||
|
||||
|
||||
# Function to start the file watcher
|
||||
def start_file_watcher(
|
||||
llm,
|
||||
watch_directory,
|
||||
persist_directory,
|
||||
collection_name,
|
||||
initialize=False,
|
||||
recreate=False,
|
||||
):
|
||||
"""
|
||||
Start watching a directory for file changes.
|
||||
|
||||
Args:
|
||||
llm: The language model client
|
||||
watch_directory: Directory to watch for changes
|
||||
persist_directory: Directory to persist ChromaDB and hash state
|
||||
collection_name: Name of the ChromaDB collection
|
||||
initialize: Whether to forcibly initialize the collection with all documents
|
||||
recreate: Whether to recreate the collection (will delete existing)
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
file_watcher = ChromaDBFileWatcher(
|
||||
llm,
|
||||
watch_directory=watch_directory,
|
||||
loop=loop,
|
||||
persist_directory=persist_directory,
|
||||
collection_name=collection_name,
|
||||
recreate=recreate,
|
||||
)
|
||||
|
||||
# Process all files if:
|
||||
# 1. initialize=True was passed (explicit request to initialize)
|
||||
# 2. This is a new collection (doesn't exist yet)
|
||||
# 3. There's no hash state (first run)
|
||||
if initialize or file_watcher.is_new_collection or not file_watcher.file_hashes:
|
||||
logging.info("Initializing collection with all documents")
|
||||
asyncio.run_coroutine_threadsafe(file_watcher.initialize_collection(), loop)
|
||||
else:
|
||||
# Only process new/changed files
|
||||
logging.info("Scanning for new/changed documents")
|
||||
asyncio.run_coroutine_threadsafe(file_watcher.scan_directory(), loop)
|
||||
|
||||
# Start observer
|
||||
observer = Observer()
|
||||
observer.schedule(file_watcher, watch_directory, recursive=True)
|
||||
observer.start()
|
||||
|
||||
logging.info(f"Started watching directory: {watch_directory}")
|
||||
return observer, file_watcher
|
14
src/backend/tools/__init__.py
Normal file
14
src/backend/tools/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
import importlib
|
||||
|
||||
from .basetools import all_tools, ToolEntry, llm_tools, enabled_tools, tool_functions
|
||||
from ..setup_logging import setup_logging
|
||||
from .. import defines
|
||||
|
||||
logger = setup_logging(level=defines.logging_level)
|
||||
|
||||
# Dynamically import all names from basetools listed in tools_all
|
||||
module = importlib.import_module(".basetools", package=__package__)
|
||||
for name in tool_functions:
|
||||
globals()[name] = getattr(module, name)
|
||||
|
||||
__all__ = ["all_tools", "ToolEntry", "llm_tools", "enabled_tools", "tool_functions"]
|
526
src/backend/tools/basetools.py
Normal file
526
src/backend/tools/basetools.py
Normal file
@ -0,0 +1,526 @@
|
||||
import os
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from typing import List, Optional, Generator, ClassVar, Any, Dict
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
)
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from bs4 import BeautifulSoup # type: ignore
|
||||
|
||||
from geopy.geocoders import Nominatim # type: ignore
|
||||
import pytz # type: ignore
|
||||
import requests
|
||||
import yfinance as yf # type: ignore
|
||||
import logging
|
||||
|
||||
|
||||
# %%
|
||||
def WeatherForecast(city, state, country="USA"):
|
||||
"""
|
||||
Get weather information from weather.gov based on city, state, and country.
|
||||
|
||||
Args:
|
||||
city (str): City name
|
||||
state (str): State name or abbreviation
|
||||
country (str): Country name (defaults to "USA" as weather.gov is for US locations)
|
||||
|
||||
Returns:
|
||||
dict: Weather forecast information
|
||||
"""
|
||||
# Step 1: Get coordinates for the location using geocoding
|
||||
location = f"{city}, {state}, {country}"
|
||||
coordinates = get_coordinates(location)
|
||||
|
||||
if not coordinates:
|
||||
return {"error": f"Could not find coordinates for {location}"}
|
||||
|
||||
# Step 2: Get the forecast grid endpoint for the coordinates
|
||||
grid_endpoint = get_grid_endpoint(coordinates)
|
||||
|
||||
if not grid_endpoint:
|
||||
return {"error": f"Could not find weather grid for coordinates {coordinates}"}
|
||||
|
||||
# Step 3: Get the forecast data from the grid endpoint
|
||||
forecast = get_forecast(grid_endpoint)
|
||||
|
||||
if not forecast["location"]:
|
||||
forecast["location"] = location
|
||||
|
||||
return forecast
|
||||
|
||||
|
||||
def get_coordinates(location):
|
||||
"""Convert a location string to latitude and longitude using Nominatim geocoder."""
|
||||
try:
|
||||
# Create a geocoder with a meaningful user agent
|
||||
geolocator = Nominatim(user_agent="weather_app_example")
|
||||
|
||||
# Get the location
|
||||
location_data = geolocator.geocode(location)
|
||||
|
||||
if location_data:
|
||||
return {
|
||||
"latitude": location_data.latitude,
|
||||
"longitude": location_data.longitude,
|
||||
}
|
||||
else:
|
||||
print(f"Location not found: {location}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error getting coordinates: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_grid_endpoint(coordinates):
|
||||
"""Get the grid endpoint from weather.gov based on coordinates."""
|
||||
try:
|
||||
lat = coordinates["latitude"]
|
||||
lon = coordinates["longitude"]
|
||||
|
||||
# Define headers for the API request
|
||||
headers = {
|
||||
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
||||
"Accept": "application/geo+json",
|
||||
}
|
||||
|
||||
# Make the request to get the grid endpoint
|
||||
url = f"https://api.weather.gov/points/{lat},{lon}"
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data["properties"]["forecast"]
|
||||
else:
|
||||
print(f"Error getting grid: {response.status_code} - {response.text}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error in get_grid_endpoint: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Weather related function
|
||||
|
||||
|
||||
def get_forecast(grid_endpoint):
|
||||
"""Get the forecast data from the grid endpoint."""
|
||||
try:
|
||||
# Define headers for the API request
|
||||
headers = {
|
||||
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
|
||||
"Accept": "application/geo+json",
|
||||
}
|
||||
|
||||
# Make the request to get the forecast
|
||||
response = requests.get(grid_endpoint, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
|
||||
# Extract the relevant forecast information
|
||||
periods = data["properties"]["periods"]
|
||||
|
||||
# Process the forecast data into a simpler format
|
||||
forecast = {
|
||||
"location": data["properties"]
|
||||
.get("relativeLocation", {})
|
||||
.get("properties", {}),
|
||||
"updated": data["properties"].get("updated", ""),
|
||||
"periods": [],
|
||||
}
|
||||
|
||||
for period in periods:
|
||||
forecast["periods"].append(
|
||||
{
|
||||
"name": period.get("name", ""),
|
||||
"temperature": period.get("temperature", ""),
|
||||
"temperatureUnit": period.get("temperatureUnit", ""),
|
||||
"windSpeed": period.get("windSpeed", ""),
|
||||
"windDirection": period.get("windDirection", ""),
|
||||
"shortForecast": period.get("shortForecast", ""),
|
||||
"detailedForecast": period.get("detailedForecast", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return forecast
|
||||
else:
|
||||
print(f"Error getting forecast: {response.status_code} - {response.text}")
|
||||
return {"error": f"API Error: {response.status_code}"}
|
||||
except Exception as e:
|
||||
print(f"Error in get_forecast: {e}")
|
||||
return {"error": f"Exception: {str(e)}"}
|
||||
|
||||
|
||||
# Example usage
|
||||
# def do_weather():
|
||||
# city = input("Enter city: ")
|
||||
# state = input("Enter state: ")
|
||||
# country = input("Enter country (default USA): ") or "USA"
|
||||
|
||||
# print(f"Getting weather for {city}, {state}, {country}...")
|
||||
# weather_data = WeatherForecast(city, state, country)
|
||||
|
||||
# if "error" in weather_data:
|
||||
# print(f"Error: {weather_data['error']}")
|
||||
# else:
|
||||
# print("\nWeather Forecast:")
|
||||
# print(f"Location: {weather_data.get('location', {}).get('city', city)}, {weather_data.get('location', {}).get('state', state)}")
|
||||
# print(f"Last Updated: {weather_data.get('updated', 'N/A')}")
|
||||
# print("\nForecast Periods:")
|
||||
|
||||
# for period in weather_data.get("periods", []):
|
||||
# print(f"\n{period['name']}:")
|
||||
# print(f" Temperature: {period['temperature']}{period['temperatureUnit']}")
|
||||
# print(f" Wind: {period['windSpeed']} {period['windDirection']}")
|
||||
# print(f" Forecast: {period['shortForecast']}")
|
||||
# print(f" Details: {period['detailedForecast']}")
|
||||
|
||||
# %%
|
||||
|
||||
|
||||
def TickerValue(ticker_symbols):
|
||||
api_key = os.getenv("TWELVEDATA_API_KEY", "")
|
||||
if not api_key:
|
||||
return {"error": f"Error fetching data: No API key for TwelveData"}
|
||||
|
||||
results = []
|
||||
for ticker_symbol in ticker_symbols.split(","):
|
||||
ticker_symbol = ticker_symbol.strip()
|
||||
if ticker_symbol == "":
|
||||
continue
|
||||
|
||||
url = (
|
||||
f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
|
||||
)
|
||||
|
||||
response = requests.get(url)
|
||||
data = response.json()
|
||||
|
||||
if "price" in data:
|
||||
logging.info(f"TwelveData: {ticker_symbol} {data}")
|
||||
results.append({"symbol": ticker_symbol, "price": float(data["price"])})
|
||||
else:
|
||||
logging.error(f"TwelveData: {data}")
|
||||
results.append({"symbol": ticker_symbol, "price": "Unavailable"})
|
||||
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
|
||||
# Stock related function
|
||||
def yfTickerValue(ticker_symbols):
|
||||
"""
|
||||
Look up the current price of a stock using its ticker symbol.
|
||||
|
||||
Args:
|
||||
ticker_symbol (str): The stock ticker symbol (e.g., 'AAPL' for Apple)
|
||||
|
||||
Returns:
|
||||
dict: Current stock information including price
|
||||
"""
|
||||
results = []
|
||||
for ticker_symbol in ticker_symbols.split(","):
|
||||
ticker_symbol = ticker_symbol.strip()
|
||||
if ticker_symbol == "":
|
||||
continue
|
||||
# Create a Ticker object
|
||||
try:
|
||||
logging.info(f"Looking up {ticker_symbol}")
|
||||
ticker = yf.Ticker(ticker_symbol)
|
||||
# Get the latest market data
|
||||
ticker_data = ticker.history(period="1d")
|
||||
|
||||
if ticker_data.empty:
|
||||
results.append({"error": f"No data found for ticker {ticker_symbol}"})
|
||||
continue
|
||||
|
||||
# Get the latest closing price
|
||||
latest_price = ticker_data["Close"].iloc[-1]
|
||||
|
||||
# Get some additional info
|
||||
results.append({"symbol": ticker_symbol, "price": latest_price})
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logging.error(f"Error fetching data for {ticker_symbol}: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
results.append(
|
||||
{"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}
|
||||
)
|
||||
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
|
||||
# %%
|
||||
def DateTime(timezone="America/Los_Angeles"):
|
||||
"""
|
||||
Returns the current date and time in the specified timezone in ISO 8601 format.
|
||||
|
||||
Args:
|
||||
timezone (str): Timezone name (e.g., "UTC", "America/New_York", "Europe/London")
|
||||
Default is "America/Los_Angeles"
|
||||
|
||||
Returns:
|
||||
str: Current date and time with timezone in the format YYYY-MM-DDTHH:MM:SS+HH:MM
|
||||
"""
|
||||
try:
|
||||
if timezone == "system" or timezone == "" or not timezone:
|
||||
timezone = "America/Los_Angeles"
|
||||
# Get current UTC time (timezone-aware)
|
||||
local_tz = pytz.timezone("America/Los_Angeles")
|
||||
local_now = datetime.now(tz=local_tz)
|
||||
|
||||
# Convert to target timezone
|
||||
target_tz = pytz.timezone(timezone)
|
||||
target_time = local_now.astimezone(target_tz)
|
||||
|
||||
return target_time.isoformat()
|
||||
except Exception as e:
|
||||
return {"error": f"Invalid timezone {timezone}: {str(e)}"}
|
||||
|
||||
async def GenerateImage(llm, model: str, prompt: str):
|
||||
return { "image_id": "image-a830a83-bd831" }
|
||||
|
||||
async def AnalyzeSite(llm, model: str, url: str, question: str):
|
||||
"""
|
||||
Fetches content from a URL, extracts the text, and uses Ollama to summarize it.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the website to summarize
|
||||
|
||||
Returns:
|
||||
str: A summary of the website content
|
||||
"""
|
||||
try:
|
||||
# Fetch the webpage
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
logging.info(f"Fetching {url}")
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
logging.info(f"{url} returned. Processing...")
|
||||
# Parse the HTML
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
# Remove script and style elements
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
# Get text content
|
||||
text = soup.get_text(separator=" ", strip=True)
|
||||
|
||||
# Clean up text (remove extra whitespace)
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = " ".join(chunk for chunk in chunks if chunk)
|
||||
|
||||
# Limit text length if needed (Ollama may have token limits)
|
||||
max_chars = 100000
|
||||
if len(text) > max_chars:
|
||||
text = text[:max_chars] + "..."
|
||||
|
||||
# Create Ollama client
|
||||
# logging.info(f"Requesting summary of: {text}")
|
||||
|
||||
# Generate summary using Ollama
|
||||
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
||||
response = llm.generate(
|
||||
model=model,
|
||||
system="You are given the contents of {url}. Answer the question about the contents",
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
# logging.info(response["response"])
|
||||
|
||||
return {
|
||||
"source": "summarizer-llm",
|
||||
"content": response["response"],
|
||||
"metadata": DateTime(),
|
||||
}
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"Error fetching the URL: {e}")
|
||||
return f"Error fetching the URL: {str(e)}"
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing the website content: {e}")
|
||||
return f"Error processing the website content: {str(e)}"
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
returns: Optional[Dict[str, Any]] = {}
|
||||
|
||||
class Tool(BaseModel):
|
||||
type: str
|
||||
function: Function
|
||||
|
||||
tools : List[Tool] = [
|
||||
# Tool.model_validate({
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "GenerateImage",
|
||||
# "description": """\
|
||||
# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
|
||||
|
||||
# 1. Call this tool when users request images, drawings, or visual content
|
||||
# 2. This tool returns an image_id (e.g., "img_abc123")
|
||||
# 3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage id={image_id}/>
|
||||
# 4. FORBIDDEN: DO NOT use markdown image syntax 
|
||||
# 5. FORBIDDEN: DO NOT create fake URLs or file paths
|
||||
# 6. FORBIDDEN: DO NOT use any other image embedding format
|
||||
|
||||
# CORRECT EXAMPLE:
|
||||
# User: "Draw a cat"
|
||||
# Tool returns: {"image_id": "img_xyz789"}
|
||||
# Your response: "Here's your cat image: <GenerateImage id=img_xyz789/>"
|
||||
|
||||
# WRONG EXAMPLES (DO NOT DO THIS):
|
||||
# - 
|
||||
# - 
|
||||
# - <img src="...">
|
||||
|
||||
# The <GenerateImage id={image_id}/> format is the ONLY way to display images in this system.
|
||||
# """,
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "prompt": {
|
||||
# "type": "string",
|
||||
# "description": "Detailed image description including style, colors, subject, composition"
|
||||
# }
|
||||
# },
|
||||
# "required": ["prompt"]
|
||||
# },
|
||||
# "returns": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "image_id": {
|
||||
# "type": "string",
|
||||
# "description": "Unique identifier for the generated image. Use this EXACTLY in <GenerateImage id={this_value}/>"
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }),
|
||||
Tool.model_validate({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "TickerValue",
|
||||
"description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ticker": {
|
||||
"type": "string",
|
||||
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').",
|
||||
},
|
||||
},
|
||||
"required": ["ticker"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}),
|
||||
Tool.model_validate({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "AnalyzeSite",
|
||||
"description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The website URL to download and process",
|
||||
},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question to ask the second LLM about the content",
|
||||
},
|
||||
},
|
||||
"required": ["url", "question"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"returns": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "Identifier for the source LLM",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The complete response from the second LLM",
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"description": "Additional information about the response",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
Tool.model_validate({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "DateTime",
|
||||
"description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.",
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}),
|
||||
Tool.model_validate({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WeatherForecast",
|
||||
"description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
|
||||
"minLength": 2,
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
|
||||
"minLength": 2,
|
||||
},
|
||||
},
|
||||
"required": ["city", "state"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}),
|
||||
]
|
||||
|
||||
class ToolEntry(BaseModel):
|
||||
enabled: bool = True
|
||||
tool: Tool
|
||||
|
||||
def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]:
|
||||
return [entry.tool.model_dump(mode='json') for entry in tools if entry.enabled == True]
|
||||
|
||||
def all_tools() -> List[ToolEntry]:
|
||||
return [ToolEntry(tool=tool) for tool in tools]
|
||||
|
||||
def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
|
||||
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled == True]
|
||||
|
||||
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
|
||||
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||
# __all__.extend(__tool_functions__) # type: ignore
|
Loading…
x
Reference in New Issue
Block a user