RAG working

This commit is contained in:
James Ketr 2025-06-08 21:08:29 -07:00
parent d477b85e5a
commit 3970cca715
7 changed files with 43 additions and 41 deletions

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, model_validator # type: ignore
from typing import (
Literal,
get_args,
@ -20,8 +20,8 @@ import re
from abc import ABC
import asyncio
from datetime import datetime, UTC
from prometheus_client import Counter, Summary, CollectorRegistry
import numpy as np
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
import numpy as np # type: ignore
from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate)
from logger import logger
@ -31,7 +31,7 @@ from metrics import Metrics
import model_cast
import backstory_traceback as traceback
from rag import ( ChromaDBGetResponse )
from models import ( ChromaDBGetResponse )
class Agent(BaseModel, ABC):
"""
@ -509,21 +509,23 @@ Content: {content}
self.metrics.generate_count.labels(agent=self.agent_type).inc()
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
context = None
rag_message : ChatMessageRagSearch | None = None
if self.user:
rag_message = None
async for rag_message in self.generate_rag_results(session_id=session_id, prompt=prompt):
if rag_message.status == ApiStatusType.ERROR:
yield rag_message
message = None
async for message in self.generate_rag_results(session_id=session_id, prompt=prompt):
if message.status == ApiStatusType.ERROR:
yield message
return
# Only yield messages that are in a streaming state
if rag_message.status == ApiStatusType.STATUS:
yield rag_message
if message.status == ApiStatusType.STATUS:
yield message
if not isinstance(rag_message, ChatMessageRagSearch):
if not isinstance(message, ChatMessageRagSearch):
raise ValueError(
f"Expected ChatMessageRagSearch, got {type(rag_message)}"
)
rag_message = message
context = self.get_rag_context(rag_message)
# Create a pruned down message list based purely on the prompt and responses,
@ -724,6 +726,7 @@ Content: {content}
eval_duration=response.usage.eval_duration,
prompt_eval_count=response.usage.prompt_eval_count,
prompt_eval_duration=response.usage.prompt_eval_duration,
rag_results=rag_message.content if rag_message else [],
timers={
"llm_streamed": end_time - start_time,
"llm_with_tools": 0, # Placeholder for tool processing time

View File

@ -52,8 +52,7 @@ class Chat(Agent):
"""
Chat Agent
"""
agent_type: Literal["general"] = "general"
agent_type: Literal["general"] = "general" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_message

View File

@ -7,15 +7,13 @@ from .base import Agent, agent_registry
from logger import logger
from .registry import agent_registry
from models import ( ChatMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, ChatOptions, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate, Tunables )
from rag import ( ChromaDBGetResponse )
from models import ( ChatMessage, ChromaDBGetResponse, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, ChatOptions, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate, Tunables )
class Chat(Agent):
"""
Chat Agent
"""
agent_type: Literal["rag_search"] = "rag_search"
agent_type: Literal["rag_search"] = "rag_search" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
async def generate(

View File

@ -1,20 +1,20 @@
from __future__ import annotations
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, model_validator # type: ignore
from uuid import uuid4
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
from typing_extensions import Annotated, Union
import numpy as np
import numpy as np # type: ignore
from uuid import uuid4
from prometheus_client import CollectorRegistry, Counter
from prometheus_client import CollectorRegistry, Counter # type: ignore
import traceback
import os
import json
import re
from pathlib import Path
from rag import start_file_watcher, ChromaDBFileWatcher, ChromaDBGetResponse
from rag import start_file_watcher, ChromaDBFileWatcher
import defines
from logger import logger
import agents as agents
@ -22,6 +22,7 @@ from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, R
import llm_proxy as llm_manager
from agents.base import Agent
from database import RedisDatabase
from models import ChromaDBGetResponse
class CandidateEntity(Candidate):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc

View File

@ -2855,6 +2855,7 @@ async def post_candidate_vector_content(
):
try:
if current_user.user_type != "candidate":
logger.warning(f"⚠️ Unauthorized access attempt by user type: {current_user.user_type}")
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Only candidates can access this endpoint")
@ -2864,13 +2865,11 @@ async def post_candidate_vector_content(
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
collection = candidate_entity.umap_collection
if not collection:
logger.warning(f"⚠️ No UMAP collection found for candidate {candidate.username}")
return JSONResponse(
{"error": "No UMAP collection found"}, status_code=404
)
if not collection.metadatas or collection.ids:
return JSONResponse(f"Document id {rag_document.id} not found.", 404)
for index, id in enumerate(collection.ids):
if id == rag_document.id:
metadata = collection.metadatas[index].copy()
@ -2883,9 +2882,11 @@ async def post_candidate_vector_content(
logger.warning(f"⚠️ No content found for document id {id} for candidate {candidate.username}")
return JSONResponse(f"No content found for document id {rag_document.id}.", 404)
return create_success_response(rag_response.model_dump(by_alias=True))
logger.warning(f"⚠️ Document id {rag_document.id} not found in UMAP collection for candidate {candidate.username}")
return JSONResponse(f"Document id {rag_document.id} not found.", 404)
except Exception as e:
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Post candidate content error: {e}")
return JSONResponse(
status_code=500,
@ -2942,6 +2943,7 @@ async def post_candidate_vectors(
return create_success_response(result)
except Exception as e:
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Post candidate vectors error: {e}")
return JSONResponse(
status_code=500,

View File

@ -1,7 +1,6 @@
from .rag import ChromaDBFileWatcher, ChromaDBGetResponse, start_file_watcher, RagEntry
from .rag import ChromaDBFileWatcher, start_file_watcher, RagEntry
__all__ = [
"ChromaDBFileWatcher",
"ChromaDBGetResponse",
"start_file_watcher",
"RagEntry"
]

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
from typing import List, Optional, Dict, Any, Union
import os
import glob
@ -9,15 +9,15 @@ import hashlib
import asyncio
import logging
import json
import numpy as np
import numpy as np # type: ignore
import traceback
import chromadb
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
import umap
from markitdown import MarkItDown
from chromadb.api.models.Collection import Collection
import chromadb # type: ignore
from watchdog.observers import Observer # type: ignore
from watchdog.events import FileSystemEventHandler # type: ignore
import umap # type: ignore
from markitdown import MarkItDown # type: ignore
from chromadb.api.models.Collection import Collection # type: ignore
from .markdown_chunker import (
MarkdownChunker,
@ -305,18 +305,18 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
def _update_umaps(self):
# Update the UMAP embeddings
self._umap_collection = self._collection.get(
self._umap_collection = ChromaDBGetResponse.model_validate(self._collection.get(
include=["embeddings", "documents", "metadatas"]
)
if not self._umap_collection or not len(self._umap_collection["embeddings"]):
))
if not self._umap_collection or not len(self._umap_collection.embeddings):
logging.warning("No embeddings found in the collection.")
return
# During initialization
logging.info(
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection['embeddings'])} vectors"
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
)
vectors = np.array(self._umap_collection["embeddings"])
vectors = np.array(self._umap_collection.embeddings)
self._umap_model_2d = umap.UMAP(
n_components=2,
random_state=8911,
@ -330,7 +330,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# ) # Should be 2
logging.info(
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection['embeddings'])} vectors"
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
)
self._umap_model_3d = umap.UMAP(
n_components=3,