123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
import chromadb
|
|
from typing import List, Dict, Any, Union
|
|
from . import defines
|
|
from .chunk import chunk_document
|
|
import ollama
|
|
|
|
def init_chroma_client(persist_directory: str = defines.persist_directory):
|
|
"""Initialize and return a ChromaDB client."""
|
|
# return chromadb.PersistentClient(path=persist_directory)
|
|
return chromadb.Client()
|
|
|
|
def create_or_get_collection(db: chromadb.Client, collection_name: str):
|
|
"""Create or get a ChromaDB collection."""
|
|
try:
|
|
return db.get_collection(
|
|
name=collection_name
|
|
)
|
|
except:
|
|
return db.create_collection(
|
|
name=collection_name,
|
|
metadata={"hnsw:space": "cosine"}
|
|
)
|
|
|
|
def process_documents_to_chroma(
|
|
client: ollama.Client,
|
|
documents: List[Dict[str, Any]],
|
|
collection_name: str = "document_collection",
|
|
text_key: str = "text",
|
|
max_tokens: int = 512,
|
|
overlap: int = 50,
|
|
model: str = defines.encoding_model,
|
|
persist_directory: str = defines.persist_directory
|
|
):
|
|
"""
|
|
Process documents, chunk them, compute embeddings, and store in ChromaDB.
|
|
|
|
Args:
|
|
documents: List of document dictionaries
|
|
collection_name: Name for the ChromaDB collection
|
|
text_key: The key containing text content
|
|
max_tokens: Maximum tokens per chunk
|
|
overlap: Token overlap between chunks
|
|
model: Ollama model for embeddings
|
|
persist_directory: Directory to store ChromaDB data
|
|
"""
|
|
# Initialize ChromaDB client and collection
|
|
db = init_chroma_client(persist_directory)
|
|
collection = create_or_get_collection(db, collection_name)
|
|
|
|
# Process each document
|
|
for doc in documents:
|
|
# Chunk the document
|
|
doc_chunks = chunk_document(doc, text_key, max_tokens, overlap)
|
|
|
|
# Prepare data for ChromaDB
|
|
ids = []
|
|
texts = []
|
|
metadatas = []
|
|
embeddings = []
|
|
|
|
for chunk in doc_chunks:
|
|
# Create a unique ID for the chunk
|
|
chunk_id = f"{chunk['id']}_{chunk['chunk_id']}"
|
|
|
|
# Extract text
|
|
text = chunk[text_key]
|
|
|
|
# Create metadata (excluding text and embedding to avoid duplication)
|
|
metadata = {k: v for k, v in chunk.items() if k != text_key and k != "embedding"}
|
|
|
|
response = client.embed(model=model, input=text)
|
|
embedding = response["embeddings"][0]
|
|
ids.append(chunk_id)
|
|
texts.append(text)
|
|
metadatas.append(metadata)
|
|
embeddings.append(embedding)
|
|
|
|
# Add chunks to ChromaDB collection
|
|
collection.add(
|
|
ids=ids,
|
|
documents=texts,
|
|
embeddings=embeddings,
|
|
metadatas=metadatas
|
|
)
|
|
|
|
return collection
|
|
|
|
def query_chroma(
|
|
client: ollama.Client,
|
|
query_text: str,
|
|
collection_name: str = "document_collection",
|
|
n_results: int = 5,
|
|
model: str = defines.encoding_model,
|
|
persist_directory: str = defines.persist_directory
|
|
):
|
|
"""
|
|
Query ChromaDB for similar documents.
|
|
|
|
Args:
|
|
query_text: The text to search for
|
|
collection_name: Name of the ChromaDB collection
|
|
n_results: Number of results to return
|
|
model: Ollama model for embedding the query
|
|
persist_directory: Directory where ChromaDB data is stored
|
|
|
|
Returns:
|
|
Query results from ChromaDB
|
|
"""
|
|
# Initialize ChromaDB client and collection
|
|
db = init_chroma_client(persist_directory)
|
|
collection = create_or_get_collection(db, collection_name)
|
|
|
|
query_response = client.embed(model=model, input=query_text)
|
|
query_embeddings = query_response["embeddings"]
|
|
|
|
# Query the collection
|
|
results = collection.query(
|
|
query_embeddings=query_embeddings,
|
|
n_results=n_results
|
|
)
|
|
|
|
return results
|