backstory/src/tests/test-rag.py

82 lines
2.7 KiB
Python

# From /opt/backstory run:
# python -m src.tests.test-rag
from ..utils import logger
from pydantic import BaseModel, field_validator # type: ignore
from prometheus_client import CollectorRegistry # type: ignore
from typing import List, Dict, Any, Optional
import ollama
import numpy as np # type: ignore
from ..utils import (rag as Rag, ChromaDBGetResponse)
from ..utils import Context
from ..utils import defines
import json
chroma_results = {
"ids": ["1", "2"],
"embeddings": np.array([[1.0, 2.0], [3.0, 4.0]]),
"documents": ["doc1", "doc2"],
"metadatas": [{"meta": "data1"}, {"meta": "data2"}],
"query_embedding": np.array([0.1, 0.2, 0.3])
}
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
umap_2d = np.array([0.4, 0.5]) # Example UMAP output
umap_3d = np.array([0.6, 0.7, 0.8]) # Example UMAP output
rag_metadata = ChromaDBGetResponse(
query="test",
query_embedding=query_embedding,
name="JPK",
ids=chroma_results.get("ids", []),
size=2
)
logger.info(json.dumps(rag_metadata.model_dump(mode="json")))
logger.info(f"Assigning type {type(umap_2d)} to rag_metadata.umap_embedding_2d")
rag_metadata.umap_embedding_2d = umap_2d
logger.info(json.dumps(rag_metadata.model_dump(mode="json")))
rag = ChromaDBGetResponse()
rag.embeddings = np.array([[1.0, 2.0], [3.0, 4.0]])
json_str = rag.model_dump(mode="json")
logger.info(json_str)
rag = ChromaDBGetResponse.model_validate(json_str)
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
prometheus_collector = CollectorRegistry()
observer, file_watcher = Rag.start_file_watcher(
llm=llm,
watch_directory=defines.doc_dir,
recreate=False, # Don't recreate if exists
)
context = Context(
file_watcher=file_watcher,
prometheus_collector=prometheus_collector,
)
skill="Codes in C++"
if context.file_watcher:
chroma_results = context.file_watcher.find_similar(query=skill, top_k=10, threshold=0.5)
if chroma_results:
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
umap_2d = context.file_watcher.umap_model_2d.transform([query_embedding])[0]
umap_3d = context.file_watcher.umap_model_3d.transform([query_embedding])[0]
rag_metadata = ChromaDBGetResponse(
query=skill,
query_embedding=query_embedding,
name="JPK",
ids=chroma_results.get("ids", []),
embeddings=chroma_results.get("embeddings", []),
documents=chroma_results.get("documents", []),
metadatas=chroma_results.get("metadatas", []),
umap_embedding_2d=umap_2d,
umap_embedding_3d=umap_3d,
size=context.file_watcher.collection.count()
)
json_str = context.model_dump(mode="json")
logger.info(json_str)