RAG working
This commit is contained in:
parent
d477b85e5a
commit
3970cca715
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from typing import (
|
||||
Literal,
|
||||
get_args,
|
||||
@ -20,8 +20,8 @@ import re
|
||||
from abc import ABC
|
||||
import asyncio
|
||||
from datetime import datetime, UTC
|
||||
from prometheus_client import Counter, Summary, CollectorRegistry
|
||||
import numpy as np
|
||||
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate)
|
||||
from logger import logger
|
||||
@ -31,7 +31,7 @@ from metrics import Metrics
|
||||
import model_cast
|
||||
import backstory_traceback as traceback
|
||||
|
||||
from rag import ( ChromaDBGetResponse )
|
||||
from models import ( ChromaDBGetResponse )
|
||||
|
||||
class Agent(BaseModel, ABC):
|
||||
"""
|
||||
@ -509,21 +509,23 @@ Content: {content}
|
||||
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
||||
context = None
|
||||
rag_message : ChatMessageRagSearch | None = None
|
||||
if self.user:
|
||||
rag_message = None
|
||||
async for rag_message in self.generate_rag_results(session_id=session_id, prompt=prompt):
|
||||
if rag_message.status == ApiStatusType.ERROR:
|
||||
yield rag_message
|
||||
message = None
|
||||
async for message in self.generate_rag_results(session_id=session_id, prompt=prompt):
|
||||
if message.status == ApiStatusType.ERROR:
|
||||
yield message
|
||||
return
|
||||
# Only yield messages that are in a streaming state
|
||||
if rag_message.status == ApiStatusType.STATUS:
|
||||
yield rag_message
|
||||
if message.status == ApiStatusType.STATUS:
|
||||
yield message
|
||||
|
||||
if not isinstance(rag_message, ChatMessageRagSearch):
|
||||
if not isinstance(message, ChatMessageRagSearch):
|
||||
raise ValueError(
|
||||
f"Expected ChatMessageRagSearch, got {type(rag_message)}"
|
||||
)
|
||||
|
||||
rag_message = message
|
||||
context = self.get_rag_context(rag_message)
|
||||
|
||||
# Create a pruned down message list based purely on the prompt and responses,
|
||||
@ -724,6 +726,7 @@ Content: {content}
|
||||
eval_duration=response.usage.eval_duration,
|
||||
prompt_eval_count=response.usage.prompt_eval_count,
|
||||
prompt_eval_duration=response.usage.prompt_eval_duration,
|
||||
rag_results=rag_message.content if rag_message else [],
|
||||
timers={
|
||||
"llm_streamed": end_time - start_time,
|
||||
"llm_with_tools": 0, # Placeholder for tool processing time
|
||||
|
@ -52,8 +52,7 @@ class Chat(Agent):
|
||||
"""
|
||||
Chat Agent
|
||||
"""
|
||||
|
||||
agent_type: Literal["general"] = "general"
|
||||
agent_type: Literal["general"] = "general" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
system_prompt: str = system_message
|
||||
|
@ -7,15 +7,13 @@ from .base import Agent, agent_registry
|
||||
from logger import logger
|
||||
|
||||
from .registry import agent_registry
|
||||
from models import ( ChatMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, ChatOptions, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate, Tunables )
|
||||
from rag import ( ChromaDBGetResponse )
|
||||
from models import ( ChatMessage, ChromaDBGetResponse, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, ChatOptions, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate, Tunables )
|
||||
|
||||
class Chat(Agent):
|
||||
"""
|
||||
Chat Agent
|
||||
"""
|
||||
|
||||
agent_type: Literal["rag_search"] = "rag_search"
|
||||
agent_type: Literal["rag_search"] = "rag_search" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
async def generate(
|
||||
|
@ -1,20 +1,20 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from uuid import uuid4
|
||||
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
|
||||
|
||||
from typing_extensions import Annotated, Union
|
||||
import numpy as np
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from uuid import uuid4
|
||||
from prometheus_client import CollectorRegistry, Counter
|
||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||
import traceback
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from rag import start_file_watcher, ChromaDBFileWatcher, ChromaDBGetResponse
|
||||
from rag import start_file_watcher, ChromaDBFileWatcher
|
||||
import defines
|
||||
from logger import logger
|
||||
import agents as agents
|
||||
@ -22,6 +22,7 @@ from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, R
|
||||
import llm_proxy as llm_manager
|
||||
from agents.base import Agent
|
||||
from database import RedisDatabase
|
||||
from models import ChromaDBGetResponse
|
||||
|
||||
class CandidateEntity(Candidate):
|
||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
||||
|
@ -2855,6 +2855,7 @@ async def post_candidate_vector_content(
|
||||
):
|
||||
try:
|
||||
if current_user.user_type != "candidate":
|
||||
logger.warning(f"⚠️ Unauthorized access attempt by user type: {current_user.user_type}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Only candidates can access this endpoint")
|
||||
@ -2864,13 +2865,11 @@ async def post_candidate_vector_content(
|
||||
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
|
||||
collection = candidate_entity.umap_collection
|
||||
if not collection:
|
||||
logger.warning(f"⚠️ No UMAP collection found for candidate {candidate.username}")
|
||||
return JSONResponse(
|
||||
{"error": "No UMAP collection found"}, status_code=404
|
||||
)
|
||||
|
||||
if not collection.metadatas or collection.ids:
|
||||
return JSONResponse(f"Document id {rag_document.id} not found.", 404)
|
||||
|
||||
for index, id in enumerate(collection.ids):
|
||||
if id == rag_document.id:
|
||||
metadata = collection.metadatas[index].copy()
|
||||
@ -2883,9 +2882,11 @@ async def post_candidate_vector_content(
|
||||
logger.warning(f"⚠️ No content found for document id {id} for candidate {candidate.username}")
|
||||
return JSONResponse(f"No content found for document id {rag_document.id}.", 404)
|
||||
return create_success_response(rag_response.model_dump(by_alias=True))
|
||||
|
||||
|
||||
logger.warning(f"⚠️ Document id {rag_document.id} not found in UMAP collection for candidate {candidate.username}")
|
||||
return JSONResponse(f"Document id {rag_document.id} not found.", 404)
|
||||
except Exception as e:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Post candidate content error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
@ -2942,6 +2943,7 @@ async def post_candidate_vectors(
|
||||
return create_success_response(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Post candidate vectors error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
|
@ -1,7 +1,6 @@
|
||||
from .rag import ChromaDBFileWatcher, ChromaDBGetResponse, start_file_watcher, RagEntry
|
||||
from .rag import ChromaDBFileWatcher, start_file_watcher, RagEntry
|
||||
__all__ = [
|
||||
"ChromaDBFileWatcher",
|
||||
"ChromaDBGetResponse",
|
||||
"start_file_watcher",
|
||||
"RagEntry"
|
||||
]
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field
|
||||
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
import os
|
||||
import glob
|
||||
@ -9,15 +9,15 @@ import hashlib
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import numpy as np
|
||||
import numpy as np # type: ignore
|
||||
import traceback
|
||||
|
||||
import chromadb
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
import umap
|
||||
from markitdown import MarkItDown
|
||||
from chromadb.api.models.Collection import Collection
|
||||
import chromadb # type: ignore
|
||||
from watchdog.observers import Observer # type: ignore
|
||||
from watchdog.events import FileSystemEventHandler # type: ignore
|
||||
import umap # type: ignore
|
||||
from markitdown import MarkItDown # type: ignore
|
||||
from chromadb.api.models.Collection import Collection # type: ignore
|
||||
|
||||
from .markdown_chunker import (
|
||||
MarkdownChunker,
|
||||
@ -305,18 +305,18 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
|
||||
def _update_umaps(self):
|
||||
# Update the UMAP embeddings
|
||||
self._umap_collection = self._collection.get(
|
||||
self._umap_collection = ChromaDBGetResponse.model_validate(self._collection.get(
|
||||
include=["embeddings", "documents", "metadatas"]
|
||||
)
|
||||
if not self._umap_collection or not len(self._umap_collection["embeddings"]):
|
||||
))
|
||||
if not self._umap_collection or not len(self._umap_collection.embeddings):
|
||||
logging.warning("No embeddings found in the collection.")
|
||||
return
|
||||
|
||||
# During initialization
|
||||
logging.info(
|
||||
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
||||
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
|
||||
)
|
||||
vectors = np.array(self._umap_collection["embeddings"])
|
||||
vectors = np.array(self._umap_collection.embeddings)
|
||||
self._umap_model_2d = umap.UMAP(
|
||||
n_components=2,
|
||||
random_state=8911,
|
||||
@ -330,7 +330,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# ) # Should be 2
|
||||
|
||||
logging.info(
|
||||
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
||||
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
|
||||
)
|
||||
self._umap_model_3d = umap.UMAP(
|
||||
n_components=3,
|
||||
|
Loading…
x
Reference in New Issue
Block a user