backstory/src/utils/context.py

182 lines
7.0 KiB
Python

from __future__ import annotations
from pydantic import BaseModel, Field, model_validator
from uuid import uuid4
from typing import List, Optional, Generator, ClassVar, Any, TYPE_CHECKING
from typing_extensions import Annotated, Union
import numpy as np
import logging
from uuid import uuid4
import traceback
from . rag import RagEntry
from . tools import ToolEntry
from . agents import AnyAgent
from . import User
if TYPE_CHECKING:
from .user import User
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Context(BaseModel):
class Config:
validate_by_name = True # Allow 'user' to be set via constructor
arbitrary_types_allowed = True # Allow ChromaDBFileWatcher
id: str = Field(
default_factory=lambda: str(uuid4()),
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
)
tools: List[ToolEntry]
rags: List[RagEntry]
username: str = "__invalid__"
# "user" is not serialized and must be set after construction
Context__user: User = Field(
default_factory=lambda: User(username="__invalid__", llm=None, rags=[]),
alias="user",
exclude=True)
# Optional fields
user_resume: Optional[str] = None
user_job_description: Optional[str] = None
user_facts: Optional[str] = None
# Class managed fields
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
default_factory=list
)
@model_validator(mode='after')
def set_user_and_username(self):
if self.Context__user.username != "__invalid__":
if self.username == "__invalid__":
logger.info(f"Binding context {self.id} to user {self.Context__user.username}")
self.username = self.Context__user.username
else:
raise ValueError("user can only be set once")
return self
# Only allow dereference of 'user' if it has been set
@property
def user(self) -> User:
if self.Context__user.username == "__invalid__":
# After deserializing Context(), you must explicitly set the
# user:
#
# context = Context(...)
# context.user = <valid user>
raise ValueError("Attempt to dereference default_factory constructed User")
return self.Context__user
@user.setter
def user(self, new_user: User) -> User:
logger.info(f"Binding context {self.id} to user {new_user.username}")
if self.Context__user.username != "__invalid__" and new_user.username != self.Context__user.username:
logger.info(f"Resetting context changing from {self.Context__user.username} to user {new_user.username}")
self.agents = [a for a in self.agents if a.agent_type == "chat"]
for a in self.agents:
a.conversation.reset()
self.username = new_user.username
self.Context__user = new_user
return new_user
processing: bool = Field(default=False, exclude=True)
# @model_validator(mode="before")
# @classmethod
# def before_model_validator(cls, values: Any):
# logger.info(f"Preparing model data: {cls} {values}")
# return values
@model_validator(mode="after")
def after_model_validator(self):
"""Ensure at most one agent per agent_type."""
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
agent_types = [agent.agent_type for agent in self.agents]
if len(agent_types) != len(set(agent_types)):
raise ValueError(
"Context cannot contain multiple agents of the same agent_type"
)
# for agent in self.agents:
# agent.set_context(self)
return self
def get_or_create_agent(self, agent_type: str, **kwargs) -> Agent:
"""
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
Args:
agent_type: The type of agent to create (e.g., 'web', 'database').
**kwargs: Additional fields required by the specific agent subclass.
Returns:
The created agent instance.
Raises:
ValueError: If no matching agent type is found or if a agent of this type already exists.
"""
# Check if a agent with the given agent_type already exists
for agent in self.agents:
if agent.agent_type == agent_type:
return agent
# Find the matching subclass
for agent_cls in Agent.__subclasses__():
if agent_cls.model_fields["agent_type"].default == agent_type:
# Create the agent instance with provided kwargs
agent = agent_cls(agent_type=agent_type, **kwargs)
# set_context after constructor to initialize any non-serialized data
agent.set_context(self)
if agent.agent_persist: # If an agent is not set to persist, do not add it to the context
self.agents.append(agent)
return agent
raise ValueError(f"No agent class found for agent_type: {agent_type}")
def add_agent(self, agent: AnyAgent) -> None:
"""Add a Agent to the context, ensuring no duplicate agent_type."""
if any(s.agent_type == agent.agent_type for s in self.agents):
raise ValueError(
f"A agent with agent_type '{agent.agent_type}' already exists"
)
self.agents.append(agent)
def remove_agent(self, agent: AnyAgent) -> None:
"""Remove an Agent from the context based on its agent_type."""
self.agents = [a for a in self.agents if a.agent_type != agent.agent_type]
def get_agent(self, agent_type: str) -> Agent | None:
"""Return the Agent with the given agent_type, or None if not found."""
for agent in self.agents:
if agent.agent_type == agent_type:
return agent
return None
def is_valid_agent_type(self, agent_type: str) -> bool:
"""Check if the given agent_type is valid."""
return agent_type in Agent.valid_agent_types()
def get_summary(self) -> str:
"""Return a summary of the context."""
if not self.agents:
return f"Context {self.uuid}: No agents."
summary = f"Context {self.uuid}:\n"
for i, agent in enumerate(self.agents, 1):
summary += f"\nAgent {i} ({agent.agent_type}):\n"
summary += agent.conversation.get_summary()
if agent.agent_type == "resume":
summary += f"\nResume: {agent.get_resume()}\n"
elif agent.agent_type == "job_description":
summary += f"\nJob Description: {agent.job_description}\n"
elif agent.agent_type == "fact_check":
summary += f"\nFacts: {agent.facts}\n"
elif agent.agent_type == "chat":
summary += f"\nChat Name: {agent.name}\n"
return summary
from .agents import Agent
Context.model_rebuild()