From d1940e18e5ea0c853c33236604a462f20a533c36 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Wed, 30 Apr 2025 15:01:50 -0700 Subject: [PATCH] Starting to work again --- src/server.py | 192 +++++++++++++++++---------------------- src/utils/__init__.py | 36 ++++---- src/utils/agents/base.py | 32 +++---- src/utils/context.py | 23 ++--- src/utils/defines.py | 3 +- 5 files changed, 128 insertions(+), 158 deletions(-) diff --git a/src/server.py b/src/server.py index 2b9001d..319305b 100644 --- a/src/server.py +++ b/src/server.py @@ -1,10 +1,4 @@ -import os -os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" - -import warnings -warnings.filterwarnings("ignore", message="Overriding a previously registered kernel") -warnings.filterwarnings("ignore", message="Warning only once for all operators") -warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv") +from utils import logger # %% # Imports [standard] @@ -55,7 +49,8 @@ from utils import ( rag as Rag, Context, Conversation, Message, Agent, - defines + defines, + logger ) from tools import ( @@ -260,25 +255,6 @@ def parse_args(): default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}") return parser.parse_args() -def setup_logging(level): - global logging - - numeric_level = getattr(logging, level.upper(), None) - if not isinstance(numeric_level, int): - raise ValueError(f"Invalid log level: {level}") - - logging.basicConfig( - level=numeric_level, - format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True - ) - - # Now reduce verbosity for FastAPI, Uvicorn, Starlette - for noisy_logger in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi", "starlette"): - logging.getLogger(noisy_logger).setLevel(logging.WARNING) - - logging.info(f"Logging is set to {level} level.") # %% @@ -298,10 +274,10 @@ async def AnalyzeSite(llm, model: str, url : str, question : str): headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" } - logging.info(f"Fetching {url}") + logger.info(f"Fetching {url}") response = requests.get(url, headers=headers, timeout=10) response.raise_for_status() - logging.info(f"{url} returned. Processing...") + logger.info(f"{url} returned. Processing...") # Parse the HTML soup = BeautifulSoup(response.text, "html.parser") @@ -323,7 +299,7 @@ async def AnalyzeSite(llm, model: str, url : str, question : str): text = text[:max_chars] + "..." # Create Ollama client - # logging.info(f"Requesting summary of: {text}") + # logger.info(f"Requesting summary of: {text}") # Generate summary using Ollama prompt = f"CONTENTS:\n\n{text}\n\n{question}" @@ -331,7 +307,7 @@ async def AnalyzeSite(llm, model: str, url : str, question : str): system="You are given the contents of {url}. Answer the question about the contents", prompt=prompt) - #logging.info(response["response"]) + #logger.info(response["response"]) return { "source": "summarizer-llm", @@ -377,12 +353,12 @@ class WebServer: watch_directory=defines.doc_dir, recreate=False # Don't recreate if exists ) - logging.info(f"API started with {self.file_watcher.collection.count()} documents in the collection") + logger.info(f"API started with {self.file_watcher.collection.count()} documents in the collection") yield if self.observer: self.observer.stop() self.observer.join() - logging.info("File watcher stopped") + logger.info("File watcher stopped") def __init__(self, llm, model=MODEL_NAME): self.app = FastAPI(lifespan=self.lifespan) @@ -400,7 +376,7 @@ class WebServer: else: allow_origins=["http://battle-linux.ketrenos.com:3000"] - logging.info(f"Allowed origins: {allow_origins}") + logger.info(f"Allowed origins: {allow_origins}") self.app.add_middleware( CORSMiddleware, @@ -416,13 +392,13 @@ class WebServer: @self.app.get("/") async def root(): context = self.create_context() - logging.info(f"Redirecting non-context to {context.id}") + logger.info(f"Redirecting non-context to {context.id}") return RedirectResponse(url=f"/{context.id}", status_code=307) #return JSONResponse({"redirect": f"/{context.id}"}) @self.app.put("/api/umap/{context_id}") async def put_umap(context_id: str, request: Request): - logging.info(f"{request.method} {request.url.path}") + logger.info(f"{request.method} {request.url.path}") try: if not self.file_watcher: raise Exception("File watcher not initialized") @@ -436,10 +412,10 @@ class WebServer: dimensions = data.get("dimensions", 2) result = self.file_watcher.umap_collection if dimensions == 2: - logging.info("Returning 2D UMAP") + logger.info("Returning 2D UMAP") umap_embedding = self.file_watcher.umap_embedding_2d else: - logging.info("Returning 3D UMAP") + logger.info("Returning 3D UMAP") umap_embedding = self.file_watcher.umap_embedding_3d result["embeddings"] = umap_embedding.tolist() @@ -447,19 +423,19 @@ class WebServer: return JSONResponse(result) except Exception as e: - logging.error(f"put_umap error: {str(e)}") + logger.error(f"put_umap error: {str(e)}") import traceback - logging.error(traceback.format_exc()) + logger.error(traceback.format_exc()) return JSONResponse({"error": str(e)}, 500) @self.app.put("/api/similarity/{context_id}") async def put_similarity(context_id: str, request: Request): - logging.info(f"{request.method} {request.url.path}") + logger.info(f"{request.method} {request.url.path}") if not self.file_watcher: raise Exception("File watcher not initialized") if not is_valid_uuid(context_id): - logging.warning(f"Invalid context_id: {context_id}") + logger.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) try: @@ -476,13 +452,13 @@ class WebServer: return JSONResponse({"error": "No results found"}, status_code=404) chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape - logging.info(f"Chroma embedding shape: {chroma_embedding.shape}") + logger.info(f"Chroma embedding shape: {chroma_embedding.shape}") umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() - logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output + logger.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() - logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output + logger.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output return JSONResponse({ **chroma_results, @@ -492,14 +468,14 @@ class WebServer: }) except Exception as e: - logging.error(e) + logger.error(e) #return JSONResponse({"error": str(e)}, 500) @self.app.put("/api/reset/{context_id}/{agent_type}") async def put_reset(context_id: str, agent_type: str, request: Request): - logging.info(f"{request.method} {request.url.path}") + logger.info(f"{request.method} {request.url.path}") if not is_valid_uuid(context_id): - logging.warning(f"Invalid context_id: {context_id}") + logger.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) agent = context.get_agent(agent_type) @@ -512,7 +488,7 @@ class WebServer: for reset_operation in data["reset"]: match reset_operation: case "system_prompt": - logging.info(f"Resetting {reset_operation}") + logger.info(f"Resetting {reset_operation}") match agent_type: case "chat": prompt = system_message @@ -526,11 +502,11 @@ class WebServer: agent.system_prompt = prompt response["system_prompt"] = { "system_prompt": prompt } case "rags": - logging.info(f"Resetting {reset_operation}") + logger.info(f"Resetting {reset_operation}") context.rags = rags.copy() response["rags"] = context.rags case "tools": - logging.info(f"Resetting {reset_operation}") + logger.info(f"Resetting {reset_operation}") context.tools = default_tools(tools) response["tools"] = context.tools case "history": @@ -546,13 +522,13 @@ class WebServer: tmp = context.get_agent(mode) if not tmp: continue - logging.info(f"Resetting {reset_operation} for {mode}") + logger.info(f"Resetting {reset_operation} for {mode}") context.conversation = Conversation() context.context_tokens = round(len(str(agent.system_prompt)) * 3 / 4) # Estimate context usage response["history"] = [] response["context_used"] = agent.context_tokens case "message_history_length": - logging.info(f"Resetting {reset_operation}") + logger.info(f"Resetting {reset_operation}") context.message_history_length = DEFAULT_HISTORY_LENGTH response["message_history_length"] = DEFAULT_HISTORY_LENGTH @@ -567,7 +543,7 @@ class WebServer: @self.app.put("/api/tunables/{context_id}") async def put_tunables(context_id: str, request: Request): - logging.info(f"{request.method} {request.url.path}") + logger.info(f"{request.method} {request.url.path}") try: context = self.upsert_context(context_id) @@ -619,14 +595,14 @@ class WebServer: case _: return JSONResponse({ "error": f"Unrecognized tunable {k}"}, status_code=404) except Exception as e: - logging.error(f"Error in put_tunables: {e}") + logger.error(f"Error in put_tunables: {e}") return JSONResponse({"error": str(e)}, status_code=500) @self.app.get("/api/tunables/{context_id}") async def get_tunables(context_id: str, request: Request): - logging.info(f"{request.method} {request.url.path}") + logger.info(f"{request.method} {request.url.path}") if not is_valid_uuid(context_id): - logging.warning(f"Invalid context_id: {context_id}") + logger.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) agent = context.get_agent("chat") @@ -644,15 +620,15 @@ class WebServer: @self.app.get("/api/system-info/{context_id}") async def get_system_info(context_id: str, request: Request): - logging.info(f"{request.method} {request.url.path}") + logger.info(f"{request.method} {request.url.path}") return JSONResponse(system_info(self.model)) @self.app.post("/api/chat/{context_id}/{agent_type}") async def post_chat_endpoint(context_id: str, agent_type: str, request: Request): - logging.info(f"{request.method} {request.url.path}") + logger.info(f"{request.method} {request.url.path}") try: if not is_valid_uuid(context_id): - logging.warning(f"Invalid context_id: {context_id}") + logger.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) @@ -660,11 +636,11 @@ class WebServer: data = await request.json() agent = context.get_agent(agent_type) if not agent and agent_type == "job_description": - logging.info(f"Agent {agent_type} not found. Returning empty history.") + logger.info(f"Agent {agent_type} not found. Returning empty history.") # Create a new agent if it doesn't exist agent = context.get_or_create_agent("job_description", system_prompt=system_generate_resume, job_description=data["content"]) except Exception as e: - logging.info(f"Attempt to create agent type: {agent_type} failed", e) + logger.info(f"Attempt to create agent type: {agent_type} failed", e) return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404) # Create a custom generator that ensures flushing @@ -688,43 +664,43 @@ class WebServer: } ) except Exception as e: - logging.error(f"Error in post_chat_endpoint: {e}") + logger.error(f"Error in post_chat_endpoint: {e}") return JSONResponse({"error": str(e)}, status_code=500) @self.app.post("/api/context") async def create_context(): context = self.create_context() - logging.info(f"Generated new agent as {context.id}") + logger.info(f"Generated new agent as {context.id}") return JSONResponse({ "id": context.id }) @self.app.get("/api/history/{context_id}/{agent_type}") async def get_history(context_id: str, agent_type: str, request: Request): - logging.info(f"{request.method} {request.url.path}") + logger.info(f"{request.method} {request.url.path}") try: context = self.upsert_context(context_id) agent = context.get_agent(agent_type) if not agent: - logging.info(f"Agent {agent_type} not found. Returning empty history.") + logger.info(f"Agent {agent_type} not found. Returning empty history.") return JSONResponse({ "messages": [] }) - logging.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.") + logger.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.") return agent.conversation except Exception as e: - logging.error(f"get_history error: {str(e)}") + logger.error(f"get_history error: {str(e)}") import traceback - logging.error(traceback.format_exc()) + logger.error(traceback.format_exc()) return JSONResponse({"error": str(e)}, status_code=404) @self.app.get("/api/tools/{context_id}") async def get_tools(context_id: str, request: Request): - logging.info(f"{request.method} {request.url.path}") + logger.info(f"{request.method} {request.url.path}") context = self.upsert_context(context_id) return JSONResponse(context.tools) @self.app.put("/api/tools/{context_id}") async def put_tools(context_id: str, request: Request): - logging.info(f"{request.method} {request.url.path}") + logger.info(f"{request.method} {request.url.path}") if not is_valid_uuid(context_id): - logging.warning(f"Invalid context_id: {context_id}") + logger.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) try: @@ -743,9 +719,9 @@ class WebServer: @self.app.get("/api/context-status/{context_id}/{agent_type}") async def get_context_status(context_id, agent_type: str, request: Request): - logging.info(f"{request.method} {request.url.path}") + logger.info(f"{request.method} {request.url.path}") if not is_valid_uuid(context_id): - logging.warning(f"Invalid context_id: {context_id}") + logger.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) agent = context.get_agent(agent_type) @@ -761,9 +737,9 @@ class WebServer: async def serve_static(path: str): full_path = os.path.join(defines.static_content, path) if os.path.exists(full_path) and os.path.isfile(full_path): - logging.info(f"Serve static request for {full_path}") + logger.info(f"Serve static request for {full_path}") return FileResponse(full_path) - logging.info(f"Serve index.html for {path}") + logger.info(f"Serve index.html for {path}") return FileResponse(os.path.join(defines.static_content, "index.html")) def save_context(self, context_id): @@ -807,28 +783,28 @@ class WebServer: # Check if the file exists if not os.path.exists(file_path): - logging.info(f"Context file {file_path} not found. Creating new context.") + logger.info(f"Context file {file_path} not found. Creating new context.") self.contexts[context_id] = self.create_context(context_id) else: # Read and deserialize the data with open(file_path, "r") as f: content = f.read() - logging.info(f"Loading context from {file_path}, content length: {len(content)}") + logger.info(f"Loading context from {file_path}, content length: {len(content)}") try: # Try parsing as JSON first to ensure valid JSON import json json_data = json.loads(content) - logging.info("JSON parsed successfully, attempting model validation") + logger.info("JSON parsed successfully, attempting model validation") # Now try Pydantic validation self.contexts[context_id] = Context.from_json(json_data, file_watcher=self.file_watcher) - logging.info(f"Successfully loaded context {context_id}") + logger.info(f"Successfully loaded context {context_id}") except json.JSONDecodeError as e: - logging.error(f"Invalid JSON in file: {e}") + logger.error(f"Invalid JSON in file: {e}") except Exception as e: - logging.error(f"Error validating context: {str(e)}") + logger.error(f"Error validating context: {str(e)}") import traceback - logging.error(traceback.format_exc()) + logger.error(traceback.format_exc()) # Fallback to creating a new context self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher) @@ -845,7 +821,7 @@ class WebServer: if not self.file_watcher: raise Exception("File watcher not initialized") - logging.info(f"Creating new context with ID: {context_id}") + logger.info(f"Creating new context with ID: {context_id}") context = Context(id=context_id, file_watcher=self.file_watcher) if os.path.exists(defines.resume_doc): @@ -859,7 +835,7 @@ class WebServer: context.tools = default_tools(tools) context.rags = rags.copy() - logging.info(f"{context.id} created and added to contexts.") + logger.info(f"{context.id} created and added to contexts.") self.contexts[context.id] = context self.save_context(context.id) return context @@ -941,13 +917,13 @@ class WebServer: """ if not context_id: - logging.warning("No context ID provided. Creating a new context.") + logger.warning("No context ID provided. Creating a new context.") return self.create_context() if context_id in self.contexts: return self.contexts[context_id] - logging.info(f"Context {context_id} is not yet loaded.") + logger.info(f"Context {context_id} is not yet loaded.") return self.load_or_create_context(context_id) def generate_rag_results(self, context, content): @@ -963,13 +939,13 @@ class WebServer: if chroma_results: results_found = True chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape - logging.info(f"Chroma embedding shape: {chroma_embedding.shape}") + logger.info(f"Chroma embedding shape: {chroma_embedding.shape}") umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() - logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output + logger.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() - logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output + logger.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output yield { **chroma_results, @@ -1012,11 +988,11 @@ class WebServer: raise Exception("File watcher not initialized") agent_type = agent.get_agent_type() - logging.info(f"generate_response: {agent_type}") + logger.info(f"generate_response: {agent_type}") if agent_type == "chat": message = Message(prompt=content) async for value in agent.prepare_message(message): - logging.info(f"{agent_type}.prepare_message: {value.status} - {value.response}") + logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}") if value.status != "done": yield value if value.status == "error": @@ -1025,7 +1001,7 @@ class WebServer: yield message return async for value in agent.process_message(message): - logging.info(f"{agent_type}.process_message: {value.status} - {value.response}") + logger.info(f"{agent_type}.process_message: {value.status} - {value.response}") if value.status != "done": yield value if value.status == "error": @@ -1034,7 +1010,7 @@ class WebServer: yield message return async for value in agent.generate_llm_response(message): - logging.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}") + logger.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}") if value.status != "done": yield value if value.status == "error": @@ -1042,13 +1018,13 @@ class WebServer: message.response = value.response yield message return - logging.info("TODO: There is more to do...") + logger.info("TODO: There is more to do...") return return if self.processing: - logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") + logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") yield {"status": "error", "message": "Busy processing another request."} return @@ -1083,11 +1059,11 @@ class WebServer: process_type = agent.get_agent_type() match process_type: case "job_description": - logging.info(f"job_description user_history len: {len(conversation.messages)}") + logger.info(f"job_description user_history len: {len(conversation.messages)}") if len(conversation.messages) >= 2: # USER, ASSISTANT process_type = "chat" case "resume": - logging.info(f"resume user_history len: {len(conversation.messages)}") + logger.info(f"resume user_history len: {len(conversation.messages)}") if len(conversation.messages) >= 3: # USER, ASSISTANT, FACT_CHECK process_type = "chat" case "fact_check": @@ -1098,7 +1074,7 @@ class WebServer: case "chat": if not message.prompt: yield {"status": "error", "message": "No query provided for chat."} - logging.info(f"user_history len: {len(conversation.messages)}") + logger.info(f"user_history len: {len(conversation.messages)}") self.processing = False return @@ -1207,8 +1183,8 @@ Use to the above information to respond to this prompt: message.add_action("generate_resume") - logging.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into agent'resume'") - logging.info("TODO: For subsequent runs, have the Agent handler generate the follow up prompts so they can have correct context preamble") + logger.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into agent'resume'") + logger.info("TODO: For subsequent runs, have the Agent handler generate the follow up prompts so they can have correct context preamble") # Switch to resume agent for LLM responses # message.metadata["origin"] = "resume" @@ -1288,11 +1264,11 @@ Use the above <|resume|> and <|job_description|> to answer this query: stuffingMessage.metadata["origin"] = "resume" stuffingMessage.metadata["display"] = "hide" stuffingMessage.actions = [ "fact_check" ] - logging.info("TODO: Switch this to use actions to keep the UI from showingit") + logger.info("TODO: Switch this to use actions to keep the UI from showingit") conversation.add_message(stuffingMessage) # For all future calls to job_description, use the system_job_description - logging.info("TODO: Create a system_resume_QA prompt to use for the resume agent") + logger.info("TODO: Create a system_resume_QA prompt to use for the resume agent") agent.system_prompt = system_prompt # Switch to fact_check agent for LLM responses @@ -1364,7 +1340,7 @@ Use the above <|resume|> and <|job_description|> to answer this query: else: response = self.llm.chat(model=self.model, messages=messages, options={ "num_ctx": ctx_size }) except Exception as e: - logging.exception({ "model": self.model, "error": str(e) }) + logger.exception({ "model": self.model, "error": str(e) }) yield {"status": "error", "message": f"An error occurred communicating with LLM"} self.processing = False return @@ -1449,7 +1425,7 @@ Use the above <|resume|> and <|job_description|> to answer this query: } # except Exception as e: - # logging.exception({ "model": self.model, "origin": agent_type, "content": content, "error": str(e) }) + # logger.exception({ "model": self.model, "origin": agent_type, "content": content, "error": str(e) }) # yield {"status": "error", "message": f"An error occurred: {str(e)}"} # finally: @@ -1460,7 +1436,7 @@ Use the above <|resume|> and <|job_description|> to answer this query: def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs): try: if self.ssl_enabled: - logging.info(f"Starting web server at https://{host}:{port}") + logger.info(f"Starting web server at https://{host}:{port}") uvicorn.run( self.app, host=host, @@ -1470,7 +1446,7 @@ Use the above <|resume|> and <|job_description|> to answer this query: ssl_certfile=defines.cert_path ) else: - logging.info(f"Starting web server at http://{host}:{port}") + logger.info(f"Starting web server at http://{host}:{port}") uvicorn.run( self.app, host=host, @@ -1493,7 +1469,7 @@ def main(): args = parse_args() # Setup logging based on the provided level - setup_logging(args.level) + logger.setLevel(args.level) warnings.filterwarnings( "ignore", diff --git a/src/utils/__init__.py b/src/utils/__init__.py index eea4e13..a06161c 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,12 +1,12 @@ from typing import Optional, Type from . import defines -from .rag import ChromaDBFileWatcher, start_file_watcher -from .message import Message -from .conversation import Conversation -from .context import Context +from . rag import ChromaDBFileWatcher, start_file_watcher +from . message import Message +from . conversation import Conversation +from . context import Context from . import agents -import logging +from . setup_logging import setup_logging from .agents import Agent, __all__ as agents_all @@ -17,13 +17,14 @@ __all__ = [ 'Message', 'ChromaDBFileWatcher', 'start_file_watcher' + 'logger', ] + agents_all # Resolve circular dependencies by rebuilding models # Call model_rebuild() on Agent and Context Agent.model_rebuild() Context.model_rebuild() -import logging + import importlib from pydantic import BaseModel from typing import Type @@ -31,19 +32,20 @@ from typing import Type # Assuming class_registry is available from agents/__init__.py from .agents import class_registry, AnyAgent +logger = setup_logging(level=defines.logging_level) + def rebuild_models(): - Context.model_rebuild() for class_name, (module_name, _) in class_registry.items(): try: module = importlib.import_module(module_name) cls = getattr(module, class_name, None) - logging.info(f"Checking: {class_name} in module {module_name}") - logging.info(f" cls: {True if cls else False}") - logging.info(f" isinstance(cls, type): {isinstance(cls, type)}") - logging.info(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}") - logging.info(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}") - logging.info(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}") + logger.info(f"Checking: {class_name} in module {module_name}") + logger.info(f" cls: {True if cls else False}") + logger.info(f" isinstance(cls, type): {isinstance(cls, type)}") + logger.info(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}") + logger.info(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}") + logger.info(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}") if ( cls @@ -52,12 +54,14 @@ def rebuild_models(): and issubclass(cls, AnyAgent) and cls is not AnyAgent ): - logging.info(f"Rebuilding {class_name} from {module_name}") + logger.info(f"Rebuilding {class_name} from {module_name}") + from . agents import Agent + from . context import Context cls.model_rebuild() except ImportError as e: - logging.error(f"Failed to import module {module_name}: {e}") + logger.error(f"Failed to import module {module_name}: {e}") except Exception as e: - logging.error(f"Error processing {class_name} in {module_name}: {e}") + logger.error(f"Error processing {class_name} in {module_name}: {e}") # Call this after all modules are imported rebuild_models() \ No newline at end of file diff --git a/src/utils/agents/base.py b/src/utils/agents/base.py index 282e905..2bf3b55 100644 --- a/src/utils/agents/base.py +++ b/src/utils/agents/base.py @@ -3,18 +3,20 @@ from pydantic import BaseModel, model_validator, PrivateAttr, Field from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef from abc import ABC, abstractmethod from typing_extensions import Annotated -import logging +from .. setup_logging import setup_logging + +logger = setup_logging() # Only import Context for type checking if TYPE_CHECKING: from .. context import Context -from .types import AgentBase, registry +from .types import registry from .. conversation import Conversation from .. message import Message -class Agent(AgentBase): +class Agent(BaseModel, ABC): """ Base class for all agent types. This class defines the common attributes and methods for all agent types. @@ -28,27 +30,19 @@ class Agent(AgentBase): system_prompt: str # Mandatory conversation: Conversation = Conversation() context_tokens: int = 0 - context: object = Field(..., exclude=True) # Avoid circular reference, require as param, and prevent serialization + context: Optional[Context] = Field(default=None, exclude=True) # Avoid circular reference, require as param, and prevent serialization _content_seed: str = PrivateAttr(default="") # Class and pydantic model management def __init_subclass__(cls, **kwargs): """Auto-register subclasses""" + logger.info(f"Agent.__init_subclass__({kwargs})") super().__init_subclass__(**kwargs) # Register this class if it has an agent_type - if hasattr(cls, 'agent_type') and cls.agent_type != AgentBase._agent_type: + if hasattr(cls, 'agent_type') and cls.agent_type != Agent._agent_type: registry.register(cls.agent_type, cls) - def __init__(self, **data): - # Set agent_type from class if not provided - if 'agent_type' not in data: - data['agent_type'] = self.__class__.agent_type - from .. context import Context - Context.model_rebuild() - self.__class__.model_rebuild() - super().__init__(**data) - def model_dump(self, *args, **kwargs): # Ensure context is always excluded, even with exclude_unset=True kwargs.setdefault("exclude", set()) @@ -65,7 +59,7 @@ class Agent(AgentBase): def set_context(self, context): object.__setattr__(self, "context", context) - + # Agent methods def get_agent_type(self): return self._agent_type @@ -80,7 +74,7 @@ class Agent(AgentBase): # Gather RAG results, yielding each result # as it becomes available for value in self.context.generate_rag_results(message): - logging.info(f"RAG: {value.status} - {value.response}") + logger.info(f"RAG: {value.status} - {value.response}") if value.status != "done": yield value if value.status == "error": @@ -120,7 +114,7 @@ class Agent(AgentBase): async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]: if self.context.processing: - logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") + logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") message.status = "error" message.response = "Busy processing another request." yield message @@ -136,7 +130,7 @@ class Agent(AgentBase): #tools=llm_tools(context.tools) if message.enable_tools else None, options={ "num_ctx": message.ctx_size } ): - logging.info(f"LLM: {value.status} - {value.response}") + logger.info(f"LLM: {value.status} - {value.response}") if value.status != "done": message.status = value.status message.response = value.response @@ -240,7 +234,7 @@ class Agent(AgentBase): yield message for value in self.generate_llm_response(message): - logging.info(f"LLM: {value.status} - {value.response}") + logger.info(f"LLM: {value.status} - {value.response}") if value.status != "done": yield value if value.status == "error": diff --git a/src/utils/context.py b/src/utils/context.py index d625c17..b0e11a3 100644 --- a/src/utils/context.py +++ b/src/utils/context.py @@ -1,3 +1,4 @@ +from __future__ import annotations from pydantic import BaseModel, Field, model_validator, ValidationError from uuid import uuid4 from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING @@ -11,12 +12,6 @@ from .message import Message from .rag import ChromaDBFileWatcher from . import defines -from .agents import Agent - -# Import only agent types, not actual classes -if TYPE_CHECKING: - from .agents import Agent, AnyAgent - from .agents import AnyAgent logging.basicConfig(level=logging.INFO) @@ -25,7 +20,7 @@ logger = logging.getLogger(__name__) class Context(BaseModel): model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher # Required fields - file_watcher: ChromaDBFileWatcher = Field(..., exclude=True) + file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True) # Optional fields id: str = Field( @@ -44,15 +39,14 @@ class Context(BaseModel): default_factory=list ) + @model_validator(mode="before") @classmethod - def from_json(cls, json_str: str, file_watcher: ChromaDBFileWatcher): - """Custom method to load from JSON with file_watcher injection""" - import json - data = json.loads(json_str) - return cls(file_watcher=file_watcher, **data) + def before_model_validator(cls, values: Any): + logger.info(f"Preparing model data: {cls} {values}") + return values @model_validator(mode="after") - def validate_unique_agent_types(self): + def after_model_validator(self): """Ensure at most one agent per agent_type.""" logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.") agent_types = [agent.agent_type for agent in self.agents] @@ -190,5 +184,6 @@ class Context(BaseModel): elif agent.agent_type == "chat": summary += f"\nChat Name: {agent.name}\n" return summary - + +from . agents import Agent Context.model_rebuild() \ No newline at end of file diff --git a/src/utils/defines.py b/src/utils/defines.py index 97f061b..2cee552 100644 --- a/src/utils/defines.py +++ b/src/utils/defines.py @@ -14,4 +14,5 @@ static_content = "/opt/backstory/frontend/deployed" resume_doc = "/opt/backstory/docs/resume/generic.md" # Only used for testing; backstory-prod will not use this key_path = "/opt/backstory/keys/key.pem" -cert_path = "/opt/backstory/keys/cert.pem" \ No newline at end of file +cert_path = "/opt/backstory/keys/cert.pem" +logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper() \ No newline at end of file