diff --git a/src/backend/agents/base.py b/src/backend/agents/base.py index 1031a2c..08f6981 100644 --- a/src/backend/agents/base.py +++ b/src/backend/agents/base.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, model_validator # type: ignore from typing import ( Literal, get_args, @@ -20,8 +20,8 @@ import re from abc import ABC import asyncio from datetime import datetime, UTC -from prometheus_client import Counter, Summary, CollectorRegistry -import numpy as np +from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore +import numpy as np # type: ignore from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate) from logger import logger @@ -31,7 +31,7 @@ from metrics import Metrics import model_cast import backstory_traceback as traceback -from rag import ( ChromaDBGetResponse ) +from models import ( ChromaDBGetResponse ) class Agent(BaseModel, ABC): """ @@ -509,21 +509,23 @@ Content: {content} self.metrics.generate_count.labels(agent=self.agent_type).inc() with self.metrics.generate_duration.labels(agent=self.agent_type).time(): context = None + rag_message : ChatMessageRagSearch | None = None if self.user: - rag_message = None - async for rag_message in self.generate_rag_results(session_id=session_id, prompt=prompt): - if rag_message.status == ApiStatusType.ERROR: - yield rag_message + message = None + async for message in self.generate_rag_results(session_id=session_id, prompt=prompt): + if message.status == ApiStatusType.ERROR: + yield message return # Only yield messages that are in a streaming state - if rag_message.status == ApiStatusType.STATUS: - yield rag_message + if message.status == ApiStatusType.STATUS: + yield message - if not isinstance(rag_message, ChatMessageRagSearch): + if not isinstance(message, ChatMessageRagSearch): raise ValueError( f"Expected ChatMessageRagSearch, got {type(rag_message)}" ) + rag_message = message context = self.get_rag_context(rag_message) # Create a pruned down message list based purely on the prompt and responses, @@ -724,6 +726,7 @@ Content: {content} eval_duration=response.usage.eval_duration, prompt_eval_count=response.usage.prompt_eval_count, prompt_eval_duration=response.usage.prompt_eval_duration, + rag_results=rag_message.content if rag_message else [], timers={ "llm_streamed": end_time - start_time, "llm_with_tools": 0, # Placeholder for tool processing time diff --git a/src/backend/agents/general.py b/src/backend/agents/general.py index 445ab72..308de40 100644 --- a/src/backend/agents/general.py +++ b/src/backend/agents/general.py @@ -52,8 +52,7 @@ class Chat(Agent): """ Chat Agent """ - - agent_type: Literal["general"] = "general" + agent_type: Literal["general"] = "general" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration system_prompt: str = system_message diff --git a/src/backend/agents/rag_search.py b/src/backend/agents/rag_search.py index 277d52b..37628c7 100644 --- a/src/backend/agents/rag_search.py +++ b/src/backend/agents/rag_search.py @@ -7,15 +7,13 @@ from .base import Agent, agent_registry from logger import logger from .registry import agent_registry -from models import ( ChatMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, ChatOptions, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate, Tunables ) -from rag import ( ChromaDBGetResponse ) +from models import ( ChatMessage, ChromaDBGetResponse, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, ChatOptions, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate, Tunables ) class Chat(Agent): """ Chat Agent """ - - agent_type: Literal["rag_search"] = "rag_search" + agent_type: Literal["rag_search"] = "rag_search" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration async def generate( diff --git a/src/backend/entities/candidate_entity.py b/src/backend/entities/candidate_entity.py index 691b3d2..e693c55 100644 --- a/src/backend/entities/candidate_entity.py +++ b/src/backend/entities/candidate_entity.py @@ -1,20 +1,20 @@ from __future__ import annotations -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, model_validator # type: ignore from uuid import uuid4 from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal from typing_extensions import Annotated, Union -import numpy as np +import numpy as np # type: ignore from uuid import uuid4 -from prometheus_client import CollectorRegistry, Counter +from prometheus_client import CollectorRegistry, Counter # type: ignore import traceback import os import json import re from pathlib import Path -from rag import start_file_watcher, ChromaDBFileWatcher, ChromaDBGetResponse +from rag import start_file_watcher, ChromaDBFileWatcher import defines from logger import logger import agents as agents @@ -22,6 +22,7 @@ from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, R import llm_proxy as llm_manager from agents.base import Agent from database import RedisDatabase +from models import ChromaDBGetResponse class CandidateEntity(Candidate): model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc diff --git a/src/backend/main.py b/src/backend/main.py index b4b9a8e..c8e80ab 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -2855,6 +2855,7 @@ async def post_candidate_vector_content( ): try: if current_user.user_type != "candidate": + logger.warning(f"⚠️ Unauthorized access attempt by user type: {current_user.user_type}") return JSONResponse( status_code=403, content=create_error_response("FORBIDDEN", "Only candidates can access this endpoint") @@ -2864,13 +2865,11 @@ async def post_candidate_vector_content( async with entities.get_candidate_entity(candidate=candidate) as candidate_entity: collection = candidate_entity.umap_collection if not collection: + logger.warning(f"⚠️ No UMAP collection found for candidate {candidate.username}") return JSONResponse( {"error": "No UMAP collection found"}, status_code=404 ) - if not collection.metadatas or collection.ids: - return JSONResponse(f"Document id {rag_document.id} not found.", 404) - for index, id in enumerate(collection.ids): if id == rag_document.id: metadata = collection.metadatas[index].copy() @@ -2883,9 +2882,11 @@ async def post_candidate_vector_content( logger.warning(f"⚠️ No content found for document id {id} for candidate {candidate.username}") return JSONResponse(f"No content found for document id {rag_document.id}.", 404) return create_success_response(rag_response.model_dump(by_alias=True)) - + + logger.warning(f"⚠️ Document id {rag_document.id} not found in UMAP collection for candidate {candidate.username}") return JSONResponse(f"Document id {rag_document.id} not found.", 404) except Exception as e: + logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Post candidate content error: {e}") return JSONResponse( status_code=500, @@ -2942,6 +2943,7 @@ async def post_candidate_vectors( return create_success_response(result) except Exception as e: + logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Post candidate vectors error: {e}") return JSONResponse( status_code=500, diff --git a/src/backend/rag/__init__.py b/src/backend/rag/__init__.py index 25a5745..d499be5 100644 --- a/src/backend/rag/__init__.py +++ b/src/backend/rag/__init__.py @@ -1,7 +1,6 @@ -from .rag import ChromaDBFileWatcher, ChromaDBGetResponse, start_file_watcher, RagEntry +from .rag import ChromaDBFileWatcher, start_file_watcher, RagEntry __all__ = [ "ChromaDBFileWatcher", - "ChromaDBGetResponse", "start_file_watcher", "RagEntry" ] diff --git a/src/backend/rag/rag.py b/src/backend/rag/rag.py index 2907cb6..af7c5ce 100644 --- a/src/backend/rag/rag.py +++ b/src/backend/rag/rag.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field +from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore from typing import List, Optional, Dict, Any, Union import os import glob @@ -9,15 +9,15 @@ import hashlib import asyncio import logging import json -import numpy as np +import numpy as np # type: ignore import traceback -import chromadb -from watchdog.observers import Observer -from watchdog.events import FileSystemEventHandler -import umap -from markitdown import MarkItDown -from chromadb.api.models.Collection import Collection +import chromadb # type: ignore +from watchdog.observers import Observer # type: ignore +from watchdog.events import FileSystemEventHandler # type: ignore +import umap # type: ignore +from markitdown import MarkItDown # type: ignore +from chromadb.api.models.Collection import Collection # type: ignore from .markdown_chunker import ( MarkdownChunker, @@ -305,18 +305,18 @@ class ChromaDBFileWatcher(FileSystemEventHandler): def _update_umaps(self): # Update the UMAP embeddings - self._umap_collection = self._collection.get( + self._umap_collection = ChromaDBGetResponse.model_validate(self._collection.get( include=["embeddings", "documents", "metadatas"] - ) - if not self._umap_collection or not len(self._umap_collection["embeddings"]): + )) + if not self._umap_collection or not len(self._umap_collection.embeddings): logging.warning("No embeddings found in the collection.") return # During initialization logging.info( - f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection['embeddings'])} vectors" + f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors" ) - vectors = np.array(self._umap_collection["embeddings"]) + vectors = np.array(self._umap_collection.embeddings) self._umap_model_2d = umap.UMAP( n_components=2, random_state=8911, @@ -330,7 +330,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): # ) # Should be 2 logging.info( - f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection['embeddings'])} vectors" + f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors" ) self._umap_model_3d = umap.UMAP( n_components=3,