719 lines
27 KiB
Python
719 lines
27 KiB
Python
# %%
|
|
# Imports [standard]
|
|
# Standard library modules (no try-except needed)
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import queue
|
|
import re
|
|
import time
|
|
from datetime import datetime
|
|
import textwrap
|
|
import threading
|
|
import uuid
|
|
import random
|
|
import subprocess
|
|
import re
|
|
|
|
def try_import(module_name, pip_name=None):
|
|
try:
|
|
__import__(module_name)
|
|
except ImportError:
|
|
print(f"Module '{module_name}' not found. Install it using:")
|
|
print(f" pip install {pip_name or module_name}")
|
|
|
|
# Third-party modules with import checks
|
|
try_import('gradio')
|
|
try_import('ollama')
|
|
try_import('pytz')
|
|
try_import('requests')
|
|
try_import('yfinance', 'yfinance')
|
|
try_import('dotenv', 'python-dotenv')
|
|
try_import('geopy', 'geopy')
|
|
try_import('hyphen', 'PyHyphen')
|
|
try_import('bs4', 'beautifulsoup4')
|
|
try_import('nltk')
|
|
try_import('fastapi')
|
|
|
|
import nltk
|
|
from dotenv import load_dotenv
|
|
from geopy.geocoders import Nominatim
|
|
import gradio as gr
|
|
import ollama
|
|
import pytz
|
|
import requests
|
|
import yfinance as yf
|
|
from hyphen import hyphenator
|
|
from bs4 import BeautifulSoup
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
|
|
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from utils import rag
|
|
|
|
from tools import (
|
|
get_weather_by_location,
|
|
get_current_datetime,
|
|
get_ticker_price,
|
|
tools
|
|
)
|
|
|
|
rags = [
|
|
{ "name": "JPK", "enabled": False, "description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects." },
|
|
{ "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." },
|
|
]
|
|
|
|
|
|
def get_installed_ram():
|
|
try:
|
|
with open('/proc/meminfo', 'r') as f:
|
|
meminfo = f.read()
|
|
match = re.search(r'MemTotal:\s+(\d+)', meminfo)
|
|
if match:
|
|
return f"{round(int(match.group(1)) / 1024**2, 2)}GB" # Convert KB to GB
|
|
except Exception as e:
|
|
return f"Error retrieving RAM: {e}"
|
|
|
|
def get_graphics_cards():
|
|
gpus = []
|
|
try:
|
|
# Run the ze-monitor utility
|
|
result = subprocess.run(['ze-monitor'], capture_output=True, text=True, check=True)
|
|
|
|
# Clean up the output (remove leading/trailing whitespace and newlines)
|
|
output = result.stdout.strip()
|
|
for line in output.splitlines():
|
|
# Updated regex to handle GPU names containing parentheses
|
|
match = re.match(r'^[^(]*\((.*)\)', line)
|
|
if match:
|
|
gpus.append(match.group(1))
|
|
|
|
return gpus
|
|
except Exception as e:
|
|
return f"Error retrieving GPU info: {e}"
|
|
|
|
def get_cpu_info():
|
|
try:
|
|
with open('/proc/cpuinfo', 'r') as f:
|
|
cpuinfo = f.read()
|
|
model_match = re.search(r'model name\s+:\s+(.+)', cpuinfo)
|
|
cores_match = re.findall(r'processor\s+:\s+\d+', cpuinfo)
|
|
if model_match and cores_match:
|
|
return f"{model_match.group(1)} with {len(cores_match)} cores"
|
|
except Exception as e:
|
|
return f"Error retrieving CPU info: {e}"
|
|
|
|
def system_info():
|
|
return {
|
|
"Installed RAM": get_installed_ram(),
|
|
"Graphics Card": get_graphics_cards(),
|
|
"CPU": get_cpu_info()
|
|
}
|
|
|
|
# %%
|
|
# Defaults
|
|
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
|
|
#MODEL_NAME = "deepseek-r1:7b"
|
|
#MODEL_NAME = "llama3.2"
|
|
MODEL_NAME = "qwen2.5:7b"
|
|
LOG_LEVEL="debug"
|
|
USE_TLS=False
|
|
WEB_HOST="0.0.0.0"
|
|
WEB_PORT=5000
|
|
|
|
# %%
|
|
# Globals
|
|
system_message = f"""
|
|
You are a helpful information agent.
|
|
You have real time access to any website or URL the user asks about, to stock prices, the current date and time, and current weather information for locations in the United States.
|
|
You are running { { 'model': MODEL_NAME, 'gpu': 'Intel Arc B580', 'cpu': 'Intel Core i9-14900KS', 'ram': '64G' } }.
|
|
You were launched on {get_current_datetime()}.
|
|
If you use any real time access, do not mention your knowledge cutoff.
|
|
Give short, courteous answers, no more than 2-3 sentences.
|
|
Always be accurate. If you don't know the answer, say so. Do not make up details.
|
|
When you receive a response from summarize_site, you must:
|
|
1. Review the entire content returned by the second LLM
|
|
2. Provide the URL used to obtain the information.
|
|
3. Incorporate the information into your response as appropriate
|
|
""".strip()
|
|
|
|
tool_log = []
|
|
command_log = []
|
|
model = None
|
|
client = None
|
|
web_server = None
|
|
|
|
# %%
|
|
# Cmd line overrides
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
|
parser.add_argument("--ollama-server", type=str, default=OLLAMA_API_URL, help=f"Ollama API endpoint. default={OLLAMA_API_URL}")
|
|
parser.add_argument("--ollama-model", type=str, default=MODEL_NAME, help=f"LLM model to use. default={MODEL_NAME}")
|
|
parser.add_argument("--web-host", type=str, default=WEB_HOST, help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.")
|
|
parser.add_argument("--web-port", type=str, default=WEB_PORT, help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.")
|
|
parser.add_argument('--level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
|
default=LOG_LEVEL, help=f'Set the logging level. default={LOG_LEVEL}')
|
|
return parser.parse_args()
|
|
|
|
def setup_logging(level):
|
|
numeric_level = getattr(logging, level.upper(), None)
|
|
if not isinstance(numeric_level, int):
|
|
raise ValueError(f"Invalid log level: {level}")
|
|
|
|
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
|
|
|
|
logging.info(f"Logging is set to {level} level.")
|
|
|
|
# %%
|
|
def is_words_downloaded():
|
|
try:
|
|
from nltk.corpus import words
|
|
words.words() # Attempt to access the dataset
|
|
return True
|
|
except LookupError:
|
|
return False
|
|
|
|
if not is_words_downloaded():
|
|
logging.info("Downloading nltk words corpus for random nick generation")
|
|
nltk.download('words')
|
|
|
|
# %%
|
|
def split_paragraph_with_hyphenation(text, line_length=80, language='en_US'):
|
|
"""
|
|
Split a paragraph into multiple lines with proper hyphenation.
|
|
|
|
Args:
|
|
text (str): The text to split.
|
|
line_length (int): The maximum length of each line.
|
|
language (str): The language code for hyphenation rules.
|
|
|
|
Returns:
|
|
[str]: The text split into multiple lines with proper hyphenation.
|
|
"""
|
|
# Initialize the hyphenator for the specified language
|
|
h = hyphenator.Hyphenator(language)
|
|
|
|
# First attempt: try to wrap without hyphenation
|
|
lines = textwrap.wrap(text, width=line_length)
|
|
|
|
# If any lines are too long, we need to apply hyphenation
|
|
result_lines = []
|
|
|
|
for line in lines:
|
|
# If the line is already short enough, keep it as is
|
|
if len(line) <= line_length:
|
|
result_lines.append(line)
|
|
continue
|
|
|
|
# Otherwise, we need to hyphenate
|
|
words = line.split()
|
|
current_line = ""
|
|
|
|
for word in words:
|
|
# If adding the word doesn't exceed the limit, add it
|
|
if len(current_line) + len(word) + (1 if current_line else 0) <= line_length:
|
|
if current_line:
|
|
current_line += " "
|
|
current_line += word
|
|
# If the word itself is too long, hyphenate it
|
|
elif len(word) > line_length - len(current_line) - (1 if current_line else 0):
|
|
# If we already have content on the line, add it to results
|
|
if current_line:
|
|
result_lines.append(current_line)
|
|
current_line = ""
|
|
|
|
# Get hyphenation points for the word
|
|
hyphenated = h.syllables(word)
|
|
|
|
if not hyphenated:
|
|
# If no hyphenation points found, just add the word to a new line
|
|
result_lines.append(word)
|
|
continue
|
|
|
|
# Try to find a suitable hyphenation point
|
|
partial_word = ""
|
|
for syllable in hyphenated:
|
|
if len(partial_word) + len(syllable) + 1 > line_length:
|
|
# Add hyphen to the partial word and start a new line
|
|
if partial_word:
|
|
result_lines.append(partial_word + "-")
|
|
partial_word = syllable
|
|
else:
|
|
# If a single syllable is too long, just add it
|
|
result_lines.append(syllable)
|
|
else:
|
|
partial_word += syllable
|
|
|
|
# Don't forget the remaining part
|
|
if partial_word:
|
|
current_line = partial_word
|
|
|
|
else:
|
|
# Start a new line with this word
|
|
result_lines.append(current_line)
|
|
current_line = word
|
|
|
|
# Don't forget any remaining content
|
|
if current_line:
|
|
result_lines.append(current_line)
|
|
|
|
return result_lines
|
|
|
|
# %%
|
|
async def handle_tool_calls(message):
|
|
response = []
|
|
tools_used = []
|
|
for tool_call in message['tool_calls']:
|
|
arguments = tool_call['function']['arguments']
|
|
tool = tool_call['function']['name']
|
|
match tool:
|
|
case 'get_ticker_price':
|
|
ticker = arguments.get('ticker')
|
|
if not ticker:
|
|
ret = None
|
|
else:
|
|
ret = get_ticker_price(ticker)
|
|
tools_used.append(f"{tool}({ticker})")
|
|
case 'summarize_site':
|
|
url = arguments.get('url');
|
|
question = arguments.get('question', 'what is the summary of this content?')
|
|
ret = await summarize_site(url, question)
|
|
tools_used.append(f"{tool}('{url}', '{question}')")
|
|
case 'get_current_datetime':
|
|
tz = arguments.get('timezone')
|
|
ret = get_current_datetime(tz)
|
|
tools_used.append(f"{tool}('{tz}')")
|
|
case 'get_weather_by_location':
|
|
city = arguments.get('city')
|
|
state = arguments.get('state')
|
|
ret = get_weather_by_location(city, state)
|
|
tools_used.append(f"{tool}('{city}', '{state}')")
|
|
case _:
|
|
ret = None
|
|
response.append({
|
|
"role": "tool",
|
|
"content": str(ret),
|
|
"name": tool_call['function']['name']
|
|
})
|
|
if len(response) == 1:
|
|
return response[0], tools_used
|
|
else:
|
|
return response, tools_used
|
|
|
|
# %%
|
|
def total_json_length(dict_array):
|
|
total = 0
|
|
for item in dict_array:
|
|
# Convert dictionary to minimized JSON string
|
|
json_string = json.dumps(item, separators=(',', ':'))
|
|
total += len(json_string)
|
|
return total
|
|
|
|
async def summarize_site(url, question):
|
|
"""
|
|
Fetches content from a URL, extracts the text, and uses Ollama to summarize it.
|
|
|
|
Args:
|
|
url (str): The URL of the website to summarize
|
|
|
|
Returns:
|
|
str: A summary of the website content
|
|
"""
|
|
global model, client
|
|
try:
|
|
# Fetch the webpage
|
|
headers = {
|
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
|
}
|
|
logging.info(f"Fetching {url}")
|
|
response = requests.get(url, headers=headers, timeout=10)
|
|
response.raise_for_status()
|
|
logging.info(f"{url} returned. Processing...")
|
|
# Parse the HTML
|
|
soup = BeautifulSoup(response.text, 'html.parser')
|
|
|
|
# Remove script and style elements
|
|
for script in soup(["script", "style"]):
|
|
script.extract()
|
|
|
|
# Get text content
|
|
text = soup.get_text(separator=' ', strip=True)
|
|
|
|
# Clean up text (remove extra whitespace)
|
|
lines = (line.strip() for line in text.splitlines())
|
|
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
|
text = ' '.join(chunk for chunk in chunks if chunk)
|
|
|
|
# Limit text length if needed (Ollama may have token limits)
|
|
max_chars = 100000
|
|
if len(text) > max_chars:
|
|
text = text[:max_chars] + "..."
|
|
|
|
# Create Ollama client
|
|
# logging.info(f"Requesting summary of: {text}")
|
|
|
|
# Generate summary using Ollama
|
|
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
|
response = client.generate(model=model,
|
|
system="You are given the contents of {url}. Answer the question about the contents",
|
|
prompt=prompt)
|
|
|
|
#logging.info(response['response'])
|
|
|
|
return {
|
|
'source': 'summarizer-llm',
|
|
'content': response['response'],
|
|
'metadata': get_current_datetime()
|
|
}
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
return f"Error fetching the URL: {str(e)}"
|
|
except Exception as e:
|
|
return f"Error processing the website content: {str(e)}"
|
|
|
|
# %%
|
|
|
|
# %%
|
|
def is_valid_uuid(value):
|
|
try:
|
|
uuid_obj = uuid.UUID(value, version=4)
|
|
return str(uuid_obj) == value
|
|
except (ValueError, TypeError):
|
|
return False
|
|
|
|
def default_tools(tools):
|
|
return [{**tool, "enabled": True} for tool in tools]
|
|
|
|
def llm_tools(tools):
|
|
return [tool for tool in tools if tool.get("enabled", False) == True]
|
|
|
|
# %%
|
|
class WebServer:
|
|
def __init__(self, logging, client, model=MODEL_NAME):
|
|
self.logging = logging
|
|
self.app = FastAPI()
|
|
self.contexts = {}
|
|
self.client = client
|
|
self.model = model
|
|
self.processing = False
|
|
|
|
self.app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://battle-linux.ketrenos.com:3000"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
self.setup_routes()
|
|
|
|
def setup_routes(self):
|
|
@self.app.get('/')
|
|
async def root():
|
|
context = self.create_context()
|
|
self.logging.info(f"Redirecting non-session to {context['id']}")
|
|
return RedirectResponse(url=f"/{context['id']}", status_code=307)
|
|
#return JSONResponse({"redirect": f"/{context['id']}"})
|
|
|
|
@self.app.put('/api/reset/{context_id}')
|
|
async def put_reset(context_id: str, request: Request):
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
data = await request.json()
|
|
try:
|
|
response = {}
|
|
for reset in data["reset"]:
|
|
match reset:
|
|
case "system-prompt":
|
|
context["system"] = [{"role": "system", "content": system_message}]
|
|
response["system-prompt"] = { "system-prompt": system_message }
|
|
case "rag":
|
|
context["rag"] = rags.copy()
|
|
response["rags"] = context["rag"]
|
|
case "tools":
|
|
context["tools"] = default_tools(tools)
|
|
response["tools"] = context["tools"]
|
|
case "history":
|
|
context["history"] = []
|
|
response["history"] = context["history"]
|
|
if not response:
|
|
return JSONResponse({ "error": "Usage: { reset: rag|tools|history|system-prompt}"})
|
|
else:
|
|
self.save_context(context_id)
|
|
return JSONResponse(response)
|
|
|
|
except:
|
|
return JSONResponse({ "error": "Usage: { reset: rag|tools|history|system-prompt}"})
|
|
|
|
|
|
@self.app.put('/api/system-prompt/{context_id}')
|
|
async def put_system_prompt(context_id: str, request: Request):
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
data = await request.json()
|
|
system_prompt = data["system-prompt"].strip()
|
|
if not system_prompt:
|
|
return JSONResponse({ "status": "error", "message": "System prompt can not be empty." })
|
|
context["system"] = [{"role": "system", "content": system_prompt}]
|
|
self.save_context(context_id)
|
|
return JSONResponse({ "system-prompt": system_prompt })
|
|
|
|
@self.app.get('/api/system-prompt/{context_id}')
|
|
async def get_system_prompt(context_id: str):
|
|
context = self.upsert_context(context_id)
|
|
system_prompt = context["system"][0]["content"];
|
|
return JSONResponse({ "system-prompt": system_prompt })
|
|
|
|
@self.app.get('/api/system-info/{context_id}')
|
|
async def get_system_info(context_id: str):
|
|
return JSONResponse(system_info())
|
|
|
|
@self.app.post('/api/chat/{context_id}')
|
|
async def chat_endpoint(context_id: str, request: Request):
|
|
context = self.upsert_context(context_id)
|
|
data = await request.json()
|
|
|
|
# Create a custom generator that ensures flushing
|
|
async def flush_generator():
|
|
async for message in self.chat(context=context, content=data['content']):
|
|
# Convert to JSON and add newline
|
|
yield json.dumps(message) + "\n"
|
|
# Save the history as its generated
|
|
self.save_context(context_id)
|
|
# Explicitly flush after each yield
|
|
await asyncio.sleep(0) # Allow the event loop to process the write
|
|
|
|
# Return StreamingResponse with appropriate headers
|
|
return StreamingResponse(
|
|
flush_generator(),
|
|
media_type="application/json",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no" # Prevents Nginx buffering if you're using it
|
|
}
|
|
)
|
|
|
|
@self.app.post('/api/context')
|
|
async def create_context():
|
|
context = self.create_context()
|
|
self.logging.info(f"Generated new session as {context['id']}")
|
|
return JSONResponse(context)
|
|
|
|
@self.app.get('/api/history/{context_id}')
|
|
async def get_history(context_id: str):
|
|
context = self.upsert_context(context_id)
|
|
return JSONResponse(context["history"])
|
|
|
|
@self.app.get('/api/tools/{context_id}')
|
|
async def get_tools(context_id: str):
|
|
context = self.upsert_context(context_id)
|
|
return JSONResponse(context["tools"])
|
|
|
|
@self.app.put('/api/tools/{context_id}')
|
|
async def put_tools(context_id: str, request: Request):
|
|
if not is_valid_uuid(context_id):
|
|
logging.warning(f"Invalid context_id: {context_id}")
|
|
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
|
context = self.upsert_context(context_id)
|
|
try:
|
|
data = await request.json()
|
|
modify = data["tool"]
|
|
enabled = data["enabled"]
|
|
for tool in context["tools"]:
|
|
if modify == tool["function"]["name"]:
|
|
tool["enabled"] = enabled
|
|
self.save_context(context_id)
|
|
return JSONResponse(context["tools"])
|
|
return JSONResponse({ "status": f"{modify} not found in tools." }), 404
|
|
except:
|
|
return JSONResponse({ "status": "error" }), 405
|
|
|
|
@self.app.get('/api/rags/{context_id}')
|
|
async def get_rags(context_id: str):
|
|
context = self.upsert_context(context_id)
|
|
return JSONResponse(context["rags"])
|
|
|
|
@self.app.get('/api/health')
|
|
async def health_check():
|
|
return JSONResponse({"status": "healthy"})
|
|
|
|
@self.app.get('/{path:path}')
|
|
async def serve_static(path: str):
|
|
full_path = os.path.join('/opt/airc/src/ketr-chat/build', path)
|
|
if os.path.exists(full_path) and os.path.isfile(full_path):
|
|
self.logging.info(f"Serve static request for {full_path}")
|
|
return FileResponse(full_path)
|
|
self.logging.info(f"Serve index.html for {path}")
|
|
return FileResponse('/opt/airc/src/ketr-chat/build/index.html')
|
|
|
|
def save_context(self, session_id):
|
|
"""
|
|
Serialize a Python dictionary to a file in the sessions directory.
|
|
|
|
Args:
|
|
data: Dictionary containing the session data
|
|
session_id: UUID string for the context. If it doesn't exist, it is created
|
|
|
|
Returns:
|
|
The session_id used for the file
|
|
"""
|
|
context = self.upsert_context(session_id)
|
|
|
|
# Create sessions directory if it doesn't exist
|
|
if not os.path.exists("sessions"):
|
|
os.makedirs("sessions")
|
|
|
|
# Create the full file path
|
|
file_path = os.path.join("sessions", session_id)
|
|
|
|
# Serialize the data to JSON and write to file
|
|
with open(file_path, 'w') as f:
|
|
json.dump(context, f)
|
|
|
|
return session_id
|
|
|
|
def load_context(self, session_id):
|
|
"""
|
|
Load a serialized Python dictionary from a file in the sessions directory.
|
|
|
|
Args:
|
|
session_id: UUID string for the filename
|
|
|
|
Returns:
|
|
The deserialized dictionary, or a new context if it doesn't exist on disk.
|
|
"""
|
|
file_path = os.path.join("sessions", session_id)
|
|
|
|
# Check if the file exists
|
|
if not os.path.exists(file_path):
|
|
return self.create_context(session_id)
|
|
|
|
# Read and deserialize the data
|
|
with open(file_path, 'r') as f:
|
|
self.contexts[session_id] = json.load(f)
|
|
|
|
return self.contexts[session_id]
|
|
|
|
|
|
def create_context(self, context_id = None):
|
|
if not context_id:
|
|
context_id = str(uuid.uuid4())
|
|
context = {
|
|
"id": context_id,
|
|
"system": [{"role": "system", "content": system_message}],
|
|
"history": [],
|
|
"tools": default_tools(tools),
|
|
"rags": rags.copy()
|
|
}
|
|
logging.info(f"{context_id} created and added to sessions.")
|
|
self.contexts[context_id] = context
|
|
return context
|
|
|
|
def upsert_context(self, context_id):
|
|
if not context_id:
|
|
logging.warning("No context ID provided. Creating a new context.")
|
|
return self.create_context()
|
|
if context_id in self.contexts:
|
|
logging.info(f"Context {context_id} found.")
|
|
return self.contexts[context_id]
|
|
logging.info(f"Context {context_id} not found. Creating new context.")
|
|
return self.load_context(context_id)
|
|
|
|
async def chat(self, context, content):
|
|
content = content.strip()
|
|
if not content:
|
|
yield {"status": "error", "message": "Invalid request"}
|
|
return
|
|
|
|
if self.processing:
|
|
yield {"status": "error", "message": "Busy"}
|
|
return
|
|
|
|
self.processing = True
|
|
|
|
try:
|
|
history = context["history"]
|
|
history.append({"role": "user", "content": content})
|
|
|
|
messages = context["system"] + history[-1:]
|
|
#logging.info(messages)
|
|
|
|
yield {"status": "processing", "message": "Processing request..."}
|
|
|
|
# Use the async generator in an async for loop
|
|
response = self.client.chat(model=self.model, messages=messages, tools=llm_tools(context["tools"]))
|
|
tools_used = []
|
|
|
|
yield {"status": "processing", "message": "Initial response received"}
|
|
|
|
if 'tool_calls' in response.get('message', {}):
|
|
yield {"status": "processing", "message": "Processing tool calls..."}
|
|
|
|
message = response['message']
|
|
tool_result, tools_used = await handle_tool_calls(message)
|
|
|
|
message_dict = {
|
|
'role': message.get('role', 'assistant'),
|
|
'content': message.get('content', '')
|
|
}
|
|
|
|
if 'tool_calls' in message:
|
|
message_dict['tool_calls'] = [
|
|
{'function': {'name': tc['function']['name'], 'arguments': tc['function']['arguments']}}
|
|
for tc in message['tool_calls']
|
|
]
|
|
messages.append(message_dict)
|
|
|
|
if isinstance(tool_result, list):
|
|
messages.extend(tool_result)
|
|
else:
|
|
messages.append(tool_result)
|
|
|
|
yield {"status": "processing", "message": "Generating final response..."}
|
|
response = self.client.chat(model=self.model, messages=messages, stream=False)
|
|
|
|
reply = response['message']['content']
|
|
if len(tools_used):
|
|
final_message = {"role": "assistant", "content": reply, 'metadata': {"title": f"🛠️ Tool(s) used: {','.join(tools_used)}"}}
|
|
else:
|
|
final_message = {"role": "assistant", "content": reply}
|
|
history.append(final_message)
|
|
yield {"status": "done", "message": final_message}
|
|
|
|
except Exception as e:
|
|
logging.exception({ 'model': self.model, 'messages': messages, 'error': str(e) })
|
|
yield {"status": "error", "message": f"An error occurred: {str(e)}"}
|
|
|
|
finally:
|
|
self.processing = False
|
|
|
|
def run(self, host='0.0.0.0', port=5000, **kwargs):
|
|
import uvicorn
|
|
uvicorn.run(self.app, host=host, port=port)
|
|
|
|
# %%
|
|
|
|
# Main function to run everything
|
|
def main():
|
|
global client, model, web_server
|
|
# Parse command-line arguments
|
|
args = parse_args()
|
|
|
|
# Setup logging based on the provided level
|
|
setup_logging(args.level)
|
|
|
|
client = ollama.Client(host=args.ollama_server)
|
|
model = args.ollama_model
|
|
|
|
web_server = WebServer(logging, client, model)
|
|
logging.info(f"Starting web server at http://{args.web_host}:{args.web_port}")
|
|
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
|
|
|
|
main()
|