238 lines
9.4 KiB
Python
238 lines
9.4 KiB
Python
from __future__ import annotations
|
|
from pydantic import BaseModel, Field, model_validator # type: ignore
|
|
from uuid import uuid4
|
|
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING
|
|
|
|
from typing_extensions import Annotated, Union
|
|
import numpy as np # type: ignore
|
|
import logging
|
|
from uuid import uuid4
|
|
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
|
import traceback
|
|
import os
|
|
import json
|
|
import re
|
|
from pathlib import Path
|
|
|
|
|
|
from . rag import start_file_watcher, ChromaDBFileWatcher, ChromaDBGetResponse
|
|
from . import defines
|
|
from . import Message
|
|
#from . import Context
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
from .rag import RagEntry
|
|
from .message import Tunables
|
|
|
|
class Question(BaseModel):
|
|
question: str
|
|
tunables: Tunables = Field(default_factory=Tunables)
|
|
|
|
class User(BaseModel):
|
|
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
|
|
|
username: str
|
|
llm: Any = Field(exclude=True)
|
|
rags: List[RagEntry] = Field(default_factory=list)
|
|
first_name: str = ""
|
|
last_name: str = ""
|
|
full_name: str = ""
|
|
description: str = ""
|
|
profile_url: str = ""
|
|
rag_content_size : int = 0
|
|
contact_info : Dict[str, str] = {}
|
|
user_questions : List[Question] = []
|
|
|
|
#context: Optional[List[Context]] = []
|
|
# file_watcher : ChromaDBFileWatcher = set by initialize
|
|
# observer: Any = set by initialize
|
|
# prometheus_collector : CollectorRegistry = set by initialize
|
|
|
|
# Internal instance members
|
|
User__observer: Optional[Any] = Field(default=None, exclude=True)
|
|
User__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
|
User__prometheus_collector: Optional[CollectorRegistry] = Field(
|
|
default=None, exclude=True
|
|
)
|
|
|
|
@classmethod
|
|
def exists(cls, username: str):
|
|
# Validate username format (only allow safe characters)
|
|
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
|
|
return False # Invalid username characters
|
|
|
|
# Check for minimum and maximum length
|
|
if not (3 <= len(username) <= 32):
|
|
return False # Invalid username length
|
|
|
|
# Use Path for safe path handling and normalization
|
|
user_dir = Path(defines.user_dir) / username
|
|
user_info_path = user_dir / defines.user_info_file
|
|
|
|
# Ensure the final path is actually within the intended parent directory
|
|
# to help prevent directory traversal attacks
|
|
try:
|
|
if not user_dir.resolve().is_relative_to(Path(defines.user_dir).resolve()):
|
|
return False # Path traversal attempt detected
|
|
except (ValueError, RuntimeError): # Potential exceptions from resolve()
|
|
return False
|
|
|
|
# Check if file exists
|
|
return user_info_path.is_file()
|
|
|
|
# Wrapper properties that map into file_watcher
|
|
@property
|
|
def umap_collection(self) -> ChromaDBGetResponse:
|
|
if not self.User__file_watcher:
|
|
raise ValueError("initialize() has not been called.")
|
|
return self.User__file_watcher.umap_collection
|
|
|
|
# Fields managed by initialize()
|
|
User__initialized: bool = Field(default=False, exclude=True)
|
|
@property
|
|
def file_watcher(self) -> ChromaDBFileWatcher:
|
|
if not self.User__file_watcher:
|
|
raise ValueError("initialize() has not been called.")
|
|
return self.User__file_watcher
|
|
|
|
@property
|
|
def prometheus_collector(self) -> CollectorRegistry:
|
|
if not self.User__prometheus_collector:
|
|
raise ValueError("initialize() has not been called.")
|
|
return self.User__prometheus_collector
|
|
|
|
@property
|
|
def observer(self) -> Any:
|
|
if not self.User__observer:
|
|
raise ValueError("initialize() has not been called.")
|
|
return self.User__observer
|
|
|
|
def generate_rag_results(
|
|
self, message: Message, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold
|
|
) -> Generator[Message, None, None]:
|
|
"""
|
|
Generate RAG results for the given query.
|
|
|
|
Args:
|
|
query: The query string to generate RAG results for.
|
|
|
|
Returns:
|
|
A list of dictionaries containing the RAG results.
|
|
"""
|
|
try:
|
|
message.status = "processing"
|
|
|
|
entries: int = 0
|
|
|
|
for rag in self.rags:
|
|
if not rag.enabled:
|
|
continue
|
|
message.response = f"Checking RAG context {rag.name}..."
|
|
yield message
|
|
chroma_results = self.file_watcher.find_similar(
|
|
query=message.prompt, top_k=top_k, threshold=threshold
|
|
)
|
|
if chroma_results:
|
|
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
|
|
|
umap_2d = self.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
|
umap_3d = self.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
|
|
|
rag_metadata = ChromaDBGetResponse(
|
|
query=message.prompt,
|
|
query_embedding=query_embedding.tolist(),
|
|
name=rag.name,
|
|
ids=chroma_results.get("ids", []),
|
|
embeddings=chroma_results.get("embeddings", []),
|
|
documents=chroma_results.get("documents", []),
|
|
metadatas=chroma_results.get("metadatas", []),
|
|
umap_embedding_2d=umap_2d.tolist(),
|
|
umap_embedding_3d=umap_3d.tolist(),
|
|
size=self.file_watcher.collection.count()
|
|
)
|
|
|
|
message.metadata.rag.append(rag_metadata)
|
|
message.response = f"Results from {rag.name} RAG: {len(chroma_results['documents'])} results."
|
|
yield message
|
|
|
|
message.response = (
|
|
f"RAG context gathered from results from {entries} documents."
|
|
)
|
|
message.status = "done"
|
|
yield message
|
|
return
|
|
except Exception as e:
|
|
message.status = "error"
|
|
message.response = f"Error generating RAG results: {str(e)}"
|
|
logger.error(traceback.format_exc())
|
|
logger.error(message.response)
|
|
yield message
|
|
return
|
|
|
|
def initialize(self, prometheus_collector):
|
|
if self.User__initialized:
|
|
# Initialization can only be attempted once; if there are multiple attempts, it means
|
|
# a subsystem is failing or there is a logic bug in the code.
|
|
#
|
|
# NOTE: It is intentional that self.User__initialize = True regardless of whether it
|
|
# succeeded. This prevents server loops on failure
|
|
raise ValueError("initialize can only be attempted once")
|
|
self.User__initialized = True
|
|
|
|
user_dir = os.path.join(defines.user_dir, self.username)
|
|
user_info = os.path.join(user_dir, defines.user_info_file)
|
|
persist_directory=os.path.join(user_dir, defines.persist_directory)
|
|
watch_directory=os.path.join(user_dir, defines.rag_content_dir)
|
|
logger.info(f"User(username={self.username}, user_dir={user_dir} persist_directory={persist_directory}, watch_directory={watch_directory}")
|
|
|
|
info = {}
|
|
# Always re-initialize the user's name and contact data from the info file in case it is changed
|
|
try:
|
|
with open(user_info, "r") as f:
|
|
info = json.loads(f.read())
|
|
except Exception as e:
|
|
logger.error(f"Error processing {user_info}: {e}")
|
|
if info:
|
|
logger.error(f"info={info}")
|
|
|
|
self.first_name = info.get("first_name", self.username)
|
|
self.last_name = info.get("last_name", "")
|
|
self.full_name = info.get("full_name", f"{self.first_name} {self.last_name}")
|
|
self.description = info.get("description", self.description)
|
|
self.profile_url = info.get("profile_url", self.description)
|
|
self.contact_info = info.get("contact_info", {})
|
|
questions = info.get("questions", [ f"Tell me about {self.first_name}.", f"What are {self.first_name}'s professional strengths?"])
|
|
self.user_questions = []
|
|
for question in questions:
|
|
if type(question) == str:
|
|
self.user_questions.append(Question(question=question, tunables=Tunables(enable_tools=False)))
|
|
else:
|
|
try:
|
|
self.user_questions.append(Question.model_validate(question))
|
|
except Exception as e:
|
|
logger.info(f"Unable to initialize all questions from {user_info}")
|
|
|
|
os.makedirs(persist_directory, exist_ok=True)
|
|
os.makedirs(watch_directory, exist_ok=True)
|
|
|
|
self.User__prometheus_collector = prometheus_collector
|
|
self.User__observer, self.User__file_watcher = start_file_watcher(
|
|
llm=self.llm,
|
|
collection_name=self.username,
|
|
persist_directory=persist_directory,
|
|
watch_directory=watch_directory,
|
|
recreate=False, # Don't recreate if exists
|
|
)
|
|
has_username_rag = any(item["name"] == self.username for item in self.rags)
|
|
if not has_username_rag:
|
|
self.rags.append(RagEntry(
|
|
name=self.username,
|
|
description=f"Expert data about {self.full_name}.",
|
|
))
|
|
self.rag_content_size = self.file_watcher.collection.count()
|
|
|
|
|
|
User.model_rebuild()
|