Compare commits

..

11 Commits

Author SHA1 Message Date
8a4f94817a Working! 2025-04-30 23:00:16 -07:00
10f28b0e9b Fix package.json with craco 2025-04-30 22:00:38 -07:00
2a3dc56897 Starting to work again 2025-04-30 21:42:30 -07:00
7f24d8870c onion peeling 2025-04-30 16:43:02 -07:00
3094288e46 onion peeling 2025-04-30 16:05:46 -07:00
d1940e18e5 Starting to work again 2025-04-30 15:01:50 -07:00
e607e3a2f2 Starting to work again 2025-04-30 12:57:51 -07:00
4614dbb237 Almost working? 2025-04-29 17:46:10 -07:00
622c33545e Almost working? 2025-04-29 16:48:42 -07:00
c3cf9a9c76 Almost working? 2025-04-29 16:04:43 -07:00
90a83a7313 Almost working? 2025-04-29 15:53:04 -07:00
22 changed files with 1425 additions and 1208 deletions

View File

@ -293,8 +293,13 @@ RUN { \
echo ' openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout src/key.pem -out src/cert.pem -subj "/C=US/ST=OR/L=Portland/O=Development/CN=localhost"'; \ echo ' openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout src/key.pem -out src/cert.pem -subj "/C=US/ST=OR/L=Portland/O=Development/CN=localhost"'; \
echo ' fi' ; \ echo ' fi' ; \
echo ' while true; do'; \ echo ' while true; do'; \
echo ' echo "Launching Backstory server..."'; \ echo ' if [[ ! -e /opt/backstory/block-server ]]; then'; \
echo ' python src/server.py "${@}" || echo "Backstory server died. Restarting in 3 seconds."'; \ echo ' echo "Launching Backstory server..."'; \
echo ' python src/server.py "${@}" || echo "Backstory server died."'; \
echo ' else'; \
echo ' echo "block-server file exists. Not launching."'; \
echo ' fi' ; \
echo ' echo "Sleeping for 3 seconds."'; \
echo ' sleep 3'; \ echo ' sleep 3'; \
echo ' done' ; \ echo ' done' ; \
echo 'fi'; \ echo 'fi'; \

File diff suppressed because it is too large Load Diff

View File

@ -18,6 +18,7 @@
"@types/node": "^16.18.126", "@types/node": "^16.18.126",
"@types/react": "^19.0.12", "@types/react": "^19.0.12",
"@types/react-dom": "^19.0.4", "@types/react-dom": "^19.0.4",
"@uiw/react-json-view": "^2.0.0-alpha.31",
"mui-markdown": "^1.2.6", "mui-markdown": "^1.2.6",
"react": "^19.0.0", "react": "^19.0.0",
"react-dom": "^19.0.0", "react-dom": "^19.0.0",
@ -55,6 +56,7 @@
] ]
}, },
"devDependencies": { "devDependencies": {
"@types/plotly.js": "^2.35.5" "@types/plotly.js": "^2.35.5",
"@craco/craco": "^0.0.0"
} }
} }

View File

@ -26,8 +26,8 @@ interface ConversationHandle {
interface BackstoryMessage { interface BackstoryMessage {
prompt: string; prompt: string;
preamble: string; preamble: {};
content: string; full_content: string;
response: string; response: string;
metadata: { metadata: {
rag: { documents: [] }; rag: { documents: [] };
@ -138,6 +138,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
let filtered = []; let filtered = [];
if (messageFilter === undefined) { if (messageFilter === undefined) {
filtered = conversation; filtered = conversation;
// console.log('No message filter provided. Using all messages.', filtered);
} else { } else {
//console.log('Filtering conversation...') //console.log('Filtering conversation...')
filtered = messageFilter(conversation); /* Do not copy conversation or useEffect will loop forever */ filtered = messageFilter(conversation); /* Do not copy conversation or useEffect will loop forever */
@ -206,8 +207,8 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
}, { }, {
role: 'assistant', role: 'assistant',
prompt: message.prompt || "", prompt: message.prompt || "",
preamble: message.preamble || "", preamble: message.preamble || {},
full_content: message.content || "", full_content: message.full_content || "",
content: message.response || "", content: message.response || "",
metadata: message.metadata, metadata: message.metadata,
actions: message.actions, actions: message.actions,
@ -403,52 +404,59 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
try { try {
const update = JSON.parse(line); const update = JSON.parse(line);
// Force an immediate state update based on the message type switch (update.status) {
if (update.status === 'processing') { case 'processing':
// Update processing message with immediate re-render case 'thinking':
setProcessingMessage({ role: 'status', content: update.message }); // Force an immediate state update based on the message type
// Add a small delay to ensure React has time to update the UI // Update processing message with immediate re-render
await new Promise(resolve => setTimeout(resolve, 0)); setProcessingMessage({ role: 'status', content: update.response });
} else if (update.status === 'done') { // Add a small delay to ensure React has time to update the UI
// Replace processing message with final result await new Promise(resolve => setTimeout(resolve, 0));
if (onResponse) { break;
update.message = onResponse(update.message); case 'done':
} console.log('Done processing:', update);
setProcessingMessage(undefined); // Replace processing message with final result
const backstoryMessage: BackstoryMessage = update.message; if (onResponse) {
setConversation([ update.message = onResponse(update);
...conversationRef.current, { }
role: 'user',
content: backstoryMessage.prompt || "",
}, {
role: 'assistant',
prompt: backstoryMessage.prompt || "",
preamble: backstoryMessage.preamble || "",
full_content: backstoryMessage.content || "",
content: backstoryMessage.response || "",
metadata: backstoryMessage.metadata,
actions: backstoryMessage.actions,
}] as MessageList);
// Add a small delay to ensure React has time to update the UI
await new Promise(resolve => setTimeout(resolve, 0));
const metadata = update.message.metadata;
if (metadata) {
const evalTPS = metadata.eval_count * 10 ** 9 / metadata.eval_duration;
const promptTPS = metadata.prompt_eval_count * 10 ** 9 / metadata.prompt_eval_duration;
setLastEvalTPS(evalTPS ? evalTPS : 35);
setLastPromptTPS(promptTPS ? promptTPS : 35);
updateContextStatus();
}
} else if (update.status === 'error') {
// Show error
setProcessingMessage({ role: 'error', content: update.message });
setTimeout(() => {
setProcessingMessage(undefined); setProcessingMessage(undefined);
}, 5000); const backstoryMessage: BackstoryMessage = update;
setConversation([
...conversationRef.current, {
// role: 'user',
// content: backstoryMessage.prompt || "",
// }, {
role: 'assistant',
origin: type,
content: backstoryMessage.response || "",
prompt: backstoryMessage.prompt || "",
preamble: backstoryMessage.preamble || {},
full_content: backstoryMessage.full_content || "",
metadata: backstoryMessage.metadata,
actions: backstoryMessage.actions,
}] as MessageList);
// Add a small delay to ensure React has time to update the UI
await new Promise(resolve => setTimeout(resolve, 0));
// Add a small delay to ensure React has time to update the UI const metadata = update.metadata;
await new Promise(resolve => setTimeout(resolve, 0)); if (metadata) {
const evalTPS = metadata.eval_count * 10 ** 9 / metadata.eval_duration;
const promptTPS = metadata.prompt_eval_count * 10 ** 9 / metadata.prompt_eval_duration;
setLastEvalTPS(evalTPS ? evalTPS : 35);
setLastPromptTPS(promptTPS ? promptTPS : 35);
updateContextStatus();
}
break;
case 'error':
// Show error
setProcessingMessage({ role: 'error', content: update.response });
setTimeout(() => {
setProcessingMessage(undefined);
}, 5000);
// Add a small delay to ensure React has time to update the UI
await new Promise(resolve => setTimeout(resolve, 0));
break;
} }
} catch (e) { } catch (e) {
setSnack("Error processing query", "error") setSnack("Error processing query", "error")
@ -462,25 +470,44 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>(({
try { try {
const update = JSON.parse(buffer); const update = JSON.parse(buffer);
if (update.status === 'done') { switch (update.status) {
if (onResponse) { case 'processing':
update.message = onResponse(update.message); case 'thinking':
} // Force an immediate state update based on the message type
setProcessingMessage(undefined); // Update processing message with immediate re-render
const backstoryMessage: BackstoryMessage = update.message; setProcessingMessage({ role: 'status', content: update.response });
setConversation([ // Add a small delay to ensure React has time to update the UI
...conversationRef.current, { await new Promise(resolve => setTimeout(resolve, 0));
role: 'user', break;
content: backstoryMessage.prompt || "", case 'error':
}, { // Show error
role: 'assistant', setProcessingMessage({ role: 'error', content: update.response });
prompt: backstoryMessage.prompt || "", setTimeout(() => {
preamble: backstoryMessage.preamble || "", setProcessingMessage(undefined);
full_content: backstoryMessage.content || "", }, 5000);
content: backstoryMessage.response || "", break;
metadata: backstoryMessage.metadata, case 'done':
actions: backstoryMessage.actions, console.log('Done processing:', update);
}] as MessageList); if (onResponse) {
update.message = onResponse(update);
}
setProcessingMessage(undefined);
const backstoryMessage: BackstoryMessage = update;
setConversation([
...conversationRef.current, {
// role: 'user',
// content: backstoryMessage.prompt || "",
// }, {
role: 'assistant',
origin: type,
prompt: backstoryMessage.prompt || "",
content: backstoryMessage.response || "",
preamble: backstoryMessage.preamble || {},
full_content: backstoryMessage.full_content || "",
metadata: backstoryMessage.metadata,
actions: backstoryMessage.actions,
}] as MessageList);
break;
} }
} catch (e) { } catch (e) {
setSnack("Error processing query", "error") setSnack("Error processing query", "error")

View File

@ -19,6 +19,7 @@ import Typography from '@mui/material/Typography';
import ExpandMoreIcon from '@mui/icons-material/ExpandMore'; import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
import { ExpandMore } from './ExpandMore'; import { ExpandMore } from './ExpandMore';
import { SxProps, Theme } from '@mui/material'; import { SxProps, Theme } from '@mui/material';
import JsonView from '@uiw/react-json-view';
import { ChatBubble } from './ChatBubble'; import { ChatBubble } from './ChatBubble';
import { StyledMarkdown } from './StyledMarkdown'; import { StyledMarkdown } from './StyledMarkdown';
@ -32,6 +33,8 @@ type MessageRoles = 'info' | 'user' | 'assistant' | 'system' | 'status' | 'error
type MessageData = { type MessageData = {
role: MessageRoles, role: MessageRoles,
content: string, content: string,
full_content?: string,
disableCopy?: boolean, disableCopy?: boolean,
user?: string, user?: string,
title?: string, title?: string,
@ -48,7 +51,6 @@ interface MessageMetaData {
vector_embedding: number[]; vector_embedding: number[];
}, },
origin: string, origin: string,
full_query?: string,
rag: any, rag: any,
tools: any[], tools: any[],
eval_count: number, eval_count: number,
@ -87,7 +89,6 @@ interface MessageMetaProps {
const MessageMeta = (props: MessageMetaProps) => { const MessageMeta = (props: MessageMetaProps) => {
const { const {
/* MessageData */ /* MessageData */
full_query,
rag, rag,
tools, tools,
eval_count, eval_count,
@ -95,7 +96,7 @@ const MessageMeta = (props: MessageMetaProps) => {
prompt_eval_count, prompt_eval_count,
prompt_eval_duration, prompt_eval_duration,
} = props.metadata || {}; } = props.metadata || {};
const messageProps = props.messageProps; const message = props.messageProps.message;
return (<> return (<>
<Box sx={{ fontSize: "0.8rem", mb: 1 }}> <Box sx={{ fontSize: "0.8rem", mb: 1 }}>
@ -137,7 +138,7 @@ const MessageMeta = (props: MessageMetaProps) => {
</TableContainer> </TableContainer>
{ {
full_query !== undefined && message.full_content !== undefined &&
<Accordion> <Accordion>
<AccordionSummary expandIcon={<ExpandMoreIcon />}> <AccordionSummary expandIcon={<ExpandMoreIcon />}>
<Box sx={{ fontSize: "0.8rem" }}> <Box sx={{ fontSize: "0.8rem" }}>
@ -145,7 +146,7 @@ const MessageMeta = (props: MessageMetaProps) => {
</Box> </Box>
</AccordionSummary> </AccordionSummary>
<AccordionDetails> <AccordionDetails>
<pre style={{ "display": "block", "position": "relative" }}><CopyBubble content={full_query?.trim()} />{full_query?.trim()}</pre> <pre style={{ "display": "block", "position": "relative" }}><CopyBubble content={message.full_content?.trim()} />{message.full_content?.trim()}</pre>
</AccordionDetails> </AccordionDetails>
</Accordion> </Accordion>
} }
@ -182,14 +183,18 @@ const MessageMeta = (props: MessageMetaProps) => {
</Accordion> </Accordion>
} }
{ {
rag?.name !== undefined && <> rag.map((rag: any) => (
<Accordion> <Accordion key={rag.name}>
<AccordionSummary expandIcon={<ExpandMoreIcon />}> <AccordionSummary expandIcon={<ExpandMoreIcon />}>
<Box sx={{ fontSize: "0.8rem" }}> <Box sx={{ fontSize: "0.8rem" }}>
Top RAG {rag.ids.length} matches from '{rag.name}' collection against embedding vector of {rag.query_embedding.length} dimensions Top RAG {rag.ids.length} matches from '{rag.name}' collection against embedding vector of {rag.query_embedding.length} dimensions
</Box> </Box>
</AccordionSummary> </AccordionSummary>
<AccordionDetails> <AccordionDetails>
<Box sx={{ fontSize: "0.8rem" }}>
UMAP Vector Visualization of '{rag.name}' RAG
</Box>
<VectorVisualizer inline {...props.messageProps} {...props.metadata} rag={rag} />
{rag.ids.map((id: number, index: number) => <Box key={index}> {rag.ids.map((id: number, index: number) => <Box key={index}>
{index !== 0 && <Divider />} {index !== 0 && <Divider />}
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "row", mb: 0.5, mt: 0.5 }}> <Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "row", mb: 0.5, mt: 0.5 }}>
@ -205,55 +210,33 @@ const MessageMeta = (props: MessageMetaProps) => {
)} )}
</AccordionDetails> </AccordionDetails>
</Accordion> </Accordion>
<Accordion> ))
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
<Box sx={{ fontSize: "0.8rem" }}>
UMAP Vector Visualization of RAG
</Box>
</AccordionSummary>
<AccordionDetails>
<VectorVisualizer inline {...messageProps} {...props.metadata} rag={rag} />
</AccordionDetails>
</Accordion>
<Accordion>
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
<Box sx={{ fontSize: "0.8rem" }}>
All response fields
</Box>
</AccordionSummary>
<AccordionDetails>
{Object.entries(props.messageProps.message)
.filter(([key, value]) => key !== undefined && value !== undefined)
.map(([key, value]) => (typeof (value) !== "string" || value?.trim() !== "") &&
<Accordion key={key}>
<AccordionSummary sx={{ fontSize: "1rem", fontWeight: "bold" }} expandIcon={<ExpandMoreIcon />}>
{key}
</AccordionSummary>
<AccordionDetails>
{key === "metadata" &&
Object.entries(value)
.filter(([key, value]) => key !== undefined && value !== undefined)
.map(([key, value]) => (
<Accordion key={`metadata.${key}`}>
<AccordionSummary sx={{ fontSize: "1rem", fontWeight: "bold" }} expandIcon={<ExpandMoreIcon />}>
{key}
</AccordionSummary>
<AccordionDetails>
<pre>{`${typeof (value) !== "object" ? value : JSON.stringify(value)}`}</pre>
</AccordionDetails>
</Accordion>
))}
{key !== "metadata" &&
<pre>{typeof (value) !== "object" ? value : JSON.stringify(value)}</pre>
}
</AccordionDetails>
</Accordion>
)}
</AccordionDetails>
</Accordion>
</>
} }
<Accordion>
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
<Box sx={{ fontSize: "0.8rem" }}>
All response fields
</Box>
</AccordionSummary>
<AccordionDetails>
{Object.entries(message)
.filter(([key, value]) => key !== undefined && value !== undefined)
.map(([key, value]) => (typeof (value) !== "string" || value?.trim() !== "") &&
<Accordion key={key}>
<AccordionSummary sx={{ fontSize: "1rem", fontWeight: "bold" }} expandIcon={<ExpandMoreIcon />}>
{key}
</AccordionSummary>
<AccordionDetails>
{typeof (value) === "string" ?
<pre>{value}</pre> :
<JsonView collapsed={1} value={value as any} style={{ fontSize: "0.8rem", maxHeight: "20rem", overflow: "auto" }} />
}
</AccordionDetails>
</Accordion>
)}
</AccordionDetails>
</Accordion>
</>); </>);
}; };

View File

@ -82,6 +82,7 @@ const emojiMap: Record<string, string> = {
query: '🔍', query: '🔍',
resume: '📄', resume: '📄',
projects: '📁', projects: '📁',
jobs: '📁',
'performance-reviews': '📄', 'performance-reviews': '📄',
news: '📰', news: '📰',
}; };
@ -91,7 +92,8 @@ const colorMap: Record<string, string> = {
resume: '#4A7A7D', // Dusty Teal — secondary theme color resume: '#4A7A7D', // Dusty Teal — secondary theme color
projects: '#1A2536', // Midnight Blue — rich and deep projects: '#1A2536', // Midnight Blue — rich and deep
news: '#D3CDBF', // Warm Gray — soft and neutral news: '#D3CDBF', // Warm Gray — soft and neutral
'performance-reviews': '#FF0000', // Bright red 'performance-reviews': '#FFD0D0', // Light red
'jobs': '#F3aD8F', // Warm Gray — soft and neutral
}; };
const sizeMap: Record<string, number> = { const sizeMap: Record<string, number> = {
@ -156,7 +158,7 @@ const VectorVisualizer: React.FC<VectorVisualizerProps> = (props: VectorVisualiz
useEffect(() => { useEffect(() => {
if (!result || !result.embeddings) return; if (!result || !result.embeddings) return;
if (result.embeddings.length === 0) return; if (result.embeddings.length === 0) return;
console.log('Result:', result);
const vectors: (number[])[] = [...result.embeddings]; const vectors: (number[])[] = [...result.embeddings];
const documents = [...result.documents || []]; const documents = [...result.documents || []];
const metadatas = [...result.metadatas || []]; const metadatas = [...result.metadatas || []];

View File

@ -2,12 +2,12 @@
# Ensure input was provided # Ensure input was provided
if [[ -z "$1" ]]; then if [[ -z "$1" ]]; then
echo "Usage: $0 <path/to/python_script.py>" TARGET=$(readlink -f "src/server.py")
exit 1 else
TARGET=$(readlink -f "$1")
fi fi
# Resolve user-supplied path to absolute path # Resolve user-supplied path to absolute path
TARGET=$(readlink -f "$1")
if [[ ! -f "$TARGET" ]]; then if [[ ! -f "$TARGET" ]]; then
echo "Target file '$TARGET' not found." echo "Target file '$TARGET' not found."

View File

@ -1,3 +1,7 @@
from utils import logger
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
# %% # %%
# Imports [standard] # Imports [standard]
# Standard library modules (no try-except needed) # Standard library modules (no try-except needed)
@ -34,6 +38,7 @@ try_import("sklearn")
import ollama import ollama
import requests import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, BackgroundTasks from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -44,8 +49,10 @@ from sklearn.preprocessing import MinMaxScaler
from utils import ( from utils import (
rag as Rag, rag as Rag,
Context, Conversation, Session, Message, Chat, Resume, JobDescription, FactCheck, Context, Conversation, Message,
defines Agent,
defines,
logger
) )
from tools import ( from tools import (
@ -250,25 +257,6 @@ def parse_args():
default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}") default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}")
return parser.parse_args() return parser.parse_args()
def setup_logging(level):
global logging
numeric_level = getattr(logging, level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f"Invalid log level: {level}")
logging.basicConfig(
level=numeric_level,
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
force=True
)
# Now reduce verbosity for FastAPI, Uvicorn, Starlette
for noisy_logger in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi", "starlette"):
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
logging.info(f"Logging is set to {level} level.")
# %% # %%
@ -288,10 +276,10 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
headers = { headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
} }
logging.info(f"Fetching {url}") logger.info(f"Fetching {url}")
response = requests.get(url, headers=headers, timeout=10) response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status() response.raise_for_status()
logging.info(f"{url} returned. Processing...") logger.info(f"{url} returned. Processing...")
# Parse the HTML # Parse the HTML
soup = BeautifulSoup(response.text, "html.parser") soup = BeautifulSoup(response.text, "html.parser")
@ -313,7 +301,7 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
text = text[:max_chars] + "..." text = text[:max_chars] + "..."
# Create Ollama client # Create Ollama client
# logging.info(f"Requesting summary of: {text}") # logger.info(f"Requesting summary of: {text}")
# Generate summary using Ollama # Generate summary using Ollama
prompt = f"CONTENTS:\n\n{text}\n\n{question}" prompt = f"CONTENTS:\n\n{text}\n\n{question}"
@ -321,7 +309,7 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
system="You are given the contents of {url}. Answer the question about the contents", system="You are given the contents of {url}. Answer the question about the contents",
prompt=prompt) prompt=prompt)
#logging.info(response["response"]) #logger.info(response["response"])
return { return {
"source": "summarizer-llm", "source": "summarizer-llm",
@ -359,8 +347,23 @@ def llm_tools(tools):
# %% # %%
class WebServer: class WebServer:
@asynccontextmanager
async def lifespan(self, app: FastAPI):
# Start the file watcher
self.observer, self.file_watcher = Rag.start_file_watcher(
llm=self.llm,
watch_directory=defines.doc_dir,
recreate=False # Don't recreate if exists
)
logger.info(f"API started with {self.file_watcher.collection.count()} documents in the collection")
yield
if self.observer:
self.observer.stop()
self.observer.join()
logger.info("File watcher stopped")
def __init__(self, llm, model=MODEL_NAME): def __init__(self, llm, model=MODEL_NAME):
self.app = FastAPI() self.app = FastAPI(lifespan=self.lifespan)
self.contexts = {} self.contexts = {}
self.llm = llm self.llm = llm
self.model = model self.model = model
@ -375,7 +378,7 @@ class WebServer:
else: else:
allow_origins=["http://battle-linux.ketrenos.com:3000"] allow_origins=["http://battle-linux.ketrenos.com:3000"]
logging.info(f"Allowed origins: {allow_origins}") logger.info(f"Allowed origins: {allow_origins}")
self.app.add_middleware( self.app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -385,38 +388,19 @@ class WebServer:
allow_headers=["*"], allow_headers=["*"],
) )
@self.app.on_event("startup")
async def startup_event():
# Start the file watcher
self.observer, self.file_watcher = Rag.start_file_watcher(
llm=llm,
watch_directory=defines.doc_dir,
recreate=False # Don't recreate if exists
)
print(f"API started with {self.file_watcher.collection.count()} documents in the collection")
@self.app.on_event("shutdown")
async def shutdown_event():
if self.observer:
self.observer.stop()
self.observer.join()
print("File watcher stopped")
self.setup_routes() self.setup_routes()
def setup_routes(self): def setup_routes(self):
@self.app.get("/") @self.app.get("/")
async def root(): async def root():
context = self.create_context() context = self.create_context()
logging.info(f"Redirecting non-session to {context.id}") logger.info(f"Redirecting non-context to {context.id}")
return RedirectResponse(url=f"/{context.id}", status_code=307) return RedirectResponse(url=f"/{context.id}", status_code=307)
#return JSONResponse({"redirect": f"/{context.id}"}) #return JSONResponse({"redirect": f"/{context.id}"})
@self.app.put("/api/umap/{context_id}") @self.app.put("/api/umap/{context_id}")
async def put_umap(context_id: str, request: Request): async def put_umap(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
try: try:
if not self.file_watcher: if not self.file_watcher:
raise Exception("File watcher not initialized") raise Exception("File watcher not initialized")
@ -429,29 +413,36 @@ class WebServer:
dimensions = data.get("dimensions", 2) dimensions = data.get("dimensions", 2)
result = self.file_watcher.umap_collection result = self.file_watcher.umap_collection
if not result:
return JSONResponse({"error": "No UMAP collection found"}, status_code=404)
if dimensions == 2: if dimensions == 2:
logging.info("Returning 2D UMAP") logger.info("Returning 2D UMAP")
umap_embedding = self.file_watcher.umap_embedding_2d umap_embedding = self.file_watcher.umap_embedding_2d
else: else:
logging.info("Returning 3D UMAP") logger.info("Returning 3D UMAP")
umap_embedding = self.file_watcher.umap_embedding_3d umap_embedding = self.file_watcher.umap_embedding_3d
if len(umap_embedding) == 0:
return JSONResponse({"error": "No UMAP embedding found"}, status_code=404)
result["embeddings"] = umap_embedding.tolist() result["embeddings"] = umap_embedding.tolist()
return JSONResponse(result) return JSONResponse(result)
except Exception as e: except Exception as e:
logging.error(e) logger.error(f"put_umap error: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, 500) return JSONResponse({"error": str(e)}, 500)
@self.app.put("/api/similarity/{context_id}") @self.app.put("/api/similarity/{context_id}")
async def put_similarity(context_id: str, request: Request): async def put_similarity(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
if not self.file_watcher: if not self.file_watcher:
return raise Exception("File watcher not initialized")
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400) return JSONResponse({"error": "Invalid context_id"}, status_code=400)
try: try:
@ -468,13 +459,13 @@ class WebServer:
return JSONResponse({"error": "No results found"}, status_code=404) return JSONResponse({"error": "No results found"}, status_code=404)
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
print(f"Chroma embedding shape: {chroma_embedding.shape}") logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output logger.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output logger.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
return JSONResponse({ return JSONResponse({
**chroma_results, **chroma_results,
@ -484,19 +475,19 @@ class WebServer:
}) })
except Exception as e: except Exception as e:
logging.error(e) logger.error(e)
#return JSONResponse({"error": str(e)}, 500) #return JSONResponse({"error": str(e)}, 500)
@self.app.put("/api/reset/{context_id}/{session_type}") @self.app.put("/api/reset/{context_id}/{agent_type}")
async def put_reset(context_id: str, session_type: str, request: Request): async def put_reset(context_id: str, agent_type: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400) return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
session = context.get_session(session_type) agent = context.get_agent(agent_type)
if not session: if not agent:
return JSONResponse({ "error": f"{session_type} is not recognized", "context": context.id }, status_code=404) return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404)
data = await request.json() data = await request.json()
try: try:
@ -504,8 +495,8 @@ class WebServer:
for reset_operation in data["reset"]: for reset_operation in data["reset"]:
match reset_operation: match reset_operation:
case "system_prompt": case "system_prompt":
logging.info(f"Resetting {reset_operation}") logger.info(f"Resetting {reset_operation}")
match session_type: match agent_type:
case "chat": case "chat":
prompt = system_message prompt = system_message
case "job_description": case "job_description":
@ -515,14 +506,14 @@ class WebServer:
case "fact_check": case "fact_check":
prompt = system_message prompt = system_message
session.system_prompt = prompt agent.system_prompt = prompt
response["system_prompt"] = { "system_prompt": prompt } response["system_prompt"] = { "system_prompt": prompt }
case "rags": case "rags":
logging.info(f"Resetting {reset_operation}") logger.info(f"Resetting {reset_operation}")
context.rags = rags.copy() context.rags = rags.copy()
response["rags"] = context.rags response["rags"] = context.rags
case "tools": case "tools":
logging.info(f"Resetting {reset_operation}") logger.info(f"Resetting {reset_operation}")
context.tools = default_tools(tools) context.tools = default_tools(tools)
response["tools"] = context.tools response["tools"] = context.tools
case "history": case "history":
@ -532,19 +523,19 @@ class WebServer:
"fact_check": ("job_description", "resume", "fact_check"), "fact_check": ("job_description", "resume", "fact_check"),
"chat": ("chat",), "chat": ("chat",),
} }
resets = reset_map.get(session_type, ()) resets = reset_map.get(agent_type, ())
for mode in resets: for mode in resets:
tmp = context.get_session(mode) tmp = context.get_agent(mode)
if not tmp: if not tmp:
continue continue
logging.info(f"Resetting {reset_operation} for {mode}") logger.info(f"Resetting {reset_operation} for {mode}")
context.conversation = Conversation() context.conversation = Conversation()
context.context_tokens = round(len(str(session.system_prompt)) * 3 / 4) # Estimate context usage context.context_tokens = round(len(str(agent.system_prompt)) * 3 / 4) # Estimate context usage
response["history"] = [] response["history"] = []
response["context_used"] = session.context_tokens response["context_used"] = agent.context_tokens
case "message_history_length": case "message_history_length":
logging.info(f"Resetting {reset_operation}") logger.info(f"Resetting {reset_operation}")
context.message_history_length = DEFAULT_HISTORY_LENGTH context.message_history_length = DEFAULT_HISTORY_LENGTH
response["message_history_length"] = DEFAULT_HISTORY_LENGTH response["message_history_length"] = DEFAULT_HISTORY_LENGTH
@ -559,13 +550,13 @@ class WebServer:
@self.app.put("/api/tunables/{context_id}") @self.app.put("/api/tunables/{context_id}")
async def put_tunables(context_id: str, request: Request): async def put_tunables(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
try: try:
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
data = await request.json() data = await request.json()
session = context.get_session("chat") agent = context.get_agent("chat")
if not session: if not agent:
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404) return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
for k in data.keys(): for k in data.keys():
match k: match k:
@ -600,7 +591,7 @@ class WebServer:
system_prompt = data[k].strip() system_prompt = data[k].strip()
if not system_prompt: if not system_prompt:
return JSONResponse({ "status": "error", "message": "System prompt can not be empty." }) return JSONResponse({ "status": "error", "message": "System prompt can not be empty." })
session.system_prompt = system_prompt agent.system_prompt = system_prompt
self.save_context(context_id) self.save_context(context_id)
return JSONResponse({ "system_prompt": system_prompt }) return JSONResponse({ "system_prompt": system_prompt })
case "message_history_length": case "message_history_length":
@ -611,21 +602,21 @@ class WebServer:
case _: case _:
return JSONResponse({ "error": f"Unrecognized tunable {k}"}, status_code=404) return JSONResponse({ "error": f"Unrecognized tunable {k}"}, status_code=404)
except Exception as e: except Exception as e:
logging.error(f"Error in put_tunables: {e}") logger.error(f"Error in put_tunables: {e}")
return JSONResponse({"error": str(e)}, status_code=500) return JSONResponse({"error": str(e)}, status_code=500)
@self.app.get("/api/tunables/{context_id}") @self.app.get("/api/tunables/{context_id}")
async def get_tunables(context_id: str, request: Request): async def get_tunables(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400) return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
session = context.get_session("chat") agent = context.get_agent("chat")
if not session: if not agent:
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404) return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
return JSONResponse({ return JSONResponse({
"system_prompt": session.system_prompt, "system_prompt": agent.system_prompt,
"message_history_length": context.message_history_length, "message_history_length": context.message_history_length,
"rags": context.rags, "rags": context.rags,
"tools": [ { "tools": [ {
@ -636,35 +627,34 @@ class WebServer:
@self.app.get("/api/system-info/{context_id}") @self.app.get("/api/system-info/{context_id}")
async def get_system_info(context_id: str, request: Request): async def get_system_info(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
return JSONResponse(system_info(self.model)) return JSONResponse(system_info(self.model))
@self.app.post("/api/chat/{context_id}/{session_type}") @self.app.post("/api/chat/{context_id}/{agent_type}")
async def post_chat_endpoint(context_id: str, session_type: str, request: Request): async def post_chat_endpoint(context_id: str, agent_type: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
try: try:
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400) return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
try: try:
data = await request.json() data = await request.json()
session = context.get_session(session_type) agent = context.get_agent(agent_type)
if not session and session_type == "job_description": if not agent and agent_type == "job_description":
logging.info(f"Session {session_type} not found. Returning empty history.") logger.info(f"Agent {agent_type} not found. Returning empty history.")
# Create a new session if it doesn't exist # Create a new agent if it doesn't exist
session = context.get_or_create_session("job_description", system_prompt=system_generate_resume, job_description=data["content"]) agent = context.get_or_create_agent("job_description", system_prompt=system_generate_resume, job_description=data["content"])
except Exception as e: except Exception as e:
logging.info(f"Attempt to create session type: {session_type} failed", e) logger.info(f"Attempt to create agent type: {agent_type} failed", e)
return JSONResponse({ "error": f"{session_type} is not recognized", "context": context.id }, status_code=404) return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404)
# Create a custom generator that ensures flushing # Create a custom generator that ensures flushing
async def flush_generator(): async def flush_generator():
async for message in self.generate_response(context=context, session=session, content=data["content"]): async for message in self.generate_response(context=context, agent=agent, content=data["content"]):
# Convert to JSON and add newline # Convert to JSON and add newline
yield json.dumps(message) + "\n" yield json.dumps(message.model_dump(mode='json')) + "\n"
# Save the history as its generated # Save the history as its generated
self.save_context(context_id) self.save_context(context_id)
# Explicitly flush after each yield # Explicitly flush after each yield
@ -681,41 +671,43 @@ class WebServer:
} }
) )
except Exception as e: except Exception as e:
logging.error(f"Error in post_chat_endpoint: {e}") logger.error(f"Error in post_chat_endpoint: {e}")
return JSONResponse({"error": str(e)}, status_code=500) return JSONResponse({"error": str(e)}, status_code=500)
@self.app.post("/api/context") @self.app.post("/api/context")
async def create_context(): async def create_context():
context = self.create_context() context = self.create_context()
logging.info(f"Generated new session as {context.id}") logger.info(f"Generated new agent as {context.id}")
return JSONResponse({ "id": context.id }) return JSONResponse({ "id": context.id })
@self.app.get("/api/history/{context_id}/{session_type}") @self.app.get("/api/history/{context_id}/{agent_type}")
async def get_history(context_id: str, session_type: str, request: Request): async def get_history(context_id: str, agent_type: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
try: try:
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
session = context.get_session(session_type) agent = context.get_agent(agent_type)
if not session: if not agent:
logging.info(f"Session {session_type} not found. Returning empty history.") logger.info(f"Agent {agent_type} not found. Returning empty history.")
return JSONResponse({ "messages": [] }) return JSONResponse({ "messages": [] })
logging.info(f"History for {session_type} contains {len(session.conversation.messages)} entries.") logger.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.")
return session.conversation return agent.conversation
except Exception as e: except Exception as e:
logging.error(f"Error in get_history: {e}") logger.error(f"get_history error: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, status_code=404) return JSONResponse({"error": str(e)}, status_code=404)
@self.app.get("/api/tools/{context_id}") @self.app.get("/api/tools/{context_id}")
async def get_tools(context_id: str, request: Request): async def get_tools(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
return JSONResponse(context.tools) return JSONResponse(context.tools)
@self.app.put("/api/tools/{context_id}") @self.app.put("/api/tools/{context_id}")
async def put_tools(context_id: str, request: Request): async def put_tools(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400) return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
try: try:
@ -732,17 +724,17 @@ class WebServer:
return JSONResponse({ "status": "error" }, 405) return JSONResponse({ "status": "error" }, 405)
@self.app.get("/api/context-status/{context_id}/{session_type}") @self.app.get("/api/context-status/{context_id}/{agent_type}")
async def get_context_status(context_id, session_type: str, request: Request): async def get_context_status(context_id, agent_type: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400) return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
session = context.get_session(session_type) agent = context.get_agent(agent_type)
if not session: if not agent:
return JSONResponse({"context_used": 0, "max_context": defines.max_context}) return JSONResponse({"context_used": 0, "max_context": defines.max_context})
return JSONResponse({"context_used": session.context_tokens, "max_context": defines.max_context}) return JSONResponse({"context_used": agent.context_tokens, "max_context": defines.max_context})
@self.app.get("/api/health") @self.app.get("/api/health")
async def health_check(): async def health_check():
@ -752,57 +744,80 @@ class WebServer:
async def serve_static(path: str): async def serve_static(path: str):
full_path = os.path.join(defines.static_content, path) full_path = os.path.join(defines.static_content, path)
if os.path.exists(full_path) and os.path.isfile(full_path): if os.path.exists(full_path) and os.path.isfile(full_path):
logging.info(f"Serve static request for {full_path}") logger.info(f"Serve static request for {full_path}")
return FileResponse(full_path) return FileResponse(full_path)
logging.info(f"Serve index.html for {path}") logger.info(f"Serve index.html for {path}")
return FileResponse(os.path.join(defines.static_content, "index.html")) return FileResponse(os.path.join(defines.static_content, "index.html"))
def save_context(self, session_id): def save_context(self, context_id):
""" """
Serialize a Python dictionary to a file in the sessions directory. Serialize a Python dictionary to a file in the agents directory.
Args: Args:
data: Dictionary containing the session data data: Dictionary containing the agent data
session_id: UUID string for the context. If it doesn't exist, it is created context_id: UUID string for the context. If it doesn't exist, it is created
Returns: Returns:
The session_id used for the file The context_id used for the file
""" """
context = self.upsert_context(session_id) context = self.upsert_context(context_id)
# Create sessions directory if it doesn't exist # Create agents directory if it doesn't exist
if not os.path.exists(defines.session_dir): if not os.path.exists(defines.context_dir):
os.makedirs(defines.session_dir) os.makedirs(defines.context_dir)
# Create the full file path # Create the full file path
file_path = os.path.join(defines.session_dir, session_id) file_path = os.path.join(defines.context_dir, context_id)
# Serialize the data to JSON and write to file # Serialize the data to JSON and write to file
with open(file_path, "w") as f: with open(file_path, "w") as f:
f.write(context.model_dump_json()) f.write(context.model_dump_json())
return session_id return context_id
def load_context(self, session_id) -> Context: def load_or_create_context(self, context_id) -> Context:
""" """
Load a context from a file in the sessions directory. Load a context from a file in the context directory or create a new one if it doesn't exist.
Args: Args:
session_id: UUID string for the context. If it doesn't exist, a new context is created. context_id: UUID string for the context.
Returns: Returns:
A Context object with the specified ID and default settings. A Context object with the specified ID and default settings.
""" """
if not self.file_watcher:
raise Exception("File watcher not initialized")
file_path = os.path.join(defines.session_dir, session_id) file_path = os.path.join(defines.context_dir, context_id)
# Check if the file exists # Check if the file exists
if not os.path.exists(file_path): if not os.path.exists(file_path):
self.contexts[session_id] = self.create_context(session_id) logger.info(f"Context file {file_path} not found. Creating new context.")
self.contexts[context_id] = self.create_context(context_id)
else: else:
# Read and deserialize the data # Read and deserialize the data
with open(file_path, "r") as f: with open(file_path, "r") as f:
self.contexts[session_id] = Context.model_validate_json(f.read()) content = f.read()
logger.info(f"Loading context from {file_path}, content length: {len(content)}")
try:
# Try parsing as JSON first to ensure valid JSON
import json
json_data = json.loads(content)
logger.info("JSON parsed successfully, attempting model validation")
# Now try Pydantic validation
self.contexts[context_id] = Context.model_validate_json(content)
self.contexts[context_id].file_watcher=self.file_watcher
logger.info(f"Successfully loaded context {context_id}")
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in file: {e}")
except Exception as e:
logger.error(f"Error validating context: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Fallback to creating a new context
self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher)
return self.contexts[session_id] return self.contexts[context_id]
def create_context(self, context_id = None) -> Context: def create_context(self, context_id = None) -> Context:
""" """
@ -812,18 +827,24 @@ class WebServer:
Returns: Returns:
A Context object with the specified ID and default settings. A Context object with the specified ID and default settings.
""" """
context = Context(id=context_id) if not self.file_watcher:
raise Exception("File watcher not initialized")
logger.info(f"Creating new context with ID: {context_id}")
context = Context(id=context_id, file_watcher=self.file_watcher)
if os.path.exists(defines.resume_doc): if os.path.exists(defines.resume_doc):
context.user_resume = open(defines.resume_doc, "r").read() context.user_resume = open(defines.resume_doc, "r").read()
context.add_session(Chat(system_prompt = system_message)) context.get_or_create_agent(
# context.add_session(Resume(system_prompt = system_generate_resume)) agent_type="chat",
# context.add_session(JobDescription(system_prompt = system_job_description)) system_prompt=system_message)
# context.add_session(FactCheck(system_prompt = system_fact_check)) # context.add_agent(Resume(system_prompt = system_generate_resume))
# context.add_agent(JobDescription(system_prompt = system_job_description))
# context.add_agent(FactCheck(system_prompt = system_fact_check))
context.tools = default_tools(tools) context.tools = default_tools(tools)
context.rags = rags.copy() context.rags = rags.copy()
logging.info(f"{context.id} created and added to sessions.") logger.info(f"{context.id} created and added to contexts.")
self.contexts[context.id] = context self.contexts[context.id] = context
self.save_context(context.id) self.save_context(context.id)
return context return context
@ -905,44 +926,42 @@ class WebServer:
""" """
if not context_id: if not context_id:
logging.warning("No context ID provided. Creating a new context.") logger.warning("No context ID provided. Creating a new context.")
return self.create_context() return self.create_context()
if not is_valid_uuid(context_id):
logging.info(f"User requested invalid context_id: {context_id}")
raise ValueError("Invalid context_id: {context_id}")
if context_id in self.contexts: if context_id in self.contexts:
return self.contexts[context_id] return self.contexts[context_id]
logging.info(f"Context {context_id} not found. Creating new context.") logger.info(f"Context {context_id} is not yet loaded.")
return self.load_context(context_id) return self.load_or_create_context(context_id)
def generate_rag_results(self, context, content): def generate_rag_results(self, context, content):
if not self.file_watcher:
raise Exception("File watcher not initialized")
results_found = False results_found = False
if self.file_watcher: for rag in context.rags:
for rag in context.rags: if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now...
if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now... yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."}
yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."} chroma_results = self.file_watcher.find_similar(query=content, top_k=10)
chroma_results = self.file_watcher.find_similar(query=content, top_k=10) if chroma_results:
if chroma_results: results_found = True
results_found = True chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
print(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output logger.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output logger.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
yield { yield {
**chroma_results, **chroma_results,
"name": rag["name"], "name": rag["name"],
"umap_embedding_2d": umap_2d, "umap_embedding_2d": umap_2d,
"umap_embedding_3d": umap_3d "umap_embedding_3d": umap_3d
} }
if not results_found: if not results_found:
yield {"status": "complete", "message": "No RAG context found"} yield {"status": "complete", "message": "No RAG context found"}
@ -956,35 +975,51 @@ class WebServer:
else: else:
yield {"status": "complete", "message": "RAG processing complete"} yield {"status": "complete", "message": "RAG processing complete"}
# session_type: chat async def generate_response(self, context : Context, agent : Agent, content : str) -> AsyncGenerator[Message, None]:
# * Q&A
#
# session_type: job_description
# * First message sets Job Description and generates Resume
# * Has content (Job Description)
# * Then Q&A of Job Description
#
# session_type: resume
# * First message sets Resume and generates Fact Check
# * Has no content
# * Then Q&A of Resume
#
# Fact Check:
# * First message sets Fact Check and is Q&A
# * Has content
# * Then Q&A of Fact Check
async def generate_response(self, context : Context, session : Session, content : str):
if not self.file_watcher: if not self.file_watcher:
raise Exception("File watcher not initialized")
agent_type = agent.get_agent_type()
logger.info(f"generate_response: {agent_type}")
if agent_type == "chat":
message = Message(prompt=content)
async for message in agent.prepare_message(message):
# logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}")
if message.status == "error":
yield message
return
if message.status != "done":
yield message
async for message in agent.process_message(self.llm, self.model, message):
# logger.info(f"{agent_type}.process_message: {value.status} - {value.response}")
if message.status == "error":
yield message
return
if message.status != "done":
yield message
# async for value in agent.generate_llm_response(message):
# logger.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}")
# if value.status != "done":
# yield value
# if value.status == "error":
# message.status = "error"
# message.response = value.response
# yield message
# return
logger.info("TODO: There is more to do...")
yield message
return return
return
if self.processing: if self.processing:
logging.info("TODO: Implement delay queing; busy for same session, otherwise return queue size and estimated wait time") logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
yield {"status": "error", "message": "Busy processing another request."} yield {"status": "error", "message": "Busy processing another request."}
return return
self.processing = True self.processing = True
conversation : Conversation = session.conversation conversation : Conversation = agent.conversation
message = Message(prompt=content) message = Message(prompt=content)
del content # Prevent accidental use of content del content # Prevent accidental use of content
@ -999,36 +1034,36 @@ class WebServer:
enable_rag = False enable_rag = False
# RAG is disabled when asking questions about the resume # RAG is disabled when asking questions about the resume
if session.session_type == "resume": if agent.get_agent_type() == "resume":
enable_rag = False enable_rag = False
# The first time through each session session_type a content_seed may be set for # The first time through each agent agent_type a content_seed may be set for
# future chat sessions; use it once, then clear it # future chat agents; use it once, then clear it
message.preamble = session.get_and_reset_content_seed() message.preamble = agent.get_and_reset_content_seed()
system_prompt = session.system_prompt system_prompt = agent.system_prompt
# After the first time a particular session session_type is used, it is handled as a chat. # After the first time a particular agent agent_type is used, it is handled as a chat.
# The number of messages indicating the session is ready for chat varies based on # The number of messages indicating the agent is ready for chat varies based on
# the session_type of session # the agent_type of agent
process_type = session.session_type process_type = agent.get_agent_type()
match process_type: match process_type:
case "job_description": case "job_description":
logging.info(f"job_description user_history len: {len(conversation.messages)}") logger.info(f"job_description user_history len: {len(conversation.messages)}")
if len(conversation.messages) >= 2: # USER, ASSISTANT if len(conversation.messages) >= 2: # USER, ASSISTANT
process_type = "chat" process_type = "chat"
case "resume": case "resume":
logging.info(f"resume user_history len: {len(conversation.messages)}") logger.info(f"resume user_history len: {len(conversation.messages)}")
if len(conversation.messages) >= 3: # USER, ASSISTANT, FACT_CHECK if len(conversation.messages) >= 3: # USER, ASSISTANT, FACT_CHECK
process_type = "chat" process_type = "chat"
case "fact_check": case "fact_check":
process_type = "chat" # Fact Check is always a chat session process_type = "chat" # Fact Check is always a chat agent
match process_type: match process_type:
# Normal chat interactions with context history # Normal chat interactions with context history
case "chat": case "chat":
if not message.prompt: if not message.prompt:
yield {"status": "error", "message": "No query provided for chat."} yield {"status": "error", "message": "No query provided for chat."}
logging.info(f"user_history len: {len(conversation.messages)}") logger.info(f"user_history len: {len(conversation.messages)}")
self.processing = False self.processing = False
return return
@ -1071,7 +1106,7 @@ class WebServer:
Use that information to respond to:""" Use that information to respond to:"""
# Use the mode specific system_prompt instead of 'chat' # Use the mode specific system_prompt instead of 'chat'
system_prompt = session.system_prompt system_prompt = agent.system_prompt
# On first entry, a single job_description is provided ("user") # On first entry, a single job_description is provided ("user")
# Generate a resume to append to RESUME history # Generate a resume to append to RESUME history
@ -1110,10 +1145,10 @@ Use that information to respond to:"""
<|job_description|> <|job_description|>
{message.prompt} {message.prompt}
""" """
tmp = context.get_session("job_description") tmp = context.get_agent("job_description")
if not tmp: if not tmp:
raise Exception(f"Job description session not found.") raise Exception(f"Job description agent not found.")
# Set the content seed for the job_description session # Set the content seed for the job_description agent
tmp.set_content_seed(message.preamble + "<|question|>\nUse the above information to respond to this prompt: ") tmp.set_content_seed(message.preamble + "<|question|>\nUse the above information to respond to this prompt: ")
message.preamble += f""" message.preamble += f"""
@ -1126,7 +1161,7 @@ Use to the above information to respond to this prompt:
""" """
# For all future calls to job_description, use the system_job_description # For all future calls to job_description, use the system_job_description
session.system_prompt = system_job_description agent.system_prompt = system_job_description
# Seed the history for job_description # Seed the history for job_description
stuffingMessage = Message(prompt=message.prompt) stuffingMessage = Message(prompt=message.prompt)
@ -1137,21 +1172,21 @@ Use to the above information to respond to this prompt:
message.add_action("generate_resume") message.add_action("generate_resume")
logging.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into session'resume'") logger.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into agent'resume'")
logging.info("TODO: For subsequent runs, have the Session handler generate the follow up prompts so they can have correct context preamble") logger.info("TODO: For subsequent runs, have the Agent handler generate the follow up prompts so they can have correct context preamble")
# Switch to resume session for LLM responses # Switch to resume agent for LLM responses
# message.metadata["origin"] = "resume" # message.metadata["origin"] = "resume"
# session = context.get_or_create_session("resume") # agent = context.get_or_create_agent("resume")
# system_prompt = session.system_prompt # system_prompt = agent.system_prompt
# llm_history = session.llm_history = [] # llm_history = agent.llm_history = []
# user_history = session.user_history = [] # user_history = agent.user_history = []
# Ignore the passed in content and invoke Fact Check # Ignore the passed in content and invoke Fact Check
case "resume": case "resume":
if len(context.get_or_create_session("resume").conversation.messages) < 2: # USER, **ASSISTANT** if len(context.get_or_create_agent("resume").conversation.messages) < 2: # USER, **ASSISTANT**
raise Exception(f"No resume found in user history.") raise Exception(f"No resume found in user history.")
resume = context.get_or_create_session("resume").conversation.messages[1] resume = context.get_or_create_agent("resume").conversation.messages[1]
# Generate RAG content if enabled, based on the content # Generate RAG content if enabled, based on the content
rag_context = "" rag_context = ""
@ -1196,7 +1231,7 @@ Use to the above information to respond to this prompt:
<|question|> <|question|>
""" """
context.get_or_create_session("resume").set_content_seed(f""" context.get_or_create_agent("resume").set_content_seed(f"""
<|resume|> <|resume|>
{resume["content"]} {resume["content"]}
@ -1218,29 +1253,29 @@ Use the above <|resume|> and <|job_description|> to answer this query:
stuffingMessage.metadata["origin"] = "resume" stuffingMessage.metadata["origin"] = "resume"
stuffingMessage.metadata["display"] = "hide" stuffingMessage.metadata["display"] = "hide"
stuffingMessage.actions = [ "fact_check" ] stuffingMessage.actions = [ "fact_check" ]
logging.info("TODO: Switch this to use actions to keep the UI from showingit") logger.info("TODO: Switch this to use actions to keep the UI from showingit")
conversation.add_message(stuffingMessage) conversation.add_message(stuffingMessage)
# For all future calls to job_description, use the system_job_description # For all future calls to job_description, use the system_job_description
logging.info("TODO: Create a system_resume_QA prompt to use for the resume session") logger.info("TODO: Create a system_resume_QA prompt to use for the resume agent")
session.system_prompt = system_prompt agent.system_prompt = system_prompt
# Switch to fact_check session for LLM responses # Switch to fact_check agent for LLM responses
message.metadata["origin"] = "fact_check" message.metadata["origin"] = "fact_check"
session = context.get_or_create_session("fact_check", system_prompt=system_fact_check) agent = context.get_or_create_agent("fact_check", system_prompt=system_fact_check)
llm_history = session.llm_history = [] llm_history = agent.llm_history = []
user_history = session.user_history = [] user_history = agent.user_history = []
case _: case _:
raise Exception(f"Invalid chat session_type: {session_type}") raise Exception(f"Invalid chat agent_type: {agent_type}")
conversation.add_message(message) conversation.add_message(message)
# llm_history.append({"role": "user", "content": message.preamble + content}) # llm_history.append({"role": "user", "content": message.preamble + content})
# user_history.append({"role": "user", "content": content, "origin": message.metadata["origin"]}) # user_history.append({"role": "user", "content": content, "origin": message.metadata["origin"]})
# message.metadata["full_query"] = llm_history[-1]["content"] # message.metadata["full_query"] = llm_history[-1]["content"]
# Uses cached system_prompt as session.system_prompt may have been updated for follow up questions # Uses cached system_prompt as agent.system_prompt may have been updated for follow up questions
messages = create_system_message(system_prompt) messages = create_system_message(system_prompt)
if context.message_history_length: if context.message_history_length:
to_add = conversation.messages[-context.message_history_length:] to_add = conversation.messages[-context.message_history_length:]
@ -1272,12 +1307,12 @@ Use the above <|resume|> and <|job_description|> to answer this query:
{message.prompt}""" {message.prompt}"""
# Estimate token length of new messages # Estimate token length of new messages
ctx_size = self.get_optimal_ctx_size(context.get_or_create_session(process_type).context_tokens, messages=message.prompt) ctx_size = self.get_optimal_ctx_size(context.get_or_create_agent(process_type).context_tokens, messages=message.prompt)
if len(conversation.messages) > 2: if len(conversation.messages) > 2:
processing_message = f"Processing {'RAG augmented ' if enable_rag else ''}query..." processing_message = f"Processing {'RAG augmented ' if enable_rag else ''}query..."
else: else:
match session.session_type: match agent.get_agent_type():
case "job_description": case "job_description":
processing_message = f"Generating {'RAG augmented ' if enable_rag else ''}resume..." processing_message = f"Generating {'RAG augmented ' if enable_rag else ''}resume..."
case "resume": case "resume":
@ -1294,7 +1329,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
else: else:
response = self.llm.chat(model=self.model, messages=messages, options={ "num_ctx": ctx_size }) response = self.llm.chat(model=self.model, messages=messages, options={ "num_ctx": ctx_size })
except Exception as e: except Exception as e:
logging.exception({ "model": self.model, "error": str(e) }) logger.exception({ "model": self.model, "error": str(e) })
yield {"status": "error", "message": f"An error occurred communicating with LLM"} yield {"status": "error", "message": f"An error occurred communicating with LLM"}
self.processing = False self.processing = False
return return
@ -1303,7 +1338,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
message.metadata["eval_duration"] += response["eval_duration"] message.metadata["eval_duration"] += response["eval_duration"]
message.metadata["prompt_eval_count"] += response["prompt_eval_count"] message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
session.context_tokens = response["prompt_eval_count"] + response["eval_count"] agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
tools_used = [] tools_used = []
@ -1347,7 +1382,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
message.metadata["tools"] = tools_used message.metadata["tools"] = tools_used
# Estimate token length of new messages # Estimate token length of new messages
ctx_size = self.get_optimal_ctx_size(session.context_tokens, messages=messages[pre_add_index:]) ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size } yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
# Decrease creativity when processing tool call requests # Decrease creativity when processing tool call requests
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 }) response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
@ -1355,11 +1390,11 @@ Use the above <|resume|> and <|job_description|> to answer this query:
message.metadata["eval_duration"] += response["eval_duration"] message.metadata["eval_duration"] += response["eval_duration"]
message.metadata["prompt_eval_count"] += response["prompt_eval_count"] message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
session.context_tokens = response["prompt_eval_count"] + response["eval_count"] agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
reply = response["message"]["content"] reply = response["message"]["content"]
message.response = reply message.response = reply
message.metadata["origin"] = session.session_type message.metadata["origin"] = agent.get_agent_type()
# final_message = {"role": "assistant", "content": reply } # final_message = {"role": "assistant", "content": reply }
# # history is provided to the LLM and should not have additional metadata # # history is provided to the LLM and should not have additional metadata
@ -1379,7 +1414,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
} }
# except Exception as e: # except Exception as e:
# logging.exception({ "model": self.model, "origin": session_type, "content": content, "error": str(e) }) # logger.exception({ "model": self.model, "origin": agent_type, "content": content, "error": str(e) })
# yield {"status": "error", "message": f"An error occurred: {str(e)}"} # yield {"status": "error", "message": f"An error occurred: {str(e)}"}
# finally: # finally:
@ -1390,7 +1425,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs): def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs):
try: try:
if self.ssl_enabled: if self.ssl_enabled:
logging.info(f"Starting web server at https://{host}:{port}") logger.info(f"Starting web server at https://{host}:{port}")
uvicorn.run( uvicorn.run(
self.app, self.app,
host=host, host=host,
@ -1400,7 +1435,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
ssl_certfile=defines.cert_path ssl_certfile=defines.cert_path
) )
else: else:
logging.info(f"Starting web server at http://{host}:{port}") logger.info(f"Starting web server at http://{host}:{port}")
uvicorn.run( uvicorn.run(
self.app, self.app,
host=host, host=host,
@ -1423,7 +1458,7 @@ def main():
args = parse_args() args = parse_args()
# Setup logging based on the provided level # Setup logging based on the provided level
setup_logging(args.level) logger.setLevel(args.level.upper())
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",

View File

@ -1,10 +1,67 @@
# Import defines to make `utils.defines` accessible from typing import Optional, Type
from . import defines from . import defines
from . rag import ChromaDBFileWatcher, start_file_watcher
from . message import Message
from . conversation import Conversation
from . context import Context
from . import agents
from . setup_logging import setup_logging
# Import rest as `utils.*` accessible from .agents import Agent, __all__ as agents_all
from .rag import ChromaDBFileWatcher, start_file_watcher
from .message import Message __all__ = [
from .conversation import Conversation 'Agent',
from .session import Session, Chat, Resume, JobDescription, FactCheck 'Context',
from .context import Context 'Conversation',
'Message',
'ChromaDBFileWatcher',
'start_file_watcher'
'logger',
] + agents_all
# Resolve circular dependencies by rebuilding models
# Call model_rebuild() on Agent and Context
Agent.model_rebuild()
Context.model_rebuild()
import importlib
from pydantic import BaseModel
from typing import Type
# Assuming class_registry is available from agents/__init__.py
from .agents import class_registry, AnyAgent
logger = setup_logging(level=defines.logging_level)
def rebuild_models():
for class_name, (module_name, _) in class_registry.items():
try:
module = importlib.import_module(module_name)
cls = getattr(module, class_name, None)
logger.debug(f"Checking: {class_name} in module {module_name}")
logger.debug(f" cls: {True if cls else False}")
logger.debug(f" isinstance(cls, type): {isinstance(cls, type)}")
logger.debug(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}")
logger.debug(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}")
logger.debug(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}")
if (
cls
and isinstance(cls, type)
and issubclass(cls, BaseModel)
and issubclass(cls, AnyAgent)
and cls is not AnyAgent
):
logger.debug(f"Rebuilding {class_name} from {module_name}")
from . agents import Agent
from . context import Context
cls.model_rebuild()
except ImportError as e:
logger.error(f"Failed to import module {module_name}: {e}")
except Exception as e:
logger.error(f"Error processing {class_name} in {module_name}: {e}")
# Call this after all modules are imported
rebuild_models()

View File

@ -0,0 +1,43 @@
from __future__ import annotations
import importlib
import pathlib
import inspect
import logging
from typing import TypeAlias, Dict, Tuple
from pydantic import BaseModel
from . base import Agent
# Type alias for Agent or any subclass
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
package_dir = pathlib.Path(__file__).parent
package_name = __name__
__all__ = []
class_registry: Dict[str, Tuple[str, str]] = {} # Maps class_name to (module_name, class_name)
for path in package_dir.glob("*.py"):
if path.name in ("__init__.py", "base.py") or path.name.startswith("_"):
continue
module_name = path.stem
full_module_name = f"{package_name}.{module_name}"
try:
module = importlib.import_module(full_module_name)
# Find all Agent subclasses in the module
for name, obj in inspect.getmembers(module, inspect.isclass):
if (
issubclass(obj, AnyAgent)
and obj is not AnyAgent
and obj is not Agent
and name not in class_registry
):
class_registry[name] = (full_module_name, name)
globals()[name] = obj
logging.info(f"Adding agent: {name} from {full_module_name}")
__all__.append(name)
except ImportError as e:
logging.error(f"Failed to import module {full_module_name}: {e}")
__all__.append("AnyAgent")

258
src/utils/agents/base.py Normal file
View File

@ -0,0 +1,258 @@
from __future__ import annotations
from pydantic import BaseModel, model_validator, PrivateAttr, Field
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef, Any
from abc import ABC, abstractmethod
from typing_extensions import Annotated
from .. setup_logging import setup_logging
logger = setup_logging()
# Only import Context for type checking
if TYPE_CHECKING:
from .. context import Context
from .types import registry
from .. conversation import Conversation
from .. message import Message
class Agent(BaseModel, ABC):
"""
Base class for all agent types.
This class defines the common attributes and methods for all agent types.
"""
# Agent management with pydantic
agent_type: Literal["base"] = "base"
_agent_type: ClassVar[str] = agent_type # Add this for registration
# Agent properties
system_prompt: str # Mandatory
conversation: Conversation = Conversation()
context_tokens: int = 0
context: Optional[Context] = Field(default=None, exclude=True) # Avoid circular reference, require as param, and prevent serialization
_content_seed: str = PrivateAttr(default="")
# Class and pydantic model management
def __init_subclass__(cls, **kwargs):
"""Auto-register subclasses"""
super().__init_subclass__(**kwargs)
# Register this class if it has an agent_type
if hasattr(cls, 'agent_type') and cls.agent_type != Agent._agent_type:
registry.register(cls.agent_type, cls)
def model_dump(self, *args, **kwargs):
# Ensure context is always excluded, even with exclude_unset=True
kwargs.setdefault("exclude", set())
if isinstance(kwargs["exclude"], set):
kwargs["exclude"].add("context")
elif isinstance(kwargs["exclude"], dict):
kwargs["exclude"]["context"] = True
return super().model_dump(*args, **kwargs)
@classmethod
def valid_agent_types(cls) -> set[str]:
"""Return the set of valid agent_type values."""
return set(get_args(cls.__annotations__["agent_type"]))
def set_context(self, context):
object.__setattr__(self, "context", context)
# Agent methods
def get_agent_type(self):
return self._agent_type
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
"""
Prepare message with context information in message.preamble
"""
# Generate RAG content if enabled, based on the content
rag_context = ""
if not message.disable_rag:
# Gather RAG results, yielding each result
# as it becomes available
for value in self.context.generate_rag_results(message):
logger.info(f"RAG: {value.status} - {value.response}")
if value.status != "done":
yield value
if value.status == "error":
message.status = "error"
message.response = value.response
yield message
return
if message.metadata["rag"]:
for rag_collection in message.metadata["rag"]:
for doc in rag_collection["documents"]:
rag_context += f"{doc}\n"
if rag_context:
message["context"] = rag_context
if self.context.user_resume:
message["resume"] = self.content.user_resume
if message.preamble:
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
preamble_types_AND = " and ".join(preamble_types)
preamble_types_OR = " or ".join(preamble_types)
message.preamble["rules"] = f"""\
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly.
- If there is no information in these sections, answer based on your knowledge.
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
"""
message.preamble["question"] = "Use that information to respond to:"
else:
message.preamble["question"] = "Respond to:"
message.system_prompt = self.system_prompt
message.status = "done"
yield message
return
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
if self.context.processing:
logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
message.status = "error"
message.response = "Busy processing another request."
yield message
return
self.context.processing = True
messages = []
for value in self.llm.chat(
model=self.model,
messages=messages,
#tools=llm_tools(context.tools) if message.enable_tools else None,
options={ "num_ctx": message.ctx_size }
):
logger.info(f"LLM: {value.status} - {value.response}")
if value.status != "done":
message.status = value.status
message.response = value.response
yield message
if value.status == "error":
return
response = value
message.metadata["eval_count"] += response["eval_count"]
message.metadata["eval_duration"] += response["eval_duration"]
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
tools_used = []
yield {"status": "processing", "message": "Initial response received..."}
if "tool_calls" in response.get("message", {}):
yield {"status": "processing", "message": "Processing tool calls..."}
tool_message = response["message"]
tool_result = None
# Process all yielded items from the handler
async for item in self.handle_tool_calls(tool_message):
if isinstance(item, tuple) and len(item) == 2:
# This is the final result tuple (tool_result, tools_used)
tool_result, tools_used = item
else:
# This is a status update, forward it
yield item
message_dict = {
"role": tool_message.get("role", "assistant"),
"content": tool_message.get("content", "")
}
if "tool_calls" in tool_message:
message_dict["tool_calls"] = [
{"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}}
for tc in tool_message["tool_calls"]
]
pre_add_index = len(messages)
messages.append(message_dict)
if isinstance(tool_result, list):
messages.extend(tool_result)
else:
if tool_result:
messages.append(tool_result)
message.metadata["tools"] = tools_used
# Estimate token length of new messages
ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
# Decrease creativity when processing tool call requests
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
message.metadata["eval_count"] += response["eval_count"]
message.metadata["eval_duration"] += response["eval_duration"]
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
reply = response["message"]["content"]
message.response = reply
message.metadata["origin"] = agent.agent_type
# final_message = {"role": "assistant", "content": reply }
# # history is provided to the LLM and should not have additional metadata
# llm_history.append(final_message)
# user_history is provided to the REST API and does not include CONTEXT
# It does include metadata
# final_message["metadata"] = message.metadata
# user_history.append({**final_message, "origin": message.metadata["origin"]})
# Return the REST API with metadata
yield {
"status": "done",
"message": {
**message.model_dump(mode='json'),
}
}
self.context.processing = False
return
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
message.full_content = ""
for i, p in enumerate(message.preamble.keys()):
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
# Estimate token length of new messages
message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
message.status = "thinking"
yield message
for value in self.generate_llm_response(message):
logger.info(f"LLM: {value.status} - {value.response}")
if value.status != "done":
yield value
if value.status == "error":
return
def get_and_reset_content_seed(self):
tmp = self._content_seed
self._content_seed = ""
return tmp
def set_content_seed(self, content: str) -> None:
"""Set the content seed for the agent."""
self._content_seed = content
def get_content_seed(self) -> str:
"""Get the content seed for the agent."""
return self._content_seed
# Register the base agent
registry.register(Agent._agent_type, Agent)

246
src/utils/agents/chat.py Normal file
View File

@ -0,0 +1,246 @@
from __future__ import annotations
from pydantic import BaseModel, model_validator, PrivateAttr
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any
from typing_extensions import Annotated
from abc import ABC, abstractmethod
from typing_extensions import Annotated
import logging
from .base import Agent, registry
from .. conversation import Conversation
from .. message import Message
from .. import defines
class Chat(Agent, ABC):
"""
Base class for all agent types.
This class defines the common attributes and methods for all agent types.
"""
agent_type: Literal["chat"] = "chat"
_agent_type: ClassVar[str] = agent_type # Add this for registration
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
"""
Prepare message with context information in message.preamble
"""
if not self.context:
raise ValueError("Context is not set for this agent.")
# Generate RAG content if enabled, based on the content
rag_context = ""
if not message.disable_rag:
# Gather RAG results, yielding each result
# as it becomes available
for message in self.context.generate_rag_results(message):
logging.info(f"RAG: {message.status} - {message.response}")
if message.status == "error":
yield message
return
if message.status != "done":
yield message
if "rag" in message.metadata and message.metadata["rag"]:
for rag in message.metadata["rag"]:
for doc in rag["documents"]:
rag_context += f"{doc}\n"
message.preamble = {}
if rag_context:
message.preamble["context"] = rag_context
if self.context.user_resume:
message.preamble["resume"] = self.context.user_resume
if message.preamble:
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
preamble_types_AND = " and ".join(preamble_types)
preamble_types_OR = " or ".join(preamble_types)
message.preamble["rules"] = f"""\
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
- If there is no information in these sections, answer based on your knowledge.
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
"""
message.preamble["question"] = "Use that information to respond to:"
else:
message.preamble["question"] = "Respond to:"
message.system_prompt = self.system_prompt
message.status = "done"
yield message
return
async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]:
if not self.context:
raise ValueError("Context is not set for this agent.")
if self.context.processing:
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
message.status = "error"
message.response = "Busy processing another request."
yield message
return
self.context.processing = True
self.conversation.add_message(message)
messages = [
item for m in self.conversation.messages
for item in [
{"role": "user", "content": m.prompt},
{"role": "assistant", "content": m.response}
]
]
for value in llm.chat(
model=model,
messages=messages,
#tools=llm_tools(context.tools) if message.enable_tools else None,
options={ "num_ctx": message.metadata["ctx_size"] if message.metadata["ctx_size"] else defines.max_context },
stream=True,
):
logging.info(f"LLM: {'done' if value.done else 'thinking'} - {value.message.content}")
message.response += value.message.content
yield message
if value.done:
response = value
message.status = "done"
if not response:
message.status = "error"
message.response = "No response from LLM."
yield message
self.context.processing = False
return
message.metadata["eval_count"] += response["eval_count"]
message.metadata["eval_duration"] += response["eval_duration"]
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
self.context_tokens = response["prompt_eval_count"] + response["eval_count"]
yield message
self.context.processing = False
return
tools_used = []
if "tool_calls" in response.get("message", {}):
message.status = "thinking"
message.response = "Processing tool calls..."
tool_message = response["message"]
tool_result = None
# Process all yielded items from the handler
async for value in self.handle_tool_calls(tool_message):
if isinstance(value, tuple) and len(value) == 2:
# This is the final result tuple (tool_result, tools_used)
tool_result, tools_used = value
else:
# This is a status update, forward it
yield value
message_dict = {
"role": tool_message.get("role", "assistant"),
"content": tool_message.get("content", "")
}
if "tool_calls" in tool_message:
message_dict["tool_calls"] = [
{"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}}
for tc in tool_message["tool_calls"]
]
pre_add_index = len(messages)
messages.append(message_dict)
if isinstance(tool_result, list):
messages.extend(tool_result)
else:
if tool_result:
messages.append(tool_result)
message.metadata["tools"] = tools_used
# Estimate token length of new messages
ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
# Decrease creativity when processing tool call requests
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
message.metadata["eval_count"] += response["eval_count"]
message.metadata["eval_duration"] += response["eval_duration"]
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
reply = response["message"]["content"]
message.response = reply
message.metadata["origin"] = agent.agent_type
# final_message = {"role": "assistant", "content": reply }
# # history is provided to the LLM and should not have additional metadata
# llm_history.append(final_message)
# user_history is provided to the REST API and does not include CONTEXT
# It does include metadata
# final_message["metadata"] = message.metadata
# user_history.append({**final_message, "origin": message.metadata["origin"]})
# Return the REST API with metadata
yield {
"status": "done",
"message": {
**message.model_dump(mode='json'),
}
}
self.context.processing = False
return
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
if not self.context:
raise ValueError("Context is not set for this agent.")
message.full_content = f"<|system|>{self.system_prompt.strip()}\n"
for i, p in enumerate(message.preamble.keys()):
message.full_content += f"\n<|{p}|>\n{message.preamble[p].strip()}\n"
message.full_content += f"{message.prompt}"
# Estimate token length of new messages
message.metadata["ctx_size"] = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
message.status = "thinking"
yield message
async for message in self.generate_llm_response(llm, model, message):
logging.info(f"LLM: {message.status} - {message.response}")
if message.status == "error":
return
if message.status != "done":
yield message
yield message
return
def get_and_reset_content_seed(self):
tmp = self._content_seed
self._content_seed = ""
return tmp
def set_content_seed(self, content: str) -> None:
"""Set the content seed for the agent."""
self._content_seed = content
def get_content_seed(self) -> str:
"""Get the content seed for the agent."""
return self._content_seed
@classmethod
def valid_agent_types(cls) -> set[str]:
"""Return the set of valid agent_type values."""
return set(get_args(cls.__annotations__["agent_type"]))
# Register the base agent
registry.register(Chat._agent_type, Chat)

View File

@ -0,0 +1,24 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
from typing_extensions import Annotated
from abc import ABC, abstractmethod
from typing_extensions import Annotated
import logging
from .base import Agent, registry
from .. conversation import Conversation
from .. message import Message
class FactCheck(Agent):
agent_type: Literal["fact_check"] = "fact_check"
_agent_type: ClassVar[str] = agent_type # Add this for registration
facts: str = ""
@model_validator(mode="after")
def validate_facts(self):
if not self.facts.strip():
raise ValueError("Facts cannot be empty")
return self
# Register the base agent
registry.register(FactCheck._agent_type, FactCheck)

View File

@ -0,0 +1,24 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
from typing_extensions import Annotated
from abc import ABC, abstractmethod
from typing_extensions import Annotated
import logging
from .base import Agent, registry
from .. conversation import Conversation
from .. message import Message
class JobDescription(Agent):
agent_type: Literal["job_description"] = "job_description"
_agent_type: ClassVar[str] = agent_type # Add this for registration
job_description: str = ""
@model_validator(mode="after")
def validate_job_description(self):
if not self.job_description.strip():
raise ValueError("Job description cannot be empty")
return self
# Register the base agent
registry.register(JobDescription._agent_type, JobDescription)

View File

@ -0,0 +1,32 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
from typing_extensions import Annotated
from abc import ABC, abstractmethod
from typing_extensions import Annotated
import logging
from .base import Agent, registry
from .. conversation import Conversation
from .. message import Message
class Resume(Agent):
agent_type: Literal["resume"] = "resume"
_agent_type: ClassVar[str] = agent_type # Add this for registration
resume: str = ""
@model_validator(mode="after")
def validate_resume(self):
if not self.resume.strip():
raise ValueError("Resume content cannot be empty")
return self
def get_resume(self) -> str:
"""Get the resume content."""
return self.resume
def set_resume(self, resume: str) -> None:
"""Set the resume content."""
self.resume = resume
# Register the base agent
registry.register(Resume._agent_type, Resume)

38
src/utils/agents/types.py Normal file
View File

@ -0,0 +1,38 @@
from __future__ import annotations
from typing import List, Dict, Any, Union, ForwardRef, TypeVar, Optional, TYPE_CHECKING, Type, ClassVar, Literal
from typing_extensions import Annotated
from pydantic import Field, BaseModel
from abc import ABC, abstractmethod
# Forward references
AgentRef = ForwardRef('Agent')
ContextRef = ForwardRef('Context')
# We'll use a registry pattern rather than hardcoded strings
class AgentRegistry:
"""Registry for agent types and classes"""
_registry: Dict[str, Type] = {}
@classmethod
def register(cls, agent_type: str, agent_class: Type) -> Type:
"""Register an agent class with its type"""
cls._registry[agent_type] = agent_class
return agent_class
@classmethod
def get_class(cls, agent_type: str) -> Optional[Type]:
"""Get the class for a given agent type"""
return cls._registry.get(agent_type)
@classmethod
def get_types(cls) -> List[str]:
"""Get all registered agent types"""
return list(cls._registry.keys())
@classmethod
def get_classes(cls) -> Dict[str, Type]:
"""Get all registered agent classes"""
return cls._registry.copy()
# Create a singleton instance
registry = AgentRegistry()

View File

@ -1,19 +1,32 @@
from pydantic import BaseModel, Field, model_validator from __future__ import annotations
from pydantic import BaseModel, Field, model_validator, ValidationError
from uuid import uuid4 from uuid import uuid4
from typing import List, Optional from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING
from typing_extensions import Annotated, Union from typing_extensions import Annotated, Union
from .session import AnySession, Session import numpy as np
import logging
from uuid import uuid4
import re
from .message import Message
from .rag import ChromaDBFileWatcher
from . import defines
from .agents import AnyAgent
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Context(BaseModel): class Context(BaseModel):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
# Required fields
file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
# Optional fields
id: str = Field( id: str = Field(
default_factory=lambda: str(uuid4()), default_factory=lambda: str(uuid4()),
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
) )
sessions: List[Annotated[Union[*Session.__subclasses__()], Field(discriminator="session_type")]] = Field(
default_factory=list
)
user_resume: Optional[str] = None user_resume: Optional[str] = None
user_job_description: Optional[str] = None user_job_description: Optional[str] = None
user_facts: Optional[str] = None user_facts: Optional[str] = None
@ -21,78 +34,160 @@ class Context(BaseModel):
rags: List[dict] = [] rags: List[dict] = []
message_history_length: int = 5 message_history_length: int = 5
context_tokens: int = 0 context_tokens: int = 0
# Class managed fields
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
default_factory=list
)
def __init__(self, id: Optional[str] = None, **kwargs): processing: bool = Field(default=False, exclude=True)
super().__init__(id=id if id is not None else str(uuid4()), **kwargs)
# @model_validator(mode="before")
# @classmethod
# def before_model_validator(cls, values: Any):
# logger.info(f"Preparing model data: {cls} {values}")
# return values
@model_validator(mode="after") @model_validator(mode="after")
def validate_unique_session_types(self): def after_model_validator(self):
"""Ensure at most one session per session_type.""" """Ensure at most one agent per agent_type."""
session_types = [session.session_type for session in self.sessions] logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
if len(session_types) != len(set(session_types)): agent_types = [agent.agent_type for agent in self.agents]
raise ValueError("Context cannot contain multiple sessions of the same session_type") if len(agent_types) != len(set(agent_types)):
raise ValueError("Context cannot contain multiple agents of the same agent_type")
for agent in self.agents:
agent.set_context(self)
return self return self
def get_or_create_session(self, session_type: str, **kwargs) -> Session: def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096):
ctx = round(context + len(str(messages)) * 3 / 4)
return max(defines.max_context, min(2048, ctx + ctx_buffer))
def generate_rag_results(self, message: Message) -> Generator[Message, None, None]:
""" """
Get or create and append a new session of the specified type, ensuring only one session per type exists. Generate RAG results for the given query.
Args: Args:
session_type: The type of session to create (e.g., 'web', 'database'). query: The query string to generate RAG results for.
**kwargs: Additional fields required by the specific session subclass.
Returns: Returns:
The created session instance. A list of dictionaries containing the RAG results.
"""
try:
message.status = "processing"
entries : int = 0
if not self.file_watcher:
message.response = "No RAG context available."
del message.metadata["rag"]
message.status = "done"
yield message
return
message.metadata["rag"] = []
for rag in self.rags:
if not rag["enabled"]:
continue
message.response = f"Checking RAG context {rag['name']}..."
yield message
chroma_results = self.file_watcher.find_similar(query=message.prompt, top_k=10)
if chroma_results:
entries += len(chroma_results["documents"])
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
print(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
message.metadata["rag"].append({
"name": rag["name"],
**chroma_results,
"umap_embedding_2d": umap_2d,
"umap_embedding_3d": umap_3d
})
yield message
if entries == 0:
del message.metadata["rag"]
message.response = f"RAG context gathered from results from {entries} documents."
message.status = "done"
yield message
return
except Exception as e:
message.status = "error"
message.response = f"Error generating RAG results: {str(e)}"
logger.error(e)
yield message
return
def get_or_create_agent(self, agent_type: str, **kwargs) -> Agent:
"""
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
Args:
agent_type: The type of agent to create (e.g., 'web', 'database').
**kwargs: Additional fields required by the specific agent subclass.
Returns:
The created agent instance.
Raises: Raises:
ValueError: If no matching session type is found or if a session of this type already exists. ValueError: If no matching agent type is found or if a agent of this type already exists.
""" """
# Check if a session with the given session_type already exists # Check if a agent with the given agent_type already exists
for session in self.sessions: for agent in self.agents:
if session.session_type == session_type: if agent.agent_type == agent_type:
return session return agent
# Find the matching subclass # Find the matching subclass
for session_cls in Session.__subclasses__(): for agent_cls in Agent.__subclasses__():
if session_cls.model_fields["session_type"].default == session_type: if agent_cls.model_fields["agent_type"].default == agent_type:
# Create the session instance with provided kwargs # Create the agent instance with provided kwargs
session = session_cls(session_type=session_type, **kwargs) agent = agent_cls(agent_type=agent_type, context=self, **kwargs)
self.sessions.append(session) self.agents.append(agent)
return session return agent
raise ValueError(f"No session class found for session_type: {session_type}") raise ValueError(f"No agent class found for agent_type: {agent_type}")
def add_session(self, session: AnySession) -> None: def add_agent(self, agent: AnyAgent) -> None:
"""Add a Session to the context, ensuring no duplicate session_type.""" """Add a Agent to the context, ensuring no duplicate agent_type."""
if any(s.session_type == session.session_type for s in self.sessions): if any(s.agent_type == agent.agent_type for s in self.agents):
raise ValueError(f"A session with session_type '{session.session_type}' already exists") raise ValueError(f"A agent with agent_type '{agent.agent_type}' already exists")
self.sessions.append(session) self.agents.append(agent)
def get_session(self, session_type: str) -> Session | None: def get_agent(self, agent_type: str) -> Agent | None:
"""Return the Session with the given session_type, or None if not found.""" """Return the Agent with the given agent_type, or None if not found."""
for session in self.sessions: for agent in self.agents:
if session.session_type == session_type: if agent.agent_type == agent_type:
return session return agent
return None return None
def is_valid_session_type(self, session_type: str) -> bool: def is_valid_agent_type(self, agent_type: str) -> bool:
"""Check if the given session_type is valid.""" """Check if the given agent_type is valid."""
return session_type in Session.valid_session_types() return agent_type in Agent.valid_agent_types()
def get_summary(self) -> str: def get_summary(self) -> str:
"""Return a summary of the context.""" """Return a summary of the context."""
if not self.sessions: if not self.agents:
return f"Context {self.uuid}: No sessions." return f"Context {self.uuid}: No agents."
summary = f"Context {self.uuid}:\n" summary = f"Context {self.uuid}:\n"
for i, session in enumerate(self.sessions, 1): for i, agent in enumerate(self.agents, 1):
summary += f"\nSession {i} ({session.session_type}):\n" summary += f"\nAgent {i} ({agent.agent_type}):\n"
summary += session.conversation.get_summary() summary += agent.conversation.get_summary()
if session.session_type == "resume": if agent.agent_type == "resume":
summary += f"\nResume: {session.get_resume()}\n" summary += f"\nResume: {agent.get_resume()}\n"
elif session.session_type == "job_description": elif agent.agent_type == "job_description":
summary += f"\nJob Description: {session.job_description}\n" summary += f"\nJob Description: {agent.job_description}\n"
elif session.session_type == "fact_check": elif agent.agent_type == "fact_check":
summary += f"\nFacts: {session.facts}\n" summary += f"\nFacts: {agent.facts}\n"
elif session.session_type == "chat": elif agent.agent_type == "chat":
summary += f"\nChat Name: {session.name}\n" summary += f"\nChat Name: {agent.name}\n"
return summary return summary
from . agents import Agent
Context.model_rebuild()

View File

@ -9,9 +9,10 @@ embedding_model = os.getenv("EMBEDDING_MODEL_NAME", "mxbai-embed-large")
persist_directory = os.getenv("PERSIST_DIR", "/opt/backstory/chromadb") persist_directory = os.getenv("PERSIST_DIR", "/opt/backstory/chromadb")
max_context = 2048*8*2 max_context = 2048*8*2
doc_dir = "/opt/backstory/docs/" doc_dir = "/opt/backstory/docs/"
session_dir = "/opt/backstory/sessions" context_dir = "/opt/backstory/sessions"
static_content = "/opt/backstory/frontend/deployed" static_content = "/opt/backstory/frontend/deployed"
resume_doc = "/opt/backstory/docs/resume/generic.md" resume_doc = "/opt/backstory/docs/resume/generic.md"
# Only used for testing; backstory-prod will not use this # Only used for testing; backstory-prod will not use this
key_path = "/opt/backstory/keys/key.pem" key_path = "/opt/backstory/keys/key.pem"
cert_path = "/opt/backstory/keys/cert.pem" cert_path = "/opt/backstory/keys/cert.pem"
logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper()

View File

@ -3,19 +3,29 @@ from typing import Dict, List, Optional, Any
from datetime import datetime, timezone from datetime import datetime, timezone
class Message(BaseModel): class Message(BaseModel):
prompt: str # Required
preamble: str = "" prompt: str # Query to be answered
content: str = ""
response: str = "" # Tunables
disable_rag: bool = False
disable_tools: bool = False
# Generated while processing message
status: str = "" # Status of the message
preamble: dict[str,str] = {} # Preamble to be prepended to the prompt
system_prompt: str = "" # System prompt provided to the LLM
full_content: str = "" # Full content of the message (preamble + prompt)
response: str = "" # LLM response to the preamble + query
metadata: dict[str, Any] = { metadata: dict[str, Any] = {
"rag": { "documents": [] }, "rag": List[dict[str, Any]],
"tools": [], "tools": [],
"eval_count": 0, "eval_count": 0,
"eval_duration": 0, "eval_duration": 0,
"prompt_eval_count": 0, "prompt_eval_count": 0,
"prompt_eval_duration": 0, "prompt_eval_duration": 0,
"ctx_size": 0,
} }
actions: List[str] = [] actions: List[str] = [] # Other session modifying actions performed while processing the message
timestamp: datetime = datetime.now(timezone.utc) timestamp: datetime = datetime.now(timezone.utc)
def add_action(self, action: str | list[str]) -> None: def add_action(self, action: str | list[str]) -> None:

View File

@ -1,3 +1,4 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr
import os import os
import glob import glob
from pathlib import Path from pathlib import Path
@ -51,8 +52,12 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap self.chunk_overlap = chunk_overlap
self.loop = loop self.loop = loop
self._umap_collection = None
self._umap_embedding_2d = []
self._umap_embedding_3d = []
self._umap_model_2d = None
self._umap_model_3d = None
self._collection = None
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
#self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') #self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

View File

@ -1,78 +0,0 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr
from typing import Literal, TypeAlias, get_args
from .conversation import Conversation
class Session(BaseModel):
session_type: Literal["resume", "job_description", "fact_check", "chat"]
system_prompt: str # Mandatory
conversation: Conversation = Conversation()
context_tokens: int = 0
_content_seed: str = PrivateAttr(default="")
def get_and_reset_content_seed(self):
tmp = self._content_seed
self._content_seed = ""
return tmp
def set_content_seed(self, content: str) -> None:
"""Set the content seed for the session."""
self._content_seed = content
def get_content_seed(self) -> str:
"""Get the content seed for the session."""
return self._content_seed
@classmethod
def valid_session_types(cls) -> set[str]:
"""Return the set of valid session_type values."""
return set(get_args(cls.__annotations__["session_type"]))
# Type alias for Session or any subclass
AnySession: TypeAlias = Session # BaseModel covers Session and subclasses
class Resume(Session):
session_type: Literal["resume"] = "resume"
resume: str = ""
@model_validator(mode="after")
def validate_resume(self):
if not self.resume.strip():
raise ValueError("Resume content cannot be empty")
return self
def get_resume(self) -> str:
"""Get the resume content."""
return self.resume
def set_resume(self, resume: str) -> None:
"""Set the resume content."""
self.resume = resume
class JobDescription(Session):
session_type: Literal["job_description"] = "job_description"
job_description: str = ""
@model_validator(mode="after")
def validate_job_description(self):
if not self.job_description.strip():
raise ValueError("Job description cannot be empty")
return self
class FactCheck(Session):
session_type: Literal["fact_check"] = "fact_check"
facts: str = ""
@model_validator(mode="after")
def validate_facts(self):
if not self.facts.strip():
raise ValueError("Facts cannot be empty")
return self
class Chat(Session):
session_type: Literal["chat"] = "chat"
@model_validator(mode="after")
def validate_name(self):
return self

View File

@ -0,0 +1,32 @@
import os
import warnings
import logging
from . import defines
def setup_logging(level=defines.logging_level) -> logging.Logger:
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
warnings.filterwarnings("ignore", message="Warning only once for all operators")
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
warnings.filterwarnings("ignore", message="n_jobs value 1 overridden")
numeric_level = getattr(logging, level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f"Invalid log level: {level}")
logging.basicConfig(
level=numeric_level,
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
force=True
)
# Now reduce verbosity for FastAPI, Uvicorn, Starlette
for noisy_logger in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi", "starlette"):
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
return logger