All seems to be working
This commit is contained in:
parent
0d27239ca6
commit
58dadb76f0
11
Dockerfile
11
Dockerfile
@ -77,6 +77,17 @@ RUN apt-get update \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||
|
||||
# pydub is loaded by torch, which will throw a warning if ffmpeg isn't installed
|
||||
RUN apt-get update \
|
||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y software-properties-common \
|
||||
&& add-apt-repository -y ppa:kobuk-team/intel-graphics \
|
||||
&& apt-get update \
|
||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
ffmpeg \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||
|
||||
|
||||
# Prerequisite for ze-monitor
|
||||
RUN apt-get update \
|
||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
|
0
cache/.keep
vendored
Normal file → Executable file
0
cache/.keep
vendored
Normal file → Executable file
0
cache/grafana/.keep
vendored
Normal file → Executable file
0
cache/grafana/.keep
vendored
Normal file → Executable file
0
cache/prometheus/.keep
vendored
Normal file → Executable file
0
cache/prometheus/.keep
vendored
Normal file → Executable file
1119
frontend/package-lock.json
generated
1119
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -19,6 +19,7 @@
|
||||
"@types/react": "^19.0.12",
|
||||
"@types/react-dom": "^19.0.4",
|
||||
"@uiw/react-json-view": "^2.0.0-alpha.31",
|
||||
"@uiw/react-markdown-editor": "^6.1.4",
|
||||
"jsonrepair": "^3.12.0",
|
||||
"markdown-it": "^14.1.0",
|
||||
"mermaid": "^11.6.0",
|
||||
|
@ -113,6 +113,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
||||
|
||||
const tunables = await response.json();
|
||||
serverTunables.system_prompt = tunables.system_prompt;
|
||||
console.log(tunables);
|
||||
setSystemPrompt(tunables.system_prompt)
|
||||
setSnack("System prompt updated", "success");
|
||||
} catch (error) {
|
||||
@ -167,31 +168,36 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
||||
body: JSON.stringify({ "reset": types }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
if (data.error) {
|
||||
throw Error()
|
||||
}
|
||||
for (const [key, value] of Object.entries(data)) {
|
||||
switch (key) {
|
||||
case "rags":
|
||||
setRags(value as Tool[]);
|
||||
break;
|
||||
case "tools":
|
||||
setTools(value as Tool[]);
|
||||
break;
|
||||
case "system_prompt":
|
||||
setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
|
||||
break;
|
||||
case "history":
|
||||
console.log('TODO: handle history reset');
|
||||
break;
|
||||
}
|
||||
}
|
||||
setSnack(message, "success");
|
||||
} else {
|
||||
throw Error(`${{ status: response.status, message: response.statusText }}`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error('Response body is null');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
if (data.error) {
|
||||
throw Error(data.error);
|
||||
}
|
||||
|
||||
for (const [key, value] of Object.entries(data)) {
|
||||
switch (key) {
|
||||
case "rags":
|
||||
setRags(value as Tool[]);
|
||||
break;
|
||||
case "tools":
|
||||
setTools(value as Tool[]);
|
||||
break;
|
||||
case "system_prompt":
|
||||
setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
|
||||
break;
|
||||
case "history":
|
||||
console.log('TODO: handle history reset');
|
||||
break;
|
||||
}
|
||||
}
|
||||
setSnack(message, "success");
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack("Unable to restore defaults", "error");
|
||||
@ -203,20 +209,37 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
||||
if (systemInfo !== undefined || sessionId === undefined) {
|
||||
return;
|
||||
}
|
||||
fetch(connectionBase + `/api/system-info/${sessionId}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
const fetchSystemInfo = async () => {
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/system-info/${sessionId}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error('Response body is null');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
if (data.error) {
|
||||
throw Error(data.error);
|
||||
}
|
||||
|
||||
setSystemInfo(data);
|
||||
})
|
||||
.catch(error => {
|
||||
} catch (error) {
|
||||
console.error('Error obtaining system information:', error);
|
||||
setSnack("Unable to obtain system information.", "error");
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
fetchSystemInfo();
|
||||
|
||||
}, [systemInfo, setSystemInfo, setSnack, sessionId])
|
||||
|
||||
useEffect(() => {
|
||||
@ -273,25 +296,30 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
||||
return;
|
||||
}
|
||||
const fetchTunables = async () => {
|
||||
// Make the fetch request with proper headers
|
||||
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
});
|
||||
const data = await response.json();
|
||||
// console.log("Server tunables: ", data);
|
||||
setServerTunables(data);
|
||||
setSystemPrompt(data["system_prompt"]);
|
||||
setMessageHistoryLength(data["message_history_length"]);
|
||||
setTools(data["tools"]);
|
||||
setRags(data["rags"]);
|
||||
try {
|
||||
// Make the fetch request with proper headers
|
||||
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
});
|
||||
const data = await response.json();
|
||||
// console.log("Server tunables: ", data);
|
||||
setServerTunables(data);
|
||||
setSystemPrompt(data["system_prompt"]);
|
||||
setMessageHistoryLength(data["message_history_length"]);
|
||||
setTools(data["tools"]);
|
||||
setRags(data["rags"]);
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack("System prompt update failed", "error");
|
||||
}
|
||||
}
|
||||
|
||||
fetchTunables();
|
||||
}, [sessionId, setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags]);
|
||||
}, [sessionId, setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags, setSnack]);
|
||||
|
||||
const toggle = async (type: string, index: number) => {
|
||||
switch (type) {
|
||||
|
@ -285,7 +285,9 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
||||
throw new Error('Response body is null');
|
||||
}
|
||||
|
||||
setConversation([])
|
||||
setProcessingMessage(undefined);
|
||||
setStreamingMessage(undefined);
|
||||
setConversation([]);
|
||||
setNoInteractions(true);
|
||||
|
||||
} catch (e) {
|
||||
@ -341,13 +343,23 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
||||
// Add a small delay to ensure React has time to update the UI
|
||||
await new Promise(resolve => setTimeout(resolve, 0));
|
||||
|
||||
let data: any = query;
|
||||
if (type === "job_description") {
|
||||
data = {
|
||||
prompt: "",
|
||||
agent_options: {
|
||||
job_description: query.prompt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const response = await fetch(connectionBase + `/api/${type}/${sessionId}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(query)
|
||||
body: JSON.stringify(data)
|
||||
});
|
||||
|
||||
setSnack(`Query sent.`, "info");
|
||||
|
@ -77,7 +77,7 @@ interface MessageMetaData {
|
||||
vector_embedding: number[];
|
||||
},
|
||||
origin: string,
|
||||
rag: any,
|
||||
rag: any[],
|
||||
tools?: {
|
||||
tool_calls: any[],
|
||||
},
|
||||
@ -117,8 +117,6 @@ const MessageMeta = (props: MessageMetaProps) => {
|
||||
} = props.metadata || {};
|
||||
const message: any = props.messageProps.message;
|
||||
|
||||
rag.forEach((r: any) => r.query = message.prompt);
|
||||
|
||||
let llm_submission: string = "<|system|>\n"
|
||||
llm_submission += message.system_prompt + "\n\n"
|
||||
llm_submission += message.context_prompt
|
||||
@ -176,7 +174,10 @@ const MessageMeta = (props: MessageMetaProps) => {
|
||||
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "column", mt: 1, mb: 1, fontWeight: "bold" }}>
|
||||
{tool.name}
|
||||
</Box>
|
||||
<JsonView displayDataTypes={false} objectSortKeys={true} collapsed={1} value={JSON.parse(tool.content)} style={{ fontSize: "0.8rem", maxHeight: "20rem", overflow: "auto" }}>
|
||||
<JsonView
|
||||
displayDataTypes={false}
|
||||
objectSortKeys={true}
|
||||
collapsed={1} value={JSON.parse(tool.content)} style={{ fontSize: "0.8rem", maxHeight: "20rem", overflow: "auto" }}>
|
||||
<JsonView.String
|
||||
render={({ children, ...reset }) => {
|
||||
if (typeof (children) === "string" && children.match("\n")) {
|
||||
|
@ -151,7 +151,6 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
|
||||
}, []);
|
||||
|
||||
const jobResponse = useCallback(async (message: BackstoryMessage) => {
|
||||
console.log('onJobResponse', message);
|
||||
if (message.actions && message.actions.includes("job_description")) {
|
||||
await jobConversationRef.current.fetchHistory();
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
|
||||
}}
|
||||
displayDataTypes={false}
|
||||
objectSortKeys={false}
|
||||
collapsed={true}
|
||||
collapsed={1}
|
||||
shortenTextAfterLength={100}
|
||||
value={fixed}>
|
||||
<JsonView.String
|
||||
|
@ -10,7 +10,6 @@ import FormControlLabel from '@mui/material/FormControlLabel';
|
||||
import Switch from '@mui/material/Switch';
|
||||
import useMediaQuery from '@mui/material/useMediaQuery';
|
||||
import { SxProps, useTheme } from '@mui/material/styles';
|
||||
import JsonView from '@uiw/react-json-view';
|
||||
import Table from '@mui/material/Table';
|
||||
import TableBody from '@mui/material/TableBody';
|
||||
import TableCell from '@mui/material/TableCell';
|
||||
@ -499,13 +498,13 @@ The scatter graph shows the query in N-dimensional space, mapped to ${view2D ? '
|
||||
</Box>
|
||||
}
|
||||
|
||||
<Box sx={{ display: "flex", flexDirection: "column" }}>
|
||||
<Box sx={{ display: "flex", flexDirection: "column", flexGrow: 1 }}>
|
||||
{node === null &&
|
||||
<Paper sx={{ m: 0.5, p: 2, flexGrow: 1 }}>
|
||||
Click a point in the scatter-graph to see information about that node.
|
||||
</Paper>
|
||||
}
|
||||
{!inline && node !== null && node.full_content &&
|
||||
{node !== null && node.full_content &&
|
||||
<Scrollable
|
||||
autoscroll={false}
|
||||
sx={{
|
||||
@ -521,7 +520,7 @@ The scatter graph shows the query in N-dimensional space, mapped to ${view2D ? '
|
||||
node.full_content.split('\n').map((line, index) => {
|
||||
index += 1 + node.chunk_begin;
|
||||
const bgColor = (index > node.line_begin && index <= node.line_end) ? '#f0f0f0' : 'auto';
|
||||
return <Box key={index} sx={{ display: "flex", flexDirection: "row", borderBottom: '1px solid #d0d0d0', ':first-child': { borderTop: '1px solid #d0d0d0' }, backgroundColor: bgColor }}>
|
||||
return <Box key={index} sx={{ display: "flex", flexDirection: "row", borderBottom: '1px solid #d0d0d0', ':first-of-type': { borderTop: '1px solid #d0d0d0' }, backgroundColor: bgColor }}>
|
||||
<Box sx={{ fontFamily: 'courier', fontSize: "0.8rem", minWidth: "2rem", pt: "0.1rem", align: "left", verticalAlign: "top" }}>{index}</Box>
|
||||
<pre style={{ margin: 0, padding: 0, border: "none", minHeight: "1rem" }} >{line || " "}</pre>
|
||||
</Box>;
|
||||
|
@ -118,6 +118,10 @@ const useAutoScrollToBottom = (
|
||||
let shouldScroll = false;
|
||||
const scrollTo = scrollToRef.current;
|
||||
|
||||
if (isPasteEvent && !scrollTo) {
|
||||
console.error("Paste Event triggered without scrollTo");
|
||||
}
|
||||
|
||||
if (scrollTo) {
|
||||
// Get positions
|
||||
const containerRect = container.getBoundingClientRect();
|
||||
@ -130,7 +134,7 @@ const useAutoScrollToBottom = (
|
||||
scrollToRect.top < containerBottom && scrollToRect.bottom > containerTop;
|
||||
|
||||
// Scroll on paste or if TextField is visible and user isn't scrolling up
|
||||
shouldScroll = (isPasteEvent || isTextFieldVisible) && !isUserScrollingUpRef.current;
|
||||
shouldScroll = isPasteEvent || (isTextFieldVisible && !isUserScrollingUpRef.current);
|
||||
if (shouldScroll) {
|
||||
requestAnimationFrame(() => {
|
||||
debug && console.debug('Scrolling to container bottom:', {
|
||||
@ -198,10 +202,12 @@ const useAutoScrollToBottom = (
|
||||
}
|
||||
|
||||
const handlePaste = () => {
|
||||
console.log("handlePaste");
|
||||
// Delay scroll check to ensure DOM updates
|
||||
setTimeout(() => {
|
||||
console.log("scrolling for handlePaste");
|
||||
requestAnimationFrame(() => checkAndScrollToBottom(true));
|
||||
}, 0);
|
||||
}, 100);
|
||||
};
|
||||
|
||||
window.addEventListener('mousemove', pauseScroll);
|
||||
|
@ -1,5 +1,5 @@
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
line-length = 120
|
||||
target-version = ['py312']
|
||||
include = '\.pyi?$'
|
||||
exclude = '''
|
||||
|
@ -1,7 +1,9 @@
|
||||
LLM_TIMEOUT = 600
|
||||
|
||||
from utils import logger
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from pydantic import BaseModel, Field, ValidationError # type: ignore
|
||||
from pydantic_core import PydanticSerializationError # type: ignore
|
||||
from typing import List
|
||||
|
||||
from typing import AsyncGenerator, Dict, Optional
|
||||
|
||||
@ -61,6 +63,7 @@ from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||
|
||||
from utils import (
|
||||
rag as Rag,
|
||||
ChromaDBGetResponse,
|
||||
tools as Tools,
|
||||
Context,
|
||||
Conversation,
|
||||
@ -69,15 +72,17 @@ from utils import (
|
||||
Metrics,
|
||||
Tunables,
|
||||
defines,
|
||||
check_serializable,
|
||||
logger,
|
||||
)
|
||||
|
||||
rags = [
|
||||
{
|
||||
"name": "JPK",
|
||||
"enabled": True,
|
||||
"description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects.",
|
||||
},
|
||||
|
||||
rags : List[ChromaDBGetResponse] = [
|
||||
ChromaDBGetResponse(
|
||||
name="JPK",
|
||||
enabled=True,
|
||||
description="Expert data about James Ketrenos, including work history, personal hobbies, and projects.",
|
||||
),
|
||||
# { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
|
||||
]
|
||||
|
||||
@ -461,10 +466,8 @@ class WebServer:
|
||||
context = self.upsert_context(context_id)
|
||||
agent = context.get_agent(agent_type)
|
||||
if not agent:
|
||||
return JSONResponse(
|
||||
{"error": f"{agent_type} is not recognized", "context": context.id},
|
||||
status_code=404,
|
||||
)
|
||||
response = { "history": [] }
|
||||
return JSONResponse(response)
|
||||
|
||||
data = await request.json()
|
||||
try:
|
||||
@ -475,8 +478,8 @@ class WebServer:
|
||||
logger.info(f"Resetting {reset_operation}")
|
||||
case "rags":
|
||||
logger.info(f"Resetting {reset_operation}")
|
||||
context.rags = rags.copy()
|
||||
response["rags"] = context.rags
|
||||
context.rags = [ r.model_copy() for r in rags]
|
||||
response["rags"] = [ r.model_dump(mode="json") for r in context.rags ]
|
||||
case "tools":
|
||||
logger.info(f"Resetting {reset_operation}")
|
||||
context.tools = Tools.enabled_tools(Tools.tools)
|
||||
@ -537,6 +540,7 @@ class WebServer:
|
||||
data = await request.json()
|
||||
agent = context.get_agent("chat")
|
||||
if not agent:
|
||||
logger.info("chat agent does not exist on this context!")
|
||||
return JSONResponse(
|
||||
{"error": f"chat is not recognized", "context": context.id},
|
||||
status_code=404,
|
||||
@ -572,20 +576,20 @@ class WebServer:
|
||||
|
||||
case "rags":
|
||||
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
||||
rags: list[dict[str, Any]] = data[k]
|
||||
if not rags:
|
||||
rag_configs: list[dict[str, Any]] = data[k]
|
||||
if not rag_configs:
|
||||
return JSONResponse(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "RAGs can not be empty.",
|
||||
}
|
||||
)
|
||||
for rag in rags:
|
||||
for config in rag_configs:
|
||||
for context_rag in context.rags:
|
||||
if context_rag["name"] == rag["name"]:
|
||||
context_rag["enabled"] = rag["enabled"]
|
||||
if context_rag.name == config["name"]:
|
||||
context_rag.enabled = config["enabled"]
|
||||
self.save_context(context_id)
|
||||
return JSONResponse({"rags": context.rags})
|
||||
return JSONResponse({"rags": [ r.model_dump(mode="json") for r in context.rags]})
|
||||
|
||||
case "system_prompt":
|
||||
system_prompt = data[k].strip()
|
||||
@ -615,12 +619,10 @@ class WebServer:
|
||||
@self.app.get("/api/tunables/{context_id}")
|
||||
async def get_tunables(context_id: str, request: Request):
|
||||
logger.info(f"{request.method} {request.url.path}")
|
||||
if not is_valid_uuid(context_id):
|
||||
logger.warning(f"Invalid context_id: {context_id}")
|
||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||
context = self.upsert_context(context_id)
|
||||
agent = context.get_agent("chat")
|
||||
if not agent:
|
||||
logger.info("chat agent does not exist on this context!")
|
||||
return JSONResponse(
|
||||
{"error": f"chat is not recognized", "context": context.id},
|
||||
status_code=404,
|
||||
@ -629,7 +631,7 @@ class WebServer:
|
||||
{
|
||||
"system_prompt": agent.system_prompt,
|
||||
"message_history_length": context.message_history_length,
|
||||
"rags": context.rags,
|
||||
"rags": [ r.model_dump(mode="json") for r in context.rags ],
|
||||
"tools": [
|
||||
{
|
||||
**t["function"],
|
||||
@ -674,6 +676,7 @@ class WebServer:
|
||||
error = {
|
||||
"error": f"Attempt to create agent type: {agent_type} failed: {e}"
|
||||
}
|
||||
logger.info(error)
|
||||
return JSONResponse(error, status_code=404)
|
||||
|
||||
try:
|
||||
@ -887,8 +890,35 @@ class WebServer:
|
||||
file_path = os.path.join(defines.context_dir, context_id)
|
||||
|
||||
# Serialize the data to JSON and write to file
|
||||
with open(file_path, "w") as f:
|
||||
f.write(context.model_dump_json(by_alias=True))
|
||||
try:
|
||||
# Check for non-serializable fields before dumping
|
||||
serialization_errors = check_serializable(context)
|
||||
if serialization_errors:
|
||||
for error in serialization_errors:
|
||||
logger.error(error)
|
||||
raise ValueError("Found non-serializable fields in the model")
|
||||
# Dump the model prior to opening file in case there is
|
||||
# a validation error so it doesn't delete the current
|
||||
# context session
|
||||
json_data = context.model_dump_json(by_alias=True)
|
||||
with open(file_path, "w") as f:
|
||||
f.write(json_data)
|
||||
except ValidationError as e:
|
||||
logger.error(e)
|
||||
logger.error(traceback.format_exc())
|
||||
for error in e.errors():
|
||||
print(f"Field: {error['loc'][0]}, Error: {error['msg']}")
|
||||
except PydanticSerializationError as e:
|
||||
logger.error(e)
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Serialization error: {str(e)}")
|
||||
# Inspect the model to identify problematic fields
|
||||
for field_name, value in context.__dict__.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
logger.error(f"Field '{field_name}' contains non-serializable type: {type(value)}")
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(e)
|
||||
|
||||
return context_id
|
||||
|
||||
@ -942,12 +972,8 @@ class WebServer:
|
||||
self.contexts[context_id] = context
|
||||
|
||||
logger.info(f"Successfully loaded context {context_id}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON in file: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating context: {str(e)}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
# Fallback to creating a new context
|
||||
self.contexts[context_id] = Context(
|
||||
@ -985,7 +1011,7 @@ class WebServer:
|
||||
# context.add_agent(JobDescription(system_prompt = system_job_description))
|
||||
# context.add_agent(FactCheck(system_prompt = system_fact_check))
|
||||
context.tools = Tools.enabled_tools(Tools.tools)
|
||||
context.rags = rags.copy()
|
||||
context.rags_enabled = [ r.name for r in rags ]
|
||||
|
||||
logger.info(f"{context.id} created and added to contexts.")
|
||||
self.contexts[context.id] = context
|
||||
|
BIN
src/tests/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/tests/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/tests/__pycache__/test-context.cpython-312.pyc
Normal file
BIN
src/tests/__pycache__/test-context.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/tests/__pycache__/test-embedding.cpython-312.pyc
Normal file
BIN
src/tests/__pycache__/test-embedding.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/tests/__pycache__/test-message.cpython-312.pyc
Normal file
BIN
src/tests/__pycache__/test-message.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/tests/__pycache__/test-rag.cpython-312.pyc
Normal file
BIN
src/tests/__pycache__/test-rag.cpython-312.pyc
Normal file
Binary file not shown.
@ -20,5 +20,5 @@ observer, file_watcher = Rag.start_file_watcher(
|
||||
)
|
||||
|
||||
context = Context(file_watcher=file_watcher)
|
||||
data = context.model_dump(mode="json")
|
||||
context = Context.from_json(json.dumps(data), file_watcher=file_watcher)
|
||||
json_data = context.model_dump(mode="json")
|
||||
context = Context.model_validate(json_data)
|
||||
|
89
src/tests/test-embedding.py
Normal file
89
src/tests/test-embedding.py
Normal file
@ -0,0 +1,89 @@
|
||||
# From /opt/backstory run:
|
||||
# python -m src.tests.test-embedding
|
||||
import numpy as np # type: ignore
|
||||
import logging
|
||||
import argparse
|
||||
from ollama import Client # type: ignore
|
||||
from ..utils import defines
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
def get_embedding(text: str, embedding_model: str, ollama_server: str) -> np.ndarray:
|
||||
"""Generate and normalize an embedding for the given text."""
|
||||
llm = Client(host=ollama_server)
|
||||
|
||||
# Get embedding
|
||||
try:
|
||||
response = llm.embeddings(model=embedding_model, prompt=text)
|
||||
embedding = np.array(response["embedding"])
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get embedding: {e}")
|
||||
raise
|
||||
|
||||
# Log diagnostics
|
||||
logging.info(f"Input text: {text}")
|
||||
logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
|
||||
|
||||
# Check for invalid embeddings
|
||||
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
|
||||
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
|
||||
raise ValueError("Invalid embedding returned from Ollama.")
|
||||
|
||||
# Check normalization
|
||||
norm = np.linalg.norm(embedding)
|
||||
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
|
||||
logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
|
||||
|
||||
# Normalize if needed
|
||||
if not is_normalized:
|
||||
embedding = embedding / norm
|
||||
logging.info("Embedding normalized manually.")
|
||||
|
||||
return embedding
|
||||
|
||||
def main():
|
||||
"""Main function to generate and normalize an embedding from command-line input."""
|
||||
parser = argparse.ArgumentParser(description="Generate embeddings for text using mxbai-embed-large.")
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
nargs="+", # Allow multiple text inputs
|
||||
default=["Test sentence."],
|
||||
help="Text(s) to generate embeddings for (default: 'Test sentence.')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ollama-server",
|
||||
type=str,
|
||||
default=defines.ollama_api_url,
|
||||
help=f"Ollama server URL (default: {defines.ollama_api_url})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default=defines.embedding_model,
|
||||
help=f"Embedding model name (default: {defines.embedding_model})",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate input
|
||||
for text in args.text:
|
||||
if not text or not isinstance(text, str):
|
||||
logging.error("Input text must be a non-empty string.")
|
||||
raise ValueError("Input text must be a non-empty string.")
|
||||
|
||||
# Generate embeddings for each text
|
||||
embeddings = []
|
||||
for text in args.text:
|
||||
embedding = get_embedding(
|
||||
text=text,
|
||||
embedding_model=args.embedding_model,
|
||||
ollama_server=args.ollama_server,
|
||||
)
|
||||
embeddings.append(embedding)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -2,10 +2,15 @@
|
||||
# python -m src.tests.test-message
|
||||
from ..utils import logger
|
||||
|
||||
from ..utils import Message
|
||||
from ..utils import Message, MessageMetaData
|
||||
from ..utils import ChromaDBGetResponse
|
||||
|
||||
import json
|
||||
|
||||
prompt = "This is a test"
|
||||
message = Message(prompt=prompt)
|
||||
print(message.model_dump(mode="json"))
|
||||
#message.metadata = MessageMetaData()
|
||||
rag = ChromaDBGetResponse()
|
||||
message.metadata.rag = rag
|
||||
print(message.model_dump(mode="json"))
|
||||
|
81
src/tests/test-rag.py
Normal file
81
src/tests/test-rag.py
Normal file
@ -0,0 +1,81 @@
|
||||
# From /opt/backstory run:
|
||||
# python -m src.tests.test-rag
|
||||
from ..utils import logger
|
||||
from pydantic import BaseModel, field_validator # type: ignore
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
from typing import List, Dict, Any, Optional
|
||||
import ollama
|
||||
import numpy as np # type: ignore
|
||||
from ..utils import (rag as Rag, ChromaDBGetResponse)
|
||||
from ..utils import Context
|
||||
from ..utils import defines
|
||||
|
||||
import json
|
||||
|
||||
chroma_results = {
|
||||
"ids": ["1", "2"],
|
||||
"embeddings": np.array([[1.0, 2.0], [3.0, 4.0]]),
|
||||
"documents": ["doc1", "doc2"],
|
||||
"metadatas": [{"meta": "data1"}, {"meta": "data2"}],
|
||||
"query_embedding": np.array([0.1, 0.2, 0.3])
|
||||
}
|
||||
|
||||
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
||||
umap_2d = np.array([0.4, 0.5]) # Example UMAP output
|
||||
umap_3d = np.array([0.6, 0.7, 0.8]) # Example UMAP output
|
||||
|
||||
rag_metadata = ChromaDBGetResponse(
|
||||
query="test",
|
||||
query_embedding=query_embedding,
|
||||
name="JPK",
|
||||
ids=chroma_results.get("ids", []),
|
||||
size=2
|
||||
)
|
||||
|
||||
logger.info(json.dumps(rag_metadata.model_dump(mode="json")))
|
||||
|
||||
logger.info(f"Assigning type {type(umap_2d)} to rag_metadata.umap_embedding_2d")
|
||||
rag_metadata.umap_embedding_2d = umap_2d
|
||||
|
||||
logger.info(json.dumps(rag_metadata.model_dump(mode="json")))
|
||||
|
||||
rag = ChromaDBGetResponse()
|
||||
rag.embeddings = np.array([[1.0, 2.0], [3.0, 4.0]])
|
||||
json_str = rag.model_dump(mode="json")
|
||||
logger.info(json_str)
|
||||
rag = ChromaDBGetResponse.model_validate(json_str)
|
||||
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
||||
prometheus_collector = CollectorRegistry()
|
||||
observer, file_watcher = Rag.start_file_watcher(
|
||||
llm=llm,
|
||||
watch_directory=defines.doc_dir,
|
||||
recreate=False, # Don't recreate if exists
|
||||
)
|
||||
context = Context(
|
||||
file_watcher=file_watcher,
|
||||
prometheus_collector=prometheus_collector,
|
||||
)
|
||||
skill="Codes in C++"
|
||||
if context.file_watcher:
|
||||
chroma_results = context.file_watcher.find_similar(query=skill, top_k=10, threshold=0.5)
|
||||
if chroma_results:
|
||||
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
||||
|
||||
umap_2d = context.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
||||
umap_3d = context.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
||||
|
||||
rag_metadata = ChromaDBGetResponse(
|
||||
query=skill,
|
||||
query_embedding=query_embedding,
|
||||
name="JPK",
|
||||
ids=chroma_results.get("ids", []),
|
||||
embeddings=chroma_results.get("embeddings", []),
|
||||
documents=chroma_results.get("documents", []),
|
||||
metadatas=chroma_results.get("metadatas", []),
|
||||
umap_embedding_2d=umap_2d,
|
||||
umap_embedding_3d=umap_3d,
|
||||
size=context.file_watcher.collection.count()
|
||||
)
|
||||
|
||||
json_str = context.model_dump(mode="json")
|
||||
logger.info(json_str)
|
@ -1,25 +1,34 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from typing import (
|
||||
Any,
|
||||
Set
|
||||
)
|
||||
import importlib
|
||||
import json
|
||||
|
||||
from . import defines
|
||||
from .context import Context
|
||||
from .conversation import Conversation
|
||||
from .message import Message, Tunables
|
||||
from .rag import ChromaDBFileWatcher, start_file_watcher
|
||||
from .message import Message, Tunables, MessageMetaData
|
||||
from .rag import ChromaDBFileWatcher, ChromaDBGetResponse, start_file_watcher
|
||||
from .setup_logging import setup_logging
|
||||
from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
|
||||
from .metrics import Metrics
|
||||
from .check_serializable import check_serializable
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"Message",
|
||||
"Tunables",
|
||||
"MessageMetaData",
|
||||
"Context",
|
||||
"Conversation",
|
||||
"Message",
|
||||
"Metrics",
|
||||
"ChromaDBFileWatcher",
|
||||
'ChromaDBGetResponse',
|
||||
"start_file_watcher",
|
||||
"check_serializable",
|
||||
"logger",
|
||||
]
|
||||
|
||||
|
@ -165,10 +165,9 @@ class Agent(BaseModel, ABC):
|
||||
if message.status != "done":
|
||||
yield message
|
||||
|
||||
if "rag" in message.metadata and message.metadata["rag"]:
|
||||
for rag in message.metadata["rag"]:
|
||||
for doc in rag["documents"]:
|
||||
rag_context += f"{doc}\n"
|
||||
for rag in message.metadata.rag:
|
||||
for doc in rag.documents:
|
||||
rag_context += f"{doc}\n"
|
||||
|
||||
message.preamble = {}
|
||||
|
||||
@ -189,7 +188,7 @@ class Agent(BaseModel, ABC):
|
||||
llm: Any,
|
||||
model: str,
|
||||
message: Message,
|
||||
tool_message: Any,
|
||||
tool_message: Any, # llama response message
|
||||
messages: List[LLMMessage],
|
||||
) -> AsyncGenerator[Message, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
@ -199,10 +198,10 @@ class Agent(BaseModel, ABC):
|
||||
|
||||
if not self.context:
|
||||
raise ValueError("Context is not set for this agent.")
|
||||
if not message.metadata["tools"]:
|
||||
if not message.metadata.tools:
|
||||
raise ValueError("tools field not initialized")
|
||||
|
||||
tool_metadata = message.metadata["tools"]
|
||||
tool_metadata = message.metadata.tools
|
||||
tool_metadata["tool_calls"] = []
|
||||
|
||||
message.status = "tooling"
|
||||
@ -301,8 +300,7 @@ class Agent(BaseModel, ABC):
|
||||
model=model,
|
||||
messages=messages,
|
||||
options={
|
||||
**message.metadata["options"],
|
||||
# "temperature": 0.5,
|
||||
**message.metadata.options,
|
||||
},
|
||||
stream=True,
|
||||
):
|
||||
@ -316,12 +314,10 @@ class Agent(BaseModel, ABC):
|
||||
|
||||
if response.done:
|
||||
self.collect_metrics(response)
|
||||
message.metadata["eval_count"] += response.eval_count
|
||||
message.metadata["eval_duration"] += response.eval_duration
|
||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
||||
message.metadata[
|
||||
"prompt_eval_duration"
|
||||
] += response.prompt_eval_duration
|
||||
message.metadata.eval_count += response.eval_count
|
||||
message.metadata.eval_duration += response.eval_duration
|
||||
message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||
self.context_tokens = (
|
||||
response.prompt_eval_count + response.eval_count
|
||||
)
|
||||
@ -329,9 +325,7 @@ class Agent(BaseModel, ABC):
|
||||
yield message
|
||||
|
||||
end_time = time.perf_counter()
|
||||
message.metadata["timers"][
|
||||
"llm_with_tools"
|
||||
] = f"{(end_time - start_time):.4f}"
|
||||
message.metadata.timers["llm_with_tools"] = end_time - start_time
|
||||
return
|
||||
|
||||
def collect_metrics(self, response):
|
||||
@ -370,22 +364,22 @@ class Agent(BaseModel, ABC):
|
||||
LLMMessage(role="user", content=message.context_prompt.strip())
|
||||
)
|
||||
|
||||
# message.metadata["messages"] = messages
|
||||
message.metadata["options"] = {
|
||||
# message.messages = messages
|
||||
message.metadata.options = {
|
||||
"seed": 8911,
|
||||
"num_ctx": self.context_size,
|
||||
"temperature": temperature, # Higher temperature to encourage tool usage
|
||||
}
|
||||
|
||||
# Create a dict for storing various timing stats
|
||||
message.metadata["timers"] = {}
|
||||
message.metadata.timers = {}
|
||||
|
||||
use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
||||
message.metadata["tools"] = {
|
||||
message.metadata.tools = {
|
||||
"available": llm_tools(self.context.tools),
|
||||
"used": False,
|
||||
}
|
||||
tool_metadata = message.metadata["tools"]
|
||||
tool_metadata = message.metadata.tools
|
||||
|
||||
if use_tools:
|
||||
message.status = "thinking"
|
||||
@ -408,17 +402,14 @@ class Agent(BaseModel, ABC):
|
||||
messages=tool_metadata["messages"],
|
||||
tools=tool_metadata["available"],
|
||||
options={
|
||||
**message.metadata["options"],
|
||||
# "num_predict": 1024, # "Low" token limit to cut off after tool call
|
||||
**message.metadata.options,
|
||||
},
|
||||
stream=False, # No need to stream the probe
|
||||
)
|
||||
self.collect_metrics(response)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
message.metadata["timers"][
|
||||
"tool_check"
|
||||
] = f"{(end_time - start_time):.4f}"
|
||||
message.metadata.timers["tool_check"] = end_time - start_time
|
||||
if not response.message.tool_calls:
|
||||
logger.info("LLM indicates tools will not be used")
|
||||
# The LLM will not use tools, so disable use_tools so we can stream the full response
|
||||
@ -442,16 +433,14 @@ class Agent(BaseModel, ABC):
|
||||
messages=tool_metadata["messages"], # messages,
|
||||
tools=tool_metadata["available"],
|
||||
options={
|
||||
**message.metadata["options"],
|
||||
**message.metadata.options,
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
self.collect_metrics(response)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
message.metadata["timers"][
|
||||
"non_streaming"
|
||||
] = f"{(end_time - start_time):.4f}"
|
||||
message.metadata.timers["non_streaming"] = end_time - start_time
|
||||
|
||||
if not response:
|
||||
message.status = "error"
|
||||
@ -475,9 +464,7 @@ class Agent(BaseModel, ABC):
|
||||
return
|
||||
yield message
|
||||
end_time = time.perf_counter()
|
||||
message.metadata["timers"][
|
||||
"process_tool_calls"
|
||||
] = f"{(end_time - start_time):.4f}"
|
||||
message.metadata.timers["process_tool_calls"] = end_time - start_time
|
||||
message.status = "done"
|
||||
return
|
||||
|
||||
@ -498,7 +485,7 @@ class Agent(BaseModel, ABC):
|
||||
model=model,
|
||||
messages=messages,
|
||||
options={
|
||||
**message.metadata["options"],
|
||||
**message.metadata.options,
|
||||
},
|
||||
stream=True,
|
||||
):
|
||||
@ -517,12 +504,10 @@ class Agent(BaseModel, ABC):
|
||||
|
||||
if response.done:
|
||||
self.collect_metrics(response)
|
||||
message.metadata["eval_count"] += response.eval_count
|
||||
message.metadata["eval_duration"] += response.eval_duration
|
||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
||||
message.metadata[
|
||||
"prompt_eval_duration"
|
||||
] += response.prompt_eval_duration
|
||||
message.metadata.eval_count += response.eval_count
|
||||
message.metadata.eval_duration += response.eval_duration
|
||||
message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||
self.context_tokens = (
|
||||
response.prompt_eval_count + response.eval_count
|
||||
)
|
||||
@ -530,7 +515,7 @@ class Agent(BaseModel, ABC):
|
||||
yield message
|
||||
|
||||
end_time = time.perf_counter()
|
||||
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
|
||||
message.metadata.timers["streamed"] = end_time - start_time
|
||||
return
|
||||
|
||||
async def process_message(
|
||||
@ -560,7 +545,7 @@ class Agent(BaseModel, ABC):
|
||||
|
||||
self.context.processing = True
|
||||
|
||||
message.metadata["system_prompt"] = (
|
||||
message.system_prompt = (
|
||||
f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
|
||||
)
|
||||
message.context_prompt = ""
|
||||
@ -575,11 +560,11 @@ class Agent(BaseModel, ABC):
|
||||
message.status = "thinking"
|
||||
yield message
|
||||
|
||||
message.metadata["context_size"] = self.set_optimal_context_size(
|
||||
message.context_size = self.set_optimal_context_size(
|
||||
llm, model, prompt=message.context_prompt
|
||||
)
|
||||
|
||||
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
||||
message.response = f"Processing {'RAG augmented ' if message.metadata.rag else ''}query..."
|
||||
message.status = "thinking"
|
||||
yield message
|
||||
|
||||
|
@ -19,9 +19,10 @@ import time
|
||||
import asyncio
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from ..message import Message
|
||||
from ..setup_logging import setup_logging
|
||||
from . base import Agent, agent_registry, LLMMessage
|
||||
from .. message import Message
|
||||
from .. rag import ChromaDBGetResponse
|
||||
from .. setup_logging import setup_logging
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
@ -91,8 +92,11 @@ class JobDescription(Agent):
|
||||
await asyncio.sleep(1) # Allow the event loop to process the write
|
||||
|
||||
self.context.processing = True
|
||||
job_description = message.preamble["job_description"]
|
||||
resume = message.preamble["resume"]
|
||||
|
||||
original_message = message.model_copy()
|
||||
original_message.prompt = job_description
|
||||
original_message.response = ""
|
||||
self.conversation.add(original_message)
|
||||
|
||||
@ -100,11 +104,9 @@ class JobDescription(Agent):
|
||||
self.model = model
|
||||
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
||||
job_description = message.preamble["job_description"]
|
||||
resume = message.preamble["resume"]
|
||||
|
||||
try:
|
||||
async for message in self.generate_factual_tailored_resume(
|
||||
async for message in self.generate_resume(
|
||||
message=message, job_description=job_description, resume=resume
|
||||
):
|
||||
if message.status != "done":
|
||||
@ -124,7 +126,7 @@ class JobDescription(Agent):
|
||||
# Done processing, add message to conversation
|
||||
self.context.processing = False
|
||||
|
||||
resume_generation = message.metadata.get("resume_generation", {})
|
||||
resume_generation = message.metadata.resume_generation
|
||||
if not resume_generation:
|
||||
message.response = (
|
||||
"Generation did not generate metadata necessary for processing."
|
||||
@ -341,18 +343,19 @@ Name: {candidate_name}
|
||||
## OUTPUT FORMAT:
|
||||
Provide the resume in clean markdown format, ready for the candidate to use.
|
||||
|
||||
## REFERENCE (Original Resume):
|
||||
"""
|
||||
# ## REFERENCE (Original Resume):
|
||||
# """
|
||||
|
||||
# Add a truncated version of the original resume for reference if it's too long
|
||||
max_resume_length = 25000 # Characters
|
||||
if len(original_resume) > max_resume_length:
|
||||
system_prompt += (
|
||||
original_resume[:max_resume_length]
|
||||
+ "...\n[Original resume truncated due to length]"
|
||||
)
|
||||
else:
|
||||
system_prompt += original_resume
|
||||
# # Add a truncated version of the original resume for reference if it's too long
|
||||
# max_resume_length = 25000 # Characters
|
||||
# if len(original_resume) > max_resume_length:
|
||||
# system_prompt += (
|
||||
# original_resume[:max_resume_length]
|
||||
# + "...\n[Original resume truncated due to length]"
|
||||
# )
|
||||
# else:
|
||||
# system_prompt += original_resume
|
||||
|
||||
prompt = "Create a tailored professional resume that highlights candidate's skills and experience most relevant to the job requirements. Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no commentary before or after."
|
||||
return system_prompt, prompt
|
||||
@ -413,7 +416,7 @@ Provide the resume in clean markdown format, ready for the candidate to use.
|
||||
yield message
|
||||
return
|
||||
|
||||
def calculate_match_statistics(self, job_requirements, skill_assessment_results):
|
||||
def calculate_match_statistics(self, job_requirements, skill_assessment_results) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Calculate statistics about how well the candidate matches job requirements
|
||||
|
||||
@ -594,9 +597,10 @@ a SPECIFIC skill based solely on their resume and supporting evidence.
|
||||
}}
|
||||
```
|
||||
|
||||
## CANDIDATE RESUME:
|
||||
{resume}
|
||||
"""
|
||||
# ## CANDIDATE RESUME:
|
||||
# {resume}
|
||||
# """
|
||||
|
||||
# Add RAG content if provided
|
||||
if rag_content:
|
||||
@ -821,7 +825,7 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
|
||||
LLMMessage(role="system", content=system_prompt),
|
||||
LLMMessage(role="user", content=prompt),
|
||||
]
|
||||
message.metadata["options"] = {
|
||||
message.metadata.options = {
|
||||
"seed": 8911,
|
||||
"num_ctx": self.context_size,
|
||||
"temperature": temperature, # Higher temperature to encourage tool usage
|
||||
@ -837,7 +841,7 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
options={
|
||||
**message.metadata["options"],
|
||||
**message.metadata.options,
|
||||
},
|
||||
stream=True,
|
||||
):
|
||||
@ -860,54 +864,40 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
|
||||
|
||||
if response.done:
|
||||
self.collect_metrics(response)
|
||||
message.metadata["eval_count"] += response.eval_count
|
||||
message.metadata["eval_duration"] += response.eval_duration
|
||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
||||
message.metadata[
|
||||
"prompt_eval_duration"
|
||||
] += response.prompt_eval_duration
|
||||
message.metadata.eval_count += response.eval_count
|
||||
message.metadata.eval_duration += response.eval_duration
|
||||
message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
||||
message.chunk = ""
|
||||
message.status = "done"
|
||||
yield message
|
||||
|
||||
def rag_function(self, skill: str) -> tuple[str, list[Any]]:
|
||||
def retrieve_rag_content(self, skill: str) -> tuple[str, ChromaDBGetResponse]:
|
||||
if self.context is None or self.context.file_watcher is None:
|
||||
raise ValueError("self.context or self.context.file_watcher is None")
|
||||
|
||||
try:
|
||||
rag_results = ""
|
||||
all_metadata = []
|
||||
chroma_results = self.context.file_watcher.find_similar(
|
||||
query=skill, top_k=5, threshold=0.5
|
||||
)
|
||||
rag_metadata = ChromaDBGetResponse()
|
||||
chroma_results = self.context.file_watcher.find_similar(query=skill, top_k=10, threshold=0.5)
|
||||
if chroma_results:
|
||||
chroma_embedding = np.array(
|
||||
chroma_results["query_embedding"]
|
||||
).flatten() # Ensure correct shape
|
||||
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
||||
|
||||
umap_2d = self.context.file_watcher.umap_model_2d.transform(
|
||||
[chroma_embedding]
|
||||
)[0].tolist()
|
||||
print(
|
||||
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
|
||||
) # Debug output
|
||||
umap_2d = self.context.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
||||
umap_3d = self.context.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
||||
|
||||
umap_3d = self.context.file_watcher.umap_model_3d.transform(
|
||||
[chroma_embedding]
|
||||
)[0].tolist()
|
||||
print(
|
||||
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
|
||||
) # Debug output
|
||||
|
||||
all_metadata.append(
|
||||
{
|
||||
"name": "JPK",
|
||||
**chroma_results,
|
||||
"umap_embedding_2d": umap_2d,
|
||||
"umap_embedding_3d": umap_3d,
|
||||
}
|
||||
rag_metadata = ChromaDBGetResponse(
|
||||
query=skill,
|
||||
query_embedding=query_embedding.tolist(),
|
||||
name="JPK",
|
||||
ids=chroma_results.get("ids", []),
|
||||
embeddings=chroma_results.get("embeddings", []),
|
||||
documents=chroma_results.get("documents", []),
|
||||
metadatas=chroma_results.get("metadatas", []),
|
||||
umap_embedding_2d=umap_2d.tolist(),
|
||||
umap_embedding_3d=umap_3d.tolist(),
|
||||
size=self.context.file_watcher.collection.count()
|
||||
)
|
||||
|
||||
for index, metadata in enumerate(chroma_results["metadatas"]):
|
||||
@ -919,18 +909,19 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
|
||||
]
|
||||
).strip()
|
||||
rag_results += f"""
|
||||
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")} lines {metadata.get("line_begin", 0)}-{metadata.get("line_end", 0)}
|
||||
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
||||
Document reference: {chroma_results["ids"][index]}
|
||||
Content: { content }
|
||||
|
||||
"""
|
||||
return rag_results, all_metadata
|
||||
return rag_results, rag_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(e)
|
||||
exit(0)
|
||||
|
||||
async def generate_factual_tailored_resume(
|
||||
async def generate_resume(
|
||||
self, message: Message, job_description: str, resume: str
|
||||
) -> AsyncGenerator[Message, None]:
|
||||
"""
|
||||
@ -947,12 +938,8 @@ Content: { content }
|
||||
if self.context is None:
|
||||
raise ValueError(f"context is None in {self.agent_type}")
|
||||
|
||||
message.status = "thinking"
|
||||
logger.info(message.response)
|
||||
yield message
|
||||
|
||||
message.metadata["resume_generation"] = {}
|
||||
metadata = message.metadata["resume_generation"]
|
||||
message.metadata.resume_generation = {}
|
||||
metadata = message.metadata.resume_generation
|
||||
# Stage 1A: Analyze job requirements
|
||||
streaming_message = Message(prompt="Analyze job requirements")
|
||||
streaming_message.status = "thinking"
|
||||
@ -975,8 +962,8 @@ Content: { content }
|
||||
prompts = self.process_job_requirements(
|
||||
job_requirements=job_requirements,
|
||||
resume=resume,
|
||||
rag_function=self.rag_function,
|
||||
) # , retrieve_rag_content)
|
||||
rag_function=self.retrieve_rag_content,
|
||||
)
|
||||
|
||||
# UI should persist this state of the message
|
||||
partial = message.model_copy()
|
||||
@ -1051,9 +1038,15 @@ Content: { content }
|
||||
partial.title = (
|
||||
f"Skill {index}/{total_prompts}: {description} [{match_level}]"
|
||||
)
|
||||
partial.metadata["rag"] = rag
|
||||
partial.metadata.rag = [rag] # Front-end expects a list of RAG retrievals
|
||||
if skill_description:
|
||||
partial.response += f"\n\n{skill_description}"
|
||||
partial.response = f"""
|
||||
```json
|
||||
{json.dumps(skill_assessment_results[skill_name]["skill_assessment"])}
|
||||
```
|
||||
|
||||
{skill_description}
|
||||
"""
|
||||
yield partial
|
||||
self.conversation.add(partial)
|
||||
|
||||
|
@ -7,9 +7,10 @@ import numpy as np # type: ignore
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||
import traceback
|
||||
|
||||
from .message import Message, Tunables
|
||||
from .rag import ChromaDBFileWatcher
|
||||
from .rag import ChromaDBFileWatcher, ChromaDBGetResponse
|
||||
from . import defines
|
||||
from . import tools as Tools
|
||||
from .agents import AnyAgent
|
||||
@ -35,7 +36,7 @@ class Context(BaseModel):
|
||||
user_job_description: Optional[str] = None
|
||||
user_facts: Optional[str] = None
|
||||
tools: List[dict] = Tools.enabled_tools(Tools.tools)
|
||||
rags: List[dict] = []
|
||||
rags: List[ChromaDBGetResponse] = []
|
||||
message_history_length: int = 5
|
||||
# Class managed fields
|
||||
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
|
||||
@ -82,56 +83,40 @@ class Context(BaseModel):
|
||||
|
||||
if not self.file_watcher:
|
||||
message.response = "No RAG context available."
|
||||
del message.metadata["rag"]
|
||||
message.status = "done"
|
||||
yield message
|
||||
return
|
||||
|
||||
message.metadata["rag"] = []
|
||||
for rag in self.rags:
|
||||
if not rag["enabled"]:
|
||||
if not rag.enabled:
|
||||
continue
|
||||
message.response = f"Checking RAG context {rag['name']}..."
|
||||
message.response = f"Checking RAG context {rag.name}..."
|
||||
yield message
|
||||
chroma_results = self.file_watcher.find_similar(
|
||||
query=message.prompt, top_k=top_k, threshold=threshold
|
||||
)
|
||||
if chroma_results:
|
||||
entries += len(chroma_results["documents"])
|
||||
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
||||
|
||||
chroma_embedding = np.array(
|
||||
chroma_results["query_embedding"]
|
||||
).flatten() # Ensure correct shape
|
||||
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||
umap_2d = self.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
||||
umap_3d = self.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
||||
|
||||
umap_2d = self.file_watcher.umap_model_2d.transform(
|
||||
[chroma_embedding]
|
||||
)[0].tolist()
|
||||
print(
|
||||
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
|
||||
) # Debug output
|
||||
|
||||
umap_3d = self.file_watcher.umap_model_3d.transform(
|
||||
[chroma_embedding]
|
||||
)[0].tolist()
|
||||
print(
|
||||
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
|
||||
) # Debug output
|
||||
|
||||
message.metadata["rag"].append(
|
||||
{
|
||||
"name": rag["name"],
|
||||
**chroma_results,
|
||||
"umap_embedding_2d": umap_2d,
|
||||
"umap_embedding_3d": umap_3d,
|
||||
"size": self.file_watcher.collection.count()
|
||||
}
|
||||
rag_metadata = ChromaDBGetResponse(
|
||||
query=message.prompt,
|
||||
query_embedding=query_embedding.tolist(),
|
||||
name=rag.name,
|
||||
ids=chroma_results.get("ids", []),
|
||||
embeddings=chroma_results.get("embeddings", []),
|
||||
documents=chroma_results.get("documents", []),
|
||||
metadatas=chroma_results.get("metadatas", []),
|
||||
umap_embedding_2d=umap_2d.tolist(),
|
||||
umap_embedding_3d=umap_3d.tolist(),
|
||||
size=self.file_watcher.collection.count()
|
||||
)
|
||||
message.response = f"Results from {rag['name']} RAG: {len(chroma_results['documents'])} results."
|
||||
yield message
|
||||
|
||||
if entries == 0:
|
||||
del message.metadata["rag"]
|
||||
message.metadata.rag.append(rag_metadata)
|
||||
message.response = f"Results from {rag.name} RAG: {len(chroma_results['documents'])} results."
|
||||
yield message
|
||||
|
||||
message.response = (
|
||||
f"RAG context gathered from results from {entries} documents."
|
||||
@ -142,7 +127,8 @@ class Context(BaseModel):
|
||||
except Exception as e:
|
||||
message.status = "error"
|
||||
message.response = f"Error generating RAG results: {str(e)}"
|
||||
logger.error(e)
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(message.response)
|
||||
yield message
|
||||
return
|
||||
|
||||
|
@ -1,12 +1,28 @@
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, List, Optional, Any, Union, Mapping
|
||||
from datetime import datetime, timezone
|
||||
from . rag import ChromaDBGetResponse
|
||||
from ollama._types import Options # type: ignore
|
||||
|
||||
class Tunables(BaseModel):
|
||||
enable_rag: bool = Field(default=True) # Enable RAG collection chromadb matching
|
||||
enable_tools: bool = Field(default=True) # Enable LLM to use tools
|
||||
enable_context: bool = Field(default=True) # Add <|context|> field to message
|
||||
enable_rag: bool = True # Enable RAG collection chromadb matching
|
||||
enable_tools: bool = True # Enable LLM to use tools
|
||||
enable_context: bool = True # Add <|context|> field to message
|
||||
|
||||
class MessageMetaData(BaseModel):
|
||||
rag: List[ChromaDBGetResponse] = Field(default_factory=list)
|
||||
eval_count: int = 0
|
||||
eval_duration: int = 0
|
||||
prompt_eval_count: int = 0
|
||||
prompt_eval_duration: int = 0
|
||||
context_size: int = 0
|
||||
resume_generation: Optional[Dict[str, Any]] = None
|
||||
options: Optional[Union[Mapping[str, Any], Options]] = None
|
||||
tools: Optional[Dict[str, Any]] = None
|
||||
timers: Optional[Dict[str, float]] = None
|
||||
|
||||
#resume : str = ""
|
||||
#match_stats: Optional[Dict[str, Dict[str, Any]]] = Field(default=None)
|
||||
|
||||
class Message(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True} # Allow Event
|
||||
@ -18,35 +34,21 @@ class Message(BaseModel):
|
||||
|
||||
# Generated while processing message
|
||||
status: str = "" # Status of the message
|
||||
preamble: dict[str, str] = {} # Preamble to be prepended to the prompt
|
||||
preamble: Dict[str, Any] = Field(default_factory=dict) # Preamble to be prepended to the prompt
|
||||
system_prompt: str = "" # System prompt provided to the LLM
|
||||
context_prompt: str = "" # Full content of the message (preamble + prompt)
|
||||
response: str = "" # LLM response to the preamble + query
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"rag": [],
|
||||
"eval_count": 0,
|
||||
"eval_duration": 0,
|
||||
"prompt_eval_count": 0,
|
||||
"prompt_eval_duration": 0,
|
||||
"context_size": 0,
|
||||
}
|
||||
)
|
||||
metadata: MessageMetaData = Field(default_factory=MessageMetaData)
|
||||
network_packets: int = 0 # Total number of streaming packets
|
||||
network_bytes: int = 0 # Total bytes sent while streaming packets
|
||||
actions: List[str] = (
|
||||
[]
|
||||
) # Other session modifying actions performed while processing the message
|
||||
timestamp: datetime = datetime.now(timezone.utc)
|
||||
chunk: str = Field(
|
||||
default=""
|
||||
) # This needs to be serialized so it will be sent in responses
|
||||
partial_response: str = Field(
|
||||
default=""
|
||||
) # This needs to be serialized so it will be sent in responses on timeout
|
||||
title: str = Field(
|
||||
default=""
|
||||
) # This needs to be serialized so it will be sent in responses on timeout
|
||||
timestamp: str = str(datetime.now(timezone.utc))
|
||||
chunk: str = ""
|
||||
partial_response: str = ""
|
||||
title: str = ""
|
||||
context_size: int = 0
|
||||
|
||||
def add_action(self, action: str | list[str]) -> None:
|
||||
"""Add a actions(s) to the message."""
|
||||
|
108
src/utils/rag.py
108
src/utils/rag.py
@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
import os
|
||||
import glob
|
||||
from pathlib import Path
|
||||
@ -7,15 +7,9 @@ import time
|
||||
import hashlib
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import glob
|
||||
import time
|
||||
import hashlib
|
||||
import asyncio
|
||||
import json
|
||||
import numpy as np # type: ignore
|
||||
import traceback
|
||||
import os
|
||||
|
||||
import chromadb
|
||||
import ollama
|
||||
@ -38,18 +32,50 @@ else:
|
||||
# When imported as a module, use relative imports
|
||||
from . import defines
|
||||
|
||||
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
|
||||
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "ChromaDBGetResponse"]
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 750
|
||||
DEFAULT_CHUNK_OVERLAP = 100
|
||||
|
||||
|
||||
class ChromaDBGetResponse(BaseModel):
|
||||
ids: List[str]
|
||||
embeddings: Optional[List[List[float]]] = None
|
||||
documents: Optional[List[str]] = None
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None
|
||||
name: str = ""
|
||||
size: int = 0
|
||||
ids: List[str] = []
|
||||
embeddings: List[List[float]] = Field(default=[])
|
||||
documents: List[str] = []
|
||||
metadatas: List[Dict[str, Any]] = []
|
||||
query: str = ""
|
||||
query_embedding: Optional[List[float]] = Field(default=None)
|
||||
umap_embedding_2d: Optional[List[float]] = Field(default=None)
|
||||
umap_embedding_3d: Optional[List[float]] = Field(default=None)
|
||||
enabled: bool = True
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
|
||||
@field_validator("embeddings", "query_embedding", "umap_embedding_2d", "umap_embedding_3d")
|
||||
@classmethod
|
||||
def validate_embeddings(cls, value, field):
|
||||
logging.info(f"Validating {field.field_name} with value: {type(value)} - {value}")
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, np.ndarray):
|
||||
if field.field_name == "embeddings":
|
||||
if value.ndim != 2:
|
||||
raise ValueError(f"{field.name} must be a 2-dimensional NumPy array")
|
||||
return [[float(x) for x in row] for row in value.tolist()]
|
||||
else:
|
||||
if value.ndim != 1:
|
||||
raise ValueError(f"{field.field_name} must be a 1-dimensional NumPy array")
|
||||
return [float(x) for x in value.tolist()]
|
||||
if field.field_name == "embeddings":
|
||||
if not all(isinstance(sublist, list) and all(isinstance(x, (int, float)) for x in sublist) for sublist in value):
|
||||
raise ValueError(f"{field.field_name} must be a list of lists of floats")
|
||||
return [[float(x) for x in sublist] for sublist in value]
|
||||
else:
|
||||
if not isinstance(value, list) or not all(isinstance(x, (int, float)) for x in value):
|
||||
raise ValueError(f"{field.field_name} must be a list of floats")
|
||||
return [float(x) for x in value]
|
||||
|
||||
class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
def __init__(
|
||||
@ -323,13 +349,13 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
n_components=2,
|
||||
random_state=8911,
|
||||
metric="cosine",
|
||||
n_neighbors=15,
|
||||
n_neighbors=30,
|
||||
min_dist=0.1,
|
||||
)
|
||||
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
|
||||
logging.info(
|
||||
f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
|
||||
) # Should be 2
|
||||
# logging.info(
|
||||
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
|
||||
# ) # Should be 2
|
||||
|
||||
logging.info(
|
||||
f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
||||
@ -338,13 +364,13 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
n_components=3,
|
||||
random_state=8911,
|
||||
metric="cosine",
|
||||
n_neighbors=15,
|
||||
min_dist=0.1,
|
||||
n_neighbors=30,
|
||||
min_dist=0.01,
|
||||
)
|
||||
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
|
||||
logging.info(
|
||||
f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
|
||||
) # Should be 3
|
||||
# logging.info(
|
||||
# f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
|
||||
# ) # Should be 3
|
||||
|
||||
def _get_vector_collection(self, recreate=False) -> Collection:
|
||||
"""Get or create a ChromaDB collection."""
|
||||
@ -380,14 +406,36 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
"""Split documents into chunks using the text splitter."""
|
||||
return self.text_splitter.split_documents(docs)
|
||||
|
||||
def get_embedding(self, text, normalize=True):
|
||||
"""Generate embeddings using Ollama."""
|
||||
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
|
||||
embedding = response["embedding"]
|
||||
def get_embedding(self, text: str) -> np.ndarray:
|
||||
"""Generate and normalize an embedding for the given text."""
|
||||
|
||||
# Get embedding
|
||||
try:
|
||||
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
|
||||
embedding = np.array(response["embedding"])
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get embedding: {e}")
|
||||
raise
|
||||
|
||||
# Log diagnostics
|
||||
logging.info(f"Input text: {text}")
|
||||
logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
|
||||
|
||||
# Check for invalid embeddings
|
||||
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
|
||||
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
|
||||
raise ValueError("Invalid embedding returned from Ollama.")
|
||||
|
||||
# Check normalization
|
||||
norm = np.linalg.norm(embedding)
|
||||
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
|
||||
logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
|
||||
|
||||
# Normalize if needed
|
||||
if not is_normalized:
|
||||
embedding = embedding / norm
|
||||
logging.info("Embedding normalized manually.")
|
||||
|
||||
if normalize:
|
||||
normalized = self._normalize_embeddings(embedding)
|
||||
return normalized
|
||||
return embedding
|
||||
|
||||
def add_embeddings_to_collection(self, chunks: List[Chunk]):
|
||||
|
Loading…
x
Reference in New Issue
Block a user