Almost working?
This commit is contained in:
parent
c3cf9a9c76
commit
622c33545e
@ -978,8 +978,40 @@ class WebServer:
|
|||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
return
|
return
|
||||||
|
|
||||||
if agent.agent_type == "Chat":
|
agent_type = agent.get_agent_type()
|
||||||
agent.proces
|
logging.info(f"generate_response: {agent_type}")
|
||||||
|
if agent_type == "chat":
|
||||||
|
message = Message(prompt=content)
|
||||||
|
async for value in agent.prepare_message(message):
|
||||||
|
logging.info(f"{agent_type}.prepare_message: {value.status} - {value.response}")
|
||||||
|
if value.status != "done":
|
||||||
|
yield value
|
||||||
|
if value.status == "error":
|
||||||
|
message.status = "error"
|
||||||
|
message.response = value.response
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
async for value in agent.process_message(message):
|
||||||
|
logging.info(f"{agent_type}.process_message: {value.status} - {value.response}")
|
||||||
|
if value.status != "done":
|
||||||
|
yield value
|
||||||
|
if value.status == "error":
|
||||||
|
message.status = "error"
|
||||||
|
message.response = value.response
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
async for value in agent.generate_llm_response(message):
|
||||||
|
logging.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}")
|
||||||
|
if value.status != "done":
|
||||||
|
yield value
|
||||||
|
if value.status == "error":
|
||||||
|
message.status = "error"
|
||||||
|
message.response = value.response
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
logging.info("TODO: There is more to do...")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
if self.processing:
|
if self.processing:
|
||||||
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||||
@ -1003,7 +1035,7 @@ class WebServer:
|
|||||||
enable_rag = False
|
enable_rag = False
|
||||||
|
|
||||||
# RAG is disabled when asking questions about the resume
|
# RAG is disabled when asking questions about the resume
|
||||||
if agent.agent_type == "resume":
|
if agent.get_agent_type() == "resume":
|
||||||
enable_rag = False
|
enable_rag = False
|
||||||
|
|
||||||
# The first time through each agent agent_type a content_seed may be set for
|
# The first time through each agent agent_type a content_seed may be set for
|
||||||
@ -1014,7 +1046,7 @@ class WebServer:
|
|||||||
# After the first time a particular agent agent_type is used, it is handled as a chat.
|
# After the first time a particular agent agent_type is used, it is handled as a chat.
|
||||||
# The number of messages indicating the agent is ready for chat varies based on
|
# The number of messages indicating the agent is ready for chat varies based on
|
||||||
# the agent_type of agent
|
# the agent_type of agent
|
||||||
process_type = agent.agent_type
|
process_type = agent.get_agent_type()
|
||||||
match process_type:
|
match process_type:
|
||||||
case "job_description":
|
case "job_description":
|
||||||
logging.info(f"job_description user_history len: {len(conversation.messages)}")
|
logging.info(f"job_description user_history len: {len(conversation.messages)}")
|
||||||
@ -1281,7 +1313,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
if len(conversation.messages) > 2:
|
if len(conversation.messages) > 2:
|
||||||
processing_message = f"Processing {'RAG augmented ' if enable_rag else ''}query..."
|
processing_message = f"Processing {'RAG augmented ' if enable_rag else ''}query..."
|
||||||
else:
|
else:
|
||||||
match agent.agent_type:
|
match agent.get_agent_type():
|
||||||
case "job_description":
|
case "job_description":
|
||||||
processing_message = f"Generating {'RAG augmented ' if enable_rag else ''}resume..."
|
processing_message = f"Generating {'RAG augmented ' if enable_rag else ''}resume..."
|
||||||
case "resume":
|
case "resume":
|
||||||
@ -1363,7 +1395,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
|
|
||||||
reply = response["message"]["content"]
|
reply = response["message"]["content"]
|
||||||
message.response = reply
|
message.response = reply
|
||||||
message.metadata["origin"] = agent.agent_type
|
message.metadata["origin"] = agent.get_agent_type()
|
||||||
# final_message = {"role": "assistant", "content": reply }
|
# final_message = {"role": "assistant", "content": reply }
|
||||||
|
|
||||||
# # history is provided to the LLM and should not have additional metadata
|
# # history is provided to the LLM and should not have additional metadata
|
||||||
|
@ -1,256 +0,0 @@
|
|||||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
|
||||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .types import AgentBase, registry
|
|
||||||
|
|
||||||
# Only import Context for type checking
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .context import Context
|
|
||||||
|
|
||||||
from .types import AgentBase
|
|
||||||
|
|
||||||
from .conversation import Conversation
|
|
||||||
from .message import Message
|
|
||||||
|
|
||||||
class Agent(AgentBase):
|
|
||||||
"""
|
|
||||||
Base class for all agent types.
|
|
||||||
This class defines the common attributes and methods for all agent types.
|
|
||||||
"""
|
|
||||||
agent_type: str = Field(default="agent", const=True) # discriminator value
|
|
||||||
|
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
|
||||||
"""Auto-register subclasses"""
|
|
||||||
super().__init_subclass__(**kwargs)
|
|
||||||
# Register this class if it has an agent_type
|
|
||||||
if hasattr(cls, 'agent_type') and cls.agent_type != Agent.agent_type:
|
|
||||||
registry.register(cls.agent_type, cls)
|
|
||||||
|
|
||||||
def __init__(self, **data):
|
|
||||||
# Set agent_type from class if not provided
|
|
||||||
if 'agent_type' not in data:
|
|
||||||
data['agent_type'] = self.__class__.agent_type
|
|
||||||
super().__init__(**data)
|
|
||||||
|
|
||||||
system_prompt: str # Mandatory
|
|
||||||
conversation: Conversation = Conversation()
|
|
||||||
context_tokens: int = 0
|
|
||||||
|
|
||||||
# Add a property for context if needed without creating a circular reference
|
|
||||||
@property
|
|
||||||
def context(self) -> Optional['Context']:
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .context import Context
|
|
||||||
# Implement logic to fetch context by ID if needed
|
|
||||||
return None
|
|
||||||
|
|
||||||
#context: Context
|
|
||||||
|
|
||||||
_content_seed: str = PrivateAttr(default="")
|
|
||||||
|
|
||||||
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
|
||||||
"""
|
|
||||||
Prepare message with context information in message.preamble
|
|
||||||
"""
|
|
||||||
# Generate RAG content if enabled, based on the content
|
|
||||||
rag_context = ""
|
|
||||||
if not message.disable_rag:
|
|
||||||
# Gather RAG results, yielding each result
|
|
||||||
# as it becomes available
|
|
||||||
for value in self.context.generate_rag_results(message):
|
|
||||||
logging.info(f"RAG: {value.status} - {value.response}")
|
|
||||||
if value.status != "done":
|
|
||||||
yield value
|
|
||||||
if value.status == "error":
|
|
||||||
message.status = "error"
|
|
||||||
message.response = value.response
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
if message.metadata["rag"]:
|
|
||||||
for rag_collection in message.metadata["rag"]:
|
|
||||||
for doc in rag_collection["documents"]:
|
|
||||||
rag_context += f"{doc}\n"
|
|
||||||
|
|
||||||
if rag_context:
|
|
||||||
message["context"] = rag_context
|
|
||||||
|
|
||||||
if self.context.user_resume:
|
|
||||||
message["resume"] = self.content.user_resume
|
|
||||||
|
|
||||||
if message.preamble:
|
|
||||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
|
|
||||||
preamble_types_AND = " and ".join(preamble_types)
|
|
||||||
preamble_types_OR = " or ".join(preamble_types)
|
|
||||||
message.preamble["rules"] = f"""\
|
|
||||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly.
|
|
||||||
- If there is no information in these sections, answer based on your knowledge.
|
|
||||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
|
||||||
"""
|
|
||||||
message.preamble["question"] = "Use that information to respond to:"
|
|
||||||
else:
|
|
||||||
message.preamble["question"] = "Respond to:"
|
|
||||||
|
|
||||||
message.system_prompt = self.system_prompt
|
|
||||||
message.status = "done"
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
|
|
||||||
if self.context.processing:
|
|
||||||
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
|
||||||
message.status = "error"
|
|
||||||
message.response = "Busy processing another request."
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
self.context.processing = True
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
for value in self.llm.chat(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
#tools=llm_tools(context.tools) if message.enable_tools else None,
|
|
||||||
options={ "num_ctx": message.ctx_size }
|
|
||||||
):
|
|
||||||
logging.info(f"LLM: {value.status} - {value.response}")
|
|
||||||
if value.status != "done":
|
|
||||||
message.status = value.status
|
|
||||||
message.response = value.response
|
|
||||||
yield message
|
|
||||||
if value.status == "error":
|
|
||||||
return
|
|
||||||
response = value
|
|
||||||
|
|
||||||
message.metadata["eval_count"] += response["eval_count"]
|
|
||||||
message.metadata["eval_duration"] += response["eval_duration"]
|
|
||||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
|
||||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
|
||||||
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
|
||||||
|
|
||||||
tools_used = []
|
|
||||||
|
|
||||||
yield {"status": "processing", "message": "Initial response received..."}
|
|
||||||
|
|
||||||
if "tool_calls" in response.get("message", {}):
|
|
||||||
yield {"status": "processing", "message": "Processing tool calls..."}
|
|
||||||
|
|
||||||
tool_message = response["message"]
|
|
||||||
tool_result = None
|
|
||||||
|
|
||||||
# Process all yielded items from the handler
|
|
||||||
async for item in self.handle_tool_calls(tool_message):
|
|
||||||
if isinstance(item, tuple) and len(item) == 2:
|
|
||||||
# This is the final result tuple (tool_result, tools_used)
|
|
||||||
tool_result, tools_used = item
|
|
||||||
else:
|
|
||||||
# This is a status update, forward it
|
|
||||||
yield item
|
|
||||||
|
|
||||||
message_dict = {
|
|
||||||
"role": tool_message.get("role", "assistant"),
|
|
||||||
"content": tool_message.get("content", "")
|
|
||||||
}
|
|
||||||
|
|
||||||
if "tool_calls" in tool_message:
|
|
||||||
message_dict["tool_calls"] = [
|
|
||||||
{"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}}
|
|
||||||
for tc in tool_message["tool_calls"]
|
|
||||||
]
|
|
||||||
|
|
||||||
pre_add_index = len(messages)
|
|
||||||
messages.append(message_dict)
|
|
||||||
|
|
||||||
if isinstance(tool_result, list):
|
|
||||||
messages.extend(tool_result)
|
|
||||||
else:
|
|
||||||
if tool_result:
|
|
||||||
messages.append(tool_result)
|
|
||||||
|
|
||||||
message.metadata["tools"] = tools_used
|
|
||||||
|
|
||||||
# Estimate token length of new messages
|
|
||||||
ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
|
|
||||||
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
|
|
||||||
# Decrease creativity when processing tool call requests
|
|
||||||
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
|
|
||||||
message.metadata["eval_count"] += response["eval_count"]
|
|
||||||
message.metadata["eval_duration"] += response["eval_duration"]
|
|
||||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
|
||||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
|
||||||
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
|
||||||
|
|
||||||
reply = response["message"]["content"]
|
|
||||||
message.response = reply
|
|
||||||
message.metadata["origin"] = agent.agent_type
|
|
||||||
# final_message = {"role": "assistant", "content": reply }
|
|
||||||
|
|
||||||
# # history is provided to the LLM and should not have additional metadata
|
|
||||||
# llm_history.append(final_message)
|
|
||||||
|
|
||||||
# user_history is provided to the REST API and does not include CONTEXT
|
|
||||||
# It does include metadata
|
|
||||||
# final_message["metadata"] = message.metadata
|
|
||||||
# user_history.append({**final_message, "origin": message.metadata["origin"]})
|
|
||||||
|
|
||||||
# Return the REST API with metadata
|
|
||||||
yield {
|
|
||||||
"status": "done",
|
|
||||||
"message": {
|
|
||||||
**message.model_dump(mode='json'),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.context.processing = False
|
|
||||||
return
|
|
||||||
|
|
||||||
async def process_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
|
||||||
message.full_content = ""
|
|
||||||
for i, p in enumerate(message.preamble.keys()):
|
|
||||||
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
|
|
||||||
|
|
||||||
# Estimate token length of new messages
|
|
||||||
message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
|
|
||||||
|
|
||||||
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
|
||||||
message.status = "thinking"
|
|
||||||
yield message
|
|
||||||
|
|
||||||
for value in self.generate_llm_response(message):
|
|
||||||
logging.info(f"LLM: {value.status} - {value.response}")
|
|
||||||
if value.status != "done":
|
|
||||||
yield value
|
|
||||||
if value.status == "error":
|
|
||||||
return
|
|
||||||
|
|
||||||
def get_and_reset_content_seed(self):
|
|
||||||
tmp = self._content_seed
|
|
||||||
self._content_seed = ""
|
|
||||||
return tmp
|
|
||||||
|
|
||||||
def set_content_seed(self, content: str) -> None:
|
|
||||||
"""Set the content seed for the agent."""
|
|
||||||
self._content_seed = content
|
|
||||||
|
|
||||||
def get_content_seed(self) -> str:
|
|
||||||
"""Get the content seed for the agent."""
|
|
||||||
return self._content_seed
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def valid_agent_types(cls) -> set[str]:
|
|
||||||
"""Return the set of valid agent_type values."""
|
|
||||||
return set(get_args(cls.__annotations__["agent_type"]))
|
|
||||||
|
|
||||||
|
|
||||||
# Register the base agent
|
|
||||||
registry.register(Agent.agent_type, Agent)
|
|
||||||
|
|
||||||
# Type alias for Agent or any subclass
|
|
||||||
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
|
|
||||||
|
|
||||||
import ./agents
|
|
@ -8,7 +8,7 @@ import logging
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .. context import Context
|
from .. context import Context
|
||||||
|
|
||||||
from .types import AgentBase, ContextBase, registry
|
from .types import AgentBase, ContextRef, registry
|
||||||
|
|
||||||
from .. conversation import Conversation
|
from .. conversation import Conversation
|
||||||
from .. message import Message
|
from .. message import Message
|
||||||
@ -18,9 +18,18 @@ class Agent(AgentBase):
|
|||||||
Base class for all agent types.
|
Base class for all agent types.
|
||||||
This class defines the common attributes and methods for all agent types.
|
This class defines the common attributes and methods for all agent types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Agent management with pydantic
|
||||||
agent_type: Literal["base"] = "base"
|
agent_type: Literal["base"] = "base"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
|
# Agent properties
|
||||||
|
system_prompt: str # Mandatory
|
||||||
|
conversation: Conversation = Conversation()
|
||||||
|
context_tokens: int = 0
|
||||||
|
context: ContextRef # Avoid circular reference
|
||||||
|
_content_seed: str = PrivateAttr(default="")
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
"""Auto-register subclasses"""
|
"""Auto-register subclasses"""
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
@ -34,12 +43,8 @@ class Agent(AgentBase):
|
|||||||
data['agent_type'] = self.__class__.agent_type
|
data['agent_type'] = self.__class__.agent_type
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|
||||||
system_prompt: str # Mandatory
|
def get_agent_type(self):
|
||||||
conversation: Conversation = Conversation()
|
return self._agent_type
|
||||||
context_tokens: int = 0
|
|
||||||
context: ContextBase # Avoid circular reference
|
|
||||||
|
|
||||||
_content_seed: str = PrivateAttr(default="")
|
|
||||||
|
|
||||||
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
||||||
"""
|
"""
|
||||||
|
@ -4,7 +4,7 @@ from typing_extensions import Annotated
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
import logging
|
import logging
|
||||||
from .base import Agent, ContextBase, registry
|
from .base import Agent, registry
|
||||||
from .. conversation import Conversation
|
from .. conversation import Conversation
|
||||||
from .. message import Message
|
from .. message import Message
|
||||||
|
|
||||||
@ -16,13 +16,6 @@ class Chat(Agent, ABC):
|
|||||||
agent_type: Literal["chat"] = "chat"
|
agent_type: Literal["chat"] = "chat"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
system_prompt: str # Mandatory
|
|
||||||
conversation: Conversation = Conversation()
|
|
||||||
context_tokens: int = 0
|
|
||||||
context: ContextBase
|
|
||||||
|
|
||||||
_content_seed: str = PrivateAttr(default="")
|
|
||||||
|
|
||||||
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
||||||
"""
|
"""
|
||||||
Prepare message with context information in message.preamble
|
Prepare message with context information in message.preamble
|
||||||
|
@ -9,7 +9,7 @@ from .message import Message
|
|||||||
from .rag import ChromaDBFileWatcher
|
from .rag import ChromaDBFileWatcher
|
||||||
from . import defines
|
from . import defines
|
||||||
|
|
||||||
from .agents import Agent, ContextBase
|
from .agents import Agent
|
||||||
|
|
||||||
# Import only agent types, not actual classes
|
# Import only agent types, not actual classes
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from .agents import AnyAgent
|
from .agents import AnyAgent
|
||||||
|
|
||||||
class Context(ContextBase):
|
class Context(BaseModel):
|
||||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
||||||
|
|
||||||
id: str = Field(
|
id: str = Field(
|
||||||
|
@ -6,6 +6,10 @@ class Message(BaseModel):
|
|||||||
# Required
|
# Required
|
||||||
prompt: str # Query to be answered
|
prompt: str # Query to be answered
|
||||||
|
|
||||||
|
# Tunables
|
||||||
|
disable_rag: bool = False
|
||||||
|
disable_tools: bool = False
|
||||||
|
|
||||||
# Generated while processing message
|
# Generated while processing message
|
||||||
preamble: dict[str,str] = {} # Preamble to be prepended to the prompt
|
preamble: dict[str,str] = {} # Preamble to be prepended to the prompt
|
||||||
system_prompt: str = "" # System prompt provided to the LLM
|
system_prompt: str = "" # System prompt provided to the LLM
|
||||||
|
Loading…
x
Reference in New Issue
Block a user