from __future__ import annotations from pydantic import BaseModel, model_validator, PrivateAttr from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar from typing_extensions import Annotated from abc import ABC, abstractmethod from typing_extensions import Annotated import logging from .base import Agent, registry from .. conversation import Conversation from .. message import Message class Chat(Agent, ABC): """ Base class for all agent types. This class defines the common attributes and methods for all agent types. """ agent_type: Literal["chat"] = "chat" _agent_type: ClassVar[str] = agent_type # Add this for registration async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: """ Prepare message with context information in message.preamble """ if not self.context: raise ValueError("Context is not set for this agent.") # Generate RAG content if enabled, based on the content rag_context = "" if not message.disable_rag: # Gather RAG results, yielding each result # as it becomes available for value in self.context.generate_rag_results(message): logging.info(f"RAG: {value.status} - {value.response}") if value.status != "done": yield value if value.status == "error": message.status = "error" message.response = value.response yield message return if "rag" in message.metadata and message.metadata["rag"]: for rag in message.metadata["rag"]: for doc in rag["documents"]: rag_context += f"{doc}\n" message.preamble = {} if rag_context: message.preamble["context"] = rag_context if self.context.user_resume: message.preamble["resume"] = self.context.user_resume if message.preamble: preamble_types = [f"<|{p}|>" for p in message.preamble.keys()] preamble_types_AND = " and ".join(preamble_types) preamble_types_OR = " or ".join(preamble_types) message.preamble["rules"] = f"""\ - Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly. - If there is no information in these sections, answer based on your knowledge. - Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}. """ message.preamble["question"] = "Use that information to respond to:" else: message.preamble["question"] = "Respond to:" message.system_prompt = self.system_prompt message.status = "done" yield message return async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]: if self.context.processing: logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") message.status = "error" message.response = "Busy processing another request." yield message return self.context.processing = True messages = [] for value in self.llm.chat( model=self.model, messages=messages, #tools=llm_tools(context.tools) if message.enable_tools else None, options={ "num_ctx": message.ctx_size } ): logging.info(f"LLM: {value.status} - {value.response}") if value.status != "done": message.status = value.status message.response = value.response yield message if value.status == "error": return response = value message.metadata["eval_count"] += response["eval_count"] message.metadata["eval_duration"] += response["eval_duration"] message.metadata["prompt_eval_count"] += response["prompt_eval_count"] message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] tools_used = [] yield {"status": "processing", "message": "Initial response received..."} if "tool_calls" in response.get("message", {}): yield {"status": "processing", "message": "Processing tool calls..."} tool_message = response["message"] tool_result = None # Process all yielded items from the handler async for item in self.handle_tool_calls(tool_message): if isinstance(item, tuple) and len(item) == 2: # This is the final result tuple (tool_result, tools_used) tool_result, tools_used = item else: # This is a status update, forward it yield item message_dict = { "role": tool_message.get("role", "assistant"), "content": tool_message.get("content", "") } if "tool_calls" in tool_message: message_dict["tool_calls"] = [ {"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}} for tc in tool_message["tool_calls"] ] pre_add_index = len(messages) messages.append(message_dict) if isinstance(tool_result, list): messages.extend(tool_result) else: if tool_result: messages.append(tool_result) message.metadata["tools"] = tools_used # Estimate token length of new messages ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:]) yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size } # Decrease creativity when processing tool call requests response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 }) message.metadata["eval_count"] += response["eval_count"] message.metadata["eval_duration"] += response["eval_duration"] message.metadata["prompt_eval_count"] += response["prompt_eval_count"] message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] reply = response["message"]["content"] message.response = reply message.metadata["origin"] = agent.agent_type # final_message = {"role": "assistant", "content": reply } # # history is provided to the LLM and should not have additional metadata # llm_history.append(final_message) # user_history is provided to the REST API and does not include CONTEXT # It does include metadata # final_message["metadata"] = message.metadata # user_history.append({**final_message, "origin": message.metadata["origin"]}) # Return the REST API with metadata yield { "status": "done", "message": { **message.model_dump(mode='json'), } } self.context.processing = False return async def process_message(self, message:Message) -> AsyncGenerator[Message, None]: message.full_content = "" for i, p in enumerate(message.preamble.keys()): message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n" # Estimate token length of new messages message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content) message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." message.status = "thinking" yield message for value in self.generate_llm_response(message): logging.info(f"LLM: {value.status} - {value.response}") if value.status != "done": yield value if value.status == "error": return def get_and_reset_content_seed(self): tmp = self._content_seed self._content_seed = "" return tmp def set_content_seed(self, content: str) -> None: """Set the content seed for the agent.""" self._content_seed = content def get_content_seed(self) -> str: """Get the content seed for the agent.""" return self._content_seed @classmethod def valid_agent_types(cls) -> set[str]: """Return the set of valid agent_type values.""" return set(get_args(cls.__annotations__["agent_type"])) # Register the base agent registry.register(Chat._agent_type, Chat)