All seems to be working

This commit is contained in:
James Ketr 2025-05-16 10:41:09 -07:00
parent 0d27239ca6
commit 58dadb76f0
30 changed files with 1714 additions and 314 deletions

View File

@ -77,6 +77,17 @@ RUN apt-get update \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
# pydub is loaded by torch, which will throw a warning if ffmpeg isn't installed
RUN apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y software-properties-common \
&& add-apt-repository -y ppa:kobuk-team/intel-graphics \
&& apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
ffmpeg \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
# Prerequisite for ze-monitor
RUN apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \

0
cache/.keep vendored Normal file → Executable file
View File

0
cache/grafana/.keep vendored Normal file → Executable file
View File

0
cache/prometheus/.keep vendored Normal file → Executable file
View File

File diff suppressed because it is too large Load Diff

View File

@ -19,6 +19,7 @@
"@types/react": "^19.0.12",
"@types/react-dom": "^19.0.4",
"@uiw/react-json-view": "^2.0.0-alpha.31",
"@uiw/react-markdown-editor": "^6.1.4",
"jsonrepair": "^3.12.0",
"markdown-it": "^14.1.0",
"mermaid": "^11.6.0",

View File

@ -113,6 +113,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
const tunables = await response.json();
serverTunables.system_prompt = tunables.system_prompt;
console.log(tunables);
setSystemPrompt(tunables.system_prompt)
setSnack("System prompt updated", "success");
} catch (error) {
@ -167,11 +168,19 @@ const ControlsPage = (props: BackstoryPageProps) => {
body: JSON.stringify({ "reset": types }),
});
if (response.ok) {
if (!response.ok) {
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
}
if (!response.body) {
throw new Error('Response body is null');
}
const data = await response.json();
if (data.error) {
throw Error()
throw Error(data.error);
}
for (const [key, value] of Object.entries(data)) {
switch (key) {
case "rags":
@ -189,9 +198,6 @@ const ControlsPage = (props: BackstoryPageProps) => {
}
}
setSnack(message, "success");
} else {
throw Error(`${{ status: response.status, message: response.statusText }}`);
}
} catch (error) {
console.error('Fetch error:', error);
setSnack("Unable to restore defaults", "error");
@ -203,20 +209,37 @@ const ControlsPage = (props: BackstoryPageProps) => {
if (systemInfo !== undefined || sessionId === undefined) {
return;
}
fetch(connectionBase + `/api/system-info/${sessionId}`, {
const fetchSystemInfo = async () => {
try {
const response = await fetch(connectionBase + `/api/system-info/${sessionId}`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
},
})
.then(response => response.json())
.then(data => {
if (!response.ok) {
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
}
if (!response.body) {
throw new Error('Response body is null');
}
const data = await response.json();
if (data.error) {
throw Error(data.error);
}
setSystemInfo(data);
})
.catch(error => {
} catch (error) {
console.error('Error obtaining system information:', error);
setSnack("Unable to obtain system information.", "error");
});
};
}
fetchSystemInfo();
}, [systemInfo, setSystemInfo, setSnack, sessionId])
useEffect(() => {
@ -273,6 +296,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
return;
}
const fetchTunables = async () => {
try {
// Make the fetch request with proper headers
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
method: 'GET',
@ -288,10 +312,14 @@ const ControlsPage = (props: BackstoryPageProps) => {
setMessageHistoryLength(data["message_history_length"]);
setTools(data["tools"]);
setRags(data["rags"]);
} catch (error) {
console.error('Fetch error:', error);
setSnack("System prompt update failed", "error");
}
}
fetchTunables();
}, [sessionId, setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags]);
}, [sessionId, setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags, setSnack]);
const toggle = async (type: string, index: number) => {
switch (type) {

View File

@ -285,7 +285,9 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
throw new Error('Response body is null');
}
setConversation([])
setProcessingMessage(undefined);
setStreamingMessage(undefined);
setConversation([]);
setNoInteractions(true);
} catch (e) {
@ -341,13 +343,23 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
// Add a small delay to ensure React has time to update the UI
await new Promise(resolve => setTimeout(resolve, 0));
let data: any = query;
if (type === "job_description") {
data = {
prompt: "",
agent_options: {
job_description: query.prompt,
}
}
}
const response = await fetch(connectionBase + `/api/${type}/${sessionId}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json',
},
body: JSON.stringify(query)
body: JSON.stringify(data)
});
setSnack(`Query sent.`, "info");

View File

@ -77,7 +77,7 @@ interface MessageMetaData {
vector_embedding: number[];
},
origin: string,
rag: any,
rag: any[],
tools?: {
tool_calls: any[],
},
@ -117,8 +117,6 @@ const MessageMeta = (props: MessageMetaProps) => {
} = props.metadata || {};
const message: any = props.messageProps.message;
rag.forEach((r: any) => r.query = message.prompt);
let llm_submission: string = "<|system|>\n"
llm_submission += message.system_prompt + "\n\n"
llm_submission += message.context_prompt
@ -176,7 +174,10 @@ const MessageMeta = (props: MessageMetaProps) => {
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "column", mt: 1, mb: 1, fontWeight: "bold" }}>
{tool.name}
</Box>
<JsonView displayDataTypes={false} objectSortKeys={true} collapsed={1} value={JSON.parse(tool.content)} style={{ fontSize: "0.8rem", maxHeight: "20rem", overflow: "auto" }}>
<JsonView
displayDataTypes={false}
objectSortKeys={true}
collapsed={1} value={JSON.parse(tool.content)} style={{ fontSize: "0.8rem", maxHeight: "20rem", overflow: "auto" }}>
<JsonView.String
render={({ children, ...reset }) => {
if (typeof (children) === "string" && children.match("\n")) {

View File

@ -151,7 +151,6 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
}, []);
const jobResponse = useCallback(async (message: BackstoryMessage) => {
console.log('onJobResponse', message);
if (message.actions && message.actions.includes("job_description")) {
await jobConversationRef.current.fetchHistory();
}

View File

@ -53,7 +53,7 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
}}
displayDataTypes={false}
objectSortKeys={false}
collapsed={true}
collapsed={1}
shortenTextAfterLength={100}
value={fixed}>
<JsonView.String

View File

@ -10,7 +10,6 @@ import FormControlLabel from '@mui/material/FormControlLabel';
import Switch from '@mui/material/Switch';
import useMediaQuery from '@mui/material/useMediaQuery';
import { SxProps, useTheme } from '@mui/material/styles';
import JsonView from '@uiw/react-json-view';
import Table from '@mui/material/Table';
import TableBody from '@mui/material/TableBody';
import TableCell from '@mui/material/TableCell';
@ -499,13 +498,13 @@ The scatter graph shows the query in N-dimensional space, mapped to ${view2D ? '
</Box>
}
<Box sx={{ display: "flex", flexDirection: "column" }}>
<Box sx={{ display: "flex", flexDirection: "column", flexGrow: 1 }}>
{node === null &&
<Paper sx={{ m: 0.5, p: 2, flexGrow: 1 }}>
Click a point in the scatter-graph to see information about that node.
</Paper>
}
{!inline && node !== null && node.full_content &&
{node !== null && node.full_content &&
<Scrollable
autoscroll={false}
sx={{
@ -521,7 +520,7 @@ The scatter graph shows the query in N-dimensional space, mapped to ${view2D ? '
node.full_content.split('\n').map((line, index) => {
index += 1 + node.chunk_begin;
const bgColor = (index > node.line_begin && index <= node.line_end) ? '#f0f0f0' : 'auto';
return <Box key={index} sx={{ display: "flex", flexDirection: "row", borderBottom: '1px solid #d0d0d0', ':first-child': { borderTop: '1px solid #d0d0d0' }, backgroundColor: bgColor }}>
return <Box key={index} sx={{ display: "flex", flexDirection: "row", borderBottom: '1px solid #d0d0d0', ':first-of-type': { borderTop: '1px solid #d0d0d0' }, backgroundColor: bgColor }}>
<Box sx={{ fontFamily: 'courier', fontSize: "0.8rem", minWidth: "2rem", pt: "0.1rem", align: "left", verticalAlign: "top" }}>{index}</Box>
<pre style={{ margin: 0, padding: 0, border: "none", minHeight: "1rem" }} >{line || " "}</pre>
</Box>;

View File

@ -118,6 +118,10 @@ const useAutoScrollToBottom = (
let shouldScroll = false;
const scrollTo = scrollToRef.current;
if (isPasteEvent && !scrollTo) {
console.error("Paste Event triggered without scrollTo");
}
if (scrollTo) {
// Get positions
const containerRect = container.getBoundingClientRect();
@ -130,7 +134,7 @@ const useAutoScrollToBottom = (
scrollToRect.top < containerBottom && scrollToRect.bottom > containerTop;
// Scroll on paste or if TextField is visible and user isn't scrolling up
shouldScroll = (isPasteEvent || isTextFieldVisible) && !isUserScrollingUpRef.current;
shouldScroll = isPasteEvent || (isTextFieldVisible && !isUserScrollingUpRef.current);
if (shouldScroll) {
requestAnimationFrame(() => {
debug && console.debug('Scrolling to container bottom:', {
@ -198,10 +202,12 @@ const useAutoScrollToBottom = (
}
const handlePaste = () => {
console.log("handlePaste");
// Delay scroll check to ensure DOM updates
setTimeout(() => {
console.log("scrolling for handlePaste");
requestAnimationFrame(() => checkAndScrollToBottom(true));
}, 0);
}, 100);
};
window.addEventListener('mousemove', pauseScroll);

View File

@ -1,5 +1,5 @@
[tool.black]
line-length = 88
line-length = 120
target-version = ['py312']
include = '\.pyi?$'
exclude = '''

View File

@ -1,7 +1,9 @@
LLM_TIMEOUT = 600
from utils import logger
from pydantic import BaseModel, Field # type: ignore
from pydantic import BaseModel, Field, ValidationError # type: ignore
from pydantic_core import PydanticSerializationError # type: ignore
from typing import List
from typing import AsyncGenerator, Dict, Optional
@ -61,6 +63,7 @@ from prometheus_client import CollectorRegistry, Counter # type: ignore
from utils import (
rag as Rag,
ChromaDBGetResponse,
tools as Tools,
Context,
Conversation,
@ -69,15 +72,17 @@ from utils import (
Metrics,
Tunables,
defines,
check_serializable,
logger,
)
rags = [
{
"name": "JPK",
"enabled": True,
"description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects.",
},
rags : List[ChromaDBGetResponse] = [
ChromaDBGetResponse(
name="JPK",
enabled=True,
description="Expert data about James Ketrenos, including work history, personal hobbies, and projects.",
),
# { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
]
@ -461,10 +466,8 @@ class WebServer:
context = self.upsert_context(context_id)
agent = context.get_agent(agent_type)
if not agent:
return JSONResponse(
{"error": f"{agent_type} is not recognized", "context": context.id},
status_code=404,
)
response = { "history": [] }
return JSONResponse(response)
data = await request.json()
try:
@ -475,8 +478,8 @@ class WebServer:
logger.info(f"Resetting {reset_operation}")
case "rags":
logger.info(f"Resetting {reset_operation}")
context.rags = rags.copy()
response["rags"] = context.rags
context.rags = [ r.model_copy() for r in rags]
response["rags"] = [ r.model_dump(mode="json") for r in context.rags ]
case "tools":
logger.info(f"Resetting {reset_operation}")
context.tools = Tools.enabled_tools(Tools.tools)
@ -537,6 +540,7 @@ class WebServer:
data = await request.json()
agent = context.get_agent("chat")
if not agent:
logger.info("chat agent does not exist on this context!")
return JSONResponse(
{"error": f"chat is not recognized", "context": context.id},
status_code=404,
@ -572,20 +576,20 @@ class WebServer:
case "rags":
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
rags: list[dict[str, Any]] = data[k]
if not rags:
rag_configs: list[dict[str, Any]] = data[k]
if not rag_configs:
return JSONResponse(
{
"status": "error",
"message": "RAGs can not be empty.",
}
)
for rag in rags:
for config in rag_configs:
for context_rag in context.rags:
if context_rag["name"] == rag["name"]:
context_rag["enabled"] = rag["enabled"]
if context_rag.name == config["name"]:
context_rag.enabled = config["enabled"]
self.save_context(context_id)
return JSONResponse({"rags": context.rags})
return JSONResponse({"rags": [ r.model_dump(mode="json") for r in context.rags]})
case "system_prompt":
system_prompt = data[k].strip()
@ -615,12 +619,10 @@ class WebServer:
@self.app.get("/api/tunables/{context_id}")
async def get_tunables(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id):
logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id)
agent = context.get_agent("chat")
if not agent:
logger.info("chat agent does not exist on this context!")
return JSONResponse(
{"error": f"chat is not recognized", "context": context.id},
status_code=404,
@ -629,7 +631,7 @@ class WebServer:
{
"system_prompt": agent.system_prompt,
"message_history_length": context.message_history_length,
"rags": context.rags,
"rags": [ r.model_dump(mode="json") for r in context.rags ],
"tools": [
{
**t["function"],
@ -674,6 +676,7 @@ class WebServer:
error = {
"error": f"Attempt to create agent type: {agent_type} failed: {e}"
}
logger.info(error)
return JSONResponse(error, status_code=404)
try:
@ -887,8 +890,35 @@ class WebServer:
file_path = os.path.join(defines.context_dir, context_id)
# Serialize the data to JSON and write to file
try:
# Check for non-serializable fields before dumping
serialization_errors = check_serializable(context)
if serialization_errors:
for error in serialization_errors:
logger.error(error)
raise ValueError("Found non-serializable fields in the model")
# Dump the model prior to opening file in case there is
# a validation error so it doesn't delete the current
# context session
json_data = context.model_dump_json(by_alias=True)
with open(file_path, "w") as f:
f.write(context.model_dump_json(by_alias=True))
f.write(json_data)
except ValidationError as e:
logger.error(e)
logger.error(traceback.format_exc())
for error in e.errors():
print(f"Field: {error['loc'][0]}, Error: {error['msg']}")
except PydanticSerializationError as e:
logger.error(e)
logger.error(traceback.format_exc())
logger.error(f"Serialization error: {str(e)}")
# Inspect the model to identify problematic fields
for field_name, value in context.__dict__.items():
if isinstance(value, np.ndarray):
logger.error(f"Field '{field_name}' contains non-serializable type: {type(value)}")
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
return context_id
@ -942,12 +972,8 @@ class WebServer:
self.contexts[context_id] = context
logger.info(f"Successfully loaded context {context_id}")
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in file: {e}")
except Exception as e:
logger.error(f"Error validating context: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Fallback to creating a new context
self.contexts[context_id] = Context(
@ -985,7 +1011,7 @@ class WebServer:
# context.add_agent(JobDescription(system_prompt = system_job_description))
# context.add_agent(FactCheck(system_prompt = system_fact_check))
context.tools = Tools.enabled_tools(Tools.tools)
context.rags = rags.copy()
context.rags_enabled = [ r.name for r in rags ]
logger.info(f"{context.id} created and added to contexts.")
self.contexts[context.id] = context

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -20,5 +20,5 @@ observer, file_watcher = Rag.start_file_watcher(
)
context = Context(file_watcher=file_watcher)
data = context.model_dump(mode="json")
context = Context.from_json(json.dumps(data), file_watcher=file_watcher)
json_data = context.model_dump(mode="json")
context = Context.model_validate(json_data)

View File

@ -0,0 +1,89 @@
# From /opt/backstory run:
# python -m src.tests.test-embedding
import numpy as np # type: ignore
import logging
import argparse
from ollama import Client # type: ignore
from ..utils import defines
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
def get_embedding(text: str, embedding_model: str, ollama_server: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text."""
llm = Client(host=ollama_server)
# Get embedding
try:
response = llm.embeddings(model=embedding_model, prompt=text)
embedding = np.array(response["embedding"])
except Exception as e:
logging.error(f"Failed to get embedding: {e}")
raise
# Log diagnostics
logging.info(f"Input text: {text}")
logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
# Check for invalid embeddings
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
raise ValueError("Invalid embedding returned from Ollama.")
# Check normalization
norm = np.linalg.norm(embedding)
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
# Normalize if needed
if not is_normalized:
embedding = embedding / norm
logging.info("Embedding normalized manually.")
return embedding
def main():
"""Main function to generate and normalize an embedding from command-line input."""
parser = argparse.ArgumentParser(description="Generate embeddings for text using mxbai-embed-large.")
parser.add_argument(
"--text",
type=str,
nargs="+", # Allow multiple text inputs
default=["Test sentence."],
help="Text(s) to generate embeddings for (default: 'Test sentence.')",
)
parser.add_argument(
"--ollama-server",
type=str,
default=defines.ollama_api_url,
help=f"Ollama server URL (default: {defines.ollama_api_url})",
)
parser.add_argument(
"--embedding-model",
type=str,
default=defines.embedding_model,
help=f"Embedding model name (default: {defines.embedding_model})",
)
args = parser.parse_args()
# Validate input
for text in args.text:
if not text or not isinstance(text, str):
logging.error("Input text must be a non-empty string.")
raise ValueError("Input text must be a non-empty string.")
# Generate embeddings for each text
embeddings = []
for text in args.text:
embedding = get_embedding(
text=text,
embedding_model=args.embedding_model,
ollama_server=args.ollama_server,
)
embeddings.append(embedding)
if __name__ == "__main__":
main()

View File

@ -2,10 +2,15 @@
# python -m src.tests.test-message
from ..utils import logger
from ..utils import Message
from ..utils import Message, MessageMetaData
from ..utils import ChromaDBGetResponse
import json
prompt = "This is a test"
message = Message(prompt=prompt)
print(message.model_dump(mode="json"))
#message.metadata = MessageMetaData()
rag = ChromaDBGetResponse()
message.metadata.rag = rag
print(message.model_dump(mode="json"))

81
src/tests/test-rag.py Normal file
View File

@ -0,0 +1,81 @@
# From /opt/backstory run:
# python -m src.tests.test-rag
from ..utils import logger
from pydantic import BaseModel, field_validator # type: ignore
from prometheus_client import CollectorRegistry # type: ignore
from typing import List, Dict, Any, Optional
import ollama
import numpy as np # type: ignore
from ..utils import (rag as Rag, ChromaDBGetResponse)
from ..utils import Context
from ..utils import defines
import json
chroma_results = {
"ids": ["1", "2"],
"embeddings": np.array([[1.0, 2.0], [3.0, 4.0]]),
"documents": ["doc1", "doc2"],
"metadatas": [{"meta": "data1"}, {"meta": "data2"}],
"query_embedding": np.array([0.1, 0.2, 0.3])
}
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
umap_2d = np.array([0.4, 0.5]) # Example UMAP output
umap_3d = np.array([0.6, 0.7, 0.8]) # Example UMAP output
rag_metadata = ChromaDBGetResponse(
query="test",
query_embedding=query_embedding,
name="JPK",
ids=chroma_results.get("ids", []),
size=2
)
logger.info(json.dumps(rag_metadata.model_dump(mode="json")))
logger.info(f"Assigning type {type(umap_2d)} to rag_metadata.umap_embedding_2d")
rag_metadata.umap_embedding_2d = umap_2d
logger.info(json.dumps(rag_metadata.model_dump(mode="json")))
rag = ChromaDBGetResponse()
rag.embeddings = np.array([[1.0, 2.0], [3.0, 4.0]])
json_str = rag.model_dump(mode="json")
logger.info(json_str)
rag = ChromaDBGetResponse.model_validate(json_str)
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
prometheus_collector = CollectorRegistry()
observer, file_watcher = Rag.start_file_watcher(
llm=llm,
watch_directory=defines.doc_dir,
recreate=False, # Don't recreate if exists
)
context = Context(
file_watcher=file_watcher,
prometheus_collector=prometheus_collector,
)
skill="Codes in C++"
if context.file_watcher:
chroma_results = context.file_watcher.find_similar(query=skill, top_k=10, threshold=0.5)
if chroma_results:
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
umap_2d = context.file_watcher.umap_model_2d.transform([query_embedding])[0]
umap_3d = context.file_watcher.umap_model_3d.transform([query_embedding])[0]
rag_metadata = ChromaDBGetResponse(
query=skill,
query_embedding=query_embedding,
name="JPK",
ids=chroma_results.get("ids", []),
embeddings=chroma_results.get("embeddings", []),
documents=chroma_results.get("documents", []),
metadatas=chroma_results.get("metadatas", []),
umap_embedding_2d=umap_2d,
umap_embedding_3d=umap_3d,
size=context.file_watcher.collection.count()
)
json_str = context.model_dump(mode="json")
logger.info(json_str)

View File

@ -1,25 +1,34 @@
from __future__ import annotations
from pydantic import BaseModel # type: ignore
from typing import (
Any,
Set
)
import importlib
import json
from . import defines
from .context import Context
from .conversation import Conversation
from .message import Message, Tunables
from .rag import ChromaDBFileWatcher, start_file_watcher
from .message import Message, Tunables, MessageMetaData
from .rag import ChromaDBFileWatcher, ChromaDBGetResponse, start_file_watcher
from .setup_logging import setup_logging
from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
from .metrics import Metrics
from .check_serializable import check_serializable
__all__ = [
"Agent",
"Message",
"Tunables",
"MessageMetaData",
"Context",
"Conversation",
"Message",
"Metrics",
"ChromaDBFileWatcher",
'ChromaDBGetResponse',
"start_file_watcher",
"check_serializable",
"logger",
]

View File

@ -165,9 +165,8 @@ class Agent(BaseModel, ABC):
if message.status != "done":
yield message
if "rag" in message.metadata and message.metadata["rag"]:
for rag in message.metadata["rag"]:
for doc in rag["documents"]:
for rag in message.metadata.rag:
for doc in rag.documents:
rag_context += f"{doc}\n"
message.preamble = {}
@ -189,7 +188,7 @@ class Agent(BaseModel, ABC):
llm: Any,
model: str,
message: Message,
tool_message: Any,
tool_message: Any, # llama response message
messages: List[LLMMessage],
) -> AsyncGenerator[Message, None]:
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
@ -199,10 +198,10 @@ class Agent(BaseModel, ABC):
if not self.context:
raise ValueError("Context is not set for this agent.")
if not message.metadata["tools"]:
if not message.metadata.tools:
raise ValueError("tools field not initialized")
tool_metadata = message.metadata["tools"]
tool_metadata = message.metadata.tools
tool_metadata["tool_calls"] = []
message.status = "tooling"
@ -301,8 +300,7 @@ class Agent(BaseModel, ABC):
model=model,
messages=messages,
options={
**message.metadata["options"],
# "temperature": 0.5,
**message.metadata.options,
},
stream=True,
):
@ -316,12 +314,10 @@ class Agent(BaseModel, ABC):
if response.done:
self.collect_metrics(response)
message.metadata["eval_count"] += response.eval_count
message.metadata["eval_duration"] += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count
message.metadata[
"prompt_eval_duration"
] += response.prompt_eval_duration
message.metadata.eval_count += response.eval_count
message.metadata.eval_duration += response.eval_duration
message.metadata.prompt_eval_count += response.prompt_eval_count
message.metadata.prompt_eval_duration += response.prompt_eval_duration
self.context_tokens = (
response.prompt_eval_count + response.eval_count
)
@ -329,9 +325,7 @@ class Agent(BaseModel, ABC):
yield message
end_time = time.perf_counter()
message.metadata["timers"][
"llm_with_tools"
] = f"{(end_time - start_time):.4f}"
message.metadata.timers["llm_with_tools"] = end_time - start_time
return
def collect_metrics(self, response):
@ -370,22 +364,22 @@ class Agent(BaseModel, ABC):
LLMMessage(role="user", content=message.context_prompt.strip())
)
# message.metadata["messages"] = messages
message.metadata["options"] = {
# message.messages = messages
message.metadata.options = {
"seed": 8911,
"num_ctx": self.context_size,
"temperature": temperature, # Higher temperature to encourage tool usage
}
# Create a dict for storing various timing stats
message.metadata["timers"] = {}
message.metadata.timers = {}
use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
message.metadata["tools"] = {
message.metadata.tools = {
"available": llm_tools(self.context.tools),
"used": False,
}
tool_metadata = message.metadata["tools"]
tool_metadata = message.metadata.tools
if use_tools:
message.status = "thinking"
@ -408,17 +402,14 @@ class Agent(BaseModel, ABC):
messages=tool_metadata["messages"],
tools=tool_metadata["available"],
options={
**message.metadata["options"],
# "num_predict": 1024, # "Low" token limit to cut off after tool call
**message.metadata.options,
},
stream=False, # No need to stream the probe
)
self.collect_metrics(response)
end_time = time.perf_counter()
message.metadata["timers"][
"tool_check"
] = f"{(end_time - start_time):.4f}"
message.metadata.timers["tool_check"] = end_time - start_time
if not response.message.tool_calls:
logger.info("LLM indicates tools will not be used")
# The LLM will not use tools, so disable use_tools so we can stream the full response
@ -442,16 +433,14 @@ class Agent(BaseModel, ABC):
messages=tool_metadata["messages"], # messages,
tools=tool_metadata["available"],
options={
**message.metadata["options"],
**message.metadata.options,
},
stream=False,
)
self.collect_metrics(response)
end_time = time.perf_counter()
message.metadata["timers"][
"non_streaming"
] = f"{(end_time - start_time):.4f}"
message.metadata.timers["non_streaming"] = end_time - start_time
if not response:
message.status = "error"
@ -475,9 +464,7 @@ class Agent(BaseModel, ABC):
return
yield message
end_time = time.perf_counter()
message.metadata["timers"][
"process_tool_calls"
] = f"{(end_time - start_time):.4f}"
message.metadata.timers["process_tool_calls"] = end_time - start_time
message.status = "done"
return
@ -498,7 +485,7 @@ class Agent(BaseModel, ABC):
model=model,
messages=messages,
options={
**message.metadata["options"],
**message.metadata.options,
},
stream=True,
):
@ -517,12 +504,10 @@ class Agent(BaseModel, ABC):
if response.done:
self.collect_metrics(response)
message.metadata["eval_count"] += response.eval_count
message.metadata["eval_duration"] += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count
message.metadata[
"prompt_eval_duration"
] += response.prompt_eval_duration
message.metadata.eval_count += response.eval_count
message.metadata.eval_duration += response.eval_duration
message.metadata.prompt_eval_count += response.prompt_eval_count
message.metadata.prompt_eval_duration += response.prompt_eval_duration
self.context_tokens = (
response.prompt_eval_count + response.eval_count
)
@ -530,7 +515,7 @@ class Agent(BaseModel, ABC):
yield message
end_time = time.perf_counter()
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
message.metadata.timers["streamed"] = end_time - start_time
return
async def process_message(
@ -560,7 +545,7 @@ class Agent(BaseModel, ABC):
self.context.processing = True
message.metadata["system_prompt"] = (
message.system_prompt = (
f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
)
message.context_prompt = ""
@ -575,11 +560,11 @@ class Agent(BaseModel, ABC):
message.status = "thinking"
yield message
message.metadata["context_size"] = self.set_optimal_context_size(
message.context_size = self.set_optimal_context_size(
llm, model, prompt=message.context_prompt
)
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
message.response = f"Processing {'RAG augmented ' if message.metadata.rag else ''}query..."
message.status = "thinking"
yield message

View File

@ -19,9 +19,10 @@ import time
import asyncio
import numpy as np # type: ignore
from .base import Agent, agent_registry, LLMMessage
from ..message import Message
from ..setup_logging import setup_logging
from . base import Agent, agent_registry, LLMMessage
from .. message import Message
from .. rag import ChromaDBGetResponse
from .. setup_logging import setup_logging
logger = setup_logging()
@ -91,8 +92,11 @@ class JobDescription(Agent):
await asyncio.sleep(1) # Allow the event loop to process the write
self.context.processing = True
job_description = message.preamble["job_description"]
resume = message.preamble["resume"]
original_message = message.model_copy()
original_message.prompt = job_description
original_message.response = ""
self.conversation.add(original_message)
@ -100,11 +104,9 @@ class JobDescription(Agent):
self.model = model
self.metrics.generate_count.labels(agent=self.agent_type).inc()
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
job_description = message.preamble["job_description"]
resume = message.preamble["resume"]
try:
async for message in self.generate_factual_tailored_resume(
async for message in self.generate_resume(
message=message, job_description=job_description, resume=resume
):
if message.status != "done":
@ -124,7 +126,7 @@ class JobDescription(Agent):
# Done processing, add message to conversation
self.context.processing = False
resume_generation = message.metadata.get("resume_generation", {})
resume_generation = message.metadata.resume_generation
if not resume_generation:
message.response = (
"Generation did not generate metadata necessary for processing."
@ -341,18 +343,19 @@ Name: {candidate_name}
## OUTPUT FORMAT:
Provide the resume in clean markdown format, ready for the candidate to use.
## REFERENCE (Original Resume):
"""
# ## REFERENCE (Original Resume):
# """
# Add a truncated version of the original resume for reference if it's too long
max_resume_length = 25000 # Characters
if len(original_resume) > max_resume_length:
system_prompt += (
original_resume[:max_resume_length]
+ "...\n[Original resume truncated due to length]"
)
else:
system_prompt += original_resume
# # Add a truncated version of the original resume for reference if it's too long
# max_resume_length = 25000 # Characters
# if len(original_resume) > max_resume_length:
# system_prompt += (
# original_resume[:max_resume_length]
# + "...\n[Original resume truncated due to length]"
# )
# else:
# system_prompt += original_resume
prompt = "Create a tailored professional resume that highlights candidate's skills and experience most relevant to the job requirements. Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no commentary before or after."
return system_prompt, prompt
@ -413,7 +416,7 @@ Provide the resume in clean markdown format, ready for the candidate to use.
yield message
return
def calculate_match_statistics(self, job_requirements, skill_assessment_results):
def calculate_match_statistics(self, job_requirements, skill_assessment_results) -> dict[str, dict[str, Any]]:
"""
Calculate statistics about how well the candidate matches job requirements
@ -594,9 +597,10 @@ a SPECIFIC skill based solely on their resume and supporting evidence.
}}
```
## CANDIDATE RESUME:
{resume}
"""
# ## CANDIDATE RESUME:
# {resume}
# """
# Add RAG content if provided
if rag_content:
@ -821,7 +825,7 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
LLMMessage(role="system", content=system_prompt),
LLMMessage(role="user", content=prompt),
]
message.metadata["options"] = {
message.metadata.options = {
"seed": 8911,
"num_ctx": self.context_size,
"temperature": temperature, # Higher temperature to encourage tool usage
@ -837,7 +841,7 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
model=self.model,
messages=messages,
options={
**message.metadata["options"],
**message.metadata.options,
},
stream=True,
):
@ -860,54 +864,40 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
if response.done:
self.collect_metrics(response)
message.metadata["eval_count"] += response.eval_count
message.metadata["eval_duration"] += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count
message.metadata[
"prompt_eval_duration"
] += response.prompt_eval_duration
message.metadata.eval_count += response.eval_count
message.metadata.eval_duration += response.eval_duration
message.metadata.prompt_eval_count += response.prompt_eval_count
message.metadata.prompt_eval_duration += response.prompt_eval_duration
self.context_tokens = response.prompt_eval_count + response.eval_count
message.chunk = ""
message.status = "done"
yield message
def rag_function(self, skill: str) -> tuple[str, list[Any]]:
def retrieve_rag_content(self, skill: str) -> tuple[str, ChromaDBGetResponse]:
if self.context is None or self.context.file_watcher is None:
raise ValueError("self.context or self.context.file_watcher is None")
try:
rag_results = ""
all_metadata = []
chroma_results = self.context.file_watcher.find_similar(
query=skill, top_k=5, threshold=0.5
)
rag_metadata = ChromaDBGetResponse()
chroma_results = self.context.file_watcher.find_similar(query=skill, top_k=10, threshold=0.5)
if chroma_results:
chroma_embedding = np.array(
chroma_results["query_embedding"]
).flatten() # Ensure correct shape
print(f"Chroma embedding shape: {chroma_embedding.shape}")
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
umap_2d = self.context.file_watcher.umap_model_2d.transform(
[chroma_embedding]
)[0].tolist()
print(
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
) # Debug output
umap_2d = self.context.file_watcher.umap_model_2d.transform([query_embedding])[0]
umap_3d = self.context.file_watcher.umap_model_3d.transform([query_embedding])[0]
umap_3d = self.context.file_watcher.umap_model_3d.transform(
[chroma_embedding]
)[0].tolist()
print(
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
) # Debug output
all_metadata.append(
{
"name": "JPK",
**chroma_results,
"umap_embedding_2d": umap_2d,
"umap_embedding_3d": umap_3d,
}
rag_metadata = ChromaDBGetResponse(
query=skill,
query_embedding=query_embedding.tolist(),
name="JPK",
ids=chroma_results.get("ids", []),
embeddings=chroma_results.get("embeddings", []),
documents=chroma_results.get("documents", []),
metadatas=chroma_results.get("metadatas", []),
umap_embedding_2d=umap_2d.tolist(),
umap_embedding_3d=umap_3d.tolist(),
size=self.context.file_watcher.collection.count()
)
for index, metadata in enumerate(chroma_results["metadatas"]):
@ -919,18 +909,19 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
]
).strip()
rag_results += f"""
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")} lines {metadata.get("line_begin", 0)}-{metadata.get("line_end", 0)}
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
Document reference: {chroma_results["ids"][index]}
Content: { content }
"""
return rag_results, all_metadata
return rag_results, rag_metadata
except Exception as e:
logger.error(e)
logger.error(traceback.format_exc())
logger.error(e)
exit(0)
async def generate_factual_tailored_resume(
async def generate_resume(
self, message: Message, job_description: str, resume: str
) -> AsyncGenerator[Message, None]:
"""
@ -947,12 +938,8 @@ Content: { content }
if self.context is None:
raise ValueError(f"context is None in {self.agent_type}")
message.status = "thinking"
logger.info(message.response)
yield message
message.metadata["resume_generation"] = {}
metadata = message.metadata["resume_generation"]
message.metadata.resume_generation = {}
metadata = message.metadata.resume_generation
# Stage 1A: Analyze job requirements
streaming_message = Message(prompt="Analyze job requirements")
streaming_message.status = "thinking"
@ -975,8 +962,8 @@ Content: { content }
prompts = self.process_job_requirements(
job_requirements=job_requirements,
resume=resume,
rag_function=self.rag_function,
) # , retrieve_rag_content)
rag_function=self.retrieve_rag_content,
)
# UI should persist this state of the message
partial = message.model_copy()
@ -1051,9 +1038,15 @@ Content: { content }
partial.title = (
f"Skill {index}/{total_prompts}: {description} [{match_level}]"
)
partial.metadata["rag"] = rag
partial.metadata.rag = [rag] # Front-end expects a list of RAG retrievals
if skill_description:
partial.response += f"\n\n{skill_description}"
partial.response = f"""
```json
{json.dumps(skill_assessment_results[skill_name]["skill_assessment"])}
```
{skill_description}
"""
yield partial
self.conversation.add(partial)

View File

@ -7,9 +7,10 @@ import numpy as np # type: ignore
import logging
from uuid import uuid4
from prometheus_client import CollectorRegistry, Counter # type: ignore
import traceback
from .message import Message, Tunables
from .rag import ChromaDBFileWatcher
from .rag import ChromaDBFileWatcher, ChromaDBGetResponse
from . import defines
from . import tools as Tools
from .agents import AnyAgent
@ -35,7 +36,7 @@ class Context(BaseModel):
user_job_description: Optional[str] = None
user_facts: Optional[str] = None
tools: List[dict] = Tools.enabled_tools(Tools.tools)
rags: List[dict] = []
rags: List[ChromaDBGetResponse] = []
message_history_length: int = 5
# Class managed fields
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
@ -82,56 +83,40 @@ class Context(BaseModel):
if not self.file_watcher:
message.response = "No RAG context available."
del message.metadata["rag"]
message.status = "done"
yield message
return
message.metadata["rag"] = []
for rag in self.rags:
if not rag["enabled"]:
if not rag.enabled:
continue
message.response = f"Checking RAG context {rag['name']}..."
message.response = f"Checking RAG context {rag.name}..."
yield message
chroma_results = self.file_watcher.find_similar(
query=message.prompt, top_k=top_k, threshold=threshold
)
if chroma_results:
entries += len(chroma_results["documents"])
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
chroma_embedding = np.array(
chroma_results["query_embedding"]
).flatten() # Ensure correct shape
print(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([query_embedding])[0]
umap_3d = self.file_watcher.umap_model_3d.transform([query_embedding])[0]
umap_2d = self.file_watcher.umap_model_2d.transform(
[chroma_embedding]
)[0].tolist()
print(
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
) # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform(
[chroma_embedding]
)[0].tolist()
print(
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
) # Debug output
message.metadata["rag"].append(
{
"name": rag["name"],
**chroma_results,
"umap_embedding_2d": umap_2d,
"umap_embedding_3d": umap_3d,
"size": self.file_watcher.collection.count()
}
rag_metadata = ChromaDBGetResponse(
query=message.prompt,
query_embedding=query_embedding.tolist(),
name=rag.name,
ids=chroma_results.get("ids", []),
embeddings=chroma_results.get("embeddings", []),
documents=chroma_results.get("documents", []),
metadatas=chroma_results.get("metadatas", []),
umap_embedding_2d=umap_2d.tolist(),
umap_embedding_3d=umap_3d.tolist(),
size=self.file_watcher.collection.count()
)
message.response = f"Results from {rag['name']} RAG: {len(chroma_results['documents'])} results."
yield message
if entries == 0:
del message.metadata["rag"]
message.metadata.rag.append(rag_metadata)
message.response = f"Results from {rag.name} RAG: {len(chroma_results['documents'])} results."
yield message
message.response = (
f"RAG context gathered from results from {entries} documents."
@ -142,7 +127,8 @@ class Context(BaseModel):
except Exception as e:
message.status = "error"
message.response = f"Error generating RAG results: {str(e)}"
logger.error(e)
logger.error(traceback.format_exc())
logger.error(message.response)
yield message
return

View File

@ -1,12 +1,28 @@
from pydantic import BaseModel, Field # type: ignore
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, Union, Mapping
from datetime import datetime, timezone
from . rag import ChromaDBGetResponse
from ollama._types import Options # type: ignore
class Tunables(BaseModel):
enable_rag: bool = Field(default=True) # Enable RAG collection chromadb matching
enable_tools: bool = Field(default=True) # Enable LLM to use tools
enable_context: bool = Field(default=True) # Add <|context|> field to message
enable_rag: bool = True # Enable RAG collection chromadb matching
enable_tools: bool = True # Enable LLM to use tools
enable_context: bool = True # Add <|context|> field to message
class MessageMetaData(BaseModel):
rag: List[ChromaDBGetResponse] = Field(default_factory=list)
eval_count: int = 0
eval_duration: int = 0
prompt_eval_count: int = 0
prompt_eval_duration: int = 0
context_size: int = 0
resume_generation: Optional[Dict[str, Any]] = None
options: Optional[Union[Mapping[str, Any], Options]] = None
tools: Optional[Dict[str, Any]] = None
timers: Optional[Dict[str, float]] = None
#resume : str = ""
#match_stats: Optional[Dict[str, Dict[str, Any]]] = Field(default=None)
class Message(BaseModel):
model_config = {"arbitrary_types_allowed": True} # Allow Event
@ -18,35 +34,21 @@ class Message(BaseModel):
# Generated while processing message
status: str = "" # Status of the message
preamble: dict[str, str] = {} # Preamble to be prepended to the prompt
preamble: Dict[str, Any] = Field(default_factory=dict) # Preamble to be prepended to the prompt
system_prompt: str = "" # System prompt provided to the LLM
context_prompt: str = "" # Full content of the message (preamble + prompt)
response: str = "" # LLM response to the preamble + query
metadata: Dict[str, Any] = Field(
default_factory=lambda: {
"rag": [],
"eval_count": 0,
"eval_duration": 0,
"prompt_eval_count": 0,
"prompt_eval_duration": 0,
"context_size": 0,
}
)
metadata: MessageMetaData = Field(default_factory=MessageMetaData)
network_packets: int = 0 # Total number of streaming packets
network_bytes: int = 0 # Total bytes sent while streaming packets
actions: List[str] = (
[]
) # Other session modifying actions performed while processing the message
timestamp: datetime = datetime.now(timezone.utc)
chunk: str = Field(
default=""
) # This needs to be serialized so it will be sent in responses
partial_response: str = Field(
default=""
) # This needs to be serialized so it will be sent in responses on timeout
title: str = Field(
default=""
) # This needs to be serialized so it will be sent in responses on timeout
timestamp: str = str(datetime.now(timezone.utc))
chunk: str = ""
partial_response: str = ""
title: str = ""
context_size: int = 0
def add_action(self, action: str | list[str]) -> None:
"""Add a actions(s) to the message."""

View File

@ -1,5 +1,5 @@
from pydantic import BaseModel # type: ignore
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
from typing import List, Optional, Dict, Any, Union
import os
import glob
from pathlib import Path
@ -7,15 +7,9 @@ import time
import hashlib
import asyncio
import logging
import os
import glob
import time
import hashlib
import asyncio
import json
import numpy as np # type: ignore
import traceback
import os
import chromadb
import ollama
@ -38,18 +32,50 @@ else:
# When imported as a module, use relative imports
from . import defines
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "ChromaDBGetResponse"]
DEFAULT_CHUNK_SIZE = 750
DEFAULT_CHUNK_OVERLAP = 100
class ChromaDBGetResponse(BaseModel):
ids: List[str]
embeddings: Optional[List[List[float]]] = None
documents: Optional[List[str]] = None
metadatas: Optional[List[Dict[str, Any]]] = None
name: str = ""
size: int = 0
ids: List[str] = []
embeddings: List[List[float]] = Field(default=[])
documents: List[str] = []
metadatas: List[Dict[str, Any]] = []
query: str = ""
query_embedding: Optional[List[float]] = Field(default=None)
umap_embedding_2d: Optional[List[float]] = Field(default=None)
umap_embedding_3d: Optional[List[float]] = Field(default=None)
enabled: bool = True
class Config:
validate_assignment = True
@field_validator("embeddings", "query_embedding", "umap_embedding_2d", "umap_embedding_3d")
@classmethod
def validate_embeddings(cls, value, field):
logging.info(f"Validating {field.field_name} with value: {type(value)} - {value}")
if value is None:
return value
if isinstance(value, np.ndarray):
if field.field_name == "embeddings":
if value.ndim != 2:
raise ValueError(f"{field.name} must be a 2-dimensional NumPy array")
return [[float(x) for x in row] for row in value.tolist()]
else:
if value.ndim != 1:
raise ValueError(f"{field.field_name} must be a 1-dimensional NumPy array")
return [float(x) for x in value.tolist()]
if field.field_name == "embeddings":
if not all(isinstance(sublist, list) and all(isinstance(x, (int, float)) for x in sublist) for sublist in value):
raise ValueError(f"{field.field_name} must be a list of lists of floats")
return [[float(x) for x in sublist] for sublist in value]
else:
if not isinstance(value, list) or not all(isinstance(x, (int, float)) for x in value):
raise ValueError(f"{field.field_name} must be a list of floats")
return [float(x) for x in value]
class ChromaDBFileWatcher(FileSystemEventHandler):
def __init__(
@ -323,13 +349,13 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
n_components=2,
random_state=8911,
metric="cosine",
n_neighbors=15,
n_neighbors=30,
min_dist=0.1,
)
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
logging.info(
f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
) # Should be 2
# logging.info(
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
# ) # Should be 2
logging.info(
f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors"
@ -338,13 +364,13 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
n_components=3,
random_state=8911,
metric="cosine",
n_neighbors=15,
min_dist=0.1,
n_neighbors=30,
min_dist=0.01,
)
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
logging.info(
f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
) # Should be 3
# logging.info(
# f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
# ) # Should be 3
def _get_vector_collection(self, recreate=False) -> Collection:
"""Get or create a ChromaDB collection."""
@ -380,14 +406,36 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
"""Split documents into chunks using the text splitter."""
return self.text_splitter.split_documents(docs)
def get_embedding(self, text, normalize=True):
"""Generate embeddings using Ollama."""
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
embedding = response["embedding"]
def get_embedding(self, text: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text."""
# Get embedding
try:
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
embedding = np.array(response["embedding"])
except Exception as e:
logging.error(f"Failed to get embedding: {e}")
raise
# Log diagnostics
logging.info(f"Input text: {text}")
logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
# Check for invalid embeddings
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
raise ValueError("Invalid embedding returned from Ollama.")
# Check normalization
norm = np.linalg.norm(embedding)
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
# Normalize if needed
if not is_normalized:
embedding = embedding / norm
logging.info("Embedding normalized manually.")
if normalize:
normalized = self._normalize_embeddings(embedding)
return normalized
return embedding
def add_embeddings_to_collection(self, chunks: List[Chunk]):