backstory/src/server.py

1316 lines
55 KiB
Python

LLM_TIMEOUT = 600
from utils import logger
from pydantic import BaseModel, Field, ValidationError
from pydantic_core import PydanticSerializationError
from typing import List
from typing import AsyncGenerator, Dict, Optional
# %%
# Imports [standard]
# Standard library modules (no try-except needed)
import argparse
import asyncio
import json
import logging
import os
import re
import uuid
import subprocess
import re
import math
import warnings
# from typing import Any
import inspect
import time
import traceback
def try_import(module_name, pip_name=None):
try:
__import__(module_name)
except ImportError:
print(f"Module '{module_name}' not found. Install it using:")
print(f" pip install {pip_name or module_name}")
# Third-party modules with import checks
try_import("ollama")
try_import("requests")
try_import("fastapi")
try_import("uvicorn")
try_import("numpy")
try_import("umap")
try_import("sklearn")
try_import("prometheus_client")
try_import("prometheus_fastapi_instrumentator")
import ollama
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import numpy as np
from utils import redis_manager
import redis.asyncio as redis
# Prometheus
from prometheus_client import Summary
from prometheus_fastapi_instrumentator import Instrumentator
from prometheus_client import CollectorRegistry, Counter
from utils import (
rag as Rag,
RagEntry,
tools as Tools,
Context,
Conversation,
Message,
Agent,
Metrics,
Tunables,
defines,
User,
check_serializable,
logger,
)
class Query(BaseModel):
prompt: str
tunables: Tunables = Field(default_factory=Tunables)
agent_options: Dict[str, str | int | float | Dict] = Field(default={}, exclude=True)
REQUEST_TIME = Summary("request_processing_seconds", "Time spent processing request")
def get_installed_ram():
try:
with open("/proc/meminfo", "r") as f:
meminfo = f.read()
match = re.search(r"MemTotal:\s+(\d+)", meminfo)
if match:
return f"{math.floor(int(match.group(1)) / 1000**2)}GB" # Convert KB to GB
except Exception as e:
return f"Error retrieving RAM: {e}"
def get_graphics_cards():
gpus = []
try:
# Run the ze-monitor utility
result = subprocess.run(
["ze-monitor"], capture_output=True, text=True, check=True
)
# Clean up the output (remove leading/trailing whitespace and newlines)
output = result.stdout.strip()
for index in range(len(output.splitlines())):
result = subprocess.run(
["ze-monitor", "--device", f"{index+1}", "--info"],
capture_output=True,
text=True,
check=True,
)
gpu_info = result.stdout.strip().splitlines()
gpu = {
"discrete": True, # Assume it's discrete initially
"name": None,
"memory": None,
}
gpus.append(gpu)
for line in gpu_info:
match = re.match(r"^Device: [^(]*\((.*)\)", line)
if match:
gpu["name"] = match.group(1)
continue
match = re.match(r"^\s*Memory: (.*)", line)
if match:
gpu["memory"] = match.group(1)
continue
match = re.match(r"^.*Is integrated with host: Yes.*", line)
if match:
gpu["discrete"] = False
continue
return gpus
except Exception as e:
return f"Error retrieving GPU info: {e}"
def get_cpu_info():
try:
with open("/proc/cpuinfo", "r") as f:
cpuinfo = f.read()
model_match = re.search(r"model name\s+:\s+(.+)", cpuinfo)
cores_match = re.findall(r"processor\s+:\s+\d+", cpuinfo)
if model_match and cores_match:
return f"{model_match.group(1)} with {len(cores_match)} cores"
except Exception as e:
return f"Error retrieving CPU info: {e}"
def system_info(model):
return {
"System RAM": get_installed_ram(),
"Graphics Card": get_graphics_cards(),
"CPU": get_cpu_info(),
"LLM Model": model,
"Embedding Model": defines.embedding_model,
"Context length": defines.max_context,
}
# %%
# Defaults
OLLAMA_API_URL = defines.ollama_api_url
MODEL_NAME = defines.model
LOG_LEVEL = "info"
WEB_HOST = "0.0.0.0"
WEB_PORT = 8911
DEFAULT_HISTORY_LENGTH = 5
# %%
# Cmd line overrides
def parse_args():
parser = argparse.ArgumentParser(description="AI is Really Cool")
parser.add_argument(
"--ollama-server",
type=str,
default=OLLAMA_API_URL,
help=f"Ollama API endpoint. default={OLLAMA_API_URL}",
)
parser.add_argument(
"--ollama-model",
type=str,
default=MODEL_NAME,
help=f"LLM model to use. default={MODEL_NAME}",
)
parser.add_argument(
"--web-host",
type=str,
default=WEB_HOST,
help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.",
)
parser.add_argument(
"--web-port",
type=str,
default=WEB_PORT,
help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.",
)
parser.add_argument(
"--level",
type=str,
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
default=LOG_LEVEL,
help=f"Set the logging level. default={LOG_LEVEL}",
)
return parser.parse_args()
# %%
# %%
# %%
def is_valid_uuid(value: str) -> bool:
try:
uuid_obj = uuid.UUID(value, version=4)
return str(uuid_obj) == value
except (ValueError, TypeError):
return False
# %%
class WebServer:
@asynccontextmanager
async def lifespan(self, app: FastAPI):
# Startup
await redis_manager.connect()
# Shutdown
yield
for user in self.users:
if user.observer:
user.observer.stop()
user.observer.join()
logger.info("File watcher stopped")
await redis_manager.disconnect()
def __init__(self, llm, model=MODEL_NAME):
self.app = FastAPI(lifespan=self.lifespan)
self.prometheus_collector = CollectorRegistry()
self.metrics = Metrics(prometheus_collector=self.prometheus_collector)
# Keep the Instrumentator instance alive
self.instrumentator = Instrumentator(registry=self.prometheus_collector)
# Instrument the FastAPI app
self.instrumentator.instrument(self.app)
# Expose the /metrics endpoint
self.instrumentator.expose(self.app, endpoint="/metrics")
self.llm = llm
self.model = model
self.processing = False
self.users = []
self.contexts = {}
self.ssl_enabled = os.path.exists(defines.key_path) and os.path.exists(
defines.cert_path
)
if self.ssl_enabled:
allow_origins = ["https://battle-linux.ketrenos.com:3000",
"https://backstory-beta.ketrenos.com"]
else:
allow_origins = ["http://battle-linux.ketrenos.com:3000",
"http://backstory-beta.ketrenos.com"]
logger.info(f"Allowed origins: {allow_origins}")
self.app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
self.setup_routes()
async def get_redis(self) -> redis.Redis:
"""Dependency to get Redis client"""
return redis_manager.get_client()
def sanitize_input(self, input: str):
# Validate input: allow only alphanumeric, underscores, and hyphens
if not re.match(r'^[a-zA-Z0-9._-]+$', input): # alphanumeric, _, -, and . are valid
raise ValueError("Invalid input format.")
if re.match(r'\.\.', input): # two ticks in a row is invalid
raise ValueError("Invalid input format.")
def setup_routes(self):
# @self.app.get("/")
# async def root():
# context = self.create_context(username=defines.default_username)
# logger.info(f"Redirecting non-context to {context.id}")
# return RedirectResponse(url=f"/{context.id}", status_code=307)
# # return JSONResponse({"redirect": f"/{context.id}"})
@self.app.get("/api/umap/entry/{doc_id}/{context_id}")
async def get_umap(doc_id: str, context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
context = await self.upsert_context(context_id)
if not context:
return JSONResponse(
{"error": f"Invalid context: {context_id}"}, status_code=400
)
user = context.user
collection = user.umap_collection
if not collection:
return JSONResponse(
{"error": "No UMAP collection found"}, status_code=404
)
if not collection.get("metadatas", None):
return JSONResponse(f"Document id {doc_id} not found.", 404)
for index, id in enumerate(collection.get("ids", [])):
if id == doc_id:
metadata = collection.get("metadatas", [])[index].copy()
content = user.file_watcher.prepare_metadata(metadata)
return JSONResponse(content)
return JSONResponse(f"Document id {doc_id} not found.", 404)
except Exception as e:
logger.error(f"get_umap error: {str(e)}")
logger.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, 500)
@self.app.put("/api/umap/{context_id}")
async def put_umap(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
context = await self.upsert_context(context_id)
if not context:
return JSONResponse(
{"error": f"Invalid context: {context_id}"}, status_code=400
)
user = context.user
data = await request.json()
dimensions = data.get("dimensions", 2)
collection = user.file_watcher.umap_collection
if not collection:
return JSONResponse(
{"error": "No UMAP collection found"}, status_code=404
)
if dimensions == 2:
logger.info("Returning 2D UMAP")
umap_embedding = user.file_watcher.umap_embedding_2d
else:
logger.info("Returning 3D UMAP")
umap_embedding = user.file_watcher.umap_embedding_3d
if len(umap_embedding) == 0:
return JSONResponse(
{"error": "No UMAP embedding found"}, status_code=404
)
result = {
"ids": collection.get("ids", []),
"metadatas": collection.get("metadatas", []),
"documents": collection.get("documents", []),
"embeddings": umap_embedding.tolist(),
"size": user.file_watcher.collection.count()
}
return JSONResponse(result)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"put_umap error: {str(e)}")
return JSONResponse({"error": str(e)}, 500)
@self.app.put("/api/similarity/{context_id}")
async def put_similarity(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
context = await self.upsert_context(context_id)
user = context.user
try:
data = await request.json()
query = data.get("query", "")
threshold = data.get("threshold", defines.default_rag_threshold)
results = data.get("results", defines.default_rag_top_k)
except:
query = ""
threshold = defines.default_rag_threshold
results = defines.default_rag_top_k
if not query:
return JSONResponse(
{"error": "No query provided for similarity search"},
status_code=400,
)
try:
chroma_results = user.file_watcher.find_similar(
query=query, top_k=results, threshold=threshold
)
if not chroma_results:
return JSONResponse({"error": "No results found"}, status_code=404)
chroma_embedding = np.array(
chroma_results["query_embedding"]
).flatten() # Ensure correct shape
logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
umap_2d = user.file_watcher.umap_model_2d.transform([chroma_embedding])[
0
].tolist()
logger.info(
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
) # Debug output
umap_3d = user.file_watcher.umap_model_3d.transform([chroma_embedding])[
0
].tolist()
logger.info(
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
) # Debug output
return JSONResponse({
"distances": chroma_results["distances"],
"ids": chroma_results["ids"],
"metadatas": chroma_results["metadatas"],
"query": query,
"umap_embedding_2d": umap_2d,
"umap_embedding_3d": umap_3d,
"size": user.file_watcher.collection.count()
})
except Exception as e:
logger.error(e)
logging.error(traceback.format_exc())
return JSONResponse({"error": str(e)}, 500)
@self.app.put("/api/reset/{context_id}/{agent_type}")
async def put_reset(context_id: str, agent_type: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id):
logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = await self.upsert_context(context_id)
agent = context.get_agent(agent_type)
if not agent:
response = { "history": [] }
return JSONResponse(response)
data = await request.json()
try:
response = {}
for reset_operation in data["reset"]:
match reset_operation:
case "system_prompt":
logger.info(f"Resetting {reset_operation}")
case "rags":
logger.info(f"Resetting {reset_operation}")
context.rags = [ r.model_copy() for r in context.user.rags]
response["rags"] = [ r.model_dump(mode="json") for r in context.rags ]
case "tools":
logger.info(f"Resetting {reset_operation}")
context.tools = Tools.all_tools()
response["tools"] = Tools.llm_tools(context.tools)
case "history":
reset_map = {
"job_description": (
"job_description",
"resume",
"fact_check",
),
"resume": ("job_description", "resume", "fact_check"),
"fact_check": (
"job_description",
"resume",
"fact_check",
),
"chat": ("chat",),
}
resets = reset_map.get(agent_type, ())
for mode in resets:
tmp = context.get_agent(mode)
if not tmp:
logger.info(
f"Agent {mode} not found for {context_id}"
)
continue
logger.info(f"Resetting {reset_operation} for {mode}")
if mode != "chat":
logger.info(f"Removing agent {tmp.agent_type} from {context_id}")
context.remove_agent(tmp)
tmp.conversation.reset()
response["history"] = []
response["context_used"] = agent.context_tokens
if not response:
return JSONResponse(
{"error": "Usage: { reset: rags|tools|history|system_prompt}"}
)
else:
await self.save_context(context_id)
return JSONResponse(response)
except Exception as e:
logger.error(f"Error in reset: {e}")
logger.error(traceback.format_exc())
return JSONResponse(
{"error": "Usage: { reset: rags|tools|history|system_prompt}"}
)
@self.app.put("/api/tunables/{context_id}")
async def put_tunables(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
context = await self.upsert_context(context_id)
data = await request.json()
agent = context.get_agent("chat")
if not agent:
logger.info("chat agent does not exist on this context!")
return JSONResponse(
{"error": f"chat is not recognized", "context": context.id},
status_code=404,
)
for k in data.keys():
match k:
case "tools":
from typing import Any
# { "tools": [{ "tool": tool.name, "enabled": tool.enabled }] }
tools: List[Dict[str, Any]] = data[k]
if not tools:
return JSONResponse(
{
"status": "error",
"message": "Tools can not be empty.",
}
)
for tool in tools:
for context_tool in context.tools:
if context_tool.tool.function.name == tool["name"]:
context_tool.enabled = tool.get("enabled", True)
await self.save_context(context_id)
return JSONResponse({
"tools": [{
**t.function.model_dump(mode='json'),
"enabled": t.enabled,
} for t in context.tools]
})
case "rags":
from typing import Any
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
rag_configs: list[dict[str, Any]] = data[k]
if not rag_configs:
return JSONResponse(
{
"status": "error",
"message": "RAGs can not be empty.",
}
)
for config in rag_configs:
for context_rag in context.rags:
if context_rag.name == config["name"]:
context_rag.enabled = config["enabled"]
await self.save_context(context_id)
return JSONResponse({"rags": [ r.model_dump(mode="json") for r in context.rags]})
case "system_prompt":
system_prompt = data[k].strip()
if not system_prompt:
return JSONResponse(
{
"status": "error",
"message": "System prompt can not be empty.",
}
)
agent.system_prompt = system_prompt
await self.save_context(context_id)
return JSONResponse({"system_prompt": system_prompt})
case _:
return JSONResponse(
{"error": f"Unrecognized tunable {k}"}, status_code=404
)
except Exception as e:
logger.error(f"Error in put_tunables: {e}")
return JSONResponse({"error": str(e)}, status_code=500)
@self.app.get("/api/user/{context_id}")
async def get_user(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
context = await self.upsert_context(context_id)
user = context.user
user_data = {
"username": user.username,
"first_name": user.first_name,
"last_name": user.last_name,
"full_name": user.full_name,
"description": user.description,
"contact_info": user.contact_info,
"title": user.title,
"phone": user.phone,
"location": user.location,
"email": user.email,
"rag_content_size": user.rag_content_size,
"has_profile": user.has_profile,
"questions": [ q.model_dump(mode='json') for q in user.user_questions],
}
return JSONResponse(user_data)
@self.app.get("/api/tunables/{context_id}")
async def get_tunables(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
context = await self.upsert_context(context_id)
agent = context.get_agent("chat")
if not agent:
logger.info("chat agent does not exist on this context!")
return JSONResponse(
{"error": f"chat is not recognized", "context": context.id},
status_code=404,
)
return JSONResponse(
{
"system_prompt": agent.system_prompt,
"rags": [ r.model_dump(mode="json") for r in context.rags ],
"tools": [{
**t.function.model_dump(mode='json'),
"enabled": t.enabled,
} for t in context.tools],
}
)
@self.app.get("/api/system-info/{context_id}")
async def get_system_info(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
return JSONResponse(system_info(self.model))
@self.app.post("/api/{agent_type}/{context_id}")
async def post_agent_endpoint(
agent_type: str, context_id: str, request: Request
):
logger.info(f"{request.method} {request.url.path}")
try:
context = await self.upsert_context(context_id)
except Exception as e:
error = {
"error": f"Unable to create or access context {context_id}: {e}"
}
logger.info(error)
return JSONResponse(error, status_code=404)
try:
data = await request.json()
query: Query = Query(**data)
except Exception as e:
error = {"error": f"Attempt to parse request: {e}"}
logger.info(error)
return JSONResponse(error, status_code=400)
try:
agent = context.get_or_create_agent(agent_type, **query.agent_options)
except Exception as e:
error = {
"error": f"Attempt to create agent type: {agent_type} failed: {e}"
}
logger.error(traceback.format_exc())
logger.info(error)
return JSONResponse(error, status_code=404)
try:
async def flush_generator():
logger.info(f"{agent.agent_type} - {inspect.stack()[0].function}")
try:
start_time = time.perf_counter()
async for message in self.generate_response(
context=context,
agent=agent,
prompt=query.prompt,
tunables=query.tunables,
):
if message.status != "done" and message.status != "partial":
if message.status == "streaming":
result = {
"status": "streaming",
"chunk": message.chunk,
"remaining_time": LLM_TIMEOUT
- (time.perf_counter() - start_time),
}
else:
start_time = time.perf_counter()
result = {
"status": message.status,
"response": message.response,
"remaining_time": LLM_TIMEOUT,
}
else:
logger.info(f"Providing {message.status} response.")
try:
result = message.model_dump(
by_alias=True, mode="json"
)
except Exception as e:
result = {"status": "error", "response": str(e)}
yield json.dumps(result) + "\n"
break
# Convert to JSON and add newline
message.network_packets += 1
message.network_bytes += len(result)
try:
disconnected = await request.is_disconnected()
except Exception as e:
logger.warning(f"Disconnection check failed: {e}")
disconnected = True
if disconnected:
logger.info("Disconnect detected. Continuing generation to store in cache.")
disconnected = True
if not disconnected:
yield json.dumps(result) + "\n"
current_time = time.perf_counter()
if current_time - start_time > LLM_TIMEOUT:
message.status = "error"
message.response = f"Processing time ({LLM_TIMEOUT}s) exceeded for single LLM inference (likely due to LLM getting stuck.) You will need to retry your query."
message.partial_response = message.response
logger.info(message.response + " Ending session")
result = message.model_dump(by_alias=True, mode="json")
if not disconnected:
yield json.dumps(result) + "\n"
if message.status == "error":
context.processing = False
break
# Allow the event loop to process the write
await asyncio.sleep(0)
except Exception as e:
context.processing = False
logger.error(traceback.format_exc())
logger.error(f"Error in generate_response: {e}")
yield json.dumps({"status": "error", "response": str(e)}) + "\n"
finally:
# Save context on completion or error
await self.save_context(context_id)
logger.info("Flush generator completed normally.")
# Return StreamingResponse with appropriate headers
return StreamingResponse(
flush_generator(),
media_type="application/json",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Prevents Nginx buffering if you're using it
},
)
except Exception as e:
context.processing = False
logger.error(f"Error in post_chat_endpoint: {e}")
return JSONResponse({"error": str(e)}, status_code=500)
@self.app.post("/api/create-session")
async def create_session(request: Request):
logger.info(f"{request.method} {request.url.path}")
context = await self.create_context(username=defines.default_username)
return JSONResponse({"id": context.id})
@self.app.get("/api/join-session/{context_id}")
async def join_session(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
context = await self.load_context(context_id=context_id)
if not context:
return JSONResponse({"error": f"{context_id} does not exist."}, 404)
return JSONResponse({"id": context.id})
@self.app.get("/api/u/{context_id}")
async def get_users(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
context = await self.load_context(context_id)
if not context:
return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404)
users = [User.sanitize(u) for u in User.get_users()]
return JSONResponse(users)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"get_users error: {str(e)}")
return JSONResponse({ "error": "Unable to parse users"}, 500)
@self.app.get("/api/u/{username}/images/{image_id}/{context_id}")
async def get_user_image(username: str, image_id: str, context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
self.sanitize_input(context_id)
self.sanitize_input(username)
self.sanitize_input(image_id)
if not User.exists(username):
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
context = await self.load_context(context_id)
if not context:
return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404)
image_path = os.path.join(defines.user_dir, username, "images", image_id)
if not os.path.exists(image_path):
return JSONResponse({ "error": "User {username} does not image {image_id}"}, status_code=404)
return FileResponse(image_path)
except ValueError as e:
return JSONResponse({ "error": f"Invalid input: {image_id}" }, 400)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
return JSONResponse({ "error": f"Unable to get image {username} {image_id}"}, 500)
@self.app.get("/api/u/{username}/profile/{context_id}")
async def get_user_profile(username: str, context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
if not User.exists(username):
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
context = await self.load_context(context_id)
if not context:
return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404)
profile_path = os.path.join(defines.user_dir, username, f"profile.png")
if not os.path.exists(profile_path):
return JSONResponse({ "error": "User {username} does not have a profile picture"}, status_code=404)
return FileResponse(profile_path)
except Exception as e:
return JSONResponse({ "error": f"Unable to load user {username}"}, 500)
@self.app.post("/api/u/{username}/{context_id}")
async def post_user(username: str, context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
if not User.exists(username):
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
context = await self.load_context(context_id)
if not context:
return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404)
matching_user = next((user for user in self.users if user.username == username), None)
if matching_user:
user = matching_user
else:
user = User(username=username, llm=self.llm)
await user.initialize(prometheus_collector=self.prometheus_collector)
self.users.append(user)
reset_map = (
"chat",
"job_description",
"resume",
"fact_check",
)
for mode in reset_map:
tmp = context.get_agent(mode)
if not tmp:
continue
logger.info(f"User change: Resetting history for {mode}")
if mode != "chat":
context.remove_agent(tmp)
tmp.conversation.reset()
context.user = user
user_data = {
"username": user.username,
"first_name": user.first_name,
"last_name": user.last_name,
"full_name": user.full_name,
"title": user.title,
"phone": user.phone,
"location": user.location,
"email": user.email,
"description": user.description,
"contact_info": user.contact_info,
"rag_content_size": user.rag_content_size,
"has_profile": user.has_profile,
"questions": [ q.model_dump(mode='json') for q in user.user_questions],
}
await self.save_context(context_id)
return JSONResponse(user_data)
except Exception as e:
return JSONResponse({ "error": f"Unable to load user {username}"}, 500)
@self.app.post("/api/context/u/{username}")
async def create_user_context(username: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
if not User.exists(username):
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
context = await self.create_context(username=username)
logger.info(f"Generated new context {context.id} for {username}")
return JSONResponse({"id": context.id})
except Exception as e:
logger.error(f"create_user_context error: {str(e)}")
logger.error(traceback.format_exc())
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
@self.app.post("/api/context")
async def create_context(request: Request):
logger.info(f"{request.method} {request.url.path}")
return self.app.create_user_context(defines.default_username, request)
@self.app.get("/api/history/{context_id}/{agent_type}")
async def get_history(context_id: str, agent_type: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
try:
context = await self.upsert_context(context_id)
agent = context.get_agent(agent_type)
if not agent:
logger.info(
f"Agent {agent_type} not found. Returning empty history."
)
return JSONResponse({"messages": []})
logger.info(
f"History for {agent_type} contains {len(agent.conversation)} entries."
)
return agent.conversation
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"get_history error: {str(e)}")
return JSONResponse({"error": str(e)}, status_code=404)
@self.app.get("/api/tools/{context_id}")
async def get_tools(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
context = await self.upsert_context(context_id)
return JSONResponse(context.tools)
@self.app.put("/api/tools/{context_id}")
async def put_tools(context_id: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id):
logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = await self.upsert_context(context_id)
try:
data = await request.json()
modify = data["tool"]
enabled = data["enabled"]
for tool in context.tools:
if modify == tool.function.name:
tool.enabled = enabled
await self.save_context(context_id)
return JSONResponse(context.tools)
return JSONResponse(
{"status": f"{modify} not found in tools."}, status_code=404
)
except:
return JSONResponse({"status": "error"}, 405)
@self.app.get("/api/context-status/{context_id}/{agent_type}")
async def get_context_status(context_id, agent_type: str, request: Request):
logger.info(f"{request.method} {request.url.path}")
if not is_valid_uuid(context_id):
logger.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = await self.upsert_context(context_id)
agent = context.get_agent(agent_type)
if not agent:
return JSONResponse(
{"context_used": 0, "max_context": defines.max_context}
)
return JSONResponse(
{
"context_used": agent.context_tokens,
"max_context": defines.max_context,
}
)
@self.app.get("/api/health")
async def health_check(redis_client: redis.Redis = Depends(self.get_redis)):
try:
await redis_client.ping()
return {"status": "healthy", "redis": "connected"}
except Exception as e:
raise HTTPException(status_code=503, detail=f"Redis connection failed: {e}")
@self.app.get("/api/redis/stats")
async def redis_stats(redis_client: redis.Redis = Depends(self.get_redis)):
try:
info = await redis_client.info()
return {
"connected_clients": info.get("connected_clients"),
"used_memory_human": info.get("used_memory_human"),
"total_commands_processed": info.get("total_commands_processed"),
"keyspace_hits": info.get("keyspace_hits"),
"keyspace_misses": info.get("keyspace_misses"),
"uptime_in_seconds": info.get("uptime_in_seconds")
}
except Exception as e:
raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}")
@self.app.get("/{path:path}")
async def serve_static(path: str, request: Request):
full_path = os.path.join(defines.static_content, path)
# Check if the original path exists
if os.path.exists(full_path) and os.path.isfile(full_path):
logger.info(f"Serve static request for {full_path}")
return FileResponse(full_path)
# Check if the path matches /{filename}.png
# if path.startswith("/") and path.endswith(".png"):
# filename = path[1:-4] # Remove leading '/' and trailing '.png'
# alt_path = f"/opt/backstory/users/{filename}/{filename}.png"
# # Check if the alternative path exists
# if os.path.exists(alt_path) and os.path.isfile(alt_path):
# logger.info(f"Serve static request for alternative path {alt_path}")
# return FileResponse(alt_path)
# If neither path exists, return 404
logger.info(f"File not found for path {full_path} -- returning index.html")
return FileResponse(os.path.join(defines.static_content, "index.html"))
async def save_context(self, context_id):
"""
Serialize a Python dictionary to a file in the agents directory.
Args:
data: Dictionary containing the agent data
context_id: UUID string for the context. If it doesn't exist, it is created
Returns:
The context_id used for the file
"""
context = await self.upsert_context(context_id)
# Create agents directory if it doesn't exist
if not os.path.exists(defines.context_dir):
os.makedirs(defines.context_dir)
# Create the full file path
file_path = os.path.join(defines.context_dir, context_id)
# Serialize the data to JSON and write to file
try:
# Check for non-serializable fields before dumping
serialization_errors = check_serializable(context)
if serialization_errors:
for error in serialization_errors:
logger.error(error)
raise ValueError("Found non-serializable fields in the model")
# Dump the model prior to opening file in case there is
# a validation error so it doesn't delete the current
# context session
json_data = context.model_dump_json(by_alias=True)
with open(file_path, "w") as f:
f.write(json_data)
except ValidationError as e:
logger.error(e)
logger.error(traceback.format_exc())
for error in e.errors():
print(f"Field: {error['loc'][0]}, Error: {error['msg']}")
except PydanticSerializationError as e:
logger.error(e)
logger.error(traceback.format_exc())
logger.error(f"Serialization error: {str(e)}")
# Inspect the model to identify problematic fields
for field_name, value in context.__dict__.items():
if isinstance(value, np.ndarray):
logger.error(f"Field '{field_name}' contains non-serializable type: {type(value)}")
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
return context_id
async def load_context(self, context_id: str) -> Context | None:
"""
Load a context from a file in the context directory or create a new one if it doesn't exist.
Args:
context_id: UUID string for the context.
Returns:
A Context object with the specified ID and default settings.
"""
file_path = os.path.join(defines.context_dir, context_id)
# Check if the file exists
if not os.path.exists(file_path):
return None
# Read and deserialize the data
with open(file_path, "r") as f:
content = f.read()
logger.info(
f"Loading context from {file_path}, content length: {len(content)}"
)
json_data = {}
try:
# Try parsing as JSON first to ensure valid JSON
json_data = json.loads(content)
logger.info("JSON parsed successfully, attempting model validation")
context = Context.model_validate(json_data)
username = context.username
if not User.exists(username):
raise ValueError(f"Attempt to load context {context.id} with invalid user {username}")
matching_user = next((user for user in self.users if user.username == username), None)
if matching_user:
user = matching_user
else:
user = User(username=username, llm=self.llm)
await user.initialize(prometheus_collector=self.prometheus_collector)
self.users.append(user)
context.user = user
# Now set context on agents manually
agent_types = [agent.agent_type for agent in context.agents]
if len(agent_types) != len(set(agent_types)):
raise ValueError(
"Context cannot contain multiple agents of the same agent_type"
)
for agent in context.agents:
agent.set_context(context)
self.contexts[context_id] = context
logger.info(f"Successfully loaded context {context_id}")
except ValidationError as e:
logger.error(e)
logger.error(traceback.format_exc())
for error in e.errors():
print(f"Field: {error['loc'][0]}, Error: {error['msg']}")
except Exception as e:
logger.error(f"Error validating context: {str(e)}")
logger.error(traceback.format_exc())
for key in json_data:
logger.info(f"{key} = {type(json_data[key])} {str(json_data[key])[:60] if json_data[key] else "None"}")
logger.info("*" * 50)
return None
return self.contexts[context_id]
async def load_or_create_context(self, context_id: str) -> Context:
"""
Load a context from a file in the context directory or create a new one if it doesn't exist.
Args:
context_id: UUID string for the context.
Returns:
A Context object with the specified ID and default settings.
"""
context = await self.load_context(context_id)
if context:
return context
logger.info(f"Context not found. Creating new instance of context {context_id}.")
self.contexts[context_id] = await self.create_context(username=defines.default_username, context_id=context_id)
return self.contexts[context_id]
async def create_context(self, username: str, context_id=None) -> Context:
"""
Create a new context with a unique ID and default settings.
Args:
context_id: Optional UUID string for the context. If not provided, a new UUID is generated.
Returns:
A Context object with the specified ID and default settings.
"""
if not User.exists(username):
raise ValueError(f"{username} does not exist.")
# If username
matching_user = next((user for user in self.users if user.username == username), None)
if matching_user:
user = matching_user
logger.info(f"Found matching user: {user.username}")
else:
user = User(username=username, llm=self.llm)
await user.initialize(prometheus_collector=self.prometheus_collector)
logger.info(f"Created new instance of user: {user.username}")
self.users.append(user)
logger.info(f"Creating context {context_id or "new"} for user: {user.username}")
try:
if context_id:
context = Context(
id=context_id,
user=user,
rags=[ rag.model_copy() for rag in user.rags ],
tools=Tools.all_tools()
)
else:
context = Context(
user=user,
rags=[ rag.model_copy() for rag in user.rags ],
tools=Tools.all_tools()
)
except ValidationError as e:
logger.error(e)
logger.error(traceback.format_exc())
for error in e.errors():
print(f"Field: {error['loc'][0]}, Error: {error['msg']}")
exit(1)
logger.info(f"New context created with ID: {context.id}")
if os.path.exists(defines.resume_doc):
context.user_resume = open(defines.resume_doc, "r").read()
context.get_or_create_agent(agent_type="chat")
# system_prompt=system_message)
# context.add_agent(Resume(system_prompt = system_generate_resume))
# context.add_agent(JobDescription(system_prompt = system_job_description))
# context.add_agent(FactCheck(system_prompt = system_fact_check))
logger.info(f"{context.id} created and added to contexts.")
self.contexts[context.id] = context
await self.save_context(context.id)
return context
async def upsert_context(self, context_id=None) -> Context:
"""
Upsert a context based on the provided context_id.
Args:
context_id: UUID string for the context. If it doesn't exist, a new context is created.
Returns:
A Context object with the specified ID and default settings.
"""
if not context_id:
logger.warning("No context ID provided. Creating a new context.")
return await self.create_context(username=defines.default_username)
if context_id in self.contexts:
return self.contexts[context_id]
logger.info(f"Context {context_id} is not yet loaded.")
return await self.load_or_create_context(context_id=context_id)
@REQUEST_TIME.time()
async def generate_response(
self, context: Context, agent: Agent, prompt: str, tunables: Tunables | None
) -> AsyncGenerator[Message, None]:
agent_type = agent.get_agent_type()
logger.info(f"generate_response: type - {agent_type}")
# Merge tunables to take agent defaults and override with user supplied settings
agent_tunables = agent.tunables.model_dump() if agent.tunables else {}
user_tunables = tunables.model_dump() if tunables else {}
merged_tunables = {**agent_tunables, **user_tunables}
message = Message(prompt=prompt, tunables=Tunables(**merged_tunables))
async for message in agent.prepare_message(message):
# logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}")
if message.status == "error":
yield message
return
if message.status != "done":
yield message
async for message in agent.process_message(self.llm, self.model, message):
if message.status != "done":
yield message
if message.status == "error":
return
logger.info(
f"{agent_type}.process_message: {message.status} {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}"
)
message.status = "done"
yield message
return
def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs):
try:
if self.ssl_enabled:
logger.info(f"Starting web server at https://{host}:{port}")
uvicorn.run(
self.app,
host=host,
port=port,
log_config=None,
ssl_keyfile=defines.key_path,
ssl_certfile=defines.cert_path,
)
else:
logger.info(f"Starting web server at http://{host}:{port}")
uvicorn.run(self.app, host=host, port=port, log_config=None)
except KeyboardInterrupt:
for user in self.users:
if user.observer:
user.observer.stop()
for user in self.users:
if user.observer:
user.observer.join()
# %%
# Main function to run everything
def main():
# Parse command-line arguments
args = parse_args()
# Setup logging based on the provided level
logger.setLevel(args.level.upper())
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.*")
warnings.filterwarnings("ignore", category=UserWarning, module="umap.*")
llm = ollama.Client(host=args.ollama_server)
web_server = WebServer(llm, args.ollama_model)
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
main()