# From /opt/backstory run: # python -m src.tests.test-embedding import numpy as np import logging import argparse from ollama import Client from ..utils import defines # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", ) def get_embedding(text: str, embedding_model: str, ollama_server: str) -> np.ndarray: """Generate and normalize an embedding for the given text.""" llm = Client(host=ollama_server) # Get embedding try: response = llm.embeddings(model=embedding_model, prompt=text) embedding = np.array(response["embedding"]) except Exception as e: logging.error(f"Failed to get embedding: {e}") raise # Log diagnostics logging.info(f"Input text: {text}") logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}") # Check for invalid embeddings if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)): logging.error("Invalid embedding: contains NaN, infinite, or empty values.") raise ValueError("Invalid embedding returned from Ollama.") # Check normalization norm = np.linalg.norm(embedding) is_normalized = np.allclose(norm, 1.0, atol=1e-3) logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}") # Normalize if needed if not is_normalized: embedding = embedding / norm logging.info("Embedding normalized manually.") return embedding def main(): """Main function to generate and normalize an embedding from command-line input.""" parser = argparse.ArgumentParser(description="Generate embeddings for text using mxbai-embed-large.") parser.add_argument( "--text", type=str, nargs="+", # Allow multiple text inputs default=["Test sentence."], help="Text(s) to generate embeddings for (default: 'Test sentence.')", ) parser.add_argument( "--ollama-server", type=str, default=defines.ollama_api_url, help=f"Ollama server URL (default: {defines.ollama_api_url})", ) parser.add_argument( "--embedding-model", type=str, default=defines.embedding_model, help=f"Embedding model name (default: {defines.embedding_model})", ) args = parser.parse_args() # Validate input for text in args.text: if not text or not isinstance(text, str): logging.error("Input text must be a non-empty string.") raise ValueError("Input text must be a non-empty string.") # Generate embeddings for each text embeddings = [] for text in args.text: embedding = get_embedding( text=text, embedding_model=args.embedding_model, ollama_server=args.ollama_server, ) embeddings.append(embedding) if __name__ == "__main__": main()