backstory/src/utils/user.py
2025-05-19 14:29:24 -07:00

238 lines
9.4 KiB
Python

from __future__ import annotations
from pydantic import BaseModel, Field, model_validator # type: ignore
from uuid import uuid4
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING
from typing_extensions import Annotated, Union
import numpy as np # type: ignore
import logging
from uuid import uuid4
from prometheus_client import CollectorRegistry, Counter # type: ignore
import traceback
import os
import json
import re
from pathlib import Path
from . rag import start_file_watcher, ChromaDBFileWatcher, ChromaDBGetResponse
from . import defines
from . import Message
#from . import Context
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from .rag import RagEntry
from .message import Tunables
class Question(BaseModel):
question: str
tunables: Tunables = Field(default_factory=Tunables)
class User(BaseModel):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
username: str
llm: Any = Field(exclude=True)
rags: List[RagEntry] = Field(default_factory=list)
first_name: str = ""
last_name: str = ""
full_name: str = ""
description: str = ""
profile_url: str = ""
rag_content_size : int = 0
contact_info : Dict[str, str] = {}
user_questions : List[Question] = []
#context: Optional[List[Context]] = []
# file_watcher : ChromaDBFileWatcher = set by initialize
# observer: Any = set by initialize
# prometheus_collector : CollectorRegistry = set by initialize
# Internal instance members
User__observer: Optional[Any] = Field(default=None, exclude=True)
User__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
User__prometheus_collector: Optional[CollectorRegistry] = Field(
default=None, exclude=True
)
@classmethod
def exists(cls, username: str):
# Validate username format (only allow safe characters)
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
return False # Invalid username characters
# Check for minimum and maximum length
if not (3 <= len(username) <= 32):
return False # Invalid username length
# Use Path for safe path handling and normalization
user_dir = Path(defines.user_dir) / username
user_info_path = user_dir / defines.user_info_file
# Ensure the final path is actually within the intended parent directory
# to help prevent directory traversal attacks
try:
if not user_dir.resolve().is_relative_to(Path(defines.user_dir).resolve()):
return False # Path traversal attempt detected
except (ValueError, RuntimeError): # Potential exceptions from resolve()
return False
# Check if file exists
return user_info_path.is_file()
# Wrapper properties that map into file_watcher
@property
def umap_collection(self) -> ChromaDBGetResponse:
if not self.User__file_watcher:
raise ValueError("initialize() has not been called.")
return self.User__file_watcher.umap_collection
# Fields managed by initialize()
User__initialized: bool = Field(default=False, exclude=True)
@property
def file_watcher(self) -> ChromaDBFileWatcher:
if not self.User__file_watcher:
raise ValueError("initialize() has not been called.")
return self.User__file_watcher
@property
def prometheus_collector(self) -> CollectorRegistry:
if not self.User__prometheus_collector:
raise ValueError("initialize() has not been called.")
return self.User__prometheus_collector
@property
def observer(self) -> Any:
if not self.User__observer:
raise ValueError("initialize() has not been called.")
return self.User__observer
def generate_rag_results(
self, message: Message, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold
) -> Generator[Message, None, None]:
"""
Generate RAG results for the given query.
Args:
query: The query string to generate RAG results for.
Returns:
A list of dictionaries containing the RAG results.
"""
try:
message.status = "processing"
entries: int = 0
for rag in self.rags:
if not rag.enabled:
continue
message.response = f"Checking RAG context {rag.name}..."
yield message
chroma_results = self.file_watcher.find_similar(
query=message.prompt, top_k=top_k, threshold=threshold
)
if chroma_results:
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
umap_2d = self.file_watcher.umap_model_2d.transform([query_embedding])[0]
umap_3d = self.file_watcher.umap_model_3d.transform([query_embedding])[0]
rag_metadata = ChromaDBGetResponse(
query=message.prompt,
query_embedding=query_embedding.tolist(),
name=rag.name,
ids=chroma_results.get("ids", []),
embeddings=chroma_results.get("embeddings", []),
documents=chroma_results.get("documents", []),
metadatas=chroma_results.get("metadatas", []),
umap_embedding_2d=umap_2d.tolist(),
umap_embedding_3d=umap_3d.tolist(),
size=self.file_watcher.collection.count()
)
message.metadata.rag.append(rag_metadata)
message.response = f"Results from {rag.name} RAG: {len(chroma_results['documents'])} results."
yield message
message.response = (
f"RAG context gathered from results from {entries} documents."
)
message.status = "done"
yield message
return
except Exception as e:
message.status = "error"
message.response = f"Error generating RAG results: {str(e)}"
logger.error(traceback.format_exc())
logger.error(message.response)
yield message
return
def initialize(self, prometheus_collector):
if self.User__initialized:
# Initialization can only be attempted once; if there are multiple attempts, it means
# a subsystem is failing or there is a logic bug in the code.
#
# NOTE: It is intentional that self.User__initialize = True regardless of whether it
# succeeded. This prevents server loops on failure
raise ValueError("initialize can only be attempted once")
self.User__initialized = True
user_dir = os.path.join(defines.user_dir, self.username)
user_info = os.path.join(user_dir, defines.user_info_file)
persist_directory=os.path.join(user_dir, defines.persist_directory)
watch_directory=os.path.join(user_dir, defines.rag_content_dir)
logger.info(f"User(username={self.username}, user_dir={user_dir} persist_directory={persist_directory}, watch_directory={watch_directory}")
info = {}
# Always re-initialize the user's name and contact data from the info file in case it is changed
try:
with open(user_info, "r") as f:
info = json.loads(f.read())
except Exception as e:
logger.error(f"Error processing {user_info}: {e}")
if info:
logger.error(f"info={info}")
self.first_name = info.get("first_name", self.username)
self.last_name = info.get("last_name", "")
self.full_name = info.get("full_name", f"{self.first_name} {self.last_name}")
self.description = info.get("description", self.description)
self.profile_url = info.get("profile_url", self.description)
self.contact_info = info.get("contact_info", {})
questions = info.get("questions", [ f"Tell me about {self.first_name}.", f"What are {self.first_name}'s professional strengths?"])
self.user_questions = []
for question in questions:
if type(question) == str:
self.user_questions.append(Question(question=question, tunables=Tunables(enable_tools=False)))
else:
try:
self.user_questions.append(Question.model_validate(question))
except Exception as e:
logger.info(f"Unable to initialize all questions from {user_info}")
os.makedirs(persist_directory, exist_ok=True)
os.makedirs(watch_directory, exist_ok=True)
self.User__prometheus_collector = prometheus_collector
self.User__observer, self.User__file_watcher = start_file_watcher(
llm=self.llm,
collection_name=self.username,
persist_directory=persist_directory,
watch_directory=watch_directory,
recreate=False, # Don't recreate if exists
)
has_username_rag = any(item["name"] == self.username for item in self.rags)
if not has_username_rag:
self.rags.append(RagEntry(
name=self.username,
description=f"Expert data about {self.full_name}.",
))
self.rag_content_size = self.file_watcher.collection.count()
User.model_rebuild()