481 lines
18 KiB
Python
481 lines
18 KiB
Python
import os
|
|
import glob
|
|
import time
|
|
import hashlib
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import glob
|
|
import time
|
|
import hashlib
|
|
import asyncio
|
|
import json
|
|
import pickle
|
|
import numpy as np
|
|
|
|
import chromadb
|
|
import ollama
|
|
from langchain.text_splitter import CharacterTextSplitter
|
|
from langchain.schema import Document
|
|
from watchdog.observers import Observer
|
|
from watchdog.events import FileSystemEventHandler
|
|
|
|
# Import your existing modules
|
|
if __name__ == "__main__":
|
|
# When running directly, use absolute imports
|
|
import defines
|
|
else:
|
|
# When imported as a module, use relative imports
|
|
from . import defines
|
|
|
|
__all__ = [
|
|
'ChromaDBFileWatcher',
|
|
'start_file_watcher'
|
|
]
|
|
|
|
class ChromaDBFileWatcher(FileSystemEventHandler):
|
|
def __init__(self, llm, watch_directory, loop, persist_directory=None, collection_name="documents",
|
|
chunk_size=1000, chunk_overlap=200, recreate=False):
|
|
self.llm = llm
|
|
self.watch_directory = watch_directory
|
|
self.persist_directory = persist_directory or defines.persist_directory
|
|
self.collection_name = collection_name
|
|
self.chunk_size = chunk_size
|
|
self.chunk_overlap = chunk_overlap
|
|
self.loop = loop
|
|
|
|
# Initialize ChromaDB collection
|
|
self.collection = self._get_vector_collection(recreate=recreate)
|
|
|
|
# Setup text splitter
|
|
self.text_splitter = CharacterTextSplitter(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap
|
|
)
|
|
|
|
# Track file hashes and processing state
|
|
self.file_hashes: dict[str, str] = {}
|
|
self.update_lock = asyncio.Lock()
|
|
self.processing_files = set()
|
|
|
|
# Initialize file hashes
|
|
self.llm = llm
|
|
self.watch_directory = watch_directory
|
|
self.persist_directory = persist_directory or defines.persist_directory
|
|
self.collection_name = collection_name
|
|
self.chunk_size = chunk_size
|
|
self.chunk_overlap = chunk_overlap
|
|
|
|
# Path for storing file hash state
|
|
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
|
|
|
|
# Initialize ChromaDB collection
|
|
self.collection = self._get_vector_collection(recreate=recreate)
|
|
|
|
# Setup text splitter
|
|
self.text_splitter = CharacterTextSplitter(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap
|
|
)
|
|
|
|
# Track file hashes and processing state
|
|
self.file_hashes = self._load_hash_state()
|
|
self.update_lock = asyncio.Lock()
|
|
self.processing_files = set()
|
|
|
|
# Only scan for new/changed files if we have previous hash state
|
|
if not self.file_hashes:
|
|
self._initialize_file_hashes()
|
|
else:
|
|
self._update_file_hashes()
|
|
|
|
def collection(self):
|
|
return self.collection
|
|
|
|
def _save_hash_state(self):
|
|
"""Save the current file hash state to disk."""
|
|
try:
|
|
# Create directory if it doesn't exist
|
|
os.makedirs(os.path.dirname(self.hash_state_path), exist_ok=True)
|
|
|
|
with open(self.hash_state_path, 'w') as f:
|
|
json.dump(self.file_hashes, f)
|
|
|
|
logging.info(f"Saved hash state with {len(self.file_hashes)} entries")
|
|
except Exception as e:
|
|
logging.error(f"Error saving hash state: {e}")
|
|
|
|
def _load_hash_state(self):
|
|
"""Load the file hash state from disk."""
|
|
if os.path.exists(self.hash_state_path):
|
|
try:
|
|
with open(self.hash_state_path, 'r') as f:
|
|
hash_state = json.load(f)
|
|
logging.info(f"Loaded hash state with {len(hash_state)} entries")
|
|
return hash_state
|
|
except Exception as e:
|
|
logging.error(f"Error loading hash state: {e}")
|
|
|
|
return {}
|
|
|
|
def _update_file_hashes(self):
|
|
"""Update file hashes by checking for new or modified files."""
|
|
# Check for new or modified files
|
|
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
|
|
files_checked = 0
|
|
files_changed = 0
|
|
|
|
for file_path in file_paths:
|
|
if os.path.isfile(file_path):
|
|
files_checked += 1
|
|
current_hash = self._get_file_hash(file_path)
|
|
if not current_hash:
|
|
continue
|
|
|
|
# If file is new or changed
|
|
if file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
|
|
self.file_hashes[file_path] = current_hash
|
|
files_changed += 1
|
|
# Schedule an update for this file
|
|
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
|
|
logging.info(f"File changed: {file_path}")
|
|
|
|
# Check for deleted files
|
|
deleted_files = []
|
|
for file_path in self.file_hashes:
|
|
if not os.path.exists(file_path):
|
|
deleted_files.append(file_path)
|
|
# Schedule removal
|
|
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
|
logging.info(f"File deleted: {file_path}")
|
|
|
|
# Remove deleted files from hash state
|
|
for file_path in deleted_files:
|
|
del self.file_hashes[file_path]
|
|
|
|
logging.info(f"Checked {files_checked} files: {files_changed} new/changed, {len(deleted_files)} deleted")
|
|
|
|
# Save the updated state
|
|
self._save_hash_state()
|
|
|
|
# ... rest of existing methods ...
|
|
|
|
async def process_file_update(self, file_path):
|
|
"""Process a file update event."""
|
|
# Skip if already being processed
|
|
if file_path in self.processing_files:
|
|
return
|
|
|
|
try:
|
|
self.processing_files.add(file_path)
|
|
|
|
# Wait a moment to ensure the file write is complete
|
|
await asyncio.sleep(0.5)
|
|
|
|
# Check if content changed via hash
|
|
current_hash = self._get_file_hash(file_path)
|
|
if not current_hash: # File might have been deleted or is inaccessible
|
|
return
|
|
|
|
if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash:
|
|
# File hasn't actually changed in content
|
|
return
|
|
|
|
# Update file hash
|
|
self.file_hashes[file_path] = current_hash
|
|
|
|
# Process and update the file in ChromaDB
|
|
async with self.update_lock:
|
|
await self._update_document_in_collection(file_path)
|
|
|
|
# Save the hash state after successful update
|
|
self._save_hash_state()
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error processing update for {file_path}: {e}")
|
|
finally:
|
|
self.processing_files.discard(file_path)
|
|
|
|
async def remove_file_from_collection(self, file_path):
|
|
"""Remove all chunks related to a deleted file."""
|
|
async with self.update_lock:
|
|
try:
|
|
# Find all documents with the specified path
|
|
results = self.collection.get(
|
|
where={"path": file_path}
|
|
)
|
|
|
|
if results and 'ids' in results and results['ids']:
|
|
self.collection.delete(ids=results['ids'])
|
|
logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}")
|
|
|
|
# Remove from hash dictionary
|
|
if file_path in self.file_hashes:
|
|
del self.file_hashes[file_path]
|
|
# Save the updated hash state
|
|
self._save_hash_state()
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error removing file from collection: {e}")
|
|
|
|
def _get_vector_collection(self, recreate=False):
|
|
"""Get or create a ChromaDB collection."""
|
|
# Initialize ChromaDB client
|
|
chroma_client = chromadb.PersistentClient(
|
|
path=self.persist_directory,
|
|
settings=chromadb.Settings(anonymized_telemetry=False)
|
|
)
|
|
|
|
# Check if the collection exists and delete it if recreate is True
|
|
if recreate and os.path.exists(self.persist_directory):
|
|
try:
|
|
chroma_client.delete_collection(name=self.collection_name)
|
|
except Exception as e:
|
|
logging.error(f"Failed to delete existing collection: {e}")
|
|
|
|
return chroma_client.get_or_create_collection(
|
|
name=self.collection_name,
|
|
metadata={
|
|
"hnsw:space": "cosine"
|
|
})
|
|
|
|
def load_text_files(self, directory=None, encoding="utf-8"):
|
|
"""Load all text files from a directory into Document objects."""
|
|
directory = directory or self.watch_directory
|
|
file_paths = glob.glob(os.path.join(directory, "**/*"), recursive=True)
|
|
documents = []
|
|
|
|
for file_path in file_paths:
|
|
if os.path.isfile(file_path): # Ensure it's a file, not a directory
|
|
try:
|
|
with open(file_path, "r", encoding=encoding) as f:
|
|
content = f.read()
|
|
|
|
# Extract top-level directory
|
|
rel_path = os.path.relpath(file_path, directory)
|
|
top_level_dir = rel_path.split(os.sep)[0]
|
|
|
|
documents.append(Document(
|
|
page_content=content,
|
|
metadata={"doc_type": top_level_dir, "path": file_path}
|
|
))
|
|
except Exception as e:
|
|
logging.error(f"Failed to load {file_path}: {e}")
|
|
|
|
return documents
|
|
|
|
def create_chunks_from_documents(self, docs):
|
|
"""Split documents into chunks using the text splitter."""
|
|
return self.text_splitter.split_documents(docs)
|
|
|
|
def get_embedding(self, text):
|
|
"""Generate embeddings using Ollama."""
|
|
response = self.llm.embeddings(
|
|
model=defines.model,
|
|
prompt=text,
|
|
options={"num_ctx": defines.max_context}
|
|
)
|
|
return self._normalize_embeddings(response["embedding"])
|
|
|
|
def add_embeddings_to_collection(self, chunks):
|
|
"""Add embeddings for chunks to the collection."""
|
|
for i, chunk in enumerate(chunks):
|
|
text = chunk.page_content
|
|
metadata = chunk.metadata
|
|
|
|
# Generate a more unique ID based on content and metadata
|
|
content_hash = hashlib.md5(text.encode()).hexdigest()
|
|
path_hash = ""
|
|
if "path" in metadata:
|
|
path_hash = hashlib.md5(metadata["path"].encode()).hexdigest()[:8]
|
|
|
|
chunk_id = f"{path_hash}_{content_hash}_{i}"
|
|
|
|
embedding = self.get_embedding(text)
|
|
self.collection.add(
|
|
ids=[chunk_id],
|
|
documents=[text],
|
|
embeddings=[embedding],
|
|
metadatas=[metadata]
|
|
)
|
|
|
|
def find_similar(self, query, top_k=3):
|
|
"""Find similar documents to the query."""
|
|
query_embedding = self.get_embedding(query)
|
|
results = self.collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=top_k,
|
|
include=["documents", "metadatas", "distances"]
|
|
)
|
|
return {
|
|
"query_embedding": query_embedding,
|
|
"ids": results["ids"][0],
|
|
"documents": results["documents"][0],
|
|
"distances": results["distances"][0],
|
|
"metadatas": results["metadatas"][0],
|
|
}
|
|
|
|
def _initialize_file_hashes(self):
|
|
"""Initialize the hash dictionary for all files in the directory."""
|
|
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
|
|
for file_path in file_paths:
|
|
if os.path.isfile(file_path):
|
|
hash = self._get_file_hash(file_path)
|
|
if hash:
|
|
self.file_hashes[file_path] = hash
|
|
|
|
def _get_file_hash(self, file_path):
|
|
"""Calculate MD5 hash of a file."""
|
|
try:
|
|
with open(file_path, 'rb') as f:
|
|
return hashlib.md5(f.read()).hexdigest()
|
|
except Exception as e:
|
|
logging.error(f"Error hashing file {file_path}: {e}")
|
|
return None
|
|
|
|
def on_modified(self, event):
|
|
"""Handle file modification events."""
|
|
if event.is_directory:
|
|
return
|
|
|
|
file_path = event.src_path
|
|
# Schedule the update using asyncio
|
|
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
|
|
logging.info(f"File modified: {file_path}")
|
|
|
|
def on_created(self, event):
|
|
"""Handle file creation events."""
|
|
if event.is_directory:
|
|
return
|
|
|
|
file_path = event.src_path
|
|
# Schedule the update using asyncio
|
|
asyncio.run_coroutine_threadsafe(self.process_file_update(file_path), self.loop)
|
|
logging.info(f"File created: {file_path}")
|
|
|
|
def on_deleted(self, event):
|
|
"""Handle file deletion events."""
|
|
if event.is_directory:
|
|
return
|
|
|
|
file_path = event.src_path
|
|
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
|
logging.info(f"File deleted: {file_path}")
|
|
|
|
|
|
def _normalize_embeddings(self, embeddings):
|
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
return embeddings / norms
|
|
|
|
async def _update_document_in_collection(self, file_path):
|
|
"""Update a document in the ChromaDB collection."""
|
|
try:
|
|
# Remove existing entries for this file
|
|
existing_results = self.collection.get(where={"path": file_path})
|
|
if existing_results and 'ids' in existing_results and existing_results['ids']:
|
|
self.collection.delete(ids=existing_results['ids'])
|
|
|
|
# Create document object in LangChain format
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
content = f.read()
|
|
|
|
# Extract top-level directory
|
|
rel_path = os.path.relpath(file_path, self.watch_directory)
|
|
top_level_dir = rel_path.split(os.sep)[0]
|
|
|
|
document = Document(
|
|
page_content=content,
|
|
metadata={"doc_type": top_level_dir, "path": file_path}
|
|
)
|
|
|
|
# Create chunks
|
|
chunks = self.text_splitter.split_documents([document])
|
|
|
|
# Add chunks to collection
|
|
self.add_embeddings_to_collection(chunks)
|
|
|
|
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error updating document in collection: {e}")
|
|
|
|
def initialize_collection(self):
|
|
"""Initialize the collection with all documents from the watch directory."""
|
|
documents = self.load_text_files()
|
|
logging.info(f"Documents loaded: {len(documents)}")
|
|
|
|
chunks = self.create_chunks_from_documents(documents)
|
|
self.add_embeddings_to_collection(chunks)
|
|
|
|
logging.info(f"Vectorstore created with {self.collection.count()} documents")
|
|
|
|
# Display document types
|
|
doc_types = set(chunk.metadata['doc_type'] for chunk in chunks)
|
|
logging.info(f"Document types: {doc_types}")
|
|
|
|
return len(chunks)
|
|
|
|
# Function to start the file watcher
|
|
def start_file_watcher(llm, watch_directory, persist_directory=None,
|
|
collection_name="documents", initialize=False, recreate=False):
|
|
"""
|
|
Start watching a directory for file changes.
|
|
|
|
Args:
|
|
llm: The language model client
|
|
watch_directory: Directory to watch for changes
|
|
persist_directory: Directory to persist ChromaDB and hash state
|
|
collection_name: Name of the ChromaDB collection
|
|
initialize: Whether to initialize the collection with all documents (only needed first time)
|
|
recreate: Whether to recreate the collection (will delete existing)
|
|
"""
|
|
loop = asyncio.get_event_loop()
|
|
|
|
file_watcher = ChromaDBFileWatcher(
|
|
llm,
|
|
watch_directory,
|
|
loop=loop,
|
|
persist_directory=persist_directory,
|
|
collection_name=collection_name,
|
|
recreate=recreate
|
|
)
|
|
|
|
# Initialize collection if requested and no existing hash state
|
|
if initialize and not file_watcher.file_hashes:
|
|
file_watcher.initialize_collection()
|
|
|
|
# Start observer
|
|
observer = Observer()
|
|
observer.schedule(file_watcher, watch_directory, recursive=True)
|
|
observer.start()
|
|
|
|
logging.info(f"Started watching directory: {watch_directory}")
|
|
return observer, file_watcher
|
|
|
|
if __name__ == "__main__":
|
|
# When running directly, use absolute imports
|
|
import defines
|
|
|
|
# Initialize Ollama client
|
|
llm = ollama.Client(host=defines.ollama_api_url)
|
|
|
|
# Start the file watcher (with initialization)
|
|
observer, file_watcher = start_file_watcher(
|
|
llm,
|
|
defines.doc_dir,
|
|
recreate=True, # Start fresh
|
|
initialize=True # Load all documents initially
|
|
)
|
|
|
|
# Example query
|
|
query = "Can you describe James Ketrenos' work history?"
|
|
top_docs = file_watcher.find_similar(query, top_k=3)
|
|
logging.info(top_docs)
|
|
|
|
try:
|
|
# Keep the main thread running
|
|
while True:
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
observer.stop()
|
|
observer.join() |