Starting to work again

This commit is contained in:
James Ketr 2025-04-30 12:57:51 -07:00
parent 4614dbb237
commit e607e3a2f2
8 changed files with 351 additions and 291 deletions

View File

@ -1,5 +1,10 @@
import os
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
import warnings
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
warnings.filterwarnings("ignore", message="Warning only once for all operators")
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
# %%
# Imports [standard]
@ -37,6 +42,7 @@ try_import("sklearn")
import ollama
import requests
from bs4 import BeautifulSoup
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
@ -363,8 +369,23 @@ def llm_tools(tools):
# %%
class WebServer:
@asynccontextmanager
async def lifespan(self, app: FastAPI):
# Start the file watcher
self.observer, self.file_watcher = Rag.start_file_watcher(
llm=self.llm,
watch_directory=defines.doc_dir,
recreate=False # Don't recreate if exists
)
logging.info(f"API started with {self.file_watcher.collection.count()} documents in the collection")
yield
if self.observer:
self.observer.stop()
self.observer.join()
logging.info("File watcher stopped")
def __init__(self, llm, model=MODEL_NAME):
self.app = FastAPI()
self.app = FastAPI(lifespan=self.lifespan)
self.contexts = {}
self.llm = llm
self.model = model
@ -389,24 +410,6 @@ class WebServer:
allow_headers=["*"],
)
@self.app.on_event("startup")
async def startup_event():
# Start the file watcher
self.observer, self.file_watcher = Rag.start_file_watcher(
llm=llm,
watch_directory=defines.doc_dir,
recreate=False # Don't recreate if exists
)
print(f"API started with {self.file_watcher.collection.count()} documents in the collection")
@self.app.on_event("shutdown")
async def shutdown_event():
if self.observer:
self.observer.stop()
self.observer.join()
print("File watcher stopped")
self.setup_routes()
def setup_routes(self):
@ -444,14 +447,16 @@ class WebServer:
return JSONResponse(result)
except Exception as e:
logging.error(e)
logging.error(f"put_umap error: {str(e)}")
import traceback
logging.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, 500)
@self.app.put("/api/similarity/{context_id}")
async def put_similarity(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}")
if not self.file_watcher:
return
raise Exception("File watcher not initialized")
if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}")
@ -471,13 +476,13 @@ class WebServer:
return JSONResponse({"error": "No results found"}, status_code=404)
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
print(f"Chroma embedding shape: {chroma_embedding.shape}")
logging.info(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
return JSONResponse({
**chroma_results,
@ -666,7 +671,7 @@ class WebServer:
async def flush_generator():
async for message in self.generate_response(context=context, agent=agent, content=data["content"]):
# Convert to JSON and add newline
yield json.dumps(message) + "\n"
yield (message.model_dump_json()) + "\n"
# Save the history as its generated
self.save_context(context_id)
# Explicitly flush after each yield
@ -704,7 +709,9 @@ class WebServer:
logging.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.")
return agent.conversation
except Exception as e:
logging.error(f"Error in get_history: {e}")
logging.error(f"get_history error: {str(e)}")
import traceback
logging.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, status_code=404)
@self.app.get("/api/tools/{context_id}")
@ -759,52 +766,73 @@ class WebServer:
logging.info(f"Serve index.html for {path}")
return FileResponse(os.path.join(defines.static_content, "index.html"))
def save_context(self, agent_id):
def save_context(self, context_id):
"""
Serialize a Python dictionary to a file in the agents directory.
Args:
data: Dictionary containing the agent data
agent_id: UUID string for the context. If it doesn't exist, it is created
context_id: UUID string for the context. If it doesn't exist, it is created
Returns:
The agent_id used for the file
The context_id used for the file
"""
context = self.upsert_context(agent_id)
context = self.upsert_context(context_id)
# Create agents directory if it doesn't exist
if not os.path.exists(defines.context_dir):
os.makedirs(defines.context_dir)
# Create the full file path
file_path = os.path.join(defines.context_dir, agent_id)
file_path = os.path.join(defines.context_dir, context_id)
# Serialize the data to JSON and write to file
with open(file_path, "w") as f:
f.write(context.model_dump_json())
return agent_id
return context_id
def load_context(self, agent_id) -> Context:
def load_or_create_context(self, context_id) -> Context:
"""
Load a context from a file in the agents directory.
Load a context from a file in the context directory or create a new one if it doesn't exist.
Args:
agent_id: UUID string for the context. If it doesn't exist, a new context is created.
context_id: UUID string for the context.
Returns:
A Context object with the specified ID and default settings.
"""
if not self.file_watcher:
raise Exception("File watcher not initialized")
file_path = os.path.join(defines.context_dir, agent_id)
file_path = os.path.join(defines.context_dir, context_id)
# Check if the file exists
if not os.path.exists(file_path):
self.contexts[agent_id] = self.create_context(agent_id)
logging.info(f"Context file {file_path} not found. Creating new context.")
self.contexts[context_id] = self.create_context(context_id)
else:
# Read and deserialize the data
with open(file_path, "r") as f:
self.contexts[agent_id] = Context.model_validate_json(f.read())
content = f.read()
logging.info(f"Loading context from {file_path}, content length: {len(content)}")
try:
# Try parsing as JSON first to ensure valid JSON
import json
json_data = json.loads(content)
logging.info("JSON parsed successfully, attempting model validation")
return self.contexts[agent_id]
# Now try Pydantic validation
self.contexts[context_id] = Context.from_json(json_data, file_watcher=self.file_watcher)
logging.info(f"Successfully loaded context {context_id}")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in file: {e}")
except Exception as e:
logging.error(f"Error validating context: {str(e)}")
import traceback
logging.error(traceback.format_exc())
# Fallback to creating a new context
self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher)
return self.contexts[context_id]
def create_context(self, context_id = None) -> Context:
"""
@ -814,7 +842,11 @@ class WebServer:
Returns:
A Context object with the specified ID and default settings.
"""
context = Context(id=context_id)
if not self.file_watcher:
raise Exception("File watcher not initialized")
logging.info(f"Creating new context with ID: {context_id}")
context = Context(id=context_id, file_watcher=self.file_watcher)
if os.path.exists(defines.resume_doc):
context.user_resume = open(defines.resume_doc, "r").read()
@ -912,20 +944,18 @@ class WebServer:
logging.warning("No context ID provided. Creating a new context.")
return self.create_context()
if not is_valid_uuid(context_id):
logging.info(f"User requested invalid context_id: {context_id}")
raise ValueError("Invalid context_id: {context_id}")
if context_id in self.contexts:
return self.contexts[context_id]
logging.info(f"Context {context_id} not found. Creating new context.")
return self.load_context(context_id)
logging.info(f"Context {context_id} is not yet loaded.")
return self.load_or_create_context(context_id)
def generate_rag_results(self, context, content):
if not self.file_watcher:
raise Exception("File watcher not initialized")
results_found = False
if self.file_watcher:
for rag in context.rags:
if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now...
yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."}
@ -933,13 +963,13 @@ class WebServer:
if chroma_results:
results_found = True
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
print(f"Chroma embedding shape: {chroma_embedding.shape}")
logging.info(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
yield {
**chroma_results,
@ -979,7 +1009,7 @@ class WebServer:
# * Then Q&A of Fact Check
async def generate_response(self, context : Context, agent : Agent, content : str):
if not self.file_watcher:
return
raise Exception("File watcher not initialized")
agent_type = agent.get_agent_type()
logging.info(f"generate_response: {agent_type}")
@ -1015,6 +1045,7 @@ class WebServer:
logging.info("TODO: There is more to do...")
return
return
if self.processing:
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")

View File

@ -1,4 +1,5 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr
from __future__ import annotations
from pydantic import BaseModel, model_validator, PrivateAttr, Field
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef
from abc import ABC, abstractmethod
from typing_extensions import Annotated
@ -8,8 +9,6 @@ import logging
if TYPE_CHECKING:
from .. context import Context
ContextRef = ForwardRef('Context')
from .types import AgentBase, registry
from .. conversation import Conversation
@ -29,9 +28,11 @@ class Agent(AgentBase):
system_prompt: str # Mandatory
conversation: Conversation = Conversation()
context_tokens: int = 0
context: ContextRef # Avoid circular reference
context: object = Field(..., exclude=True) # Avoid circular reference, require as param, and prevent serialization
_content_seed: str = PrivateAttr(default="")
# Class and pydantic model management
def __init_subclass__(cls, **kwargs):
"""Auto-register subclasses"""
super().__init_subclass__(**kwargs)
@ -48,6 +49,24 @@ class Agent(AgentBase):
self.__class__.model_rebuild()
super().__init__(**data)
def model_dump(self, *args, **kwargs):
# Ensure context is always excluded, even with exclude_unset=True
kwargs.setdefault("exclude", set())
if isinstance(kwargs["exclude"], set):
kwargs["exclude"].add("context")
elif isinstance(kwargs["exclude"], dict):
kwargs["exclude"]["context"] = True
return super().model_dump(*args, **kwargs)
@classmethod
def valid_agent_types(cls) -> set[str]:
"""Return the set of valid agent_type values."""
return set(get_args(cls.__annotations__["agent_type"]))
def set_context(self, context):
object.__setattr__(self, "context", context)
# Agent methods
def get_agent_type(self):
return self._agent_type
@ -240,11 +259,6 @@ class Agent(AgentBase):
"""Get the content seed for the agent."""
return self._content_seed
@classmethod
def valid_agent_types(cls) -> set[str]:
"""Return the set of valid agent_type values."""
return set(get_args(cls.__annotations__["agent_type"]))
# Register the base agent
registry.register(Agent._agent_type, Agent)

View File

@ -1,4 +1,5 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr
from __future__ import annotations
from pydantic import BaseModel, model_validator, PrivateAttr
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
from typing_extensions import Annotated
from abc import ABC, abstractmethod
@ -35,7 +36,7 @@ class Chat(Agent, ABC):
yield message
return
if message.metadata["rag"]:
if "rag" in message.metadata and message.metadata["rag"]:
for rag_collection in message.metadata["rag"]:
for doc in rag_collection["documents"]:
rag_context += f"{doc}\n"

View File

@ -1,9 +1,11 @@
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, model_validator, ValidationError
from uuid import uuid4
from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING
from typing_extensions import Annotated, Union
import numpy as np
import logging
from uuid import uuid4
import re
from .message import Message
from .rag import ChromaDBFileWatcher
@ -13,22 +15,23 @@ from .agents import Agent
# Import only agent types, not actual classes
if TYPE_CHECKING:
from .agents import Agent, AnyAgent, Chat, Resume, JobDescription, FactCheck
from .agents import Agent, AnyAgent
from .agents import AnyAgent
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Context(BaseModel):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
# Required fields
file_watcher: ChromaDBFileWatcher = Field(..., exclude=True)
# Optional fields
id: str = Field(
default_factory=lambda: str(uuid4()),
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
)
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
default_factory=list
)
user_resume: Optional[str] = None
user_job_description: Optional[str] = None
user_facts: Optional[str] = None
@ -36,17 +39,27 @@ class Context(BaseModel):
rags: List[dict] = []
message_history_length: int = 5
context_tokens: int = 0
file_watcher: ChromaDBFileWatcher = Field(default=None, exclude=True)
# Class managed fields
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
default_factory=list
)
def __init__(self, id: Optional[str] = None, **kwargs):
super().__init__(id=id if id is not None else str(uuid4()), **kwargs)
@classmethod
def from_json(cls, json_str: str, file_watcher: ChromaDBFileWatcher):
"""Custom method to load from JSON with file_watcher injection"""
import json
data = json.loads(json_str)
return cls(file_watcher=file_watcher, **data)
@model_validator(mode="after")
def validate_unique_agent_types(self):
"""Ensure at most one agent per agent_type."""
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
agent_types = [agent.agent_type for agent in self.agents]
if len(agent_types) != len(set(agent_types)):
raise ValueError("Context cannot contain multiple agents of the same agent_type")
for agent in self.agents:
agent.set_context(self)
return self
def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096):
@ -110,7 +123,7 @@ class Context(BaseModel):
except Exception as e:
message.response = f"Error generating RAG results: {str(e)}"
message.status = "error"
logging.error(e)
logger.error(e)
yield message
return

View File

@ -11,6 +11,7 @@ class Message(BaseModel):
disable_tools: bool = False
# Generated while processing message
status: str = "" # Status of the message
preamble: dict[str,str] = {} # Preamble to be prepended to the prompt
system_prompt: str = "" # System prompt provided to the LLM
full_content: str = "" # Full content of the message (preamble + prompt)