From 202060f5b56cc965515e176b4d061043a7faa800 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Fri, 2 May 2025 13:57:09 -0700 Subject: [PATCH] Tools are working and shared context is in use aross all agents --- frontend/src/ChatBubble.tsx | 1 - frontend/src/Conversation.tsx | 21 ++---- frontend/src/Message.tsx | 123 ++++++++++++++-------------------- src/server.py | 43 +++++++++--- src/utils/agents/base.py | 39 ++++++++++- src/utils/agents/chat.py | 54 ++++++++------- src/utils/context.py | 5 -- src/utils/message.py | 6 +- 8 files changed, 156 insertions(+), 136 deletions(-) diff --git a/frontend/src/ChatBubble.tsx b/frontend/src/ChatBubble.tsx index a5276b4..d0d613d 100644 --- a/frontend/src/ChatBubble.tsx +++ b/frontend/src/ChatBubble.tsx @@ -162,7 +162,6 @@ function ChatBubble(props: ChatBubbleProps) { ); } - console.log(role); return ( {icons[role] !== undefined && icons[role]} diff --git a/frontend/src/Conversation.tsx b/frontend/src/Conversation.tsx index 0d5100f..e3350e4 100644 --- a/frontend/src/Conversation.tsx +++ b/frontend/src/Conversation.tsx @@ -201,17 +201,13 @@ const Conversation = forwardRef(({ // isProcessing?: boolean, // metadata?: MessageMetaData // }; - setConversation(backstoryMessages.flatMap((message: BackstoryMessage) => [{ + setConversation(backstoryMessages.flatMap((backstoryMessage: BackstoryMessage) => [{ role: 'user', - content: message.prompt || "", + content: backstoryMessage.prompt || "", }, { + ...backstoryMessage, role: 'assistant', - prompt: message.prompt || "", - preamble: message.preamble || {}, - full_content: message.full_content || "", - content: message.response || "", - metadata: message.metadata, - actions: message.actions, + content: backstoryMessage.response || "", }] as MessageList)); setNoInteractions(false); } @@ -400,17 +396,10 @@ const Conversation = forwardRef(({ const backstoryMessage: BackstoryMessage = update; setConversation([ ...conversationRef.current, { - // role: 'user', - // content: backstoryMessage.prompt || "", - // }, { + ...backstoryMessage, role: 'assistant', origin: type, content: backstoryMessage.response || "", - prompt: backstoryMessage.prompt || "", - preamble: backstoryMessage.preamble || {}, - full_content: backstoryMessage.full_content || "", - metadata: backstoryMessage.metadata, - actions: backstoryMessage.actions, }] as MessageList); // Add a small delay to ensure React has time to update the UI await new Promise(resolve => setTimeout(resolve, 0)); diff --git a/frontend/src/Message.tsx b/frontend/src/Message.tsx index 4442b6d..01fff20 100644 --- a/frontend/src/Message.tsx +++ b/frontend/src/Message.tsx @@ -33,7 +33,6 @@ type MessageRoles = 'info' | 'user' | 'assistant' | 'system' | 'status' | 'error type MessageData = { role: MessageRoles, content: string, - full_content?: string, disableCopy?: boolean, user?: string, @@ -101,56 +100,46 @@ const MessageMeta = (props: MessageMetaProps) => { const message = props.messageProps.message; return (<> - - Below is the LLM performance of this query. Note that if tools are called, the - entire context is processed for each separate tool request by the LLM. This - can dramatically increase the total time for a response. - - - - - - - Tokens - Time (s) - TPS - - - - - Prompt - {prompt_eval_count} - {Math.round(prompt_eval_duration / 10 ** 7) / 100} - {Math.round(prompt_eval_count * 10 ** 9 / prompt_eval_duration)} - - - Response - {eval_count} - {Math.round(eval_duration / 10 ** 7) / 100} - {Math.round(eval_count * 10 ** 9 / eval_duration)} - - - Total - {prompt_eval_count + eval_count} - {Math.round((prompt_eval_duration + eval_duration) / 10 ** 7) / 100} - {Math.round((prompt_eval_count + eval_count) * 10 ** 9 / (prompt_eval_duration + eval_duration))} - - -
-
- { - message.full_content !== undefined && - - }> - - Full Query - - - -
{message.full_content?.trim()}
-
-
+ prompt_eval_duration !== 0 && eval_duration !== 0 && <> + + Below is the LLM performance of this query. Note that if tools are called, the + entire context is processed for each separate tool request by the LLM. This + can dramatically increase the total time for a response. + + + + + + + Tokens + Time (s) + TPS + + + + + Prompt + {prompt_eval_count} + {Math.round(prompt_eval_duration / 10 ** 7) / 100} + {Math.round(prompt_eval_count * 10 ** 9 / prompt_eval_duration)} + + + Response + {eval_count} + {Math.round(eval_duration / 10 ** 7) / 100} + {Math.round(eval_count * 10 ** 9 / eval_duration)} + + + Total + {prompt_eval_count + eval_count} + {Math.round((prompt_eval_duration + eval_duration) / 10 ** 7) / 100} + {Math.round((prompt_eval_count + eval_count) * 10 ** 9 / (prompt_eval_duration + eval_duration))} + + +
+
+ } { tools !== undefined && tools.tool_calls && tools.tool_calls.length !== 0 && @@ -216,33 +205,19 @@ const MessageMeta = (props: MessageMetaProps) => { }> - All response fields + Full Response Details - {Object.entries(message) - .filter(([key, value]) => key !== undefined && value !== undefined) - .map(([key, value]) => (typeof (value) !== "string" || value?.trim() !== "") && - - }> - {key} - - - {typeof (value) === "string" ? -
{value}
: - - { - if (typeof (children) === "string" && children.match("\n")) { - return
{children}
- } - }} - /> -
- } -
-
- )} + + { + if (typeof (children) === "string" && children.match("\n")) { + return
{children.trim()}
+ } + }} + /> +
); diff --git a/src/server.py b/src/server.py index 5183688..f73c445 100644 --- a/src/server.py +++ b/src/server.py @@ -17,6 +17,7 @@ import re import math import warnings from typing import Any +from collections import deque from uuid import uuid4 @@ -66,12 +67,6 @@ rags = [ system_message = f""" Launched on {Tools.DateTime()}. -You have access to tools to get real time access to: -- AnalyzeSite: Allows you to look up information on the Internet -- TickerValue: Allows you to find stock price values -- DateTime: Allows you to get the current date and time -- WeatherForecast: Allows you to get the weather forecast for a given location - When answering queries, follow these steps: - First analyze the query to determine if real-time information from the tools might be helpful @@ -87,6 +82,22 @@ When answering queries, follow these steps: Always use tools and <|context|> when possible. Be concise, and never make up information. If you do not know the answer, say so. """ +system_message_old = f""" +Launched on {Tools.DateTime()}. + +When answering queries, follow these steps: + +1. First analyze the query to determine if real-time information might be helpful +2. Even when <|context|> is provided, consider whether the tools would provide more current or comprehensive information +3. Use the provided tools whenever they would enhance your response, regardless of whether context is also available +4. When presenting weather forecasts, include relevant emojis immediately before the corresponding text. For example, for a sunny day, say \"☀️ Sunny\" or if the forecast says there will be \"rain showers, say \"🌧️ Rain showers\". Use this mapping for weather emojis: Sunny: ☀️, Cloudy: ☁️, Rainy: 🌧️, Snowy: ❄️ +4. When both <|context|> and tool outputs are relevant, synthesize information from both sources to provide the most complete answer +5. Always prioritize the most up-to-date and relevant information, whether it comes from <|context|> or tools +6. If <|context|> and tool outputs contain conflicting information, prefer the tool outputs as they likely represent more current data + +Always use tools and <|context|> when possible. Be concise, and never make up information. If you do not know the answer, say so. +""".strip() + system_generate_resume = f""" Launched on {Tools.DateTime()}. @@ -585,13 +596,25 @@ class WebServer: # Create a custom generator that ensures flushing async def flush_generator(): + logging.info(f"Message starting. Streaming partial results.") async for message in self.generate_response(context=context, agent=agent, content=data["content"]): + if message.status != "done": + result = { + "status": message.status, + "response": message.response + } + else: + logging.info(f"Message complete. Providing full response.") + result = message.model_dump(mode='json') + result = json.dumps(result) + "\n" + message.network_packets += 1 + message.network_bytes += len(result) # Convert to JSON and add newline - yield json.dumps(message.model_dump(mode='json')) + "\n" - # Save the history as its generated - self.save_context(context_id) + yield result # Explicitly flush after each yield await asyncio.sleep(0) # Allow the event loop to process the write + # Save the history once completed + self.save_context(context_id) # Return StreamingResponse with appropriate headers return StreamingResponse( @@ -914,7 +937,7 @@ class WebServer: } else: yield {"status": "complete", "message": "RAG processing complete"} - + async def generate_response(self, context : Context, agent : Agent, content : str) -> AsyncGenerator[Message, None]: if not self.file_watcher: raise Exception("File watcher not initialized") diff --git a/src/utils/agents/base.py b/src/utils/agents/base.py index 73416fa..23f49ff 100644 --- a/src/utils/agents/base.py +++ b/src/utils/agents/base.py @@ -4,6 +4,7 @@ from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, Asyn from abc import ABC, abstractmethod from typing_extensions import Annotated from .. setup_logging import setup_logging +from .. import defines logger = setup_logging() @@ -22,6 +23,16 @@ class Agent(BaseModel, ABC): This class defines the common attributes and methods for all agent types. """ + # context_size is shared across all subclasses + _context_size: ClassVar[int] = int(defines.max_context * 0.5) + @property + def context_size(self) -> int: + return Agent._context_size + + @context_size.setter + def context_size(self, value: int): + Agent._context_size = value + # Agent management with pydantic agent_type: Literal["base"] = "base" _agent_type: ClassVar[str] = agent_type # Add this for registration @@ -34,15 +45,39 @@ class Agent(BaseModel, ABC): _content_seed: str = PrivateAttr(default="") + def set_optimal_context_size(self, llm: Any, model: str, prompt: str, ctx_buffer=2048) -> int: + # Get more accurate token count estimate using tiktoken or similar + response = llm.generate( + model=model, + prompt=prompt, + options={ + "num_ctx": self.context_size, + "num_predict": 0, + } # Don't generate any tokens, just tokenize + ) + # The prompt_eval_count gives you the token count of your input + tokens = response.get("prompt_eval_count", 0) + + # Add buffer for safety + total_ctx = tokens + ctx_buffer + + if total_ctx > self.context_size: + logger.info(f"Increasing context size from {self.context_size} to {total_ctx}") + + # Grow the context size if necessary + self.context_size = max(self.context_size, total_ctx) + # Use actual model maximum context size + return self.context_size + # Class and pydantic model management - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs) -> None: """Auto-register subclasses""" super().__init_subclass__(**kwargs) # Register this class if it has an agent_type if hasattr(cls, 'agent_type') and cls.agent_type != Agent._agent_type: registry.register(cls.agent_type, cls) - def model_dump(self, *args, **kwargs): + def model_dump(self, *args, **kwargs) -> Any: # Ensure context is always excluded, even with exclude_unset=True kwargs.setdefault("exclude", set()) if isinstance(kwargs["exclude"], set): diff --git a/src/utils/agents/chat.py b/src/utils/agents/chat.py index 024d47e..62a44e0 100644 --- a/src/utils/agents/chat.py +++ b/src/utils/agents/chat.py @@ -62,13 +62,11 @@ class Chat(Agent, ABC): preamble_types_OR = " or ".join(preamble_types) message.preamble["rules"] = f"""\ - Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly. -- If there is no information in these sections, answer based on your knowledge. +- If there is no information in these sections, answer based on your knowledge, or use any available tools. - Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}. """ - message.preamble["question"] = "Use that information to respond to:" - else: - message.preamble["question"] = "Respond to:" - + message.preamble["question"] = "Respond to:" + message.system_prompt = self.system_prompt message.status = "done" yield message @@ -80,7 +78,6 @@ class Chat(Agent, ABC): raise ValueError("Context is not set for this agent.") if not message.metadata["tools"]: raise ValueError("tools field not initialized") - logging.info(f"LLM - tool processing - {tool_message}") tool_metadata = message.metadata["tools"] tool_metadata["messages"] = messages @@ -95,6 +92,7 @@ class Chat(Agent, ABC): # Yield status update before processing each tool message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..." yield message + logging.info(f"LLM - {message.response}") # Process the tool based on its type match tool: @@ -186,10 +184,10 @@ class Chat(Agent, ABC): message.metadata["prompt_eval_duration"] += response.prompt_eval_duration self.context_tokens = response.prompt_eval_count + response.eval_count message.status = "done" + yield message + end_time = time.perf_counter() message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}" - message.status = "done" - yield message return async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]: @@ -197,22 +195,23 @@ class Chat(Agent, ABC): if not self.context: raise ValueError("Context is not set for this agent.") - messages = [ + messages = [ { "role": "system", "content": message.system_prompt } ] + messages.extend([ item for m in self.conversation.messages for item in [ - {"role": "user", "content": m.prompt}, - {"role": "assistant", "content": m.response} + {"role": "user", "content": m.prompt.strip()}, + {"role": "assistant", "content": m.response.strip()} ] - ] + ]) messages.append({ "role": "user", - "content": message.full_content, + "content": message.context_prompt.strip(), }) - + message.metadata["messages"] = messages message.metadata["options"]={ "seed": 8911, - "num_ctx": message.metadata["ctx_size"] if message.metadata["ctx_size"] else defines.max_context, - "temperature": 0.9, # Higher temperature to encourage tool usage + "num_ctx": self.context_size, + #"temperature": 0.9, # Higher temperature to encourage tool usage } message.metadata["timers"] = {} @@ -222,6 +221,7 @@ class Chat(Agent, ABC): "available": Tools.llm_tools(self.context.tools), "used": False } + tool_metadata = message.metadata["tools"] if use_tools: message.status = "thinking" @@ -232,10 +232,11 @@ class Chat(Agent, ABC): start_time = time.perf_counter() # Tools are enabled and available, so query the LLM with a short token target to see if it will # use the tools + tool_metadata["messages"] = [{ "role": "system", "content": self.system_prompt}, {"role": "user", "content": message.prompt}] response = llm.chat( model=model, - messages=messages, #[{ "role": "system", "content": self.system_prompt}, {"role": "user", "content": message.prompt}], - tools=message.metadata["tools"]["available"], + messages=tool_metadata["messages"], + tools=tool_metadata["available"], options={ **message.metadata["options"], #"num_predict": 1024, # "Low" token limit to cut off after tool call @@ -253,7 +254,7 @@ class Chat(Agent, ABC): logging.info("LLM indicates tools will be used") # Tools are enabled and available and the LLM indicated it will use them - message.metadata["tools"]["attempted"] = response.message.tool_calls + tool_metadata["attempted"] = response.message.tool_calls message.response = f"Performing tool analysis step 2/2 (tool use suspected)..." yield message @@ -261,8 +262,8 @@ class Chat(Agent, ABC): start_time = time.perf_counter() response = llm.chat( model=model, - messages=messages, - tools=message.metadata["tools"]["available"], + messages=tool_metadata["messages"], # messages, + tools=tool_metadata["available"], options={ **message.metadata["options"], }, @@ -278,7 +279,7 @@ class Chat(Agent, ABC): return if response.message.tool_calls: - message.metadata["tools"]["used"] = response.message.tool_calls + tool_metadata["used"] = response.message.tool_calls # Process all yielded items from the handler start_time = time.perf_counter() async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages): @@ -345,13 +346,14 @@ class Chat(Agent, ABC): self.context.processing = True - message.metadata["system_prompt"] = f"<|system|>{self.system_prompt.strip()}\n" + message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n" + message.context_prompt = "" for p in message.preamble.keys(): - message.full_content += f"\n<|{p}|>\n{message.preamble[p].strip()}\n" - message.full_content += f"{message.prompt}" + message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n" + message.context_prompt += f"{message.prompt}" # Estimate token length of new messages - message.metadata["ctx_size"] = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content) + message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt) message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." message.status = "thinking" diff --git a/src/utils/context.py b/src/utils/context.py index e97112b..f4bb32c 100644 --- a/src/utils/context.py +++ b/src/utils/context.py @@ -33,7 +33,6 @@ class Context(BaseModel): tools: List[dict] = Tools.default_tools(Tools.tools) rags: List[dict] = [] message_history_length: int = 5 - context_tokens: int = 0 # Class managed fields agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( default_factory=list @@ -58,10 +57,6 @@ class Context(BaseModel): agent.set_context(self) return self - def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096): - ctx = round(context + len(str(messages)) * 3 / 4) - return max(defines.max_context, min(2048, ctx + ctx_buffer)) - def generate_rag_results(self, message: Message) -> Generator[Message, None, None]: """ Generate RAG results for the given query. diff --git a/src/utils/message.py b/src/utils/message.py index 0a9c542..d106bd2 100644 --- a/src/utils/message.py +++ b/src/utils/message.py @@ -14,7 +14,7 @@ class Message(BaseModel): status: str = "" # Status of the message preamble: dict[str,str] = {} # Preamble to be prepended to the prompt system_prompt: str = "" # System prompt provided to the LLM - full_content: str = "" # Full content of the message (preamble + prompt) + context_prompt: str = "" # Full content of the message (preamble + prompt) response: str = "" # LLM response to the preamble + query metadata: dict[str, Any] = { "rag": List[dict[str, Any]], @@ -22,8 +22,10 @@ class Message(BaseModel): "eval_duration": 0, "prompt_eval_count": 0, "prompt_eval_duration": 0, - "ctx_size": 0, + "context_size": 0, } + network_packets: int = 0 # Total number of streaming packets + network_bytes: int = 0 # Total bytes sent while streaming packets actions: List[str] = [] # Other session modifying actions performed while processing the message timestamp: datetime = datetime.now(timezone.utc)