from __future__ import annotations from pydantic import BaseModel, Field, model_validator # type: ignore from typing import ( Literal, get_args, List, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any, TypeAlias, Dict, Tuple, ) import json import time import inspect from abc import ABC import asyncio from datetime import datetime, UTC from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore from models import ( ChatQuery, ChatMessage, ChatOptions, ChatMessageBase, ChatMessageUser, Tunables, ChatMessageType, ChatSenderType, ChatStatusType, ChatMessageMetaData) from logger import logger import defines from .registry import agent_registry from metrics import Metrics from database import RedisDatabase # type: ignore import model_cast class LLMMessage(BaseModel): role: str = Field(default="") content: str = Field(default="") tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True) class Agent(BaseModel, ABC): """ Base class for all agent types. This class defines the common attributes and methods for all agent types. """ class Config: arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase # Agent management with pydantic agent_type: Literal["base"] = "base" _agent_type: ClassVar[str] = agent_type # Add this for registration agent_persist: bool = True # Whether this agent will persist in the database database: RedisDatabase = Field( ..., description="Database connection for this agent, used to store and retrieve data." ) prometheus_collector: CollectorRegistry = Field(..., description="Prometheus collector for this agent, used to track metrics.", exclude=True) # Tunables (sets default for new Messages attached to this agent) tunables: Tunables = Field(default_factory=Tunables) metrics: Metrics = Field( None, description="Metrics collector for this agent, used to track performance and usage." ) @model_validator(mode="after") def initialize_metrics(self) -> "Agent": if self.metrics is None: self.metrics = Metrics(prometheus_collector=self.prometheus_collector) return self # Agent properties system_prompt: str # Mandatory context_tokens: int = 0 # context_size is shared across all subclasses _context_size: ClassVar[int] = int(defines.max_context * 0.5) conversation: List[ChatMessage] = Field( default_factory=list, description="Conversation history for this agent, used to maintain context across messages." ) @property def context_size(self) -> int: return Agent._context_size @context_size.setter def context_size(self, value: int): Agent._context_size = value def set_optimal_context_size( self, llm: Any, model: str, prompt: str, ctx_buffer=2048 ) -> int: # Most models average 1.3-1.5 tokens per word word_count = len(prompt.split()) tokens = int(word_count * 1.4) # Add buffer for safety total_ctx = tokens + ctx_buffer if total_ctx > self.context_size: logger.info( f"Increasing context size from {self.context_size} to {total_ctx}" ) # Grow the context size if necessary self.context_size = max(self.context_size, total_ctx) # Use actual model maximum context size return self.context_size # Class and pydantic model management def __init_subclass__(cls, **kwargs) -> None: """Auto-register subclasses""" super().__init_subclass__(**kwargs) # Register this class if it has an agent_type if hasattr(cls, "agent_type") and cls.agent_type != Agent._agent_type: agent_registry.register(cls.agent_type, cls) def model_dump(self, *args, **kwargs) -> Any: # Ensure context is always excluded, even with exclude_unset=True kwargs.setdefault("exclude", set()) if isinstance(kwargs["exclude"], set): kwargs["exclude"].add("context") elif isinstance(kwargs["exclude"], dict): kwargs["exclude"]["context"] = True return super().model_dump(*args, **kwargs) @classmethod def valid_agent_types(cls) -> set[str]: """Return the set of valid agent_type values.""" return set(get_args(cls.__annotations__["agent_type"])) # Agent methods def get_agent_type(self): return self._agent_type # async def prepare_message(self, message: ChatMessage) -> AsyncGenerator[ChatMessage, None]: # """ # Prepare message with context information in message.preamble # """ # logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") # self.metrics.prepare_count.labels(agent=self.agent_type).inc() # with self.metrics.prepare_duration.labels(agent=self.agent_type).time(): # if not self.context: # raise ValueError("Context is not set for this agent.") # # Generate RAG content if enabled, based on the content # rag_context = "" # if message.tunables.enable_rag and message.prompt: # # Gather RAG results, yielding each result # # as it becomes available # for message in self.context.user.generate_rag_results(message): # logger.info(f"RAG: {message.status} - {message.content}") # if message.status == "error": # yield message # return # if message.status != "done": # yield message # # for rag in message.metadata.rag: # # for doc in rag.documents: # # rag_context += f"{doc}\n" # message.preamble = {} # if rag_context: # message.preamble["context"] = f"The following is context information about {self.context.user.full_name}:\n{rag_context}" # if message.tunables.enable_context and self.context.user_resume: # message.preamble["resume"] = self.context.user_resume # message.system_prompt = self.system_prompt # message.status = ChatStatusType.DONE # yield message # return # async def process_tool_calls( # self, # llm: Any, # model: str, # message: ChatMessage, # tool_message: Any, # llama response message # messages: List[LLMMessage], # ) -> AsyncGenerator[ChatMessage, None]: # logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") # self.metrics.tool_count.labels(agent=self.agent_type).inc() # with self.metrics.tool_duration.labels(agent=self.agent_type).time(): # if not self.context: # raise ValueError("Context is not set for this agent.") # if not message.metadata.tools: # raise ValueError("tools field not initialized") # tool_metadata = message.metadata.tools # tool_metadata["tool_calls"] = [] # message.status = "tooling" # for i, tool_call in enumerate(tool_message.tool_calls): # arguments = tool_call.function.arguments # tool = tool_call.function.name # # Yield status update before processing each tool # message.content = ( # f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..." # ) # yield message # logger.info(f"LLM - {message.content}") # # Process the tool based on its type # match tool: # case "TickerValue": # ticker = arguments.get("ticker") # if not ticker: # ret = None # else: # ret = TickerValue(ticker) # case "AnalyzeSite": # url = arguments.get("url") # question = arguments.get( # "question", "what is the summary of this content?" # ) # # Additional status update for long-running operations # message.content = ( # f"Retrieving and summarizing content from {url}..." # ) # yield message # ret = await AnalyzeSite( # llm=llm, model=model, url=url, question=question # ) # case "GenerateImage": # prompt = arguments.get("prompt", None) # if not prompt: # logger.info("No prompt supplied to GenerateImage") # ret = { "error": "No prompt supplied to GenerateImage" } # # Additional status update for long-running operations # message.content = ( # f"Generating image for {prompt}..." # ) # yield message # ret = await GenerateImage( # llm=llm, model=model, prompt=prompt # ) # logger.info("GenerateImage returning", ret) # case "DateTime": # tz = arguments.get("timezone") # ret = DateTime(tz) # case "WeatherForecast": # city = arguments.get("city") # state = arguments.get("state") # message.content = ( # f"Fetching weather data for {city}, {state}..." # ) # yield message # ret = WeatherForecast(city, state) # case _: # logger.error(f"Requested tool {tool} does not exist") # ret = None # # Build response for this tool # tool_response = { # "role": "tool", # "content": json.dumps(ret), # "name": tool_call.function.name, # } # tool_metadata["tool_calls"].append(tool_response) # if len(tool_metadata["tool_calls"]) == 0: # message.status = "done" # yield message # return # message_dict = LLMMessage( # role=tool_message.get("role", "assistant"), # content=tool_message.get("content", ""), # tool_calls=[ # { # "function": { # "name": tc["function"]["name"], # "arguments": tc["function"]["arguments"], # } # } # for tc in tool_message.tool_calls # ], # ) # messages.append(message_dict) # messages.extend(tool_metadata["tool_calls"]) # message.status = "thinking" # message.content = "Incorporating tool results into response..." # yield message # # Decrease creativity when processing tool call requests # message.content = "" # start_time = time.perf_counter() # for response in llm.chat( # model=model, # messages=messages, # options={ # **message.metadata.options, # }, # stream=True, # ): # # logger.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}") # message.status = "streaming" # message.chunk = response.message.content # message.content += message.chunk # if not response.done: # yield message # if response.done: # self.collect_metrics(response) # message.metadata.eval_count += response.eval_count # message.metadata.eval_duration += response.eval_duration # message.metadata.prompt_eval_count += response.prompt_eval_count # message.metadata.prompt_eval_duration += response.prompt_eval_duration # self.context_tokens = ( # response.prompt_eval_count + response.eval_count # ) # message.status = "done" # yield message # end_time = time.perf_counter() # message.metadata.timers["llm_with_tools"] = end_time - start_time # return def collect_metrics(self, response): self.metrics.tokens_prompt.labels(agent=self.agent_type).inc( response.prompt_eval_count ) self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count) async def generate( self, llm: Any, model: str, query: ChatQuery, session_id: str, user_id: str, temperature=0.7 ) -> AsyncGenerator[ChatMessage | ChatMessageBase, None]: logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") user_message = ChatMessageUser( session_id=session_id, tunables=query.tunables, type=ChatMessageType.USER, status=ChatStatusType.DONE, sender=ChatSenderType.USER, content=query.prompt.strip(), timestamp=datetime.now(UTC) ) chat_message = ChatMessage( session_id=session_id, tunables=query.tunables, status=ChatStatusType.INITIALIZING, type=ChatMessageType.PREPARING, sender=ChatSenderType.ASSISTANT, content="", timestamp=datetime.now(UTC) ) self.metrics.generate_count.labels(agent=self.agent_type).inc() with self.metrics.generate_duration.labels(agent=self.agent_type).time(): # Create a pruned down message list based purely on the prompt and responses, # discarding the full preamble generated by prepare_message messages: List[LLMMessage] = [ LLMMessage(role="system", content=self.system_prompt) ] messages.extend([ LLMMessage(role=m.sender, content=m.content.strip()) for m in self.conversation ]) # Only the actual user query is provided with the full context message messages.append( LLMMessage(role=user_message.sender, content=user_message.content.strip()) ) # message.messages = messages chat_message.metadata = ChatMessageMetaData() chat_message.metadata.options = ChatOptions( seed=8911, num_ctx=self.context_size, temperature=temperature, # Higher temperature to encourage tool usage ) # Create a dict for storing various timing stats chat_message.metadata.timers = {} # use_tools = message.tunables.enable_tools and len(self.context.tools) > 0 # message.metadata.tools = { # "available": llm_tools(self.context.tools), # "used": False, # } # tool_metadata = message.metadata.tools # if use_tools: # message.status = "thinking" # message.content = f"Performing tool analysis step 1/2..." # yield message # logger.info("Checking for LLM tool usage") # start_time = time.perf_counter() # # Tools are enabled and available, so query the LLM with a short context of messages # # in case the LLM did something like ask "Do you want me to run the tool?" and the # # user said "Yes" -- need to keep the context in the thread. # tool_metadata["messages"] = ( # [{"role": "system", "content": self.system_prompt}] + messages[-6:] # if len(messages) >= 7 # else messages # ) # response = llm.chat( # model=model, # messages=tool_metadata["messages"], # tools=tool_metadata["available"], # options={ # **message.metadata.options, # }, # stream=False, # No need to stream the probe # ) # self.collect_metrics(response) # end_time = time.perf_counter() # message.metadata.timers["tool_check"] = end_time - start_time # if not response.message.tool_calls: # logger.info("LLM indicates tools will not be used") # # The LLM will not use tools, so disable use_tools so we can stream the full response # use_tools = False # else: # tool_metadata["attempted"] = response.message.tool_calls # if use_tools: # logger.info("LLM indicates tools will be used") # # Tools are enabled and available and the LLM indicated it will use them # message.content = ( # f"Performing tool analysis step 2/2 (tool use suspected)..." # ) # yield message # logger.info(f"Performing LLM call with tools") # start_time = time.perf_counter() # response = llm.chat( # model=model, # messages=tool_metadata["messages"], # messages, # tools=tool_metadata["available"], # options={ # **message.metadata.options, # }, # stream=False, # ) # self.collect_metrics(response) # end_time = time.perf_counter() # message.metadata.timers["non_streaming"] = end_time - start_time # if not response: # message.status = "error" # message.content = "No response from LLM." # yield message # return # if response.message.tool_calls: # tool_metadata["used"] = response.message.tool_calls # # Process all yielded items from the handler # start_time = time.perf_counter() # async for message in self.process_tool_calls( # llm=llm, # model=model, # message=message, # tool_message=response.message, # messages=messages, # ): # if message.status == "error": # yield message # return # yield message # end_time = time.perf_counter() # message.metadata.timers["process_tool_calls"] = end_time - start_time # message.status = "done" # return # logger.info("LLM indicated tools will be used, and then they weren't") # message.content = response.message.content # message.status = "done" # yield message # return # not use_tools chat_message.type = ChatMessageType.THINKING chat_message.content = f"Generating response..." yield chat_message # Reset the response for streaming chat_message.content = "" start_time = time.perf_counter() chat_message.type = ChatMessageType.GENERATING chat_message.status = ChatStatusType.STREAMING for response in llm.chat( model=model, messages=messages, options={ **chat_message.metadata.model_dump(exclude_unset=True), }, stream=True, ): if not response: chat_message.status = ChatStatusType.ERROR chat_message.content = "No response from LLM." yield chat_message return chat_message.content += response.message.content if not response.done: chat_chunk = model_cast.cast_to_model(ChatMessageBase, chat_message) chat_chunk.content = response.message.content yield chat_message continue if response.done: self.collect_metrics(response) chat_message.metadata.eval_count += response.eval_count chat_message.metadata.eval_duration += response.eval_duration chat_message.metadata.prompt_eval_count += response.prompt_eval_count chat_message.metadata.prompt_eval_duration += response.prompt_eval_duration self.context_tokens = ( response.prompt_eval_count + response.eval_count ) chat_message.type = ChatMessageType.RESPONSE chat_message.status = ChatStatusType.DONE yield chat_message end_time = time.perf_counter() chat_message.metadata.timers["streamed"] = end_time - start_time # Add the user and chat messages to the conversation self.conversation.append(user_message) self.conversation.append(chat_message) return # async def process_message( # self, llm: Any, model: str, message: Message # ) -> AsyncGenerator[Message, None]: # logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") # self.metrics.process_count.labels(agent=self.agent_type).inc() # with self.metrics.process_duration.labels(agent=self.agent_type).time(): # if not self.context: # raise ValueError("Context is not set for this agent.") # logger.info( # "TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time" # ) # spinner: List[str] = ["\\", "|", "/", "-"] # tick: int = 0 # while self.context.processing: # message.status = "waiting" # message.content = ( # f"Busy processing another request. Please wait. {spinner[tick]}" # ) # tick = (tick + 1) % len(spinner) # yield message # await asyncio.sleep(1) # Allow the event loop to process the write # self.context.processing = True # message.system_prompt = ( # f"<|system|>\n{self.system_prompt.strip()}\n" # ) # message.context_prompt = "" # for p in message.preamble.keys(): # message.context_prompt += ( # f"\n<|{p}|>\n{message.preamble[p].strip()}\n\n\n" # ) # message.context_prompt += f"{message.prompt}" # # Estimate token length of new messages # message.content = f"Optimizing context..." # message.status = "thinking" # yield message # message.context_size = self.set_optimal_context_size( # llm, model, prompt=message.context_prompt # ) # message.content = f"Processing {'RAG augmented ' if message.metadata.rag else ''}query..." # message.status = "thinking" # yield message # async for message in self.generate_llm_response( # llm=llm, model=model, message=message # ): # # logger.info(f"LLM: {message.status} - {f'...{message.content[-20:]}' if len(message.content) > 20 else message.content}") # if message.status == "error": # yield message # self.context.processing = False # return # yield message # # Done processing, add message to conversation # message.status = "done" # self.conversation.add(message) # self.context.processing = False # return # Register the base agent agent_registry.register(Agent._agent_type, Agent)