diff --git a/src/server.py b/src/server.py index cede209..acfa475 100644 --- a/src/server.py +++ b/src/server.py @@ -649,7 +649,7 @@ class WebServer: async def flush_generator(): async for message in self.generate_response(context=context, agent=agent, content=data["content"]): # Convert to JSON and add newline - yield str(message) + "\n" + yield json.dumps(message.model_dump(mode='json')) + "\n" # Save the history as its generated self.save_context(context_id) # Explicitly flush after each yield @@ -987,7 +987,7 @@ class WebServer: # * First message sets Fact Check and is Q&A # * Has content # * Then Q&A of Fact Check - async def generate_response(self, context : Context, agent : Agent, content : str) -> Generator[Message, Any, None]: + async def generate_response(self, context : Context, agent : Agent, content : str) -> AsyncGenerator[Message, None]: if not self.file_watcher: raise Exception("File watcher not initialized") @@ -996,7 +996,7 @@ class WebServer: if agent_type == "chat": message = Message(prompt=content) async for value in agent.prepare_message(message): - logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}") + # logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}") if value.status != "done": yield value if value.status == "error": @@ -1004,17 +1004,8 @@ class WebServer: message.response = value.response yield message return - async for value in agent.process_message(message): - logger.info(f"{agent_type}.process_message: {value.status} - {value.response}") - if value.status != "done": - yield value - if value.status == "error": - message.status = "error" - message.response = value.response - yield message - return - async for value in agent.generate_llm_response(message): - logger.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}") + async for value in agent.process_message(self.llm, self.model, message): + # logger.info(f"{agent_type}.process_message: {value.status} - {value.response}") if value.status != "done": yield value if value.status == "error": @@ -1022,6 +1013,15 @@ class WebServer: message.response = value.response yield message return + # async for value in agent.generate_llm_response(message): + # logger.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}") + # if value.status != "done": + # yield value + # if value.status == "error": + # message.status = "error" + # message.response = value.response + # yield message + # return logger.info("TODO: There is more to do...") return diff --git a/src/utils/agents/base.py b/src/utils/agents/base.py index 80e55a2..b20e74d 100644 --- a/src/utils/agents/base.py +++ b/src/utils/agents/base.py @@ -1,6 +1,6 @@ from __future__ import annotations from pydantic import BaseModel, model_validator, PrivateAttr, Field -from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef +from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef, Any from abc import ABC, abstractmethod from typing_extensions import Annotated from .. setup_logging import setup_logging @@ -220,7 +220,7 @@ class Agent(BaseModel, ABC): self.context.processing = False return - async def process_message(self, message:Message) -> AsyncGenerator[Message, None]: + async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]: message.full_content = "" for i, p in enumerate(message.preamble.keys()): message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n" diff --git a/src/utils/agents/chat.py b/src/utils/agents/chat.py index f8b479c..b2a7011 100644 --- a/src/utils/agents/chat.py +++ b/src/utils/agents/chat.py @@ -1,6 +1,6 @@ from __future__ import annotations from pydantic import BaseModel, model_validator, PrivateAttr -from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar +from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any from typing_extensions import Annotated from abc import ABC, abstractmethod from typing_extensions import Annotated @@ -8,6 +8,7 @@ import logging from .base import Agent, registry from .. conversation import Conversation from .. message import Message +from .. import defines class Chat(Agent, ABC): """ @@ -70,7 +71,10 @@ class Chat(Agent, ABC): yield message return - async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]: + async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]: + if not self.context: + raise ValueError("Context is not set for this agent.") + if self.context.processing: logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") message.status = "error" @@ -80,47 +84,62 @@ class Chat(Agent, ABC): self.context.processing = True - messages = [] + self.conversation.add_message(message) - for value in self.llm.chat( - model=self.model, + messages = [ + item for m in self.conversation.messages + for item in [ + {"role": "user", "content": m.prompt}, + {"role": "assistant", "content": m.response} + ] + ] + + for value in llm.chat( + model=model, messages=messages, #tools=llm_tools(context.tools) if message.enable_tools else None, - options={ "num_ctx": message.ctx_size } + options={ "num_ctx": message.metadata["ctx_size"] if message.metadata["ctx_size"] else defines.max_context }, + stream=True, ): - logging.info(f"LLM: {value.status} - {value.response}") - if value.status != "done": - message.status = value.status - message.response = value.response - yield message - if value.status == "error": - return - response = value + logging.info(f"LLM: {'done' if value.done else 'thinking'} - {value.message.content}") + message.response += value.message.content + yield message + if value.done: + response = value + if not response: + message.status = "error" + message.response = "No response from LLM." + yield message + return + message.metadata["eval_count"] += response["eval_count"] message.metadata["eval_duration"] += response["eval_duration"] message.metadata["prompt_eval_count"] += response["prompt_eval_count"] message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] - agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] + self.context_tokens = response["prompt_eval_count"] + response["eval_count"] + yield message + return + tools_used = [] - yield {"status": "processing", "message": "Initial response received..."} if "tool_calls" in response.get("message", {}): - yield {"status": "processing", "message": "Processing tool calls..."} + message.status = "thinking" + message.response = "Processing tool calls..." tool_message = response["message"] tool_result = None # Process all yielded items from the handler - async for item in self.handle_tool_calls(tool_message): - if isinstance(item, tuple) and len(item) == 2: + async for value in self.handle_tool_calls(tool_message): + if isinstance(value, tuple) and len(value) == 2: # This is the final result tuple (tool_result, tools_used) - tool_result, tools_used = item + tool_result, tools_used = value else: # This is a status update, forward it - yield item + yield value message_dict = { "role": tool_message.get("role", "assistant"), @@ -179,19 +198,22 @@ class Chat(Agent, ABC): self.context.processing = False return - async def process_message(self, message:Message) -> AsyncGenerator[Message, None]: + async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]: + if not self.context: + raise ValueError("Context is not set for this agent.") + message.full_content = "" for i, p in enumerate(message.preamble.keys()): message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n" # Estimate token length of new messages - message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content) + message.metadata["ctx_size"] = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content) message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." message.status = "thinking" yield message - for value in self.generate_llm_response(message): + async for value in self.generate_llm_response(llm, model, message): logging.info(f"LLM: {value.status} - {value.response}") if value.status != "done": yield value diff --git a/src/utils/context.py b/src/utils/context.py index c2d5c4b..a4d60b0 100644 --- a/src/utils/context.py +++ b/src/utils/context.py @@ -39,6 +39,8 @@ class Context(BaseModel): default_factory=list ) + processing: bool = Field(default=False, exclude=True) + # @model_validator(mode="before") # @classmethod # def before_model_validator(cls, values: Any): diff --git a/src/utils/message.py b/src/utils/message.py index 8f5688b..aee6a84 100644 --- a/src/utils/message.py +++ b/src/utils/message.py @@ -23,6 +23,7 @@ class Message(BaseModel): "eval_duration": 0, "prompt_eval_count": 0, "prompt_eval_duration": 0, + "ctx_size": 0, } actions: List[str] = [] # Other session modifying actions performed while processing the message timestamp: datetime = datetime.now(timezone.utc)