591 lines
22 KiB
Python
591 lines
22 KiB
Python
from __future__ import annotations
|
|
from pydantic import BaseModel, Field # type: ignore
|
|
from typing import (
|
|
Literal,
|
|
get_args,
|
|
List,
|
|
AsyncGenerator,
|
|
TYPE_CHECKING,
|
|
Optional,
|
|
ClassVar,
|
|
Any,
|
|
TypeAlias,
|
|
Dict,
|
|
Tuple,
|
|
)
|
|
import json
|
|
import time
|
|
import inspect
|
|
from abc import ABC
|
|
import asyncio
|
|
|
|
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
|
|
|
from ..setup_logging import setup_logging
|
|
|
|
logger = setup_logging()
|
|
|
|
# Only import Context for type checking
|
|
if TYPE_CHECKING:
|
|
from ..context import Context
|
|
|
|
from .types import agent_registry
|
|
from .. import defines
|
|
from ..message import Message, Tunables
|
|
from ..metrics import Metrics
|
|
from ..tools import TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools # type: ignore -- dynamically added to __all__
|
|
from ..conversation import Conversation
|
|
|
|
|
|
class LLMMessage(BaseModel):
|
|
role: str = Field(default="")
|
|
content: str = Field(default="")
|
|
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
|
|
|
|
|
|
class Agent(BaseModel, ABC):
|
|
"""
|
|
Base class for all agent types.
|
|
This class defines the common attributes and methods for all agent types.
|
|
"""
|
|
|
|
# Agent management with pydantic
|
|
agent_type: Literal["base"] = "base"
|
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
|
|
|
# Tunables (sets default for new Messages attached to this agent)
|
|
tunables: Tunables = Field(default_factory=Tunables)
|
|
|
|
# Agent properties
|
|
system_prompt: str # Mandatory
|
|
conversation: Conversation = Conversation()
|
|
context_tokens: int = 0
|
|
context: Optional[Context] = Field(
|
|
default=None, exclude=True
|
|
) # Avoid circular reference, require as param, and prevent serialization
|
|
metrics: Metrics = Field(default_factory=Metrics, exclude=True)
|
|
|
|
# context_size is shared across all subclasses
|
|
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
|
|
|
|
@property
|
|
def context_size(self) -> int:
|
|
return Agent._context_size
|
|
|
|
@context_size.setter
|
|
def context_size(self, value: int):
|
|
Agent._context_size = value
|
|
|
|
def set_optimal_context_size(
|
|
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
|
|
) -> int:
|
|
# # Get more accurate token count estimate using tiktoken or similar
|
|
# response = llm.generate(
|
|
# model=model,
|
|
# prompt=prompt,
|
|
# options={
|
|
# "num_ctx": self.context_size,
|
|
# "num_predict": 0,
|
|
# } # Don't generate any tokens, just tokenize
|
|
# )
|
|
# # The prompt_eval_count gives you the token count of your input
|
|
# tokens = response.get("prompt_eval_count", 0)
|
|
|
|
# Most models average 1.3-1.5 tokens per word
|
|
word_count = len(prompt.split())
|
|
tokens = int(word_count * 1.4)
|
|
|
|
# Add buffer for safety
|
|
total_ctx = tokens + ctx_buffer
|
|
|
|
if total_ctx > self.context_size:
|
|
logger.info(
|
|
f"Increasing context size from {self.context_size} to {total_ctx}"
|
|
)
|
|
|
|
# Grow the context size if necessary
|
|
self.context_size = max(self.context_size, total_ctx)
|
|
# Use actual model maximum context size
|
|
return self.context_size
|
|
|
|
# Class and pydantic model management
|
|
def __init_subclass__(cls, **kwargs) -> None:
|
|
"""Auto-register subclasses"""
|
|
super().__init_subclass__(**kwargs)
|
|
# Register this class if it has an agent_type
|
|
if hasattr(cls, "agent_type") and cls.agent_type != Agent._agent_type:
|
|
agent_registry.register(cls.agent_type, cls)
|
|
|
|
# def __init__(self, *, context=context, **data):
|
|
# super().__init__(**data)
|
|
# self.set_context(context)
|
|
|
|
def model_dump(self, *args, **kwargs) -> Any:
|
|
# Ensure context is always excluded, even with exclude_unset=True
|
|
kwargs.setdefault("exclude", set())
|
|
if isinstance(kwargs["exclude"], set):
|
|
kwargs["exclude"].add("context")
|
|
elif isinstance(kwargs["exclude"], dict):
|
|
kwargs["exclude"]["context"] = True
|
|
return super().model_dump(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def valid_agent_types(cls) -> set[str]:
|
|
"""Return the set of valid agent_type values."""
|
|
return set(get_args(cls.__annotations__["agent_type"]))
|
|
|
|
def set_context(self, context: Context):
|
|
object.__setattr__(self, "context", context)
|
|
|
|
# Agent methods
|
|
def get_agent_type(self):
|
|
return self._agent_type
|
|
|
|
async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
|
"""
|
|
Prepare message with context information in message.preamble
|
|
"""
|
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
|
|
|
self.metrics.prepare_count.labels(agent=self.agent_type).inc()
|
|
with self.metrics.prepare_duration.labels(agent=self.agent_type).time():
|
|
if not self.context:
|
|
raise ValueError("Context is not set for this agent.")
|
|
|
|
# Generate RAG content if enabled, based on the content
|
|
rag_context = ""
|
|
if message.tunables.enable_rag and message.prompt:
|
|
# Gather RAG results, yielding each result
|
|
# as it becomes available
|
|
for message in self.context.generate_rag_results(message):
|
|
logger.info(f"RAG: {message.status} - {message.response}")
|
|
if message.status == "error":
|
|
yield message
|
|
return
|
|
if message.status != "done":
|
|
yield message
|
|
|
|
for rag in message.metadata.rag:
|
|
for doc in rag.documents:
|
|
rag_context += f"{doc}\n"
|
|
|
|
message.preamble = {}
|
|
|
|
if rag_context:
|
|
message.preamble["context"] = rag_context
|
|
|
|
if message.tunables.enable_context and self.context.user_resume:
|
|
message.preamble["resume"] = self.context.user_resume
|
|
|
|
message.system_prompt = self.system_prompt
|
|
message.status = "done"
|
|
yield message
|
|
|
|
return
|
|
|
|
async def process_tool_calls(
|
|
self,
|
|
llm: Any,
|
|
model: str,
|
|
message: Message,
|
|
tool_message: Any, # llama response message
|
|
messages: List[LLMMessage],
|
|
) -> AsyncGenerator[Message, None]:
|
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
|
|
|
self.metrics.tool_count.labels(agent=self.agent_type).inc()
|
|
with self.metrics.tool_duration.labels(agent=self.agent_type).time():
|
|
|
|
if not self.context:
|
|
raise ValueError("Context is not set for this agent.")
|
|
if not message.metadata.tools:
|
|
raise ValueError("tools field not initialized")
|
|
|
|
tool_metadata = message.metadata.tools
|
|
tool_metadata["tool_calls"] = []
|
|
|
|
message.status = "tooling"
|
|
|
|
for i, tool_call in enumerate(tool_message.tool_calls):
|
|
arguments = tool_call.function.arguments
|
|
tool = tool_call.function.name
|
|
|
|
# Yield status update before processing each tool
|
|
message.response = (
|
|
f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
|
|
)
|
|
yield message
|
|
logger.info(f"LLM - {message.response}")
|
|
|
|
# Process the tool based on its type
|
|
match tool:
|
|
case "TickerValue":
|
|
ticker = arguments.get("ticker")
|
|
if not ticker:
|
|
ret = None
|
|
else:
|
|
ret = TickerValue(ticker)
|
|
|
|
case "AnalyzeSite":
|
|
url = arguments.get("url")
|
|
question = arguments.get(
|
|
"question", "what is the summary of this content?"
|
|
)
|
|
|
|
# Additional status update for long-running operations
|
|
message.response = (
|
|
f"Retrieving and summarizing content from {url}..."
|
|
)
|
|
yield message
|
|
ret = await AnalyzeSite(
|
|
llm=llm, model=model, url=url, question=question
|
|
)
|
|
|
|
case "DateTime":
|
|
tz = arguments.get("timezone")
|
|
ret = DateTime(tz)
|
|
|
|
case "WeatherForecast":
|
|
city = arguments.get("city")
|
|
state = arguments.get("state")
|
|
|
|
message.response = (
|
|
f"Fetching weather data for {city}, {state}..."
|
|
)
|
|
yield message
|
|
ret = WeatherForecast(city, state)
|
|
|
|
case _:
|
|
ret = None
|
|
|
|
# Build response for this tool
|
|
tool_response = {
|
|
"role": "tool",
|
|
"content": json.dumps(ret),
|
|
"name": tool_call.function.name,
|
|
}
|
|
|
|
tool_metadata["tool_calls"].append(tool_response)
|
|
|
|
if len(tool_metadata["tool_calls"]) == 0:
|
|
message.status = "done"
|
|
yield message
|
|
return
|
|
|
|
message_dict = LLMMessage(
|
|
role=tool_message.get("role", "assistant"),
|
|
content=tool_message.get("content", ""),
|
|
tool_calls=[
|
|
{
|
|
"function": {
|
|
"name": tc["function"]["name"],
|
|
"arguments": tc["function"]["arguments"],
|
|
}
|
|
}
|
|
for tc in tool_message.tool_calls
|
|
],
|
|
)
|
|
|
|
messages.append(message_dict)
|
|
messages.extend(tool_metadata["tool_calls"])
|
|
|
|
message.status = "thinking"
|
|
message.response = "Incorporating tool results into response..."
|
|
yield message
|
|
|
|
# Decrease creativity when processing tool call requests
|
|
message.response = ""
|
|
start_time = time.perf_counter()
|
|
for response in llm.chat(
|
|
model=model,
|
|
messages=messages,
|
|
options={
|
|
**message.metadata.options,
|
|
},
|
|
stream=True,
|
|
):
|
|
# logger.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
|
|
message.status = "streaming"
|
|
message.chunk = response.message.content
|
|
message.response += message.chunk
|
|
|
|
if not response.done:
|
|
yield message
|
|
|
|
if response.done:
|
|
self.collect_metrics(response)
|
|
message.metadata.eval_count += response.eval_count
|
|
message.metadata.eval_duration += response.eval_duration
|
|
message.metadata.prompt_eval_count += response.prompt_eval_count
|
|
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
|
self.context_tokens = (
|
|
response.prompt_eval_count + response.eval_count
|
|
)
|
|
message.status = "done"
|
|
yield message
|
|
|
|
end_time = time.perf_counter()
|
|
message.metadata.timers["llm_with_tools"] = end_time - start_time
|
|
return
|
|
|
|
def collect_metrics(self, response):
|
|
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
|
|
response.prompt_eval_count
|
|
)
|
|
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
|
|
|
|
async def generate_llm_response(
|
|
self, llm: Any, model: str, message: Message, temperature=0.7
|
|
) -> AsyncGenerator[Message, None]:
|
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
|
|
|
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
|
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
|
if not self.context:
|
|
raise ValueError("Context is not set for this agent.")
|
|
|
|
# Create a pruned down message list based purely on the prompt and responses,
|
|
# discarding the full preamble generated by prepare_message
|
|
messages: List[LLMMessage] = [
|
|
LLMMessage(role="system", content=message.system_prompt)
|
|
]
|
|
messages.extend(
|
|
[
|
|
item
|
|
for m in self.conversation
|
|
for item in [
|
|
LLMMessage(role="user", content=m.prompt.strip()),
|
|
LLMMessage(role="assistant", content=m.response.strip()),
|
|
]
|
|
]
|
|
)
|
|
# Only the actual user query is provided with the full context message
|
|
messages.append(
|
|
LLMMessage(role="user", content=message.context_prompt.strip())
|
|
)
|
|
|
|
# message.messages = messages
|
|
message.metadata.options = {
|
|
"seed": 8911,
|
|
"num_ctx": self.context_size,
|
|
"temperature": temperature, # Higher temperature to encourage tool usage
|
|
}
|
|
|
|
# Create a dict for storing various timing stats
|
|
message.metadata.timers = {}
|
|
|
|
use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
|
message.metadata.tools = {
|
|
"available": llm_tools(self.context.tools),
|
|
"used": False,
|
|
}
|
|
tool_metadata = message.metadata.tools
|
|
|
|
if use_tools:
|
|
message.status = "thinking"
|
|
message.response = f"Performing tool analysis step 1/2..."
|
|
yield message
|
|
|
|
logger.info("Checking for LLM tool usage")
|
|
start_time = time.perf_counter()
|
|
# Tools are enabled and available, so query the LLM with a short context of messages
|
|
# in case the LLM did something like ask "Do you want me to run the tool?" and the
|
|
# user said "Yes" -- need to keep the context in the thread.
|
|
tool_metadata["messages"] = (
|
|
[{"role": "system", "content": self.system_prompt}] + messages[-6:]
|
|
if len(messages) >= 7
|
|
else messages
|
|
)
|
|
|
|
response = llm.chat(
|
|
model=model,
|
|
messages=tool_metadata["messages"],
|
|
tools=tool_metadata["available"],
|
|
options={
|
|
**message.metadata.options,
|
|
},
|
|
stream=False, # No need to stream the probe
|
|
)
|
|
self.collect_metrics(response)
|
|
|
|
end_time = time.perf_counter()
|
|
message.metadata.timers["tool_check"] = end_time - start_time
|
|
if not response.message.tool_calls:
|
|
logger.info("LLM indicates tools will not be used")
|
|
# The LLM will not use tools, so disable use_tools so we can stream the full response
|
|
use_tools = False
|
|
else:
|
|
tool_metadata["attempted"] = response.message.tool_calls
|
|
|
|
if use_tools:
|
|
logger.info("LLM indicates tools will be used")
|
|
|
|
# Tools are enabled and available and the LLM indicated it will use them
|
|
message.response = (
|
|
f"Performing tool analysis step 2/2 (tool use suspected)..."
|
|
)
|
|
yield message
|
|
|
|
logger.info(f"Performing LLM call with tools")
|
|
start_time = time.perf_counter()
|
|
response = llm.chat(
|
|
model=model,
|
|
messages=tool_metadata["messages"], # messages,
|
|
tools=tool_metadata["available"],
|
|
options={
|
|
**message.metadata.options,
|
|
},
|
|
stream=False,
|
|
)
|
|
self.collect_metrics(response)
|
|
|
|
end_time = time.perf_counter()
|
|
message.metadata.timers["non_streaming"] = end_time - start_time
|
|
|
|
if not response:
|
|
message.status = "error"
|
|
message.response = "No response from LLM."
|
|
yield message
|
|
return
|
|
|
|
if response.message.tool_calls:
|
|
tool_metadata["used"] = response.message.tool_calls
|
|
# Process all yielded items from the handler
|
|
start_time = time.perf_counter()
|
|
async for message in self.process_tool_calls(
|
|
llm=llm,
|
|
model=model,
|
|
message=message,
|
|
tool_message=response.message,
|
|
messages=messages,
|
|
):
|
|
if message.status == "error":
|
|
yield message
|
|
return
|
|
yield message
|
|
end_time = time.perf_counter()
|
|
message.metadata.timers["process_tool_calls"] = end_time - start_time
|
|
message.status = "done"
|
|
return
|
|
|
|
logger.info("LLM indicated tools will be used, and then they weren't")
|
|
message.response = response.message.content
|
|
message.status = "done"
|
|
yield message
|
|
return
|
|
|
|
# not use_tools
|
|
message.status = "thinking"
|
|
message.response = f"Generating response..."
|
|
yield message
|
|
# Reset the response for streaming
|
|
message.response = ""
|
|
start_time = time.perf_counter()
|
|
for response in llm.chat(
|
|
model=model,
|
|
messages=messages,
|
|
options={
|
|
**message.metadata.options,
|
|
},
|
|
stream=True,
|
|
):
|
|
if not response:
|
|
message.status = "error"
|
|
message.response = "No response from LLM."
|
|
yield message
|
|
return
|
|
|
|
message.status = "streaming"
|
|
message.chunk = response.message.content
|
|
message.response += message.chunk
|
|
|
|
if not response.done:
|
|
yield message
|
|
|
|
if response.done:
|
|
self.collect_metrics(response)
|
|
message.metadata.eval_count += response.eval_count
|
|
message.metadata.eval_duration += response.eval_duration
|
|
message.metadata.prompt_eval_count += response.prompt_eval_count
|
|
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
|
self.context_tokens = (
|
|
response.prompt_eval_count + response.eval_count
|
|
)
|
|
message.status = "done"
|
|
yield message
|
|
|
|
end_time = time.perf_counter()
|
|
message.metadata.timers["streamed"] = end_time - start_time
|
|
return
|
|
|
|
async def process_message(
|
|
self, llm: Any, model: str, message: Message
|
|
) -> AsyncGenerator[Message, None]:
|
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
|
|
|
self.metrics.process_count.labels(agent=self.agent_type).inc()
|
|
with self.metrics.process_duration.labels(agent=self.agent_type).time():
|
|
|
|
if not self.context:
|
|
raise ValueError("Context is not set for this agent.")
|
|
|
|
logger.info(
|
|
"TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
|
|
)
|
|
spinner: List[str] = ["\\", "|", "/", "-"]
|
|
tick: int = 0
|
|
while self.context.processing:
|
|
message.status = "waiting"
|
|
message.response = (
|
|
f"Busy processing another request. Please wait. {spinner[tick]}"
|
|
)
|
|
tick = (tick + 1) % len(spinner)
|
|
yield message
|
|
await asyncio.sleep(1) # Allow the event loop to process the write
|
|
|
|
self.context.processing = True
|
|
|
|
message.system_prompt = (
|
|
f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
|
|
)
|
|
message.context_prompt = ""
|
|
for p in message.preamble.keys():
|
|
message.context_prompt += (
|
|
f"\n<|{p}|>\n{message.preamble[p].strip()}\n</|{p}>\n\n"
|
|
)
|
|
message.context_prompt += f"{message.prompt}"
|
|
|
|
# Estimate token length of new messages
|
|
message.response = f"Optimizing context..."
|
|
message.status = "thinking"
|
|
yield message
|
|
|
|
message.context_size = self.set_optimal_context_size(
|
|
llm, model, prompt=message.context_prompt
|
|
)
|
|
|
|
message.response = f"Processing {'RAG augmented ' if message.metadata.rag else ''}query..."
|
|
message.status = "thinking"
|
|
yield message
|
|
|
|
async for message in self.generate_llm_response(
|
|
llm=llm, model=model, message=message
|
|
):
|
|
# logger.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
|
|
if message.status == "error":
|
|
yield message
|
|
self.context.processing = False
|
|
return
|
|
yield message
|
|
|
|
# Done processing, add message to conversation
|
|
message.status = "done"
|
|
self.conversation.add(message)
|
|
self.context.processing = False
|
|
|
|
return
|
|
|
|
|
|
# Register the base agent
|
|
agent_registry.register(Agent._agent_type, Agent)
|