Rag is being generated (again) however the LLM is not using it.

This commit is contained in:
James Ketr 2025-05-31 11:31:36 -07:00
parent 4a80004363
commit 77440a9d6b
21 changed files with 2447 additions and 351 deletions

View File

@ -15,7 +15,7 @@ import { BackstoryElementProps } from './BackstoryTab';
import { connectionBase } from 'utils/Global';
import { useAuth } from "hooks/AuthContext";
import { StreamingResponse } from 'services/api-client';
import { ChatMessage, ChatMessageBase, ChatContext, ChatSession, ChatQuery } from 'types/types';
import { ChatMessage, ChatMessageBase, ChatContext, ChatSession, ChatQuery, ChatMessageUser } from 'types/types';
import { PaginatedResponse } from 'types/conversion';
import './Conversation.css';
@ -259,7 +259,17 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
{ ...defaultMessage, content: 'Submitting request...' }
);
controllerRef.current = apiClient.sendMessageStream(sessionId, query, {
const chatMessage: ChatMessageUser = {
sessionId: chatSession.id,
content: query.prompt,
tunables: query.tunables,
status: "done",
type: "user",
sender: "user",
timestamp: new Date()
};
controllerRef.current = apiClient.sendMessageStream(chatMessage, {
onMessage: (msg: ChatMessageBase) => {
console.log("onMessage:", msg);
if (msg.type === "response") {

View File

@ -32,7 +32,7 @@ import { SetSnackType } from './Snack';
import { CopyBubble } from './CopyBubble';
import { Scrollable } from './Scrollable';
import { BackstoryElementProps } from './BackstoryTab';
import { ChatMessage, ChatSession, ChatMessageType } from 'types/types';
import { ChatMessage, ChatSession, ChatMessageType, ChatMessageMetaData, ChromaDBGetResponse } from 'types/types';
const getStyle = (theme: Theme, type: ChatMessageType): any => {
const defaultRadius = '16px';
@ -184,19 +184,19 @@ interface MessageProps extends BackstoryElementProps {
};
interface MessageMetaProps {
metadata: Record<string, any>,
metadata: ChatMessageMetaData,
messageProps: MessageProps
};
const MessageMeta = (props: MessageMetaProps) => {
const {
/* MessageData */
rag,
tools,
eval_count,
eval_duration,
prompt_eval_count,
prompt_eval_duration,
ragResults = [],
tools = null,
evalCount = 0,
evalDuration = 0,
promptEvalCount = 0,
promptEvalDuration = 0,
} = props.metadata || {};
const message: any = props.messageProps.message;
@ -206,7 +206,7 @@ const MessageMeta = (props: MessageMetaProps) => {
return (<>
{
prompt_eval_duration !== 0 && eval_duration !== 0 && <>
promptEvalDuration !== 0 && evalDuration !== 0 && <>
<TableContainer component={Card} className="PromptStats" sx={{ mb: 1 }}>
<Table aria-label="prompt stats" size="small">
<TableHead>
@ -220,21 +220,21 @@ const MessageMeta = (props: MessageMetaProps) => {
<TableBody>
<TableRow key="prompt" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
<TableCell component="th" scope="row">Prompt</TableCell>
<TableCell align="right">{prompt_eval_count}</TableCell>
<TableCell align="right">{Math.round(prompt_eval_duration / 10 ** 7) / 100}</TableCell>
<TableCell align="right">{Math.round(prompt_eval_count * 10 ** 9 / prompt_eval_duration)}</TableCell>
<TableCell align="right">{promptEvalCount}</TableCell>
<TableCell align="right">{Math.round(promptEvalDuration / 10 ** 7) / 100}</TableCell>
<TableCell align="right">{Math.round(promptEvalCount * 10 ** 9 / promptEvalDuration)}</TableCell>
</TableRow>
<TableRow key="response" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
<TableCell component="th" scope="row">Response</TableCell>
<TableCell align="right">{eval_count}</TableCell>
<TableCell align="right">{Math.round(eval_duration / 10 ** 7) / 100}</TableCell>
<TableCell align="right">{Math.round(eval_count * 10 ** 9 / eval_duration)}</TableCell>
<TableCell align="right">{evalCount}</TableCell>
<TableCell align="right">{Math.round(evalDuration / 10 ** 7) / 100}</TableCell>
<TableCell align="right">{Math.round(evalCount * 10 ** 9 / evalDuration)}</TableCell>
</TableRow>
<TableRow key="total" sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
<TableCell component="th" scope="row">Total</TableCell>
<TableCell align="right">{prompt_eval_count + eval_count}</TableCell>
<TableCell align="right">{Math.round((prompt_eval_duration + eval_duration) / 10 ** 7) / 100}</TableCell>
<TableCell align="right">{Math.round((prompt_eval_count + eval_count) * 10 ** 9 / (prompt_eval_duration + eval_duration))}</TableCell>
<TableCell align="right">{promptEvalCount + evalCount}</TableCell>
<TableCell align="right">{Math.round((promptEvalDuration + evalDuration) / 10 ** 7) / 100}</TableCell>
<TableCell align="right">{Math.round((promptEvalCount + evalCount) * 10 ** 9 / (promptEvalDuration + evalDuration))}</TableCell>
</TableRow>
</TableBody>
</Table>
@ -278,11 +278,11 @@ const MessageMeta = (props: MessageMetaProps) => {
</Accordion>
}
{
rag.map((collection: any) => (
ragResults.map((collection: ChromaDBGetResponse) => (
<Accordion key={collection.name}>
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
<Box sx={{ fontSize: "0.8rem" }}>
Top {collection.ids.length} RAG matches from {collection.size} entries using an embedding vector of {collection.query_embedding.length} dimensions
Top {collection.ids?.length} RAG matches from {collection.size} entries using an embedding vector of {collection.queryEmbedding?.length} dimensions
</Box>
</AccordionSummary>
<AccordionDetails>

View File

@ -26,7 +26,7 @@ import {
Chat as ChatIcon
} from '@mui/icons-material';
import { useAuth } from 'hooks/AuthContext';
import { ChatMessageBase, ChatMessage, ChatSession, ChatStatusType } from 'types/types';
import { ChatMessageBase, ChatMessage, ChatSession, ChatStatusType, ChatMessageType, ChatMessageUser } from 'types/types';
import { ConversationHandle } from 'components/Conversation';
import { BackstoryPageProps } from 'components/BackstoryTab';
import { Message } from 'components/Message';
@ -35,16 +35,23 @@ import { CandidateSessionsResponse } from 'services/api-client';
import { CandidateInfo } from 'components/CandidateInfo';
import { useNavigate } from 'react-router-dom';
import { useSelectedCandidate } from 'hooks/GlobalContext';
import PropagateLoader from 'react-spinners/PropagateLoader';
const DRAWER_WIDTH = 300;
const HANDLE_WIDTH = 48;
const defaultMessage: ChatMessage = {
type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: ""
};
const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
const { apiClient } = useAuth();
const { selectedCandidate } = useSelectedCandidate()
const navigate = useNavigate();
const theme = useTheme();
const isMdUp = useMediaQuery(theme.breakpoints.up('md'));
const [processingMessage, setProcessingMessage] = useState<ChatMessage | null>(null);
const [streamingMessage, setStreamingMessage] = useState<ChatMessage | null>(null);
const {
setSnack,
@ -53,10 +60,10 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
const [sessions, setSessions] = useState<CandidateSessionsResponse | null>(null);
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
const [messages, setMessages] = useState([]);
const [newMessage, setNewMessage] = useState('');
const [loading, setLoading] = useState(false);
const [streaming, setStreaming] = useState(false);
const [messages, setMessages] = useState<ChatMessage[]>([]);
const [newMessage, setNewMessage] = useState<string>('');
const [loading, setLoading] = useState<boolean>(false);
const [streaming, setStreaming] = useState<boolean>(false);
const messagesEndRef = useRef(null);
// Drawer state - defaults to open on md+ screens or when no session is selected
@ -88,7 +95,11 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
try {
const result = await apiClient.getChatMessages(chatSession.id);
setMessages(result.data as any);
const chatMessages: ChatMessage[] = result.data;
setMessages(chatMessages);
setProcessingMessage(null);
setStreamingMessage(null);
console.log(`getChatMessages returned ${chatMessages.length} messages.`, chatMessages);
} catch (error) {
console.error('Failed to load messages:', error);
}
@ -106,6 +117,8 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
);
setChatSession(newSession);
setMessages([]);
setProcessingMessage(null);
setStreamingMessage(null);
await loadSessions(); // Refresh sessions list
} catch (error) {
console.error('Failed to create session:', error);
@ -122,35 +135,56 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
setNewMessage('');
setStreaming(true);
const chatMessage: ChatMessageUser = {
sessionId: chatSession.id,
content: messageContent,
status: "done",
type: "user",
sender: "user",
timestamp: new Date()
};
setMessages(prev => {
const filtered = prev.filter((m: any) => m.id !== chatMessage.id);
return [...filtered, chatMessage] as any;
});
try {
await apiClient.sendMessageStream(
chatSession.id,
{ prompt: messageContent }, {
await apiClient.sendMessageStream(chatMessage, {
onMessage: (msg: ChatMessage) => {
console.log("onMessage:", msg);
console.log(`onMessage: ${msg.type} ${msg.content}`, msg);
if (msg.type === "response") {
setMessages(prev => {
const filtered = prev.filter((m: any) => m.id !== msg.id);
return [...filtered, msg].sort((a, b) =>
new Date(a.timestamp).getTime() - new Date(b.timestamp).getTime()
) as any;
});
setMessages(prev => {
const filtered = prev.filter((m: any) => m.id !== msg.id);
return [...filtered, msg] as any;
});
setStreamingMessage(null);
setProcessingMessage(null);
} else {
console.log(msg);
setProcessingMessage(msg);
}
},
onError: (error: string | ChatMessageBase) => {
console.log("onError:", error);
// Type-guard to determine if this is a ChatMessageBase or a string
if (typeof error === "object" && error !== null && "content" in error) {
setProcessingMessage(error as ChatMessage);
} else {
setProcessingMessage({ ...defaultMessage, content: error as string });
}
setStreaming(false);
},
onStreaming: (chunk: ChatMessageBase) => {
console.log("onStreaming:", chunk);
// console.log("onStreaming:", chunk);
setStreamingMessage({ ...defaultMessage, ...chunk });
},
onStatusChange: (status: ChatStatusType) => {
console.log("onStatusChange:", status);
onStatusChange: (status: string) => {
console.log(`onStatusChange: ${status}`);
},
onComplete: () => {
console.log("onComplete");
setStreamingMessage(null);
setProcessingMessage(null);
setStreaming(false);
}
});
@ -321,27 +355,37 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
>
{chatSession?.id ? (
<>
{/* Messages Area */}
<Box sx={{ flexGrow: 1, overflow: 'auto', p: 2 }}>
{messages.map((message: ChatMessageBase) => (
<Message key={message.id} {...{ chatSession, message, setSnack, submitQuery }} />
))}
{streaming && (
<Box sx={{ display: 'flex', alignItems: 'center', mb: 2 }}>
<Avatar sx={{ mr: 1, bgcolor: 'primary.main' }}>
🤖
</Avatar>
<Card>
<CardContent sx={{ display: 'flex', alignItems: 'center', p: 2 }}>
<CircularProgress size={16} sx={{ mr: 1 }} />
<Typography variant="body2">AI is typing...</Typography>
</CardContent>
</Card>
</Box>
)}
{
messages.map((message: ChatMessageBase) => (
<Message key={message.id} {...{ chatSession, message, setSnack, submitQuery }} />
))
}
{
processingMessage !== null &&
<Message {...{ chatSession, message: processingMessage, setSnack, submitQuery }} />
}
{
streamingMessage !== null &&
<Message {...{ chatSession, message: streamingMessage, setSnack, submitQuery }} />
}
{streaming && <Box sx={{
display: "flex",
flexDirection: "column",
alignItems: "center",
justifyContent: "center",
m: 1,
}}>
<PropagateLoader
size="10px"
loading={streaming}
aria-label="Loading Spinner"
data-testid="loader"
/>
</Box>
}
<div ref={messagesEndRef} />
</Box>
<Divider />
@ -354,7 +398,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
placeholder="Type your message about the candidate..."
value={newMessage}
onChange={(e) => setNewMessage(e.target.value)}
onKeyPress={(e) => {
onKeyDown={(e) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendMessage();
@ -384,10 +428,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
flexDirection: 'column',
gap: 2
}}
>
<Typography variant="h1" sx={{ fontSize: 64, color: 'text.secondary' }}>
🤖
</Typography>
>
<Typography variant="h6" color="text.secondary">
Select a session to start chatting
</Typography>

View File

@ -19,10 +19,11 @@ import { StyledMarkdown } from 'components/StyledMarkdown';
import { Scrollable } from '../components/Scrollable';
import { Pulse } from 'components/Pulse';
import { StreamingResponse } from 'services/api-client';
import { ChatContext, ChatMessage, ChatMessageBase, ChatSession, ChatQuery } from 'types/types';
import { ChatContext, ChatMessage, ChatMessageUser, ChatMessageBase, ChatSession, ChatQuery } from 'types/types';
import { useAuth } from 'hooks/AuthContext';
const emptyUser: Candidate = {
userType: "candidate",
description: "[blank]",
username: "[blank]",
firstName: "[blank]",
@ -101,7 +102,17 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
setCanGenImage(false);
setShouldGenerateProfile(false); // Reset the flag
const streamResponse = apiClient.sendMessageStream(sessionId, query, {
const chatMessage: ChatMessageUser = {
sessionId: chatSession.id,
content: query.prompt,
tunables: query.tunables,
status: "done",
type: "user",
sender: "user",
timestamp: new Date()
};
const streamResponse = apiClient.sendMessageStream(chatMessage, {
onMessage: (chatMessage: ChatMessage) => {
console.log('Message:', chatMessage);
// Update UI with partial content

View File

@ -505,25 +505,11 @@ class ApiClient {
return this.handleApiResponseWithConversion<Types.ChatSession>(response, 'ChatSession');
}
/**
* Send message with standard response (non-streaming)
*/
async sendMessage(sessionId: string, query: Types.ChatQuery): Promise<Types.ChatMessage> {
const response = await fetch(`${this.baseUrl}/chat/sessions/${sessionId}/messages`, {
method: 'POST',
headers: this.defaultHeaders,
body: JSON.stringify(formatApiRequest({query}))
});
return this.handleApiResponseWithConversion<Types.ChatMessage>(response, 'ChatMessage');
}
/**
* Send message with streaming response support and date conversion
*/
sendMessageStream(
sessionId: string,
query: Types.ChatQuery,
chatMessage: Types.ChatMessageUser,
options: StreamingOptions = {}
): StreamingResponse {
const abortController = new AbortController();
@ -533,14 +519,14 @@ class ApiClient {
const promise = new Promise<Types.ChatMessage[]>(async (resolve, reject) => {
try {
const response = await fetch(`${this.baseUrl}/chat/sessions/${sessionId}/messages/stream`, {
const response = await fetch(`${this.baseUrl}/chat/sessions/${chatMessage.sessionId}/messages/stream`, {
method: 'POST',
headers: {
...this.defaultHeaders,
'Accept': 'text/event-stream',
'Cache-Control': 'no-cache'
},
body: JSON.stringify(formatApiRequest({ query })),
body: JSON.stringify(formatApiRequest({ chatMessage })),
signal
});
@ -555,8 +541,8 @@ class ApiClient {
const decoder = new TextDecoder();
let buffer = '';
let chatMessage: Types.ChatMessage | null = null;
const chatMessageList: Types.ChatMessage[] = [];
let incomingMessage: Types.ChatMessage | null = null;
const incomingMessageList: Types.ChatMessage[] = [];
try {
while (true) {
@ -585,20 +571,20 @@ class ApiClient {
const convertedIncoming = convertChatMessageFromApi(incoming);
// Trigger callbacks based on status
if (convertedIncoming.status !== chatMessage?.status) {
if (convertedIncoming.status !== incomingMessage?.status) {
options.onStatusChange?.(convertedIncoming.status);
}
// Handle different status types
switch (convertedIncoming.status) {
case 'streaming':
if (chatMessage === null) {
chatMessage = {...convertedIncoming};
if (incomingMessage === null) {
incomingMessage = {...convertedIncoming};
} else {
// Can't do a simple += as typescript thinks .content might not be there
chatMessage.content = (chatMessage?.content || '') + convertedIncoming.content;
incomingMessage.content = (incomingMessage?.content || '') + convertedIncoming.content;
// Update timestamp to latest
chatMessage.timestamp = convertedIncoming.timestamp;
incomingMessage.timestamp = convertedIncoming.timestamp;
}
options.onStreaming?.(convertedIncoming);
break;
@ -608,7 +594,7 @@ class ApiClient {
break;
default:
chatMessageList.push(convertedIncoming);
incomingMessageList.push(convertedIncoming);
options.onMessage?.(convertedIncoming);
break;
}
@ -627,7 +613,7 @@ class ApiClient {
}
options.onComplete?.();
resolve(chatMessageList);
resolve(incomingMessageList);
} catch (error) {
if (signal.aborted) {
options.onComplete?.();
@ -647,24 +633,6 @@ class ApiClient {
};
}
/**
* Send message with automatic streaming detection
*/
async sendMessageAuto(
sessionId: string,
query: Types.ChatQuery,
options?: StreamingOptions
): Promise<Types.ChatMessage[]> {
// If streaming options are provided, use streaming
if (options && (options.onMessage || options.onStreaming || options.onStatusChange)) {
const streamResponse = this.sendMessageStream(sessionId, query, options);
return streamResponse.promise;
}
// Otherwise, use standard response
return [await this.sendMessage(sessionId, query)];
}
/**
* Get persisted chat messages for a session with date conversion
*/

View File

@ -1,6 +1,6 @@
// Generated TypeScript types from Pydantic models
// Source: src/backend/models.py
// Generated on: 2025-05-30T18:07:14.923475
// Generated on: 2025-05-31T18:20:52.253576
// DO NOT EDIT MANUALLY - This file is auto-generated
// ============================
@ -176,7 +176,7 @@ export interface Candidate {
lastLogin?: Date;
profileImage?: string;
status: "active" | "inactive" | "pending" | "banned";
userType?: "candidate";
userType: "candidate";
username: string;
description?: string;
resume?: string;
@ -192,6 +192,8 @@ export interface Candidate {
certifications?: Array<Certification>;
jobApplications?: Array<JobApplication>;
hasProfile?: boolean;
rags?: Array<RagEntry>;
ragContentSize?: number;
age?: number;
gender?: "female" | "male";
ethnicity?: string;
@ -246,6 +248,7 @@ export interface ChatMessage {
type: "error" | "generating" | "info" | "preparing" | "processing" | "response" | "searching" | "system" | "thinking" | "tooling" | "user";
sender: "user" | "assistant" | "system";
timestamp: Date;
tunables?: Tunables;
content?: string;
metadata?: ChatMessageMetaData;
}
@ -258,19 +261,20 @@ export interface ChatMessageBase {
type: "error" | "generating" | "info" | "preparing" | "processing" | "response" | "searching" | "system" | "thinking" | "tooling" | "user";
sender: "user" | "assistant" | "system";
timestamp: Date;
tunables?: Tunables;
content?: string;
}
export interface ChatMessageMetaData {
model?: "qwen2.5" | "flux-schnell";
model: "qwen2.5";
temperature?: number;
maxTokens?: number;
topP?: number;
frequencyPenalty?: number;
presencePenalty?: number;
stopSequences?: Array<string>;
tunables?: Tunables;
rag?: Array<ChromaDBGetResponse>;
ragResults?: Array<ChromaDBGetResponse>;
llmHistory?: Array<LLMMessage>;
evalCount?: number;
evalDuration?: number;
promptEvalCount?: number;
@ -284,10 +288,11 @@ export interface ChatMessageUser {
id?: string;
sessionId: string;
senderId?: string;
status: "initializing" | "streaming" | "done" | "error";
type?: "error" | "generating" | "info" | "preparing" | "processing" | "response" | "searching" | "system" | "thinking" | "tooling" | "user";
sender: "user" | "assistant" | "system";
status: "done";
type: "user";
sender: "user";
timestamp: Date;
tunables?: Tunables;
content?: string;
}
@ -386,7 +391,7 @@ export interface Employer {
lastLogin?: Date;
profileImage?: string;
status: "active" | "inactive" | "pending" | "banned";
userType?: "employer";
userType: "employer";
companyName: string;
industry: string;
description?: string;
@ -507,6 +512,12 @@ export interface JobResponse {
meta?: Record<string, any>;
}
export interface LLMMessage {
role?: string;
content?: string;
toolCalls?: Array<Record<string, any>>;
}
export interface Language {
language: string;
proficiency: "basic" | "conversational" | "fluent" | "native";

View File

@ -1,4 +1,5 @@
from __future__ import annotations
import traceback
from pydantic import BaseModel, Field # type: ignore
from typing import (
Literal,
@ -18,18 +19,18 @@ import pathlib
import inspect
from prometheus_client import CollectorRegistry # type: ignore
from database import RedisDatabase
from . base import Agent
from logger import logger
from models import Candidate
_agents: List[Agent] = []
def get_or_create_agent(agent_type: str, prometheus_collector: CollectorRegistry, database: RedisDatabase, **kwargs) -> Agent:
def get_or_create_agent(agent_type: str, prometheus_collector: CollectorRegistry, user: Optional[Candidate]=None, **kwargs) -> Agent:
"""
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
Args:
agent_type: The type of agent to create (e.g., 'web', 'database').
agent_type: The type of agent to create (e.g., 'general', 'candidate_chat', 'image_generation').
**kwargs: Additional fields required by the specific agent subclass.
Returns:
@ -47,7 +48,7 @@ def get_or_create_agent(agent_type: str, prometheus_collector: CollectorRegistry
for agent_cls in Agent.__subclasses__():
if agent_cls.model_fields["agent_type"].default == agent_type:
# Create the agent instance with provided kwargs
agent = agent_cls(agent_type=agent_type, prometheus_collector=prometheus_collector, database=database, **kwargs)
agent = agent_cls(agent_type=agent_type, user=user, prometheus_collector=prometheus_collector, **kwargs)
# if agent.agent_persist: # If an agent is not set to persist, do not add it to the list
_agents.append(agent)
return agent
@ -90,8 +91,10 @@ for path in package_dir.glob("*.py"):
logger.info(f"Adding agent: {name}")
__all__.append(name) # type: ignore
except ImportError as e:
logger.error(traceback.format_exc())
logger.error(f"Error importing {full_module_name}: {e}")
raise e
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Error processing {full_module_name}: {e}")
raise e

View File

@ -1,4 +1,5 @@
from __future__ import annotations
import traceback
from pydantic import BaseModel, Field, model_validator # type: ignore
from typing import (
Literal,
@ -20,19 +21,16 @@ from abc import ABC
import asyncio
from datetime import datetime, UTC
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
import numpy as np # type: ignore
from models import ( ChatQuery, ChatMessage, ChatOptions, ChatMessageBase, ChatMessageUser, Tunables, ChatMessageType, ChatSenderType, ChatStatusType, ChatMessageMetaData)
from models import ( LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageBase, ChatMessageUser, Tunables, ChatMessageType, ChatSenderType, ChatStatusType, ChatMessageMetaData, Candidate)
from logger import logger
import defines
from .registry import agent_registry
from metrics import Metrics
from database import RedisDatabase # type: ignore
import model_cast
class LLMMessage(BaseModel):
role: str = Field(default="")
content: str = Field(default="")
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
from rag import ( ChromaDBGetResponse )
class Agent(BaseModel, ABC):
"""
@ -45,12 +43,8 @@ class Agent(BaseModel, ABC):
# Agent management with pydantic
agent_type: Literal["base"] = "base"
_agent_type: ClassVar[str] = agent_type # Add this for registration
agent_persist: bool = True # Whether this agent will persist in the database
database: RedisDatabase = Field(
...,
description="Database connection for this agent, used to store and retrieve data."
)
user: Optional[Candidate] = None
prometheus_collector: CollectorRegistry = Field(..., description="Prometheus collector for this agent, used to track metrics.", exclude=True)
# Tunables (sets default for new Messages attached to this agent)
@ -131,48 +125,6 @@ class Agent(BaseModel, ABC):
def get_agent_type(self):
return self._agent_type
# async def prepare_message(self, message: ChatMessage) -> AsyncGenerator[ChatMessage, None]:
# """
# Prepare message with context information in message.preamble
# """
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
# self.metrics.prepare_count.labels(agent=self.agent_type).inc()
# with self.metrics.prepare_duration.labels(agent=self.agent_type).time():
# if not self.context:
# raise ValueError("Context is not set for this agent.")
# # Generate RAG content if enabled, based on the content
# rag_context = ""
# if message.tunables.enable_rag and message.prompt:
# # Gather RAG results, yielding each result
# # as it becomes available
# for message in self.context.user.generate_rag_results(message):
# logger.info(f"RAG: {message.status} - {message.content}")
# if message.status == "error":
# yield message
# return
# if message.status != "done":
# yield message
# # for rag in message.metadata.rag:
# # for doc in rag.documents:
# # rag_context += f"{doc}\n"
# message.preamble = {}
# if rag_context:
# message.preamble["context"] = f"The following is context information about {self.context.user.full_name}:\n{rag_context}"
# if message.tunables.enable_context and self.context.user_resume:
# message.preamble["resume"] = self.context.user_resume
# message.system_prompt = self.system_prompt
# message.status = ChatStatusType.DONE
# yield message
# return
# async def process_tool_calls(
# self,
# llm: Any,
@ -341,46 +293,168 @@ class Agent(BaseModel, ABC):
)
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
async def generate_rag_results(
self,
chat_message: ChatMessage,
top_k: int=defines.default_rag_top_k,
threshold: float=defines.default_rag_threshold
) -> AsyncGenerator[ChatMessage, None]:
"""
Generate RAG results for the given query.
Args:
query: The query string to generate RAG results for.
Returns:
A list of dictionaries containing the RAG results.
"""
rag_message = ChatMessage(
session_id=chat_message.session_id,
tunables=chat_message.tunables,
status=ChatStatusType.INITIALIZING,
type=ChatMessageType.PREPARING,
sender=ChatSenderType.ASSISTANT,
content="",
timestamp=datetime.now(UTC),
metadata=ChatMessageMetaData()
)
if not self.user:
rag_message.status = ChatStatusType.DONE
rag_message.content = "No user connected to this chat, so no RAG content."
yield rag_message
return
try:
entries: int = 0
user: Candidate = self.user
for rag in user.rags:
if not rag.enabled:
continue
rag_message.type = ChatMessageType.SEARCHING
rag_message.status = ChatStatusType.INITIALIZING
rag_message.content = f"Checking RAG context {rag.name}..."
yield rag_message
chroma_results = user.file_watcher.find_similar(
query=rag_message.content, top_k=top_k, threshold=threshold
)
if chroma_results:
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0]
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0]
rag_metadata = ChromaDBGetResponse(
query=chat_message.content,
query_embedding=query_embedding.tolist(),
name=rag.name,
ids=chroma_results.get("ids", []),
embeddings=chroma_results.get("embeddings", []),
documents=chroma_results.get("documents", []),
metadatas=chroma_results.get("metadatas", []),
umap_embedding_2d=umap_2d.tolist(),
umap_embedding_3d=umap_3d.tolist(),
size=user.file_watcher.collection.count()
)
rag_message.metadata.rag_results.append(rag_metadata)
rag_message.content = f"Results from {rag.name} RAG: {len(chroma_results['documents'])} results."
yield rag_message
rag_message.content = (
f"RAG context gathered from results from {entries} documents."
)
rag_message.status = ChatStatusType.DONE
yield rag_message
return
except Exception as e:
rag_message.status = ChatStatusType.ERROR
rag_message.content = f"Error generating RAG results: {str(e)}"
logger.error(traceback.format_exc())
logger.error(rag_message.content)
yield rag_message
return
async def generate(
self, llm: Any, model: str, query: ChatQuery, user_message: ChatMessageUser, user_id: str, temperature=0.7
self, llm: Any, model: str, user_message: ChatMessageUser, user_id: str, temperature=0.7
) -> AsyncGenerator[ChatMessage | ChatMessageBase, None]:
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
chat_message = ChatMessage(
session_id=user_message.session_id,
tunables=query.tunables,
tunables=user_message.tunables,
status=ChatStatusType.INITIALIZING,
type=ChatMessageType.PREPARING,
sender=ChatSenderType.ASSISTANT,
content="",
timestamp=datetime.now(UTC)
)
chat_message.metadata = ChatMessageMetaData()
chat_message.metadata.options = ChatOptions(
seed=8911,
num_ctx=self.context_size,
temperature=temperature, # Higher temperature to encourage tool usage
)
# Create a dict for storing various timing stats
chat_message.metadata.timers = {}
self.metrics.generate_count.labels(agent=self.agent_type).inc()
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
rag_message : Optional[ChatMessage] = None
async for rag_message in self.generate_rag_results(chat_message=user_message):
if rag_message.status == ChatStatusType.ERROR:
chat_message.status = rag_message.status
chat_message.content = rag_message.content
yield chat_message
return
yield rag_message
rag_context = ""
if rag_message:
rag_results: List[ChromaDBGetResponse] = rag_message.metadata.rag_results
chat_message.metadata.rag_results = rag_results
for chroma_results in rag_results:
for index, metadata in enumerate(chroma_results.metadatas):
content = "\n".join([
line.strip()
for line in chroma_results.documents[index].split("\n")
if line
]).strip()
rag_context += f"""
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
Document reference: {chroma_results.ids[index]}
Content: { content }
"""
# Create a pruned down message list based purely on the prompt and responses,
# discarding the full preamble generated by prepare_message
messages: List[LLMMessage] = [
LLMMessage(role="system", content=self.system_prompt)
]
# Add the conversation history to the messages
messages.extend([
LLMMessage(role=m.sender, content=m.content.strip())
for m in self.conversation
])
# Add the RAG context to the messages if available
if rag_context:
messages.append(
LLMMessage(
role="system",
content=f"<|context|>\n{rag_context.strip()}\n</|context|>"
)
)
# Only the actual user query is provided with the full context message
messages.append(
LLMMessage(role=user_message.sender, content=user_message.content.strip())
)
# message.messages = messages
chat_message.metadata = ChatMessageMetaData()
chat_message.metadata.options = ChatOptions(
seed=8911,
num_ctx=self.context_size,
temperature=temperature, # Higher temperature to encourage tool usage
)
# Create a dict for storing various timing stats
chat_message.metadata.timers = {}
chat_message.metadata.llm_history = messages
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
# message.metadata.tools = {
@ -497,7 +571,7 @@ class Agent(BaseModel, ABC):
model=model,
messages=messages,
options={
**chat_message.metadata.model_dump(exclude_unset=True),
**chat_message.metadata.options.model_dump(exclude_unset=True),
},
stream=True,
):

View File

@ -0,0 +1,3 @@
from .candidate_entity import CandidateEntity
from .entity_manager import entity_manager, get_candidate_entity

View File

@ -0,0 +1,226 @@
from __future__ import annotations
from pydantic import BaseModel, Field, model_validator # type: ignore
from uuid import uuid4
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
from typing_extensions import Annotated, Union
import numpy as np # type: ignore
from uuid import uuid4
from prometheus_client import CollectorRegistry, Counter # type: ignore
import traceback
import os
import json
import re
from pathlib import Path
from rag import start_file_watcher, ChromaDBFileWatcher, ChromaDBGetResponse
import defines
from logger import logger
import agents as agents
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageType, ChatMessageMetaData, ChatStatusType, Candidate, ChatContextType)
from llm_manager import llm_manager
class CandidateEntity(Candidate):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
# Internal instance members
CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True)
CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(
default=None, exclude=True
)
def __init__(self, candidate=None):
if candidate is not None:
# Copy attributes from the candidate instance
super().__init__(**vars(candidate))
else:
super().__init__()
@classmethod
def exists(cls, username: str):
# Validate username format (only allow safe characters)
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
return False # Invalid username characters
# Check for minimum and maximum length
if not (3 <= len(username) <= 32):
return False # Invalid username length
# Use Path for safe path handling and normalization
user_dir = Path(defines.user_dir) / username
user_info_path = user_dir / defines.user_info_file
# Ensure the final path is actually within the intended parent directory
# to help prevent directory traversal attacks
try:
if not user_dir.resolve().is_relative_to(Path(defines.user_dir).resolve()):
return False # Path traversal attempt detected
except (ValueError, RuntimeError): # Potential exceptions from resolve()
return False
# Check if file exists
return user_info_path.is_file()
def get_or_create_agent(self, agent_type: ChatContextType, **kwargs) -> agents.Agent:
"""
Get or create an agent of the specified type for this candidate.
Args:
agent_type: The type of agent to create (default is 'candidate_chat').
**kwargs: Additional fields required by the specific agent subclass.
Returns:
The created agent instance.
"""
return agents.get_or_create_agent(
agent_type=agent_type,
user=self,
prometheus_collector=self.prometheus_collector,
**kwargs)
# Wrapper properties that map into file_watcher
@property
def umap_collection(self) -> ChromaDBGetResponse:
if not self.CandidateEntity__file_watcher:
raise ValueError("initialize() has not been called.")
return self.CandidateEntity__file_watcher.umap_collection
# Fields managed by initialize()
CandidateEntity__initialized: bool = Field(default=False, exclude=True)
@property
def file_watcher(self) -> ChromaDBFileWatcher:
if not self.CandidateEntity__file_watcher:
raise ValueError("initialize() has not been called.")
return self.CandidateEntity__file_watcher
@property
def prometheus_collector(self) -> CollectorRegistry:
if not self.CandidateEntity__prometheus_collector:
raise ValueError("initialize() has not been called with a prometheus_collector.")
return self.CandidateEntity__prometheus_collector
@property
def observer(self) -> Any:
if not self.CandidateEntity__observer:
raise ValueError("initialize() has not been called.")
return self.CandidateEntity__observer
@classmethod
def sanitize(cls, user: Dict[str, Any]):
sanitized : Dict[str, Any] = {}
sanitized["username"] = user.get("username", "default")
sanitized["first_name"] = user.get("first_name", sanitized["username"])
sanitized["last_name"] = user.get("last_name", "")
sanitized["title"] = user.get("title", "")
sanitized["phone"] = user.get("phone", "")
sanitized["location"] = user.get("location", "")
sanitized["email"] = user.get("email", "")
sanitized["full_name"] = user.get("full_name", f"{sanitized["first_name"]} {sanitized["last_name"]}")
sanitized["description"] = user.get("description", "")
profile_image = os.path.join(defines.user_dir, sanitized["username"], "profile.png")
sanitized["has_profile"] = os.path.exists(profile_image)
contact_info = user.get("contact_info", {})
sanitized["contact_info"] = {}
for key in contact_info:
if not isinstance(contact_info[key], (str, int, float, complex)):
continue
sanitized["contact_info"][key] = contact_info[key]
questions = user.get("questions", [ f"Tell me about {sanitized['first_name']}.", f"What are {sanitized['first_name']}'s professional strengths?"])
sanitized["user_questions"] = []
for question in questions:
if type(question) == str:
sanitized["user_questions"].append({"question": question})
else:
try:
tmp = CandidateQuestion.model_validate(question)
sanitized["user_questions"].append({"question": tmp.question})
except Exception as e:
continue
return sanitized
@classmethod
def get_users(cls):
# Initialize an empty list to store parsed JSON data
user_data = []
# Define the users directory path
users_dir = os.path.join(defines.user_dir)
# Check if the users directory exists
if not os.path.exists(users_dir):
return user_data
# Iterate through all items in the users directory
for item in os.listdir(users_dir):
# Construct the full path to the item
item_path = os.path.join(users_dir, item)
# Check if the item is a directory
if os.path.isdir(item_path):
# Construct the path to info.json
info_path = os.path.join(item_path, "info.json")
# Check if info.json exists
if os.path.exists(info_path):
try:
# Read and parse the JSON file
with open(info_path, 'r') as file:
data = json.load(file)
data["username"] = item
profile_image = os.path.join(defines.user_dir, item, "profile.png")
data["has_profile"] = os.path.exists(profile_image)
user_data.append(data)
except json.JSONDecodeError as e:
# Skip files that aren't valid JSON
logger.info(f"Invalid JSON for {info_path}: {str(e)}")
continue
except Exception as e:
# Skip files that can't be read
logger.info(f"Exception processing {info_path}: {str(e)}")
continue
return user_data
async def initialize(self, prometheus_collector: CollectorRegistry):
if self.CandidateEntity__initialized:
# Initialization can only be attempted once; if there are multiple attempts, it means
# a subsystem is failing or there is a logic bug in the code.
#
# NOTE: It is intentional that self.CandidateEntity__initialize = True regardless of whether it
# succeeded. This prevents server loops on failure
raise ValueError("initialize can only be attempted once")
self.CandidateEntity__initialized = True
if not self.username:
raise ValueError("username can not be empty")
user_dir = os.path.join(defines.user_dir, self.username)
vector_db_dir=os.path.join(user_dir, defines.persist_directory)
rag_content_dir=os.path.join(user_dir, defines.rag_content_dir)
logger.info(f"CandidateEntity(username={self.username}, user_dir={user_dir} persist_directory={vector_db_dir}, rag_content_dir={rag_content_dir}")
os.makedirs(vector_db_dir, exist_ok=True)
os.makedirs(rag_content_dir, exist_ok=True)
if prometheus_collector:
self.CandidateEntity__prometheus_collector = prometheus_collector
self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher(
llm=llm_manager.get_llm(),
collection_name=self.username,
persist_directory=vector_db_dir,
watch_directory=rag_content_dir,
recreate=False, # Don't recreate if exists
)
has_username_rag = any(item["name"] == self.username for item in self.rags)
if not has_username_rag:
self.rags.append(RagEntry(
name=self.username,
description=f"Expert data about {self.full_name}.",
))
self.rag_content_size = self.file_watcher.collection.count()
CandidateEntity.model_rebuild()

View File

@ -0,0 +1,122 @@
import asyncio
import weakref
from datetime import datetime, timedelta
from typing import Dict, Optional, Any
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field # type: ignore
from models import ( Candidate )
from .candidate_entity import CandidateEntity
from prometheus_client import CollectorRegistry # type: ignore
class EntityManager:
"""Manages lifecycle of CandidateEntity instances"""
def __init__(self, default_ttl_minutes: int = 30):
self._entities: Dict[str, CandidateEntity] = {}
self._weak_refs: Dict[str, weakref.ReferenceType] = {}
self._ttl_minutes = default_ttl_minutes
self._cleanup_task: Optional[asyncio.Task] = None
self._prometheus_collector: Optional[CollectorRegistry] = None
async def start_cleanup_task(self):
"""Start background cleanup task"""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
async def stop_cleanup_task(self):
"""Stop background cleanup task"""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
def initialize(self, prometheus_collector: CollectorRegistry):
"""Initialize the EntityManager with Prometheus collector"""
self._prometheus_collector = prometheus_collector
async def get_entity(self, candidate: Candidate) -> CandidateEntity:
"""Get or create CandidateEntity with proper reference tracking"""
# Check if entity exists and is still valid
if candidate.id in self._entities:
entity = self._entities[candidate.id]
entity._last_accessed = datetime.now()
entity._reference_count += 1
return entity
entity = CandidateEntity(candidate=candidate)
await entity.initialize(prometheus_collector=self._prometheus_collector)
# Store with reference tracking
self._entities[candidate.id] = entity
self._weak_refs[candidate.id] = weakref.ref(entity, self._on_entity_deleted(candidate.id))
entity._reference_count = 1
entity._last_accessed = datetime.now()
return entity
def _on_entity_deleted(self, user_id: str):
"""Callback when entity is garbage collected"""
def cleanup_callback(weak_ref):
self._entities.pop(user_id, None)
self._weak_refs.pop(user_id, None)
print(f"Entity {user_id} garbage collected")
return cleanup_callback
async def release_entity(self, user_id: str):
"""Explicitly release reference to entity"""
if user_id in self._entities:
entity = self._entities[user_id]
entity._reference_count = max(0, entity._reference_count - 1)
entity._last_accessed = datetime.now()
async def _periodic_cleanup(self):
"""Background task to clean up expired entities"""
while True:
try:
await asyncio.sleep(60) # Check every minute
await self._cleanup_expired_entities()
except asyncio.CancelledError:
break
except Exception as e:
print(f"Error in cleanup task: {e}")
async def _cleanup_expired_entities(self):
"""Remove entities that have expired based on TTL and reference count"""
current_time = datetime.now()
expired_entities = []
for user_id, entity in list(self._entities.items()):
time_since_access = current_time - entity._last_accessed
# Remove if TTL exceeded and no active references
if (time_since_access > timedelta(minutes=self._ttl_minutes)
and entity._reference_count == 0):
expired_entities.append(user_id)
for user_id in expired_entities:
entity = self._entities.pop(user_id, None)
self._weak_refs.pop(user_id, None)
if entity:
await entity.cleanup()
print(f"Cleaned up expired entity {user_id}")
# Global entity manager instance
entity_manager = EntityManager(default_ttl_minutes=30)
@asynccontextmanager
async def get_candidate_entity(candidate: Candidate):
"""Context manager for safe entity access with automatic reference management"""
if not entity_manager._prometheus_collector:
raise ValueError("EntityManager has not been initialized with a Prometheus collector.")
entity = await entity_manager.get_entity(candidate=candidate)
try:
yield entity
finally:
await entity_manager.release_entity(candidate.id)

View File

@ -2,7 +2,7 @@
"""
Enhanced Type Generator - Generate TypeScript types from Pydantic models
Now with command line parameters, pre-test validation, TypeScript compilation,
and automatic date field conversion functions
automatic date field conversion functions, and proper enum default handling
"""
import sys
@ -138,8 +138,55 @@ def is_date_type(python_type: Any) -> bool:
return False
def python_type_to_typescript(python_type: Any, debug: bool = False) -> str:
"""Convert a Python type to TypeScript type string"""
def get_default_enum_value(field_info: Any, debug: bool = False) -> Optional[Any]:
"""Extract the specific enum value from a field's default, if it exists"""
if not hasattr(field_info, 'default'):
return None
default_val = field_info.default
if debug:
print(f" 🔍 Checking default value: {repr(default_val)} (type: {type(default_val)})")
# Check for different types of "no default" markers
if default_val is ... or default_val is None:
if debug:
print(f" └─ Default is undefined marker")
return None
# Check for Pydantic's internal "PydanticUndefined" or similar markers
default_str = str(default_val)
default_type_str = str(type(default_val))
undefined_patterns = [
'PydanticUndefined',
'Undefined',
'_Unset',
'UNSET',
'NotSet',
'_MISSING'
]
is_undefined_marker = any(pattern in default_str or pattern in default_type_str
for pattern in undefined_patterns)
if is_undefined_marker:
if debug:
print(f" └─ Default is undefined marker pattern")
return None
# Check if it's an enum instance
if isinstance(default_val, Enum):
if debug:
print(f" └─ Default is enum instance: {default_val.value}")
return default_val
if debug:
print(f" └─ Default is not an enum instance")
return None
def python_type_to_typescript(python_type: Any, field_info: Any = None, debug: bool = False) -> str:
"""Convert a Python type to TypeScript type string, considering field defaults"""
if debug:
print(f" 🔍 Converting type: {python_type} (type: {type(python_type)})")
@ -151,6 +198,14 @@ def python_type_to_typescript(python_type: Any, debug: bool = False) -> str:
if debug and original_type != python_type:
print(f" 🔄 Unwrapped: {original_type} -> {python_type}")
# Check if this field has a specific enum default value
if field_info:
default_enum = get_default_enum_value(field_info, debug)
if default_enum is not None:
if debug:
print(f" 🎯 Field has specific enum default: {default_enum.value}")
return f'"{default_enum.value}"'
# Handle None/null
if python_type is type(None):
return "null"
@ -180,22 +235,22 @@ def python_type_to_typescript(python_type: Any, debug: bool = False) -> str:
# Handle Optional (Union[T, None])
if len(args) == 2 and type(None) in args:
non_none_type = next(arg for arg in args if arg is not type(None))
return python_type_to_typescript(non_none_type, debug)
return python_type_to_typescript(non_none_type, field_info, debug)
# Handle other unions
union_types = [python_type_to_typescript(arg, debug) for arg in args if arg is not type(None)]
union_types = [python_type_to_typescript(arg, None, debug) for arg in args if arg is not type(None)]
return " | ".join(union_types)
elif origin is list or origin is List:
if args:
item_type = python_type_to_typescript(args[0], debug)
item_type = python_type_to_typescript(args[0], None, debug)
return f"Array<{item_type}>"
return "Array<any>"
elif origin is dict or origin is Dict:
if len(args) == 2:
key_type = python_type_to_typescript(args[0], debug)
value_type = python_type_to_typescript(args[1], debug)
key_type = python_type_to_typescript(args[0], None, debug)
value_type = python_type_to_typescript(args[1], None, debug)
return f"Record<{key_type}, {value_type}>"
return "Record<string, any>"
@ -319,6 +374,13 @@ def is_field_optional(field_info: Any, field_type: Any, debug: bool = False) ->
if debug:
print(f" └─ RESULT: Required (default is undefined marker)")
return False
# Special case: if field has a specific default value (like enum), it's required
# because it will always have a value, just not optional for the consumer
if isinstance(default_val, Enum):
if debug:
print(f" └─ RESULT: Required (has specific enum default: {default_val.value})")
return False
# Any other actual default value makes it optional
if debug:
@ -398,7 +460,8 @@ def process_pydantic_model(model_class, debug: bool = False) -> Dict[str, Any]:
elif debug and ('date' in str(field_type).lower() or 'time' in str(field_type).lower()):
print(f" ⚠️ Field {ts_name} contains 'date'/'time' but not detected as date type: {field_type}")
ts_type = python_type_to_typescript(field_type, debug)
# Pass field_info to the type converter for default enum handling
ts_type = python_type_to_typescript(field_type, field_info, debug)
# Check if optional
is_optional = is_field_optional(field_info, field_type, debug)
@ -449,30 +512,16 @@ def process_pydantic_model(model_class, debug: bool = False) -> Dict[str, Any]:
elif debug and ('date' in str(field_type).lower() or 'time' in str(field_type).lower()):
print(f" ⚠️ Field {ts_name} contains 'date'/'time' but not detected as date type: {field_type}")
ts_type = python_type_to_typescript(field_type, debug)
# Pass field_info to the type converter for default enum handling
ts_type = python_type_to_typescript(field_type, field_info, debug)
# For Pydantic v1, check required and default
is_optional = is_field_optional(field_info, field_type)
is_optional = is_field_optional(field_info, field_type, debug)
if debug:
print(f" TS name: {ts_name}")
print(f" TS type: {ts_type}")
print(f" Optional: {is_optional}")
# Debug the optional logic
origin = get_origin(field_type)
args = get_args(field_type)
is_union_with_none = origin is Union and type(None) in args
has_default = hasattr(field_info, 'default')
has_default_factory = hasattr(field_info, 'default_factory') and field_info.default_factory is not None
print(f" └─ Type is Optional: {is_union_with_none}")
if has_default:
default_val = field_info.default
print(f" └─ Has default: {default_val} (is ...? {default_val is ...})")
else:
print(f" └─ No default attribute")
print(f" └─ Has default factory: {has_default_factory}")
print()
properties.append({
@ -730,7 +779,7 @@ def compile_typescript(ts_file: str) -> bool:
def main():
"""Main function with command line argument parsing"""
parser = argparse.ArgumentParser(
description='Generate TypeScript types from Pydantic models with date conversion functions',
description='Generate TypeScript types from Pydantic models with date conversion functions and proper enum default handling',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
@ -744,6 +793,9 @@ Examples:
Generated conversion functions can be used like:
const candidate = convertCandidateFromApi(apiResponse);
const jobs = convertArrayFromApi<Job>(apiResponse, 'Job');
Enum defaults are now properly handled:
status: ChatStatusType = ChatStatusType.DONE -> status: "done"
"""
)
@ -780,12 +832,12 @@ Generated conversion functions can be used like:
parser.add_argument(
'--version', '-v',
action='version',
version='TypeScript Generator 3.0 (with Date Conversion)'
version='TypeScript Generator 3.1 (with Enum Default Handling)'
)
args = parser.parse_args()
print("🚀 Enhanced TypeScript Type Generator with Date Conversion")
print("🚀 Enhanced TypeScript Type Generator with Enum Default Handling")
print("=" * 60)
print(f"📁 Source file: {args.source}")
print(f"📁 Output file: {args.output}")
@ -830,25 +882,30 @@ Generated conversion functions can be used like:
# Count conversion functions and provide detailed feedback
conversion_count = ts_content.count('export function convert') - ts_content.count('convertFromApi') - ts_content.count('convertArrayFromApi')
enum_specific_count = ts_content.count(': "') - ts_content.count('export type')
if conversion_count > 0:
print(f"🗓️ Generated {conversion_count} date conversion functions")
if args.debug:
# Show which models have date conversion
models_with_dates = []
for line in ts_content.split('\n'):
if line.startswith('export function convert') and 'FromApi' in line and 'convertFromApi' not in line:
model_name = line.split('convert')[1].split('FromApi')[0]
models_with_dates.append(model_name)
if models_with_dates:
print(f" Models with date conversion: {', '.join(models_with_dates)}")
if enum_specific_count > 0:
print(f"🎯 Generated {enum_specific_count} specific enum default types")
if args.debug:
# Show which models have date conversion
models_with_dates = []
for line in ts_content.split('\n'):
if line.startswith('export function convert') and 'FromApi' in line and 'convertFromApi' not in line:
model_name = line.split('convert')[1].split('FromApi')[0]
models_with_dates.append(model_name)
if models_with_dates:
print(f" Models with date conversion: {', '.join(models_with_dates)}")
# Provide troubleshooting info if debug mode
if args.debug:
print(f"\n🐛 Debug mode was enabled. If you see incorrect date conversions:")
print(f" 1. Check the debug output above for '📅 Date type check' lines")
print(f" 2. Look for '⚠️' warnings about false positives")
print(f" 3. Verify your Pydantic model field types are correct")
print(f" 4. Re-run with --debug to see detailed type analysis")
print(f"\n🐛 Debug mode was enabled. If you see incorrect type conversions:")
print(f" 1. Check the debug output above for '🎯 Field has specific enum default' lines")
print(f" 2. Look for '📅 Date type check' lines for date handling")
print(f" 3. Look for '⚠️' warnings about fallback types")
print(f" 4. Verify your Pydantic model field types and defaults are correct")
# Step 5: Compile TypeScript (unless skipped)
if not args.skip_compile:
@ -866,6 +923,8 @@ Generated conversion functions can be used like:
print(f"✅ File size: {file_size} characters")
if conversion_count > 0:
print(f"✅ Date conversion functions: {conversion_count}")
if enum_specific_count > 0:
print(f"✅ Specific enum default types: {enum_specific_count}")
if not args.skip_test:
print("✅ Model validation passed")
if not args.skip_compile:

View File

View File

@ -0,0 +1,60 @@
from pydantic import BaseModel, Field # type: ignore
import json
from typing import Any, List, Set
def check_serializable(obj: Any, path: str = "", errors: List[str] = [], visited: Set[int] = set()) -> List[str]:
"""
Recursively check all fields in an object for non-JSON-serializable types, avoiding infinite recursion.
Skips fields in Pydantic models marked with Field(..., exclude=True).
Args:
obj: The object to inspect (Pydantic model, dict, list, or other).
path: The current field path (e.g., 'field1.nested_field').
errors: List to collect error messages.
visited: Set of object IDs to track visited objects and prevent infinite recursion.
Returns:
List of error messages for non-serializable fields.
"""
# Check for circular reference by object ID
obj_id = id(obj)
if obj_id in visited:
errors.append(f"Field '{path}' contains a circular reference, skipping further inspection")
return errors
# Add current object to visited set
visited.add(obj_id)
try:
# Handle Pydantic models
if isinstance(obj, BaseModel):
for field_name, field_info in obj.model_fields.items():
# Skip fields marked with exclude=True
if field_info.exclude:
continue
value = getattr(obj, field_name)
new_path = f"{path}.{field_name}" if path else field_name
check_serializable(value, new_path, errors, visited)
# Handle dictionaries
elif isinstance(obj, dict):
for key, value in obj.items():
new_path = f"{path}[{key}]" if path else str(key)
check_serializable(value, new_path, errors, visited)
# Handle lists, tuples, or other iterables
elif isinstance(obj, (list, tuple)):
for i, value in enumerate(obj):
new_path = f"{path}[{i}]" if path else str(i)
check_serializable(value, new_path, errors, visited)
# Handle other types (check for JSON serializability)
else:
try:
json.dumps(obj)
except (TypeError, OverflowError, ValueError) as e:
errors.append(f"Field '{path}' contains non-serializable type: {type(obj)} ({str(e)})")
finally:
# Remove the current object from visited to allow processing in other branches
visited.discard(obj_id)
return errors

View File

@ -21,7 +21,7 @@ import uuid
import logging
from datetime import datetime, timezone, timedelta
from typing import Dict, Any, Optional
from pydantic import BaseModel, EmailStr, validator # type: ignore
from pydantic import BaseModel, EmailStr, field_validator # type: ignore
# Prometheus
from prometheus_client import Summary # type: ignore
from prometheus_fastapi_instrumentator import Instrumentator # type: ignore
@ -38,11 +38,11 @@ from auth_utils import (
)
import model_cast
import defines
import agents
from logger import logger
from database import RedisDatabase, redis_manager, DatabaseManager
from metrics import Metrics
from llm_manager import llm_manager
import entities
# =============================
# Import Pydantic models
@ -149,11 +149,11 @@ class LoginRequest(BaseModel):
login: str # Can be email or username
password: str
@validator('login')
@field_validator('login')
def sanitize_login(cls, v):
return sanitize_login_input(v)
@validator('password')
@field_validator('password')
def validate_password_not_empty(cls, v):
if not v or not v.strip():
raise ValueError('Password cannot be empty')
@ -168,13 +168,13 @@ class CreateCandidateRequest(BaseModel):
# Add other required candidate fields as needed
phone: Optional[str] = None
@validator('username')
@field_validator('username')
def validate_username(cls, v):
if not v or len(v.strip()) < 3:
raise ValueError('Username must be at least 3 characters long')
return v.strip().lower()
@validator('password')
@field_validator('password')
def validate_password_strength(cls, v):
is_valid, issues = validate_password_strength(v)
if not is_valid:
@ -194,13 +194,13 @@ class CreateEmployerRequest(BaseModel):
websiteUrl: Optional[str] = None
phone: Optional[str] = None
@validator('username')
@field_validator('username')
def validate_username(cls, v):
if not v or len(v.strip()) < 3:
raise ValueError('Username must be at least 3 characters long')
return v.strip().lower()
@validator('password')
@field_validator('password')
def validate_password_strength(cls, v):
is_valid, issues = validate_password_strength(v)
if not is_valid:
@ -1004,7 +1004,7 @@ class PasswordResetConfirm(BaseModel):
token: str
new_password: str
@validator('new_password')
@field_validator('new_password')
def validate_password_strength(cls, v):
is_valid, issues = validate_password_strength(v)
if not is_valid:
@ -1401,7 +1401,6 @@ async def create_chat_session(
@api_router.post("/chat/sessions/{session_id}/messages/stream")
async def post_chat_session_message_stream(
session_id: str = Path(...),
data: Dict[str, Any] = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
@ -1409,113 +1408,112 @@ async def post_chat_session_message_stream(
):
"""Post a message to a chat session and stream the response with persistence"""
try:
chat_session_data = await database.get_chat_session(session_id)
user_message_data = data.get("chatMessage")
if not user_message_data:
return JSONResponse(
status_code=400,
content=create_error_response("INVALID_CHAT_MESSAGE", "chatMessage cannot be empty")
)
user_message = ChatMessageUser.model_validate(user_message_data)
chat_session_data = await database.get_chat_session(user_message.session_id)
if not chat_session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
chat_session = ChatSession.model_validate(chat_session_data)
chat_type = chat_session.context.type
candidate_info = chat_session.context.additional_context.get("candidateInfo", {})
chat_type = chat_session_data.get("context", {}).get("type", "general")
# Get candidate info if this chat is about a specific candidate
candidate_info = chat_session_data.get("context", {}).get("additionalContext", {}).get("candidateInfo")
if candidate_info:
logger.info(f"🔗 Chat session {session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}")
logger.info(f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}")
else:
logger.info(f"🔗 Chat session {session_id} type {chat_type} accessed by user {current_user.id}")
query = data.get("query")
if not query:
logger.info(f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}")
return JSONResponse(
status_code=400,
content=create_error_response("INVALID_QUERY", "Query cannot be empty")
content=create_error_response("CANDIDATE_REQUIRED", "This chat session requires a candidate association")
)
chat_query = ChatQuery.model_validate(query)
chat_agent = agents.get_or_create_agent(
agent_type=chat_type,
prometheus_collector=prometheus_collector,
database=database
)
if not chat_agent:
candidate_data = await database.get_candidate(candidate_info["id"]) if candidate_info else None
candidate : Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None
if not candidate:
return JSONResponse(
status_code=400,
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type")
status_code=404,
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session")
)
logger.info(f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query: {user_message.content}")
# Store the user's message first
user_message = ChatMessageUser(
session_id=session_id,
type=ChatMessageType.USER,
status=ChatStatusType.DONE,
sender=ChatSenderType.USER,
content=chat_query.prompt,
timestamp=datetime.now(UTC)
)
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
# Entity automatically released when done
chat_agent = candidate_entity.get_or_create_agent(agent_type=chat_type)
if not chat_agent:
return JSONResponse(
status_code=400,
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type")
)
# Persist user message to database
await database.add_chat_message(session_id, user_message.model_dump())
logger.info(f"💬 User message saved to database for session {session_id}")
# Update session last activity
chat_session_data["lastActivity"] = datetime.now(UTC).isoformat()
await database.set_chat_session(session_id, chat_session_data)
async def message_stream_generator():
"""Generator to stream messages with persistence"""
last_log = None
ai_message = None
# Persist user message to database
await database.add_chat_message(user_message.session_id, user_message.model_dump())
logger.info(f"💬 User message saved to database for session {user_message.session_id}")
async for chat_message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
query=chat_query,
user_message=user_message,
user_id=current_user.id,
):
# Store reference to the complete AI message
if chat_message.status == ChatStatusType.DONE:
ai_message = chat_message
# If the message is not done, convert it to a ChatMessageBase to remove
# metadata and other unnecessary fields for streaming
if chat_message.status != ChatStatusType.DONE:
chat_message = model_cast.cast_to_model(ChatMessageBase, chat_message)
# Update session last activity
chat_session_data["lastActivity"] = datetime.now(UTC).isoformat()
await database.set_chat_session(user_message.session_id, chat_session_data)
json_data = chat_message.model_dump(mode='json', by_alias=True, exclude_unset=True)
json_str = json.dumps(json_data)
async def message_stream_generator():
"""Generator to stream messages with persistence"""
last_log = None
final_message = None
log = f"🔗 Message status={chat_message.status}, sender={getattr(chat_message, 'sender', 'unknown')}"
if last_log != log:
last_log = log
logger.info(log)
yield f"data: {json_str}\n\n"
# After streaming is complete, persist the final AI message to database
if ai_message and ai_message.status == ChatStatusType.DONE:
try:
await database.add_chat_message(session_id, ai_message.model_dump())
logger.info(f"🤖 AI message saved to database for session {session_id}")
async for generated_message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
user_message=user_message,
user_id=current_user.id,
):
# Store reference to the complete AI message
if generated_message.status == ChatStatusType.DONE:
final_message = generated_message
# Update session last activity again
chat_session_data["lastActivity"] = datetime.now(UTC).isoformat()
await database.set_chat_session(session_id, chat_session_data)
except Exception as e:
logger.error(f"Failed to save AI message to database: {e}")
# If the message is not done, convert it to a ChatMessageBase to remove
# metadata and other unnecessary fields for streaming
if generated_message.status != ChatStatusType.DONE:
generated_message = model_cast.cast_to_model(ChatMessageBase, generated_message)
return StreamingResponse(
message_stream_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
json_data = generated_message.model_dump(mode='json', by_alias=True, exclude_unset=True)
json_str = json.dumps(json_data)
log = f"🔗 Message status={generated_message.status}, sender={getattr(generated_message, 'sender', 'unknown')}"
if last_log != log:
last_log = log
logger.info(log)
yield f"data: {json_str}\n\n"
# After streaming is complete, persist the final AI message to database
if final_message and final_message.status == ChatStatusType.DONE:
try:
await database.add_chat_message(final_message.session_id, final_message.model_dump())
logger.info(f"🤖 AI message saved to database for session {final_message.session_id}")
# Update session last activity again
chat_session_data["lastActivity"] = datetime.now(UTC).isoformat()
await database.set_chat_session(final_message.session_id, chat_session_data)
except Exception as e:
logger.error(f"Failed to save AI message to database: {e}")
return StreamingResponse(
message_stream_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
except Exception as e:
logger.error(traceback.format_exc())
@ -1779,6 +1777,7 @@ async def track_requests(request, call_next):
# FastAPI Metrics
# ============================
prometheus_collector = CollectorRegistry()
entities.entity_manager.initialize(prometheus_collector)
# Keep the Instrumentator instance alive
instrumentator = Instrumentator(
@ -1835,4 +1834,4 @@ if __name__ == "__main__":
)
else:
logger.info(f"Starting web server at http://{host}:{port}")
uvicorn.run(app=app, host=host, port=port, log_config=None)
uvicorn.run(app="main:app", host=host, port=port, log_config=None)

View File

@ -388,6 +388,11 @@ class BaseUser(BaseModel):
class BaseUserWithType(BaseUser):
user_type: UserType = Field(..., alias="userType")
class RagEntry(BaseModel):
name: str
description: str = ""
enabled: bool = True
class Candidate(BaseUser):
user_type: Literal[UserType.CANDIDATE] = Field(UserType.CANDIDATE, alias="userType")
username: str
@ -405,6 +410,8 @@ class Candidate(BaseUser):
certifications: Optional[List[Certification]] = None
job_applications: Optional[List["JobApplication"]] = Field(None, alias="jobApplications")
has_profile: bool = Field(default=False, alias="hasProfile")
rags: List[RagEntry] = Field(default_factory=list)
rag_content_size : int = 0
# Used for AI generated personas
age: Optional[int] = None
gender: Optional[UserGender] = None
@ -540,11 +547,6 @@ class JobApplication(BaseModel):
"populate_by_name": True # Allow both field names and aliases
}
class RagEntry(BaseModel):
name: str
description: str = ""
enabled: bool = True
class ChromaDBGetResponse(BaseModel):
# Chroma fields
ids: List[str] = []
@ -573,6 +575,26 @@ class ChatOptions(BaseModel):
num_ctx: Optional[int] = Field(default=None, alias="numCtx") # Number of context tokens
temperature: Optional[float] = Field(default=0.7) # Higher temperature to encourage tool usage
class LLMMessage(BaseModel):
role: str = Field(default="")
content: str = Field(default="")
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
class ChatMessageBase(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
session_id: str = Field(..., alias="sessionId")
sender_id: Optional[str] = Field(None, alias="senderId")
status: ChatStatusType
type: ChatMessageType
sender: ChatSenderType
timestamp: datetime
tunables: Optional[Tunables] = None
content: str = ""
model_config = {
"populate_by_name": True # Allow both field names and aliases
}
class ChatMessageMetaData(BaseModel):
model: AIModelType = AIModelType.QWEN2_5
temperature: float = 0.7
@ -581,8 +603,8 @@ class ChatMessageMetaData(BaseModel):
frequency_penalty: Optional[float] = Field(None, alias="frequencyPenalty")
presence_penalty: Optional[float] = Field(None, alias="presencePenalty")
stop_sequences: Optional[List[str]] = Field(None, alias="stopSequences")
tunables: Optional[Tunables] = None
rag: List[ChromaDBGetResponse] = Field(default_factory=list)
rag_results: List[ChromaDBGetResponse] = Field(default_factory=list, alias="ragResults")
llm_history: List[LLMMessage] = Field(default_factory=list, alias="llmHistory")
eval_count: int = 0
eval_duration: int = 0
prompt_eval_count: int = 0
@ -594,28 +616,17 @@ class ChatMessageMetaData(BaseModel):
"populate_by_name": True # Allow both field names and aliases
}
class ChatMessageBase(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
session_id: str = Field(..., alias="sessionId")
sender_id: Optional[str] = Field(None, alias="senderId")
status: ChatStatusType
type: ChatMessageType
sender: ChatSenderType
timestamp: datetime
content: str = ""
model_config = {
"populate_by_name": True # Allow both field names and aliases
}
class ChatMessageUser(ChatMessageBase):
status: ChatStatusType = ChatStatusType.DONE
type: ChatMessageType = ChatMessageType.USER
sender: ChatSenderType = ChatSenderType.USER
class ChatMessage(ChatMessageBase):
#attachments: Optional[List[Attachment]] = None
#reactions: Optional[List[MessageReaction]] = None
#is_edited: bool = Field(False, alias="isEdited")
#edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory")
metadata: ChatMessageMetaData = Field(None)
metadata: ChatMessageMetaData = Field(default_factory=ChatMessageMetaData)
class ChatSession(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))

View File

@ -0,0 +1,8 @@
from .rag import ChromaDBFileWatcher, ChromaDBGetResponse, start_file_watcher, RagEntry
__all__ = [
"ChromaDBFileWatcher",
"ChromaDBGetResponse",
"start_file_watcher",
"RagEntry"
]

View File

@ -0,0 +1,223 @@
from __future__ import annotations
from typing import List, Dict, Any, Optional, TypedDict, Tuple
from markdown_it import MarkdownIt
from markdown_it.tree import SyntaxTreeNode
import traceback
import logging
import json
import defines
class Chunk(TypedDict):
text: str
metadata: Dict[str, Any]
def clear_chunk(chunk: Chunk):
chunk["text"] = ""
chunk["metadata"] = {
"doc_type": "unknown",
"source_file": chunk["metadata"]["source_file"],
"lines": chunk["metadata"]["lines"],
"path": "", # This will be updated during processing
"level": 0,
}
return chunk
class MarkdownChunker:
def __init__(self):
# Initialize the Markdown parser
self.md_parser = MarkdownIt("commonmark")
def process_file(self, file_path: str) -> Optional[List[Chunk]]:
"""
Process a single markdown file and return chunks.
Args:
file_path: Path to the markdown file
Returns:
List of chunks with metadata or None if file can't be processed
"""
try:
# Read the markdown file
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# Parse the markdown
tokens = self.md_parser.parse(content)
logging.info(f"Found {len(tokens)} in {file_path}")
ast = SyntaxTreeNode(tokens)
# Extract chunks with metadata
chunks = self.extract_chunks(ast, file_path, len(content.splitlines()))
return chunks
except Exception as e:
logging.error(f"Error processing {file_path}: {str(e)}")
logging.error(traceback.format_exc())
return None
def extract_chunks(self, ast: SyntaxTreeNode, file_path: str, total_lines: int) -> List[Chunk]:
"""
Extract logical chunks from the AST with appropriate metadata.
Args:
ast: The Abstract Syntax Tree from markdown-it-py
file_path: Path to the source file
Returns:
List of chunks with metadata
"""
chunks: List[Chunk] = []
current_headings: List[str] = [""] * 6 # Track h1-h6 headings
# Initialize a chunk structure
chunk: Chunk = {
"text": "",
"metadata": {
"source_file": file_path,
"lines": total_lines
},
}
clear_chunk(chunk)
# Process the AST recursively
self._process_node(ast, current_headings, chunks, chunk, level=0)
return chunks
def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
return {
k: ("" if v is None else v) for k, v in metadata.items() if v is not None
}
def _extract_text_from_children(self, node: SyntaxTreeNode) -> str:
lines = []
for child in node.children:
if child.type == "list_item":
lines.append(f"- {self._extract_text_from_children(child).strip()}")
elif child.type == "fence":
info = f"{child.info}" if hasattr(child, "info") and child.info else ""
lines.append(f"\n```{info}\n{child.content.strip()}\n```\n")
elif child.children:
lines.append(self._extract_text_from_children(child).strip())
elif hasattr(child, "content"):
lines.append(child.content.strip())
return "\n".join(lines)
def _process_node(
self,
node: SyntaxTreeNode,
current_headings: list[str],
chunks: List[Chunk],
chunk: Chunk,
level: int,
buffer: int = defines.chunk_buffer
) -> int:
is_list = False
# Handle heading nodes
if node.type == "heading":
# print(f'{" " * level}{chunk["metadata"]["path"]} {node.type}')
if node.is_nested and node.nester_tokens:
opening, closing = node.nester_tokens
level = int(opening.tag[1:]) - 1
heading_text = self._extract_text_from_children(node)
# Update lower heading states
current_headings[level] = heading_text
for i in range(level, 6):
if i != level:
current_headings[i] = ""
# Store previous chunk if it has text
if chunk["text"].strip():
path = " > ".join([h for h in current_headings if h])
chunk["text"] = chunk["text"].strip()
chunk["metadata"]["path"] = path
chunk["metadata"]["level"] = level
if node.nester_tokens:
opening, closing = node.nester_tokens
if opening and opening.map:
( begin, end ) = opening.map
metadata = chunk["metadata"]
metadata["chunk_begin"] = max(0, begin - buffer)
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
metadata["line_begin"] = begin
metadata["line_end"] = end
chunks.append(chunk.copy())
clear_chunk(chunk)
# Add code block directly to current chunk
elif node.type == "fence":
# print(f'{" " * level}{chunk["metadata"]["path"]} {node.type}')
language = node.info.strip() if hasattr(node, "info") and node.info else ""
code_block = f"\n```{language}\n{node.content.strip()}\n```\n"
chunk["text"] += code_block
# Handle list structures
# elif node.type in ["list", "list_item"]:
# # print(node.nester_tokens)
# # print(node.pretty(show_text=True))
# is_list = True
# list_chunk = []
# for child in node.children:
# level = self._process_node(child, current_headings, chunks, chunk, level=level)
# Handle paragraph
elif node.type in ["paragraph"]: # , "list", "list_item"]:
text = self._extract_text_from_children(node)
# indented_text = "\n".join([f'{"-" * (level + 2)}{line}' for line in text.split('\n')])
# print(f'{"-" * level}{chunk["metadata"]["path"]} {node.type}:\n{indented_text} {len(node.children)}\n')
text = text.strip()
if text:
# indented_text = "\n".join([f'{"-" * (level + 2)}{line}' for line in text.split('\n')])
# print(f'{" " * level}{chunk["metadata"]["path"]} {node.type}:\n{indented_text}\n')
chunk["text"] += f"\n{text}\n"
chunk["text"] = chunk["text"].strip()
if chunk["text"]:
path = " > ".join([h for h in current_headings if h])
chunk["text"] = chunk["text"]
chunk["metadata"]["path"] = path
chunk["metadata"]["level"] = level
if node.nester_tokens:
opening, closing = node.nester_tokens
if opening and opening.map:
( begin, end ) = opening.map
metadata = chunk["metadata"]
metadata["chunk_begin"] = max(0, begin - buffer)
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
metadata["line_begin"] = begin
metadata["line_end"] = end
chunks.append(chunk.copy())
clear_chunk(chunk)
# Recursively process children
if not is_list:
for child in node.children:
level = self._process_node(
child, current_headings, chunks, chunk, level=level
)
# After root-level recursion, finalize any remaining chunk
if node.type == "document":
# print(node.type)
if chunk["text"].strip():
path = " > ".join([h for h in current_headings if h])
chunk["metadata"]["path"] = path
if node.nester_tokens:
opening, closing = node.nester_tokens
if opening and opening.map:
( begin, end ) = opening.map
metadata = chunk["metadata"]
metadata["chunk_begin"] = max(0, begin - buffer)
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
metadata["line_begin"] = begin
metadata["line_end"] = end
chunks.append(chunk.copy())
return level

727
src/backend/rag/rag.py Normal file
View File

@ -0,0 +1,727 @@
from __future__ import annotations
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
from typing import List, Optional, Dict, Any, Union
import os
import glob
from pathlib import Path
import time
import hashlib
import asyncio
import logging
import json
import numpy as np # type: ignore
import traceback
import chromadb # type: ignore
import ollama
from watchdog.observers import Observer # type: ignore
from watchdog.events import FileSystemEventHandler # type: ignore
import umap # type: ignore
from markitdown import MarkItDown # type: ignore
from chromadb.api.models.Collection import Collection # type: ignore
from .markdown_chunker import (
MarkdownChunker,
Chunk,
)
# When imported as a module, use relative imports
import defines
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "ChromaDBGetResponse"]
DEFAULT_CHUNK_SIZE = 750
DEFAULT_CHUNK_OVERLAP = 100
class RagEntry(BaseModel):
name: str
description: str = ""
enabled: bool = True
class ChromaDBGetResponse(BaseModel):
name: str = ""
size: int = 0
ids: List[str] = []
embeddings: List[List[float]] = Field(default=[])
documents: List[str] = []
metadatas: List[Dict[str, Any]] = []
query: str = ""
query_embedding: Optional[List[float]] = Field(default=None)
umap_embedding_2d: Optional[List[float]] = Field(default=None)
umap_embedding_3d: Optional[List[float]] = Field(default=None)
enabled: bool = True
class Config:
validate_assignment = True
@field_validator("embeddings", "query_embedding", "umap_embedding_2d", "umap_embedding_3d")
@classmethod
def validate_embeddings(cls, value, field):
# logging.info(f"Validating {field.field_name} with value: {type(value)} - {value}")
if value is None:
return value
if isinstance(value, np.ndarray):
if field.field_name == "embeddings":
if value.ndim != 2:
raise ValueError(f"{field.name} must be a 2-dimensional NumPy array")
return [[float(x) for x in row] for row in value.tolist()]
else:
if value.ndim != 1:
raise ValueError(f"{field.field_name} must be a 1-dimensional NumPy array")
return [float(x) for x in value.tolist()]
if field.field_name == "embeddings":
if not all(isinstance(sublist, list) and all(isinstance(x, (int, float)) for x in sublist) for sublist in value):
raise ValueError(f"{field.field_name} must be a list of lists of floats")
return [[float(x) for x in sublist] for sublist in value]
else:
if not isinstance(value, list) or not all(isinstance(x, (int, float)) for x in value):
raise ValueError(f"{field.field_name} must be a list of floats")
return [float(x) for x in value]
class ChromaDBFileWatcher(FileSystemEventHandler):
def __init__(
self,
llm,
watch_directory,
loop,
persist_directory,
collection_name,
chunk_size=DEFAULT_CHUNK_SIZE,
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
recreate=False,
):
self.llm = llm
self.watch_directory = watch_directory
self.persist_directory = persist_directory or defines.persist_directory
self.collection_name = collection_name
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.loop = loop
self._umap_collection: ChromaDBGetResponse | None = None
self._umap_embedding_2d: np.ndarray = []
self._umap_embedding_3d: np.ndarray = []
self._umap_model_2d: umap.UMAP = None
self._umap_model_3d: umap.UMAP = None
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Path for storing file hash state
self.hash_state_path = os.path.join(
self.persist_directory, f"{collection_name}_hash_state.json"
)
# Flag to track if this is a new collection
self.is_new_collection = False
# Initialize ChromaDB collection
self._collection: Collection = self._get_vector_collection(recreate=recreate)
self._markdown_chunker = MarkdownChunker()
self._update_umaps()
# Setup text splitter
# Track file hashes and processing state
self.file_hashes = self._load_hash_state()
self.update_lock = asyncio.Lock()
self.processing_files = set()
@property
def collection(self) -> Collection:
return self._collection
@property
def umap_collection(self) -> ChromaDBGetResponse:
if not self._umap_collection:
raise ValueError("initialize_collection has not been called")
return self._umap_collection
@property
def umap_embedding_2d(self) -> np.ndarray:
return self._umap_embedding_2d
@property
def umap_embedding_3d(self) -> np.ndarray:
return self._umap_embedding_3d
@property
def umap_model_2d(self):
return self._umap_model_2d
@property
def umap_model_3d(self):
return self._umap_model_3d
def _markitdown(self, document: str, markdown: Path):
logging.info(f"Converting {document} to {markdown}")
try:
result = self.md.convert(document)
markdown.write_text(result.text_content)
except Exception as e:
logging.error(f"Error convering via markdownit: {e}")
def _save_hash_state(self):
"""Save the current file hash state to disk."""
try:
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True)
with open(self.hash_state_path, "w") as f:
json.dump(self.file_hashes, f)
logging.info(f"Saved hash state with {len(self.file_hashes)} entries")
except Exception as e:
logging.error(f"Error saving hash state: {e}")
def _load_hash_state(self):
"""Load the file hash state from disk."""
if os.path.exists(self.hash_state_path):
try:
with open(self.hash_state_path, "r") as f:
hash_state = json.load(f)
logging.info(f"Loaded hash state with {len(hash_state)} entries")
return hash_state
except Exception as e:
logging.error(f"Error loading hash state: {e}")
return {}
async def scan_directory(self, process_all=False):
"""
Scan directory for new, modified, or deleted files and update collection.
Args:
process_all: If True, process all files regardless of hash status
"""
# Check for new or modified files
file_paths = glob.glob(
os.path.join(self.watch_directory, "**/*"), recursive=True
)
files_checked = 0
files_processed = 0
files_to_process = []
logging.info(f"Starting directory scan. Found {len(file_paths)} total paths.")
for file_path in file_paths:
if os.path.isfile(file_path):
# Do not put the Resume in RAG as it is provideded with all queries.
# if file_path == defines.resume_doc:
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
# continue
files_checked += 1
current_hash = self._get_file_hash(file_path)
if not current_hash:
logging.info(f"Unable to obtain hash of {file_path}")
continue
# If file is new, changed, or we're processing all files
if (
process_all
or file_path not in self.file_hashes
or self.file_hashes[file_path] != current_hash
):
self.file_hashes[file_path] = current_hash
files_to_process.append(file_path)
logging.info(
f"File {'found' if process_all else 'changed'}: {file_path}"
)
logging.info(
f"Found {len(files_to_process)} files to process after scanning {files_checked} files"
)
# Check for deleted files
deleted_files = []
for file_path in self.file_hashes:
if not os.path.exists(file_path):
deleted_files.append(file_path)
# Schedule removal
asyncio.run_coroutine_threadsafe(
self.remove_file_from_collection(file_path), self.loop
)
# Don't block on result, just let it run
logging.info(f"File deleted: {file_path}")
# Remove deleted files from hash state
for file_path in deleted_files:
del self.file_hashes[file_path]
# Process all discovered files using asyncio.gather with the existing loop
if files_to_process:
logging.info(f"Starting to process {len(files_to_process)} files")
for file_path in files_to_process:
async with self.update_lock:
files_processed += 1
await self._update_document_in_collection(file_path)
else:
logging.info("No files to process")
# Save the updated state
self._save_hash_state()
logging.info(
f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}"
)
return files_processed
async def process_file_update(self, file_path):
"""Process a file update event."""
# Skip if already being processed
if file_path in self.processing_files:
logging.info(f"{file_path} already in queue. Not adding.")
return
# if file_path == defines.resume_doc:
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
# return
try:
logging.info(f"{file_path} not in queue. Adding.")
self.processing_files.add(file_path)
# Wait a moment to ensure the file write is complete
await asyncio.sleep(0.5)
# Check if content changed via hash
current_hash = self._get_file_hash(file_path)
if not current_hash: # File might have been deleted or is inaccessible
return
if (
file_path in self.file_hashes
and self.file_hashes[file_path] == current_hash
):
# File hasn't actually changed in content
logging.info(f"Hash has not changed for {file_path}")
return
# Update file hash
self.file_hashes[file_path] = current_hash
# Process and update the file in ChromaDB
async with self.update_lock:
await self._update_document_in_collection(file_path)
# Save the hash state after successful update
self._save_hash_state()
# Re-fit the UMAP for the new content
self._update_umaps()
except Exception as e:
logging.error(f"Error processing update for {file_path}: {e}")
finally:
self.processing_files.discard(file_path)
async def remove_file_from_collection(self, file_path):
"""Remove all chunks related to a deleted file."""
async with self.update_lock:
try:
# Find all documents with the specified path
results = self.collection.get(where={"path": file_path})
if results and "ids" in results and results["ids"]:
self.collection.delete(ids=results["ids"])
logging.info(
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
)
# Remove from hash dictionary
if file_path in self.file_hashes:
del self.file_hashes[file_path]
# Save the updated hash state
self._save_hash_state()
except Exception as e:
logging.error(f"Error removing file from collection: {e}")
def _update_umaps(self):
# Update the UMAP embeddings
self._umap_collection = self._collection.get(
include=["embeddings", "documents", "metadatas"]
)
if not self._umap_collection or not len(self._umap_collection["embeddings"]):
logging.warning("No embeddings found in the collection.")
return
# During initialization
logging.info(
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection['embeddings'])} vectors"
)
vectors = np.array(self._umap_collection["embeddings"])
self._umap_model_2d = umap.UMAP(
n_components=2,
random_state=8911,
metric="cosine",
n_neighbors=30,
min_dist=0.1,
)
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
# logging.info(
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
# ) # Should be 2
logging.info(
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection['embeddings'])} vectors"
)
self._umap_model_3d = umap.UMAP(
n_components=3,
random_state=8911,
metric="cosine",
n_neighbors=30,
min_dist=0.01,
)
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
# logging.info(
# f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
# ) # Should be 3
def _get_vector_collection(self, recreate=False) -> Collection:
"""Get or create a ChromaDB collection."""
# Create the directory if it doesn't exist
if not os.path.exists(self.persist_directory):
os.makedirs(self.persist_directory)
# Initialize ChromaDB client
chroma_client = chromadb.PersistentClient( # type: ignore
path=self.persist_directory,
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
)
# Check if the collection exists
try:
chroma_client.get_collection(self.collection_name)
collection_exists = True
except:
collection_exists = False
# If collection doesn't exist, mark it as new
if not collection_exists:
self.is_new_collection = True
logging.info(f"Creating new collection: {self.collection_name}")
# Delete if recreate is True
if recreate and collection_exists:
chroma_client.delete_collection(name=self.collection_name)
self.is_new_collection = True
logging.info(f"Recreating collection: {self.collection_name}")
return chroma_client.get_or_create_collection(
name=self.collection_name, metadata={"hnsw:space": "cosine"}
)
def get_embedding(self, text: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text."""
# Get embedding
try:
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
embedding = np.array(response["embedding"])
except Exception as e:
logging.error(f"Failed to get embedding: {e}")
raise
# Log diagnostics
logging.info(f"Input text: {text}")
logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
# Check for invalid embeddings
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
raise ValueError("Invalid embedding returned from Ollama.")
# Check normalization
norm = np.linalg.norm(embedding)
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
# Normalize if needed
if not is_normalized:
embedding = embedding / norm
logging.info("Embedding normalized manually.")
return embedding
def add_embeddings_to_collection(self, chunks: List[Chunk]):
"""Add embeddings for chunks to the collection."""
for i, chunk in enumerate(chunks):
text = chunk["text"]
metadata = chunk["metadata"]
# Generate a more unique ID based on content and metadata
path_hash = ""
if "path" in metadata:
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[
:8
]
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
chunk_id = f"{path_hash}_{i}_{content_hash}"
embedding = self.get_embedding(text)
try:
self.collection.add(
ids=[chunk_id],
documents=[text],
embeddings=[embedding],
metadatas=[metadata],
)
except Exception as e:
logging.error(f"Error adding chunk to collection: {e}")
logging.error(traceback.format_exc())
logging.error(chunk)
def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer)-> str | None:
try:
source_file = meta["source_file"]
path_parts = source_file.split(os.sep)
file_name = path_parts[-1]
meta["source_file"] = file_name
with open(source_file, "r") as file:
lines = file.readlines()
meta["file_lines"] = len(lines)
start = max(0, meta["line_begin"] - buffer)
meta["chunk_begin"] = start
end = min(meta["lines"], meta["line_end"] + buffer)
meta["chunk_end"] = end
return "".join(lines[start:end])
except:
logging.warning(f"Unable to open {meta["source_file"]}")
return None
# Cosine Distance Equivalent Similarity Retrieval Characteristics
# 0.2 - 0.3 0.85 - 0.90 Very strict, highly precise results only
# 0.3 - 0.5 0.75 - 0.85 Strong relevance, good precision
# 0.5 - 0.7 0.65 - 0.75 Balanced precision/recall
# 0.7 - 0.9 0.55 - 0.65 Higher recall, more inclusive
# 0.9 - 1.2 0.40 - 0.55 Very inclusive, may include tangential content
def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
"""Find similar documents to the query."""
# collection is configured with hnsw:space cosine
query_embedding = self.get_embedding(query)
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=["documents", "metadatas", "distances"],
)
# Extract results
ids = results["ids"][0]
documents = results["documents"][0]
distances = results["distances"][0]
metadatas = results["metadatas"][0]
filtered_ids = []
filtered_documents = []
filtered_distances = []
filtered_metadatas = []
for i, distance in enumerate(distances):
if distance <= threshold: # For cosine distance, smaller is better
filtered_ids.append(ids[i])
filtered_documents.append(documents[i])
filtered_metadatas.append(metadatas[i])
filtered_distances.append(distance)
for index, meta in enumerate(filtered_metadatas):
content = self.prepare_metadata(meta)
if content is not None:
filtered_documents[index] = content
# Return the filtered results instead of all results
return {
"query_embedding": query_embedding,
"ids": filtered_ids,
"documents": filtered_documents,
"distances": filtered_distances,
"metadatas": filtered_metadatas,
}
def _get_file_hash(self, file_path):
"""Calculate MD5 hash of a file."""
try:
with open(file_path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()
except Exception as e:
logging.error(f"Error hashing file {file_path}: {e}")
return None
def on_modified(self, event):
"""Handle file modification events."""
if event.is_directory:
return
file_path = event.src_path
# Schedule the update using asyncio
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
logging.info(f"File modified: {file_path}")
def on_created(self, event):
"""Handle file creation events."""
if event.is_directory:
return
file_path = event.src_path
# Schedule the update using asyncio
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
logging.info(f"File created: {file_path}")
def on_deleted(self, event):
"""Handle file deletion events."""
if event.is_directory:
return
file_path = event.src_path
asyncio.run_coroutine_threadsafe(
self.remove_file_from_collection(file_path), self.loop
)
logging.info(f"File deleted: {file_path}")
def on_moved(self, event):
"""Handle move deletion events."""
if event.is_directory:
return
file_path = event.src_path
logging.info(f"TODO: on_moved: ${file_path}")
def _normalize_embeddings(self, embeddings):
"""Normalize the embeddings to unit length."""
# Handle both single vector and array of vectors
if isinstance(embeddings[0], (int, float)):
# Single vector
norm = np.linalg.norm(embeddings)
return [e / norm for e in embeddings] if norm > 0 else embeddings
else:
# Array of vectors
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings / norms
async def _update_document_in_collection(self, file_path):
"""Update a document in the ChromaDB collection."""
try:
# Remove existing entries for this file
existing_results = self.collection.get(where={"path": file_path})
if (
existing_results
and "ids" in existing_results
and existing_results["ids"]
):
self.collection.delete(ids=existing_results["ids"])
extensions = (".docx", ".xlsx", ".xls", ".pdf")
if file_path.endswith(extensions):
p = Path(file_path)
p_as_md = p.with_suffix(".md")
if p_as_md.exists():
logging.info(
f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}"
)
# If file_path.md doesn't exist or file_path is newer than file_path.md,
# fire off markitdown
if (not p_as_md.exists()) or (
p.stat().st_mtime > p_as_md.stat().st_mtime
):
self._markitdown(file_path, p_as_md)
return
chunks = self._markdown_chunker.process_file(file_path)
if not chunks:
logging.info(f"No chunks found in markdown: {file_path}")
return
# Extract top-level directory
rel_path = os.path.relpath(file_path, self.watch_directory)
path_parts = rel_path.split(os.sep)
top_level_dir = path_parts[0]
# file_name = path_parts[-1]
for i, chunk in enumerate(chunks):
chunk["metadata"]["doc_type"] = top_level_dir
# with open(f"src/tmp/{file_name}.{i}", "w") as f:
# f.write(json.dumps(chunk, indent=2))
# Add chunks to collection
self.add_embeddings_to_collection(chunks)
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
except Exception as e:
logging.error(f"Error updating document in collection: {e}")
logging.error(traceback.format_exc())
async def initialize_collection(self):
"""Initialize the collection with all documents from the watch directory."""
# Process all files regardless of hash state
num_processed = await self.scan_directory(process_all=True)
logging.info(
f"Vectorstore initialized with {self.collection.count()} documents"
)
self._update_umaps()
# Show stats
try:
all_metadata = self.collection.get()["metadatas"]
if all_metadata:
doc_types = set(m.get("doc_type", "unknown") for m in all_metadata)
logging.info(f"Document types: {doc_types}")
except Exception as e:
logging.error(f"Error getting document types: {e}")
return num_processed
# Function to start the file watcher
def start_file_watcher(
llm,
watch_directory,
persist_directory,
collection_name,
initialize=False,
recreate=False,
):
"""
Start watching a directory for file changes.
Args:
llm: The language model client
watch_directory: Directory to watch for changes
persist_directory: Directory to persist ChromaDB and hash state
collection_name: Name of the ChromaDB collection
initialize: Whether to forcibly initialize the collection with all documents
recreate: Whether to recreate the collection (will delete existing)
"""
loop = asyncio.get_event_loop()
file_watcher = ChromaDBFileWatcher(
llm,
watch_directory=watch_directory,
loop=loop,
persist_directory=persist_directory,
collection_name=collection_name,
recreate=recreate,
)
# Process all files if:
# 1. initialize=True was passed (explicit request to initialize)
# 2. This is a new collection (doesn't exist yet)
# 3. There's no hash state (first run)
if initialize or file_watcher.is_new_collection or not file_watcher.file_hashes:
logging.info("Initializing collection with all documents")
asyncio.run_coroutine_threadsafe(file_watcher.initialize_collection(), loop)
else:
# Only process new/changed files
logging.info("Scanning for new/changed documents")
asyncio.run_coroutine_threadsafe(file_watcher.scan_directory(), loop)
# Start observer
observer = Observer()
observer.schedule(file_watcher, watch_directory, recursive=True)
observer.start()
logging.info(f"Started watching directory: {watch_directory}")
return observer, file_watcher

View File

@ -0,0 +1,14 @@
import importlib
from .basetools import all_tools, ToolEntry, llm_tools, enabled_tools, tool_functions
from ..setup_logging import setup_logging
from .. import defines
logger = setup_logging(level=defines.logging_level)
# Dynamically import all names from basetools listed in tools_all
module = importlib.import_module(".basetools", package=__package__)
for name in tool_functions:
globals()[name] = getattr(module, name)
__all__ = ["all_tools", "ToolEntry", "llm_tools", "enabled_tools", "tool_functions"]

View File

@ -0,0 +1,526 @@
import os
from pydantic import BaseModel, Field, model_validator # type: ignore
from typing import List, Optional, Generator, ClassVar, Any, Dict
from datetime import datetime
from typing import (
Any,
)
from typing_extensions import Annotated
from bs4 import BeautifulSoup # type: ignore
from geopy.geocoders import Nominatim # type: ignore
import pytz # type: ignore
import requests
import yfinance as yf # type: ignore
import logging
# %%
def WeatherForecast(city, state, country="USA"):
"""
Get weather information from weather.gov based on city, state, and country.
Args:
city (str): City name
state (str): State name or abbreviation
country (str): Country name (defaults to "USA" as weather.gov is for US locations)
Returns:
dict: Weather forecast information
"""
# Step 1: Get coordinates for the location using geocoding
location = f"{city}, {state}, {country}"
coordinates = get_coordinates(location)
if not coordinates:
return {"error": f"Could not find coordinates for {location}"}
# Step 2: Get the forecast grid endpoint for the coordinates
grid_endpoint = get_grid_endpoint(coordinates)
if not grid_endpoint:
return {"error": f"Could not find weather grid for coordinates {coordinates}"}
# Step 3: Get the forecast data from the grid endpoint
forecast = get_forecast(grid_endpoint)
if not forecast["location"]:
forecast["location"] = location
return forecast
def get_coordinates(location):
"""Convert a location string to latitude and longitude using Nominatim geocoder."""
try:
# Create a geocoder with a meaningful user agent
geolocator = Nominatim(user_agent="weather_app_example")
# Get the location
location_data = geolocator.geocode(location)
if location_data:
return {
"latitude": location_data.latitude,
"longitude": location_data.longitude,
}
else:
print(f"Location not found: {location}")
return None
except Exception as e:
print(f"Error getting coordinates: {e}")
return None
def get_grid_endpoint(coordinates):
"""Get the grid endpoint from weather.gov based on coordinates."""
try:
lat = coordinates["latitude"]
lon = coordinates["longitude"]
# Define headers for the API request
headers = {
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
"Accept": "application/geo+json",
}
# Make the request to get the grid endpoint
url = f"https://api.weather.gov/points/{lat},{lon}"
response = requests.get(url, headers=headers)
if response.status_code == 200:
data = response.json()
return data["properties"]["forecast"]
else:
print(f"Error getting grid: {response.status_code} - {response.text}")
return None
except Exception as e:
print(f"Error in get_grid_endpoint: {e}")
return None
# Weather related function
def get_forecast(grid_endpoint):
"""Get the forecast data from the grid endpoint."""
try:
# Define headers for the API request
headers = {
"User-Agent": "WeatherAppExample/1.0 (your_email@example.com)",
"Accept": "application/geo+json",
}
# Make the request to get the forecast
response = requests.get(grid_endpoint, headers=headers)
if response.status_code == 200:
data = response.json()
# Extract the relevant forecast information
periods = data["properties"]["periods"]
# Process the forecast data into a simpler format
forecast = {
"location": data["properties"]
.get("relativeLocation", {})
.get("properties", {}),
"updated": data["properties"].get("updated", ""),
"periods": [],
}
for period in periods:
forecast["periods"].append(
{
"name": period.get("name", ""),
"temperature": period.get("temperature", ""),
"temperatureUnit": period.get("temperatureUnit", ""),
"windSpeed": period.get("windSpeed", ""),
"windDirection": period.get("windDirection", ""),
"shortForecast": period.get("shortForecast", ""),
"detailedForecast": period.get("detailedForecast", ""),
}
)
return forecast
else:
print(f"Error getting forecast: {response.status_code} - {response.text}")
return {"error": f"API Error: {response.status_code}"}
except Exception as e:
print(f"Error in get_forecast: {e}")
return {"error": f"Exception: {str(e)}"}
# Example usage
# def do_weather():
# city = input("Enter city: ")
# state = input("Enter state: ")
# country = input("Enter country (default USA): ") or "USA"
# print(f"Getting weather for {city}, {state}, {country}...")
# weather_data = WeatherForecast(city, state, country)
# if "error" in weather_data:
# print(f"Error: {weather_data['error']}")
# else:
# print("\nWeather Forecast:")
# print(f"Location: {weather_data.get('location', {}).get('city', city)}, {weather_data.get('location', {}).get('state', state)}")
# print(f"Last Updated: {weather_data.get('updated', 'N/A')}")
# print("\nForecast Periods:")
# for period in weather_data.get("periods", []):
# print(f"\n{period['name']}:")
# print(f" Temperature: {period['temperature']}{period['temperatureUnit']}")
# print(f" Wind: {period['windSpeed']} {period['windDirection']}")
# print(f" Forecast: {period['shortForecast']}")
# print(f" Details: {period['detailedForecast']}")
# %%
def TickerValue(ticker_symbols):
api_key = os.getenv("TWELVEDATA_API_KEY", "")
if not api_key:
return {"error": f"Error fetching data: No API key for TwelveData"}
results = []
for ticker_symbol in ticker_symbols.split(","):
ticker_symbol = ticker_symbol.strip()
if ticker_symbol == "":
continue
url = (
f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
)
response = requests.get(url)
data = response.json()
if "price" in data:
logging.info(f"TwelveData: {ticker_symbol} {data}")
results.append({"symbol": ticker_symbol, "price": float(data["price"])})
else:
logging.error(f"TwelveData: {data}")
results.append({"symbol": ticker_symbol, "price": "Unavailable"})
return results[0] if len(results) == 1 else results
# Stock related function
def yfTickerValue(ticker_symbols):
"""
Look up the current price of a stock using its ticker symbol.
Args:
ticker_symbol (str): The stock ticker symbol (e.g., 'AAPL' for Apple)
Returns:
dict: Current stock information including price
"""
results = []
for ticker_symbol in ticker_symbols.split(","):
ticker_symbol = ticker_symbol.strip()
if ticker_symbol == "":
continue
# Create a Ticker object
try:
logging.info(f"Looking up {ticker_symbol}")
ticker = yf.Ticker(ticker_symbol)
# Get the latest market data
ticker_data = ticker.history(period="1d")
if ticker_data.empty:
results.append({"error": f"No data found for ticker {ticker_symbol}"})
continue
# Get the latest closing price
latest_price = ticker_data["Close"].iloc[-1]
# Get some additional info
results.append({"symbol": ticker_symbol, "price": latest_price})
except Exception as e:
import traceback
logging.error(f"Error fetching data for {ticker_symbol}: {e}")
logging.error(traceback.format_exc())
results.append(
{"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}
)
return results[0] if len(results) == 1 else results
# %%
def DateTime(timezone="America/Los_Angeles"):
"""
Returns the current date and time in the specified timezone in ISO 8601 format.
Args:
timezone (str): Timezone name (e.g., "UTC", "America/New_York", "Europe/London")
Default is "America/Los_Angeles"
Returns:
str: Current date and time with timezone in the format YYYY-MM-DDTHH:MM:SS+HH:MM
"""
try:
if timezone == "system" or timezone == "" or not timezone:
timezone = "America/Los_Angeles"
# Get current UTC time (timezone-aware)
local_tz = pytz.timezone("America/Los_Angeles")
local_now = datetime.now(tz=local_tz)
# Convert to target timezone
target_tz = pytz.timezone(timezone)
target_time = local_now.astimezone(target_tz)
return target_time.isoformat()
except Exception as e:
return {"error": f"Invalid timezone {timezone}: {str(e)}"}
async def GenerateImage(llm, model: str, prompt: str):
return { "image_id": "image-a830a83-bd831" }
async def AnalyzeSite(llm, model: str, url: str, question: str):
"""
Fetches content from a URL, extracts the text, and uses Ollama to summarize it.
Args:
url (str): The URL of the website to summarize
Returns:
str: A summary of the website content
"""
try:
# Fetch the webpage
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
logging.info(f"Fetching {url}")
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
logging.info(f"{url} returned. Processing...")
# Parse the HTML
soup = BeautifulSoup(response.text, "html.parser")
# Remove script and style elements
for script in soup(["script", "style"]):
script.extract()
# Get text content
text = soup.get_text(separator=" ", strip=True)
# Clean up text (remove extra whitespace)
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = " ".join(chunk for chunk in chunks if chunk)
# Limit text length if needed (Ollama may have token limits)
max_chars = 100000
if len(text) > max_chars:
text = text[:max_chars] + "..."
# Create Ollama client
# logging.info(f"Requesting summary of: {text}")
# Generate summary using Ollama
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
response = llm.generate(
model=model,
system="You are given the contents of {url}. Answer the question about the contents",
prompt=prompt,
)
# logging.info(response["response"])
return {
"source": "summarizer-llm",
"content": response["response"],
"metadata": DateTime(),
}
except requests.exceptions.RequestException as e:
logging.error(f"Error fetching the URL: {e}")
return f"Error fetching the URL: {str(e)}"
except Exception as e:
logging.error(f"Error processing the website content: {e}")
return f"Error processing the website content: {str(e)}"
# %%
class Function(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
returns: Optional[Dict[str, Any]] = {}
class Tool(BaseModel):
type: str
function: Function
tools : List[Tool] = [
# Tool.model_validate({
# "type": "function",
# "function": {
# "name": "GenerateImage",
# "description": """\
# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
# 1. Call this tool when users request images, drawings, or visual content
# 2. This tool returns an image_id (e.g., "img_abc123")
# 3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage id={image_id}/>
# 4. FORBIDDEN: DO NOT use markdown image syntax ![](url)
# 5. FORBIDDEN: DO NOT create fake URLs or file paths
# 6. FORBIDDEN: DO NOT use any other image embedding format
# CORRECT EXAMPLE:
# User: "Draw a cat"
# Tool returns: {"image_id": "img_xyz789"}
# Your response: "Here's your cat image: <GenerateImage id=img_xyz789/>"
# WRONG EXAMPLES (DO NOT DO THIS):
# - ![](https://example.com/...)
# - ![Cat image](any_url)
# - <img src="...">
# The <GenerateImage id={image_id}/> format is the ONLY way to display images in this system.
# """,
# "parameters": {
# "type": "object",
# "properties": {
# "prompt": {
# "type": "string",
# "description": "Detailed image description including style, colors, subject, composition"
# }
# },
# "required": ["prompt"]
# },
# "returns": {
# "type": "object",
# "properties": {
# "image_id": {
# "type": "string",
# "description": "Unique identifier for the generated image. Use this EXACTLY in <GenerateImage id={this_value}/>"
# }
# }
# }
# }
# }),
Tool.model_validate({
"type": "function",
"function": {
"name": "TickerValue",
"description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').",
},
},
"required": ["ticker"],
"additionalProperties": False,
},
},
}),
Tool.model_validate({
"type": "function",
"function": {
"name": "AnalyzeSite",
"description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The website URL to download and process",
},
"question": {
"type": "string",
"description": "The question to ask the second LLM about the content",
},
},
"required": ["url", "question"],
"additionalProperties": False,
},
"returns": {
"type": "object",
"properties": {
"source": {
"type": "string",
"description": "Identifier for the source LLM",
},
"content": {
"type": "string",
"description": "The complete response from the second LLM",
},
"metadata": {
"type": "object",
"description": "Additional information about the response",
},
},
},
},
}),
Tool.model_validate({
"type": "function",
"function": {
"name": "DateTime",
"description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.",
}
},
"required": [],
},
},
}),
Tool.model_validate({
"type": "function",
"function": {
"name": "WeatherForecast",
"description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
"minLength": 2,
},
"state": {
"type": "string",
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
"minLength": 2,
},
},
"required": ["city", "state"],
"additionalProperties": False,
},
},
}),
]
class ToolEntry(BaseModel):
enabled: bool = True
tool: Tool
def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]:
return [entry.tool.model_dump(mode='json') for entry in tools if entry.enabled == True]
def all_tools() -> List[ToolEntry]:
return [ToolEntry(tool=tool) for tool in tools]
def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled == True]
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
# __all__.extend(__tool_functions__) # type: ignore