from __future__ import annotations from pydantic import BaseModel, PrivateAttr, Field # type: ignore from typing import ( Literal, get_args, List, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any, TypeAlias, Dict, Tuple, ) import json import time import inspect from abc import ABC import asyncio from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore from ..setup_logging import setup_logging logger = setup_logging() # Only import Context for type checking if TYPE_CHECKING: from ..context import Context from .types import agent_registry from .. import defines from ..message import Message, Tunables from ..metrics import Metrics from ..tools import TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools # type: ignore -- dynamically added to __all__ from ..conversation import Conversation class LLMMessage(BaseModel): role: str = Field(default="") content: str = Field(default="") tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True) class Agent(BaseModel, ABC): """ Base class for all agent types. This class defines the common attributes and methods for all agent types. """ # Agent management with pydantic agent_type: Literal["base"] = "base" _agent_type: ClassVar[str] = agent_type # Add this for registration # Tunables (sets default for new Messages attached to this agent) tunables: Tunables = Field(default_factory=Tunables) # Agent properties system_prompt: str # Mandatory conversation: Conversation = Conversation() context_tokens: int = 0 context: Optional[Context] = Field( default=None, exclude=True ) # Avoid circular reference, require as param, and prevent serialization metrics: Metrics = Field(default_factory=Metrics, exclude=True) # context_size is shared across all subclasses _context_size: ClassVar[int] = int(defines.max_context * 0.5) @property def context_size(self) -> int: return Agent._context_size @context_size.setter def context_size(self, value: int): Agent._context_size = value def set_optimal_context_size( self, llm: Any, model: str, prompt: str, ctx_buffer=2048 ) -> int: # # Get more accurate token count estimate using tiktoken or similar # response = llm.generate( # model=model, # prompt=prompt, # options={ # "num_ctx": self.context_size, # "num_predict": 0, # } # Don't generate any tokens, just tokenize # ) # # The prompt_eval_count gives you the token count of your input # tokens = response.get("prompt_eval_count", 0) # Most models average 1.3-1.5 tokens per word word_count = len(prompt.split()) tokens = int(word_count * 1.4) # Add buffer for safety total_ctx = tokens + ctx_buffer if total_ctx > self.context_size: logger.info( f"Increasing context size from {self.context_size} to {total_ctx}" ) # Grow the context size if necessary self.context_size = max(self.context_size, total_ctx) # Use actual model maximum context size return self.context_size # Class and pydantic model management def __init_subclass__(cls, **kwargs) -> None: """Auto-register subclasses""" super().__init_subclass__(**kwargs) # Register this class if it has an agent_type if hasattr(cls, "agent_type") and cls.agent_type != Agent._agent_type: agent_registry.register(cls.agent_type, cls) # def __init__(self, *, context=context, **data): # super().__init__(**data) # self.set_context(context) def model_dump(self, *args, **kwargs) -> Any: # Ensure context is always excluded, even with exclude_unset=True kwargs.setdefault("exclude", set()) if isinstance(kwargs["exclude"], set): kwargs["exclude"].add("context") elif isinstance(kwargs["exclude"], dict): kwargs["exclude"]["context"] = True return super().model_dump(*args, **kwargs) @classmethod def valid_agent_types(cls) -> set[str]: """Return the set of valid agent_type values.""" return set(get_args(cls.__annotations__["agent_type"])) def set_context(self, context: Context): object.__setattr__(self, "context", context) # Agent methods def get_agent_type(self): return self._agent_type async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]: """ Prepare message with context information in message.preamble """ logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") self.metrics.prepare_count.labels(agent=self.agent_type).inc() with self.metrics.prepare_duration.labels(agent=self.agent_type).time(): if not self.context: raise ValueError("Context is not set for this agent.") # Generate RAG content if enabled, based on the content rag_context = "" if message.tunables.enable_rag and message.prompt: # Gather RAG results, yielding each result # as it becomes available for message in self.context.generate_rag_results(message): logger.info(f"RAG: {message.status} - {message.response}") if message.status == "error": yield message return if message.status != "done": yield message if "rag" in message.metadata and message.metadata["rag"]: for rag in message.metadata["rag"]: for doc in rag["documents"]: rag_context += f"{doc}\n" message.preamble = {} if rag_context: message.preamble["context"] = rag_context if message.tunables.enable_context and self.context.user_resume: message.preamble["resume"] = self.context.user_resume message.system_prompt = self.system_prompt message.status = "done" yield message return async def process_tool_calls( self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[LLMMessage], ) -> AsyncGenerator[Message, None]: logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") self.metrics.tool_count.labels(agent=self.agent_type).inc() with self.metrics.tool_duration.labels(agent=self.agent_type).time(): if not self.context: raise ValueError("Context is not set for this agent.") if not message.metadata["tools"]: raise ValueError("tools field not initialized") tool_metadata = message.metadata["tools"] tool_metadata["tool_calls"] = [] message.status = "tooling" for i, tool_call in enumerate(tool_message.tool_calls): arguments = tool_call.function.arguments tool = tool_call.function.name # Yield status update before processing each tool message.response = ( f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..." ) yield message logger.info(f"LLM - {message.response}") # Process the tool based on its type match tool: case "TickerValue": ticker = arguments.get("ticker") if not ticker: ret = None else: ret = TickerValue(ticker) case "AnalyzeSite": url = arguments.get("url") question = arguments.get( "question", "what is the summary of this content?" ) # Additional status update for long-running operations message.response = ( f"Retrieving and summarizing content from {url}..." ) yield message ret = await AnalyzeSite( llm=llm, model=model, url=url, question=question ) case "DateTime": tz = arguments.get("timezone") ret = DateTime(tz) case "WeatherForecast": city = arguments.get("city") state = arguments.get("state") message.response = ( f"Fetching weather data for {city}, {state}..." ) yield message ret = WeatherForecast(city, state) case _: ret = None # Build response for this tool tool_response = { "role": "tool", "content": json.dumps(ret), "name": tool_call.function.name, } tool_metadata["tool_calls"].append(tool_response) if len(tool_metadata["tool_calls"]) == 0: message.status = "done" yield message return message_dict = LLMMessage( role=tool_message.get("role", "assistant"), content=tool_message.get("content", ""), tool_calls=[ { "function": { "name": tc["function"]["name"], "arguments": tc["function"]["arguments"], } } for tc in tool_message.tool_calls ], ) messages.append(message_dict) messages.extend(tool_metadata["tool_calls"]) message.status = "thinking" message.response = "Incorporating tool results into response..." yield message # Decrease creativity when processing tool call requests message.response = "" start_time = time.perf_counter() for response in llm.chat( model=model, messages=messages, options={ **message.metadata["options"], # "temperature": 0.5, }, stream=True, ): # logger.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}") message.status = "streaming" message.chunk = response.message.content message.response += message.chunk if not response.done: yield message if response.done: self.collect_metrics(response) message.metadata["eval_count"] += response.eval_count message.metadata["eval_duration"] += response.eval_duration message.metadata["prompt_eval_count"] += response.prompt_eval_count message.metadata[ "prompt_eval_duration" ] += response.prompt_eval_duration self.context_tokens = ( response.prompt_eval_count + response.eval_count ) message.status = "done" yield message end_time = time.perf_counter() message.metadata["timers"][ "llm_with_tools" ] = f"{(end_time - start_time):.4f}" return def collect_metrics(self, response): self.metrics.tokens_prompt.labels(agent=self.agent_type).inc( response.prompt_eval_count ) self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count) async def generate_llm_response( self, llm: Any, model: str, message: Message, temperature=0.7 ) -> AsyncGenerator[Message, None]: logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") self.metrics.generate_count.labels(agent=self.agent_type).inc() with self.metrics.generate_duration.labels(agent=self.agent_type).time(): if not self.context: raise ValueError("Context is not set for this agent.") # Create a pruned down message list based purely on the prompt and responses, # discarding the full preamble generated by prepare_message messages: List[LLMMessage] = [ LLMMessage(role="system", content=message.system_prompt) ] messages.extend( [ item for m in self.conversation for item in [ LLMMessage(role="user", content=m.prompt.strip()), LLMMessage(role="assistant", content=m.response.strip()), ] ] ) # Only the actual user query is provided with the full context message messages.append( LLMMessage(role="user", content=message.context_prompt.strip()) ) # message.metadata["messages"] = messages message.metadata["options"] = { "seed": 8911, "num_ctx": self.context_size, "temperature": temperature, # Higher temperature to encourage tool usage } # Create a dict for storing various timing stats message.metadata["timers"] = {} use_tools = message.tunables.enable_tools and len(self.context.tools) > 0 message.metadata["tools"] = { "available": llm_tools(self.context.tools), "used": False, } tool_metadata = message.metadata["tools"] if use_tools: message.status = "thinking" message.response = f"Performing tool analysis step 1/2..." yield message logger.info("Checking for LLM tool usage") start_time = time.perf_counter() # Tools are enabled and available, so query the LLM with a short context of messages # in case the LLM did something like ask "Do you want me to run the tool?" and the # user said "Yes" -- need to keep the context in the thread. tool_metadata["messages"] = ( [{"role": "system", "content": self.system_prompt}] + messages[-6:] if len(messages) >= 7 else messages ) response = llm.chat( model=model, messages=tool_metadata["messages"], tools=tool_metadata["available"], options={ **message.metadata["options"], # "num_predict": 1024, # "Low" token limit to cut off after tool call }, stream=False, # No need to stream the probe ) self.collect_metrics(response) end_time = time.perf_counter() message.metadata["timers"][ "tool_check" ] = f"{(end_time - start_time):.4f}" if not response.message.tool_calls: logger.info("LLM indicates tools will not be used") # The LLM will not use tools, so disable use_tools so we can stream the full response use_tools = False else: tool_metadata["attempted"] = response.message.tool_calls if use_tools: logger.info("LLM indicates tools will be used") # Tools are enabled and available and the LLM indicated it will use them message.response = ( f"Performing tool analysis step 2/2 (tool use suspected)..." ) yield message logger.info(f"Performing LLM call with tools") start_time = time.perf_counter() response = llm.chat( model=model, messages=tool_metadata["messages"], # messages, tools=tool_metadata["available"], options={ **message.metadata["options"], }, stream=False, ) self.collect_metrics(response) end_time = time.perf_counter() message.metadata["timers"][ "non_streaming" ] = f"{(end_time - start_time):.4f}" if not response: message.status = "error" message.response = "No response from LLM." yield message return if response.message.tool_calls: tool_metadata["used"] = response.message.tool_calls # Process all yielded items from the handler start_time = time.perf_counter() async for message in self.process_tool_calls( llm=llm, model=model, message=message, tool_message=response.message, messages=messages, ): if message.status == "error": yield message return yield message end_time = time.perf_counter() message.metadata["timers"][ "process_tool_calls" ] = f"{(end_time - start_time):.4f}" message.status = "done" return logger.info("LLM indicated tools will be used, and then they weren't") message.response = response.message.content message.status = "done" yield message return # not use_tools message.status = "thinking" message.response = f"Generating response..." yield message # Reset the response for streaming message.response = "" start_time = time.perf_counter() for response in llm.chat( model=model, messages=messages, options={ **message.metadata["options"], }, stream=True, ): if not response: message.status = "error" message.response = "No response from LLM." yield message return message.status = "streaming" message.chunk = response.message.content message.response += message.chunk if not response.done: yield message if response.done: self.collect_metrics(response) message.metadata["eval_count"] += response.eval_count message.metadata["eval_duration"] += response.eval_duration message.metadata["prompt_eval_count"] += response.prompt_eval_count message.metadata[ "prompt_eval_duration" ] += response.prompt_eval_duration self.context_tokens = ( response.prompt_eval_count + response.eval_count ) message.status = "done" yield message end_time = time.perf_counter() message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}" return async def process_message( self, llm: Any, model: str, message: Message ) -> AsyncGenerator[Message, None]: logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") self.metrics.process_count.labels(agent=self.agent_type).inc() with self.metrics.process_duration.labels(agent=self.agent_type).time(): if not self.context: raise ValueError("Context is not set for this agent.") logger.info( "TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time" ) spinner: List[str] = ["\\", "|", "/", "-"] tick: int = 0 while self.context.processing: message.status = "waiting" message.response = ( f"Busy processing another request. Please wait. {spinner[tick]}" ) tick = (tick + 1) % len(spinner) yield message await asyncio.sleep(1) # Allow the event loop to process the write self.context.processing = True message.metadata["system_prompt"] = ( f"<|system|>\n{self.system_prompt.strip()}\n" ) message.context_prompt = "" for p in message.preamble.keys(): message.context_prompt += ( f"\n<|{p}|>\n{message.preamble[p].strip()}\n\n\n" ) message.context_prompt += f"{message.prompt}" # Estimate token length of new messages message.response = f"Optimizing context..." message.status = "thinking" yield message message.metadata["context_size"] = self.set_optimal_context_size( llm, model, prompt=message.context_prompt ) message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." message.status = "thinking" yield message async for message in self.generate_llm_response( llm=llm, model=model, message=message ): # logger.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}") if message.status == "error": yield message self.context.processing = False return yield message # Done processing, add message to conversation message.status = "done" self.conversation.add(message) self.context.processing = False return # Register the base agent agent_registry.register(Agent._agent_type, Agent)