From 1dd5ae115d85aec1bd83b5d8da627c8d13cbc6c0 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Thu, 17 Apr 2025 18:16:34 -0700 Subject: [PATCH] Updated deployed instance with visualization --- Dockerfile | 11 +- docker-compose.yml | 3 + frontend/deployed/.keep | 0 frontend/src/App.tsx | 32 +--- frontend/src/VectorVisualizer.tsx | 291 +++++++++++++++++++++--------- src/requirements.txt | 12 -- src/server.py | 20 +- src/utils/rag.py | 144 +++++++++------ 8 files changed, 314 insertions(+), 199 deletions(-) delete mode 100644 frontend/deployed/.keep diff --git a/Dockerfile b/Dockerfile index 42ab651..1631c90 100644 --- a/Dockerfile +++ b/Dockerfile @@ -257,7 +257,6 @@ FROM llm-base AS backstory COPY /src/requirements.txt /opt/backstory/src/requirements.txt RUN pip install -r /opt/backstory/src/requirements.txt -COPY /src/ /opt/backstory/src/ SHELL [ "/bin/bash", "-c" ] @@ -303,6 +302,8 @@ ENV SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 ENV SYCL_CACHE_PERSISTENT=1 ENV PATH=/opt/backstory:$PATH +COPY /src/ /opt/backstory/src/ + ENTRYPOINT [ "/entrypoint.sh" ] FROM ubuntu:oracular AS ollama @@ -406,6 +407,14 @@ RUN { \ && chmod +x /fetch-models.sh ENV PYTHONUNBUFFERED=1 +# Enable ext_intel_free_memory +ENV ZES_ENABLE_SYSMAN=1 +# Use all GPUs +ENV OLLAMA_NUM_GPU=999 +# Use immediate command lists +ENV SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +# Use persistent cache +ENV SYCL_CACHE_PERSISTENT=1 VOLUME [" /root/.ollama" ] diff --git a/docker-compose.yml b/docker-compose.yml index e5151fc..f7f6ab9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,6 +4,7 @@ services: context: . dockerfile: Dockerfile target: backstory + container_name: backstory image: backstory restart: "no" env_file: @@ -70,6 +71,7 @@ services: dockerfile: Dockerfile target: ollama image: ollama + container_name: ollama restart: "always" env_file: - .env @@ -133,6 +135,7 @@ services: dockerfile: Dockerfile target: miniircd image: miniircd + container_name: miniircd restart: "no" env_file: - .env diff --git a/frontend/deployed/.keep b/frontend/deployed/.keep deleted file mode 100644 index e69de29..0000000 diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 7786204..9c8c2f7 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -36,7 +36,7 @@ import { Message, MessageList } from './Message'; import { MessageData } from './MessageMeta'; import { SeverityType } from './Snack'; import { ContextStatus } from './ContextStatus'; -import { VectorVisualizer, ResultData } from './VectorVisualizer'; +import { VectorVisualizer } from './VectorVisualizer'; import './App.css'; @@ -343,7 +343,6 @@ const App = () => { const [resume, setResume] = useState(undefined); const [facts, setFacts] = useState(undefined); const timerRef = useRef(null); - const [result, setResult] = useState(undefined); const startCountdown = (seconds: number) => { if (timerRef.current) clearInterval(timerRef.current); @@ -422,31 +421,6 @@ const App = () => { }); }, [systemInfo, setSystemInfo, connectionBase, setSnack, sessionId]) - // Get the collection to visualize - useEffect(() => { - if (result !== undefined || sessionId === undefined) { - return; - } - const fetchCollection = async () => { - try { - const response = await fetch(connectionBase + `/api/tsne/${sessionId}`, { - method: 'PUT', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ dimensions: 3 }), - }); - const data = await response.json(); - setResult(data); - } catch (error) { - console.error('Error obtaining collection information:', error); - setSnack("Unable to obtain collection information.", "error"); - }; - }; - - fetchCollection(); - }, [result, setResult, connectionBase, setSnack, sessionId]) - // Get the About markdown useEffect(() => { if (about !== "") { @@ -1101,7 +1075,7 @@ const App = () => { onChange={handleTabChange} aria-label="Backstory navigation"> } iconPosition="start" /> - + } @@ -1215,7 +1189,7 @@ const App = () => { - {result !== undefined && } + diff --git a/frontend/src/VectorVisualizer.tsx b/frontend/src/VectorVisualizer.tsx index 0c2cc1c..a84dc68 100644 --- a/frontend/src/VectorVisualizer.tsx +++ b/frontend/src/VectorVisualizer.tsx @@ -1,13 +1,17 @@ import React, { useEffect, useState } from 'react'; import Box from '@mui/material/Box'; +import Card from '@mui/material/Card'; +import Typography from '@mui/material/Typography'; import Plot from 'react-plotly.js'; import TextField from '@mui/material/TextField'; import Tooltip from '@mui/material/Tooltip'; import Button from '@mui/material/Button'; import SendIcon from '@mui/icons-material/Send'; +import { SeverityType } from './Snack'; + interface Metadata { - type?: string; + doc_type?: string; [key: string]: any; } @@ -18,19 +22,23 @@ interface ResultData { } interface PlotData { - x: number[]; - y: number[]; - z?: number[]; - colors: string[]; - text: string[]; - sizes: number[]; - symbols: string[]; + data: { + x: number[]; + y: number[]; + z?: number[]; + colors: string[]; + text: string[]; + sizes: number[]; + symbols: string[]; + doc_types: string[]; + }; + layout: Partial; } interface VectorVisualizerProps { - result: ResultData; connectionBase: string; sessionId?: string; + setSnack: (message: string, severity: SeverityType) => void; } interface ChromaResult { @@ -43,10 +51,82 @@ interface ChromaResult { vector_embedding?: number[]; } -const VectorVisualizer: React.FC = ({ result, connectionBase, sessionId }) => { +const normalizeDimension = (arr: number[]): number[] => { + const min = Math.min(...arr); + const max = Math.max(...arr); + const range = max - min; + if (range === 0) return arr.map(() => 0.5); // flat dimension + return arr.map(v => (v - min) / range); +}; + +const getTextColorForBackground = (bgColor: string): string => { + const r = parseInt(bgColor.slice(1, 3), 16); + const g = parseInt(bgColor.slice(3, 5), 16); + const b = parseInt(bgColor.slice(5, 7), 16); + const luminance = 0.299 * r + 0.587 * g + 0.114 * b; + return luminance > 186 ? '#2E2E2E' : '#FFFFFF'; // Charcoal or white from your theme +}; + +const emojiMap: Record = { + query: '🔍', + resume: '📄', + projects: '📁', + news: '📰', +}; + +const colorMap: Record = { + query: '#D4A017', // Golden Ochre — strong highlight + resume: '#4A7A7D', // Dusty Teal — secondary theme color + projects: '#1A2536', // Midnight Blue — rich and deep + news: '#D3CDBF', // Warm Gray — soft and neutral +}; + +const sizeMap: Record = { + 'query': 10, +}; + +const symbolMap: Record = { + 'query': 'circle', +}; + +const VectorVisualizer: React.FC = ({ setSnack, connectionBase, sessionId }) => { const [plotData, setPlotData] = useState(null); const [query, setQuery] = useState(''); const [queryEmbedding, setQueryEmbedding] = useState(undefined); + const [result, setResult] = useState(undefined); + const [tooltip, setTooltip] = useState<{ + visible: boolean, + // x: number, + // y: number, + content: string, + background: string, + color: string, + } | null>(null); + + // Get the collection to visualize + useEffect(() => { + if (result !== undefined || sessionId === undefined) { + return; + } + const fetchCollection = async () => { + try { + const response = await fetch(connectionBase + `/api/umap/${sessionId}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ dimensions: 3 }), + }); + const data = await response.json(); + setResult(data); + } catch (error) { + console.error('Error obtaining collection information:', error); + setSnack("Unable to obtain collection information.", "error"); + }; + }; + + fetchCollection(); + }, [result, setResult, connectionBase, setSnack, sessionId]) useEffect(() => { if (!result || !result.embeddings) return; @@ -57,7 +137,7 @@ const VectorVisualizer: React.FC = ({ result, connectionB const metadatas = [...result.metadatas || []]; if (queryEmbedding !== undefined && queryEmbedding.vector_embedding !== undefined) { - metadatas.unshift({ type: 'query' }); + metadatas.unshift({ doc_type: 'query' }); documents.unshift(queryEmbedding.query || ''); vectors.unshift(queryEmbedding.vector_embedding); } @@ -67,19 +147,8 @@ const VectorVisualizer: React.FC = ({ result, connectionB console.error('Vectors are neither 2D nor 3D'); return; } - console.log('Vectors:', vectors); - // Placeholder color assignment - const colorMap: Record = { - 'query': '#00ff00', - }; - const sizeMap: Record = { - 'query': 10, - }; - const symbolMap: Record = { - 'query': 'circle', - }; - const doc_types = metadatas.map(m => m.type || 'unknown'); + const doc_types = metadatas.map(m => m.doc_type || 'unknown'); const sizes = doc_types.map(type => { if (!sizeMap[type]) { sizeMap[type] = 5; @@ -98,31 +167,51 @@ const VectorVisualizer: React.FC = ({ result, connectionB } return colorMap[type]; }); + const customdata = metadatas.map((m, index) => { + return { doc: documents[index], type: m.doc_type || 'unknown' }; + }); + const x = normalizeDimension(vectors.map((v: number[]) => v[0])); + const y = normalizeDimension(vectors.map((v: number[]) => v[1])); + const z = is3D ? normalizeDimension(vectors.map((v: number[]) => v[2])) : undefined + + const layout: Partial = { + autosize: true, + paper_bgcolor: '#FFFFFF', // white + plot_bgcolor: '#FFFFFF', // white plot background + font: { + family: 'Roboto, sans-serif', + color: '#2E2E2E', // charcoal black + }, + hovermode: 'closest', + scene: { + bgcolor: '#FFFFFF', // 3D plot background + zaxis: { title: 'Z', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' }, + }, + xaxis: { title: 'X', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' }, + yaxis: { title: 'Y', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' }, + margin: { r: 20, b: 10, l: 10, t: 40 }, + }; + + const data: any = { + x: x, + y: y, + mode: 'markers', + marker: { + size: sizes, + symbol: symbols, + color: colors, + opacity: 0.8, + }, + customdata: customdata, + type: z?.length ? 'scatter3d' : 'scatter', + }; - const x = vectors.map((v: number[]) => v[0]); - const y = vectors.map((v: number[]) => v[1]); - const text = documents.map((doc, i) => `Type: ${doc_types[i]}
Text: ${doc.slice(0, 100)}...`); if (is3D) { - const z = vectors.map((v: number[]) => v[2]); - setPlotData({ - x: x, - y: y, - z: z, - colors: colors, - sizes: sizes, - symbols: symbols, - text: text - }); - } else { - setPlotData({ - x: x, - y: y, - colors: colors, - sizes: sizes, - symbols: symbols, - text: text - }); + data.z = z; } + + setPlotData({ data, layout }); + }, [result, queryEmbedding]); const handleKeyPress = (event: any) => { @@ -134,19 +223,23 @@ const VectorVisualizer: React.FC = ({ result, connectionB const sendQuery = async (query: string) => { if (!query.trim()) return; setQuery(''); - - const response = await fetch(`${connectionBase}/api/similarity/${sessionId}`, { - method: 'PUT', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - query: query, - }) - }); - const chroma: ChromaResult = await response.json(); - console.log('Chroma:', chroma); - setQueryEmbedding(chroma); + try { + const response = await fetch(`${connectionBase}/api/similarity/${sessionId}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + query: query, + }) + }); + const chroma: ChromaResult = await response.json(); + console.log('Chroma:', chroma); + setQueryEmbedding(chroma); + } catch (error) { + console.error('Error obtaining query similarity information:', error); + setSnack("Unable to obtain query similarity information.", "error"); + }; }; if (!plotData || sessionId === undefined) return ( @@ -157,45 +250,65 @@ const VectorVisualizer: React.FC = ({ result, connectionB return ( <> + + + Similarity Visualization via Uniform Manifold Approximation and Projection (UMAP) + + { + const point = event.points[0]; + console.log('Point:', point); + const type = point.customdata.type; + const text = point.customdata.doc; + const emoji = emojiMap[type] || '❓'; + setTooltip({ + visible: true, + background: point['marker.color'], + color: getTextColorForBackground(point['marker.color']), + content: `${emoji} ${type.toUpperCase()}\n${text}`, + }); + }} + + data={[plotData.data]} useResizeHandler={true} - config={{ responsive: true }} + config={{ + responsive: true, + displayModeBar: false, + displaylogo: false, + showSendToCloud: false, + staticPlot: false, + }} style={{ width: '100%', height: '100%' }} - layout={{ - autosize: true, - title: 'Vector Store Visualization', - xaxis: { title: 'x' }, - yaxis: { title: 'y' }, - zaxis: { title: 'z' }, - margin: { r: 20, b: 10, l: 10, t: 40 }, - }} + layout={plotData.layout} /> - + + + {tooltip?.content} + + { queryEmbedding !== undefined && - - + + Query: {queryEmbedding.query} - - + + } diff --git a/src/requirements.txt b/src/requirements.txt index 0b64e63..2a5f747 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -48,7 +48,6 @@ Deprecated==1.2.18 diffusers==0.33.1 dill==0.3.8 distro==1.9.0 -dpcpp-cpp-rt==2025.0.4 durationpy==0.9 einops==0.8.1 emoji==2.14.1 @@ -104,15 +103,6 @@ impi-devel==2021.14.1 impi-rt==2021.14.1 importlib_metadata==8.6.1 importlib_resources==6.5.2 -intel-cmplr-lib-rt==2025.0.2 -intel-cmplr-lib-ur==2025.0.2 -intel-cmplr-lic-rt==2025.0.2 -intel-opencl-rt==2025.0.4 -intel-openmp==2025.0.4 -intel-pti==0.10.0 -intel-sycl-rt==2025.0.2 -intel_extension_for_pytorch==2.6.10+xpu -ipex-llm @ file:///opt/wheels/ipex_llm-2.2.0.dev0-py3-none-any.whl#sha256=5023ff4dc9799838486b4d160d5f3dcd5f6d3bb9ac8a2c6cabaf90034b540ba3 ipykernel==6.29.5 ipython==9.1.0 ipython_pygments_lexers==1.1.1 @@ -160,8 +150,6 @@ matplotlib==3.10.1 matplotlib-inline==0.1.7 mdurl==0.1.2 mistune==3.1.3 -mkl==2025.0.1 -mkl-dpcpp==2025.0.1 mmh3==5.1.0 modal==0.74.4 monotonic==1.6 diff --git a/src/server.py b/src/server.py index c6b2b15..4911fc3 100644 --- a/src/server.py +++ b/src/server.py @@ -25,9 +25,9 @@ try_import('requests') try_import('bs4', 'beautifulsoup4') try_import('fastapi') try_import('uvicorn') -try_import('sklearn') try_import('numpy') try_import('umap') +try_import('sklearn') import ollama import requests @@ -37,8 +37,8 @@ from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, Red from fastapi.middleware.cors import CORSMiddleware import uvicorn import numpy as np -#from sklearn.manifold import TSNE import umap +from sklearn.preprocessing import MinMaxScaler from utils import ( rag as Rag, @@ -46,7 +46,7 @@ from utils import ( ) from tools import ( - DateTime, + DateTime, WeatherForecast, TickerValue, tools @@ -381,8 +381,7 @@ class WebServer: ) @self.app.on_event("startup") - async def startup_event(): - + async def startup_event(): # Start the file watcher self.observer, self.file_watcher = Rag.start_file_watcher( llm=client, @@ -475,8 +474,8 @@ class WebServer: # "document_count": file_watcher.collection.count() # } - @self.app.put('/api/tsne/{context_id}') - async def put_tsne(context_id: str, request: Request): + @self.app.put('/api/umap/{context_id}') + async def put_umap(context_id: str, request: Request): if not self.file_watcher: return @@ -495,7 +494,7 @@ class WebServer: try: result = self.file_watcher.collection.get(include=['embeddings', 'documents', 'metadatas']) vectors = np.array(result['embeddings']) - umap_model = umap.UMAP(n_components=dimensions, random_state=42) + umap_model = umap.UMAP(n_components=dimensions, random_state=42) #, n_neighbors=15, min_dist=0.1) embedding = umap_model.fit_transform(vectors) context['umap_model'] = umap_model result['embeddings'] = embedding.tolist() @@ -531,9 +530,8 @@ class WebServer: if not chroma_results: return JSONResponse({"error": "No results found"}, status_code=404) chroma_embedding = chroma_results["query_embedding"] - normalized = (chroma_embedding - chroma_embedding.min()) / (chroma_embedding.max() - chroma_embedding.min()) - vector_embedding = context["umap_model"].transform([normalized])[0].tolist() - return JSONResponse({ **chroma_results, "query": query, "vector_embedding": vector_embedding }) + umap_embedding = context["umap_model"].transform([chroma_embedding])[0].tolist() + return JSONResponse({ **chroma_results, "query": query, "vector_embedding": umap_embedding }) except Exception as e: logging.error(e) diff --git a/src/utils/rag.py b/src/utils/rag.py index d569b4c..bbb6e43 100644 --- a/src/utils/rag.py +++ b/src/utils/rag.py @@ -35,7 +35,7 @@ __all__ = [ class ChromaDBFileWatcher(FileSystemEventHandler): def __init__(self, llm, watch_directory, loop, persist_directory=None, collection_name="documents", - chunk_size=1000, chunk_overlap=200, recreate=False): + chunk_size=500, chunk_overlap=200, recreate=False): self.llm = llm self.watch_directory = watch_directory self.persist_directory = persist_directory or defines.persist_directory @@ -47,8 +47,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler): # Path for storing file hash state self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json") + # Flag to track if this is a new collection + self.is_new_collection = False + # Initialize ChromaDB collection - self.collection = self._get_vector_collection(recreate=recreate) + self._collection = self._get_vector_collection(recreate=recreate) # Setup text splitter self.text_splitter = CharacterTextSplitter( @@ -60,18 +63,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler): self.file_hashes = self._load_hash_state() self.update_lock = asyncio.Lock() self.processing_files = set() - - # Always scan for new/changed files at startup - self._update_file_hashes() @property def collection(self): return self._collection - @collection.setter - def collection(self, value): - self._collection = value - def _save_hash_state(self): """Save the current file hash state to disk.""" try: @@ -98,12 +94,20 @@ class ChromaDBFileWatcher(FileSystemEventHandler): return {} - def _update_file_hashes(self): - """Update file hashes by checking for new or modified files.""" + async def scan_directory(self, process_all=False): + """ + Scan directory for new, modified, or deleted files and update collection. + + Args: + process_all: If True, process all files regardless of hash status + """ # Check for new or modified files file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True) files_checked = 0 - files_changed = 0 + files_processed = 0 + files_to_process = [] + + logging.info(f"Starting directory scan. Found {len(file_paths)} total paths.") for file_path in file_paths: if os.path.isfile(file_path): @@ -112,13 +116,13 @@ class ChromaDBFileWatcher(FileSystemEventHandler): if not current_hash: continue - # If file is new or changed - if file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash: + # If file is new, changed, or we're processing all files + if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash: self.file_hashes[file_path] = current_hash - files_changed += 1 - # Schedule an update for this file - asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop) - logging.info(f"File changed: {file_path}") + files_to_process.append(file_path) + logging.info(f"File {'found' if process_all else 'changed'}: {file_path}") + + logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files") # Check for deleted files deleted_files = [] @@ -127,16 +131,28 @@ class ChromaDBFileWatcher(FileSystemEventHandler): deleted_files.append(file_path) # Schedule removal asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop) + # Don't block on result, just let it run logging.info(f"File deleted: {file_path}") # Remove deleted files from hash state for file_path in deleted_files: del self.file_hashes[file_path] + + # Process all discovered files using asyncio.gather with the existing loop + if files_to_process: + logging.info(f"Starting to process {len(files_to_process)} files") - logging.info(f"Checked {files_checked} files: {files_changed} new/changed, {len(deleted_files)} deleted") + for file_path in files_to_process: + async with self.update_lock: + await self._update_document_in_collection(file_path) + else: + logging.info("No files to process") # Save the updated state self._save_hash_state() + + logging.info(f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}") + return files_processed async def process_file_update(self, file_path): """Process a file update event.""" @@ -204,13 +220,24 @@ class ChromaDBFileWatcher(FileSystemEventHandler): settings=chromadb.Settings(anonymized_telemetry=False) ) - # Check if the collection exists and delete it if recreate is True - if recreate and os.path.exists(self.persist_directory): - try: - chroma_client.delete_collection(name=self.collection_name) - except Exception as e: - logging.error(f"Failed to delete existing collection: {e}") - + # Check if the collection exists + try: + chroma_client.get_collection(self.collection_name) + collection_exists = True + except: + collection_exists = False + + # If collection doesn't exist, mark it as new + if not collection_exists: + self.is_new_collection = True + logging.info(f"Creating new collection: {self.collection_name}") + + # Delete if recreate is True + if recreate and collection_exists: + chroma_client.delete_collection(name=self.collection_name) + self.is_new_collection = True + logging.info(f"Recreating collection: {self.collection_name}") + return chroma_client.get_or_create_collection( name=self.collection_name, metadata={ @@ -246,14 +273,17 @@ class ChromaDBFileWatcher(FileSystemEventHandler): """Split documents into chunks using the text splitter.""" return self.text_splitter.split_documents(docs) - def get_embedding(self, text): + def get_embedding(self, text, normalize=True): """Generate embeddings using Ollama.""" response = self.llm.embeddings( model=defines.model, prompt=text, - options={"num_ctx": defines.max_context} + options={"num_ctx": self.chunk_size * 3} # No need waste ctx space ) - return self._normalize_embeddings(response["embedding"]) + if normalize: + normalized = self._normalize_embeddings(response["embedding"]) + return normalized + return response["embedding"] def add_embeddings_to_collection(self, chunks): """Add embeddings for chunks to the collection.""" @@ -293,18 +323,6 @@ class ChromaDBFileWatcher(FileSystemEventHandler): "metadatas": results["metadatas"][0], } - def _initialize_file_hashes(self): - """Initialize the hash dictionary for all files in the directory.""" - file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True) - for file_path in file_paths: - if os.path.isfile(file_path): - hash = self._get_file_hash(file_path) - if hash: - self.file_hashes[file_path] = hash - - # Save the initialized hash state - self._save_hash_state() - def _get_file_hash(self, file_path): """Calculate MD5 hash of a file.""" try: @@ -358,6 +376,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): async def _update_document_in_collection(self, file_path): """Update a document in the ChromaDB collection.""" try: + logging.info(f"Updating document in collection: {file_path}") # Remove existing entries for this file existing_results = self.collection.get(where={"path": file_path}) if existing_results and 'ids' in existing_results and existing_results['ids']: @@ -387,25 +406,27 @@ class ChromaDBFileWatcher(FileSystemEventHandler): except Exception as e: logging.error(f"Error updating document in collection: {e}") - def initialize_collection(self): + async def initialize_collection(self): """Initialize the collection with all documents from the watch directory.""" - documents = self.load_text_files() - logging.info(f"Documents loaded: {len(documents)}") + # Process all files regardless of hash state + num_processed = await self.scan_directory(process_all=True) - chunks = self.create_chunks_from_documents(documents) - self.add_embeddings_to_collection(chunks) + logging.info(f"Vectorstore initialized with {self.collection.count()} documents") - logging.info(f"Vectorstore created with {self.collection.count()} documents") - - # Display document types - doc_types = set(chunk.metadata['doc_type'] for chunk in chunks) - logging.info(f"Document types: {doc_types}") - - return len(chunks) + # Show stats + try: + all_metadata = self.collection.get()['metadatas'] + if all_metadata: + doc_types = set(m.get('doc_type', 'unknown') for m in all_metadata) + logging.info(f"Document types: {doc_types}") + except Exception as e: + logging.error(f"Error getting document types: {e}") + + return num_processed # Function to start the file watcher def start_file_watcher(llm, watch_directory, persist_directory=None, - collection_name="documents", recreate=False): + collection_name="documents", initialize=False, recreate=False): """ Start watching a directory for file changes. @@ -414,6 +435,7 @@ def start_file_watcher(llm, watch_directory, persist_directory=None, watch_directory: Directory to watch for changes persist_directory: Directory to persist ChromaDB and hash state collection_name: Name of the ChromaDB collection + initialize: Whether to forcibly initialize the collection with all documents recreate: Whether to recreate the collection (will delete existing) """ loop = asyncio.get_event_loop() @@ -427,9 +449,17 @@ def start_file_watcher(llm, watch_directory, persist_directory=None, recreate=recreate ) - # Initialize collection if it does not exist - if not os.path.exists(file_watcher.hash_state_path): - file_watcher.initialize_collection() + # Process all files if: + # 1. initialize=True was passed (explicit request to initialize) + # 2. This is a new collection (doesn't exist yet) + # 3. There's no hash state (first run) + if initialize or file_watcher.is_new_collection or not file_watcher.file_hashes: + logging.info("Initializing collection with all documents") + asyncio.run_coroutine_threadsafe(file_watcher.initialize_collection(), loop) + else: + # Only process new/changed files + logging.info("Scanning for new/changed documents") + asyncio.run_coroutine_threadsafe(file_watcher.scan_directory(), loop) # Start observer observer = Observer()