1448 lines
62 KiB
Python
1448 lines
62 KiB
Python
# %%
|
|
# Imports [standard]
|
|
# Standard library modules (no try-except needed)
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import uuid
|
|
import subprocess
|
|
import re
|
|
import math
|
|
import warnings
|
|
from typing import Any
|
|
|
|
def try_import(module_name, pip_name=None):
|
|
try:
|
|
__import__(module_name)
|
|
except ImportError:
|
|
print(f"Module '{module_name}' not found. Install it using:")
|
|
print(f" pip install {pip_name or module_name}")
|
|
|
|
# Third-party modules with import checks
|
|
try_import("ollama")
|
|
try_import("requests")
|
|
try_import("bs4", "beautifulsoup4")
|
|
try_import("fastapi")
|
|
try_import("uvicorn")
|
|
try_import("numpy")
|
|
try_import("umap")
|
|
try_import("sklearn")
|
|
|
|
import ollama
|
|
import requests
|
|
from bs4 import BeautifulSoup
|
|
from fastapi import FastAPI, Request, BackgroundTasks
|
|
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import uvicorn
|
|
import numpy as np
|
|
import umap
|
|
from sklearn.preprocessing import MinMaxScaler
|
|
|
|
from utils import (
|
|
rag as Rag,
|
|
Context, Conversation, Session, Message, Chat, Resume, JobDescription, FactCheck,
|
|
defines
|
|
)
|
|
|
|
from tools import (
|
|
DateTime,
|
|
WeatherForecast,
|
|
TickerValue,
|
|
tools
|
|
)
|
|
|
|
CONTEXT_VERSION=2
|
|
|
|
rags = [
|
|
{ "name": "JPK", "enabled": True, "description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects." },
|
|
# { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
|
|
]
|
|
|
|
system_message = f"""
|
|
Launched on {DateTime()}.
|
|
|
|
When answering queries, follow these steps:
|
|
|
|
- First analyze the query to determine if real-time information might be helpful
|
|
- Even when <|context|> is provided, consider whether the tools would provide more current or comprehensive information
|
|
- Use the provided tools whenever they would enhance your response, regardless of whether context is also available
|
|
- When presenting weather forecasts, include relevant emojis immediately before the corresponding text. For example, for a sunny day, say \"☀️ Sunny\" or if the forecast says there will be \"rain showers, say \"🌧️ Rain showers\". Use this mapping for weather emojis: Sunny: ☀️, Cloudy: ☁️, Rainy: 🌧️, Snowy: ❄️
|
|
- When both <|context|> and tool outputs are relevant, synthesize information from both sources to provide the most complete answer
|
|
- Always prioritize the most up-to-date and relevant information, whether it comes from <|context|> or tools
|
|
- If <|context|> and tool outputs contain conflicting information, prefer the tool outputs as they likely represent more current data
|
|
- If there is information in the <|context|>, <|job_description|>, or <|context|> sections to enhance the answer, incorporate it seamlessly and refer to it as 'the latest information' or 'recent data' instead of mentioning '<|context|>' (etc.) or quoting it directly.
|
|
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|job_description|>, or <|context|> tags.
|
|
|
|
Always use tools and <|context|> when possible. Be concise, and never make up information. If you do not know the answer, say so.
|
|
"""
|
|
|
|
system_generate_resume = f"""
|
|
Launched on {DateTime()}.
|
|
|
|
You are a professional resume writer. Your task is to write a concise, polished, and tailored resume for a specific job based only on the individual's <|context|>.
|
|
|
|
When answering queries, follow these steps:
|
|
|
|
- You must not invent or assume any inforation not explicitly present in the <|context|>.
|
|
- Analyze the <|job_description|> to identify skills required for the job.
|
|
- Use the <|job_description|> provided to guide the focus, tone, and relevant skills or experience to highlight from the <|context|>.
|
|
- Identify and emphasize the experiences, achievements, and responsibilities from the <|context|> that best align with the <|job_description|>.
|
|
- Only provide information from <|context|> items if it is relevant to the <|job_description|>.
|
|
- Do not use the <|job_description|> skills unless listed in <|context|>.
|
|
- Do not include any information unless it is provided in <|context|>.
|
|
- Use the <|context|> to create a polished, professional resume.
|
|
- Do not list any locations or mailing addresses in the resume.
|
|
- If there is information in the <|context|>, <|job_description|>, <|context|>, or <|resume|> sections to enhance the answer, incorporate it seamlessly and refer to it using natural language instead of mentioning '<|job_description|>' (etc.) or quoting it directly.
|
|
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|job_description|>, or <|context|> tags.
|
|
- Ensure the langauge is clear, concise, and aligned with industry standards for professional resumes.
|
|
|
|
Structure the resume professionally with the following sections where applicable:
|
|
|
|
* Name: Use full name
|
|
* Professional Summary: A 2-4 sentence overview tailored to the job.
|
|
* Skills: A bullet list of key skills derived from the work history and relevant to the job.
|
|
* Professional Experience: A detailed list of roles, achievements, and responsibilities from <|context|> that relate to the <|job_description|>.
|
|
* Education: Include only if available in the work history.
|
|
* Notes: Indicate the initial draft of the resume was generated using the Backstory application.
|
|
|
|
""".strip()
|
|
|
|
system_fact_check = f"""
|
|
Launched on {DateTime()}.
|
|
|
|
You are a professional resume fact checker. Your task is to identify any inaccuracies in the <|resume|> based on the individual's <|context|>.
|
|
|
|
If there are inaccuracies, list them in a bullet point format.
|
|
|
|
When answering queries, follow these steps:
|
|
- You must not invent or assume any information not explicitly present in the <|context|>.
|
|
- Analyze the <|resume|> to identify any discrepancies or inaccuracies based on the <|context|>.
|
|
- If there is information in the <|context|>, <|job_description|>, <|context|>, or <|resume|> sections to enhance the answer, incorporate it seamlessly and refer to it using natural language instead of mentioning '<|job_description|>' (etc.) or quoting it directly.
|
|
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|job_description|>, <|resume|>, or <|context|> tags.
|
|
""".strip()
|
|
|
|
system_fact_check_QA = f"""
|
|
Launched on {DateTime()}.
|
|
|
|
You are a professional resume fact checker.
|
|
|
|
You are provided with a <|resume|> which was generated by you, the <|context|> you used to generate that <|resume|>, and a <|fact_check|> generated by you when you analyzed <|context|> against the <|resume|> to identify dicrepancies.
|
|
|
|
Your task is to answer questions about the <|fact_check|> you generated based on the <|resume|> and <|context>.
|
|
"""
|
|
|
|
system_job_description = f"""
|
|
Launched on {DateTime()}.
|
|
|
|
You are a hiring and job placing specialist. Your task is to answers about a job description.
|
|
|
|
When answering queries, follow these steps:
|
|
- Analyze the <|job_description|> to provide insights for the asked question.
|
|
- If any financial information is requested, be sure to account for inflation.
|
|
- If there is information in the <|context|>, <|job_description|>, <|context|>, or <|resume|> sections to enhance the answer, incorporate it seamlessly and refer to it using natural language instead of mentioning '<|job_description|>' (etc.) or quoting it directly.
|
|
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|>, <|job_description|>, <|resume|>, or <|context|> tags.
|
|
""".strip()
|
|
|
|
def get_installed_ram():
|
|
try:
|
|
with open("/proc/meminfo", "r") as f:
|
|
meminfo = f.read()
|
|
match = re.search(r"MemTotal:\s+(\d+)", meminfo)
|
|
if match:
|
|
return f"{math.floor(int(match.group(1)) / 1000**2)}GB" # Convert KB to GB
|
|
except Exception as e:
|
|
return f"Error retrieving RAM: {e}"
|
|
|
|
def get_graphics_cards():
|
|
gpus = []
|
|
try:
|
|
# Run the ze-monitor utility
|
|
result = subprocess.run(["ze-monitor"], capture_output=True, text=True, check=True)
|
|
|
|
# Clean up the output (remove leading/trailing whitespace and newlines)
|
|
output = result.stdout.strip()
|
|
for index in range(len(output.splitlines())):
|
|
result = subprocess.run(["ze-monitor", "--device", f"{index+1}", "--info"], capture_output=True, text=True, check=True)
|
|
gpu_info = result.stdout.strip().splitlines()
|
|
gpu = {
|
|
"discrete": True, # Assume it's discrete initially
|
|
"name": None,
|
|
"memory": None
|
|
}
|
|
gpus.append(gpu)
|
|
for line in gpu_info:
|
|
match = re.match(r"^Device: [^(]*\((.*)\)", line)
|
|
if match:
|
|
gpu["name"] = match.group(1)
|
|
continue
|
|
|
|
match = re.match(r"^\s*Memory: (.*)", line)
|
|
if match:
|
|
gpu["memory"] = match.group(1)
|
|
continue
|
|
|
|
match = re.match(r"^.*Is integrated with host: Yes.*", line)
|
|
if match:
|
|
gpu["discrete"] = False
|
|
continue
|
|
|
|
return gpus
|
|
except Exception as e:
|
|
return f"Error retrieving GPU info: {e}"
|
|
|
|
def get_cpu_info():
|
|
try:
|
|
with open("/proc/cpuinfo", "r") as f:
|
|
cpuinfo = f.read()
|
|
model_match = re.search(r"model name\s+:\s+(.+)", cpuinfo)
|
|
cores_match = re.findall(r"processor\s+:\s+\d+", cpuinfo)
|
|
if model_match and cores_match:
|
|
return f"{model_match.group(1)} with {len(cores_match)} cores"
|
|
except Exception as e:
|
|
return f"Error retrieving CPU info: {e}"
|
|
|
|
def system_info(model):
|
|
return {
|
|
"System RAM": get_installed_ram(),
|
|
"Graphics Card": get_graphics_cards(),
|
|
"CPU": get_cpu_info(),
|
|
"LLM Model": model,
|
|
"Embedding Model": defines.embedding_model,
|
|
"Context length": defines.max_context
|
|
}
|
|
|
|
# %%
|
|
# Defaults
|
|
OLLAMA_API_URL = defines.ollama_api_url
|
|
MODEL_NAME = defines.model
|
|
LOG_LEVEL="info"
|
|
USE_TLS=False
|
|
WEB_HOST="0.0.0.0"
|
|
WEB_PORT=8911
|
|
DEFAULT_HISTORY_LENGTH=5
|
|
|
|
# %%
|
|
# Globals
|
|
|
|
|
|
def create_system_message(prompt):
|
|
return [{"role": "system", "content": prompt}]
|
|
|
|
tool_log = []
|
|
command_log = []
|
|
model = None
|
|
client = None
|
|
web_server = None
|
|
|
|
# %%
|
|
# Cmd line overrides
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
|
parser.add_argument("--ollama-server", type=str, default=OLLAMA_API_URL, help=f"Ollama API endpoint. default={OLLAMA_API_URL}")
|
|
parser.add_argument("--ollama-model", type=str, default=MODEL_NAME, help=f"LLM model to use. default={MODEL_NAME}")
|
|
parser.add_argument("--web-host", type=str, default=WEB_HOST, help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.")
|
|
parser.add_argument("--web-port", type=str, default=WEB_PORT, help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.")
|
|
parser.add_argument("--level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
|
default=LOG_LEVEL, help=f"Set the logging level. default={LOG_LEVEL}")
|
|
return parser.parse_args()
|
|
|
|
def setup_logging(level):
|
|
global logging
|
|
|
|
numeric_level = getattr(logging, level.upper(), None)
|
|
if not isinstance(numeric_level, int):
|
|
raise ValueError(f"Invalid log level: {level}")
|
|
|
|
logging.basicConfig(
|
|
level=numeric_level,
|
|
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
force=True
|
|
)
|
|
|
|
# Now reduce verbosity for FastAPI, Uvicorn, Starlette
|
|
for noisy_logger in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi", "starlette"):
|
|
logging.getLogger(noisy_logger).setLevel(logging.WARNING)
|
|
|
|
logging.info(f"Logging is set to {level} level.")
|
|
|
|
|
|
# %%
|
|
|
|
async def AnalyzeSite(llm, model: str, url : str, question : str):
|
|
"""
|
|
Fetches content from a URL, extracts the text, and uses Ollama to summarize it.
|
|
|
|
Args:
|
|
url (str): The URL of the website to summarize
|
|
|
|
Returns:
|
|
str: A summary of the website content
|
|
"""
|
|
try:
|
|
# Fetch the webpage
|
|
headers = {
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
|
}
|
|
logging.info(f"Fetching {url}")
|
|
response = requests.get(url, headers=headers, timeout=10)
|
|
response.raise_for_status()
|
|
logging.info(f"{url} returned. Processing...")
|
|
# Parse the HTML
|
|
soup = BeautifulSoup(response.text, "html.parser")
|
|
|
|
# Remove script and style elements
|
|
for script in soup(["script", "style"]):
|
|
script.extract()
|
|
|
|
# Get text content
|
|
text = soup.get_text(separator=" ", strip=True)
|
|
|
|
# Clean up text (remove extra whitespace)
|
|
lines = (line.strip() for line in text.splitlines())
|
|
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
|
text = " ".join(chunk for chunk in chunks if chunk)
|
|
|
|
# Limit text length if needed (Ollama may have token limits)
|
|
max_chars = 100000
|
|
if len(text) > max_chars:
|
|
text = text[:max_chars] + "..."
|
|
|
|
# Create Ollama client
|
|
# logging.info(f"Requesting summary of: {text}")
|
|
|
|
# Generate summary using Ollama
|
|
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
|
response = llm.generate(model=model,
|
|
system="You are given the contents of {url}. Answer the question about the contents",
|
|
prompt=prompt)
|
|
|
|
#logging.info(response["response"])
|
|
|
|
return {
|
|
"source": "summarizer-llm",
|
|
"content": response["response"],
|
|
"metadata": DateTime()
|
|
}
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
return f"Error fetching the URL: {str(e)}"
|
|
except Exception as e:
|
|
return f"Error processing the website content: {str(e)}"
|
|
|
|
# %%
|
|
|
|
# %%
|
|
def is_valid_uuid(value):
|
|
try:
|
|
uuid_obj = uuid.UUID(value, version=4)
|
|
return str(uuid_obj) == value
|
|
except (ValueError, TypeError):
|
|
return False
|
|
|
|
def default_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
return [{**tool, "enabled": True} for tool in tools]
|
|
|
|
def find_summarize_tool(tools):
|
|
return [{**tool, "enabled": True} for tool in tools if tool.get("name", "") == "AnalyzeSite"]
|
|
|
|
def llm_tools(tools):
|
|
return [tool for tool in tools if tool.get("enabled", False) == True]
|
|
|
|
|
|
|
|
|
|
|
|
# %%
|
|
class WebServer:
|
|
def __init__(self, llm, model=MODEL_NAME):
|
|
self.app = FastAPI()
|
|
self.contexts = {}
|
|
self.llm = llm
|
|
self.model = model
|
|
self.processing = False
|
|
self.file_watcher = None
|
|
self.observer = None
|
|
|
|
self.ssl_enabled = os.path.exists(defines.key_path) and os.path.exists(defines.cert_path)
|
|
|
|
if self.ssl_enabled:
|
|
allow_origins=["https://battle-linux.ketrenos.com:3000"]
|
|
else:
|
|
allow_origins=["http://battle-linux.ketrenos.com:3000"]
|
|
|
|
logging.info(f"Allowed origins: {allow_origins}")
|
|
|
|
self.app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=allow_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
@self.app.on_event("startup")
|
|
async def startup_event():
|
|
# Start the file watcher
|
|
self.observer, self.file_watcher = Rag.start_file_watcher(
|
|
llm=llm,
|
|
watch_directory=defines.doc_dir,
|
|
recreate=False # Don't recreate if exists
|
|
)
|
|
|
|
print(f"API started with {self.file_watcher.collection.count()} documents in the collection")
|
|
|
|
@self.app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
if self.observer:
|
|
self.observer.stop()
|
|
self.observer.join()
|
|
print("File watcher stopped")
|
|
|
|
self.setup_routes()
|
|
|
|
def setup_routes(self):
|
|
@self.app.get("/")
|
|
async def root():
|
|
context = self.create_context()
|
|
logging.info(f"Redirecting non-session to {context.id}")
|
|
return RedirectResponse(url=f"/{context.id}", status_code=307)
|
|
#return JSONResponse({"redirect": f"/{context.id}"})
|
|
|
|
|
|
@self.app.put("/api/umap/{context_id}")
|
|
async def put_umap(context_id: str, request: Request):
|
|
logging.info(f"{request.method} {request.url.path}")
|
|
try:
|
|
if not self.file_watcher:
|
|
raise Exception("File watcher not initialized")
|
|
|
|
context = self.upsert_context(context_id)
|
|
if not context:
|
|
return JSONResponse({"error": f"Invalid context: {context_id}"}, status_code=400)
|
|
|
|
data = await request.json()
|
|
|
|
dimensions = data.get("dimensions", 2)
|
|
result = self.file_watcher.umap_collection
|
|
if dimensions == 2:
|
|
logging.info("Returning 2D UMAP")
|
|
umap_embedding = self.file_watcher.umap_embedding_2d
|
|
else:
|
|
logging.info("Returning 3D UMAP")
|
|
umap_embedding = self.file_watcher.umap_embedding_3d
|
|
|
|
result["embeddings"] = umap_embedding.tolist()
|
|
|
|
return JSONResponse(result)
|
|
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return JSONResponse({"error": str(e)}, 500)
|
|
|
|
@self.app.put("/api/similarity/{context_id}")
|
|
async def put_similarity(context_id: str, request: Request):
|
|
logging.info(f"{request.method} {request.url.path}")
|
|
if not self.file_watcher:
|
|
return
|
|
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
|
|
try:
|
|
data = await request.json()
|
|
query = data.get("query", "")
|
|
except:
|
|
query = ""
|
|
if not query:
|
|
return JSONResponse({"error": "No query provided for similarity search"}, status_code=400)
|
|
|
|
try:
|
|
chroma_results = self.file_watcher.find_similar(query=query, top_k=10)
|
|
if not chroma_results:
|
|
return JSONResponse({"error": "No results found"}, status_code=404)
|
|
|
|
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
|
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
|
|
|
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
|
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
|
|
|
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
|
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
|
|
|
return JSONResponse({
|
|
**chroma_results,
|
|
"query": query,
|
|
"umap_embedding_2d": umap_2d,
|
|
"umap_embedding_3d": umap_3d
|
|
})
|
|
|
|
except Exception as e:
|
|
logging.error(e)
|
|
#return JSONResponse({"error": str(e)}, 500)
|
|
|
|
@self.app.put("/api/reset/{context_id}/{session_type}")
|
|
async def put_reset(context_id: str, session_type: str, request: Request):
|
|
logging.info(f"{request.method} {request.url.path}")
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
session = context.get_session(session_type)
|
|
if not session:
|
|
return JSONResponse({ "error": f"{session_type} is not recognized", "context": context.id }, status_code=404)
|
|
|
|
data = await request.json()
|
|
try:
|
|
response = {}
|
|
for reset_operation in data["reset"]:
|
|
match reset_operation:
|
|
case "system_prompt":
|
|
logging.info(f"Resetting {reset_operation}")
|
|
match session_type:
|
|
case "chat":
|
|
prompt = system_message
|
|
case "job_description":
|
|
prompt = system_generate_resume
|
|
case "resume":
|
|
prompt = system_generate_resume
|
|
case "fact_check":
|
|
prompt = system_message
|
|
|
|
session.system_prompt = prompt
|
|
response["system_prompt"] = { "system_prompt": prompt }
|
|
case "rags":
|
|
logging.info(f"Resetting {reset_operation}")
|
|
context.rags = rags.copy()
|
|
response["rags"] = context.rags
|
|
case "tools":
|
|
logging.info(f"Resetting {reset_operation}")
|
|
context.tools = default_tools(tools)
|
|
response["tools"] = context.tools
|
|
case "history":
|
|
reset_map = {
|
|
"job_description": ("job_description", "resume", "fact_check"),
|
|
"resume": ("job_description", "resume", "fact_check"),
|
|
"fact_check": ("job_description", "resume", "fact_check"),
|
|
"chat": ("chat",),
|
|
}
|
|
resets = reset_map.get(session_type, ())
|
|
|
|
for mode in resets:
|
|
tmp = context.get_session(mode)
|
|
if not tmp:
|
|
continue
|
|
logging.info(f"Resetting {reset_operation} for {mode}")
|
|
context.conversation = Conversation()
|
|
context.context_tokens = round(len(str(session.system_prompt)) * 3 / 4) # Estimate context usage
|
|
response["history"] = []
|
|
response["context_used"] = session.context_tokens
|
|
case "message_history_length":
|
|
logging.info(f"Resetting {reset_operation}")
|
|
context.message_history_length = DEFAULT_HISTORY_LENGTH
|
|
response["message_history_length"] = DEFAULT_HISTORY_LENGTH
|
|
|
|
if not response:
|
|
return JSONResponse({ "error": "Usage: { reset: rags|tools|history|system_prompt}"})
|
|
else:
|
|
self.save_context(context_id)
|
|
return JSONResponse(response)
|
|
|
|
except:
|
|
return JSONResponse({ "error": "Usage: { reset: rags|tools|history|system_prompt}"})
|
|
|
|
@self.app.put("/api/tunables/{context_id}")
|
|
async def put_tunables(context_id: str, request: Request):
|
|
logging.info(f"{request.method} {request.url.path}")
|
|
try:
|
|
context = self.upsert_context(context_id)
|
|
|
|
data = await request.json()
|
|
session = context.get_session("chat")
|
|
if not session:
|
|
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
|
for k in data.keys():
|
|
match k:
|
|
case "tools":
|
|
# { "tools": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
|
tools : list[dict[str, Any]] = data[k]
|
|
if not tools:
|
|
return JSONResponse({ "status": "error", "message": "Tools can not be empty." })
|
|
for tool in tools:
|
|
for context_tool in context.tools:
|
|
if context_tool["function"]["name"] == tool["name"]:
|
|
context_tool["enabled"] = tool["enabled"]
|
|
self.save_context(context_id)
|
|
return JSONResponse({ "tools": [ {
|
|
**t["function"],
|
|
"enabled": t["enabled"],
|
|
} for t in context.tools] })
|
|
|
|
case "rags":
|
|
# { "rags": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
|
rags : list[dict[str, Any]] = data[k]
|
|
if not rags:
|
|
return JSONResponse({ "status": "error", "message": "RAGs can not be empty." })
|
|
for rag in rags:
|
|
for context_rag in context.rags:
|
|
if context_rag["name"] == rag["name"]:
|
|
context_rag["enabled"] = rag["enabled"]
|
|
self.save_context(context_id)
|
|
return JSONResponse({ "rags": context.rags })
|
|
|
|
case "system_prompt":
|
|
system_prompt = data[k].strip()
|
|
if not system_prompt:
|
|
return JSONResponse({ "status": "error", "message": "System prompt can not be empty." })
|
|
session.system_prompt = system_prompt
|
|
self.save_context(context_id)
|
|
return JSONResponse({ "system_prompt": system_prompt })
|
|
case "message_history_length":
|
|
value = max(0, int(data[k]))
|
|
context.message_history_length = value
|
|
self.save_context(context_id)
|
|
return JSONResponse({ "message_history_length": value })
|
|
case _:
|
|
return JSONResponse({ "error": f"Unrecognized tunable {k}"}, status_code=404)
|
|
except Exception as e:
|
|
logging.error(f"Error in put_tunables: {e}")
|
|
return JSONResponse({"error": str(e)}, status_code=500)
|
|
|
|
@self.app.get("/api/tunables/{context_id}")
|
|
async def get_tunables(context_id: str, request: Request):
|
|
logging.info(f"{request.method} {request.url.path}")
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
session = context.get_session("chat")
|
|
if not session:
|
|
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
|
return JSONResponse({
|
|
"system_prompt": session.system_prompt,
|
|
"message_history_length": context.message_history_length,
|
|
"rags": context.rags,
|
|
"tools": [ {
|
|
**t["function"],
|
|
"enabled": t["enabled"],
|
|
} for t in context.tools ]
|
|
})
|
|
|
|
@self.app.get("/api/system-info/{context_id}")
|
|
async def get_system_info(context_id: str, request: Request):
|
|
logging.info(f"{request.method} {request.url.path}")
|
|
return JSONResponse(system_info(self.model))
|
|
|
|
@self.app.post("/api/chat/{context_id}/{session_type}")
|
|
async def post_chat_endpoint(context_id: str, session_type: str, request: Request):
|
|
logging.info(f"{request.method} {request.url.path}")
|
|
try:
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
|
|
|
|
try:
|
|
data = await request.json()
|
|
session = context.get_session(session_type)
|
|
if not session and session_type == "job_description":
|
|
logging.info(f"Session {session_type} not found. Returning empty history.")
|
|
# Create a new session if it doesn't exist
|
|
session = context.get_or_create_session("job_description", system_prompt=system_generate_resume, job_description=data["content"])
|
|
except Exception as e:
|
|
logging.info(f"Attempt to create session type: {session_type} failed", e)
|
|
return JSONResponse({ "error": f"{session_type} is not recognized", "context": context.id }, status_code=404)
|
|
|
|
# Create a custom generator that ensures flushing
|
|
async def flush_generator():
|
|
async for message in self.generate_response(context=context, session=session, content=data["content"]):
|
|
# Convert to JSON and add newline
|
|
yield json.dumps(message) + "\n"
|
|
# Save the history as its generated
|
|
self.save_context(context_id)
|
|
# Explicitly flush after each yield
|
|
await asyncio.sleep(0) # Allow the event loop to process the write
|
|
|
|
# Return StreamingResponse with appropriate headers
|
|
return StreamingResponse(
|
|
flush_generator(),
|
|
media_type="application/json",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no" # Prevents Nginx buffering if you're using it
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"Error in post_chat_endpoint: {e}")
|
|
return JSONResponse({"error": str(e)}, status_code=500)
|
|
|
|
@self.app.post("/api/context")
|
|
async def create_context():
|
|
context = self.create_context()
|
|
logging.info(f"Generated new session as {context.id}")
|
|
return JSONResponse({ "id": context.id })
|
|
|
|
@self.app.get("/api/history/{context_id}/{session_type}")
|
|
async def get_history(context_id: str, session_type: str, request: Request):
|
|
logging.info(f"{request.method} {request.url.path}")
|
|
try:
|
|
context = self.upsert_context(context_id)
|
|
session = context.get_session(session_type)
|
|
if not session:
|
|
logging.info(f"Session {session_type} not found. Returning empty history.")
|
|
return JSONResponse({ "messages": [] })
|
|
logging.info(f"History for {session_type} contains {len(session.conversation.messages)} entries.")
|
|
return session.conversation
|
|
except Exception as e:
|
|
logging.error(f"Error in get_history: {e}")
|
|
return JSONResponse({"error": str(e)}, status_code=404)
|
|
|
|
@self.app.get("/api/tools/{context_id}")
|
|
async def get_tools(context_id: str, request: Request):
|
|
logging.info(f"{request.method} {request.url.path}")
|
|
context = self.upsert_context(context_id)
|
|
return JSONResponse(context.tools)
|
|
|
|
@self.app.put("/api/tools/{context_id}")
|
|
async def put_tools(context_id: str, request: Request):
|
|
logging.info(f"{request.method} {request.url.path}")
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
try:
|
|
data = await request.json()
|
|
modify = data["tool"]
|
|
enabled = data["enabled"]
|
|
for tool in context.tools:
|
|
if modify == tool["function"]["name"]:
|
|
tool["enabled"] = enabled
|
|
self.save_context(context_id)
|
|
return JSONResponse(context.tools)
|
|
return JSONResponse({ "status": f"{modify} not found in tools." }, status_code=404)
|
|
except:
|
|
return JSONResponse({ "status": "error" }, 405)
|
|
|
|
|
|
@self.app.get("/api/context-status/{context_id}/{session_type}")
|
|
async def get_context_status(context_id, session_type: str, request: Request):
|
|
logging.info(f"{request.method} {request.url.path}")
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
session = context.get_session(session_type)
|
|
if not session:
|
|
return JSONResponse({"context_used": 0, "max_context": defines.max_context})
|
|
return JSONResponse({"context_used": session.context_tokens, "max_context": defines.max_context})
|
|
|
|
@self.app.get("/api/health")
|
|
async def health_check():
|
|
return JSONResponse({"status": "healthy"})
|
|
|
|
@self.app.get("/{path:path}")
|
|
async def serve_static(path: str):
|
|
full_path = os.path.join(defines.static_content, path)
|
|
if os.path.exists(full_path) and os.path.isfile(full_path):
|
|
logging.info(f"Serve static request for {full_path}")
|
|
return FileResponse(full_path)
|
|
logging.info(f"Serve index.html for {path}")
|
|
return FileResponse(os.path.join(defines.static_content, "index.html"))
|
|
|
|
def save_context(self, session_id):
|
|
"""
|
|
Serialize a Python dictionary to a file in the sessions directory.
|
|
|
|
Args:
|
|
data: Dictionary containing the session data
|
|
session_id: UUID string for the context. If it doesn't exist, it is created
|
|
|
|
Returns:
|
|
The session_id used for the file
|
|
"""
|
|
context = self.upsert_context(session_id)
|
|
|
|
# Create sessions directory if it doesn't exist
|
|
if not os.path.exists(defines.session_dir):
|
|
os.makedirs(defines.session_dir)
|
|
|
|
# Create the full file path
|
|
file_path = os.path.join(defines.session_dir, session_id)
|
|
|
|
# Serialize the data to JSON and write to file
|
|
with open(file_path, "w") as f:
|
|
f.write(context.model_dump_json())
|
|
|
|
return session_id
|
|
|
|
def load_context(self, session_id) -> Context:
|
|
"""
|
|
Load a context from a file in the sessions directory.
|
|
Args:
|
|
session_id: UUID string for the context. If it doesn't exist, a new context is created.
|
|
Returns:
|
|
A Context object with the specified ID and default settings.
|
|
"""
|
|
|
|
file_path = os.path.join(defines.session_dir, session_id)
|
|
|
|
# Check if the file exists
|
|
if not os.path.exists(file_path):
|
|
self.contexts[session_id] = self.create_context(session_id)
|
|
else:
|
|
# Read and deserialize the data
|
|
with open(file_path, "r") as f:
|
|
self.contexts[session_id] = Context.model_validate_json(f.read())
|
|
|
|
return self.contexts[session_id]
|
|
|
|
def create_context(self, context_id = None) -> Context:
|
|
"""
|
|
Create a new context with a unique ID and default settings.
|
|
Args:
|
|
context_id: Optional UUID string for the context. If not provided, a new UUID is generated.
|
|
Returns:
|
|
A Context object with the specified ID and default settings.
|
|
"""
|
|
context = Context(id=context_id)
|
|
|
|
if os.path.exists(defines.resume_doc):
|
|
context.user_resume = open(defines.resume_doc, "r").read()
|
|
context.add_session(Chat(system_prompt = system_message))
|
|
# context.add_session(Resume(system_prompt = system_generate_resume))
|
|
# context.add_session(JobDescription(system_prompt = system_job_description))
|
|
# context.add_session(FactCheck(system_prompt = system_fact_check))
|
|
context.tools = default_tools(tools)
|
|
context.rags = rags.copy()
|
|
|
|
logging.info(f"{context.id} created and added to sessions.")
|
|
self.contexts[context.id] = context
|
|
self.save_context(context.id)
|
|
return context
|
|
|
|
def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096):
|
|
ctx = round(context + len(str(messages)) * 3 / 4)
|
|
return max(defines.max_context, min(2048, ctx + ctx_buffer))
|
|
|
|
# %%
|
|
async def handle_tool_calls(self, message):
|
|
"""
|
|
Process tool calls and yield status updates along the way.
|
|
The last yielded item will be a tuple containing (tool_result, tools_used).
|
|
"""
|
|
tools_used = []
|
|
all_responses = []
|
|
|
|
for i, tool_call in enumerate(message["tool_calls"]):
|
|
arguments = tool_call["function"]["arguments"]
|
|
tool = tool_call["function"]["name"]
|
|
|
|
# Yield status update before processing each tool
|
|
yield {"status": "processing", "message": f"Processing tool {i+1}/{len(message['tool_calls'])}: {tool}..."}
|
|
|
|
# Process the tool based on its type
|
|
match tool:
|
|
case "TickerValue":
|
|
ticker = arguments.get("ticker")
|
|
if not ticker:
|
|
ret = None
|
|
else:
|
|
ret = TickerValue(ticker)
|
|
tools_used.append({ "tool": f"{tool}({ticker})", "result": ret})
|
|
|
|
case "AnalyzeSite":
|
|
url = arguments.get("url")
|
|
question = arguments.get("question", "what is the summary of this content?")
|
|
|
|
# Additional status update for long-running operations
|
|
yield {"status": "processing", "message": f"Retrieving and summarizing content from {url}..."}
|
|
ret = await AnalyzeSite(llm=self.llm, model=self.model, url=url, question=question)
|
|
tools_used.append({ "tool": f"{tool}('{url}', '{question}')", "result": ret })
|
|
|
|
case "DateTime":
|
|
tz = arguments.get("timezone")
|
|
ret = DateTime(tz)
|
|
tools_used.append({ "tool": f"{tool}('{tz}')", "result": ret })
|
|
|
|
case "WeatherForecast":
|
|
city = arguments.get("city")
|
|
state = arguments.get("state")
|
|
|
|
yield {"status": "processing", "message": f"Fetching weather data for {city}, {state}..."}
|
|
ret = WeatherForecast(city, state)
|
|
tools_used.append({ "tool": f"{tool}('{city}', '{state}')", "result": ret })
|
|
|
|
case _:
|
|
ret = None
|
|
|
|
# Build response for this tool
|
|
tool_response = {
|
|
"role": "tool",
|
|
"content": str(ret),
|
|
"name": tool_call["function"]["name"]
|
|
}
|
|
all_responses.append(tool_response)
|
|
|
|
# Yield the final result as the last item
|
|
final_result = all_responses[0] if len(all_responses) == 1 else all_responses
|
|
yield (final_result, tools_used)
|
|
|
|
def upsert_context(self, context_id = None) -> Context:
|
|
"""
|
|
Upsert a context based on the provided context_id.
|
|
Args:
|
|
context_id: UUID string for the context. If it doesn't exist, a new context is created.
|
|
Returns:
|
|
A Context object with the specified ID and default settings.
|
|
"""
|
|
|
|
if not context_id:
|
|
logging.warning("No context ID provided. Creating a new context.")
|
|
return self.create_context()
|
|
|
|
if not is_valid_uuid(context_id):
|
|
logging.info(f"User requested invalid context_id: {context_id}")
|
|
raise ValueError("Invalid context_id: {context_id}")
|
|
|
|
if context_id in self.contexts:
|
|
return self.contexts[context_id]
|
|
|
|
logging.info(f"Context {context_id} not found. Creating new context.")
|
|
return self.load_context(context_id)
|
|
|
|
def generate_rag_results(self, context, content):
|
|
results_found = False
|
|
|
|
if self.file_watcher:
|
|
for rag in context.rags:
|
|
if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now...
|
|
yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."}
|
|
chroma_results = self.file_watcher.find_similar(query=content, top_k=10)
|
|
if chroma_results:
|
|
results_found = True
|
|
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
|
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
|
|
|
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
|
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
|
|
|
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
|
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
|
|
|
yield {
|
|
**chroma_results,
|
|
"name": rag["name"],
|
|
"umap_embedding_2d": umap_2d,
|
|
"umap_embedding_3d": umap_3d
|
|
}
|
|
|
|
if not results_found:
|
|
yield {"status": "complete", "message": "No RAG context found"}
|
|
yield {
|
|
"rag": None,
|
|
"documents": [],
|
|
"embeddings": [],
|
|
"umap_embedding_2d": [],
|
|
"umap_embedding_3d": []
|
|
}
|
|
else:
|
|
yield {"status": "complete", "message": "RAG processing complete"}
|
|
|
|
# session_type: chat
|
|
# * Q&A
|
|
#
|
|
# session_type: job_description
|
|
# * First message sets Job Description and generates Resume
|
|
# * Has content (Job Description)
|
|
# * Then Q&A of Job Description
|
|
#
|
|
# session_type: resume
|
|
# * First message sets Resume and generates Fact Check
|
|
# * Has no content
|
|
# * Then Q&A of Resume
|
|
#
|
|
# Fact Check:
|
|
# * First message sets Fact Check and is Q&A
|
|
# * Has content
|
|
# * Then Q&A of Fact Check
|
|
async def generate_response(self, context : Context, session : Session, content : str):
|
|
if not self.file_watcher:
|
|
return
|
|
|
|
if self.processing:
|
|
logging.info("TODO: Implement delay queing; busy for same session, otherwise return queue size and estimated wait time")
|
|
yield {"status": "error", "message": "Busy processing another request."}
|
|
return
|
|
|
|
self.processing = True
|
|
|
|
conversation : Conversation = session.conversation
|
|
|
|
message = Message(prompt=content)
|
|
del content # Prevent accidental use of content
|
|
|
|
# Default to not using tools
|
|
enable_tools = False
|
|
|
|
# Default to using RAG if there is content to check
|
|
if message.prompt:
|
|
enable_rag = True
|
|
else:
|
|
enable_rag = False
|
|
|
|
# RAG is disabled when asking questions about the resume
|
|
if session.session_type == "resume":
|
|
enable_rag = False
|
|
|
|
# The first time through each session session_type a content_seed may be set for
|
|
# future chat sessions; use it once, then clear it
|
|
message.preamble = session.get_and_reset_content_seed()
|
|
system_prompt = session.system_prompt
|
|
|
|
# After the first time a particular session session_type is used, it is handled as a chat.
|
|
# The number of messages indicating the session is ready for chat varies based on
|
|
# the session_type of session
|
|
process_type = session.session_type
|
|
match process_type:
|
|
case "job_description":
|
|
logging.info(f"job_description user_history len: {len(conversation.messages)}")
|
|
if len(conversation.messages) >= 2: # USER, ASSISTANT
|
|
process_type = "chat"
|
|
case "resume":
|
|
logging.info(f"resume user_history len: {len(conversation.messages)}")
|
|
if len(conversation.messages) >= 3: # USER, ASSISTANT, FACT_CHECK
|
|
process_type = "chat"
|
|
case "fact_check":
|
|
process_type = "chat" # Fact Check is always a chat session
|
|
|
|
match process_type:
|
|
# Normal chat interactions with context history
|
|
case "chat":
|
|
if not message.prompt:
|
|
yield {"status": "error", "message": "No query provided for chat."}
|
|
logging.info(f"user_history len: {len(conversation.messages)}")
|
|
self.processing = False
|
|
return
|
|
|
|
enable_tools = True
|
|
|
|
# Generate RAG content if enabled, based on the content
|
|
rag_context = ""
|
|
if enable_rag:
|
|
# Initialize metadata["rag"] to None or a default value
|
|
message.metadata["rag"] = None
|
|
|
|
for value in self.generate_rag_results(context, message.prompt):
|
|
if "status" in value:
|
|
yield value
|
|
else:
|
|
if value.get("documents") or value.get("rag") is not None:
|
|
message.metadata["rag"] = value
|
|
|
|
if message.metadata["rag"]:
|
|
for doc in message.metadata["rag"]["documents"]:
|
|
rag_context += f"{doc}\n"
|
|
|
|
if rag_context:
|
|
message.preamble = f"""
|
|
<|context|>
|
|
{rag_context}
|
|
"""
|
|
if context.user_resume:
|
|
message.preamble += f"""
|
|
<|resume|>
|
|
{context.user_resume}
|
|
"""
|
|
|
|
message.preamble += """
|
|
<|rules|>
|
|
- If there is information in the <|context|> or <|resume|> sections to enhance the answer, incorporate it seamlessly and refer to it using natural language instead of mentioning '<|context|>' or '<|resume|> or quoting it directly.
|
|
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|> or <|resume|>.
|
|
|
|
<|question|>
|
|
Use that information to respond to:"""
|
|
|
|
# Use the mode specific system_prompt instead of 'chat'
|
|
system_prompt = session.system_prompt
|
|
|
|
# On first entry, a single job_description is provided ("user")
|
|
# Generate a resume to append to RESUME history
|
|
case "job_description":
|
|
# Generate RAG content if enabled, based on the content
|
|
rag_context = ""
|
|
if enable_rag:
|
|
# Initialize metadata["rag"] to None or a default value
|
|
message.metadata["rag"] = None
|
|
|
|
for value in self.generate_rag_results(context, message.prompt):
|
|
if "status" in value:
|
|
yield value
|
|
else:
|
|
if value.get("documents") or value.get("rag") is not None:
|
|
message.metadata["rag"] = value
|
|
|
|
if message.metadata["rag"]:
|
|
for doc in message.metadata["rag"]["documents"]:
|
|
rag_context += f"{doc}\n"
|
|
|
|
message.preamble = ""
|
|
if rag_context:
|
|
message.preamble += f"""
|
|
<|context|>
|
|
{rag_context}
|
|
"""
|
|
|
|
if context.user_resume:
|
|
message.preamble += f"""
|
|
<|resume|>
|
|
{context.user_resume}
|
|
"""
|
|
|
|
message.preamble += f"""
|
|
<|job_description|>
|
|
{message.prompt}
|
|
"""
|
|
tmp = context.get_session("job_description")
|
|
if not tmp:
|
|
raise Exception(f"Job description session not found.")
|
|
# Set the content seed for the job_description session
|
|
tmp.set_content_seed(message.preamble + "<|question|>\nUse the above information to respond to this prompt: ")
|
|
|
|
message.preamble += f"""
|
|
<|rules|>
|
|
1. Use the above <|resume|> and <|context|> to create the resume for the <|job_description|>.
|
|
2. Do not use content from the <|job_description|> in the response unless the <|context|> or <|resume|> mentions them.
|
|
|
|
<|question|>
|
|
Use to the above information to respond to this prompt:
|
|
"""
|
|
|
|
# For all future calls to job_description, use the system_job_description
|
|
session.system_prompt = system_job_description
|
|
|
|
# Seed the history for job_description
|
|
stuffingMessage = Message(prompt=message.prompt)
|
|
stuffingMessage.response = "Job description stored to use in future queries."
|
|
stuffingMessage.metadata["origin"] = "job_description"
|
|
stuffingMessage.metadata["display"] = "hide"
|
|
conversation.add_message(stuffingMessage)
|
|
|
|
message.add_action("generate_resume")
|
|
|
|
logging.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into session'resume'")
|
|
logging.info("TODO: For subsequent runs, have the Session handler generate the follow up prompts so they can have correct context preamble")
|
|
|
|
# Switch to resume session for LLM responses
|
|
# message.metadata["origin"] = "resume"
|
|
# session = context.get_or_create_session("resume")
|
|
# system_prompt = session.system_prompt
|
|
# llm_history = session.llm_history = []
|
|
# user_history = session.user_history = []
|
|
|
|
# Ignore the passed in content and invoke Fact Check
|
|
case "resume":
|
|
if len(context.get_or_create_session("resume").conversation.messages) < 2: # USER, **ASSISTANT**
|
|
raise Exception(f"No resume found in user history.")
|
|
resume = context.get_or_create_session("resume").conversation.messages[1]
|
|
|
|
# Generate RAG content if enabled, based on the content
|
|
rag_context = ""
|
|
if enable_rag:
|
|
# Initialize metadata["rag"] to None or a default value
|
|
message.metadata["rag"] = None
|
|
|
|
for value in self.generate_rag_results(context, resume["content"]):
|
|
if "status" in value:
|
|
yield value
|
|
else:
|
|
if value.get("documents") or value.get("rag") is not None:
|
|
message.metadata["rag"] = value
|
|
|
|
if message.metadata["rag"]:
|
|
for doc in message.metadata["rag"]["documents"]:
|
|
rag_context += f"{doc}\n"
|
|
|
|
|
|
# This is being passed to Fact Check, so do not provide the <|job_description|>
|
|
message.preamble = f""
|
|
|
|
if rag_context:
|
|
message.preamble += f"""
|
|
<|context|>
|
|
{rag_context}
|
|
"""
|
|
if context.user_resume:
|
|
# Do not prefix the resume with <|resume|>; just add to the <|context|>
|
|
message.preamble += f"""
|
|
{context.user_resume}
|
|
"""
|
|
|
|
message.preamble += f"""
|
|
<|resume|>
|
|
{resume['content']}
|
|
|
|
<|rules|>
|
|
1. Do not invent or assume any information not explicitly present in the <|context|>.
|
|
2. Analyze the <|resume|> to identify any discrepancies or inaccuracies based on the <|context|>.
|
|
|
|
<|question|>
|
|
"""
|
|
|
|
context.get_or_create_session("resume").set_content_seed(f"""
|
|
<|resume|>
|
|
{resume["content"]}
|
|
|
|
<|question|>
|
|
Use the above <|resume|> and <|job_description|> to answer this query:
|
|
""")
|
|
|
|
message.prompt = "Fact check the resume and report discrepancies."
|
|
|
|
# Seed the history for resume
|
|
messages = [ {
|
|
"role": "user", "content": "Fact check resume", "origin": "resume", "display": "hide"
|
|
}, {
|
|
"role": "assistant", "content": "Resume fact checked.", "origin": "resume", "display": "hide"
|
|
} ]
|
|
# Do not add this to the LLM history; it is only used for UI presentation
|
|
stuffingMessage = Message(prompt="Fact check resume")
|
|
stuffingMessage.response = "Resume fact checked."
|
|
stuffingMessage.metadata["origin"] = "resume"
|
|
stuffingMessage.metadata["display"] = "hide"
|
|
stuffingMessage.actions = [ "fact_check" ]
|
|
logging.info("TODO: Switch this to use actions to keep the UI from showingit")
|
|
conversation.add_message(stuffingMessage)
|
|
|
|
# For all future calls to job_description, use the system_job_description
|
|
logging.info("TODO: Create a system_resume_QA prompt to use for the resume session")
|
|
session.system_prompt = system_prompt
|
|
|
|
# Switch to fact_check session for LLM responses
|
|
message.metadata["origin"] = "fact_check"
|
|
session = context.get_or_create_session("fact_check", system_prompt=system_fact_check)
|
|
|
|
llm_history = session.llm_history = []
|
|
user_history = session.user_history = []
|
|
|
|
case _:
|
|
raise Exception(f"Invalid chat session_type: {session_type}")
|
|
|
|
conversation.add_message(message)
|
|
# llm_history.append({"role": "user", "content": message.preamble + content})
|
|
# user_history.append({"role": "user", "content": content, "origin": message.metadata["origin"]})
|
|
# message.metadata["full_query"] = llm_history[-1]["content"]
|
|
|
|
# Uses cached system_prompt as session.system_prompt may have been updated for follow up questions
|
|
messages = create_system_message(system_prompt)
|
|
if context.message_history_length:
|
|
to_add = conversation.messages[-context.message_history_length:]
|
|
else:
|
|
to_add = conversation.messages
|
|
for m in to_add:
|
|
messages.extend([ {
|
|
"role": "user",
|
|
"content": m.content,
|
|
}, {
|
|
"role": "assistant",
|
|
"content": m.response,
|
|
} ])
|
|
|
|
message.content = message.preamble + message.prompt
|
|
|
|
# To send to the LLM
|
|
messages.append({
|
|
"role": "user",
|
|
"content": message.content
|
|
})
|
|
|
|
# Add the system message to the beginning of the messages list
|
|
message.content = f"""
|
|
<|system_prompt|>
|
|
{system_prompt}
|
|
|
|
{message.preamble}
|
|
{message.prompt}"""
|
|
|
|
# Estimate token length of new messages
|
|
ctx_size = self.get_optimal_ctx_size(context.get_or_create_session(process_type).context_tokens, messages=message.prompt)
|
|
|
|
if len(conversation.messages) > 2:
|
|
processing_message = f"Processing {'RAG augmented ' if enable_rag else ''}query..."
|
|
else:
|
|
match session.session_type:
|
|
case "job_description":
|
|
processing_message = f"Generating {'RAG augmented ' if enable_rag else ''}resume..."
|
|
case "resume":
|
|
processing_message = f"Fact Checking {'RAG augmented ' if enable_rag else ''}resume..."
|
|
case _:
|
|
processing_message = f"Processing {'RAG augmented ' if enable_rag else ''}query..."
|
|
|
|
yield {"status": "processing", "message": processing_message, "num_ctx": ctx_size}
|
|
|
|
# Use the async generator in an async for loop
|
|
try:
|
|
if enable_tools:
|
|
response = self.llm.chat(model=self.model, messages=messages, tools=llm_tools(context.tools), options={ "num_ctx": ctx_size })
|
|
else:
|
|
response = self.llm.chat(model=self.model, messages=messages, options={ "num_ctx": ctx_size })
|
|
except Exception as e:
|
|
logging.exception({ "model": self.model, "error": str(e) })
|
|
yield {"status": "error", "message": f"An error occurred communicating with LLM"}
|
|
self.processing = False
|
|
return
|
|
|
|
message.metadata["eval_count"] += response["eval_count"]
|
|
message.metadata["eval_duration"] += response["eval_duration"]
|
|
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
|
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
|
session.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
|
|
|
tools_used = []
|
|
|
|
yield {"status": "processing", "message": "Initial response received..."}
|
|
|
|
if "tool_calls" in response.get("message", {}):
|
|
yield {"status": "processing", "message": "Processing tool calls..."}
|
|
|
|
tool_message = response["message"]
|
|
tool_result = None
|
|
|
|
# Process all yielded items from the handler
|
|
async for item in self.handle_tool_calls(tool_message):
|
|
if isinstance(item, tuple) and len(item) == 2:
|
|
# This is the final result tuple (tool_result, tools_used)
|
|
tool_result, tools_used = item
|
|
else:
|
|
# This is a status update, forward it
|
|
yield item
|
|
|
|
message_dict = {
|
|
"role": tool_message.get("role", "assistant"),
|
|
"content": tool_message.get("content", "")
|
|
}
|
|
|
|
if "tool_calls" in tool_message:
|
|
message_dict["tool_calls"] = [
|
|
{"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}}
|
|
for tc in tool_message["tool_calls"]
|
|
]
|
|
|
|
pre_add_index = len(messages)
|
|
messages.append(message_dict)
|
|
|
|
if isinstance(tool_result, list):
|
|
messages.extend(tool_result)
|
|
else:
|
|
if tool_result:
|
|
messages.append(tool_result)
|
|
|
|
message.metadata["tools"] = tools_used
|
|
|
|
# Estimate token length of new messages
|
|
ctx_size = self.get_optimal_ctx_size(session.context_tokens, messages=messages[pre_add_index:])
|
|
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
|
|
# Decrease creativity when processing tool call requests
|
|
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
|
|
message.metadata["eval_count"] += response["eval_count"]
|
|
message.metadata["eval_duration"] += response["eval_duration"]
|
|
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
|
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
|
session.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
|
|
|
reply = response["message"]["content"]
|
|
message.response = reply
|
|
message.metadata["origin"] = session.session_type
|
|
# final_message = {"role": "assistant", "content": reply }
|
|
|
|
# # history is provided to the LLM and should not have additional metadata
|
|
# llm_history.append(final_message)
|
|
|
|
# user_history is provided to the REST API and does not include CONTEXT
|
|
# It does include metadata
|
|
# final_message["metadata"] = message.metadata
|
|
# user_history.append({**final_message, "origin": message.metadata["origin"]})
|
|
|
|
# Return the REST API with metadata
|
|
yield {
|
|
"status": "done",
|
|
"message": {
|
|
**message.model_dump(mode='json'),
|
|
}
|
|
}
|
|
|
|
# except Exception as e:
|
|
# logging.exception({ "model": self.model, "origin": session_type, "content": content, "error": str(e) })
|
|
# yield {"status": "error", "message": f"An error occurred: {str(e)}"}
|
|
|
|
# finally:
|
|
# self.processing = False
|
|
self.processing = False
|
|
return
|
|
|
|
def run(self, host="0.0.0.0", port=WEB_PORT, **kwargs):
|
|
try:
|
|
if self.ssl_enabled:
|
|
logging.info(f"Starting web server at https://{host}:{port}")
|
|
uvicorn.run(
|
|
self.app,
|
|
host=host,
|
|
port=port,
|
|
log_config=None,
|
|
ssl_keyfile=defines.key_path,
|
|
ssl_certfile=defines.cert_path
|
|
)
|
|
else:
|
|
logging.info(f"Starting web server at http://{host}:{port}")
|
|
uvicorn.run(
|
|
self.app,
|
|
host=host,
|
|
port=port,
|
|
log_config=None
|
|
)
|
|
except KeyboardInterrupt:
|
|
if self.observer:
|
|
self.observer.stop()
|
|
if self.observer:
|
|
self.observer.join()
|
|
|
|
# %%
|
|
|
|
# Main function to run everything
|
|
def main():
|
|
global model
|
|
|
|
# Parse command-line arguments
|
|
args = parse_args()
|
|
|
|
# Setup logging based on the provided level
|
|
setup_logging(args.level)
|
|
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
category=FutureWarning,
|
|
module="sklearn.*"
|
|
)
|
|
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
category=UserWarning,
|
|
module="umap.*"
|
|
)
|
|
|
|
llm = ollama.Client(host=args.ollama_server)
|
|
model = args.ollama_model
|
|
|
|
web_server = WebServer(llm, model)
|
|
|
|
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
|
|
|
|
main()
|