backstory/src/server.py

614 lines
23 KiB
Python

# %%
# Imports [standard]
# Standard library modules (no try-except needed)
import argparse
import asyncio
import anyio
import json
import logging
import os
import queue
import re
import time
from datetime import datetime
import textwrap
import threading
import uuid
import random
def try_import(module_name, pip_name=None):
try:
__import__(module_name)
except ImportError:
print(f"Module '{module_name}' not found. Install it using:")
print(f" pip install {pip_name or module_name}")
# Third-party modules with import checks
try_import('gradio')
try_import('ollama')
try_import('openai')
try_import('pytz')
try_import('requests')
try_import('yfinance', 'yfinance')
try_import('dotenv', 'python-dotenv')
try_import('geopy', 'geopy')
try_import('hyphen', 'PyHyphen')
try_import('bs4', 'beautifulsoup4')
try_import('nltk')
import nltk
from dotenv import load_dotenv
from geopy.geocoders import Nominatim
import gradio as gr
import ollama
import openai
import pytz
import requests
import yfinance as yf
from hyphen import hyphenator
from bs4 import BeautifulSoup
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from tools import (
get_weather_by_location,
get_current_datetime,
get_ticker_price,
tools
)
rags = [
{ "name": "JPK", "enabled": False, "description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects." },
{ "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
]
# %%
# Defaults
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
#MODEL_NAME = "deepseek-r1:7b"
MODEL_NAME = "llama3.2"
LOG_LEVEL="debug"
USE_TLS=False
WEB_HOST="0.0.0.0"
WEB_PORT=5000
# %%
# Globals
system_message = f"""
You are a helpful information agent.
You have real time access to any website or URL the user asks about, to stock prices, the current date and time, and current weather information for locations in the United States.
You are running { { 'model': MODEL_NAME, 'gpu': 'Intel Arc B580', 'cpu': 'Intel Core i9-14900KS', 'ram': '64G' } }.
You were launched on {get_current_datetime()}.
If you use any real time access, do not mention your knowledge cutoff.
Give short, courteous answers, no more than 2-3 sentences.
Always be accurate. If you don't know the answer, say so. Do not make up details.
When you receive a response from summarize_site, you must:
1. Review the entire content returned by the second LLM
2. Provide the URL used to obtain the information.
3. Incorporate the information into your response as appropriate
""".strip()
tool_log = []
command_log = []
model = None
client = None
web_server = None
# %%
# Cmd line overrides
def parse_args():
parser = argparse.ArgumentParser(description="AI is Really Cool")
parser.add_argument("--ollama-server", type=str, default=OLLAMA_API_URL, help=f"Ollama API endpoint. default={OLLAMA_API_URL}")
parser.add_argument("--ollama-model", type=str, default=MODEL_NAME, help=f"LLM model to use. default={MODEL_NAME}")
parser.add_argument("--web-host", type=str, default=WEB_HOST, help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.")
parser.add_argument("--web-port", type=str, default=WEB_PORT, help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.")
parser.add_argument('--level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default=LOG_LEVEL, help=f'Set the logging level. default={LOG_LEVEL}')
return parser.parse_args()
def setup_logging(level):
numeric_level = getattr(logging, level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f"Invalid log level: {level}")
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
logging.info(f"Logging is set to {level} level.")
# %%
def is_words_downloaded():
try:
from nltk.corpus import words
words.words() # Attempt to access the dataset
return True
except LookupError:
return False
if not is_words_downloaded():
logging.info("Downloading nltk words corpus for random nick generation")
nltk.download('words')
# %%
def split_paragraph_with_hyphenation(text, line_length=80, language='en_US'):
"""
Split a paragraph into multiple lines with proper hyphenation.
Args:
text (str): The text to split.
line_length (int): The maximum length of each line.
language (str): The language code for hyphenation rules.
Returns:
[str]: The text split into multiple lines with proper hyphenation.
"""
# Initialize the hyphenator for the specified language
h = hyphenator.Hyphenator(language)
# First attempt: try to wrap without hyphenation
lines = textwrap.wrap(text, width=line_length)
# If any lines are too long, we need to apply hyphenation
result_lines = []
for line in lines:
# If the line is already short enough, keep it as is
if len(line) <= line_length:
result_lines.append(line)
continue
# Otherwise, we need to hyphenate
words = line.split()
current_line = ""
for word in words:
# If adding the word doesn't exceed the limit, add it
if len(current_line) + len(word) + (1 if current_line else 0) <= line_length:
if current_line:
current_line += " "
current_line += word
# If the word itself is too long, hyphenate it
elif len(word) > line_length - len(current_line) - (1 if current_line else 0):
# If we already have content on the line, add it to results
if current_line:
result_lines.append(current_line)
current_line = ""
# Get hyphenation points for the word
hyphenated = h.syllables(word)
if not hyphenated:
# If no hyphenation points found, just add the word to a new line
result_lines.append(word)
continue
# Try to find a suitable hyphenation point
partial_word = ""
for syllable in hyphenated:
if len(partial_word) + len(syllable) + 1 > line_length:
# Add hyphen to the partial word and start a new line
if partial_word:
result_lines.append(partial_word + "-")
partial_word = syllable
else:
# If a single syllable is too long, just add it
result_lines.append(syllable)
else:
partial_word += syllable
# Don't forget the remaining part
if partial_word:
current_line = partial_word
else:
# Start a new line with this word
result_lines.append(current_line)
current_line = word
# Don't forget any remaining content
if current_line:
result_lines.append(current_line)
return result_lines
# %%
async def handle_tool_calls(message):
response = []
tools_used = []
for tool_call in message['tool_calls']:
arguments = tool_call['function']['arguments']
tool = tool_call['function']['name']
match tool:
case 'get_ticker_price':
ticker = arguments.get('ticker')
if not ticker:
ret = None
else:
ret = get_ticker_price(ticker)
tools_used.append(f"{tool}({ticker})")
case 'summarize_site':
url = arguments.get('url');
question = arguments.get('question', 'what is the summary of this content?')
ret = await summarize_site(url, question)
tools_used.append(f"{tool}('{url}', '{question}')")
case 'get_current_datetime':
tz = arguments.get('timezone')
ret = get_current_datetime(tz)
tools_used.append(f"{tool}('{tz}')")
case 'get_weather_by_location':
city = arguments.get('city')
state = arguments.get('state')
ret = get_weather_by_location(city, state)
tools_used.append(f"{tool}('{city}', '{state}')")
case _:
ret = None
response.append({
"role": "tool",
"content": str(ret),
"name": tool_call['function']['name']
})
if len(response) == 1:
return response[0], tools_used
else:
return response, tools_used
# %%
def total_json_length(dict_array):
total = 0
for item in dict_array:
# Convert dictionary to minimized JSON string
json_string = json.dumps(item, separators=(',', ':'))
total += len(json_string)
return total
async def summarize_site(url, question):
"""
Fetches content from a URL, extracts the text, and uses Ollama to summarize it.
Args:
url (str): The URL of the website to summarize
Returns:
str: A summary of the website content
"""
global model, client
try:
# Fetch the webpage
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
logging.info(f"Fetching {url}")
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
logging.info(f"{url} returned. Processing...")
# Parse the HTML
soup = BeautifulSoup(response.text, 'html.parser')
# Remove script and style elements
for script in soup(["script", "style"]):
script.extract()
# Get text content
text = soup.get_text(separator=' ', strip=True)
# Clean up text (remove extra whitespace)
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = ' '.join(chunk for chunk in chunks if chunk)
# Limit text length if needed (Ollama may have token limits)
max_chars = 100000
if len(text) > max_chars:
text = text[:max_chars] + "..."
# Create Ollama client
# logging.info(f"Requesting summary of: {text}")
# Generate summary using Ollama
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
response = client.generate(model=model,
system="You are given the contents of {url}. Answer the question about the contents",
prompt=prompt)
#logging.info(response['response'])
return {
'source': 'summarizer-llm',
'content': response['response'],
'metadata': get_current_datetime()
}
except requests.exceptions.RequestException as e:
return f"Error fetching the URL: {str(e)}"
except Exception as e:
return f"Error processing the website content: {str(e)}"
# %%
# %%
def is_valid_uuid(value):
try:
uuid_obj = uuid.UUID(value, version=4)
return str(uuid_obj) == value
except (ValueError, TypeError):
return False
def default_tools(tools):
return [{**tool, "enabled": True} for tool in tools]
def llm_tools(tools):
return [tool for tool in tools if tool.get("enabled", False) == True]
# %%
class WebServer:
def __init__(self, logging, client, model=MODEL_NAME):
self.logging = logging
self.app = FastAPI()
self.contexts = {}
self.client = client
self.model = model
self.processing = False
self.app.add_middleware(
CORSMiddleware,
allow_origins=["http://battle-linux.ketrenos.com:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
self.setup_routes()
def setup_routes(self):
@self.app.get('/')
async def root():
context = self.create_context()
self.logging.info(f"Redirecting non-session to {context['id']}")
return RedirectResponse(url=f"/{context['id']}", status_code=307)
#return JSONResponse({"redirect": f"/{context['id']}"})
@self.app.put('/api/reset/{context_id}')
async def put_reset(context_id: str, request: Request):
if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id)
data = await request.json()
try:
response = {}
for reset in data["reset"]:
match reset:
case "system-prompt":
context["system"] = [{"role": "system", "content": system_message}]
response["system-prompt"] = { "system-prompt": system_message }
case "rag":
context["rag"] = rags.copy()
response["rags"] = context["rag"]
case "tools":
context["tools"] = default_tools(tools)
response["tools"] = context["tools"]
case "history":
context["history"] = []
response["history"] = context["history"]
if not response:
return JSONResponse({ "error": "Usage: { reset: rag|tools|history|system-prompt}"})
else:
return JSONResponse(response);
except:
return JSONResponse({ "error": "Usage: { reset: rag|tools|history|system-prompt}"})
@self.app.put('/api/system-prompt/{context_id}')
async def put_system_prompt(context_id: str, request: Request):
if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id)
data = await request.json()
system_prompt = data["system-prompt"].strip()
if not system_prompt:
return JSONResponse({ "status": "error", "message": "System prompt can not be empty." })
context["system"] = [{"role": "system", "content": system_prompt}]
return JSONResponse({ "system-prompt": system_prompt })
@self.app.get('/api/system-prompt/{context_id}')
async def get_system_prompt(context_id: str):
context = self.upsert_context(context_id)
system_prompt = context["system"][0]["content"];
return JSONResponse({ "system-prompt": system_prompt })
@self.app.post('/api/chat/{context_id}')
async def chat_endpoint(context_id: str, request: Request):
context = self.upsert_context(context_id)
data = await request.json()
# Create a custom generator that ensures flushing
async def flush_generator():
async for message in self.chat(context=context, content=data['content']):
# Convert to JSON and add newline
yield json.dumps(message) + "\n"
# Explicitly flush after each yield
await asyncio.sleep(0) # Allow the event loop to process the write
# Return StreamingResponse with appropriate headers
return StreamingResponse(
flush_generator(),
media_type="application/json",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # Prevents Nginx buffering if you're using it
}
)
@self.app.post('/api/context')
async def create_context():
context = self.create_context()
self.logging.info(f"Generated new session as {context['id']}")
return JSONResponse(context)
@self.app.get('/api/history/{context_id}')
async def get_history(context_id: str):
context = self.upsert_context(context_id)
return JSONResponse(context["history"])
@self.app.get('/api/tools/{context_id}')
async def get_tools(context_id: str):
context = self.upsert_context(context_id)
return JSONResponse(context["tools"])
@self.app.put('/api/tools/{context_id}')
async def put_tools(context_id: str, request: Request):
if not is_valid_uuid(context_id):
logging.warning(f"Invalid context_id: {context_id}")
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
context = self.upsert_context(context_id)
try:
data = await request.json()
modify = data["tool"]
enabled = data["enabled"]
for tool in context["tools"]:
if modify == tool["function"]["name"]:
tool["enabled"] = enabled
return JSONResponse(context["tools"])
return JSONResponse({ "status": f"{modify} not found in tools." }), 404
except:
return JSONResponse({ "status": "error" }), 405
@self.app.get('/api/rags/{context_id}')
async def get_rags(context_id: str):
context = self.upsert_context(context_id)
return JSONResponse(context["rags"])
@self.app.get('/api/health')
async def health_check():
return JSONResponse({"status": "healthy"})
@self.app.get('/{path:path}')
async def serve_static(path: str):
full_path = os.path.join('/opt/airc/src/ketr-chat/build', path)
if os.path.exists(full_path) and os.path.isfile(full_path):
self.logging.info(f"Serve static request for {full_path}")
return FileResponse(full_path)
self.logging.info(f"Serve index.html for {path}")
return FileResponse('/opt/airc/src/ketr-chat/build/index.html')
def create_context(self, context_id = None):
if not context_id:
context_id = str(uuid.uuid4())
context = {
"id": context_id,
"system": [{"role": "system", "content": system_message}],
"history": [],
"tools": default_tools(tools),
"rags": rags.copy()
}
logging.info(f"{context_id} created and added to sessions.")
self.contexts[context_id] = context
return context
def upsert_context(self, context_id):
if not context_id:
logging.warning("No context ID provided. Creating a new context.")
return self.create_context()
if context_id in self.contexts:
logging.info(f"Context {context_id} found.")
return self.contexts[context_id]
logging.info(f"Context {context_id} not found. Creating new context.")
return self.create_context(context_id)
async def chat(self, context, content):
content = content.strip()
if not content:
yield {"status": "error", "message": "Invalid request"}
return
if self.processing:
yield {"status": "error", "message": "Busy"}
return
self.processing = True
try:
history = context["history"]
history.append({"role": "user", "content": content})
messages = context["system"] + history[-1:]
#logging.info(messages)
yield {"status": "processing", "message": "Processing request..."}
# Use the async generator in an async for loop
response = self.client.chat(model=self.model, messages=messages, tools=llm_tools(context["tools"]))
tools_used = []
yield {"status": "processing", "message": "Initial response received"}
if 'tool_calls' in response.get('message', {}):
yield {"status": "processing", "message": "Processing tool calls..."}
message = response['message']
tool_result, tools_used = await handle_tool_calls(message)
message_dict = {
'role': message.get('role', 'assistant'),
'content': message.get('content', '')
}
if 'tool_calls' in message:
message_dict['tool_calls'] = [
{'function': {'name': tc['function']['name'], 'arguments': tc['function']['arguments']}}
for tc in message['tool_calls']
]
messages.append(message_dict)
if isinstance(tool_result, list):
messages.extend(tool_result)
else:
messages.append(tool_result)
yield {"status": "processing", "message": "Generating final response..."}
response = self.client.chat(model=self.model, messages=messages, stream=False)
reply = response['message']['content']
if len(tools_used):
final_message = {"role": "assistant", "content": reply, 'metadata': {"title": f"🛠️ Tool(s) used: {','.join(tools_used)}"}}
else:
final_message = {"role": "assistant", "content": reply}
yield {"status": "done", "message": final_message}
except Exception as e:
logging.exception({ 'model': self.model, 'messages': messages, 'error': str(e) })
yield {"status": "error", "message": f"An error occurred: {str(e)}"}
finally:
self.processing = False
def run(self, host='0.0.0.0', port=5000, **kwargs):
import uvicorn
uvicorn.run(self.app, host=host, port=port)
# %%
# Main function to run everything
def main():
global client, model, web_server
# Parse command-line arguments
args = parse_args()
# Setup logging based on the provided level
setup_logging(args.level)
client = ollama.Client(host=args.ollama_server)
model = args.ollama_model
web_server = WebServer(logging, client, model)
logging.info(f"Starting web server at http://{args.web_host}:{args.web_port}")
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
# Run the main function using anyio
main()