From e607e3a2f23dbb3f7d7f26ca28a63b53d2111581 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Wed, 30 Apr 2025 12:57:51 -0700 Subject: [PATCH] Starting to work again --- src/server.py | 165 +++++++------ src/utils/agents/base.py | 32 ++- src/utils/agents/chat.py | 345 ++++++++++++++-------------- src/utils/agents/fact_check.py | 18 +- src/utils/agents/job_description.py | 16 +- src/utils/agents/resume.py | 28 +-- src/utils/context.py | 37 ++- src/utils/message.py | 1 + 8 files changed, 351 insertions(+), 291 deletions(-) diff --git a/src/server.py b/src/server.py index 2a91789..2b9001d 100644 --- a/src/server.py +++ b/src/server.py @@ -1,5 +1,10 @@ +import os +os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" + import warnings warnings.filterwarnings("ignore", message="Overriding a previously registered kernel") +warnings.filterwarnings("ignore", message="Warning only once for all operators") +warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv") # %% # Imports [standard] @@ -37,6 +42,7 @@ try_import("sklearn") import ollama import requests from bs4 import BeautifulSoup +from contextlib import asynccontextmanager from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse from fastapi.middleware.cors import CORSMiddleware @@ -363,8 +369,23 @@ def llm_tools(tools): # %% class WebServer: + @asynccontextmanager + async def lifespan(self, app: FastAPI): + # Start the file watcher + self.observer, self.file_watcher = Rag.start_file_watcher( + llm=self.llm, + watch_directory=defines.doc_dir, + recreate=False # Don't recreate if exists + ) + logging.info(f"API started with {self.file_watcher.collection.count()} documents in the collection") + yield + if self.observer: + self.observer.stop() + self.observer.join() + logging.info("File watcher stopped") + def __init__(self, llm, model=MODEL_NAME): - self.app = FastAPI() + self.app = FastAPI(lifespan=self.lifespan) self.contexts = {} self.llm = llm self.model = model @@ -389,24 +410,6 @@ class WebServer: allow_headers=["*"], ) - @self.app.on_event("startup") - async def startup_event(): - # Start the file watcher - self.observer, self.file_watcher = Rag.start_file_watcher( - llm=llm, - watch_directory=defines.doc_dir, - recreate=False # Don't recreate if exists - ) - - print(f"API started with {self.file_watcher.collection.count()} documents in the collection") - - @self.app.on_event("shutdown") - async def shutdown_event(): - if self.observer: - self.observer.stop() - self.observer.join() - print("File watcher stopped") - self.setup_routes() def setup_routes(self): @@ -444,14 +447,16 @@ class WebServer: return JSONResponse(result) except Exception as e: - logging.error(e) + logging.error(f"put_umap error: {str(e)}") + import traceback + logging.error(traceback.format_exc()) return JSONResponse({"error": str(e)}, 500) @self.app.put("/api/similarity/{context_id}") async def put_similarity(context_id: str, request: Request): logging.info(f"{request.method} {request.url.path}") if not self.file_watcher: - return + raise Exception("File watcher not initialized") if not is_valid_uuid(context_id): logging.warning(f"Invalid context_id: {context_id}") @@ -471,13 +476,13 @@ class WebServer: return JSONResponse({"error": "No results found"}, status_code=404) chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape - print(f"Chroma embedding shape: {chroma_embedding.shape}") + logging.info(f"Chroma embedding shape: {chroma_embedding.shape}") umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() - print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output + logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() - print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output + logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output return JSONResponse({ **chroma_results, @@ -666,7 +671,7 @@ class WebServer: async def flush_generator(): async for message in self.generate_response(context=context, agent=agent, content=data["content"]): # Convert to JSON and add newline - yield json.dumps(message) + "\n" + yield (message.model_dump_json()) + "\n" # Save the history as its generated self.save_context(context_id) # Explicitly flush after each yield @@ -704,7 +709,9 @@ class WebServer: logging.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.") return agent.conversation except Exception as e: - logging.error(f"Error in get_history: {e}") + logging.error(f"get_history error: {str(e)}") + import traceback + logging.error(traceback.format_exc()) return JSONResponse({"error": str(e)}, status_code=404) @self.app.get("/api/tools/{context_id}") @@ -759,52 +766,73 @@ class WebServer: logging.info(f"Serve index.html for {path}") return FileResponse(os.path.join(defines.static_content, "index.html")) - def save_context(self, agent_id): + def save_context(self, context_id): """ Serialize a Python dictionary to a file in the agents directory. Args: data: Dictionary containing the agent data - agent_id: UUID string for the context. If it doesn't exist, it is created + context_id: UUID string for the context. If it doesn't exist, it is created Returns: - The agent_id used for the file + The context_id used for the file """ - context = self.upsert_context(agent_id) + context = self.upsert_context(context_id) # Create agents directory if it doesn't exist if not os.path.exists(defines.context_dir): os.makedirs(defines.context_dir) # Create the full file path - file_path = os.path.join(defines.context_dir, agent_id) + file_path = os.path.join(defines.context_dir, context_id) # Serialize the data to JSON and write to file with open(file_path, "w") as f: f.write(context.model_dump_json()) - return agent_id + return context_id - def load_context(self, agent_id) -> Context: + def load_or_create_context(self, context_id) -> Context: """ - Load a context from a file in the agents directory. + Load a context from a file in the context directory or create a new one if it doesn't exist. Args: - agent_id: UUID string for the context. If it doesn't exist, a new context is created. + context_id: UUID string for the context. Returns: A Context object with the specified ID and default settings. """ + if not self.file_watcher: + raise Exception("File watcher not initialized") - file_path = os.path.join(defines.context_dir, agent_id) + file_path = os.path.join(defines.context_dir, context_id) # Check if the file exists if not os.path.exists(file_path): - self.contexts[agent_id] = self.create_context(agent_id) + logging.info(f"Context file {file_path} not found. Creating new context.") + self.contexts[context_id] = self.create_context(context_id) else: # Read and deserialize the data with open(file_path, "r") as f: - self.contexts[agent_id] = Context.model_validate_json(f.read()) + content = f.read() + logging.info(f"Loading context from {file_path}, content length: {len(content)}") + try: + # Try parsing as JSON first to ensure valid JSON + import json + json_data = json.loads(content) + logging.info("JSON parsed successfully, attempting model validation") + + # Now try Pydantic validation + self.contexts[context_id] = Context.from_json(json_data, file_watcher=self.file_watcher) + logging.info(f"Successfully loaded context {context_id}") + except json.JSONDecodeError as e: + logging.error(f"Invalid JSON in file: {e}") + except Exception as e: + logging.error(f"Error validating context: {str(e)}") + import traceback + logging.error(traceback.format_exc()) + # Fallback to creating a new context + self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher) - return self.contexts[agent_id] + return self.contexts[context_id] def create_context(self, context_id = None) -> Context: """ @@ -814,7 +842,11 @@ class WebServer: Returns: A Context object with the specified ID and default settings. """ - context = Context(id=context_id) + if not self.file_watcher: + raise Exception("File watcher not initialized") + + logging.info(f"Creating new context with ID: {context_id}") + context = Context(id=context_id, file_watcher=self.file_watcher) if os.path.exists(defines.resume_doc): context.user_resume = open(defines.resume_doc, "r").read() @@ -911,42 +943,40 @@ class WebServer: if not context_id: logging.warning("No context ID provided. Creating a new context.") return self.create_context() - - if not is_valid_uuid(context_id): - logging.info(f"User requested invalid context_id: {context_id}") - raise ValueError("Invalid context_id: {context_id}") if context_id in self.contexts: return self.contexts[context_id] - logging.info(f"Context {context_id} not found. Creating new context.") - return self.load_context(context_id) + logging.info(f"Context {context_id} is not yet loaded.") + return self.load_or_create_context(context_id) def generate_rag_results(self, context, content): + if not self.file_watcher: + raise Exception("File watcher not initialized") + results_found = False - if self.file_watcher: - for rag in context.rags: - if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now... - yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."} - chroma_results = self.file_watcher.find_similar(query=content, top_k=10) - if chroma_results: - results_found = True - chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape - print(f"Chroma embedding shape: {chroma_embedding.shape}") + for rag in context.rags: + if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now... + yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."} + chroma_results = self.file_watcher.find_similar(query=content, top_k=10) + if chroma_results: + results_found = True + chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape + logging.info(f"Chroma embedding shape: {chroma_embedding.shape}") - umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() - print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output + umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() + logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output - umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() - print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output + umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() + logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output - yield { - **chroma_results, - "name": rag["name"], - "umap_embedding_2d": umap_2d, - "umap_embedding_3d": umap_3d - } + yield { + **chroma_results, + "name": rag["name"], + "umap_embedding_2d": umap_2d, + "umap_embedding_3d": umap_3d + } if not results_found: yield {"status": "complete", "message": "No RAG context found"} @@ -979,7 +1009,7 @@ class WebServer: # * Then Q&A of Fact Check async def generate_response(self, context : Context, agent : Agent, content : str): if not self.file_watcher: - return + raise Exception("File watcher not initialized") agent_type = agent.get_agent_type() logging.info(f"generate_response: {agent_type}") @@ -1014,7 +1044,8 @@ class WebServer: return logging.info("TODO: There is more to do...") return - + + return if self.processing: logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") diff --git a/src/utils/agents/base.py b/src/utils/agents/base.py index 7e06312..282e905 100644 --- a/src/utils/agents/base.py +++ b/src/utils/agents/base.py @@ -1,4 +1,5 @@ -from pydantic import BaseModel, Field, model_validator, PrivateAttr +from __future__ import annotations +from pydantic import BaseModel, model_validator, PrivateAttr, Field from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef from abc import ABC, abstractmethod from typing_extensions import Annotated @@ -8,8 +9,6 @@ import logging if TYPE_CHECKING: from .. context import Context -ContextRef = ForwardRef('Context') - from .types import AgentBase, registry from .. conversation import Conversation @@ -29,9 +28,11 @@ class Agent(AgentBase): system_prompt: str # Mandatory conversation: Conversation = Conversation() context_tokens: int = 0 - context: ContextRef # Avoid circular reference + context: object = Field(..., exclude=True) # Avoid circular reference, require as param, and prevent serialization + _content_seed: str = PrivateAttr(default="") + # Class and pydantic model management def __init_subclass__(cls, **kwargs): """Auto-register subclasses""" super().__init_subclass__(**kwargs) @@ -48,6 +49,24 @@ class Agent(AgentBase): self.__class__.model_rebuild() super().__init__(**data) + def model_dump(self, *args, **kwargs): + # Ensure context is always excluded, even with exclude_unset=True + kwargs.setdefault("exclude", set()) + if isinstance(kwargs["exclude"], set): + kwargs["exclude"].add("context") + elif isinstance(kwargs["exclude"], dict): + kwargs["exclude"]["context"] = True + return super().model_dump(*args, **kwargs) + + @classmethod + def valid_agent_types(cls) -> set[str]: + """Return the set of valid agent_type values.""" + return set(get_args(cls.__annotations__["agent_type"])) + + def set_context(self, context): + object.__setattr__(self, "context", context) + + # Agent methods def get_agent_type(self): return self._agent_type @@ -240,11 +259,6 @@ class Agent(AgentBase): """Get the content seed for the agent.""" return self._content_seed - @classmethod - def valid_agent_types(cls) -> set[str]: - """Return the set of valid agent_type values.""" - return set(get_args(cls.__annotations__["agent_type"])) - # Register the base agent registry.register(Agent._agent_type, Agent) diff --git a/src/utils/agents/chat.py b/src/utils/agents/chat.py index eb1f492..d248b2e 100644 --- a/src/utils/agents/chat.py +++ b/src/utils/agents/chat.py @@ -1,4 +1,5 @@ -from pydantic import BaseModel, Field, model_validator, PrivateAttr +from __future__ import annotations +from pydantic import BaseModel, model_validator, PrivateAttr from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar from typing_extensions import Annotated from abc import ABC, abstractmethod @@ -9,206 +10,206 @@ from .. conversation import Conversation from .. message import Message class Chat(Agent, ABC): + """ + Base class for all agent types. + This class defines the common attributes and methods for all agent types. + """ + agent_type: Literal["chat"] = "chat" + _agent_type: ClassVar[str] = agent_type # Add this for registration + + async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: """ - Base class for all agent types. - This class defines the common attributes and methods for all agent types. + Prepare message with context information in message.preamble """ - agent_type: Literal["chat"] = "chat" - _agent_type: ClassVar[str] = agent_type # Add this for registration + # Generate RAG content if enabled, based on the content + rag_context = "" + if not message.disable_rag: + # Gather RAG results, yielding each result + # as it becomes available + for value in self.context.generate_rag_results(message): + logging.info(f"RAG: {value.status} - {value.response}") + if value.status != "done": + yield value + if value.status == "error": + message.status = "error" + message.response = value.response + yield message + return + + if "rag" in message.metadata and message.metadata["rag"]: + for rag_collection in message.metadata["rag"]: + for doc in rag_collection["documents"]: + rag_context += f"{doc}\n" - async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: - """ - Prepare message with context information in message.preamble - """ - # Generate RAG content if enabled, based on the content - rag_context = "" - if not message.disable_rag: - # Gather RAG results, yielding each result - # as it becomes available - for value in self.context.generate_rag_results(message): - logging.info(f"RAG: {value.status} - {value.response}") - if value.status != "done": - yield value - if value.status == "error": - message.status = "error" - message.response = value.response - yield message - return - - if message.metadata["rag"]: - for rag_collection in message.metadata["rag"]: - for doc in rag_collection["documents"]: - rag_context += f"{doc}\n" + if rag_context: + message["context"] = rag_context - if rag_context: - message["context"] = rag_context + if self.context.user_resume: + message["resume"] = self.content.user_resume - if self.context.user_resume: - message["resume"] = self.content.user_resume - - if message.preamble: - preamble_types = [f"<|{p}|>" for p in message.preamble.keys()] - preamble_types_AND = " and ".join(preamble_types) - preamble_types_OR = " or ".join(preamble_types) - message.preamble["rules"] = f"""\ + if message.preamble: + preamble_types = [f"<|{p}|>" for p in message.preamble.keys()] + preamble_types_AND = " and ".join(preamble_types) + preamble_types_OR = " or ".join(preamble_types) + message.preamble["rules"] = f"""\ - Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly. - If there is no information in these sections, answer based on your knowledge. - Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}. """ - message.preamble["question"] = "Use that information to respond to:" - else: - message.preamble["question"] = "Respond to:" + message.preamble["question"] = "Use that information to respond to:" + else: + message.preamble["question"] = "Respond to:" - message.system_prompt = self.system_prompt - message.status = "done" + message.system_prompt = self.system_prompt + message.status = "done" + yield message + return + + async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]: + if self.context.processing: + logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") + message.status = "error" + message.response = "Busy processing another request." + yield message + return + + self.context.processing = True + + messages = [] + + for value in self.llm.chat( + model=self.model, + messages=messages, + #tools=llm_tools(context.tools) if message.enable_tools else None, + options={ "num_ctx": message.ctx_size } + ): + logging.info(f"LLM: {value.status} - {value.response}") + if value.status != "done": + message.status = value.status + message.response = value.response yield message + if value.status == "error": return + response = value - async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]: - if self.context.processing: - logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") - message.status = "error" - message.response = "Busy processing another request." - yield message - return + message.metadata["eval_count"] += response["eval_count"] + message.metadata["eval_duration"] += response["eval_duration"] + message.metadata["prompt_eval_count"] += response["prompt_eval_count"] + message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] + agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] - self.context.processing = True + tools_used = [] + + yield {"status": "processing", "message": "Initial response received..."} - messages = [] + if "tool_calls" in response.get("message", {}): + yield {"status": "processing", "message": "Processing tool calls..."} + + tool_message = response["message"] + tool_result = None + + # Process all yielded items from the handler + async for item in self.handle_tool_calls(tool_message): + if isinstance(item, tuple) and len(item) == 2: + # This is the final result tuple (tool_result, tools_used) + tool_result, tools_used = item + else: + # This is a status update, forward it + yield item + + message_dict = { + "role": tool_message.get("role", "assistant"), + "content": tool_message.get("content", "") + } + + if "tool_calls" in tool_message: + message_dict["tool_calls"] = [ + {"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}} + for tc in tool_message["tool_calls"] + ] - for value in self.llm.chat( - model=self.model, - messages=messages, - #tools=llm_tools(context.tools) if message.enable_tools else None, - options={ "num_ctx": message.ctx_size } - ): - logging.info(f"LLM: {value.status} - {value.response}") - if value.status != "done": - message.status = value.status - message.response = value.response - yield message - if value.status == "error": - return - response = value + pre_add_index = len(messages) + messages.append(message_dict) - message.metadata["eval_count"] += response["eval_count"] - message.metadata["eval_duration"] += response["eval_duration"] - message.metadata["prompt_eval_count"] += response["prompt_eval_count"] - message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] - agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] + if isinstance(tool_result, list): + messages.extend(tool_result) + else: + if tool_result: + messages.append(tool_result) - tools_used = [] - - yield {"status": "processing", "message": "Initial response received..."} + message.metadata["tools"] = tools_used - if "tool_calls" in response.get("message", {}): - yield {"status": "processing", "message": "Processing tool calls..."} - - tool_message = response["message"] - tool_result = None - - # Process all yielded items from the handler - async for item in self.handle_tool_calls(tool_message): - if isinstance(item, tuple) and len(item) == 2: - # This is the final result tuple (tool_result, tools_used) - tool_result, tools_used = item - else: - # This is a status update, forward it - yield item - - message_dict = { - "role": tool_message.get("role", "assistant"), - "content": tool_message.get("content", "") - } - - if "tool_calls" in tool_message: - message_dict["tool_calls"] = [ - {"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}} - for tc in tool_message["tool_calls"] - ] + # Estimate token length of new messages + ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:]) + yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size } + # Decrease creativity when processing tool call requests + response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 }) + message.metadata["eval_count"] += response["eval_count"] + message.metadata["eval_duration"] += response["eval_duration"] + message.metadata["prompt_eval_count"] += response["prompt_eval_count"] + message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] + agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] - pre_add_index = len(messages) - messages.append(message_dict) + reply = response["message"]["content"] + message.response = reply + message.metadata["origin"] = agent.agent_type + # final_message = {"role": "assistant", "content": reply } - if isinstance(tool_result, list): - messages.extend(tool_result) - else: - if tool_result: - messages.append(tool_result) + # # history is provided to the LLM and should not have additional metadata + # llm_history.append(final_message) - message.metadata["tools"] = tools_used + # user_history is provided to the REST API and does not include CONTEXT + # It does include metadata + # final_message["metadata"] = message.metadata + # user_history.append({**final_message, "origin": message.metadata["origin"]}) - # Estimate token length of new messages - ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:]) - yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size } - # Decrease creativity when processing tool call requests - response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 }) - message.metadata["eval_count"] += response["eval_count"] - message.metadata["eval_duration"] += response["eval_duration"] - message.metadata["prompt_eval_count"] += response["prompt_eval_count"] - message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] - agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] - - reply = response["message"]["content"] - message.response = reply - message.metadata["origin"] = agent.agent_type - # final_message = {"role": "assistant", "content": reply } - - # # history is provided to the LLM and should not have additional metadata - # llm_history.append(final_message) - - # user_history is provided to the REST API and does not include CONTEXT - # It does include metadata - # final_message["metadata"] = message.metadata - # user_history.append({**final_message, "origin": message.metadata["origin"]}) - - # Return the REST API with metadata - yield { - "status": "done", - "message": { - **message.model_dump(mode='json'), - } + # Return the REST API with metadata + yield { + "status": "done", + "message": { + **message.model_dump(mode='json'), } + } - self.context.processing = False - return - - async def process_message(self, message:Message) -> AsyncGenerator[Message, None]: - message.full_content = "" - for i, p in enumerate(message.preamble.keys()): - message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n" + self.context.processing = False + return + + async def process_message(self, message:Message) -> AsyncGenerator[Message, None]: + message.full_content = "" + for i, p in enumerate(message.preamble.keys()): + message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n" - # Estimate token length of new messages - message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content) - - message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." - message.status = "thinking" - yield message + # Estimate token length of new messages + message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content) + + message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." + message.status = "thinking" + yield message - for value in self.generate_llm_response(message): - logging.info(f"LLM: {value.status} - {value.response}") - if value.status != "done": - yield value - if value.status == "error": - return + for value in self.generate_llm_response(message): + logging.info(f"LLM: {value.status} - {value.response}") + if value.status != "done": + yield value + if value.status == "error": + return - def get_and_reset_content_seed(self): - tmp = self._content_seed - self._content_seed = "" - return tmp - - def set_content_seed(self, content: str) -> None: - """Set the content seed for the agent.""" - self._content_seed = content - - def get_content_seed(self) -> str: - """Get the content seed for the agent.""" - return self._content_seed + def get_and_reset_content_seed(self): + tmp = self._content_seed + self._content_seed = "" + return tmp + + def set_content_seed(self, content: str) -> None: + """Set the content seed for the agent.""" + self._content_seed = content + + def get_content_seed(self) -> str: + """Get the content seed for the agent.""" + return self._content_seed - @classmethod - def valid_agent_types(cls) -> set[str]: - """Return the set of valid agent_type values.""" - return set(get_args(cls.__annotations__["agent_type"])) + @classmethod + def valid_agent_types(cls) -> set[str]: + """Return the set of valid agent_type values.""" + return set(get_args(cls.__annotations__["agent_type"])) # Register the base agent registry.register(Chat._agent_type, Chat) diff --git a/src/utils/agents/fact_check.py b/src/utils/agents/fact_check.py index c6e5ff9..0800387 100644 --- a/src/utils/agents/fact_check.py +++ b/src/utils/agents/fact_check.py @@ -9,16 +9,16 @@ from .. conversation import Conversation from .. message import Message class FactCheck(Agent): - agent_type: Literal["fact_check"] = "fact_check" - _agent_type: ClassVar[str] = agent_type # Add this for registration + agent_type: Literal["fact_check"] = "fact_check" + _agent_type: ClassVar[str] = agent_type # Add this for registration - facts: str = "" - - @model_validator(mode="after") - def validate_facts(self): - if not self.facts.strip(): - raise ValueError("Facts cannot be empty") - return self + facts: str = "" + + @model_validator(mode="after") + def validate_facts(self): + if not self.facts.strip(): + raise ValueError("Facts cannot be empty") + return self # Register the base agent registry.register(FactCheck._agent_type, FactCheck) diff --git a/src/utils/agents/job_description.py b/src/utils/agents/job_description.py index 360f677..c5da2a1 100644 --- a/src/utils/agents/job_description.py +++ b/src/utils/agents/job_description.py @@ -9,16 +9,16 @@ from .. conversation import Conversation from .. message import Message class JobDescription(Agent): - agent_type: Literal["job_description"] = "job_description" - _agent_type: ClassVar[str] = agent_type # Add this for registration + agent_type: Literal["job_description"] = "job_description" + _agent_type: ClassVar[str] = agent_type # Add this for registration - job_description: str = "" + job_description: str = "" - @model_validator(mode="after") - def validate_job_description(self): - if not self.job_description.strip(): - raise ValueError("Job description cannot be empty") - return self + @model_validator(mode="after") + def validate_job_description(self): + if not self.job_description.strip(): + raise ValueError("Job description cannot be empty") + return self # Register the base agent registry.register(JobDescription._agent_type, JobDescription) diff --git a/src/utils/agents/resume.py b/src/utils/agents/resume.py index 635da5c..16fec73 100644 --- a/src/utils/agents/resume.py +++ b/src/utils/agents/resume.py @@ -9,24 +9,24 @@ from .. conversation import Conversation from .. message import Message class Resume(Agent): - agent_type: Literal["resume"] = "resume" - _agent_type: ClassVar[str] = agent_type # Add this for registration + agent_type: Literal["resume"] = "resume" + _agent_type: ClassVar[str] = agent_type # Add this for registration - resume: str = "" + resume: str = "" - @model_validator(mode="after") - def validate_resume(self): - if not self.resume.strip(): - raise ValueError("Resume content cannot be empty") - return self + @model_validator(mode="after") + def validate_resume(self): + if not self.resume.strip(): + raise ValueError("Resume content cannot be empty") + return self - def get_resume(self) -> str: - """Get the resume content.""" - return self.resume + def get_resume(self) -> str: + """Get the resume content.""" + return self.resume - def set_resume(self, resume: str) -> None: - """Set the resume content.""" - self.resume = resume + def set_resume(self, resume: str) -> None: + """Set the resume content.""" + self.resume = resume # Register the base agent registry.register(Resume._agent_type, Resume) diff --git a/src/utils/context.py b/src/utils/context.py index da5248d..d625c17 100644 --- a/src/utils/context.py +++ b/src/utils/context.py @@ -1,9 +1,11 @@ -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, model_validator, ValidationError from uuid import uuid4 from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING from typing_extensions import Annotated, Union import numpy as np import logging +from uuid import uuid4 +import re from .message import Message from .rag import ChromaDBFileWatcher @@ -13,22 +15,23 @@ from .agents import Agent # Import only agent types, not actual classes if TYPE_CHECKING: - from .agents import Agent, AnyAgent, Chat, Resume, JobDescription, FactCheck + from .agents import Agent, AnyAgent from .agents import AnyAgent +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + class Context(BaseModel): model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher + # Required fields + file_watcher: ChromaDBFileWatcher = Field(..., exclude=True) + # Optional fields id: str = Field( default_factory=lambda: str(uuid4()), pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" ) - - agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( - default_factory=list - ) - user_resume: Optional[str] = None user_job_description: Optional[str] = None user_facts: Optional[str] = None @@ -36,19 +39,29 @@ class Context(BaseModel): rags: List[dict] = [] message_history_length: int = 5 context_tokens: int = 0 - file_watcher: ChromaDBFileWatcher = Field(default=None, exclude=True) + # Class managed fields + agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( + default_factory=list + ) - def __init__(self, id: Optional[str] = None, **kwargs): - super().__init__(id=id if id is not None else str(uuid4()), **kwargs) + @classmethod + def from_json(cls, json_str: str, file_watcher: ChromaDBFileWatcher): + """Custom method to load from JSON with file_watcher injection""" + import json + data = json.loads(json_str) + return cls(file_watcher=file_watcher, **data) @model_validator(mode="after") def validate_unique_agent_types(self): """Ensure at most one agent per agent_type.""" + logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.") agent_types = [agent.agent_type for agent in self.agents] if len(agent_types) != len(set(agent_types)): raise ValueError("Context cannot contain multiple agents of the same agent_type") + for agent in self.agents: + agent.set_context(self) return self - + def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096): ctx = round(context + len(str(messages)) * 3 / 4) return max(defines.max_context, min(2048, ctx + ctx_buffer)) @@ -110,7 +123,7 @@ class Context(BaseModel): except Exception as e: message.response = f"Error generating RAG results: {str(e)}" message.status = "error" - logging.error(e) + logger.error(e) yield message return diff --git a/src/utils/message.py b/src/utils/message.py index 6647566..1fefee5 100644 --- a/src/utils/message.py +++ b/src/utils/message.py @@ -11,6 +11,7 @@ class Message(BaseModel): disable_tools: bool = False # Generated while processing message + status: str = "" # Status of the message preamble: dict[str,str] = {} # Preamble to be prepended to the prompt system_prompt: str = "" # System prompt provided to the LLM full_content: str = "" # Full content of the message (preamble + prompt)