backstory/src/utils/context.py

192 lines
7.7 KiB
Python

from __future__ import annotations
from pydantic import BaseModel, Field, model_validator# type: ignore
from uuid import uuid4
from typing import List, Optional, Generator, ClassVar, Any
from typing_extensions import Annotated, Union
import numpy as np # type: ignore
import logging
from uuid import uuid4
from prometheus_client import CollectorRegistry, Counter # type: ignore
from . message import Message, Tunables
from . rag import ChromaDBFileWatcher
from . import defines
from . import tools as Tools
from . agents import AnyAgent
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Context(BaseModel):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
# Required fields
file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
prometheus_collector: Optional[CollectorRegistry] = Field(default=None, exclude=True)
# Optional fields
id: str = Field(
default_factory=lambda: str(uuid4()),
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
)
user_resume: Optional[str] = None
user_job_description: Optional[str] = None
user_facts: Optional[str] = None
tools: List[dict] = Tools.enabled_tools(Tools.tools)
rags: List[dict] = []
message_history_length: int = 5
# Class managed fields
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
default_factory=list
)
processing: bool = Field(default=False, exclude=True)
# @model_validator(mode="before")
# @classmethod
# def before_model_validator(cls, values: Any):
# logger.info(f"Preparing model data: {cls} {values}")
# return values
@model_validator(mode="after")
def after_model_validator(self):
"""Ensure at most one agent per agent_type."""
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
agent_types = [agent.agent_type for agent in self.agents]
if len(agent_types) != len(set(agent_types)):
raise ValueError("Context cannot contain multiple agents of the same agent_type")
# for agent in self.agents:
# agent.set_context(self)
return self
def generate_rag_results(self, message: Message) -> Generator[Message, None, None]:
"""
Generate RAG results for the given query.
Args:
query: The query string to generate RAG results for.
Returns:
A list of dictionaries containing the RAG results.
"""
try:
message.status = "processing"
entries : int = 0
if not self.file_watcher:
message.response = "No RAG context available."
del message.metadata["rag"]
message.status = "done"
yield message
return
message.metadata["rag"] = []
for rag in self.rags:
if not rag["enabled"]:
continue
message.response = f"Checking RAG context {rag['name']}..."
yield message
chroma_results = self.file_watcher.find_similar(query=message.prompt, top_k=10, threshold=0.7)
if chroma_results:
entries += len(chroma_results["documents"])
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
print(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
message.metadata["rag"].append({
"name": rag["name"],
**chroma_results,
"umap_embedding_2d": umap_2d,
"umap_embedding_3d": umap_3d
})
message.response = f"Results from {rag['name']} RAG: {len(chroma_results['documents'])} results."
yield message
if entries == 0:
del message.metadata["rag"]
message.response = f"RAG context gathered from results from {entries} documents."
message.status = "done"
yield message
return
except Exception as e:
message.status = "error"
message.response = f"Error generating RAG results: {str(e)}"
logger.error(e)
yield message
return
def get_or_create_agent(self, agent_type: str, **kwargs) -> Agent:
"""
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
Args:
agent_type: The type of agent to create (e.g., 'web', 'database').
**kwargs: Additional fields required by the specific agent subclass.
Returns:
The created agent instance.
Raises:
ValueError: If no matching agent type is found or if a agent of this type already exists.
"""
# Check if a agent with the given agent_type already exists
for agent in self.agents:
if agent.agent_type == agent_type:
return agent
# Find the matching subclass
for agent_cls in Agent.__subclasses__():
if agent_cls.model_fields["agent_type"].default == agent_type:
# Create the agent instance with provided kwargs
agent = agent_cls(agent_type=agent_type, **kwargs)
# set_context after constructor to initialize any non-serialized data
agent.set_context(self)
self.agents.append(agent)
return agent
raise ValueError(f"No agent class found for agent_type: {agent_type}")
def add_agent(self, agent: AnyAgent) -> None:
"""Add a Agent to the context, ensuring no duplicate agent_type."""
if any(s.agent_type == agent.agent_type for s in self.agents):
raise ValueError(f"A agent with agent_type '{agent.agent_type}' already exists")
self.agents.append(agent)
def get_agent(self, agent_type: str) -> Agent | None:
"""Return the Agent with the given agent_type, or None if not found."""
for agent in self.agents:
if agent.agent_type == agent_type:
return agent
return None
def is_valid_agent_type(self, agent_type: str) -> bool:
"""Check if the given agent_type is valid."""
return agent_type in Agent.valid_agent_types()
def get_summary(self) -> str:
"""Return a summary of the context."""
if not self.agents:
return f"Context {self.uuid}: No agents."
summary = f"Context {self.uuid}:\n"
for i, agent in enumerate(self.agents, 1):
summary += f"\nAgent {i} ({agent.agent_type}):\n"
summary += agent.conversation.get_summary()
if agent.agent_type == "resume":
summary += f"\nResume: {agent.get_resume()}\n"
elif agent.agent_type == "job_description":
summary += f"\nJob Description: {agent.job_description}\n"
elif agent.agent_type == "fact_check":
summary += f"\nFacts: {agent.facts}\n"
elif agent.agent_type == "chat":
summary += f"\nChat Name: {agent.name}\n"
return summary
from . agents import Agent
Context.model_rebuild()