Added regen

Fixed Anthropic backend
This commit is contained in:
James Ketr 2025-07-11 13:23:38 -07:00
parent a46172d696
commit fd6ec5015c
9 changed files with 365 additions and 143 deletions

View File

@ -26,6 +26,7 @@ from rag import start_file_watcher, ChromaDBFileWatcher
import defines
from logger import logger
from models import (
ChatResponse,
Tunables,
ChatMessageUser,
ChatMessage,
@ -34,6 +35,7 @@ from models import (
ApiStatusType,
Candidate,
ChatContextType,
UsageStats,
)
import utils.llm_proxy as llm_manager
from database.manager import RedisDatabase
@ -158,12 +160,15 @@ class CandidateEntity(Candidate):
raise ValueError("initialize() has not been called.")
return self.CandidateEntity__observer
def collect_metrics(self, agent: Agent, response):
def collect_metrics(self, agent: Agent, response: ChatResponse):
if not self.metrics:
logger.warning("No metrics collector set for this agent.")
return
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(response.usage.prompt_eval_count)
self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count)
if response.usage:
if response.usage.prompt_tokens:
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(response.usage.prompt_tokens)
if response.usage.completion_tokens:
self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.completion_tokens)
async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
if self.CandidateEntity__initialized:
@ -231,7 +236,6 @@ class Agent(BaseModel, ABC):
# Agent properties
system_prompt: str = ""
context_tokens: int = 0
# context_size is shared across all subclasses
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
@ -444,9 +448,6 @@ class Agent(BaseModel, ABC):
# message.metadata.eval_duration += response.eval_duration
# message.metadata.prompt_eval_count += response.prompt_eval_count
# message.metadata.prompt_eval_duration += response.prompt_eval_duration
# self.context_tokens = (
# response.prompt_eval_count + response.eval_count
# )
# message.status = "done"
# yield message
@ -617,7 +618,6 @@ Content: {content}
return
self.user.collect_metrics(agent=self, response=response)
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
chat_message = ChatMessage(
session_id=session_id,
@ -626,10 +626,12 @@ Content: {content}
content=content,
metadata=ChatMessageMetaData(
options=options,
eval_count=response.usage.eval_count,
eval_duration=response.usage.eval_duration,
prompt_eval_count=response.usage.prompt_eval_count,
prompt_eval_duration=response.usage.prompt_eval_duration,
usage=UsageStats(
eval_count=response.usage.eval_count,
eval_duration=response.usage.eval_duration,
prompt_eval_count=response.usage.prompt_eval_count,
prompt_eval_duration=response.usage.prompt_eval_duration,
),
),
)
yield chat_message
@ -848,7 +850,6 @@ Content: {content}
return
self.user.collect_metrics(agent=self, response=response)
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
end_time = time.perf_counter()
chat_message = ChatMessage(
@ -858,10 +859,12 @@ Content: {content}
content=content,
metadata=ChatMessageMetaData(
options=options,
eval_count=response.usage.eval_count,
eval_duration=response.usage.eval_duration,
prompt_eval_count=response.usage.prompt_eval_count,
prompt_eval_duration=response.usage.prompt_eval_duration,
usage=UsageStats(
eval_count=response.usage.eval_count,
eval_duration=response.usage.eval_duration,
prompt_eval_count=response.usage.prompt_eval_count,
prompt_eval_duration=response.usage.prompt_eval_duration,
),
rag_results=rag_message.content if rag_message else [],
llm_history=messages,
timers={

View File

@ -189,7 +189,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
return
# Stage 1A: Analyze job requirements
status_message = ChatMessageStatus(
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
session_id=session_id, content="Generating resume...", activity=ApiActivityType.THINKING
)
yield status_message

View File

@ -46,7 +46,7 @@ experiences, and qualifications required in a job description.
## INSTRUCTIONS:
1. Analyze ONLY the <|job_description|> provided, and provide only requirements from that description.
2. Extract company information, job title, and all requirements.
2. Extract company information, job title, pay range, and all requirements.
3. If a requirement can be broken into multiple requirements, do so.
4. Categorize each requirement into one and only one of the following categories:
- Technical skills (required and preferred)
@ -68,6 +68,8 @@ experiences, and qualifications required in a job description.
"company_name": "Company Name",
"job_title": "Job Title",
"job_summary": "Brief summary of the job",
"pay_minimum": "Minimum pay range",
"pay_maximum": "Maximum pay range",
"job_requirements": {
"technical_skills": {
"required": ["skill1", "skill2"],
@ -197,6 +199,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
company = ""
summary = ""
title = ""
pay_maximum = None
pay_minimum = None
try:
json_str = self.extract_json_from_text(generated_message.content)
requirements_json = json.loads(json_str)
@ -205,6 +209,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
title = requirements_json.get("job_title", "")
summary = requirements_json.get("job_summary", "")
job_requirements_data = requirements_json.get("job_requirements", None)
pay_minimum = requirements_json.get("pay_minimum", None)
pay_maximum = requirements_json.get("pay_maximum", None)
requirements = JobRequirements.model_validate(job_requirements_data)
except json.JSONDecodeError as e:
status_message.status = ApiStatusType.ERROR
@ -238,6 +244,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
company=company,
title=title,
summary=summary,
pay_minimum=pay_minimum,
pay_maximum=pay_maximum,
requirements=requirements,
description=prompt,
)

View File

@ -13,12 +13,14 @@ from database.core import RedisDatabase
from .base import Agent, agent_registry
from models import (
ApiActivityType,
ApiMessage,
ChatMessage,
ChatMessageError,
ChatMessageRagSearch,
ChatMessageSkillAssessment,
ApiStatusType,
ChatMessageStatus,
EvidenceDetail,
SkillAssessment,
Tunables,
@ -162,6 +164,11 @@ JSON RESPONSE:"""
logger.info(f"🔍 RAG content retrieved {len(rag_context)} bytes of context")
system_prompt, prompt = self.generate_skill_assessment_prompt(skill=skill, rag_context=rag_context)
status_message = ChatMessageStatus(
session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating skill assessment..."
)
yield status_message
generated_message = None
async for generated_message in self.llm_one_shot(
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7

View File

@ -1,6 +1,6 @@
import os
ollama_api_url = "http://ollama:11434" # Default Ollama local endpoint
ollama_api_url = os.getenv("OPENAI_URL", "http://ollama:11434") # Default Ollama local endpoint
frontend_url = os.getenv("FRONTEND_URL", "https://backstory.ketrenos.com")

View File

@ -2,6 +2,7 @@ from typing import List, Dict, Optional, Any, Union, Literal, TypeVar, Annotated
from pydantic import BaseModel, Field, EmailStr, HttpUrl, model_validator, field_validator, ConfigDict
from datetime import datetime, UTC
from enum import Enum
import os
import uuid
from utils.auth_utils import (
validate_password_strength,
@ -800,6 +801,8 @@ class Job(BaseModel):
title: Optional[str]
summary: Optional[str]
company: Optional[str]
pay_minimum: Optional[str] = Field(default=None, alias=str("payMinimum"))
pay_maximum: Optional[str] = Field(default=None, alias=str("payMaximum"))
description: str
requirements: Optional[JobRequirements]
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), alias=str("createdAt"))
@ -886,6 +889,69 @@ class ChatOptions(BaseModel):
temperature: Optional[float] = Field(default=0.7) # Higher temperature to encourage tool usage
model_config = ConfigDict(populate_by_name=True)
# Enhanced usage statistics model
class UsageStats(BaseModel):
"""Comprehensive usage statistics across all providers"""
# Token counts (standardized across providers)
prompt_tokens: Optional[int] = Field(default=None, description="Number of tokens in the prompt")
completion_tokens: Optional[int] = Field(default=None, description="Number of tokens in the completion")
total_tokens: Optional[int] = Field(default=None, description="Total number of tokens used")
# Ollama-specific detailed stats
prompt_eval_count: Optional[int] = Field(default=None, description="Number of tokens evaluated in prompt")
prompt_eval_duration: Optional[int] = Field(default=None, description="Time spent evaluating prompt (nanoseconds)")
eval_count: Optional[int] = Field(default=None, description="Number of tokens generated")
eval_duration: Optional[int] = Field(default=None, description="Time spent generating tokens (nanoseconds)")
total_duration: Optional[int] = Field(default=None, description="Total request duration (nanoseconds)")
# Performance metrics
tokens_per_second: Optional[float] = Field(default=None, description="Generation speed in tokens/second")
prompt_tokens_per_second: Optional[float] = Field(default=None, description="Prompt processing speed")
# Additional provider-specific stats
extra_stats: Optional[Dict[str, Any]] = Field(
default_factory=dict, description="Provider-specific additional statistics"
)
def calculate_derived_stats(self) -> None:
"""Calculate derived statistics where possible"""
# Calculate tokens per second for Ollama
if self.eval_count and self.eval_duration and self.eval_duration > 0:
# Convert nanoseconds to seconds and calculate tokens/sec
duration_seconds = self.eval_duration / 1_000_000_000
self.tokens_per_second = self.eval_count / duration_seconds
# Calculate prompt processing speed for Ollama
if self.prompt_eval_count and self.prompt_eval_duration and self.prompt_eval_duration > 0:
duration_seconds = self.prompt_eval_duration / 1_000_000_000
self.prompt_tokens_per_second = self.prompt_eval_count / duration_seconds
# Standardize token counts across providers
if not self.total_tokens and self.prompt_tokens and self.completion_tokens:
self.total_tokens = self.prompt_tokens + self.completion_tokens
# Map Ollama counts to standard format if not already set
if not self.prompt_tokens and self.prompt_eval_count:
self.prompt_tokens = self.prompt_eval_count
if not self.completion_tokens and self.eval_count:
self.completion_tokens = self.eval_count
class ChatResponse(BaseModel):
content: str
model: str
finish_reason: Optional[str] = Field(default="")
usage: Optional[UsageStats] = Field(default=None)
# Keep legacy usage field for backward compatibility
usage_legacy: Optional[Dict[str, int]] = Field(default=None, alias="usage_dict")
def get_usage_dict(self) -> Dict[str, Any]:
"""Get usage statistics as dictionary for backward compatibility"""
if self.usage:
return self.usage.model_dump(exclude_none=True)
return self.usage_legacy or {}
# Add rate limiting configuration models
class RateLimitConfig(BaseModel):
@ -1063,15 +1129,17 @@ class ChatMessageMetaData(BaseModel):
stop_sequences: List[str] = Field(default=[], alias=str("stopSequences"))
rag_results: List[ChromaDBGetResponse] = Field(default_factory=list, alias=str("ragResults"))
llm_history: List[LLMMessage] = Field(default_factory=list, alias=str("llmHistory"))
eval_count: int = 0
eval_duration: int = 0
prompt_eval_count: int = 0
prompt_eval_duration: int = 0
usage: UsageStats = Field(default_factory=UsageStats)
options: Optional[ChatOptions] = None
tools: Optional[Tool] = None
timers: Dict[str, float] = Field(default_factory=dict)
model_config = ConfigDict(populate_by_name=True)
class SkillMatchRequest(BaseModel):
"""Request model for skill match"""
skill: str
regenerate: bool = Field(default=False, description="Whether to regenerate the skill match even if cached")
class ChatMessageUser(ApiMessage):
type: ApiMessageType = ApiMessageType.TEXT
@ -1126,6 +1194,7 @@ class SystemInfo(BaseModel):
installed_RAM: str = Field(..., alias=str("installedRAM"))
graphics_cards: List[GPUInfo] = Field(..., alias=str("graphicsCards"))
CPU: str
llm_backend: str = Field(default=os.getenv("DEFAULT_LLM_PROVIDER", "openai"), alias=str("llmBackend"))
llm_model: str = Field(default=defines.model, alias=str("llmModel"))
embedding_model: str = Field(default=defines.embedding_model, alias=str("embeddingModel"))
max_context_length: int = Field(default=defines.max_context, alias=str("maxContextLength"))

View File

@ -59,6 +59,7 @@ from models import (
RagContentMetadata,
RagContentResponse,
SkillAssessment,
SkillMatchRequest,
SkillStrength,
UserType,
)
@ -1545,11 +1546,10 @@ async def post_job_analysis(
logger.error(f"❌ Get candidate job analysis error: {e}")
return JSONResponse(status_code=500, content=create_error_response("JOB_ANALYSIS_ERROR", str(e)))
@router.post("/{candidate_id}/skill-match")
async def get_candidate_skill_match(
candidate_id: str = Path(...),
skill: str = Body(...),
request: SkillMatchRequest = Body(...),
current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database),
) -> StreamingResponse:
@ -1567,10 +1567,14 @@ async def get_candidate_skill_match(
candidate = Candidate.model_validate(candidate_data)
skill = request.skill.strip()
cache_key = get_skill_cache_key(candidate.id, skill)
# Get cached assessment if it exists
assessment: SkillAssessment | None = await database.get_cached_skill_match(cache_key)
assessment: SkillAssessment | None = None
if not request.regenerate:
assessment = await database.get_cached_skill_match(cache_key)
if assessment and assessment.skill.lower() != skill.lower():
logger.warning(

View File

@ -2,6 +2,7 @@ import defines
import re
import subprocess
import math
import os
from models import GPUInfo, SystemInfo
@ -83,6 +84,7 @@ def system_info() -> SystemInfo:
installed_RAM=get_installed_ram(),
graphics_cards=get_graphics_cards(),
CPU=get_cpu_info(),
llm_backend=os.getenv("DEFAULT_LLM_PROVIDER", "openai"),
llm_model=defines.model,
embedding_model=defines.embedding_model,
max_context_length=defines.max_context,

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
import time
import traceback
from typing import Dict, List, Any, AsyncGenerator, Optional, Union, cast
from pydantic import BaseModel, Field
from enum import Enum
@ -8,6 +9,7 @@ from dataclasses import dataclass
import os
from redis.asyncio import Redis
from models import ChatResponse, UsageStats
import defines
from logger import logger
import hashlib
@ -48,55 +50,6 @@ class LLMMessage:
return {"role": self.role, "content": self.content}
# Enhanced usage statistics model
class UsageStats(BaseModel):
"""Comprehensive usage statistics across all providers"""
# Token counts (standardized across providers)
prompt_tokens: Optional[int] = Field(default=None, description="Number of tokens in the prompt")
completion_tokens: Optional[int] = Field(default=None, description="Number of tokens in the completion")
total_tokens: Optional[int] = Field(default=None, description="Total number of tokens used")
# Ollama-specific detailed stats
prompt_eval_count: Optional[int] = Field(default=None, description="Number of tokens evaluated in prompt")
prompt_eval_duration: Optional[int] = Field(default=None, description="Time spent evaluating prompt (nanoseconds)")
eval_count: Optional[int] = Field(default=None, description="Number of tokens generated")
eval_duration: Optional[int] = Field(default=None, description="Time spent generating tokens (nanoseconds)")
total_duration: Optional[int] = Field(default=None, description="Total request duration (nanoseconds)")
# Performance metrics
tokens_per_second: Optional[float] = Field(default=None, description="Generation speed in tokens/second")
prompt_tokens_per_second: Optional[float] = Field(default=None, description="Prompt processing speed")
# Additional provider-specific stats
extra_stats: Optional[Dict[str, Any]] = Field(
default_factory=dict, description="Provider-specific additional statistics"
)
def calculate_derived_stats(self) -> None:
"""Calculate derived statistics where possible"""
# Calculate tokens per second for Ollama
if self.eval_count and self.eval_duration and self.eval_duration > 0:
# Convert nanoseconds to seconds and calculate tokens/sec
duration_seconds = self.eval_duration / 1_000_000_000
self.tokens_per_second = self.eval_count / duration_seconds
# Calculate prompt processing speed for Ollama
if self.prompt_eval_count and self.prompt_eval_duration and self.prompt_eval_duration > 0:
duration_seconds = self.prompt_eval_duration / 1_000_000_000
self.prompt_tokens_per_second = self.prompt_eval_count / duration_seconds
# Standardize token counts across providers
if not self.total_tokens and self.prompt_tokens and self.completion_tokens:
self.total_tokens = self.prompt_tokens + self.completion_tokens
# Map Ollama counts to standard format if not already set
if not self.prompt_tokens and self.prompt_eval_count:
self.prompt_tokens = self.prompt_eval_count
if not self.completion_tokens and self.eval_count:
self.completion_tokens = self.eval_count
# Embedding response models
class EmbeddingData(BaseModel):
"""Single embedding result"""
@ -123,21 +76,6 @@ class EmbeddingResponse(BaseModel):
return [item.embedding for item in self.data]
class ChatResponse(BaseModel):
content: str
model: str
finish_reason: Optional[str] = Field(default="")
usage: Optional[UsageStats] = Field(default=None)
# Keep legacy usage field for backward compatibility
usage_legacy: Optional[Dict[str, int]] = Field(default=None, alias="usage_dict")
def get_usage_dict(self) -> Dict[str, Any]:
"""Get usage statistics as dictionary for backward compatibility"""
if self.usage:
return self.usage.model_dump(exclude_none=True)
return self.usage_legacy or {}
class LLMProvider(str, Enum):
OLLAMA = "ollama"
OPENAI = "openai"
@ -154,13 +92,23 @@ class BaseLLMAdapter(ABC):
@abstractmethod
async def chat(
self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs
self,
model: str,
messages: List[LLMMessage],
stream: bool = False,
options: Optional[dict] = None,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Send chat messages and get response"""
@abstractmethod
async def generate(
self, model: str, prompt: str, stream: bool = False, **kwargs
self,
model: str,
prompt: str,
stream: bool = False,
options: Optional[dict] = None,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Generate text from prompt"""
@ -223,7 +171,7 @@ class OllamaAdapter(BaseLLMAdapter):
return usage_stats
async def chat(
self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs
self, model: str, messages: List[LLMMessage], stream: bool = False, options: Optional[dict] = None, **kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
# Convert LLMMessage objects to Ollama format
ollama_messages = []
@ -231,9 +179,11 @@ class OllamaAdapter(BaseLLMAdapter):
ollama_messages.append({"role": msg.role, "content": msg.content})
if stream:
return self._stream_chat(model, ollama_messages, **kwargs)
return self._stream_chat(model, ollama_messages, options=options, **kwargs)
else:
response = await self.client.chat(model=model, messages=ollama_messages, stream=False, **kwargs)
response = await self.client.chat(
model=model, messages=ollama_messages, stream=False, options=options, **kwargs
)
usage_stats = self._create_usage_stats(response.model_dump())
@ -244,9 +194,9 @@ class OllamaAdapter(BaseLLMAdapter):
usage=usage_stats,
)
async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
async def _stream_chat(self, model: str, messages: List[Dict], options: Optional[dict] = None, **kwargs):
# Await the chat call first, then iterate over the result
stream = await self.client.chat(model=model, messages=messages, stream=True, **kwargs)
stream = await self.client.chat(model=model, messages=messages, stream=True, options=options, **kwargs)
# Accumulate stats for final chunk
accumulated_stats = {}
@ -265,12 +215,17 @@ class OllamaAdapter(BaseLLMAdapter):
yield ChatResponse(content=content, model=model, finish_reason=None, usage=None)
async def generate(
self, model: str, prompt: str, stream: bool = False, **kwargs
self,
model: str,
prompt: str,
stream: bool = False,
options: Optional[dict] = None,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
if stream:
return self._stream_generate(model, prompt, **kwargs)
return self._stream_generate(model, prompt, options=options, **kwargs)
else:
response = await self.client.generate(model=model, prompt=prompt, stream=False, **kwargs)
response = await self.client.generate(model=model, prompt=prompt, stream=False, options=options, **kwargs)
usage_stats = self._create_usage_stats(response.model_dump())
@ -278,9 +233,9 @@ class OllamaAdapter(BaseLLMAdapter):
content=response["response"], model=model, finish_reason=response.get("done_reason"), usage=usage_stats
)
async def _stream_generate(self, model: str, prompt: str, **kwargs):
async def _stream_generate(self, model: str, prompt: str, options: Optional[dict] = None, **kwargs):
# Await the generate call first, then iterate over the result
stream = await self.client.generate(model=model, prompt=prompt, stream=True, **kwargs)
stream = await self.client.generate(model=model, prompt=prompt, stream=True, options=options, **kwargs)
accumulated_stats = {}
@ -352,7 +307,7 @@ class OpenAIAdapter(BaseLLMAdapter):
)
async def chat(
self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs
self, model: str, messages: List[LLMMessage], stream: bool = False, options: Optional[dict] = None, **kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
# Convert LLMMessage objects to OpenAI format
openai_messages = []
@ -360,10 +315,16 @@ class OpenAIAdapter(BaseLLMAdapter):
openai_messages.append({"role": msg.role, "content": msg.content})
if stream:
return self._stream_chat(model, openai_messages, **kwargs)
return self._stream_chat(model, openai_messages, options=options, **kwargs)
else:
response = await self.client.chat.completions.create(
model=model, messages=openai_messages, stream=False, **kwargs
model=model,
messages=openai_messages,
stream=False,
temperature=options.get("temperature") if options else None,
max_tokens=options.get("num_ctx") if options else None,
seed=options.get("seed") if options else None,
**kwargs,
)
usage_stats = self._create_usage_stats(response.usage)
@ -394,11 +355,16 @@ class OpenAIAdapter(BaseLLMAdapter):
)
async def generate(
self, model: str, prompt: str, stream: bool = False, **kwargs
self,
model: str,
prompt: str,
stream: bool = False,
options: Optional[dict] = None,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
# Convert to chat format for OpenAI
messages = [LLMMessage(role="user", content=prompt)]
return await self.chat(model, messages, stream, **kwargs)
return await self.chat(model, messages, stream, options=options, **kwargs)
async def embeddings(self, model: str, input_texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
"""Generate embeddings using OpenAI"""
@ -448,8 +414,47 @@ class AnthropicAdapter(BaseLLMAdapter):
extra_stats={k: v for k, v in usage_dict.items() if k not in ["input_tokens", "output_tokens"]},
)
def _prepare_anthropic_kwargs(self, options: Optional[dict], **kwargs) -> dict:
"""Prepare kwargs for Anthropic API from ChatOptions and additional kwargs"""
anthropic_kwargs = dict(kwargs)
if options:
# Extract from dict
temperature = options.get("temperature")
num_ctx = options.get("num_ctx") or options.get("numCtx") # Handle both field names
seed = options.get("seed")
# Map to Anthropic parameters
if temperature is not None:
anthropic_kwargs["temperature"] = temperature
if num_ctx is not None:
# num_ctx maps to max_tokens for Anthropic
# Anthropic has a max context of 8192 tokens
if num_ctx > 8192:
logger.warning(
f"num_ctx ({num_ctx}) exceeds Anthropic's max context length (8192). "
"Setting max_tokens to 8192."
)
anthropic_kwargs["max_tokens"] = min(8192, num_ctx)
# Anthropic doesn't support seed, but we can try passing it
if seed is not None:
logger.warning("Anthropic does not officially support 'seed' parameter. ")
# Set default max_tokens if not provided
if "max_tokens" not in anthropic_kwargs:
anthropic_kwargs["max_tokens"] = 1000
return anthropic_kwargs
async def chat(
self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs
self,
model: str,
messages: List[LLMMessage],
stream: bool = False,
options: Optional[dict] = None,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
# Anthropic requires system message to be separate
system_message = None
@ -461,11 +466,13 @@ class AnthropicAdapter(BaseLLMAdapter):
else:
anthropic_messages.append({"role": msg.role, "content": msg.content})
# Prepare kwargs with options
prepared_kwargs = self._prepare_anthropic_kwargs(options, **kwargs)
request_kwargs = {
"model": model,
"messages": anthropic_messages,
"max_tokens": kwargs.pop("max_tokens", 1000),
**kwargs,
**prepared_kwargs,
}
if system_message:
@ -511,10 +518,15 @@ class AnthropicAdapter(BaseLLMAdapter):
logger.info(f"Could not retrieve final usage stats: {e}")
async def generate(
self, model: str, prompt: str, stream: bool = False, **kwargs
self,
model: str,
prompt: str,
stream: bool = False,
options: Optional[dict] = None,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
messages = [LLMMessage(role="user", content=prompt)]
return await self.chat(model, messages, stream, **kwargs)
return await self.chat(model, messages, stream, options, **kwargs)
async def embeddings(self, model: str, input_texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
"""Anthropic doesn't provide embeddings API"""
@ -524,7 +536,12 @@ class AnthropicAdapter(BaseLLMAdapter):
async def list_models(self) -> List[str]:
# Anthropic doesn't have a list models endpoint, return known models
return ["claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", "claude-3-opus-20240229"]
return [
"claude-3-5-haiku-latest",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-opus-20240229",
]
class GeminiAdapter(BaseLLMAdapter):
@ -564,7 +581,7 @@ class GeminiAdapter(BaseLLMAdapter):
return usage_stats
async def chat(
self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs
self, model: str, messages: List[LLMMessage], stream: bool = False, options=Optional[dict], **kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
model_instance = self.genai.GenerativeModel(model) # type: ignore
@ -629,7 +646,12 @@ class GeminiAdapter(BaseLLMAdapter):
)
async def generate(
self, model: str, prompt: str, stream: bool = False, **kwargs
self,
model: str,
prompt: str,
stream: bool = False,
options: Optional[dict] = None,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
messages = [LLMMessage(role="user", content=prompt)]
return await self.chat(model, messages, stream, **kwargs)
@ -737,6 +759,7 @@ class LLMCacheMixin:
model: str,
messages_or_prompt: Union[List[LLMMessage], str],
provider: Optional[LLMProvider] = None,
options: Optional[dict] = None,
**kwargs,
) -> str:
"""Create consistent cache key from parameters"""
@ -752,7 +775,13 @@ class LLMCacheMixin:
serializable_messages.append(msg)
messages_or_prompt = serializable_messages
cache_data = {"method": method, "model": model, "content": messages_or_prompt, "kwargs": cache_kwargs}
cache_data = {
"method": method,
"model": model,
"content": messages_or_prompt,
"options": options,
"kwargs": cache_kwargs,
}
if self.cache_config.provider_specific and provider:
cache_data["provider"] = provider.value
@ -935,7 +964,7 @@ class LLMCacheMixin:
class UnifiedLLMProxy(LLMCacheMixin):
"""Main proxy class that provides unified interface to all LLM providers"""
"""Main proxy class with automatic Ollama fallback for Anthropic embeddings"""
def __init__(self, redis_client: Redis, default_provider: LLMProvider = LLMProvider.OLLAMA):
# Initialize with caching support
@ -1019,6 +1048,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
messages: Union[List[LLMMessage], List[Dict[str, str]]],
provider: Optional[LLMProvider] = None,
stream: bool = False,
options: Optional[dict] = None,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Enhanced chat method with streaming cache support"""
@ -1037,7 +1067,12 @@ class UnifiedLLMProxy(LLMCacheMixin):
# Create cache key
cache_key = self._create_cache_key(
method="chat", model=model, messages_or_prompt=normalized_messages, provider=provider, **kwargs
method="chat",
model=model,
messages_or_prompt=normalized_messages,
provider=provider,
options=options,
**kwargs,
)
# Check cache
@ -1059,7 +1094,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
if stream:
# For streaming: we know this will return AsyncGenerator
stream_response = await adapter.chat(model, normalized_messages, stream=True, **kwargs)
stream_response = await adapter.chat(model, normalized_messages, stream=True, options=options, **kwargs)
stream_response = cast(AsyncGenerator[ChatResponse, None], stream_response)
if self.cache_config.enabled:
@ -1068,7 +1103,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
return stream_response
else:
# For non-streaming: we know this will return ChatResponse
single_response = await adapter.chat(model, normalized_messages, stream=False, **kwargs)
single_response = await adapter.chat(model, normalized_messages, stream=False, options=options, **kwargs)
single_response = cast(ChatResponse, single_response)
if self.cache_config.enabled:
@ -1082,6 +1117,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
prompt: str,
provider: Optional[LLMProvider] = None,
stream: bool = False,
options: Optional[dict] = None,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Enhanced generate method with streaming cache support"""
@ -1114,7 +1150,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
if stream:
# For streaming: we know this will return AsyncGenerator
stream_response = await adapter.generate(model, prompt, stream=True, **kwargs)
stream_response = await adapter.generate(model, prompt, stream=True, options=options, **kwargs)
stream_response = cast(AsyncGenerator[ChatResponse, None], stream_response)
if self.cache_config.enabled:
@ -1125,7 +1161,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
return stream_response
else:
# For non-streaming: we know this will return ChatResponse
single_response = await adapter.generate(model, prompt, stream=False, **kwargs)
single_response = await adapter.generate(model, prompt, stream=False, options=options, **kwargs)
single_response = cast(ChatResponse, single_response)
if self.cache_config.enabled:
await self._set_cached_response(cache_key, single_response)
@ -1145,6 +1181,13 @@ class UnifiedLLMProxy(LLMCacheMixin):
if provider in adapter_classes:
self.adapters[provider] = adapter_classes[provider](**config)
self._initialized_providers.add(provider)
# If configuring Anthropic, ensure Ollama is also configured for embeddings
if provider == LLMProvider.ANTHROPIC and LLMProvider.OLLAMA not in self._initialized_providers:
logger.warning(
"Anthropic provider configured but Ollama is not available for embeddings. "
"Embedding operations with Anthropic provider will fail until Ollama is configured."
)
else:
raise ValueError(f"Unsupported provider: {provider}")
@ -1160,12 +1203,13 @@ class UnifiedLLMProxy(LLMCacheMixin):
messages: Union[List[LLMMessage], List[Dict[str, str]]],
provider: Optional[LLMProvider] = None,
stream: bool = True,
options: Optional[dict] = None,
**kwargs,
) -> AsyncGenerator[ChatResponse, None]:
"""Stream chat messages using specified or default provider"""
if stream is False:
raise ValueError("stream must be True for chat_stream")
result = await self.chat(model, messages, provider, stream=True, **kwargs)
result = await self.chat(model, messages, provider, stream=True, options=options, **kwargs)
if isinstance(result, ChatResponse):
raise RuntimeError("Expected AsyncGenerator, got ChatResponse")
async for chunk in result:
@ -1176,32 +1220,43 @@ class UnifiedLLMProxy(LLMCacheMixin):
model: str,
messages: Union[List[LLMMessage], List[Dict[str, str]]],
provider: Optional[LLMProvider] = None,
options: Optional[dict] = None,
**kwargs,
) -> ChatResponse:
"""Get single chat response using specified or default provider"""
result = await self.chat(model, messages, provider, stream=False, **kwargs)
result = await self.chat(model, messages, provider, stream=False, options=options, **kwargs)
if not isinstance(result, ChatResponse):
raise RuntimeError("Expected ChatResponse, got AsyncGenerator")
return result
async def generate_stream(
self, model: str, prompt: str, provider: Optional[LLMProvider] = None, **kwargs
self,
model: str,
prompt: str,
provider: Optional[LLMProvider] = None,
options: Optional[dict] = None,
**kwargs,
) -> AsyncGenerator[ChatResponse, None]:
"""Stream text generation using specified or default provider"""
result = await self.generate(model, prompt, provider, stream=True, **kwargs)
result = await self.generate(model, prompt, provider, stream=True, options=options, **kwargs)
if isinstance(result, ChatResponse):
raise RuntimeError("Expected AsyncGenerator, got ChatResponse")
async for chunk in result:
yield chunk
async def generate_single(
self, model: str, prompt: str, provider: Optional[LLMProvider] = None, **kwargs
self,
model: str,
prompt: str,
provider: Optional[LLMProvider] = None,
options: Optional[dict] = None,
**kwargs,
) -> ChatResponse:
"""Get single generation response using specified or default provider"""
result = await self.generate(model, prompt, provider, stream=False, **kwargs)
result = await self.generate(model, prompt, provider, stream=False, options=options, **kwargs)
if not isinstance(result, ChatResponse):
raise RuntimeError("Expected ChatResponse, got AsyncGenerator")
return result
@ -1209,11 +1264,52 @@ class UnifiedLLMProxy(LLMCacheMixin):
async def embeddings(
self, model: str, input_texts: Union[str, List[str]], provider: Optional[LLMProvider] = None, **kwargs
) -> EmbeddingResponse:
"""Generate embeddings using specified or default provider"""
provider = provider or self.default_provider
adapter = self._get_adapter(provider)
"""Generate embeddings using specified or default provider.
return await adapter.embeddings(model, input_texts, **kwargs)
For Anthropic provider, automatically routes to Ollama since Anthropic doesn't provide embeddings API.
"""
provider = provider or self.default_provider
# Special handling: If Anthropic is requested for embeddings, use Ollama instead
if provider == LLMProvider.ANTHROPIC:
logger.info("Anthropic provider requested for embeddings - routing to Ollama provider")
# Ensure Ollama is available
if LLMProvider.OLLAMA not in self._initialized_providers:
raise RuntimeError(
"Anthropic provider requires Ollama for embeddings, but Ollama is not configured. "
"Please ensure Ollama endpoint is properly configured."
)
# Use Ollama for embeddings with a default embedding model if not specified
embedding_provider = LLMProvider.OLLAMA
# If no specific model provided, use a default Ollama embedding model
if not model or model in [
"claude-3-5-haiku-latest",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-opus-20240229",
]:
# Default to a common Ollama embedding model
model = os.getenv("OLLAMA_EMBEDDING_MODEL", "mxbai-embed-large")
logger.info(f"Using Ollama embedding model: {model}")
else:
embedding_provider = provider
adapter = self._get_adapter(embedding_provider)
try:
return await adapter.embeddings(model, input_texts, **kwargs)
except Exception as e:
if embedding_provider == LLMProvider.OLLAMA and provider == LLMProvider.ANTHROPIC:
raise RuntimeError(
f"Failed to generate embeddings using Ollama (fallback for Anthropic): {e}. "
f"Please ensure Ollama is running and the embedding model '{model}' is available."
) from e
else:
raise
async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
"""List available models for specified or default provider"""
@ -1224,21 +1320,32 @@ class UnifiedLLMProxy(LLMCacheMixin):
return await adapter.list_models()
async def list_embedding_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
"""List available embedding models for specified or default provider"""
"""List available embedding models for specified or default provider.
For Anthropic provider, returns Ollama embedding models since that's what will be used.
"""
provider = provider or self.default_provider
# Provider-specific embedding models
embedding_models = {
LLMProvider.OLLAMA: [
"mxbai-embed-large",
"nomic-embed-text",
"all-minilm",
],
LLMProvider.OPENAI: ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
LLMProvider.ANTHROPIC: [], # No embeddings API
LLMProvider.ANTHROPIC: [], # Will be handled specially below
LLMProvider.GEMINI: ["models/embedding-001", "models/text-embedding-004"],
}
# Special handling for Anthropic - return Ollama models since that's what will be used
if provider == LLMProvider.ANTHROPIC:
raise NotImplementedError("Anthropic does not provide embeddings API")
logger.info("Listing Ollama embedding models (used for Anthropic embedding requests)")
if LLMProvider.OLLAMA not in self._initialized_providers:
raise RuntimeError("Anthropic provider requires Ollama for embeddings, but Ollama is not configured.")
return await self.list_embedding_models(LLMProvider.OLLAMA)
# For Ollama, check which embedding models are actually available
if provider == LLMProvider.OLLAMA:
@ -1266,7 +1373,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
class LLMManager:
"""Updated manager with caching support"""
"""Updated manager with caching support and required Ollama endpoint"""
_instance = None
_proxy = None
@ -1283,24 +1390,42 @@ class LLMManager:
self._configure_from_environment()
def _configure_from_environment(self):
"""Configure providers and cache from environment variables"""
"""Configure providers and cache from environment variables - Ollama is now required"""
if not self._proxy:
raise RuntimeError("UnifiedLLMProxy instance not initialized")
# Your existing configuration code...
# REQUIRED: Configure Ollama provider first (mandatory)
ollama_host = os.getenv("OLLAMA_HOST", defines.ollama_api_url)
self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host)
if not ollama_host:
raise ValueError(
"Ollama endpoint is required. Please set OLLAMA_HOST environment variable "
"or ensure defines.ollama_api_url is configured."
)
logger.info(f"Configuring Ollama provider (REQUIRED) with host: {ollama_host}")
try:
self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host)
except Exception as e:
raise RuntimeError(f"Failed to configure required Ollama provider: {e}")
# Configure other providers as before
if os.getenv("OPENAI_API_KEY"):
logger.info("Configuring OpenAI provider")
self._proxy.configure_provider(LLMProvider.OPENAI)
if os.getenv("ANTHROPIC_API_KEY"):
logger.info("Configuring Anthropic provider (will use Ollama for embeddings)")
self._proxy.configure_provider(LLMProvider.ANTHROPIC)
if os.getenv("GEMINI_API_KEY"):
logger.info("Configuring Gemini provider")
self._proxy.configure_provider(LLMProvider.GEMINI)
# Add cache configuration from environment
# Validate that Ollama is properly configured
if LLMProvider.OLLAMA not in self._proxy._initialized_providers:
raise RuntimeError("Ollama provider must be successfully configured")
# Add cache configuration from environment (unchanged)
cache_config = {}
if os.getenv("LLM_CACHE_TTL"):
cache_config["ttl"] = int(os.getenv("LLM_CACHE_TTL", 31536000))
@ -1319,11 +1444,15 @@ class LLMManager:
if cache_config:
asyncio.create_task(self._proxy.configure_cache(**cache_config))
# Your existing default provider setup...
# Set default provider
default_provider = os.getenv("DEFAULT_LLM_PROVIDER", "ollama")
try:
logger.info(f"Setting default LLM provider to: {default_provider}")
self._proxy.set_default_provider(LLMProvider(default_provider))
except ValueError:
except ValueError as e:
logger.error(traceback.format_exc())
logger.error(f"Invalid default provider: {default_provider}. Error: {e}")
logger.info(f"Setting default LLM provider to fallback: {LLMProvider.OLLAMA}")
self._proxy.set_default_provider(LLMProvider.OLLAMA)
def get_proxy(self) -> UnifiedLLMProxy: