backstory/src/server.py
2025-05-23 11:59:25 -07:00

1245 lines
52 KiB
Python

LLM_TIMEOUT = 600
from utils import logger
from pydantic import BaseModel, Field, ValidationError # type: ignore
from pydantic_core import PydanticSerializationError # type: ignore
from typing import List
from typing import AsyncGenerator, Dict, Optional
# %%
# Imports [standard]
# Standard library modules (no try-except needed)
import argparse
import asyncio
import json
import logging
import os
import re
import uuid
import subprocess
import re
import math
import warnings
# from typing import Any
import inspect
import time
import traceback
def try_import(module_name, pip_name=None):
try:
__import__(module_name)
except ImportError:
print(f"Module '{module_name}' not found. Install it using:")
print(f" pip install {pip_name or module_name}")
# Third-party modules with import checks
try_import("ollama")
try_import("requests")
try_import("fastapi")
try_import("uvicorn")
try_import("numpy")
try_import("umap")
try_import("sklearn")
try_import("prometheus_client")
try_import("prometheus_fastapi_instrumentator")
import ollama
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, HTTPException # type: ignore
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse # type: ignore
from fastapi.middleware.cors import CORSMiddleware # type: ignore
import uvicorn # type: ignore
import numpy as np # type: ignore
# Prometheus
from prometheus_client import Summary # type: ignore
from prometheus_fastapi_instrumentator import Instrumentator # type: ignore
from prometheus_client import CollectorRegistry, Counter # type: ignore
from utils import (
rag as Rag,
RagEntry,
tools as Tools,
Context,
Conversation,
Message,
Agent,
Metrics,
Tunables,
defines,
User,
check_serializable,
logger,
)
class Query(BaseModel):
prompt: str
tunables: Tunables = Field(default_factory=Tunables)
agent_options: Dict[str, str | int | float | Dict] = Field(default={}, exclude=True)
REQUEST_TIME = Summary("request_processing_seconds", "Time spent processing request")
def get_installed_ram():
try:
with open("/proc/meminfo", "r") as f:
meminfo = f.read()
match = re.search(r"MemTotal:\s+(\d+)", meminfo)
if match:
return f"{math.floor(int(match.group(1)) / 1000**2)}GB" # Convert KB to GB
except Exception as e:
return f"Error retrieving RAM: {e}"
def get_graphics_cards():
gpus = []
try:
# Run the ze-monitor utility
result = subprocess.run(
["ze-monitor"], capture_output=True, text=True, check=True
)
# Clean up the output (remove leading/trailing whitespace and newlines)
output = result.stdout.strip()
for index in range(len(output.splitlines())):
result = subprocess.run(
["ze-monitor", "--device", f"{index+1}", "--info"],
capture_output=True,
text=True,
check=True,
)
gpu_info = result.stdout.strip().splitlines()
gpu = {
"discrete": True, # Assume it's discrete initially
"name": None,
"memory": None,
}
gpus.append(gpu)
for line in gpu_info:
match = re.match(r"^Device: [^(]*\((.*)\)", line)
if match:
gpu["name"] = match.group(1)
continue
match = re.match(r"^\s*Memory: (.*)", line)
if match:
gpu["memory"] = match.group(1)
continue
match = re.match(r"^.*Is integrated with host: Yes.*", line)
if match:
gpu["discrete"] = False
continue
return gpus
except Exception as e:
return f"Error retrieving GPU info: {e}"
def get_cpu_info():
try:
with open("/proc/cpuinfo", "r") as f:
cpuinfo = f.read()
model_match = re.search(r"model name\s+:\s+(.+)", cpuinfo)
cores_match = re.findall(r"processor\s+:\s+\d+", cpuinfo)
if model_match and cores_match:
return f"{model_match.group(1)} with {len(cores_match)} cores"
except Exception as e:
return f"Error retrieving CPU info: {e}"
def system_info(model):
return {
"System RAM": get_installed_ram(),
"Graphics Card": get_graphics_cards(),
"CPU": get_cpu_info(),
"LLM Model": model,
"Embedding Model": defines.embedding_model,
"Context length": defines.max_context,
}
# %%
# Defaults
OLLAMA_API_URL = defines.ollama_api_url
MODEL_NAME = defines.model
LOG_LEVEL = "info"
WEB_HOST = "0.0.0.0"
WEB_PORT = 8911
DEFAULT_HISTORY_LENGTH = 5
# %%
# Cmd line overrides
def parse_args():
parser = argparse.ArgumentParser(description="AI is Really Cool")
parser.add_argument(
"--ollama-server",
type=str,
default=OLLAMA_API_URL,
help=f"Ollama API endpoint. default={OLLAMA_API_URL}",
)
parser.add_argument(
"--ollama-model",
type=str,
default=MODEL_NAME,
help=f"LLM model to use. default={MODEL_NAME}",
)
parser.add_argument(
"--web-host",
type=str,
default=WEB_HOST,
help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.",
)
parser.add_argument(
"--web-port",
type=str,
default=WEB_PORT,
help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.",
)
parser.add_argument(
"--level",
type=str,
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
default=LOG_LEVEL,
help=f"Set the logging level. default={LOG_LEVEL}",
)
return parser.parse_args()
# %%
# %%
# %%
def is_valid_uuid(value: str) -> bool:
try:
uuid_obj = uuid.UUID(value, version=4)
return str(uuid_obj) == value
except (ValueError, TypeError):
return False
# %%
class WebServer:
@asynccontextmanager
async def lifespan(self, app: FastAPI):
yield
for user in self.users:
if user.observer:
user.observer.stop()
user.observer.join()
logger.info("File watcher stopped")
def __init__(self, llm, model=MODEL_NAME):
self.app = FastAPI(lifespan=self.lifespan)
self.prometheus_collector = CollectorRegistry()
self.metrics = Metrics(prometheus_collector=self.prometheus_collector)
# Keep the Instrumentator instance alive
self.instrumentator = Instrumentator(registry=self.prometheus_collector)
# Instrument the FastAPI app
self.instrumentator.instrument(self.app)
# Expose the /metrics endpoint
self.instrumentator.expose(self.app, endpoint="/metrics")
self.llm = llm
self.model = model
self.processing = False
self.users = []
self.contexts = {}
self.ssl_enabled = os.path.exists(defines.key_path) and os.path.exists(
defines.cert_path
)
if self.ssl_enabled:
allow_origins = ["https://battle-linux.ketrenos.com:3000",
"https://backstory-beta.ketrenos.com"]
else:
allow_origins = ["http://battle-linux.ketrenos.com:3000",
"http://backstory-beta.ketrenos.com"]
logger.info(f"Allowed origins: {allow_origins}")
self.app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
self.setup_routes()
def setup_routes(self):
# @self.app.get("/")
# async def root():
# context = self.create_context(username=defines.default_username)
# logger.info(f"Redirecting non-context to {context.id}")
# return RedirectResponse(url=f"/{context.id}", status_code=307)
# # return JSONResponse({"redirect": f"/{context.id}"})
@self.app.get("/api/umap/entry/{doc_id}/{context_id}")
async def get_umap(doc_id: str, context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
context = await self.upsert_context(context_id)
if not context:
return JSONResponse(
{"error": f"Invalid context: {context_id}"}, status_code=400
)
user = context.user
collection = user.umap_collection
if not collection:
return JSONResponse(
{"error": "No UMAP collection found"}, status_code=404
)
if not collection.get("metadatas", None):
return JSONResponse(f"Document id {doc_id} not found.", 404)
for index, id in enumerate(collection.get("ids", [])):
if id == doc_id:
metadata = collection.get("metadatas", [])[index].copy()
content = user.file_watcher.prepare_metadata(metadata)
return JSONResponse(content)
return JSONResponse(f"Document id {doc_id} not found.", 404)
except Exception as e:
logger.error(f"get_umap error: {str(e)}")
logger.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, 500)
@self.app.put("/api/umap/{context_id}")
async def put_umap(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
context = await self.upsert_context(context_id)
if not context:
return JSONResponse(
{"error": f"Invalid context: {context_id}"}, status_code=400
)
user = context.user
data = await request.json()
dimensions = data.get("dimensions", 2)
collection = user.file_watcher.umap_collection
if not collection:
return JSONResponse(
{"error": "No UMAP collection found"}, status_code=404
)
if dimensions == 2:
logger.info("Returning 2D UMAP")
umap_embedding = user.file_watcher.umap_embedding_2d
else:
logger.info("Returning 3D UMAP")
umap_embedding = user.file_watcher.umap_embedding_3d
if len(umap_embedding) == 0:
return JSONResponse(
{"error": "No UMAP embedding found"}, status_code=404
)
result = {
"ids": collection.get("ids", []),
"metadatas": collection.get("metadatas", []),
"documents": collection.get("documents", []),
"embeddings": umap_embedding.tolist(),
"size": user.file_watcher.collection.count()
}
return JSONResponse(result)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"put_umap error: {str(e)}")
return JSONResponse({"error": str(e)}, 500)
@self.app.put("/api/similarity/{context_id}")
async def put_similarity(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
context = await self.upsert_context(context_id)
user = context.user
try:
data = await request.json()
query = data.get("query", "")
threshold = data.get("threshold", defines.default_rag_threshold)
results = data.get("results", defines.default_rag_top_k)
except:
query = ""
threshold = defines.default_rag_threshold
results = defines.default_rag_top_k
if not query:
return JSONResponse(
{"error": "No query provided for similarity search"},
status_code=400,
)
try:
chroma_results = user.file_watcher.find_similar(
query=query, top_k=results, threshold=threshold
)
if not chroma_results:
return JSONResponse({"error": "No results found"}, status_code=404)
chroma_embedding = np.array(
chroma_results["query_embedding"]
).flatten() # Ensure correct shape
logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = user.file_watcher.umap_model_2d.transform([chroma_embedding])[
0
].tolist()
logger.info(
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
) # Debug output
umap_3d = user.file_watcher.umap_model_3d.transform([chroma_embedding])[
0
].tolist()
logger.info(
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
) # Debug output
return JSONResponse({
"distances": chroma_results["distances"],
"ids": chroma_results["ids"],
"metadatas": chroma_results["metadatas"],
"query": query,
"umap_embedding_2d": umap_2d,
"umap_embedding_3d": umap_3d,
"size": user.file_watcher.collection.count()
})
except Exception as e:
logger.error(e)
logging.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, 500)
@self.app.put("/api/reset/{context_id}/{agent_type}")
async def put_reset(context_id: str, agent_type: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id):
logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = await self.upsert_context(context_id)
agent = context.get_agent(agent_type)
if not agent:
response = { "history": [] }
return JSONResponse(response)
data = await request.json()
try:
response = {}
for reset_operation in data["reset"]:
match reset_operation:
case "system_prompt":
logger.info(f"Resetting {reset_operation}")
case "rags":
logger.info(f"Resetting {reset_operation}")
context.rags = [ r.model_copy() for r in context.user.rags]
response["rags"] = [ r.model_dump(mode="json") for r in context.rags ]
case "tools":
logger.info(f"Resetting {reset_operation}")
context.tools = Tools.all_tools()
response["tools"] = Tools.llm_tools(context.tools)
case "history":
reset_map = {
"job_description": (
"job_description",
"resume",
"fact_check",
),
"resume": ("job_description", "resume", "fact_check"),
"fact_check": (
"job_description",
"resume",
"fact_check",
),
"chat": ("chat",),
}
resets = reset_map.get(agent_type, ())
for mode in resets:
tmp = context.get_agent(mode)
if not tmp:
logger.info(
f"Agent {mode} not found for {context_id}"
)
continue
logger.info(f"Resetting {reset_operation} for {mode}")
if mode != "chat":
logger.info(f"Removing agent {tmp.agent_type} from {context_id}")
context.remove_agent(tmp)
tmp.conversation.reset()
response["history"] = []
response["context_used"] = agent.context_tokens
if not response:
return JSONResponse(
{"error": "Usage: { reset: rags|tools|history|system_prompt}"}
)
else:
await self.save_context(context_id)
return JSONResponse(response)
except Exception as e:
logger.error(f"Error in reset: {e}")
logger.error(traceback.format_exc())
return JSONResponse(
{"error": "Usage: { reset: rags|tools|history|system_prompt}"}
)
@self.app.put("/api/tunables/{context_id}")
async def put_tunables(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
context = await self.upsert_context(context_id)
data = await request.json()
agent = context.get_agent("chat")
if not agent:
logger.info("chat agent does not exist on this context!")
return JSONResponse(
{"error": f"chat is not recognized", "context": context.id},
status_code=404,
)
for k in data.keys():
match k:
case "tools":
from typing import Any
# { "tools": [{ "tool": tool.name, "enabled": tool.enabled }] }
tools: List[Dict[str, Any]] = data[k]
if not tools:
return JSONResponse(
{
"status": "error",
"message": "Tools can not be empty.",
}
)
for tool in tools:
for context_tool in context.tools:
if context_tool.tool.function.name == tool["name"]:
context_tool.enabled = tool.get("enabled", True)
await self.save_context(context_id)
return JSONResponse({
"tools": [{
**t.function.model_dump(mode='json'),
"enabled": t.enabled,
} for t in context.tools]
})
case "rags":
from typing import Any
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
rag_configs: list[dict[str, Any]] = data[k]
if not rag_configs:
return JSONResponse(
{
"status": "error",
"message": "RAGs can not be empty.",
}
)
for config in rag_configs:
for context_rag in context.rags:
if context_rag.name == config["name"]:
context_rag.enabled = config["enabled"]
await self.save_context(context_id)
return JSONResponse({"rags": [ r.model_dump(mode="json") for r in context.rags]})
case "system_prompt":
system_prompt = data[k].strip()
if not system_prompt:
return JSONResponse(
{
"status": "error",
"message": "System prompt can not be empty.",
}
)
agent.system_prompt = system_prompt
await self.save_context(context_id)
return JSONResponse({"system_prompt": system_prompt})
case _:
return JSONResponse(
{"error": f"Unrecognized tunable {k}"}, status_code=404
)
except Exception as e:
logger.error(f"Error in put_tunables: {e}")
return JSONResponse({"error": str(e)}, status_code=500)
@self.app.get("/api/user/{context_id}")
async def get_user(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
context = await self.upsert_context(context_id)
user = context.user
user_data = {
"username": user.username,
"first_name": user.first_name,
"last_name": user.last_name,
"full_name": user.full_name,
"description": user.description,
"contact_info": user.contact_info,
"rag_content_size": user.rag_content_size,
"has_profile": user.has_profile,
"questions": [ q.model_dump(mode='json') for q in user.user_questions],
}
return JSONResponse(user_data)
@self.app.get("/api/tunables/{context_id}")
async def get_tunables(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
context = await self.upsert_context(context_id)
agent = context.get_agent("chat")
if not agent:
logger.info("chat agent does not exist on this context!")
return JSONResponse(
{"error": f"chat is not recognized", "context": context.id},
status_code=404,
)
return JSONResponse(
{
"system_prompt": agent.system_prompt,
"rags": [ r.model_dump(mode="json") for r in context.rags ],
"tools": [{
**t.function.model_dump(mode='json'),
"enabled": t.enabled,
} for t in context.tools],
}
)
@self.app.get("/api/system-info/{context_id}")
async def get_system_info(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
return JSONResponse(system_info(self.model))
@self.app.post("/api/{agent_type}/{context_id}")
async def post_agent_endpoint(
agent_type: str, context_id: str, request: Request
):
logger.info(f"{request.method} {request.url.path}")
try:
context = await self.upsert_context(context_id)
except Exception as e:
error = {
"error": f"Unable to create or access context {context_id}: {e}"
}
logger.info(error)
return JSONResponse(error, status_code=404)
try:
data = await request.json()
query: Query = Query(**data)
except Exception as e:
error = {"error": f"Attempt to parse request: {e}"}
logger.info(error)
return JSONResponse(error, status_code=400)
try:
agent = context.get_or_create_agent(agent_type, **query.agent_options)
except Exception as e:
error = {
"error": f"Attempt to create agent type: {agent_type} failed: {e}"
}
logger.error(traceback.format_exc())
logger.info(error)
return JSONResponse(error, status_code=404)
try:
async def flush_generator():
logger.info(f"{agent.agent_type} - {inspect.stack()[0].function}")
try:
start_time = time.perf_counter()
async for message in self.generate_response(
context=context,
agent=agent,
prompt=query.prompt,
tunables=query.tunables,
):
if message.status != "done" and message.status != "partial":
if message.status == "streaming":
result = {
"status": "streaming",
"chunk": message.chunk,
"remaining_time": LLM_TIMEOUT
- (time.perf_counter() - start_time),
}
else:
start_time = time.perf_counter()
result = {
"status": message.status,
"response": message.response,
"remaining_time": LLM_TIMEOUT,
}
else:
logger.info(f"Providing {message.status} response.")
try:
result = message.model_dump(
by_alias=True, mode="json"
)
except Exception as e:
result = {"status": "error", "response": str(e)}
yield json.dumps(result) + "\n"
break
# Convert to JSON and add newline
message.network_packets += 1
message.network_bytes += len(result)
try:
disconnected = await request.is_disconnected()
except Exception as e:
logger.warning(f"Disconnection check failed: {e}")
disconnected = True
if disconnected:
logger.info("Disconnect detected. Continuing generation to store in cache.")
disconnected = True
if not disconnected:
yield json.dumps(result) + "\n"
current_time = time.perf_counter()
if current_time - start_time > LLM_TIMEOUT:
message.status = "error"
message.response = f"Processing time ({LLM_TIMEOUT}s) exceeded for single LLM inference (likely due to LLM getting stuck.) You will need to retry your query."
message.partial_response = message.response
logger.info(message.response + " Ending session")
result = message.model_dump(by_alias=True, mode="json")
if not disconnected:
yield json.dumps(result) + "\n"
if message.status == "error":
context.processing = False
break
# Allow the event loop to process the write
await asyncio.sleep(0)
except Exception as e:
context.processing = False
logger.error(traceback.format_exc())
logger.error(f"Error in generate_response: {e}")
yield json.dumps({"status": "error", "response": str(e)}) + "\n"
finally:
# Save context on completion or error
await self.save_context(context_id)
logger.info("Flush generator completed normally.")
# Return StreamingResponse with appropriate headers
return StreamingResponse(
flush_generator(),
media_type="application/json",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Prevents Nginx buffering if you're using it
},
)
except Exception as e:
context.processing = False
logger.error(f"Error in post_chat_endpoint: {e}")
return JSONResponse({"error": str(e)}, status_code=500)
@self.app.post("/api/create-session")
async def create_session(request: Request):
logger.info(f"{request.method} {request.url.path}")
context = await self.create_context(username=defines.default_username)
return JSONResponse({"id": context.id})
@self.app.get("/api/join-session/{context_id}")
async def join_session(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
context = await self.load_context(context_id=context_id)
if not context:
return JSONResponse({"error": f"{context_id} does not exist."}, 404)
return JSONResponse({"id": context.id})
@self.app.get("/api/u/{context_id}")
async def get_users(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
context = await self.load_context(context_id)
if not context:
return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404)
users = [User.sanitize(u) for u in User.get_users()]
return JSONResponse(users)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"get_users error: {str(e)}")
return JSONResponse({ "error": "Unable to parse users"}, 500)
@self.app.get("/api/u/{username}/profile/{context_id}")
async def get_user_profile(username: str, context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
if not User.exists(username):
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
context = await self.load_context(context_id)
if not context:
return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404)
profile_path = os.path.join(defines.user_dir, username, f"profile.png")
if not os.path.exists(profile_path):
return JSONResponse({ "error": "User {username} does not have a profile picture"}, status_code=404)
return FileResponse(profile_path)
except Exception as e:
return JSONResponse({ "error": "Unable to load user {username}"}, 500)
@self.app.post("/api/u/{username}/{context_id}")
async def post_user(username: str, context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
if not User.exists(username):
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
context = await self.load_context(context_id)
if not context:
return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404)
matching_user = next((user for user in self.users if user.username == username), None)
if matching_user:
user = matching_user
else:
user = User(username=username, llm=self.llm)
await user.initialize(prometheus_collector=self.prometheus_collector)
self.users.append(user)
reset_map = (
"chat",
"job_description",
"resume",
"fact_check",
)
for mode in reset_map:
tmp = context.get_agent(mode)
if not tmp:
continue
logger.info(f"User change: Resetting history for {mode}")
if mode != "chat":
context.remove_agent(tmp)
tmp.conversation.reset()
context.user = user
user_data = {
"username": user.username,
"first_name": user.first_name,
"last_name": user.last_name,
"full_name": user.full_name,
"description": user.description,
"contact_info": user.contact_info,
"rag_content_size": user.rag_content_size,
"has_profile": user.has_profile,
"questions": [ q.model_dump(mode='json') for q in user.user_questions],
}
await self.save_context(context_id)
return JSONResponse(user_data)
except Exception as e:
return JSONResponse({ "error": "Unable to load user {username}"}, 500)
@self.app.post("/api/context/u/{username}")
async def create_user_context(username: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
if not User.exists(username):
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
context = await self.create_context(username=username)
logger.info(f"Generated new context {context.id} for {username}")
return JSONResponse({"id": context.id})
except Exception as e:
logger.error(f"create_user_context error: {str(e)}")
logger.error(traceback.format_exc())
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
@self.app.post("/api/context")
async def create_context(request: Request):
logger.info(f"{request.method} {request.url.path}")
return self.app.create_user_context(defines.default_username, request)
@self.app.get("/api/history/{context_id}/{agent_type}")
async def get_history(context_id: str, agent_type: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
context = await self.upsert_context(context_id)
agent = context.get_agent(agent_type)
if not agent:
logger.info(
f"Agent {agent_type} not found. Returning empty history."
)
return JSONResponse({"messages": []})
logger.info(
f"History for {agent_type} contains {len(agent.conversation)} entries."
)
return agent.conversation
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"get_history error: {str(e)}")
return JSONResponse({"error": str(e)}, status_code=404)
@self.app.get("/api/tools/{context_id}")
async def get_tools(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
context = await self.upsert_context(context_id)
return JSONResponse(context.tools)
@self.app.put("/api/tools/{context_id}")
async def put_tools(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id):
logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = await self.upsert_context(context_id)
try:
data = await request.json()
modify = data["tool"]
enabled = data["enabled"]
for tool in context.tools:
if modify == tool.function.name:
tool.enabled = enabled
await self.save_context(context_id)
return JSONResponse(context.tools)
return JSONResponse(
{"status": f"{modify} not found in tools."}, status_code=404
)
except:
return JSONResponse({"status": "error"}, 405)
@self.app.get("/api/context-status/{context_id}/{agent_type}")
async def get_context_status(context_id, agent_type: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id):
logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = await self.upsert_context(context_id)
agent = context.get_agent(agent_type)
if not agent:
return JSONResponse(
{"context_used": 0, "max_context": defines.max_context}
)
return JSONResponse(
{
"context_used": agent.context_tokens,
"max_context": defines.max_context,
}
)
@self.app.get("/api/health")
async def health_check():
return JSONResponse({"status": "healthy"})
@self.app.get("/{path:path}")
async def serve_static(path: str, request: Request):
full_path = os.path.join(defines.static_content, path)
# Check if the original path exists
if os.path.exists(full_path) and os.path.isfile(full_path):
logger.info(f"Serve static request for {full_path}")
return FileResponse(full_path)
# Check if the path matches /{filename}.png
# if path.startswith("/") and path.endswith(".png"):
# filename = path[1:-4] # Remove leading '/' and trailing '.png'
# alt_path = f"/opt/backstory/users/{filename}/{filename}.png"
# # Check if the alternative path exists
# if os.path.exists(alt_path) and os.path.isfile(alt_path):
# logger.info(f"Serve static request for alternative path {alt_path}")
# return FileResponse(alt_path)
# If neither path exists, return 404
logger.info(f"File not found for path {full_path} -- returning index.html")
return FileResponse(os.path.join(defines.static_content, "index.html"))
async def save_context(self, context_id):
"""
Serialize a Python dictionary to a file in the agents directory.
Args:
data: Dictionary containing the agent data
context_id: UUID string for the context. If it doesn't exist, it is created
Returns:
The context_id used for the file
"""
context = await self.upsert_context(context_id)
# Create agents directory if it doesn't exist
if not os.path.exists(defines.context_dir):
os.makedirs(defines.context_dir)
# Create the full file path
file_path = os.path.join(defines.context_dir, context_id)
# Serialize the data to JSON and write to file
try:
# Check for non-serializable fields before dumping
serialization_errors = check_serializable(context)
if serialization_errors:
for error in serialization_errors:
logger.error(error)
raise ValueError("Found non-serializable fields in the model")
# Dump the model prior to opening file in case there is
# a validation error so it doesn't delete the current
# context session
json_data = context.model_dump_json(by_alias=True)
with open(file_path, "w") as f:
f.write(json_data)
except ValidationError as e:
logger.error(e)
logger.error(traceback.format_exc())
for error in e.errors():
print(f"Field: {error['loc'][0]}, Error: {error['msg']}")
except PydanticSerializationError as e:
logger.error(e)
logger.error(traceback.format_exc())
logger.error(f"Serialization error: {str(e)}")
# Inspect the model to identify problematic fields
for field_name, value in context.__dict__.items():
if isinstance(value, np.ndarray):
logger.error(f"Field '{field_name}' contains non-serializable type: {type(value)}")
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
return context_id
async def load_context(self, context_id: str) -> Context | None:
"""
Load a context from a file in the context directory or create a new one if it doesn't exist.
Args:
context_id: UUID string for the context.
Returns:
A Context object with the specified ID and default settings.
"""
file_path = os.path.join(defines.context_dir, context_id)
# Check if the file exists
if not os.path.exists(file_path):
return None
# Read and deserialize the data
with open(file_path, "r") as f:
content = f.read()
logger.info(
f"Loading context from {file_path}, content length: {len(content)}"
)
json_data = {}
try:
# Try parsing as JSON first to ensure valid JSON
json_data = json.loads(content)
logger.info("JSON parsed successfully, attempting model validation")
context = Context.model_validate(json_data)
username = context.username
if not User.exists(username):
raise ValueError(f"Attempt to load context {context.id} with invalid user {username}")
matching_user = next((user for user in self.users if user.username == username), None)
if matching_user:
user = matching_user
else:
user = User(username=username, llm=self.llm)
await user.initialize(prometheus_collector=self.prometheus_collector)
self.users.append(user)
context.user = user
# Now set context on agents manually
agent_types = [agent.agent_type for agent in context.agents]
if len(agent_types) != len(set(agent_types)):
raise ValueError(
"Context cannot contain multiple agents of the same agent_type"
)
for agent in context.agents:
agent.set_context(context)
self.contexts[context_id] = context
logger.info(f"Successfully loaded context {context_id}")
except ValidationError as e:
logger.error(e)
logger.error(traceback.format_exc())
for error in e.errors():
print(f"Field: {error['loc'][0]}, Error: {error['msg']}")
except Exception as e:
logger.error(f"Error validating context: {str(e)}")
logger.error(traceback.format_exc())
for key in json_data:
logger.info(f"{key} = {type(json_data[key])} {str(json_data[key])[:60] if json_data[key] else "None"}")
logger.info("*" * 50)
return None
return self.contexts[context_id]
async def load_or_create_context(self, context_id: str) -> Context:
"""
Load a context from a file in the context directory or create a new one if it doesn't exist.
Args:
context_id: UUID string for the context.
Returns:
A Context object with the specified ID and default settings.
"""
context = await self.load_context(context_id)
if context:
return context
logger.info(f"Context not found. Creating new instance of context {context_id}.")
self.contexts[context_id] = await self.create_context(username=defines.default_username, context_id=context_id)
return self.contexts[context_id]
async def create_context(self, username: str, context_id=None) -> Context:
"""
Create a new context with a unique ID and default settings.
Args:
context_id: Optional UUID string for the context. If not provided, a new UUID is generated.
Returns:
A Context object with the specified ID and default settings.
"""
if not User.exists(username):
raise ValueError(f"{username} does not exist.")
# If username
matching_user = next((user for user in self.users if user.username == username), None)
if matching_user:
user = matching_user
logger.info(f"Found matching user: {user.username}")
else:
user = User(username=username, llm=self.llm)
await user.initialize(prometheus_collector=self.prometheus_collector)
logger.info(f"Created new instance of user: {user.username}")
self.users.append(user)
logger.info(f"Creating context {context_id or "new"} for user: {user.username}")
try:
if context_id:
context = Context(
id=context_id,
user=user,
rags=[ rag.model_copy() for rag in user.rags ],
tools=Tools.all_tools()
)
else:
context = Context(
user=user,
rags=[ rag.model_copy() for rag in user.rags ],
tools=Tools.all_tools()
)
except ValidationError as e:
logger.error(e)
logger.error(traceback.format_exc())
for error in e.errors():
print(f"Field: {error['loc'][0]}, Error: {error['msg']}")
exit(1)
logger.info(f"New context created with ID: {context.id}")
if os.path.exists(defines.resume_doc):
context.user_resume = open(defines.resume_doc, "r").read()
context.get_or_create_agent(agent_type="chat")
# system_prompt=system_message)
# context.add_agent(Resume(system_prompt = system_generate_resume))
# context.add_agent(JobDescription(system_prompt = system_job_description))
# context.add_agent(FactCheck(system_prompt = system_fact_check))
logger.info(f"{context.id} created and added to contexts.")
self.contexts[context.id] = context
await self.save_context(context.id)
return context
async def upsert_context(self, context_id=None) -> Context:
"""
Upsert a context based on the provided context_id.
Args:
context_id: UUID string for the context. If it doesn't exist, a new context is created.
Returns:
A Context object with the specified ID and default settings.
"""
if not context_id:
logger.warning("No context ID provided. Creating a new context.")
return await self.create_context(username=defines.default_username)
if context_id in self.contexts:
return self.contexts[context_id]
logger.info(f"Context {context_id} is not yet loaded.")
return await self.load_or_create_context(context_id=context_id)
@REQUEST_TIME.time()
async def generate_response(
self, context: Context, agent: Agent, prompt: str, tunables: Tunables | None
) -> AsyncGenerator[Message, None]:
agent_type = agent.get_agent_type()
logger.info(f"generate_response: type - {agent_type}")
# Merge tunables to take agent defaults and override with user supplied settings
agent_tunables = agent.tunables.model_dump() if agent.tunables else {}
user_tunables = tunables.model_dump() if tunables else {}
merged_tunables = {**agent_tunables, **user_tunables}
message = Message(prompt=prompt, tunables=Tunables(**merged_tunables))
async for message in agent.prepare_message(message):
# logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}")
if message.status == "error":
yield message
return
if message.status != "done":
yield message
async for message in agent.process_message(self.llm, self.model, message):
if message.status != "done":
yield message
if message.status == "error":
return
logger.info(
f"{agent_type}.process_message: {message.status} {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}"
)
message.status = "done"
yield message
return
def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs):
try:
if self.ssl_enabled:
logger.info(f"Starting web server at https://{host}:{port}")
uvicorn.run(
self.app,
host=host,
port=port,
log_config=None,
ssl_keyfile=defines.key_path,
ssl_certfile=defines.cert_path,
)
else:
logger.info(f"Starting web server at http://{host}:{port}")
uvicorn.run(self.app, host=host, port=port, log_config=None)
except KeyboardInterrupt:
for user in self.users:
if user.observer:
user.observer.stop()
for user in self.users:
if user.observer:
user.observer.join()
# %%
# Main function to run everything
def main():
# Parse command-line arguments
args = parse_args()
# Setup logging based on the provided level
logger.setLevel(args.level.upper())
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.*")
warnings.filterwarnings("ignore", category=UserWarning, module="umap.*")
llm = ollama.Client(host=args.ollama_server) # type: ignore
web_server = WebServer(llm, args.ollama_model)
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
main()