Starting to work again

This commit is contained in:
James Ketr 2025-04-30 15:01:50 -07:00
parent e607e3a2f2
commit d1940e18e5
5 changed files with 128 additions and 158 deletions

View File

@ -1,10 +1,4 @@
import os from utils import logger
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
import warnings
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
warnings.filterwarnings("ignore", message="Warning only once for all operators")
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
# %% # %%
# Imports [standard] # Imports [standard]
@ -55,7 +49,8 @@ from utils import (
rag as Rag, rag as Rag,
Context, Conversation, Message, Context, Conversation, Message,
Agent, Agent,
defines defines,
logger
) )
from tools import ( from tools import (
@ -260,25 +255,6 @@ def parse_args():
default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}") default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}")
return parser.parse_args() return parser.parse_args()
def setup_logging(level):
global logging
numeric_level = getattr(logging, level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f"Invalid log level: {level}")
logging.basicConfig(
level=numeric_level,
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
force=True
)
# Now reduce verbosity for FastAPI, Uvicorn, Starlette
for noisy_logger in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi", "starlette"):
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
logging.info(f"Logging is set to {level} level.")
# %% # %%
@ -298,10 +274,10 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
headers = { headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
} }
logging.info(f"Fetching {url}") logger.info(f"Fetching {url}")
response = requests.get(url, headers=headers, timeout=10) response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status() response.raise_for_status()
logging.info(f"{url} returned. Processing...") logger.info(f"{url} returned. Processing...")
# Parse the HTML # Parse the HTML
soup = BeautifulSoup(response.text, "html.parser") soup = BeautifulSoup(response.text, "html.parser")
@ -323,7 +299,7 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
text = text[:max_chars] + "..." text = text[:max_chars] + "..."
# Create Ollama client # Create Ollama client
# logging.info(f"Requesting summary of: {text}") # logger.info(f"Requesting summary of: {text}")
# Generate summary using Ollama # Generate summary using Ollama
prompt = f"CONTENTS:\n\n{text}\n\n{question}" prompt = f"CONTENTS:\n\n{text}\n\n{question}"
@ -331,7 +307,7 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
system="You are given the contents of {url}. Answer the question about the contents", system="You are given the contents of {url}. Answer the question about the contents",
prompt=prompt) prompt=prompt)
#logging.info(response["response"]) #logger.info(response["response"])
return { return {
"source": "summarizer-llm", "source": "summarizer-llm",
@ -377,12 +353,12 @@ class WebServer:
watch_directory=defines.doc_dir, watch_directory=defines.doc_dir,
recreate=False # Don't recreate if exists recreate=False # Don't recreate if exists
) )
logging.info(f"API started with {self.file_watcher.collection.count()} documents in the collection") logger.info(f"API started with {self.file_watcher.collection.count()} documents in the collection")
yield yield
if self.observer: if self.observer:
self.observer.stop() self.observer.stop()
self.observer.join() self.observer.join()
logging.info("File watcher stopped") logger.info("File watcher stopped")
def __init__(self, llm, model=MODEL_NAME): def __init__(self, llm, model=MODEL_NAME):
self.app = FastAPI(lifespan=self.lifespan) self.app = FastAPI(lifespan=self.lifespan)
@ -400,7 +376,7 @@ class WebServer:
else: else:
allow_origins=["http://battle-linux.ketrenos.com:3000"] allow_origins=["http://battle-linux.ketrenos.com:3000"]
logging.info(f"Allowed origins: {allow_origins}") logger.info(f"Allowed origins: {allow_origins}")
self.app.add_middleware( self.app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -416,13 +392,13 @@ class WebServer:
@self.app.get("/") @self.app.get("/")
async def root(): async def root():
context = self.create_context() context = self.create_context()
logging.info(f"Redirecting non-context to {context.id}") logger.info(f"Redirecting non-context to {context.id}")
return RedirectResponse(url=f"/{context.id}", status_code=307) return RedirectResponse(url=f"/{context.id}", status_code=307)
#return JSONResponse({"redirect": f"/{context.id}"}) #return JSONResponse({"redirect": f"/{context.id}"})
@self.app.put("/api/umap/{context_id}") @self.app.put("/api/umap/{context_id}")
async def put_umap(context_id: str, request: Request): async def put_umap(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
try: try:
if not self.file_watcher: if not self.file_watcher:
raise Exception("File watcher not initialized") raise Exception("File watcher not initialized")
@ -436,10 +412,10 @@ class WebServer:
dimensions = data.get("dimensions", 2) dimensions = data.get("dimensions", 2)
result = self.file_watcher.umap_collection result = self.file_watcher.umap_collection
if dimensions == 2: if dimensions == 2:
logging.info("Returning 2D UMAP") logger.info("Returning 2D UMAP")
umap_embedding = self.file_watcher.umap_embedding_2d umap_embedding = self.file_watcher.umap_embedding_2d
else: else:
logging.info("Returning 3D UMAP") logger.info("Returning 3D UMAP")
umap_embedding = self.file_watcher.umap_embedding_3d umap_embedding = self.file_watcher.umap_embedding_3d
result["embeddings"] = umap_embedding.tolist() result["embeddings"] = umap_embedding.tolist()
@ -447,19 +423,19 @@ class WebServer:
return JSONResponse(result) return JSONResponse(result)
except Exception as e: except Exception as e:
logging.error(f"put_umap error: {str(e)}") logger.error(f"put_umap error: {str(e)}")
import traceback import traceback
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, 500) return JSONResponse({"error": str(e)}, 500)
@self.app.put("/api/similarity/{context_id}") @self.app.put("/api/similarity/{context_id}")
async def put_similarity(context_id: str, request: Request): async def put_similarity(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
if not self.file_watcher: if not self.file_watcher:
raise Exception("File watcher not initialized") raise Exception("File watcher not initialized")
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400) return JSONResponse({"error": "Invalid context_id"}, status_code=400)
try: try:
@ -476,13 +452,13 @@ class WebServer:
return JSONResponse({"error": "No results found"}, status_code=404) return JSONResponse({"error": "No results found"}, status_code=404)
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
logging.info(f"Chroma embedding shape: {chroma_embedding.shape}") logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output logger.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output logger.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
return JSONResponse({ return JSONResponse({
**chroma_results, **chroma_results,
@ -492,14 +468,14 @@ class WebServer:
}) })
except Exception as e: except Exception as e:
logging.error(e) logger.error(e)
#return JSONResponse({"error": str(e)}, 500) #return JSONResponse({"error": str(e)}, 500)
@self.app.put("/api/reset/{context_id}/{agent_type}") @self.app.put("/api/reset/{context_id}/{agent_type}")
async def put_reset(context_id: str, agent_type: str, request: Request): async def put_reset(context_id: str, agent_type: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400) return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
agent = context.get_agent(agent_type) agent = context.get_agent(agent_type)
@ -512,7 +488,7 @@ class WebServer:
for reset_operation in data["reset"]: for reset_operation in data["reset"]:
match reset_operation: match reset_operation:
case "system_prompt": case "system_prompt":
logging.info(f"Resetting {reset_operation}") logger.info(f"Resetting {reset_operation}")
match agent_type: match agent_type:
case "chat": case "chat":
prompt = system_message prompt = system_message
@ -526,11 +502,11 @@ class WebServer:
agent.system_prompt = prompt agent.system_prompt = prompt
response["system_prompt"] = { "system_prompt": prompt } response["system_prompt"] = { "system_prompt": prompt }
case "rags": case "rags":
logging.info(f"Resetting {reset_operation}") logger.info(f"Resetting {reset_operation}")
context.rags = rags.copy() context.rags = rags.copy()
response["rags"] = context.rags response["rags"] = context.rags
case "tools": case "tools":
logging.info(f"Resetting {reset_operation}") logger.info(f"Resetting {reset_operation}")
context.tools = default_tools(tools) context.tools = default_tools(tools)
response["tools"] = context.tools response["tools"] = context.tools
case "history": case "history":
@ -546,13 +522,13 @@ class WebServer:
tmp = context.get_agent(mode) tmp = context.get_agent(mode)
if not tmp: if not tmp:
continue continue
logging.info(f"Resetting {reset_operation} for {mode}") logger.info(f"Resetting {reset_operation} for {mode}")
context.conversation = Conversation() context.conversation = Conversation()
context.context_tokens = round(len(str(agent.system_prompt)) * 3 / 4) # Estimate context usage context.context_tokens = round(len(str(agent.system_prompt)) * 3 / 4) # Estimate context usage
response["history"] = [] response["history"] = []
response["context_used"] = agent.context_tokens response["context_used"] = agent.context_tokens
case "message_history_length": case "message_history_length":
logging.info(f"Resetting {reset_operation}") logger.info(f"Resetting {reset_operation}")
context.message_history_length = DEFAULT_HISTORY_LENGTH context.message_history_length = DEFAULT_HISTORY_LENGTH
response["message_history_length"] = DEFAULT_HISTORY_LENGTH response["message_history_length"] = DEFAULT_HISTORY_LENGTH
@ -567,7 +543,7 @@ class WebServer:
@self.app.put("/api/tunables/{context_id}") @self.app.put("/api/tunables/{context_id}")
async def put_tunables(context_id: str, request: Request): async def put_tunables(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
try: try:
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
@ -619,14 +595,14 @@ class WebServer:
case _: case _:
return JSONResponse({ "error": f"Unrecognized tunable {k}"}, status_code=404) return JSONResponse({ "error": f"Unrecognized tunable {k}"}, status_code=404)
except Exception as e: except Exception as e:
logging.error(f"Error in put_tunables: {e}") logger.error(f"Error in put_tunables: {e}")
return JSONResponse({"error": str(e)}, status_code=500) return JSONResponse({"error": str(e)}, status_code=500)
@self.app.get("/api/tunables/{context_id}") @self.app.get("/api/tunables/{context_id}")
async def get_tunables(context_id: str, request: Request): async def get_tunables(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400) return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
agent = context.get_agent("chat") agent = context.get_agent("chat")
@ -644,15 +620,15 @@ class WebServer:
@self.app.get("/api/system-info/{context_id}") @self.app.get("/api/system-info/{context_id}")
async def get_system_info(context_id: str, request: Request): async def get_system_info(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
return JSONResponse(system_info(self.model)) return JSONResponse(system_info(self.model))
@self.app.post("/api/chat/{context_id}/{agent_type}") @self.app.post("/api/chat/{context_id}/{agent_type}")
async def post_chat_endpoint(context_id: str, agent_type: str, request: Request): async def post_chat_endpoint(context_id: str, agent_type: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
try: try:
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400) return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
@ -660,11 +636,11 @@ class WebServer:
data = await request.json() data = await request.json()
agent = context.get_agent(agent_type) agent = context.get_agent(agent_type)
if not agent and agent_type == "job_description": if not agent and agent_type == "job_description":
logging.info(f"Agent {agent_type} not found. Returning empty history.") logger.info(f"Agent {agent_type} not found. Returning empty history.")
# Create a new agent if it doesn't exist # Create a new agent if it doesn't exist
agent = context.get_or_create_agent("job_description", system_prompt=system_generate_resume, job_description=data["content"]) agent = context.get_or_create_agent("job_description", system_prompt=system_generate_resume, job_description=data["content"])
except Exception as e: except Exception as e:
logging.info(f"Attempt to create agent type: {agent_type} failed", e) logger.info(f"Attempt to create agent type: {agent_type} failed", e)
return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404) return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404)
# Create a custom generator that ensures flushing # Create a custom generator that ensures flushing
@ -688,43 +664,43 @@ class WebServer:
} }
) )
except Exception as e: except Exception as e:
logging.error(f"Error in post_chat_endpoint: {e}") logger.error(f"Error in post_chat_endpoint: {e}")
return JSONResponse({"error": str(e)}, status_code=500) return JSONResponse({"error": str(e)}, status_code=500)
@self.app.post("/api/context") @self.app.post("/api/context")
async def create_context(): async def create_context():
context = self.create_context() context = self.create_context()
logging.info(f"Generated new agent as {context.id}") logger.info(f"Generated new agent as {context.id}")
return JSONResponse({ "id": context.id }) return JSONResponse({ "id": context.id })
@self.app.get("/api/history/{context_id}/{agent_type}") @self.app.get("/api/history/{context_id}/{agent_type}")
async def get_history(context_id: str, agent_type: str, request: Request): async def get_history(context_id: str, agent_type: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
try: try:
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
agent = context.get_agent(agent_type) agent = context.get_agent(agent_type)
if not agent: if not agent:
logging.info(f"Agent {agent_type} not found. Returning empty history.") logger.info(f"Agent {agent_type} not found. Returning empty history.")
return JSONResponse({ "messages": [] }) return JSONResponse({ "messages": [] })
logging.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.") logger.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.")
return agent.conversation return agent.conversation
except Exception as e: except Exception as e:
logging.error(f"get_history error: {str(e)}") logger.error(f"get_history error: {str(e)}")
import traceback import traceback
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, status_code=404) return JSONResponse({"error": str(e)}, status_code=404)
@self.app.get("/api/tools/{context_id}") @self.app.get("/api/tools/{context_id}")
async def get_tools(context_id: str, request: Request): async def get_tools(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
return JSONResponse(context.tools) return JSONResponse(context.tools)
@self.app.put("/api/tools/{context_id}") @self.app.put("/api/tools/{context_id}")
async def put_tools(context_id: str, request: Request): async def put_tools(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400) return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
try: try:
@ -743,9 +719,9 @@ class WebServer:
@self.app.get("/api/context-status/{context_id}/{agent_type}") @self.app.get("/api/context-status/{context_id}/{agent_type}")
async def get_context_status(context_id, agent_type: str, request: Request): async def get_context_status(context_id, agent_type: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400) return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id) context = self.upsert_context(context_id)
agent = context.get_agent(agent_type) agent = context.get_agent(agent_type)
@ -761,9 +737,9 @@ class WebServer:
async def serve_static(path: str): async def serve_static(path: str):
full_path = os.path.join(defines.static_content, path) full_path = os.path.join(defines.static_content, path)
if os.path.exists(full_path) and os.path.isfile(full_path): if os.path.exists(full_path) and os.path.isfile(full_path):
logging.info(f"Serve static request for {full_path}") logger.info(f"Serve static request for {full_path}")
return FileResponse(full_path) return FileResponse(full_path)
logging.info(f"Serve index.html for {path}") logger.info(f"Serve index.html for {path}")
return FileResponse(os.path.join(defines.static_content, "index.html")) return FileResponse(os.path.join(defines.static_content, "index.html"))
def save_context(self, context_id): def save_context(self, context_id):
@ -807,28 +783,28 @@ class WebServer:
# Check if the file exists # Check if the file exists
if not os.path.exists(file_path): if not os.path.exists(file_path):
logging.info(f"Context file {file_path} not found. Creating new context.") logger.info(f"Context file {file_path} not found. Creating new context.")
self.contexts[context_id] = self.create_context(context_id) self.contexts[context_id] = self.create_context(context_id)
else: else:
# Read and deserialize the data # Read and deserialize the data
with open(file_path, "r") as f: with open(file_path, "r") as f:
content = f.read() content = f.read()
logging.info(f"Loading context from {file_path}, content length: {len(content)}") logger.info(f"Loading context from {file_path}, content length: {len(content)}")
try: try:
# Try parsing as JSON first to ensure valid JSON # Try parsing as JSON first to ensure valid JSON
import json import json
json_data = json.loads(content) json_data = json.loads(content)
logging.info("JSON parsed successfully, attempting model validation") logger.info("JSON parsed successfully, attempting model validation")
# Now try Pydantic validation # Now try Pydantic validation
self.contexts[context_id] = Context.from_json(json_data, file_watcher=self.file_watcher) self.contexts[context_id] = Context.from_json(json_data, file_watcher=self.file_watcher)
logging.info(f"Successfully loaded context {context_id}") logger.info(f"Successfully loaded context {context_id}")
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in file: {e}") logger.error(f"Invalid JSON in file: {e}")
except Exception as e: except Exception as e:
logging.error(f"Error validating context: {str(e)}") logger.error(f"Error validating context: {str(e)}")
import traceback import traceback
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
# Fallback to creating a new context # Fallback to creating a new context
self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher) self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher)
@ -845,7 +821,7 @@ class WebServer:
if not self.file_watcher: if not self.file_watcher:
raise Exception("File watcher not initialized") raise Exception("File watcher not initialized")
logging.info(f"Creating new context with ID: {context_id}") logger.info(f"Creating new context with ID: {context_id}")
context = Context(id=context_id, file_watcher=self.file_watcher) context = Context(id=context_id, file_watcher=self.file_watcher)
if os.path.exists(defines.resume_doc): if os.path.exists(defines.resume_doc):
@ -859,7 +835,7 @@ class WebServer:
context.tools = default_tools(tools) context.tools = default_tools(tools)
context.rags = rags.copy() context.rags = rags.copy()
logging.info(f"{context.id} created and added to contexts.") logger.info(f"{context.id} created and added to contexts.")
self.contexts[context.id] = context self.contexts[context.id] = context
self.save_context(context.id) self.save_context(context.id)
return context return context
@ -941,13 +917,13 @@ class WebServer:
""" """
if not context_id: if not context_id:
logging.warning("No context ID provided. Creating a new context.") logger.warning("No context ID provided. Creating a new context.")
return self.create_context() return self.create_context()
if context_id in self.contexts: if context_id in self.contexts:
return self.contexts[context_id] return self.contexts[context_id]
logging.info(f"Context {context_id} is not yet loaded.") logger.info(f"Context {context_id} is not yet loaded.")
return self.load_or_create_context(context_id) return self.load_or_create_context(context_id)
def generate_rag_results(self, context, content): def generate_rag_results(self, context, content):
@ -963,13 +939,13 @@ class WebServer:
if chroma_results: if chroma_results:
results_found = True results_found = True
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
logging.info(f"Chroma embedding shape: {chroma_embedding.shape}") logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output logger.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output logger.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
yield { yield {
**chroma_results, **chroma_results,
@ -1012,11 +988,11 @@ class WebServer:
raise Exception("File watcher not initialized") raise Exception("File watcher not initialized")
agent_type = agent.get_agent_type() agent_type = agent.get_agent_type()
logging.info(f"generate_response: {agent_type}") logger.info(f"generate_response: {agent_type}")
if agent_type == "chat": if agent_type == "chat":
message = Message(prompt=content) message = Message(prompt=content)
async for value in agent.prepare_message(message): async for value in agent.prepare_message(message):
logging.info(f"{agent_type}.prepare_message: {value.status} - {value.response}") logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}")
if value.status != "done": if value.status != "done":
yield value yield value
if value.status == "error": if value.status == "error":
@ -1025,7 +1001,7 @@ class WebServer:
yield message yield message
return return
async for value in agent.process_message(message): async for value in agent.process_message(message):
logging.info(f"{agent_type}.process_message: {value.status} - {value.response}") logger.info(f"{agent_type}.process_message: {value.status} - {value.response}")
if value.status != "done": if value.status != "done":
yield value yield value
if value.status == "error": if value.status == "error":
@ -1034,7 +1010,7 @@ class WebServer:
yield message yield message
return return
async for value in agent.generate_llm_response(message): async for value in agent.generate_llm_response(message):
logging.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}") logger.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}")
if value.status != "done": if value.status != "done":
yield value yield value
if value.status == "error": if value.status == "error":
@ -1042,13 +1018,13 @@ class WebServer:
message.response = value.response message.response = value.response
yield message yield message
return return
logging.info("TODO: There is more to do...") logger.info("TODO: There is more to do...")
return return
return return
if self.processing: if self.processing:
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
yield {"status": "error", "message": "Busy processing another request."} yield {"status": "error", "message": "Busy processing another request."}
return return
@ -1083,11 +1059,11 @@ class WebServer:
process_type = agent.get_agent_type() process_type = agent.get_agent_type()
match process_type: match process_type:
case "job_description": case "job_description":
logging.info(f"job_description user_history len: {len(conversation.messages)}") logger.info(f"job_description user_history len: {len(conversation.messages)}")
if len(conversation.messages) >= 2: # USER, ASSISTANT if len(conversation.messages) >= 2: # USER, ASSISTANT
process_type = "chat" process_type = "chat"
case "resume": case "resume":
logging.info(f"resume user_history len: {len(conversation.messages)}") logger.info(f"resume user_history len: {len(conversation.messages)}")
if len(conversation.messages) >= 3: # USER, ASSISTANT, FACT_CHECK if len(conversation.messages) >= 3: # USER, ASSISTANT, FACT_CHECK
process_type = "chat" process_type = "chat"
case "fact_check": case "fact_check":
@ -1098,7 +1074,7 @@ class WebServer:
case "chat": case "chat":
if not message.prompt: if not message.prompt:
yield {"status": "error", "message": "No query provided for chat."} yield {"status": "error", "message": "No query provided for chat."}
logging.info(f"user_history len: {len(conversation.messages)}") logger.info(f"user_history len: {len(conversation.messages)}")
self.processing = False self.processing = False
return return
@ -1207,8 +1183,8 @@ Use to the above information to respond to this prompt:
message.add_action("generate_resume") message.add_action("generate_resume")
logging.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into agent'resume'") logger.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into agent'resume'")
logging.info("TODO: For subsequent runs, have the Agent handler generate the follow up prompts so they can have correct context preamble") logger.info("TODO: For subsequent runs, have the Agent handler generate the follow up prompts so they can have correct context preamble")
# Switch to resume agent for LLM responses # Switch to resume agent for LLM responses
# message.metadata["origin"] = "resume" # message.metadata["origin"] = "resume"
@ -1288,11 +1264,11 @@ Use the above <|resume|> and <|job_description|> to answer this query:
stuffingMessage.metadata["origin"] = "resume" stuffingMessage.metadata["origin"] = "resume"
stuffingMessage.metadata["display"] = "hide" stuffingMessage.metadata["display"] = "hide"
stuffingMessage.actions = [ "fact_check" ] stuffingMessage.actions = [ "fact_check" ]
logging.info("TODO: Switch this to use actions to keep the UI from showingit") logger.info("TODO: Switch this to use actions to keep the UI from showingit")
conversation.add_message(stuffingMessage) conversation.add_message(stuffingMessage)
# For all future calls to job_description, use the system_job_description # For all future calls to job_description, use the system_job_description
logging.info("TODO: Create a system_resume_QA prompt to use for the resume agent") logger.info("TODO: Create a system_resume_QA prompt to use for the resume agent")
agent.system_prompt = system_prompt agent.system_prompt = system_prompt
# Switch to fact_check agent for LLM responses # Switch to fact_check agent for LLM responses
@ -1364,7 +1340,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
else: else:
response = self.llm.chat(model=self.model, messages=messages, options={ "num_ctx": ctx_size }) response = self.llm.chat(model=self.model, messages=messages, options={ "num_ctx": ctx_size })
except Exception as e: except Exception as e:
logging.exception({ "model": self.model, "error": str(e) }) logger.exception({ "model": self.model, "error": str(e) })
yield {"status": "error", "message": f"An error occurred communicating with LLM"} yield {"status": "error", "message": f"An error occurred communicating with LLM"}
self.processing = False self.processing = False
return return
@ -1449,7 +1425,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
} }
# except Exception as e: # except Exception as e:
# logging.exception({ "model": self.model, "origin": agent_type, "content": content, "error": str(e) }) # logger.exception({ "model": self.model, "origin": agent_type, "content": content, "error": str(e) })
# yield {"status": "error", "message": f"An error occurred: {str(e)}"} # yield {"status": "error", "message": f"An error occurred: {str(e)}"}
# finally: # finally:
@ -1460,7 +1436,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs): def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs):
try: try:
if self.ssl_enabled: if self.ssl_enabled:
logging.info(f"Starting web server at https://{host}:{port}") logger.info(f"Starting web server at https://{host}:{port}")
uvicorn.run( uvicorn.run(
self.app, self.app,
host=host, host=host,
@ -1470,7 +1446,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
ssl_certfile=defines.cert_path ssl_certfile=defines.cert_path
) )
else: else:
logging.info(f"Starting web server at http://{host}:{port}") logger.info(f"Starting web server at http://{host}:{port}")
uvicorn.run( uvicorn.run(
self.app, self.app,
host=host, host=host,
@ -1493,7 +1469,7 @@ def main():
args = parse_args() args = parse_args()
# Setup logging based on the provided level # Setup logging based on the provided level
setup_logging(args.level) logger.setLevel(args.level)
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",

View File

@ -6,7 +6,7 @@ from .message import Message
from . conversation import Conversation from . conversation import Conversation
from . context import Context from . context import Context
from . import agents from . import agents
import logging from . setup_logging import setup_logging
from .agents import Agent, __all__ as agents_all from .agents import Agent, __all__ as agents_all
@ -17,13 +17,14 @@ __all__ = [
'Message', 'Message',
'ChromaDBFileWatcher', 'ChromaDBFileWatcher',
'start_file_watcher' 'start_file_watcher'
'logger',
] + agents_all ] + agents_all
# Resolve circular dependencies by rebuilding models # Resolve circular dependencies by rebuilding models
# Call model_rebuild() on Agent and Context # Call model_rebuild() on Agent and Context
Agent.model_rebuild() Agent.model_rebuild()
Context.model_rebuild() Context.model_rebuild()
import logging
import importlib import importlib
from pydantic import BaseModel from pydantic import BaseModel
from typing import Type from typing import Type
@ -31,19 +32,20 @@ from typing import Type
# Assuming class_registry is available from agents/__init__.py # Assuming class_registry is available from agents/__init__.py
from .agents import class_registry, AnyAgent from .agents import class_registry, AnyAgent
logger = setup_logging(level=defines.logging_level)
def rebuild_models(): def rebuild_models():
Context.model_rebuild()
for class_name, (module_name, _) in class_registry.items(): for class_name, (module_name, _) in class_registry.items():
try: try:
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
cls = getattr(module, class_name, None) cls = getattr(module, class_name, None)
logging.info(f"Checking: {class_name} in module {module_name}") logger.info(f"Checking: {class_name} in module {module_name}")
logging.info(f" cls: {True if cls else False}") logger.info(f" cls: {True if cls else False}")
logging.info(f" isinstance(cls, type): {isinstance(cls, type)}") logger.info(f" isinstance(cls, type): {isinstance(cls, type)}")
logging.info(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}") logger.info(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}")
logging.info(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}") logger.info(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}")
logging.info(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}") logger.info(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}")
if ( if (
cls cls
@ -52,12 +54,14 @@ def rebuild_models():
and issubclass(cls, AnyAgent) and issubclass(cls, AnyAgent)
and cls is not AnyAgent and cls is not AnyAgent
): ):
logging.info(f"Rebuilding {class_name} from {module_name}") logger.info(f"Rebuilding {class_name} from {module_name}")
from . agents import Agent
from . context import Context
cls.model_rebuild() cls.model_rebuild()
except ImportError as e: except ImportError as e:
logging.error(f"Failed to import module {module_name}: {e}") logger.error(f"Failed to import module {module_name}: {e}")
except Exception as e: except Exception as e:
logging.error(f"Error processing {class_name} in {module_name}: {e}") logger.error(f"Error processing {class_name} in {module_name}: {e}")
# Call this after all modules are imported # Call this after all modules are imported
rebuild_models() rebuild_models()

View File

@ -3,18 +3,20 @@ from pydantic import BaseModel, model_validator, PrivateAttr, Field
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing_extensions import Annotated from typing_extensions import Annotated
import logging from .. setup_logging import setup_logging
logger = setup_logging()
# Only import Context for type checking # Only import Context for type checking
if TYPE_CHECKING: if TYPE_CHECKING:
from .. context import Context from .. context import Context
from .types import AgentBase, registry from .types import registry
from .. conversation import Conversation from .. conversation import Conversation
from .. message import Message from .. message import Message
class Agent(AgentBase): class Agent(BaseModel, ABC):
""" """
Base class for all agent types. Base class for all agent types.
This class defines the common attributes and methods for all agent types. This class defines the common attributes and methods for all agent types.
@ -28,27 +30,19 @@ class Agent(AgentBase):
system_prompt: str # Mandatory system_prompt: str # Mandatory
conversation: Conversation = Conversation() conversation: Conversation = Conversation()
context_tokens: int = 0 context_tokens: int = 0
context: object = Field(..., exclude=True) # Avoid circular reference, require as param, and prevent serialization context: Optional[Context] = Field(default=None, exclude=True) # Avoid circular reference, require as param, and prevent serialization
_content_seed: str = PrivateAttr(default="") _content_seed: str = PrivateAttr(default="")
# Class and pydantic model management # Class and pydantic model management
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
"""Auto-register subclasses""" """Auto-register subclasses"""
logger.info(f"Agent.__init_subclass__({kwargs})")
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
# Register this class if it has an agent_type # Register this class if it has an agent_type
if hasattr(cls, 'agent_type') and cls.agent_type != AgentBase._agent_type: if hasattr(cls, 'agent_type') and cls.agent_type != Agent._agent_type:
registry.register(cls.agent_type, cls) registry.register(cls.agent_type, cls)
def __init__(self, **data):
# Set agent_type from class if not provided
if 'agent_type' not in data:
data['agent_type'] = self.__class__.agent_type
from .. context import Context
Context.model_rebuild()
self.__class__.model_rebuild()
super().__init__(**data)
def model_dump(self, *args, **kwargs): def model_dump(self, *args, **kwargs):
# Ensure context is always excluded, even with exclude_unset=True # Ensure context is always excluded, even with exclude_unset=True
kwargs.setdefault("exclude", set()) kwargs.setdefault("exclude", set())
@ -80,7 +74,7 @@ class Agent(AgentBase):
# Gather RAG results, yielding each result # Gather RAG results, yielding each result
# as it becomes available # as it becomes available
for value in self.context.generate_rag_results(message): for value in self.context.generate_rag_results(message):
logging.info(f"RAG: {value.status} - {value.response}") logger.info(f"RAG: {value.status} - {value.response}")
if value.status != "done": if value.status != "done":
yield value yield value
if value.status == "error": if value.status == "error":
@ -120,7 +114,7 @@ class Agent(AgentBase):
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]: async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
if self.context.processing: if self.context.processing:
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
message.status = "error" message.status = "error"
message.response = "Busy processing another request." message.response = "Busy processing another request."
yield message yield message
@ -136,7 +130,7 @@ class Agent(AgentBase):
#tools=llm_tools(context.tools) if message.enable_tools else None, #tools=llm_tools(context.tools) if message.enable_tools else None,
options={ "num_ctx": message.ctx_size } options={ "num_ctx": message.ctx_size }
): ):
logging.info(f"LLM: {value.status} - {value.response}") logger.info(f"LLM: {value.status} - {value.response}")
if value.status != "done": if value.status != "done":
message.status = value.status message.status = value.status
message.response = value.response message.response = value.response
@ -240,7 +234,7 @@ class Agent(AgentBase):
yield message yield message
for value in self.generate_llm_response(message): for value in self.generate_llm_response(message):
logging.info(f"LLM: {value.status} - {value.response}") logger.info(f"LLM: {value.status} - {value.response}")
if value.status != "done": if value.status != "done":
yield value yield value
if value.status == "error": if value.status == "error":

View File

@ -1,3 +1,4 @@
from __future__ import annotations
from pydantic import BaseModel, Field, model_validator, ValidationError from pydantic import BaseModel, Field, model_validator, ValidationError
from uuid import uuid4 from uuid import uuid4
from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING
@ -11,12 +12,6 @@ from .message import Message
from .rag import ChromaDBFileWatcher from .rag import ChromaDBFileWatcher
from . import defines from . import defines
from .agents import Agent
# Import only agent types, not actual classes
if TYPE_CHECKING:
from .agents import Agent, AnyAgent
from .agents import AnyAgent from .agents import AnyAgent
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -25,7 +20,7 @@ logger = logging.getLogger(__name__)
class Context(BaseModel): class Context(BaseModel):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
# Required fields # Required fields
file_watcher: ChromaDBFileWatcher = Field(..., exclude=True) file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
# Optional fields # Optional fields
id: str = Field( id: str = Field(
@ -44,15 +39,14 @@ class Context(BaseModel):
default_factory=list default_factory=list
) )
@model_validator(mode="before")
@classmethod @classmethod
def from_json(cls, json_str: str, file_watcher: ChromaDBFileWatcher): def before_model_validator(cls, values: Any):
"""Custom method to load from JSON with file_watcher injection""" logger.info(f"Preparing model data: {cls} {values}")
import json return values
data = json.loads(json_str)
return cls(file_watcher=file_watcher, **data)
@model_validator(mode="after") @model_validator(mode="after")
def validate_unique_agent_types(self): def after_model_validator(self):
"""Ensure at most one agent per agent_type.""" """Ensure at most one agent per agent_type."""
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.") logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
agent_types = [agent.agent_type for agent in self.agents] agent_types = [agent.agent_type for agent in self.agents]
@ -191,4 +185,5 @@ class Context(BaseModel):
summary += f"\nChat Name: {agent.name}\n" summary += f"\nChat Name: {agent.name}\n"
return summary return summary
from . agents import Agent
Context.model_rebuild() Context.model_rebuild()

View File

@ -15,3 +15,4 @@ resume_doc = "/opt/backstory/docs/resume/generic.md"
# Only used for testing; backstory-prod will not use this # Only used for testing; backstory-prod will not use this
key_path = "/opt/backstory/keys/key.pem" key_path = "/opt/backstory/keys/key.pem"
cert_path = "/opt/backstory/keys/cert.pem" cert_path = "/opt/backstory/keys/cert.pem"
logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper()