backstory/src/tests/test-embedding.py

89 lines
2.9 KiB
Python

# From /opt/backstory run:
# python -m src.tests.test-embedding
import numpy as np
import logging
import argparse
from ollama import Client
from ..utils import defines
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
def get_embedding(text: str, embedding_model: str, ollama_server: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text."""
llm = Client(host=ollama_server)
# Get embedding
try:
response = llm.embeddings(model=embedding_model, prompt=text)
embedding = np.array(response["embedding"])
except Exception as e:
logging.error(f"Failed to get embedding: {e}")
raise
# Log diagnostics
logging.info(f"Input text: {text}")
logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
# Check for invalid embeddings
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
raise ValueError("Invalid embedding returned from Ollama.")
# Check normalization
norm = np.linalg.norm(embedding)
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
# Normalize if needed
if not is_normalized:
embedding = embedding / norm
logging.info("Embedding normalized manually.")
return embedding
def main():
"""Main function to generate and normalize an embedding from command-line input."""
parser = argparse.ArgumentParser(description="Generate embeddings for text using mxbai-embed-large.")
parser.add_argument(
"--text",
type=str,
nargs="+", # Allow multiple text inputs
default=["Test sentence."],
help="Text(s) to generate embeddings for (default: 'Test sentence.')",
)
parser.add_argument(
"--ollama-server",
type=str,
default=defines.ollama_api_url,
help=f"Ollama server URL (default: {defines.ollama_api_url})",
)
parser.add_argument(
"--embedding-model",
type=str,
default=defines.embedding_model,
help=f"Embedding model name (default: {defines.embedding_model})",
)
args = parser.parse_args()
# Validate input
for text in args.text:
if not text or not isinstance(text, str):
logging.error("Input text must be a non-empty string.")
raise ValueError("Input text must be a non-empty string.")
# Generate embeddings for each text
embeddings = []
for text in args.text:
embedding = get_embedding(
text=text,
embedding_model=args.embedding_model,
ollama_server=args.ollama_server,
)
embeddings.append(embedding)
if __name__ == "__main__":
main()