onion peeling

This commit is contained in:
James Ketr 2025-04-30 16:43:02 -07:00
parent 3094288e46
commit 7f24d8870c
5 changed files with 65 additions and 40 deletions

View File

@ -649,7 +649,7 @@ class WebServer:
async def flush_generator():
async for message in self.generate_response(context=context, agent=agent, content=data["content"]):
# Convert to JSON and add newline
yield str(message) + "\n"
yield json.dumps(message.model_dump(mode='json')) + "\n"
# Save the history as its generated
self.save_context(context_id)
# Explicitly flush after each yield
@ -987,7 +987,7 @@ class WebServer:
# * First message sets Fact Check and is Q&A
# * Has content
# * Then Q&A of Fact Check
async def generate_response(self, context : Context, agent : Agent, content : str) -> Generator[Message, Any, None]:
async def generate_response(self, context : Context, agent : Agent, content : str) -> AsyncGenerator[Message, None]:
if not self.file_watcher:
raise Exception("File watcher not initialized")
@ -996,7 +996,7 @@ class WebServer:
if agent_type == "chat":
message = Message(prompt=content)
async for value in agent.prepare_message(message):
logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}")
# logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}")
if value.status != "done":
yield value
if value.status == "error":
@ -1004,17 +1004,8 @@ class WebServer:
message.response = value.response
yield message
return
async for value in agent.process_message(message):
logger.info(f"{agent_type}.process_message: {value.status} - {value.response}")
if value.status != "done":
yield value
if value.status == "error":
message.status = "error"
message.response = value.response
yield message
return
async for value in agent.generate_llm_response(message):
logger.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}")
async for value in agent.process_message(self.llm, self.model, message):
# logger.info(f"{agent_type}.process_message: {value.status} - {value.response}")
if value.status != "done":
yield value
if value.status == "error":
@ -1022,6 +1013,15 @@ class WebServer:
message.response = value.response
yield message
return
# async for value in agent.generate_llm_response(message):
# logger.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}")
# if value.status != "done":
# yield value
# if value.status == "error":
# message.status = "error"
# message.response = value.response
# yield message
# return
logger.info("TODO: There is more to do...")
return

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from pydantic import BaseModel, model_validator, PrivateAttr, Field
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef, Any
from abc import ABC, abstractmethod
from typing_extensions import Annotated
from .. setup_logging import setup_logging
@ -220,7 +220,7 @@ class Agent(BaseModel, ABC):
self.context.processing = False
return
async def process_message(self, message:Message) -> AsyncGenerator[Message, None]:
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
message.full_content = ""
for i, p in enumerate(message.preamble.keys()):
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from pydantic import BaseModel, model_validator, PrivateAttr
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any
from typing_extensions import Annotated
from abc import ABC, abstractmethod
from typing_extensions import Annotated
@ -8,6 +8,7 @@ import logging
from .base import Agent, registry
from .. conversation import Conversation
from .. message import Message
from .. import defines
class Chat(Agent, ABC):
"""
@ -70,7 +71,10 @@ class Chat(Agent, ABC):
yield message
return
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]:
if not self.context:
raise ValueError("Context is not set for this agent.")
if self.context.processing:
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
message.status = "error"
@ -80,47 +84,62 @@ class Chat(Agent, ABC):
self.context.processing = True
messages = []
self.conversation.add_message(message)
for value in self.llm.chat(
model=self.model,
messages = [
item for m in self.conversation.messages
for item in [
{"role": "user", "content": m.prompt},
{"role": "assistant", "content": m.response}
]
]
for value in llm.chat(
model=model,
messages=messages,
#tools=llm_tools(context.tools) if message.enable_tools else None,
options={ "num_ctx": message.ctx_size }
options={ "num_ctx": message.metadata["ctx_size"] if message.metadata["ctx_size"] else defines.max_context },
stream=True,
):
logging.info(f"LLM: {value.status} - {value.response}")
if value.status != "done":
message.status = value.status
message.response = value.response
yield message
if value.status == "error":
return
response = value
logging.info(f"LLM: {'done' if value.done else 'thinking'} - {value.message.content}")
message.response += value.message.content
yield message
if value.done:
response = value
if not response:
message.status = "error"
message.response = "No response from LLM."
yield message
return
message.metadata["eval_count"] += response["eval_count"]
message.metadata["eval_duration"] += response["eval_duration"]
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
self.context_tokens = response["prompt_eval_count"] + response["eval_count"]
yield message
return
tools_used = []
yield {"status": "processing", "message": "Initial response received..."}
if "tool_calls" in response.get("message", {}):
yield {"status": "processing", "message": "Processing tool calls..."}
message.status = "thinking"
message.response = "Processing tool calls..."
tool_message = response["message"]
tool_result = None
# Process all yielded items from the handler
async for item in self.handle_tool_calls(tool_message):
if isinstance(item, tuple) and len(item) == 2:
async for value in self.handle_tool_calls(tool_message):
if isinstance(value, tuple) and len(value) == 2:
# This is the final result tuple (tool_result, tools_used)
tool_result, tools_used = item
tool_result, tools_used = value
else:
# This is a status update, forward it
yield item
yield value
message_dict = {
"role": tool_message.get("role", "assistant"),
@ -179,19 +198,22 @@ class Chat(Agent, ABC):
self.context.processing = False
return
async def process_message(self, message:Message) -> AsyncGenerator[Message, None]:
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
if not self.context:
raise ValueError("Context is not set for this agent.")
message.full_content = ""
for i, p in enumerate(message.preamble.keys()):
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
# Estimate token length of new messages
message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
message.metadata["ctx_size"] = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
message.status = "thinking"
yield message
for value in self.generate_llm_response(message):
async for value in self.generate_llm_response(llm, model, message):
logging.info(f"LLM: {value.status} - {value.response}")
if value.status != "done":
yield value

View File

@ -39,6 +39,8 @@ class Context(BaseModel):
default_factory=list
)
processing: bool = Field(default=False, exclude=True)
# @model_validator(mode="before")
# @classmethod
# def before_model_validator(cls, values: Any):

View File

@ -23,6 +23,7 @@ class Message(BaseModel):
"eval_duration": 0,
"prompt_eval_count": 0,
"prompt_eval_duration": 0,
"ctx_size": 0,
}
actions: List[str] = [] # Other session modifying actions performed while processing the message
timestamp: datetime = datetime.now(timezone.utc)