Added regen
Fixed Anthropic backend
This commit is contained in:
parent
a46172d696
commit
fd6ec5015c
@ -26,6 +26,7 @@ from rag import start_file_watcher, ChromaDBFileWatcher
|
||||
import defines
|
||||
from logger import logger
|
||||
from models import (
|
||||
ChatResponse,
|
||||
Tunables,
|
||||
ChatMessageUser,
|
||||
ChatMessage,
|
||||
@ -34,6 +35,7 @@ from models import (
|
||||
ApiStatusType,
|
||||
Candidate,
|
||||
ChatContextType,
|
||||
UsageStats,
|
||||
)
|
||||
import utils.llm_proxy as llm_manager
|
||||
from database.manager import RedisDatabase
|
||||
@ -158,12 +160,15 @@ class CandidateEntity(Candidate):
|
||||
raise ValueError("initialize() has not been called.")
|
||||
return self.CandidateEntity__observer
|
||||
|
||||
def collect_metrics(self, agent: Agent, response):
|
||||
def collect_metrics(self, agent: Agent, response: ChatResponse):
|
||||
if not self.metrics:
|
||||
logger.warning("No metrics collector set for this agent.")
|
||||
return
|
||||
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(response.usage.prompt_eval_count)
|
||||
self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count)
|
||||
if response.usage:
|
||||
if response.usage.prompt_tokens:
|
||||
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(response.usage.prompt_tokens)
|
||||
if response.usage.completion_tokens:
|
||||
self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.completion_tokens)
|
||||
|
||||
async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
|
||||
if self.CandidateEntity__initialized:
|
||||
@ -231,7 +236,6 @@ class Agent(BaseModel, ABC):
|
||||
|
||||
# Agent properties
|
||||
system_prompt: str = ""
|
||||
context_tokens: int = 0
|
||||
|
||||
# context_size is shared across all subclasses
|
||||
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
|
||||
@ -444,9 +448,6 @@ class Agent(BaseModel, ABC):
|
||||
# message.metadata.eval_duration += response.eval_duration
|
||||
# message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||
# message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||
# self.context_tokens = (
|
||||
# response.prompt_eval_count + response.eval_count
|
||||
# )
|
||||
# message.status = "done"
|
||||
# yield message
|
||||
|
||||
@ -617,7 +618,6 @@ Content: {content}
|
||||
return
|
||||
|
||||
self.user.collect_metrics(agent=self, response=response)
|
||||
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
|
||||
|
||||
chat_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
@ -626,10 +626,12 @@ Content: {content}
|
||||
content=content,
|
||||
metadata=ChatMessageMetaData(
|
||||
options=options,
|
||||
eval_count=response.usage.eval_count,
|
||||
eval_duration=response.usage.eval_duration,
|
||||
prompt_eval_count=response.usage.prompt_eval_count,
|
||||
prompt_eval_duration=response.usage.prompt_eval_duration,
|
||||
usage=UsageStats(
|
||||
eval_count=response.usage.eval_count,
|
||||
eval_duration=response.usage.eval_duration,
|
||||
prompt_eval_count=response.usage.prompt_eval_count,
|
||||
prompt_eval_duration=response.usage.prompt_eval_duration,
|
||||
),
|
||||
),
|
||||
)
|
||||
yield chat_message
|
||||
@ -848,7 +850,6 @@ Content: {content}
|
||||
return
|
||||
|
||||
self.user.collect_metrics(agent=self, response=response)
|
||||
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
|
||||
end_time = time.perf_counter()
|
||||
|
||||
chat_message = ChatMessage(
|
||||
@ -858,10 +859,12 @@ Content: {content}
|
||||
content=content,
|
||||
metadata=ChatMessageMetaData(
|
||||
options=options,
|
||||
eval_count=response.usage.eval_count,
|
||||
eval_duration=response.usage.eval_duration,
|
||||
prompt_eval_count=response.usage.prompt_eval_count,
|
||||
prompt_eval_duration=response.usage.prompt_eval_duration,
|
||||
usage=UsageStats(
|
||||
eval_count=response.usage.eval_count,
|
||||
eval_duration=response.usage.eval_duration,
|
||||
prompt_eval_count=response.usage.prompt_eval_count,
|
||||
prompt_eval_duration=response.usage.prompt_eval_duration,
|
||||
),
|
||||
rag_results=rag_message.content if rag_message else [],
|
||||
llm_history=messages,
|
||||
timers={
|
||||
|
@ -189,7 +189,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
|
||||
return
|
||||
# Stage 1A: Analyze job requirements
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
|
||||
session_id=session_id, content="Generating resume...", activity=ApiActivityType.THINKING
|
||||
)
|
||||
yield status_message
|
||||
|
||||
|
@ -46,7 +46,7 @@ experiences, and qualifications required in a job description.
|
||||
## INSTRUCTIONS:
|
||||
|
||||
1. Analyze ONLY the <|job_description|> provided, and provide only requirements from that description.
|
||||
2. Extract company information, job title, and all requirements.
|
||||
2. Extract company information, job title, pay range, and all requirements.
|
||||
3. If a requirement can be broken into multiple requirements, do so.
|
||||
4. Categorize each requirement into one and only one of the following categories:
|
||||
- Technical skills (required and preferred)
|
||||
@ -68,6 +68,8 @@ experiences, and qualifications required in a job description.
|
||||
"company_name": "Company Name",
|
||||
"job_title": "Job Title",
|
||||
"job_summary": "Brief summary of the job",
|
||||
"pay_minimum": "Minimum pay range",
|
||||
"pay_maximum": "Maximum pay range",
|
||||
"job_requirements": {
|
||||
"technical_skills": {
|
||||
"required": ["skill1", "skill2"],
|
||||
@ -197,6 +199,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
company = ""
|
||||
summary = ""
|
||||
title = ""
|
||||
pay_maximum = None
|
||||
pay_minimum = None
|
||||
try:
|
||||
json_str = self.extract_json_from_text(generated_message.content)
|
||||
requirements_json = json.loads(json_str)
|
||||
@ -205,6 +209,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
title = requirements_json.get("job_title", "")
|
||||
summary = requirements_json.get("job_summary", "")
|
||||
job_requirements_data = requirements_json.get("job_requirements", None)
|
||||
pay_minimum = requirements_json.get("pay_minimum", None)
|
||||
pay_maximum = requirements_json.get("pay_maximum", None)
|
||||
requirements = JobRequirements.model_validate(job_requirements_data)
|
||||
except json.JSONDecodeError as e:
|
||||
status_message.status = ApiStatusType.ERROR
|
||||
@ -238,6 +244,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
company=company,
|
||||
title=title,
|
||||
summary=summary,
|
||||
pay_minimum=pay_minimum,
|
||||
pay_maximum=pay_maximum,
|
||||
requirements=requirements,
|
||||
description=prompt,
|
||||
)
|
||||
|
@ -13,12 +13,14 @@ from database.core import RedisDatabase
|
||||
|
||||
from .base import Agent, agent_registry
|
||||
from models import (
|
||||
ApiActivityType,
|
||||
ApiMessage,
|
||||
ChatMessage,
|
||||
ChatMessageError,
|
||||
ChatMessageRagSearch,
|
||||
ChatMessageSkillAssessment,
|
||||
ApiStatusType,
|
||||
ChatMessageStatus,
|
||||
EvidenceDetail,
|
||||
SkillAssessment,
|
||||
Tunables,
|
||||
@ -162,6 +164,11 @@ JSON RESPONSE:"""
|
||||
logger.info(f"🔍 RAG content retrieved {len(rag_context)} bytes of context")
|
||||
system_prompt, prompt = self.generate_skill_assessment_prompt(skill=skill, rag_context=rag_context)
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating skill assessment..."
|
||||
)
|
||||
yield status_message
|
||||
|
||||
generated_message = None
|
||||
async for generated_message in self.llm_one_shot(
|
||||
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
|
||||
ollama_api_url = "http://ollama:11434" # Default Ollama local endpoint
|
||||
ollama_api_url = os.getenv("OPENAI_URL", "http://ollama:11434") # Default Ollama local endpoint
|
||||
|
||||
frontend_url = os.getenv("FRONTEND_URL", "https://backstory.ketrenos.com")
|
||||
|
||||
|
@ -2,6 +2,7 @@ from typing import List, Dict, Optional, Any, Union, Literal, TypeVar, Annotated
|
||||
from pydantic import BaseModel, Field, EmailStr, HttpUrl, model_validator, field_validator, ConfigDict
|
||||
from datetime import datetime, UTC
|
||||
from enum import Enum
|
||||
import os
|
||||
import uuid
|
||||
from utils.auth_utils import (
|
||||
validate_password_strength,
|
||||
@ -800,6 +801,8 @@ class Job(BaseModel):
|
||||
title: Optional[str]
|
||||
summary: Optional[str]
|
||||
company: Optional[str]
|
||||
pay_minimum: Optional[str] = Field(default=None, alias=str("payMinimum"))
|
||||
pay_maximum: Optional[str] = Field(default=None, alias=str("payMaximum"))
|
||||
description: str
|
||||
requirements: Optional[JobRequirements]
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), alias=str("createdAt"))
|
||||
@ -886,6 +889,69 @@ class ChatOptions(BaseModel):
|
||||
temperature: Optional[float] = Field(default=0.7) # Higher temperature to encourage tool usage
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
# Enhanced usage statistics model
|
||||
class UsageStats(BaseModel):
|
||||
"""Comprehensive usage statistics across all providers"""
|
||||
|
||||
# Token counts (standardized across providers)
|
||||
prompt_tokens: Optional[int] = Field(default=None, description="Number of tokens in the prompt")
|
||||
completion_tokens: Optional[int] = Field(default=None, description="Number of tokens in the completion")
|
||||
total_tokens: Optional[int] = Field(default=None, description="Total number of tokens used")
|
||||
|
||||
# Ollama-specific detailed stats
|
||||
prompt_eval_count: Optional[int] = Field(default=None, description="Number of tokens evaluated in prompt")
|
||||
prompt_eval_duration: Optional[int] = Field(default=None, description="Time spent evaluating prompt (nanoseconds)")
|
||||
eval_count: Optional[int] = Field(default=None, description="Number of tokens generated")
|
||||
eval_duration: Optional[int] = Field(default=None, description="Time spent generating tokens (nanoseconds)")
|
||||
total_duration: Optional[int] = Field(default=None, description="Total request duration (nanoseconds)")
|
||||
|
||||
# Performance metrics
|
||||
tokens_per_second: Optional[float] = Field(default=None, description="Generation speed in tokens/second")
|
||||
prompt_tokens_per_second: Optional[float] = Field(default=None, description="Prompt processing speed")
|
||||
|
||||
# Additional provider-specific stats
|
||||
extra_stats: Optional[Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Provider-specific additional statistics"
|
||||
)
|
||||
|
||||
def calculate_derived_stats(self) -> None:
|
||||
"""Calculate derived statistics where possible"""
|
||||
# Calculate tokens per second for Ollama
|
||||
if self.eval_count and self.eval_duration and self.eval_duration > 0:
|
||||
# Convert nanoseconds to seconds and calculate tokens/sec
|
||||
duration_seconds = self.eval_duration / 1_000_000_000
|
||||
self.tokens_per_second = self.eval_count / duration_seconds
|
||||
|
||||
# Calculate prompt processing speed for Ollama
|
||||
if self.prompt_eval_count and self.prompt_eval_duration and self.prompt_eval_duration > 0:
|
||||
duration_seconds = self.prompt_eval_duration / 1_000_000_000
|
||||
self.prompt_tokens_per_second = self.prompt_eval_count / duration_seconds
|
||||
|
||||
# Standardize token counts across providers
|
||||
if not self.total_tokens and self.prompt_tokens and self.completion_tokens:
|
||||
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
||||
|
||||
# Map Ollama counts to standard format if not already set
|
||||
if not self.prompt_tokens and self.prompt_eval_count:
|
||||
self.prompt_tokens = self.prompt_eval_count
|
||||
if not self.completion_tokens and self.eval_count:
|
||||
self.completion_tokens = self.eval_count
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
content: str
|
||||
model: str
|
||||
finish_reason: Optional[str] = Field(default="")
|
||||
usage: Optional[UsageStats] = Field(default=None)
|
||||
# Keep legacy usage field for backward compatibility
|
||||
usage_legacy: Optional[Dict[str, int]] = Field(default=None, alias="usage_dict")
|
||||
|
||||
def get_usage_dict(self) -> Dict[str, Any]:
|
||||
"""Get usage statistics as dictionary for backward compatibility"""
|
||||
if self.usage:
|
||||
return self.usage.model_dump(exclude_none=True)
|
||||
return self.usage_legacy or {}
|
||||
|
||||
|
||||
# Add rate limiting configuration models
|
||||
class RateLimitConfig(BaseModel):
|
||||
@ -1063,15 +1129,17 @@ class ChatMessageMetaData(BaseModel):
|
||||
stop_sequences: List[str] = Field(default=[], alias=str("stopSequences"))
|
||||
rag_results: List[ChromaDBGetResponse] = Field(default_factory=list, alias=str("ragResults"))
|
||||
llm_history: List[LLMMessage] = Field(default_factory=list, alias=str("llmHistory"))
|
||||
eval_count: int = 0
|
||||
eval_duration: int = 0
|
||||
prompt_eval_count: int = 0
|
||||
prompt_eval_duration: int = 0
|
||||
usage: UsageStats = Field(default_factory=UsageStats)
|
||||
options: Optional[ChatOptions] = None
|
||||
tools: Optional[Tool] = None
|
||||
timers: Dict[str, float] = Field(default_factory=dict)
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
class SkillMatchRequest(BaseModel):
|
||||
"""Request model for skill match"""
|
||||
|
||||
skill: str
|
||||
regenerate: bool = Field(default=False, description="Whether to regenerate the skill match even if cached")
|
||||
|
||||
class ChatMessageUser(ApiMessage):
|
||||
type: ApiMessageType = ApiMessageType.TEXT
|
||||
@ -1126,6 +1194,7 @@ class SystemInfo(BaseModel):
|
||||
installed_RAM: str = Field(..., alias=str("installedRAM"))
|
||||
graphics_cards: List[GPUInfo] = Field(..., alias=str("graphicsCards"))
|
||||
CPU: str
|
||||
llm_backend: str = Field(default=os.getenv("DEFAULT_LLM_PROVIDER", "openai"), alias=str("llmBackend"))
|
||||
llm_model: str = Field(default=defines.model, alias=str("llmModel"))
|
||||
embedding_model: str = Field(default=defines.embedding_model, alias=str("embeddingModel"))
|
||||
max_context_length: int = Field(default=defines.max_context, alias=str("maxContextLength"))
|
||||
|
@ -59,6 +59,7 @@ from models import (
|
||||
RagContentMetadata,
|
||||
RagContentResponse,
|
||||
SkillAssessment,
|
||||
SkillMatchRequest,
|
||||
SkillStrength,
|
||||
UserType,
|
||||
)
|
||||
@ -1545,11 +1546,10 @@ async def post_job_analysis(
|
||||
logger.error(f"❌ Get candidate job analysis error: {e}")
|
||||
return JSONResponse(status_code=500, content=create_error_response("JOB_ANALYSIS_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/{candidate_id}/skill-match")
|
||||
async def get_candidate_skill_match(
|
||||
candidate_id: str = Path(...),
|
||||
skill: str = Body(...),
|
||||
request: SkillMatchRequest = Body(...),
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
) -> StreamingResponse:
|
||||
@ -1567,10 +1567,14 @@ async def get_candidate_skill_match(
|
||||
|
||||
candidate = Candidate.model_validate(candidate_data)
|
||||
|
||||
skill = request.skill.strip()
|
||||
|
||||
cache_key = get_skill_cache_key(candidate.id, skill)
|
||||
|
||||
# Get cached assessment if it exists
|
||||
assessment: SkillAssessment | None = await database.get_cached_skill_match(cache_key)
|
||||
assessment: SkillAssessment | None = None
|
||||
if not request.regenerate:
|
||||
assessment = await database.get_cached_skill_match(cache_key)
|
||||
|
||||
if assessment and assessment.skill.lower() != skill.lower():
|
||||
logger.warning(
|
||||
|
@ -2,6 +2,7 @@ import defines
|
||||
import re
|
||||
import subprocess
|
||||
import math
|
||||
import os
|
||||
from models import GPUInfo, SystemInfo
|
||||
|
||||
|
||||
@ -83,6 +84,7 @@ def system_info() -> SystemInfo:
|
||||
installed_RAM=get_installed_ram(),
|
||||
graphics_cards=get_graphics_cards(),
|
||||
CPU=get_cpu_info(),
|
||||
llm_backend=os.getenv("DEFAULT_LLM_PROVIDER", "openai"),
|
||||
llm_model=defines.model,
|
||||
embedding_model=defines.embedding_model,
|
||||
max_context_length=defines.max_context,
|
||||
|
@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, List, Any, AsyncGenerator, Optional, Union, cast
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
@ -8,6 +9,7 @@ from dataclasses import dataclass
|
||||
import os
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from models import ChatResponse, UsageStats
|
||||
import defines
|
||||
from logger import logger
|
||||
import hashlib
|
||||
@ -48,55 +50,6 @@ class LLMMessage:
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
# Enhanced usage statistics model
|
||||
class UsageStats(BaseModel):
|
||||
"""Comprehensive usage statistics across all providers"""
|
||||
|
||||
# Token counts (standardized across providers)
|
||||
prompt_tokens: Optional[int] = Field(default=None, description="Number of tokens in the prompt")
|
||||
completion_tokens: Optional[int] = Field(default=None, description="Number of tokens in the completion")
|
||||
total_tokens: Optional[int] = Field(default=None, description="Total number of tokens used")
|
||||
|
||||
# Ollama-specific detailed stats
|
||||
prompt_eval_count: Optional[int] = Field(default=None, description="Number of tokens evaluated in prompt")
|
||||
prompt_eval_duration: Optional[int] = Field(default=None, description="Time spent evaluating prompt (nanoseconds)")
|
||||
eval_count: Optional[int] = Field(default=None, description="Number of tokens generated")
|
||||
eval_duration: Optional[int] = Field(default=None, description="Time spent generating tokens (nanoseconds)")
|
||||
total_duration: Optional[int] = Field(default=None, description="Total request duration (nanoseconds)")
|
||||
|
||||
# Performance metrics
|
||||
tokens_per_second: Optional[float] = Field(default=None, description="Generation speed in tokens/second")
|
||||
prompt_tokens_per_second: Optional[float] = Field(default=None, description="Prompt processing speed")
|
||||
|
||||
# Additional provider-specific stats
|
||||
extra_stats: Optional[Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Provider-specific additional statistics"
|
||||
)
|
||||
|
||||
def calculate_derived_stats(self) -> None:
|
||||
"""Calculate derived statistics where possible"""
|
||||
# Calculate tokens per second for Ollama
|
||||
if self.eval_count and self.eval_duration and self.eval_duration > 0:
|
||||
# Convert nanoseconds to seconds and calculate tokens/sec
|
||||
duration_seconds = self.eval_duration / 1_000_000_000
|
||||
self.tokens_per_second = self.eval_count / duration_seconds
|
||||
|
||||
# Calculate prompt processing speed for Ollama
|
||||
if self.prompt_eval_count and self.prompt_eval_duration and self.prompt_eval_duration > 0:
|
||||
duration_seconds = self.prompt_eval_duration / 1_000_000_000
|
||||
self.prompt_tokens_per_second = self.prompt_eval_count / duration_seconds
|
||||
|
||||
# Standardize token counts across providers
|
||||
if not self.total_tokens and self.prompt_tokens and self.completion_tokens:
|
||||
self.total_tokens = self.prompt_tokens + self.completion_tokens
|
||||
|
||||
# Map Ollama counts to standard format if not already set
|
||||
if not self.prompt_tokens and self.prompt_eval_count:
|
||||
self.prompt_tokens = self.prompt_eval_count
|
||||
if not self.completion_tokens and self.eval_count:
|
||||
self.completion_tokens = self.eval_count
|
||||
|
||||
|
||||
# Embedding response models
|
||||
class EmbeddingData(BaseModel):
|
||||
"""Single embedding result"""
|
||||
@ -123,21 +76,6 @@ class EmbeddingResponse(BaseModel):
|
||||
return [item.embedding for item in self.data]
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
content: str
|
||||
model: str
|
||||
finish_reason: Optional[str] = Field(default="")
|
||||
usage: Optional[UsageStats] = Field(default=None)
|
||||
# Keep legacy usage field for backward compatibility
|
||||
usage_legacy: Optional[Dict[str, int]] = Field(default=None, alias="usage_dict")
|
||||
|
||||
def get_usage_dict(self) -> Dict[str, Any]:
|
||||
"""Get usage statistics as dictionary for backward compatibility"""
|
||||
if self.usage:
|
||||
return self.usage.model_dump(exclude_none=True)
|
||||
return self.usage_legacy or {}
|
||||
|
||||
|
||||
class LLMProvider(str, Enum):
|
||||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
@ -154,13 +92,23 @@ class BaseLLMAdapter(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
stream: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
"""Send chat messages and get response"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self, model: str, prompt: str, stream: bool = False, **kwargs
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
stream: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
"""Generate text from prompt"""
|
||||
|
||||
@ -223,7 +171,7 @@ class OllamaAdapter(BaseLLMAdapter):
|
||||
return usage_stats
|
||||
|
||||
async def chat(
|
||||
self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs
|
||||
self, model: str, messages: List[LLMMessage], stream: bool = False, options: Optional[dict] = None, **kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
# Convert LLMMessage objects to Ollama format
|
||||
ollama_messages = []
|
||||
@ -231,9 +179,11 @@ class OllamaAdapter(BaseLLMAdapter):
|
||||
ollama_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
if stream:
|
||||
return self._stream_chat(model, ollama_messages, **kwargs)
|
||||
return self._stream_chat(model, ollama_messages, options=options, **kwargs)
|
||||
else:
|
||||
response = await self.client.chat(model=model, messages=ollama_messages, stream=False, **kwargs)
|
||||
response = await self.client.chat(
|
||||
model=model, messages=ollama_messages, stream=False, options=options, **kwargs
|
||||
)
|
||||
|
||||
usage_stats = self._create_usage_stats(response.model_dump())
|
||||
|
||||
@ -244,9 +194,9 @@ class OllamaAdapter(BaseLLMAdapter):
|
||||
usage=usage_stats,
|
||||
)
|
||||
|
||||
async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
|
||||
async def _stream_chat(self, model: str, messages: List[Dict], options: Optional[dict] = None, **kwargs):
|
||||
# Await the chat call first, then iterate over the result
|
||||
stream = await self.client.chat(model=model, messages=messages, stream=True, **kwargs)
|
||||
stream = await self.client.chat(model=model, messages=messages, stream=True, options=options, **kwargs)
|
||||
|
||||
# Accumulate stats for final chunk
|
||||
accumulated_stats = {}
|
||||
@ -265,12 +215,17 @@ class OllamaAdapter(BaseLLMAdapter):
|
||||
yield ChatResponse(content=content, model=model, finish_reason=None, usage=None)
|
||||
|
||||
async def generate(
|
||||
self, model: str, prompt: str, stream: bool = False, **kwargs
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
stream: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
if stream:
|
||||
return self._stream_generate(model, prompt, **kwargs)
|
||||
return self._stream_generate(model, prompt, options=options, **kwargs)
|
||||
else:
|
||||
response = await self.client.generate(model=model, prompt=prompt, stream=False, **kwargs)
|
||||
response = await self.client.generate(model=model, prompt=prompt, stream=False, options=options, **kwargs)
|
||||
|
||||
usage_stats = self._create_usage_stats(response.model_dump())
|
||||
|
||||
@ -278,9 +233,9 @@ class OllamaAdapter(BaseLLMAdapter):
|
||||
content=response["response"], model=model, finish_reason=response.get("done_reason"), usage=usage_stats
|
||||
)
|
||||
|
||||
async def _stream_generate(self, model: str, prompt: str, **kwargs):
|
||||
async def _stream_generate(self, model: str, prompt: str, options: Optional[dict] = None, **kwargs):
|
||||
# Await the generate call first, then iterate over the result
|
||||
stream = await self.client.generate(model=model, prompt=prompt, stream=True, **kwargs)
|
||||
stream = await self.client.generate(model=model, prompt=prompt, stream=True, options=options, **kwargs)
|
||||
|
||||
accumulated_stats = {}
|
||||
|
||||
@ -352,7 +307,7 @@ class OpenAIAdapter(BaseLLMAdapter):
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs
|
||||
self, model: str, messages: List[LLMMessage], stream: bool = False, options: Optional[dict] = None, **kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
# Convert LLMMessage objects to OpenAI format
|
||||
openai_messages = []
|
||||
@ -360,10 +315,16 @@ class OpenAIAdapter(BaseLLMAdapter):
|
||||
openai_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
if stream:
|
||||
return self._stream_chat(model, openai_messages, **kwargs)
|
||||
return self._stream_chat(model, openai_messages, options=options, **kwargs)
|
||||
else:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model, messages=openai_messages, stream=False, **kwargs
|
||||
model=model,
|
||||
messages=openai_messages,
|
||||
stream=False,
|
||||
temperature=options.get("temperature") if options else None,
|
||||
max_tokens=options.get("num_ctx") if options else None,
|
||||
seed=options.get("seed") if options else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
usage_stats = self._create_usage_stats(response.usage)
|
||||
@ -394,11 +355,16 @@ class OpenAIAdapter(BaseLLMAdapter):
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self, model: str, prompt: str, stream: bool = False, **kwargs
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
stream: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
# Convert to chat format for OpenAI
|
||||
messages = [LLMMessage(role="user", content=prompt)]
|
||||
return await self.chat(model, messages, stream, **kwargs)
|
||||
return await self.chat(model, messages, stream, options=options, **kwargs)
|
||||
|
||||
async def embeddings(self, model: str, input_texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
|
||||
"""Generate embeddings using OpenAI"""
|
||||
@ -448,8 +414,47 @@ class AnthropicAdapter(BaseLLMAdapter):
|
||||
extra_stats={k: v for k, v in usage_dict.items() if k not in ["input_tokens", "output_tokens"]},
|
||||
)
|
||||
|
||||
def _prepare_anthropic_kwargs(self, options: Optional[dict], **kwargs) -> dict:
|
||||
"""Prepare kwargs for Anthropic API from ChatOptions and additional kwargs"""
|
||||
anthropic_kwargs = dict(kwargs)
|
||||
|
||||
if options:
|
||||
# Extract from dict
|
||||
temperature = options.get("temperature")
|
||||
num_ctx = options.get("num_ctx") or options.get("numCtx") # Handle both field names
|
||||
seed = options.get("seed")
|
||||
|
||||
# Map to Anthropic parameters
|
||||
if temperature is not None:
|
||||
anthropic_kwargs["temperature"] = temperature
|
||||
|
||||
if num_ctx is not None:
|
||||
# num_ctx maps to max_tokens for Anthropic
|
||||
# Anthropic has a max context of 8192 tokens
|
||||
if num_ctx > 8192:
|
||||
logger.warning(
|
||||
f"num_ctx ({num_ctx}) exceeds Anthropic's max context length (8192). "
|
||||
"Setting max_tokens to 8192."
|
||||
)
|
||||
anthropic_kwargs["max_tokens"] = min(8192, num_ctx)
|
||||
|
||||
# Anthropic doesn't support seed, but we can try passing it
|
||||
if seed is not None:
|
||||
logger.warning("Anthropic does not officially support 'seed' parameter. ")
|
||||
|
||||
# Set default max_tokens if not provided
|
||||
if "max_tokens" not in anthropic_kwargs:
|
||||
anthropic_kwargs["max_tokens"] = 1000
|
||||
|
||||
return anthropic_kwargs
|
||||
|
||||
async def chat(
|
||||
self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
stream: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
# Anthropic requires system message to be separate
|
||||
system_message = None
|
||||
@ -461,11 +466,13 @@ class AnthropicAdapter(BaseLLMAdapter):
|
||||
else:
|
||||
anthropic_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
# Prepare kwargs with options
|
||||
prepared_kwargs = self._prepare_anthropic_kwargs(options, **kwargs)
|
||||
|
||||
request_kwargs = {
|
||||
"model": model,
|
||||
"messages": anthropic_messages,
|
||||
"max_tokens": kwargs.pop("max_tokens", 1000),
|
||||
**kwargs,
|
||||
**prepared_kwargs,
|
||||
}
|
||||
|
||||
if system_message:
|
||||
@ -511,10 +518,15 @@ class AnthropicAdapter(BaseLLMAdapter):
|
||||
logger.info(f"Could not retrieve final usage stats: {e}")
|
||||
|
||||
async def generate(
|
||||
self, model: str, prompt: str, stream: bool = False, **kwargs
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
stream: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
messages = [LLMMessage(role="user", content=prompt)]
|
||||
return await self.chat(model, messages, stream, **kwargs)
|
||||
return await self.chat(model, messages, stream, options, **kwargs)
|
||||
|
||||
async def embeddings(self, model: str, input_texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
|
||||
"""Anthropic doesn't provide embeddings API"""
|
||||
@ -524,7 +536,12 @@ class AnthropicAdapter(BaseLLMAdapter):
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
# Anthropic doesn't have a list models endpoint, return known models
|
||||
return ["claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", "claude-3-opus-20240229"]
|
||||
return [
|
||||
"claude-3-5-haiku-latest",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-opus-20240229",
|
||||
]
|
||||
|
||||
|
||||
class GeminiAdapter(BaseLLMAdapter):
|
||||
@ -564,7 +581,7 @@ class GeminiAdapter(BaseLLMAdapter):
|
||||
return usage_stats
|
||||
|
||||
async def chat(
|
||||
self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs
|
||||
self, model: str, messages: List[LLMMessage], stream: bool = False, options=Optional[dict], **kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
model_instance = self.genai.GenerativeModel(model) # type: ignore
|
||||
|
||||
@ -629,7 +646,12 @@ class GeminiAdapter(BaseLLMAdapter):
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self, model: str, prompt: str, stream: bool = False, **kwargs
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
stream: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
messages = [LLMMessage(role="user", content=prompt)]
|
||||
return await self.chat(model, messages, stream, **kwargs)
|
||||
@ -737,6 +759,7 @@ class LLMCacheMixin:
|
||||
model: str,
|
||||
messages_or_prompt: Union[List[LLMMessage], str],
|
||||
provider: Optional[LLMProvider] = None,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Create consistent cache key from parameters"""
|
||||
@ -752,7 +775,13 @@ class LLMCacheMixin:
|
||||
serializable_messages.append(msg)
|
||||
messages_or_prompt = serializable_messages
|
||||
|
||||
cache_data = {"method": method, "model": model, "content": messages_or_prompt, "kwargs": cache_kwargs}
|
||||
cache_data = {
|
||||
"method": method,
|
||||
"model": model,
|
||||
"content": messages_or_prompt,
|
||||
"options": options,
|
||||
"kwargs": cache_kwargs,
|
||||
}
|
||||
|
||||
if self.cache_config.provider_specific and provider:
|
||||
cache_data["provider"] = provider.value
|
||||
@ -935,7 +964,7 @@ class LLMCacheMixin:
|
||||
|
||||
|
||||
class UnifiedLLMProxy(LLMCacheMixin):
|
||||
"""Main proxy class that provides unified interface to all LLM providers"""
|
||||
"""Main proxy class with automatic Ollama fallback for Anthropic embeddings"""
|
||||
|
||||
def __init__(self, redis_client: Redis, default_provider: LLMProvider = LLMProvider.OLLAMA):
|
||||
# Initialize with caching support
|
||||
@ -1019,6 +1048,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
|
||||
messages: Union[List[LLMMessage], List[Dict[str, str]]],
|
||||
provider: Optional[LLMProvider] = None,
|
||||
stream: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
"""Enhanced chat method with streaming cache support"""
|
||||
@ -1037,7 +1067,12 @@ class UnifiedLLMProxy(LLMCacheMixin):
|
||||
|
||||
# Create cache key
|
||||
cache_key = self._create_cache_key(
|
||||
method="chat", model=model, messages_or_prompt=normalized_messages, provider=provider, **kwargs
|
||||
method="chat",
|
||||
model=model,
|
||||
messages_or_prompt=normalized_messages,
|
||||
provider=provider,
|
||||
options=options,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Check cache
|
||||
@ -1059,7 +1094,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
|
||||
|
||||
if stream:
|
||||
# For streaming: we know this will return AsyncGenerator
|
||||
stream_response = await adapter.chat(model, normalized_messages, stream=True, **kwargs)
|
||||
stream_response = await adapter.chat(model, normalized_messages, stream=True, options=options, **kwargs)
|
||||
stream_response = cast(AsyncGenerator[ChatResponse, None], stream_response)
|
||||
|
||||
if self.cache_config.enabled:
|
||||
@ -1068,7 +1103,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
|
||||
return stream_response
|
||||
else:
|
||||
# For non-streaming: we know this will return ChatResponse
|
||||
single_response = await adapter.chat(model, normalized_messages, stream=False, **kwargs)
|
||||
single_response = await adapter.chat(model, normalized_messages, stream=False, options=options, **kwargs)
|
||||
single_response = cast(ChatResponse, single_response)
|
||||
|
||||
if self.cache_config.enabled:
|
||||
@ -1082,6 +1117,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
|
||||
prompt: str,
|
||||
provider: Optional[LLMProvider] = None,
|
||||
stream: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
"""Enhanced generate method with streaming cache support"""
|
||||
@ -1114,7 +1150,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
|
||||
|
||||
if stream:
|
||||
# For streaming: we know this will return AsyncGenerator
|
||||
stream_response = await adapter.generate(model, prompt, stream=True, **kwargs)
|
||||
stream_response = await adapter.generate(model, prompt, stream=True, options=options, **kwargs)
|
||||
stream_response = cast(AsyncGenerator[ChatResponse, None], stream_response)
|
||||
|
||||
if self.cache_config.enabled:
|
||||
@ -1125,7 +1161,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
|
||||
return stream_response
|
||||
else:
|
||||
# For non-streaming: we know this will return ChatResponse
|
||||
single_response = await adapter.generate(model, prompt, stream=False, **kwargs)
|
||||
single_response = await adapter.generate(model, prompt, stream=False, options=options, **kwargs)
|
||||
single_response = cast(ChatResponse, single_response)
|
||||
if self.cache_config.enabled:
|
||||
await self._set_cached_response(cache_key, single_response)
|
||||
@ -1145,6 +1181,13 @@ class UnifiedLLMProxy(LLMCacheMixin):
|
||||
if provider in adapter_classes:
|
||||
self.adapters[provider] = adapter_classes[provider](**config)
|
||||
self._initialized_providers.add(provider)
|
||||
|
||||
# If configuring Anthropic, ensure Ollama is also configured for embeddings
|
||||
if provider == LLMProvider.ANTHROPIC and LLMProvider.OLLAMA not in self._initialized_providers:
|
||||
logger.warning(
|
||||
"Anthropic provider configured but Ollama is not available for embeddings. "
|
||||
"Embedding operations with Anthropic provider will fail until Ollama is configured."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
@ -1160,12 +1203,13 @@ class UnifiedLLMProxy(LLMCacheMixin):
|
||||
messages: Union[List[LLMMessage], List[Dict[str, str]]],
|
||||
provider: Optional[LLMProvider] = None,
|
||||
stream: bool = True,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Stream chat messages using specified or default provider"""
|
||||
if stream is False:
|
||||
raise ValueError("stream must be True for chat_stream")
|
||||
result = await self.chat(model, messages, provider, stream=True, **kwargs)
|
||||
result = await self.chat(model, messages, provider, stream=True, options=options, **kwargs)
|
||||
if isinstance(result, ChatResponse):
|
||||
raise RuntimeError("Expected AsyncGenerator, got ChatResponse")
|
||||
async for chunk in result:
|
||||
@ -1176,32 +1220,43 @@ class UnifiedLLMProxy(LLMCacheMixin):
|
||||
model: str,
|
||||
messages: Union[List[LLMMessage], List[Dict[str, str]]],
|
||||
provider: Optional[LLMProvider] = None,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> ChatResponse:
|
||||
"""Get single chat response using specified or default provider"""
|
||||
|
||||
result = await self.chat(model, messages, provider, stream=False, **kwargs)
|
||||
result = await self.chat(model, messages, provider, stream=False, options=options, **kwargs)
|
||||
if not isinstance(result, ChatResponse):
|
||||
raise RuntimeError("Expected ChatResponse, got AsyncGenerator")
|
||||
return result
|
||||
|
||||
async def generate_stream(
|
||||
self, model: str, prompt: str, provider: Optional[LLMProvider] = None, **kwargs
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
provider: Optional[LLMProvider] = None,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Stream text generation using specified or default provider"""
|
||||
|
||||
result = await self.generate(model, prompt, provider, stream=True, **kwargs)
|
||||
result = await self.generate(model, prompt, provider, stream=True, options=options, **kwargs)
|
||||
if isinstance(result, ChatResponse):
|
||||
raise RuntimeError("Expected AsyncGenerator, got ChatResponse")
|
||||
async for chunk in result:
|
||||
yield chunk
|
||||
|
||||
async def generate_single(
|
||||
self, model: str, prompt: str, provider: Optional[LLMProvider] = None, **kwargs
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
provider: Optional[LLMProvider] = None,
|
||||
options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> ChatResponse:
|
||||
"""Get single generation response using specified or default provider"""
|
||||
|
||||
result = await self.generate(model, prompt, provider, stream=False, **kwargs)
|
||||
result = await self.generate(model, prompt, provider, stream=False, options=options, **kwargs)
|
||||
if not isinstance(result, ChatResponse):
|
||||
raise RuntimeError("Expected ChatResponse, got AsyncGenerator")
|
||||
return result
|
||||
@ -1209,11 +1264,52 @@ class UnifiedLLMProxy(LLMCacheMixin):
|
||||
async def embeddings(
|
||||
self, model: str, input_texts: Union[str, List[str]], provider: Optional[LLMProvider] = None, **kwargs
|
||||
) -> EmbeddingResponse:
|
||||
"""Generate embeddings using specified or default provider"""
|
||||
provider = provider or self.default_provider
|
||||
adapter = self._get_adapter(provider)
|
||||
"""Generate embeddings using specified or default provider.
|
||||
|
||||
return await adapter.embeddings(model, input_texts, **kwargs)
|
||||
For Anthropic provider, automatically routes to Ollama since Anthropic doesn't provide embeddings API.
|
||||
"""
|
||||
provider = provider or self.default_provider
|
||||
|
||||
# Special handling: If Anthropic is requested for embeddings, use Ollama instead
|
||||
if provider == LLMProvider.ANTHROPIC:
|
||||
logger.info("Anthropic provider requested for embeddings - routing to Ollama provider")
|
||||
|
||||
# Ensure Ollama is available
|
||||
if LLMProvider.OLLAMA not in self._initialized_providers:
|
||||
raise RuntimeError(
|
||||
"Anthropic provider requires Ollama for embeddings, but Ollama is not configured. "
|
||||
"Please ensure Ollama endpoint is properly configured."
|
||||
)
|
||||
|
||||
# Use Ollama for embeddings with a default embedding model if not specified
|
||||
embedding_provider = LLMProvider.OLLAMA
|
||||
|
||||
# If no specific model provided, use a default Ollama embedding model
|
||||
if not model or model in [
|
||||
"claude-3-5-haiku-latest",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-opus-20240229",
|
||||
]:
|
||||
# Default to a common Ollama embedding model
|
||||
model = os.getenv("OLLAMA_EMBEDDING_MODEL", "mxbai-embed-large")
|
||||
logger.info(f"Using Ollama embedding model: {model}")
|
||||
|
||||
else:
|
||||
embedding_provider = provider
|
||||
|
||||
adapter = self._get_adapter(embedding_provider)
|
||||
|
||||
try:
|
||||
return await adapter.embeddings(model, input_texts, **kwargs)
|
||||
except Exception as e:
|
||||
if embedding_provider == LLMProvider.OLLAMA and provider == LLMProvider.ANTHROPIC:
|
||||
raise RuntimeError(
|
||||
f"Failed to generate embeddings using Ollama (fallback for Anthropic): {e}. "
|
||||
f"Please ensure Ollama is running and the embedding model '{model}' is available."
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
|
||||
"""List available models for specified or default provider"""
|
||||
@ -1224,21 +1320,32 @@ class UnifiedLLMProxy(LLMCacheMixin):
|
||||
return await adapter.list_models()
|
||||
|
||||
async def list_embedding_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
|
||||
"""List available embedding models for specified or default provider"""
|
||||
"""List available embedding models for specified or default provider.
|
||||
|
||||
For Anthropic provider, returns Ollama embedding models since that's what will be used.
|
||||
"""
|
||||
provider = provider or self.default_provider
|
||||
|
||||
# Provider-specific embedding models
|
||||
embedding_models = {
|
||||
LLMProvider.OLLAMA: [
|
||||
"mxbai-embed-large",
|
||||
"nomic-embed-text",
|
||||
"all-minilm",
|
||||
],
|
||||
LLMProvider.OPENAI: ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
|
||||
LLMProvider.ANTHROPIC: [], # No embeddings API
|
||||
LLMProvider.ANTHROPIC: [], # Will be handled specially below
|
||||
LLMProvider.GEMINI: ["models/embedding-001", "models/text-embedding-004"],
|
||||
}
|
||||
|
||||
# Special handling for Anthropic - return Ollama models since that's what will be used
|
||||
if provider == LLMProvider.ANTHROPIC:
|
||||
raise NotImplementedError("Anthropic does not provide embeddings API")
|
||||
logger.info("Listing Ollama embedding models (used for Anthropic embedding requests)")
|
||||
|
||||
if LLMProvider.OLLAMA not in self._initialized_providers:
|
||||
raise RuntimeError("Anthropic provider requires Ollama for embeddings, but Ollama is not configured.")
|
||||
|
||||
return await self.list_embedding_models(LLMProvider.OLLAMA)
|
||||
|
||||
# For Ollama, check which embedding models are actually available
|
||||
if provider == LLMProvider.OLLAMA:
|
||||
@ -1266,7 +1373,7 @@ class UnifiedLLMProxy(LLMCacheMixin):
|
||||
|
||||
|
||||
class LLMManager:
|
||||
"""Updated manager with caching support"""
|
||||
"""Updated manager with caching support and required Ollama endpoint"""
|
||||
|
||||
_instance = None
|
||||
_proxy = None
|
||||
@ -1283,24 +1390,42 @@ class LLMManager:
|
||||
self._configure_from_environment()
|
||||
|
||||
def _configure_from_environment(self):
|
||||
"""Configure providers and cache from environment variables"""
|
||||
"""Configure providers and cache from environment variables - Ollama is now required"""
|
||||
if not self._proxy:
|
||||
raise RuntimeError("UnifiedLLMProxy instance not initialized")
|
||||
|
||||
# Your existing configuration code...
|
||||
# REQUIRED: Configure Ollama provider first (mandatory)
|
||||
ollama_host = os.getenv("OLLAMA_HOST", defines.ollama_api_url)
|
||||
self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host)
|
||||
if not ollama_host:
|
||||
raise ValueError(
|
||||
"Ollama endpoint is required. Please set OLLAMA_HOST environment variable "
|
||||
"or ensure defines.ollama_api_url is configured."
|
||||
)
|
||||
|
||||
logger.info(f"Configuring Ollama provider (REQUIRED) with host: {ollama_host}")
|
||||
try:
|
||||
self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to configure required Ollama provider: {e}")
|
||||
|
||||
# Configure other providers as before
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
logger.info("Configuring OpenAI provider")
|
||||
self._proxy.configure_provider(LLMProvider.OPENAI)
|
||||
|
||||
if os.getenv("ANTHROPIC_API_KEY"):
|
||||
logger.info("Configuring Anthropic provider (will use Ollama for embeddings)")
|
||||
self._proxy.configure_provider(LLMProvider.ANTHROPIC)
|
||||
|
||||
if os.getenv("GEMINI_API_KEY"):
|
||||
logger.info("Configuring Gemini provider")
|
||||
self._proxy.configure_provider(LLMProvider.GEMINI)
|
||||
|
||||
# Add cache configuration from environment
|
||||
# Validate that Ollama is properly configured
|
||||
if LLMProvider.OLLAMA not in self._proxy._initialized_providers:
|
||||
raise RuntimeError("Ollama provider must be successfully configured")
|
||||
|
||||
# Add cache configuration from environment (unchanged)
|
||||
cache_config = {}
|
||||
if os.getenv("LLM_CACHE_TTL"):
|
||||
cache_config["ttl"] = int(os.getenv("LLM_CACHE_TTL", 31536000))
|
||||
@ -1319,11 +1444,15 @@ class LLMManager:
|
||||
if cache_config:
|
||||
asyncio.create_task(self._proxy.configure_cache(**cache_config))
|
||||
|
||||
# Your existing default provider setup...
|
||||
# Set default provider
|
||||
default_provider = os.getenv("DEFAULT_LLM_PROVIDER", "ollama")
|
||||
try:
|
||||
logger.info(f"Setting default LLM provider to: {default_provider}")
|
||||
self._proxy.set_default_provider(LLMProvider(default_provider))
|
||||
except ValueError:
|
||||
except ValueError as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Invalid default provider: {default_provider}. Error: {e}")
|
||||
logger.info(f"Setting default LLM provider to fallback: {LLMProvider.OLLAMA}")
|
||||
self._proxy.set_default_provider(LLMProvider.OLLAMA)
|
||||
|
||||
def get_proxy(self) -> UnifiedLLMProxy:
|
||||
|
Loading…
x
Reference in New Issue
Block a user