Starting to work again
This commit is contained in:
parent
e607e3a2f2
commit
d1940e18e5
192
src/server.py
192
src/server.py
@ -1,10 +1,4 @@
|
|||||||
import os
|
from utils import logger
|
||||||
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
|
|
||||||
warnings.filterwarnings("ignore", message="Warning only once for all operators")
|
|
||||||
warnings.filterwarnings("ignore", message="Couldn't find ffmpeg or avconv")
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Imports [standard]
|
# Imports [standard]
|
||||||
@ -55,7 +49,8 @@ from utils import (
|
|||||||
rag as Rag,
|
rag as Rag,
|
||||||
Context, Conversation, Message,
|
Context, Conversation, Message,
|
||||||
Agent,
|
Agent,
|
||||||
defines
|
defines,
|
||||||
|
logger
|
||||||
)
|
)
|
||||||
|
|
||||||
from tools import (
|
from tools import (
|
||||||
@ -260,25 +255,6 @@ def parse_args():
|
|||||||
default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}")
|
default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
def setup_logging(level):
|
|
||||||
global logging
|
|
||||||
|
|
||||||
numeric_level = getattr(logging, level.upper(), None)
|
|
||||||
if not isinstance(numeric_level, int):
|
|
||||||
raise ValueError(f"Invalid log level: {level}")
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=numeric_level,
|
|
||||||
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
force=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now reduce verbosity for FastAPI, Uvicorn, Starlette
|
|
||||||
for noisy_logger in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi", "starlette"):
|
|
||||||
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
logging.info(f"Logging is set to {level} level.")
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
@ -298,10 +274,10 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
|
|||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||||
}
|
}
|
||||||
logging.info(f"Fetching {url}")
|
logger.info(f"Fetching {url}")
|
||||||
response = requests.get(url, headers=headers, timeout=10)
|
response = requests.get(url, headers=headers, timeout=10)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
logging.info(f"{url} returned. Processing...")
|
logger.info(f"{url} returned. Processing...")
|
||||||
# Parse the HTML
|
# Parse the HTML
|
||||||
soup = BeautifulSoup(response.text, "html.parser")
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
|
|
||||||
@ -323,7 +299,7 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
|
|||||||
text = text[:max_chars] + "..."
|
text = text[:max_chars] + "..."
|
||||||
|
|
||||||
# Create Ollama client
|
# Create Ollama client
|
||||||
# logging.info(f"Requesting summary of: {text}")
|
# logger.info(f"Requesting summary of: {text}")
|
||||||
|
|
||||||
# Generate summary using Ollama
|
# Generate summary using Ollama
|
||||||
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
||||||
@ -331,7 +307,7 @@ async def AnalyzeSite(llm, model: str, url : str, question : str):
|
|||||||
system="You are given the contents of {url}. Answer the question about the contents",
|
system="You are given the contents of {url}. Answer the question about the contents",
|
||||||
prompt=prompt)
|
prompt=prompt)
|
||||||
|
|
||||||
#logging.info(response["response"])
|
#logger.info(response["response"])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"source": "summarizer-llm",
|
"source": "summarizer-llm",
|
||||||
@ -377,12 +353,12 @@ class WebServer:
|
|||||||
watch_directory=defines.doc_dir,
|
watch_directory=defines.doc_dir,
|
||||||
recreate=False # Don't recreate if exists
|
recreate=False # Don't recreate if exists
|
||||||
)
|
)
|
||||||
logging.info(f"API started with {self.file_watcher.collection.count()} documents in the collection")
|
logger.info(f"API started with {self.file_watcher.collection.count()} documents in the collection")
|
||||||
yield
|
yield
|
||||||
if self.observer:
|
if self.observer:
|
||||||
self.observer.stop()
|
self.observer.stop()
|
||||||
self.observer.join()
|
self.observer.join()
|
||||||
logging.info("File watcher stopped")
|
logger.info("File watcher stopped")
|
||||||
|
|
||||||
def __init__(self, llm, model=MODEL_NAME):
|
def __init__(self, llm, model=MODEL_NAME):
|
||||||
self.app = FastAPI(lifespan=self.lifespan)
|
self.app = FastAPI(lifespan=self.lifespan)
|
||||||
@ -400,7 +376,7 @@ class WebServer:
|
|||||||
else:
|
else:
|
||||||
allow_origins=["http://battle-linux.ketrenos.com:3000"]
|
allow_origins=["http://battle-linux.ketrenos.com:3000"]
|
||||||
|
|
||||||
logging.info(f"Allowed origins: {allow_origins}")
|
logger.info(f"Allowed origins: {allow_origins}")
|
||||||
|
|
||||||
self.app.add_middleware(
|
self.app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
@ -416,13 +392,13 @@ class WebServer:
|
|||||||
@self.app.get("/")
|
@self.app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
context = self.create_context()
|
context = self.create_context()
|
||||||
logging.info(f"Redirecting non-context to {context.id}")
|
logger.info(f"Redirecting non-context to {context.id}")
|
||||||
return RedirectResponse(url=f"/{context.id}", status_code=307)
|
return RedirectResponse(url=f"/{context.id}", status_code=307)
|
||||||
#return JSONResponse({"redirect": f"/{context.id}"})
|
#return JSONResponse({"redirect": f"/{context.id}"})
|
||||||
|
|
||||||
@self.app.put("/api/umap/{context_id}")
|
@self.app.put("/api/umap/{context_id}")
|
||||||
async def put_umap(context_id: str, request: Request):
|
async def put_umap(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
try:
|
try:
|
||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
raise Exception("File watcher not initialized")
|
raise Exception("File watcher not initialized")
|
||||||
@ -436,10 +412,10 @@ class WebServer:
|
|||||||
dimensions = data.get("dimensions", 2)
|
dimensions = data.get("dimensions", 2)
|
||||||
result = self.file_watcher.umap_collection
|
result = self.file_watcher.umap_collection
|
||||||
if dimensions == 2:
|
if dimensions == 2:
|
||||||
logging.info("Returning 2D UMAP")
|
logger.info("Returning 2D UMAP")
|
||||||
umap_embedding = self.file_watcher.umap_embedding_2d
|
umap_embedding = self.file_watcher.umap_embedding_2d
|
||||||
else:
|
else:
|
||||||
logging.info("Returning 3D UMAP")
|
logger.info("Returning 3D UMAP")
|
||||||
umap_embedding = self.file_watcher.umap_embedding_3d
|
umap_embedding = self.file_watcher.umap_embedding_3d
|
||||||
|
|
||||||
result["embeddings"] = umap_embedding.tolist()
|
result["embeddings"] = umap_embedding.tolist()
|
||||||
@ -447,19 +423,19 @@ class WebServer:
|
|||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"put_umap error: {str(e)}")
|
logger.error(f"put_umap error: {str(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
logging.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return JSONResponse({"error": str(e)}, 500)
|
return JSONResponse({"error": str(e)}, 500)
|
||||||
|
|
||||||
@self.app.put("/api/similarity/{context_id}")
|
@self.app.put("/api/similarity/{context_id}")
|
||||||
async def put_similarity(context_id: str, request: Request):
|
async def put_similarity(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
raise Exception("File watcher not initialized")
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
if not is_valid_uuid(context_id):
|
if not is_valid_uuid(context_id):
|
||||||
logging.warning(f"Invalid context_id: {context_id}")
|
logger.warning(f"Invalid context_id: {context_id}")
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -476,13 +452,13 @@ class WebServer:
|
|||||||
return JSONResponse({"error": "No results found"}, status_code=404)
|
return JSONResponse({"error": "No results found"}, status_code=404)
|
||||||
|
|
||||||
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
||||||
logging.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||||
|
|
||||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
||||||
logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
logger.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
||||||
|
|
||||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
||||||
logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
logger.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
||||||
|
|
||||||
return JSONResponse({
|
return JSONResponse({
|
||||||
**chroma_results,
|
**chroma_results,
|
||||||
@ -492,14 +468,14 @@ class WebServer:
|
|||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(e)
|
logger.error(e)
|
||||||
#return JSONResponse({"error": str(e)}, 500)
|
#return JSONResponse({"error": str(e)}, 500)
|
||||||
|
|
||||||
@self.app.put("/api/reset/{context_id}/{agent_type}")
|
@self.app.put("/api/reset/{context_id}/{agent_type}")
|
||||||
async def put_reset(context_id: str, agent_type: str, request: Request):
|
async def put_reset(context_id: str, agent_type: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
if not is_valid_uuid(context_id):
|
if not is_valid_uuid(context_id):
|
||||||
logging.warning(f"Invalid context_id: {context_id}")
|
logger.warning(f"Invalid context_id: {context_id}")
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent(agent_type)
|
agent = context.get_agent(agent_type)
|
||||||
@ -512,7 +488,7 @@ class WebServer:
|
|||||||
for reset_operation in data["reset"]:
|
for reset_operation in data["reset"]:
|
||||||
match reset_operation:
|
match reset_operation:
|
||||||
case "system_prompt":
|
case "system_prompt":
|
||||||
logging.info(f"Resetting {reset_operation}")
|
logger.info(f"Resetting {reset_operation}")
|
||||||
match agent_type:
|
match agent_type:
|
||||||
case "chat":
|
case "chat":
|
||||||
prompt = system_message
|
prompt = system_message
|
||||||
@ -526,11 +502,11 @@ class WebServer:
|
|||||||
agent.system_prompt = prompt
|
agent.system_prompt = prompt
|
||||||
response["system_prompt"] = { "system_prompt": prompt }
|
response["system_prompt"] = { "system_prompt": prompt }
|
||||||
case "rags":
|
case "rags":
|
||||||
logging.info(f"Resetting {reset_operation}")
|
logger.info(f"Resetting {reset_operation}")
|
||||||
context.rags = rags.copy()
|
context.rags = rags.copy()
|
||||||
response["rags"] = context.rags
|
response["rags"] = context.rags
|
||||||
case "tools":
|
case "tools":
|
||||||
logging.info(f"Resetting {reset_operation}")
|
logger.info(f"Resetting {reset_operation}")
|
||||||
context.tools = default_tools(tools)
|
context.tools = default_tools(tools)
|
||||||
response["tools"] = context.tools
|
response["tools"] = context.tools
|
||||||
case "history":
|
case "history":
|
||||||
@ -546,13 +522,13 @@ class WebServer:
|
|||||||
tmp = context.get_agent(mode)
|
tmp = context.get_agent(mode)
|
||||||
if not tmp:
|
if not tmp:
|
||||||
continue
|
continue
|
||||||
logging.info(f"Resetting {reset_operation} for {mode}")
|
logger.info(f"Resetting {reset_operation} for {mode}")
|
||||||
context.conversation = Conversation()
|
context.conversation = Conversation()
|
||||||
context.context_tokens = round(len(str(agent.system_prompt)) * 3 / 4) # Estimate context usage
|
context.context_tokens = round(len(str(agent.system_prompt)) * 3 / 4) # Estimate context usage
|
||||||
response["history"] = []
|
response["history"] = []
|
||||||
response["context_used"] = agent.context_tokens
|
response["context_used"] = agent.context_tokens
|
||||||
case "message_history_length":
|
case "message_history_length":
|
||||||
logging.info(f"Resetting {reset_operation}")
|
logger.info(f"Resetting {reset_operation}")
|
||||||
context.message_history_length = DEFAULT_HISTORY_LENGTH
|
context.message_history_length = DEFAULT_HISTORY_LENGTH
|
||||||
response["message_history_length"] = DEFAULT_HISTORY_LENGTH
|
response["message_history_length"] = DEFAULT_HISTORY_LENGTH
|
||||||
|
|
||||||
@ -567,7 +543,7 @@ class WebServer:
|
|||||||
|
|
||||||
@self.app.put("/api/tunables/{context_id}")
|
@self.app.put("/api/tunables/{context_id}")
|
||||||
async def put_tunables(context_id: str, request: Request):
|
async def put_tunables(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
try:
|
try:
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
|
|
||||||
@ -619,14 +595,14 @@ class WebServer:
|
|||||||
case _:
|
case _:
|
||||||
return JSONResponse({ "error": f"Unrecognized tunable {k}"}, status_code=404)
|
return JSONResponse({ "error": f"Unrecognized tunable {k}"}, status_code=404)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in put_tunables: {e}")
|
logger.error(f"Error in put_tunables: {e}")
|
||||||
return JSONResponse({"error": str(e)}, status_code=500)
|
return JSONResponse({"error": str(e)}, status_code=500)
|
||||||
|
|
||||||
@self.app.get("/api/tunables/{context_id}")
|
@self.app.get("/api/tunables/{context_id}")
|
||||||
async def get_tunables(context_id: str, request: Request):
|
async def get_tunables(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
if not is_valid_uuid(context_id):
|
if not is_valid_uuid(context_id):
|
||||||
logging.warning(f"Invalid context_id: {context_id}")
|
logger.warning(f"Invalid context_id: {context_id}")
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent("chat")
|
agent = context.get_agent("chat")
|
||||||
@ -644,15 +620,15 @@ class WebServer:
|
|||||||
|
|
||||||
@self.app.get("/api/system-info/{context_id}")
|
@self.app.get("/api/system-info/{context_id}")
|
||||||
async def get_system_info(context_id: str, request: Request):
|
async def get_system_info(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
return JSONResponse(system_info(self.model))
|
return JSONResponse(system_info(self.model))
|
||||||
|
|
||||||
@self.app.post("/api/chat/{context_id}/{agent_type}")
|
@self.app.post("/api/chat/{context_id}/{agent_type}")
|
||||||
async def post_chat_endpoint(context_id: str, agent_type: str, request: Request):
|
async def post_chat_endpoint(context_id: str, agent_type: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
try:
|
try:
|
||||||
if not is_valid_uuid(context_id):
|
if not is_valid_uuid(context_id):
|
||||||
logging.warning(f"Invalid context_id: {context_id}")
|
logger.warning(f"Invalid context_id: {context_id}")
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
|
|
||||||
@ -660,11 +636,11 @@ class WebServer:
|
|||||||
data = await request.json()
|
data = await request.json()
|
||||||
agent = context.get_agent(agent_type)
|
agent = context.get_agent(agent_type)
|
||||||
if not agent and agent_type == "job_description":
|
if not agent and agent_type == "job_description":
|
||||||
logging.info(f"Agent {agent_type} not found. Returning empty history.")
|
logger.info(f"Agent {agent_type} not found. Returning empty history.")
|
||||||
# Create a new agent if it doesn't exist
|
# Create a new agent if it doesn't exist
|
||||||
agent = context.get_or_create_agent("job_description", system_prompt=system_generate_resume, job_description=data["content"])
|
agent = context.get_or_create_agent("job_description", system_prompt=system_generate_resume, job_description=data["content"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(f"Attempt to create agent type: {agent_type} failed", e)
|
logger.info(f"Attempt to create agent type: {agent_type} failed", e)
|
||||||
return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404)
|
return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404)
|
||||||
|
|
||||||
# Create a custom generator that ensures flushing
|
# Create a custom generator that ensures flushing
|
||||||
@ -688,43 +664,43 @@ class WebServer:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in post_chat_endpoint: {e}")
|
logger.error(f"Error in post_chat_endpoint: {e}")
|
||||||
return JSONResponse({"error": str(e)}, status_code=500)
|
return JSONResponse({"error": str(e)}, status_code=500)
|
||||||
|
|
||||||
@self.app.post("/api/context")
|
@self.app.post("/api/context")
|
||||||
async def create_context():
|
async def create_context():
|
||||||
context = self.create_context()
|
context = self.create_context()
|
||||||
logging.info(f"Generated new agent as {context.id}")
|
logger.info(f"Generated new agent as {context.id}")
|
||||||
return JSONResponse({ "id": context.id })
|
return JSONResponse({ "id": context.id })
|
||||||
|
|
||||||
@self.app.get("/api/history/{context_id}/{agent_type}")
|
@self.app.get("/api/history/{context_id}/{agent_type}")
|
||||||
async def get_history(context_id: str, agent_type: str, request: Request):
|
async def get_history(context_id: str, agent_type: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
try:
|
try:
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent(agent_type)
|
agent = context.get_agent(agent_type)
|
||||||
if not agent:
|
if not agent:
|
||||||
logging.info(f"Agent {agent_type} not found. Returning empty history.")
|
logger.info(f"Agent {agent_type} not found. Returning empty history.")
|
||||||
return JSONResponse({ "messages": [] })
|
return JSONResponse({ "messages": [] })
|
||||||
logging.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.")
|
logger.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.")
|
||||||
return agent.conversation
|
return agent.conversation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"get_history error: {str(e)}")
|
logger.error(f"get_history error: {str(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
logging.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return JSONResponse({"error": str(e)}, status_code=404)
|
return JSONResponse({"error": str(e)}, status_code=404)
|
||||||
|
|
||||||
@self.app.get("/api/tools/{context_id}")
|
@self.app.get("/api/tools/{context_id}")
|
||||||
async def get_tools(context_id: str, request: Request):
|
async def get_tools(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
return JSONResponse(context.tools)
|
return JSONResponse(context.tools)
|
||||||
|
|
||||||
@self.app.put("/api/tools/{context_id}")
|
@self.app.put("/api/tools/{context_id}")
|
||||||
async def put_tools(context_id: str, request: Request):
|
async def put_tools(context_id: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
if not is_valid_uuid(context_id):
|
if not is_valid_uuid(context_id):
|
||||||
logging.warning(f"Invalid context_id: {context_id}")
|
logger.warning(f"Invalid context_id: {context_id}")
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
try:
|
try:
|
||||||
@ -743,9 +719,9 @@ class WebServer:
|
|||||||
|
|
||||||
@self.app.get("/api/context-status/{context_id}/{agent_type}")
|
@self.app.get("/api/context-status/{context_id}/{agent_type}")
|
||||||
async def get_context_status(context_id, agent_type: str, request: Request):
|
async def get_context_status(context_id, agent_type: str, request: Request):
|
||||||
logging.info(f"{request.method} {request.url.path}")
|
logger.info(f"{request.method} {request.url.path}")
|
||||||
if not is_valid_uuid(context_id):
|
if not is_valid_uuid(context_id):
|
||||||
logging.warning(f"Invalid context_id: {context_id}")
|
logger.warning(f"Invalid context_id: {context_id}")
|
||||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||||
context = self.upsert_context(context_id)
|
context = self.upsert_context(context_id)
|
||||||
agent = context.get_agent(agent_type)
|
agent = context.get_agent(agent_type)
|
||||||
@ -761,9 +737,9 @@ class WebServer:
|
|||||||
async def serve_static(path: str):
|
async def serve_static(path: str):
|
||||||
full_path = os.path.join(defines.static_content, path)
|
full_path = os.path.join(defines.static_content, path)
|
||||||
if os.path.exists(full_path) and os.path.isfile(full_path):
|
if os.path.exists(full_path) and os.path.isfile(full_path):
|
||||||
logging.info(f"Serve static request for {full_path}")
|
logger.info(f"Serve static request for {full_path}")
|
||||||
return FileResponse(full_path)
|
return FileResponse(full_path)
|
||||||
logging.info(f"Serve index.html for {path}")
|
logger.info(f"Serve index.html for {path}")
|
||||||
return FileResponse(os.path.join(defines.static_content, "index.html"))
|
return FileResponse(os.path.join(defines.static_content, "index.html"))
|
||||||
|
|
||||||
def save_context(self, context_id):
|
def save_context(self, context_id):
|
||||||
@ -807,28 +783,28 @@ class WebServer:
|
|||||||
|
|
||||||
# Check if the file exists
|
# Check if the file exists
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
logging.info(f"Context file {file_path} not found. Creating new context.")
|
logger.info(f"Context file {file_path} not found. Creating new context.")
|
||||||
self.contexts[context_id] = self.create_context(context_id)
|
self.contexts[context_id] = self.create_context(context_id)
|
||||||
else:
|
else:
|
||||||
# Read and deserialize the data
|
# Read and deserialize the data
|
||||||
with open(file_path, "r") as f:
|
with open(file_path, "r") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
logging.info(f"Loading context from {file_path}, content length: {len(content)}")
|
logger.info(f"Loading context from {file_path}, content length: {len(content)}")
|
||||||
try:
|
try:
|
||||||
# Try parsing as JSON first to ensure valid JSON
|
# Try parsing as JSON first to ensure valid JSON
|
||||||
import json
|
import json
|
||||||
json_data = json.loads(content)
|
json_data = json.loads(content)
|
||||||
logging.info("JSON parsed successfully, attempting model validation")
|
logger.info("JSON parsed successfully, attempting model validation")
|
||||||
|
|
||||||
# Now try Pydantic validation
|
# Now try Pydantic validation
|
||||||
self.contexts[context_id] = Context.from_json(json_data, file_watcher=self.file_watcher)
|
self.contexts[context_id] = Context.from_json(json_data, file_watcher=self.file_watcher)
|
||||||
logging.info(f"Successfully loaded context {context_id}")
|
logger.info(f"Successfully loaded context {context_id}")
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logging.error(f"Invalid JSON in file: {e}")
|
logger.error(f"Invalid JSON in file: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error validating context: {str(e)}")
|
logger.error(f"Error validating context: {str(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
logging.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
# Fallback to creating a new context
|
# Fallback to creating a new context
|
||||||
self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher)
|
self.contexts[context_id] = Context(id=context_id, file_watcher=self.file_watcher)
|
||||||
|
|
||||||
@ -845,7 +821,7 @@ class WebServer:
|
|||||||
if not self.file_watcher:
|
if not self.file_watcher:
|
||||||
raise Exception("File watcher not initialized")
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
logging.info(f"Creating new context with ID: {context_id}")
|
logger.info(f"Creating new context with ID: {context_id}")
|
||||||
context = Context(id=context_id, file_watcher=self.file_watcher)
|
context = Context(id=context_id, file_watcher=self.file_watcher)
|
||||||
|
|
||||||
if os.path.exists(defines.resume_doc):
|
if os.path.exists(defines.resume_doc):
|
||||||
@ -859,7 +835,7 @@ class WebServer:
|
|||||||
context.tools = default_tools(tools)
|
context.tools = default_tools(tools)
|
||||||
context.rags = rags.copy()
|
context.rags = rags.copy()
|
||||||
|
|
||||||
logging.info(f"{context.id} created and added to contexts.")
|
logger.info(f"{context.id} created and added to contexts.")
|
||||||
self.contexts[context.id] = context
|
self.contexts[context.id] = context
|
||||||
self.save_context(context.id)
|
self.save_context(context.id)
|
||||||
return context
|
return context
|
||||||
@ -941,13 +917,13 @@ class WebServer:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if not context_id:
|
if not context_id:
|
||||||
logging.warning("No context ID provided. Creating a new context.")
|
logger.warning("No context ID provided. Creating a new context.")
|
||||||
return self.create_context()
|
return self.create_context()
|
||||||
|
|
||||||
if context_id in self.contexts:
|
if context_id in self.contexts:
|
||||||
return self.contexts[context_id]
|
return self.contexts[context_id]
|
||||||
|
|
||||||
logging.info(f"Context {context_id} is not yet loaded.")
|
logger.info(f"Context {context_id} is not yet loaded.")
|
||||||
return self.load_or_create_context(context_id)
|
return self.load_or_create_context(context_id)
|
||||||
|
|
||||||
def generate_rag_results(self, context, content):
|
def generate_rag_results(self, context, content):
|
||||||
@ -963,13 +939,13 @@ class WebServer:
|
|||||||
if chroma_results:
|
if chroma_results:
|
||||||
results_found = True
|
results_found = True
|
||||||
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
||||||
logging.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||||
|
|
||||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
||||||
logging.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
logger.info(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
||||||
|
|
||||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
||||||
logging.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
logger.info(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
**chroma_results,
|
**chroma_results,
|
||||||
@ -1012,11 +988,11 @@ class WebServer:
|
|||||||
raise Exception("File watcher not initialized")
|
raise Exception("File watcher not initialized")
|
||||||
|
|
||||||
agent_type = agent.get_agent_type()
|
agent_type = agent.get_agent_type()
|
||||||
logging.info(f"generate_response: {agent_type}")
|
logger.info(f"generate_response: {agent_type}")
|
||||||
if agent_type == "chat":
|
if agent_type == "chat":
|
||||||
message = Message(prompt=content)
|
message = Message(prompt=content)
|
||||||
async for value in agent.prepare_message(message):
|
async for value in agent.prepare_message(message):
|
||||||
logging.info(f"{agent_type}.prepare_message: {value.status} - {value.response}")
|
logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}")
|
||||||
if value.status != "done":
|
if value.status != "done":
|
||||||
yield value
|
yield value
|
||||||
if value.status == "error":
|
if value.status == "error":
|
||||||
@ -1025,7 +1001,7 @@ class WebServer:
|
|||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
async for value in agent.process_message(message):
|
async for value in agent.process_message(message):
|
||||||
logging.info(f"{agent_type}.process_message: {value.status} - {value.response}")
|
logger.info(f"{agent_type}.process_message: {value.status} - {value.response}")
|
||||||
if value.status != "done":
|
if value.status != "done":
|
||||||
yield value
|
yield value
|
||||||
if value.status == "error":
|
if value.status == "error":
|
||||||
@ -1034,7 +1010,7 @@ class WebServer:
|
|||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
async for value in agent.generate_llm_response(message):
|
async for value in agent.generate_llm_response(message):
|
||||||
logging.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}")
|
logger.info(f"{agent_type}.generate_llm_response: {value.status} - {value.response}")
|
||||||
if value.status != "done":
|
if value.status != "done":
|
||||||
yield value
|
yield value
|
||||||
if value.status == "error":
|
if value.status == "error":
|
||||||
@ -1042,13 +1018,13 @@ class WebServer:
|
|||||||
message.response = value.response
|
message.response = value.response
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
logging.info("TODO: There is more to do...")
|
logger.info("TODO: There is more to do...")
|
||||||
return
|
return
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.processing:
|
if self.processing:
|
||||||
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||||
yield {"status": "error", "message": "Busy processing another request."}
|
yield {"status": "error", "message": "Busy processing another request."}
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -1083,11 +1059,11 @@ class WebServer:
|
|||||||
process_type = agent.get_agent_type()
|
process_type = agent.get_agent_type()
|
||||||
match process_type:
|
match process_type:
|
||||||
case "job_description":
|
case "job_description":
|
||||||
logging.info(f"job_description user_history len: {len(conversation.messages)}")
|
logger.info(f"job_description user_history len: {len(conversation.messages)}")
|
||||||
if len(conversation.messages) >= 2: # USER, ASSISTANT
|
if len(conversation.messages) >= 2: # USER, ASSISTANT
|
||||||
process_type = "chat"
|
process_type = "chat"
|
||||||
case "resume":
|
case "resume":
|
||||||
logging.info(f"resume user_history len: {len(conversation.messages)}")
|
logger.info(f"resume user_history len: {len(conversation.messages)}")
|
||||||
if len(conversation.messages) >= 3: # USER, ASSISTANT, FACT_CHECK
|
if len(conversation.messages) >= 3: # USER, ASSISTANT, FACT_CHECK
|
||||||
process_type = "chat"
|
process_type = "chat"
|
||||||
case "fact_check":
|
case "fact_check":
|
||||||
@ -1098,7 +1074,7 @@ class WebServer:
|
|||||||
case "chat":
|
case "chat":
|
||||||
if not message.prompt:
|
if not message.prompt:
|
||||||
yield {"status": "error", "message": "No query provided for chat."}
|
yield {"status": "error", "message": "No query provided for chat."}
|
||||||
logging.info(f"user_history len: {len(conversation.messages)}")
|
logger.info(f"user_history len: {len(conversation.messages)}")
|
||||||
self.processing = False
|
self.processing = False
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -1207,8 +1183,8 @@ Use to the above information to respond to this prompt:
|
|||||||
|
|
||||||
message.add_action("generate_resume")
|
message.add_action("generate_resume")
|
||||||
|
|
||||||
logging.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into agent'resume'")
|
logger.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into agent'resume'")
|
||||||
logging.info("TODO: For subsequent runs, have the Agent handler generate the follow up prompts so they can have correct context preamble")
|
logger.info("TODO: For subsequent runs, have the Agent handler generate the follow up prompts so they can have correct context preamble")
|
||||||
|
|
||||||
# Switch to resume agent for LLM responses
|
# Switch to resume agent for LLM responses
|
||||||
# message.metadata["origin"] = "resume"
|
# message.metadata["origin"] = "resume"
|
||||||
@ -1288,11 +1264,11 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
stuffingMessage.metadata["origin"] = "resume"
|
stuffingMessage.metadata["origin"] = "resume"
|
||||||
stuffingMessage.metadata["display"] = "hide"
|
stuffingMessage.metadata["display"] = "hide"
|
||||||
stuffingMessage.actions = [ "fact_check" ]
|
stuffingMessage.actions = [ "fact_check" ]
|
||||||
logging.info("TODO: Switch this to use actions to keep the UI from showingit")
|
logger.info("TODO: Switch this to use actions to keep the UI from showingit")
|
||||||
conversation.add_message(stuffingMessage)
|
conversation.add_message(stuffingMessage)
|
||||||
|
|
||||||
# For all future calls to job_description, use the system_job_description
|
# For all future calls to job_description, use the system_job_description
|
||||||
logging.info("TODO: Create a system_resume_QA prompt to use for the resume agent")
|
logger.info("TODO: Create a system_resume_QA prompt to use for the resume agent")
|
||||||
agent.system_prompt = system_prompt
|
agent.system_prompt = system_prompt
|
||||||
|
|
||||||
# Switch to fact_check agent for LLM responses
|
# Switch to fact_check agent for LLM responses
|
||||||
@ -1364,7 +1340,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
else:
|
else:
|
||||||
response = self.llm.chat(model=self.model, messages=messages, options={ "num_ctx": ctx_size })
|
response = self.llm.chat(model=self.model, messages=messages, options={ "num_ctx": ctx_size })
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception({ "model": self.model, "error": str(e) })
|
logger.exception({ "model": self.model, "error": str(e) })
|
||||||
yield {"status": "error", "message": f"An error occurred communicating with LLM"}
|
yield {"status": "error", "message": f"An error occurred communicating with LLM"}
|
||||||
self.processing = False
|
self.processing = False
|
||||||
return
|
return
|
||||||
@ -1449,7 +1425,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# logging.exception({ "model": self.model, "origin": agent_type, "content": content, "error": str(e) })
|
# logger.exception({ "model": self.model, "origin": agent_type, "content": content, "error": str(e) })
|
||||||
# yield {"status": "error", "message": f"An error occurred: {str(e)}"}
|
# yield {"status": "error", "message": f"An error occurred: {str(e)}"}
|
||||||
|
|
||||||
# finally:
|
# finally:
|
||||||
@ -1460,7 +1436,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs):
|
def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs):
|
||||||
try:
|
try:
|
||||||
if self.ssl_enabled:
|
if self.ssl_enabled:
|
||||||
logging.info(f"Starting web server at https://{host}:{port}")
|
logger.info(f"Starting web server at https://{host}:{port}")
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
self.app,
|
self.app,
|
||||||
host=host,
|
host=host,
|
||||||
@ -1470,7 +1446,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
|||||||
ssl_certfile=defines.cert_path
|
ssl_certfile=defines.cert_path
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info(f"Starting web server at http://{host}:{port}")
|
logger.info(f"Starting web server at http://{host}:{port}")
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
self.app,
|
self.app,
|
||||||
host=host,
|
host=host,
|
||||||
@ -1493,7 +1469,7 @@ def main():
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Setup logging based on the provided level
|
# Setup logging based on the provided level
|
||||||
setup_logging(args.level)
|
logger.setLevel(args.level)
|
||||||
|
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
"ignore",
|
"ignore",
|
||||||
|
@ -6,7 +6,7 @@ from .message import Message
|
|||||||
from . conversation import Conversation
|
from . conversation import Conversation
|
||||||
from . context import Context
|
from . context import Context
|
||||||
from . import agents
|
from . import agents
|
||||||
import logging
|
from . setup_logging import setup_logging
|
||||||
|
|
||||||
from .agents import Agent, __all__ as agents_all
|
from .agents import Agent, __all__ as agents_all
|
||||||
|
|
||||||
@ -17,13 +17,14 @@ __all__ = [
|
|||||||
'Message',
|
'Message',
|
||||||
'ChromaDBFileWatcher',
|
'ChromaDBFileWatcher',
|
||||||
'start_file_watcher'
|
'start_file_watcher'
|
||||||
|
'logger',
|
||||||
] + agents_all
|
] + agents_all
|
||||||
|
|
||||||
# Resolve circular dependencies by rebuilding models
|
# Resolve circular dependencies by rebuilding models
|
||||||
# Call model_rebuild() on Agent and Context
|
# Call model_rebuild() on Agent and Context
|
||||||
Agent.model_rebuild()
|
Agent.model_rebuild()
|
||||||
Context.model_rebuild()
|
Context.model_rebuild()
|
||||||
import logging
|
|
||||||
import importlib
|
import importlib
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Type
|
from typing import Type
|
||||||
@ -31,19 +32,20 @@ from typing import Type
|
|||||||
# Assuming class_registry is available from agents/__init__.py
|
# Assuming class_registry is available from agents/__init__.py
|
||||||
from .agents import class_registry, AnyAgent
|
from .agents import class_registry, AnyAgent
|
||||||
|
|
||||||
|
logger = setup_logging(level=defines.logging_level)
|
||||||
|
|
||||||
def rebuild_models():
|
def rebuild_models():
|
||||||
Context.model_rebuild()
|
|
||||||
for class_name, (module_name, _) in class_registry.items():
|
for class_name, (module_name, _) in class_registry.items():
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
cls = getattr(module, class_name, None)
|
cls = getattr(module, class_name, None)
|
||||||
|
|
||||||
logging.info(f"Checking: {class_name} in module {module_name}")
|
logger.info(f"Checking: {class_name} in module {module_name}")
|
||||||
logging.info(f" cls: {True if cls else False}")
|
logger.info(f" cls: {True if cls else False}")
|
||||||
logging.info(f" isinstance(cls, type): {isinstance(cls, type)}")
|
logger.info(f" isinstance(cls, type): {isinstance(cls, type)}")
|
||||||
logging.info(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}")
|
logger.info(f" issubclass(cls, BaseModel): {issubclass(cls, BaseModel) if cls else False}")
|
||||||
logging.info(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}")
|
logger.info(f" issubclass(cls, AnyAgent): {issubclass(cls, AnyAgent) if cls else False}")
|
||||||
logging.info(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}")
|
logger.info(f" cls is not AnyAgent: {cls is not AnyAgent if cls else True}")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cls
|
cls
|
||||||
@ -52,12 +54,14 @@ def rebuild_models():
|
|||||||
and issubclass(cls, AnyAgent)
|
and issubclass(cls, AnyAgent)
|
||||||
and cls is not AnyAgent
|
and cls is not AnyAgent
|
||||||
):
|
):
|
||||||
logging.info(f"Rebuilding {class_name} from {module_name}")
|
logger.info(f"Rebuilding {class_name} from {module_name}")
|
||||||
|
from . agents import Agent
|
||||||
|
from . context import Context
|
||||||
cls.model_rebuild()
|
cls.model_rebuild()
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.error(f"Failed to import module {module_name}: {e}")
|
logger.error(f"Failed to import module {module_name}: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error processing {class_name} in {module_name}: {e}")
|
logger.error(f"Error processing {class_name} in {module_name}: {e}")
|
||||||
|
|
||||||
# Call this after all modules are imported
|
# Call this after all modules are imported
|
||||||
rebuild_models()
|
rebuild_models()
|
@ -3,18 +3,20 @@ from pydantic import BaseModel, model_validator, PrivateAttr, Field
|
|||||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef
|
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
import logging
|
from .. setup_logging import setup_logging
|
||||||
|
|
||||||
|
logger = setup_logging()
|
||||||
|
|
||||||
# Only import Context for type checking
|
# Only import Context for type checking
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .. context import Context
|
from .. context import Context
|
||||||
|
|
||||||
from .types import AgentBase, registry
|
from .types import registry
|
||||||
|
|
||||||
from .. conversation import Conversation
|
from .. conversation import Conversation
|
||||||
from .. message import Message
|
from .. message import Message
|
||||||
|
|
||||||
class Agent(AgentBase):
|
class Agent(BaseModel, ABC):
|
||||||
"""
|
"""
|
||||||
Base class for all agent types.
|
Base class for all agent types.
|
||||||
This class defines the common attributes and methods for all agent types.
|
This class defines the common attributes and methods for all agent types.
|
||||||
@ -28,27 +30,19 @@ class Agent(AgentBase):
|
|||||||
system_prompt: str # Mandatory
|
system_prompt: str # Mandatory
|
||||||
conversation: Conversation = Conversation()
|
conversation: Conversation = Conversation()
|
||||||
context_tokens: int = 0
|
context_tokens: int = 0
|
||||||
context: object = Field(..., exclude=True) # Avoid circular reference, require as param, and prevent serialization
|
context: Optional[Context] = Field(default=None, exclude=True) # Avoid circular reference, require as param, and prevent serialization
|
||||||
|
|
||||||
_content_seed: str = PrivateAttr(default="")
|
_content_seed: str = PrivateAttr(default="")
|
||||||
|
|
||||||
# Class and pydantic model management
|
# Class and pydantic model management
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
"""Auto-register subclasses"""
|
"""Auto-register subclasses"""
|
||||||
|
logger.info(f"Agent.__init_subclass__({kwargs})")
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
# Register this class if it has an agent_type
|
# Register this class if it has an agent_type
|
||||||
if hasattr(cls, 'agent_type') and cls.agent_type != AgentBase._agent_type:
|
if hasattr(cls, 'agent_type') and cls.agent_type != Agent._agent_type:
|
||||||
registry.register(cls.agent_type, cls)
|
registry.register(cls.agent_type, cls)
|
||||||
|
|
||||||
def __init__(self, **data):
|
|
||||||
# Set agent_type from class if not provided
|
|
||||||
if 'agent_type' not in data:
|
|
||||||
data['agent_type'] = self.__class__.agent_type
|
|
||||||
from .. context import Context
|
|
||||||
Context.model_rebuild()
|
|
||||||
self.__class__.model_rebuild()
|
|
||||||
super().__init__(**data)
|
|
||||||
|
|
||||||
def model_dump(self, *args, **kwargs):
|
def model_dump(self, *args, **kwargs):
|
||||||
# Ensure context is always excluded, even with exclude_unset=True
|
# Ensure context is always excluded, even with exclude_unset=True
|
||||||
kwargs.setdefault("exclude", set())
|
kwargs.setdefault("exclude", set())
|
||||||
@ -80,7 +74,7 @@ class Agent(AgentBase):
|
|||||||
# Gather RAG results, yielding each result
|
# Gather RAG results, yielding each result
|
||||||
# as it becomes available
|
# as it becomes available
|
||||||
for value in self.context.generate_rag_results(message):
|
for value in self.context.generate_rag_results(message):
|
||||||
logging.info(f"RAG: {value.status} - {value.response}")
|
logger.info(f"RAG: {value.status} - {value.response}")
|
||||||
if value.status != "done":
|
if value.status != "done":
|
||||||
yield value
|
yield value
|
||||||
if value.status == "error":
|
if value.status == "error":
|
||||||
@ -120,7 +114,7 @@ class Agent(AgentBase):
|
|||||||
|
|
||||||
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
|
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||||
if self.context.processing:
|
if self.context.processing:
|
||||||
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
logger.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
message.response = "Busy processing another request."
|
message.response = "Busy processing another request."
|
||||||
yield message
|
yield message
|
||||||
@ -136,7 +130,7 @@ class Agent(AgentBase):
|
|||||||
#tools=llm_tools(context.tools) if message.enable_tools else None,
|
#tools=llm_tools(context.tools) if message.enable_tools else None,
|
||||||
options={ "num_ctx": message.ctx_size }
|
options={ "num_ctx": message.ctx_size }
|
||||||
):
|
):
|
||||||
logging.info(f"LLM: {value.status} - {value.response}")
|
logger.info(f"LLM: {value.status} - {value.response}")
|
||||||
if value.status != "done":
|
if value.status != "done":
|
||||||
message.status = value.status
|
message.status = value.status
|
||||||
message.response = value.response
|
message.response = value.response
|
||||||
@ -240,7 +234,7 @@ class Agent(AgentBase):
|
|||||||
yield message
|
yield message
|
||||||
|
|
||||||
for value in self.generate_llm_response(message):
|
for value in self.generate_llm_response(message):
|
||||||
logging.info(f"LLM: {value.status} - {value.response}")
|
logger.info(f"LLM: {value.status} - {value.response}")
|
||||||
if value.status != "done":
|
if value.status != "done":
|
||||||
yield value
|
yield value
|
||||||
if value.status == "error":
|
if value.status == "error":
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, Field, model_validator, ValidationError
|
from pydantic import BaseModel, Field, model_validator, ValidationError
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING
|
from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING
|
||||||
@ -11,12 +12,6 @@ from .message import Message
|
|||||||
from .rag import ChromaDBFileWatcher
|
from .rag import ChromaDBFileWatcher
|
||||||
from . import defines
|
from . import defines
|
||||||
|
|
||||||
from .agents import Agent
|
|
||||||
|
|
||||||
# Import only agent types, not actual classes
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .agents import Agent, AnyAgent
|
|
||||||
|
|
||||||
from .agents import AnyAgent
|
from .agents import AnyAgent
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@ -25,7 +20,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class Context(BaseModel):
|
class Context(BaseModel):
|
||||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
||||||
# Required fields
|
# Required fields
|
||||||
file_watcher: ChromaDBFileWatcher = Field(..., exclude=True)
|
file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
||||||
|
|
||||||
# Optional fields
|
# Optional fields
|
||||||
id: str = Field(
|
id: str = Field(
|
||||||
@ -44,15 +39,14 @@ class Context(BaseModel):
|
|||||||
default_factory=list
|
default_factory=list
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json(cls, json_str: str, file_watcher: ChromaDBFileWatcher):
|
def before_model_validator(cls, values: Any):
|
||||||
"""Custom method to load from JSON with file_watcher injection"""
|
logger.info(f"Preparing model data: {cls} {values}")
|
||||||
import json
|
return values
|
||||||
data = json.loads(json_str)
|
|
||||||
return cls(file_watcher=file_watcher, **data)
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_unique_agent_types(self):
|
def after_model_validator(self):
|
||||||
"""Ensure at most one agent per agent_type."""
|
"""Ensure at most one agent per agent_type."""
|
||||||
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
|
logger.info(f"Context {self.id} initialized with {len(self.agents)} agents.")
|
||||||
agent_types = [agent.agent_type for agent in self.agents]
|
agent_types = [agent.agent_type for agent in self.agents]
|
||||||
@ -191,4 +185,5 @@ class Context(BaseModel):
|
|||||||
summary += f"\nChat Name: {agent.name}\n"
|
summary += f"\nChat Name: {agent.name}\n"
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
from . agents import Agent
|
||||||
Context.model_rebuild()
|
Context.model_rebuild()
|
@ -15,3 +15,4 @@ resume_doc = "/opt/backstory/docs/resume/generic.md"
|
|||||||
# Only used for testing; backstory-prod will not use this
|
# Only used for testing; backstory-prod will not use this
|
||||||
key_path = "/opt/backstory/keys/key.pem"
|
key_path = "/opt/backstory/keys/key.pem"
|
||||||
cert_path = "/opt/backstory/keys/cert.pem"
|
cert_path = "/opt/backstory/keys/cert.pem"
|
||||||
|
logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper()
|
Loading…
x
Reference in New Issue
Block a user