All seems to be working
This commit is contained in:
parent
0d27239ca6
commit
58dadb76f0
11
Dockerfile
11
Dockerfile
@ -77,6 +77,17 @@ RUN apt-get update \
|
|||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||||
|
|
||||||
|
# pydub is loaded by torch, which will throw a warning if ffmpeg isn't installed
|
||||||
|
RUN apt-get update \
|
||||||
|
&& DEBIAN_FRONTEND=noninteractive apt-get install -y software-properties-common \
|
||||||
|
&& add-apt-repository -y ppa:kobuk-team/intel-graphics \
|
||||||
|
&& apt-get update \
|
||||||
|
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||||
|
ffmpeg \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||||
|
|
||||||
|
|
||||||
# Prerequisite for ze-monitor
|
# Prerequisite for ze-monitor
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||||
|
0
cache/.keep
vendored
Normal file → Executable file
0
cache/.keep
vendored
Normal file → Executable file
0
cache/grafana/.keep
vendored
Normal file → Executable file
0
cache/grafana/.keep
vendored
Normal file → Executable file
0
cache/prometheus/.keep
vendored
Normal file → Executable file
0
cache/prometheus/.keep
vendored
Normal file → Executable file
1119
frontend/package-lock.json
generated
1119
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -19,6 +19,7 @@
|
|||||||
"@types/react": "^19.0.12",
|
"@types/react": "^19.0.12",
|
||||||
"@types/react-dom": "^19.0.4",
|
"@types/react-dom": "^19.0.4",
|
||||||
"@uiw/react-json-view": "^2.0.0-alpha.31",
|
"@uiw/react-json-view": "^2.0.0-alpha.31",
|
||||||
|
"@uiw/react-markdown-editor": "^6.1.4",
|
||||||
"jsonrepair": "^3.12.0",
|
"jsonrepair": "^3.12.0",
|
||||||
"markdown-it": "^14.1.0",
|
"markdown-it": "^14.1.0",
|
||||||
"mermaid": "^11.6.0",
|
"mermaid": "^11.6.0",
|
||||||
|
@ -113,6 +113,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
|||||||
|
|
||||||
const tunables = await response.json();
|
const tunables = await response.json();
|
||||||
serverTunables.system_prompt = tunables.system_prompt;
|
serverTunables.system_prompt = tunables.system_prompt;
|
||||||
|
console.log(tunables);
|
||||||
setSystemPrompt(tunables.system_prompt)
|
setSystemPrompt(tunables.system_prompt)
|
||||||
setSnack("System prompt updated", "success");
|
setSnack("System prompt updated", "success");
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@ -167,31 +168,36 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
|||||||
body: JSON.stringify({ "reset": types }),
|
body: JSON.stringify({ "reset": types }),
|
||||||
});
|
});
|
||||||
|
|
||||||
if (response.ok) {
|
if (!response.ok) {
|
||||||
const data = await response.json();
|
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
||||||
if (data.error) {
|
|
||||||
throw Error()
|
|
||||||
}
|
|
||||||
for (const [key, value] of Object.entries(data)) {
|
|
||||||
switch (key) {
|
|
||||||
case "rags":
|
|
||||||
setRags(value as Tool[]);
|
|
||||||
break;
|
|
||||||
case "tools":
|
|
||||||
setTools(value as Tool[]);
|
|
||||||
break;
|
|
||||||
case "system_prompt":
|
|
||||||
setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
|
|
||||||
break;
|
|
||||||
case "history":
|
|
||||||
console.log('TODO: handle history reset');
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
setSnack(message, "success");
|
|
||||||
} else {
|
|
||||||
throw Error(`${{ status: response.status, message: response.statusText }}`);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!response.body) {
|
||||||
|
throw new Error('Response body is null');
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
if (data.error) {
|
||||||
|
throw Error(data.error);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const [key, value] of Object.entries(data)) {
|
||||||
|
switch (key) {
|
||||||
|
case "rags":
|
||||||
|
setRags(value as Tool[]);
|
||||||
|
break;
|
||||||
|
case "tools":
|
||||||
|
setTools(value as Tool[]);
|
||||||
|
break;
|
||||||
|
case "system_prompt":
|
||||||
|
setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
|
||||||
|
break;
|
||||||
|
case "history":
|
||||||
|
console.log('TODO: handle history reset');
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
setSnack(message, "success");
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Fetch error:', error);
|
console.error('Fetch error:', error);
|
||||||
setSnack("Unable to restore defaults", "error");
|
setSnack("Unable to restore defaults", "error");
|
||||||
@ -203,20 +209,37 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
|||||||
if (systemInfo !== undefined || sessionId === undefined) {
|
if (systemInfo !== undefined || sessionId === undefined) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
fetch(connectionBase + `/api/system-info/${sessionId}`, {
|
const fetchSystemInfo = async () => {
|
||||||
method: 'GET',
|
try {
|
||||||
headers: {
|
const response = await fetch(connectionBase + `/api/system-info/${sessionId}`, {
|
||||||
'Content-Type': 'application/json',
|
method: 'GET',
|
||||||
},
|
headers: {
|
||||||
})
|
'Content-Type': 'application/json',
|
||||||
.then(response => response.json())
|
},
|
||||||
.then(data => {
|
})
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!response.body) {
|
||||||
|
throw new Error('Response body is null');
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
if (data.error) {
|
||||||
|
throw Error(data.error);
|
||||||
|
}
|
||||||
|
|
||||||
setSystemInfo(data);
|
setSystemInfo(data);
|
||||||
})
|
} catch (error) {
|
||||||
.catch(error => {
|
|
||||||
console.error('Error obtaining system information:', error);
|
console.error('Error obtaining system information:', error);
|
||||||
setSnack("Unable to obtain system information.", "error");
|
setSnack("Unable to obtain system information.", "error");
|
||||||
});
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
fetchSystemInfo();
|
||||||
|
|
||||||
}, [systemInfo, setSystemInfo, setSnack, sessionId])
|
}, [systemInfo, setSystemInfo, setSnack, sessionId])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -273,25 +296,30 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const fetchTunables = async () => {
|
const fetchTunables = async () => {
|
||||||
// Make the fetch request with proper headers
|
try {
|
||||||
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
// Make the fetch request with proper headers
|
||||||
method: 'GET',
|
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
|
||||||
headers: {
|
method: 'GET',
|
||||||
'Content-Type': 'application/json',
|
headers: {
|
||||||
'Accept': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
},
|
'Accept': 'application/json',
|
||||||
});
|
},
|
||||||
const data = await response.json();
|
});
|
||||||
// console.log("Server tunables: ", data);
|
const data = await response.json();
|
||||||
setServerTunables(data);
|
// console.log("Server tunables: ", data);
|
||||||
setSystemPrompt(data["system_prompt"]);
|
setServerTunables(data);
|
||||||
setMessageHistoryLength(data["message_history_length"]);
|
setSystemPrompt(data["system_prompt"]);
|
||||||
setTools(data["tools"]);
|
setMessageHistoryLength(data["message_history_length"]);
|
||||||
setRags(data["rags"]);
|
setTools(data["tools"]);
|
||||||
|
setRags(data["rags"]);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Fetch error:', error);
|
||||||
|
setSnack("System prompt update failed", "error");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fetchTunables();
|
fetchTunables();
|
||||||
}, [sessionId, setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags]);
|
}, [sessionId, setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags, setSnack]);
|
||||||
|
|
||||||
const toggle = async (type: string, index: number) => {
|
const toggle = async (type: string, index: number) => {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
|
@ -285,7 +285,9 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
|||||||
throw new Error('Response body is null');
|
throw new Error('Response body is null');
|
||||||
}
|
}
|
||||||
|
|
||||||
setConversation([])
|
setProcessingMessage(undefined);
|
||||||
|
setStreamingMessage(undefined);
|
||||||
|
setConversation([]);
|
||||||
setNoInteractions(true);
|
setNoInteractions(true);
|
||||||
|
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
@ -341,13 +343,23 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
|||||||
// Add a small delay to ensure React has time to update the UI
|
// Add a small delay to ensure React has time to update the UI
|
||||||
await new Promise(resolve => setTimeout(resolve, 0));
|
await new Promise(resolve => setTimeout(resolve, 0));
|
||||||
|
|
||||||
|
let data: any = query;
|
||||||
|
if (type === "job_description") {
|
||||||
|
data = {
|
||||||
|
prompt: "",
|
||||||
|
agent_options: {
|
||||||
|
job_description: query.prompt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const response = await fetch(connectionBase + `/api/${type}/${sessionId}`, {
|
const response = await fetch(connectionBase + `/api/${type}/${sessionId}`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Accept': 'application/json',
|
'Accept': 'application/json',
|
||||||
},
|
},
|
||||||
body: JSON.stringify(query)
|
body: JSON.stringify(data)
|
||||||
});
|
});
|
||||||
|
|
||||||
setSnack(`Query sent.`, "info");
|
setSnack(`Query sent.`, "info");
|
||||||
|
@ -77,7 +77,7 @@ interface MessageMetaData {
|
|||||||
vector_embedding: number[];
|
vector_embedding: number[];
|
||||||
},
|
},
|
||||||
origin: string,
|
origin: string,
|
||||||
rag: any,
|
rag: any[],
|
||||||
tools?: {
|
tools?: {
|
||||||
tool_calls: any[],
|
tool_calls: any[],
|
||||||
},
|
},
|
||||||
@ -117,8 +117,6 @@ const MessageMeta = (props: MessageMetaProps) => {
|
|||||||
} = props.metadata || {};
|
} = props.metadata || {};
|
||||||
const message: any = props.messageProps.message;
|
const message: any = props.messageProps.message;
|
||||||
|
|
||||||
rag.forEach((r: any) => r.query = message.prompt);
|
|
||||||
|
|
||||||
let llm_submission: string = "<|system|>\n"
|
let llm_submission: string = "<|system|>\n"
|
||||||
llm_submission += message.system_prompt + "\n\n"
|
llm_submission += message.system_prompt + "\n\n"
|
||||||
llm_submission += message.context_prompt
|
llm_submission += message.context_prompt
|
||||||
@ -176,7 +174,10 @@ const MessageMeta = (props: MessageMetaProps) => {
|
|||||||
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "column", mt: 1, mb: 1, fontWeight: "bold" }}>
|
<Box sx={{ fontSize: "0.75rem", display: "flex", flexDirection: "column", mt: 1, mb: 1, fontWeight: "bold" }}>
|
||||||
{tool.name}
|
{tool.name}
|
||||||
</Box>
|
</Box>
|
||||||
<JsonView displayDataTypes={false} objectSortKeys={true} collapsed={1} value={JSON.parse(tool.content)} style={{ fontSize: "0.8rem", maxHeight: "20rem", overflow: "auto" }}>
|
<JsonView
|
||||||
|
displayDataTypes={false}
|
||||||
|
objectSortKeys={true}
|
||||||
|
collapsed={1} value={JSON.parse(tool.content)} style={{ fontSize: "0.8rem", maxHeight: "20rem", overflow: "auto" }}>
|
||||||
<JsonView.String
|
<JsonView.String
|
||||||
render={({ children, ...reset }) => {
|
render={({ children, ...reset }) => {
|
||||||
if (typeof (children) === "string" && children.match("\n")) {
|
if (typeof (children) === "string" && children.match("\n")) {
|
||||||
|
@ -151,7 +151,6 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const jobResponse = useCallback(async (message: BackstoryMessage) => {
|
const jobResponse = useCallback(async (message: BackstoryMessage) => {
|
||||||
console.log('onJobResponse', message);
|
|
||||||
if (message.actions && message.actions.includes("job_description")) {
|
if (message.actions && message.actions.includes("job_description")) {
|
||||||
await jobConversationRef.current.fetchHistory();
|
await jobConversationRef.current.fetchHistory();
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
|
|||||||
}}
|
}}
|
||||||
displayDataTypes={false}
|
displayDataTypes={false}
|
||||||
objectSortKeys={false}
|
objectSortKeys={false}
|
||||||
collapsed={true}
|
collapsed={1}
|
||||||
shortenTextAfterLength={100}
|
shortenTextAfterLength={100}
|
||||||
value={fixed}>
|
value={fixed}>
|
||||||
<JsonView.String
|
<JsonView.String
|
||||||
|
@ -10,7 +10,6 @@ import FormControlLabel from '@mui/material/FormControlLabel';
|
|||||||
import Switch from '@mui/material/Switch';
|
import Switch from '@mui/material/Switch';
|
||||||
import useMediaQuery from '@mui/material/useMediaQuery';
|
import useMediaQuery from '@mui/material/useMediaQuery';
|
||||||
import { SxProps, useTheme } from '@mui/material/styles';
|
import { SxProps, useTheme } from '@mui/material/styles';
|
||||||
import JsonView from '@uiw/react-json-view';
|
|
||||||
import Table from '@mui/material/Table';
|
import Table from '@mui/material/Table';
|
||||||
import TableBody from '@mui/material/TableBody';
|
import TableBody from '@mui/material/TableBody';
|
||||||
import TableCell from '@mui/material/TableCell';
|
import TableCell from '@mui/material/TableCell';
|
||||||
@ -499,13 +498,13 @@ The scatter graph shows the query in N-dimensional space, mapped to ${view2D ? '
|
|||||||
</Box>
|
</Box>
|
||||||
}
|
}
|
||||||
|
|
||||||
<Box sx={{ display: "flex", flexDirection: "column" }}>
|
<Box sx={{ display: "flex", flexDirection: "column", flexGrow: 1 }}>
|
||||||
{node === null &&
|
{node === null &&
|
||||||
<Paper sx={{ m: 0.5, p: 2, flexGrow: 1 }}>
|
<Paper sx={{ m: 0.5, p: 2, flexGrow: 1 }}>
|
||||||
Click a point in the scatter-graph to see information about that node.
|
Click a point in the scatter-graph to see information about that node.
|
||||||
</Paper>
|
</Paper>
|
||||||
}
|
}
|
||||||
{!inline && node !== null && node.full_content &&
|
{node !== null && node.full_content &&
|
||||||
<Scrollable
|
<Scrollable
|
||||||
autoscroll={false}
|
autoscroll={false}
|
||||||
sx={{
|
sx={{
|
||||||
@ -521,7 +520,7 @@ The scatter graph shows the query in N-dimensional space, mapped to ${view2D ? '
|
|||||||
node.full_content.split('\n').map((line, index) => {
|
node.full_content.split('\n').map((line, index) => {
|
||||||
index += 1 + node.chunk_begin;
|
index += 1 + node.chunk_begin;
|
||||||
const bgColor = (index > node.line_begin && index <= node.line_end) ? '#f0f0f0' : 'auto';
|
const bgColor = (index > node.line_begin && index <= node.line_end) ? '#f0f0f0' : 'auto';
|
||||||
return <Box key={index} sx={{ display: "flex", flexDirection: "row", borderBottom: '1px solid #d0d0d0', ':first-child': { borderTop: '1px solid #d0d0d0' }, backgroundColor: bgColor }}>
|
return <Box key={index} sx={{ display: "flex", flexDirection: "row", borderBottom: '1px solid #d0d0d0', ':first-of-type': { borderTop: '1px solid #d0d0d0' }, backgroundColor: bgColor }}>
|
||||||
<Box sx={{ fontFamily: 'courier', fontSize: "0.8rem", minWidth: "2rem", pt: "0.1rem", align: "left", verticalAlign: "top" }}>{index}</Box>
|
<Box sx={{ fontFamily: 'courier', fontSize: "0.8rem", minWidth: "2rem", pt: "0.1rem", align: "left", verticalAlign: "top" }}>{index}</Box>
|
||||||
<pre style={{ margin: 0, padding: 0, border: "none", minHeight: "1rem" }} >{line || " "}</pre>
|
<pre style={{ margin: 0, padding: 0, border: "none", minHeight: "1rem" }} >{line || " "}</pre>
|
||||||
</Box>;
|
</Box>;
|
||||||
|
@ -118,6 +118,10 @@ const useAutoScrollToBottom = (
|
|||||||
let shouldScroll = false;
|
let shouldScroll = false;
|
||||||
const scrollTo = scrollToRef.current;
|
const scrollTo = scrollToRef.current;
|
||||||
|
|
||||||
|
if (isPasteEvent && !scrollTo) {
|
||||||
|
console.error("Paste Event triggered without scrollTo");
|
||||||
|
}
|
||||||
|
|
||||||
if (scrollTo) {
|
if (scrollTo) {
|
||||||
// Get positions
|
// Get positions
|
||||||
const containerRect = container.getBoundingClientRect();
|
const containerRect = container.getBoundingClientRect();
|
||||||
@ -130,7 +134,7 @@ const useAutoScrollToBottom = (
|
|||||||
scrollToRect.top < containerBottom && scrollToRect.bottom > containerTop;
|
scrollToRect.top < containerBottom && scrollToRect.bottom > containerTop;
|
||||||
|
|
||||||
// Scroll on paste or if TextField is visible and user isn't scrolling up
|
// Scroll on paste or if TextField is visible and user isn't scrolling up
|
||||||
shouldScroll = (isPasteEvent || isTextFieldVisible) && !isUserScrollingUpRef.current;
|
shouldScroll = isPasteEvent || (isTextFieldVisible && !isUserScrollingUpRef.current);
|
||||||
if (shouldScroll) {
|
if (shouldScroll) {
|
||||||
requestAnimationFrame(() => {
|
requestAnimationFrame(() => {
|
||||||
debug && console.debug('Scrolling to container bottom:', {
|
debug && console.debug('Scrolling to container bottom:', {
|
||||||
@ -198,10 +202,12 @@ const useAutoScrollToBottom = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
const handlePaste = () => {
|
const handlePaste = () => {
|
||||||
|
console.log("handlePaste");
|
||||||
// Delay scroll check to ensure DOM updates
|
// Delay scroll check to ensure DOM updates
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
|
console.log("scrolling for handlePaste");
|
||||||
requestAnimationFrame(() => checkAndScrollToBottom(true));
|
requestAnimationFrame(() => checkAndScrollToBottom(true));
|
||||||
}, 0);
|
}, 100);
|
||||||
};
|
};
|
||||||
|
|
||||||
window.addEventListener('mousemove', pauseScroll);
|
window.addEventListener('mousemove', pauseScroll);
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 88
|
line-length = 120
|
||||||
target-version = ['py312']
|
target-version = ['py312']
|
||||||
include = '\.pyi?$'
|
include = '\.pyi?$'
|
||||||
exclude = '''
|
exclude = '''
|
||||||
@ -19,4 +19,4 @@ ignore_decorators = [
|
|||||||
"@model_validator",
|
"@model_validator",
|
||||||
"@override", "@classmethod"
|
"@override", "@classmethod"
|
||||||
]
|
]
|
||||||
exclude = ["tests/", "__pycache__/"]
|
exclude = ["tests/", "__pycache__/"]
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
LLM_TIMEOUT = 600
|
LLM_TIMEOUT = 600
|
||||||
|
|
||||||
from utils import logger
|
from utils import logger
|
||||||
from pydantic import BaseModel, Field # type: ignore
|
from pydantic import BaseModel, Field, ValidationError # type: ignore
|
||||||
|
from pydantic_core import PydanticSerializationError # type: ignore
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from typing import AsyncGenerator, Dict, Optional
|
from typing import AsyncGenerator, Dict, Optional
|
||||||
|
|
||||||
@ -61,6 +63,7 @@ from prometheus_client import CollectorRegistry, Counter # type: ignore
|
|||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
rag as Rag,
|
rag as Rag,
|
||||||
|
ChromaDBGetResponse,
|
||||||
tools as Tools,
|
tools as Tools,
|
||||||
Context,
|
Context,
|
||||||
Conversation,
|
Conversation,
|
||||||
@ -69,15 +72,17 @@ from utils import (
|
|||||||
Metrics,
|
Metrics,
|
||||||
Tunables,
|
Tunables,
|
||||||
defines,
|
defines,
|
||||||
|
check_serializable,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
rags = [
|
|
||||||
{
|
rags : List[ChromaDBGetResponse] = [
|
||||||
"name": "JPK",
|
ChromaDBGetResponse(
|
||||||
"enabled": True,
|
name="JPK",
|
||||||
"description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects.",
|
enabled=True,
|
||||||
},
|
description="Expert data about James Ketrenos, including work history, personal hobbies, and projects.",
|
||||||
|
),
|
||||||
# { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
|
# { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -461,10 +466,8 @@ class WebServer:
|
|||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent(agent_type)
|
agent = context.get_agent(agent_type)
|
||||||
if not agent:
|
if not agent:
|
||||||
return JSONResponse(
|
response = { "history": [] }
|
||||||
{"error": f"{agent_type} is not recognized", "context": context.id},
|
return JSONResponse(response)
|
||||||
status_code=404,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
try:
|
try:
|
||||||
@ -475,8 +478,8 @@ class WebServer:
|
|||||||
logger.info(f"Resetting {reset_operation}")
|
logger.info(f"Resetting {reset_operation}")
|
||||||
case "rags":
|
case "rags":
|
||||||
logger.info(f"Resetting {reset_operation}")
|
logger.info(f"Resetting {reset_operation}")
|
||||||
context.rags = rags.copy()
|
context.rags = [ r.model_copy() for r in rags]
|
||||||
response["rags"] = context.rags
|
response["rags"] = [ r.model_dump(mode="json") for r in context.rags ]
|
||||||
case "tools":
|
case "tools":
|
||||||
logger.info(f"Resetting {reset_operation}")
|
logger.info(f"Resetting {reset_operation}")
|
||||||
context.tools = Tools.enabled_tools(Tools.tools)
|
context.tools = Tools.enabled_tools(Tools.tools)
|
||||||
@ -537,6 +540,7 @@ class WebServer:
|
|||||||
data = await request.json()
|
data = await request.json()
|
||||||
agent = context.get_agent("chat")
|
agent = context.get_agent("chat")
|
||||||
if not agent:
|
if not agent:
|
||||||
|
logger.info("chat agent does not exist on this context!")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{"error": f"chat is not recognized", "context": context.id},
|
{"error": f"chat is not recognized", "context": context.id},
|
||||||
status_code=404,
|
status_code=404,
|
||||||
@ -572,20 +576,20 @@ class WebServer:
|
|||||||
|
|
||||||
case "rags":
|
case "rags":
|
||||||
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
||||||
rags: list[dict[str, Any]] = data[k]
|
rag_configs: list[dict[str, Any]] = data[k]
|
||||||
if not rags:
|
if not rag_configs:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{
|
{
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "RAGs can not be empty.",
|
"message": "RAGs can not be empty.",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
for rag in rags:
|
for config in rag_configs:
|
||||||
for context_rag in context.rags:
|
for context_rag in context.rags:
|
||||||
if context_rag["name"] == rag["name"]:
|
if context_rag.name == config["name"]:
|
||||||
context_rag["enabled"] = rag["enabled"]
|
context_rag.enabled = config["enabled"]
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse({"rags": context.rags})
|
return JSONResponse({"rags": [ r.model_dump(mode="json") for r in context.rags]})
|
||||||
|
|
||||||
case "system_prompt":
|
case "system_prompt":
|
||||||
system_prompt = data[k].strip()
|
system_prompt = data[k].strip()
|
||||||
@ -615,12 +619,10 @@ class WebServer:
|
|||||||
@self.app.get("/api/tunables/{context_id}")
|
@self.app.get("/api/tunables/{context_id}")
|
||||||
async def get_tunables(context_id: str, request: Request):
|
async def get_tunables(context_id: str, request: Request):
|
||||||
logger.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
if not is_valid_uuid(context_id):
|
|
||||||
logger.warning(f"Invalid context_id: {context_id}")
|
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent("chat")
|
agent = context.get_agent("chat")
|
||||||
if not agent:
|
if not agent:
|
||||||
|
logger.info("chat agent does not exist on this context!")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{"error": f"chat is not recognized", "context": context.id},
|
{"error": f"chat is not recognized", "context": context.id},
|
||||||
status_code=404,
|
status_code=404,
|
||||||
@ -629,7 +631,7 @@ class WebServer:
|
|||||||
{
|
{
|
||||||
"system_prompt": agent.system_prompt,
|
"system_prompt": agent.system_prompt,
|
||||||
"message_history_length": context.message_history_length,
|
"message_history_length": context.message_history_length,
|
||||||
"rags": context.rags,
|
"rags": [ r.model_dump(mode="json") for r in context.rags ],
|
||||||
"tools": [
|
"tools": [
|
||||||
{
|
{
|
||||||
**t["function"],
|
**t["function"],
|
||||||
@ -674,6 +676,7 @@ class WebServer:
|
|||||||
error = {
|
error = {
|
||||||
"error": f"Attempt to create agent type: {agent_type} failed: {e}"
|
"error": f"Attempt to create agent type: {agent_type} failed: {e}"
|
||||||
}
|
}
|
||||||
|
logger.info(error)
|
||||||
return JSONResponse(error, status_code=404)
|
return JSONResponse(error, status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -887,8 +890,35 @@ class WebServer:
|
|||||||
file_path = os.path.join(defines.context_dir, context_id)
|
file_path = os.path.join(defines.context_dir, context_id)
|
||||||
|
|
||||||
# Serialize the data to JSON and write to file
|
# Serialize the data to JSON and write to file
|
||||||
with open(file_path, "w") as f:
|
try:
|
||||||
f.write(context.model_dump_json(by_alias=True))
|
# Check for non-serializable fields before dumping
|
||||||
|
serialization_errors = check_serializable(context)
|
||||||
|
if serialization_errors:
|
||||||
|
for error in serialization_errors:
|
||||||
|
logger.error(error)
|
||||||
|
raise ValueError("Found non-serializable fields in the model")
|
||||||
|
# Dump the model prior to opening file in case there is
|
||||||
|
# a validation error so it doesn't delete the current
|
||||||
|
# context session
|
||||||
|
json_data = context.model_dump_json(by_alias=True)
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
f.write(json_data)
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.error(e)
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
for error in e.errors():
|
||||||
|
print(f"Field: {error['loc'][0]}, Error: {error['msg']}")
|
||||||
|
except PydanticSerializationError as e:
|
||||||
|
logger.error(e)
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(f"Serialization error: {str(e)}")
|
||||||
|
# Inspect the model to identify problematic fields
|
||||||
|
for field_name, value in context.__dict__.items():
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
logger.error(f"Field '{field_name}' contains non-serializable type: {type(value)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
return context_id
|
return context_id
|
||||||
|
|
||||||
@ -942,12 +972,8 @@ class WebServer:
|
|||||||
self.contexts[context_id] = context
|
self.contexts[context_id] = context
|
||||||
|
|
||||||
logger.info(f"Successfully loaded context {context_id}")
|
logger.info(f"Successfully loaded context {context_id}")
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"Invalid JSON in file: {e}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error validating context: {str(e)}")
|
logger.error(f"Error validating context: {str(e)}")
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
# Fallback to creating a new context
|
# Fallback to creating a new context
|
||||||
self.contexts[context_id] = Context(
|
self.contexts[context_id] = Context(
|
||||||
@ -985,7 +1011,7 @@ class WebServer:
|
|||||||
# context.add_agent(JobDescription(system_prompt = system_job_description))
|
# context.add_agent(JobDescription(system_prompt = system_job_description))
|
||||||
# context.add_agent(FactCheck(system_prompt = system_fact_check))
|
# context.add_agent(FactCheck(system_prompt = system_fact_check))
|
||||||
context.tools = Tools.enabled_tools(Tools.tools)
|
context.tools = Tools.enabled_tools(Tools.tools)
|
||||||
context.rags = rags.copy()
|
context.rags_enabled = [ r.name for r in rags ]
|
||||||
|
|
||||||
logger.info(f"{context.id} created and added to contexts.")
|
logger.info(f"{context.id} created and added to contexts.")
|
||||||
self.contexts[context.id] = context
|
self.contexts[context.id] = context
|
||||||
|
BIN
src/tests/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/tests/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/tests/__pycache__/test-context.cpython-312.pyc
Normal file
BIN
src/tests/__pycache__/test-context.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/tests/__pycache__/test-embedding.cpython-312.pyc
Normal file
BIN
src/tests/__pycache__/test-embedding.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/tests/__pycache__/test-message.cpython-312.pyc
Normal file
BIN
src/tests/__pycache__/test-message.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/tests/__pycache__/test-rag.cpython-312.pyc
Normal file
BIN
src/tests/__pycache__/test-rag.cpython-312.pyc
Normal file
Binary file not shown.
@ -20,5 +20,5 @@ observer, file_watcher = Rag.start_file_watcher(
|
|||||||
)
|
)
|
||||||
|
|
||||||
context = Context(file_watcher=file_watcher)
|
context = Context(file_watcher=file_watcher)
|
||||||
data = context.model_dump(mode="json")
|
json_data = context.model_dump(mode="json")
|
||||||
context = Context.from_json(json.dumps(data), file_watcher=file_watcher)
|
context = Context.model_validate(json_data)
|
||||||
|
89
src/tests/test-embedding.py
Normal file
89
src/tests/test-embedding.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
# From /opt/backstory run:
|
||||||
|
# python -m src.tests.test-embedding
|
||||||
|
import numpy as np # type: ignore
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
from ollama import Client # type: ignore
|
||||||
|
from ..utils import defines
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_embedding(text: str, embedding_model: str, ollama_server: str) -> np.ndarray:
|
||||||
|
"""Generate and normalize an embedding for the given text."""
|
||||||
|
llm = Client(host=ollama_server)
|
||||||
|
|
||||||
|
# Get embedding
|
||||||
|
try:
|
||||||
|
response = llm.embeddings(model=embedding_model, prompt=text)
|
||||||
|
embedding = np.array(response["embedding"])
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to get embedding: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Log diagnostics
|
||||||
|
logging.info(f"Input text: {text}")
|
||||||
|
logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
|
||||||
|
|
||||||
|
# Check for invalid embeddings
|
||||||
|
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
|
||||||
|
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
|
||||||
|
raise ValueError("Invalid embedding returned from Ollama.")
|
||||||
|
|
||||||
|
# Check normalization
|
||||||
|
norm = np.linalg.norm(embedding)
|
||||||
|
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
|
||||||
|
logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
|
||||||
|
|
||||||
|
# Normalize if needed
|
||||||
|
if not is_normalized:
|
||||||
|
embedding = embedding / norm
|
||||||
|
logging.info("Embedding normalized manually.")
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to generate and normalize an embedding from command-line input."""
|
||||||
|
parser = argparse.ArgumentParser(description="Generate embeddings for text using mxbai-embed-large.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--text",
|
||||||
|
type=str,
|
||||||
|
nargs="+", # Allow multiple text inputs
|
||||||
|
default=["Test sentence."],
|
||||||
|
help="Text(s) to generate embeddings for (default: 'Test sentence.')",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ollama-server",
|
||||||
|
type=str,
|
||||||
|
default=defines.ollama_api_url,
|
||||||
|
help=f"Ollama server URL (default: {defines.ollama_api_url})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default=defines.embedding_model,
|
||||||
|
help=f"Embedding model name (default: {defines.embedding_model})",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Validate input
|
||||||
|
for text in args.text:
|
||||||
|
if not text or not isinstance(text, str):
|
||||||
|
logging.error("Input text must be a non-empty string.")
|
||||||
|
raise ValueError("Input text must be a non-empty string.")
|
||||||
|
|
||||||
|
# Generate embeddings for each text
|
||||||
|
embeddings = []
|
||||||
|
for text in args.text:
|
||||||
|
embedding = get_embedding(
|
||||||
|
text=text,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
ollama_server=args.ollama_server,
|
||||||
|
)
|
||||||
|
embeddings.append(embedding)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -2,10 +2,15 @@
|
|||||||
# python -m src.tests.test-message
|
# python -m src.tests.test-message
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
|
|
||||||
from ..utils import Message
|
from ..utils import Message, MessageMetaData
|
||||||
|
from ..utils import ChromaDBGetResponse
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
prompt = "This is a test"
|
prompt = "This is a test"
|
||||||
message = Message(prompt=prompt)
|
message = Message(prompt=prompt)
|
||||||
print(message.model_dump(mode="json"))
|
print(message.model_dump(mode="json"))
|
||||||
|
#message.metadata = MessageMetaData()
|
||||||
|
rag = ChromaDBGetResponse()
|
||||||
|
message.metadata.rag = rag
|
||||||
|
print(message.model_dump(mode="json"))
|
||||||
|
81
src/tests/test-rag.py
Normal file
81
src/tests/test-rag.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
# From /opt/backstory run:
|
||||||
|
# python -m src.tests.test-rag
|
||||||
|
from ..utils import logger
|
||||||
|
from pydantic import BaseModel, field_validator # type: ignore
|
||||||
|
from prometheus_client import CollectorRegistry # type: ignore
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
import ollama
|
||||||
|
import numpy as np # type: ignore
|
||||||
|
from ..utils import (rag as Rag, ChromaDBGetResponse)
|
||||||
|
from ..utils import Context
|
||||||
|
from ..utils import defines
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
chroma_results = {
|
||||||
|
"ids": ["1", "2"],
|
||||||
|
"embeddings": np.array([[1.0, 2.0], [3.0, 4.0]]),
|
||||||
|
"documents": ["doc1", "doc2"],
|
||||||
|
"metadatas": [{"meta": "data1"}, {"meta": "data2"}],
|
||||||
|
"query_embedding": np.array([0.1, 0.2, 0.3])
|
||||||
|
}
|
||||||
|
|
||||||
|
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
||||||
|
umap_2d = np.array([0.4, 0.5]) # Example UMAP output
|
||||||
|
umap_3d = np.array([0.6, 0.7, 0.8]) # Example UMAP output
|
||||||
|
|
||||||
|
rag_metadata = ChromaDBGetResponse(
|
||||||
|
query="test",
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
name="JPK",
|
||||||
|
ids=chroma_results.get("ids", []),
|
||||||
|
size=2
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(json.dumps(rag_metadata.model_dump(mode="json")))
|
||||||
|
|
||||||
|
logger.info(f"Assigning type {type(umap_2d)} to rag_metadata.umap_embedding_2d")
|
||||||
|
rag_metadata.umap_embedding_2d = umap_2d
|
||||||
|
|
||||||
|
logger.info(json.dumps(rag_metadata.model_dump(mode="json")))
|
||||||
|
|
||||||
|
rag = ChromaDBGetResponse()
|
||||||
|
rag.embeddings = np.array([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
json_str = rag.model_dump(mode="json")
|
||||||
|
logger.info(json_str)
|
||||||
|
rag = ChromaDBGetResponse.model_validate(json_str)
|
||||||
|
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
||||||
|
prometheus_collector = CollectorRegistry()
|
||||||
|
observer, file_watcher = Rag.start_file_watcher(
|
||||||
|
llm=llm,
|
||||||
|
watch_directory=defines.doc_dir,
|
||||||
|
recreate=False, # Don't recreate if exists
|
||||||
|
)
|
||||||
|
context = Context(
|
||||||
|
file_watcher=file_watcher,
|
||||||
|
prometheus_collector=prometheus_collector,
|
||||||
|
)
|
||||||
|
skill="Codes in C++"
|
||||||
|
if context.file_watcher:
|
||||||
|
chroma_results = context.file_watcher.find_similar(query=skill, top_k=10, threshold=0.5)
|
||||||
|
if chroma_results:
|
||||||
|
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
||||||
|
|
||||||
|
umap_2d = context.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
||||||
|
umap_3d = context.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
||||||
|
|
||||||
|
rag_metadata = ChromaDBGetResponse(
|
||||||
|
query=skill,
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
name="JPK",
|
||||||
|
ids=chroma_results.get("ids", []),
|
||||||
|
embeddings=chroma_results.get("embeddings", []),
|
||||||
|
documents=chroma_results.get("documents", []),
|
||||||
|
metadatas=chroma_results.get("metadatas", []),
|
||||||
|
umap_embedding_2d=umap_2d,
|
||||||
|
umap_embedding_3d=umap_3d,
|
||||||
|
size=context.file_watcher.collection.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
json_str = context.model_dump(mode="json")
|
||||||
|
logger.info(json_str)
|
@ -1,25 +1,34 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel # type: ignore
|
from pydantic import BaseModel # type: ignore
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Set
|
||||||
|
)
|
||||||
import importlib
|
import importlib
|
||||||
|
import json
|
||||||
|
|
||||||
from . import defines
|
from . import defines
|
||||||
from .context import Context
|
from .context import Context
|
||||||
from .conversation import Conversation
|
from .conversation import Conversation
|
||||||
from .message import Message, Tunables
|
from .message import Message, Tunables, MessageMetaData
|
||||||
from .rag import ChromaDBFileWatcher, start_file_watcher
|
from .rag import ChromaDBFileWatcher, ChromaDBGetResponse, start_file_watcher
|
||||||
from .setup_logging import setup_logging
|
from .setup_logging import setup_logging
|
||||||
from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
|
from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
|
||||||
from .metrics import Metrics
|
from .metrics import Metrics
|
||||||
|
from .check_serializable import check_serializable
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Agent",
|
"Agent",
|
||||||
|
"Message",
|
||||||
"Tunables",
|
"Tunables",
|
||||||
|
"MessageMetaData",
|
||||||
"Context",
|
"Context",
|
||||||
"Conversation",
|
"Conversation",
|
||||||
"Message",
|
|
||||||
"Metrics",
|
"Metrics",
|
||||||
"ChromaDBFileWatcher",
|
"ChromaDBFileWatcher",
|
||||||
|
'ChromaDBGetResponse',
|
||||||
"start_file_watcher",
|
"start_file_watcher",
|
||||||
|
"check_serializable",
|
||||||
"logger",
|
"logger",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -165,10 +165,9 @@ class Agent(BaseModel, ABC):
|
|||||||
if message.status != "done":
|
if message.status != "done":
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
if "rag" in message.metadata and message.metadata["rag"]:
|
for rag in message.metadata.rag:
|
||||||
for rag in message.metadata["rag"]:
|
for doc in rag.documents:
|
||||||
for doc in rag["documents"]:
|
rag_context += f"{doc}\n"
|
||||||
rag_context += f"{doc}\n"
|
|
||||||
|
|
||||||
message.preamble = {}
|
message.preamble = {}
|
||||||
|
|
||||||
@ -189,7 +188,7 @@ class Agent(BaseModel, ABC):
|
|||||||
llm: Any,
|
llm: Any,
|
||||||
model: str,
|
model: str,
|
||||||
message: Message,
|
message: Message,
|
||||||
tool_message: Any,
|
tool_message: Any, # llama response message
|
||||||
messages: List[LLMMessage],
|
messages: List[LLMMessage],
|
||||||
) -> AsyncGenerator[Message, None]:
|
) -> AsyncGenerator[Message, None]:
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
@ -199,10 +198,10 @@ class Agent(BaseModel, ABC):
|
|||||||
|
|
||||||
if not self.context:
|
if not self.context:
|
||||||
raise ValueError("Context is not set for this agent.")
|
raise ValueError("Context is not set for this agent.")
|
||||||
if not message.metadata["tools"]:
|
if not message.metadata.tools:
|
||||||
raise ValueError("tools field not initialized")
|
raise ValueError("tools field not initialized")
|
||||||
|
|
||||||
tool_metadata = message.metadata["tools"]
|
tool_metadata = message.metadata.tools
|
||||||
tool_metadata["tool_calls"] = []
|
tool_metadata["tool_calls"] = []
|
||||||
|
|
||||||
message.status = "tooling"
|
message.status = "tooling"
|
||||||
@ -301,8 +300,7 @@ class Agent(BaseModel, ABC):
|
|||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
options={
|
options={
|
||||||
**message.metadata["options"],
|
**message.metadata.options,
|
||||||
# "temperature": 0.5,
|
|
||||||
},
|
},
|
||||||
stream=True,
|
stream=True,
|
||||||
):
|
):
|
||||||
@ -316,12 +314,10 @@ class Agent(BaseModel, ABC):
|
|||||||
|
|
||||||
if response.done:
|
if response.done:
|
||||||
self.collect_metrics(response)
|
self.collect_metrics(response)
|
||||||
message.metadata["eval_count"] += response.eval_count
|
message.metadata.eval_count += response.eval_count
|
||||||
message.metadata["eval_duration"] += response.eval_duration
|
message.metadata.eval_duration += response.eval_duration
|
||||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||||
message.metadata[
|
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||||
"prompt_eval_duration"
|
|
||||||
] += response.prompt_eval_duration
|
|
||||||
self.context_tokens = (
|
self.context_tokens = (
|
||||||
response.prompt_eval_count + response.eval_count
|
response.prompt_eval_count + response.eval_count
|
||||||
)
|
)
|
||||||
@ -329,9 +325,7 @@ class Agent(BaseModel, ABC):
|
|||||||
yield message
|
yield message
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
message.metadata["timers"][
|
message.metadata.timers["llm_with_tools"] = end_time - start_time
|
||||||
"llm_with_tools"
|
|
||||||
] = f"{(end_time - start_time):.4f}"
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def collect_metrics(self, response):
|
def collect_metrics(self, response):
|
||||||
@ -370,22 +364,22 @@ class Agent(BaseModel, ABC):
|
|||||||
LLMMessage(role="user", content=message.context_prompt.strip())
|
LLMMessage(role="user", content=message.context_prompt.strip())
|
||||||
)
|
)
|
||||||
|
|
||||||
# message.metadata["messages"] = messages
|
# message.messages = messages
|
||||||
message.metadata["options"] = {
|
message.metadata.options = {
|
||||||
"seed": 8911,
|
"seed": 8911,
|
||||||
"num_ctx": self.context_size,
|
"num_ctx": self.context_size,
|
||||||
"temperature": temperature, # Higher temperature to encourage tool usage
|
"temperature": temperature, # Higher temperature to encourage tool usage
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create a dict for storing various timing stats
|
# Create a dict for storing various timing stats
|
||||||
message.metadata["timers"] = {}
|
message.metadata.timers = {}
|
||||||
|
|
||||||
use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
||||||
message.metadata["tools"] = {
|
message.metadata.tools = {
|
||||||
"available": llm_tools(self.context.tools),
|
"available": llm_tools(self.context.tools),
|
||||||
"used": False,
|
"used": False,
|
||||||
}
|
}
|
||||||
tool_metadata = message.metadata["tools"]
|
tool_metadata = message.metadata.tools
|
||||||
|
|
||||||
if use_tools:
|
if use_tools:
|
||||||
message.status = "thinking"
|
message.status = "thinking"
|
||||||
@ -408,17 +402,14 @@ class Agent(BaseModel, ABC):
|
|||||||
messages=tool_metadata["messages"],
|
messages=tool_metadata["messages"],
|
||||||
tools=tool_metadata["available"],
|
tools=tool_metadata["available"],
|
||||||
options={
|
options={
|
||||||
**message.metadata["options"],
|
**message.metadata.options,
|
||||||
# "num_predict": 1024, # "Low" token limit to cut off after tool call
|
|
||||||
},
|
},
|
||||||
stream=False, # No need to stream the probe
|
stream=False, # No need to stream the probe
|
||||||
)
|
)
|
||||||
self.collect_metrics(response)
|
self.collect_metrics(response)
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
message.metadata["timers"][
|
message.metadata.timers["tool_check"] = end_time - start_time
|
||||||
"tool_check"
|
|
||||||
] = f"{(end_time - start_time):.4f}"
|
|
||||||
if not response.message.tool_calls:
|
if not response.message.tool_calls:
|
||||||
logger.info("LLM indicates tools will not be used")
|
logger.info("LLM indicates tools will not be used")
|
||||||
# The LLM will not use tools, so disable use_tools so we can stream the full response
|
# The LLM will not use tools, so disable use_tools so we can stream the full response
|
||||||
@ -442,16 +433,14 @@ class Agent(BaseModel, ABC):
|
|||||||
messages=tool_metadata["messages"], # messages,
|
messages=tool_metadata["messages"], # messages,
|
||||||
tools=tool_metadata["available"],
|
tools=tool_metadata["available"],
|
||||||
options={
|
options={
|
||||||
**message.metadata["options"],
|
**message.metadata.options,
|
||||||
},
|
},
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
self.collect_metrics(response)
|
self.collect_metrics(response)
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
message.metadata["timers"][
|
message.metadata.timers["non_streaming"] = end_time - start_time
|
||||||
"non_streaming"
|
|
||||||
] = f"{(end_time - start_time):.4f}"
|
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
@ -475,9 +464,7 @@ class Agent(BaseModel, ABC):
|
|||||||
return
|
return
|
||||||
yield message
|
yield message
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
message.metadata["timers"][
|
message.metadata.timers["process_tool_calls"] = end_time - start_time
|
||||||
"process_tool_calls"
|
|
||||||
] = f"{(end_time - start_time):.4f}"
|
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -498,7 +485,7 @@ class Agent(BaseModel, ABC):
|
|||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
options={
|
options={
|
||||||
**message.metadata["options"],
|
**message.metadata.options,
|
||||||
},
|
},
|
||||||
stream=True,
|
stream=True,
|
||||||
):
|
):
|
||||||
@ -517,12 +504,10 @@ class Agent(BaseModel, ABC):
|
|||||||
|
|
||||||
if response.done:
|
if response.done:
|
||||||
self.collect_metrics(response)
|
self.collect_metrics(response)
|
||||||
message.metadata["eval_count"] += response.eval_count
|
message.metadata.eval_count += response.eval_count
|
||||||
message.metadata["eval_duration"] += response.eval_duration
|
message.metadata.eval_duration += response.eval_duration
|
||||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||||
message.metadata[
|
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||||
"prompt_eval_duration"
|
|
||||||
] += response.prompt_eval_duration
|
|
||||||
self.context_tokens = (
|
self.context_tokens = (
|
||||||
response.prompt_eval_count + response.eval_count
|
response.prompt_eval_count + response.eval_count
|
||||||
)
|
)
|
||||||
@ -530,7 +515,7 @@ class Agent(BaseModel, ABC):
|
|||||||
yield message
|
yield message
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
|
message.metadata.timers["streamed"] = end_time - start_time
|
||||||
return
|
return
|
||||||
|
|
||||||
async def process_message(
|
async def process_message(
|
||||||
@ -560,7 +545,7 @@ class Agent(BaseModel, ABC):
|
|||||||
|
|
||||||
self.context.processing = True
|
self.context.processing = True
|
||||||
|
|
||||||
message.metadata["system_prompt"] = (
|
message.system_prompt = (
|
||||||
f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
|
f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
|
||||||
)
|
)
|
||||||
message.context_prompt = ""
|
message.context_prompt = ""
|
||||||
@ -575,11 +560,11 @@ class Agent(BaseModel, ABC):
|
|||||||
message.status = "thinking"
|
message.status = "thinking"
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
message.metadata["context_size"] = self.set_optimal_context_size(
|
message.context_size = self.set_optimal_context_size(
|
||||||
llm, model, prompt=message.context_prompt
|
llm, model, prompt=message.context_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
message.response = f"Processing {'RAG augmented ' if message.metadata.rag else ''}query..."
|
||||||
message.status = "thinking"
|
message.status = "thinking"
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
|
@ -19,9 +19,10 @@ import time
|
|||||||
import asyncio
|
import asyncio
|
||||||
import numpy as np # type: ignore
|
import numpy as np # type: ignore
|
||||||
|
|
||||||
from .base import Agent, agent_registry, LLMMessage
|
from . base import Agent, agent_registry, LLMMessage
|
||||||
from ..message import Message
|
from .. message import Message
|
||||||
from ..setup_logging import setup_logging
|
from .. rag import ChromaDBGetResponse
|
||||||
|
from .. setup_logging import setup_logging
|
||||||
|
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
|
|
||||||
@ -91,8 +92,11 @@ class JobDescription(Agent):
|
|||||||
await asyncio.sleep(1) # Allow the event loop to process the write
|
await asyncio.sleep(1) # Allow the event loop to process the write
|
||||||
|
|
||||||
self.context.processing = True
|
self.context.processing = True
|
||||||
|
job_description = message.preamble["job_description"]
|
||||||
|
resume = message.preamble["resume"]
|
||||||
|
|
||||||
original_message = message.model_copy()
|
original_message = message.model_copy()
|
||||||
|
original_message.prompt = job_description
|
||||||
original_message.response = ""
|
original_message.response = ""
|
||||||
self.conversation.add(original_message)
|
self.conversation.add(original_message)
|
||||||
|
|
||||||
@ -100,11 +104,9 @@ class JobDescription(Agent):
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||||
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
||||||
job_description = message.preamble["job_description"]
|
|
||||||
resume = message.preamble["resume"]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for message in self.generate_factual_tailored_resume(
|
async for message in self.generate_resume(
|
||||||
message=message, job_description=job_description, resume=resume
|
message=message, job_description=job_description, resume=resume
|
||||||
):
|
):
|
||||||
if message.status != "done":
|
if message.status != "done":
|
||||||
@ -124,7 +126,7 @@ class JobDescription(Agent):
|
|||||||
# Done processing, add message to conversation
|
# Done processing, add message to conversation
|
||||||
self.context.processing = False
|
self.context.processing = False
|
||||||
|
|
||||||
resume_generation = message.metadata.get("resume_generation", {})
|
resume_generation = message.metadata.resume_generation
|
||||||
if not resume_generation:
|
if not resume_generation:
|
||||||
message.response = (
|
message.response = (
|
||||||
"Generation did not generate metadata necessary for processing."
|
"Generation did not generate metadata necessary for processing."
|
||||||
@ -341,18 +343,19 @@ Name: {candidate_name}
|
|||||||
## OUTPUT FORMAT:
|
## OUTPUT FORMAT:
|
||||||
Provide the resume in clean markdown format, ready for the candidate to use.
|
Provide the resume in clean markdown format, ready for the candidate to use.
|
||||||
|
|
||||||
## REFERENCE (Original Resume):
|
|
||||||
"""
|
"""
|
||||||
|
# ## REFERENCE (Original Resume):
|
||||||
|
# """
|
||||||
|
|
||||||
# Add a truncated version of the original resume for reference if it's too long
|
# # Add a truncated version of the original resume for reference if it's too long
|
||||||
max_resume_length = 25000 # Characters
|
# max_resume_length = 25000 # Characters
|
||||||
if len(original_resume) > max_resume_length:
|
# if len(original_resume) > max_resume_length:
|
||||||
system_prompt += (
|
# system_prompt += (
|
||||||
original_resume[:max_resume_length]
|
# original_resume[:max_resume_length]
|
||||||
+ "...\n[Original resume truncated due to length]"
|
# + "...\n[Original resume truncated due to length]"
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
system_prompt += original_resume
|
# system_prompt += original_resume
|
||||||
|
|
||||||
prompt = "Create a tailored professional resume that highlights candidate's skills and experience most relevant to the job requirements. Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no commentary before or after."
|
prompt = "Create a tailored professional resume that highlights candidate's skills and experience most relevant to the job requirements. Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no commentary before or after."
|
||||||
return system_prompt, prompt
|
return system_prompt, prompt
|
||||||
@ -413,7 +416,7 @@ Provide the resume in clean markdown format, ready for the candidate to use.
|
|||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
def calculate_match_statistics(self, job_requirements, skill_assessment_results):
|
def calculate_match_statistics(self, job_requirements, skill_assessment_results) -> dict[str, dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Calculate statistics about how well the candidate matches job requirements
|
Calculate statistics about how well the candidate matches job requirements
|
||||||
|
|
||||||
@ -594,9 +597,10 @@ a SPECIFIC skill based solely on their resume and supporting evidence.
|
|||||||
}}
|
}}
|
||||||
```
|
```
|
||||||
|
|
||||||
## CANDIDATE RESUME:
|
|
||||||
{resume}
|
|
||||||
"""
|
"""
|
||||||
|
# ## CANDIDATE RESUME:
|
||||||
|
# {resume}
|
||||||
|
# """
|
||||||
|
|
||||||
# Add RAG content if provided
|
# Add RAG content if provided
|
||||||
if rag_content:
|
if rag_content:
|
||||||
@ -821,7 +825,7 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
|
|||||||
LLMMessage(role="system", content=system_prompt),
|
LLMMessage(role="system", content=system_prompt),
|
||||||
LLMMessage(role="user", content=prompt),
|
LLMMessage(role="user", content=prompt),
|
||||||
]
|
]
|
||||||
message.metadata["options"] = {
|
message.metadata.options = {
|
||||||
"seed": 8911,
|
"seed": 8911,
|
||||||
"num_ctx": self.context_size,
|
"num_ctx": self.context_size,
|
||||||
"temperature": temperature, # Higher temperature to encourage tool usage
|
"temperature": temperature, # Higher temperature to encourage tool usage
|
||||||
@ -837,7 +841,7 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
options={
|
options={
|
||||||
**message.metadata["options"],
|
**message.metadata.options,
|
||||||
},
|
},
|
||||||
stream=True,
|
stream=True,
|
||||||
):
|
):
|
||||||
@ -860,54 +864,40 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
|
|||||||
|
|
||||||
if response.done:
|
if response.done:
|
||||||
self.collect_metrics(response)
|
self.collect_metrics(response)
|
||||||
message.metadata["eval_count"] += response.eval_count
|
message.metadata.eval_count += response.eval_count
|
||||||
message.metadata["eval_duration"] += response.eval_duration
|
message.metadata.eval_duration += response.eval_duration
|
||||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||||
message.metadata[
|
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||||
"prompt_eval_duration"
|
|
||||||
] += response.prompt_eval_duration
|
|
||||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
self.context_tokens = response.prompt_eval_count + response.eval_count
|
||||||
message.chunk = ""
|
message.chunk = ""
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
def rag_function(self, skill: str) -> tuple[str, list[Any]]:
|
def retrieve_rag_content(self, skill: str) -> tuple[str, ChromaDBGetResponse]:
|
||||||
if self.context is None or self.context.file_watcher is None:
|
if self.context is None or self.context.file_watcher is None:
|
||||||
raise ValueError("self.context or self.context.file_watcher is None")
|
raise ValueError("self.context or self.context.file_watcher is None")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rag_results = ""
|
rag_results = ""
|
||||||
all_metadata = []
|
rag_metadata = ChromaDBGetResponse()
|
||||||
chroma_results = self.context.file_watcher.find_similar(
|
chroma_results = self.context.file_watcher.find_similar(query=skill, top_k=10, threshold=0.5)
|
||||||
query=skill, top_k=5, threshold=0.5
|
|
||||||
)
|
|
||||||
if chroma_results:
|
if chroma_results:
|
||||||
chroma_embedding = np.array(
|
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
||||||
chroma_results["query_embedding"]
|
|
||||||
).flatten() # Ensure correct shape
|
|
||||||
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
|
||||||
|
|
||||||
umap_2d = self.context.file_watcher.umap_model_2d.transform(
|
umap_2d = self.context.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
||||||
[chroma_embedding]
|
umap_3d = self.context.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
||||||
)[0].tolist()
|
|
||||||
print(
|
|
||||||
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
|
|
||||||
) # Debug output
|
|
||||||
|
|
||||||
umap_3d = self.context.file_watcher.umap_model_3d.transform(
|
rag_metadata = ChromaDBGetResponse(
|
||||||
[chroma_embedding]
|
query=skill,
|
||||||
)[0].tolist()
|
query_embedding=query_embedding.tolist(),
|
||||||
print(
|
name="JPK",
|
||||||
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
|
ids=chroma_results.get("ids", []),
|
||||||
) # Debug output
|
embeddings=chroma_results.get("embeddings", []),
|
||||||
|
documents=chroma_results.get("documents", []),
|
||||||
all_metadata.append(
|
metadatas=chroma_results.get("metadatas", []),
|
||||||
{
|
umap_embedding_2d=umap_2d.tolist(),
|
||||||
"name": "JPK",
|
umap_embedding_3d=umap_3d.tolist(),
|
||||||
**chroma_results,
|
size=self.context.file_watcher.collection.count()
|
||||||
"umap_embedding_2d": umap_2d,
|
|
||||||
"umap_embedding_3d": umap_3d,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for index, metadata in enumerate(chroma_results["metadatas"]):
|
for index, metadata in enumerate(chroma_results["metadatas"]):
|
||||||
@ -919,18 +909,19 @@ IMPORTANT: Be factual and precise. If you cannot find strong evidence for this s
|
|||||||
]
|
]
|
||||||
).strip()
|
).strip()
|
||||||
rag_results += f"""
|
rag_results += f"""
|
||||||
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")} lines {metadata.get("line_begin", 0)}-{metadata.get("line_end", 0)}
|
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
||||||
|
Document reference: {chroma_results["ids"][index]}
|
||||||
Content: { content }
|
Content: { content }
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return rag_results, all_metadata
|
return rag_results, rag_metadata
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(e)
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
async def generate_factual_tailored_resume(
|
async def generate_resume(
|
||||||
self, message: Message, job_description: str, resume: str
|
self, message: Message, job_description: str, resume: str
|
||||||
) -> AsyncGenerator[Message, None]:
|
) -> AsyncGenerator[Message, None]:
|
||||||
"""
|
"""
|
||||||
@ -947,12 +938,8 @@ Content: { content }
|
|||||||
if self.context is None:
|
if self.context is None:
|
||||||
raise ValueError(f"context is None in {self.agent_type}")
|
raise ValueError(f"context is None in {self.agent_type}")
|
||||||
|
|
||||||
message.status = "thinking"
|
message.metadata.resume_generation = {}
|
||||||
logger.info(message.response)
|
metadata = message.metadata.resume_generation
|
||||||
yield message
|
|
||||||
|
|
||||||
message.metadata["resume_generation"] = {}
|
|
||||||
metadata = message.metadata["resume_generation"]
|
|
||||||
# Stage 1A: Analyze job requirements
|
# Stage 1A: Analyze job requirements
|
||||||
streaming_message = Message(prompt="Analyze job requirements")
|
streaming_message = Message(prompt="Analyze job requirements")
|
||||||
streaming_message.status = "thinking"
|
streaming_message.status = "thinking"
|
||||||
@ -975,8 +962,8 @@ Content: { content }
|
|||||||
prompts = self.process_job_requirements(
|
prompts = self.process_job_requirements(
|
||||||
job_requirements=job_requirements,
|
job_requirements=job_requirements,
|
||||||
resume=resume,
|
resume=resume,
|
||||||
rag_function=self.rag_function,
|
rag_function=self.retrieve_rag_content,
|
||||||
) # , retrieve_rag_content)
|
)
|
||||||
|
|
||||||
# UI should persist this state of the message
|
# UI should persist this state of the message
|
||||||
partial = message.model_copy()
|
partial = message.model_copy()
|
||||||
@ -1051,9 +1038,15 @@ Content: { content }
|
|||||||
partial.title = (
|
partial.title = (
|
||||||
f"Skill {index}/{total_prompts}: {description} [{match_level}]"
|
f"Skill {index}/{total_prompts}: {description} [{match_level}]"
|
||||||
)
|
)
|
||||||
partial.metadata["rag"] = rag
|
partial.metadata.rag = [rag] # Front-end expects a list of RAG retrievals
|
||||||
if skill_description:
|
if skill_description:
|
||||||
partial.response += f"\n\n{skill_description}"
|
partial.response = f"""
|
||||||
|
```json
|
||||||
|
{json.dumps(skill_assessment_results[skill_name]["skill_assessment"])}
|
||||||
|
```
|
||||||
|
|
||||||
|
{skill_description}
|
||||||
|
"""
|
||||||
yield partial
|
yield partial
|
||||||
self.conversation.add(partial)
|
self.conversation.add(partial)
|
||||||
|
|
||||||
|
@ -7,9 +7,10 @@ import numpy as np # type: ignore
|
|||||||
import logging
|
import logging
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||||
|
import traceback
|
||||||
|
|
||||||
from .message import Message, Tunables
|
from .message import Message, Tunables
|
||||||
from .rag import ChromaDBFileWatcher
|
from .rag import ChromaDBFileWatcher, ChromaDBGetResponse
|
||||||
from . import defines
|
from . import defines
|
||||||
from . import tools as Tools
|
from . import tools as Tools
|
||||||
from .agents import AnyAgent
|
from .agents import AnyAgent
|
||||||
@ -35,7 +36,7 @@ class Context(BaseModel):
|
|||||||
user_job_description: Optional[str] = None
|
user_job_description: Optional[str] = None
|
||||||
user_facts: Optional[str] = None
|
user_facts: Optional[str] = None
|
||||||
tools: List[dict] = Tools.enabled_tools(Tools.tools)
|
tools: List[dict] = Tools.enabled_tools(Tools.tools)
|
||||||
rags: List[dict] = []
|
rags: List[ChromaDBGetResponse] = []
|
||||||
message_history_length: int = 5
|
message_history_length: int = 5
|
||||||
# Class managed fields
|
# Class managed fields
|
||||||
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
|
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
|
||||||
@ -82,56 +83,40 @@ class Context(BaseModel):
|
|||||||
|
|
||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
message.response = "No RAG context available."
|
message.response = "No RAG context available."
|
||||||
del message.metadata["rag"]
|
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
message.metadata["rag"] = []
|
|
||||||
for rag in self.rags:
|
for rag in self.rags:
|
||||||
if not rag["enabled"]:
|
if not rag.enabled:
|
||||||
continue
|
continue
|
||||||
message.response = f"Checking RAG context {rag['name']}..."
|
message.response = f"Checking RAG context {rag.name}..."
|
||||||
yield message
|
yield message
|
||||||
chroma_results = self.file_watcher.find_similar(
|
chroma_results = self.file_watcher.find_similar(
|
||||||
query=message.prompt, top_k=top_k, threshold=threshold
|
query=message.prompt, top_k=top_k, threshold=threshold
|
||||||
)
|
)
|
||||||
if chroma_results:
|
if chroma_results:
|
||||||
entries += len(chroma_results["documents"])
|
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
||||||
|
|
||||||
chroma_embedding = np.array(
|
umap_2d = self.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
||||||
chroma_results["query_embedding"]
|
umap_3d = self.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
||||||
).flatten() # Ensure correct shape
|
|
||||||
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
|
||||||
|
|
||||||
umap_2d = self.file_watcher.umap_model_2d.transform(
|
rag_metadata = ChromaDBGetResponse(
|
||||||
[chroma_embedding]
|
query=message.prompt,
|
||||||
)[0].tolist()
|
query_embedding=query_embedding.tolist(),
|
||||||
print(
|
name=rag.name,
|
||||||
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
|
ids=chroma_results.get("ids", []),
|
||||||
) # Debug output
|
embeddings=chroma_results.get("embeddings", []),
|
||||||
|
documents=chroma_results.get("documents", []),
|
||||||
umap_3d = self.file_watcher.umap_model_3d.transform(
|
metadatas=chroma_results.get("metadatas", []),
|
||||||
[chroma_embedding]
|
umap_embedding_2d=umap_2d.tolist(),
|
||||||
)[0].tolist()
|
umap_embedding_3d=umap_3d.tolist(),
|
||||||
print(
|
size=self.file_watcher.collection.count()
|
||||||
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
|
|
||||||
) # Debug output
|
|
||||||
|
|
||||||
message.metadata["rag"].append(
|
|
||||||
{
|
|
||||||
"name": rag["name"],
|
|
||||||
**chroma_results,
|
|
||||||
"umap_embedding_2d": umap_2d,
|
|
||||||
"umap_embedding_3d": umap_3d,
|
|
||||||
"size": self.file_watcher.collection.count()
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
message.response = f"Results from {rag['name']} RAG: {len(chroma_results['documents'])} results."
|
|
||||||
yield message
|
|
||||||
|
|
||||||
if entries == 0:
|
message.metadata.rag.append(rag_metadata)
|
||||||
del message.metadata["rag"]
|
message.response = f"Results from {rag.name} RAG: {len(chroma_results['documents'])} results."
|
||||||
|
yield message
|
||||||
|
|
||||||
message.response = (
|
message.response = (
|
||||||
f"RAG context gathered from results from {entries} documents."
|
f"RAG context gathered from results from {entries} documents."
|
||||||
@ -142,7 +127,8 @@ class Context(BaseModel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
message.response = f"Error generating RAG results: {str(e)}"
|
message.response = f"Error generating RAG results: {str(e)}"
|
||||||
logger.error(e)
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(message.response)
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1,12 +1,28 @@
|
|||||||
from pydantic import BaseModel, Field # type: ignore
|
from pydantic import BaseModel, Field # type: ignore
|
||||||
from typing import Dict, List, Optional, Any
|
from typing import Dict, List, Optional, Any, Union, Mapping
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from . rag import ChromaDBGetResponse
|
||||||
|
from ollama._types import Options # type: ignore
|
||||||
|
|
||||||
class Tunables(BaseModel):
|
class Tunables(BaseModel):
|
||||||
enable_rag: bool = Field(default=True) # Enable RAG collection chromadb matching
|
enable_rag: bool = True # Enable RAG collection chromadb matching
|
||||||
enable_tools: bool = Field(default=True) # Enable LLM to use tools
|
enable_tools: bool = True # Enable LLM to use tools
|
||||||
enable_context: bool = Field(default=True) # Add <|context|> field to message
|
enable_context: bool = True # Add <|context|> field to message
|
||||||
|
|
||||||
|
class MessageMetaData(BaseModel):
|
||||||
|
rag: List[ChromaDBGetResponse] = Field(default_factory=list)
|
||||||
|
eval_count: int = 0
|
||||||
|
eval_duration: int = 0
|
||||||
|
prompt_eval_count: int = 0
|
||||||
|
prompt_eval_duration: int = 0
|
||||||
|
context_size: int = 0
|
||||||
|
resume_generation: Optional[Dict[str, Any]] = None
|
||||||
|
options: Optional[Union[Mapping[str, Any], Options]] = None
|
||||||
|
tools: Optional[Dict[str, Any]] = None
|
||||||
|
timers: Optional[Dict[str, float]] = None
|
||||||
|
|
||||||
|
#resume : str = ""
|
||||||
|
#match_stats: Optional[Dict[str, Dict[str, Any]]] = Field(default=None)
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
model_config = {"arbitrary_types_allowed": True} # Allow Event
|
model_config = {"arbitrary_types_allowed": True} # Allow Event
|
||||||
@ -18,35 +34,21 @@ class Message(BaseModel):
|
|||||||
|
|
||||||
# Generated while processing message
|
# Generated while processing message
|
||||||
status: str = "" # Status of the message
|
status: str = "" # Status of the message
|
||||||
preamble: dict[str, str] = {} # Preamble to be prepended to the prompt
|
preamble: Dict[str, Any] = Field(default_factory=dict) # Preamble to be prepended to the prompt
|
||||||
system_prompt: str = "" # System prompt provided to the LLM
|
system_prompt: str = "" # System prompt provided to the LLM
|
||||||
context_prompt: str = "" # Full content of the message (preamble + prompt)
|
context_prompt: str = "" # Full content of the message (preamble + prompt)
|
||||||
response: str = "" # LLM response to the preamble + query
|
response: str = "" # LLM response to the preamble + query
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: MessageMetaData = Field(default_factory=MessageMetaData)
|
||||||
default_factory=lambda: {
|
|
||||||
"rag": [],
|
|
||||||
"eval_count": 0,
|
|
||||||
"eval_duration": 0,
|
|
||||||
"prompt_eval_count": 0,
|
|
||||||
"prompt_eval_duration": 0,
|
|
||||||
"context_size": 0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
network_packets: int = 0 # Total number of streaming packets
|
network_packets: int = 0 # Total number of streaming packets
|
||||||
network_bytes: int = 0 # Total bytes sent while streaming packets
|
network_bytes: int = 0 # Total bytes sent while streaming packets
|
||||||
actions: List[str] = (
|
actions: List[str] = (
|
||||||
[]
|
[]
|
||||||
) # Other session modifying actions performed while processing the message
|
) # Other session modifying actions performed while processing the message
|
||||||
timestamp: datetime = datetime.now(timezone.utc)
|
timestamp: str = str(datetime.now(timezone.utc))
|
||||||
chunk: str = Field(
|
chunk: str = ""
|
||||||
default=""
|
partial_response: str = ""
|
||||||
) # This needs to be serialized so it will be sent in responses
|
title: str = ""
|
||||||
partial_response: str = Field(
|
context_size: int = 0
|
||||||
default=""
|
|
||||||
) # This needs to be serialized so it will be sent in responses on timeout
|
|
||||||
title: str = Field(
|
|
||||||
default=""
|
|
||||||
) # This needs to be serialized so it will be sent in responses on timeout
|
|
||||||
|
|
||||||
def add_action(self, action: str | list[str]) -> None:
|
def add_action(self, action: str | list[str]) -> None:
|
||||||
"""Add a actions(s) to the message."""
|
"""Add a actions(s) to the message."""
|
||||||
|
110
src/utils/rag.py
110
src/utils/rag.py
@ -1,5 +1,5 @@
|
|||||||
from pydantic import BaseModel # type: ignore
|
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any, Union
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -7,15 +7,9 @@ import time
|
|||||||
import hashlib
|
import hashlib
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
import numpy as np # type: ignore
|
import numpy as np # type: ignore
|
||||||
import traceback
|
import traceback
|
||||||
import os
|
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import ollama
|
import ollama
|
||||||
@ -38,19 +32,51 @@ else:
|
|||||||
# When imported as a module, use relative imports
|
# When imported as a module, use relative imports
|
||||||
from . import defines
|
from . import defines
|
||||||
|
|
||||||
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
|
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "ChromaDBGetResponse"]
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 750
|
DEFAULT_CHUNK_SIZE = 750
|
||||||
DEFAULT_CHUNK_OVERLAP = 100
|
DEFAULT_CHUNK_OVERLAP = 100
|
||||||
|
|
||||||
|
|
||||||
class ChromaDBGetResponse(BaseModel):
|
class ChromaDBGetResponse(BaseModel):
|
||||||
ids: List[str]
|
name: str = ""
|
||||||
embeddings: Optional[List[List[float]]] = None
|
size: int = 0
|
||||||
documents: Optional[List[str]] = None
|
ids: List[str] = []
|
||||||
metadatas: Optional[List[Dict[str, Any]]] = None
|
embeddings: List[List[float]] = Field(default=[])
|
||||||
|
documents: List[str] = []
|
||||||
|
metadatas: List[Dict[str, Any]] = []
|
||||||
|
query: str = ""
|
||||||
|
query_embedding: Optional[List[float]] = Field(default=None)
|
||||||
|
umap_embedding_2d: Optional[List[float]] = Field(default=None)
|
||||||
|
umap_embedding_3d: Optional[List[float]] = Field(default=None)
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
validate_assignment = True
|
||||||
|
|
||||||
|
@field_validator("embeddings", "query_embedding", "umap_embedding_2d", "umap_embedding_3d")
|
||||||
|
@classmethod
|
||||||
|
def validate_embeddings(cls, value, field):
|
||||||
|
logging.info(f"Validating {field.field_name} with value: {type(value)} - {value}")
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
if field.field_name == "embeddings":
|
||||||
|
if value.ndim != 2:
|
||||||
|
raise ValueError(f"{field.name} must be a 2-dimensional NumPy array")
|
||||||
|
return [[float(x) for x in row] for row in value.tolist()]
|
||||||
|
else:
|
||||||
|
if value.ndim != 1:
|
||||||
|
raise ValueError(f"{field.field_name} must be a 1-dimensional NumPy array")
|
||||||
|
return [float(x) for x in value.tolist()]
|
||||||
|
if field.field_name == "embeddings":
|
||||||
|
if not all(isinstance(sublist, list) and all(isinstance(x, (int, float)) for x in sublist) for sublist in value):
|
||||||
|
raise ValueError(f"{field.field_name} must be a list of lists of floats")
|
||||||
|
return [[float(x) for x in sublist] for sublist in value]
|
||||||
|
else:
|
||||||
|
if not isinstance(value, list) or not all(isinstance(x, (int, float)) for x in value):
|
||||||
|
raise ValueError(f"{field.field_name} must be a list of floats")
|
||||||
|
return [float(x) for x in value]
|
||||||
|
|
||||||
class ChromaDBFileWatcher(FileSystemEventHandler):
|
class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -323,13 +349,13 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
n_components=2,
|
n_components=2,
|
||||||
random_state=8911,
|
random_state=8911,
|
||||||
metric="cosine",
|
metric="cosine",
|
||||||
n_neighbors=15,
|
n_neighbors=30,
|
||||||
min_dist=0.1,
|
min_dist=0.1,
|
||||||
)
|
)
|
||||||
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
|
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
|
||||||
logging.info(
|
# logging.info(
|
||||||
f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
|
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
|
||||||
) # Should be 2
|
# ) # Should be 2
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
||||||
@ -338,13 +364,13 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
n_components=3,
|
n_components=3,
|
||||||
random_state=8911,
|
random_state=8911,
|
||||||
metric="cosine",
|
metric="cosine",
|
||||||
n_neighbors=15,
|
n_neighbors=30,
|
||||||
min_dist=0.1,
|
min_dist=0.01,
|
||||||
)
|
)
|
||||||
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
|
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
|
||||||
logging.info(
|
# logging.info(
|
||||||
f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
|
# f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
|
||||||
) # Should be 3
|
# ) # Should be 3
|
||||||
|
|
||||||
def _get_vector_collection(self, recreate=False) -> Collection:
|
def _get_vector_collection(self, recreate=False) -> Collection:
|
||||||
"""Get or create a ChromaDB collection."""
|
"""Get or create a ChromaDB collection."""
|
||||||
@ -380,14 +406,36 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
"""Split documents into chunks using the text splitter."""
|
"""Split documents into chunks using the text splitter."""
|
||||||
return self.text_splitter.split_documents(docs)
|
return self.text_splitter.split_documents(docs)
|
||||||
|
|
||||||
def get_embedding(self, text, normalize=True):
|
def get_embedding(self, text: str) -> np.ndarray:
|
||||||
"""Generate embeddings using Ollama."""
|
"""Generate and normalize an embedding for the given text."""
|
||||||
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
|
|
||||||
embedding = response["embedding"]
|
# Get embedding
|
||||||
|
try:
|
||||||
|
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
|
||||||
|
embedding = np.array(response["embedding"])
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to get embedding: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Log diagnostics
|
||||||
|
logging.info(f"Input text: {text}")
|
||||||
|
logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
|
||||||
|
|
||||||
|
# Check for invalid embeddings
|
||||||
|
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
|
||||||
|
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
|
||||||
|
raise ValueError("Invalid embedding returned from Ollama.")
|
||||||
|
|
||||||
|
# Check normalization
|
||||||
|
norm = np.linalg.norm(embedding)
|
||||||
|
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
|
||||||
|
logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
|
||||||
|
|
||||||
|
# Normalize if needed
|
||||||
|
if not is_normalized:
|
||||||
|
embedding = embedding / norm
|
||||||
|
logging.info("Embedding normalized manually.")
|
||||||
|
|
||||||
if normalize:
|
|
||||||
normalized = self._normalize_embeddings(embedding)
|
|
||||||
return normalized
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def add_embeddings_to_collection(self, chunks: List[Chunk]):
|
def add_embeddings_to_collection(self, chunks: List[Chunk]):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user