All seems to be working

This commit is contained in:
James Ketr 2025-05-16 10:41:09 -07:00
parent 0d27239ca6
commit 58dadb76f0
30 changed files with 1714 additions and 314 deletions

View File

@ -77,6 +77,17 @@ RUN apt-get update \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
# pydub is loaded by torch, which will throw a warning if ffmpeg isn't installed
RUN apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y software-properties-common \
&& add-apt-repository -y ppa:kobuk-team/intel-graphics \
&& apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
ffmpeg \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
# Prerequisite for ze-monitor # Prerequisite for ze-monitor
RUN apt-get update \ RUN apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \ && DEBIAN_FRONTEND=noninteractive apt-get install -y \

0
cache/.keep vendored Normal file → Executable file
View File

0
cache/grafana/.keep vendored Normal file → Executable file
View File

0
cache/prometheus/.keep vendored Normal file → Executable file
View File

File diff suppressed because it is too large Load Diff

View File

@ -19,6 +19,7 @@
"@types/react": "^19.0.12", "@types/react": "^19.0.12",
"@types/react-dom": "^19.0.4", "@types/react-dom": "^19.0.4",
"@uiw/react-json-view": "^2.0.0-alpha.31", "@uiw/react-json-view": "^2.0.0-alpha.31",
"@uiw/react-markdown-editor": "^6.1.4",
"jsonrepair": "^3.12.0", "jsonrepair": "^3.12.0",
"markdown-it": "^14.1.0", "markdown-it": "^14.1.0",
"mermaid": "^11.6.0", "mermaid": "^11.6.0",

View File

@ -113,6 +113,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
const tunables = await response.json(); const tunables = await response.json();
serverTunables.system_prompt = tunables.system_prompt; serverTunables.system_prompt = tunables.system_prompt;
console.log(tunables);
setSystemPrompt(tunables.system_prompt) setSystemPrompt(tunables.system_prompt)
setSnack("System prompt updated", "success"); setSnack("System prompt updated", "success");
} catch (error) { } catch (error) {
@ -167,31 +168,36 @@ const ControlsPage = (props: BackstoryPageProps) => {
body: JSON.stringify({ "reset": types }), body: JSON.stringify({ "reset": types }),
}); });
if (response.ok) { if (!response.ok) {
const data = await response.json(); throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
if (data.error) {
throw Error()
}
for (const [key, value] of Object.entries(data)) {
switch (key) {
case "rags":
setRags(value as Tool[]);
break;
case "tools":
setTools(value as Tool[]);
break;
case "system_prompt":
setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
break;
case "history":
console.log('TODO: handle history reset');
break;
}
}
setSnack(message, "success");
} else {
throw Error(`${{ status: response.status, message: response.statusText }}`);
} }
if (!response.body) {
throw new Error('Response body is null');
}
const data = await response.json();
if (data.error) {
throw Error(data.error);
}
for (const [key, value] of Object.entries(data)) {
switch (key) {
case "rags":
setRags(value as Tool[]);
break;
case "tools":
setTools(value as Tool[]);
break;
case "system_prompt":
setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
break;
case "history":
console.log('TODO: handle history reset');
break;
}
}
setSnack(message, "success");
} catch (error) { } catch (error) {
console.error('Fetch error:', error); console.error('Fetch error:', error);
setSnack("Unable to restore defaults", "error"); setSnack("Unable to restore defaults", "error");
@ -203,20 +209,37 @@ const ControlsPage = (props: BackstoryPageProps) => {
if (systemInfo !== undefined || sessionId === undefined) { if (systemInfo !== undefined || sessionId === undefined) {
return; return;
} }
fetch(connectionBase + `/api/system-info/${sessionId}`, { const fetchSystemInfo = async () => {
method: 'GET', try {
headers: { const response = await fetch(connectionBase + `/api/system-info/${sessionId}`, {
'Content-Type': 'application/json', method: 'GET',
}, headers: {
}) 'Content-Type': 'application/json',
.then(response => response.json()) },
.then(data => { })
if (!response.ok) {
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
}
if (!response.body) {
throw new Error('Response body is null');
}
const data = await response.json();
if (data.error) {
throw Error(data.error);
}
setSystemInfo(data); setSystemInfo(data);
}) } catch (error) {
.catch(error => {
console.error('Error obtaining system information:', error); console.error('Error obtaining system information:', error);
setSnack("Unable to obtain system information.", "error"); setSnack("Unable to obtain system information.", "error");
}); };
}
fetchSystemInfo();
}, [systemInfo, setSystemInfo, setSnack, sessionId]) }, [systemInfo, setSystemInfo, setSnack, sessionId])
useEffect(() => { useEffect(() => {
@ -273,25 +296,30 @@ const ControlsPage = (props: BackstoryPageProps) => {
return; return;
} }
const fetchTunables = async () => { const fetchTunables = async () => {
// Make the fetch request with proper headers try {
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, { // Make the fetch request with proper headers
method: 'GET', const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
headers: { method: 'GET',
'Content-Type': 'application/json', headers: {
'Accept': 'application/json', 'Content-Type': 'application/json',
}, 'Accept': 'application/json',
}); },
const data = await response.json(); });
// console.log("Server tunables: ", data); const data = await response.json();
setServerTunables(data); // console.log("Server tunables: ", data);
setSystemPrompt(data["system_prompt"]); setServerTunables(data);
setMessageHistoryLength(data["message_history_length"]); setSystemPrompt(data["system_prompt"]);
setTools(data["tools"]); setMessageHistoryLength(data["message_history_length"]);
setRags(data["rags"]); setTools(data["tools"]);
setRags(data["rags"]);
} catch (error) {
console.error('Fetch error:', error);
setSnack("System prompt update failed", "error");
}
} }
fetchTunables(); fetchTunables();
}, [sessionId, setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags]); }, [sessionId, setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags, setSnack]);
const toggle = async (type: string, index: number) => { const toggle = async (type: string, index: number) => {
switch (type) { switch (type) {

View File

@ -285,7 +285,9 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
throw new Error('Response body is null'); throw new Error('Response body is null');
} }
setConversation([]) setProcessingMessage(undefined);
setStreamingMessage(undefined);
setConversation([]);
setNoInteractions(true); setNoInteractions(true);
} catch (e) { } catch (e) {
@ -341,13 +343,23 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
// Add a small delay to ensure React has time to update the UI // Add a small delay to ensure React has time to update the UI
await new Promise(resolve => setTimeout(resolve, 0)); await new Promise(resolve => setTimeout(resolve, 0));
let data: any = query;
if (type === "job_description") {
data = {
prompt: "",
agent_options: {
job_description: query.prompt,
}
}
}
const response = await fetch(connectionBase + `/api/${type}/${sessionId}`, { const response = await fetch(connectionBase + `/api/${type}/${sessionId}`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Accept': 'application/json', 'Accept': 'application/json',
}, },
body: JSON.stringify(query) body: JSON.stringify(data)
}); });
setSnack(`Query sent.`, "info"); setSnack(`Query sent.`, "info");

View File

@ -77,7 +77,7 @@ interface MessageMetaData {
vector_embedding: number[]; vector_embedding: number[];
}, },
origin: string, origin: string,
rag: any, rag: any[],
tools?: { tools?: {
tool_calls: any[], tool_calls: any[],
}, },
@ -117,8 +117,6 @@ const MessageMeta = (props: MessageMetaProps) => {
} = props.metadata || {}; } = props.metadata || {};
const message: any = props.messageProps.message; const message: any = props.messageProps.message;
rag.forEach((r: any) => r.query = message.prompt);
let llm_submission: string = "<|system|>\n" let llm_submission: string = "<|system|>\n"
llm_submission += message.system_prompt + "\n\n" llm_submission += message.system_prompt + "\n\n"
llm_submission += message.context_prompt llm_submission += message.context_prompt
@ -176,7 +174,10 @@ const MessageMeta = (props: MessageMetaProps) => {
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "column", mt: 1, mb: 1, fontWeight: "bold" }}> <Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "column", mt: 1, mb: 1, fontWeight: "bold" }}>
{tool.name} {tool.name}
</Box> </Box>
<JsonView displayDataTypes={false} objectSortKeys={true} collapsed={1} value={JSON.parse(tool.content)} style={{ fontSize: "0.8rem", maxHeight: "20rem", overflow: "auto" }}> <JsonView
displayDataTypes={false}
objectSortKeys={true}
collapsed={1} value={JSON.parse(tool.content)} style={{ fontSize: "0.8rem", maxHeight: "20rem", overflow: "auto" }}>
<JsonView.String <JsonView.String
render={({ children, ...reset }) => { render={({ children, ...reset }) => {
if (typeof (children) === "string" && children.match("\n")) { if (typeof (children) === "string" && children.match("\n")) {

View File

@ -151,7 +151,6 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
}, []); }, []);
const jobResponse = useCallback(async (message: BackstoryMessage) => { const jobResponse = useCallback(async (message: BackstoryMessage) => {
console.log('onJobResponse', message);
if (message.actions && message.actions.includes("job_description")) { if (message.actions && message.actions.includes("job_description")) {
await jobConversationRef.current.fetchHistory(); await jobConversationRef.current.fetchHistory();
} }

View File

@ -53,7 +53,7 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
}} }}
displayDataTypes={false} displayDataTypes={false}
objectSortKeys={false} objectSortKeys={false}
collapsed={true} collapsed={1}
shortenTextAfterLength={100} shortenTextAfterLength={100}
value={fixed}> value={fixed}>
<JsonView.String <JsonView.String

View File

@ -10,7 +10,6 @@ import FormControlLabel from '@mui/material/FormControlLabel';
import Switch from '@mui/material/Switch'; import Switch from '@mui/material/Switch';
import useMediaQuery from '@mui/material/useMediaQuery'; import useMediaQuery from '@mui/material/useMediaQuery';
import { SxProps, useTheme } from '@mui/material/styles'; import { SxProps, useTheme } from '@mui/material/styles';
import JsonView from '@uiw/react-json-view';
import Table from '@mui/material/Table'; import Table from '@mui/material/Table';
import TableBody from '@mui/material/TableBody'; import TableBody from '@mui/material/TableBody';
import TableCell from '@mui/material/TableCell'; import TableCell from '@mui/material/TableCell';
@ -499,13 +498,13 @@ The scatter graph shows the query in N-dimensional space, mapped to ${view2D ? '
</Box> </Box>
} }
<Box sx={{ display: "flex", flexDirection: "column" }}> <Box sx={{ display: "flex", flexDirection: "column", flexGrow: 1 }}>
{node === null && {node === null &&
<Paper sx={{ m: 0.5, p: 2, flexGrow: 1 }}> <Paper sx={{ m: 0.5, p: 2, flexGrow: 1 }}>
Click a point in the scatter-graph to see information about that node. Click a point in the scatter-graph to see information about that node.
</Paper> </Paper>
} }
{!inline && node !== null && node.full_content && {node !== null && node.full_content &&
<Scrollable <Scrollable
autoscroll={false} autoscroll={false}
sx={{ sx={{
@ -521,7 +520,7 @@ The scatter graph shows the query in N-dimensional space, mapped to ${view2D ? '
node.full_content.split('\n').map((line, index) => { node.full_content.split('\n').map((line, index) => {
index += 1 + node.chunk_begin; index += 1 + node.chunk_begin;
const bgColor = (index > node.line_begin && index <= node.line_end) ? '#f0f0f0' : 'auto'; const bgColor = (index > node.line_begin && index <= node.line_end) ? '#f0f0f0' : 'auto';
return <Box key={index} sx={{ display: "flex", flexDirection: "row", borderBottom: '1px solid #d0d0d0', ':first-child': { borderTop: '1px solid #d0d0d0' }, backgroundColor: bgColor }}> return <Box key={index} sx={{ display: "flex", flexDirection: "row", borderBottom: '1px solid #d0d0d0', ':first-of-type': { borderTop: '1px solid #d0d0d0' }, backgroundColor: bgColor }}>
<Box sx={{ fontFamily: 'courier', fontSize: "0.8rem", minWidth: "2rem", pt: "0.1rem", align: "left", verticalAlign: "top" }}>{index}</Box> <Box sx={{ fontFamily: 'courier', fontSize: "0.8rem", minWidth: "2rem", pt: "0.1rem", align: "left", verticalAlign: "top" }}>{index}</Box>
<pre style={{ margin: 0, padding: 0, border: "none", minHeight: "1rem" }} >{line || " "}</pre> <pre style={{ margin: 0, padding: 0, border: "none", minHeight: "1rem" }} >{line || " "}</pre>
</Box>; </Box>;

View File

@ -118,6 +118,10 @@ const useAutoScrollToBottom = (
let shouldScroll = false; let shouldScroll = false;
const scrollTo = scrollToRef.current; const scrollTo = scrollToRef.current;
if (isPasteEvent && !scrollTo) {
console.error("Paste Event triggered without scrollTo");
}
if (scrollTo) { if (scrollTo) {
// Get positions // Get positions
const containerRect = container.getBoundingClientRect(); const containerRect = container.getBoundingClientRect();
@ -130,7 +134,7 @@ const useAutoScrollToBottom = (
scrollToRect.top < containerBottom && scrollToRect.bottom > containerTop; scrollToRect.top < containerBottom && scrollToRect.bottom > containerTop;
// Scroll on paste or if TextField is visible and user isn't scrolling up // Scroll on paste or if TextField is visible and user isn't scrolling up
shouldScroll = (isPasteEvent || isTextFieldVisible) && !isUserScrollingUpRef.current; shouldScroll = isPasteEvent || (isTextFieldVisible && !isUserScrollingUpRef.current);
if (shouldScroll) { if (shouldScroll) {
requestAnimationFrame(() => { requestAnimationFrame(() => {
debug && console.debug('Scrolling to container bottom:', { debug && console.debug('Scrolling to container bottom:', {
@ -198,10 +202,12 @@ const useAutoScrollToBottom = (
} }
const handlePaste = () => { const handlePaste = () => {
console.log("handlePaste");
// Delay scroll check to ensure DOM updates // Delay scroll check to ensure DOM updates
setTimeout(() => { setTimeout(() => {
console.log("scrolling for handlePaste");
requestAnimationFrame(() => checkAndScrollToBottom(true)); requestAnimationFrame(() => checkAndScrollToBottom(true));
}, 0); }, 100);
}; };
window.addEventListener('mousemove', pauseScroll); window.addEventListener('mousemove', pauseScroll);

View File

@ -1,5 +1,5 @@
[tool.black] [tool.black]
line-length = 88 line-length = 120
target-version = ['py312'] target-version = ['py312']
include = '\.pyi?$' include = '\.pyi?$'
exclude = ''' exclude = '''
@ -19,4 +19,4 @@ ignore_decorators = [
"@model_validator", "@model_validator",
"@override", "@classmethod" "@override", "@classmethod"
] ]
exclude = ["tests/", "__pycache__/"] exclude = ["tests/", "__pycache__/"]

View File

@ -1,7 +1,9 @@
LLM_TIMEOUT = 600 LLM_TIMEOUT = 600
from utils import logger from utils import logger
from pydantic import BaseModel, Field # type: ignore from pydantic import BaseModel, Field, ValidationError # type: ignore
from pydantic_core import PydanticSerializationError # type: ignore
from typing import List
from typing import AsyncGenerator, Dict, Optional from typing import AsyncGenerator, Dict, Optional
@ -61,6 +63,7 @@ from prometheus_client import CollectorRegistry, Counter # type: ignore
from utils import ( from utils import (
rag as Rag, rag as Rag,
ChromaDBGetResponse,
tools as Tools, tools as Tools,
Context, Context,
Conversation, Conversation,
@ -69,15 +72,17 @@ from utils import (
Metrics, Metrics,
Tunables, Tunables,
defines, defines,
check_serializable,
logger, logger,
) )
rags = [
{ rags : List[ChromaDBGetResponse] = [
"name": "JPK", ChromaDBGetResponse(
"enabled": True, name="JPK",
"description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects.", enabled=True,
}, description="Expert data about James Ketrenos, including work history, personal hobbies, and projects.",
),
# { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." }, # { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
] ]
@ -461,10 +466,8 @@ class WebServer:
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
agent = context.get_agent(agent_type) agent = context.get_agent(agent_type)
if not agent: if not agent:
return JSONResponse( response = { "history": [] }
{"error": f"{agent_type} is not recognized", "context": context.id}, return JSONResponse(response)
status_code=404,
)
data = await request.json() data = await request.json()
try: try:
@ -475,8 +478,8 @@ class WebServer:
logger.info(f"Resetting {reset_operation}") logger.info(f"Resetting {reset_operation}")
case "rags": case "rags":
logger.info(f"Resetting {reset_operation}") logger.info(f"Resetting {reset_operation}")
context.rags = rags.copy() context.rags = [ r.model_copy() for r in rags]
response["rags"] = context.rags response["rags"] = [ r.model_dump(mode="json") for r in context.rags ]
case "tools": case "tools":
logger.info(f"Resetting {reset_operation}") logger.info(f"Resetting {reset_operation}")
context.tools = Tools.enabled_tools(Tools.tools) context.tools = Tools.enabled_tools(Tools.tools)
@ -537,6 +540,7 @@ class WebServer:
data = await request.json() data = await request.json()
agent = context.get_agent("chat") agent = context.get_agent("chat")
if not agent: if not agent:
logger.info("chat agent does not exist on this context!")
return JSONResponse( return JSONResponse(
{"error": f"chat is not recognized", "context": context.id}, {"error": f"chat is not recognized", "context": context.id},
status_code=404, status_code=404,
@ -572,20 +576,20 @@ class WebServer:
case "rags": case "rags":
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] } # { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
rags: list[dict[str, Any]] = data[k] rag_configs: list[dict[str, Any]] = data[k]
if not rags: if not rag_configs:
return JSONResponse( return JSONResponse(
{ {
"status": "error", "status": "error",
"message": "RAGs can not be empty.", "message": "RAGs can not be empty.",
} }
) )
for rag in rags: for config in rag_configs:
for context_rag in context.rags: for context_rag in context.rags:
if context_rag["name"] == rag["name"]: if context_rag.name == config["name"]:
context_rag["enabled"] = rag["enabled"] context_rag.enabled = config["enabled"]
self.save_context(context_id) self.save_context(context_id)
return JSONResponse({"rags": context.rags}) return JSONResponse({"rags": [ r.model_dump(mode="json") for r in context.rags]})
case "system_prompt": case "system_prompt":
system_prompt = data[k].strip() system_prompt = data[k].strip()
@ -615,12 +619,10 @@ class WebServer:
@self.app.get("/api/tunables/{context_id}") @self.app.get("/api/tunables/{context_id}")
async def get_tunables(context_id: str, request: Request): async def get_tunables(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id):
logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
agent = context.get_agent("chat") agent = context.get_agent("chat")
if not agent: if not agent:
logger.info("chat agent does not exist on this context!")
return JSONResponse( return JSONResponse(
{"error": f"chat is not recognized", "context": context.id}, {"error": f"chat is not recognized", "context": context.id},
status_code=404, status_code=404,
@ -629,7 +631,7 @@ class WebServer:
{ {
"system_prompt": agent.system_prompt, "system_prompt": agent.system_prompt,
"message_history_length": context.message_history_length, "message_history_length": context.message_history_length,
"rags": context.rags, "rags": [ r.model_dump(mode="json") for r in context.rags ],
"tools": [ "tools": [
{ {
**t["function"], **t["function"],
@ -674,6 +676,7 @@ class WebServer:
error = { error = {
"error": f"Attempt to create agent type: {agent_type} failed: {e}" "error": f"Attempt to create agent type: {agent_type} failed: {e}"
} }
logger.info(error)
return JSONResponse(error, status_code=404) return JSONResponse(error, status_code=404)
try: try:
@ -887,8 +890,35 @@ class WebServer:
file_path = os.path.join(defines.context_dir, context_id) file_path = os.path.join(defines.context_dir, context_id)
# Serialize the data to JSON and write to file # Serialize the data to JSON and write to file
with open(file_path, "w") as f: try:
f.write(context.model_dump_json(by_alias=True)) # Check for non-serializable fields before dumping
serialization_errors = check_serializable(context)
if serialization_errors:
for error in serialization_errors:
logger.error(error)
raise ValueError("Found non-serializable fields in the model")
# Dump the model prior to opening file in case there is
# a validation error so it doesn't delete the current
# context session
json_data = context.model_dump_json(by_alias=True)
with open(file_path, "w") as f:
f.write(json_data)
except ValidationError as e:
logger.error(e)
logger.error(traceback.format_exc())
for error in e.errors():
print(f"Field: {error['loc'][0]}, Error: {error['msg']}")
except PydanticSerializationError as e:
logger.error(e)
logger.error(traceback.format_exc())
logger.error(f"Serialization error: {str(e)}")
# Inspect the model to identify problematic fields
for field_name, value in context.__dict__.items():
if isinstance(value, np.ndarray):
logger.error(f"Field '{field_name}' contains non-serializable type: {type(value)}")
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
return context_id return context_id
@ -942,12 +972,8 @@ class WebServer:
self.contexts[context_id] = context self.contexts[context_id] = context
logger.info(f"Successfully loaded context {context_id}") logger.info(f"Successfully loaded context {context_id}")
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in file: {e}")
except Exception as e: except Exception as e:
logger.error(f"Error validating context: {str(e)}") logger.error(f"Error validating context: {str(e)}")
import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
# Fallback to creating a new context # Fallback to creating a new context
self.contexts[context_id] = Context( self.contexts[context_id] = Context(
@ -985,7 +1011,7 @@ class WebServer:
# context.add_agent(JobDescription(system_prompt = system_job_description)) # context.add_agent(JobDescription(system_prompt = system_job_description))
# context.add_agent(FactCheck(system_prompt = system_fact_check)) # context.add_agent(FactCheck(system_prompt = system_fact_check))
context.tools = Tools.enabled_tools(Tools.tools) context.tools = Tools.enabled_tools(Tools.tools)
context.rags = rags.copy() context.rags_enabled = [ r.name for r in rags ]
logger.info(f"{context.id} created and added to contexts.") logger.info(f"{context.id} created and added to contexts.")
self.contexts[context.id] = context self.contexts[context.id] = context

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -20,5 +20,5 @@ observer, file_watcher = Rag.start_file_watcher(
) )
context = Context(file_watcher=file_watcher) context = Context(file_watcher=file_watcher)
data = context.model_dump(mode="json") json_data = context.model_dump(mode="json")
context = Context.from_json(json.dumps(data), file_watcher=file_watcher) context = Context.model_validate(json_data)

View File

@ -0,0 +1,89 @@
# From /opt/backstory run:
# python -m src.tests.test-embedding
import numpy as np # type: ignore
import logging
import argparse
from ollama import Client # type: ignore
from ..utils import defines
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
def get_embedding(text: str, embedding_model: str, ollama_server: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text."""
llm = Client(host=ollama_server)
# Get embedding
try:
response = llm.embeddings(model=embedding_model, prompt=text)
embedding = np.array(response["embedding"])
except Exception as e:
logging.error(f"Failed to get embedding: {e}")
raise
# Log diagnostics
logging.info(f"Input text: {text}")
logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
# Check for invalid embeddings
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
raise ValueError("Invalid embedding returned from Ollama.")
# Check normalization
norm = np.linalg.norm(embedding)
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
# Normalize if needed
if not is_normalized:
embedding = embedding / norm
logging.info("Embedding normalized manually.")
return embedding
def main():
"""Main function to generate and normalize an embedding from command-line input."""
parser = argparse.ArgumentParser(description="Generate embeddings for text using mxbai-embed-large.")
parser.add_argument(
"--text",
type=str,
nargs="+", # Allow multiple text inputs
default=["Test sentence."],
help="Text(s) to generate embeddings for (default: 'Test sentence.')",
)
parser.add_argument(
"--ollama-server",
type=str,
default=defines.ollama_api_url,
help=f"Ollama server URL (default: {defines.ollama_api_url})",
)
parser.add_argument(
"--embedding-model",
type=str,
default=defines.embedding_model,
help=f"Embedding model name (default: {defines.embedding_model})",
)
args = parser.parse_args()
# Validate input
for text in args.text:
if not text or not isinstance(text, str):
logging.error("Input text must be a non-empty string.")
raise ValueError("Input text must be a non-empty string.")
# Generate embeddings for each text
embeddings = []
for text in args.text:
embedding = get_embedding(
text=text,
embedding_model=args.embedding_model,
ollama_server=args.ollama_server,
)
embeddings.append(embedding)
if __name__ == "__main__":
main()

View File

@ -2,10 +2,15 @@
# python -m src.tests.test-message # python -m src.tests.test-message
from ..utils import logger from ..utils import logger
from ..utils import Message from ..utils import Message, MessageMetaData
from ..utils import ChromaDBGetResponse
import json import json
prompt = "This is a test" prompt = "This is a test"
message = Message(prompt=prompt) message = Message(prompt=prompt)
print(message.model_dump(mode="json")) print(message.model_dump(mode="json"))
#message.metadata = MessageMetaData()
rag = ChromaDBGetResponse()
message.metadata.rag = rag
print(message.model_dump(mode="json"))

81
src/tests/test-rag.py Normal file
View File

@ -0,0 +1,81 @@
# From /opt/backstory run:
# python -m src.tests.test-rag
from ..utils import logger
from pydantic import BaseModel, field_validator # type: ignore
from prometheus_client import CollectorRegistry # type: ignore
from typing import List, Dict, Any, Optional
import ollama
import numpy as np # type: ignore
from ..utils import (rag as Rag, ChromaDBGetResponse)
from ..utils import Context
from ..utils import defines
import json
chroma_results = {
"ids": ["1", "2"],
"embeddings": np.array([[1.0, 2.0], [3.0, 4.0]]),
"documents": ["doc1", "doc2"],
"metadatas": [{"meta": "data1"}, {"meta": "data2"}],
"query_embedding": np.array([0.1, 0.2, 0.3])
}
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
umap_2d = np.array([0.4, 0.5]) # Example UMAP output
umap_3d = np.array([0.6, 0.7, 0.8]) # Example UMAP output
rag_metadata = ChromaDBGetResponse(
query="test",
query_embedding=query_embedding,
name="JPK",
ids=chroma_results.get("ids", []),
size=2
)
logger.info(json.dumps(rag_metadata.model_dump(mode="json")))
logger.info(f"Assigning type {type(umap_2d)} to rag_metadata.umap_embedding_2d")
rag_metadata.umap_embedding_2d = umap_2d
logger.info(json.dumps(rag_metadata.model_dump(mode="json")))
rag = ChromaDBGetResponse()
rag.embeddings = np.array([[1.0, 2.0], [3.0, 4.0]])
json_str = rag.model_dump(mode="json")
logger.info(json_str)
rag = ChromaDBGetResponse.model_validate(json_str)
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
prometheus_collector = CollectorRegistry()
observer, file_watcher = Rag.start_file_watcher(
llm=llm,
watch_directory=defines.doc_dir,
recreate=False, # Don't recreate if exists
)
context = Context(
file_watcher=file_watcher,
prometheus_collector=prometheus_collector,
)
skill="Codes in C++"
if context.file_watcher:
chroma_results = context.file_watcher.find_similar(query=skill, top_k=10, threshold=0.5)
if chroma_results:
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
umap_2d = context.file_watcher.umap_model_2d.transform([query_embedding])[0]
umap_3d = context.file_watcher.umap_model_3d.transform([query_embedding])[0]
rag_metadata = ChromaDBGetResponse(
query=skill,
query_embedding=query_embedding,
name="JPK",
ids=chroma_results.get("ids", []),
embeddings=chroma_results.get("embeddings", []),
documents=chroma_results.get("documents", []),
metadatas=chroma_results.get("metadatas", []),
umap_embedding_2d=umap_2d,
umap_embedding_3d=umap_3d,
size=context.file_watcher.collection.count()
)
json_str = context.model_dump(mode="json")
logger.info(json_str)

View File

@ -1,25 +1,34 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel # type: ignore from pydantic import BaseModel # type: ignore
from typing import (
Any,
Set
)
import importlib import importlib
import json
from . import defines from . import defines
from .context import Context from .context import Context
from .conversation import Conversation from .conversation import Conversation
from .message import Message, Tunables from .message import Message, Tunables, MessageMetaData
from .rag import ChromaDBFileWatcher, start_file_watcher from .rag import ChromaDBFileWatcher, ChromaDBGetResponse, start_file_watcher
from .setup_logging import setup_logging from .setup_logging import setup_logging
from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
from .metrics import Metrics from .metrics import Metrics
from .check_serializable import check_serializable
__all__ = [ __all__ = [
"Agent", "Agent",
"Message",
"Tunables", "Tunables",
"MessageMetaData",
"Context", "Context",
"Conversation", "Conversation",
"Message",
"Metrics", "Metrics",
"ChromaDBFileWatcher", "ChromaDBFileWatcher",
'ChromaDBGetResponse',
"start_file_watcher", "start_file_watcher",
"check_serializable",
"logger", "logger",
] ]

View File

@ -165,10 +165,9 @@ class Agent(BaseModel, ABC):
if message.status != "done": if message.status != "done":
yield message yield message
if "rag" in message.metadata and message.metadata["rag"]: for rag in message.metadata.rag:
for rag in message.metadata["rag"]: for doc in rag.documents:
for doc in rag["documents"]: rag_context += f"{doc}\n"
rag_context += f"{doc}\n"
message.preamble = {} message.preamble = {}
@ -189,7 +188,7 @@ class Agent(BaseModel, ABC):
llm: Any, llm: Any,
model: str, model: str,
message: Message, message: Message,
tool_message: Any, tool_message: Any, # llama response message
messages: List[LLMMessage], messages: List[LLMMessage],
) -> AsyncGenerator[Message, None]: ) -> AsyncGenerator[Message, None]:
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
@ -199,10 +198,10 @@ class Agent(BaseModel, ABC):
if not self.context: if not self.context:
raise ValueError("Context is not set for this agent.") raise ValueError("Context is not set for this agent.")
if not message.metadata["tools"]: if not message.metadata.tools:
raise ValueError("tools field not initialized") raise ValueError("tools field not initialized")
tool_metadata = message.metadata["tools"] tool_metadata = message.metadata.tools
tool_metadata["tool_calls"] = [] tool_metadata["tool_calls"] = []
message.status = "tooling" message.status = "tooling"
@ -301,8 +300,7 @@ class Agent(BaseModel, ABC):
model=model, model=model,
messages=messages, messages=messages,
options={ options={
**message.metadata["options"], **message.metadata.options,
# "temperature": 0.5,
}, },
stream=True, stream=True,
): ):
@ -316,12 +314,10 @@ class Agent(BaseModel, ABC):
if response.done: if response.done:
self.collect_metrics(response) self.collect_metrics(response)
message.metadata["eval_count"] += response.eval_count message.metadata.eval_count += response.eval_count
message.metadata["eval_duration"] += response.eval_duration message.metadata.eval_duration += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count message.metadata.prompt_eval_count += response.prompt_eval_count
message.metadata[ message.metadata.prompt_eval_duration += response.prompt_eval_duration
"prompt_eval_duration"
] += response.prompt_eval_duration
self.context_tokens = ( self.context_tokens = (
response.prompt_eval_count + response.eval_count response.prompt_eval_count + response.eval_count
) )
@ -329,9 +325,7 @@ class Agent(BaseModel, ABC):
yield message yield message
end_time = time.perf_counter() end_time = time.perf_counter()
message.metadata["timers"][ message.metadata.timers["llm_with_tools"] = end_time - start_time
"llm_with_tools"
] = f"{(end_time - start_time):.4f}"
return return
def collect_metrics(self, response): def collect_metrics(self, response):
@ -370,22 +364,22 @@ class Agent(BaseModel, ABC):
LLMMessage(role="user", content=message.context_prompt.strip()) LLMMessage(role="user", content=message.context_prompt.strip())
) )
# message.metadata["messages"] = messages # message.messages = messages
message.metadata["options"] = { message.metadata.options = {
"seed": 8911, "seed": 8911,
"num_ctx": self.context_size, "num_ctx": self.context_size,
"temperature": temperature, # Higher temperature to encourage tool usage "temperature": temperature, # Higher temperature to encourage tool usage
} }
# Create a dict for storing various timing stats # Create a dict for storing various timing stats
message.metadata["timers"] = {} message.metadata.timers = {}
use_tools = message.tunables.enable_tools and len(self.context.tools) > 0 use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
message.metadata["tools"] = { message.metadata.tools = {
"available": llm_tools(self.context.tools), "available": llm_tools(self.context.tools),
"used": False, "used": False,
} }
tool_metadata = message.metadata["tools"] tool_metadata = message.metadata.tools
if use_tools: if use_tools:
message.status = "thinking" message.status = "thinking"
@ -408,17 +402,14 @@ class Agent(BaseModel, ABC):
messages=tool_metadata["messages"], messages=tool_metadata["messages"],
tools=tool_metadata["available"], tools=tool_metadata["available"],
options={ options={
**message.metadata["options"], **message.metadata.options,
# "num_predict": 1024, # "Low" token limit to cut off after tool call
}, },
stream=False, # No need to stream the probe stream=False, # No need to stream the probe
) )
self.collect_metrics(response) self.collect_metrics(response)
end_time = time.perf_counter() end_time = time.perf_counter()
message.metadata["timers"][ message.metadata.timers["tool_check"] = end_time - start_time
"tool_check"
] = f"{(end_time - start_time):.4f}"
if not response.message.tool_calls: if not response.message.tool_calls:
logger.info("LLM indicates tools will not be used") logger.info("LLM indicates tools will not be used")
# The LLM will not use tools, so disable use_tools so we can stream the full response # The LLM will not use tools, so disable use_tools so we can stream the full response
@ -442,16 +433,14 @@ class Agent(BaseModel, ABC):
messages=tool_metadata["messages"], # messages, messages=tool_metadata["messages"], # messages,
tools=tool_metadata["available"], tools=tool_metadata["available"],
options={ options={
**message.metadata["options"], **message.metadata.options,
}, },
stream=False, stream=False,
) )
self.collect_metrics(response) self.collect_metrics(response)
end_time = time.perf_counter() end_time = time.perf_counter()
message.metadata["timers"][ message.metadata.timers["non_streaming"] = end_time - start_time
"non_streaming"
] = f"{(end_time - start_time):.4f}"
if not response: if not response:
message.status = "error" message.status = "error"
@ -475,9 +464,7 @@ class Agent(BaseModel, ABC):
return return
yield message yield message
end_time = time.perf_counter() end_time = time.perf_counter()
message.metadata["timers"][ message.metadata.timers["process_tool_calls"] = end_time - start_time
"process_tool_calls"
] = f"{(end_time - start_time):.4f}"
message.status = "done" message.status = "done"
return return
@ -498,7 +485,7 @@ class Agent(BaseModel, ABC):
model=model, model=model,
messages=messages, messages=messages,
options={ options={
**message.metadata["options"], **message.metadata.options,
}, },
stream=True, stream=True,
): ):
@ -517,12 +504,10 @@ class Agent(BaseModel, ABC):
if response.done: if response.done:
self.collect_metrics(response) self.collect_metrics(response)
message.metadata["eval_count"] += response.eval_count message.metadata.eval_count += response.eval_count
message.metadata["eval_duration"] += response.eval_duration message.metadata.eval_duration += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count message.metadata.prompt_eval_count += response.prompt_eval_count
message.metadata[ message.metadata.prompt_eval_duration += response.prompt_eval_duration
"prompt_eval_duration"
] += response.prompt_eval_duration
self.context_tokens = ( self.context_tokens = (
response.prompt_eval_count + response.eval_count response.prompt_eval_count + response.eval_count
) )
@ -530,7 +515,7 @@ class Agent(BaseModel, ABC):
yield message yield message
end_time = time.perf_counter() end_time = time.perf_counter()
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}" message.metadata.timers["streamed"] = end_time - start_time
return return
async def process_message( async def process_message(
@ -560,7 +545,7 @@ class Agent(BaseModel, ABC):
self.context.processing = True self.context.processing = True
message.metadata["system_prompt"] = ( message.system_prompt = (
f"<|system|>\n{self.system_prompt.strip()}\n</|system|>" f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
) )
message.context_prompt = "" message.context_prompt = ""
@ -575,11 +560,11 @@ class Agent(BaseModel, ABC):
message.status = "thinking" message.status = "thinking"
yield message yield message
message.metadata["context_size"] = self.set_optimal_context_size( message.context_size = self.set_optimal_context_size(
llm, model, prompt=message.context_prompt llm, model, prompt=message.context_prompt
) )
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." message.response = f"Processing {'RAG augmented ' if message.metadata.rag else ''}query..."
message.status = "thinking" message.status = "thinking"
yield message yield message

View File

@ -19,9 +19,10 @@ import time
import asyncio import asyncio
import numpy as np # type: ignore import numpy as np # type: ignore
from .base import Agent, agent_registry, LLMMessage from . base import Agent, agent_registry, LLMMessage
from ..message import Message from .. message import Message
from ..setup_logging import setup_logging from .. rag import ChromaDBGetResponse
from .. setup_logging import setup_logging
logger = setup_logging() logger = setup_logging()
@ -91,8 +92,11 @@ class JobDescription(Agent):
await asyncio.sleep(1) # Allow the event loop to process the write await asyncio.sleep(1) # Allow the event loop to process the write
self.context.processing = True self.context.processing = True
job_description = message.preamble["job_description"]
resume = message.preamble["resume"]
original_message = message.model_copy() original_message = message.model_copy()
original_message.prompt = job_description
original_message.response = "" original_message.response = ""
self.conversation.add(original_message) self.conversation.add(original_message)
@ -100,11 +104,9 @@ class JobDescription(Agent):
self.model = model self.model = model
self.metrics.generate_count.labels(agent=self.agent_type).inc() self.metrics.generate_count.labels(agent=self.agent_type).inc()
with self.metrics.generate_duration.labels(agent=self.agent_type).time(): with self.metrics.generate_duration.labels(agent=self.agent_type).time():
job_description = message.preamble["job_description"]
resume = message.preamble["resume"]
try: try:
async for message in self.generate_factual_tailored_resume( async for message in self.generate_resume(
message=message, job_description=job_description, resume=resume message=message, job_description=job_description, resume=resume
): ):
if message.status != "done": if message.status != "done":
@ -124,7 +126,7 @@ class JobDescription(Agent):
# Done processing, add message to conversation # Done processing, add message to conversation
self.context.processing = False self.context.processing = False
resume_generation = message.metadata.get("resume_generation", {}) resume_generation = message.metadata.resume_generation
if not resume_generation: if not resume_generation:
message.response = ( message.response = (
"Generation did not generate metadata necessary for processing." "Generation did not generate metadata necessary for processing."
@ -341,18 +343,19 @@ Name: {candidate_name}
## OUTPUT FORMAT: ## OUTPUT FORMAT:
Provide the resume in clean markdown format, ready for the candidate to use. Provide the resume in clean markdown format, ready for the candidate to use.
## REFERENCE (Original Resume):
""" """
# ## REFERENCE (Original Resume):
# """
# Add a truncated version of the original resume for reference if it's too long # # Add a truncated version of the original resume for reference if it's too long
max_resume_length = 25000 # Characters # max_resume_length = 25000 # Characters
if len(original_resume) > max_resume_length: # if len(original_resume) > max_resume_length:
system_prompt += ( # system_prompt += (
original_resume[:max_resume_length] # original_resume[:max_resume_length]
+ "...\n[Original resume truncated due to length]" # + "...\n[Original resume truncated due to length]"
) # )
else: # else:
system_prompt += original_resume # system_prompt += original_resume
prompt = "Create a tailored professional resume that highlights candidate's skills and experience most relevant to the job requirements. Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no commentary before or after." prompt = "Create a tailored professional resume that highlights candidate's skills and experience most relevant to the job requirements. Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no commentary before or after."
return system_prompt, prompt return system_prompt, prompt
@ -413,7 +416,7 @@ Provide the resume in clean markdown format, ready for the candidate to use.
yield message yield message
return return
def calculate_match_statistics(self, job_requirements, skill_assessment_results): def calculate_match_statistics(self, job_requirements, skill_assessment_results) -> dict[str, dict[str, Any]]:
""" """
Calculate statistics about how well the candidate matches job requirements Calculate statistics about how well the candidate matches job requirements
@ -594,9 +597,10 @@ a SPECIFIC skill based solely on their resume and supporting evidence.
}} }}
``` ```
## CANDIDATE RESUME:
{resume}
""" """
# ## CANDIDATE RESUME:
# {resume}
# """
# Add RAG content if provided # Add RAG content if provided
if rag_content: if rag_content:
@ -821,7 +825,7 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
LLMMessage(role="system", content=system_prompt), LLMMessage(role="system", content=system_prompt),
LLMMessage(role="user", content=prompt), LLMMessage(role="user", content=prompt),
] ]
message.metadata["options"] = { message.metadata.options = {
"seed": 8911, "seed": 8911,
"num_ctx": self.context_size, "num_ctx": self.context_size,
"temperature": temperature, # Higher temperature to encourage tool usage "temperature": temperature, # Higher temperature to encourage tool usage
@ -837,7 +841,7 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
model=self.model, model=self.model,
messages=messages, messages=messages,
options={ options={
**message.metadata["options"], **message.metadata.options,
}, },
stream=True, stream=True,
): ):
@ -860,54 +864,40 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
if response.done: if response.done:
self.collect_metrics(response) self.collect_metrics(response)
message.metadata["eval_count"] += response.eval_count message.metadata.eval_count += response.eval_count
message.metadata["eval_duration"] += response.eval_duration message.metadata.eval_duration += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count message.metadata.prompt_eval_count += response.prompt_eval_count
message.metadata[ message.metadata.prompt_eval_duration += response.prompt_eval_duration
"prompt_eval_duration"
] += response.prompt_eval_duration
self.context_tokens = response.prompt_eval_count + response.eval_count self.context_tokens = response.prompt_eval_count + response.eval_count
message.chunk = "" message.chunk = ""
message.status = "done" message.status = "done"
yield message yield message
def rag_function(self, skill: str) -> tuple[str, list[Any]]: def retrieve_rag_content(self, skill: str) -> tuple[str, ChromaDBGetResponse]:
if self.context is None or self.context.file_watcher is None: if self.context is None or self.context.file_watcher is None:
raise ValueError("self.context or self.context.file_watcher is None") raise ValueError("self.context or self.context.file_watcher is None")
try: try:
rag_results = "" rag_results = ""
all_metadata = [] rag_metadata = ChromaDBGetResponse()
chroma_results = self.context.file_watcher.find_similar( chroma_results = self.context.file_watcher.find_similar(query=skill, top_k=10, threshold=0.5)
query=skill, top_k=5, threshold=0.5
)
if chroma_results: if chroma_results:
chroma_embedding = np.array( query_embedding = np.array(chroma_results["query_embedding"]).flatten()
chroma_results["query_embedding"]
).flatten() # Ensure correct shape
print(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.context.file_watcher.umap_model_2d.transform( umap_2d = self.context.file_watcher.umap_model_2d.transform([query_embedding])[0]
[chroma_embedding] umap_3d = self.context.file_watcher.umap_model_3d.transform([query_embedding])[0]
)[0].tolist()
print(
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
) # Debug output
umap_3d = self.context.file_watcher.umap_model_3d.transform( rag_metadata = ChromaDBGetResponse(
[chroma_embedding] query=skill,
)[0].tolist() query_embedding=query_embedding.tolist(),
print( name="JPK",
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}" ids=chroma_results.get("ids", []),
) # Debug output embeddings=chroma_results.get("embeddings", []),
documents=chroma_results.get("documents", []),
all_metadata.append( metadatas=chroma_results.get("metadatas", []),
{ umap_embedding_2d=umap_2d.tolist(),
"name": "JPK", umap_embedding_3d=umap_3d.tolist(),
**chroma_results, size=self.context.file_watcher.collection.count()
"umap_embedding_2d": umap_2d,
"umap_embedding_3d": umap_3d,
}
) )
for index, metadata in enumerate(chroma_results["metadatas"]): for index, metadata in enumerate(chroma_results["metadatas"]):
@ -919,18 +909,19 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
] ]
).strip() ).strip()
rag_results += f""" rag_results += f"""
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")} lines {metadata.get("line_begin", 0)}-{metadata.get("line_end", 0)} Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
Document reference: {chroma_results["ids"][index]}
Content: { content } Content: { content }
""" """
return rag_results, all_metadata return rag_results, rag_metadata
except Exception as e: except Exception as e:
logger.error(e)
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(e)
exit(0) exit(0)
async def generate_factual_tailored_resume( async def generate_resume(
self, message: Message, job_description: str, resume: str self, message: Message, job_description: str, resume: str
) -> AsyncGenerator[Message, None]: ) -> AsyncGenerator[Message, None]:
""" """
@ -947,12 +938,8 @@ Content: { content }
if self.context is None: if self.context is None:
raise ValueError(f"context is None in {self.agent_type}") raise ValueError(f"context is None in {self.agent_type}")
message.status = "thinking" message.metadata.resume_generation = {}
logger.info(message.response) metadata = message.metadata.resume_generation
yield message
message.metadata["resume_generation"] = {}
metadata = message.metadata["resume_generation"]
# Stage 1A: Analyze job requirements # Stage 1A: Analyze job requirements
streaming_message = Message(prompt="Analyze job requirements") streaming_message = Message(prompt="Analyze job requirements")
streaming_message.status = "thinking" streaming_message.status = "thinking"
@ -975,8 +962,8 @@ Content: { content }
prompts = self.process_job_requirements( prompts = self.process_job_requirements(
job_requirements=job_requirements, job_requirements=job_requirements,
resume=resume, resume=resume,
rag_function=self.rag_function, rag_function=self.retrieve_rag_content,
) # , retrieve_rag_content) )
# UI should persist this state of the message # UI should persist this state of the message
partial = message.model_copy() partial = message.model_copy()
@ -1051,9 +1038,15 @@ Content: { content }
partial.title = ( partial.title = (
f"Skill {index}/{total_prompts}: {description} [{match_level}]" f"Skill {index}/{total_prompts}: {description} [{match_level}]"
) )
partial.metadata["rag"] = rag partial.metadata.rag = [rag] # Front-end expects a list of RAG retrievals
if skill_description: if skill_description:
partial.response += f"\n\n{skill_description}" partial.response = f"""
```json
{json.dumps(skill_assessment_results[skill_name]["skill_assessment"])}
```
{skill_description}
"""
yield partial yield partial
self.conversation.add(partial) self.conversation.add(partial)

View File

@ -7,9 +7,10 @@ import numpy as np # type: ignore
import logging import logging
from uuid import uuid4 from uuid import uuid4
from prometheus_client import CollectorRegistry, Counter # type: ignore from prometheus_client import CollectorRegistry, Counter # type: ignore
import traceback
from .message import Message, Tunables from .message import Message, Tunables
from .rag import ChromaDBFileWatcher from .rag import ChromaDBFileWatcher, ChromaDBGetResponse
from . import defines from . import defines
from . import tools as Tools from . import tools as Tools
from .agents import AnyAgent from .agents import AnyAgent
@ -35,7 +36,7 @@ class Context(BaseModel):
user_job_description: Optional[str] = None user_job_description: Optional[str] = None
user_facts: Optional[str] = None user_facts: Optional[str] = None
tools: List[dict] = Tools.enabled_tools(Tools.tools) tools: List[dict] = Tools.enabled_tools(Tools.tools)
rags: List[dict] = [] rags: List[ChromaDBGetResponse] = []
message_history_length: int = 5 message_history_length: int = 5
# Class managed fields # Class managed fields
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
@ -82,56 +83,40 @@ class Context(BaseModel):
if not self.file_watcher: if not self.file_watcher:
message.response = "No RAG context available." message.response = "No RAG context available."
del message.metadata["rag"]
message.status = "done" message.status = "done"
yield message yield message
return return
message.metadata["rag"] = []
for rag in self.rags: for rag in self.rags:
if not rag["enabled"]: if not rag.enabled:
continue continue
message.response = f"Checking RAG context {rag['name']}..." message.response = f"Checking RAG context {rag.name}..."
yield message yield message
chroma_results = self.file_watcher.find_similar( chroma_results = self.file_watcher.find_similar(
query=message.prompt, top_k=top_k, threshold=threshold query=message.prompt, top_k=top_k, threshold=threshold
) )
if chroma_results: if chroma_results:
entries += len(chroma_results["documents"]) query_embedding = np.array(chroma_results["query_embedding"]).flatten()
chroma_embedding = np.array( umap_2d = self.file_watcher.umap_model_2d.transform([query_embedding])[0]
chroma_results["query_embedding"] umap_3d = self.file_watcher.umap_model_3d.transform([query_embedding])[0]
).flatten() # Ensure correct shape
print(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform( rag_metadata = ChromaDBGetResponse(
[chroma_embedding] query=message.prompt,
)[0].tolist() query_embedding=query_embedding.tolist(),
print( name=rag.name,
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}" ids=chroma_results.get("ids", []),
) # Debug output embeddings=chroma_results.get("embeddings", []),
documents=chroma_results.get("documents", []),
umap_3d = self.file_watcher.umap_model_3d.transform( metadatas=chroma_results.get("metadatas", []),
[chroma_embedding] umap_embedding_2d=umap_2d.tolist(),
)[0].tolist() umap_embedding_3d=umap_3d.tolist(),
print( size=self.file_watcher.collection.count()
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
) # Debug output
message.metadata["rag"].append(
{
"name": rag["name"],
**chroma_results,
"umap_embedding_2d": umap_2d,
"umap_embedding_3d": umap_3d,
"size": self.file_watcher.collection.count()
}
) )
message.response = f"Results from {rag['name']} RAG: {len(chroma_results['documents'])} results."
yield message
if entries == 0: message.metadata.rag.append(rag_metadata)
del message.metadata["rag"] message.response = f"Results from {rag.name} RAG: {len(chroma_results['documents'])} results."
yield message
message.response = ( message.response = (
f"RAG context gathered from results from {entries} documents." f"RAG context gathered from results from {entries} documents."
@ -142,7 +127,8 @@ class Context(BaseModel):
except Exception as e: except Exception as e:
message.status = "error" message.status = "error"
message.response = f"Error generating RAG results: {str(e)}" message.response = f"Error generating RAG results: {str(e)}"
logger.error(e) logger.error(traceback.format_exc())
logger.error(message.response)
yield message yield message
return return

View File

@ -1,12 +1,28 @@
from pydantic import BaseModel, Field # type: ignore from pydantic import BaseModel, Field # type: ignore
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any, Union, Mapping
from datetime import datetime, timezone from datetime import datetime, timezone
from . rag import ChromaDBGetResponse
from ollama._types import Options # type: ignore
class Tunables(BaseModel): class Tunables(BaseModel):
enable_rag: bool = Field(default=True) # Enable RAG collection chromadb matching enable_rag: bool = True # Enable RAG collection chromadb matching
enable_tools: bool = Field(default=True) # Enable LLM to use tools enable_tools: bool = True # Enable LLM to use tools
enable_context: bool = Field(default=True) # Add <|context|> field to message enable_context: bool = True # Add <|context|> field to message
class MessageMetaData(BaseModel):
rag: List[ChromaDBGetResponse] = Field(default_factory=list)
eval_count: int = 0
eval_duration: int = 0
prompt_eval_count: int = 0
prompt_eval_duration: int = 0
context_size: int = 0
resume_generation: Optional[Dict[str, Any]] = None
options: Optional[Union[Mapping[str, Any], Options]] = None
tools: Optional[Dict[str, Any]] = None
timers: Optional[Dict[str, float]] = None
#resume : str = ""
#match_stats: Optional[Dict[str, Dict[str, Any]]] = Field(default=None)
class Message(BaseModel): class Message(BaseModel):
model_config = {"arbitrary_types_allowed": True} # Allow Event model_config = {"arbitrary_types_allowed": True} # Allow Event
@ -18,35 +34,21 @@ class Message(BaseModel):
# Generated while processing message # Generated while processing message
status: str = "" # Status of the message status: str = "" # Status of the message
preamble: dict[str, str] = {} # Preamble to be prepended to the prompt preamble: Dict[str, Any] = Field(default_factory=dict) # Preamble to be prepended to the prompt
system_prompt: str = "" # System prompt provided to the LLM system_prompt: str = "" # System prompt provided to the LLM
context_prompt: str = "" # Full content of the message (preamble + prompt) context_prompt: str = "" # Full content of the message (preamble + prompt)
response: str = "" # LLM response to the preamble + query response: str = "" # LLM response to the preamble + query
metadata: Dict[str, Any] = Field( metadata: MessageMetaData = Field(default_factory=MessageMetaData)
default_factory=lambda: {
"rag": [],
"eval_count": 0,
"eval_duration": 0,
"prompt_eval_count": 0,
"prompt_eval_duration": 0,
"context_size": 0,
}
)
network_packets: int = 0 # Total number of streaming packets network_packets: int = 0 # Total number of streaming packets
network_bytes: int = 0 # Total bytes sent while streaming packets network_bytes: int = 0 # Total bytes sent while streaming packets
actions: List[str] = ( actions: List[str] = (
[] []
) # Other session modifying actions performed while processing the message ) # Other session modifying actions performed while processing the message
timestamp: datetime = datetime.now(timezone.utc) timestamp: str = str(datetime.now(timezone.utc))
chunk: str = Field( chunk: str = ""
default="" partial_response: str = ""
) # This needs to be serialized so it will be sent in responses title: str = ""
partial_response: str = Field( context_size: int = 0
default=""
) # This needs to be serialized so it will be sent in responses on timeout
title: str = Field(
default=""
) # This needs to be serialized so it will be sent in responses on timeout
def add_action(self, action: str | list[str]) -> None: def add_action(self, action: str | list[str]) -> None:
"""Add a actions(s) to the message.""" """Add a actions(s) to the message."""

View File

@ -1,5 +1,5 @@
from pydantic import BaseModel # type: ignore from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any, Union
import os import os
import glob import glob
from pathlib import Path from pathlib import Path
@ -7,15 +7,9 @@ import time
import hashlib import hashlib
import asyncio import asyncio
import logging import logging
import os
import glob
import time
import hashlib
import asyncio
import json import json
import numpy as np # type: ignore import numpy as np # type: ignore
import traceback import traceback
import os
import chromadb import chromadb
import ollama import ollama
@ -38,19 +32,51 @@ else:
# When imported as a module, use relative imports # When imported as a module, use relative imports
from . import defines from . import defines
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"] __all__ = ["ChromaDBFileWatcher", "start_file_watcher", "ChromaDBGetResponse"]
DEFAULT_CHUNK_SIZE = 750 DEFAULT_CHUNK_SIZE = 750
DEFAULT_CHUNK_OVERLAP = 100 DEFAULT_CHUNK_OVERLAP = 100
class ChromaDBGetResponse(BaseModel): class ChromaDBGetResponse(BaseModel):
ids: List[str] name: str = ""
embeddings: Optional[List[List[float]]] = None size: int = 0
documents: Optional[List[str]] = None ids: List[str] = []
metadatas: Optional[List[Dict[str, Any]]] = None embeddings: List[List[float]] = Field(default=[])
documents: List[str] = []
metadatas: List[Dict[str, Any]] = []
query: str = ""
query_embedding: Optional[List[float]] = Field(default=None)
umap_embedding_2d: Optional[List[float]] = Field(default=None)
umap_embedding_3d: Optional[List[float]] = Field(default=None)
enabled: bool = True
class Config:
validate_assignment = True
@field_validator("embeddings", "query_embedding", "umap_embedding_2d", "umap_embedding_3d")
@classmethod
def validate_embeddings(cls, value, field):
logging.info(f"Validating {field.field_name} with value: {type(value)} - {value}")
if value is None:
return value
if isinstance(value, np.ndarray):
if field.field_name == "embeddings":
if value.ndim != 2:
raise ValueError(f"{field.name} must be a 2-dimensional NumPy array")
return [[float(x) for x in row] for row in value.tolist()]
else:
if value.ndim != 1:
raise ValueError(f"{field.field_name} must be a 1-dimensional NumPy array")
return [float(x) for x in value.tolist()]
if field.field_name == "embeddings":
if not all(isinstance(sublist, list) and all(isinstance(x, (int, float)) for x in sublist) for sublist in value):
raise ValueError(f"{field.field_name} must be a list of lists of floats")
return [[float(x) for x in sublist] for sublist in value]
else:
if not isinstance(value, list) or not all(isinstance(x, (int, float)) for x in value):
raise ValueError(f"{field.field_name} must be a list of floats")
return [float(x) for x in value]
class ChromaDBFileWatcher(FileSystemEventHandler): class ChromaDBFileWatcher(FileSystemEventHandler):
def __init__( def __init__(
self, self,
@ -323,13 +349,13 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
n_components=2, n_components=2,
random_state=8911, random_state=8911,
metric="cosine", metric="cosine",
n_neighbors=15, n_neighbors=30,
min_dist=0.1, min_dist=0.1,
) )
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors) self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
logging.info( # logging.info(
f"2D UMAP model n_components: {self._umap_model_2d.n_components}" # f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
) # Should be 2 # ) # Should be 2
logging.info( logging.info(
f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors" f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors"
@ -338,13 +364,13 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
n_components=3, n_components=3,
random_state=8911, random_state=8911,
metric="cosine", metric="cosine",
n_neighbors=15, n_neighbors=30,
min_dist=0.1, min_dist=0.01,
) )
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors) self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
logging.info( # logging.info(
f"3D UMAP model n_components: {self._umap_model_3d.n_components}" # f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
) # Should be 3 # ) # Should be 3
def _get_vector_collection(self, recreate=False) -> Collection: def _get_vector_collection(self, recreate=False) -> Collection:
"""Get or create a ChromaDB collection.""" """Get or create a ChromaDB collection."""
@ -380,14 +406,36 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
"""Split documents into chunks using the text splitter.""" """Split documents into chunks using the text splitter."""
return self.text_splitter.split_documents(docs) return self.text_splitter.split_documents(docs)
def get_embedding(self, text, normalize=True): def get_embedding(self, text: str) -> np.ndarray:
"""Generate embeddings using Ollama.""" """Generate and normalize an embedding for the given text."""
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
embedding = response["embedding"] # Get embedding
try:
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
embedding = np.array(response["embedding"])
except Exception as e:
logging.error(f"Failed to get embedding: {e}")
raise
# Log diagnostics
logging.info(f"Input text: {text}")
logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
# Check for invalid embeddings
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
raise ValueError("Invalid embedding returned from Ollama.")
# Check normalization
norm = np.linalg.norm(embedding)
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
# Normalize if needed
if not is_normalized:
embedding = embedding / norm
logging.info("Embedding normalized manually.")
if normalize:
normalized = self._normalize_embeddings(embedding)
return normalized
return embedding return embedding
def add_embeddings_to_collection(self, chunks: List[Chunk]): def add_embeddings_to_collection(self, chunks: List[Chunk]):