82 lines
2.7 KiB
Python
82 lines
2.7 KiB
Python
# From /opt/backstory run:
|
|
# python -m src.tests.test-rag
|
|
from ..utils import logger
|
|
from pydantic import BaseModel, field_validator # type: ignore
|
|
from prometheus_client import CollectorRegistry # type: ignore
|
|
from typing import List, Dict, Any, Optional
|
|
import ollama
|
|
import numpy as np # type: ignore
|
|
from ..utils import (rag as Rag, ChromaDBGetResponse)
|
|
from ..utils import Context
|
|
from ..utils import defines
|
|
|
|
import json
|
|
|
|
chroma_results = {
|
|
"ids": ["1", "2"],
|
|
"embeddings": np.array([[1.0, 2.0], [3.0, 4.0]]),
|
|
"documents": ["doc1", "doc2"],
|
|
"metadatas": [{"meta": "data1"}, {"meta": "data2"}],
|
|
"query_embedding": np.array([0.1, 0.2, 0.3])
|
|
}
|
|
|
|
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
|
umap_2d = np.array([0.4, 0.5]) # Example UMAP output
|
|
umap_3d = np.array([0.6, 0.7, 0.8]) # Example UMAP output
|
|
|
|
rag_metadata = ChromaDBGetResponse(
|
|
query="test",
|
|
query_embedding=query_embedding,
|
|
name="JPK",
|
|
ids=chroma_results.get("ids", []),
|
|
size=2
|
|
)
|
|
|
|
logger.info(json.dumps(rag_metadata.model_dump(mode="json")))
|
|
|
|
logger.info(f"Assigning type {type(umap_2d)} to rag_metadata.umap_embedding_2d")
|
|
rag_metadata.umap_embedding_2d = umap_2d
|
|
|
|
logger.info(json.dumps(rag_metadata.model_dump(mode="json")))
|
|
|
|
rag = ChromaDBGetResponse()
|
|
rag.embeddings = np.array([[1.0, 2.0], [3.0, 4.0]])
|
|
json_str = rag.model_dump(mode="json")
|
|
logger.info(json_str)
|
|
rag = ChromaDBGetResponse.model_validate(json_str)
|
|
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
|
prometheus_collector = CollectorRegistry()
|
|
observer, file_watcher = Rag.start_file_watcher(
|
|
llm=llm,
|
|
watch_directory=defines.doc_dir,
|
|
recreate=False, # Don't recreate if exists
|
|
)
|
|
context = Context(
|
|
file_watcher=file_watcher,
|
|
prometheus_collector=prometheus_collector,
|
|
)
|
|
skill="Codes in C++"
|
|
if context.file_watcher:
|
|
chroma_results = context.file_watcher.find_similar(query=skill, top_k=10, threshold=0.5)
|
|
if chroma_results:
|
|
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
|
|
|
umap_2d = context.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
|
umap_3d = context.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
|
|
|
rag_metadata = ChromaDBGetResponse(
|
|
query=skill,
|
|
query_embedding=query_embedding,
|
|
name="JPK",
|
|
ids=chroma_results.get("ids", []),
|
|
embeddings=chroma_results.get("embeddings", []),
|
|
documents=chroma_results.get("documents", []),
|
|
metadatas=chroma_results.get("metadatas", []),
|
|
umap_embedding_2d=umap_2d,
|
|
umap_embedding_3d=umap_3d,
|
|
size=context.file_watcher.collection.count()
|
|
)
|
|
|
|
json_str = context.model_dump(mode="json")
|
|
logger.info(json_str)
|