481 lines
18 KiB
Python

import os
import glob
import time
import hashlib
import asyncio
import logging
import os
import glob
import time
import hashlib
import asyncio
import json
import pickle
import numpy as np
import chromadb
import ollama
from langchain.text_splitter import CharacterTextSplitter
from langchain.schema import Document
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
# Import your existing modules
if __name__ == "__main__":
# When running directly, use absolute imports
import defines
else:
# When imported as a module, use relative imports
from . import defines
__all__ = [
'ChromaDBFileWatcher',
'start_file_watcher'
]
class ChromaDBFileWatcher(FileSystemEventHandler):
def __init__(self, llm, watch_directory, loop, persist_directory=None, collection_name="documents",
chunk_size=1000, chunk_overlap=200, recreate=False):
self.llm = llm
self.watch_directory = watch_directory
self.persist_directory = persist_directory or defines.persist_directory
self.collection_name = collection_name
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.loop = loop
# Initialize ChromaDB collection
self.collection = self._get_vector_collection(recreate=recreate)
# Setup text splitter
self.text_splitter = CharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
# Track file hashes and processing state
self.file_hashes: dict[str, str] = {}
self.update_lock = asyncio.Lock()
self.processing_files = set()
# Initialize file hashes
self.llm = llm
self.watch_directory = watch_directory
self.persist_directory = persist_directory or defines.persist_directory
self.collection_name = collection_name
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
# Path for storing file hash state
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
# Initialize ChromaDB collection
self.collection = self._get_vector_collection(recreate=recreate)
# Setup text splitter
self.text_splitter = CharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
# Track file hashes and processing state
self.file_hashes = self._load_hash_state()
self.update_lock = asyncio.Lock()
self.processing_files = set()
# Only scan for new/changed files if we have previous hash state
if not self.file_hashes:
self._initialize_file_hashes()
else:
self._update_file_hashes()
def collection(self):
return self.collection
def _save_hash_state(self):
"""Save the current file hash state to disk."""
try:
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True)
with open(self.hash_state_path, 'w') as f:
json.dump(self.file_hashes, f)
logging.info(f"Saved hash state with {len(self.file_hashes)} entries")
except Exception as e:
logging.error(f"Error saving hash state: {e}")
def _load_hash_state(self):
"""Load the file hash state from disk."""
if os.path.exists(self.hash_state_path):
try:
with open(self.hash_state_path, 'r') as f:
hash_state = json.load(f)
logging.info(f"Loaded hash state with {len(hash_state)} entries")
return hash_state
except Exception as e:
logging.error(f"Error loading hash state: {e}")
return {}
def _update_file_hashes(self):
"""Update file hashes by checking for new or modified files."""
# Check for new or modified files
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
files_checked = 0
files_changed = 0
for file_path in file_paths:
if os.path.isfile(file_path):
files_checked += 1
current_hash = self._get_file_hash(file_path)
if not current_hash:
continue
# If file is new or changed
if file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
self.file_hashes[file_path] = current_hash
files_changed += 1
# Schedule an update for this file
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
logging.info(f"File changed: {file_path}")
# Check for deleted files
deleted_files = []
for file_path in self.file_hashes:
if not os.path.exists(file_path):
deleted_files.append(file_path)
# Schedule removal
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
logging.info(f"File deleted: {file_path}")
# Remove deleted files from hash state
for file_path in deleted_files:
del self.file_hashes[file_path]
logging.info(f"Checked {files_checked} files: {files_changed} new/changed, {len(deleted_files)} deleted")
# Save the updated state
self._save_hash_state()
# ... rest of existing methods ...
async def process_file_update(self, file_path):
"""Process a file update event."""
# Skip if already being processed
if file_path in self.processing_files:
return
try:
self.processing_files.add(file_path)
# Wait a moment to ensure the file write is complete
await asyncio.sleep(0.5)
# Check if content changed via hash
current_hash = self._get_file_hash(file_path)
if not current_hash: # File might have been deleted or is inaccessible
return
if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash:
# File hasn't actually changed in content
return
# Update file hash
self.file_hashes[file_path] = current_hash
# Process and update the file in ChromaDB
async with self.update_lock:
await self._update_document_in_collection(file_path)
# Save the hash state after successful update
self._save_hash_state()
except Exception as e:
logging.error(f"Error processing update for {file_path}: {e}")
finally:
self.processing_files.discard(file_path)
async def remove_file_from_collection(self, file_path):
"""Remove all chunks related to a deleted file."""
async with self.update_lock:
try:
# Find all documents with the specified path
results = self.collection.get(
where={"path": file_path}
)
if results and 'ids' in results and results['ids']:
self.collection.delete(ids=results['ids'])
logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}")
# Remove from hash dictionary
if file_path in self.file_hashes:
del self.file_hashes[file_path]
# Save the updated hash state
self._save_hash_state()
except Exception as e:
logging.error(f"Error removing file from collection: {e}")
def _get_vector_collection(self, recreate=False):
"""Get or create a ChromaDB collection."""
# Initialize ChromaDB client
chroma_client = chromadb.PersistentClient(
path=self.persist_directory,
settings=chromadb.Settings(anonymized_telemetry=False)
)
# Check if the collection exists and delete it if recreate is True
if recreate and os.path.exists(self.persist_directory):
try:
chroma_client.delete_collection(name=self.collection_name)
except Exception as e:
logging.error(f"Failed to delete existing collection: {e}")
return chroma_client.get_or_create_collection(
name=self.collection_name,
metadata={
"hnsw:space": "cosine"
})
def load_text_files(self, directory=None, encoding="utf-8"):
"""Load all text files from a directory into Document objects."""
directory = directory or self.watch_directory
file_paths = glob.glob(os.path.join(directory, "**/*"), recursive=True)
documents = []
for file_path in file_paths:
if os.path.isfile(file_path): # Ensure it's a file, not a directory
try:
with open(file_path, "r", encoding=encoding) as f:
content = f.read()
# Extract top-level directory
rel_path = os.path.relpath(file_path, directory)
top_level_dir = rel_path.split(os.sep)[0]
documents.append(Document(
page_content=content,
metadata={"doc_type": top_level_dir, "path": file_path}
))
except Exception as e:
logging.error(f"Failed to load {file_path}: {e}")
return documents
def create_chunks_from_documents(self, docs):
"""Split documents into chunks using the text splitter."""
return self.text_splitter.split_documents(docs)
def get_embedding(self, text):
"""Generate embeddings using Ollama."""
response = self.llm.embeddings(
model=defines.model,
prompt=text,
options={"num_ctx": defines.max_context}
)
return self._normalize_embeddings(response["embedding"])
def add_embeddings_to_collection(self, chunks):
"""Add embeddings for chunks to the collection."""
for i, chunk in enumerate(chunks):
text = chunk.page_content
metadata = chunk.metadata
# Generate a more unique ID based on content and metadata
content_hash = hashlib.md5(text.encode()).hexdigest()
path_hash = ""
if "path" in metadata:
path_hash = hashlib.md5(metadata["path"].encode()).hexdigest()[:8]
chunk_id = f"{path_hash}_{content_hash}_{i}"
embedding = self.get_embedding(text)
self.collection.add(
ids=[chunk_id],
documents=[text],
embeddings=[embedding],
metadatas=[metadata]
)
def find_similar(self, query, top_k=3):
"""Find similar documents to the query."""
query_embedding = self.get_embedding(query)
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=["documents", "metadatas", "distances"]
)
return {
"query_embedding": query_embedding,
"ids": results["ids"][0],
"documents": results["documents"][0],
"distances": results["distances"][0],
"metadatas": results["metadatas"][0],
}
def _initialize_file_hashes(self):
"""Initialize the hash dictionary for all files in the directory."""
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
for file_path in file_paths:
if os.path.isfile(file_path):
hash = self._get_file_hash(file_path)
if hash:
self.file_hashes[file_path] = hash
def _get_file_hash(self, file_path):
"""Calculate MD5 hash of a file."""
try:
with open(file_path, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
except Exception as e:
logging.error(f"Error hashing file {file_path}: {e}")
return None
def on_modified(self, event):
"""Handle file modification events."""
if event.is_directory:
return
file_path = event.src_path
# Schedule the update using asyncio
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
logging.info(f"File modified: {file_path}")
def on_created(self, event):
"""Handle file creation events."""
if event.is_directory:
return
file_path = event.src_path
# Schedule the update using asyncio
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
logging.info(f"File created: {file_path}")
def on_deleted(self, event):
"""Handle file deletion events."""
if event.is_directory:
return
file_path = event.src_path
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
logging.info(f"File deleted: {file_path}")
def _normalize_embeddings(self, embeddings):
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings / norms
async def _update_document_in_collection(self, file_path):
"""Update a document in the ChromaDB collection."""
try:
# Remove existing entries for this file
existing_results = self.collection.get(where={"path": file_path})
if existing_results and 'ids' in existing_results and existing_results['ids']:
self.collection.delete(ids=existing_results['ids'])
# Create document object in LangChain format
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# Extract top-level directory
rel_path = os.path.relpath(file_path, self.watch_directory)
top_level_dir = rel_path.split(os.sep)[0]
document = Document(
page_content=content,
metadata={"doc_type": top_level_dir, "path": file_path}
)
# Create chunks
chunks = self.text_splitter.split_documents([document])
# Add chunks to collection
self.add_embeddings_to_collection(chunks)
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
except Exception as e:
logging.error(f"Error updating document in collection: {e}")
def initialize_collection(self):
"""Initialize the collection with all documents from the watch directory."""
documents = self.load_text_files()
logging.info(f"Documents loaded: {len(documents)}")
chunks = self.create_chunks_from_documents(documents)
self.add_embeddings_to_collection(chunks)
logging.info(f"Vectorstore created with {self.collection.count()} documents")
# Display document types
doc_types = set(chunk.metadata['doc_type'] for chunk in chunks)
logging.info(f"Document types: {doc_types}")
return len(chunks)
# Function to start the file watcher
def start_file_watcher(llm, watch_directory, persist_directory=None,
collection_name="documents", initialize=False, recreate=False):
"""
Start watching a directory for file changes.
Args:
llm: The language model client
watch_directory: Directory to watch for changes
persist_directory: Directory to persist ChromaDB and hash state
collection_name: Name of the ChromaDB collection
initialize: Whether to initialize the collection with all documents (only needed first time)
recreate: Whether to recreate the collection (will delete existing)
"""
loop = asyncio.get_event_loop()
file_watcher = ChromaDBFileWatcher(
llm,
watch_directory,
loop=loop,
persist_directory=persist_directory,
collection_name=collection_name,
recreate=recreate
)
# Initialize collection if requested and no existing hash state
if initialize and not file_watcher.file_hashes:
file_watcher.initialize_collection()
# Start observer
observer = Observer()
observer.schedule(file_watcher, watch_directory, recursive=True)
observer.start()
logging.info(f"Started watching directory: {watch_directory}")
return observer, file_watcher
if __name__ == "__main__":
# When running directly, use absolute imports
import defines
# Initialize Ollama client
llm = ollama.Client(host=defines.ollama_api_url)
# Start the file watcher (with initialization)
observer, file_watcher = start_file_watcher(
llm,
defines.doc_dir,
recreate=True, # Start fresh
initialize=True # Load all documents initially
)
# Example query
query = "Can you describe James Ketrenos' work history?"
top_docs = file_watcher.find_similar(query, top_k=3)
logging.info(top_docs)
try:
# Keep the main thread running
while True:
time.sleep(1)
except KeyboardInterrupt:
observer.stop()
observer.join()