685 lines
27 KiB
Python

from __future__ import annotations
import traceback
from pydantic import BaseModel, Field, model_validator # type: ignore
from typing import (
Literal,
get_args,
List,
AsyncGenerator,
TYPE_CHECKING,
Optional,
ClassVar,
Any,
TypeAlias,
Dict,
Tuple,
)
import json
import time
import inspect
from abc import ABC
import asyncio
from datetime import datetime, UTC
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
import numpy as np # type: ignore
from models import ( LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageBase, ChatMessageUser, Tunables, ChatMessageType, ChatSenderType, ChatStatusType, ChatMessageMetaData, Candidate)
from logger import logger
import defines
from .registry import agent_registry
from metrics import Metrics
import model_cast
from rag import ( ChromaDBGetResponse )
class Agent(BaseModel, ABC):
"""
Base class for all agent types.
This class defines the common attributes and methods for all agent types.
"""
class Config:
arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase
# Agent management with pydantic
agent_type: Literal["base"] = "base"
_agent_type: ClassVar[str] = agent_type # Add this for registration
user: Optional[Candidate] = None
prometheus_collector: CollectorRegistry = Field(..., description="Prometheus collector for this agent, used to track metrics.", exclude=True)
# Tunables (sets default for new Messages attached to this agent)
tunables: Tunables = Field(default_factory=Tunables)
metrics: Metrics = Field(
None, description="Metrics collector for this agent, used to track performance and usage."
)
@model_validator(mode="after")
def initialize_metrics(self) -> "Agent":
if self.metrics is None:
self.metrics = Metrics(prometheus_collector=self.prometheus_collector)
return self
# Agent properties
system_prompt: str = ""
context_tokens: int = 0
# context_size is shared across all subclasses
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
conversation: List[ChatMessage] = Field(
default_factory=list,
description="Conversation history for this agent, used to maintain context across messages."
)
@property
def context_size(self) -> int:
return Agent._context_size
@context_size.setter
def context_size(self, value: int):
Agent._context_size = value
def set_optimal_context_size(
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
) -> int:
# Most models average 1.3-1.5 tokens per word
word_count = len(prompt.split())
tokens = int(word_count * 1.4)
# Add buffer for safety
total_ctx = tokens + ctx_buffer
if total_ctx > self.context_size:
logger.info(
f"Increasing context size from {self.context_size} to {total_ctx}"
)
# Grow the context size if necessary
self.context_size = max(self.context_size, total_ctx)
# Use actual model maximum context size
return self.context_size
# Class and pydantic model management
def __init_subclass__(cls, **kwargs) -> None:
"""Auto-register subclasses"""
super().__init_subclass__(**kwargs)
# Register this class if it has an agent_type
if hasattr(cls, "agent_type") and cls.agent_type != Agent._agent_type:
agent_registry.register(cls.agent_type, cls)
def model_dump(self, *args, **kwargs) -> Any:
# Ensure context is always excluded, even with exclude_unset=True
kwargs.setdefault("exclude", set())
if isinstance(kwargs["exclude"], set):
kwargs["exclude"].add("context")
elif isinstance(kwargs["exclude"], dict):
kwargs["exclude"]["context"] = True
return super().model_dump(*args, **kwargs)
@classmethod
def valid_agent_types(cls) -> set[str]:
"""Return the set of valid agent_type values."""
return set(get_args(cls.__annotations__["agent_type"]))
# Agent methods
def get_agent_type(self):
return self._agent_type
# async def process_tool_calls(
# self,
# llm: Any,
# model: str,
# message: ChatMessage,
# tool_message: Any, # llama response message
# messages: List[LLMMessage],
# ) -> AsyncGenerator[ChatMessage, None]:
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
# self.metrics.tool_count.labels(agent=self.agent_type).inc()
# with self.metrics.tool_duration.labels(agent=self.agent_type).time():
# if not self.context:
# raise ValueError("Context is not set for this agent.")
# if not message.metadata.tools:
# raise ValueError("tools field not initialized")
# tool_metadata = message.metadata.tools
# tool_metadata["tool_calls"] = []
# message.status = "tooling"
# for i, tool_call in enumerate(tool_message.tool_calls):
# arguments = tool_call.function.arguments
# tool = tool_call.function.name
# # Yield status update before processing each tool
# message.content = (
# f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
# )
# yield message
# logger.info(f"LLM - {message.content}")
# # Process the tool based on its type
# match tool:
# case "TickerValue":
# ticker = arguments.get("ticker")
# if not ticker:
# ret = None
# else:
# ret = TickerValue(ticker)
# case "AnalyzeSite":
# url = arguments.get("url")
# question = arguments.get(
# "question", "what is the summary of this content?"
# )
# # Additional status update for long-running operations
# message.content = (
# f"Retrieving and summarizing content from {url}..."
# )
# yield message
# ret = await AnalyzeSite(
# llm=llm, model=model, url=url, question=question
# )
# case "GenerateImage":
# prompt = arguments.get("prompt", None)
# if not prompt:
# logger.info("No prompt supplied to GenerateImage")
# ret = { "error": "No prompt supplied to GenerateImage" }
# # Additional status update for long-running operations
# message.content = (
# f"Generating image for {prompt}..."
# )
# yield message
# ret = await GenerateImage(
# llm=llm, model=model, prompt=prompt
# )
# logger.info("GenerateImage returning", ret)
# case "DateTime":
# tz = arguments.get("timezone")
# ret = DateTime(tz)
# case "WeatherForecast":
# city = arguments.get("city")
# state = arguments.get("state")
# message.content = (
# f"Fetching weather data for {city}, {state}..."
# )
# yield message
# ret = WeatherForecast(city, state)
# case _:
# logger.error(f"Requested tool {tool} does not exist")
# ret = None
# # Build response for this tool
# tool_response = {
# "role": "tool",
# "content": json.dumps(ret),
# "name": tool_call.function.name,
# }
# tool_metadata["tool_calls"].append(tool_response)
# if len(tool_metadata["tool_calls"]) == 0:
# message.status = "done"
# yield message
# return
# message_dict = LLMMessage(
# role=tool_message.get("role", "assistant"),
# content=tool_message.get("content", ""),
# tool_calls=[
# {
# "function": {
# "name": tc["function"]["name"],
# "arguments": tc["function"]["arguments"],
# }
# }
# for tc in tool_message.tool_calls
# ],
# )
# messages.append(message_dict)
# messages.extend(tool_metadata["tool_calls"])
# message.status = "thinking"
# message.content = "Incorporating tool results into response..."
# yield message
# # Decrease creativity when processing tool call requests
# message.content = ""
# start_time = time.perf_counter()
# for response in llm.chat(
# model=model,
# messages=messages,
# options={
# **message.metadata.options,
# },
# stream=True,
# ):
# # logger.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
# message.status = "streaming"
# message.chunk = response.message.content
# message.content += message.chunk
# if not response.done:
# yield message
# if response.done:
# self.collect_metrics(response)
# message.metadata.eval_count += response.eval_count
# message.metadata.eval_duration += response.eval_duration
# message.metadata.prompt_eval_count += response.prompt_eval_count
# message.metadata.prompt_eval_duration += response.prompt_eval_duration
# self.context_tokens = (
# response.prompt_eval_count + response.eval_count
# )
# message.status = "done"
# yield message
# end_time = time.perf_counter()
# message.metadata.timers["llm_with_tools"] = end_time - start_time
# return
def collect_metrics(self, response):
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
response.prompt_eval_count
)
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
async def generate_rag_results(
self,
chat_message: ChatMessage,
top_k: int=defines.default_rag_top_k,
threshold: float=defines.default_rag_threshold
) -> AsyncGenerator[ChatMessage, None]:
"""
Generate RAG results for the given query.
Args:
query: The query string to generate RAG results for.
Returns:
A list of dictionaries containing the RAG results.
"""
rag_message = ChatMessage(
session_id=chat_message.session_id,
tunables=chat_message.tunables,
status=ChatStatusType.INITIALIZING,
type=ChatMessageType.PREPARING,
sender=ChatSenderType.ASSISTANT,
content="",
timestamp=datetime.now(UTC),
metadata=ChatMessageMetaData()
)
if not self.user:
rag_message.status = ChatStatusType.DONE
rag_message.content = "No user connected to this chat, so no RAG content."
yield rag_message
return
try:
entries: int = 0
user: Candidate = self.user
for rag in user.rags:
if not rag.enabled:
continue
rag_message.type = ChatMessageType.SEARCHING
rag_message.status = ChatStatusType.INITIALIZING
rag_message.content = f"Checking RAG context {rag.name}..."
yield rag_message
chroma_results = user.file_watcher.find_similar(
query=rag_message.content, top_k=top_k, threshold=threshold
)
if chroma_results:
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0]
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0]
rag_metadata = ChromaDBGetResponse(
query=chat_message.content,
query_embedding=query_embedding.tolist(),
name=rag.name,
ids=chroma_results.get("ids", []),
embeddings=chroma_results.get("embeddings", []),
documents=chroma_results.get("documents", []),
metadatas=chroma_results.get("metadatas", []),
umap_embedding_2d=umap_2d.tolist(),
umap_embedding_3d=umap_3d.tolist(),
size=user.file_watcher.collection.count()
)
entries += len(rag_metadata.documents)
rag_message.metadata.rag_results.append(rag_metadata)
rag_message.content = f"Results from {rag.name} RAG: {len(rag_metadata.documents)} results."
yield rag_message
rag_message.content = (
f"RAG context gathered from results from {entries} documents."
)
rag_message.status = ChatStatusType.DONE
yield rag_message
return
except Exception as e:
rag_message.status = ChatStatusType.ERROR
rag_message.content = f"Error generating RAG results: {str(e)}"
logger.error(traceback.format_exc())
logger.error(rag_message.content)
yield rag_message
return
async def generate(
self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate | None, temperature=0.7
) -> AsyncGenerator[ChatMessage | ChatMessageBase, None]:
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
chat_message = ChatMessage(
session_id=user_message.session_id,
tunables=user_message.tunables,
status=ChatStatusType.INITIALIZING,
type=ChatMessageType.PREPARING,
sender=ChatSenderType.ASSISTANT,
content="",
timestamp=datetime.now(UTC)
)
chat_message.metadata = ChatMessageMetaData()
chat_message.metadata.options = ChatOptions(
seed=8911,
num_ctx=self.context_size,
temperature=temperature, # Higher temperature to encourage tool usage
)
# Create a dict for storing various timing stats
chat_message.metadata.timers = {}
self.metrics.generate_count.labels(agent=self.agent_type).inc()
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
rag_message : Optional[ChatMessage] = None
async for rag_message in self.generate_rag_results(chat_message=user_message):
if rag_message.status == ChatStatusType.ERROR:
chat_message.status = rag_message.status
chat_message.content = rag_message.content
yield chat_message
return
yield rag_message
rag_context = ""
if rag_message:
rag_results: List[ChromaDBGetResponse] = rag_message.metadata.rag_results
chat_message.metadata.rag_results = rag_results
for chroma_results in rag_results:
for index, metadata in enumerate(chroma_results.metadatas):
content = "\n".join([
line.strip()
for line in chroma_results.documents[index].split("\n")
if line
]).strip()
rag_context += f"""
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
Document reference: {chroma_results.ids[index]}
Content: { content }
"""
# Create a pruned down message list based purely on the prompt and responses,
# discarding the full preamble generated by prepare_message
messages: List[LLMMessage] = [
LLMMessage(role="system", content=self.system_prompt)
]
# Add the conversation history to the messages
messages.extend([
LLMMessage(role=m.sender, content=m.content.strip())
for m in self.conversation
])
# Add the RAG context to the messages if available
if rag_context and user:
messages.append(
LLMMessage(
role="user",
content=f"<|context|>\nThe following is context information about {user.full_name}:\n{rag_context.strip()}\n</|context|>\n\nPrompt to respond to:\n{user_message.content.strip()}\n"
)
)
else:
# Only the actual user query is provided with the full context message
messages.append(
LLMMessage(role=user_message.sender, content=user_message.content.strip())
)
chat_message.metadata.llm_history = messages
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
# message.metadata.tools = {
# "available": llm_tools(self.context.tools),
# "used": False,
# }
# tool_metadata = message.metadata.tools
# if use_tools:
# message.status = "thinking"
# message.content = f"Performing tool analysis step 1/2..."
# yield message
# logger.info("Checking for LLM tool usage")
# start_time = time.perf_counter()
# # Tools are enabled and available, so query the LLM with a short context of messages
# # in case the LLM did something like ask "Do you want me to run the tool?" and the
# # user said "Yes" -- need to keep the context in the thread.
# tool_metadata["messages"] = (
# [{"role": "system", "content": self.system_prompt}] + messages[-6:]
# if len(messages) >= 7
# else messages
# )
# response = llm.chat(
# model=model,
# messages=tool_metadata["messages"],
# tools=tool_metadata["available"],
# options={
# **message.metadata.options,
# },
# stream=False, # No need to stream the probe
# )
# self.collect_metrics(response)
# end_time = time.perf_counter()
# message.metadata.timers["tool_check"] = end_time - start_time
# if not response.message.tool_calls:
# logger.info("LLM indicates tools will not be used")
# # The LLM will not use tools, so disable use_tools so we can stream the full response
# use_tools = False
# else:
# tool_metadata["attempted"] = response.message.tool_calls
# if use_tools:
# logger.info("LLM indicates tools will be used")
# # Tools are enabled and available and the LLM indicated it will use them
# message.content = (
# f"Performing tool analysis step 2/2 (tool use suspected)..."
# )
# yield message
# logger.info(f"Performing LLM call with tools")
# start_time = time.perf_counter()
# response = llm.chat(
# model=model,
# messages=tool_metadata["messages"], # messages,
# tools=tool_metadata["available"],
# options={
# **message.metadata.options,
# },
# stream=False,
# )
# self.collect_metrics(response)
# end_time = time.perf_counter()
# message.metadata.timers["non_streaming"] = end_time - start_time
# if not response:
# message.status = "error"
# message.content = "No response from LLM."
# yield message
# return
# if response.message.tool_calls:
# tool_metadata["used"] = response.message.tool_calls
# # Process all yielded items from the handler
# start_time = time.perf_counter()
# async for message in self.process_tool_calls(
# llm=llm,
# model=model,
# message=message,
# tool_message=response.message,
# messages=messages,
# ):
# if message.status == "error":
# yield message
# return
# yield message
# end_time = time.perf_counter()
# message.metadata.timers["process_tool_calls"] = end_time - start_time
# message.status = "done"
# return
# logger.info("LLM indicated tools will be used, and then they weren't")
# message.content = response.message.content
# message.status = "done"
# yield message
# return
# not use_tools
chat_message.type = ChatMessageType.THINKING
chat_message.content = f"Generating response..."
yield chat_message
# Reset the response for streaming
chat_message.content = ""
start_time = time.perf_counter()
chat_message.type = ChatMessageType.GENERATING
chat_message.status = ChatStatusType.STREAMING
for response in llm.chat(
model=model,
messages=messages,
options={
**chat_message.metadata.options.model_dump(exclude_unset=True),
},
stream=True,
):
if not response:
chat_message.status = ChatStatusType.ERROR
chat_message.content = "No response from LLM."
yield chat_message
return
chat_message.content += response.message.content
if not response.done:
chat_chunk = model_cast.cast_to_model(ChatMessageBase, chat_message)
chat_chunk.content = response.message.content
yield chat_message
continue
if response.done:
self.collect_metrics(response)
chat_message.metadata.eval_count += response.eval_count
chat_message.metadata.eval_duration += response.eval_duration
chat_message.metadata.prompt_eval_count += response.prompt_eval_count
chat_message.metadata.prompt_eval_duration += response.prompt_eval_duration
self.context_tokens = (
response.prompt_eval_count + response.eval_count
)
chat_message.type = ChatMessageType.RESPONSE
chat_message.status = ChatStatusType.DONE
yield chat_message
end_time = time.perf_counter()
chat_message.metadata.timers["streamed"] = end_time - start_time
# Add the user and chat messages to the conversation
self.conversation.append(user_message)
self.conversation.append(chat_message)
return
# async def process_message(
# self, llm: Any, model: str, message: Message
# ) -> AsyncGenerator[Message, None]:
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
# self.metrics.process_count.labels(agent=self.agent_type).inc()
# with self.metrics.process_duration.labels(agent=self.agent_type).time():
# if not self.context:
# raise ValueError("Context is not set for this agent.")
# logger.info(
# "TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
# )
# spinner: List[str] = ["\\", "|", "/", "-"]
# tick: int = 0
# while self.context.processing:
# message.status = "waiting"
# message.content = (
# f"Busy processing another request. Please wait. {spinner[tick]}"
# )
# tick = (tick + 1) % len(spinner)
# yield message
# await asyncio.sleep(1) # Allow the event loop to process the write
# self.context.processing = True
# message.system_prompt = (
# f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
# )
# message.context_prompt = ""
# for p in message.preamble.keys():
# message.context_prompt += (
# f"\n<|{p}|>\n{message.preamble[p].strip()}\n</|{p}>\n\n"
# )
# message.context_prompt += f"{message.prompt}"
# # Estimate token length of new messages
# message.content = f"Optimizing context..."
# message.status = "thinking"
# yield message
# message.context_size = self.set_optimal_context_size(
# llm, model, prompt=message.context_prompt
# )
# message.content = f"Processing {'RAG augmented ' if message.metadata.rag else ''}query..."
# message.status = "thinking"
# yield message
# async for message in self.generate_llm_response(
# llm=llm, model=model, message=message
# ):
# # logger.info(f"LLM: {message.status} - {f'...{message.content[-20:]}' if len(message.content) > 20 else message.content}")
# if message.status == "error":
# yield message
# self.context.processing = False
# return
# yield message
# # Done processing, add message to conversation
# message.status = "done"
# self.conversation.add(message)
# self.context.processing = False
# return
# Register the base agent
agent_registry.register(Agent._agent_type, Agent)