RAG working
This commit is contained in:
parent
d477b85e5a
commit
3970cca715
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
get_args,
|
get_args,
|
||||||
@ -20,8 +20,8 @@ import re
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from prometheus_client import Counter, Summary, CollectorRegistry
|
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
||||||
import numpy as np
|
import numpy as np # type: ignore
|
||||||
|
|
||||||
from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate)
|
from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate)
|
||||||
from logger import logger
|
from logger import logger
|
||||||
@ -31,7 +31,7 @@ from metrics import Metrics
|
|||||||
import model_cast
|
import model_cast
|
||||||
import backstory_traceback as traceback
|
import backstory_traceback as traceback
|
||||||
|
|
||||||
from rag import ( ChromaDBGetResponse )
|
from models import ( ChromaDBGetResponse )
|
||||||
|
|
||||||
class Agent(BaseModel, ABC):
|
class Agent(BaseModel, ABC):
|
||||||
"""
|
"""
|
||||||
@ -509,21 +509,23 @@ Content: {content}
|
|||||||
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||||
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
||||||
context = None
|
context = None
|
||||||
|
rag_message : ChatMessageRagSearch | None = None
|
||||||
if self.user:
|
if self.user:
|
||||||
rag_message = None
|
message = None
|
||||||
async for rag_message in self.generate_rag_results(session_id=session_id, prompt=prompt):
|
async for message in self.generate_rag_results(session_id=session_id, prompt=prompt):
|
||||||
if rag_message.status == ApiStatusType.ERROR:
|
if message.status == ApiStatusType.ERROR:
|
||||||
yield rag_message
|
yield message
|
||||||
return
|
return
|
||||||
# Only yield messages that are in a streaming state
|
# Only yield messages that are in a streaming state
|
||||||
if rag_message.status == ApiStatusType.STATUS:
|
if message.status == ApiStatusType.STATUS:
|
||||||
yield rag_message
|
yield message
|
||||||
|
|
||||||
if not isinstance(rag_message, ChatMessageRagSearch):
|
if not isinstance(message, ChatMessageRagSearch):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected ChatMessageRagSearch, got {type(rag_message)}"
|
f"Expected ChatMessageRagSearch, got {type(rag_message)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rag_message = message
|
||||||
context = self.get_rag_context(rag_message)
|
context = self.get_rag_context(rag_message)
|
||||||
|
|
||||||
# Create a pruned down message list based purely on the prompt and responses,
|
# Create a pruned down message list based purely on the prompt and responses,
|
||||||
@ -724,6 +726,7 @@ Content: {content}
|
|||||||
eval_duration=response.usage.eval_duration,
|
eval_duration=response.usage.eval_duration,
|
||||||
prompt_eval_count=response.usage.prompt_eval_count,
|
prompt_eval_count=response.usage.prompt_eval_count,
|
||||||
prompt_eval_duration=response.usage.prompt_eval_duration,
|
prompt_eval_duration=response.usage.prompt_eval_duration,
|
||||||
|
rag_results=rag_message.content if rag_message else [],
|
||||||
timers={
|
timers={
|
||||||
"llm_streamed": end_time - start_time,
|
"llm_streamed": end_time - start_time,
|
||||||
"llm_with_tools": 0, # Placeholder for tool processing time
|
"llm_with_tools": 0, # Placeholder for tool processing time
|
||||||
|
@ -52,8 +52,7 @@ class Chat(Agent):
|
|||||||
"""
|
"""
|
||||||
Chat Agent
|
Chat Agent
|
||||||
"""
|
"""
|
||||||
|
agent_type: Literal["general"] = "general" # type: ignore
|
||||||
agent_type: Literal["general"] = "general"
|
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
system_prompt: str = system_message
|
system_prompt: str = system_message
|
||||||
|
@ -7,15 +7,13 @@ from .base import Agent, agent_registry
|
|||||||
from logger import logger
|
from logger import logger
|
||||||
|
|
||||||
from .registry import agent_registry
|
from .registry import agent_registry
|
||||||
from models import ( ChatMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, ChatOptions, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate, Tunables )
|
from models import ( ChatMessage, ChromaDBGetResponse, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, ChatOptions, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate, Tunables )
|
||||||
from rag import ( ChromaDBGetResponse )
|
|
||||||
|
|
||||||
class Chat(Agent):
|
class Chat(Agent):
|
||||||
"""
|
"""
|
||||||
Chat Agent
|
Chat Agent
|
||||||
"""
|
"""
|
||||||
|
agent_type: Literal["rag_search"] = "rag_search" # type: ignore
|
||||||
agent_type: Literal["rag_search"] = "rag_search"
|
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
|
@ -1,20 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
|
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
|
||||||
|
|
||||||
from typing_extensions import Annotated, Union
|
from typing_extensions import Annotated, Union
|
||||||
import numpy as np
|
import numpy as np # type: ignore
|
||||||
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from prometheus_client import CollectorRegistry, Counter
|
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||||
import traceback
|
import traceback
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from rag import start_file_watcher, ChromaDBFileWatcher, ChromaDBGetResponse
|
from rag import start_file_watcher, ChromaDBFileWatcher
|
||||||
import defines
|
import defines
|
||||||
from logger import logger
|
from logger import logger
|
||||||
import agents as agents
|
import agents as agents
|
||||||
@ -22,6 +22,7 @@ from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, R
|
|||||||
import llm_proxy as llm_manager
|
import llm_proxy as llm_manager
|
||||||
from agents.base import Agent
|
from agents.base import Agent
|
||||||
from database import RedisDatabase
|
from database import RedisDatabase
|
||||||
|
from models import ChromaDBGetResponse
|
||||||
|
|
||||||
class CandidateEntity(Candidate):
|
class CandidateEntity(Candidate):
|
||||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
||||||
|
@ -2855,6 +2855,7 @@ async def post_candidate_vector_content(
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if current_user.user_type != "candidate":
|
if current_user.user_type != "candidate":
|
||||||
|
logger.warning(f"⚠️ Unauthorized access attempt by user type: {current_user.user_type}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
content=create_error_response("FORBIDDEN", "Only candidates can access this endpoint")
|
content=create_error_response("FORBIDDEN", "Only candidates can access this endpoint")
|
||||||
@ -2864,13 +2865,11 @@ async def post_candidate_vector_content(
|
|||||||
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
|
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
|
||||||
collection = candidate_entity.umap_collection
|
collection = candidate_entity.umap_collection
|
||||||
if not collection:
|
if not collection:
|
||||||
|
logger.warning(f"⚠️ No UMAP collection found for candidate {candidate.username}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{"error": "No UMAP collection found"}, status_code=404
|
{"error": "No UMAP collection found"}, status_code=404
|
||||||
)
|
)
|
||||||
|
|
||||||
if not collection.metadatas or collection.ids:
|
|
||||||
return JSONResponse(f"Document id {rag_document.id} not found.", 404)
|
|
||||||
|
|
||||||
for index, id in enumerate(collection.ids):
|
for index, id in enumerate(collection.ids):
|
||||||
if id == rag_document.id:
|
if id == rag_document.id:
|
||||||
metadata = collection.metadatas[index].copy()
|
metadata = collection.metadatas[index].copy()
|
||||||
@ -2883,9 +2882,11 @@ async def post_candidate_vector_content(
|
|||||||
logger.warning(f"⚠️ No content found for document id {id} for candidate {candidate.username}")
|
logger.warning(f"⚠️ No content found for document id {id} for candidate {candidate.username}")
|
||||||
return JSONResponse(f"No content found for document id {rag_document.id}.", 404)
|
return JSONResponse(f"No content found for document id {rag_document.id}.", 404)
|
||||||
return create_success_response(rag_response.model_dump(by_alias=True))
|
return create_success_response(rag_response.model_dump(by_alias=True))
|
||||||
|
|
||||||
|
logger.warning(f"⚠️ Document id {rag_document.id} not found in UMAP collection for candidate {candidate.username}")
|
||||||
return JSONResponse(f"Document id {rag_document.id} not found.", 404)
|
return JSONResponse(f"Document id {rag_document.id} not found.", 404)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(backstory_traceback.format_exc())
|
||||||
logger.error(f"❌ Post candidate content error: {e}")
|
logger.error(f"❌ Post candidate content error: {e}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
@ -2942,6 +2943,7 @@ async def post_candidate_vectors(
|
|||||||
return create_success_response(result)
|
return create_success_response(result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(backstory_traceback.format_exc())
|
||||||
logger.error(f"❌ Post candidate vectors error: {e}")
|
logger.error(f"❌ Post candidate vectors error: {e}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from .rag import ChromaDBFileWatcher, ChromaDBGetResponse, start_file_watcher, RagEntry
|
from .rag import ChromaDBFileWatcher, start_file_watcher, RagEntry
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ChromaDBFileWatcher",
|
"ChromaDBFileWatcher",
|
||||||
"ChromaDBGetResponse",
|
|
||||||
"start_file_watcher",
|
"start_file_watcher",
|
||||||
"RagEntry"
|
"RagEntry"
|
||||||
]
|
]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field
|
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
|
||||||
from typing import List, Optional, Dict, Any, Union
|
from typing import List, Optional, Dict, Any, Union
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
@ -9,15 +9,15 @@ import hashlib
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np # type: ignore
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import chromadb
|
import chromadb # type: ignore
|
||||||
from watchdog.observers import Observer
|
from watchdog.observers import Observer # type: ignore
|
||||||
from watchdog.events import FileSystemEventHandler
|
from watchdog.events import FileSystemEventHandler # type: ignore
|
||||||
import umap
|
import umap # type: ignore
|
||||||
from markitdown import MarkItDown
|
from markitdown import MarkItDown # type: ignore
|
||||||
from chromadb.api.models.Collection import Collection
|
from chromadb.api.models.Collection import Collection # type: ignore
|
||||||
|
|
||||||
from .markdown_chunker import (
|
from .markdown_chunker import (
|
||||||
MarkdownChunker,
|
MarkdownChunker,
|
||||||
@ -305,18 +305,18 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
|
|
||||||
def _update_umaps(self):
|
def _update_umaps(self):
|
||||||
# Update the UMAP embeddings
|
# Update the UMAP embeddings
|
||||||
self._umap_collection = self._collection.get(
|
self._umap_collection = ChromaDBGetResponse.model_validate(self._collection.get(
|
||||||
include=["embeddings", "documents", "metadatas"]
|
include=["embeddings", "documents", "metadatas"]
|
||||||
)
|
))
|
||||||
if not self._umap_collection or not len(self._umap_collection["embeddings"]):
|
if not self._umap_collection or not len(self._umap_collection.embeddings):
|
||||||
logging.warning("No embeddings found in the collection.")
|
logging.warning("No embeddings found in the collection.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# During initialization
|
# During initialization
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
|
||||||
)
|
)
|
||||||
vectors = np.array(self._umap_collection["embeddings"])
|
vectors = np.array(self._umap_collection.embeddings)
|
||||||
self._umap_model_2d = umap.UMAP(
|
self._umap_model_2d = umap.UMAP(
|
||||||
n_components=2,
|
n_components=2,
|
||||||
random_state=8911,
|
random_state=8911,
|
||||||
@ -330,7 +330,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
# ) # Should be 2
|
# ) # Should be 2
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection['embeddings'])} vectors"
|
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
|
||||||
)
|
)
|
||||||
self._umap_model_3d = umap.UMAP(
|
self._umap_model_3d = umap.UMAP(
|
||||||
n_components=3,
|
n_components=3,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user