192 lines
7.7 KiB
Python
192 lines
7.7 KiB
Python
from __future__ import annotations
|
|
from pydantic import BaseModel, Field, model_validator# type: ignore
|
|
from uuid import uuid4
|
|
from typing import List, Optional, Generator, ClassVar, Any
|
|
from typing_extensions import Annotated, Union
|
|
import numpy as np # type: ignore
|
|
import logging
|
|
from uuid import uuid4
|
|
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
|
|
|
from . message import Message, Tunables
|
|
from . rag import ChromaDBFileWatcher
|
|
from . import defines
|
|
from . import tools as Tools
|
|
from . agents import AnyAgent
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class Context(BaseModel):
|
|
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
|
# Required fields
|
|
file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
|
prometheus_collector: Optional[CollectorRegistry] = Field(default=None, exclude=True)
|
|
|
|
# Optional fields
|
|
id: str = Field(
|
|
default_factory=lambda: str(uuid4()),
|
|
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
|
)
|
|
user_resume: Optional[str] = None
|
|
user_job_description: Optional[str] = None
|
|
user_facts: Optional[str] = None
|
|
tools: List[dict] = Tools.enabled_tools(Tools.tools)
|
|
rags: List[dict] = []
|
|
message_history_length: int = 5
|
|
# Class managed fields
|
|
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
|
|
default_factory=list
|
|
)
|
|
|
|
processing: bool = Field(default=False, exclude=True)
|
|
|
|
# @model_validator(mode="before")
|
|
# @classmethod
|
|
# def before_model_validator(cls, values: Any):
|
|
# logger.info(f"Preparing model data: {cls} {values}")
|
|
# return values
|
|
|
|
@model_validator(mode="after")
|
|
def after_model_validator(self):
|
|
"""Ensure at most one agent per agent_type."""
|
|
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
|
|
agent_types = [agent.agent_type for agent in self.agents]
|
|
if len(agent_types) != len(set(agent_types)):
|
|
raise ValueError("Context cannot contain multiple agents of the same agent_type")
|
|
# for agent in self.agents:
|
|
# agent.set_context(self)
|
|
return self
|
|
|
|
def generate_rag_results(self, message: Message) -> Generator[Message, None, None]:
|
|
"""
|
|
Generate RAG results for the given query.
|
|
|
|
Args:
|
|
query: The query string to generate RAG results for.
|
|
|
|
Returns:
|
|
A list of dictionaries containing the RAG results.
|
|
"""
|
|
try:
|
|
message.status = "processing"
|
|
|
|
entries : int = 0
|
|
|
|
if not self.file_watcher:
|
|
message.response = "No RAG context available."
|
|
del message.metadata["rag"]
|
|
message.status = "done"
|
|
yield message
|
|
return
|
|
|
|
message.metadata["rag"] = []
|
|
for rag in self.rags:
|
|
if not rag["enabled"]:
|
|
continue
|
|
message.response = f"Checking RAG context {rag['name']}..."
|
|
yield message
|
|
chroma_results = self.file_watcher.find_similar(query=message.prompt, top_k=10, threshold=0.7)
|
|
if chroma_results:
|
|
entries += len(chroma_results["documents"])
|
|
|
|
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
|
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
|
|
|
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
|
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
|
|
|
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
|
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
|
|
|
message.metadata["rag"].append({
|
|
"name": rag["name"],
|
|
**chroma_results,
|
|
"umap_embedding_2d": umap_2d,
|
|
"umap_embedding_3d": umap_3d
|
|
})
|
|
message.response = f"Results from {rag['name']} RAG: {len(chroma_results['documents'])} results."
|
|
yield message
|
|
|
|
if entries == 0:
|
|
del message.metadata["rag"]
|
|
|
|
message.response = f"RAG context gathered from results from {entries} documents."
|
|
message.status = "done"
|
|
yield message
|
|
return
|
|
except Exception as e:
|
|
message.status = "error"
|
|
message.response = f"Error generating RAG results: {str(e)}"
|
|
logger.error(e)
|
|
yield message
|
|
return
|
|
|
|
def get_or_create_agent(self, agent_type: str, **kwargs) -> Agent:
|
|
"""
|
|
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
|
|
|
|
Args:
|
|
agent_type: The type of agent to create (e.g., 'web', 'database').
|
|
**kwargs: Additional fields required by the specific agent subclass.
|
|
|
|
Returns:
|
|
The created agent instance.
|
|
|
|
Raises:
|
|
ValueError: If no matching agent type is found or if a agent of this type already exists.
|
|
"""
|
|
# Check if a agent with the given agent_type already exists
|
|
for agent in self.agents:
|
|
if agent.agent_type == agent_type:
|
|
return agent
|
|
|
|
# Find the matching subclass
|
|
for agent_cls in Agent.__subclasses__():
|
|
if agent_cls.model_fields["agent_type"].default == agent_type:
|
|
# Create the agent instance with provided kwargs
|
|
agent = agent_cls(agent_type=agent_type, **kwargs)
|
|
# set_context after constructor to initialize any non-serialized data
|
|
agent.set_context(self)
|
|
self.agents.append(agent)
|
|
return agent
|
|
|
|
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
|
|
|
def add_agent(self, agent: AnyAgent) -> None:
|
|
"""Add a Agent to the context, ensuring no duplicate agent_type."""
|
|
if any(s.agent_type == agent.agent_type for s in self.agents):
|
|
raise ValueError(f"A agent with agent_type '{agent.agent_type}' already exists")
|
|
self.agents.append(agent)
|
|
|
|
def get_agent(self, agent_type: str) -> Agent | None:
|
|
"""Return the Agent with the given agent_type, or None if not found."""
|
|
for agent in self.agents:
|
|
if agent.agent_type == agent_type:
|
|
return agent
|
|
return None
|
|
|
|
def is_valid_agent_type(self, agent_type: str) -> bool:
|
|
"""Check if the given agent_type is valid."""
|
|
return agent_type in Agent.valid_agent_types()
|
|
|
|
def get_summary(self) -> str:
|
|
"""Return a summary of the context."""
|
|
if not self.agents:
|
|
return f"Context {self.uuid}: No agents."
|
|
summary = f"Context {self.uuid}:\n"
|
|
for i, agent in enumerate(self.agents, 1):
|
|
summary += f"\nAgent {i} ({agent.agent_type}):\n"
|
|
summary += agent.conversation.get_summary()
|
|
if agent.agent_type == "resume":
|
|
summary += f"\nResume: {agent.get_resume()}\n"
|
|
elif agent.agent_type == "job_description":
|
|
summary += f"\nJob Description: {agent.job_description}\n"
|
|
elif agent.agent_type == "fact_check":
|
|
summary += f"\nFacts: {agent.facts}\n"
|
|
elif agent.agent_type == "chat":
|
|
summary += f"\nChat Name: {agent.name}\n"
|
|
return summary
|
|
|
|
from . agents import Agent
|
|
Context.model_rebuild() |