From 2fe1f5b18144ac587c65a10abeec9d2780ee0179 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Wed, 9 Jul 2025 13:08:36 -0700 Subject: [PATCH] Added caching to LLM queries --- frontend/src/pages/JobAnalysisPage.tsx | 6 +- frontend/src/pages/JobsViewPage.tsx | 30 ++ frontend/src/services/api-client.ts | 2 +- src/backend/agents/base.py | 2 +- src/backend/routes/candidates.py | 16 +- src/backend/routes/jobs.py | 7 +- src/backend/routes/providers.py | 31 +- src/backend/utils/helpers.py | 11 +- src/backend/utils/llm_proxy.py | 629 +++++++++++++++++++++---- 9 files changed, 624 insertions(+), 110 deletions(-) create mode 100644 frontend/src/pages/JobsViewPage.tsx diff --git a/frontend/src/pages/JobAnalysisPage.tsx b/frontend/src/pages/JobAnalysisPage.tsx index ef40248..472c42e 100644 --- a/frontend/src/pages/JobAnalysisPage.tsx +++ b/frontend/src/pages/JobAnalysisPage.tsx @@ -137,7 +137,11 @@ const JobAnalysisPage: React.FC = () => { } // Initialize step - const urlStep = stepId ? parseInt(stepId, 10) : undefined; + let urlStep = stepId ? parseInt(stepId, 10) : undefined; + // If job and candidate are both provided, default to step 2 + if (urlStep === undefined) { + urlStep = candidate && job ? 2 : 0; + } if (urlStep !== undefined && !isNaN(urlStep) && urlStep !== activeStep) { setActiveStep(urlStep); } diff --git a/frontend/src/pages/JobsViewPage.tsx b/frontend/src/pages/JobsViewPage.tsx new file mode 100644 index 0000000..0e23fef --- /dev/null +++ b/frontend/src/pages/JobsViewPage.tsx @@ -0,0 +1,30 @@ +import React from 'react'; +import { SxProps } from '@mui/material'; +import * as Types from 'types/types'; // Adjust the import path as necessary +import { useAuth } from 'hooks/AuthContext'; +import { JobsView } from 'components/ui/JobsView'; + +interface JobsViewPageProps { + sx?: SxProps; +} + +const JobsViewPage: React.FC = (props: JobsViewPageProps) => { + const { sx } = props; + const { apiClient } = useAuth(); + + return ( + console.log('Selected:', selectedJobs)} + onJobView={(job: Types.Job): void => console.log('View job:', job)} + onJobEdit={(job: Types.Job): void => console.log('Edit job:', job)} + onJobDelete={async (job: Types.Job): Promise => { + await apiClient.deleteJob(job.id || ''); + }} + selectable={true} + showActions={true} + sx={sx} + /> + ); +}; + +export { JobsViewPage }; diff --git a/frontend/src/services/api-client.ts b/frontend/src/services/api-client.ts index a4f0f44..0ccba4d 100644 --- a/frontend/src/services/api-client.ts +++ b/frontend/src/services/api-client.ts @@ -997,7 +997,7 @@ class ApiClient { } async getSystemInfo(): Promise { - const response = await fetch(`${this.baseUrl}/system-info`, { + const response = await fetch(`${this.baseUrl}/system/info`, { method: 'GET', headers: this.defaultHeaders, }); diff --git a/src/backend/agents/base.py b/src/backend/agents/base.py index e52b2fc..e3c3f73 100644 --- a/src/backend/agents/base.py +++ b/src/backend/agents/base.py @@ -192,7 +192,7 @@ class CandidateEntity(Candidate): os.makedirs(rag_content_dir, exist_ok=True) self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher( - llm=llm_manager.get_llm(), + llm=llm_manager.get_llm(database.redis), user_id=self.id, collection_name=self.username, persist_directory=vector_db_dir, diff --git a/src/backend/routes/candidates.py b/src/backend/routes/candidates.py index 8e4a7d7..dcc589e 100644 --- a/src/backend/routes/candidates.py +++ b/src/backend/routes/candidates.py @@ -109,7 +109,7 @@ async def create_candidate_ai( resume_message = None state = 0 # 0 -- create persona, 1 -- create resume async for generated_message in generate_agent.generate( - llm=llm_manager.get_llm(), + llm=llm_manager.get_llm(database.redis), model=defines.model, session_id=user_message.session_id, prompt=user_message.content, @@ -754,7 +754,7 @@ async def get_candidate_qr_code( import pyqrcode if job: - qrobj = pyqrcode.create(f"{defines.frontend_url}/job-analysis/{candidate.id}/{job.id}") + qrobj = pyqrcode.create(f"{defines.frontend_url}/job-analysis/{candidate.id}/{job.id}/2") else: qrobj = pyqrcode.create(f"{defines.frontend_url}/u/{candidate.id}") with open(file_path, "wb") as f: @@ -1341,7 +1341,11 @@ async def search_candidates( @router.post("/rag-search") -async def post_candidate_rag_search(query: str = Body(...), current_user=Depends(get_current_user)): +async def post_candidate_rag_search( + query: str = Body(...), + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), +): """Get chat activity summary for a candidate""" try: if current_user.user_type != "candidate": @@ -1367,7 +1371,7 @@ async def post_candidate_rag_search(query: str = Body(...), current_user=Depends ) rag_message: Any = None async for generated_message in chat_agent.generate( - llm=llm_manager.get_llm(), + llm=llm_manager.get_llm(database.redis), model=defines.model, session_id=user_message.session_id, prompt=user_message.content, @@ -1614,7 +1618,7 @@ async def get_candidate_skill_match( # Generate new skill match final_message = None async for generated_message in agent.generate( - llm=llm_manager.get_llm(), + llm=llm_manager.get_llm(database.redis), model=defines.model, session_id=MOCK_UUID, prompt=skill, @@ -1944,7 +1948,7 @@ async def generate_resume( yield error_message return async for generated_message in agent.generate_resume( - llm=llm_manager.get_llm(), + llm=llm_manager.get_llm(database.redis), model=defines.model, session_id=MOCK_UUID, skills=skills, diff --git a/src/backend/routes/jobs.py b/src/backend/routes/jobs.py index a11146b..5e50959 100644 --- a/src/backend/routes/jobs.py +++ b/src/backend/routes/jobs.py @@ -61,7 +61,7 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid message = None async for message in chat_agent.llm_one_shot( - llm=llm_manager.get_llm(), + llm=llm_manager.get_llm(database.redis), model=defines.model, session_id=MOCK_UUID, prompt=content, @@ -132,7 +132,10 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida message = None async for message in chat_agent.generate( - llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content + llm=llm_manager.get_llm(database.redis), + model=defines.model, + session_id=MOCK_UUID, + prompt=markdown_message.content, ): if message.status != ApiStatusType.DONE: yield message diff --git a/src/backend/routes/providers.py b/src/backend/routes/providers.py index ed4d3c2..1271729 100644 --- a/src/backend/routes/providers.py +++ b/src/backend/routes/providers.py @@ -1,17 +1,21 @@ from typing import Optional -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Depends +from database.core import RedisDatabase from utils.llm_proxy import LLMProvider, get_llm +from utils.dependencies import get_database router = APIRouter(prefix="/providers", tags=["providers"]) - @router.get("/models") -async def list_models(provider: Optional[str] = None): +async def list_models( + provider: Optional[str] = None, + database: RedisDatabase = Depends(get_database), +): """List available models for a provider""" try: - llm = get_llm() + llm = get_llm(database.redis) provider_enum = None if provider: @@ -28,9 +32,11 @@ async def list_models(provider: Optional[str] = None): @router.get("") -async def list_providers(): +async def list_providers( + database: RedisDatabase = Depends(get_database), +): """List all configured providers""" - llm = get_llm() + llm = get_llm(database.redis) return { "providers": [provider.value for provider in llm._initialized_providers], "default": llm.default_provider.value, @@ -38,10 +44,13 @@ async def list_providers(): @router.post("/{provider}/set-default") -async def set_default_provider(provider: str): +async def set_default_provider( + provider: str, + database: RedisDatabase = Depends(get_database), +): """Set the default provider""" try: - llm = get_llm() + llm = get_llm(database.redis) provider_enum = LLMProvider(provider.lower()) llm.set_default_provider(provider_enum) return {"message": f"Default provider set to {provider}", "default": provider} @@ -51,9 +60,11 @@ async def set_default_provider(provider: str): # Health check endpoint @router.get("/health") -async def health_check(): +async def health_check( + database: RedisDatabase = Depends(get_database), +): """Health check endpoint""" - llm = get_llm() + llm = get_llm(database.redis) return { "status": "healthy", "providers_configured": len(llm._initialized_providers), diff --git a/src/backend/utils/helpers.py b/src/backend/utils/helpers.py index e4ba74b..7507e8f 100644 --- a/src/backend/utils/helpers.py +++ b/src/backend/utils/helpers.py @@ -65,7 +65,7 @@ def filter_and_paginate( return paginated_items, total -async def stream_agent_response(chat_agent, user_message, chat_session_data=None, database=None) -> StreamingResponse: +async def stream_agent_response(chat_agent, user_message, database, chat_session_data=None) -> StreamingResponse: """Stream agent response with proper formatting""" async def message_stream_generator(): @@ -75,7 +75,7 @@ async def stream_agent_response(chat_agent, user_message, chat_session_data=None import utils.llm_proxy as llm_manager async for generated_message in chat_agent.generate( - llm=llm_manager.get_llm(), + llm=llm_manager.get_llm(database.redis), model=defines.model, session_id=user_message.session_id, prompt=user_message.content, @@ -186,7 +186,7 @@ async def reformat_as_markdown(database, candidate_entity, content: str): message = None async for message in chat_agent.llm_one_shot( - llm=llm_manager.get_llm(), + llm=llm_manager.get_llm(database.redis), model=defines.model, session_id=MOCK_UUID, prompt=content, @@ -267,7 +267,10 @@ async def create_job_from_content(database, current_user, content: str): message = None async for message in chat_agent.generate( - llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content + llm=llm_manager.get_llm(database.redis), + model=defines.model, + session_id=MOCK_UUID, + prompt=markdown_message.content, ): if message.status != ApiStatusType.DONE: yield message diff --git a/src/backend/utils/llm_proxy.py b/src/backend/utils/llm_proxy.py index 4edf13f..e2c215b 100644 --- a/src/backend/utils/llm_proxy.py +++ b/src/backend/utils/llm_proxy.py @@ -1,13 +1,18 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Any, AsyncGenerator, Optional, Union +import time +from typing import Dict, List, Any, AsyncGenerator, Optional, Union, cast from pydantic import BaseModel, Field from enum import Enum import asyncio from dataclasses import dataclass import os + +from redis.asyncio import Redis import defines from logger import logger - +import hashlib +import json +from dataclasses import asdict # Standard message format for all providers @dataclass @@ -364,7 +369,7 @@ class OpenAIAdapter(BaseLLMAdapter): usage_stats = self._create_usage_stats(response.usage) return ChatResponse( - content=response.choices[0].message.content, + content=response.choices[0].message.content, # type: ignore model=model, finish_reason=response.choices[0].finish_reason, usage=usage_stats, @@ -372,7 +377,7 @@ class OpenAIAdapter(BaseLLMAdapter): async def _stream_chat(self, model: str, messages: List[Dict], **kwargs): # Await the stream creation first, then iterate - stream = await self.client.chat.completions.create(model=model, messages=messages, stream=True, **kwargs) + stream = await self.client.chat.completions.create(model=model, messages=messages, stream=True, **kwargs) # type: ignore async for chunk in stream: if chunk.choices[0].delta.content: @@ -474,7 +479,10 @@ class AnthropicAdapter(BaseLLMAdapter): usage_stats = self._create_usage_stats(response.usage) return ChatResponse( - content=response.content[0].text, model=model, finish_reason=response.stop_reason, usage=usage_stats + content=response.content[0].text, # type: ignore + model=model, + finish_reason=response.stop_reason, + usage=usage_stats, # type: ignore ) async def _stream_chat(self, **kwargs): @@ -500,7 +508,7 @@ class AnthropicAdapter(BaseLLMAdapter): # Yield a final empty response with usage stats yield ChatResponse(content="", model=model, usage=final_usage_stats, finish_reason="stop") except Exception as e: - logger.debug(f"Could not retrieve final usage stats: {e}") + logger.info(f"Could not retrieve final usage stats: {e}") async def generate( self, model: str, prompt: str, stream: bool = False, **kwargs @@ -524,9 +532,9 @@ class GeminiAdapter(BaseLLMAdapter): def __init__(self, **config): super().__init__(**config) - import google.generativeai as genai # type: ignore + import google.generativeai as genai # type: ignore - genai.configure(api_key=config.get("api_key", os.getenv("GEMINI_API_KEY"))) + genai.configure(api_key=config.get("api_key", os.getenv("GEMINI_API_KEY"))) # type: ignore self.genai = genai def _create_usage_stats(self, response: Any) -> UsageStats: @@ -558,7 +566,7 @@ class GeminiAdapter(BaseLLMAdapter): async def chat( self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: - model_instance = self.genai.GenerativeModel(model) + model_instance = self.genai.GenerativeModel(model) # type: ignore # Convert messages to Gemini format chat_history = [] @@ -638,7 +646,7 @@ class GeminiAdapter(BaseLLMAdapter): # Gemini embeddings - use embedding model for i, text in enumerate(texts): - response = self.genai.embed_content(model=model, content=text, **kwargs) + response = self.genai.embed_content(model=model, content=text, **kwargs) # type: ignore results.append(EmbeddingData(embedding=response["embedding"], index=i)) @@ -649,18 +657,481 @@ class GeminiAdapter(BaseLLMAdapter): ) async def list_models(self) -> List[str]: - models = self.genai.list_models() + models = self.genai.list_models() # type: ignore return [model.name for model in models if "generateContent" in model.supported_generation_methods] -class UnifiedLLMProxy: +@dataclass +class CacheStats: + total_requests: int = 0 + cache_hits: int = 0 + cache_misses: int = 0 + total_cached_responses: int = 0 + cache_size_bytes: int = 0 + + @property + def hit_rate(self) -> float: + total = self.cache_hits + self.cache_misses + return self.cache_hits / total if total > 0 else 0.0 + + +@dataclass +class CacheConfig: + enabled: bool = True + ttl: int = 31536000 # 1 year + max_size: int = 1000 # Maximum number of cached entries (COUNT, not bytes) + provider_specific: bool = False + cache_streaming: bool = True # Enable/disable streaming cache + streaming_chunk_size: int = 5 # Size of chunks when simulating streaming from cache + streaming_delay: float = 0.01 # Delay between chunks in seconds + hash_algorithm: str = "sha256" + + +# Cache key prefixes +CACHE_KEY_PREFIXES = {"llm_cache": "llm_cache:", "cache_stats": "cache_stats:", "cache_config": "cache_config:"} + + +class LLMCacheMixin: + """Mixin class to add caching capabilities to LLM adapters""" + + def __init__(self, redis_client: Redis, **kwargs): + super().__init__(**kwargs) + self.redis: Redis = redis_client + self.cache_config = CacheConfig() + self.cache_stats = CacheStats() + + def _serialize_chat_response(self, response: ChatResponse) -> Dict[str, Any]: + """Safely serialize ChatResponse to dictionary""" + return { + "content": response.content, + "model": response.model, + "finish_reason": response.finish_reason, + "usage": response.usage.model_dump(exclude_none=True) if response.usage else None, + "usage_legacy": response.usage_legacy, + } + + def _serialize_for_hash(self, obj: Any) -> str: + """Serialize object to consistent string for hashing""" + + def json_default(o): + if hasattr(o, "__dict__"): + return o.__dict__ + elif hasattr(o, "_asdict"): + return o._asdict() + elif isinstance(o, (set, frozenset)): + return sorted(list(o)) + elif hasattr(o, "isoformat"): + return o.isoformat() + else: + return str(o) + + try: + return json.dumps(obj, default=json_default, ensure_ascii=False, sort_keys=True, separators=(",", ":")) + except (TypeError, ValueError) as e: + logger.warning(f"Serialization fallback for {type(obj)}: {e}") + return str(obj) + + def _create_cache_key( + self, + method: str, + model: str, + messages_or_prompt: Union[List[LLMMessage], str], + provider: Optional[LLMProvider] = None, + **kwargs, + ) -> str: + """Create consistent cache key from parameters""" + + cache_kwargs = {k: v for k, v in kwargs.items() if k not in ["stream", "callback", "timeout"]} + + if isinstance(messages_or_prompt, list): + serializable_messages = [] + for msg in messages_or_prompt: + if isinstance(msg, LLMMessage): + serializable_messages.append(msg.to_dict()) + else: + serializable_messages.append(msg) + messages_or_prompt = serializable_messages + + cache_data = {"method": method, "model": model, "content": messages_or_prompt, "kwargs": cache_kwargs} + + if self.cache_config.provider_specific and provider: + cache_data["provider"] = provider.value + + serialized = self._serialize_for_hash(cache_data) + hash_obj = hashlib.sha256(serialized.encode("utf-8")) + return hash_obj.hexdigest() + + async def _track_cache_access(self, cache_key: str) -> None: + """Track cache access time for LRU eviction""" + if not self.redis: + return + + try: + access_key = f"{CACHE_KEY_PREFIXES['llm_cache']}access_times" + current_time = time.time() + await self.redis.zadd(access_key, {cache_key: current_time}) + except Exception as e: + logger.error(f"Error tracking cache access: {e}") + + async def _enforce_cache_size_limit_efficient(self) -> None: + """Efficient LRU enforcement using sorted sets""" + if not self.redis or not self.cache_config.enabled: + return + + try: + access_key = f"{CACHE_KEY_PREFIXES['llm_cache']}access_times" + + # Get current cache count + current_count = await self.redis.zcard(access_key) + + if current_count >= self.cache_config.max_size: + # Calculate how many to remove + entries_to_remove = current_count - self.cache_config.max_size + 1 + + # Get oldest entries (lowest scores) + oldest_entries = await self.redis.zrange(access_key, 0, entries_to_remove - 1) + + if oldest_entries: + # Remove from both cache and access tracking + pipe = self.redis.pipeline() + + for cache_key in oldest_entries: + full_cache_key = f"{CACHE_KEY_PREFIXES['llm_cache']}{cache_key}" + pipe.delete(full_cache_key) + pipe.zrem(access_key, cache_key) + + await pipe.execute() + logger.info(f"LRU eviction: removed {len(oldest_entries)} entries") + + except Exception as e: + logger.error(f"Error in LRU eviction: {e}") + + async def _get_cached_response(self, cache_key: str) -> Optional[ChatResponse]: + """Retrieve cached response and update access time""" + if not self.redis or not self.cache_config.enabled: + return None + + try: + full_key = f"{CACHE_KEY_PREFIXES['llm_cache']}{cache_key}" + cached_data = await self.redis.get(full_key) + + if cached_data: + # Update access time for LRU + await self._track_cache_access(cache_key) + + if isinstance(cached_data, bytes): + cached_data = cached_data.decode("utf-8") + + response_data = json.loads(cached_data) + + if "usage" in response_data and response_data["usage"]: + response_data["usage"] = UsageStats(**response_data["usage"]) + + self.cache_stats.cache_hits += 1 + await self._update_cache_stats() + + logger.info(f"Cache hit for key: {cache_key[:16]}...") + return ChatResponse(**response_data) + + except Exception as e: + logger.error(f"Error retrieving cached response: {e}") + raise e + + return None + + async def _set_cached_response(self, cache_key: str, response: ChatResponse) -> None: + """Store response with efficient LRU tracking""" + if not self.redis or not self.cache_config.enabled: + return + + try: + # Enforce size limit before adding + await self._enforce_cache_size_limit_efficient() + + full_key = f"{CACHE_KEY_PREFIXES['llm_cache']}{cache_key}" + + response_data = self._serialize_chat_response(response) + response_data = {k: v for k, v in response_data.items() if v is not None} + serialized_data = json.dumps(response_data, default=str) + + # Store data and track access + pipe = self.redis.pipeline() + pipe.setex(full_key, self.cache_config.ttl, serialized_data) + await pipe.execute() + + await self._track_cache_access(cache_key) + + self.cache_stats.total_cached_responses += 1 + self.cache_stats.cache_size_bytes += len(serialized_data) + await self._update_cache_stats() + + logger.info(f"Cached response for key: {cache_key[:16]}...") + + except Exception as e: + logger.error(f"Error caching response: {e}") + + async def _update_cache_stats(self) -> None: + """Update cache statistics in Redis""" + if not self.redis: + return + + try: + stats_key = f"{CACHE_KEY_PREFIXES['cache_stats']}global" + stats_data = asdict(self.cache_stats) + serialized_stats = json.dumps(stats_data) + await self.redis.set(stats_key, serialized_stats) + except Exception as e: + logger.error(f"Error updating cache stats: {e}") + + async def get_cache_stats(self) -> CacheStats: + """Get current cache statistics""" + if not self.redis: + return self.cache_stats + + try: + stats_key = f"{CACHE_KEY_PREFIXES['cache_stats']}global" + cached_stats = await self.redis.get(stats_key) + + if cached_stats: + if isinstance(cached_stats, bytes): + cached_stats = cached_stats.decode("utf-8") + stats_data = json.loads(cached_stats) + self.cache_stats = CacheStats(**stats_data) + + except Exception as e: + logger.error(f"Error retrieving cache stats: {e}") + + return self.cache_stats + + async def clear_cache(self, pattern: str = "*") -> int: + """Clear cached responses matching pattern""" + if not self.redis: + return 0 + + try: + search_pattern = f"{CACHE_KEY_PREFIXES['llm_cache']}{pattern}" + keys = await self.redis.keys(search_pattern) + + if keys: + deleted = await self.redis.delete(*keys) + logger.info(f"Cleared {deleted} cached responses") + + self.cache_stats = CacheStats() + await self._update_cache_stats() + + return deleted + + except Exception as e: + logger.error(f"Error clearing cache: {e}") + + return 0 + + async def configure_cache(self, **config_kwargs) -> None: + """Update cache configuration""" + for key, value in config_kwargs.items(): + if hasattr(self.cache_config, key): + setattr(self.cache_config, key, value) + logger.info(f"Updated cache config: {key} = {value}") + + +class UnifiedLLMProxy(LLMCacheMixin): """Main proxy class that provides unified interface to all LLM providers""" - def __init__(self, default_provider: LLMProvider = LLMProvider.OLLAMA): + def __init__(self, redis_client: Redis, default_provider: LLMProvider = LLMProvider.OLLAMA): + # Initialize with caching support + super().__init__(redis_client=redis_client) self.adapters: Dict[LLMProvider, BaseLLMAdapter] = {} self.default_provider = default_provider self._initialized_providers = set() + async def _simulate_streaming_from_cache( + self, cached_response: ChatResponse, chunk_size: Optional[int] = None + ) -> AsyncGenerator[ChatResponse, None]: + """Simulate streaming by yielding cached response in chunks""" + content = cached_response.content + + # Use configured chunk size if not provided + if chunk_size is None: + chunk_size = self.cache_config.streaming_chunk_size + + # Yield content in chunks to simulate streaming + for i in range(0, len(content), chunk_size): + chunk_content = content[i : i + chunk_size] + + # Create chunk response + chunk_response = ChatResponse( + content=chunk_content, + model=cached_response.model, + finish_reason=None, # Not finished yet + usage=None, # Usage only in final chunk + ) + + yield chunk_response + + # Use configured streaming delay + await asyncio.sleep(self.cache_config.streaming_delay) + + # Yield final empty chunk with usage stats and finish reason + final_chunk = ChatResponse( + content="", + model=cached_response.model, + finish_reason=cached_response.finish_reason, + usage=cached_response.usage, + ) + + yield final_chunk + + async def _accumulate_and_cache_stream( + self, cache_key: str, stream: AsyncGenerator[ChatResponse, None], model: str + ) -> AsyncGenerator[ChatResponse, None]: + """Accumulate streaming response and cache the final result""" + accumulated_content = "" + final_usage = None + final_finish_reason = None + + async for chunk in stream: + # Yield the chunk immediately to maintain streaming behavior + yield chunk + + # Accumulate content and final metadata + if chunk.content: + accumulated_content += chunk.content + + if chunk.usage: + final_usage = chunk.usage + + if chunk.finish_reason: + final_finish_reason = chunk.finish_reason + + # Create final response for caching + if accumulated_content: # Only cache if we got content + final_response = ChatResponse( + content=accumulated_content, model=model, finish_reason=final_finish_reason, usage=final_usage + ) + + # Cache the complete response + await self._set_cached_response(cache_key, final_response) + logger.info(f"Cached accumulated streaming response for key: {cache_key[:16]}...") + + async def chat( + self, + model: str, + messages: Union[List[LLMMessage], List[Dict[str, str]]], + provider: Optional[LLMProvider] = None, + stream: bool = False, + **kwargs, + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + """Enhanced chat method with streaming cache support""" + + self.cache_stats.total_requests += 1 + + # Normalize messages for caching + normalized_messages = [] + for msg in messages: + if isinstance(msg, LLMMessage): + normalized_messages.append(msg) + elif isinstance(msg, dict): + normalized_messages.append(LLMMessage.from_dict(msg)) + else: + raise ValueError(f"Invalid message type: {type(msg)}") + + # Create cache key + cache_key = self._create_cache_key( + method="chat", model=model, messages_or_prompt=normalized_messages, provider=provider, **kwargs + ) + + # Check cache + if self.cache_config.enabled: + cached_response = await self._get_cached_response(cache_key) + if cached_response: + if stream: + logger.info("Cache hit - returning cached streaming response") + return self._simulate_streaming_from_cache(cached_response) + else: + logger.info("Cache hit - returning cached single response") + return cached_response + + # Cache miss + self.cache_stats.cache_misses += 1 + + provider = provider or self.default_provider + adapter = self._get_adapter(provider) + + if stream: + # For streaming: we know this will return AsyncGenerator + stream_response = await adapter.chat(model, normalized_messages, stream=True, **kwargs) + stream_response = cast(AsyncGenerator[ChatResponse, None], stream_response) + + if self.cache_config.enabled: + return self._accumulate_and_cache_stream(cache_key, stream_response, model) + else: + return stream_response + else: + # For non-streaming: we know this will return ChatResponse + single_response = await adapter.chat(model, normalized_messages, stream=False, **kwargs) + single_response = cast(ChatResponse, single_response) + + if self.cache_config.enabled: + await self._set_cached_response(cache_key, single_response) + + return single_response + + async def generate( + self, + model: str, + prompt: str, + provider: Optional[LLMProvider] = None, + stream: bool = False, + **kwargs, + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + """Enhanced generate method with streaming cache support""" + + self.cache_stats.total_requests += 1 + + # Create cache key + cache_key = self._create_cache_key( + method="generate", model=model, messages_or_prompt=prompt, provider=provider, **kwargs + ) + + # Check cache for both streaming and non-streaming requests + if self.cache_config.enabled: + cached_response = await self._get_cached_response(cache_key) + if cached_response: + logger.info("Cache hit - returning cached response") + + if stream: + # Return cached response as simulated stream + return self._simulate_streaming_from_cache(cached_response) + else: + # Return cached response directly + return cached_response + + # Cache miss - call the actual LLM + self.cache_stats.cache_misses += 1 + + provider = provider or self.default_provider + adapter = self._get_adapter(provider) + + if stream: + # For streaming: we know this will return AsyncGenerator + stream_response = await adapter.generate(model, prompt, stream=True, **kwargs) + stream_response = cast(AsyncGenerator[ChatResponse, None], stream_response) + + if self.cache_config.enabled: + # Return stream that accumulates and caches + return self._accumulate_and_cache_stream(cache_key, stream_response, model) + else: + # Return original stream if caching disabled + return stream_response + else: + # For non-streaming: we know this will return ChatResponse + single_response = await adapter.generate(model, prompt, stream=False, **kwargs) + single_response = cast(ChatResponse, single_response) + if self.cache_config.enabled: + await self._set_cached_response(cache_key, single_response) + + return single_response + def configure_provider(self, provider: LLMProvider, **config): """Configure a specific provider with its settings""" adapter_classes = { @@ -683,31 +1154,6 @@ class UnifiedLLMProxy: raise ValueError(f"Provider {provider} not configured") self.default_provider = provider - async def chat( - self, - model: str, - messages: Union[List[LLMMessage], List[Dict[str, str]]], - provider: Optional[LLMProvider] = None, - stream: bool = False, - **kwargs, - ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: - """Send chat messages using specified or default provider""" - - provider = provider or self.default_provider - adapter = self._get_adapter(provider) - - # Normalize messages to LLMMessage objects - normalized_messages = [] - for msg in messages: - if isinstance(msg, LLMMessage): - normalized_messages.append(msg) - elif isinstance(msg, dict): - normalized_messages.append(LLMMessage.from_dict(msg)) - else: - raise ValueError(f"Invalid message type: {type(msg)}") - - return await adapter.chat(model, normalized_messages, stream, **kwargs) - async def chat_stream( self, model: str, @@ -739,16 +1185,6 @@ class UnifiedLLMProxy: raise RuntimeError("Expected ChatResponse, got AsyncGenerator") return result - async def generate( - self, model: str, prompt: str, provider: Optional[LLMProvider] = None, stream: bool = False, **kwargs - ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: - """Generate text using specified or default provider""" - - provider = provider or self.default_provider - adapter = self._get_adapter(provider) - - return await adapter.generate(model, prompt, stream, **kwargs) - async def generate_stream( self, model: str, prompt: str, provider: Optional[LLMProvider] = None, **kwargs ) -> AsyncGenerator[ChatResponse, None]: @@ -793,7 +1229,9 @@ class UnifiedLLMProxy: # Provider-specific embedding models embedding_models = { - LLMProvider.OLLAMA: ["nomic-embed-text", "mxbai-embed-large", "all-minilm", "snowflake-arctic-embed"], + LLMProvider.OLLAMA: [ + "mxbai-embed-large", + ], LLMProvider.OPENAI: ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"], LLMProvider.ANTHROPIC: [], # No embeddings API LLMProvider.GEMINI: ["models/embedding-001", "models/text-embedding-004"], @@ -815,7 +1253,7 @@ class UnifiedLLMProxy: available_embedding_models.append(model) return available_embedding_models if available_embedding_models else suggested_models - except: + except Exception: return embedding_models[provider] return embedding_models.get(provider, []) @@ -827,77 +1265,90 @@ class UnifiedLLMProxy: return self.adapters[provider] -# Example usage and configuration class LLMManager: - """Singleton manager for the unified LLM proxy""" + """Updated manager with caching support""" _instance = None _proxy = None @classmethod - def get_instance(cls): + def get_instance(cls, redis_client: Redis): if cls._instance is None: - cls._instance = cls() + cls._instance = cls(redis_client) return cls._instance - def __init__(self): + def __init__(self, redis_client: Redis): if LLMManager._proxy is None: - LLMManager._proxy = UnifiedLLMProxy() + LLMManager._proxy = UnifiedLLMProxy(redis_client=redis_client) self._configure_from_environment() def _configure_from_environment(self): - """Configure providers based on environment variables""" + """Configure providers and cache from environment variables""" if not self._proxy: raise RuntimeError("UnifiedLLMProxy instance not initialized") - # Configure Ollama if available + + # Your existing configuration code... ollama_host = os.getenv("OLLAMA_HOST", defines.ollama_api_url) self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host) - # Configure OpenAI if API key is available if os.getenv("OPENAI_API_KEY"): self._proxy.configure_provider(LLMProvider.OPENAI) - # Configure Anthropic if API key is available if os.getenv("ANTHROPIC_API_KEY"): self._proxy.configure_provider(LLMProvider.ANTHROPIC) - # Configure Gemini if API key is available if os.getenv("GEMINI_API_KEY"): self._proxy.configure_provider(LLMProvider.GEMINI) - # Set default provider from environment or use Ollama + # Add cache configuration from environment + cache_config = {} + if os.getenv("LLM_CACHE_TTL"): + cache_config["ttl"] = int(os.getenv("LLM_CACHE_TTL", 31536000)) + if os.getenv("LLM_CACHE_MAX_SIZE"): + cache_config["max_size"] = int(os.getenv("LLM_CACHE_MAX_SIZE", 5000)) + if os.getenv("LLM_CACHE_PROVIDER_SPECIFIC"): + cache_config["provider_specific"] = os.getenv("LLM_CACHE_PROVIDER_SPECIFIC", "false").lower() == "true" + if os.getenv("LLM_CACHE_ENABLED"): + cache_config["enabled"] = os.getenv("LLM_CACHE_ENABLED", "true").lower() == "true" + if os.getenv("LLM_CACHE_STREAMING_DELAY"): + cache_config["streaming_delay"] = float(os.getenv("LLM_CACHE_STREAMING_DELAY", "0.01")) + if os.getenv("LLM_CACHE_STREAMING_CHUNK_SIZE"): + cache_config["streaming_chunk_size"] = int(os.getenv("LLM_CACHE_STREAMING_CHUNK_SIZE", "50")) + if os.getenv("LLM_CACHE_STREAMING_ENABLED"): + cache_config["cache_streaming"] = os.getenv("LLM_CACHE_STREAMING_ENABLED", "true").lower() == "true" + if cache_config: + asyncio.create_task(self._proxy.configure_cache(**cache_config)) + + # Your existing default provider setup... default_provider = os.getenv("DEFAULT_LLM_PROVIDER", "ollama") try: self._proxy.set_default_provider(LLMProvider(default_provider)) except ValueError: - # Fallback to Ollama if specified provider not available self._proxy.set_default_provider(LLMProvider.OLLAMA) def get_proxy(self) -> UnifiedLLMProxy: - """Get the unified LLM proxy instance""" + """Get the LLM proxy instance""" if not self._proxy: raise RuntimeError("UnifiedLLMProxy instance not initialized") return self._proxy -# Convenience function for easy access -def get_llm() -> UnifiedLLMProxy: - """Get the configured LLM proxy""" - return LLMManager.get_instance().get_proxy() +def get_llm(redis_client: Redis) -> UnifiedLLMProxy: + """Get the configured LLM proxy with caching""" + return LLMManager.get_instance(redis_client).get_proxy() # Example usage with detailed statistics -async def example_usage(): +async def example_usage(llm: UnifiedLLMProxy): """Example showing how to access detailed statistics""" - llm = get_llm() # Configure providers - llm.configure_provider(LLMProvider.OLLAMA, host="http://localhost:11434") + llm.configure_provider(LLMProvider.OLLAMA, host=defines.ollama_api_url) # Simple chat messages = [LLMMessage(role="user", content="Explain quantum computing in one paragraph")] - response = await llm.chat_single("llama2", messages) + response = await llm.chat_single(defines.model, messages) print(f"Content: {response.content}") print(f"Model: {response.model}") @@ -923,12 +1374,11 @@ async def example_usage(): print(f"Usage as dict: {usage_dict}") -async def example_embeddings_usage(): +async def example_embeddings_usage(llm: UnifiedLLMProxy): """Example showing how to use the embeddings API""" - llm = get_llm() # Configure providers - llm.configure_provider(LLMProvider.OLLAMA, host="http://localhost:11434") + llm.configure_provider(LLMProvider.OLLAMA, host=defines.ollama_api_url) if os.getenv("OPENAI_API_KEY"): llm.configure_provider(LLMProvider.OPENAI) @@ -952,7 +1402,7 @@ async def example_embeddings_usage(): single_text = "The quick brown fox jumps over the lazy dog" try: - response = await llm.embeddings("nomic-embed-text", single_text, provider=LLMProvider.OLLAMA) + response = await llm.embeddings(defines.embedding_model, single_text, provider=LLMProvider.OLLAMA) print(f"Model: {response.model}") print(f"Embedding dimension: {len(response.get_single_embedding())}") print(f"First 5 values: {response.get_single_embedding()[:5]}") @@ -994,7 +1444,7 @@ async def example_embeddings_usage(): print(f"Usage: {response.usage.model_dump(exclude_none=True)}") else: # Fallback to Ollama for batch - response = await llm.embeddings("nomic-embed-text", texts, provider=LLMProvider.OLLAMA) + response = await llm.embeddings(defines.embedding_model, texts, provider=LLMProvider.OLLAMA) print(f"Ollama Model: {response.model}") print(f"Number of embeddings: {len(response.data)}") print(f"Embedding dimension: {len(response.data[0].embedding)}") @@ -1003,16 +1453,15 @@ async def example_embeddings_usage(): print(f"Batch embedding failed: {e}") -async def example_streaming_with_stats(): +async def example_streaming_with_stats(llm: UnifiedLLMProxy): """Example showing how to collect usage stats from streaming responses""" - llm = get_llm() messages = [LLMMessage(role="user", content="Write a short story about AI")] print("Streaming response:") final_stats = None - async for chunk in llm.chat_stream("llama2", messages): + async for chunk in llm.chat_stream(defines.model, messages): # Print content as it streams if chunk.content: print(chunk.content, end="", flush=True) @@ -1034,8 +1483,18 @@ async def example_streaming_with_stats(): if __name__ == "__main__": - asyncio.run(example_usage()) - print("\n" + "=" * 50 + "\n") - asyncio.run(example_streaming_with_stats()) - print("\n" + "=" * 50 + "\n") - asyncio.run(example_embeddings_usage()) + from database.manager import DatabaseManager + + async def main(): + """Main entry point to run examples""" + db_manager = DatabaseManager() + await db_manager.initialize() + redis_client = db_manager.get_database().redis + llm = get_llm(redis_client) + await example_usage(llm) + print("\n" + "=" * 50 + "\n") + await example_streaming_with_stats(llm) + print("\n" + "=" * 50 + "\n") + # await example_embeddings_usage(llm) + + asyncio.run(main())