387 lines
14 KiB
Python

from __future__ import annotations
from pydantic import BaseModel, model_validator, PrivateAttr
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any
from typing_extensions import Annotated
from abc import ABC, abstractmethod
from typing_extensions import Annotated
import logging
from . base import Agent, registry
from .. conversation import Conversation
from .. message import Message
from .. import defines
from .. import tools as Tools
from ollama import ChatResponse
import json
import time
import inspect
class Chat(Agent, ABC):
"""
Base class for all agent types.
This class defines the common attributes and methods for all agent types.
"""
agent_type: Literal["chat"] = "chat"
_agent_type: ClassVar[str] = agent_type # Add this for registration
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
"""
Prepare message with context information in message.preamble
"""
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
# Generate RAG content if enabled, based on the content
rag_context = ""
if message.enable_rag:
# Gather RAG results, yielding each result
# as it becomes available
for message in self.context.generate_rag_results(message):
logging.info(f"RAG: {message.status} - {message.response}")
if message.status == "error":
yield message
return
if message.status != "done":
yield message
if "rag" in message.metadata and message.metadata["rag"]:
for rag in message.metadata["rag"]:
for doc in rag["documents"]:
rag_context += f"{doc}\n"
message.preamble = {}
if rag_context:
message.preamble["context"] = rag_context
if self.context.user_resume:
message.preamble["resume"] = self.context.user_resume
if message.preamble:
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
preamble_types_AND = " and ".join(preamble_types)
preamble_types_OR = " or ".join(preamble_types)
message.preamble["rules"] = f"""\
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
"""
message.preamble["question"] = "Respond to:"
message.system_prompt = self.system_prompt
message.status = "done"
yield message
return
async def process_tool_calls(self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[Any]) -> AsyncGenerator[Message, None]:
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
if not message.metadata["tools"]:
raise ValueError("tools field not initialized")
tool_metadata = message.metadata["tools"]
tool_metadata["messages"] = messages
tool_metadata["tool_calls"] = []
message.status = "tooling"
for i, tool_call in enumerate(tool_message.tool_calls):
arguments = tool_call.function.arguments
tool = tool_call.function.name
# Yield status update before processing each tool
message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
yield message
logging.info(f"LLM - {message.response}")
# Process the tool based on its type
match tool:
case "TickerValue":
ticker = arguments.get("ticker")
if not ticker:
ret = None
else:
ret = Tools.TickerValue(ticker)
case "AnalyzeSite":
url = arguments.get("url")
question = arguments.get("question", "what is the summary of this content?")
# Additional status update for long-running operations
message.response = f"Retrieving and summarizing content from {url}..."
yield message
ret = await Tools.AnalyzeSite(llm=llm, model=model, url=url, question=question)
case "DateTime":
tz = arguments.get("timezone")
ret = Tools.DateTime(tz)
case "WeatherForecast":
city = arguments.get("city")
state = arguments.get("state")
message.response = f"Fetching weather data for {city}, {state}..."
yield message
ret = Tools.WeatherForecast(city, state)
case _:
ret = None
# Build response for this tool
tool_response = {
"role": "tool",
"content": json.dumps(ret),
"name": tool_call.function.name
}
tool_metadata["tool_calls"].append(tool_response)
if len(tool_metadata["tool_calls"]) == 0:
message.status = "done"
yield message
return
message_dict = {
"role": tool_message.get("role", "assistant"),
"content": tool_message.get("content", ""),
"tool_calls": [
{
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
}
for tc in tool_message.tool_calls
]
}
messages.append(message_dict)
messages.extend(tool_metadata["tool_calls"])
message.status = "thinking"
# Decrease creativity when processing tool call requests
message.response = ""
start_time = time.perf_counter()
for response in llm.chat(
model=model,
messages=messages,
stream=True,
options={
**message.metadata["options"],
# "temperature": 0.5,
}
):
# logging.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
message.status = "streaming"
message.response += response.message.content
if not response.done:
yield message
if response.done:
message.metadata["eval_count"] += response.eval_count
message.metadata["eval_duration"] += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
self.context_tokens = response.prompt_eval_count + response.eval_count
message.status = "done"
yield message
end_time = time.perf_counter()
message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}"
return
async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]:
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
messages = [ { "role": "system", "content": message.system_prompt } ]
messages.extend([
item for m in self.conversation.messages
for item in [
{"role": "user", "content": m.prompt.strip()},
{"role": "assistant", "content": m.response.strip()}
]
])
messages.append({
"role": "user",
"content": message.context_prompt.strip(),
})
message.metadata["messages"] = messages
message.metadata["options"]={
"seed": 8911,
"num_ctx": self.context_size,
#"temperature": 0.9, # Higher temperature to encourage tool usage
}
message.metadata["timers"] = {}
use_tools = message.enable_tools and len(self.context.tools) > 0
message.metadata["tools"] = {
"available": Tools.llm_tools(self.context.tools),
"used": False
}
tool_metadata = message.metadata["tools"]
if use_tools:
message.status = "thinking"
message.response = f"Performing tool analysis step 1/2..."
yield message
logging.info("Checking for LLM tool usage")
start_time = time.perf_counter()
# Tools are enabled and available, so query the LLM with a short token target to see if it will
# use the tools
tool_metadata["messages"] = [{ "role": "system", "content": self.system_prompt}, {"role": "user", "content": message.prompt}]
response = llm.chat(
model=model,
messages=tool_metadata["messages"],
tools=tool_metadata["available"],
options={
**message.metadata["options"],
#"num_predict": 1024, # "Low" token limit to cut off after tool call
},
stream=False # No need to stream the probe
)
end_time = time.perf_counter()
message.metadata["timers"]["tool_check"] = f"{(end_time - start_time):.4f}"
if not response.message.tool_calls:
logging.info("LLM indicates tools will not be used")
# The LLM will not use tools, so disable use_tools so we can stream the full response
use_tools = False
if use_tools:
logging.info("LLM indicates tools will be used")
# Tools are enabled and available and the LLM indicated it will use them
tool_metadata["attempted"] = response.message.tool_calls
message.response = f"Performing tool analysis step 2/2 (tool use suspected)..."
yield message
logging.info(f"Performing LLM call with tools")
start_time = time.perf_counter()
response = llm.chat(
model=model,
messages=tool_metadata["messages"], # messages,
tools=tool_metadata["available"],
options={
**message.metadata["options"],
},
stream=False
)
end_time = time.perf_counter()
message.metadata["timers"]["non_streaming"] = f"{(end_time - start_time):.4f}"
if not response:
message.status = "error"
message.response = "No response from LLM."
yield message
return
if response.message.tool_calls:
tool_metadata["used"] = response.message.tool_calls
# Process all yielded items from the handler
start_time = time.perf_counter()
async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages):
if message.status == "error":
yield message
return
yield message
end_time = time.perf_counter()
message.metadata["timers"]["process_tool_calls"] = f"{(end_time - start_time):.4f}"
message.status = "done"
return
logging.info("LLM indicated tools will be used, and then they weren't")
message.response = response.message.content
message.status = "done"
yield message
return
# not use_tools
yield message
# Reset the response for streaming
message.response = ""
start_time = time.perf_counter()
for response in llm.chat(
model=model,
messages=messages,
options={
**message.metadata["options"],
},
stream=True,
):
if not response:
message.status = "error"
message.response = "No response from LLM."
yield message
return
message.status = "streaming"
message.response += response.message.content
if not response.done:
yield message
if response.done:
message.metadata["eval_count"] += response.eval_count
message.metadata["eval_duration"] += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
self.context_tokens = response.prompt_eval_count + response.eval_count
message.status = "done"
yield message
end_time = time.perf_counter()
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
return
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
if self.context.processing:
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
message.status = "error"
message.response = "Busy processing another request."
yield message
return
self.context.processing = True
message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n"
message.context_prompt = ""
for p in message.preamble.keys():
message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n"
message.context_prompt += f"{message.prompt}"
# Estimate token length of new messages
message.response = f"Optimizing context..."
message.status = "thinking"
yield message
message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt)
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
message.status = "thinking"
yield message
async for message in self.generate_llm_response(llm, model, message):
# logging.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
if message.status == "error":
yield message
self.context.processing = False
return
yield message
# Done processing, add message to conversation
message.status = "done"
self.conversation.add_message(message)
self.context.processing = False
return
# Register the base agent
registry.register(Chat._agent_type, Chat)