diff --git a/src/server.py b/src/server.py index d1cfa19..61ab2f0 100644 --- a/src/server.py +++ b/src/server.py @@ -978,8 +978,40 @@ class WebServer: if not self.file_watcher: return - if agent.agent_type == "Chat": - agent.proces + agent_type = agent.get_agent_type() + logging.info(f"generate_response: {agent_type}") + if agent_type == "chat": + message = Message(prompt=content) + async for value in agent.prepare_message(message): + logging.info(f"{agent_type}.prepare_message: {value.status} - {value.response}") + if value.status != "done": + yield value + if value.status == "error": + message.status = "error" + message.response = value.response + yield message + return + async for value in agent.process_message(message): + logging.info(f"{agent_type}.process_message: {value.status} - {value.response}") + if value.status != "done": + yield value + if value.status == "error": + message.status = "error" + message.response = value.response + yield message + return + async for value in agent.generate_llm_response(message): + logging.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}") + if value.status != "done": + yield value + if value.status == "error": + message.status = "error" + message.response = value.response + yield message + return + logging.info("TODO: There is more to do...") + return + if self.processing: logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") @@ -1003,7 +1035,7 @@ class WebServer: enable_rag = False # RAG is disabled when asking questions about the resume - if agent.agent_type == "resume": + if agent.get_agent_type() == "resume": enable_rag = False # The first time through each agent agent_type a content_seed may be set for @@ -1014,7 +1046,7 @@ class WebServer: # After the first time a particular agent agent_type is used, it is handled as a chat. # The number of messages indicating the agent is ready for chat varies based on # the agent_type of agent - process_type = agent.agent_type + process_type = agent.get_agent_type() match process_type: case "job_description": logging.info(f"job_description user_history len: {len(conversation.messages)}") @@ -1281,7 +1313,7 @@ Use the above <|resume|> and <|job_description|> to answer this query: if len(conversation.messages) > 2: processing_message = f"Processing {'RAG augmented ' if enable_rag else ''}query..." else: - match agent.agent_type: + match agent.get_agent_type(): case "job_description": processing_message = f"Generating {'RAG augmented ' if enable_rag else ''}resume..." case "resume": @@ -1363,7 +1395,7 @@ Use the above <|resume|> and <|job_description|> to answer this query: reply = response["message"]["content"] message.response = reply - message.metadata["origin"] = agent.agent_type + message.metadata["origin"] = agent.get_agent_type() # final_message = {"role": "assistant", "content": reply } # # history is provided to the LLM and should not have additional metadata diff --git a/src/utils/agent.py b/src/utils/agent.py deleted file mode 100644 index e723f54..0000000 --- a/src/utils/agent.py +++ /dev/null @@ -1,256 +0,0 @@ -from pydantic import BaseModel, Field, model_validator, PrivateAttr -from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar -from abc import ABC, abstractmethod -from typing_extensions import Annotated -import logging - -from .types import AgentBase, registry - -# Only import Context for type checking -if TYPE_CHECKING: - from .context import Context - -from .types import AgentBase - -from .conversation import Conversation -from .message import Message - -class Agent(AgentBase): - """ - Base class for all agent types. - This class defines the common attributes and methods for all agent types. - """ - agent_type: str = Field(default="agent", const=True) # discriminator value - - - def __init_subclass__(cls, **kwargs): - """Auto-register subclasses""" - super().__init_subclass__(**kwargs) - # Register this class if it has an agent_type - if hasattr(cls, 'agent_type') and cls.agent_type != Agent.agent_type: - registry.register(cls.agent_type, cls) - - def __init__(self, **data): - # Set agent_type from class if not provided - if 'agent_type' not in data: - data['agent_type'] = self.__class__.agent_type - super().__init__(**data) - - system_prompt: str # Mandatory - conversation: Conversation = Conversation() - context_tokens: int = 0 - - # Add a property for context if needed without creating a circular reference - @property - def context(self) -> Optional['Context']: - if TYPE_CHECKING: - from .context import Context - # Implement logic to fetch context by ID if needed - return None - - #context: Context - - _content_seed: str = PrivateAttr(default="") - - async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: - """ - Prepare message with context information in message.preamble - """ - # Generate RAG content if enabled, based on the content - rag_context = "" - if not message.disable_rag: - # Gather RAG results, yielding each result - # as it becomes available - for value in self.context.generate_rag_results(message): - logging.info(f"RAG: {value.status} - {value.response}") - if value.status != "done": - yield value - if value.status == "error": - message.status = "error" - message.response = value.response - yield message - return - - if message.metadata["rag"]: - for rag_collection in message.metadata["rag"]: - for doc in rag_collection["documents"]: - rag_context += f"{doc}\n" - - if rag_context: - message["context"] = rag_context - - if self.context.user_resume: - message["resume"] = self.content.user_resume - - if message.preamble: - preamble_types = [f"<|{p}|>" for p in message.preamble.keys()] - preamble_types_AND = " and ".join(preamble_types) - preamble_types_OR = " or ".join(preamble_types) - message.preamble["rules"] = f"""\ -- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly. -- If there is no information in these sections, answer based on your knowledge. -- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}. -""" - message.preamble["question"] = "Use that information to respond to:" - else: - message.preamble["question"] = "Respond to:" - - message.system_prompt = self.system_prompt - message.status = "done" - yield message - return - - async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]: - if self.context.processing: - logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") - message.status = "error" - message.response = "Busy processing another request." - yield message - return - - self.context.processing = True - - messages = [] - - for value in self.llm.chat( - model=self.model, - messages=messages, - #tools=llm_tools(context.tools) if message.enable_tools else None, - options={ "num_ctx": message.ctx_size } - ): - logging.info(f"LLM: {value.status} - {value.response}") - if value.status != "done": - message.status = value.status - message.response = value.response - yield message - if value.status == "error": - return - response = value - - message.metadata["eval_count"] += response["eval_count"] - message.metadata["eval_duration"] += response["eval_duration"] - message.metadata["prompt_eval_count"] += response["prompt_eval_count"] - message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] - agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] - - tools_used = [] - - yield {"status": "processing", "message": "Initial response received..."} - - if "tool_calls" in response.get("message", {}): - yield {"status": "processing", "message": "Processing tool calls..."} - - tool_message = response["message"] - tool_result = None - - # Process all yielded items from the handler - async for item in self.handle_tool_calls(tool_message): - if isinstance(item, tuple) and len(item) == 2: - # This is the final result tuple (tool_result, tools_used) - tool_result, tools_used = item - else: - # This is a status update, forward it - yield item - - message_dict = { - "role": tool_message.get("role", "assistant"), - "content": tool_message.get("content", "") - } - - if "tool_calls" in tool_message: - message_dict["tool_calls"] = [ - {"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}} - for tc in tool_message["tool_calls"] - ] - - pre_add_index = len(messages) - messages.append(message_dict) - - if isinstance(tool_result, list): - messages.extend(tool_result) - else: - if tool_result: - messages.append(tool_result) - - message.metadata["tools"] = tools_used - - # Estimate token length of new messages - ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:]) - yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size } - # Decrease creativity when processing tool call requests - response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 }) - message.metadata["eval_count"] += response["eval_count"] - message.metadata["eval_duration"] += response["eval_duration"] - message.metadata["prompt_eval_count"] += response["prompt_eval_count"] - message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] - agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] - - reply = response["message"]["content"] - message.response = reply - message.metadata["origin"] = agent.agent_type - # final_message = {"role": "assistant", "content": reply } - - # # history is provided to the LLM and should not have additional metadata - # llm_history.append(final_message) - - # user_history is provided to the REST API and does not include CONTEXT - # It does include metadata - # final_message["metadata"] = message.metadata - # user_history.append({**final_message, "origin": message.metadata["origin"]}) - - # Return the REST API with metadata - yield { - "status": "done", - "message": { - **message.model_dump(mode='json'), - } - } - - self.context.processing = False - return - - async def process_message(self, message:Message) -> AsyncGenerator[Message, None]: - message.full_content = "" - for i, p in enumerate(message.preamble.keys()): - message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n" - - # Estimate token length of new messages - message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content) - - message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." - message.status = "thinking" - yield message - - for value in self.generate_llm_response(message): - logging.info(f"LLM: {value.status} - {value.response}") - if value.status != "done": - yield value - if value.status == "error": - return - - def get_and_reset_content_seed(self): - tmp = self._content_seed - self._content_seed = "" - return tmp - - def set_content_seed(self, content: str) -> None: - """Set the content seed for the agent.""" - self._content_seed = content - - def get_content_seed(self) -> str: - """Get the content seed for the agent.""" - return self._content_seed - - @classmethod - def valid_agent_types(cls) -> set[str]: - """Return the set of valid agent_type values.""" - return set(get_args(cls.__annotations__["agent_type"])) - - -# Register the base agent -registry.register(Agent.agent_type, Agent) - -# Type alias for Agent or any subclass -AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses - -import ./agents \ No newline at end of file diff --git a/src/utils/agents/base.py b/src/utils/agents/base.py index 489673d..f77deb3 100644 --- a/src/utils/agents/base.py +++ b/src/utils/agents/base.py @@ -8,7 +8,7 @@ import logging if TYPE_CHECKING: from .. context import Context -from .types import AgentBase, ContextBase, registry +from .types import AgentBase, ContextRef, registry from .. conversation import Conversation from .. message import Message @@ -18,9 +18,18 @@ class Agent(AgentBase): Base class for all agent types. This class defines the common attributes and methods for all agent types. """ + + # Agent management with pydantic agent_type: Literal["base"] = "base" _agent_type: ClassVar[str] = agent_type # Add this for registration + # Agent properties + system_prompt: str # Mandatory + conversation: Conversation = Conversation() + context_tokens: int = 0 + context: ContextRef # Avoid circular reference + _content_seed: str = PrivateAttr(default="") + def __init_subclass__(cls, **kwargs): """Auto-register subclasses""" super().__init_subclass__(**kwargs) @@ -34,12 +43,8 @@ class Agent(AgentBase): data['agent_type'] = self.__class__.agent_type super().__init__(**data) - system_prompt: str # Mandatory - conversation: Conversation = Conversation() - context_tokens: int = 0 - context: ContextBase # Avoid circular reference - - _content_seed: str = PrivateAttr(default="") + def get_agent_type(self): + return self._agent_type async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: """ diff --git a/src/utils/agents/chat.py b/src/utils/agents/chat.py index 14475e2..eb1f492 100644 --- a/src/utils/agents/chat.py +++ b/src/utils/agents/chat.py @@ -4,7 +4,7 @@ from typing_extensions import Annotated from abc import ABC, abstractmethod from typing_extensions import Annotated import logging -from .base import Agent, ContextBase, registry +from .base import Agent, registry from .. conversation import Conversation from .. message import Message @@ -16,13 +16,6 @@ class Chat(Agent, ABC): agent_type: Literal["chat"] = "chat" _agent_type: ClassVar[str] = agent_type # Add this for registration - system_prompt: str # Mandatory - conversation: Conversation = Conversation() - context_tokens: int = 0 - context: ContextBase - - _content_seed: str = PrivateAttr(default="") - async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: """ Prepare message with context information in message.preamble diff --git a/src/utils/context.py b/src/utils/context.py index ecd4d8c..a9582ab 100644 --- a/src/utils/context.py +++ b/src/utils/context.py @@ -9,7 +9,7 @@ from .message import Message from .rag import ChromaDBFileWatcher from . import defines -from .agents import Agent, ContextBase +from .agents import Agent # Import only agent types, not actual classes if TYPE_CHECKING: @@ -17,7 +17,7 @@ if TYPE_CHECKING: from .agents import AnyAgent -class Context(ContextBase): +class Context(BaseModel): model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher id: str = Field( diff --git a/src/utils/message.py b/src/utils/message.py index 37c1f19..6647566 100644 --- a/src/utils/message.py +++ b/src/utils/message.py @@ -6,6 +6,10 @@ class Message(BaseModel): # Required prompt: str # Query to be answered + # Tunables + disable_rag: bool = False + disable_tools: bool = False + # Generated while processing message preamble: dict[str,str] = {} # Preamble to be prepended to the prompt system_prompt: str = "" # System prompt provided to the LLM