Added caching to LLM queries

This commit is contained in:
James Ketr 2025-07-09 13:08:36 -07:00
parent b130cd3974
commit 2fe1f5b181
9 changed files with 624 additions and 110 deletions

View File

@ -137,7 +137,11 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = () => {
}
// Initialize step
const urlStep = stepId ? parseInt(stepId, 10) : undefined;
let urlStep = stepId ? parseInt(stepId, 10) : undefined;
// If job and candidate are both provided, default to step 2
if (urlStep === undefined) {
urlStep = candidate && job ? 2 : 0;
}
if (urlStep !== undefined && !isNaN(urlStep) && urlStep !== activeStep) {
setActiveStep(urlStep);
}

View File

@ -0,0 +1,30 @@
import React from 'react';
import { SxProps } from '@mui/material';
import * as Types from 'types/types'; // Adjust the import path as necessary
import { useAuth } from 'hooks/AuthContext';
import { JobsView } from 'components/ui/JobsView';
interface JobsViewPageProps {
sx?: SxProps;
}
const JobsViewPage: React.FC<JobsViewPageProps> = (props: JobsViewPageProps) => {
const { sx } = props;
const { apiClient } = useAuth();
return (
<JobsView
onJobSelect={(selectedJobs: Types.Job[]): void => console.log('Selected:', selectedJobs)}
onJobView={(job: Types.Job): void => console.log('View job:', job)}
onJobEdit={(job: Types.Job): void => console.log('Edit job:', job)}
onJobDelete={async (job: Types.Job): Promise<void> => {
await apiClient.deleteJob(job.id || '');
}}
selectable={true}
showActions={true}
sx={sx}
/>
);
};
export { JobsViewPage };

View File

@ -997,7 +997,7 @@ class ApiClient {
}
async getSystemInfo(): Promise<Types.SystemInfo> {
const response = await fetch(`${this.baseUrl}/system-info`, {
const response = await fetch(`${this.baseUrl}/system/info`, {
method: 'GET',
headers: this.defaultHeaders,
});

View File

@ -192,7 +192,7 @@ class CandidateEntity(Candidate):
os.makedirs(rag_content_dir, exist_ok=True)
self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher(
llm=llm_manager.get_llm(),
llm=llm_manager.get_llm(database.redis),
user_id=self.id,
collection_name=self.username,
persist_directory=vector_db_dir,

View File

@ -109,7 +109,7 @@ async def create_candidate_ai(
resume_message = None
state = 0 # 0 -- create persona, 1 -- create resume
async for generated_message in generate_agent.generate(
llm=llm_manager.get_llm(),
llm=llm_manager.get_llm(database.redis),
model=defines.model,
session_id=user_message.session_id,
prompt=user_message.content,
@ -754,7 +754,7 @@ async def get_candidate_qr_code(
import pyqrcode
if job:
qrobj = pyqrcode.create(f"{defines.frontend_url}/job-analysis/{candidate.id}/{job.id}")
qrobj = pyqrcode.create(f"{defines.frontend_url}/job-analysis/{candidate.id}/{job.id}/2")
else:
qrobj = pyqrcode.create(f"{defines.frontend_url}/u/{candidate.id}")
with open(file_path, "wb") as f:
@ -1341,7 +1341,11 @@ async def search_candidates(
@router.post("/rag-search")
async def post_candidate_rag_search(query: str = Body(...), current_user=Depends(get_current_user)):
async def post_candidate_rag_search(
query: str = Body(...),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Get chat activity summary for a candidate"""
try:
if current_user.user_type != "candidate":
@ -1367,7 +1371,7 @@ async def post_candidate_rag_search(query: str = Body(...), current_user=Depends
)
rag_message: Any = None
async for generated_message in chat_agent.generate(
llm=llm_manager.get_llm(),
llm=llm_manager.get_llm(database.redis),
model=defines.model,
session_id=user_message.session_id,
prompt=user_message.content,
@ -1614,7 +1618,7 @@ async def get_candidate_skill_match(
# Generate new skill match
final_message = None
async for generated_message in agent.generate(
llm=llm_manager.get_llm(),
llm=llm_manager.get_llm(database.redis),
model=defines.model,
session_id=MOCK_UUID,
prompt=skill,
@ -1944,7 +1948,7 @@ async def generate_resume(
yield error_message
return
async for generated_message in agent.generate_resume(
llm=llm_manager.get_llm(),
llm=llm_manager.get_llm(database.redis),
model=defines.model,
session_id=MOCK_UUID,
skills=skills,

View File

@ -61,7 +61,7 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid
message = None
async for message in chat_agent.llm_one_shot(
llm=llm_manager.get_llm(),
llm=llm_manager.get_llm(database.redis),
model=defines.model,
session_id=MOCK_UUID,
prompt=content,
@ -132,7 +132,10 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
message = None
async for message in chat_agent.generate(
llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
llm=llm_manager.get_llm(database.redis),
model=defines.model,
session_id=MOCK_UUID,
prompt=markdown_message.content,
):
if message.status != ApiStatusType.DONE:
yield message

View File

@ -1,17 +1,21 @@
from typing import Optional
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Depends
from database.core import RedisDatabase
from utils.llm_proxy import LLMProvider, get_llm
from utils.dependencies import get_database
router = APIRouter(prefix="/providers", tags=["providers"])
@router.get("/models")
async def list_models(provider: Optional[str] = None):
async def list_models(
provider: Optional[str] = None,
database: RedisDatabase = Depends(get_database),
):
"""List available models for a provider"""
try:
llm = get_llm()
llm = get_llm(database.redis)
provider_enum = None
if provider:
@ -28,9 +32,11 @@ async def list_models(provider: Optional[str] = None):
@router.get("")
async def list_providers():
async def list_providers(
database: RedisDatabase = Depends(get_database),
):
"""List all configured providers"""
llm = get_llm()
llm = get_llm(database.redis)
return {
"providers": [provider.value for provider in llm._initialized_providers],
"default": llm.default_provider.value,
@ -38,10 +44,13 @@ async def list_providers():
@router.post("/{provider}/set-default")
async def set_default_provider(provider: str):
async def set_default_provider(
provider: str,
database: RedisDatabase = Depends(get_database),
):
"""Set the default provider"""
try:
llm = get_llm()
llm = get_llm(database.redis)
provider_enum = LLMProvider(provider.lower())
llm.set_default_provider(provider_enum)
return {"message": f"Default provider set to {provider}", "default": provider}
@ -51,9 +60,11 @@ async def set_default_provider(provider: str):
# Health check endpoint
@router.get("/health")
async def health_check():
async def health_check(
database: RedisDatabase = Depends(get_database),
):
"""Health check endpoint"""
llm = get_llm()
llm = get_llm(database.redis)
return {
"status": "healthy",
"providers_configured": len(llm._initialized_providers),

View File

@ -65,7 +65,7 @@ def filter_and_paginate(
return paginated_items, total
async def stream_agent_response(chat_agent, user_message, chat_session_data=None, database=None) -> StreamingResponse:
async def stream_agent_response(chat_agent, user_message, database, chat_session_data=None) -> StreamingResponse:
"""Stream agent response with proper formatting"""
async def message_stream_generator():
@ -75,7 +75,7 @@ async def stream_agent_response(chat_agent, user_message, chat_session_data=None
import utils.llm_proxy as llm_manager
async for generated_message in chat_agent.generate(
llm=llm_manager.get_llm(),
llm=llm_manager.get_llm(database.redis),
model=defines.model,
session_id=user_message.session_id,
prompt=user_message.content,
@ -186,7 +186,7 @@ async def reformat_as_markdown(database, candidate_entity, content: str):
message = None
async for message in chat_agent.llm_one_shot(
llm=llm_manager.get_llm(),
llm=llm_manager.get_llm(database.redis),
model=defines.model,
session_id=MOCK_UUID,
prompt=content,
@ -267,7 +267,10 @@ async def create_job_from_content(database, current_user, content: str):
message = None
async for message in chat_agent.generate(
llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
llm=llm_manager.get_llm(database.redis),
model=defines.model,
session_id=MOCK_UUID,
prompt=markdown_message.content,
):
if message.status != ApiStatusType.DONE:
yield message

View File

@ -1,13 +1,18 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Any, AsyncGenerator, Optional, Union
import time
from typing import Dict, List, Any, AsyncGenerator, Optional, Union, cast
from pydantic import BaseModel, Field
from enum import Enum
import asyncio
from dataclasses import dataclass
import os
from redis.asyncio import Redis
import defines
from logger import logger
import hashlib
import json
from dataclasses import asdict
# Standard message format for all providers
@dataclass
@ -364,7 +369,7 @@ class OpenAIAdapter(BaseLLMAdapter):
usage_stats = self._create_usage_stats(response.usage)
return ChatResponse(
content=response.choices[0].message.content,
content=response.choices[0].message.content, # type: ignore
model=model,
finish_reason=response.choices[0].finish_reason,
usage=usage_stats,
@ -372,7 +377,7 @@ class OpenAIAdapter(BaseLLMAdapter):
async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
# Await the stream creation first, then iterate
stream = await self.client.chat.completions.create(model=model, messages=messages, stream=True, **kwargs)
stream = await self.client.chat.completions.create(model=model, messages=messages, stream=True, **kwargs) # type: ignore
async for chunk in stream:
if chunk.choices[0].delta.content:
@ -474,7 +479,10 @@ class AnthropicAdapter(BaseLLMAdapter):
usage_stats = self._create_usage_stats(response.usage)
return ChatResponse(
content=response.content[0].text, model=model, finish_reason=response.stop_reason, usage=usage_stats
content=response.content[0].text, # type: ignore
model=model,
finish_reason=response.stop_reason,
usage=usage_stats, # type: ignore
)
async def _stream_chat(self, **kwargs):
@ -500,7 +508,7 @@ class AnthropicAdapter(BaseLLMAdapter):
# Yield a final empty response with usage stats
yield ChatResponse(content="", model=model, usage=final_usage_stats, finish_reason="stop")
except Exception as e:
logger.debug(f"Could not retrieve final usage stats: {e}")
logger.info(f"Could not retrieve final usage stats: {e}")
async def generate(
self, model: str, prompt: str, stream: bool = False, **kwargs
@ -524,9 +532,9 @@ class GeminiAdapter(BaseLLMAdapter):
def __init__(self, **config):
super().__init__(**config)
import google.generativeai as genai # type: ignore
import google.generativeai as genai # type: ignore
genai.configure(api_key=config.get("api_key", os.getenv("GEMINI_API_KEY")))
genai.configure(api_key=config.get("api_key", os.getenv("GEMINI_API_KEY"))) # type: ignore
self.genai = genai
def _create_usage_stats(self, response: Any) -> UsageStats:
@ -558,7 +566,7 @@ class GeminiAdapter(BaseLLMAdapter):
async def chat(
self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
model_instance = self.genai.GenerativeModel(model)
model_instance = self.genai.GenerativeModel(model) # type: ignore
# Convert messages to Gemini format
chat_history = []
@ -638,7 +646,7 @@ class GeminiAdapter(BaseLLMAdapter):
# Gemini embeddings - use embedding model
for i, text in enumerate(texts):
response = self.genai.embed_content(model=model, content=text, **kwargs)
response = self.genai.embed_content(model=model, content=text, **kwargs) # type: ignore
results.append(EmbeddingData(embedding=response["embedding"], index=i))
@ -649,18 +657,481 @@ class GeminiAdapter(BaseLLMAdapter):
)
async def list_models(self) -> List[str]:
models = self.genai.list_models()
models = self.genai.list_models() # type: ignore
return [model.name for model in models if "generateContent" in model.supported_generation_methods]
class UnifiedLLMProxy:
@dataclass
class CacheStats:
total_requests: int = 0
cache_hits: int = 0
cache_misses: int = 0
total_cached_responses: int = 0
cache_size_bytes: int = 0
@property
def hit_rate(self) -> float:
total = self.cache_hits + self.cache_misses
return self.cache_hits / total if total > 0 else 0.0
@dataclass
class CacheConfig:
enabled: bool = True
ttl: int = 31536000 # 1 year
max_size: int = 1000 # Maximum number of cached entries (COUNT, not bytes)
provider_specific: bool = False
cache_streaming: bool = True # Enable/disable streaming cache
streaming_chunk_size: int = 5 # Size of chunks when simulating streaming from cache
streaming_delay: float = 0.01 # Delay between chunks in seconds
hash_algorithm: str = "sha256"
# Cache key prefixes
CACHE_KEY_PREFIXES = {"llm_cache": "llm_cache:", "cache_stats": "cache_stats:", "cache_config": "cache_config:"}
class LLMCacheMixin:
"""Mixin class to add caching capabilities to LLM adapters"""
def __init__(self, redis_client: Redis, **kwargs):
super().__init__(**kwargs)
self.redis: Redis = redis_client
self.cache_config = CacheConfig()
self.cache_stats = CacheStats()
def _serialize_chat_response(self, response: ChatResponse) -> Dict[str, Any]:
"""Safely serialize ChatResponse to dictionary"""
return {
"content": response.content,
"model": response.model,
"finish_reason": response.finish_reason,
"usage": response.usage.model_dump(exclude_none=True) if response.usage else None,
"usage_legacy": response.usage_legacy,
}
def _serialize_for_hash(self, obj: Any) -> str:
"""Serialize object to consistent string for hashing"""
def json_default(o):
if hasattr(o, "__dict__"):
return o.__dict__
elif hasattr(o, "_asdict"):
return o._asdict()
elif isinstance(o, (set, frozenset)):
return sorted(list(o))
elif hasattr(o, "isoformat"):
return o.isoformat()
else:
return str(o)
try:
return json.dumps(obj, default=json_default, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
except (TypeError, ValueError) as e:
logger.warning(f"Serialization fallback for {type(obj)}: {e}")
return str(obj)
def _create_cache_key(
self,
method: str,
model: str,
messages_or_prompt: Union[List[LLMMessage], str],
provider: Optional[LLMProvider] = None,
**kwargs,
) -> str:
"""Create consistent cache key from parameters"""
cache_kwargs = {k: v for k, v in kwargs.items() if k not in ["stream", "callback", "timeout"]}
if isinstance(messages_or_prompt, list):
serializable_messages = []
for msg in messages_or_prompt:
if isinstance(msg, LLMMessage):
serializable_messages.append(msg.to_dict())
else:
serializable_messages.append(msg)
messages_or_prompt = serializable_messages
cache_data = {"method": method, "model": model, "content": messages_or_prompt, "kwargs": cache_kwargs}
if self.cache_config.provider_specific and provider:
cache_data["provider"] = provider.value
serialized = self._serialize_for_hash(cache_data)
hash_obj = hashlib.sha256(serialized.encode("utf-8"))
return hash_obj.hexdigest()
async def _track_cache_access(self, cache_key: str) -> None:
"""Track cache access time for LRU eviction"""
if not self.redis:
return
try:
access_key = f"{CACHE_KEY_PREFIXES['llm_cache']}access_times"
current_time = time.time()
await self.redis.zadd(access_key, {cache_key: current_time})
except Exception as e:
logger.error(f"Error tracking cache access: {e}")
async def _enforce_cache_size_limit_efficient(self) -> None:
"""Efficient LRU enforcement using sorted sets"""
if not self.redis or not self.cache_config.enabled:
return
try:
access_key = f"{CACHE_KEY_PREFIXES['llm_cache']}access_times"
# Get current cache count
current_count = await self.redis.zcard(access_key)
if current_count >= self.cache_config.max_size:
# Calculate how many to remove
entries_to_remove = current_count - self.cache_config.max_size + 1
# Get oldest entries (lowest scores)
oldest_entries = await self.redis.zrange(access_key, 0, entries_to_remove - 1)
if oldest_entries:
# Remove from both cache and access tracking
pipe = self.redis.pipeline()
for cache_key in oldest_entries:
full_cache_key = f"{CACHE_KEY_PREFIXES['llm_cache']}{cache_key}"
pipe.delete(full_cache_key)
pipe.zrem(access_key, cache_key)
await pipe.execute()
logger.info(f"LRU eviction: removed {len(oldest_entries)} entries")
except Exception as e:
logger.error(f"Error in LRU eviction: {e}")
async def _get_cached_response(self, cache_key: str) -> Optional[ChatResponse]:
"""Retrieve cached response and update access time"""
if not self.redis or not self.cache_config.enabled:
return None
try:
full_key = f"{CACHE_KEY_PREFIXES['llm_cache']}{cache_key}"
cached_data = await self.redis.get(full_key)
if cached_data:
# Update access time for LRU
await self._track_cache_access(cache_key)
if isinstance(cached_data, bytes):
cached_data = cached_data.decode("utf-8")
response_data = json.loads(cached_data)
if "usage" in response_data and response_data["usage"]:
response_data["usage"] = UsageStats(**response_data["usage"])
self.cache_stats.cache_hits += 1
await self._update_cache_stats()
logger.info(f"Cache hit for key: {cache_key[:16]}...")
return ChatResponse(**response_data)
except Exception as e:
logger.error(f"Error retrieving cached response: {e}")
raise e
return None
async def _set_cached_response(self, cache_key: str, response: ChatResponse) -> None:
"""Store response with efficient LRU tracking"""
if not self.redis or not self.cache_config.enabled:
return
try:
# Enforce size limit before adding
await self._enforce_cache_size_limit_efficient()
full_key = f"{CACHE_KEY_PREFIXES['llm_cache']}{cache_key}"
response_data = self._serialize_chat_response(response)
response_data = {k: v for k, v in response_data.items() if v is not None}
serialized_data = json.dumps(response_data, default=str)
# Store data and track access
pipe = self.redis.pipeline()
pipe.setex(full_key, self.cache_config.ttl, serialized_data)
await pipe.execute()
await self._track_cache_access(cache_key)
self.cache_stats.total_cached_responses += 1
self.cache_stats.cache_size_bytes += len(serialized_data)
await self._update_cache_stats()
logger.info(f"Cached response for key: {cache_key[:16]}...")
except Exception as e:
logger.error(f"Error caching response: {e}")
async def _update_cache_stats(self) -> None:
"""Update cache statistics in Redis"""
if not self.redis:
return
try:
stats_key = f"{CACHE_KEY_PREFIXES['cache_stats']}global"
stats_data = asdict(self.cache_stats)
serialized_stats = json.dumps(stats_data)
await self.redis.set(stats_key, serialized_stats)
except Exception as e:
logger.error(f"Error updating cache stats: {e}")
async def get_cache_stats(self) -> CacheStats:
"""Get current cache statistics"""
if not self.redis:
return self.cache_stats
try:
stats_key = f"{CACHE_KEY_PREFIXES['cache_stats']}global"
cached_stats = await self.redis.get(stats_key)
if cached_stats:
if isinstance(cached_stats, bytes):
cached_stats = cached_stats.decode("utf-8")
stats_data = json.loads(cached_stats)
self.cache_stats = CacheStats(**stats_data)
except Exception as e:
logger.error(f"Error retrieving cache stats: {e}")
return self.cache_stats
async def clear_cache(self, pattern: str = "*") -> int:
"""Clear cached responses matching pattern"""
if not self.redis:
return 0
try:
search_pattern = f"{CACHE_KEY_PREFIXES['llm_cache']}{pattern}"
keys = await self.redis.keys(search_pattern)
if keys:
deleted = await self.redis.delete(*keys)
logger.info(f"Cleared {deleted} cached responses")
self.cache_stats = CacheStats()
await self._update_cache_stats()
return deleted
except Exception as e:
logger.error(f"Error clearing cache: {e}")
return 0
async def configure_cache(self, **config_kwargs) -> None:
"""Update cache configuration"""
for key, value in config_kwargs.items():
if hasattr(self.cache_config, key):
setattr(self.cache_config, key, value)
logger.info(f"Updated cache config: {key} = {value}")
class UnifiedLLMProxy(LLMCacheMixin):
"""Main proxy class that provides unified interface to all LLM providers"""
def __init__(self, default_provider: LLMProvider = LLMProvider.OLLAMA):
def __init__(self, redis_client: Redis, default_provider: LLMProvider = LLMProvider.OLLAMA):
# Initialize with caching support
super().__init__(redis_client=redis_client)
self.adapters: Dict[LLMProvider, BaseLLMAdapter] = {}
self.default_provider = default_provider
self._initialized_providers = set()
async def _simulate_streaming_from_cache(
self, cached_response: ChatResponse, chunk_size: Optional[int] = None
) -> AsyncGenerator[ChatResponse, None]:
"""Simulate streaming by yielding cached response in chunks"""
content = cached_response.content
# Use configured chunk size if not provided
if chunk_size is None:
chunk_size = self.cache_config.streaming_chunk_size
# Yield content in chunks to simulate streaming
for i in range(0, len(content), chunk_size):
chunk_content = content[i : i + chunk_size]
# Create chunk response
chunk_response = ChatResponse(
content=chunk_content,
model=cached_response.model,
finish_reason=None, # Not finished yet
usage=None, # Usage only in final chunk
)
yield chunk_response
# Use configured streaming delay
await asyncio.sleep(self.cache_config.streaming_delay)
# Yield final empty chunk with usage stats and finish reason
final_chunk = ChatResponse(
content="",
model=cached_response.model,
finish_reason=cached_response.finish_reason,
usage=cached_response.usage,
)
yield final_chunk
async def _accumulate_and_cache_stream(
self, cache_key: str, stream: AsyncGenerator[ChatResponse, None], model: str
) -> AsyncGenerator[ChatResponse, None]:
"""Accumulate streaming response and cache the final result"""
accumulated_content = ""
final_usage = None
final_finish_reason = None
async for chunk in stream:
# Yield the chunk immediately to maintain streaming behavior
yield chunk
# Accumulate content and final metadata
if chunk.content:
accumulated_content += chunk.content
if chunk.usage:
final_usage = chunk.usage
if chunk.finish_reason:
final_finish_reason = chunk.finish_reason
# Create final response for caching
if accumulated_content: # Only cache if we got content
final_response = ChatResponse(
content=accumulated_content, model=model, finish_reason=final_finish_reason, usage=final_usage
)
# Cache the complete response
await self._set_cached_response(cache_key, final_response)
logger.info(f"Cached accumulated streaming response for key: {cache_key[:16]}...")
async def chat(
self,
model: str,
messages: Union[List[LLMMessage], List[Dict[str, str]]],
provider: Optional[LLMProvider] = None,
stream: bool = False,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Enhanced chat method with streaming cache support"""
self.cache_stats.total_requests += 1
# Normalize messages for caching
normalized_messages = []
for msg in messages:
if isinstance(msg, LLMMessage):
normalized_messages.append(msg)
elif isinstance(msg, dict):
normalized_messages.append(LLMMessage.from_dict(msg))
else:
raise ValueError(f"Invalid message type: {type(msg)}")
# Create cache key
cache_key = self._create_cache_key(
method="chat", model=model, messages_or_prompt=normalized_messages, provider=provider, **kwargs
)
# Check cache
if self.cache_config.enabled:
cached_response = await self._get_cached_response(cache_key)
if cached_response:
if stream:
logger.info("Cache hit - returning cached streaming response")
return self._simulate_streaming_from_cache(cached_response)
else:
logger.info("Cache hit - returning cached single response")
return cached_response
# Cache miss
self.cache_stats.cache_misses += 1
provider = provider or self.default_provider
adapter = self._get_adapter(provider)
if stream:
# For streaming: we know this will return AsyncGenerator
stream_response = await adapter.chat(model, normalized_messages, stream=True, **kwargs)
stream_response = cast(AsyncGenerator[ChatResponse, None], stream_response)
if self.cache_config.enabled:
return self._accumulate_and_cache_stream(cache_key, stream_response, model)
else:
return stream_response
else:
# For non-streaming: we know this will return ChatResponse
single_response = await adapter.chat(model, normalized_messages, stream=False, **kwargs)
single_response = cast(ChatResponse, single_response)
if self.cache_config.enabled:
await self._set_cached_response(cache_key, single_response)
return single_response
async def generate(
self,
model: str,
prompt: str,
provider: Optional[LLMProvider] = None,
stream: bool = False,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Enhanced generate method with streaming cache support"""
self.cache_stats.total_requests += 1
# Create cache key
cache_key = self._create_cache_key(
method="generate", model=model, messages_or_prompt=prompt, provider=provider, **kwargs
)
# Check cache for both streaming and non-streaming requests
if self.cache_config.enabled:
cached_response = await self._get_cached_response(cache_key)
if cached_response:
logger.info("Cache hit - returning cached response")
if stream:
# Return cached response as simulated stream
return self._simulate_streaming_from_cache(cached_response)
else:
# Return cached response directly
return cached_response
# Cache miss - call the actual LLM
self.cache_stats.cache_misses += 1
provider = provider or self.default_provider
adapter = self._get_adapter(provider)
if stream:
# For streaming: we know this will return AsyncGenerator
stream_response = await adapter.generate(model, prompt, stream=True, **kwargs)
stream_response = cast(AsyncGenerator[ChatResponse, None], stream_response)
if self.cache_config.enabled:
# Return stream that accumulates and caches
return self._accumulate_and_cache_stream(cache_key, stream_response, model)
else:
# Return original stream if caching disabled
return stream_response
else:
# For non-streaming: we know this will return ChatResponse
single_response = await adapter.generate(model, prompt, stream=False, **kwargs)
single_response = cast(ChatResponse, single_response)
if self.cache_config.enabled:
await self._set_cached_response(cache_key, single_response)
return single_response
def configure_provider(self, provider: LLMProvider, **config):
"""Configure a specific provider with its settings"""
adapter_classes = {
@ -683,31 +1154,6 @@ class UnifiedLLMProxy:
raise ValueError(f"Provider {provider} not configured")
self.default_provider = provider
async def chat(
self,
model: str,
messages: Union[List[LLMMessage], List[Dict[str, str]]],
provider: Optional[LLMProvider] = None,
stream: bool = False,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Send chat messages using specified or default provider"""
provider = provider or self.default_provider
adapter = self._get_adapter(provider)
# Normalize messages to LLMMessage objects
normalized_messages = []
for msg in messages:
if isinstance(msg, LLMMessage):
normalized_messages.append(msg)
elif isinstance(msg, dict):
normalized_messages.append(LLMMessage.from_dict(msg))
else:
raise ValueError(f"Invalid message type: {type(msg)}")
return await adapter.chat(model, normalized_messages, stream, **kwargs)
async def chat_stream(
self,
model: str,
@ -739,16 +1185,6 @@ class UnifiedLLMProxy:
raise RuntimeError("Expected ChatResponse, got AsyncGenerator")
return result
async def generate(
self, model: str, prompt: str, provider: Optional[LLMProvider] = None, stream: bool = False, **kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Generate text using specified or default provider"""
provider = provider or self.default_provider
adapter = self._get_adapter(provider)
return await adapter.generate(model, prompt, stream, **kwargs)
async def generate_stream(
self, model: str, prompt: str, provider: Optional[LLMProvider] = None, **kwargs
) -> AsyncGenerator[ChatResponse, None]:
@ -793,7 +1229,9 @@ class UnifiedLLMProxy:
# Provider-specific embedding models
embedding_models = {
LLMProvider.OLLAMA: ["nomic-embed-text", "mxbai-embed-large", "all-minilm", "snowflake-arctic-embed"],
LLMProvider.OLLAMA: [
"mxbai-embed-large",
],
LLMProvider.OPENAI: ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
LLMProvider.ANTHROPIC: [], # No embeddings API
LLMProvider.GEMINI: ["models/embedding-001", "models/text-embedding-004"],
@ -815,7 +1253,7 @@ class UnifiedLLMProxy:
available_embedding_models.append(model)
return available_embedding_models if available_embedding_models else suggested_models
except:
except Exception:
return embedding_models[provider]
return embedding_models.get(provider, [])
@ -827,77 +1265,90 @@ class UnifiedLLMProxy:
return self.adapters[provider]
# Example usage and configuration
class LLMManager:
"""Singleton manager for the unified LLM proxy"""
"""Updated manager with caching support"""
_instance = None
_proxy = None
@classmethod
def get_instance(cls):
def get_instance(cls, redis_client: Redis):
if cls._instance is None:
cls._instance = cls()
cls._instance = cls(redis_client)
return cls._instance
def __init__(self):
def __init__(self, redis_client: Redis):
if LLMManager._proxy is None:
LLMManager._proxy = UnifiedLLMProxy()
LLMManager._proxy = UnifiedLLMProxy(redis_client=redis_client)
self._configure_from_environment()
def _configure_from_environment(self):
"""Configure providers based on environment variables"""
"""Configure providers and cache from environment variables"""
if not self._proxy:
raise RuntimeError("UnifiedLLMProxy instance not initialized")
# Configure Ollama if available
# Your existing configuration code...
ollama_host = os.getenv("OLLAMA_HOST", defines.ollama_api_url)
self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host)
# Configure OpenAI if API key is available
if os.getenv("OPENAI_API_KEY"):
self._proxy.configure_provider(LLMProvider.OPENAI)
# Configure Anthropic if API key is available
if os.getenv("ANTHROPIC_API_KEY"):
self._proxy.configure_provider(LLMProvider.ANTHROPIC)
# Configure Gemini if API key is available
if os.getenv("GEMINI_API_KEY"):
self._proxy.configure_provider(LLMProvider.GEMINI)
# Set default provider from environment or use Ollama
# Add cache configuration from environment
cache_config = {}
if os.getenv("LLM_CACHE_TTL"):
cache_config["ttl"] = int(os.getenv("LLM_CACHE_TTL", 31536000))
if os.getenv("LLM_CACHE_MAX_SIZE"):
cache_config["max_size"] = int(os.getenv("LLM_CACHE_MAX_SIZE", 5000))
if os.getenv("LLM_CACHE_PROVIDER_SPECIFIC"):
cache_config["provider_specific"] = os.getenv("LLM_CACHE_PROVIDER_SPECIFIC", "false").lower() == "true"
if os.getenv("LLM_CACHE_ENABLED"):
cache_config["enabled"] = os.getenv("LLM_CACHE_ENABLED", "true").lower() == "true"
if os.getenv("LLM_CACHE_STREAMING_DELAY"):
cache_config["streaming_delay"] = float(os.getenv("LLM_CACHE_STREAMING_DELAY", "0.01"))
if os.getenv("LLM_CACHE_STREAMING_CHUNK_SIZE"):
cache_config["streaming_chunk_size"] = int(os.getenv("LLM_CACHE_STREAMING_CHUNK_SIZE", "50"))
if os.getenv("LLM_CACHE_STREAMING_ENABLED"):
cache_config["cache_streaming"] = os.getenv("LLM_CACHE_STREAMING_ENABLED", "true").lower() == "true"
if cache_config:
asyncio.create_task(self._proxy.configure_cache(**cache_config))
# Your existing default provider setup...
default_provider = os.getenv("DEFAULT_LLM_PROVIDER", "ollama")
try:
self._proxy.set_default_provider(LLMProvider(default_provider))
except ValueError:
# Fallback to Ollama if specified provider not available
self._proxy.set_default_provider(LLMProvider.OLLAMA)
def get_proxy(self) -> UnifiedLLMProxy:
"""Get the unified LLM proxy instance"""
"""Get the LLM proxy instance"""
if not self._proxy:
raise RuntimeError("UnifiedLLMProxy instance not initialized")
return self._proxy
# Convenience function for easy access
def get_llm() -> UnifiedLLMProxy:
"""Get the configured LLM proxy"""
return LLMManager.get_instance().get_proxy()
def get_llm(redis_client: Redis) -> UnifiedLLMProxy:
"""Get the configured LLM proxy with caching"""
return LLMManager.get_instance(redis_client).get_proxy()
# Example usage with detailed statistics
async def example_usage():
async def example_usage(llm: UnifiedLLMProxy):
"""Example showing how to access detailed statistics"""
llm = get_llm()
# Configure providers
llm.configure_provider(LLMProvider.OLLAMA, host="http://localhost:11434")
llm.configure_provider(LLMProvider.OLLAMA, host=defines.ollama_api_url)
# Simple chat
messages = [LLMMessage(role="user", content="Explain quantum computing in one paragraph")]
response = await llm.chat_single("llama2", messages)
response = await llm.chat_single(defines.model, messages)
print(f"Content: {response.content}")
print(f"Model: {response.model}")
@ -923,12 +1374,11 @@ async def example_usage():
print(f"Usage as dict: {usage_dict}")
async def example_embeddings_usage():
async def example_embeddings_usage(llm: UnifiedLLMProxy):
"""Example showing how to use the embeddings API"""
llm = get_llm()
# Configure providers
llm.configure_provider(LLMProvider.OLLAMA, host="http://localhost:11434")
llm.configure_provider(LLMProvider.OLLAMA, host=defines.ollama_api_url)
if os.getenv("OPENAI_API_KEY"):
llm.configure_provider(LLMProvider.OPENAI)
@ -952,7 +1402,7 @@ async def example_embeddings_usage():
single_text = "The quick brown fox jumps over the lazy dog"
try:
response = await llm.embeddings("nomic-embed-text", single_text, provider=LLMProvider.OLLAMA)
response = await llm.embeddings(defines.embedding_model, single_text, provider=LLMProvider.OLLAMA)
print(f"Model: {response.model}")
print(f"Embedding dimension: {len(response.get_single_embedding())}")
print(f"First 5 values: {response.get_single_embedding()[:5]}")
@ -994,7 +1444,7 @@ async def example_embeddings_usage():
print(f"Usage: {response.usage.model_dump(exclude_none=True)}")
else:
# Fallback to Ollama for batch
response = await llm.embeddings("nomic-embed-text", texts, provider=LLMProvider.OLLAMA)
response = await llm.embeddings(defines.embedding_model, texts, provider=LLMProvider.OLLAMA)
print(f"Ollama Model: {response.model}")
print(f"Number of embeddings: {len(response.data)}")
print(f"Embedding dimension: {len(response.data[0].embedding)}")
@ -1003,16 +1453,15 @@ async def example_embeddings_usage():
print(f"Batch embedding failed: {e}")
async def example_streaming_with_stats():
async def example_streaming_with_stats(llm: UnifiedLLMProxy):
"""Example showing how to collect usage stats from streaming responses"""
llm = get_llm()
messages = [LLMMessage(role="user", content="Write a short story about AI")]
print("Streaming response:")
final_stats = None
async for chunk in llm.chat_stream("llama2", messages):
async for chunk in llm.chat_stream(defines.model, messages):
# Print content as it streams
if chunk.content:
print(chunk.content, end="", flush=True)
@ -1034,8 +1483,18 @@ async def example_streaming_with_stats():
if __name__ == "__main__":
asyncio.run(example_usage())
print("\n" + "=" * 50 + "\n")
asyncio.run(example_streaming_with_stats())
print("\n" + "=" * 50 + "\n")
asyncio.run(example_embeddings_usage())
from database.manager import DatabaseManager
async def main():
"""Main entry point to run examples"""
db_manager = DatabaseManager()
await db_manager.initialize()
redis_client = db_manager.get_database().redis
llm = get_llm(redis_client)
await example_usage(llm)
print("\n" + "=" * 50 + "\n")
await example_streaming_with_stats(llm)
print("\n" + "=" * 50 + "\n")
# await example_embeddings_usage(llm)
asyncio.run(main())