From 3094288e46756fd6f4a2b5b9daf3ce85e1284252 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Wed, 30 Apr 2025 16:05:46 -0700 Subject: [PATCH] onion peeling --- src/server.py | 12 ++++-- src/utils/__init__.py | 14 +++---- src/utils/agents/base.py | 1 - src/utils/agents/chat.py | 15 +++++--- src/utils/context.py | 80 ++++++++++++++++++++-------------------- src/utils/message.py | 2 +- 6 files changed, 67 insertions(+), 57 deletions(-) diff --git a/src/server.py b/src/server.py index 319305b..cede209 100644 --- a/src/server.py +++ b/src/server.py @@ -1,5 +1,7 @@ from utils import logger +from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar + # %% # Imports [standard] # Standard library modules (no try-except needed) @@ -647,7 +649,7 @@ class WebServer: async def flush_generator(): async for message in self.generate_response(context=context, agent=agent, content=data["content"]): # Convert to JSON and add newline - yield (message.model_dump_json()) + "\n" + yield str(message) + "\n" # Save the history as its generated self.save_context(context_id) # Explicitly flush after each yield @@ -797,7 +799,9 @@ class WebServer: logger.info("JSON parsed successfully, attempting model validation") # Now try Pydantic validation - self.contexts[context_id] = Context.from_json(json_data, file_watcher=self.file_watcher) + self.contexts[context_id] = Context.model_validate_json(content) + self.contexts[context_id].file_watcher=self.file_watcher + logger.info(f"Successfully loaded context {context_id}") except json.JSONDecodeError as e: logger.error(f"Invalid JSON in file: {e}") @@ -983,7 +987,7 @@ class WebServer: # * First message sets Fact Check and is Q&A # * Has content # * Then Q&A of Fact Check - async def generate_response(self, context : Context, agent : Agent, content : str): + async def generate_response(self, context : Context, agent : Agent, content : str) -> Generator[Message, Any, None]: if not self.file_watcher: raise Exception("File watcher not initialized") @@ -1469,7 +1473,7 @@ def main(): args = parse_args() # Setup logging based on the provided level - logger.setLevel(args.level) + logger.setLevel(args.level.upper()) warnings.filterwarnings( "ignore", diff --git a/src/utils/__init__.py b/src/utils/__init__.py index a06161c..1bbcc4f 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -40,12 +40,12 @@ def rebuild_models(): module = importlib.import_module(module_name) cls = getattr(module, class_name, None) - logger.info(f"Checking: {class_name} in module {module_name}") - logger.info(f" cls: {True if cls else False}") - logger.info(f" isinstance(cls, type): {isinstance(cls, type)}") - logger.info(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}") - logger.info(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}") - logger.info(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}") + logger.debug(f"Checking: {class_name} in module {module_name}") + logger.debug(f" cls: {True if cls else False}") + logger.debug(f" isinstance(cls, type): {isinstance(cls, type)}") + logger.debug(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}") + logger.debug(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}") + logger.debug(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}") if ( cls @@ -54,7 +54,7 @@ def rebuild_models(): and issubclass(cls, AnyAgent) and cls is not AnyAgent ): - logger.info(f"Rebuilding {class_name} from {module_name}") + logger.debug(f"Rebuilding {class_name} from {module_name}") from . agents import Agent from . context import Context cls.model_rebuild() diff --git a/src/utils/agents/base.py b/src/utils/agents/base.py index 2bf3b55..80e55a2 100644 --- a/src/utils/agents/base.py +++ b/src/utils/agents/base.py @@ -37,7 +37,6 @@ class Agent(BaseModel, ABC): # Class and pydantic model management def __init_subclass__(cls, **kwargs): """Auto-register subclasses""" - logger.info(f"Agent.__init_subclass__({kwargs})") super().__init_subclass__(**kwargs) # Register this class if it has an agent_type if hasattr(cls, 'agent_type') and cls.agent_type != Agent._agent_type: diff --git a/src/utils/agents/chat.py b/src/utils/agents/chat.py index d248b2e..f8b479c 100644 --- a/src/utils/agents/chat.py +++ b/src/utils/agents/chat.py @@ -21,6 +21,9 @@ class Chat(Agent, ABC): """ Prepare message with context information in message.preamble """ + if not self.context: + raise ValueError("Context is not set for this agent.") + # Generate RAG content if enabled, based on the content rag_context = "" if not message.disable_rag: @@ -37,22 +40,24 @@ class Chat(Agent, ABC): return if "rag" in message.metadata and message.metadata["rag"]: - for rag_collection in message.metadata["rag"]: - for doc in rag_collection["documents"]: + for rag in message.metadata["rag"]: + for doc in rag["documents"]: rag_context += f"{doc}\n" + message.preamble = {} + if rag_context: - message["context"] = rag_context + message.preamble["context"] = rag_context if self.context.user_resume: - message["resume"] = self.content.user_resume + message.preamble["resume"] = self.context.user_resume if message.preamble: preamble_types = [f"<|{p}|>" for p in message.preamble.keys()] preamble_types_AND = " and ".join(preamble_types) preamble_types_OR = " or ".join(preamble_types) message.preamble["rules"] = f"""\ -- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly. +- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly. - If there is no information in these sections, answer based on your knowledge. - Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}. """ diff --git a/src/utils/context.py b/src/utils/context.py index b0e11a3..c2d5c4b 100644 --- a/src/utils/context.py +++ b/src/utils/context.py @@ -39,11 +39,11 @@ class Context(BaseModel): default_factory=list ) - @model_validator(mode="before") - @classmethod - def before_model_validator(cls, values: Any): - logger.info(f"Preparing model data: {cls} {values}") - return values + # @model_validator(mode="before") + # @classmethod + # def before_model_validator(cls, values: Any): + # logger.info(f"Preparing model data: {cls} {values}") + # return values @model_validator(mode="after") def after_model_validator(self): @@ -71,49 +71,51 @@ class Context(BaseModel): A list of dictionaries containing the RAG results. """ try: - message.status = "processing" + message.status = "processing" - entries : int = 0 + entries : int = 0 - if not self.file_watcher: - message.response = "No RAG context available." - del message.metadata["rag"] - message.status = "done" - yield message - return + if not self.file_watcher: + message.response = "No RAG context available." + del message.metadata["rag"] + message.status = "done" + yield message + return - for rag in self.rags: - if not rag["enabled"]: - continue - message.response = f"Checking RAG context {rag['name']}..." - yield message - chroma_results = self.file_watcher.find_similar(query=message.prompt, top_k=10) - if chroma_results: - entries += len(chroma_results["documents"]) + message.metadata["rag"] = [] + for rag in self.rags: + if not rag["enabled"]: + continue + message.response = f"Checking RAG context {rag['name']}..." + yield message + chroma_results = self.file_watcher.find_similar(query=message.prompt, top_k=10) + if chroma_results: + entries += len(chroma_results["documents"]) - chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape - print(f"Chroma embedding shape: {chroma_embedding.shape}") + chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape + print(f"Chroma embedding shape: {chroma_embedding.shape}") - umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() - print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output + umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() + print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output - umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() - print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output + umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() + print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output - message.metadata["rag"][rag["name"]] = { - **chroma_results, - "umap_embedding_2d": umap_2d, - "umap_embedding_3d": umap_3d - } - yield message + message.metadata["rag"].append({ + "name": rag["name"], + **chroma_results, + "umap_embedding_2d": umap_2d, + "umap_embedding_3d": umap_3d + }) + yield message - if entries == 0: - del message.metadata["rag"] + if entries == 0: + del message.metadata["rag"] - message.response = f"RAG context gathered from results from {entries} documents." - message.status = "done" - yield message - return + message.response = f"RAG context gathered from results from {entries} documents." + message.status = "done" + yield message + return except Exception as e: message.response = f"Error generating RAG results: {str(e)}" message.status = "error" diff --git a/src/utils/message.py b/src/utils/message.py index 1fefee5..8f5688b 100644 --- a/src/utils/message.py +++ b/src/utils/message.py @@ -17,7 +17,7 @@ class Message(BaseModel): full_content: str = "" # Full content of the message (preamble + prompt) response: str = "" # LLM response to the preamble + query metadata: dict[str, Any] = { - "rag": { "documents": [] }, + "rag": List[dict[str, Any]], "tools": [], "eval_count": 0, "eval_duration": 0,