750 lines
27 KiB
Python

from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
from typing import List, Optional, Dict, Any, Union
import os
import glob
from pathlib import Path
import time
import hashlib
import asyncio
import logging
import json
import numpy as np # type: ignore
import traceback
import chromadb
import ollama
from watchdog.observers import Observer # type: ignore
from watchdog.events import FileSystemEventHandler # type: ignore
import umap # type: ignore
from markitdown import MarkItDown # type: ignore
from chromadb.api.models.Collection import Collection # type: ignore
from .markdown_chunker import (
MarkdownChunker,
Chunk,
)
# Import your existing modules
if __name__ == "__main__":
# When running directly, use absolute imports
import defines
else:
# When imported as a module, use relative imports
from . import defines
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "ChromaDBGetResponse"]
DEFAULT_CHUNK_SIZE = 750
DEFAULT_CHUNK_OVERLAP = 100
class ChromaDBGetResponse(BaseModel):
name: str = ""
size: int = 0
ids: List[str] = []
embeddings: List[List[float]] = Field(default=[])
documents: List[str] = []
metadatas: List[Dict[str, Any]] = []
query: str = ""
query_embedding: Optional[List[float]] = Field(default=None)
umap_embedding_2d: Optional[List[float]] = Field(default=None)
umap_embedding_3d: Optional[List[float]] = Field(default=None)
enabled: bool = True
class Config:
validate_assignment = True
@field_validator("embeddings", "query_embedding", "umap_embedding_2d", "umap_embedding_3d")
@classmethod
def validate_embeddings(cls, value, field):
logging.info(f"Validating {field.field_name} with value: {type(value)} - {value}")
if value is None:
return value
if isinstance(value, np.ndarray):
if field.field_name == "embeddings":
if value.ndim != 2:
raise ValueError(f"{field.name} must be a 2-dimensional NumPy array")
return [[float(x) for x in row] for row in value.tolist()]
else:
if value.ndim != 1:
raise ValueError(f"{field.field_name} must be a 1-dimensional NumPy array")
return [float(x) for x in value.tolist()]
if field.field_name == "embeddings":
if not all(isinstance(sublist, list) and all(isinstance(x, (int, float)) for x in sublist) for sublist in value):
raise ValueError(f"{field.field_name} must be a list of lists of floats")
return [[float(x) for x in sublist] for sublist in value]
else:
if not isinstance(value, list) or not all(isinstance(x, (int, float)) for x in value):
raise ValueError(f"{field.field_name} must be a list of floats")
return [float(x) for x in value]
class ChromaDBFileWatcher(FileSystemEventHandler):
def __init__(
self,
llm,
watch_directory,
loop,
persist_directory=None,
collection_name="documents",
chunk_size=DEFAULT_CHUNK_SIZE,
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
recreate=False,
):
self.llm = llm
self.watch_directory = watch_directory
self.persist_directory = persist_directory or defines.persist_directory
self.collection_name = collection_name
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.loop = loop
self._umap_collection: ChromaDBGetResponse | None = None
self._umap_embedding_2d: np.ndarray = []
self._umap_embedding_3d: np.ndarray = []
self._umap_model_2d: umap.UMAP = None
self._umap_model_3d: umap.UMAP = None
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Path for storing file hash state
self.hash_state_path = os.path.join(
self.persist_directory, f"{collection_name}_hash_state.json"
)
# Flag to track if this is a new collection
self.is_new_collection = False
# Initialize ChromaDB collection
self._collection: Collection = self._get_vector_collection(recreate=recreate)
self._markdown_chunker = MarkdownChunker()
self._update_umaps()
# Setup text splitter
# Track file hashes and processing state
self.file_hashes = self._load_hash_state()
self.update_lock = asyncio.Lock()
self.processing_files = set()
@property
def collection(self):
return self._collection
@property
def umap_collection(self) -> ChromaDBGetResponse | None:
return self._umap_collection
@property
def umap_embedding_2d(self) -> np.ndarray:
return self._umap_embedding_2d
@property
def umap_embedding_3d(self) -> np.ndarray:
return self._umap_embedding_3d
@property
def umap_model_2d(self):
return self._umap_model_2d
@property
def umap_model_3d(self):
return self._umap_model_3d
def _markitdown(self, document: str, markdown: Path):
logging.info(f"Converting {document} to {markdown}")
try:
result = self.md.convert(document)
markdown.write_text(result.text_content)
except Exception as e:
logging.error(f"Error convering via markdownit: {e}")
def _save_hash_state(self):
"""Save the current file hash state to disk."""
try:
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True)
with open(self.hash_state_path, "w") as f:
json.dump(self.file_hashes, f)
logging.info(f"Saved hash state with {len(self.file_hashes)} entries")
except Exception as e:
logging.error(f"Error saving hash state: {e}")
def _load_hash_state(self):
"""Load the file hash state from disk."""
if os.path.exists(self.hash_state_path):
try:
with open(self.hash_state_path, "r") as f:
hash_state = json.load(f)
logging.info(f"Loaded hash state with {len(hash_state)} entries")
return hash_state
except Exception as e:
logging.error(f"Error loading hash state: {e}")
return {}
async def scan_directory(self, process_all=False):
"""
Scan directory for new, modified, or deleted files and update collection.
Args:
process_all: If True, process all files regardless of hash status
"""
# Check for new or modified files
file_paths = glob.glob(
os.path.join(self.watch_directory, "**/*"), recursive=True
)
files_checked = 0
files_processed = 0
files_to_process = []
logging.info(f"Starting directory scan. Found {len(file_paths)} total paths.")
for file_path in file_paths:
if os.path.isfile(file_path):
# Do not put the Resume in RAG as it is provideded with all queries.
# if file_path == defines.resume_doc:
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
# continue
files_checked += 1
current_hash = self._get_file_hash(file_path)
if not current_hash:
continue
# If file is new, changed, or we're processing all files
if (
process_all
or file_path not in self.file_hashes
or self.file_hashes[file_path] != current_hash
):
self.file_hashes[file_path] = current_hash
files_to_process.append(file_path)
logging.info(
f"File {'found' if process_all else 'changed'}: {file_path}"
)
logging.info(
f"Found {len(files_to_process)} files to process after scanning {files_checked} files"
)
# Check for deleted files
deleted_files = []
for file_path in self.file_hashes:
if not os.path.exists(file_path):
deleted_files.append(file_path)
# Schedule removal
asyncio.run_coroutine_threadsafe(
self.remove_file_from_collection(file_path), self.loop
)
# Don't block on result, just let it run
logging.info(f"File deleted: {file_path}")
# Remove deleted files from hash state
for file_path in deleted_files:
del self.file_hashes[file_path]
# Process all discovered files using asyncio.gather with the existing loop
if files_to_process:
logging.info(f"Starting to process {len(files_to_process)} files")
for file_path in files_to_process:
async with self.update_lock:
await self._update_document_in_collection(file_path)
else:
logging.info("No files to process")
# Save the updated state
self._save_hash_state()
logging.info(
f"Scan complete: Checked {files_checked} files, processed {files_processed}, removed {len(deleted_files)}"
)
return files_processed
async def process_file_update(self, file_path):
"""Process a file update event."""
# Skip if already being processed
if file_path in self.processing_files:
logging.info(f"{file_path} already in queue. Not adding.")
return
# if file_path == defines.resume_doc:
# logging.info(f"Not adding {file_path} to RAG -- primary resume")
# return
try:
logging.info(f"{file_path} not in queue. Adding.")
self.processing_files.add(file_path)
# Wait a moment to ensure the file write is complete
await asyncio.sleep(0.5)
# Check if content changed via hash
current_hash = self._get_file_hash(file_path)
if not current_hash: # File might have been deleted or is inaccessible
return
if (
file_path in self.file_hashes
and self.file_hashes[file_path] == current_hash
):
# File hasn't actually changed in content
logging.info(f"Hash has not changed for {file_path}")
return
# Update file hash
self.file_hashes[file_path] = current_hash
# Process and update the file in ChromaDB
async with self.update_lock:
await self._update_document_in_collection(file_path)
# Save the hash state after successful update
self._save_hash_state()
# Re-fit the UMAP for the new content
self._update_umaps()
except Exception as e:
logging.error(f"Error processing update for {file_path}: {e}")
finally:
self.processing_files.discard(file_path)
async def remove_file_from_collection(self, file_path):
"""Remove all chunks related to a deleted file."""
async with self.update_lock:
try:
# Find all documents with the specified path
results = self.collection.get(where={"path": file_path})
if results and "ids" in results and results["ids"]:
self.collection.delete(ids=results["ids"])
logging.info(
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
)
# Remove from hash dictionary
if file_path in self.file_hashes:
del self.file_hashes[file_path]
# Save the updated hash state
self._save_hash_state()
except Exception as e:
logging.error(f"Error removing file from collection: {e}")
def _update_umaps(self):
# Update the UMAP embeddings
self._umap_collection = self._collection.get(
include=["embeddings", "documents", "metadatas"]
)
if not self._umap_collection or not len(self._umap_collection["embeddings"]):
logging.warning("No embeddings found in the collection.")
return
# During initialization
logging.info(
f"Updating 2D UMAP for {len(self._umap_collection['embeddings'])} vectors"
)
vectors = np.array(self._umap_collection["embeddings"])
self._umap_model_2d = umap.UMAP(
n_components=2,
random_state=8911,
metric="cosine",
n_neighbors=30,
min_dist=0.1,
)
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors)
# logging.info(
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
# ) # Should be 2
logging.info(
f"Updating 3D UMAP for {len(self._umap_collection['embeddings'])} vectors"
)
self._umap_model_3d = umap.UMAP(
n_components=3,
random_state=8911,
metric="cosine",
n_neighbors=30,
min_dist=0.01,
)
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)
# logging.info(
# f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
# ) # Should be 3
def _get_vector_collection(self, recreate=False) -> Collection:
"""Get or create a ChromaDB collection."""
# Initialize ChromaDB client
chroma_client = chromadb.PersistentClient( # type: ignore
path=self.persist_directory,
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
)
# Check if the collection exists
try:
chroma_client.get_collection(self.collection_name)
collection_exists = True
except:
collection_exists = False
# If collection doesn't exist, mark it as new
if not collection_exists:
self.is_new_collection = True
logging.info(f"Creating new collection: {self.collection_name}")
# Delete if recreate is True
if recreate and collection_exists:
chroma_client.delete_collection(name=self.collection_name)
self.is_new_collection = True
logging.info(f"Recreating collection: {self.collection_name}")
return chroma_client.get_or_create_collection(
name=self.collection_name, metadata={"hnsw:space": "cosine"}
)
def create_chunks_from_documents(self, docs):
"""Split documents into chunks using the text splitter."""
return self.text_splitter.split_documents(docs)
def get_embedding(self, text: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text."""
# Get embedding
try:
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
embedding = np.array(response["embedding"])
except Exception as e:
logging.error(f"Failed to get embedding: {e}")
raise
# Log diagnostics
logging.info(f"Input text: {text}")
logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
# Check for invalid embeddings
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
raise ValueError("Invalid embedding returned from Ollama.")
# Check normalization
norm = np.linalg.norm(embedding)
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
# Normalize if needed
if not is_normalized:
embedding = embedding / norm
logging.info("Embedding normalized manually.")
return embedding
def add_embeddings_to_collection(self, chunks: List[Chunk]):
"""Add embeddings for chunks to the collection."""
for i, chunk in enumerate(chunks):
text = chunk["text"]
metadata = chunk["metadata"]
# Generate a more unique ID based on content and metadata
path_hash = ""
if "path" in metadata:
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[
:8
]
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
chunk_id = f"{path_hash}_{i}_{content_hash}"
embedding = self.get_embedding(text)
try:
self.collection.add(
ids=[chunk_id],
documents=[text],
embeddings=[embedding],
metadatas=[metadata],
)
except Exception as e:
logging.error(f"Error adding chunk to collection: {e}")
logging.error(traceback.format_exc())
logging.error(chunk)
def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer)-> str | None:
try:
source_file = meta["source_file"]
path_parts = source_file.split(os.sep)
file_name = path_parts[-1]
meta["source_file"] = file_name
with open(source_file, "r") as file:
lines = file.readlines()
meta["file_lines"] = len(lines)
start = max(0, meta["line_begin"] - buffer)
meta["chunk_begin"] = start
end = min(meta["lines"], meta["line_end"] + buffer)
meta["chunk_end"] = end
return "".join(lines[start:end])
except:
logging.warning(f"Unable to open {meta["source_file"]}")
return None
# Cosine Distance Equivalent Similarity Retrieval Characteristics
# 0.2 - 0.3 0.85 - 0.90 Very strict, highly precise results only
# 0.3 - 0.5 0.75 - 0.85 Strong relevance, good precision
# 0.5 - 0.7 0.65 - 0.75 Balanced precision/recall
# 0.7 - 0.9 0.55 - 0.65 Higher recall, more inclusive
# 0.9 - 1.2 0.40 - 0.55 Very inclusive, may include tangential content
def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
"""Find similar documents to the query."""
# collection is configured with hnsw:space cosine
query_embedding = self.get_embedding(query)
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=["documents", "metadatas", "distances"],
)
# Extract results
ids = results["ids"][0]
documents = results["documents"][0]
distances = results["distances"][0]
metadatas = results["metadatas"][0]
filtered_ids = []
filtered_documents = []
filtered_distances = []
filtered_metadatas = []
for i, distance in enumerate(distances):
if distance <= threshold: # For cosine distance, smaller is better
filtered_ids.append(ids[i])
filtered_documents.append(documents[i])
filtered_metadatas.append(metadatas[i])
filtered_distances.append(distance)
for index, meta in enumerate(filtered_metadatas):
content = self.prepare_metadata(meta)
if content is not None:
filtered_documents[index] = content
# Return the filtered results instead of all results
return {
"query_embedding": query_embedding,
"ids": filtered_ids,
"documents": filtered_documents,
"distances": filtered_distances,
"metadatas": filtered_metadatas,
}
def _get_file_hash(self, file_path):
"""Calculate MD5 hash of a file."""
try:
with open(file_path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()
except Exception as e:
logging.error(f"Error hashing file {file_path}: {e}")
return None
def on_modified(self, event):
"""Handle file modification events."""
if event.is_directory:
return
file_path = event.src_path
# Schedule the update using asyncio
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
logging.info(f"File modified: {file_path}")
def on_created(self, event):
"""Handle file creation events."""
if event.is_directory:
return
file_path = event.src_path
# Schedule the update using asyncio
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
logging.info(f"File created: {file_path}")
def on_deleted(self, event):
"""Handle file deletion events."""
if event.is_directory:
return
file_path = event.src_path
asyncio.run_coroutine_threadsafe(
self.remove_file_from_collection(file_path), self.loop
)
logging.info(f"File deleted: {file_path}")
def on_moved(self, event):
"""Handle move deletion events."""
if event.is_directory:
return
file_path = event.src_path
logging.info(f"TODO: on_moved: ${file_path}")
def _normalize_embeddings(self, embeddings):
"""Normalize the embeddings to unit length."""
# Handle both single vector and array of vectors
if isinstance(embeddings[0], (int, float)):
# Single vector
norm = np.linalg.norm(embeddings)
return [e / norm for e in embeddings] if norm > 0 else embeddings
else:
# Array of vectors
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings / norms
async def _update_document_in_collection(self, file_path):
"""Update a document in the ChromaDB collection."""
try:
# Remove existing entries for this file
existing_results = self.collection.get(where={"path": file_path})
if (
existing_results
and "ids" in existing_results
and existing_results["ids"]
):
self.collection.delete(ids=existing_results["ids"])
extensions = (".docx", ".xlsx", ".xls", ".pdf")
if file_path.endswith(extensions):
p = Path(file_path)
p_as_md = p.with_suffix(".md")
if p_as_md.exists():
logging.info(
f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}"
)
# If file_path.md doesn't exist or file_path is newer than file_path.md,
# fire off markitdown
if (not p_as_md.exists()) or (
p.stat().st_mtime > p_as_md.stat().st_mtime
):
self._markitdown(file_path, p_as_md)
return
chunks = self._markdown_chunker.process_file(file_path)
if not chunks:
return
# Extract top-level directory
rel_path = os.path.relpath(file_path, self.watch_directory)
path_parts = rel_path.split(os.sep)
top_level_dir = path_parts[0]
# file_name = path_parts[-1]
for i, chunk in enumerate(chunks):
chunk["metadata"]["doc_type"] = top_level_dir
# with open(f"src/tmp/{file_name}.{i}", "w") as f:
# f.write(json.dumps(chunk, indent=2))
# Add chunks to collection
self.add_embeddings_to_collection(chunks)
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
except Exception as e:
logging.error(f"Error updating document in collection: {e}")
logging.error(traceback.format_exc())
async def initialize_collection(self):
"""Initialize the collection with all documents from the watch directory."""
# Process all files regardless of hash state
num_processed = await self.scan_directory(process_all=True)
logging.info(
f"Vectorstore initialized with {self.collection.count()} documents"
)
self._update_umaps()
# Show stats
try:
all_metadata = self.collection.get()["metadatas"]
if all_metadata:
doc_types = set(m.get("doc_type", "unknown") for m in all_metadata)
logging.info(f"Document types: {doc_types}")
except Exception as e:
logging.error(f"Error getting document types: {e}")
return num_processed
# Function to start the file watcher
def start_file_watcher(
llm,
watch_directory,
persist_directory=None,
collection_name="documents",
initialize=False,
recreate=False,
):
"""
Start watching a directory for file changes.
Args:
llm: The language model client
watch_directory: Directory to watch for changes
persist_directory: Directory to persist ChromaDB and hash state
collection_name: Name of the ChromaDB collection
initialize: Whether to forcibly initialize the collection with all documents
recreate: Whether to recreate the collection (will delete existing)
"""
loop = asyncio.get_event_loop()
file_watcher = ChromaDBFileWatcher(
llm,
watch_directory,
loop=loop,
persist_directory=persist_directory,
collection_name=collection_name,
recreate=recreate,
)
# Process all files if:
# 1. initialize=True was passed (explicit request to initialize)
# 2. This is a new collection (doesn't exist yet)
# 3. There's no hash state (first run)
if initialize or file_watcher.is_new_collection or not file_watcher.file_hashes:
logging.info("Initializing collection with all documents")
asyncio.run_coroutine_threadsafe(file_watcher.initialize_collection(), loop)
else:
# Only process new/changed files
logging.info("Scanning for new/changed documents")
asyncio.run_coroutine_threadsafe(file_watcher.scan_directory(), loop)
# Start observer
observer = Observer()
observer.schedule(file_watcher, watch_directory, recursive=True)
observer.start()
logging.info(f"Started watching directory: {watch_directory}")
return observer, file_watcher
if __name__ == "__main__":
# When running directly, use absolute imports
import defines
# Initialize Ollama client
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
# Start the file watcher (with initialization)
observer, file_watcher = start_file_watcher(
llm,
defines.doc_dir,
recreate=True, # Start fresh
)
# Example query
query = "Can you describe James Ketrenos' work history?"
top_docs = file_watcher.find_similar(query, top_k=3)
logging.info(top_docs)
try:
# Keep the main thread running
while True:
time.sleep(1)
except KeyboardInterrupt:
observer.stop()
observer.join()