227 lines
9.7 KiB
Python
227 lines
9.7 KiB
Python
from __future__ import annotations
|
|
from pydantic import BaseModel, Field, model_validator # type: ignore
|
|
from uuid import uuid4
|
|
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
|
|
|
|
from typing_extensions import Annotated, Union
|
|
import numpy as np # type: ignore
|
|
|
|
from uuid import uuid4
|
|
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
|
import traceback
|
|
import os
|
|
import json
|
|
import re
|
|
from pathlib import Path
|
|
|
|
from rag import start_file_watcher, ChromaDBFileWatcher, ChromaDBGetResponse
|
|
import defines
|
|
from logger import logger
|
|
import agents as agents
|
|
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageType, ChatMessageMetaData, ChatStatusType, Candidate, ChatContextType)
|
|
from llm_manager import llm_manager
|
|
|
|
class CandidateEntity(Candidate):
|
|
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
|
|
|
# Internal instance members
|
|
CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True)
|
|
CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
|
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(
|
|
default=None, exclude=True
|
|
)
|
|
|
|
def __init__(self, candidate=None):
|
|
if candidate is not None:
|
|
# Copy attributes from the candidate instance
|
|
super().__init__(**vars(candidate))
|
|
else:
|
|
super().__init__()
|
|
|
|
@classmethod
|
|
def exists(cls, username: str):
|
|
# Validate username format (only allow safe characters)
|
|
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
|
|
return False # Invalid username characters
|
|
|
|
# Check for minimum and maximum length
|
|
if not (3 <= len(username) <= 32):
|
|
return False # Invalid username length
|
|
|
|
# Use Path for safe path handling and normalization
|
|
user_dir = Path(defines.user_dir) / username
|
|
user_info_path = user_dir / defines.user_info_file
|
|
|
|
# Ensure the final path is actually within the intended parent directory
|
|
# to help prevent directory traversal attacks
|
|
try:
|
|
if not user_dir.resolve().is_relative_to(Path(defines.user_dir).resolve()):
|
|
return False # Path traversal attempt detected
|
|
except (ValueError, RuntimeError): # Potential exceptions from resolve()
|
|
return False
|
|
|
|
# Check if file exists
|
|
return user_info_path.is_file()
|
|
|
|
def get_or_create_agent(self, agent_type: ChatContextType, **kwargs) -> agents.Agent:
|
|
"""
|
|
Get or create an agent of the specified type for this candidate.
|
|
|
|
Args:
|
|
agent_type: The type of agent to create (default is 'candidate_chat').
|
|
**kwargs: Additional fields required by the specific agent subclass.
|
|
|
|
Returns:
|
|
The created agent instance.
|
|
"""
|
|
return agents.get_or_create_agent(
|
|
agent_type=agent_type,
|
|
user=self,
|
|
prometheus_collector=self.prometheus_collector,
|
|
**kwargs)
|
|
|
|
# Wrapper properties that map into file_watcher
|
|
@property
|
|
def umap_collection(self) -> ChromaDBGetResponse:
|
|
if not self.CandidateEntity__file_watcher:
|
|
raise ValueError("initialize() has not been called.")
|
|
return self.CandidateEntity__file_watcher.umap_collection
|
|
|
|
# Fields managed by initialize()
|
|
CandidateEntity__initialized: bool = Field(default=False, exclude=True)
|
|
@property
|
|
def file_watcher(self) -> ChromaDBFileWatcher:
|
|
if not self.CandidateEntity__file_watcher:
|
|
raise ValueError("initialize() has not been called.")
|
|
return self.CandidateEntity__file_watcher
|
|
|
|
@property
|
|
def prometheus_collector(self) -> CollectorRegistry:
|
|
if not self.CandidateEntity__prometheus_collector:
|
|
raise ValueError("initialize() has not been called with a prometheus_collector.")
|
|
return self.CandidateEntity__prometheus_collector
|
|
|
|
@property
|
|
def observer(self) -> Any:
|
|
if not self.CandidateEntity__observer:
|
|
raise ValueError("initialize() has not been called.")
|
|
return self.CandidateEntity__observer
|
|
|
|
@classmethod
|
|
def sanitize(cls, user: Dict[str, Any]):
|
|
sanitized : Dict[str, Any] = {}
|
|
sanitized["username"] = user.get("username", "default")
|
|
sanitized["first_name"] = user.get("first_name", sanitized["username"])
|
|
sanitized["last_name"] = user.get("last_name", "")
|
|
sanitized["title"] = user.get("title", "")
|
|
sanitized["phone"] = user.get("phone", "")
|
|
sanitized["location"] = user.get("location", "")
|
|
sanitized["email"] = user.get("email", "")
|
|
sanitized["full_name"] = user.get("full_name", f"{sanitized["first_name"]} {sanitized["last_name"]}")
|
|
sanitized["description"] = user.get("description", "")
|
|
profile_image = os.path.join(defines.user_dir, sanitized["username"], "profile.png")
|
|
sanitized["has_profile"] = os.path.exists(profile_image)
|
|
contact_info = user.get("contact_info", {})
|
|
sanitized["contact_info"] = {}
|
|
for key in contact_info:
|
|
if not isinstance(contact_info[key], (str, int, float, complex)):
|
|
continue
|
|
sanitized["contact_info"][key] = contact_info[key]
|
|
questions = user.get("questions", [ f"Tell me about {sanitized['first_name']}.", f"What are {sanitized['first_name']}'s professional strengths?"])
|
|
sanitized["user_questions"] = []
|
|
for question in questions:
|
|
if type(question) == str:
|
|
sanitized["user_questions"].append({"question": question})
|
|
else:
|
|
try:
|
|
tmp = CandidateQuestion.model_validate(question)
|
|
sanitized["user_questions"].append({"question": tmp.question})
|
|
except Exception as e:
|
|
continue
|
|
return sanitized
|
|
|
|
@classmethod
|
|
def get_users(cls):
|
|
# Initialize an empty list to store parsed JSON data
|
|
user_data = []
|
|
|
|
# Define the users directory path
|
|
users_dir = os.path.join(defines.user_dir)
|
|
|
|
# Check if the users directory exists
|
|
if not os.path.exists(users_dir):
|
|
return user_data
|
|
|
|
# Iterate through all items in the users directory
|
|
for item in os.listdir(users_dir):
|
|
# Construct the full path to the item
|
|
item_path = os.path.join(users_dir, item)
|
|
|
|
# Check if the item is a directory
|
|
if os.path.isdir(item_path):
|
|
# Construct the path to info.json
|
|
info_path = os.path.join(item_path, "info.json")
|
|
|
|
# Check if info.json exists
|
|
if os.path.exists(info_path):
|
|
try:
|
|
# Read and parse the JSON file
|
|
with open(info_path, 'r') as file:
|
|
data = json.load(file)
|
|
data["username"] = item
|
|
profile_image = os.path.join(defines.user_dir, item, "profile.png")
|
|
data["has_profile"] = os.path.exists(profile_image)
|
|
user_data.append(data)
|
|
except json.JSONDecodeError as e:
|
|
# Skip files that aren't valid JSON
|
|
logger.info(f"Invalid JSON for {info_path}: {str(e)}")
|
|
continue
|
|
except Exception as e:
|
|
# Skip files that can't be read
|
|
logger.info(f"Exception processing {info_path}: {str(e)}")
|
|
continue
|
|
|
|
return user_data
|
|
|
|
async def initialize(self, prometheus_collector: CollectorRegistry):
|
|
if self.CandidateEntity__initialized:
|
|
# Initialization can only be attempted once; if there are multiple attempts, it means
|
|
# a subsystem is failing or there is a logic bug in the code.
|
|
#
|
|
# NOTE: It is intentional that self.CandidateEntity__initialize = True regardless of whether it
|
|
# succeeded. This prevents server loops on failure
|
|
raise ValueError("initialize can only be attempted once")
|
|
self.CandidateEntity__initialized = True
|
|
|
|
if not self.username:
|
|
raise ValueError("username can not be empty")
|
|
|
|
user_dir = os.path.join(defines.user_dir, self.username)
|
|
vector_db_dir=os.path.join(user_dir, defines.persist_directory)
|
|
rag_content_dir=os.path.join(user_dir, defines.rag_content_dir)
|
|
|
|
logger.info(f"CandidateEntity(username={self.username}, user_dir={user_dir} persist_directory={vector_db_dir}, rag_content_dir={rag_content_dir}")
|
|
|
|
os.makedirs(vector_db_dir, exist_ok=True)
|
|
os.makedirs(rag_content_dir, exist_ok=True)
|
|
|
|
if prometheus_collector:
|
|
self.CandidateEntity__prometheus_collector = prometheus_collector
|
|
|
|
self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher(
|
|
llm=llm_manager.get_llm(),
|
|
collection_name=self.username,
|
|
persist_directory=vector_db_dir,
|
|
watch_directory=rag_content_dir,
|
|
recreate=False, # Don't recreate if exists
|
|
)
|
|
has_username_rag = any(item["name"] == self.username for item in self.rags)
|
|
if not has_username_rag:
|
|
self.rags.append(RagEntry(
|
|
name=self.username,
|
|
description=f"Expert data about {self.full_name}.",
|
|
))
|
|
self.rag_content_size = self.file_watcher.collection.count()
|
|
|
|
CandidateEntity.model_rebuild()
|