182 lines
7.0 KiB
Python
182 lines
7.0 KiB
Python
from __future__ import annotations
|
|
from pydantic import BaseModel, Field, model_validator
|
|
from uuid import uuid4
|
|
from typing import List, Optional, Generator, ClassVar, Any, TYPE_CHECKING
|
|
from typing_extensions import Annotated, Union
|
|
import numpy as np
|
|
import logging
|
|
from uuid import uuid4
|
|
import traceback
|
|
|
|
from . rag import RagEntry
|
|
from . tools import ToolEntry
|
|
from . agents import AnyAgent
|
|
from . import User
|
|
|
|
if TYPE_CHECKING:
|
|
from .user import User
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class Context(BaseModel):
|
|
class Config:
|
|
validate_by_name = True # Allow 'user' to be set via constructor
|
|
arbitrary_types_allowed = True # Allow ChromaDBFileWatcher
|
|
|
|
id: str = Field(
|
|
default_factory=lambda: str(uuid4()),
|
|
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
|
)
|
|
tools: List[ToolEntry]
|
|
rags: List[RagEntry]
|
|
username: str = "__invalid__"
|
|
|
|
# "user" is not serialized and must be set after construction
|
|
Context__user: User = Field(
|
|
default_factory=lambda: User(username="__invalid__", llm=None, rags=[]),
|
|
alias="user",
|
|
exclude=True)
|
|
|
|
# Optional fields
|
|
user_resume: Optional[str] = None
|
|
user_job_description: Optional[str] = None
|
|
user_facts: Optional[str] = None
|
|
|
|
# Class managed fields
|
|
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
|
|
default_factory=list
|
|
)
|
|
|
|
@model_validator(mode='after')
|
|
def set_user_and_username(self):
|
|
if self.Context__user.username != "__invalid__":
|
|
if self.username == "__invalid__":
|
|
logger.info(f"Binding context {self.id} to user {self.Context__user.username}")
|
|
self.username = self.Context__user.username
|
|
else:
|
|
raise ValueError("user can only be set once")
|
|
return self
|
|
|
|
# Only allow dereference of 'user' if it has been set
|
|
@property
|
|
def user(self) -> User:
|
|
if self.Context__user.username == "__invalid__":
|
|
# After deserializing Context(), you must explicitly set the
|
|
# user:
|
|
#
|
|
# context = Context(...)
|
|
# context.user = <valid user>
|
|
raise ValueError("Attempt to dereference default_factory constructed User")
|
|
return self.Context__user
|
|
|
|
@user.setter
|
|
def user(self, new_user: User) -> User:
|
|
logger.info(f"Binding context {self.id} to user {new_user.username}")
|
|
if self.Context__user.username != "__invalid__" and new_user.username != self.Context__user.username:
|
|
logger.info(f"Resetting context changing from {self.Context__user.username} to user {new_user.username}")
|
|
self.agents = [a for a in self.agents if a.agent_type == "chat"]
|
|
for a in self.agents:
|
|
a.conversation.reset()
|
|
self.username = new_user.username
|
|
self.Context__user = new_user
|
|
return new_user
|
|
|
|
processing: bool = Field(default=False, exclude=True)
|
|
|
|
# @model_validator(mode="before")
|
|
# @classmethod
|
|
# def before_model_validator(cls, values: Any):
|
|
# logger.info(f"Preparing model data: {cls} {values}")
|
|
# return values
|
|
|
|
@model_validator(mode="after")
|
|
def after_model_validator(self):
|
|
"""Ensure at most one agent per agent_type."""
|
|
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
|
|
agent_types = [agent.agent_type for agent in self.agents]
|
|
if len(agent_types) != len(set(agent_types)):
|
|
raise ValueError(
|
|
"Context cannot contain multiple agents of the same agent_type"
|
|
)
|
|
# for agent in self.agents:
|
|
# agent.set_context(self)
|
|
return self
|
|
|
|
def get_or_create_agent(self, agent_type: str, **kwargs) -> Agent:
|
|
"""
|
|
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
|
|
|
|
Args:
|
|
agent_type: The type of agent to create (e.g., 'web', 'database').
|
|
**kwargs: Additional fields required by the specific agent subclass.
|
|
|
|
Returns:
|
|
The created agent instance.
|
|
|
|
Raises:
|
|
ValueError: If no matching agent type is found or if a agent of this type already exists.
|
|
"""
|
|
# Check if a agent with the given agent_type already exists
|
|
for agent in self.agents:
|
|
if agent.agent_type == agent_type:
|
|
return agent
|
|
|
|
# Find the matching subclass
|
|
for agent_cls in Agent.__subclasses__():
|
|
if agent_cls.model_fields["agent_type"].default == agent_type:
|
|
# Create the agent instance with provided kwargs
|
|
agent = agent_cls(agent_type=agent_type, **kwargs)
|
|
# set_context after constructor to initialize any non-serialized data
|
|
agent.set_context(self)
|
|
if agent.agent_persist: # If an agent is not set to persist, do not add it to the context
|
|
self.agents.append(agent)
|
|
return agent
|
|
|
|
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
|
|
|
def add_agent(self, agent: AnyAgent) -> None:
|
|
"""Add a Agent to the context, ensuring no duplicate agent_type."""
|
|
if any(s.agent_type == agent.agent_type for s in self.agents):
|
|
raise ValueError(
|
|
f"A agent with agent_type '{agent.agent_type}' already exists"
|
|
)
|
|
self.agents.append(agent)
|
|
|
|
def remove_agent(self, agent: AnyAgent) -> None:
|
|
"""Remove an Agent from the context based on its agent_type."""
|
|
self.agents = [a for a in self.agents if a.agent_type != agent.agent_type]
|
|
|
|
def get_agent(self, agent_type: str) -> Agent | None:
|
|
"""Return the Agent with the given agent_type, or None if not found."""
|
|
for agent in self.agents:
|
|
if agent.agent_type == agent_type:
|
|
return agent
|
|
return None
|
|
|
|
def is_valid_agent_type(self, agent_type: str) -> bool:
|
|
"""Check if the given agent_type is valid."""
|
|
return agent_type in Agent.valid_agent_types()
|
|
|
|
def get_summary(self) -> str:
|
|
"""Return a summary of the context."""
|
|
if not self.agents:
|
|
return f"Context {self.uuid}: No agents."
|
|
summary = f"Context {self.uuid}:\n"
|
|
for i, agent in enumerate(self.agents, 1):
|
|
summary += f"\nAgent {i} ({agent.agent_type}):\n"
|
|
summary += agent.conversation.get_summary()
|
|
if agent.agent_type == "resume":
|
|
summary += f"\nResume: {agent.get_resume()}\n"
|
|
elif agent.agent_type == "job_description":
|
|
summary += f"\nJob Description: {agent.job_description}\n"
|
|
elif agent.agent_type == "fact_check":
|
|
summary += f"\nFacts: {agent.facts}\n"
|
|
elif agent.agent_type == "chat":
|
|
summary += f"\nChat Name: {agent.name}\n"
|
|
return summary
|
|
|
|
from .agents import Agent
|
|
|
|
Context.model_rebuild()
|