387 lines
14 KiB
Python
387 lines
14 KiB
Python
from __future__ import annotations
|
|
from pydantic import BaseModel, model_validator, PrivateAttr
|
|
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any
|
|
from typing_extensions import Annotated
|
|
from abc import ABC, abstractmethod
|
|
from typing_extensions import Annotated
|
|
import logging
|
|
from . base import Agent, registry
|
|
from .. conversation import Conversation
|
|
from .. message import Message
|
|
from .. import defines
|
|
from .. import tools as Tools
|
|
from ollama import ChatResponse
|
|
import json
|
|
import time
|
|
import inspect
|
|
|
|
class Chat(Agent, ABC):
|
|
"""
|
|
Base class for all agent types.
|
|
This class defines the common attributes and methods for all agent types.
|
|
"""
|
|
agent_type: Literal["chat"] = "chat"
|
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
|
|
|
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
|
"""
|
|
Prepare message with context information in message.preamble
|
|
"""
|
|
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
|
|
|
|
if not self.context:
|
|
raise ValueError("Context is not set for this agent.")
|
|
|
|
# Generate RAG content if enabled, based on the content
|
|
rag_context = ""
|
|
if message.enable_rag:
|
|
# Gather RAG results, yielding each result
|
|
# as it becomes available
|
|
for message in self.context.generate_rag_results(message):
|
|
logging.info(f"RAG: {message.status} - {message.response}")
|
|
if message.status == "error":
|
|
yield message
|
|
return
|
|
if message.status != "done":
|
|
yield message
|
|
|
|
if "rag" in message.metadata and message.metadata["rag"]:
|
|
for rag in message.metadata["rag"]:
|
|
for doc in rag["documents"]:
|
|
rag_context += f"{doc}\n"
|
|
|
|
message.preamble = {}
|
|
|
|
if rag_context:
|
|
message.preamble["context"] = rag_context
|
|
|
|
if self.context.user_resume:
|
|
message.preamble["resume"] = self.context.user_resume
|
|
|
|
if message.preamble:
|
|
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
|
|
preamble_types_AND = " and ".join(preamble_types)
|
|
preamble_types_OR = " or ".join(preamble_types)
|
|
message.preamble["rules"] = f"""\
|
|
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
|
- If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
|
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
|
"""
|
|
message.preamble["question"] = "Respond to:"
|
|
|
|
message.system_prompt = self.system_prompt
|
|
message.status = "done"
|
|
yield message
|
|
return
|
|
|
|
async def process_tool_calls(self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[Any]) -> AsyncGenerator[Message, None]:
|
|
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
|
|
|
|
if not self.context:
|
|
raise ValueError("Context is not set for this agent.")
|
|
if not message.metadata["tools"]:
|
|
raise ValueError("tools field not initialized")
|
|
|
|
tool_metadata = message.metadata["tools"]
|
|
tool_metadata["messages"] = messages
|
|
tool_metadata["tool_calls"] = []
|
|
|
|
message.status = "tooling"
|
|
|
|
for i, tool_call in enumerate(tool_message.tool_calls):
|
|
arguments = tool_call.function.arguments
|
|
tool = tool_call.function.name
|
|
|
|
# Yield status update before processing each tool
|
|
message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
|
|
yield message
|
|
logging.info(f"LLM - {message.response}")
|
|
|
|
# Process the tool based on its type
|
|
match tool:
|
|
case "TickerValue":
|
|
ticker = arguments.get("ticker")
|
|
if not ticker:
|
|
ret = None
|
|
else:
|
|
ret = Tools.TickerValue(ticker)
|
|
|
|
case "AnalyzeSite":
|
|
url = arguments.get("url")
|
|
question = arguments.get("question", "what is the summary of this content?")
|
|
|
|
# Additional status update for long-running operations
|
|
message.response = f"Retrieving and summarizing content from {url}..."
|
|
yield message
|
|
ret = await Tools.AnalyzeSite(llm=llm, model=model, url=url, question=question)
|
|
|
|
case "DateTime":
|
|
tz = arguments.get("timezone")
|
|
ret = Tools.DateTime(tz)
|
|
|
|
case "WeatherForecast":
|
|
city = arguments.get("city")
|
|
state = arguments.get("state")
|
|
|
|
message.response = f"Fetching weather data for {city}, {state}..."
|
|
yield message
|
|
ret = Tools.WeatherForecast(city, state)
|
|
|
|
case _:
|
|
ret = None
|
|
|
|
# Build response for this tool
|
|
tool_response = {
|
|
"role": "tool",
|
|
"content": json.dumps(ret),
|
|
"name": tool_call.function.name
|
|
}
|
|
|
|
tool_metadata["tool_calls"].append(tool_response)
|
|
|
|
if len(tool_metadata["tool_calls"]) == 0:
|
|
message.status = "done"
|
|
yield message
|
|
return
|
|
|
|
message_dict = {
|
|
"role": tool_message.get("role", "assistant"),
|
|
"content": tool_message.get("content", ""),
|
|
"tool_calls": [
|
|
{
|
|
"function": {
|
|
"name": tc["function"]["name"],
|
|
"arguments": tc["function"]["arguments"]
|
|
}
|
|
}
|
|
for tc in tool_message.tool_calls
|
|
]
|
|
}
|
|
|
|
messages.append(message_dict)
|
|
messages.extend(tool_metadata["tool_calls"])
|
|
|
|
message.status = "thinking"
|
|
|
|
# Decrease creativity when processing tool call requests
|
|
message.response = ""
|
|
start_time = time.perf_counter()
|
|
for response in llm.chat(
|
|
model=model,
|
|
messages=messages,
|
|
stream=True,
|
|
options={
|
|
**message.metadata["options"],
|
|
# "temperature": 0.5,
|
|
}
|
|
):
|
|
# logging.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
|
|
message.status = "streaming"
|
|
message.response += response.message.content
|
|
if not response.done:
|
|
yield message
|
|
if response.done:
|
|
message.metadata["eval_count"] += response.eval_count
|
|
message.metadata["eval_duration"] += response.eval_duration
|
|
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
|
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
|
|
self.context_tokens = response.prompt_eval_count + response.eval_count
|
|
message.status = "done"
|
|
yield message
|
|
|
|
end_time = time.perf_counter()
|
|
message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}"
|
|
return
|
|
|
|
async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]:
|
|
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
|
|
|
|
if not self.context:
|
|
raise ValueError("Context is not set for this agent.")
|
|
|
|
messages = [ { "role": "system", "content": message.system_prompt } ]
|
|
messages.extend([
|
|
item for m in self.conversation.messages
|
|
for item in [
|
|
{"role": "user", "content": m.prompt.strip()},
|
|
{"role": "assistant", "content": m.response.strip()}
|
|
]
|
|
])
|
|
messages.append({
|
|
"role": "user",
|
|
"content": message.context_prompt.strip(),
|
|
})
|
|
message.metadata["messages"] = messages
|
|
message.metadata["options"]={
|
|
"seed": 8911,
|
|
"num_ctx": self.context_size,
|
|
#"temperature": 0.9, # Higher temperature to encourage tool usage
|
|
}
|
|
|
|
message.metadata["timers"] = {}
|
|
|
|
use_tools = message.enable_tools and len(self.context.tools) > 0
|
|
message.metadata["tools"] = {
|
|
"available": Tools.llm_tools(self.context.tools),
|
|
"used": False
|
|
}
|
|
tool_metadata = message.metadata["tools"]
|
|
|
|
if use_tools:
|
|
message.status = "thinking"
|
|
message.response = f"Performing tool analysis step 1/2..."
|
|
yield message
|
|
|
|
logging.info("Checking for LLM tool usage")
|
|
start_time = time.perf_counter()
|
|
# Tools are enabled and available, so query the LLM with a short token target to see if it will
|
|
# use the tools
|
|
tool_metadata["messages"] = [{ "role": "system", "content": self.system_prompt}, {"role": "user", "content": message.prompt}]
|
|
response = llm.chat(
|
|
model=model,
|
|
messages=tool_metadata["messages"],
|
|
tools=tool_metadata["available"],
|
|
options={
|
|
**message.metadata["options"],
|
|
#"num_predict": 1024, # "Low" token limit to cut off after tool call
|
|
},
|
|
stream=False # No need to stream the probe
|
|
)
|
|
end_time = time.perf_counter()
|
|
message.metadata["timers"]["tool_check"] = f"{(end_time - start_time):.4f}"
|
|
if not response.message.tool_calls:
|
|
logging.info("LLM indicates tools will not be used")
|
|
# The LLM will not use tools, so disable use_tools so we can stream the full response
|
|
use_tools = False
|
|
|
|
if use_tools:
|
|
logging.info("LLM indicates tools will be used")
|
|
|
|
# Tools are enabled and available and the LLM indicated it will use them
|
|
tool_metadata["attempted"] = response.message.tool_calls
|
|
message.response = f"Performing tool analysis step 2/2 (tool use suspected)..."
|
|
yield message
|
|
|
|
logging.info(f"Performing LLM call with tools")
|
|
start_time = time.perf_counter()
|
|
response = llm.chat(
|
|
model=model,
|
|
messages=tool_metadata["messages"], # messages,
|
|
tools=tool_metadata["available"],
|
|
options={
|
|
**message.metadata["options"],
|
|
},
|
|
stream=False
|
|
)
|
|
end_time = time.perf_counter()
|
|
message.metadata["timers"]["non_streaming"] = f"{(end_time - start_time):.4f}"
|
|
|
|
if not response:
|
|
message.status = "error"
|
|
message.response = "No response from LLM."
|
|
yield message
|
|
return
|
|
|
|
if response.message.tool_calls:
|
|
tool_metadata["used"] = response.message.tool_calls
|
|
# Process all yielded items from the handler
|
|
start_time = time.perf_counter()
|
|
async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages):
|
|
if message.status == "error":
|
|
yield message
|
|
return
|
|
yield message
|
|
end_time = time.perf_counter()
|
|
message.metadata["timers"]["process_tool_calls"] = f"{(end_time - start_time):.4f}"
|
|
message.status = "done"
|
|
return
|
|
|
|
logging.info("LLM indicated tools will be used, and then they weren't")
|
|
message.response = response.message.content
|
|
message.status = "done"
|
|
yield message
|
|
return
|
|
|
|
# not use_tools
|
|
yield message
|
|
# Reset the response for streaming
|
|
message.response = ""
|
|
start_time = time.perf_counter()
|
|
for response in llm.chat(
|
|
model=model,
|
|
messages=messages,
|
|
options={
|
|
**message.metadata["options"],
|
|
},
|
|
stream=True,
|
|
):
|
|
if not response:
|
|
message.status = "error"
|
|
message.response = "No response from LLM."
|
|
yield message
|
|
return
|
|
message.status = "streaming"
|
|
message.response += response.message.content
|
|
if not response.done:
|
|
yield message
|
|
if response.done:
|
|
message.metadata["eval_count"] += response.eval_count
|
|
message.metadata["eval_duration"] += response.eval_duration
|
|
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
|
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
|
|
self.context_tokens = response.prompt_eval_count + response.eval_count
|
|
message.status = "done"
|
|
yield message
|
|
|
|
end_time = time.perf_counter()
|
|
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
|
|
return
|
|
|
|
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
|
|
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
|
|
|
|
if not self.context:
|
|
raise ValueError("Context is not set for this agent.")
|
|
|
|
if self.context.processing:
|
|
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
|
message.status = "error"
|
|
message.response = "Busy processing another request."
|
|
yield message
|
|
return
|
|
|
|
self.context.processing = True
|
|
|
|
message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n"
|
|
message.context_prompt = ""
|
|
for p in message.preamble.keys():
|
|
message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n"
|
|
message.context_prompt += f"{message.prompt}"
|
|
|
|
# Estimate token length of new messages
|
|
message.response = f"Optimizing context..."
|
|
message.status = "thinking"
|
|
yield message
|
|
message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt)
|
|
|
|
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
|
message.status = "thinking"
|
|
yield message
|
|
|
|
async for message in self.generate_llm_response(llm, model, message):
|
|
# logging.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
|
|
if message.status == "error":
|
|
yield message
|
|
self.context.processing = False
|
|
return
|
|
yield message
|
|
|
|
# Done processing, add message to conversation
|
|
message.status = "done"
|
|
self.conversation.add_message(message)
|
|
self.context.processing = False
|
|
return
|
|
|
|
# Register the base agent
|
|
registry.register(Chat._agent_type, Chat)
|