477 lines
18 KiB
Python

from __future__ import annotations
from pydantic import BaseModel, PrivateAttr, Field # type: ignore
from typing import (
Literal, get_args, List, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any,
TypeAlias, Dict, Tuple
)
from abc import ABC
from .. setup_logging import setup_logging
from .. import defines
from abc import ABC
import logging
from .. message import Message
from .. import tools as Tools
import json
import time
import inspect
logger = setup_logging()
# Only import Context for type checking
if TYPE_CHECKING:
from .. context import Context
from .types import registry
from .. conversation import Conversation
from .. message import Message
class Agent(BaseModel, ABC):
"""
Base class for all agent types.
This class defines the common attributes and methods for all agent types.
"""
# Agent management with pydantic
agent_type: Literal["base"] = "base"
_agent_type: ClassVar[str] = agent_type # Add this for registration
# context_size is shared across all subclasses
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
@property
def context_size(self) -> int:
return Agent._context_size
@context_size.setter
def context_size(self, value: int):
Agent._context_size = value
# Agent properties
system_prompt: str # Mandatory
conversation: Conversation = Conversation()
context_tokens: int = 0
context: Optional[Context] = Field(default=None, exclude=True) # Avoid circular reference, require as param, and prevent serialization
_content_seed: str = PrivateAttr(default="")
def set_optimal_context_size(self, llm: Any, model: str, prompt: str, ctx_buffer=2048) -> int:
# # Get more accurate token count estimate using tiktoken or similar
# response = llm.generate(
# model=model,
# prompt=prompt,
# options={
# "num_ctx": self.context_size,
# "num_predict": 0,
# } # Don't generate any tokens, just tokenize
# )
# # The prompt_eval_count gives you the token count of your input
# tokens = response.get("prompt_eval_count", 0)
# Most models average 1.3-1.5 tokens per word
word_count = len(prompt.split())
tokens = int(word_count * 1.4)
# Add buffer for safety
total_ctx = tokens + ctx_buffer
if total_ctx > self.context_size:
logger.info(f"Increasing context size from {self.context_size} to {total_ctx}")
# Grow the context size if necessary
self.context_size = max(self.context_size, total_ctx)
# Use actual model maximum context size
return self.context_size
# Class and pydantic model management
def __init_subclass__(cls, **kwargs) -> None:
"""Auto-register subclasses"""
super().__init_subclass__(**kwargs)
# Register this class if it has an agent_type
if hasattr(cls, 'agent_type') and cls.agent_type != Agent._agent_type:
registry.register(cls.agent_type, cls)
def model_dump(self, *args, **kwargs) -> Any:
# Ensure context is always excluded, even with exclude_unset=True
kwargs.setdefault("exclude", set())
if isinstance(kwargs["exclude"], set):
kwargs["exclude"].add("context")
elif isinstance(kwargs["exclude"], dict):
kwargs["exclude"]["context"] = True
return super().model_dump(*args, **kwargs)
@classmethod
def valid_agent_types(cls) -> set[str]:
"""Return the set of valid agent_type values."""
return set(get_args(cls.__annotations__["agent_type"]))
def set_context(self, context):
object.__setattr__(self, "context", context)
# Agent methods
def get_agent_type(self):
return self._agent_type
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
"""
Prepare message with context information in message.preamble
"""
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
# Generate RAG content if enabled, based on the content
rag_context = ""
if message.enable_rag:
# Gather RAG results, yielding each result
# as it becomes available
for message in self.context.generate_rag_results(message):
logging.info(f"RAG: {message.status} - {message.response}")
if message.status == "error":
yield message
return
if message.status != "done":
yield message
if "rag" in message.metadata and message.metadata["rag"]:
for rag in message.metadata["rag"]:
for doc in rag["documents"]:
rag_context += f"{doc}\n"
message.preamble = {}
if rag_context:
message.preamble["context"] = rag_context
if self.context.user_resume:
message.preamble["resume"] = self.context.user_resume
if message.preamble:
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
preamble_types_AND = " and ".join(preamble_types)
preamble_types_OR = " or ".join(preamble_types)
message.preamble["rules"] = f"""\
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
"""
message.preamble["question"] = "Respond to:"
message.system_prompt = self.system_prompt
message.status = "done"
yield message
return
async def process_tool_calls(self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[Any]) -> AsyncGenerator[Message, None]:
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
if not message.metadata["tools"]:
raise ValueError("tools field not initialized")
tool_metadata = message.metadata["tools"]
tool_metadata["messages"] = messages
tool_metadata["tool_calls"] = []
message.status = "tooling"
for i, tool_call in enumerate(tool_message.tool_calls):
arguments = tool_call.function.arguments
tool = tool_call.function.name
# Yield status update before processing each tool
message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
yield message
logging.info(f"LLM - {message.response}")
# Process the tool based on its type
match tool:
case "TickerValue":
ticker = arguments.get("ticker")
if not ticker:
ret = None
else:
ret = Tools.TickerValue(ticker)
case "AnalyzeSite":
url = arguments.get("url")
question = arguments.get("question", "what is the summary of this content?")
# Additional status update for long-running operations
message.response = f"Retrieving and summarizing content from {url}..."
yield message
ret = await Tools.AnalyzeSite(llm=llm, model=model, url=url, question=question)
case "DateTime":
tz = arguments.get("timezone")
ret = Tools.DateTime(tz)
case "WeatherForecast":
city = arguments.get("city")
state = arguments.get("state")
message.response = f"Fetching weather data for {city}, {state}..."
yield message
ret = Tools.WeatherForecast(city, state)
case _:
ret = None
# Build response for this tool
tool_response = {
"role": "tool",
"content": json.dumps(ret),
"name": tool_call.function.name
}
tool_metadata["tool_calls"].append(tool_response)
if len(tool_metadata["tool_calls"]) == 0:
message.status = "done"
yield message
return
message_dict = {
"role": tool_message.get("role", "assistant"),
"content": tool_message.get("content", ""),
"tool_calls": [ {
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
} for tc in tool_message.tool_calls
]
}
messages.append(message_dict)
messages.extend(tool_metadata["tool_calls"])
message.status = "thinking"
# Decrease creativity when processing tool call requests
message.response = ""
start_time = time.perf_counter()
for response in llm.chat(
model=model,
messages=messages,
stream=True,
options={
**message.metadata["options"],
# "temperature": 0.5,
}
):
# logging.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
message.status = "streaming"
message.response += response.message.content
if not response.done:
yield message
if response.done:
message.metadata["eval_count"] += response.eval_count
message.metadata["eval_duration"] += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
self.context_tokens = response.prompt_eval_count + response.eval_count
message.status = "done"
yield message
end_time = time.perf_counter()
message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}"
return
async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]:
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
messages = [ { "role": "system", "content": message.system_prompt } ]
messages.extend([
item for m in self.conversation.messages
for item in [
{"role": "user", "content": m.prompt.strip()},
{"role": "assistant", "content": m.response.strip()}
]
])
messages.append({
"role": "user",
"content": message.context_prompt.strip(),
})
message.metadata["messages"] = messages
message.metadata["options"]={
"seed": 8911,
"num_ctx": self.context_size,
#"temperature": 0.9, # Higher temperature to encourage tool usage
}
message.metadata["timers"] = {}
use_tools = message.enable_tools and len(self.context.tools) > 0
message.metadata["tools"] = {
"available": Tools.llm_tools(self.context.tools),
"used": False
}
tool_metadata = message.metadata["tools"]
if use_tools:
message.status = "thinking"
message.response = f"Performing tool analysis step 1/2..."
yield message
logging.info("Checking for LLM tool usage")
start_time = time.perf_counter()
# Tools are enabled and available, so query the LLM with a short token target to see if it will
# use the tools
tool_metadata["messages"] = [{ "role": "system", "content": self.system_prompt}, {"role": "user", "content": message.prompt}]
response = llm.chat(
model=model,
messages=tool_metadata["messages"],
tools=tool_metadata["available"],
options={
**message.metadata["options"],
#"num_predict": 1024, # "Low" token limit to cut off after tool call
},
stream=False # No need to stream the probe
)
end_time = time.perf_counter()
message.metadata["timers"]["tool_check"] = f"{(end_time - start_time):.4f}"
if not response.message.tool_calls:
logging.info("LLM indicates tools will not be used")
# The LLM will not use tools, so disable use_tools so we can stream the full response
use_tools = False
if use_tools:
logging.info("LLM indicates tools will be used")
# Tools are enabled and available and the LLM indicated it will use them
tool_metadata["attempted"] = response.message.tool_calls
message.response = f"Performing tool analysis step 2/2 (tool use suspected)..."
yield message
logging.info(f"Performing LLM call with tools")
start_time = time.perf_counter()
response = llm.chat(
model=model,
messages=tool_metadata["messages"], # messages,
tools=tool_metadata["available"],
options={
**message.metadata["options"],
},
stream=False
)
end_time = time.perf_counter()
message.metadata["timers"]["non_streaming"] = f"{(end_time - start_time):.4f}"
if not response:
message.status = "error"
message.response = "No response from LLM."
yield message
return
if response.message.tool_calls:
tool_metadata["used"] = response.message.tool_calls
# Process all yielded items from the handler
start_time = time.perf_counter()
async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages):
if message.status == "error":
yield message
return
yield message
end_time = time.perf_counter()
message.metadata["timers"]["process_tool_calls"] = f"{(end_time - start_time):.4f}"
message.status = "done"
return
logging.info("LLM indicated tools will be used, and then they weren't")
message.response = response.message.content
message.status = "done"
yield message
return
# not use_tools
yield message
# Reset the response for streaming
message.response = ""
start_time = time.perf_counter()
for response in llm.chat(
model=model,
messages=messages,
options={
**message.metadata["options"],
},
stream=True,
):
if not response:
message.status = "error"
message.response = "No response from LLM."
yield message
return
message.status = "streaming"
message.response += response.message.content
if not response.done:
yield message
if response.done:
message.metadata["eval_count"] += response.eval_count
message.metadata["eval_duration"] += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
self.context_tokens = response.prompt_eval_count + response.eval_count
message.status = "done"
yield message
end_time = time.perf_counter()
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
return
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
if self.context.processing:
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
message.status = "error"
message.response = "Busy processing another request."
yield message
return
self.context.processing = True
message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n"
message.context_prompt = ""
for p in message.preamble.keys():
message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n"
message.context_prompt += f"{message.prompt}"
# Estimate token length of new messages
message.response = f"Optimizing context..."
message.status = "thinking"
yield message
message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt)
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
message.status = "thinking"
yield message
async for message in self.generate_llm_response(llm, model, message):
# logging.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
if message.status == "error":
yield message
self.context.processing = False
return
yield message
# Done processing, add message to conversation
message.status = "done"
self.conversation.add_message(message)
self.context.processing = False
return
# Register the base agent
registry.register(Agent._agent_type, Agent)