89 lines
2.9 KiB
Python
89 lines
2.9 KiB
Python
# From /opt/backstory run:
|
|
# python -m src.tests.test-embedding
|
|
import numpy as np
|
|
import logging
|
|
import argparse
|
|
from ollama import Client
|
|
from ..utils import defines
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
)
|
|
|
|
def get_embedding(text: str, embedding_model: str, ollama_server: str) -> np.ndarray:
|
|
"""Generate and normalize an embedding for the given text."""
|
|
llm = Client(host=ollama_server)
|
|
|
|
# Get embedding
|
|
try:
|
|
response = llm.embeddings(model=embedding_model, prompt=text)
|
|
embedding = np.array(response["embedding"])
|
|
except Exception as e:
|
|
logging.error(f"Failed to get embedding: {e}")
|
|
raise
|
|
|
|
# Log diagnostics
|
|
logging.info(f"Input text: {text}")
|
|
logging.info(f"Embedding shape: {embedding.shape}, First 5 values: {embedding[:5]}")
|
|
|
|
# Check for invalid embeddings
|
|
if embedding.size == 0 or np.any(np.isnan(embedding)) or np.any(np.isinf(embedding)):
|
|
logging.error("Invalid embedding: contains NaN, infinite, or empty values.")
|
|
raise ValueError("Invalid embedding returned from Ollama.")
|
|
|
|
# Check normalization
|
|
norm = np.linalg.norm(embedding)
|
|
is_normalized = np.allclose(norm, 1.0, atol=1e-3)
|
|
logging.info(f"Embedding norm: {norm}, Is normalized: {is_normalized}")
|
|
|
|
# Normalize if needed
|
|
if not is_normalized:
|
|
embedding = embedding / norm
|
|
logging.info("Embedding normalized manually.")
|
|
|
|
return embedding
|
|
|
|
def main():
|
|
"""Main function to generate and normalize an embedding from command-line input."""
|
|
parser = argparse.ArgumentParser(description="Generate embeddings for text using mxbai-embed-large.")
|
|
parser.add_argument(
|
|
"--text",
|
|
type=str,
|
|
nargs="+", # Allow multiple text inputs
|
|
default=["Test sentence."],
|
|
help="Text(s) to generate embeddings for (default: 'Test sentence.')",
|
|
)
|
|
parser.add_argument(
|
|
"--ollama-server",
|
|
type=str,
|
|
default=defines.ollama_api_url,
|
|
help=f"Ollama server URL (default: {defines.ollama_api_url})",
|
|
)
|
|
parser.add_argument(
|
|
"--embedding-model",
|
|
type=str,
|
|
default=defines.embedding_model,
|
|
help=f"Embedding model name (default: {defines.embedding_model})",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Validate input
|
|
for text in args.text:
|
|
if not text or not isinstance(text, str):
|
|
logging.error("Input text must be a non-empty string.")
|
|
raise ValueError("Input text must be a non-empty string.")
|
|
|
|
# Generate embeddings for each text
|
|
embeddings = []
|
|
for text in args.text:
|
|
embedding = get_embedding(
|
|
text=text,
|
|
embedding_model=args.embedding_model,
|
|
ollama_server=args.ollama_server,
|
|
)
|
|
embeddings.append(embedding)
|
|
|
|
if __name__ == "__main__":
|
|
main() |