from __future__ import annotations from pydantic import BaseModel, Field, model_validator # type: ignore from uuid import uuid4 from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal from typing_extensions import Annotated, Union import numpy as np # type: ignore from uuid import uuid4 from prometheus_client import CollectorRegistry, Counter # type: ignore import traceback import os import json import re from pathlib import Path from rag import start_file_watcher, ChromaDBFileWatcher import defines from logger import logger import agents as agents from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType) import llm_proxy as llm_manager from agents.base import Agent from database import RedisDatabase from models import ChromaDBGetResponse class CandidateEntity(Candidate): model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc # Internal instance members CandidateEntity__agents: List[Agent] = [] CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True) CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True) CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field( default=None, exclude=True ) def __init__(self, candidate=None): if candidate is not None: # Copy attributes from the candidate instance super().__init__(**vars(candidate)) else: super().__init__() @classmethod def exists(cls, username: str): # Validate username format (only allow safe characters) if not re.match(r'^[a-zA-Z0-9_-]+$', username): return False # Invalid username characters # Check for minimum and maximum length if not (3 <= len(username) <= 32): return False # Invalid username length # Use Path for safe path handling and normalization user_dir = Path(defines.user_dir) / username user_info_path = user_dir / defines.user_info_file # Ensure the final path is actually within the intended parent directory # to help prevent directory traversal attacks try: if not user_dir.resolve().is_relative_to(Path(defines.user_dir).resolve()): return False # Path traversal attempt detected except (ValueError, RuntimeError): # Potential exceptions from resolve() return False # Check if file exists return user_info_path.is_file() def get_or_create_agent(self, agent_type: ChatContextType) -> agents.Agent: """ Get or create an agent of the specified type for this candidate. Args: agent_type: The type of agent to create (default is 'candidate_chat'). **kwargs: Additional fields required by the specific agent subclass. Returns: The created agent instance. """ # Only instantiate one agent of each type per user for agent in self.CandidateEntity__agents: if agent.agent_type == agent_type: return agent return agents.get_or_create_agent( agent_type=agent_type, user=self, prometheus_collector=self.prometheus_collector ) # Wrapper properties that map into file_watcher @property def umap_collection(self) -> ChromaDBGetResponse: if not self.CandidateEntity__file_watcher: raise ValueError("initialize() has not been called.") return self.CandidateEntity__file_watcher.umap_collection # Fields managed by initialize() CandidateEntity__initialized: bool = Field(default=False, exclude=True) @property def file_watcher(self) -> ChromaDBFileWatcher: if not self.CandidateEntity__file_watcher: raise ValueError("initialize() has not been called.") return self.CandidateEntity__file_watcher @property def prometheus_collector(self) -> CollectorRegistry: if not self.CandidateEntity__prometheus_collector: raise ValueError("initialize() has not been called with a prometheus_collector.") return self.CandidateEntity__prometheus_collector @property def observer(self) -> Any: if not self.CandidateEntity__observer: raise ValueError("initialize() has not been called.") return self.CandidateEntity__observer async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase): if self.CandidateEntity__initialized: # Initialization can only be attempted once; if there are multiple attempts, it means # a subsystem is failing or there is a logic bug in the code. # # NOTE: It is intentional that self.CandidateEntity__initialize = True regardless of whether it # succeeded. This prevents server loops on failure raise ValueError("initialize can only be attempted once") self.CandidateEntity__initialized = True if not self.username: raise ValueError("username can not be empty") user_dir = os.path.join(defines.user_dir, self.username) vector_db_dir=os.path.join(user_dir, defines.persist_directory) rag_content_dir=os.path.join(user_dir, defines.rag_content_dir) os.makedirs(vector_db_dir, exist_ok=True) os.makedirs(rag_content_dir, exist_ok=True) if prometheus_collector: self.CandidateEntity__prometheus_collector = prometheus_collector self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher( llm=llm_manager.get_llm(), user_id=self.id, collection_name=self.username, persist_directory=vector_db_dir, watch_directory=rag_content_dir, database=database, recreate=False, # Don't recreate if exists ) has_username_rag = any(item["name"] == self.username for item in self.rags) if not has_username_rag: self.rags.append(RagEntry( name=self.username, description=f"Expert data about {self.full_name}.", )) self.rag_content_size = self.file_watcher.collection.count() CandidateEntity.model_rebuild()