684 lines
26 KiB
Python
684 lines
26 KiB
Python
from __future__ import annotations
|
|
from pydantic import BaseModel # type: ignore
|
|
from typing import List, Optional, Dict, Any
|
|
import os
|
|
import glob
|
|
from pathlib import Path
|
|
import hashlib
|
|
import asyncio
|
|
import logging
|
|
import json
|
|
import numpy as np # type: ignore
|
|
import traceback
|
|
|
|
import chromadb # type: ignore
|
|
from watchdog.observers import Observer # type: ignore
|
|
from watchdog.events import FileSystemEventHandler # type: ignore
|
|
import umap # type: ignore
|
|
from markitdown import MarkItDown # type: ignore
|
|
from chromadb.api.models.Collection import Collection # type: ignore
|
|
|
|
from .markdown_chunker import (
|
|
MarkdownChunker,
|
|
Chunk,
|
|
)
|
|
|
|
# When imported as a module, use relative imports
|
|
import defines
|
|
from database.manager import RedisDatabase
|
|
from models import ChromaDBGetResponse
|
|
|
|
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
|
|
|
|
DEFAULT_CHUNK_SIZE = 750
|
|
DEFAULT_CHUNK_OVERLAP = 100
|
|
|
|
|
|
class RagEntry(BaseModel):
|
|
name: str
|
|
description: str = ""
|
|
enabled: bool = True
|
|
|
|
|
|
class ChromaDBFileWatcher(FileSystemEventHandler):
|
|
def __init__(
|
|
self,
|
|
llm,
|
|
watch_directory,
|
|
loop,
|
|
persist_directory,
|
|
collection_name,
|
|
database: RedisDatabase,
|
|
user_id: str,
|
|
chunk_size=DEFAULT_CHUNK_SIZE,
|
|
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
|
|
recreate=False,
|
|
):
|
|
self.llm = llm
|
|
self.database = database
|
|
self.user_id = user_id
|
|
self.database = database
|
|
self.watch_directory = watch_directory
|
|
self.persist_directory = persist_directory or defines.persist_directory
|
|
self.collection_name = collection_name
|
|
self.chunk_size = chunk_size
|
|
self.chunk_overlap = chunk_overlap
|
|
self.loop = loop
|
|
self._umap_collection: ChromaDBGetResponse | None = None
|
|
self._umap_embedding_2d: np.ndarray = np.array([])
|
|
self._umap_embedding_3d: np.ndarray = np.array([])
|
|
self._umap_model_2d: Optional[umap.UMAP] = None
|
|
self._umap_model_3d: Optional[umap.UMAP] = None
|
|
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
|
|
self.processing_lock = asyncio.Lock()
|
|
|
|
self.processing_debounce = {} # Add this line
|
|
self.debounce_delay = 1.0 # seconds
|
|
|
|
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
|
|
# Path for storing file hash state
|
|
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
|
|
|
|
# Flag to track if this is a new collection
|
|
self.is_new_collection = False
|
|
|
|
# Initialize ChromaDB collection
|
|
self._collection: Collection = self._get_vector_collection(recreate=recreate)
|
|
self._markdown_chunker = MarkdownChunker()
|
|
self._update_umaps()
|
|
|
|
# Setup text splitter
|
|
# Track file hashes and processing state
|
|
self.file_hashes = self._load_hash_state()
|
|
self.update_lock = asyncio.Lock()
|
|
self.processing_files = set()
|
|
|
|
@property
|
|
def collection(self) -> Collection:
|
|
return self._collection
|
|
|
|
@property
|
|
def umap_collection(self) -> ChromaDBGetResponse:
|
|
if not self._umap_collection:
|
|
raise ValueError("initialize_collection has not been called")
|
|
return self._umap_collection
|
|
|
|
@property
|
|
def umap_embedding_2d(self) -> np.ndarray:
|
|
return self._umap_embedding_2d
|
|
|
|
@property
|
|
def umap_embedding_3d(self) -> np.ndarray:
|
|
return self._umap_embedding_3d
|
|
|
|
@property
|
|
def umap_model_2d(self):
|
|
return self._umap_model_2d
|
|
|
|
@property
|
|
def umap_model_3d(self):
|
|
return self._umap_model_3d
|
|
|
|
def _markitdown(self, document: str, markdown: Path):
|
|
logging.info(f"Converting {document} to {markdown}")
|
|
try:
|
|
result = self.md.convert(document)
|
|
markdown.write_text(result.text_content)
|
|
except Exception as e:
|
|
logging.error(f"Error convering via markdownit: {e}")
|
|
|
|
def _save_hash_state(self):
|
|
"""Save the current file hash state to disk."""
|
|
try:
|
|
# Create directory if it doesn't exist
|
|
os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True)
|
|
|
|
with open(self.hash_state_path, "w") as f:
|
|
json.dump(self.file_hashes, f)
|
|
|
|
logging.info(f"Saved hash state with {len(self.file_hashes)} entries")
|
|
except Exception as e:
|
|
logging.error(f"Error saving hash state: {e}")
|
|
|
|
def _load_hash_state(self):
|
|
"""Load the file hash state from disk."""
|
|
if os.path.exists(self.hash_state_path):
|
|
try:
|
|
with open(self.hash_state_path, "r") as f:
|
|
hash_state = json.load(f)
|
|
logging.info(f"Loaded hash state with {len(hash_state)} entries")
|
|
return hash_state
|
|
except Exception as e:
|
|
logging.error(f"Error loading hash state: {e}")
|
|
|
|
return {}
|
|
|
|
async def scan_directory(self, process_all=False):
|
|
"""
|
|
Scan directory for new, modified, or deleted files and update collection.
|
|
|
|
Args:
|
|
process_all: If True, process all files regardless of hash status
|
|
"""
|
|
# Check for new or modified files
|
|
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
|
|
files_checked = 0
|
|
files_processed = 0
|
|
files_to_process = []
|
|
|
|
logging.info(f"Starting directory scan. Found {len(file_paths)} total paths.")
|
|
|
|
for file_path in file_paths:
|
|
if os.path.isfile(file_path):
|
|
# Do not put the Resume in RAG as it is provideded with all queries.
|
|
# if file_path == defines.resume_doc:
|
|
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
|
|
# continue
|
|
files_checked += 1
|
|
current_hash = self._get_file_hash(file_path)
|
|
if not current_hash:
|
|
logging.info(f"Unable to obtain hash of {file_path}")
|
|
continue
|
|
|
|
# If file is new, changed, or we're processing all files
|
|
if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
|
|
self.file_hashes[file_path] = current_hash
|
|
files_to_process.append(file_path)
|
|
logging.info(f"File {'found' if process_all else 'changed'}: {file_path}")
|
|
|
|
logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files")
|
|
|
|
# Check for deleted files
|
|
deleted_files = []
|
|
for file_path in self.file_hashes:
|
|
if not os.path.exists(file_path):
|
|
deleted_files.append(file_path)
|
|
# Schedule removal
|
|
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
|
# Don't block on result, just let it run
|
|
logging.info(f"File deleted: {file_path}")
|
|
|
|
# Remove deleted files from hash state
|
|
for file_path in deleted_files:
|
|
del self.file_hashes[file_path]
|
|
|
|
# Process all discovered files using asyncio.gather with the existing loop
|
|
if files_to_process:
|
|
logging.info(f"Starting to process {len(files_to_process)} files")
|
|
|
|
for file_path in files_to_process:
|
|
async with self.update_lock:
|
|
files_processed += 1
|
|
await self._update_document_in_collection(file_path)
|
|
else:
|
|
logging.info("No files to process")
|
|
|
|
# Save the updated state
|
|
self._save_hash_state()
|
|
|
|
logging.info(
|
|
f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}"
|
|
)
|
|
return files_processed
|
|
|
|
async def process_file_update(self, file_path):
|
|
"""Process a file update event, debounced to debounce_delay."""
|
|
|
|
# Debouncing logic
|
|
current_time = asyncio.get_event_loop().time()
|
|
if file_path in self.processing_debounce:
|
|
time_since_last = current_time - self.processing_debounce[file_path]
|
|
if time_since_last < self.debounce_delay:
|
|
logging.info(f"Debouncing {file_path} (last processed {time_since_last:.2f}s ago)")
|
|
return
|
|
|
|
self.processing_debounce[file_path] = current_time
|
|
|
|
# Use a lock to make the check-and-add atomic
|
|
async with self.processing_lock:
|
|
if file_path in self.processing_files:
|
|
logging.info(f"{file_path} already in queue. Not adding.")
|
|
return
|
|
|
|
logging.info(f"{file_path} not in queue. Adding.")
|
|
self.processing_files.add(file_path)
|
|
|
|
try:
|
|
# Wait a moment to ensure the file write is complete
|
|
await asyncio.sleep(0.5)
|
|
|
|
# Check if content changed via hash
|
|
current_hash = self._get_file_hash(file_path)
|
|
if not current_hash:
|
|
return
|
|
|
|
# Use the update_lock to make hash check and update atomic
|
|
async with self.update_lock:
|
|
if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash:
|
|
logging.info(f"Hash has not changed for {file_path}")
|
|
return
|
|
|
|
# Update file hash BEFORE processing to prevent race conditions
|
|
self.file_hashes[file_path] = current_hash
|
|
|
|
# Process and update the file in ChromaDB
|
|
await self._update_document_in_collection(file_path)
|
|
|
|
# Save the hash state after successful update
|
|
self._save_hash_state()
|
|
|
|
# Re-fit the UMAP for the new content
|
|
self._update_umaps()
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error processing update for {file_path}: {e}")
|
|
finally:
|
|
self.processing_files.discard(file_path)
|
|
|
|
async def remove_file_from_collection(self, file_path):
|
|
"""Remove all chunks related to a deleted file."""
|
|
async with self.update_lock:
|
|
try:
|
|
# Find all documents with the specified path
|
|
results = self.collection.get(where={"path": file_path})
|
|
|
|
if results and "ids" in results and results["ids"]:
|
|
self.collection.delete(ids=results["ids"])
|
|
await self.database.update_user_rag_timestamp(self.user_id)
|
|
logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}")
|
|
|
|
# Remove from hash dictionary
|
|
if file_path in self.file_hashes:
|
|
del self.file_hashes[file_path]
|
|
# Save the updated hash state
|
|
self._save_hash_state()
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error removing file from collection: {e}")
|
|
|
|
def _update_umaps(self):
|
|
# Update the UMAP embeddings
|
|
self._umap_collection = ChromaDBGetResponse.model_validate(
|
|
self._collection.get(include=["embeddings", "documents", "metadatas"])
|
|
)
|
|
if not self._umap_collection or not len(self._umap_collection.embeddings):
|
|
logging.warning("⚠️ No embeddings found in the collection.")
|
|
return
|
|
|
|
# During initialization
|
|
logging.info(f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
|
|
vectors = np.array(self._umap_collection.embeddings)
|
|
self._umap_model_2d = umap.UMAP(
|
|
n_components=2,
|
|
random_state=8911,
|
|
metric="cosine",
|
|
n_neighbors=round(min(30, len(self._umap_collection.embeddings) * 0.5)),
|
|
min_dist=0.1,
|
|
)
|
|
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors) # type: ignore
|
|
# logging.info(
|
|
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
|
|
# ) # Should be 2
|
|
|
|
logging.info(f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
|
|
self._umap_model_3d = umap.UMAP(
|
|
n_components=3,
|
|
random_state=8911,
|
|
metric="cosine",
|
|
n_neighbors=round(min(30, len(self._umap_collection.embeddings) * 0.5)),
|
|
min_dist=0.01,
|
|
)
|
|
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors) # type: ignore
|
|
# logging.info(
|
|
# f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
|
|
# ) # Should be 3
|
|
|
|
def _get_vector_collection(self, recreate=False) -> Collection:
|
|
"""Get or create a ChromaDB collection."""
|
|
# Create the directory if it doesn't exist
|
|
if not os.path.exists(self.persist_directory):
|
|
os.makedirs(self.persist_directory)
|
|
|
|
# Initialize ChromaDB client
|
|
chroma_client = chromadb.PersistentClient(
|
|
path=self.persist_directory,
|
|
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
|
|
)
|
|
|
|
# Check if the collection exists
|
|
try:
|
|
chroma_client.get_collection(self.collection_name)
|
|
collection_exists = True
|
|
except:
|
|
collection_exists = False
|
|
|
|
# If collection doesn't exist, mark it as new
|
|
if not collection_exists:
|
|
self.is_new_collection = True
|
|
logging.info(f"Creating new collection: {self.collection_name}")
|
|
|
|
# Delete if recreate is True
|
|
if recreate and collection_exists:
|
|
chroma_client.delete_collection(name=self.collection_name)
|
|
self.is_new_collection = True
|
|
logging.info(f"Recreating collection: {self.collection_name}")
|
|
|
|
return chroma_client.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"})
|
|
|
|
async def get_embedding(self, text: str) -> np.ndarray:
|
|
"""Generate and normalize an embedding for the given text."""
|
|
|
|
# Get embedding
|
|
try:
|
|
response = await self.llm.embeddings(model=defines.embedding_model, input_texts=text)
|
|
embedding = np.array(response.get_single_embedding())
|
|
except Exception as e:
|
|
logging.error(traceback.format_exc())
|
|
logging.error(f"Failed to get embedding: {e}")
|
|
raise
|
|
|
|
# Log diagnostics
|
|
logging.debug(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
|
|
|
|
# Check for invalid embeddings
|
|
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
|
|
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
|
|
raise ValueError("Invalid embedding returned from Ollama.")
|
|
|
|
# Check normalization
|
|
norm = np.linalg.norm(embedding)
|
|
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
|
|
logging.debug(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
|
|
|
|
# Normalize if needed
|
|
if not is_normalized:
|
|
embedding = embedding / norm
|
|
logging.debug("Embedding normalized manually.")
|
|
|
|
return embedding
|
|
|
|
async def _add_embeddings_to_collection(self, chunks: List[Chunk]):
|
|
"""Add embeddings for chunks to the collection."""
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
text = chunk["text"]
|
|
metadata = chunk["metadata"]
|
|
|
|
# Generate a more unique ID based on content and metadata
|
|
path_hash = ""
|
|
if "path" in metadata:
|
|
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[:8]
|
|
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
|
|
chunk_id = f"{path_hash}_{i}_{content_hash}"
|
|
|
|
embedding = await self.get_embedding(text)
|
|
try:
|
|
self.collection.add(
|
|
ids=[chunk_id],
|
|
documents=[text],
|
|
embeddings=[embedding],
|
|
metadatas=[metadata],
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"Error adding chunk to collection: {e}")
|
|
logging.error(traceback.format_exc())
|
|
logging.error(chunk)
|
|
|
|
def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer) -> str | None:
|
|
source_file = meta.get("source_file")
|
|
try:
|
|
source_file = meta["source_file"]
|
|
path_parts = source_file.split(os.sep)
|
|
file_name = path_parts[-1]
|
|
meta["source_file"] = file_name
|
|
with open(source_file, "r") as file:
|
|
lines = file.readlines()
|
|
meta["file_lines"] = len(lines)
|
|
start = max(0, meta["line_begin"] - buffer)
|
|
meta["chunk_begin"] = start
|
|
end = min(meta["lines"], meta["line_end"] + buffer)
|
|
meta["chunk_end"] = end
|
|
return "".join(lines[start:end])
|
|
except:
|
|
logging.warning(f"⚠️ Unable to open {source_file}")
|
|
return None
|
|
|
|
# Cosine Distance Equivalent Similarity Retrieval Characteristics
|
|
# 0.2 - 0.3 0.85 - 0.90 Very strict, highly precise results only
|
|
# 0.3 - 0.5 0.75 - 0.85 Strong relevance, good precision
|
|
# 0.5 - 0.7 0.65 - 0.75 Balanced precision/recall
|
|
# 0.7 - 0.9 0.55 - 0.65 Higher recall, more inclusive
|
|
# 0.9 - 1.2 0.40 - 0.55 Very inclusive, may include tangential content
|
|
async def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
|
|
"""Find similar documents to the query."""
|
|
|
|
# collection is configured with hnsw:space cosine
|
|
query_embedding = await self.get_embedding(query)
|
|
results = self.collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=top_k,
|
|
include=["documents", "metadatas", "distances"],
|
|
)
|
|
|
|
# Extract results
|
|
ids = results["ids"][0]
|
|
documents = results["documents"][0] if results["documents"] else []
|
|
distances = results["distances"][0] if results["distances"] else []
|
|
metadatas = results["metadatas"][0] if results["metadatas"] else []
|
|
|
|
filtered_ids = []
|
|
filtered_documents = []
|
|
filtered_distances = []
|
|
filtered_metadatas = []
|
|
|
|
for i, distance in enumerate(distances):
|
|
if distance <= threshold: # For cosine distance, smaller is better
|
|
filtered_ids.append(ids[i])
|
|
filtered_documents.append(documents[i])
|
|
filtered_metadatas.append(metadatas[i])
|
|
filtered_distances.append(distance)
|
|
|
|
for index, meta in enumerate(filtered_metadatas):
|
|
content = self.prepare_metadata(meta)
|
|
if content is not None:
|
|
filtered_documents[index] = content
|
|
|
|
# Return the filtered results instead of all results
|
|
return {
|
|
"query_embedding": query_embedding,
|
|
"ids": filtered_ids,
|
|
"documents": filtered_documents,
|
|
"distances": filtered_distances,
|
|
"metadatas": filtered_metadatas,
|
|
}
|
|
|
|
def _get_file_hash(self, file_path):
|
|
"""Calculate MD5 hash of a file."""
|
|
try:
|
|
with open(file_path, "rb") as f:
|
|
return hashlib.md5(f.read()).hexdigest()
|
|
except Exception as e:
|
|
logging.error(f"Error hashing file {file_path}: {e}")
|
|
return None
|
|
|
|
def _should_process_file(self, file_path):
|
|
"""Check if a file should be processed."""
|
|
# Skip temporary files, hidden files, etc.
|
|
file_name = os.path.basename(file_path)
|
|
if file_name.startswith(".") or file_name.endswith(".tmp"):
|
|
return False
|
|
|
|
# Add other filtering logic as needed
|
|
return True
|
|
|
|
def on_modified(self, event):
|
|
"""Handle file modification events."""
|
|
if event.is_directory or not self._should_process_file(event.src_path):
|
|
return
|
|
|
|
file_path = event.src_path
|
|
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
|
|
logging.info(f"File modified: {file_path}")
|
|
|
|
def on_created(self, event):
|
|
"""Handle file creation events."""
|
|
if event.is_directory or not self._should_process_file(event.src_path):
|
|
return
|
|
|
|
file_path = event.src_path
|
|
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
|
|
logging.info(f"File created: {file_path}")
|
|
|
|
def on_deleted(self, event):
|
|
"""Handle file deletion events."""
|
|
if event.is_directory:
|
|
return
|
|
|
|
file_path = event.src_path
|
|
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
|
logging.info(f"File deleted: {file_path}")
|
|
|
|
def on_moved(self, event):
|
|
"""Handle move deletion events."""
|
|
if event.is_directory:
|
|
return
|
|
|
|
file_path = event.src_path
|
|
logging.info(f"TODO: on_moved: ${file_path}")
|
|
|
|
def _normalize_embeddings(self, embeddings):
|
|
"""Normalize the embeddings to unit length."""
|
|
# Handle both single vector and array of vectors
|
|
if isinstance(embeddings[0], (int, float)):
|
|
# Single vector
|
|
norm = np.linalg.norm(embeddings)
|
|
return [e / norm for e in embeddings] if norm > 0 else embeddings
|
|
else:
|
|
# Array of vectors
|
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
return embeddings / norms
|
|
|
|
async def _update_document_in_collection(self, file_path):
|
|
"""Update a document in the ChromaDB collection."""
|
|
try:
|
|
# Remove existing entries for this file
|
|
existing_results = self.collection.get(where={"path": file_path})
|
|
if existing_results and "ids" in existing_results and existing_results["ids"]:
|
|
self.collection.delete(ids=existing_results["ids"])
|
|
await self.database.update_user_rag_timestamp(self.user_id)
|
|
|
|
extensions = (".docx", ".xlsx", ".xls", ".pdf")
|
|
if file_path.endswith(extensions):
|
|
p = Path(file_path)
|
|
p_as_md = p.with_suffix(".md")
|
|
|
|
# If file_path.md doesn't exist or file_path is newer than file_path.md,
|
|
# fire off markitdown
|
|
if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime):
|
|
self._markitdown(file_path, p_as_md)
|
|
# Add the generated .md file to processing_files to prevent double-processing
|
|
self.processing_files.add(str(p_as_md))
|
|
return
|
|
|
|
chunks = self._markdown_chunker.process_file(file_path)
|
|
if not chunks:
|
|
logging.info(f"No chunks found in markdown: {file_path}")
|
|
return
|
|
|
|
# Extract top-level directory
|
|
rel_path = os.path.relpath(file_path, self.watch_directory)
|
|
path_parts = rel_path.split(os.sep)
|
|
top_level_dir = path_parts[0]
|
|
# file_name = path_parts[-1]
|
|
for i, chunk in enumerate(chunks):
|
|
chunk["metadata"]["doc_type"] = top_level_dir
|
|
# with open(f"src/tmp/{file_name}.{i}", "w") as f:
|
|
# f.write(json.dumps(chunk, indent=2))
|
|
|
|
# Add chunks to collection
|
|
await self._add_embeddings_to_collection(chunks)
|
|
await self.database.update_user_rag_timestamp(self.user_id)
|
|
|
|
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error updating document in collection: {e}")
|
|
logging.error(traceback.format_exc())
|
|
|
|
async def initialize_collection(self):
|
|
"""Initialize the collection with all documents from the watch directory."""
|
|
# Process all files regardless of hash state
|
|
num_processed = await self.scan_directory(process_all=True)
|
|
|
|
logging.info(f"Vectorstore initialized with {self.collection.count()} documents")
|
|
|
|
self._update_umaps()
|
|
|
|
# Show stats
|
|
try:
|
|
all_metadata = self.collection.get()["metadatas"]
|
|
if all_metadata:
|
|
doc_types = set(m.get("doc_type", "unknown") for m in all_metadata)
|
|
logging.info(f"Document types: {doc_types}")
|
|
except Exception as e:
|
|
logging.error(f"Error getting document types: {e}")
|
|
|
|
return num_processed
|
|
|
|
|
|
# Function to start the file watcher
|
|
def start_file_watcher(
|
|
llm,
|
|
user_id,
|
|
watch_directory,
|
|
persist_directory,
|
|
collection_name,
|
|
database: RedisDatabase,
|
|
initialize=False,
|
|
recreate=False,
|
|
):
|
|
"""
|
|
Start watching a directory for file changes.
|
|
|
|
Args:
|
|
llm: The language model client
|
|
watch_directory: Directory to watch for changes
|
|
persist_directory: Directory to persist ChromaDB and hash state
|
|
collection_name: Name of the ChromaDB collection
|
|
initialize: Whether to forcibly initialize the collection with all documents
|
|
recreate: Whether to recreate the collection (will delete existing)
|
|
"""
|
|
loop = asyncio.get_event_loop()
|
|
|
|
file_watcher = ChromaDBFileWatcher(
|
|
llm,
|
|
watch_directory=watch_directory,
|
|
loop=loop,
|
|
user_id=user_id,
|
|
persist_directory=persist_directory,
|
|
collection_name=collection_name,
|
|
recreate=recreate,
|
|
database=database,
|
|
)
|
|
|
|
# Process all files if:
|
|
# 1. initialize=True was passed (explicit request to initialize)
|
|
# 2. This is a new collection (doesn't exist yet)
|
|
# 3. There's no hash state (first run)
|
|
if initialize or file_watcher.is_new_collection or not file_watcher.file_hashes:
|
|
logging.info("Initializing collection with all documents")
|
|
asyncio.run_coroutine_threadsafe(file_watcher.initialize_collection(), loop)
|
|
else:
|
|
# Only process new/changed files
|
|
logging.info("Scanning for new/changed documents")
|
|
asyncio.run_coroutine_threadsafe(file_watcher.scan_directory(), loop)
|
|
|
|
# Start observer
|
|
observer = Observer()
|
|
observer.schedule(file_watcher, watch_directory, recursive=True)
|
|
observer.start()
|
|
|
|
logging.info(f"Started watching directory: {watch_directory}")
|
|
return observer, file_watcher
|