2025-07-16 16:12:10 -07:00

684 lines
26 KiB
Python

from __future__ import annotations
from pydantic import BaseModel # type: ignore
from typing import List, Optional, Dict, Any
import os
import glob
from pathlib import Path
import hashlib
import asyncio
import logging
import json
import numpy as np # type: ignore
import traceback
import chromadb # type: ignore
from watchdog.observers import Observer # type: ignore
from watchdog.events import FileSystemEventHandler # type: ignore
import umap # type: ignore
from markitdown import MarkItDown # type: ignore
from chromadb.api.models.Collection import Collection # type: ignore
from .markdown_chunker import (
MarkdownChunker,
Chunk,
)
# When imported as a module, use relative imports
import defines
from database.manager import RedisDatabase
from models import ChromaDBGetResponse
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
DEFAULT_CHUNK_SIZE = 750
DEFAULT_CHUNK_OVERLAP = 100
class RagEntry(BaseModel):
name: str
description: str = ""
enabled: bool = True
class ChromaDBFileWatcher(FileSystemEventHandler):
def __init__(
self,
llm,
watch_directory,
loop,
persist_directory,
collection_name,
database: RedisDatabase,
user_id: str,
chunk_size=DEFAULT_CHUNK_SIZE,
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
recreate=False,
):
self.llm = llm
self.database = database
self.user_id = user_id
self.database = database
self.watch_directory = watch_directory
self.persist_directory = persist_directory or defines.persist_directory
self.collection_name = collection_name
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.loop = loop
self._umap_collection: ChromaDBGetResponse | None = None
self._umap_embedding_2d: np.ndarray = np.array([])
self._umap_embedding_3d: np.ndarray = np.array([])
self._umap_model_2d: Optional[umap.UMAP] = None
self._umap_model_3d: Optional[umap.UMAP] = None
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
self.processing_lock = asyncio.Lock()
self.processing_debounce = {} # Add this line
self.debounce_delay = 1.0 # seconds
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Path for storing file hash state
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
# Flag to track if this is a new collection
self.is_new_collection = False
# Initialize ChromaDB collection
self._collection: Collection = self._get_vector_collection(recreate=recreate)
self._markdown_chunker = MarkdownChunker()
self._update_umaps()
# Setup text splitter
# Track file hashes and processing state
self.file_hashes = self._load_hash_state()
self.update_lock = asyncio.Lock()
self.processing_files = set()
@property
def collection(self) -> Collection:
return self._collection
@property
def umap_collection(self) -> ChromaDBGetResponse:
if not self._umap_collection:
raise ValueError("initialize_collection has not been called")
return self._umap_collection
@property
def umap_embedding_2d(self) -> np.ndarray:
return self._umap_embedding_2d
@property
def umap_embedding_3d(self) -> np.ndarray:
return self._umap_embedding_3d
@property
def umap_model_2d(self):
return self._umap_model_2d
@property
def umap_model_3d(self):
return self._umap_model_3d
def _markitdown(self, document: str, markdown: Path):
logging.info(f"Converting {document} to {markdown}")
try:
result = self.md.convert(document)
markdown.write_text(result.text_content)
except Exception as e:
logging.error(f"Error convering via markdownit: {e}")
def _save_hash_state(self):
"""Save the current file hash state to disk."""
try:
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True)
with open(self.hash_state_path, "w") as f:
json.dump(self.file_hashes, f)
logging.info(f"Saved hash state with {len(self.file_hashes)} entries")
except Exception as e:
logging.error(f"Error saving hash state: {e}")
def _load_hash_state(self):
"""Load the file hash state from disk."""
if os.path.exists(self.hash_state_path):
try:
with open(self.hash_state_path, "r") as f:
hash_state = json.load(f)
logging.info(f"Loaded hash state with {len(hash_state)} entries")
return hash_state
except Exception as e:
logging.error(f"Error loading hash state: {e}")
return {}
async def scan_directory(self, process_all=False):
"""
Scan directory for new, modified, or deleted files and update collection.
Args:
process_all: If True, process all files regardless of hash status
"""
# Check for new or modified files
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
files_checked = 0
files_processed = 0
files_to_process = []
logging.info(f"Starting directory scan. Found {len(file_paths)} total paths.")
for file_path in file_paths:
if os.path.isfile(file_path):
# Do not put the Resume in RAG as it is provideded with all queries.
# if file_path == defines.resume_doc:
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
# continue
files_checked += 1
current_hash = self._get_file_hash(file_path)
if not current_hash:
logging.info(f"Unable to obtain hash of {file_path}")
continue
# If file is new, changed, or we're processing all files
if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
self.file_hashes[file_path] = current_hash
files_to_process.append(file_path)
logging.info(f"File {'found' if process_all else 'changed'}: {file_path}")
logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files")
# Check for deleted files
deleted_files = []
for file_path in self.file_hashes:
if not os.path.exists(file_path):
deleted_files.append(file_path)
# Schedule removal
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
# Don't block on result, just let it run
logging.info(f"File deleted: {file_path}")
# Remove deleted files from hash state
for file_path in deleted_files:
del self.file_hashes[file_path]
# Process all discovered files using asyncio.gather with the existing loop
if files_to_process:
logging.info(f"Starting to process {len(files_to_process)} files")
for file_path in files_to_process:
async with self.update_lock:
files_processed += 1
await self._update_document_in_collection(file_path)
else:
logging.info("No files to process")
# Save the updated state
self._save_hash_state()
logging.info(
f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}"
)
return files_processed
async def process_file_update(self, file_path):
"""Process a file update event, debounced to debounce_delay."""
# Debouncing logic
current_time = asyncio.get_event_loop().time()
if file_path in self.processing_debounce:
time_since_last = current_time - self.processing_debounce[file_path]
if time_since_last < self.debounce_delay:
logging.info(f"Debouncing {file_path} (last processed {time_since_last:.2f}s ago)")
return
self.processing_debounce[file_path] = current_time
# Use a lock to make the check-and-add atomic
async with self.processing_lock:
if file_path in self.processing_files:
logging.info(f"{file_path} already in queue. Not adding.")
return
logging.info(f"{file_path} not in queue. Adding.")
self.processing_files.add(file_path)
try:
# Wait a moment to ensure the file write is complete
await asyncio.sleep(0.5)
# Check if content changed via hash
current_hash = self._get_file_hash(file_path)
if not current_hash:
return
# Use the update_lock to make hash check and update atomic
async with self.update_lock:
if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash:
logging.info(f"Hash has not changed for {file_path}")
return
# Update file hash BEFORE processing to prevent race conditions
self.file_hashes[file_path] = current_hash
# Process and update the file in ChromaDB
await self._update_document_in_collection(file_path)
# Save the hash state after successful update
self._save_hash_state()
# Re-fit the UMAP for the new content
self._update_umaps()
except Exception as e:
logging.error(f"Error processing update for {file_path}: {e}")
finally:
self.processing_files.discard(file_path)
async def remove_file_from_collection(self, file_path):
"""Remove all chunks related to a deleted file."""
async with self.update_lock:
try:
# Find all documents with the specified path
results = self.collection.get(where={"path": file_path})
if results and "ids" in results and results["ids"]:
self.collection.delete(ids=results["ids"])
await self.database.update_user_rag_timestamp(self.user_id)
logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}")
# Remove from hash dictionary
if file_path in self.file_hashes:
del self.file_hashes[file_path]
# Save the updated hash state
self._save_hash_state()
except Exception as e:
logging.error(f"Error removing file from collection: {e}")
def _update_umaps(self):
# Update the UMAP embeddings
self._umap_collection = ChromaDBGetResponse.model_validate(
self._collection.get(include=["embeddings", "documents", "metadatas"])
)
if not self._umap_collection or not len(self._umap_collection.embeddings):
logging.warning("⚠️ No embeddings found in the collection.")
return
# During initialization
logging.info(f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
vectors = np.array(self._umap_collection.embeddings)
self._umap_model_2d = umap.UMAP(
n_components=2,
random_state=8911,
metric="cosine",
n_neighbors=round(min(30, len(self._umap_collection.embeddings) * 0.5)),
min_dist=0.1,
)
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors) # type: ignore
# logging.info(
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
# ) # Should be 2
logging.info(f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
self._umap_model_3d = umap.UMAP(
n_components=3,
random_state=8911,
metric="cosine",
n_neighbors=round(min(30, len(self._umap_collection.embeddings) * 0.5)),
min_dist=0.01,
)
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors) # type: ignore
# logging.info(
# f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
# ) # Should be 3
def _get_vector_collection(self, recreate=False) -> Collection:
"""Get or create a ChromaDB collection."""
# Create the directory if it doesn't exist
if not os.path.exists(self.persist_directory):
os.makedirs(self.persist_directory)
# Initialize ChromaDB client
chroma_client = chromadb.PersistentClient(
path=self.persist_directory,
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
)
# Check if the collection exists
try:
chroma_client.get_collection(self.collection_name)
collection_exists = True
except:
collection_exists = False
# If collection doesn't exist, mark it as new
if not collection_exists:
self.is_new_collection = True
logging.info(f"Creating new collection: {self.collection_name}")
# Delete if recreate is True
if recreate and collection_exists:
chroma_client.delete_collection(name=self.collection_name)
self.is_new_collection = True
logging.info(f"Recreating collection: {self.collection_name}")
return chroma_client.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"})
async def get_embedding(self, text: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text."""
# Get embedding
try:
response = await self.llm.embeddings(model=defines.embedding_model, input_texts=text)
embedding = np.array(response.get_single_embedding())
except Exception as e:
logging.error(traceback.format_exc())
logging.error(f"Failed to get embedding: {e}")
raise
# Log diagnostics
logging.debug(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
# Check for invalid embeddings
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
raise ValueError("Invalid embedding returned from Ollama.")
# Check normalization
norm = np.linalg.norm(embedding)
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
logging.debug(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
# Normalize if needed
if not is_normalized:
embedding = embedding / norm
logging.debug("Embedding normalized manually.")
return embedding
async def _add_embeddings_to_collection(self, chunks: List[Chunk]):
"""Add embeddings for chunks to the collection."""
for i, chunk in enumerate(chunks):
text = chunk["text"]
metadata = chunk["metadata"]
# Generate a more unique ID based on content and metadata
path_hash = ""
if "path" in metadata:
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[:8]
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
chunk_id = f"{path_hash}_{i}_{content_hash}"
embedding = await self.get_embedding(text)
try:
self.collection.add(
ids=[chunk_id],
documents=[text],
embeddings=[embedding],
metadatas=[metadata],
)
except Exception as e:
logging.error(f"Error adding chunk to collection: {e}")
logging.error(traceback.format_exc())
logging.error(chunk)
def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer) -> str | None:
source_file = meta.get("source_file")
try:
source_file = meta["source_file"]
path_parts = source_file.split(os.sep)
file_name = path_parts[-1]
meta["source_file"] = file_name
with open(source_file, "r") as file:
lines = file.readlines()
meta["file_lines"] = len(lines)
start = max(0, meta["line_begin"] - buffer)
meta["chunk_begin"] = start
end = min(meta["lines"], meta["line_end"] + buffer)
meta["chunk_end"] = end
return "".join(lines[start:end])
except:
logging.warning(f"⚠️ Unable to open {source_file}")
return None
# Cosine Distance Equivalent Similarity Retrieval Characteristics
# 0.2 - 0.3 0.85 - 0.90 Very strict, highly precise results only
# 0.3 - 0.5 0.75 - 0.85 Strong relevance, good precision
# 0.5 - 0.7 0.65 - 0.75 Balanced precision/recall
# 0.7 - 0.9 0.55 - 0.65 Higher recall, more inclusive
# 0.9 - 1.2 0.40 - 0.55 Very inclusive, may include tangential content
async def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
"""Find similar documents to the query."""
# collection is configured with hnsw:space cosine
query_embedding = await self.get_embedding(query)
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=["documents", "metadatas", "distances"],
)
# Extract results
ids = results["ids"][0]
documents = results["documents"][0] if results["documents"] else []
distances = results["distances"][0] if results["distances"] else []
metadatas = results["metadatas"][0] if results["metadatas"] else []
filtered_ids = []
filtered_documents = []
filtered_distances = []
filtered_metadatas = []
for i, distance in enumerate(distances):
if distance <= threshold: # For cosine distance, smaller is better
filtered_ids.append(ids[i])
filtered_documents.append(documents[i])
filtered_metadatas.append(metadatas[i])
filtered_distances.append(distance)
for index, meta in enumerate(filtered_metadatas):
content = self.prepare_metadata(meta)
if content is not None:
filtered_documents[index] = content
# Return the filtered results instead of all results
return {
"query_embedding": query_embedding,
"ids": filtered_ids,
"documents": filtered_documents,
"distances": filtered_distances,
"metadatas": filtered_metadatas,
}
def _get_file_hash(self, file_path):
"""Calculate MD5 hash of a file."""
try:
with open(file_path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()
except Exception as e:
logging.error(f"Error hashing file {file_path}: {e}")
return None
def _should_process_file(self, file_path):
"""Check if a file should be processed."""
# Skip temporary files, hidden files, etc.
file_name = os.path.basename(file_path)
if file_name.startswith(".") or file_name.endswith(".tmp"):
return False
# Add other filtering logic as needed
return True
def on_modified(self, event):
"""Handle file modification events."""
if event.is_directory or not self._should_process_file(event.src_path):
return
file_path = event.src_path
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
logging.info(f"File modified: {file_path}")
def on_created(self, event):
"""Handle file creation events."""
if event.is_directory or not self._should_process_file(event.src_path):
return
file_path = event.src_path
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
logging.info(f"File created: {file_path}")
def on_deleted(self, event):
"""Handle file deletion events."""
if event.is_directory:
return
file_path = event.src_path
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
logging.info(f"File deleted: {file_path}")
def on_moved(self, event):
"""Handle move deletion events."""
if event.is_directory:
return
file_path = event.src_path
logging.info(f"TODO: on_moved: ${file_path}")
def _normalize_embeddings(self, embeddings):
"""Normalize the embeddings to unit length."""
# Handle both single vector and array of vectors
if isinstance(embeddings[0], (int, float)):
# Single vector
norm = np.linalg.norm(embeddings)
return [e / norm for e in embeddings] if norm > 0 else embeddings
else:
# Array of vectors
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings / norms
async def _update_document_in_collection(self, file_path):
"""Update a document in the ChromaDB collection."""
try:
# Remove existing entries for this file
existing_results = self.collection.get(where={"path": file_path})
if existing_results and "ids" in existing_results and existing_results["ids"]:
self.collection.delete(ids=existing_results["ids"])
await self.database.update_user_rag_timestamp(self.user_id)
extensions = (".docx", ".xlsx", ".xls", ".pdf")
if file_path.endswith(extensions):
p = Path(file_path)
p_as_md = p.with_suffix(".md")
# If file_path.md doesn't exist or file_path is newer than file_path.md,
# fire off markitdown
if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime):
self._markitdown(file_path, p_as_md)
# Add the generated .md file to processing_files to prevent double-processing
self.processing_files.add(str(p_as_md))
return
chunks = self._markdown_chunker.process_file(file_path)
if not chunks:
logging.info(f"No chunks found in markdown: {file_path}")
return
# Extract top-level directory
rel_path = os.path.relpath(file_path, self.watch_directory)
path_parts = rel_path.split(os.sep)
top_level_dir = path_parts[0]
# file_name = path_parts[-1]
for i, chunk in enumerate(chunks):
chunk["metadata"]["doc_type"] = top_level_dir
# with open(f"src/tmp/{file_name}.{i}", "w") as f:
# f.write(json.dumps(chunk, indent=2))
# Add chunks to collection
await self._add_embeddings_to_collection(chunks)
await self.database.update_user_rag_timestamp(self.user_id)
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
except Exception as e:
logging.error(f"Error updating document in collection: {e}")
logging.error(traceback.format_exc())
async def initialize_collection(self):
"""Initialize the collection with all documents from the watch directory."""
# Process all files regardless of hash state
num_processed = await self.scan_directory(process_all=True)
logging.info(f"Vectorstore initialized with {self.collection.count()} documents")
self._update_umaps()
# Show stats
try:
all_metadata = self.collection.get()["metadatas"]
if all_metadata:
doc_types = set(m.get("doc_type", "unknown") for m in all_metadata)
logging.info(f"Document types: {doc_types}")
except Exception as e:
logging.error(f"Error getting document types: {e}")
return num_processed
# Function to start the file watcher
def start_file_watcher(
llm,
user_id,
watch_directory,
persist_directory,
collection_name,
database: RedisDatabase,
initialize=False,
recreate=False,
):
"""
Start watching a directory for file changes.
Args:
llm: The language model client
watch_directory: Directory to watch for changes
persist_directory: Directory to persist ChromaDB and hash state
collection_name: Name of the ChromaDB collection
initialize: Whether to forcibly initialize the collection with all documents
recreate: Whether to recreate the collection (will delete existing)
"""
loop = asyncio.get_event_loop()
file_watcher = ChromaDBFileWatcher(
llm,
watch_directory=watch_directory,
loop=loop,
user_id=user_id,
persist_directory=persist_directory,
collection_name=collection_name,
recreate=recreate,
database=database,
)
# Process all files if:
# 1. initialize=True was passed (explicit request to initialize)
# 2. This is a new collection (doesn't exist yet)
# 3. There's no hash state (first run)
if initialize or file_watcher.is_new_collection or not file_watcher.file_hashes:
logging.info("Initializing collection with all documents")
asyncio.run_coroutine_threadsafe(file_watcher.initialize_collection(), loop)
else:
# Only process new/changed files
logging.info("Scanning for new/changed documents")
asyncio.run_coroutine_threadsafe(file_watcher.scan_directory(), loop)
# Start observer
observer = Observer()
observer.schedule(file_watcher, watch_directory, recursive=True)
observer.start()
logging.info(f"Started watching directory: {watch_directory}")
return observer, file_watcher