Updated deployed instance with visualization
This commit is contained in:
parent
1ad2638277
commit
1dd5ae115d
11
Dockerfile
11
Dockerfile
@ -257,7 +257,6 @@ FROM llm-base AS backstory
|
||||
|
||||
COPY /src/requirements.txt /opt/backstory/src/requirements.txt
|
||||
RUN pip install -r /opt/backstory/src/requirements.txt
|
||||
COPY /src/ /opt/backstory/src/
|
||||
|
||||
SHELL [ "/bin/bash", "-c" ]
|
||||
|
||||
@ -303,6 +302,8 @@ ENV SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||
ENV SYCL_CACHE_PERSISTENT=1
|
||||
ENV PATH=/opt/backstory:$PATH
|
||||
|
||||
COPY /src/ /opt/backstory/src/
|
||||
|
||||
ENTRYPOINT [ "/entrypoint.sh" ]
|
||||
|
||||
FROM ubuntu:oracular AS ollama
|
||||
@ -406,6 +407,14 @@ RUN { \
|
||||
&& chmod +x /fetch-models.sh
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
# Enable ext_intel_free_memory
|
||||
ENV ZES_ENABLE_SYSMAN=1
|
||||
# Use all GPUs
|
||||
ENV OLLAMA_NUM_GPU=999
|
||||
# Use immediate command lists
|
||||
ENV SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||
# Use persistent cache
|
||||
ENV SYCL_CACHE_PERSISTENT=1
|
||||
|
||||
VOLUME [" /root/.ollama" ]
|
||||
|
||||
|
@ -4,6 +4,7 @@ services:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: backstory
|
||||
container_name: backstory
|
||||
image: backstory
|
||||
restart: "no"
|
||||
env_file:
|
||||
@ -70,6 +71,7 @@ services:
|
||||
dockerfile: Dockerfile
|
||||
target: ollama
|
||||
image: ollama
|
||||
container_name: ollama
|
||||
restart: "always"
|
||||
env_file:
|
||||
- .env
|
||||
@ -133,6 +135,7 @@ services:
|
||||
dockerfile: Dockerfile
|
||||
target: miniircd
|
||||
image: miniircd
|
||||
container_name: miniircd
|
||||
restart: "no"
|
||||
env_file:
|
||||
- .env
|
||||
|
@ -36,7 +36,7 @@ import { Message, MessageList } from './Message';
|
||||
import { MessageData } from './MessageMeta';
|
||||
import { SeverityType } from './Snack';
|
||||
import { ContextStatus } from './ContextStatus';
|
||||
import { VectorVisualizer, ResultData } from './VectorVisualizer';
|
||||
import { VectorVisualizer } from './VectorVisualizer';
|
||||
|
||||
|
||||
import './App.css';
|
||||
@ -343,7 +343,6 @@ const App = () => {
|
||||
const [resume, setResume] = useState<MessageData | undefined>(undefined);
|
||||
const [facts, setFacts] = useState<MessageData | undefined>(undefined);
|
||||
const timerRef = useRef<any>(null);
|
||||
const [result, setResult] = useState<ResultData | undefined>(undefined);
|
||||
|
||||
const startCountdown = (seconds: number) => {
|
||||
if (timerRef.current) clearInterval(timerRef.current);
|
||||
@ -422,31 +421,6 @@ const App = () => {
|
||||
});
|
||||
}, [systemInfo, setSystemInfo, connectionBase, setSnack, sessionId])
|
||||
|
||||
// Get the collection to visualize
|
||||
useEffect(() => {
|
||||
if (result !== undefined || sessionId === undefined) {
|
||||
return;
|
||||
}
|
||||
const fetchCollection = async () => {
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/tsne/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ dimensions: 3 }),
|
||||
});
|
||||
const data = await response.json();
|
||||
setResult(data);
|
||||
} catch (error) {
|
||||
console.error('Error obtaining collection information:', error);
|
||||
setSnack("Unable to obtain collection information.", "error");
|
||||
};
|
||||
};
|
||||
|
||||
fetchCollection();
|
||||
}, [result, setResult, connectionBase, setSnack, sessionId])
|
||||
|
||||
// Get the About markdown
|
||||
useEffect(() => {
|
||||
if (about !== "") {
|
||||
@ -1101,7 +1075,7 @@ const App = () => {
|
||||
onChange={handleTabChange} aria-label="Backstory navigation">
|
||||
<Tab label="Backstory" icon={<Avatar sx={{ width: 24, height: 24 }} variant="rounded" alt="Backstory logo" src="/logo192.png" />} iconPosition="start" />
|
||||
<Tab label="Resume Builder"/>
|
||||
<Tab label="Visualizer" />
|
||||
<Tab label="Context Visualizer" />
|
||||
<Tab label="About"/>
|
||||
</Tabs>
|
||||
</Box>}
|
||||
@ -1215,7 +1189,7 @@ const App = () => {
|
||||
<CustomTabPanel tab={tab} index={2}>
|
||||
<Box className="ChatBox">
|
||||
<Box className="Conversation">
|
||||
{result !== undefined && <VectorVisualizer {...{ result, connectionBase, sessionId }} />}
|
||||
<VectorVisualizer {...{ connectionBase, sessionId, setSnack }} />
|
||||
</Box>
|
||||
</Box>
|
||||
</CustomTabPanel>
|
||||
|
@ -1,13 +1,17 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import Box from '@mui/material/Box';
|
||||
import Card from '@mui/material/Card';
|
||||
import Typography from '@mui/material/Typography';
|
||||
import Plot from 'react-plotly.js';
|
||||
import TextField from '@mui/material/TextField';
|
||||
import Tooltip from '@mui/material/Tooltip';
|
||||
import Button from '@mui/material/Button';
|
||||
import SendIcon from '@mui/icons-material/Send';
|
||||
|
||||
import { SeverityType } from './Snack';
|
||||
|
||||
interface Metadata {
|
||||
type?: string;
|
||||
doc_type?: string;
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
@ -18,19 +22,23 @@ interface ResultData {
|
||||
}
|
||||
|
||||
interface PlotData {
|
||||
x: number[];
|
||||
y: number[];
|
||||
z?: number[];
|
||||
colors: string[];
|
||||
text: string[];
|
||||
sizes: number[];
|
||||
symbols: string[];
|
||||
data: {
|
||||
x: number[];
|
||||
y: number[];
|
||||
z?: number[];
|
||||
colors: string[];
|
||||
text: string[];
|
||||
sizes: number[];
|
||||
symbols: string[];
|
||||
doc_types: string[];
|
||||
};
|
||||
layout: Partial<Plotly.Layout>;
|
||||
}
|
||||
|
||||
interface VectorVisualizerProps {
|
||||
result: ResultData;
|
||||
connectionBase: string;
|
||||
sessionId?: string;
|
||||
setSnack: (message: string, severity: SeverityType) => void;
|
||||
}
|
||||
|
||||
interface ChromaResult {
|
||||
@ -43,10 +51,82 @@ interface ChromaResult {
|
||||
vector_embedding?: number[];
|
||||
}
|
||||
|
||||
const VectorVisualizer: React.FC<VectorVisualizerProps> = ({ result, connectionBase, sessionId }) => {
|
||||
const normalizeDimension = (arr: number[]): number[] => {
|
||||
const min = Math.min(...arr);
|
||||
const max = Math.max(...arr);
|
||||
const range = max - min;
|
||||
if (range === 0) return arr.map(() => 0.5); // flat dimension
|
||||
return arr.map(v => (v - min) / range);
|
||||
};
|
||||
|
||||
const getTextColorForBackground = (bgColor: string): string => {
|
||||
const r = parseInt(bgColor.slice(1, 3), 16);
|
||||
const g = parseInt(bgColor.slice(3, 5), 16);
|
||||
const b = parseInt(bgColor.slice(5, 7), 16);
|
||||
const luminance = 0.299 * r + 0.587 * g + 0.114 * b;
|
||||
return luminance > 186 ? '#2E2E2E' : '#FFFFFF'; // Charcoal or white from your theme
|
||||
};
|
||||
|
||||
const emojiMap: Record<string, string> = {
|
||||
query: '🔍',
|
||||
resume: '📄',
|
||||
projects: '📁',
|
||||
news: '📰',
|
||||
};
|
||||
|
||||
const colorMap: Record<string, string> = {
|
||||
query: '#D4A017', // Golden Ochre — strong highlight
|
||||
resume: '#4A7A7D', // Dusty Teal — secondary theme color
|
||||
projects: '#1A2536', // Midnight Blue — rich and deep
|
||||
news: '#D3CDBF', // Warm Gray — soft and neutral
|
||||
};
|
||||
|
||||
const sizeMap: Record<string, number> = {
|
||||
'query': 10,
|
||||
};
|
||||
|
||||
const symbolMap: Record<string, string> = {
|
||||
'query': 'circle',
|
||||
};
|
||||
|
||||
const VectorVisualizer: React.FC<VectorVisualizerProps> = ({ setSnack, connectionBase, sessionId }) => {
|
||||
const [plotData, setPlotData] = useState<PlotData | null>(null);
|
||||
const [query, setQuery] = useState<string>('');
|
||||
const [queryEmbedding, setQueryEmbedding] = useState<ChromaResult | undefined>(undefined);
|
||||
const [result, setResult] = useState<ResultData | undefined>(undefined);
|
||||
const [tooltip, setTooltip] = useState<{
|
||||
visible: boolean,
|
||||
// x: number,
|
||||
// y: number,
|
||||
content: string,
|
||||
background: string,
|
||||
color: string,
|
||||
} | null>(null);
|
||||
|
||||
// Get the collection to visualize
|
||||
useEffect(() => {
|
||||
if (result !== undefined || sessionId === undefined) {
|
||||
return;
|
||||
}
|
||||
const fetchCollection = async () => {
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/umap/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ dimensions: 3 }),
|
||||
});
|
||||
const data = await response.json();
|
||||
setResult(data);
|
||||
} catch (error) {
|
||||
console.error('Error obtaining collection information:', error);
|
||||
setSnack("Unable to obtain collection information.", "error");
|
||||
};
|
||||
};
|
||||
|
||||
fetchCollection();
|
||||
}, [result, setResult, connectionBase, setSnack, sessionId])
|
||||
|
||||
useEffect(() => {
|
||||
if (!result || !result.embeddings) return;
|
||||
@ -57,7 +137,7 @@ const VectorVisualizer: React.FC<VectorVisualizerProps> = ({ result, connectionB
|
||||
const metadatas = [...result.metadatas || []];
|
||||
|
||||
if (queryEmbedding !== undefined && queryEmbedding.vector_embedding !== undefined) {
|
||||
metadatas.unshift({ type: 'query' });
|
||||
metadatas.unshift({ doc_type: 'query' });
|
||||
documents.unshift(queryEmbedding.query || '');
|
||||
vectors.unshift(queryEmbedding.vector_embedding);
|
||||
}
|
||||
@ -67,19 +147,8 @@ const VectorVisualizer: React.FC<VectorVisualizerProps> = ({ result, connectionB
|
||||
console.error('Vectors are neither 2D nor 3D');
|
||||
return;
|
||||
}
|
||||
console.log('Vectors:', vectors);
|
||||
// Placeholder color assignment
|
||||
const colorMap: Record<string, string> = {
|
||||
'query': '#00ff00',
|
||||
};
|
||||
const sizeMap: Record<string, number> = {
|
||||
'query': 10,
|
||||
};
|
||||
const symbolMap: Record<string, string> = {
|
||||
'query': 'circle',
|
||||
};
|
||||
|
||||
const doc_types = metadatas.map(m => m.type || 'unknown');
|
||||
const doc_types = metadatas.map(m => m.doc_type || 'unknown');
|
||||
const sizes = doc_types.map(type => {
|
||||
if (!sizeMap[type]) {
|
||||
sizeMap[type] = 5;
|
||||
@ -98,31 +167,51 @@ const VectorVisualizer: React.FC<VectorVisualizerProps> = ({ result, connectionB
|
||||
}
|
||||
return colorMap[type];
|
||||
});
|
||||
const customdata = metadatas.map((m, index) => {
|
||||
return { doc: documents[index], type: m.doc_type || 'unknown' };
|
||||
});
|
||||
const x = normalizeDimension(vectors.map((v: number[]) => v[0]));
|
||||
const y = normalizeDimension(vectors.map((v: number[]) => v[1]));
|
||||
const z = is3D ? normalizeDimension(vectors.map((v: number[]) => v[2])) : undefined
|
||||
|
||||
const layout: Partial<Plotly.Layout> = {
|
||||
autosize: true,
|
||||
paper_bgcolor: '#FFFFFF', // white
|
||||
plot_bgcolor: '#FFFFFF', // white plot background
|
||||
font: {
|
||||
family: 'Roboto, sans-serif',
|
||||
color: '#2E2E2E', // charcoal black
|
||||
},
|
||||
hovermode: 'closest',
|
||||
scene: {
|
||||
bgcolor: '#FFFFFF', // 3D plot background
|
||||
zaxis: { title: 'Z', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' },
|
||||
},
|
||||
xaxis: { title: 'X', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' },
|
||||
yaxis: { title: 'Y', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' },
|
||||
margin: { r: 20, b: 10, l: 10, t: 40 },
|
||||
};
|
||||
|
||||
const data: any = {
|
||||
x: x,
|
||||
y: y,
|
||||
mode: 'markers',
|
||||
marker: {
|
||||
size: sizes,
|
||||
symbol: symbols,
|
||||
color: colors,
|
||||
opacity: 0.8,
|
||||
},
|
||||
customdata: customdata,
|
||||
type: z?.length ? 'scatter3d' : 'scatter',
|
||||
};
|
||||
|
||||
const x = vectors.map((v: number[]) => v[0]);
|
||||
const y = vectors.map((v: number[]) => v[1]);
|
||||
const text = documents.map((doc, i) => `Type: ${doc_types[i]}<br>Text: ${doc.slice(0, 100)}...`);
|
||||
if (is3D) {
|
||||
const z = vectors.map((v: number[]) => v[2]);
|
||||
setPlotData({
|
||||
x: x,
|
||||
y: y,
|
||||
z: z,
|
||||
colors: colors,
|
||||
sizes: sizes,
|
||||
symbols: symbols,
|
||||
text: text
|
||||
});
|
||||
} else {
|
||||
setPlotData({
|
||||
x: x,
|
||||
y: y,
|
||||
colors: colors,
|
||||
sizes: sizes,
|
||||
symbols: symbols,
|
||||
text: text
|
||||
});
|
||||
data.z = z;
|
||||
}
|
||||
|
||||
setPlotData({ data, layout });
|
||||
|
||||
}, [result, queryEmbedding]);
|
||||
|
||||
const handleKeyPress = (event: any) => {
|
||||
@ -134,19 +223,23 @@ const VectorVisualizer: React.FC<VectorVisualizerProps> = ({ result, connectionB
|
||||
const sendQuery = async (query: string) => {
|
||||
if (!query.trim()) return;
|
||||
setQuery('');
|
||||
|
||||
const response = await fetch(`${connectionBase}/api/similarity/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
query: query,
|
||||
})
|
||||
});
|
||||
const chroma: ChromaResult = await response.json();
|
||||
console.log('Chroma:', chroma);
|
||||
setQueryEmbedding(chroma);
|
||||
try {
|
||||
const response = await fetch(`${connectionBase}/api/similarity/${sessionId}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
query: query,
|
||||
})
|
||||
});
|
||||
const chroma: ChromaResult = await response.json();
|
||||
console.log('Chroma:', chroma);
|
||||
setQueryEmbedding(chroma);
|
||||
} catch (error) {
|
||||
console.error('Error obtaining query similarity information:', error);
|
||||
setSnack("Unable to obtain query similarity information.", "error");
|
||||
};
|
||||
};
|
||||
|
||||
if (!plotData || sessionId === undefined) return (
|
||||
@ -157,45 +250,65 @@ const VectorVisualizer: React.FC<VectorVisualizerProps> = ({ result, connectionB
|
||||
|
||||
return (
|
||||
<>
|
||||
<Card sx={{ display: 'flex', flexDirection: 'column', justifyContent: 'center', alignItems: 'center', mb: 1, pt: 0 }}>
|
||||
<Typography variant="h6" sx={{ p: 1, pt: 0 }}>
|
||||
Similarity Visualization via Uniform Manifold Approximation and Projection (UMAP)
|
||||
</Typography>
|
||||
</Card>
|
||||
<Box sx={{ display: 'flex', flexGrow: 1, justifyContent: 'center', alignItems: 'center' }}>
|
||||
<Plot
|
||||
data={[
|
||||
{
|
||||
x: plotData.x,
|
||||
y: plotData.y,
|
||||
z: plotData.z,
|
||||
mode: 'markers',
|
||||
marker: {
|
||||
size: plotData.sizes,
|
||||
symbol: plotData.symbols,
|
||||
color: plotData.colors,
|
||||
opacity: 0.8,
|
||||
},
|
||||
text: plotData.text,
|
||||
hoverinfo: 'text',
|
||||
type: plotData.z?.length ? 'scatter3d' : 'scatter',
|
||||
},
|
||||
]}
|
||||
onHover={(event: any) => {
|
||||
const point = event.points[0];
|
||||
console.log('Point:', point);
|
||||
const type = point.customdata.type;
|
||||
const text = point.customdata.doc;
|
||||
const emoji = emojiMap[type] || '❓';
|
||||
setTooltip({
|
||||
visible: true,
|
||||
background: point['marker.color'],
|
||||
color: getTextColorForBackground(point['marker.color']),
|
||||
content: `${emoji} ${type.toUpperCase()}\n${text}`,
|
||||
});
|
||||
}}
|
||||
|
||||
data={[plotData.data]}
|
||||
useResizeHandler={true}
|
||||
config={{ responsive: true }}
|
||||
config={{
|
||||
responsive: true,
|
||||
displayModeBar: false,
|
||||
displaylogo: false,
|
||||
showSendToCloud: false,
|
||||
staticPlot: false,
|
||||
}}
|
||||
style={{ width: '100%', height: '100%' }}
|
||||
layout={{
|
||||
autosize: true,
|
||||
title: 'Vector Store Visualization',
|
||||
xaxis: { title: 'x' },
|
||||
yaxis: { title: 'y' },
|
||||
zaxis: { title: 'z' },
|
||||
margin: { r: 20, b: 10, l: 10, t: 40 },
|
||||
}}
|
||||
layout={plotData.layout}
|
||||
/>
|
||||
</Box>
|
||||
|
||||
<Card sx={{
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
flexGrow: 1,
|
||||
mt: 1,
|
||||
p: 0.5,
|
||||
color: tooltip?.color || '#2E2E2E',
|
||||
background: tooltip?.background || '#FFFFFF',
|
||||
whiteSpace: 'pre-line',
|
||||
zIndex: 1000,
|
||||
overflow: 'auto',
|
||||
maxHeight: '20vh',
|
||||
minHeight: '20vh',
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ p: 1, pt: 0 }}>
|
||||
{tooltip?.content}
|
||||
</Typography>
|
||||
</Card>
|
||||
{ queryEmbedding !== undefined &&
|
||||
<Box sx={{ display: 'flex', flexDirection: 'column', p: 1 }}>
|
||||
<Box sx={{ fontSize: '0.8rem', mb: 1 }}>
|
||||
<Card sx={{ display: 'flex', flexDirection: 'column', justifyContent: 'center', alignItems: 'center', mt: 1, pb: 0 }}>
|
||||
<Typography variant="h6" sx={{ p: 1, pt: 0 }}>
|
||||
Query: {queryEmbedding.query}
|
||||
</Box>
|
||||
</Box>
|
||||
</Typography>
|
||||
</Card>
|
||||
}
|
||||
|
||||
<Box className="Query" sx={{ display: "flex", flexDirection: "row", p: 1 }}>
|
||||
|
@ -48,7 +48,6 @@ Deprecated==1.2.18
|
||||
diffusers==0.33.1
|
||||
dill==0.3.8
|
||||
distro==1.9.0
|
||||
dpcpp-cpp-rt==2025.0.4
|
||||
durationpy==0.9
|
||||
einops==0.8.1
|
||||
emoji==2.14.1
|
||||
@ -104,15 +103,6 @@ impi-devel==2021.14.1
|
||||
impi-rt==2021.14.1
|
||||
importlib_metadata==8.6.1
|
||||
importlib_resources==6.5.2
|
||||
intel-cmplr-lib-rt==2025.0.2
|
||||
intel-cmplr-lib-ur==2025.0.2
|
||||
intel-cmplr-lic-rt==2025.0.2
|
||||
intel-opencl-rt==2025.0.4
|
||||
intel-openmp==2025.0.4
|
||||
intel-pti==0.10.0
|
||||
intel-sycl-rt==2025.0.2
|
||||
intel_extension_for_pytorch==2.6.10+xpu
|
||||
ipex-llm @ file:///opt/wheels/ipex_llm-2.2.0.dev0-py3-none-any.whl#sha256=5023ff4dc9799838486b4d160d5f3dcd5f6d3bb9ac8a2c6cabaf90034b540ba3
|
||||
ipykernel==6.29.5
|
||||
ipython==9.1.0
|
||||
ipython_pygments_lexers==1.1.1
|
||||
@ -160,8 +150,6 @@ matplotlib==3.10.1
|
||||
matplotlib-inline==0.1.7
|
||||
mdurl==0.1.2
|
||||
mistune==3.1.3
|
||||
mkl==2025.0.1
|
||||
mkl-dpcpp==2025.0.1
|
||||
mmh3==5.1.0
|
||||
modal==0.74.4
|
||||
monotonic==1.6
|
||||
|
@ -25,9 +25,9 @@ try_import('requests')
|
||||
try_import('bs4', 'beautifulsoup4')
|
||||
try_import('fastapi')
|
||||
try_import('uvicorn')
|
||||
try_import('sklearn')
|
||||
try_import('numpy')
|
||||
try_import('umap')
|
||||
try_import('sklearn')
|
||||
|
||||
import ollama
|
||||
import requests
|
||||
@ -37,8 +37,8 @@ from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, Red
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
import numpy as np
|
||||
#from sklearn.manifold import TSNE
|
||||
import umap
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
|
||||
from utils import (
|
||||
rag as Rag,
|
||||
@ -46,7 +46,7 @@ from utils import (
|
||||
)
|
||||
|
||||
from tools import (
|
||||
DateTime,
|
||||
DateTime,
|
||||
WeatherForecast,
|
||||
TickerValue,
|
||||
tools
|
||||
@ -381,8 +381,7 @@ class WebServer:
|
||||
)
|
||||
|
||||
@self.app.on_event("startup")
|
||||
async def startup_event():
|
||||
|
||||
async def startup_event():
|
||||
# Start the file watcher
|
||||
self.observer, self.file_watcher = Rag.start_file_watcher(
|
||||
llm=client,
|
||||
@ -475,8 +474,8 @@ class WebServer:
|
||||
# "document_count": file_watcher.collection.count()
|
||||
# }
|
||||
|
||||
@self.app.put('/api/tsne/{context_id}')
|
||||
async def put_tsne(context_id: str, request: Request):
|
||||
@self.app.put('/api/umap/{context_id}')
|
||||
async def put_umap(context_id: str, request: Request):
|
||||
if not self.file_watcher:
|
||||
return
|
||||
|
||||
@ -495,7 +494,7 @@ class WebServer:
|
||||
try:
|
||||
result = self.file_watcher.collection.get(include=['embeddings', 'documents', 'metadatas'])
|
||||
vectors = np.array(result['embeddings'])
|
||||
umap_model = umap.UMAP(n_components=dimensions, random_state=42)
|
||||
umap_model = umap.UMAP(n_components=dimensions, random_state=42) #, n_neighbors=15, min_dist=0.1)
|
||||
embedding = umap_model.fit_transform(vectors)
|
||||
context['umap_model'] = umap_model
|
||||
result['embeddings'] = embedding.tolist()
|
||||
@ -531,9 +530,8 @@ class WebServer:
|
||||
if not chroma_results:
|
||||
return JSONResponse({"error": "No results found"}, status_code=404)
|
||||
chroma_embedding = chroma_results["query_embedding"]
|
||||
normalized = (chroma_embedding - chroma_embedding.min()) / (chroma_embedding.max() - chroma_embedding.min())
|
||||
vector_embedding = context["umap_model"].transform([normalized])[0].tolist()
|
||||
return JSONResponse({ **chroma_results, "query": query, "vector_embedding": vector_embedding })
|
||||
umap_embedding = context["umap_model"].transform([chroma_embedding])[0].tolist()
|
||||
return JSONResponse({ **chroma_results, "query": query, "vector_embedding": umap_embedding })
|
||||
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
|
144
src/utils/rag.py
144
src/utils/rag.py
@ -35,7 +35,7 @@ __all__ = [
|
||||
|
||||
class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
def __init__(self, llm, watch_directory, loop, persist_directory=None, collection_name="documents",
|
||||
chunk_size=1000, chunk_overlap=200, recreate=False):
|
||||
chunk_size=500, chunk_overlap=200, recreate=False):
|
||||
self.llm = llm
|
||||
self.watch_directory = watch_directory
|
||||
self.persist_directory = persist_directory or defines.persist_directory
|
||||
@ -47,8 +47,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# Path for storing file hash state
|
||||
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
|
||||
|
||||
# Flag to track if this is a new collection
|
||||
self.is_new_collection = False
|
||||
|
||||
# Initialize ChromaDB collection
|
||||
self.collection = self._get_vector_collection(recreate=recreate)
|
||||
self._collection = self._get_vector_collection(recreate=recreate)
|
||||
|
||||
# Setup text splitter
|
||||
self.text_splitter = CharacterTextSplitter(
|
||||
@ -60,18 +63,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
self.file_hashes = self._load_hash_state()
|
||||
self.update_lock = asyncio.Lock()
|
||||
self.processing_files = set()
|
||||
|
||||
# Always scan for new/changed files at startup
|
||||
self._update_file_hashes()
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
return self._collection
|
||||
|
||||
@collection.setter
|
||||
def collection(self, value):
|
||||
self._collection = value
|
||||
|
||||
def _save_hash_state(self):
|
||||
"""Save the current file hash state to disk."""
|
||||
try:
|
||||
@ -98,12 +94,20 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
|
||||
return {}
|
||||
|
||||
def _update_file_hashes(self):
|
||||
"""Update file hashes by checking for new or modified files."""
|
||||
async def scan_directory(self, process_all=False):
|
||||
"""
|
||||
Scan directory for new, modified, or deleted files and update collection.
|
||||
|
||||
Args:
|
||||
process_all: If True, process all files regardless of hash status
|
||||
"""
|
||||
# Check for new or modified files
|
||||
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
|
||||
files_checked = 0
|
||||
files_changed = 0
|
||||
files_processed = 0
|
||||
files_to_process = []
|
||||
|
||||
logging.info(f"Starting directory scan. Found {len(file_paths)} total paths.")
|
||||
|
||||
for file_path in file_paths:
|
||||
if os.path.isfile(file_path):
|
||||
@ -112,13 +116,13 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
if not current_hash:
|
||||
continue
|
||||
|
||||
# If file is new or changed
|
||||
if file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
|
||||
# If file is new, changed, or we're processing all files
|
||||
if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
|
||||
self.file_hashes[file_path] = current_hash
|
||||
files_changed += 1
|
||||
# Schedule an update for this file
|
||||
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
|
||||
logging.info(f"File changed: {file_path}")
|
||||
files_to_process.append(file_path)
|
||||
logging.info(f"File {'found' if process_all else 'changed'}: {file_path}")
|
||||
|
||||
logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files")
|
||||
|
||||
# Check for deleted files
|
||||
deleted_files = []
|
||||
@ -127,16 +131,28 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
deleted_files.append(file_path)
|
||||
# Schedule removal
|
||||
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
||||
# Don't block on result, just let it run
|
||||
logging.info(f"File deleted: {file_path}")
|
||||
|
||||
# Remove deleted files from hash state
|
||||
for file_path in deleted_files:
|
||||
del self.file_hashes[file_path]
|
||||
|
||||
# Process all discovered files using asyncio.gather with the existing loop
|
||||
if files_to_process:
|
||||
logging.info(f"Starting to process {len(files_to_process)} files")
|
||||
|
||||
logging.info(f"Checked {files_checked} files: {files_changed} new/changed, {len(deleted_files)} deleted")
|
||||
for file_path in files_to_process:
|
||||
async with self.update_lock:
|
||||
await self._update_document_in_collection(file_path)
|
||||
else:
|
||||
logging.info("No files to process")
|
||||
|
||||
# Save the updated state
|
||||
self._save_hash_state()
|
||||
|
||||
logging.info(f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}")
|
||||
return files_processed
|
||||
|
||||
async def process_file_update(self, file_path):
|
||||
"""Process a file update event."""
|
||||
@ -204,13 +220,24 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
settings=chromadb.Settings(anonymized_telemetry=False)
|
||||
)
|
||||
|
||||
# Check if the collection exists and delete it if recreate is True
|
||||
if recreate and os.path.exists(self.persist_directory):
|
||||
try:
|
||||
chroma_client.delete_collection(name=self.collection_name)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to delete existing collection: {e}")
|
||||
|
||||
# Check if the collection exists
|
||||
try:
|
||||
chroma_client.get_collection(self.collection_name)
|
||||
collection_exists = True
|
||||
except:
|
||||
collection_exists = False
|
||||
|
||||
# If collection doesn't exist, mark it as new
|
||||
if not collection_exists:
|
||||
self.is_new_collection = True
|
||||
logging.info(f"Creating new collection: {self.collection_name}")
|
||||
|
||||
# Delete if recreate is True
|
||||
if recreate and collection_exists:
|
||||
chroma_client.delete_collection(name=self.collection_name)
|
||||
self.is_new_collection = True
|
||||
logging.info(f"Recreating collection: {self.collection_name}")
|
||||
|
||||
return chroma_client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={
|
||||
@ -246,14 +273,17 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
"""Split documents into chunks using the text splitter."""
|
||||
return self.text_splitter.split_documents(docs)
|
||||
|
||||
def get_embedding(self, text):
|
||||
def get_embedding(self, text, normalize=True):
|
||||
"""Generate embeddings using Ollama."""
|
||||
response = self.llm.embeddings(
|
||||
model=defines.model,
|
||||
prompt=text,
|
||||
options={"num_ctx": defines.max_context}
|
||||
options={"num_ctx": self.chunk_size * 3} # No need waste ctx space
|
||||
)
|
||||
return self._normalize_embeddings(response["embedding"])
|
||||
if normalize:
|
||||
normalized = self._normalize_embeddings(response["embedding"])
|
||||
return normalized
|
||||
return response["embedding"]
|
||||
|
||||
def add_embeddings_to_collection(self, chunks):
|
||||
"""Add embeddings for chunks to the collection."""
|
||||
@ -293,18 +323,6 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
"metadatas": results["metadatas"][0],
|
||||
}
|
||||
|
||||
def _initialize_file_hashes(self):
|
||||
"""Initialize the hash dictionary for all files in the directory."""
|
||||
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
|
||||
for file_path in file_paths:
|
||||
if os.path.isfile(file_path):
|
||||
hash = self._get_file_hash(file_path)
|
||||
if hash:
|
||||
self.file_hashes[file_path] = hash
|
||||
|
||||
# Save the initialized hash state
|
||||
self._save_hash_state()
|
||||
|
||||
def _get_file_hash(self, file_path):
|
||||
"""Calculate MD5 hash of a file."""
|
||||
try:
|
||||
@ -358,6 +376,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
async def _update_document_in_collection(self, file_path):
|
||||
"""Update a document in the ChromaDB collection."""
|
||||
try:
|
||||
logging.info(f"Updating document in collection: {file_path}")
|
||||
# Remove existing entries for this file
|
||||
existing_results = self.collection.get(where={"path": file_path})
|
||||
if existing_results and 'ids' in existing_results and existing_results['ids']:
|
||||
@ -387,25 +406,27 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating document in collection: {e}")
|
||||
|
||||
def initialize_collection(self):
|
||||
async def initialize_collection(self):
|
||||
"""Initialize the collection with all documents from the watch directory."""
|
||||
documents = self.load_text_files()
|
||||
logging.info(f"Documents loaded: {len(documents)}")
|
||||
# Process all files regardless of hash state
|
||||
num_processed = await self.scan_directory(process_all=True)
|
||||
|
||||
chunks = self.create_chunks_from_documents(documents)
|
||||
self.add_embeddings_to_collection(chunks)
|
||||
logging.info(f"Vectorstore initialized with {self.collection.count()} documents")
|
||||
|
||||
logging.info(f"Vectorstore created with {self.collection.count()} documents")
|
||||
|
||||
# Display document types
|
||||
doc_types = set(chunk.metadata['doc_type'] for chunk in chunks)
|
||||
logging.info(f"Document types: {doc_types}")
|
||||
|
||||
return len(chunks)
|
||||
# Show stats
|
||||
try:
|
||||
all_metadata = self.collection.get()['metadatas']
|
||||
if all_metadata:
|
||||
doc_types = set(m.get('doc_type', 'unknown') for m in all_metadata)
|
||||
logging.info(f"Document types: {doc_types}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting document types: {e}")
|
||||
|
||||
return num_processed
|
||||
|
||||
# Function to start the file watcher
|
||||
def start_file_watcher(llm, watch_directory, persist_directory=None,
|
||||
collection_name="documents", recreate=False):
|
||||
collection_name="documents", initialize=False, recreate=False):
|
||||
"""
|
||||
Start watching a directory for file changes.
|
||||
|
||||
@ -414,6 +435,7 @@ def start_file_watcher(llm, watch_directory, persist_directory=None,
|
||||
watch_directory: Directory to watch for changes
|
||||
persist_directory: Directory to persist ChromaDB and hash state
|
||||
collection_name: Name of the ChromaDB collection
|
||||
initialize: Whether to forcibly initialize the collection with all documents
|
||||
recreate: Whether to recreate the collection (will delete existing)
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
@ -427,9 +449,17 @@ def start_file_watcher(llm, watch_directory, persist_directory=None,
|
||||
recreate=recreate
|
||||
)
|
||||
|
||||
# Initialize collection if it does not exist
|
||||
if not os.path.exists(file_watcher.hash_state_path):
|
||||
file_watcher.initialize_collection()
|
||||
# Process all files if:
|
||||
# 1. initialize=True was passed (explicit request to initialize)
|
||||
# 2. This is a new collection (doesn't exist yet)
|
||||
# 3. There's no hash state (first run)
|
||||
if initialize or file_watcher.is_new_collection or not file_watcher.file_hashes:
|
||||
logging.info("Initializing collection with all documents")
|
||||
asyncio.run_coroutine_threadsafe(file_watcher.initialize_collection(), loop)
|
||||
else:
|
||||
# Only process new/changed files
|
||||
logging.info("Scanning for new/changed documents")
|
||||
asyncio.run_coroutine_threadsafe(file_watcher.scan_directory(), loop)
|
||||
|
||||
# Start observer
|
||||
observer = Observer()
|
||||
|
Loading…
x
Reference in New Issue
Block a user