from __future__ import annotations from pydantic import BaseModel, Field, model_validator from uuid import uuid4 from typing import List, Optional, Generator, ClassVar, Any, TYPE_CHECKING from typing_extensions import Annotated, Union import numpy as np import logging from uuid import uuid4 import traceback from . rag import RagEntry from . tools import ToolEntry from . agents import AnyAgent from . import User if TYPE_CHECKING: from .user import User logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class Context(BaseModel): class Config: validate_by_name = True # Allow 'user' to be set via constructor arbitrary_types_allowed = True # Allow ChromaDBFileWatcher id: str = Field( default_factory=lambda: str(uuid4()), pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", ) tools: List[ToolEntry] rags: List[RagEntry] username: str = "__invalid__" # "user" is not serialized and must be set after construction Context__user: User = Field( default_factory=lambda: User(username="__invalid__", llm=None, rags=[]), alias="user", exclude=True) # Optional fields user_resume: Optional[str] = None user_job_description: Optional[str] = None user_facts: Optional[str] = None # Class managed fields agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( default_factory=list ) @model_validator(mode='after') def set_user_and_username(self): if self.Context__user.username != "__invalid__": if self.username == "__invalid__": logger.info(f"Binding context {self.id} to user {self.Context__user.username}") self.username = self.Context__user.username else: raise ValueError("user can only be set once") return self # Only allow dereference of 'user' if it has been set @property def user(self) -> User: if self.Context__user.username == "__invalid__": # After deserializing Context(), you must explicitly set the # user: # # context = Context(...) # context.user = raise ValueError("Attempt to dereference default_factory constructed User") return self.Context__user @user.setter def user(self, new_user: User) -> User: logger.info(f"Binding context {self.id} to user {new_user.username}") if self.Context__user.username != "__invalid__" and new_user.username != self.Context__user.username: logger.info(f"Resetting context changing from {self.Context__user.username} to user {new_user.username}") self.agents = [a for a in self.agents if a.agent_type == "chat"] for a in self.agents: a.conversation.reset() self.username = new_user.username self.Context__user = new_user return new_user processing: bool = Field(default=False, exclude=True) # @model_validator(mode="before") # @classmethod # def before_model_validator(cls, values: Any): # logger.info(f"Preparing model data: {cls} {values}") # return values @model_validator(mode="after") def after_model_validator(self): """Ensure at most one agent per agent_type.""" logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.") agent_types = [agent.agent_type for agent in self.agents] if len(agent_types) != len(set(agent_types)): raise ValueError( "Context cannot contain multiple agents of the same agent_type" ) # for agent in self.agents: # agent.set_context(self) return self def get_or_create_agent(self, agent_type: str, **kwargs) -> Agent: """ Get or create and append a new agent of the specified type, ensuring only one agent per type exists. Args: agent_type: The type of agent to create (e.g., 'web', 'database'). **kwargs: Additional fields required by the specific agent subclass. Returns: The created agent instance. Raises: ValueError: If no matching agent type is found or if a agent of this type already exists. """ # Check if a agent with the given agent_type already exists for agent in self.agents: if agent.agent_type == agent_type: return agent # Find the matching subclass for agent_cls in Agent.__subclasses__(): if agent_cls.model_fields["agent_type"].default == agent_type: # Create the agent instance with provided kwargs agent = agent_cls(agent_type=agent_type, **kwargs) # set_context after constructor to initialize any non-serialized data agent.set_context(self) if agent.agent_persist: # If an agent is not set to persist, do not add it to the context self.agents.append(agent) return agent raise ValueError(f"No agent class found for agent_type: {agent_type}") def add_agent(self, agent: AnyAgent) -> None: """Add a Agent to the context, ensuring no duplicate agent_type.""" if any(s.agent_type == agent.agent_type for s in self.agents): raise ValueError( f"A agent with agent_type '{agent.agent_type}' already exists" ) self.agents.append(agent) def remove_agent(self, agent: AnyAgent) -> None: """Remove an Agent from the context based on its agent_type.""" self.agents = [a for a in self.agents if a.agent_type != agent.agent_type] def get_agent(self, agent_type: str) -> Agent | None: """Return the Agent with the given agent_type, or None if not found.""" for agent in self.agents: if agent.agent_type == agent_type: return agent return None def is_valid_agent_type(self, agent_type: str) -> bool: """Check if the given agent_type is valid.""" return agent_type in Agent.valid_agent_types() def get_summary(self) -> str: """Return a summary of the context.""" if not self.agents: return f"Context {self.uuid}: No agents." summary = f"Context {self.uuid}:\n" for i, agent in enumerate(self.agents, 1): summary += f"\nAgent {i} ({agent.agent_type}):\n" summary += agent.conversation.get_summary() if agent.agent_type == "resume": summary += f"\nResume: {agent.get_resume()}\n" elif agent.agent_type == "job_description": summary += f"\nJob Description: {agent.job_description}\n" elif agent.agent_type == "fact_check": summary += f"\nFacts: {agent.facts}\n" elif agent.agent_type == "chat": summary += f"\nChat Name: {agent.name}\n" return summary from .agents import Agent Context.model_rebuild()