Starting to work again

This commit is contained in:
James Ketr 2025-04-30 12:57:51 -07:00
parent 4614dbb237
commit e607e3a2f2
8 changed files with 351 additions and 291 deletions

View File

@ -1,5 +1,10 @@
import os
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
import warnings import warnings
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel") warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
warnings.filterwarnings("ignore", message="Warning only once for all operators")
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
# %% # %%
# Imports [standard] # Imports [standard]
@ -37,6 +42,7 @@ try_import("sklearn")
import ollama import ollama
import requests import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, BackgroundTasks from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -363,8 +369,23 @@ def llm_tools(tools):
# %% # %%
class WebServer: class WebServer:
@asynccontextmanager
async def lifespan(self, app: FastAPI):
# Start the file watcher
self.observer, self.file_watcher = Rag.start_file_watcher(
llm=self.llm,
watch_directory=defines.doc_dir,
recreate=False # Don't recreate if exists
)
logging.info(f"API started with {self.file_watcher.collection.count()} documents in the collection")
yield
if self.observer:
self.observer.stop()
self.observer.join()
logging.info("File watcher stopped")
def __init__(self, llm, model=MODEL_NAME): def __init__(self, llm, model=MODEL_NAME):
self.app = FastAPI() self.app = FastAPI(lifespan=self.lifespan)
self.contexts = {} self.contexts = {}
self.llm = llm self.llm = llm
self.model = model self.model = model
@ -389,24 +410,6 @@ class WebServer:
allow_headers=["*"], allow_headers=["*"],
) )
@self.app.on_event("startup")
async def startup_event():
# Start the file watcher
self.observer, self.file_watcher = Rag.start_file_watcher(
llm=llm,
watch_directory=defines.doc_dir,
recreate=False # Don't recreate if exists
)
print(f"API started with {self.file_watcher.collection.count()} documents in the collection")
@self.app.on_event("shutdown")
async def shutdown_event():
if self.observer:
self.observer.stop()
self.observer.join()
print("File watcher stopped")
self.setup_routes() self.setup_routes()
def setup_routes(self): def setup_routes(self):
@ -444,14 +447,16 @@ class WebServer:
return JSONResponse(result) return JSONResponse(result)
except Exception as e: except Exception as e:
logging.error(e) logging.error(f"put_umap error: {str(e)}")
import traceback
logging.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, 500) return JSONResponse({"error": str(e)}, 500)
@self.app.put("/api/similarity/{context_id}") @self.app.put("/api/similarity/{context_id}")
async def put_similarity(context_id: str, request: Request): async def put_similarity(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logging.info(f"{request.method} {request.url.path}")
if not self.file_watcher: if not self.file_watcher:
return raise Exception("File watcher not initialized")
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logging.warning(f"Invalid context_id: {context_id}")
@ -471,13 +476,13 @@ class WebServer:
return JSONResponse({"error": "No results found"}, status_code=404) return JSONResponse({"error": "No results found"}, status_code=404)
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
print(f"Chroma embedding shape: {chroma_embedding.shape}") logging.info(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
return JSONResponse({ return JSONResponse({
**chroma_results, **chroma_results,
@ -666,7 +671,7 @@ class WebServer:
async def flush_generator(): async def flush_generator():
async for message in self.generate_response(context=context, agent=agent, content=data["content"]): async for message in self.generate_response(context=context, agent=agent, content=data["content"]):
# Convert to JSON and add newline # Convert to JSON and add newline
yield json.dumps(message) + "\n" yield (message.model_dump_json()) + "\n"
# Save the history as its generated # Save the history as its generated
self.save_context(context_id) self.save_context(context_id)
# Explicitly flush after each yield # Explicitly flush after each yield
@ -704,7 +709,9 @@ class WebServer:
logging.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.") logging.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.")
return agent.conversation return agent.conversation
except Exception as e: except Exception as e:
logging.error(f"Error in get_history: {e}") logging.error(f"get_history error: {str(e)}")
import traceback
logging.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, status_code=404) return JSONResponse({"error": str(e)}, status_code=404)
@self.app.get("/api/tools/{context_id}") @self.app.get("/api/tools/{context_id}")
@ -759,52 +766,73 @@ class WebServer:
logging.info(f"Serve index.html for {path}") logging.info(f"Serve index.html for {path}")
return FileResponse(os.path.join(defines.static_content, "index.html")) return FileResponse(os.path.join(defines.static_content, "index.html"))
def save_context(self, agent_id): def save_context(self, context_id):
""" """
Serialize a Python dictionary to a file in the agents directory. Serialize a Python dictionary to a file in the agents directory.
Args: Args:
data: Dictionary containing the agent data data: Dictionary containing the agent data
agent_id: UUID string for the context. If it doesn't exist, it is created context_id: UUID string for the context. If it doesn't exist, it is created
Returns: Returns:
The agent_id used for the file The context_id used for the file
""" """
context = self.upsert_context(agent_id) context = self.upsert_context(context_id)
# Create agents directory if it doesn't exist # Create agents directory if it doesn't exist
if not os.path.exists(defines.context_dir): if not os.path.exists(defines.context_dir):
os.makedirs(defines.context_dir) os.makedirs(defines.context_dir)
# Create the full file path # Create the full file path
file_path = os.path.join(defines.context_dir, agent_id) file_path = os.path.join(defines.context_dir, context_id)
# Serialize the data to JSON and write to file # Serialize the data to JSON and write to file
with open(file_path, "w") as f: with open(file_path, "w") as f:
f.write(context.model_dump_json()) f.write(context.model_dump_json())
return agent_id return context_id
def load_context(self, agent_id) -> Context: def load_or_create_context(self, context_id) -> Context:
""" """
Load a context from a file in the agents directory. Load a context from a file in the context directory or create a new one if it doesn't exist.
Args: Args:
agent_id: UUID string for the context. If it doesn't exist, a new context is created. context_id: UUID string for the context.
Returns: Returns:
A Context object with the specified ID and default settings. A Context object with the specified ID and default settings.
""" """
if not self.file_watcher:
raise Exception("File watcher not initialized")
file_path = os.path.join(defines.context_dir, agent_id) file_path = os.path.join(defines.context_dir, context_id)
# Check if the file exists # Check if the file exists
if not os.path.exists(file_path): if not os.path.exists(file_path):
self.contexts[agent_id] = self.create_context(agent_id) logging.info(f"Context file {file_path} not found. Creating new context.")
self.contexts[context_id] = self.create_context(context_id)
else: else:
# Read and deserialize the data # Read and deserialize the data
with open(file_path, "r") as f: with open(file_path, "r") as f:
self.contexts[agent_id] = Context.model_validate_json(f.read()) content = f.read()
logging.info(f"Loading context from {file_path}, content length: {len(content)}")
try:
# Try parsing as JSON first to ensure valid JSON
import json
json_data = json.loads(content)
logging.info("JSON parsed successfully, attempting model validation")
return self.contexts[agent_id] # Now try Pydantic validation
self.contexts[context_id] = Context.from_json(json_data, file_watcher=self.file_watcher)
logging.info(f"Successfully loaded context {context_id}")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in file: {e}")
except Exception as e:
logging.error(f"Error validating context: {str(e)}")
import traceback
logging.error(traceback.format_exc())
# Fallback to creating a new context
self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher)
return self.contexts[context_id]
def create_context(self, context_id = None) -> Context: def create_context(self, context_id = None) -> Context:
""" """
@ -814,7 +842,11 @@ class WebServer:
Returns: Returns:
A Context object with the specified ID and default settings. A Context object with the specified ID and default settings.
""" """
context = Context(id=context_id) if not self.file_watcher:
raise Exception("File watcher not initialized")
logging.info(f"Creating new context with ID: {context_id}")
context = Context(id=context_id, file_watcher=self.file_watcher)
if os.path.exists(defines.resume_doc): if os.path.exists(defines.resume_doc):
context.user_resume = open(defines.resume_doc, "r").read() context.user_resume = open(defines.resume_doc, "r").read()
@ -912,20 +944,18 @@ class WebServer:
logging.warning("No context ID provided. Creating a new context.") logging.warning("No context ID provided. Creating a new context.")
return self.create_context() return self.create_context()
if not is_valid_uuid(context_id):
logging.info(f"User requested invalid context_id: {context_id}")
raise ValueError("Invalid context_id: {context_id}")
if context_id in self.contexts: if context_id in self.contexts:
return self.contexts[context_id] return self.contexts[context_id]
logging.info(f"Context {context_id} not found. Creating new context.") logging.info(f"Context {context_id} is not yet loaded.")
return self.load_context(context_id) return self.load_or_create_context(context_id)
def generate_rag_results(self, context, content): def generate_rag_results(self, context, content):
if not self.file_watcher:
raise Exception("File watcher not initialized")
results_found = False results_found = False
if self.file_watcher:
for rag in context.rags: for rag in context.rags:
if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now... if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now...
yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."} yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."}
@ -933,13 +963,13 @@ class WebServer:
if chroma_results: if chroma_results:
results_found = True results_found = True
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
print(f"Chroma embedding shape: {chroma_embedding.shape}") logging.info(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
yield { yield {
**chroma_results, **chroma_results,
@ -979,7 +1009,7 @@ class WebServer:
# * Then Q&A of Fact Check # * Then Q&A of Fact Check
async def generate_response(self, context : Context, agent : Agent, content : str): async def generate_response(self, context : Context, agent : Agent, content : str):
if not self.file_watcher: if not self.file_watcher:
return raise Exception("File watcher not initialized")
agent_type = agent.get_agent_type() agent_type = agent.get_agent_type()
logging.info(f"generate_response: {agent_type}") logging.info(f"generate_response: {agent_type}")
@ -1015,6 +1045,7 @@ class WebServer:
logging.info("TODO: There is more to do...") logging.info("TODO: There is more to do...")
return return
return
if self.processing: if self.processing:
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")

View File

@ -1,4 +1,5 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr from __future__ import annotations
from pydantic import BaseModel, model_validator, PrivateAttr, Field
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing_extensions import Annotated from typing_extensions import Annotated
@ -8,8 +9,6 @@ import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from .. context import Context from .. context import Context
ContextRef = ForwardRef('Context')
from .types import AgentBase, registry from .types import AgentBase, registry
from .. conversation import Conversation from .. conversation import Conversation
@ -29,9 +28,11 @@ class Agent(AgentBase):
system_prompt: str # Mandatory system_prompt: str # Mandatory
conversation: Conversation = Conversation() conversation: Conversation = Conversation()
context_tokens: int = 0 context_tokens: int = 0
context: ContextRef # Avoid circular reference context: object = Field(..., exclude=True) # Avoid circular reference, require as param, and prevent serialization
_content_seed: str = PrivateAttr(default="") _content_seed: str = PrivateAttr(default="")
# Class and pydantic model management
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
"""Auto-register subclasses""" """Auto-register subclasses"""
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
@ -48,6 +49,24 @@ class Agent(AgentBase):
self.__class__.model_rebuild() self.__class__.model_rebuild()
super().__init__(**data) super().__init__(**data)
def model_dump(self, *args, **kwargs):
# Ensure context is always excluded, even with exclude_unset=True
kwargs.setdefault("exclude", set())
if isinstance(kwargs["exclude"], set):
kwargs["exclude"].add("context")
elif isinstance(kwargs["exclude"], dict):
kwargs["exclude"]["context"] = True
return super().model_dump(*args, **kwargs)
@classmethod
def valid_agent_types(cls) -> set[str]:
"""Return the set of valid agent_type values."""
return set(get_args(cls.__annotations__["agent_type"]))
def set_context(self, context):
object.__setattr__(self, "context", context)
# Agent methods
def get_agent_type(self): def get_agent_type(self):
return self._agent_type return self._agent_type
@ -240,11 +259,6 @@ class Agent(AgentBase):
"""Get the content seed for the agent.""" """Get the content seed for the agent."""
return self._content_seed return self._content_seed
@classmethod
def valid_agent_types(cls) -> set[str]:
"""Return the set of valid agent_type values."""
return set(get_args(cls.__annotations__["agent_type"]))
# Register the base agent # Register the base agent
registry.register(Agent._agent_type, Agent) registry.register(Agent._agent_type, Agent)

View File

@ -1,4 +1,5 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr from __future__ import annotations
from pydantic import BaseModel, model_validator, PrivateAttr
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
from typing_extensions import Annotated from typing_extensions import Annotated
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -35,7 +36,7 @@ class Chat(Agent, ABC):
yield message yield message
return return
if message.metadata["rag"]: if "rag" in message.metadata and message.metadata["rag"]:
for rag_collection in message.metadata["rag"]: for rag_collection in message.metadata["rag"]:
for doc in rag_collection["documents"]: for doc in rag_collection["documents"]:
rag_context += f"{doc}\n" rag_context += f"{doc}\n"

View File

@ -1,9 +1,11 @@
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator, ValidationError
from uuid import uuid4 from uuid import uuid4
from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING
from typing_extensions import Annotated, Union from typing_extensions import Annotated, Union
import numpy as np import numpy as np
import logging import logging
from uuid import uuid4
import re
from .message import Message from .message import Message
from .rag import ChromaDBFileWatcher from .rag import ChromaDBFileWatcher
@ -13,22 +15,23 @@ from .agents import Agent
# Import only agent types, not actual classes # Import only agent types, not actual classes
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import Agent, AnyAgent, Chat, Resume, JobDescription, FactCheck from .agents import Agent, AnyAgent
from .agents import AnyAgent from .agents import AnyAgent
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Context(BaseModel): class Context(BaseModel):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
# Required fields
file_watcher: ChromaDBFileWatcher = Field(..., exclude=True)
# Optional fields
id: str = Field( id: str = Field(
default_factory=lambda: str(uuid4()), default_factory=lambda: str(uuid4()),
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
) )
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
default_factory=list
)
user_resume: Optional[str] = None user_resume: Optional[str] = None
user_job_description: Optional[str] = None user_job_description: Optional[str] = None
user_facts: Optional[str] = None user_facts: Optional[str] = None
@ -36,17 +39,27 @@ class Context(BaseModel):
rags: List[dict] = [] rags: List[dict] = []
message_history_length: int = 5 message_history_length: int = 5
context_tokens: int = 0 context_tokens: int = 0
file_watcher: ChromaDBFileWatcher = Field(default=None, exclude=True) # Class managed fields
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
default_factory=list
)
def __init__(self, id: Optional[str] = None, **kwargs): @classmethod
super().__init__(id=id if id is not None else str(uuid4()), **kwargs) def from_json(cls, json_str: str, file_watcher: ChromaDBFileWatcher):
"""Custom method to load from JSON with file_watcher injection"""
import json
data = json.loads(json_str)
return cls(file_watcher=file_watcher, **data)
@model_validator(mode="after") @model_validator(mode="after")
def validate_unique_agent_types(self): def validate_unique_agent_types(self):
"""Ensure at most one agent per agent_type.""" """Ensure at most one agent per agent_type."""
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
agent_types = [agent.agent_type for agent in self.agents] agent_types = [agent.agent_type for agent in self.agents]
if len(agent_types) != len(set(agent_types)): if len(agent_types) != len(set(agent_types)):
raise ValueError("Context cannot contain multiple agents of the same agent_type") raise ValueError("Context cannot contain multiple agents of the same agent_type")
for agent in self.agents:
agent.set_context(self)
return self return self
def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096): def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096):
@ -110,7 +123,7 @@ class Context(BaseModel):
except Exception as e: except Exception as e:
message.response = f"Error generating RAG results: {str(e)}" message.response = f"Error generating RAG results: {str(e)}"
message.status = "error" message.status = "error"
logging.error(e) logger.error(e)
yield message yield message
return return

View File

@ -11,6 +11,7 @@ class Message(BaseModel):
disable_tools: bool = False disable_tools: bool = False
# Generated while processing message # Generated while processing message
status: str = "" # Status of the message
preamble: dict[str,str] = {} # Preamble to be prepended to the prompt preamble: dict[str,str] = {} # Preamble to be prepended to the prompt
system_prompt: str = "" # System prompt provided to the LLM system_prompt: str = "" # System prompt provided to the LLM
full_content: str = "" # Full content of the message (preamble + prompt) full_content: str = "" # Full content of the message (preamble + prompt)