1191 lines
49 KiB
Python
1191 lines
49 KiB
Python
# %%
|
|
# Imports [standard]
|
|
# Standard library modules (no try-except needed)
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import uuid
|
|
import subprocess
|
|
import re
|
|
import math
|
|
|
|
def try_import(module_name, pip_name=None):
|
|
try:
|
|
__import__(module_name)
|
|
except ImportError:
|
|
print(f"Module '{module_name}' not found. Install it using:")
|
|
print(f" pip install {pip_name or module_name}")
|
|
|
|
# Third-party modules with import checks
|
|
try_import('ollama')
|
|
try_import('requests')
|
|
try_import('bs4', 'beautifulsoup4')
|
|
try_import('fastapi')
|
|
try_import('uvicorn')
|
|
try_import('sklearn')
|
|
try_import('numpy')
|
|
try_import('umap')
|
|
|
|
import ollama
|
|
import requests
|
|
from bs4 import BeautifulSoup
|
|
from fastapi import FastAPI, Request, BackgroundTasks
|
|
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import uvicorn
|
|
import numpy as np
|
|
#from sklearn.manifold import TSNE
|
|
import umap
|
|
|
|
from utils import (
|
|
rag as Rag,
|
|
defines
|
|
)
|
|
|
|
from tools import (
|
|
DateTime,
|
|
WeatherForecast,
|
|
TickerValue,
|
|
tools
|
|
)
|
|
|
|
rags = [
|
|
{ "name": "JPK", "enabled": True, "description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects." },
|
|
# { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
|
|
]
|
|
|
|
def get_installed_ram():
|
|
try:
|
|
with open('/proc/meminfo', 'r') as f:
|
|
meminfo = f.read()
|
|
match = re.search(r'MemTotal:\s+(\d+)', meminfo)
|
|
if match:
|
|
return f"{math.floor(int(match.group(1)) / 1000**2)}GB" # Convert KB to GB
|
|
except Exception as e:
|
|
return f"Error retrieving RAM: {e}"
|
|
|
|
def get_graphics_cards():
|
|
gpus = []
|
|
try:
|
|
# Run the ze-monitor utility
|
|
result = subprocess.run(['ze-monitor'], capture_output=True, text=True, check=True)
|
|
|
|
# Clean up the output (remove leading/trailing whitespace and newlines)
|
|
output = result.stdout.strip()
|
|
for index in range(len(output.splitlines())):
|
|
result = subprocess.run(['ze-monitor', '--device', f'{index+1}', '--info'], capture_output=True, text=True, check=True)
|
|
gpu_info = result.stdout.strip().splitlines()
|
|
gpu = {
|
|
"discrete": True, # Assume it's discrete initially
|
|
"name": None,
|
|
"memory": None
|
|
}
|
|
gpus.append(gpu)
|
|
for line in gpu_info:
|
|
match = re.match(r'^Device: [^(]*\((.*)\)', line)
|
|
if match:
|
|
gpu["name"] = match.group(1)
|
|
continue
|
|
|
|
match = re.match(r'^\s*Memory: (.*)', line)
|
|
if match:
|
|
gpu["memory"] = match.group(1)
|
|
continue
|
|
|
|
match = re.match(r'^.*Is integrated with host: Yes.*', line)
|
|
if match:
|
|
gpu["discrete"] = False
|
|
continue
|
|
|
|
return gpus
|
|
except Exception as e:
|
|
return f"Error retrieving GPU info: {e}"
|
|
|
|
def get_cpu_info():
|
|
try:
|
|
with open('/proc/cpuinfo', 'r') as f:
|
|
cpuinfo = f.read()
|
|
model_match = re.search(r'model name\s+:\s+(.+)', cpuinfo)
|
|
cores_match = re.findall(r'processor\s+:\s+\d+', cpuinfo)
|
|
if model_match and cores_match:
|
|
return f"{model_match.group(1)} with {len(cores_match)} cores"
|
|
except Exception as e:
|
|
return f"Error retrieving CPU info: {e}"
|
|
|
|
def system_info(model):
|
|
return {
|
|
"System RAM": get_installed_ram(),
|
|
"Graphics Card": get_graphics_cards(),
|
|
"CPU": get_cpu_info(),
|
|
"LLM Model": model,
|
|
"Context length": defines.max_context
|
|
}
|
|
|
|
# %%
|
|
# Defaults
|
|
OLLAMA_API_URL = defines.ollama_api_url
|
|
MODEL_NAME = defines.model
|
|
LOG_LEVEL="info"
|
|
USE_TLS=False
|
|
WEB_HOST="0.0.0.0"
|
|
WEB_PORT=8911
|
|
DEFAULT_HISTORY_LENGTH=5
|
|
|
|
# %%
|
|
# Globals
|
|
context_tag = "INFO"
|
|
system_message = f"""
|
|
Launched on {DateTime()}.
|
|
|
|
When answering queries, follow these steps:
|
|
|
|
1. First analyze the query to determine if real-time information might be helpful
|
|
2. Even when [{context_tag}] is provided, consider whether the tools would provide more current or comprehensive information
|
|
3. Use the provided tools whenever they would enhance your response, regardless of whether context is also available
|
|
4. When presenting weather forecasts, include relevant emojis immediately before the corresponding text. For example, for a sunny day, say \"☀️ Sunny\" or if the forecast says there will be \"rain showers, say \"🌧️ Rain showers\". Use this mapping for weather emojis: Sunny: ☀️, Cloudy: ☁️, Rainy: 🌧️, Snowy: ❄️
|
|
4. When both [{context_tag}] and tool outputs are relevant, synthesize information from both sources to provide the most complete answer
|
|
5. Always prioritize the most up-to-date and relevant information, whether it comes from [{context_tag}] or tools
|
|
6. If [{context_tag}] and tool outputs contain conflicting information, prefer the tool outputs as they likely represent more current data
|
|
|
|
Always use tools and [{context_tag}] when possible. Be concise, and never make up information. If you do not know the answer, say so.
|
|
""".strip()
|
|
|
|
system_generate_resume = f"""
|
|
You are a professional resume writer. Your task is to write a polished, tailored resume for a specific job based only on the individual's [WORK HISTORY].
|
|
|
|
When answering queries, follow these steps:
|
|
|
|
1. You must not invent or assume any inforation not explicitly present in the [WORK HISTORY].
|
|
2. Analyze the [JOB DESCRIPTION] to identify skills required for the job.
|
|
3. Use the [JOB DESCRIPTION] provided to guide the focus, tone, and relevant skills or experience to highlight from the [WORK HISTORY].
|
|
4. Identify and emphasisze the experiences, achievements, and responsibilities from the [WORK HISTORY] that best align with the [JOB DESCRIPTION].
|
|
5. Do not use the [JOB DESCRIPTION] skills unless listed in [WORK HISTORY].
|
|
|
|
Structure the resume professionally with the following sections where applicable:
|
|
|
|
* "Name: Use full name."
|
|
* "Professional Summary: A 2-4 sentence overview tailored to the job."
|
|
* "Skills: A bullet list of key skills derived from the work history and relevant to the job."
|
|
* Professional Experience: A detailed list of roles, achievements, and responsibilities from the work history that relate to the job."
|
|
* Education: Include only if available in the work history."
|
|
|
|
Do not include any information unless it is provided in [WORK HISTORY].
|
|
Ensure the langauge is clear, concise, and aligned with industry standards for professional resumes.
|
|
"""
|
|
|
|
system_fact_check = f"""
|
|
You are a professional resume fact checker. Your task is to identify any inaccuracies in the [RESUME] based on the individual's [WORK HISTORY].
|
|
|
|
If there are inaccuracies, list them in a bullet point format.
|
|
|
|
When answering queries, follow these steps:
|
|
1. You must not invent or assume any information not explicitly present in the [WORK HISTORY].
|
|
2. Analyze the [RESUME] to identify any discrepancies or inaccuracies based on the [WORK HISTORY].
|
|
"""
|
|
|
|
tool_log = []
|
|
command_log = []
|
|
model = None
|
|
client = None
|
|
web_server = None
|
|
|
|
# %%
|
|
# Cmd line overrides
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
|
parser.add_argument("--ollama-server", type=str, default=OLLAMA_API_URL, help=f"Ollama API endpoint. default={OLLAMA_API_URL}")
|
|
parser.add_argument("--ollama-model", type=str, default=MODEL_NAME, help=f"LLM model to use. default={MODEL_NAME}")
|
|
parser.add_argument("--web-host", type=str, default=WEB_HOST, help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.")
|
|
parser.add_argument("--web-port", type=str, default=WEB_PORT, help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.")
|
|
parser.add_argument('--level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
|
default=LOG_LEVEL, help=f'Set the logging level. default={LOG_LEVEL}')
|
|
return parser.parse_args()
|
|
|
|
def setup_logging(level):
|
|
numeric_level = getattr(logging, level.upper(), None)
|
|
if not isinstance(numeric_level, int):
|
|
raise ValueError(f"Invalid log level: {level}")
|
|
|
|
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
|
|
|
|
logging.info(f"Logging is set to {level} level.")
|
|
|
|
|
|
# %%
|
|
|
|
async def AnalyzeSite(url, question):
|
|
"""
|
|
Fetches content from a URL, extracts the text, and uses Ollama to summarize it.
|
|
|
|
Args:
|
|
url (str): The URL of the website to summarize
|
|
|
|
Returns:
|
|
str: A summary of the website content
|
|
"""
|
|
global model, client
|
|
try:
|
|
# Fetch the webpage
|
|
headers = {
|
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
|
}
|
|
logging.info(f"Fetching {url}")
|
|
response = requests.get(url, headers=headers, timeout=10)
|
|
response.raise_for_status()
|
|
logging.info(f"{url} returned. Processing...")
|
|
# Parse the HTML
|
|
soup = BeautifulSoup(response.text, 'html.parser')
|
|
|
|
# Remove script and style elements
|
|
for script in soup(["script", "style"]):
|
|
script.extract()
|
|
|
|
# Get text content
|
|
text = soup.get_text(separator=' ', strip=True)
|
|
|
|
# Clean up text (remove extra whitespace)
|
|
lines = (line.strip() for line in text.splitlines())
|
|
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
|
text = ' '.join(chunk for chunk in chunks if chunk)
|
|
|
|
# Limit text length if needed (Ollama may have token limits)
|
|
max_chars = 100000
|
|
if len(text) > max_chars:
|
|
text = text[:max_chars] + "..."
|
|
|
|
# Create Ollama client
|
|
# logging.info(f"Requesting summary of: {text}")
|
|
|
|
# Generate summary using Ollama
|
|
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
|
response = client.generate(model=model,
|
|
system="You are given the contents of {url}. Answer the question about the contents",
|
|
prompt=prompt)
|
|
|
|
#logging.info(response['response'])
|
|
|
|
return {
|
|
'source': 'summarizer-llm',
|
|
'content': response['response'],
|
|
'metadata': DateTime()
|
|
}
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
return f"Error fetching the URL: {str(e)}"
|
|
except Exception as e:
|
|
return f"Error processing the website content: {str(e)}"
|
|
|
|
# %%
|
|
|
|
# %%
|
|
def is_valid_uuid(value):
|
|
try:
|
|
uuid_obj = uuid.UUID(value, version=4)
|
|
return str(uuid_obj) == value
|
|
except (ValueError, TypeError):
|
|
return False
|
|
|
|
def default_tools(tools):
|
|
return [{**tool, "enabled": True} for tool in tools]
|
|
|
|
def find_summarize_tool(tools):
|
|
return [{**tool, "enabled": True} for tool in tools if tool.get("name", "") == "AnalyzeSite"]
|
|
|
|
def llm_tools(tools):
|
|
return [tool for tool in tools if tool.get("enabled", False) == True]
|
|
|
|
# %%
|
|
async def handle_tool_calls(message):
|
|
"""
|
|
Process tool calls and yield status updates along the way.
|
|
The last yielded item will be a tuple containing (tool_result, tools_used).
|
|
"""
|
|
tools_used = []
|
|
all_responses = []
|
|
|
|
for i, tool_call in enumerate(message['tool_calls']):
|
|
arguments = tool_call['function']['arguments']
|
|
tool = tool_call['function']['name']
|
|
|
|
# Yield status update before processing each tool
|
|
yield {"status": "processing", "message": f"Processing tool {i+1}/{len(message['tool_calls'])}: {tool}..."}
|
|
|
|
# Process the tool based on its type
|
|
match tool:
|
|
case 'TickerValue':
|
|
ticker = arguments.get('ticker')
|
|
if not ticker:
|
|
ret = None
|
|
else:
|
|
ret = TickerValue(ticker)
|
|
tools_used.append({ "tool": f"{tool}({ticker})", "result": ret})
|
|
|
|
case 'AnalyzeSite':
|
|
url = arguments.get('url')
|
|
question = arguments.get('question', 'what is the summary of this content?')
|
|
|
|
# Additional status update for long-running operations
|
|
yield {"status": "processing", "message": f"Retrieving and summarizing content from {url}..."}
|
|
ret = await AnalyzeSite(url, question)
|
|
tools_used.append({ "tool": f"{tool}('{url}', '{question}')", "result": ret })
|
|
|
|
case 'DateTime':
|
|
tz = arguments.get('timezone')
|
|
ret = DateTime(tz)
|
|
tools_used.append({ "tool": f"{tool}('{tz}')", "result": ret })
|
|
|
|
case 'WeatherForecast':
|
|
city = arguments.get('city')
|
|
state = arguments.get('state')
|
|
|
|
yield {"status": "processing", "message": f"Fetching weather data for {city}, {state}..."}
|
|
ret = WeatherForecast(city, state)
|
|
tools_used.append({ "tool": f"{tool}('{city}', '{state}')", "result": ret })
|
|
|
|
case _:
|
|
ret = None
|
|
|
|
# Build response for this tool
|
|
tool_response = {
|
|
"role": "tool",
|
|
"content": str(ret),
|
|
"name": tool_call['function']['name']
|
|
}
|
|
all_responses.append(tool_response)
|
|
|
|
# Yield the final result as the last item
|
|
final_result = all_responses[0] if len(all_responses) == 1 else all_responses
|
|
yield (final_result, tools_used)
|
|
|
|
# %%
|
|
class WebServer:
|
|
def __init__(self, logging, client, model=MODEL_NAME):
|
|
self.logging = logging
|
|
self.app = FastAPI()
|
|
self.contexts = {}
|
|
self.client = client
|
|
self.model = model
|
|
self.processing = False
|
|
self.file_watcher = None
|
|
self.observer = None
|
|
|
|
self.app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://battle-linux.ketrenos.com:3000"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
@self.app.on_event("startup")
|
|
async def startup_event():
|
|
|
|
# Start the file watcher
|
|
self.observer, self.file_watcher = Rag.start_file_watcher(
|
|
llm=client,
|
|
watch_directory=defines.doc_dir,
|
|
recreate=False # Don't recreate if exists
|
|
)
|
|
|
|
print(f"API started with {self.file_watcher.collection.count()} documents in the collection")
|
|
|
|
@self.app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
if self.observer:
|
|
self.observer.stop()
|
|
self.observer.join()
|
|
print("File watcher stopped")
|
|
|
|
self.setup_routes()
|
|
|
|
def setup_routes(self):
|
|
@self.app.get('/')
|
|
async def root():
|
|
context = self.create_context()
|
|
self.logging.info(f"Redirecting non-session to {context['id']}")
|
|
return RedirectResponse(url=f"/{context['id']}", status_code=307)
|
|
#return JSONResponse({"redirect": f"/{context['id']}"})
|
|
|
|
@self.app.get("/api/query")
|
|
async def query_documents(query: str, top_k: int = 3):
|
|
if not self.file_watcher:
|
|
return
|
|
|
|
"""Query the RAG system with the given prompt."""
|
|
results = self.file_watcher.find_similar(query, top_k=top_k)
|
|
return {
|
|
"query": query,
|
|
"results": [
|
|
{
|
|
"content": doc,
|
|
"metadata": meta,
|
|
"distance": dist
|
|
}
|
|
for doc, meta, dist in zip(
|
|
results["documents"],
|
|
results["metadatas"],
|
|
results["distances"]
|
|
)
|
|
]
|
|
}
|
|
|
|
@self.app.post("/api/refresh/{file_path:path}")
|
|
async def refresh_document(file_path: str, background_tasks: BackgroundTasks):
|
|
if not self.file_watcher:
|
|
return
|
|
|
|
"""Manually refresh a specific document in the collection."""
|
|
full_path = os.path.join(defines.doc_dir, file_path)
|
|
|
|
if not os.path.exists(full_path):
|
|
return {"status": "error", "message": "File not found"}
|
|
|
|
# Schedule the update in the background
|
|
background_tasks.add_task(
|
|
self.file_watcher.process_file_update, full_path
|
|
)
|
|
|
|
return {
|
|
"status": "success",
|
|
"message": f"Document refresh scheduled for {file_path}"
|
|
}
|
|
|
|
# @self.app.post("/api/refresh-all")
|
|
# async def refresh_all_documents():
|
|
# if not self.file_watcher:
|
|
# return
|
|
|
|
# """Refresh all documents in the collection."""
|
|
# # Re-initialize file hashes and process all files
|
|
# self.file_watcher._initialize_file_hashes()
|
|
|
|
# # Schedule updates for all files
|
|
# file_paths = self.file_watcher.file_hashes.keys()
|
|
# tasks = [self.file_watcher.process_file_update(path) for path in file_paths]
|
|
|
|
# # Wait for all updates to complete
|
|
# await asyncio.gather(*tasks)
|
|
|
|
# return {
|
|
# "status": "success",
|
|
# "message": f"Refreshed {len(file_paths)} documents",
|
|
# "document_count": file_watcher.collection.count()
|
|
# }
|
|
|
|
@self.app.put('/api/tsne/{context_id}')
|
|
async def put_tsne(context_id: str, request: Request):
|
|
if not self.file_watcher:
|
|
return
|
|
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
|
|
context = self.upsert_context(context_id)
|
|
|
|
try:
|
|
data = await request.json()
|
|
dimensions = data.get('dimensions', 2)
|
|
except:
|
|
dimensions = 2
|
|
|
|
try:
|
|
result = self.file_watcher.collection.get(include=['embeddings', 'documents', 'metadatas'])
|
|
vectors = np.array(result['embeddings'])
|
|
umap_model = umap.UMAP(n_components=dimensions, random_state=42)
|
|
embedding = umap_model.fit_transform(vectors)
|
|
context['umap_model'] = umap_model
|
|
result['embeddings'] = embedding.tolist()
|
|
return JSONResponse(result)
|
|
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return JSONResponse({"error": str(e)}, 500)
|
|
|
|
@self.app.put('/api/similarity/{context_id}')
|
|
async def put_similarity(context_id: str, request: Request):
|
|
if not self.file_watcher:
|
|
return
|
|
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
|
|
context = self.upsert_context(context_id)
|
|
if not context.get("umap_model"):
|
|
return JSONResponse({"error": "No umap_model found in context"}, status_code=404)
|
|
|
|
try:
|
|
data = await request.json()
|
|
query = data.get('query', '')
|
|
except:
|
|
query = ''
|
|
if not query:
|
|
return JSONResponse({"error": "No query provided"}, status_code=400)
|
|
|
|
try:
|
|
chroma_results = self.file_watcher.find_similar(query=query, top_k=10)
|
|
if not chroma_results:
|
|
return JSONResponse({"error": "No results found"}, status_code=404)
|
|
chroma_embedding = chroma_results["query_embedding"]
|
|
normalized = (chroma_embedding - chroma_embedding.min()) / (chroma_embedding.max() - chroma_embedding.min())
|
|
vector_embedding = context["umap_model"].transform([normalized])[0].tolist()
|
|
return JSONResponse({ **chroma_results, "query": query, "vector_embedding": vector_embedding })
|
|
|
|
except Exception as e:
|
|
logging.error(e)
|
|
#return JSONResponse({"error": str(e)}, 500)
|
|
|
|
@self.app.put('/api/reset/{context_id}')
|
|
async def put_reset(context_id: str, request: Request):
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
data = await request.json()
|
|
try:
|
|
response = {}
|
|
for reset in data["reset"]:
|
|
match reset:
|
|
case "system-prompt":
|
|
context["system"] = [{"role": "system", "content": system_message}]
|
|
response["system-prompt"] = { "system-prompt": system_message }
|
|
case "rags":
|
|
context["rags"] = rags.copy()
|
|
response["rags"] = context["rags"]
|
|
case "tools":
|
|
context["tools"] = default_tools(tools)
|
|
response["tools"] = context["tools"]
|
|
case "history":
|
|
context["llm_history"] = []
|
|
context["user_history"] = []
|
|
response["history"] = []
|
|
context["context_tokens"] = round(len(str(context["system"])) * 3 / 4) # Estimate context usage
|
|
response["context_used"] = context["context_tokens"]
|
|
case "message-history-length":
|
|
context["message_history_length"] = DEFAULT_HISTORY_LENGTH
|
|
response["message-history-length"] = DEFAULT_HISTORY_LENGTH
|
|
|
|
if not response:
|
|
return JSONResponse({ "error": "Usage: { reset: rags|tools|history|system-prompt}"})
|
|
else:
|
|
self.save_context(context_id)
|
|
return JSONResponse(response)
|
|
|
|
except:
|
|
return JSONResponse({ "error": "Usage: { reset: rags|tools|history|system-prompt}"})
|
|
|
|
@self.app.put('/api/tunables/{context_id}')
|
|
async def put_tunables(context_id: str, request: Request):
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
data = await request.json()
|
|
for k in data.keys():
|
|
match k:
|
|
case "system-prompt":
|
|
system_prompt = data[k].strip()
|
|
if not system_prompt:
|
|
return JSONResponse({ "status": "error", "message": "System prompt can not be empty." })
|
|
context["system"] = [{"role": "system", "content": system_prompt}]
|
|
self.save_context(context_id)
|
|
return JSONResponse({ "system-prompt": system_prompt })
|
|
case "message-history-length":
|
|
value = max(0, int(data[k]))
|
|
context["message_history_length"] = value
|
|
self.save_context(context_id)
|
|
return JSONResponse({ "message-history-length": value })
|
|
case _:
|
|
return JSONResponse({ "error": f"Unrecognized tunable {k}"}, 404)
|
|
|
|
@self.app.get('/api/tunables/{context_id}')
|
|
async def get_tunables(context_id: str):
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
return JSONResponse({
|
|
"system-prompt": context["system"][0]["content"],
|
|
"message-history-length": context["message_history_length"]
|
|
})
|
|
|
|
@self.app.get('/api/resume/{context_id}')
|
|
async def get_resume(context_id: str):
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
return JSONResponse(context["resume_history"])
|
|
|
|
@self.app.get('/api/system-info/{context_id}')
|
|
async def get_system_info(context_id: str):
|
|
return JSONResponse(system_info(self.model))
|
|
|
|
@self.app.post('/api/chat/{context_id}')
|
|
async def chat_endpoint(context_id: str, request: Request):
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
data = await request.json()
|
|
|
|
# Create a custom generator that ensures flushing
|
|
async def flush_generator():
|
|
async for message in self.chat(context=context, content=data['content']):
|
|
# Convert to JSON and add newline
|
|
yield json.dumps(message) + "\n"
|
|
# Save the history as its generated
|
|
self.save_context(context_id)
|
|
# Explicitly flush after each yield
|
|
await asyncio.sleep(0) # Allow the event loop to process the write
|
|
|
|
# Return StreamingResponse with appropriate headers
|
|
return StreamingResponse(
|
|
flush_generator(),
|
|
media_type="application/json",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no" # Prevents Nginx buffering if you're using it
|
|
}
|
|
)
|
|
|
|
@self.app.post('/api/generate-resume/{context_id}')
|
|
async def post_generate_resume(context_id: str, request: Request):
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
data = await request.json()
|
|
|
|
# Create a custom generator that ensures flushing
|
|
async def flush_generator():
|
|
async for message in self.generate_resume(context=context, content=data['content']):
|
|
# Convert to JSON and add newline
|
|
yield json.dumps(message) + "\n"
|
|
# Save the history as its generated
|
|
self.save_context(context_id)
|
|
# Explicitly flush after each yield
|
|
await asyncio.sleep(0) # Allow the event loop to process the write
|
|
|
|
# Return StreamingResponse with appropriate headers
|
|
return StreamingResponse(
|
|
flush_generator(),
|
|
media_type="application/json",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no" # Prevents Nginx buffering if you're using it
|
|
}
|
|
)
|
|
|
|
@self.app.post('/api/fact-check/{context_id}')
|
|
async def post_fact_check(context_id: str, request: Request):
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
data = await request.json()
|
|
|
|
# Create a custom generator that ensures flushing
|
|
async def flush_generator():
|
|
async for message in self.fact_check(context=context, content=data['content']):
|
|
# Convert to JSON and add newline
|
|
yield json.dumps(message) + "\n"
|
|
# Save the history as its generated
|
|
self.save_context(context_id)
|
|
# Explicitly flush after each yield
|
|
await asyncio.sleep(0) # Allow the event loop to process the write
|
|
|
|
# Return StreamingResponse with appropriate headers
|
|
return StreamingResponse(
|
|
flush_generator(),
|
|
media_type="application/json",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no" # Prevents Nginx buffering if you're using it
|
|
}
|
|
)
|
|
|
|
@self.app.post('/api/context')
|
|
async def create_context():
|
|
context = self.create_context()
|
|
self.logging.info(f"Generated new session as {context['id']}")
|
|
return JSONResponse(context)
|
|
|
|
@self.app.get('/api/history/{context_id}')
|
|
async def get_history(context_id: str):
|
|
context = self.upsert_context(context_id)
|
|
return JSONResponse(context["user_history"])
|
|
|
|
@self.app.get('/api/tools/{context_id}')
|
|
async def get_tools(context_id: str):
|
|
context = self.upsert_context(context_id)
|
|
return JSONResponse(context["tools"])
|
|
|
|
@self.app.put('/api/tools/{context_id}')
|
|
async def put_tools(context_id: str, request: Request):
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
try:
|
|
data = await request.json()
|
|
modify = data["tool"]
|
|
enabled = data["enabled"]
|
|
for tool in context["tools"]:
|
|
if modify == tool["function"]["name"]:
|
|
tool["enabled"] = enabled
|
|
self.save_context(context_id)
|
|
return JSONResponse(context["tools"])
|
|
return JSONResponse({ "status": f"{modify} not found in tools." }), 404
|
|
except:
|
|
return JSONResponse({ "status": "error" }), 405
|
|
|
|
@self.app.get('/api/rags/{context_id}')
|
|
async def get_rags(context_id: str):
|
|
context = self.upsert_context(context_id)
|
|
return JSONResponse(context["rags"])
|
|
|
|
@self.app.put('/api/rags/{context_id}')
|
|
async def put_rags(context_id: str, request: Request):
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
try:
|
|
data = await request.json()
|
|
modify = data["tool"]
|
|
enabled = data["enabled"]
|
|
for tool in context["rags"]:
|
|
if modify == tool["name"]:
|
|
tool["enabled"] = enabled
|
|
self.save_context(context_id)
|
|
return JSONResponse(context["rags"])
|
|
return JSONResponse({ "status": f"{modify} not found in tools." }), 404
|
|
except:
|
|
return JSONResponse({ "status": "error" }), 405
|
|
|
|
@self.app.get('/api/context-status/{context_id}')
|
|
async def get_context_status(context_id):
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
return JSONResponse({"context_used": context["context_tokens"], "max_context": defines.max_context})
|
|
|
|
@self.app.get('/api/health')
|
|
async def health_check():
|
|
return JSONResponse({"status": "healthy"})
|
|
|
|
@self.app.get('/{path:path}')
|
|
async def serve_static(path: str):
|
|
full_path = os.path.join(defines.static_content, path)
|
|
if os.path.exists(full_path) and os.path.isfile(full_path):
|
|
self.logging.info(f"Serve static request for {full_path}")
|
|
return FileResponse(full_path)
|
|
self.logging.info(f"Serve index.html for {path}")
|
|
return FileResponse(os.path.join(defines.static_content, 'index.html'))
|
|
|
|
def save_context(self, session_id):
|
|
"""
|
|
Serialize a Python dictionary to a file in the sessions directory.
|
|
|
|
Args:
|
|
data: Dictionary containing the session data
|
|
session_id: UUID string for the context. If it doesn't exist, it is created
|
|
|
|
Returns:
|
|
The session_id used for the file
|
|
"""
|
|
context = self.upsert_context(session_id)
|
|
|
|
# Create sessions directory if it doesn't exist
|
|
if not os.path.exists(defines.session_dir):
|
|
os.makedirs(defines.session_dir)
|
|
|
|
# Create the full file path
|
|
file_path = os.path.join(defines.session_dir, session_id)
|
|
|
|
umap_model = context.get("umap_model")
|
|
if umap_model:
|
|
del context["umap_model"]
|
|
# Serialize the data to JSON and write to file
|
|
with open(file_path, 'w') as f:
|
|
json.dump(context, f)
|
|
if umap_model:
|
|
context["umap_model"] = umap_model
|
|
return session_id
|
|
|
|
def load_context(self, session_id):
|
|
"""
|
|
Load a serialized Python dictionary from a file in the sessions directory.
|
|
|
|
Args:
|
|
session_id: UUID string for the filename
|
|
|
|
Returns:
|
|
The deserialized dictionary, or a new context if it doesn't exist on disk.
|
|
"""
|
|
file_path = os.path.join(defines.session_dir, session_id)
|
|
|
|
# Check if the file exists
|
|
if not os.path.exists(file_path):
|
|
return self.create_context(session_id)
|
|
|
|
# Read and deserialize the data
|
|
with open(file_path, 'r') as f:
|
|
self.contexts[session_id] = json.load(f)
|
|
|
|
return self.contexts[session_id]
|
|
|
|
def create_context(self, context_id = None):
|
|
if not context_id:
|
|
context_id = str(uuid.uuid4())
|
|
system_context = [{"role": "system", "content": system_message}];
|
|
context = {
|
|
"id": context_id,
|
|
"system": system_context,
|
|
"system_generate_resume": system_generate_resume,
|
|
"llm_history": [],
|
|
"user_history": [],
|
|
"tools": default_tools(tools),
|
|
"resume_history": [],
|
|
"rags": rags.copy(),
|
|
"context_tokens": round(len(str(system_context)) * 3 / 4), # Estimate context usage
|
|
"message_history_length": 5 # Number of messages to supply in context
|
|
}
|
|
logging.info(f"{context_id} created and added to sessions.")
|
|
self.contexts[context_id] = context
|
|
return context
|
|
|
|
def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096):
|
|
ctx = round(context + len(str(messages)) * 3 / 4)
|
|
return max(defines.max_context, min(2048, ctx + ctx_buffer))
|
|
|
|
def upsert_context(self, context_id):
|
|
if not context_id:
|
|
logging.warning("No context ID provided. Creating a new context.")
|
|
return self.create_context()
|
|
if context_id in self.contexts:
|
|
logging.info(f"Context {context_id} found.")
|
|
return self.contexts[context_id]
|
|
logging.info(f"Context {context_id} not found. Creating new context.")
|
|
return self.load_context(context_id)
|
|
|
|
async def chat(self, context, content):
|
|
if not self.file_watcher:
|
|
return
|
|
|
|
content = content.strip()
|
|
if not content:
|
|
yield {"status": "error", "message": "Invalid request"}
|
|
return
|
|
|
|
if self.processing:
|
|
yield {"status": "error", "message": "Busy"}
|
|
return
|
|
|
|
self.processing = True
|
|
|
|
llm_history = context["llm_history"]
|
|
user_history = context["user_history"]
|
|
metadata = {
|
|
"rag": {},
|
|
"tools": [],
|
|
"eval_count": 0,
|
|
"eval_duration": 0,
|
|
"prompt_eval_count": 0,
|
|
"prompt_eval_duration": 0,
|
|
}
|
|
rag_docs = []
|
|
for rag in context["rags"]:
|
|
if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now...
|
|
yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."}
|
|
chroma_results = self.file_watcher.find_similar(query=content, top_k=10)
|
|
if chroma_results:
|
|
rag_docs.extend(chroma_results["documents"])
|
|
metadata["rag"] = { "name": rag["name"], **chroma_results }
|
|
preamble = ""
|
|
if len(rag_docs):
|
|
preamble = f"""
|
|
1. Respond to this query: {content}
|
|
2. If there is information in this context to enhance the answer, do so:
|
|
[{context_tag}]:\n"""
|
|
for doc in rag_docs:
|
|
preamble += doc
|
|
preamble += f"\n[/{context_tag}]\nUse all of that information to respond to: "
|
|
|
|
# Figure
|
|
llm_history.append({"role": "user", "content": preamble + content})
|
|
user_history.append({"role": "user", "content": content})
|
|
|
|
if context["message_history_length"]:
|
|
messages = context["system"] + llm_history[-context["message_history_length"]:]
|
|
else:
|
|
messages = context["system"] + llm_history
|
|
|
|
try:
|
|
# Estimate token length of new messages
|
|
ctx_size = self.get_optimal_ctx_size(context["context_tokens"], messages=llm_history[-1]["content"])
|
|
yield {"status": "processing", "message": "Processing request...", "num_ctx": ctx_size}
|
|
|
|
# Use the async generator in an async for loop
|
|
response = self.client.chat(model=self.model, messages=messages, tools=llm_tools(context["tools"]), options={ 'num_ctx': ctx_size })
|
|
metadata["eval_count"] += response['eval_count']
|
|
metadata["eval_duration"] += response['eval_duration']
|
|
metadata["prompt_eval_count"] += response['prompt_eval_count']
|
|
metadata["prompt_eval_duration"] += response['prompt_eval_duration']
|
|
context["context_tokens"] = response['prompt_eval_count'] + response['eval_count']
|
|
|
|
tools_used = []
|
|
|
|
yield {"status": "processing", "message": "Initial response received..."}
|
|
|
|
if 'tool_calls' in response.get('message', {}):
|
|
yield {"status": "processing", "message": "Processing tool calls..."}
|
|
|
|
message = response['message']
|
|
tool_result = None
|
|
|
|
# Process all yielded items from the handler
|
|
async for item in handle_tool_calls(message):
|
|
if isinstance(item, tuple) and len(item) == 2:
|
|
# This is the final result tuple (tool_result, tools_used)
|
|
tool_result, tools_used = item
|
|
else:
|
|
# This is a status update, forward it
|
|
yield item
|
|
|
|
message_dict = {
|
|
'role': message.get('role', 'assistant'),
|
|
'content': message.get('content', '')
|
|
}
|
|
|
|
if 'tool_calls' in message:
|
|
message_dict['tool_calls'] = [
|
|
{'function': {'name': tc['function']['name'], 'arguments': tc['function']['arguments']}}
|
|
for tc in message['tool_calls']
|
|
]
|
|
|
|
pre_add_index = len(messages)
|
|
messages.append(message_dict)
|
|
|
|
if isinstance(tool_result, list):
|
|
messages.extend(tool_result)
|
|
else:
|
|
messages.append(tool_result)
|
|
|
|
metadata["tools"] = tools_used
|
|
|
|
# Estimate token length of new messages
|
|
ctx_size = self.get_optimal_ctx_size(context["context_tokens"], messages=messages[pre_add_index:])
|
|
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
|
|
# Decrease creativity when processing tool call requests
|
|
response = self.client.chat(model=self.model, messages=messages, stream=False, options={ 'num_ctx': ctx_size }) #, "temperature": 0.5 })
|
|
metadata["eval_count"] += response['eval_count']
|
|
metadata["eval_duration"] += response['eval_duration']
|
|
metadata["prompt_eval_count"] += response['prompt_eval_count']
|
|
metadata["prompt_eval_duration"] += response['prompt_eval_duration']
|
|
context["context_tokens"] = response['prompt_eval_count'] + response['eval_count']
|
|
|
|
reply = response['message']['content']
|
|
final_message = {"role": "assistant", "content": reply }
|
|
|
|
# history is provided to the LLM and should not have additional metadata
|
|
llm_history.append(final_message)
|
|
final_message["metadata"] = metadata
|
|
|
|
# user_history is provided to the REST API and does not include CONTEXT or metadata
|
|
user_history.append(final_message)
|
|
|
|
# Return the REST API with metadata
|
|
yield {"status": "done", "message": final_message }
|
|
|
|
except Exception as e:
|
|
logging.exception({ 'model': self.model, 'messages': messages, 'error': str(e) })
|
|
yield {"status": "error", "message": f"An error occurred: {str(e)}"}
|
|
|
|
finally:
|
|
self.processing = False
|
|
|
|
async def generate_resume(self, context, content):
|
|
if not self.file_watcher:
|
|
return
|
|
|
|
content = content.strip()
|
|
if not content:
|
|
yield {"status": "error", "message": "Invalid request"}
|
|
return
|
|
|
|
if self.processing:
|
|
yield {"status": "error", "message": "Busy"}
|
|
return
|
|
|
|
self.processing = True
|
|
resume_history = context["resume_history"]
|
|
resume = {
|
|
"job_description": content,
|
|
"resume": "",
|
|
"metadata": {},
|
|
"rag": "",
|
|
"fact_check": ""
|
|
}
|
|
|
|
metadata = {
|
|
"rag": {},
|
|
"tools": [],
|
|
"eval_count": 0,
|
|
"eval_duration": 0,
|
|
"prompt_eval_count": 0,
|
|
"prompt_eval_duration": 0,
|
|
}
|
|
rag_docs = []
|
|
resume_doc = open(defines.resume_doc, 'r').read()
|
|
rag_docs.append(resume_doc)
|
|
for rag in context["rags"]:
|
|
if rag["enabled"] and rag["name"] == "JPK": # Only support JPK rag right now...
|
|
yield {"status": "processing", "message": f"Checking RAG context {rag['name']}..."}
|
|
chroma_results = self.file_watcher.find_similar(query=content, top_k=10)
|
|
if chroma_results:
|
|
rag_docs.extend(chroma_results["documents"])
|
|
metadata["rag"] = { "name": rag["name"], **chroma_results }
|
|
preamble = f"The current time is {DateTime()}\n"
|
|
preamble = f"""[WORK HISTORY]:\n"""
|
|
for doc in rag_docs:
|
|
preamble += f"{doc}\n"
|
|
resume["rag"] += f"{doc}\n"
|
|
preamble += f"\n[/WORK HISTORY]\n"
|
|
|
|
content = f"{preamble}\nUse the above WORK HISTORY to create the resume for this JOB DESCRIPTION. Do not use the JOB DESCRIPTION skills as skills the user posseses unless listed in WORK HISTORY:\n[JOB DESCRIPTION]\n{content}\n[/JOB DESCRIPTION]\n"
|
|
|
|
try:
|
|
# Estimate token length of new messages
|
|
ctx_size = self.get_optimal_ctx_size(context["context_tokens"], messages=[system_generate_resume, content])
|
|
|
|
yield {"status": "processing", "message": "Processing request...", "num_ctx": ctx_size}
|
|
|
|
# Use the async generator in an async for loop
|
|
#
|
|
# To support URL lookup:
|
|
#
|
|
# 1. Enable tools in a call to chat() with a simple prompt to invoke the tool to generate the summary if requested.
|
|
# 2. If not requested (no tool call,) abort the path
|
|
# 3. Otherwise, we know the URL was good and can use that URLs fetched content as context.
|
|
#
|
|
response = self.client.generate(model=self.model, system=system_generate_resume, prompt=content, options={ 'num_ctx': ctx_size })
|
|
metadata["eval_count"] += response['eval_count']
|
|
metadata["eval_duration"] += response['eval_duration']
|
|
metadata["prompt_eval_count"] += response['prompt_eval_count']
|
|
metadata["prompt_eval_duration"] += response['prompt_eval_duration']
|
|
context["context_tokens"] = response['prompt_eval_count'] + response['eval_count']
|
|
|
|
reply = response['response']
|
|
final_message = {"role": "assistant", "content": reply, "metadata": metadata }
|
|
|
|
resume['resume'] = final_message
|
|
resume_history.append(resume)
|
|
|
|
# Return the REST API with metadata
|
|
yield {"status": "done", "message": final_message }
|
|
|
|
except Exception as e:
|
|
logging.exception({ 'model': self.model, 'content': content, 'error': str(e) })
|
|
yield {"status": "error", "message": f"An error occurred: {str(e)}"}
|
|
|
|
finally:
|
|
self.processing = False
|
|
|
|
async def fact_check(self, context, content):
|
|
content = content.strip()
|
|
if not content:
|
|
yield {"status": "error", "message": "Invalid request"}
|
|
return
|
|
|
|
if self.processing:
|
|
yield {"status": "error", "message": "Busy"}
|
|
return
|
|
|
|
self.processing = True
|
|
resume_history = context["resume_history"]
|
|
if len(resume_history) == 0:
|
|
yield {"status": "done", "message": "No resume history found." }
|
|
return
|
|
|
|
resume = resume_history[-1]
|
|
metadata = resume["metadata"]
|
|
metadata["eval_count"] = 0
|
|
metadata["eval_duration"] = 0
|
|
metadata["prompt_eval_count"] = 0
|
|
metadata["prompt_eval_duration"] = 0
|
|
|
|
content = f"[WORK HISTORY]:{resume['rag']}[/WORK HISTORY]\n\n[RESUME]\n{resume['resume']['content']}\n[/RESUME]\n\n"
|
|
|
|
try:
|
|
# Estimate token length of new messages
|
|
ctx_size = self.get_optimal_ctx_size(context["context_tokens"], messages=[system_fact_check, content])
|
|
yield {"status": "processing", "message": "Processing request...", "num_ctx": ctx_size}
|
|
response = self.client.generate(model=self.model, system=system_fact_check, prompt=content, options={ 'num_ctx': ctx_size })
|
|
logging.info(f"Fact checking {ctx_size} tokens.")
|
|
metadata["eval_count"] += response['eval_count']
|
|
metadata["eval_duration"] += response['eval_duration']
|
|
metadata["prompt_eval_count"] += response['prompt_eval_count']
|
|
metadata["prompt_eval_duration"] += response['prompt_eval_duration']
|
|
context["context_tokens"] = response['prompt_eval_count'] + response['eval_count']
|
|
reply = response['response']
|
|
final_message = {"role": "assistant", "content": reply, "metadata": metadata }
|
|
resume['fact_check'] = final_message
|
|
|
|
# Return the REST API with metadata
|
|
yield {"status": "done", "message": final_message }
|
|
|
|
except Exception as e:
|
|
logging.exception({ 'model': self.model, 'content': content, 'error': str(e) })
|
|
yield {"status": "error", "message": f"An error occurred: {str(e)}"}
|
|
|
|
finally:
|
|
self.processing = False
|
|
|
|
|
|
def run(self, host='0.0.0.0', port=WEB_PORT, **kwargs):
|
|
try:
|
|
uvicorn.run(self.app, host=host, port=port)
|
|
except KeyboardInterrupt:
|
|
if self.observer:
|
|
self.observer.stop()
|
|
if self.observer:
|
|
self.observer.join()
|
|
|
|
# %%
|
|
|
|
# Main function to run everything
|
|
def main():
|
|
global client, model, web_server
|
|
# Parse command-line arguments
|
|
args = parse_args()
|
|
|
|
# Setup logging based on the provided level
|
|
setup_logging(args.level)
|
|
|
|
client = ollama.Client(host=args.ollama_server)
|
|
model = args.ollama_model
|
|
|
|
# documents = Rag.load_text_files(defines.doc_dir)
|
|
# print(f"Documents loaded {len(documents)}")
|
|
# chunks = Rag.create_chunks_from_documents(documents)
|
|
# doc_types = set(chunk.metadata['doc_type'] for chunk in chunks)
|
|
# print(f"Document types: {doc_types}")
|
|
# print(f"Vectorstore created with {collection.count()} documents")
|
|
|
|
web_server = WebServer(logging, client, model)
|
|
logging.info(f"Starting web server at http://{args.web_host}:{args.web_port}")
|
|
|
|
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
|
|
|
|
main()
|