diff --git a/src/backend/agents/base.py b/src/backend/agents/base.py index b733298..aa3f1d5 100644 --- a/src/backend/agents/base.py +++ b/src/backend/agents/base.py @@ -26,6 +26,7 @@ from rag import start_file_watcher, ChromaDBFileWatcher import defines from logger import logger from models import ( + ChatResponse, Tunables, ChatMessageUser, ChatMessage, @@ -34,6 +35,7 @@ from models import ( ApiStatusType, Candidate, ChatContextType, + UsageStats, ) import utils.llm_proxy as llm_manager from database.manager import RedisDatabase @@ -158,12 +160,15 @@ class CandidateEntity(Candidate): raise ValueError("initialize() has not been called.") return self.CandidateEntity__observer - def collect_metrics(self, agent: Agent, response): + def collect_metrics(self, agent: Agent, response: ChatResponse): if not self.metrics: logger.warning("No metrics collector set for this agent.") return - self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(response.usage.prompt_eval_count) - self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count) + if response.usage: + if response.usage.prompt_tokens: + self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(response.usage.prompt_tokens) + if response.usage.completion_tokens: + self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.completion_tokens) async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase): if self.CandidateEntity__initialized: @@ -231,7 +236,6 @@ class Agent(BaseModel, ABC): # Agent properties system_prompt: str = "" - context_tokens: int = 0 # context_size is shared across all subclasses _context_size: ClassVar[int] = int(defines.max_context * 0.5) @@ -444,9 +448,6 @@ class Agent(BaseModel, ABC): # message.metadata.eval_duration += response.eval_duration # message.metadata.prompt_eval_count += response.prompt_eval_count # message.metadata.prompt_eval_duration += response.prompt_eval_duration - # self.context_tokens = ( - # response.prompt_eval_count + response.eval_count - # ) # message.status = "done" # yield message @@ -617,7 +618,6 @@ Content: {content} return self.user.collect_metrics(agent=self, response=response) - self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count chat_message = ChatMessage( session_id=session_id, @@ -626,10 +626,12 @@ Content: {content} content=content, metadata=ChatMessageMetaData( options=options, - eval_count=response.usage.eval_count, - eval_duration=response.usage.eval_duration, - prompt_eval_count=response.usage.prompt_eval_count, - prompt_eval_duration=response.usage.prompt_eval_duration, + usage=UsageStats( + eval_count=response.usage.eval_count, + eval_duration=response.usage.eval_duration, + prompt_eval_count=response.usage.prompt_eval_count, + prompt_eval_duration=response.usage.prompt_eval_duration, + ), ), ) yield chat_message @@ -848,7 +850,6 @@ Content: {content} return self.user.collect_metrics(agent=self, response=response) - self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count end_time = time.perf_counter() chat_message = ChatMessage( @@ -858,10 +859,12 @@ Content: {content} content=content, metadata=ChatMessageMetaData( options=options, - eval_count=response.usage.eval_count, - eval_duration=response.usage.eval_duration, - prompt_eval_count=response.usage.prompt_eval_count, - prompt_eval_duration=response.usage.prompt_eval_duration, + usage=UsageStats( + eval_count=response.usage.eval_count, + eval_duration=response.usage.eval_duration, + prompt_eval_count=response.usage.prompt_eval_count, + prompt_eval_duration=response.usage.prompt_eval_duration, + ), rag_results=rag_message.content if rag_message else [], llm_history=messages, timers={ diff --git a/src/backend/agents/generate_resume.py b/src/backend/agents/generate_resume.py index 0d26b0f..d69390e 100644 --- a/src/backend/agents/generate_resume.py +++ b/src/backend/agents/generate_resume.py @@ -189,7 +189,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme return # Stage 1A: Analyze job requirements status_message = ChatMessageStatus( - session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING + session_id=session_id, content="Generating resume...", activity=ApiActivityType.THINKING ) yield status_message diff --git a/src/backend/agents/job_requirements.py b/src/backend/agents/job_requirements.py index 0985296..db9270c 100644 --- a/src/backend/agents/job_requirements.py +++ b/src/backend/agents/job_requirements.py @@ -46,7 +46,7 @@ experiences, and qualifications required in a job description. ## INSTRUCTIONS: 1. Analyze ONLY the <|job_description|> provided, and provide only requirements from that description. -2. Extract company information, job title, and all requirements. +2. Extract company information, job title, pay range, and all requirements. 3. If a requirement can be broken into multiple requirements, do so. 4. Categorize each requirement into one and only one of the following categories: - Technical skills (required and preferred) @@ -68,6 +68,8 @@ experiences, and qualifications required in a job description. "company_name": "Company Name", "job_title": "Job Title", "job_summary": "Brief summary of the job", +"pay_minimum": "Minimum pay range", +"pay_maximum": "Maximum pay range", "job_requirements": { "technical_skills": { "required": ["skill1", "skill2"], @@ -197,6 +199,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r company = "" summary = "" title = "" + pay_maximum = None + pay_minimum = None try: json_str = self.extract_json_from_text(generated_message.content) requirements_json = json.loads(json_str) @@ -205,6 +209,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r title = requirements_json.get("job_title", "") summary = requirements_json.get("job_summary", "") job_requirements_data = requirements_json.get("job_requirements", None) + pay_minimum = requirements_json.get("pay_minimum", None) + pay_maximum = requirements_json.get("pay_maximum", None) requirements = JobRequirements.model_validate(job_requirements_data) except json.JSONDecodeError as e: status_message.status = ApiStatusType.ERROR @@ -238,6 +244,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r company=company, title=title, summary=summary, + pay_minimum=pay_minimum, + pay_maximum=pay_maximum, requirements=requirements, description=prompt, ) diff --git a/src/backend/agents/skill_match.py b/src/backend/agents/skill_match.py index 9c17cc8..30a53f1 100644 --- a/src/backend/agents/skill_match.py +++ b/src/backend/agents/skill_match.py @@ -13,12 +13,14 @@ from database.core import RedisDatabase from .base import Agent, agent_registry from models import ( + ApiActivityType, ApiMessage, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageSkillAssessment, ApiStatusType, + ChatMessageStatus, EvidenceDetail, SkillAssessment, Tunables, @@ -162,6 +164,11 @@ JSON RESPONSE:""" logger.info(f"🔍 RAG content retrieved {len(rag_context)} bytes of context") system_prompt, prompt = self.generate_skill_assessment_prompt(skill=skill, rag_context=rag_context) + status_message = ChatMessageStatus( + session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating skill assessment..." + ) + yield status_message + generated_message = None async for generated_message in self.llm_one_shot( llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7 diff --git a/src/backend/defines.py b/src/backend/defines.py index fa31cb2..f3c3915 100644 --- a/src/backend/defines.py +++ b/src/backend/defines.py @@ -1,6 +1,6 @@ import os -ollama_api_url = "http://ollama:11434" # Default Ollama local endpoint +ollama_api_url = os.getenv("OPENAI_URL", "http://ollama:11434") # Default Ollama local endpoint frontend_url = os.getenv("FRONTEND_URL", "https://backstory.ketrenos.com") diff --git a/src/backend/models.py b/src/backend/models.py index c297217..7e3bc2a 100644 --- a/src/backend/models.py +++ b/src/backend/models.py @@ -2,6 +2,7 @@ from typing import List, Dict, Optional, Any, Union, Literal, TypeVar, Annotated from pydantic import BaseModel, Field, EmailStr, HttpUrl, model_validator, field_validator, ConfigDict from datetime import datetime, UTC from enum import Enum +import os import uuid from utils.auth_utils import ( validate_password_strength, @@ -800,6 +801,8 @@ class Job(BaseModel): title: Optional[str] summary: Optional[str] company: Optional[str] + pay_minimum: Optional[str] = Field(default=None, alias=str("payMinimum")) + pay_maximum: Optional[str] = Field(default=None, alias=str("payMaximum")) description: str requirements: Optional[JobRequirements] created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), alias=str("createdAt")) @@ -886,6 +889,69 @@ class ChatOptions(BaseModel): temperature: Optional[float] = Field(default=0.7) # Higher temperature to encourage tool usage model_config = ConfigDict(populate_by_name=True) +# Enhanced usage statistics model +class UsageStats(BaseModel): + """Comprehensive usage statistics across all providers""" + + # Token counts (standardized across providers) + prompt_tokens: Optional[int] = Field(default=None, description="Number of tokens in the prompt") + completion_tokens: Optional[int] = Field(default=None, description="Number of tokens in the completion") + total_tokens: Optional[int] = Field(default=None, description="Total number of tokens used") + + # Ollama-specific detailed stats + prompt_eval_count: Optional[int] = Field(default=None, description="Number of tokens evaluated in prompt") + prompt_eval_duration: Optional[int] = Field(default=None, description="Time spent evaluating prompt (nanoseconds)") + eval_count: Optional[int] = Field(default=None, description="Number of tokens generated") + eval_duration: Optional[int] = Field(default=None, description="Time spent generating tokens (nanoseconds)") + total_duration: Optional[int] = Field(default=None, description="Total request duration (nanoseconds)") + + # Performance metrics + tokens_per_second: Optional[float] = Field(default=None, description="Generation speed in tokens/second") + prompt_tokens_per_second: Optional[float] = Field(default=None, description="Prompt processing speed") + + # Additional provider-specific stats + extra_stats: Optional[Dict[str, Any]] = Field( + default_factory=dict, description="Provider-specific additional statistics" + ) + + def calculate_derived_stats(self) -> None: + """Calculate derived statistics where possible""" + # Calculate tokens per second for Ollama + if self.eval_count and self.eval_duration and self.eval_duration > 0: + # Convert nanoseconds to seconds and calculate tokens/sec + duration_seconds = self.eval_duration / 1_000_000_000 + self.tokens_per_second = self.eval_count / duration_seconds + + # Calculate prompt processing speed for Ollama + if self.prompt_eval_count and self.prompt_eval_duration and self.prompt_eval_duration > 0: + duration_seconds = self.prompt_eval_duration / 1_000_000_000 + self.prompt_tokens_per_second = self.prompt_eval_count / duration_seconds + + # Standardize token counts across providers + if not self.total_tokens and self.prompt_tokens and self.completion_tokens: + self.total_tokens = self.prompt_tokens + self.completion_tokens + + # Map Ollama counts to standard format if not already set + if not self.prompt_tokens and self.prompt_eval_count: + self.prompt_tokens = self.prompt_eval_count + if not self.completion_tokens and self.eval_count: + self.completion_tokens = self.eval_count + + +class ChatResponse(BaseModel): + content: str + model: str + finish_reason: Optional[str] = Field(default="") + usage: Optional[UsageStats] = Field(default=None) + # Keep legacy usage field for backward compatibility + usage_legacy: Optional[Dict[str, int]] = Field(default=None, alias="usage_dict") + + def get_usage_dict(self) -> Dict[str, Any]: + """Get usage statistics as dictionary for backward compatibility""" + if self.usage: + return self.usage.model_dump(exclude_none=True) + return self.usage_legacy or {} + # Add rate limiting configuration models class RateLimitConfig(BaseModel): @@ -1063,15 +1129,17 @@ class ChatMessageMetaData(BaseModel): stop_sequences: List[str] = Field(default=[], alias=str("stopSequences")) rag_results: List[ChromaDBGetResponse] = Field(default_factory=list, alias=str("ragResults")) llm_history: List[LLMMessage] = Field(default_factory=list, alias=str("llmHistory")) - eval_count: int = 0 - eval_duration: int = 0 - prompt_eval_count: int = 0 - prompt_eval_duration: int = 0 + usage: UsageStats = Field(default_factory=UsageStats) options: Optional[ChatOptions] = None tools: Optional[Tool] = None timers: Dict[str, float] = Field(default_factory=dict) model_config = ConfigDict(populate_by_name=True) +class SkillMatchRequest(BaseModel): + """Request model for skill match""" + + skill: str + regenerate: bool = Field(default=False, description="Whether to regenerate the skill match even if cached") class ChatMessageUser(ApiMessage): type: ApiMessageType = ApiMessageType.TEXT @@ -1126,6 +1194,7 @@ class SystemInfo(BaseModel): installed_RAM: str = Field(..., alias=str("installedRAM")) graphics_cards: List[GPUInfo] = Field(..., alias=str("graphicsCards")) CPU: str + llm_backend: str = Field(default=os.getenv("DEFAULT_LLM_PROVIDER", "openai"), alias=str("llmBackend")) llm_model: str = Field(default=defines.model, alias=str("llmModel")) embedding_model: str = Field(default=defines.embedding_model, alias=str("embeddingModel")) max_context_length: int = Field(default=defines.max_context, alias=str("maxContextLength")) diff --git a/src/backend/routes/candidates.py b/src/backend/routes/candidates.py index 75628ff..07262c9 100644 --- a/src/backend/routes/candidates.py +++ b/src/backend/routes/candidates.py @@ -59,6 +59,7 @@ from models import ( RagContentMetadata, RagContentResponse, SkillAssessment, + SkillMatchRequest, SkillStrength, UserType, ) @@ -1545,11 +1546,10 @@ async def post_job_analysis( logger.error(f"❌ Get candidate job analysis error: {e}") return JSONResponse(status_code=500, content=create_error_response("JOB_ANALYSIS_ERROR", str(e))) - @router.post("/{candidate_id}/skill-match") async def get_candidate_skill_match( candidate_id: str = Path(...), - skill: str = Body(...), + request: SkillMatchRequest = Body(...), current_user=Depends(get_current_user_or_guest), database: RedisDatabase = Depends(get_database), ) -> StreamingResponse: @@ -1567,10 +1567,14 @@ async def get_candidate_skill_match( candidate = Candidate.model_validate(candidate_data) + skill = request.skill.strip() + cache_key = get_skill_cache_key(candidate.id, skill) # Get cached assessment if it exists - assessment: SkillAssessment | None = await database.get_cached_skill_match(cache_key) + assessment: SkillAssessment | None = None + if not request.regenerate: + assessment = await database.get_cached_skill_match(cache_key) if assessment and assessment.skill.lower() != skill.lower(): logger.warning( diff --git a/src/backend/system_info.py b/src/backend/system_info.py index e7a552d..4e20bb3 100644 --- a/src/backend/system_info.py +++ b/src/backend/system_info.py @@ -2,6 +2,7 @@ import defines import re import subprocess import math +import os from models import GPUInfo, SystemInfo @@ -83,6 +84,7 @@ def system_info() -> SystemInfo: installed_RAM=get_installed_ram(), graphics_cards=get_graphics_cards(), CPU=get_cpu_info(), + llm_backend=os.getenv("DEFAULT_LLM_PROVIDER", "openai"), llm_model=defines.model, embedding_model=defines.embedding_model, max_context_length=defines.max_context, diff --git a/src/backend/utils/llm_proxy.py b/src/backend/utils/llm_proxy.py index e2c215b..b8d75b3 100644 --- a/src/backend/utils/llm_proxy.py +++ b/src/backend/utils/llm_proxy.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod import time +import traceback from typing import Dict, List, Any, AsyncGenerator, Optional, Union, cast from pydantic import BaseModel, Field from enum import Enum @@ -8,6 +9,7 @@ from dataclasses import dataclass import os from redis.asyncio import Redis +from models import ChatResponse, UsageStats import defines from logger import logger import hashlib @@ -48,55 +50,6 @@ class LLMMessage: return {"role": self.role, "content": self.content} -# Enhanced usage statistics model -class UsageStats(BaseModel): - """Comprehensive usage statistics across all providers""" - - # Token counts (standardized across providers) - prompt_tokens: Optional[int] = Field(default=None, description="Number of tokens in the prompt") - completion_tokens: Optional[int] = Field(default=None, description="Number of tokens in the completion") - total_tokens: Optional[int] = Field(default=None, description="Total number of tokens used") - - # Ollama-specific detailed stats - prompt_eval_count: Optional[int] = Field(default=None, description="Number of tokens evaluated in prompt") - prompt_eval_duration: Optional[int] = Field(default=None, description="Time spent evaluating prompt (nanoseconds)") - eval_count: Optional[int] = Field(default=None, description="Number of tokens generated") - eval_duration: Optional[int] = Field(default=None, description="Time spent generating tokens (nanoseconds)") - total_duration: Optional[int] = Field(default=None, description="Total request duration (nanoseconds)") - - # Performance metrics - tokens_per_second: Optional[float] = Field(default=None, description="Generation speed in tokens/second") - prompt_tokens_per_second: Optional[float] = Field(default=None, description="Prompt processing speed") - - # Additional provider-specific stats - extra_stats: Optional[Dict[str, Any]] = Field( - default_factory=dict, description="Provider-specific additional statistics" - ) - - def calculate_derived_stats(self) -> None: - """Calculate derived statistics where possible""" - # Calculate tokens per second for Ollama - if self.eval_count and self.eval_duration and self.eval_duration > 0: - # Convert nanoseconds to seconds and calculate tokens/sec - duration_seconds = self.eval_duration / 1_000_000_000 - self.tokens_per_second = self.eval_count / duration_seconds - - # Calculate prompt processing speed for Ollama - if self.prompt_eval_count and self.prompt_eval_duration and self.prompt_eval_duration > 0: - duration_seconds = self.prompt_eval_duration / 1_000_000_000 - self.prompt_tokens_per_second = self.prompt_eval_count / duration_seconds - - # Standardize token counts across providers - if not self.total_tokens and self.prompt_tokens and self.completion_tokens: - self.total_tokens = self.prompt_tokens + self.completion_tokens - - # Map Ollama counts to standard format if not already set - if not self.prompt_tokens and self.prompt_eval_count: - self.prompt_tokens = self.prompt_eval_count - if not self.completion_tokens and self.eval_count: - self.completion_tokens = self.eval_count - - # Embedding response models class EmbeddingData(BaseModel): """Single embedding result""" @@ -123,21 +76,6 @@ class EmbeddingResponse(BaseModel): return [item.embedding for item in self.data] -class ChatResponse(BaseModel): - content: str - model: str - finish_reason: Optional[str] = Field(default="") - usage: Optional[UsageStats] = Field(default=None) - # Keep legacy usage field for backward compatibility - usage_legacy: Optional[Dict[str, int]] = Field(default=None, alias="usage_dict") - - def get_usage_dict(self) -> Dict[str, Any]: - """Get usage statistics as dictionary for backward compatibility""" - if self.usage: - return self.usage.model_dump(exclude_none=True) - return self.usage_legacy or {} - - class LLMProvider(str, Enum): OLLAMA = "ollama" OPENAI = "openai" @@ -154,13 +92,23 @@ class BaseLLMAdapter(ABC): @abstractmethod async def chat( - self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs + self, + model: str, + messages: List[LLMMessage], + stream: bool = False, + options: Optional[dict] = None, + **kwargs, ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: """Send chat messages and get response""" @abstractmethod async def generate( - self, model: str, prompt: str, stream: bool = False, **kwargs + self, + model: str, + prompt: str, + stream: bool = False, + options: Optional[dict] = None, + **kwargs, ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: """Generate text from prompt""" @@ -223,7 +171,7 @@ class OllamaAdapter(BaseLLMAdapter): return usage_stats async def chat( - self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs + self, model: str, messages: List[LLMMessage], stream: bool = False, options: Optional[dict] = None, **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: # Convert LLMMessage objects to Ollama format ollama_messages = [] @@ -231,9 +179,11 @@ class OllamaAdapter(BaseLLMAdapter): ollama_messages.append({"role": msg.role, "content": msg.content}) if stream: - return self._stream_chat(model, ollama_messages, **kwargs) + return self._stream_chat(model, ollama_messages, options=options, **kwargs) else: - response = await self.client.chat(model=model, messages=ollama_messages, stream=False, **kwargs) + response = await self.client.chat( + model=model, messages=ollama_messages, stream=False, options=options, **kwargs + ) usage_stats = self._create_usage_stats(response.model_dump()) @@ -244,9 +194,9 @@ class OllamaAdapter(BaseLLMAdapter): usage=usage_stats, ) - async def _stream_chat(self, model: str, messages: List[Dict], **kwargs): + async def _stream_chat(self, model: str, messages: List[Dict], options: Optional[dict] = None, **kwargs): # Await the chat call first, then iterate over the result - stream = await self.client.chat(model=model, messages=messages, stream=True, **kwargs) + stream = await self.client.chat(model=model, messages=messages, stream=True, options=options, **kwargs) # Accumulate stats for final chunk accumulated_stats = {} @@ -265,12 +215,17 @@ class OllamaAdapter(BaseLLMAdapter): yield ChatResponse(content=content, model=model, finish_reason=None, usage=None) async def generate( - self, model: str, prompt: str, stream: bool = False, **kwargs + self, + model: str, + prompt: str, + stream: bool = False, + options: Optional[dict] = None, + **kwargs, ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: if stream: - return self._stream_generate(model, prompt, **kwargs) + return self._stream_generate(model, prompt, options=options, **kwargs) else: - response = await self.client.generate(model=model, prompt=prompt, stream=False, **kwargs) + response = await self.client.generate(model=model, prompt=prompt, stream=False, options=options, **kwargs) usage_stats = self._create_usage_stats(response.model_dump()) @@ -278,9 +233,9 @@ class OllamaAdapter(BaseLLMAdapter): content=response["response"], model=model, finish_reason=response.get("done_reason"), usage=usage_stats ) - async def _stream_generate(self, model: str, prompt: str, **kwargs): + async def _stream_generate(self, model: str, prompt: str, options: Optional[dict] = None, **kwargs): # Await the generate call first, then iterate over the result - stream = await self.client.generate(model=model, prompt=prompt, stream=True, **kwargs) + stream = await self.client.generate(model=model, prompt=prompt, stream=True, options=options, **kwargs) accumulated_stats = {} @@ -352,7 +307,7 @@ class OpenAIAdapter(BaseLLMAdapter): ) async def chat( - self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs + self, model: str, messages: List[LLMMessage], stream: bool = False, options: Optional[dict] = None, **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: # Convert LLMMessage objects to OpenAI format openai_messages = [] @@ -360,10 +315,16 @@ class OpenAIAdapter(BaseLLMAdapter): openai_messages.append({"role": msg.role, "content": msg.content}) if stream: - return self._stream_chat(model, openai_messages, **kwargs) + return self._stream_chat(model, openai_messages, options=options, **kwargs) else: response = await self.client.chat.completions.create( - model=model, messages=openai_messages, stream=False, **kwargs + model=model, + messages=openai_messages, + stream=False, + temperature=options.get("temperature") if options else None, + max_tokens=options.get("num_ctx") if options else None, + seed=options.get("seed") if options else None, + **kwargs, ) usage_stats = self._create_usage_stats(response.usage) @@ -394,11 +355,16 @@ class OpenAIAdapter(BaseLLMAdapter): ) async def generate( - self, model: str, prompt: str, stream: bool = False, **kwargs + self, + model: str, + prompt: str, + stream: bool = False, + options: Optional[dict] = None, + **kwargs, ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: # Convert to chat format for OpenAI messages = [LLMMessage(role="user", content=prompt)] - return await self.chat(model, messages, stream, **kwargs) + return await self.chat(model, messages, stream, options=options, **kwargs) async def embeddings(self, model: str, input_texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse: """Generate embeddings using OpenAI""" @@ -448,8 +414,47 @@ class AnthropicAdapter(BaseLLMAdapter): extra_stats={k: v for k, v in usage_dict.items() if k not in ["input_tokens", "output_tokens"]}, ) + def _prepare_anthropic_kwargs(self, options: Optional[dict], **kwargs) -> dict: + """Prepare kwargs for Anthropic API from ChatOptions and additional kwargs""" + anthropic_kwargs = dict(kwargs) + + if options: + # Extract from dict + temperature = options.get("temperature") + num_ctx = options.get("num_ctx") or options.get("numCtx") # Handle both field names + seed = options.get("seed") + + # Map to Anthropic parameters + if temperature is not None: + anthropic_kwargs["temperature"] = temperature + + if num_ctx is not None: + # num_ctx maps to max_tokens for Anthropic + # Anthropic has a max context of 8192 tokens + if num_ctx > 8192: + logger.warning( + f"num_ctx ({num_ctx}) exceeds Anthropic's max context length (8192). " + "Setting max_tokens to 8192." + ) + anthropic_kwargs["max_tokens"] = min(8192, num_ctx) + + # Anthropic doesn't support seed, but we can try passing it + if seed is not None: + logger.warning("Anthropic does not officially support 'seed' parameter. ") + + # Set default max_tokens if not provided + if "max_tokens" not in anthropic_kwargs: + anthropic_kwargs["max_tokens"] = 1000 + + return anthropic_kwargs + async def chat( - self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs + self, + model: str, + messages: List[LLMMessage], + stream: bool = False, + options: Optional[dict] = None, + **kwargs, ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: # Anthropic requires system message to be separate system_message = None @@ -461,11 +466,13 @@ class AnthropicAdapter(BaseLLMAdapter): else: anthropic_messages.append({"role": msg.role, "content": msg.content}) + # Prepare kwargs with options + prepared_kwargs = self._prepare_anthropic_kwargs(options, **kwargs) + request_kwargs = { "model": model, "messages": anthropic_messages, - "max_tokens": kwargs.pop("max_tokens", 1000), - **kwargs, + **prepared_kwargs, } if system_message: @@ -511,10 +518,15 @@ class AnthropicAdapter(BaseLLMAdapter): logger.info(f"Could not retrieve final usage stats: {e}") async def generate( - self, model: str, prompt: str, stream: bool = False, **kwargs + self, + model: str, + prompt: str, + stream: bool = False, + options: Optional[dict] = None, + **kwargs, ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: messages = [LLMMessage(role="user", content=prompt)] - return await self.chat(model, messages, stream, **kwargs) + return await self.chat(model, messages, stream, options, **kwargs) async def embeddings(self, model: str, input_texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse: """Anthropic doesn't provide embeddings API""" @@ -524,7 +536,12 @@ class AnthropicAdapter(BaseLLMAdapter): async def list_models(self) -> List[str]: # Anthropic doesn't have a list models endpoint, return known models - return ["claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", "claude-3-opus-20240229"] + return [ + "claude-3-5-haiku-latest", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "claude-3-opus-20240229", + ] class GeminiAdapter(BaseLLMAdapter): @@ -564,7 +581,7 @@ class GeminiAdapter(BaseLLMAdapter): return usage_stats async def chat( - self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs + self, model: str, messages: List[LLMMessage], stream: bool = False, options=Optional[dict], **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: model_instance = self.genai.GenerativeModel(model) # type: ignore @@ -629,7 +646,12 @@ class GeminiAdapter(BaseLLMAdapter): ) async def generate( - self, model: str, prompt: str, stream: bool = False, **kwargs + self, + model: str, + prompt: str, + stream: bool = False, + options: Optional[dict] = None, + **kwargs, ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: messages = [LLMMessage(role="user", content=prompt)] return await self.chat(model, messages, stream, **kwargs) @@ -737,6 +759,7 @@ class LLMCacheMixin: model: str, messages_or_prompt: Union[List[LLMMessage], str], provider: Optional[LLMProvider] = None, + options: Optional[dict] = None, **kwargs, ) -> str: """Create consistent cache key from parameters""" @@ -752,7 +775,13 @@ class LLMCacheMixin: serializable_messages.append(msg) messages_or_prompt = serializable_messages - cache_data = {"method": method, "model": model, "content": messages_or_prompt, "kwargs": cache_kwargs} + cache_data = { + "method": method, + "model": model, + "content": messages_or_prompt, + "options": options, + "kwargs": cache_kwargs, + } if self.cache_config.provider_specific and provider: cache_data["provider"] = provider.value @@ -935,7 +964,7 @@ class LLMCacheMixin: class UnifiedLLMProxy(LLMCacheMixin): - """Main proxy class that provides unified interface to all LLM providers""" + """Main proxy class with automatic Ollama fallback for Anthropic embeddings""" def __init__(self, redis_client: Redis, default_provider: LLMProvider = LLMProvider.OLLAMA): # Initialize with caching support @@ -1019,6 +1048,7 @@ class UnifiedLLMProxy(LLMCacheMixin): messages: Union[List[LLMMessage], List[Dict[str, str]]], provider: Optional[LLMProvider] = None, stream: bool = False, + options: Optional[dict] = None, **kwargs, ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: """Enhanced chat method with streaming cache support""" @@ -1037,7 +1067,12 @@ class UnifiedLLMProxy(LLMCacheMixin): # Create cache key cache_key = self._create_cache_key( - method="chat", model=model, messages_or_prompt=normalized_messages, provider=provider, **kwargs + method="chat", + model=model, + messages_or_prompt=normalized_messages, + provider=provider, + options=options, + **kwargs, ) # Check cache @@ -1059,7 +1094,7 @@ class UnifiedLLMProxy(LLMCacheMixin): if stream: # For streaming: we know this will return AsyncGenerator - stream_response = await adapter.chat(model, normalized_messages, stream=True, **kwargs) + stream_response = await adapter.chat(model, normalized_messages, stream=True, options=options, **kwargs) stream_response = cast(AsyncGenerator[ChatResponse, None], stream_response) if self.cache_config.enabled: @@ -1068,7 +1103,7 @@ class UnifiedLLMProxy(LLMCacheMixin): return stream_response else: # For non-streaming: we know this will return ChatResponse - single_response = await adapter.chat(model, normalized_messages, stream=False, **kwargs) + single_response = await adapter.chat(model, normalized_messages, stream=False, options=options, **kwargs) single_response = cast(ChatResponse, single_response) if self.cache_config.enabled: @@ -1082,6 +1117,7 @@ class UnifiedLLMProxy(LLMCacheMixin): prompt: str, provider: Optional[LLMProvider] = None, stream: bool = False, + options: Optional[dict] = None, **kwargs, ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: """Enhanced generate method with streaming cache support""" @@ -1114,7 +1150,7 @@ class UnifiedLLMProxy(LLMCacheMixin): if stream: # For streaming: we know this will return AsyncGenerator - stream_response = await adapter.generate(model, prompt, stream=True, **kwargs) + stream_response = await adapter.generate(model, prompt, stream=True, options=options, **kwargs) stream_response = cast(AsyncGenerator[ChatResponse, None], stream_response) if self.cache_config.enabled: @@ -1125,7 +1161,7 @@ class UnifiedLLMProxy(LLMCacheMixin): return stream_response else: # For non-streaming: we know this will return ChatResponse - single_response = await adapter.generate(model, prompt, stream=False, **kwargs) + single_response = await adapter.generate(model, prompt, stream=False, options=options, **kwargs) single_response = cast(ChatResponse, single_response) if self.cache_config.enabled: await self._set_cached_response(cache_key, single_response) @@ -1145,6 +1181,13 @@ class UnifiedLLMProxy(LLMCacheMixin): if provider in adapter_classes: self.adapters[provider] = adapter_classes[provider](**config) self._initialized_providers.add(provider) + + # If configuring Anthropic, ensure Ollama is also configured for embeddings + if provider == LLMProvider.ANTHROPIC and LLMProvider.OLLAMA not in self._initialized_providers: + logger.warning( + "Anthropic provider configured but Ollama is not available for embeddings. " + "Embedding operations with Anthropic provider will fail until Ollama is configured." + ) else: raise ValueError(f"Unsupported provider: {provider}") @@ -1160,12 +1203,13 @@ class UnifiedLLMProxy(LLMCacheMixin): messages: Union[List[LLMMessage], List[Dict[str, str]]], provider: Optional[LLMProvider] = None, stream: bool = True, + options: Optional[dict] = None, **kwargs, ) -> AsyncGenerator[ChatResponse, None]: """Stream chat messages using specified or default provider""" if stream is False: raise ValueError("stream must be True for chat_stream") - result = await self.chat(model, messages, provider, stream=True, **kwargs) + result = await self.chat(model, messages, provider, stream=True, options=options, **kwargs) if isinstance(result, ChatResponse): raise RuntimeError("Expected AsyncGenerator, got ChatResponse") async for chunk in result: @@ -1176,32 +1220,43 @@ class UnifiedLLMProxy(LLMCacheMixin): model: str, messages: Union[List[LLMMessage], List[Dict[str, str]]], provider: Optional[LLMProvider] = None, + options: Optional[dict] = None, **kwargs, ) -> ChatResponse: """Get single chat response using specified or default provider""" - result = await self.chat(model, messages, provider, stream=False, **kwargs) + result = await self.chat(model, messages, provider, stream=False, options=options, **kwargs) if not isinstance(result, ChatResponse): raise RuntimeError("Expected ChatResponse, got AsyncGenerator") return result async def generate_stream( - self, model: str, prompt: str, provider: Optional[LLMProvider] = None, **kwargs + self, + model: str, + prompt: str, + provider: Optional[LLMProvider] = None, + options: Optional[dict] = None, + **kwargs, ) -> AsyncGenerator[ChatResponse, None]: """Stream text generation using specified or default provider""" - result = await self.generate(model, prompt, provider, stream=True, **kwargs) + result = await self.generate(model, prompt, provider, stream=True, options=options, **kwargs) if isinstance(result, ChatResponse): raise RuntimeError("Expected AsyncGenerator, got ChatResponse") async for chunk in result: yield chunk async def generate_single( - self, model: str, prompt: str, provider: Optional[LLMProvider] = None, **kwargs + self, + model: str, + prompt: str, + provider: Optional[LLMProvider] = None, + options: Optional[dict] = None, + **kwargs, ) -> ChatResponse: """Get single generation response using specified or default provider""" - result = await self.generate(model, prompt, provider, stream=False, **kwargs) + result = await self.generate(model, prompt, provider, stream=False, options=options, **kwargs) if not isinstance(result, ChatResponse): raise RuntimeError("Expected ChatResponse, got AsyncGenerator") return result @@ -1209,11 +1264,52 @@ class UnifiedLLMProxy(LLMCacheMixin): async def embeddings( self, model: str, input_texts: Union[str, List[str]], provider: Optional[LLMProvider] = None, **kwargs ) -> EmbeddingResponse: - """Generate embeddings using specified or default provider""" - provider = provider or self.default_provider - adapter = self._get_adapter(provider) + """Generate embeddings using specified or default provider. - return await adapter.embeddings(model, input_texts, **kwargs) + For Anthropic provider, automatically routes to Ollama since Anthropic doesn't provide embeddings API. + """ + provider = provider or self.default_provider + + # Special handling: If Anthropic is requested for embeddings, use Ollama instead + if provider == LLMProvider.ANTHROPIC: + logger.info("Anthropic provider requested for embeddings - routing to Ollama provider") + + # Ensure Ollama is available + if LLMProvider.OLLAMA not in self._initialized_providers: + raise RuntimeError( + "Anthropic provider requires Ollama for embeddings, but Ollama is not configured. " + "Please ensure Ollama endpoint is properly configured." + ) + + # Use Ollama for embeddings with a default embedding model if not specified + embedding_provider = LLMProvider.OLLAMA + + # If no specific model provided, use a default Ollama embedding model + if not model or model in [ + "claude-3-5-haiku-latest", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "claude-3-opus-20240229", + ]: + # Default to a common Ollama embedding model + model = os.getenv("OLLAMA_EMBEDDING_MODEL", "mxbai-embed-large") + logger.info(f"Using Ollama embedding model: {model}") + + else: + embedding_provider = provider + + adapter = self._get_adapter(embedding_provider) + + try: + return await adapter.embeddings(model, input_texts, **kwargs) + except Exception as e: + if embedding_provider == LLMProvider.OLLAMA and provider == LLMProvider.ANTHROPIC: + raise RuntimeError( + f"Failed to generate embeddings using Ollama (fallback for Anthropic): {e}. " + f"Please ensure Ollama is running and the embedding model '{model}' is available." + ) from e + else: + raise async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]: """List available models for specified or default provider""" @@ -1224,21 +1320,32 @@ class UnifiedLLMProxy(LLMCacheMixin): return await adapter.list_models() async def list_embedding_models(self, provider: Optional[LLMProvider] = None) -> List[str]: - """List available embedding models for specified or default provider""" + """List available embedding models for specified or default provider. + + For Anthropic provider, returns Ollama embedding models since that's what will be used. + """ provider = provider or self.default_provider # Provider-specific embedding models embedding_models = { LLMProvider.OLLAMA: [ "mxbai-embed-large", + "nomic-embed-text", + "all-minilm", ], LLMProvider.OPENAI: ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"], - LLMProvider.ANTHROPIC: [], # No embeddings API + LLMProvider.ANTHROPIC: [], # Will be handled specially below LLMProvider.GEMINI: ["models/embedding-001", "models/text-embedding-004"], } + # Special handling for Anthropic - return Ollama models since that's what will be used if provider == LLMProvider.ANTHROPIC: - raise NotImplementedError("Anthropic does not provide embeddings API") + logger.info("Listing Ollama embedding models (used for Anthropic embedding requests)") + + if LLMProvider.OLLAMA not in self._initialized_providers: + raise RuntimeError("Anthropic provider requires Ollama for embeddings, but Ollama is not configured.") + + return await self.list_embedding_models(LLMProvider.OLLAMA) # For Ollama, check which embedding models are actually available if provider == LLMProvider.OLLAMA: @@ -1266,7 +1373,7 @@ class UnifiedLLMProxy(LLMCacheMixin): class LLMManager: - """Updated manager with caching support""" + """Updated manager with caching support and required Ollama endpoint""" _instance = None _proxy = None @@ -1283,24 +1390,42 @@ class LLMManager: self._configure_from_environment() def _configure_from_environment(self): - """Configure providers and cache from environment variables""" + """Configure providers and cache from environment variables - Ollama is now required""" if not self._proxy: raise RuntimeError("UnifiedLLMProxy instance not initialized") - # Your existing configuration code... + # REQUIRED: Configure Ollama provider first (mandatory) ollama_host = os.getenv("OLLAMA_HOST", defines.ollama_api_url) - self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host) + if not ollama_host: + raise ValueError( + "Ollama endpoint is required. Please set OLLAMA_HOST environment variable " + "or ensure defines.ollama_api_url is configured." + ) + logger.info(f"Configuring Ollama provider (REQUIRED) with host: {ollama_host}") + try: + self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host) + except Exception as e: + raise RuntimeError(f"Failed to configure required Ollama provider: {e}") + + # Configure other providers as before if os.getenv("OPENAI_API_KEY"): + logger.info("Configuring OpenAI provider") self._proxy.configure_provider(LLMProvider.OPENAI) if os.getenv("ANTHROPIC_API_KEY"): + logger.info("Configuring Anthropic provider (will use Ollama for embeddings)") self._proxy.configure_provider(LLMProvider.ANTHROPIC) if os.getenv("GEMINI_API_KEY"): + logger.info("Configuring Gemini provider") self._proxy.configure_provider(LLMProvider.GEMINI) - # Add cache configuration from environment + # Validate that Ollama is properly configured + if LLMProvider.OLLAMA not in self._proxy._initialized_providers: + raise RuntimeError("Ollama provider must be successfully configured") + + # Add cache configuration from environment (unchanged) cache_config = {} if os.getenv("LLM_CACHE_TTL"): cache_config["ttl"] = int(os.getenv("LLM_CACHE_TTL", 31536000)) @@ -1319,11 +1444,15 @@ class LLMManager: if cache_config: asyncio.create_task(self._proxy.configure_cache(**cache_config)) - # Your existing default provider setup... + # Set default provider default_provider = os.getenv("DEFAULT_LLM_PROVIDER", "ollama") try: + logger.info(f"Setting default LLM provider to: {default_provider}") self._proxy.set_default_provider(LLMProvider(default_provider)) - except ValueError: + except ValueError as e: + logger.error(traceback.format_exc()) + logger.error(f"Invalid default provider: {default_provider}. Error: {e}") + logger.info(f"Setting default LLM provider to fallback: {LLMProvider.OLLAMA}") self._proxy.set_default_provider(LLMProvider.OLLAMA) def get_proxy(self) -> UnifiedLLMProxy: