Starting to work again
This commit is contained in:
parent
4614dbb237
commit
e607e3a2f2
163
src/server.py
163
src/server.py
@ -1,5 +1,10 @@
|
|||||||
|
import os
|
||||||
|
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
|
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
|
||||||
|
warnings.filterwarnings("ignore", message="Warning only once for all operators")
|
||||||
|
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Imports [standard]
|
# Imports [standard]
|
||||||
@ -37,6 +42,7 @@ try_import("sklearn")
|
|||||||
import ollama
|
import ollama
|
||||||
import requests
|
import requests
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI, Request, BackgroundTasks
|
from fastapi import FastAPI, Request, BackgroundTasks
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
|
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@ -363,8 +369,23 @@ def llm_tools(tools):
|
|||||||
|
|
||||||
# %%
|
# %%
|
||||||
class WebServer:
|
class WebServer:
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(self, app: FastAPI):
|
||||||
|
# Start the file watcher
|
||||||
|
self.observer, self.file_watcher = Rag.start_file_watcher(
|
||||||
|
llm=self.llm,
|
||||||
|
watch_directory=defines.doc_dir,
|
||||||
|
recreate=False # Don't recreate if exists
|
||||||
|
)
|
||||||
|
logging.info(f"API started with {self.file_watcher.collection.count()} documents in the collection")
|
||||||
|
yield
|
||||||
|
if self.observer:
|
||||||
|
self.observer.stop()
|
||||||
|
self.observer.join()
|
||||||
|
logging.info("File watcher stopped")
|
||||||
|
|
||||||
def __init__(self, llm, model=MODEL_NAME):
|
def __init__(self, llm, model=MODEL_NAME):
|
||||||
self.app = FastAPI()
|
self.app = FastAPI(lifespan=self.lifespan)
|
||||||
self.contexts = {}
|
self.contexts = {}
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -389,24 +410,6 @@ class WebServer:
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@self.app.on_event("startup")
|
|
||||||
async def startup_event():
|
|
||||||
# Start the file watcher
|
|
||||||
self.observer, self.file_watcher = Rag.start_file_watcher(
|
|
||||||
llm=llm,
|
|
||||||
watch_directory=defines.doc_dir,
|
|
||||||
recreate=False # Don't recreate if exists
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"API started with {self.file_watcher.collection.count()} documents in the collection")
|
|
||||||
|
|
||||||
@self.app.on_event("shutdown")
|
|
||||||
async def shutdown_event():
|
|
||||||
if self.observer:
|
|
||||||
self.observer.stop()
|
|
||||||
self.observer.join()
|
|
||||||
print("File watcher stopped")
|
|
||||||
|
|
||||||
self.setup_routes()
|
self.setup_routes()
|
||||||
|
|
||||||
def setup_routes(self):
|
def setup_routes(self):
|
||||||
@ -444,14 +447,16 @@ class WebServer:
|
|||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(e)
|
logging.error(f"put_umap error: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
return JSONResponse({"error": str(e)}, 500)
|
return JSONResponse({"error": str(e)}, 500)
|
||||||
|
|
||||||
@self.app.put("/api/similarity/{context_id}")
|
@self.app.put("/api/similarity/{context_id}")
|
||||||
async def put_similarity(context_id: str, request: Request):
|
async def put_similarity(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logging.info(f"{request.method} {request.url.path}")
|
||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
return
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
if not is_valid_uuid(context_id):
|
if not is_valid_uuid(context_id):
|
||||||
logging.warning(f"Invalid context_id: {context_id}")
|
logging.warning(f"Invalid context_id: {context_id}")
|
||||||
@ -471,13 +476,13 @@ class WebServer:
|
|||||||
return JSONResponse({"error": "No results found"}, status_code=404)
|
return JSONResponse({"error": "No results found"}, status_code=404)
|
||||||
|
|
||||||
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
||||||
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
logging.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||||
|
|
||||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
||||||
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
||||||
|
|
||||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
||||||
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
||||||
|
|
||||||
return JSONResponse({
|
return JSONResponse({
|
||||||
**chroma_results,
|
**chroma_results,
|
||||||
@ -666,7 +671,7 @@ class WebServer:
|
|||||||
async def flush_generator():
|
async def flush_generator():
|
||||||
async for message in self.generate_response(context=context, agent=agent, content=data["content"]):
|
async for message in self.generate_response(context=context, agent=agent, content=data["content"]):
|
||||||
# Convert to JSON and add newline
|
# Convert to JSON and add newline
|
||||||
yield json.dumps(message) + "\n"
|
yield (message.model_dump_json()) + "\n"
|
||||||
# Save the history as its generated
|
# Save the history as its generated
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
# Explicitly flush after each yield
|
# Explicitly flush after each yield
|
||||||
@ -704,7 +709,9 @@ class WebServer:
|
|||||||
logging.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.")
|
logging.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.")
|
||||||
return agent.conversation
|
return agent.conversation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in get_history: {e}")
|
logging.error(f"get_history error: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
return JSONResponse({"error": str(e)}, status_code=404)
|
return JSONResponse({"error": str(e)}, status_code=404)
|
||||||
|
|
||||||
@self.app.get("/api/tools/{context_id}")
|
@self.app.get("/api/tools/{context_id}")
|
||||||
@ -759,52 +766,73 @@ class WebServer:
|
|||||||
logging.info(f"Serve index.html for {path}")
|
logging.info(f"Serve index.html for {path}")
|
||||||
return FileResponse(os.path.join(defines.static_content, "index.html"))
|
return FileResponse(os.path.join(defines.static_content, "index.html"))
|
||||||
|
|
||||||
def save_context(self, agent_id):
|
def save_context(self, context_id):
|
||||||
"""
|
"""
|
||||||
Serialize a Python dictionary to a file in the agents directory.
|
Serialize a Python dictionary to a file in the agents directory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Dictionary containing the agent data
|
data: Dictionary containing the agent data
|
||||||
agent_id: UUID string for the context. If it doesn't exist, it is created
|
context_id: UUID string for the context. If it doesn't exist, it is created
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The agent_id used for the file
|
The context_id used for the file
|
||||||
"""
|
"""
|
||||||
context = self.upsert_context(agent_id)
|
context = self.upsert_context(context_id)
|
||||||
|
|
||||||
# Create agents directory if it doesn't exist
|
# Create agents directory if it doesn't exist
|
||||||
if not os.path.exists(defines.context_dir):
|
if not os.path.exists(defines.context_dir):
|
||||||
os.makedirs(defines.context_dir)
|
os.makedirs(defines.context_dir)
|
||||||
|
|
||||||
# Create the full file path
|
# Create the full file path
|
||||||
file_path = os.path.join(defines.context_dir, agent_id)
|
file_path = os.path.join(defines.context_dir, context_id)
|
||||||
|
|
||||||
# Serialize the data to JSON and write to file
|
# Serialize the data to JSON and write to file
|
||||||
with open(file_path, "w") as f:
|
with open(file_path, "w") as f:
|
||||||
f.write(context.model_dump_json())
|
f.write(context.model_dump_json())
|
||||||
|
|
||||||
return agent_id
|
return context_id
|
||||||
|
|
||||||
def load_context(self, agent_id) -> Context:
|
def load_or_create_context(self, context_id) -> Context:
|
||||||
"""
|
"""
|
||||||
Load a context from a file in the agents directory.
|
Load a context from a file in the context directory or create a new one if it doesn't exist.
|
||||||
Args:
|
Args:
|
||||||
agent_id: UUID string for the context. If it doesn't exist, a new context is created.
|
context_id: UUID string for the context.
|
||||||
Returns:
|
Returns:
|
||||||
A Context object with the specified ID and default settings.
|
A Context object with the specified ID and default settings.
|
||||||
"""
|
"""
|
||||||
|
if not self.file_watcher:
|
||||||
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
file_path = os.path.join(defines.context_dir, agent_id)
|
file_path = os.path.join(defines.context_dir, context_id)
|
||||||
|
|
||||||
# Check if the file exists
|
# Check if the file exists
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
self.contexts[agent_id] = self.create_context(agent_id)
|
logging.info(f"Context file {file_path} not found. Creating new context.")
|
||||||
|
self.contexts[context_id] = self.create_context(context_id)
|
||||||
else:
|
else:
|
||||||
# Read and deserialize the data
|
# Read and deserialize the data
|
||||||
with open(file_path, "r") as f:
|
with open(file_path, "r") as f:
|
||||||
self.contexts[agent_id] = Context.model_validate_json(f.read())
|
content = f.read()
|
||||||
|
logging.info(f"Loading context from {file_path}, content length: {len(content)}")
|
||||||
|
try:
|
||||||
|
# Try parsing as JSON first to ensure valid JSON
|
||||||
|
import json
|
||||||
|
json_data = json.loads(content)
|
||||||
|
logging.info("JSON parsed successfully, attempting model validation")
|
||||||
|
|
||||||
return self.contexts[agent_id]
|
# Now try Pydantic validation
|
||||||
|
self.contexts[context_id] = Context.from_json(json_data, file_watcher=self.file_watcher)
|
||||||
|
logging.info(f"Successfully loaded context {context_id}")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logging.error(f"Invalid JSON in file: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error validating context: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
# Fallback to creating a new context
|
||||||
|
self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher)
|
||||||
|
|
||||||
|
return self.contexts[context_id]
|
||||||
|
|
||||||
def create_context(self, context_id = None) -> Context:
|
def create_context(self, context_id = None) -> Context:
|
||||||
"""
|
"""
|
||||||
@ -814,7 +842,11 @@ class WebServer:
|
|||||||
Returns:
|
Returns:
|
||||||
A Context object with the specified ID and default settings.
|
A Context object with the specified ID and default settings.
|
||||||
"""
|
"""
|
||||||
context = Context(id=context_id)
|
if not self.file_watcher:
|
||||||
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
|
logging.info(f"Creating new context with ID: {context_id}")
|
||||||
|
context = Context(id=context_id, file_watcher=self.file_watcher)
|
||||||
|
|
||||||
if os.path.exists(defines.resume_doc):
|
if os.path.exists(defines.resume_doc):
|
||||||
context.user_resume = open(defines.resume_doc, "r").read()
|
context.user_resume = open(defines.resume_doc, "r").read()
|
||||||
@ -912,41 +944,39 @@ class WebServer:
|
|||||||
logging.warning("No context ID provided. Creating a new context.")
|
logging.warning("No context ID provided. Creating a new context.")
|
||||||
return self.create_context()
|
return self.create_context()
|
||||||
|
|
||||||
if not is_valid_uuid(context_id):
|
|
||||||
logging.info(f"User requested invalid context_id: {context_id}")
|
|
||||||
raise ValueError("Invalid context_id: {context_id}")
|
|
||||||
|
|
||||||
if context_id in self.contexts:
|
if context_id in self.contexts:
|
||||||
return self.contexts[context_id]
|
return self.contexts[context_id]
|
||||||
|
|
||||||
logging.info(f"Context {context_id} not found. Creating new context.")
|
logging.info(f"Context {context_id} is not yet loaded.")
|
||||||
return self.load_context(context_id)
|
return self.load_or_create_context(context_id)
|
||||||
|
|
||||||
def generate_rag_results(self, context, content):
|
def generate_rag_results(self, context, content):
|
||||||
|
if not self.file_watcher:
|
||||||
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
results_found = False
|
results_found = False
|
||||||
|
|
||||||
if self.file_watcher:
|
for rag in context.rags:
|
||||||
for rag in context.rags:
|
if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now...
|
||||||
if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now...
|
yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."}
|
||||||
yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."}
|
chroma_results = self.file_watcher.find_similar(query=content, top_k=10)
|
||||||
chroma_results = self.file_watcher.find_similar(query=content, top_k=10)
|
if chroma_results:
|
||||||
if chroma_results:
|
results_found = True
|
||||||
results_found = True
|
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
||||||
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
logging.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||||
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
|
||||||
|
|
||||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
||||||
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
||||||
|
|
||||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
||||||
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
**chroma_results,
|
**chroma_results,
|
||||||
"name": rag["name"],
|
"name": rag["name"],
|
||||||
"umap_embedding_2d": umap_2d,
|
"umap_embedding_2d": umap_2d,
|
||||||
"umap_embedding_3d": umap_3d
|
"umap_embedding_3d": umap_3d
|
||||||
}
|
}
|
||||||
|
|
||||||
if not results_found:
|
if not results_found:
|
||||||
yield {"status": "complete", "message": "No RAG context found"}
|
yield {"status": "complete", "message": "No RAG context found"}
|
||||||
@ -979,7 +1009,7 @@ class WebServer:
|
|||||||
# * Then Q&A of Fact Check
|
# * Then Q&A of Fact Check
|
||||||
async def generate_response(self, context : Context, agent : Agent, content : str):
|
async def generate_response(self, context : Context, agent : Agent, content : str):
|
||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
return
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
agent_type = agent.get_agent_type()
|
agent_type = agent.get_agent_type()
|
||||||
logging.info(f"generate_response: {agent_type}")
|
logging.info(f"generate_response: {agent_type}")
|
||||||
@ -1015,6 +1045,7 @@ class WebServer:
|
|||||||
logging.info("TODO: There is more to do...")
|
logging.info("TODO: There is more to do...")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
if self.processing:
|
if self.processing:
|
||||||
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
from __future__ import annotations
|
||||||
|
from pydantic import BaseModel, model_validator, PrivateAttr, Field
|
||||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef
|
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
@ -8,8 +9,6 @@ import logging
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .. context import Context
|
from .. context import Context
|
||||||
|
|
||||||
ContextRef = ForwardRef('Context')
|
|
||||||
|
|
||||||
from .types import AgentBase, registry
|
from .types import AgentBase, registry
|
||||||
|
|
||||||
from .. conversation import Conversation
|
from .. conversation import Conversation
|
||||||
@ -29,9 +28,11 @@ class Agent(AgentBase):
|
|||||||
system_prompt: str # Mandatory
|
system_prompt: str # Mandatory
|
||||||
conversation: Conversation = Conversation()
|
conversation: Conversation = Conversation()
|
||||||
context_tokens: int = 0
|
context_tokens: int = 0
|
||||||
context: ContextRef # Avoid circular reference
|
context: object = Field(..., exclude=True) # Avoid circular reference, require as param, and prevent serialization
|
||||||
|
|
||||||
_content_seed: str = PrivateAttr(default="")
|
_content_seed: str = PrivateAttr(default="")
|
||||||
|
|
||||||
|
# Class and pydantic model management
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
"""Auto-register subclasses"""
|
"""Auto-register subclasses"""
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
@ -48,6 +49,24 @@ class Agent(AgentBase):
|
|||||||
self.__class__.model_rebuild()
|
self.__class__.model_rebuild()
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|
||||||
|
def model_dump(self, *args, **kwargs):
|
||||||
|
# Ensure context is always excluded, even with exclude_unset=True
|
||||||
|
kwargs.setdefault("exclude", set())
|
||||||
|
if isinstance(kwargs["exclude"], set):
|
||||||
|
kwargs["exclude"].add("context")
|
||||||
|
elif isinstance(kwargs["exclude"], dict):
|
||||||
|
kwargs["exclude"]["context"] = True
|
||||||
|
return super().model_dump(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def valid_agent_types(cls) -> set[str]:
|
||||||
|
"""Return the set of valid agent_type values."""
|
||||||
|
return set(get_args(cls.__annotations__["agent_type"]))
|
||||||
|
|
||||||
|
def set_context(self, context):
|
||||||
|
object.__setattr__(self, "context", context)
|
||||||
|
|
||||||
|
# Agent methods
|
||||||
def get_agent_type(self):
|
def get_agent_type(self):
|
||||||
return self._agent_type
|
return self._agent_type
|
||||||
|
|
||||||
@ -240,11 +259,6 @@ class Agent(AgentBase):
|
|||||||
"""Get the content seed for the agent."""
|
"""Get the content seed for the agent."""
|
||||||
return self._content_seed
|
return self._content_seed
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def valid_agent_types(cls) -> set[str]:
|
|
||||||
"""Return the set of valid agent_type values."""
|
|
||||||
return set(get_args(cls.__annotations__["agent_type"]))
|
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
registry.register(Agent._agent_type, Agent)
|
registry.register(Agent._agent_type, Agent)
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
from __future__ import annotations
|
||||||
|
from pydantic import BaseModel, model_validator, PrivateAttr
|
||||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -9,206 +10,206 @@ from .. conversation import Conversation
|
|||||||
from .. message import Message
|
from .. message import Message
|
||||||
|
|
||||||
class Chat(Agent, ABC):
|
class Chat(Agent, ABC):
|
||||||
|
"""
|
||||||
|
Base class for all agent types.
|
||||||
|
This class defines the common attributes and methods for all agent types.
|
||||||
|
"""
|
||||||
|
agent_type: Literal["chat"] = "chat"
|
||||||
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
|
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
||||||
"""
|
"""
|
||||||
Base class for all agent types.
|
Prepare message with context information in message.preamble
|
||||||
This class defines the common attributes and methods for all agent types.
|
|
||||||
"""
|
"""
|
||||||
agent_type: Literal["chat"] = "chat"
|
# Generate RAG content if enabled, based on the content
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
rag_context = ""
|
||||||
|
if not message.disable_rag:
|
||||||
|
# Gather RAG results, yielding each result
|
||||||
|
# as it becomes available
|
||||||
|
for value in self.context.generate_rag_results(message):
|
||||||
|
logging.info(f"RAG: {value.status} - {value.response}")
|
||||||
|
if value.status != "done":
|
||||||
|
yield value
|
||||||
|
if value.status == "error":
|
||||||
|
message.status = "error"
|
||||||
|
message.response = value.response
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
if "rag" in message.metadata and message.metadata["rag"]:
|
||||||
"""
|
for rag_collection in message.metadata["rag"]:
|
||||||
Prepare message with context information in message.preamble
|
for doc in rag_collection["documents"]:
|
||||||
"""
|
rag_context += f"{doc}\n"
|
||||||
# Generate RAG content if enabled, based on the content
|
|
||||||
rag_context = ""
|
|
||||||
if not message.disable_rag:
|
|
||||||
# Gather RAG results, yielding each result
|
|
||||||
# as it becomes available
|
|
||||||
for value in self.context.generate_rag_results(message):
|
|
||||||
logging.info(f"RAG: {value.status} - {value.response}")
|
|
||||||
if value.status != "done":
|
|
||||||
yield value
|
|
||||||
if value.status == "error":
|
|
||||||
message.status = "error"
|
|
||||||
message.response = value.response
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
if message.metadata["rag"]:
|
if rag_context:
|
||||||
for rag_collection in message.metadata["rag"]:
|
message["context"] = rag_context
|
||||||
for doc in rag_collection["documents"]:
|
|
||||||
rag_context += f"{doc}\n"
|
|
||||||
|
|
||||||
if rag_context:
|
if self.context.user_resume:
|
||||||
message["context"] = rag_context
|
message["resume"] = self.content.user_resume
|
||||||
|
|
||||||
if self.context.user_resume:
|
if message.preamble:
|
||||||
message["resume"] = self.content.user_resume
|
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
|
||||||
|
preamble_types_AND = " and ".join(preamble_types)
|
||||||
if message.preamble:
|
preamble_types_OR = " or ".join(preamble_types)
|
||||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
|
message.preamble["rules"] = f"""\
|
||||||
preamble_types_AND = " and ".join(preamble_types)
|
|
||||||
preamble_types_OR = " or ".join(preamble_types)
|
|
||||||
message.preamble["rules"] = f"""\
|
|
||||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly.
|
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly.
|
||||||
- If there is no information in these sections, answer based on your knowledge.
|
- If there is no information in these sections, answer based on your knowledge.
|
||||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||||
"""
|
"""
|
||||||
message.preamble["question"] = "Use that information to respond to:"
|
message.preamble["question"] = "Use that information to respond to:"
|
||||||
else:
|
else:
|
||||||
message.preamble["question"] = "Respond to:"
|
message.preamble["question"] = "Respond to:"
|
||||||
|
|
||||||
message.system_prompt = self.system_prompt
|
message.system_prompt = self.system_prompt
|
||||||
message.status = "done"
|
message.status = "done"
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||||
|
if self.context.processing:
|
||||||
|
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||||
|
message.status = "error"
|
||||||
|
message.response = "Busy processing another request."
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
self.context.processing = True
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
for value in self.llm.chat(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
#tools=llm_tools(context.tools) if message.enable_tools else None,
|
||||||
|
options={ "num_ctx": message.ctx_size }
|
||||||
|
):
|
||||||
|
logging.info(f"LLM: {value.status} - {value.response}")
|
||||||
|
if value.status != "done":
|
||||||
|
message.status = value.status
|
||||||
|
message.response = value.response
|
||||||
yield message
|
yield message
|
||||||
|
if value.status == "error":
|
||||||
return
|
return
|
||||||
|
response = value
|
||||||
|
|
||||||
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
|
message.metadata["eval_count"] += response["eval_count"]
|
||||||
if self.context.processing:
|
message.metadata["eval_duration"] += response["eval_duration"]
|
||||||
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||||
message.status = "error"
|
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||||
message.response = "Busy processing another request."
|
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
self.context.processing = True
|
tools_used = []
|
||||||
|
|
||||||
messages = []
|
yield {"status": "processing", "message": "Initial response received..."}
|
||||||
|
|
||||||
for value in self.llm.chat(
|
if "tool_calls" in response.get("message", {}):
|
||||||
model=self.model,
|
yield {"status": "processing", "message": "Processing tool calls..."}
|
||||||
messages=messages,
|
|
||||||
#tools=llm_tools(context.tools) if message.enable_tools else None,
|
|
||||||
options={ "num_ctx": message.ctx_size }
|
|
||||||
):
|
|
||||||
logging.info(f"LLM: {value.status} - {value.response}")
|
|
||||||
if value.status != "done":
|
|
||||||
message.status = value.status
|
|
||||||
message.response = value.response
|
|
||||||
yield message
|
|
||||||
if value.status == "error":
|
|
||||||
return
|
|
||||||
response = value
|
|
||||||
|
|
||||||
message.metadata["eval_count"] += response["eval_count"]
|
tool_message = response["message"]
|
||||||
message.metadata["eval_duration"] += response["eval_duration"]
|
tool_result = None
|
||||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
|
||||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
|
||||||
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
|
||||||
|
|
||||||
tools_used = []
|
# Process all yielded items from the handler
|
||||||
|
async for item in self.handle_tool_calls(tool_message):
|
||||||
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
|
# This is the final result tuple (tool_result, tools_used)
|
||||||
|
tool_result, tools_used = item
|
||||||
|
else:
|
||||||
|
# This is a status update, forward it
|
||||||
|
yield item
|
||||||
|
|
||||||
yield {"status": "processing", "message": "Initial response received..."}
|
message_dict = {
|
||||||
|
"role": tool_message.get("role", "assistant"),
|
||||||
|
"content": tool_message.get("content", "")
|
||||||
|
}
|
||||||
|
|
||||||
if "tool_calls" in response.get("message", {}):
|
if "tool_calls" in tool_message:
|
||||||
yield {"status": "processing", "message": "Processing tool calls..."}
|
message_dict["tool_calls"] = [
|
||||||
|
{"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}}
|
||||||
|
for tc in tool_message["tool_calls"]
|
||||||
|
]
|
||||||
|
|
||||||
tool_message = response["message"]
|
pre_add_index = len(messages)
|
||||||
tool_result = None
|
messages.append(message_dict)
|
||||||
|
|
||||||
# Process all yielded items from the handler
|
if isinstance(tool_result, list):
|
||||||
async for item in self.handle_tool_calls(tool_message):
|
messages.extend(tool_result)
|
||||||
if isinstance(item, tuple) and len(item) == 2:
|
else:
|
||||||
# This is the final result tuple (tool_result, tools_used)
|
if tool_result:
|
||||||
tool_result, tools_used = item
|
messages.append(tool_result)
|
||||||
else:
|
|
||||||
# This is a status update, forward it
|
|
||||||
yield item
|
|
||||||
|
|
||||||
message_dict = {
|
message.metadata["tools"] = tools_used
|
||||||
"role": tool_message.get("role", "assistant"),
|
|
||||||
"content": tool_message.get("content", "")
|
|
||||||
}
|
|
||||||
|
|
||||||
if "tool_calls" in tool_message:
|
# Estimate token length of new messages
|
||||||
message_dict["tool_calls"] = [
|
ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
|
||||||
{"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}}
|
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
|
||||||
for tc in tool_message["tool_calls"]
|
# Decrease creativity when processing tool call requests
|
||||||
]
|
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
|
||||||
|
message.metadata["eval_count"] += response["eval_count"]
|
||||||
|
message.metadata["eval_duration"] += response["eval_duration"]
|
||||||
|
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||||
|
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||||
|
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||||
|
|
||||||
pre_add_index = len(messages)
|
reply = response["message"]["content"]
|
||||||
messages.append(message_dict)
|
message.response = reply
|
||||||
|
message.metadata["origin"] = agent.agent_type
|
||||||
|
# final_message = {"role": "assistant", "content": reply }
|
||||||
|
|
||||||
if isinstance(tool_result, list):
|
# # history is provided to the LLM and should not have additional metadata
|
||||||
messages.extend(tool_result)
|
# llm_history.append(final_message)
|
||||||
else:
|
|
||||||
if tool_result:
|
|
||||||
messages.append(tool_result)
|
|
||||||
|
|
||||||
message.metadata["tools"] = tools_used
|
# user_history is provided to the REST API and does not include CONTEXT
|
||||||
|
# It does include metadata
|
||||||
|
# final_message["metadata"] = message.metadata
|
||||||
|
# user_history.append({**final_message, "origin": message.metadata["origin"]})
|
||||||
|
|
||||||
# Estimate token length of new messages
|
# Return the REST API with metadata
|
||||||
ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
|
yield {
|
||||||
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
|
"status": "done",
|
||||||
# Decrease creativity when processing tool call requests
|
"message": {
|
||||||
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
|
**message.model_dump(mode='json'),
|
||||||
message.metadata["eval_count"] += response["eval_count"]
|
|
||||||
message.metadata["eval_duration"] += response["eval_duration"]
|
|
||||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
|
||||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
|
||||||
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
|
||||||
|
|
||||||
reply = response["message"]["content"]
|
|
||||||
message.response = reply
|
|
||||||
message.metadata["origin"] = agent.agent_type
|
|
||||||
# final_message = {"role": "assistant", "content": reply }
|
|
||||||
|
|
||||||
# # history is provided to the LLM and should not have additional metadata
|
|
||||||
# llm_history.append(final_message)
|
|
||||||
|
|
||||||
# user_history is provided to the REST API and does not include CONTEXT
|
|
||||||
# It does include metadata
|
|
||||||
# final_message["metadata"] = message.metadata
|
|
||||||
# user_history.append({**final_message, "origin": message.metadata["origin"]})
|
|
||||||
|
|
||||||
# Return the REST API with metadata
|
|
||||||
yield {
|
|
||||||
"status": "done",
|
|
||||||
"message": {
|
|
||||||
**message.model_dump(mode='json'),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
self.context.processing = False
|
self.context.processing = False
|
||||||
return
|
return
|
||||||
|
|
||||||
async def process_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
async def process_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
||||||
message.full_content = ""
|
message.full_content = ""
|
||||||
for i, p in enumerate(message.preamble.keys()):
|
for i, p in enumerate(message.preamble.keys()):
|
||||||
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
|
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
|
||||||
|
|
||||||
# Estimate token length of new messages
|
# Estimate token length of new messages
|
||||||
message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
|
message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
|
||||||
|
|
||||||
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
||||||
message.status = "thinking"
|
message.status = "thinking"
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
for value in self.generate_llm_response(message):
|
for value in self.generate_llm_response(message):
|
||||||
logging.info(f"LLM: {value.status} - {value.response}")
|
logging.info(f"LLM: {value.status} - {value.response}")
|
||||||
if value.status != "done":
|
if value.status != "done":
|
||||||
yield value
|
yield value
|
||||||
if value.status == "error":
|
if value.status == "error":
|
||||||
return
|
return
|
||||||
|
|
||||||
def get_and_reset_content_seed(self):
|
def get_and_reset_content_seed(self):
|
||||||
tmp = self._content_seed
|
tmp = self._content_seed
|
||||||
self._content_seed = ""
|
self._content_seed = ""
|
||||||
return tmp
|
return tmp
|
||||||
|
|
||||||
def set_content_seed(self, content: str) -> None:
|
def set_content_seed(self, content: str) -> None:
|
||||||
"""Set the content seed for the agent."""
|
"""Set the content seed for the agent."""
|
||||||
self._content_seed = content
|
self._content_seed = content
|
||||||
|
|
||||||
def get_content_seed(self) -> str:
|
def get_content_seed(self) -> str:
|
||||||
"""Get the content seed for the agent."""
|
"""Get the content seed for the agent."""
|
||||||
return self._content_seed
|
return self._content_seed
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def valid_agent_types(cls) -> set[str]:
|
def valid_agent_types(cls) -> set[str]:
|
||||||
"""Return the set of valid agent_type values."""
|
"""Return the set of valid agent_type values."""
|
||||||
return set(get_args(cls.__annotations__["agent_type"]))
|
return set(get_args(cls.__annotations__["agent_type"]))
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
registry.register(Chat._agent_type, Chat)
|
registry.register(Chat._agent_type, Chat)
|
||||||
|
@ -9,16 +9,16 @@ from .. conversation import Conversation
|
|||||||
from .. message import Message
|
from .. message import Message
|
||||||
|
|
||||||
class FactCheck(Agent):
|
class FactCheck(Agent):
|
||||||
agent_type: Literal["fact_check"] = "fact_check"
|
agent_type: Literal["fact_check"] = "fact_check"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
facts: str = ""
|
facts: str = ""
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_facts(self):
|
def validate_facts(self):
|
||||||
if not self.facts.strip():
|
if not self.facts.strip():
|
||||||
raise ValueError("Facts cannot be empty")
|
raise ValueError("Facts cannot be empty")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
registry.register(FactCheck._agent_type, FactCheck)
|
registry.register(FactCheck._agent_type, FactCheck)
|
||||||
|
@ -9,16 +9,16 @@ from .. conversation import Conversation
|
|||||||
from .. message import Message
|
from .. message import Message
|
||||||
|
|
||||||
class JobDescription(Agent):
|
class JobDescription(Agent):
|
||||||
agent_type: Literal["job_description"] = "job_description"
|
agent_type: Literal["job_description"] = "job_description"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
job_description: str = ""
|
job_description: str = ""
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_job_description(self):
|
def validate_job_description(self):
|
||||||
if not self.job_description.strip():
|
if not self.job_description.strip():
|
||||||
raise ValueError("Job description cannot be empty")
|
raise ValueError("Job description cannot be empty")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
registry.register(JobDescription._agent_type, JobDescription)
|
registry.register(JobDescription._agent_type, JobDescription)
|
||||||
|
@ -9,24 +9,24 @@ from .. conversation import Conversation
|
|||||||
from .. message import Message
|
from .. message import Message
|
||||||
|
|
||||||
class Resume(Agent):
|
class Resume(Agent):
|
||||||
agent_type: Literal["resume"] = "resume"
|
agent_type: Literal["resume"] = "resume"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
resume: str = ""
|
resume: str = ""
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_resume(self):
|
def validate_resume(self):
|
||||||
if not self.resume.strip():
|
if not self.resume.strip():
|
||||||
raise ValueError("Resume content cannot be empty")
|
raise ValueError("Resume content cannot be empty")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_resume(self) -> str:
|
def get_resume(self) -> str:
|
||||||
"""Get the resume content."""
|
"""Get the resume content."""
|
||||||
return self.resume
|
return self.resume
|
||||||
|
|
||||||
def set_resume(self, resume: str) -> None:
|
def set_resume(self, resume: str) -> None:
|
||||||
"""Set the resume content."""
|
"""Set the resume content."""
|
||||||
self.resume = resume
|
self.resume = resume
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
registry.register(Resume._agent_type, Resume)
|
registry.register(Resume._agent_type, Resume)
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator, ValidationError
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING
|
from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING
|
||||||
from typing_extensions import Annotated, Union
|
from typing_extensions import Annotated, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
|
from uuid import uuid4
|
||||||
|
import re
|
||||||
|
|
||||||
from .message import Message
|
from .message import Message
|
||||||
from .rag import ChromaDBFileWatcher
|
from .rag import ChromaDBFileWatcher
|
||||||
@ -13,22 +15,23 @@ from .agents import Agent
|
|||||||
|
|
||||||
# Import only agent types, not actual classes
|
# Import only agent types, not actual classes
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .agents import Agent, AnyAgent, Chat, Resume, JobDescription, FactCheck
|
from .agents import Agent, AnyAgent
|
||||||
|
|
||||||
from .agents import AnyAgent
|
from .agents import AnyAgent
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Context(BaseModel):
|
class Context(BaseModel):
|
||||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
||||||
|
# Required fields
|
||||||
|
file_watcher: ChromaDBFileWatcher = Field(..., exclude=True)
|
||||||
|
|
||||||
|
# Optional fields
|
||||||
id: str = Field(
|
id: str = Field(
|
||||||
default_factory=lambda: str(uuid4()),
|
default_factory=lambda: str(uuid4()),
|
||||||
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
||||||
)
|
)
|
||||||
|
|
||||||
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
|
|
||||||
default_factory=list
|
|
||||||
)
|
|
||||||
|
|
||||||
user_resume: Optional[str] = None
|
user_resume: Optional[str] = None
|
||||||
user_job_description: Optional[str] = None
|
user_job_description: Optional[str] = None
|
||||||
user_facts: Optional[str] = None
|
user_facts: Optional[str] = None
|
||||||
@ -36,17 +39,27 @@ class Context(BaseModel):
|
|||||||
rags: List[dict] = []
|
rags: List[dict] = []
|
||||||
message_history_length: int = 5
|
message_history_length: int = 5
|
||||||
context_tokens: int = 0
|
context_tokens: int = 0
|
||||||
file_watcher: ChromaDBFileWatcher = Field(default=None, exclude=True)
|
# Class managed fields
|
||||||
|
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, id: Optional[str] = None, **kwargs):
|
@classmethod
|
||||||
super().__init__(id=id if id is not None else str(uuid4()), **kwargs)
|
def from_json(cls, json_str: str, file_watcher: ChromaDBFileWatcher):
|
||||||
|
"""Custom method to load from JSON with file_watcher injection"""
|
||||||
|
import json
|
||||||
|
data = json.loads(json_str)
|
||||||
|
return cls(file_watcher=file_watcher, **data)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_unique_agent_types(self):
|
def validate_unique_agent_types(self):
|
||||||
"""Ensure at most one agent per agent_type."""
|
"""Ensure at most one agent per agent_type."""
|
||||||
|
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
|
||||||
agent_types = [agent.agent_type for agent in self.agents]
|
agent_types = [agent.agent_type for agent in self.agents]
|
||||||
if len(agent_types) != len(set(agent_types)):
|
if len(agent_types) != len(set(agent_types)):
|
||||||
raise ValueError("Context cannot contain multiple agents of the same agent_type")
|
raise ValueError("Context cannot contain multiple agents of the same agent_type")
|
||||||
|
for agent in self.agents:
|
||||||
|
agent.set_context(self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096):
|
def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096):
|
||||||
@ -110,7 +123,7 @@ class Context(BaseModel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
message.response = f"Error generating RAG results: {str(e)}"
|
message.response = f"Error generating RAG results: {str(e)}"
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
logging.error(e)
|
logger.error(e)
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ class Message(BaseModel):
|
|||||||
disable_tools: bool = False
|
disable_tools: bool = False
|
||||||
|
|
||||||
# Generated while processing message
|
# Generated while processing message
|
||||||
|
status: str = "" # Status of the message
|
||||||
preamble: dict[str,str] = {} # Preamble to be prepended to the prompt
|
preamble: dict[str,str] = {} # Preamble to be prepended to the prompt
|
||||||
system_prompt: str = "" # System prompt provided to the LLM
|
system_prompt: str = "" # System prompt provided to the LLM
|
||||||
full_content: str = "" # Full content of the message (preamble + prompt)
|
full_content: str = "" # Full content of the message (preamble + prompt)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user