onion peeling
This commit is contained in:
parent
3094288e46
commit
7f24d8870c
@ -649,7 +649,7 @@ class WebServer:
|
|||||||
async def flush_generator():
|
async def flush_generator():
|
||||||
async for message in self.generate_response(context=context, agent=agent, content=data["content"]):
|
async for message in self.generate_response(context=context, agent=agent, content=data["content"]):
|
||||||
# Convert to JSON and add newline
|
# Convert to JSON and add newline
|
||||||
yield str(message) + "\n"
|
yield json.dumps(message.model_dump(mode='json')) + "\n"
|
||||||
# Save the history as its generated
|
# Save the history as its generated
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
# Explicitly flush after each yield
|
# Explicitly flush after each yield
|
||||||
@ -987,7 +987,7 @@ class WebServer:
|
|||||||
# * First message sets Fact Check and is Q&A
|
# * First message sets Fact Check and is Q&A
|
||||||
# * Has content
|
# * Has content
|
||||||
# * Then Q&A of Fact Check
|
# * Then Q&A of Fact Check
|
||||||
async def generate_response(self, context : Context, agent : Agent, content : str) -> Generator[Message, Any, None]:
|
async def generate_response(self, context : Context, agent : Agent, content : str) -> AsyncGenerator[Message, None]:
|
||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
raise Exception("File watcher not initialized")
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
@ -996,7 +996,7 @@ class WebServer:
|
|||||||
if agent_type == "chat":
|
if agent_type == "chat":
|
||||||
message = Message(prompt=content)
|
message = Message(prompt=content)
|
||||||
async for value in agent.prepare_message(message):
|
async for value in agent.prepare_message(message):
|
||||||
logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}")
|
# logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}")
|
||||||
if value.status != "done":
|
if value.status != "done":
|
||||||
yield value
|
yield value
|
||||||
if value.status == "error":
|
if value.status == "error":
|
||||||
@ -1004,17 +1004,8 @@ class WebServer:
|
|||||||
message.response = value.response
|
message.response = value.response
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
async for value in agent.process_message(message):
|
async for value in agent.process_message(self.llm, self.model, message):
|
||||||
logger.info(f"{agent_type}.process_message: {value.status} - {value.response}")
|
# logger.info(f"{agent_type}.process_message: {value.status} - {value.response}")
|
||||||
if value.status != "done":
|
|
||||||
yield value
|
|
||||||
if value.status == "error":
|
|
||||||
message.status = "error"
|
|
||||||
message.response = value.response
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
async for value in agent.generate_llm_response(message):
|
|
||||||
logger.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}")
|
|
||||||
if value.status != "done":
|
if value.status != "done":
|
||||||
yield value
|
yield value
|
||||||
if value.status == "error":
|
if value.status == "error":
|
||||||
@ -1022,6 +1013,15 @@ class WebServer:
|
|||||||
message.response = value.response
|
message.response = value.response
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
# async for value in agent.generate_llm_response(message):
|
||||||
|
# logger.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}")
|
||||||
|
# if value.status != "done":
|
||||||
|
# yield value
|
||||||
|
# if value.status == "error":
|
||||||
|
# message.status = "error"
|
||||||
|
# message.response = value.response
|
||||||
|
# yield message
|
||||||
|
# return
|
||||||
logger.info("TODO: There is more to do...")
|
logger.info("TODO: There is more to do...")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, model_validator, PrivateAttr, Field
|
from pydantic import BaseModel, model_validator, PrivateAttr, Field
|
||||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef
|
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef, Any
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
from .. setup_logging import setup_logging
|
from .. setup_logging import setup_logging
|
||||||
@ -220,7 +220,7 @@ class Agent(BaseModel, ABC):
|
|||||||
self.context.processing = False
|
self.context.processing = False
|
||||||
return
|
return
|
||||||
|
|
||||||
async def process_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
|
||||||
message.full_content = ""
|
message.full_content = ""
|
||||||
for i, p in enumerate(message.preamble.keys()):
|
for i, p in enumerate(message.preamble.keys()):
|
||||||
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
|
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, model_validator, PrivateAttr
|
from pydantic import BaseModel, model_validator, PrivateAttr
|
||||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
@ -8,6 +8,7 @@ import logging
|
|||||||
from .base import Agent, registry
|
from .base import Agent, registry
|
||||||
from .. conversation import Conversation
|
from .. conversation import Conversation
|
||||||
from .. message import Message
|
from .. message import Message
|
||||||
|
from .. import defines
|
||||||
|
|
||||||
class Chat(Agent, ABC):
|
class Chat(Agent, ABC):
|
||||||
"""
|
"""
|
||||||
@ -70,7 +71,10 @@ class Chat(Agent, ABC):
|
|||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
|
async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]:
|
||||||
|
if not self.context:
|
||||||
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
if self.context.processing:
|
if self.context.processing:
|
||||||
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
@ -80,47 +84,62 @@ class Chat(Agent, ABC):
|
|||||||
|
|
||||||
self.context.processing = True
|
self.context.processing = True
|
||||||
|
|
||||||
messages = []
|
self.conversation.add_message(message)
|
||||||
|
|
||||||
for value in self.llm.chat(
|
messages = [
|
||||||
model=self.model,
|
item for m in self.conversation.messages
|
||||||
|
for item in [
|
||||||
|
{"role": "user", "content": m.prompt},
|
||||||
|
{"role": "assistant", "content": m.response}
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
for value in llm.chat(
|
||||||
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
#tools=llm_tools(context.tools) if message.enable_tools else None,
|
#tools=llm_tools(context.tools) if message.enable_tools else None,
|
||||||
options={ "num_ctx": message.ctx_size }
|
options={ "num_ctx": message.metadata["ctx_size"] if message.metadata["ctx_size"] else defines.max_context },
|
||||||
|
stream=True,
|
||||||
):
|
):
|
||||||
logging.info(f"LLM: {value.status} - {value.response}")
|
logging.info(f"LLM: {'done' if value.done else 'thinking'} - {value.message.content}")
|
||||||
if value.status != "done":
|
message.response += value.message.content
|
||||||
message.status = value.status
|
yield message
|
||||||
message.response = value.response
|
if value.done:
|
||||||
yield message
|
response = value
|
||||||
if value.status == "error":
|
|
||||||
return
|
if not response:
|
||||||
response = value
|
message.status = "error"
|
||||||
|
message.response = "No response from LLM."
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
message.metadata["eval_count"] += response["eval_count"]
|
message.metadata["eval_count"] += response["eval_count"]
|
||||||
message.metadata["eval_duration"] += response["eval_duration"]
|
message.metadata["eval_duration"] += response["eval_duration"]
|
||||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||||
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
self.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||||
|
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
tools_used = []
|
tools_used = []
|
||||||
|
|
||||||
yield {"status": "processing", "message": "Initial response received..."}
|
|
||||||
|
|
||||||
if "tool_calls" in response.get("message", {}):
|
if "tool_calls" in response.get("message", {}):
|
||||||
yield {"status": "processing", "message": "Processing tool calls..."}
|
message.status = "thinking"
|
||||||
|
message.response = "Processing tool calls..."
|
||||||
|
|
||||||
tool_message = response["message"]
|
tool_message = response["message"]
|
||||||
tool_result = None
|
tool_result = None
|
||||||
|
|
||||||
# Process all yielded items from the handler
|
# Process all yielded items from the handler
|
||||||
async for item in self.handle_tool_calls(tool_message):
|
async for value in self.handle_tool_calls(tool_message):
|
||||||
if isinstance(item, tuple) and len(item) == 2:
|
if isinstance(value, tuple) and len(value) == 2:
|
||||||
# This is the final result tuple (tool_result, tools_used)
|
# This is the final result tuple (tool_result, tools_used)
|
||||||
tool_result, tools_used = item
|
tool_result, tools_used = value
|
||||||
else:
|
else:
|
||||||
# This is a status update, forward it
|
# This is a status update, forward it
|
||||||
yield item
|
yield value
|
||||||
|
|
||||||
message_dict = {
|
message_dict = {
|
||||||
"role": tool_message.get("role", "assistant"),
|
"role": tool_message.get("role", "assistant"),
|
||||||
@ -179,19 +198,22 @@ class Chat(Agent, ABC):
|
|||||||
self.context.processing = False
|
self.context.processing = False
|
||||||
return
|
return
|
||||||
|
|
||||||
async def process_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
|
||||||
|
if not self.context:
|
||||||
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
message.full_content = ""
|
message.full_content = ""
|
||||||
for i, p in enumerate(message.preamble.keys()):
|
for i, p in enumerate(message.preamble.keys()):
|
||||||
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
|
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
|
||||||
|
|
||||||
# Estimate token length of new messages
|
# Estimate token length of new messages
|
||||||
message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
|
message.metadata["ctx_size"] = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
|
||||||
|
|
||||||
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
||||||
message.status = "thinking"
|
message.status = "thinking"
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
for value in self.generate_llm_response(message):
|
async for value in self.generate_llm_response(llm, model, message):
|
||||||
logging.info(f"LLM: {value.status} - {value.response}")
|
logging.info(f"LLM: {value.status} - {value.response}")
|
||||||
if value.status != "done":
|
if value.status != "done":
|
||||||
yield value
|
yield value
|
||||||
|
@ -39,6 +39,8 @@ class Context(BaseModel):
|
|||||||
default_factory=list
|
default_factory=list
|
||||||
)
|
)
|
||||||
|
|
||||||
|
processing: bool = Field(default=False, exclude=True)
|
||||||
|
|
||||||
# @model_validator(mode="before")
|
# @model_validator(mode="before")
|
||||||
# @classmethod
|
# @classmethod
|
||||||
# def before_model_validator(cls, values: Any):
|
# def before_model_validator(cls, values: Any):
|
||||||
|
@ -23,6 +23,7 @@ class Message(BaseModel):
|
|||||||
"eval_duration": 0,
|
"eval_duration": 0,
|
||||||
"prompt_eval_count": 0,
|
"prompt_eval_count": 0,
|
||||||
"prompt_eval_duration": 0,
|
"prompt_eval_duration": 0,
|
||||||
|
"ctx_size": 0,
|
||||||
}
|
}
|
||||||
actions: List[str] = [] # Other session modifying actions performed while processing the message
|
actions: List[str] = [] # Other session modifying actions performed while processing the message
|
||||||
timestamp: datetime = datetime.now(timezone.utc)
|
timestamp: datetime = datetime.now(timezone.utc)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user