Starting to work again

This commit is contained in:
James Ketr 2025-04-30 12:57:51 -07:00
parent 4614dbb237
commit e607e3a2f2
8 changed files with 351 additions and 291 deletions

View File

@ -1,5 +1,10 @@
import os
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
import warnings import warnings
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel") warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
warnings.filterwarnings("ignore", message="Warning only once for all operators")
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
# %% # %%
# Imports [standard] # Imports [standard]
@ -37,6 +42,7 @@ try_import("sklearn")
import ollama import ollama
import requests import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, BackgroundTasks from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -363,8 +369,23 @@ def llm_tools(tools):
# %% # %%
class WebServer: class WebServer:
@asynccontextmanager
async def lifespan(self, app: FastAPI):
# Start the file watcher
self.observer, self.file_watcher = Rag.start_file_watcher(
llm=self.llm,
watch_directory=defines.doc_dir,
recreate=False # Don't recreate if exists
)
logging.info(f"API started with {self.file_watcher.collection.count()} documents in the collection")
yield
if self.observer:
self.observer.stop()
self.observer.join()
logging.info("File watcher stopped")
def __init__(self, llm, model=MODEL_NAME): def __init__(self, llm, model=MODEL_NAME):
self.app = FastAPI() self.app = FastAPI(lifespan=self.lifespan)
self.contexts = {} self.contexts = {}
self.llm = llm self.llm = llm
self.model = model self.model = model
@ -389,24 +410,6 @@ class WebServer:
allow_headers=["*"], allow_headers=["*"],
) )
@self.app.on_event("startup")
async def startup_event():
# Start the file watcher
self.observer, self.file_watcher = Rag.start_file_watcher(
llm=llm,
watch_directory=defines.doc_dir,
recreate=False # Don't recreate if exists
)
print(f"API started with {self.file_watcher.collection.count()} documents in the collection")
@self.app.on_event("shutdown")
async def shutdown_event():
if self.observer:
self.observer.stop()
self.observer.join()
print("File watcher stopped")
self.setup_routes() self.setup_routes()
def setup_routes(self): def setup_routes(self):
@ -444,14 +447,16 @@ class WebServer:
return JSONResponse(result) return JSONResponse(result)
except Exception as e: except Exception as e:
logging.error(e) logging.error(f"put_umap error: {str(e)}")
import traceback
logging.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, 500) return JSONResponse({"error": str(e)}, 500)
@self.app.put("/api/similarity/{context_id}") @self.app.put("/api/similarity/{context_id}")
async def put_similarity(context_id: str, request: Request): async def put_similarity(context_id: str, request: Request):
logging.info(f"{request.method} {request.url.path}") logging.info(f"{request.method} {request.url.path}")
if not self.file_watcher: if not self.file_watcher:
return raise Exception("File watcher not initialized")
if not is_valid_uuid(context_id): if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}") logging.warning(f"Invalid context_id: {context_id}")
@ -471,13 +476,13 @@ class WebServer:
return JSONResponse({"error": "No results found"}, status_code=404) return JSONResponse({"error": "No results found"}, status_code=404)
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
print(f"Chroma embedding shape: {chroma_embedding.shape}") logging.info(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
return JSONResponse({ return JSONResponse({
**chroma_results, **chroma_results,
@ -666,7 +671,7 @@ class WebServer:
async def flush_generator(): async def flush_generator():
async for message in self.generate_response(context=context, agent=agent, content=data["content"]): async for message in self.generate_response(context=context, agent=agent, content=data["content"]):
# Convert to JSON and add newline # Convert to JSON and add newline
yield json.dumps(message) + "\n" yield (message.model_dump_json()) + "\n"
# Save the history as its generated # Save the history as its generated
self.save_context(context_id) self.save_context(context_id)
# Explicitly flush after each yield # Explicitly flush after each yield
@ -704,7 +709,9 @@ class WebServer:
logging.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.") logging.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.")
return agent.conversation return agent.conversation
except Exception as e: except Exception as e:
logging.error(f"Error in get_history: {e}") logging.error(f"get_history error: {str(e)}")
import traceback
logging.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, status_code=404) return JSONResponse({"error": str(e)}, status_code=404)
@self.app.get("/api/tools/{context_id}") @self.app.get("/api/tools/{context_id}")
@ -759,52 +766,73 @@ class WebServer:
logging.info(f"Serve index.html for {path}") logging.info(f"Serve index.html for {path}")
return FileResponse(os.path.join(defines.static_content, "index.html")) return FileResponse(os.path.join(defines.static_content, "index.html"))
def save_context(self, agent_id): def save_context(self, context_id):
""" """
Serialize a Python dictionary to a file in the agents directory. Serialize a Python dictionary to a file in the agents directory.
Args: Args:
data: Dictionary containing the agent data data: Dictionary containing the agent data
agent_id: UUID string for the context. If it doesn't exist, it is created context_id: UUID string for the context. If it doesn't exist, it is created
Returns: Returns:
The agent_id used for the file The context_id used for the file
""" """
context = self.upsert_context(agent_id) context = self.upsert_context(context_id)
# Create agents directory if it doesn't exist # Create agents directory if it doesn't exist
if not os.path.exists(defines.context_dir): if not os.path.exists(defines.context_dir):
os.makedirs(defines.context_dir) os.makedirs(defines.context_dir)
# Create the full file path # Create the full file path
file_path = os.path.join(defines.context_dir, agent_id) file_path = os.path.join(defines.context_dir, context_id)
# Serialize the data to JSON and write to file # Serialize the data to JSON and write to file
with open(file_path, "w") as f: with open(file_path, "w") as f:
f.write(context.model_dump_json()) f.write(context.model_dump_json())
return agent_id return context_id
def load_context(self, agent_id) -> Context: def load_or_create_context(self, context_id) -> Context:
""" """
Load a context from a file in the agents directory. Load a context from a file in the context directory or create a new one if it doesn't exist.
Args: Args:
agent_id: UUID string for the context. If it doesn't exist, a new context is created. context_id: UUID string for the context.
Returns: Returns:
A Context object with the specified ID and default settings. A Context object with the specified ID and default settings.
""" """
if not self.file_watcher:
raise Exception("File watcher not initialized")
file_path = os.path.join(defines.context_dir, agent_id) file_path = os.path.join(defines.context_dir, context_id)
# Check if the file exists # Check if the file exists
if not os.path.exists(file_path): if not os.path.exists(file_path):
self.contexts[agent_id] = self.create_context(agent_id) logging.info(f"Context file {file_path} not found. Creating new context.")
self.contexts[context_id] = self.create_context(context_id)
else: else:
# Read and deserialize the data # Read and deserialize the data
with open(file_path, "r") as f: with open(file_path, "r") as f:
self.contexts[agent_id] = Context.model_validate_json(f.read()) content = f.read()
logging.info(f"Loading context from {file_path}, content length: {len(content)}")
try:
# Try parsing as JSON first to ensure valid JSON
import json
json_data = json.loads(content)
logging.info("JSON parsed successfully, attempting model validation")
return self.contexts[agent_id] # Now try Pydantic validation
self.contexts[context_id] = Context.from_json(json_data, file_watcher=self.file_watcher)
logging.info(f"Successfully loaded context {context_id}")
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON in file: {e}")
except Exception as e:
logging.error(f"Error validating context: {str(e)}")
import traceback
logging.error(traceback.format_exc())
# Fallback to creating a new context
self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher)
return self.contexts[context_id]
def create_context(self, context_id = None) -> Context: def create_context(self, context_id = None) -> Context:
""" """
@ -814,7 +842,11 @@ class WebServer:
Returns: Returns:
A Context object with the specified ID and default settings. A Context object with the specified ID and default settings.
""" """
context = Context(id=context_id) if not self.file_watcher:
raise Exception("File watcher not initialized")
logging.info(f"Creating new context with ID: {context_id}")
context = Context(id=context_id, file_watcher=self.file_watcher)
if os.path.exists(defines.resume_doc): if os.path.exists(defines.resume_doc):
context.user_resume = open(defines.resume_doc, "r").read() context.user_resume = open(defines.resume_doc, "r").read()
@ -912,41 +944,39 @@ class WebServer:
logging.warning("No context ID provided. Creating a new context.") logging.warning("No context ID provided. Creating a new context.")
return self.create_context() return self.create_context()
if not is_valid_uuid(context_id):
logging.info(f"User requested invalid context_id: {context_id}")
raise ValueError("Invalid context_id: {context_id}")
if context_id in self.contexts: if context_id in self.contexts:
return self.contexts[context_id] return self.contexts[context_id]
logging.info(f"Context {context_id} not found. Creating new context.") logging.info(f"Context {context_id} is not yet loaded.")
return self.load_context(context_id) return self.load_or_create_context(context_id)
def generate_rag_results(self, context, content): def generate_rag_results(self, context, content):
if not self.file_watcher:
raise Exception("File watcher not initialized")
results_found = False results_found = False
if self.file_watcher: for rag in context.rags:
for rag in context.rags: if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now...
if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now... yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."}
yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."} chroma_results = self.file_watcher.find_similar(query=content, top_k=10)
chroma_results = self.file_watcher.find_similar(query=content, top_k=10) if chroma_results:
if chroma_results: results_found = True
results_found = True chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape logging.info(f"Chroma embedding shape: {chroma_embedding.shape}")
print(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
yield { yield {
**chroma_results, **chroma_results,
"name": rag["name"], "name": rag["name"],
"umap_embedding_2d": umap_2d, "umap_embedding_2d": umap_2d,
"umap_embedding_3d": umap_3d "umap_embedding_3d": umap_3d
} }
if not results_found: if not results_found:
yield {"status": "complete", "message": "No RAG context found"} yield {"status": "complete", "message": "No RAG context found"}
@ -979,7 +1009,7 @@ class WebServer:
# * Then Q&A of Fact Check # * Then Q&A of Fact Check
async def generate_response(self, context : Context, agent : Agent, content : str): async def generate_response(self, context : Context, agent : Agent, content : str):
if not self.file_watcher: if not self.file_watcher:
return raise Exception("File watcher not initialized")
agent_type = agent.get_agent_type() agent_type = agent.get_agent_type()
logging.info(f"generate_response: {agent_type}") logging.info(f"generate_response: {agent_type}")
@ -1015,6 +1045,7 @@ class WebServer:
logging.info("TODO: There is more to do...") logging.info("TODO: There is more to do...")
return return
return
if self.processing: if self.processing:
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")

View File

@ -1,4 +1,5 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr from __future__ import annotations
from pydantic import BaseModel, model_validator, PrivateAttr, Field
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing_extensions import Annotated from typing_extensions import Annotated
@ -8,8 +9,6 @@ import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from .. context import Context from .. context import Context
ContextRef = ForwardRef('Context')
from .types import AgentBase, registry from .types import AgentBase, registry
from .. conversation import Conversation from .. conversation import Conversation
@ -29,9 +28,11 @@ class Agent(AgentBase):
system_prompt: str # Mandatory system_prompt: str # Mandatory
conversation: Conversation = Conversation() conversation: Conversation = Conversation()
context_tokens: int = 0 context_tokens: int = 0
context: ContextRef # Avoid circular reference context: object = Field(..., exclude=True) # Avoid circular reference, require as param, and prevent serialization
_content_seed: str = PrivateAttr(default="") _content_seed: str = PrivateAttr(default="")
# Class and pydantic model management
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
"""Auto-register subclasses""" """Auto-register subclasses"""
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
@ -48,6 +49,24 @@ class Agent(AgentBase):
self.__class__.model_rebuild() self.__class__.model_rebuild()
super().__init__(**data) super().__init__(**data)
def model_dump(self, *args, **kwargs):
# Ensure context is always excluded, even with exclude_unset=True
kwargs.setdefault("exclude", set())
if isinstance(kwargs["exclude"], set):
kwargs["exclude"].add("context")
elif isinstance(kwargs["exclude"], dict):
kwargs["exclude"]["context"] = True
return super().model_dump(*args, **kwargs)
@classmethod
def valid_agent_types(cls) -> set[str]:
"""Return the set of valid agent_type values."""
return set(get_args(cls.__annotations__["agent_type"]))
def set_context(self, context):
object.__setattr__(self, "context", context)
# Agent methods
def get_agent_type(self): def get_agent_type(self):
return self._agent_type return self._agent_type
@ -240,11 +259,6 @@ class Agent(AgentBase):
"""Get the content seed for the agent.""" """Get the content seed for the agent."""
return self._content_seed return self._content_seed
@classmethod
def valid_agent_types(cls) -> set[str]:
"""Return the set of valid agent_type values."""
return set(get_args(cls.__annotations__["agent_type"]))
# Register the base agent # Register the base agent
registry.register(Agent._agent_type, Agent) registry.register(Agent._agent_type, Agent)

View File

@ -1,4 +1,5 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr from __future__ import annotations
from pydantic import BaseModel, model_validator, PrivateAttr
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
from typing_extensions import Annotated from typing_extensions import Annotated
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -9,206 +10,206 @@ from .. conversation import Conversation
from .. message import Message from .. message import Message
class Chat(Agent, ABC): class Chat(Agent, ABC):
"""
Base class for all agent types.
This class defines the common attributes and methods for all agent types.
"""
agent_type: Literal["chat"] = "chat"
_agent_type: ClassVar[str] = agent_type # Add this for registration
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
""" """
Base class for all agent types. Prepare message with context information in message.preamble
This class defines the common attributes and methods for all agent types.
""" """
agent_type: Literal["chat"] = "chat" # Generate RAG content if enabled, based on the content
_agent_type: ClassVar[str] = agent_type # Add this for registration rag_context = ""
if not message.disable_rag:
# Gather RAG results, yielding each result
# as it becomes available
for value in self.context.generate_rag_results(message):
logging.info(f"RAG: {value.status} - {value.response}")
if value.status != "done":
yield value
if value.status == "error":
message.status = "error"
message.response = value.response
yield message
return
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: if "rag" in message.metadata and message.metadata["rag"]:
""" for rag_collection in message.metadata["rag"]:
Prepare message with context information in message.preamble for doc in rag_collection["documents"]:
""" rag_context += f"{doc}\n"
# Generate RAG content if enabled, based on the content
rag_context = ""
if not message.disable_rag:
# Gather RAG results, yielding each result
# as it becomes available
for value in self.context.generate_rag_results(message):
logging.info(f"RAG: {value.status} - {value.response}")
if value.status != "done":
yield value
if value.status == "error":
message.status = "error"
message.response = value.response
yield message
return
if message.metadata["rag"]: if rag_context:
for rag_collection in message.metadata["rag"]: message["context"] = rag_context
for doc in rag_collection["documents"]:
rag_context += f"{doc}\n"
if rag_context: if self.context.user_resume:
message["context"] = rag_context message["resume"] = self.content.user_resume
if self.context.user_resume: if message.preamble:
message["resume"] = self.content.user_resume preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
preamble_types_AND = " and ".join(preamble_types)
if message.preamble: preamble_types_OR = " or ".join(preamble_types)
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()] message.preamble["rules"] = f"""\
preamble_types_AND = " and ".join(preamble_types)
preamble_types_OR = " or ".join(preamble_types)
message.preamble["rules"] = f"""\
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly. - Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly.
- If there is no information in these sections, answer based on your knowledge. - If there is no information in these sections, answer based on your knowledge.
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}. - Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
""" """
message.preamble["question"] = "Use that information to respond to:" message.preamble["question"] = "Use that information to respond to:"
else: else:
message.preamble["question"] = "Respond to:" message.preamble["question"] = "Respond to:"
message.system_prompt = self.system_prompt message.system_prompt = self.system_prompt
message.status = "done" message.status = "done"
yield message
return
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
if self.context.processing:
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
message.status = "error"
message.response = "Busy processing another request."
yield message
return
self.context.processing = True
messages = []
for value in self.llm.chat(
model=self.model,
messages=messages,
#tools=llm_tools(context.tools) if message.enable_tools else None,
options={ "num_ctx": message.ctx_size }
):
logging.info(f"LLM: {value.status} - {value.response}")
if value.status != "done":
message.status = value.status
message.response = value.response
yield message yield message
if value.status == "error":
return return
response = value
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]: message.metadata["eval_count"] += response["eval_count"]
if self.context.processing: message.metadata["eval_duration"] += response["eval_duration"]
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
message.status = "error" message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
message.response = "Busy processing another request." agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
yield message
return
self.context.processing = True tools_used = []
messages = [] yield {"status": "processing", "message": "Initial response received..."}
for value in self.llm.chat( if "tool_calls" in response.get("message", {}):
model=self.model, yield {"status": "processing", "message": "Processing tool calls..."}
messages=messages,
#tools=llm_tools(context.tools) if message.enable_tools else None,
options={ "num_ctx": message.ctx_size }
):
logging.info(f"LLM: {value.status} - {value.response}")
if value.status != "done":
message.status = value.status
message.response = value.response
yield message
if value.status == "error":
return
response = value
message.metadata["eval_count"] += response["eval_count"] tool_message = response["message"]
message.metadata["eval_duration"] += response["eval_duration"] tool_result = None
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
tools_used = [] # Process all yielded items from the handler
async for item in self.handle_tool_calls(tool_message):
if isinstance(item, tuple) and len(item) == 2:
# This is the final result tuple (tool_result, tools_used)
tool_result, tools_used = item
else:
# This is a status update, forward it
yield item
yield {"status": "processing", "message": "Initial response received..."} message_dict = {
"role": tool_message.get("role", "assistant"),
"content": tool_message.get("content", "")
}
if "tool_calls" in response.get("message", {}): if "tool_calls" in tool_message:
yield {"status": "processing", "message": "Processing tool calls..."} message_dict["tool_calls"] = [
{"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}}
for tc in tool_message["tool_calls"]
]
tool_message = response["message"] pre_add_index = len(messages)
tool_result = None messages.append(message_dict)
# Process all yielded items from the handler if isinstance(tool_result, list):
async for item in self.handle_tool_calls(tool_message): messages.extend(tool_result)
if isinstance(item, tuple) and len(item) == 2: else:
# This is the final result tuple (tool_result, tools_used) if tool_result:
tool_result, tools_used = item messages.append(tool_result)
else:
# This is a status update, forward it
yield item
message_dict = { message.metadata["tools"] = tools_used
"role": tool_message.get("role", "assistant"),
"content": tool_message.get("content", "")
}
if "tool_calls" in tool_message: # Estimate token length of new messages
message_dict["tool_calls"] = [ ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
{"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}} yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
for tc in tool_message["tool_calls"] # Decrease creativity when processing tool call requests
] response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
message.metadata["eval_count"] += response["eval_count"]
message.metadata["eval_duration"] += response["eval_duration"]
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
pre_add_index = len(messages) reply = response["message"]["content"]
messages.append(message_dict) message.response = reply
message.metadata["origin"] = agent.agent_type
# final_message = {"role": "assistant", "content": reply }
if isinstance(tool_result, list): # # history is provided to the LLM and should not have additional metadata
messages.extend(tool_result) # llm_history.append(final_message)
else:
if tool_result:
messages.append(tool_result)
message.metadata["tools"] = tools_used # user_history is provided to the REST API and does not include CONTEXT
# It does include metadata
# final_message["metadata"] = message.metadata
# user_history.append({**final_message, "origin": message.metadata["origin"]})
# Estimate token length of new messages # Return the REST API with metadata
ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:]) yield {
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size } "status": "done",
# Decrease creativity when processing tool call requests "message": {
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 }) **message.model_dump(mode='json'),
message.metadata["eval_count"] += response["eval_count"]
message.metadata["eval_duration"] += response["eval_duration"]
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
reply = response["message"]["content"]
message.response = reply
message.metadata["origin"] = agent.agent_type
# final_message = {"role": "assistant", "content": reply }
# # history is provided to the LLM and should not have additional metadata
# llm_history.append(final_message)
# user_history is provided to the REST API and does not include CONTEXT
# It does include metadata
# final_message["metadata"] = message.metadata
# user_history.append({**final_message, "origin": message.metadata["origin"]})
# Return the REST API with metadata
yield {
"status": "done",
"message": {
**message.model_dump(mode='json'),
}
} }
}
self.context.processing = False self.context.processing = False
return return
async def process_message(self, message:Message) -> AsyncGenerator[Message, None]: async def process_message(self, message:Message) -> AsyncGenerator[Message, None]:
message.full_content = "" message.full_content = ""
for i, p in enumerate(message.preamble.keys()): for i, p in enumerate(message.preamble.keys()):
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n" message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
# Estimate token length of new messages # Estimate token length of new messages
message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content) message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
message.status = "thinking" message.status = "thinking"
yield message yield message
for value in self.generate_llm_response(message): for value in self.generate_llm_response(message):
logging.info(f"LLM: {value.status} - {value.response}") logging.info(f"LLM: {value.status} - {value.response}")
if value.status != "done": if value.status != "done":
yield value yield value
if value.status == "error": if value.status == "error":
return return
def get_and_reset_content_seed(self): def get_and_reset_content_seed(self):
tmp = self._content_seed tmp = self._content_seed
self._content_seed = "" self._content_seed = ""
return tmp return tmp
def set_content_seed(self, content: str) -> None: def set_content_seed(self, content: str) -> None:
"""Set the content seed for the agent.""" """Set the content seed for the agent."""
self._content_seed = content self._content_seed = content
def get_content_seed(self) -> str: def get_content_seed(self) -> str:
"""Get the content seed for the agent.""" """Get the content seed for the agent."""
return self._content_seed return self._content_seed
@classmethod @classmethod
def valid_agent_types(cls) -> set[str]: def valid_agent_types(cls) -> set[str]:
"""Return the set of valid agent_type values.""" """Return the set of valid agent_type values."""
return set(get_args(cls.__annotations__["agent_type"])) return set(get_args(cls.__annotations__["agent_type"]))
# Register the base agent # Register the base agent
registry.register(Chat._agent_type, Chat) registry.register(Chat._agent_type, Chat)

View File

@ -9,16 +9,16 @@ from .. conversation import Conversation
from .. message import Message from .. message import Message
class FactCheck(Agent): class FactCheck(Agent):
agent_type: Literal["fact_check"] = "fact_check" agent_type: Literal["fact_check"] = "fact_check"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
facts: str = "" facts: str = ""
@model_validator(mode="after") @model_validator(mode="after")
def validate_facts(self): def validate_facts(self):
if not self.facts.strip(): if not self.facts.strip():
raise ValueError("Facts cannot be empty") raise ValueError("Facts cannot be empty")
return self return self
# Register the base agent # Register the base agent
registry.register(FactCheck._agent_type, FactCheck) registry.register(FactCheck._agent_type, FactCheck)

View File

@ -9,16 +9,16 @@ from .. conversation import Conversation
from .. message import Message from .. message import Message
class JobDescription(Agent): class JobDescription(Agent):
agent_type: Literal["job_description"] = "job_description" agent_type: Literal["job_description"] = "job_description"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
job_description: str = "" job_description: str = ""
@model_validator(mode="after") @model_validator(mode="after")
def validate_job_description(self): def validate_job_description(self):
if not self.job_description.strip(): if not self.job_description.strip():
raise ValueError("Job description cannot be empty") raise ValueError("Job description cannot be empty")
return self return self
# Register the base agent # Register the base agent
registry.register(JobDescription._agent_type, JobDescription) registry.register(JobDescription._agent_type, JobDescription)

View File

@ -9,24 +9,24 @@ from .. conversation import Conversation
from .. message import Message from .. message import Message
class Resume(Agent): class Resume(Agent):
agent_type: Literal["resume"] = "resume" agent_type: Literal["resume"] = "resume"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
resume: str = "" resume: str = ""
@model_validator(mode="after") @model_validator(mode="after")
def validate_resume(self): def validate_resume(self):
if not self.resume.strip(): if not self.resume.strip():
raise ValueError("Resume content cannot be empty") raise ValueError("Resume content cannot be empty")
return self return self
def get_resume(self) -> str: def get_resume(self) -> str:
"""Get the resume content.""" """Get the resume content."""
return self.resume return self.resume
def set_resume(self, resume: str) -> None: def set_resume(self, resume: str) -> None:
"""Set the resume content.""" """Set the resume content."""
self.resume = resume self.resume = resume
# Register the base agent # Register the base agent
registry.register(Resume._agent_type, Resume) registry.register(Resume._agent_type, Resume)

View File

@ -1,9 +1,11 @@
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator, ValidationError
from uuid import uuid4 from uuid import uuid4
from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING
from typing_extensions import Annotated, Union from typing_extensions import Annotated, Union
import numpy as np import numpy as np
import logging import logging
from uuid import uuid4
import re
from .message import Message from .message import Message
from .rag import ChromaDBFileWatcher from .rag import ChromaDBFileWatcher
@ -13,22 +15,23 @@ from .agents import Agent
# Import only agent types, not actual classes # Import only agent types, not actual classes
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import Agent, AnyAgent, Chat, Resume, JobDescription, FactCheck from .agents import Agent, AnyAgent
from .agents import AnyAgent from .agents import AnyAgent
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Context(BaseModel): class Context(BaseModel):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
# Required fields
file_watcher: ChromaDBFileWatcher = Field(..., exclude=True)
# Optional fields
id: str = Field( id: str = Field(
default_factory=lambda: str(uuid4()), default_factory=lambda: str(uuid4()),
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
) )
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
default_factory=list
)
user_resume: Optional[str] = None user_resume: Optional[str] = None
user_job_description: Optional[str] = None user_job_description: Optional[str] = None
user_facts: Optional[str] = None user_facts: Optional[str] = None
@ -36,17 +39,27 @@ class Context(BaseModel):
rags: List[dict] = [] rags: List[dict] = []
message_history_length: int = 5 message_history_length: int = 5
context_tokens: int = 0 context_tokens: int = 0
file_watcher: ChromaDBFileWatcher = Field(default=None, exclude=True) # Class managed fields
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
default_factory=list
)
def __init__(self, id: Optional[str] = None, **kwargs): @classmethod
super().__init__(id=id if id is not None else str(uuid4()), **kwargs) def from_json(cls, json_str: str, file_watcher: ChromaDBFileWatcher):
"""Custom method to load from JSON with file_watcher injection"""
import json
data = json.loads(json_str)
return cls(file_watcher=file_watcher, **data)
@model_validator(mode="after") @model_validator(mode="after")
def validate_unique_agent_types(self): def validate_unique_agent_types(self):
"""Ensure at most one agent per agent_type.""" """Ensure at most one agent per agent_type."""
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
agent_types = [agent.agent_type for agent in self.agents] agent_types = [agent.agent_type for agent in self.agents]
if len(agent_types) != len(set(agent_types)): if len(agent_types) != len(set(agent_types)):
raise ValueError("Context cannot contain multiple agents of the same agent_type") raise ValueError("Context cannot contain multiple agents of the same agent_type")
for agent in self.agents:
agent.set_context(self)
return self return self
def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096): def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096):
@ -110,7 +123,7 @@ class Context(BaseModel):
except Exception as e: except Exception as e:
message.response = f"Error generating RAG results: {str(e)}" message.response = f"Error generating RAG results: {str(e)}"
message.status = "error" message.status = "error"
logging.error(e) logger.error(e)
yield message yield message
return return

View File

@ -11,6 +11,7 @@ class Message(BaseModel):
disable_tools: bool = False disable_tools: bool = False
# Generated while processing message # Generated while processing message
status: str = "" # Status of the message
preamble: dict[str,str] = {} # Preamble to be prepended to the prompt preamble: dict[str,str] = {} # Preamble to be prepended to the prompt
system_prompt: str = "" # System prompt provided to the LLM system_prompt: str = "" # System prompt provided to the LLM
full_content: str = "" # Full content of the message (preamble + prompt) full_content: str = "" # Full content of the message (preamble + prompt)