1285 lines
54 KiB
Python
1285 lines
54 KiB
Python
LLM_TIMEOUT = 600
|
|
|
|
from utils import logger
|
|
from pydantic import BaseModel, Field, ValidationError # type: ignore
|
|
from pydantic_core import PydanticSerializationError # type: ignore
|
|
from typing import List
|
|
|
|
from typing import AsyncGenerator, Dict, Optional
|
|
|
|
# %%
|
|
# Imports [standard]
|
|
# Standard library modules (no try-except needed)
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import uuid
|
|
import subprocess
|
|
import re
|
|
import math
|
|
import warnings
|
|
# from typing import Any
|
|
import inspect
|
|
import time
|
|
import traceback
|
|
|
|
|
|
def try_import(module_name, pip_name=None):
|
|
try:
|
|
__import__(module_name)
|
|
except ImportError:
|
|
print(f"Module '{module_name}' not found. Install it using:")
|
|
print(f" pip install {pip_name or module_name}")
|
|
|
|
|
|
# Third-party modules with import checks
|
|
try_import("ollama")
|
|
try_import("requests")
|
|
try_import("fastapi")
|
|
try_import("uvicorn")
|
|
try_import("numpy")
|
|
try_import("umap")
|
|
try_import("sklearn")
|
|
try_import("prometheus_client")
|
|
try_import("prometheus_fastapi_instrumentator")
|
|
|
|
import ollama
|
|
from contextlib import asynccontextmanager
|
|
from fastapi import FastAPI, Request, HTTPException # type: ignore
|
|
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse # type: ignore
|
|
from fastapi.middleware.cors import CORSMiddleware # type: ignore
|
|
import uvicorn # type: ignore
|
|
import numpy as np # type: ignore
|
|
|
|
# Prometheus
|
|
from prometheus_client import Summary # type: ignore
|
|
from prometheus_fastapi_instrumentator import Instrumentator # type: ignore
|
|
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
|
|
|
from utils import (
|
|
rag as Rag,
|
|
RagEntry,
|
|
tools as Tools,
|
|
Context,
|
|
Conversation,
|
|
Message,
|
|
Agent,
|
|
Metrics,
|
|
Tunables,
|
|
defines,
|
|
User,
|
|
check_serializable,
|
|
logger,
|
|
)
|
|
|
|
class Query(BaseModel):
|
|
prompt: str
|
|
tunables: Tunables = Field(default_factory=Tunables)
|
|
agent_options: Dict[str, str | int | float | Dict] = Field(default={}, exclude=True)
|
|
|
|
REQUEST_TIME = Summary("request_processing_seconds", "Time spent processing request")
|
|
|
|
|
|
def get_installed_ram():
|
|
try:
|
|
with open("/proc/meminfo", "r") as f:
|
|
meminfo = f.read()
|
|
match = re.search(r"MemTotal:\s+(\d+)", meminfo)
|
|
if match:
|
|
return f"{math.floor(int(match.group(1)) / 1000**2)}GB" # Convert KB to GB
|
|
except Exception as e:
|
|
return f"Error retrieving RAM: {e}"
|
|
|
|
|
|
def get_graphics_cards():
|
|
gpus = []
|
|
try:
|
|
# Run the ze-monitor utility
|
|
result = subprocess.run(
|
|
["ze-monitor"], capture_output=True, text=True, check=True
|
|
)
|
|
|
|
# Clean up the output (remove leading/trailing whitespace and newlines)
|
|
output = result.stdout.strip()
|
|
for index in range(len(output.splitlines())):
|
|
result = subprocess.run(
|
|
["ze-monitor", "--device", f"{index+1}", "--info"],
|
|
capture_output=True,
|
|
text=True,
|
|
check=True,
|
|
)
|
|
gpu_info = result.stdout.strip().splitlines()
|
|
gpu = {
|
|
"discrete": True, # Assume it's discrete initially
|
|
"name": None,
|
|
"memory": None,
|
|
}
|
|
gpus.append(gpu)
|
|
for line in gpu_info:
|
|
match = re.match(r"^Device: [^(]*\((.*)\)", line)
|
|
if match:
|
|
gpu["name"] = match.group(1)
|
|
continue
|
|
|
|
match = re.match(r"^\s*Memory: (.*)", line)
|
|
if match:
|
|
gpu["memory"] = match.group(1)
|
|
continue
|
|
|
|
match = re.match(r"^.*Is integrated with host: Yes.*", line)
|
|
if match:
|
|
gpu["discrete"] = False
|
|
continue
|
|
|
|
return gpus
|
|
except Exception as e:
|
|
return f"Error retrieving GPU info: {e}"
|
|
|
|
|
|
def get_cpu_info():
|
|
try:
|
|
with open("/proc/cpuinfo", "r") as f:
|
|
cpuinfo = f.read()
|
|
model_match = re.search(r"model name\s+:\s+(.+)", cpuinfo)
|
|
cores_match = re.findall(r"processor\s+:\s+\d+", cpuinfo)
|
|
if model_match and cores_match:
|
|
return f"{model_match.group(1)} with {len(cores_match)} cores"
|
|
except Exception as e:
|
|
return f"Error retrieving CPU info: {e}"
|
|
|
|
def system_info(model):
|
|
return {
|
|
"System RAM": get_installed_ram(),
|
|
"Graphics Card": get_graphics_cards(),
|
|
"CPU": get_cpu_info(),
|
|
"LLM Model": model,
|
|
"Embedding Model": defines.embedding_model,
|
|
"Context length": defines.max_context,
|
|
}
|
|
|
|
# %%
|
|
# Defaults
|
|
OLLAMA_API_URL = defines.ollama_api_url
|
|
MODEL_NAME = defines.model
|
|
LOG_LEVEL = "info"
|
|
WEB_HOST = "0.0.0.0"
|
|
WEB_PORT = 8911
|
|
DEFAULT_HISTORY_LENGTH = 5
|
|
|
|
|
|
# %%
|
|
# Cmd line overrides
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
|
parser.add_argument(
|
|
"--ollama-server",
|
|
type=str,
|
|
default=OLLAMA_API_URL,
|
|
help=f"Ollama API endpoint. default={OLLAMA_API_URL}",
|
|
)
|
|
parser.add_argument(
|
|
"--ollama-model",
|
|
type=str,
|
|
default=MODEL_NAME,
|
|
help=f"LLM model to use. default={MODEL_NAME}",
|
|
)
|
|
parser.add_argument(
|
|
"--web-host",
|
|
type=str,
|
|
default=WEB_HOST,
|
|
help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.",
|
|
)
|
|
parser.add_argument(
|
|
"--web-port",
|
|
type=str,
|
|
default=WEB_PORT,
|
|
help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.",
|
|
)
|
|
parser.add_argument(
|
|
"--level",
|
|
type=str,
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
|
default=LOG_LEVEL,
|
|
help=f"Set the logging level. default={LOG_LEVEL}",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
# %%
|
|
|
|
|
|
# %%
|
|
|
|
|
|
# %%
|
|
def is_valid_uuid(value: str) -> bool:
|
|
try:
|
|
uuid_obj = uuid.UUID(value, version=4)
|
|
return str(uuid_obj) == value
|
|
except (ValueError, TypeError):
|
|
return False
|
|
|
|
|
|
# %%
|
|
class WebServer:
|
|
@asynccontextmanager
|
|
async def lifespan(self, app: FastAPI):
|
|
yield
|
|
for user in self.users:
|
|
if user.observer:
|
|
user.observer.stop()
|
|
user.observer.join()
|
|
logger.info("File watcher stopped")
|
|
|
|
def __init__(self, llm, model=MODEL_NAME):
|
|
self.app = FastAPI(lifespan=self.lifespan)
|
|
|
|
self.prometheus_collector = CollectorRegistry()
|
|
self.metrics = Metrics(prometheus_collector=self.prometheus_collector)
|
|
|
|
# Keep the Instrumentator instance alive
|
|
self.instrumentator = Instrumentator(registry=self.prometheus_collector)
|
|
|
|
# Instrument the FastAPI app
|
|
self.instrumentator.instrument(self.app)
|
|
|
|
# Expose the /metrics endpoint
|
|
self.instrumentator.expose(self.app, endpoint="/metrics")
|
|
|
|
self.llm = llm
|
|
self.model = model
|
|
self.processing = False
|
|
|
|
self.users = []
|
|
self.contexts = {}
|
|
|
|
self.ssl_enabled = os.path.exists(defines.key_path) and os.path.exists(
|
|
defines.cert_path
|
|
)
|
|
|
|
if self.ssl_enabled:
|
|
allow_origins = ["https://battle-linux.ketrenos.com:3000",
|
|
"https://backstory-beta.ketrenos.com"]
|
|
else:
|
|
allow_origins = ["http://battle-linux.ketrenos.com:3000",
|
|
"http://backstory-beta.ketrenos.com"]
|
|
|
|
logger.info(f"Allowed origins: {allow_origins}")
|
|
|
|
self.app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=allow_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
self.setup_routes()
|
|
|
|
def sanitize_input(self, input: str):
|
|
# Validate input: allow only alphanumeric, underscores, and hyphens
|
|
if not re.match(r'^[a-zA-Z0-9._-]+$', input): # alphanumeric, _, -, and . are valid
|
|
raise ValueError("Invalid input format.")
|
|
if re.match(r'\.\.', input): # two ticks in a row is invalid
|
|
raise ValueError("Invalid input format.")
|
|
|
|
|
|
def setup_routes(self):
|
|
# @self.app.get("/")
|
|
# async def root():
|
|
# context = self.create_context(username=defines.default_username)
|
|
# logger.info(f"Redirecting non-context to {context.id}")
|
|
# return RedirectResponse(url=f"/{context.id}", status_code=307)
|
|
# # return JSONResponse({"redirect": f"/{context.id}"})
|
|
|
|
@self.app.get("/api/umap/entry/{doc_id}/{context_id}")
|
|
async def get_umap(doc_id: str, context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
try:
|
|
context = await self.upsert_context(context_id)
|
|
if not context:
|
|
return JSONResponse(
|
|
{"error": f"Invalid context: {context_id}"}, status_code=400
|
|
)
|
|
|
|
user = context.user
|
|
collection = user.umap_collection
|
|
if not collection:
|
|
return JSONResponse(
|
|
{"error": "No UMAP collection found"}, status_code=404
|
|
)
|
|
if not collection.get("metadatas", None):
|
|
return JSONResponse(f"Document id {doc_id} not found.", 404)
|
|
|
|
for index, id in enumerate(collection.get("ids", [])):
|
|
if id == doc_id:
|
|
metadata = collection.get("metadatas", [])[index].copy()
|
|
content = user.file_watcher.prepare_metadata(metadata)
|
|
return JSONResponse(content)
|
|
|
|
return JSONResponse(f"Document id {doc_id} not found.", 404)
|
|
except Exception as e:
|
|
logger.error(f"get_umap error: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
return JSONResponse({"error": str(e)}, 500)
|
|
|
|
@self.app.put("/api/umap/{context_id}")
|
|
async def put_umap(context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
try:
|
|
context = await self.upsert_context(context_id)
|
|
if not context:
|
|
return JSONResponse(
|
|
{"error": f"Invalid context: {context_id}"}, status_code=400
|
|
)
|
|
user = context.user
|
|
data = await request.json()
|
|
dimensions = data.get("dimensions", 2)
|
|
collection = user.file_watcher.umap_collection
|
|
if not collection:
|
|
return JSONResponse(
|
|
{"error": "No UMAP collection found"}, status_code=404
|
|
)
|
|
if dimensions == 2:
|
|
logger.info("Returning 2D UMAP")
|
|
umap_embedding = user.file_watcher.umap_embedding_2d
|
|
else:
|
|
logger.info("Returning 3D UMAP")
|
|
umap_embedding = user.file_watcher.umap_embedding_3d
|
|
|
|
if len(umap_embedding) == 0:
|
|
return JSONResponse(
|
|
{"error": "No UMAP embedding found"}, status_code=404
|
|
)
|
|
result = {
|
|
"ids": collection.get("ids", []),
|
|
"metadatas": collection.get("metadatas", []),
|
|
"documents": collection.get("documents", []),
|
|
"embeddings": umap_embedding.tolist(),
|
|
"size": user.file_watcher.collection.count()
|
|
}
|
|
|
|
return JSONResponse(result)
|
|
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
logger.error(f"put_umap error: {str(e)}")
|
|
return JSONResponse({"error": str(e)}, 500)
|
|
|
|
@self.app.put("/api/similarity/{context_id}")
|
|
async def put_similarity(context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
context = await self.upsert_context(context_id)
|
|
user = context.user
|
|
try:
|
|
data = await request.json()
|
|
query = data.get("query", "")
|
|
threshold = data.get("threshold", defines.default_rag_threshold)
|
|
results = data.get("results", defines.default_rag_top_k)
|
|
except:
|
|
query = ""
|
|
threshold = defines.default_rag_threshold
|
|
results = defines.default_rag_top_k
|
|
if not query:
|
|
return JSONResponse(
|
|
{"error": "No query provided for similarity search"},
|
|
status_code=400,
|
|
)
|
|
try:
|
|
chroma_results = user.file_watcher.find_similar(
|
|
query=query, top_k=results, threshold=threshold
|
|
)
|
|
if not chroma_results:
|
|
return JSONResponse({"error": "No results found"}, status_code=404)
|
|
|
|
chroma_embedding = np.array(
|
|
chroma_results["query_embedding"]
|
|
).flatten() # Ensure correct shape
|
|
logger.info(f"Chroma embedding shape: {chroma_embedding.shape}")
|
|
|
|
umap_2d = user.file_watcher.umap_model_2d.transform([chroma_embedding])[
|
|
0
|
|
].tolist()
|
|
logger.info(
|
|
f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}"
|
|
) # Debug output
|
|
|
|
umap_3d = user.file_watcher.umap_model_3d.transform([chroma_embedding])[
|
|
0
|
|
].tolist()
|
|
logger.info(
|
|
f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}"
|
|
) # Debug output
|
|
|
|
return JSONResponse({
|
|
"distances": chroma_results["distances"],
|
|
"ids": chroma_results["ids"],
|
|
"metadatas": chroma_results["metadatas"],
|
|
"query": query,
|
|
"umap_embedding_2d": umap_2d,
|
|
"umap_embedding_3d": umap_3d,
|
|
"size": user.file_watcher.collection.count()
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(e)
|
|
logging.error(traceback.format_exc())
|
|
return JSONResponse({"error": str(e)}, 500)
|
|
|
|
@self.app.put("/api/reset/{context_id}/{agent_type}")
|
|
async def put_reset(context_id: str, agent_type: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
if not is_valid_uuid(context_id):
|
|
logger.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = await self.upsert_context(context_id)
|
|
agent = context.get_agent(agent_type)
|
|
if not agent:
|
|
response = { "history": [] }
|
|
return JSONResponse(response)
|
|
|
|
data = await request.json()
|
|
try:
|
|
response = {}
|
|
for reset_operation in data["reset"]:
|
|
match reset_operation:
|
|
case "system_prompt":
|
|
logger.info(f"Resetting {reset_operation}")
|
|
case "rags":
|
|
logger.info(f"Resetting {reset_operation}")
|
|
context.rags = [ r.model_copy() for r in context.user.rags]
|
|
response["rags"] = [ r.model_dump(mode="json") for r in context.rags ]
|
|
case "tools":
|
|
logger.info(f"Resetting {reset_operation}")
|
|
context.tools = Tools.all_tools()
|
|
response["tools"] = Tools.llm_tools(context.tools)
|
|
case "history":
|
|
reset_map = {
|
|
"job_description": (
|
|
"job_description",
|
|
"resume",
|
|
"fact_check",
|
|
),
|
|
"resume": ("job_description", "resume", "fact_check"),
|
|
"fact_check": (
|
|
"job_description",
|
|
"resume",
|
|
"fact_check",
|
|
),
|
|
"chat": ("chat",),
|
|
}
|
|
resets = reset_map.get(agent_type, ())
|
|
for mode in resets:
|
|
tmp = context.get_agent(mode)
|
|
if not tmp:
|
|
logger.info(
|
|
f"Agent {mode} not found for {context_id}"
|
|
)
|
|
continue
|
|
logger.info(f"Resetting {reset_operation} for {mode}")
|
|
if mode != "chat":
|
|
logger.info(f"Removing agent {tmp.agent_type} from {context_id}")
|
|
context.remove_agent(tmp)
|
|
tmp.conversation.reset()
|
|
response["history"] = []
|
|
response["context_used"] = agent.context_tokens
|
|
|
|
if not response:
|
|
return JSONResponse(
|
|
{"error": "Usage: { reset: rags|tools|history|system_prompt}"}
|
|
)
|
|
else:
|
|
await self.save_context(context_id)
|
|
return JSONResponse(response)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in reset: {e}")
|
|
logger.error(traceback.format_exc())
|
|
return JSONResponse(
|
|
{"error": "Usage: { reset: rags|tools|history|system_prompt}"}
|
|
)
|
|
|
|
@self.app.put("/api/tunables/{context_id}")
|
|
async def put_tunables(context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
try:
|
|
context = await self.upsert_context(context_id)
|
|
|
|
data = await request.json()
|
|
agent = context.get_agent("chat")
|
|
if not agent:
|
|
logger.info("chat agent does not exist on this context!")
|
|
return JSONResponse(
|
|
{"error": f"chat is not recognized", "context": context.id},
|
|
status_code=404,
|
|
)
|
|
for k in data.keys():
|
|
match k:
|
|
case "tools":
|
|
from typing import Any
|
|
# { "tools": [{ "tool": tool.name, "enabled": tool.enabled }] }
|
|
tools: List[Dict[str, Any]] = data[k]
|
|
if not tools:
|
|
return JSONResponse(
|
|
{
|
|
"status": "error",
|
|
"message": "Tools can not be empty.",
|
|
}
|
|
)
|
|
for tool in tools:
|
|
for context_tool in context.tools:
|
|
if context_tool.tool.function.name == tool["name"]:
|
|
context_tool.enabled = tool.get("enabled", True)
|
|
await self.save_context(context_id)
|
|
return JSONResponse({
|
|
"tools": [{
|
|
**t.function.model_dump(mode='json'),
|
|
"enabled": t.enabled,
|
|
} for t in context.tools]
|
|
})
|
|
|
|
case "rags":
|
|
from typing import Any
|
|
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
|
rag_configs: list[dict[str, Any]] = data[k]
|
|
if not rag_configs:
|
|
return JSONResponse(
|
|
{
|
|
"status": "error",
|
|
"message": "RAGs can not be empty.",
|
|
}
|
|
)
|
|
for config in rag_configs:
|
|
for context_rag in context.rags:
|
|
if context_rag.name == config["name"]:
|
|
context_rag.enabled = config["enabled"]
|
|
await self.save_context(context_id)
|
|
return JSONResponse({"rags": [ r.model_dump(mode="json") for r in context.rags]})
|
|
|
|
case "system_prompt":
|
|
system_prompt = data[k].strip()
|
|
if not system_prompt:
|
|
return JSONResponse(
|
|
{
|
|
"status": "error",
|
|
"message": "System prompt can not be empty.",
|
|
}
|
|
)
|
|
agent.system_prompt = system_prompt
|
|
await self.save_context(context_id)
|
|
return JSONResponse({"system_prompt": system_prompt})
|
|
case _:
|
|
return JSONResponse(
|
|
{"error": f"Unrecognized tunable {k}"}, status_code=404
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error in put_tunables: {e}")
|
|
return JSONResponse({"error": str(e)}, status_code=500)
|
|
|
|
@self.app.get("/api/user/{context_id}")
|
|
async def get_user(context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
context = await self.upsert_context(context_id)
|
|
user = context.user
|
|
user_data = {
|
|
"username": user.username,
|
|
"first_name": user.first_name,
|
|
"last_name": user.last_name,
|
|
"full_name": user.full_name,
|
|
"description": user.description,
|
|
"contact_info": user.contact_info,
|
|
"title": user.title,
|
|
"phone": user.phone,
|
|
"location": user.location,
|
|
"email": user.email,
|
|
"rag_content_size": user.rag_content_size,
|
|
"has_profile": user.has_profile,
|
|
"questions": [ q.model_dump(mode='json') for q in user.user_questions],
|
|
}
|
|
return JSONResponse(user_data)
|
|
|
|
@self.app.get("/api/tunables/{context_id}")
|
|
async def get_tunables(context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
context = await self.upsert_context(context_id)
|
|
agent = context.get_agent("chat")
|
|
if not agent:
|
|
logger.info("chat agent does not exist on this context!")
|
|
return JSONResponse(
|
|
{"error": f"chat is not recognized", "context": context.id},
|
|
status_code=404,
|
|
)
|
|
return JSONResponse(
|
|
{
|
|
"system_prompt": agent.system_prompt,
|
|
"rags": [ r.model_dump(mode="json") for r in context.rags ],
|
|
"tools": [{
|
|
**t.function.model_dump(mode='json'),
|
|
"enabled": t.enabled,
|
|
} for t in context.tools],
|
|
}
|
|
)
|
|
|
|
@self.app.get("/api/system-info/{context_id}")
|
|
async def get_system_info(context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
return JSONResponse(system_info(self.model))
|
|
|
|
@self.app.post("/api/{agent_type}/{context_id}")
|
|
async def post_agent_endpoint(
|
|
agent_type: str, context_id: str, request: Request
|
|
):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
|
|
try:
|
|
context = await self.upsert_context(context_id)
|
|
except Exception as e:
|
|
error = {
|
|
"error": f"Unable to create or access context {context_id}: {e}"
|
|
}
|
|
logger.info(error)
|
|
return JSONResponse(error, status_code=404)
|
|
|
|
try:
|
|
data = await request.json()
|
|
query: Query = Query(**data)
|
|
except Exception as e:
|
|
error = {"error": f"Attempt to parse request: {e}"}
|
|
logger.info(error)
|
|
return JSONResponse(error, status_code=400)
|
|
|
|
try:
|
|
agent = context.get_or_create_agent(agent_type, **query.agent_options)
|
|
except Exception as e:
|
|
error = {
|
|
"error": f"Attempt to create agent type: {agent_type} failed: {e}"
|
|
}
|
|
logger.error(traceback.format_exc())
|
|
logger.info(error)
|
|
return JSONResponse(error, status_code=404)
|
|
|
|
try:
|
|
|
|
async def flush_generator():
|
|
logger.info(f"{agent.agent_type} - {inspect.stack()[0].function}")
|
|
try:
|
|
start_time = time.perf_counter()
|
|
async for message in self.generate_response(
|
|
context=context,
|
|
agent=agent,
|
|
prompt=query.prompt,
|
|
tunables=query.tunables,
|
|
):
|
|
if message.status != "done" and message.status != "partial":
|
|
if message.status == "streaming":
|
|
result = {
|
|
"status": "streaming",
|
|
"chunk": message.chunk,
|
|
"remaining_time": LLM_TIMEOUT
|
|
- (time.perf_counter() - start_time),
|
|
}
|
|
else:
|
|
start_time = time.perf_counter()
|
|
result = {
|
|
"status": message.status,
|
|
"response": message.response,
|
|
"remaining_time": LLM_TIMEOUT,
|
|
}
|
|
else:
|
|
logger.info(f"Providing {message.status} response.")
|
|
try:
|
|
result = message.model_dump(
|
|
by_alias=True, mode="json"
|
|
)
|
|
except Exception as e:
|
|
result = {"status": "error", "response": str(e)}
|
|
yield json.dumps(result) + "\n"
|
|
break
|
|
|
|
# Convert to JSON and add newline
|
|
message.network_packets += 1
|
|
message.network_bytes += len(result)
|
|
try:
|
|
disconnected = await request.is_disconnected()
|
|
except Exception as e:
|
|
logger.warning(f"Disconnection check failed: {e}")
|
|
disconnected = True
|
|
|
|
if disconnected:
|
|
logger.info("Disconnect detected. Continuing generation to store in cache.")
|
|
disconnected = True
|
|
|
|
if not disconnected:
|
|
yield json.dumps(result) + "\n"
|
|
|
|
current_time = time.perf_counter()
|
|
if current_time - start_time > LLM_TIMEOUT:
|
|
message.status = "error"
|
|
message.response = f"Processing time ({LLM_TIMEOUT}s) exceeded for single LLM inference (likely due to LLM getting stuck.) You will need to retry your query."
|
|
message.partial_response = message.response
|
|
logger.info(message.response + " Ending session")
|
|
result = message.model_dump(by_alias=True, mode="json")
|
|
if not disconnected:
|
|
yield json.dumps(result) + "\n"
|
|
|
|
if message.status == "error":
|
|
context.processing = False
|
|
break
|
|
|
|
# Allow the event loop to process the write
|
|
await asyncio.sleep(0)
|
|
except Exception as e:
|
|
context.processing = False
|
|
logger.error(traceback.format_exc())
|
|
logger.error(f"Error in generate_response: {e}")
|
|
yield json.dumps({"status": "error", "response": str(e)}) + "\n"
|
|
finally:
|
|
# Save context on completion or error
|
|
await self.save_context(context_id)
|
|
logger.info("Flush generator completed normally.")
|
|
|
|
# Return StreamingResponse with appropriate headers
|
|
return StreamingResponse(
|
|
flush_generator(),
|
|
media_type="application/json",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no", # Prevents Nginx buffering if you're using it
|
|
},
|
|
)
|
|
except Exception as e:
|
|
context.processing = False
|
|
logger.error(f"Error in post_chat_endpoint: {e}")
|
|
return JSONResponse({"error": str(e)}, status_code=500)
|
|
|
|
@self.app.post("/api/create-session")
|
|
async def create_session(request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
context = await self.create_context(username=defines.default_username)
|
|
return JSONResponse({"id": context.id})
|
|
|
|
@self.app.get("/api/join-session/{context_id}")
|
|
async def join_session(context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
context = await self.load_context(context_id=context_id)
|
|
if not context:
|
|
return JSONResponse({"error": f"{context_id} does not exist."}, 404)
|
|
return JSONResponse({"id": context.id})
|
|
|
|
@self.app.get("/api/u/{context_id}")
|
|
async def get_users(context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
try:
|
|
context = await self.load_context(context_id)
|
|
if not context:
|
|
return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404)
|
|
|
|
users = [User.sanitize(u) for u in User.get_users()]
|
|
|
|
return JSONResponse(users)
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
logger.error(f"get_users error: {str(e)}")
|
|
return JSONResponse({ "error": "Unable to parse users"}, 500)
|
|
|
|
@self.app.get("/api/u/{username}/images/{image_id}/{context_id}")
|
|
async def get_user_image(username: str, image_id: str, context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
try:
|
|
self.sanitize_input(context_id)
|
|
self.sanitize_input(username)
|
|
self.sanitize_input(image_id)
|
|
|
|
if not User.exists(username):
|
|
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
|
|
context = await self.load_context(context_id)
|
|
if not context:
|
|
return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404)
|
|
image_path = os.path.join(defines.user_dir, username, "images", image_id)
|
|
if not os.path.exists(image_path):
|
|
return JSONResponse({ "error": "User {username} does not image {image_id}"}, status_code=404)
|
|
return FileResponse(image_path)
|
|
except ValueError as e:
|
|
return JSONResponse({ "error": f"Invalid input: {image_id}" }, 400)
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
logger.error(e)
|
|
return JSONResponse({ "error": f"Unable to get image {username} {image_id}"}, 500)
|
|
|
|
@self.app.get("/api/u/{username}/profile/{context_id}")
|
|
async def get_user_profile(username: str, context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
try:
|
|
if not User.exists(username):
|
|
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
|
|
context = await self.load_context(context_id)
|
|
if not context:
|
|
return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404)
|
|
profile_path = os.path.join(defines.user_dir, username, f"profile.png")
|
|
if not os.path.exists(profile_path):
|
|
return JSONResponse({ "error": "User {username} does not have a profile picture"}, status_code=404)
|
|
return FileResponse(profile_path)
|
|
except Exception as e:
|
|
return JSONResponse({ "error": f"Unable to load user {username}"}, 500)
|
|
|
|
@self.app.post("/api/u/{username}/{context_id}")
|
|
async def post_user(username: str, context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
try:
|
|
if not User.exists(username):
|
|
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
|
|
context = await self.load_context(context_id)
|
|
if not context:
|
|
return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404)
|
|
matching_user = next((user for user in self.users if user.username == username), None)
|
|
if matching_user:
|
|
user = matching_user
|
|
else:
|
|
user = User(username=username, llm=self.llm)
|
|
await user.initialize(prometheus_collector=self.prometheus_collector)
|
|
self.users.append(user)
|
|
reset_map = (
|
|
"chat",
|
|
"job_description",
|
|
"resume",
|
|
"fact_check",
|
|
)
|
|
for mode in reset_map:
|
|
tmp = context.get_agent(mode)
|
|
if not tmp:
|
|
continue
|
|
logger.info(f"User change: Resetting history for {mode}")
|
|
if mode != "chat":
|
|
context.remove_agent(tmp)
|
|
tmp.conversation.reset()
|
|
context.user = user
|
|
user_data = {
|
|
"username": user.username,
|
|
"first_name": user.first_name,
|
|
"last_name": user.last_name,
|
|
"full_name": user.full_name,
|
|
"title": user.title,
|
|
"phone": user.phone,
|
|
"location": user.location,
|
|
"email": user.email,
|
|
"description": user.description,
|
|
"contact_info": user.contact_info,
|
|
"rag_content_size": user.rag_content_size,
|
|
"has_profile": user.has_profile,
|
|
"questions": [ q.model_dump(mode='json') for q in user.user_questions],
|
|
}
|
|
await self.save_context(context_id)
|
|
return JSONResponse(user_data)
|
|
except Exception as e:
|
|
return JSONResponse({ "error": f"Unable to load user {username}"}, 500)
|
|
|
|
@self.app.post("/api/context/u/{username}")
|
|
async def create_user_context(username: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
try:
|
|
if not User.exists(username):
|
|
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
|
|
|
|
context = await self.create_context(username=username)
|
|
logger.info(f"Generated new context {context.id} for {username}")
|
|
return JSONResponse({"id": context.id})
|
|
except Exception as e:
|
|
logger.error(f"create_user_context error: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
return JSONResponse({"error": f"User {username} not found."}, status_code=404)
|
|
|
|
@self.app.post("/api/context")
|
|
async def create_context(request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
return self.app.create_user_context(defines.default_username, request)
|
|
|
|
@self.app.get("/api/history/{context_id}/{agent_type}")
|
|
async def get_history(context_id: str, agent_type: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
try:
|
|
context = await self.upsert_context(context_id)
|
|
agent = context.get_agent(agent_type)
|
|
if not agent:
|
|
logger.info(
|
|
f"Agent {agent_type} not found. Returning empty history."
|
|
)
|
|
return JSONResponse({"messages": []})
|
|
logger.info(
|
|
f"History for {agent_type} contains {len(agent.conversation)} entries."
|
|
)
|
|
return agent.conversation
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
logger.error(f"get_history error: {str(e)}")
|
|
return JSONResponse({"error": str(e)}, status_code=404)
|
|
|
|
@self.app.get("/api/tools/{context_id}")
|
|
async def get_tools(context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
context = await self.upsert_context(context_id)
|
|
return JSONResponse(context.tools)
|
|
|
|
@self.app.put("/api/tools/{context_id}")
|
|
async def put_tools(context_id: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
if not is_valid_uuid(context_id):
|
|
logger.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = await self.upsert_context(context_id)
|
|
try:
|
|
data = await request.json()
|
|
modify = data["tool"]
|
|
enabled = data["enabled"]
|
|
for tool in context.tools:
|
|
if modify == tool.function.name:
|
|
tool.enabled = enabled
|
|
await self.save_context(context_id)
|
|
return JSONResponse(context.tools)
|
|
return JSONResponse(
|
|
{"status": f"{modify} not found in tools."}, status_code=404
|
|
)
|
|
except:
|
|
return JSONResponse({"status": "error"}, 405)
|
|
|
|
@self.app.get("/api/context-status/{context_id}/{agent_type}")
|
|
async def get_context_status(context_id, agent_type: str, request: Request):
|
|
logger.info(f"{request.method} {request.url.path}")
|
|
if not is_valid_uuid(context_id):
|
|
logger.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = await self.upsert_context(context_id)
|
|
agent = context.get_agent(agent_type)
|
|
if not agent:
|
|
return JSONResponse(
|
|
{"context_used": 0, "max_context": defines.max_context}
|
|
)
|
|
return JSONResponse(
|
|
{
|
|
"context_used": agent.context_tokens,
|
|
"max_context": defines.max_context,
|
|
}
|
|
)
|
|
|
|
@self.app.get("/api/health")
|
|
async def health_check():
|
|
return JSONResponse({"status": "healthy"})
|
|
|
|
@self.app.get("/{path:path}")
|
|
async def serve_static(path: str, request: Request):
|
|
full_path = os.path.join(defines.static_content, path)
|
|
|
|
# Check if the original path exists
|
|
if os.path.exists(full_path) and os.path.isfile(full_path):
|
|
logger.info(f"Serve static request for {full_path}")
|
|
return FileResponse(full_path)
|
|
|
|
# Check if the path matches /{filename}.png
|
|
# if path.startswith("/") and path.endswith(".png"):
|
|
# filename = path[1:-4] # Remove leading '/' and trailing '.png'
|
|
# alt_path = f"/opt/backstory/users/{filename}/{filename}.png"
|
|
|
|
# # Check if the alternative path exists
|
|
# if os.path.exists(alt_path) and os.path.isfile(alt_path):
|
|
# logger.info(f"Serve static request for alternative path {alt_path}")
|
|
# return FileResponse(alt_path)
|
|
|
|
# If neither path exists, return 404
|
|
logger.info(f"File not found for path {full_path} -- returning index.html")
|
|
return FileResponse(os.path.join(defines.static_content, "index.html"))
|
|
|
|
async def save_context(self, context_id):
|
|
"""
|
|
Serialize a Python dictionary to a file in the agents directory.
|
|
|
|
Args:
|
|
data: Dictionary containing the agent data
|
|
context_id: UUID string for the context. If it doesn't exist, it is created
|
|
|
|
Returns:
|
|
The context_id used for the file
|
|
"""
|
|
context = await self.upsert_context(context_id)
|
|
|
|
# Create agents directory if it doesn't exist
|
|
if not os.path.exists(defines.context_dir):
|
|
os.makedirs(defines.context_dir)
|
|
|
|
# Create the full file path
|
|
file_path = os.path.join(defines.context_dir, context_id)
|
|
|
|
# Serialize the data to JSON and write to file
|
|
try:
|
|
# Check for non-serializable fields before dumping
|
|
serialization_errors = check_serializable(context)
|
|
if serialization_errors:
|
|
for error in serialization_errors:
|
|
logger.error(error)
|
|
raise ValueError("Found non-serializable fields in the model")
|
|
# Dump the model prior to opening file in case there is
|
|
# a validation error so it doesn't delete the current
|
|
# context session
|
|
json_data = context.model_dump_json(by_alias=True)
|
|
with open(file_path, "w") as f:
|
|
f.write(json_data)
|
|
except ValidationError as e:
|
|
logger.error(e)
|
|
logger.error(traceback.format_exc())
|
|
for error in e.errors():
|
|
print(f"Field: {error['loc'][0]}, Error: {error['msg']}")
|
|
except PydanticSerializationError as e:
|
|
logger.error(e)
|
|
logger.error(traceback.format_exc())
|
|
logger.error(f"Serialization error: {str(e)}")
|
|
# Inspect the model to identify problematic fields
|
|
for field_name, value in context.__dict__.items():
|
|
if isinstance(value, np.ndarray):
|
|
logger.error(f"Field '{field_name}' contains non-serializable type: {type(value)}")
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
logger.error(e)
|
|
|
|
return context_id
|
|
|
|
async def load_context(self, context_id: str) -> Context | None:
|
|
"""
|
|
Load a context from a file in the context directory or create a new one if it doesn't exist.
|
|
Args:
|
|
context_id: UUID string for the context.
|
|
Returns:
|
|
A Context object with the specified ID and default settings.
|
|
"""
|
|
file_path = os.path.join(defines.context_dir, context_id)
|
|
|
|
# Check if the file exists
|
|
if not os.path.exists(file_path):
|
|
return None
|
|
|
|
# Read and deserialize the data
|
|
with open(file_path, "r") as f:
|
|
content = f.read()
|
|
logger.info(
|
|
f"Loading context from {file_path}, content length: {len(content)}"
|
|
)
|
|
json_data = {}
|
|
try:
|
|
# Try parsing as JSON first to ensure valid JSON
|
|
json_data = json.loads(content)
|
|
logger.info("JSON parsed successfully, attempting model validation")
|
|
|
|
context = Context.model_validate(json_data)
|
|
username = context.username
|
|
if not User.exists(username):
|
|
raise ValueError(f"Attempt to load context {context.id} with invalid user {username}")
|
|
|
|
matching_user = next((user for user in self.users if user.username == username), None)
|
|
if matching_user:
|
|
user = matching_user
|
|
else:
|
|
user = User(username=username, llm=self.llm)
|
|
await user.initialize(prometheus_collector=self.prometheus_collector)
|
|
self.users.append(user)
|
|
context.user = user
|
|
|
|
# Now set context on agents manually
|
|
agent_types = [agent.agent_type for agent in context.agents]
|
|
if len(agent_types) != len(set(agent_types)):
|
|
raise ValueError(
|
|
"Context cannot contain multiple agents of the same agent_type"
|
|
)
|
|
for agent in context.agents:
|
|
agent.set_context(context)
|
|
|
|
self.contexts[context_id] = context
|
|
|
|
logger.info(f"Successfully loaded context {context_id}")
|
|
except ValidationError as e:
|
|
logger.error(e)
|
|
logger.error(traceback.format_exc())
|
|
for error in e.errors():
|
|
print(f"Field: {error['loc'][0]}, Error: {error['msg']}")
|
|
except Exception as e:
|
|
logger.error(f"Error validating context: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
for key in json_data:
|
|
logger.info(f"{key} = {type(json_data[key])} {str(json_data[key])[:60] if json_data[key] else "None"}")
|
|
logger.info("*" * 50)
|
|
return None
|
|
|
|
return self.contexts[context_id]
|
|
|
|
async def load_or_create_context(self, context_id: str) -> Context:
|
|
"""
|
|
Load a context from a file in the context directory or create a new one if it doesn't exist.
|
|
Args:
|
|
context_id: UUID string for the context.
|
|
Returns:
|
|
A Context object with the specified ID and default settings.
|
|
"""
|
|
context = await self.load_context(context_id)
|
|
if context:
|
|
return context
|
|
logger.info(f"Context not found. Creating new instance of context {context_id}.")
|
|
self.contexts[context_id] = await self.create_context(username=defines.default_username, context_id=context_id)
|
|
return self.contexts[context_id]
|
|
|
|
async def create_context(self, username: str, context_id=None) -> Context:
|
|
"""
|
|
Create a new context with a unique ID and default settings.
|
|
Args:
|
|
context_id: Optional UUID string for the context. If not provided, a new UUID is generated.
|
|
Returns:
|
|
A Context object with the specified ID and default settings.
|
|
"""
|
|
if not User.exists(username):
|
|
raise ValueError(f"{username} does not exist.")
|
|
|
|
# If username
|
|
matching_user = next((user for user in self.users if user.username == username), None)
|
|
if matching_user:
|
|
user = matching_user
|
|
logger.info(f"Found matching user: {user.username}")
|
|
else:
|
|
user = User(username=username, llm=self.llm)
|
|
await user.initialize(prometheus_collector=self.prometheus_collector)
|
|
logger.info(f"Created new instance of user: {user.username}")
|
|
self.users.append(user)
|
|
|
|
logger.info(f"Creating context {context_id or "new"} for user: {user.username}")
|
|
try:
|
|
if context_id:
|
|
context = Context(
|
|
id=context_id,
|
|
user=user,
|
|
rags=[ rag.model_copy() for rag in user.rags ],
|
|
tools=Tools.all_tools()
|
|
)
|
|
else:
|
|
context = Context(
|
|
user=user,
|
|
rags=[ rag.model_copy() for rag in user.rags ],
|
|
tools=Tools.all_tools()
|
|
)
|
|
except ValidationError as e:
|
|
logger.error(e)
|
|
logger.error(traceback.format_exc())
|
|
for error in e.errors():
|
|
print(f"Field: {error['loc'][0]}, Error: {error['msg']}")
|
|
exit(1)
|
|
|
|
logger.info(f"New context created with ID: {context.id}")
|
|
|
|
if os.path.exists(defines.resume_doc):
|
|
context.user_resume = open(defines.resume_doc, "r").read()
|
|
context.get_or_create_agent(agent_type="chat")
|
|
# system_prompt=system_message)
|
|
# context.add_agent(Resume(system_prompt = system_generate_resume))
|
|
# context.add_agent(JobDescription(system_prompt = system_job_description))
|
|
# context.add_agent(FactCheck(system_prompt = system_fact_check))
|
|
|
|
logger.info(f"{context.id} created and added to contexts.")
|
|
self.contexts[context.id] = context
|
|
await self.save_context(context.id)
|
|
return context
|
|
|
|
async def upsert_context(self, context_id=None) -> Context:
|
|
"""
|
|
Upsert a context based on the provided context_id.
|
|
Args:
|
|
context_id: UUID string for the context. If it doesn't exist, a new context is created.
|
|
Returns:
|
|
A Context object with the specified ID and default settings.
|
|
"""
|
|
|
|
if not context_id:
|
|
logger.warning("No context ID provided. Creating a new context.")
|
|
return await self.create_context(username=defines.default_username)
|
|
|
|
if context_id in self.contexts:
|
|
return self.contexts[context_id]
|
|
|
|
logger.info(f"Context {context_id} is not yet loaded.")
|
|
return await self.load_or_create_context(context_id=context_id)
|
|
|
|
@REQUEST_TIME.time()
|
|
async def generate_response(
|
|
self, context: Context, agent: Agent, prompt: str, tunables: Tunables | None
|
|
) -> AsyncGenerator[Message, None]:
|
|
agent_type = agent.get_agent_type()
|
|
logger.info(f"generate_response: type - {agent_type}")
|
|
# Merge tunables to take agent defaults and override with user supplied settings
|
|
agent_tunables = agent.tunables.model_dump() if agent.tunables else {}
|
|
user_tunables = tunables.model_dump() if tunables else {}
|
|
merged_tunables = {**agent_tunables, **user_tunables}
|
|
message = Message(prompt=prompt, tunables=Tunables(**merged_tunables))
|
|
|
|
async for message in agent.prepare_message(message):
|
|
# logger.info(f"{agent_type}.prepare_message: {value.status} - {value.response}")
|
|
if message.status == "error":
|
|
yield message
|
|
return
|
|
if message.status != "done":
|
|
yield message
|
|
|
|
async for message in agent.process_message(self.llm, self.model, message):
|
|
if message.status != "done":
|
|
yield message
|
|
|
|
if message.status == "error":
|
|
return
|
|
|
|
logger.info(
|
|
f"{agent_type}.process_message: {message.status} {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}"
|
|
)
|
|
message.status = "done"
|
|
yield message
|
|
return
|
|
|
|
def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs):
|
|
try:
|
|
if self.ssl_enabled:
|
|
logger.info(f"Starting web server at https://{host}:{port}")
|
|
uvicorn.run(
|
|
self.app,
|
|
host=host,
|
|
port=port,
|
|
log_config=None,
|
|
ssl_keyfile=defines.key_path,
|
|
ssl_certfile=defines.cert_path,
|
|
)
|
|
else:
|
|
logger.info(f"Starting web server at http://{host}:{port}")
|
|
uvicorn.run(self.app, host=host, port=port, log_config=None)
|
|
except KeyboardInterrupt:
|
|
for user in self.users:
|
|
if user.observer:
|
|
user.observer.stop()
|
|
for user in self.users:
|
|
if user.observer:
|
|
user.observer.join()
|
|
|
|
|
|
# %%
|
|
|
|
|
|
# Main function to run everything
|
|
def main():
|
|
# Parse command-line arguments
|
|
args = parse_args()
|
|
|
|
# Setup logging based on the provided level
|
|
logger.setLevel(args.level.upper())
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.*")
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="umap.*")
|
|
|
|
llm = ollama.Client(host=args.ollama_server) # type: ignore
|
|
web_server = WebServer(llm, args.ollama_model)
|
|
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
|
|
|
|
main()
|