1002 lines
40 KiB
Python
1002 lines
40 KiB
Python
# %%
|
|
# Imports [standard]
|
|
# Standard library modules (no try-except needed)
|
|
import argparse
|
|
import asyncio
|
|
import anyio
|
|
import json
|
|
import logging
|
|
import os
|
|
import queue
|
|
import re
|
|
import time
|
|
from datetime import datetime
|
|
import textwrap
|
|
import threading
|
|
import uuid
|
|
import random
|
|
|
|
def try_import(module_name, pip_name=None):
|
|
try:
|
|
__import__(module_name)
|
|
except ImportError:
|
|
print(f"Module '{module_name}' not found. Install it using:")
|
|
print(f" pip install {pip_name or module_name}")
|
|
|
|
# Third-party modules with import checks
|
|
try_import('gradio')
|
|
try_import('ollama')
|
|
try_import('openai')
|
|
try_import('pydle')
|
|
try_import('pytz')
|
|
try_import('requests')
|
|
try_import('yfinance', 'yfinance')
|
|
try_import('dotenv', 'python-dotenv')
|
|
try_import('geopy', 'geopy')
|
|
try_import('hyphen', 'PyHyphen')
|
|
try_import('bs4', 'beautifulsoup4')
|
|
try_import('flask')
|
|
try_import('flask_cors')
|
|
try_import('flask_sock')
|
|
try_import('nltk')
|
|
|
|
import nltk
|
|
from dotenv import load_dotenv
|
|
from geopy.geocoders import Nominatim
|
|
import gradio as gr
|
|
import ollama
|
|
import openai
|
|
import pydle
|
|
import pytz
|
|
import requests
|
|
import yfinance as yf
|
|
from hyphen import hyphenator
|
|
from bs4 import BeautifulSoup
|
|
from flask import Flask, request, jsonify, render_template, send_from_directory, redirect
|
|
from flask_cors import CORS
|
|
from flask_sock import Sock
|
|
|
|
from tools import (
|
|
get_weather_by_location,
|
|
get_current_datetime,
|
|
get_ticker_price,
|
|
tools
|
|
)
|
|
|
|
# %%
|
|
# Defaults
|
|
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
|
|
#MODEL_NAME = "deepseek-r1:7b"
|
|
MODEL_NAME = "llama3.2"
|
|
CHANNEL = "#airc-test"
|
|
NICK = "airc"
|
|
IRC_SERVER = "miniircd"
|
|
IRC_PORT = 6667
|
|
LOG_LEVEL="debug"
|
|
USE_TLS=False
|
|
GRADIO_HOST="0.0.0.0"
|
|
GRADIO_PORT=60673
|
|
GRADIO_ENABLE=False
|
|
WEB_HOST="0.0.0.0"
|
|
WEB_PORT=5000
|
|
WEB_DISABLE=False
|
|
BOT_ADMIN="james"
|
|
|
|
# %%
|
|
# Globals
|
|
system_message = f"""
|
|
You are a helpful information agent connected to the IRC network {IRC_SERVER}. Your name is {NICK}.
|
|
You have real time access to any website or URL the user asks about, to stock prices, the current date and time, and current weather information for locations in the United States.
|
|
You are running { { 'model': MODEL_NAME, 'gpu': 'Intel Arc B580', 'cpu': 'Intel Core i9-14900KS', 'ram': '64G' } }.
|
|
You were launched on {get_current_datetime()}.
|
|
If you use any real time access, do not mention your knowledge cutoff.
|
|
Give short, courteous answers, no more than 2-3 sentences.
|
|
Always be accurate. If you don't know the answer, say so. Do not make up details.
|
|
When you receive a response from summarize_site, you must:
|
|
1. Review the entire content returned by the second LLM
|
|
2. Provide the URL used to obtain the information.
|
|
3. Incorporate the information into your response as appropriate
|
|
"""
|
|
system_log = [{"role": "system", "content": system_message}]
|
|
tool_log = []
|
|
command_log = []
|
|
model = None
|
|
client = None
|
|
irc_bot = None
|
|
web_server = None
|
|
|
|
# %%
|
|
# Cmd line overrides
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
|
parser.add_argument("--irc-server", type=str, default=IRC_SERVER, help=f"IRC server address. Example: irc.libera.chat default={IRC_SERVER}")
|
|
parser.add_argument("--irc-port", type=int, default=IRC_PORT, help=f"IRC server port. default={IRC_PORT}")
|
|
parser.add_argument("--irc-nickname", type=str, default=NICK, help=f"Bot nickname. default={NICK}")
|
|
parser.add_argument("--irc-channel", type=str, default=CHANNEL, help=f"Channel to join. default={CHANNEL}")
|
|
parser.add_argument("--irc-use-tls", type=bool, default=USE_TLS, help=f"Use TLS with --irc-server. default={USE_TLS}")
|
|
parser.add_argument("--irc-bot-admin", type=str, default=BOT_ADMIN, help=f"Nick that can send admin commands via IRC. default={BOT_ADMIN}")
|
|
parser.add_argument("--ollama-server", type=str, default=OLLAMA_API_URL, help=f"Ollama API endpoint. default={OLLAMA_API_URL}")
|
|
parser.add_argument("--ollama-model", type=str, default=MODEL_NAME, help=f"LLM model to use. default={MODEL_NAME}")
|
|
parser.add_argument("--gradio-host", type=str, default=GRADIO_HOST, help=f"Host to launch gradio on. default={GRADIO_HOST} only if --gradio-enable is specified.")
|
|
parser.add_argument("--gradio-port", type=str, default=GRADIO_PORT, help=f"Port to launch gradio on. default={GRADIO_PORT} only if --gradio-enable is specified.")
|
|
parser.add_argument("--gradio-enable", action="store_true", default=GRADIO_ENABLE, help=f"If set to True, enable Gradio. default={GRADIO_ENABLE}")
|
|
parser.add_argument("--web-host", type=str, default=WEB_HOST, help=f"Host to launch Flask web server. default={WEB_HOST} only if --web-disable not specified.")
|
|
parser.add_argument("--web-port", type=str, default=WEB_PORT, help=f"Port to launch Flask web server. default={WEB_PORT} only if --web-disable not specified.")
|
|
parser.add_argument("--web-disable", action="store_true", default=WEB_DISABLE, help=f"If set to True, disable Flask web server. default={WEB_DISABLE}")
|
|
parser.add_argument('--level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
|
default=LOG_LEVEL, help=f'Set the logging level. default={LOG_LEVEL}')
|
|
return parser.parse_args()
|
|
|
|
def setup_logging(level):
|
|
numeric_level = getattr(logging, level.upper(), None)
|
|
if not isinstance(numeric_level, int):
|
|
raise ValueError(f"Invalid log level: {level}")
|
|
|
|
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
|
|
|
|
logging.info(f"Logging is set to {level} level.")
|
|
|
|
# %%
|
|
def is_words_downloaded():
|
|
try:
|
|
from nltk.corpus import words
|
|
words.words() # Attempt to access the dataset
|
|
return True
|
|
except LookupError:
|
|
return False
|
|
|
|
if not is_words_downloaded():
|
|
logging.info("Downloading nltk words corpus for random nick generation")
|
|
nltk.download('words')
|
|
|
|
def random_nick():
|
|
from nltk.corpus import words
|
|
word_list = words.words()
|
|
return random.choice(word_list).capitalize()
|
|
|
|
# %%
|
|
def split_paragraph_with_hyphenation(text, line_length=80, language='en_US'):
|
|
"""
|
|
Split a paragraph into multiple lines with proper hyphenation.
|
|
|
|
Args:
|
|
text (str): The text to split.
|
|
line_length (int): The maximum length of each line.
|
|
language (str): The language code for hyphenation rules.
|
|
|
|
Returns:
|
|
[str]: The text split into multiple lines with proper hyphenation.
|
|
"""
|
|
# Initialize the hyphenator for the specified language
|
|
h = hyphenator.Hyphenator(language)
|
|
|
|
# First attempt: try to wrap without hyphenation
|
|
lines = textwrap.wrap(text, width=line_length)
|
|
|
|
# If any lines are too long, we need to apply hyphenation
|
|
result_lines = []
|
|
|
|
for line in lines:
|
|
# If the line is already short enough, keep it as is
|
|
if len(line) <= line_length:
|
|
result_lines.append(line)
|
|
continue
|
|
|
|
# Otherwise, we need to hyphenate
|
|
words = line.split()
|
|
current_line = ""
|
|
|
|
for word in words:
|
|
# If adding the word doesn't exceed the limit, add it
|
|
if len(current_line) + len(word) + (1 if current_line else 0) <= line_length:
|
|
if current_line:
|
|
current_line += " "
|
|
current_line += word
|
|
# If the word itself is too long, hyphenate it
|
|
elif len(word) > line_length - len(current_line) - (1 if current_line else 0):
|
|
# If we already have content on the line, add it to results
|
|
if current_line:
|
|
result_lines.append(current_line)
|
|
current_line = ""
|
|
|
|
# Get hyphenation points for the word
|
|
hyphenated = h.syllables(word)
|
|
|
|
if not hyphenated:
|
|
# If no hyphenation points found, just add the word to a new line
|
|
result_lines.append(word)
|
|
continue
|
|
|
|
# Try to find a suitable hyphenation point
|
|
partial_word = ""
|
|
for syllable in hyphenated:
|
|
if len(partial_word) + len(syllable) + 1 > line_length:
|
|
# Add hyphen to the partial word and start a new line
|
|
if partial_word:
|
|
result_lines.append(partial_word + "-")
|
|
partial_word = syllable
|
|
else:
|
|
# If a single syllable is too long, just add it
|
|
result_lines.append(syllable)
|
|
else:
|
|
partial_word += syllable
|
|
|
|
# Don't forget the remaining part
|
|
if partial_word:
|
|
current_line = partial_word
|
|
|
|
else:
|
|
# Start a new line with this word
|
|
result_lines.append(current_line)
|
|
current_line = word
|
|
|
|
# Don't forget any remaining content
|
|
if current_line:
|
|
result_lines.append(current_line)
|
|
|
|
return result_lines
|
|
|
|
# %%
|
|
async def handle_tool_calls(message):
|
|
response = []
|
|
tools_used = []
|
|
for tool_call in message['tool_calls']:
|
|
arguments = tool_call['function']['arguments']
|
|
tool = tool_call['function']['name']
|
|
match tool:
|
|
case 'get_ticker_price':
|
|
ticker = arguments.get('ticker')
|
|
if not ticker:
|
|
ret = None
|
|
else:
|
|
ret = get_ticker_price(ticker)
|
|
tools_used.append(f"{tool}({ticker})")
|
|
case 'summarize_site':
|
|
url = arguments.get('url');
|
|
question = arguments.get('question', 'what is the summary of this content?')
|
|
ret = await summarize_site(url, question)
|
|
tools_used.append(f"{tool}('{url}', '{question}')")
|
|
case 'get_current_datetime':
|
|
tz = arguments.get('timezone')
|
|
ret = get_current_datetime(tz)
|
|
tools_used.append(f"{tool}('{tz}')")
|
|
case 'get_weather_by_location':
|
|
city = arguments.get('city')
|
|
state = arguments.get('state')
|
|
ret = get_weather_by_location(city, state)
|
|
tools_used.append(f"{tool}('{city}', '{state}')")
|
|
case _:
|
|
ret = None
|
|
response.append({
|
|
"role": "tool",
|
|
"content": str(ret),
|
|
"name": tool_call['function']['name']
|
|
})
|
|
if len(response) == 1:
|
|
return response[0], tools_used
|
|
else:
|
|
return response, tools_used
|
|
|
|
# %%
|
|
def total_json_length(dict_array):
|
|
total = 0
|
|
for item in dict_array:
|
|
# Convert dictionary to minimized JSON string
|
|
json_string = json.dumps(item, separators=(',', ':'))
|
|
total += len(json_string)
|
|
return total
|
|
|
|
async def summarize_site(url, question):
|
|
"""
|
|
Fetches content from a URL, extracts the text, and uses Ollama to summarize it.
|
|
|
|
Args:
|
|
url (str): The URL of the website to summarize
|
|
|
|
Returns:
|
|
str: A summary of the website content
|
|
"""
|
|
global model, client
|
|
try:
|
|
# Fetch the webpage
|
|
headers = {
|
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
|
}
|
|
logging.info(f"Fetching {url}")
|
|
response = requests.get(url, headers=headers, timeout=10)
|
|
response.raise_for_status()
|
|
logging.info(f"{url} returned. Processing...")
|
|
# Parse the HTML
|
|
soup = BeautifulSoup(response.text, 'html.parser')
|
|
|
|
# Remove script and style elements
|
|
for script in soup(["script", "style"]):
|
|
script.extract()
|
|
|
|
# Get text content
|
|
text = soup.get_text(separator=' ', strip=True)
|
|
|
|
# Clean up text (remove extra whitespace)
|
|
lines = (line.strip() for line in text.splitlines())
|
|
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
|
text = ' '.join(chunk for chunk in chunks if chunk)
|
|
|
|
# Limit text length if needed (Ollama may have token limits)
|
|
max_chars = 100000
|
|
if len(text) > max_chars:
|
|
text = text[:max_chars] + "..."
|
|
|
|
# Create Ollama client
|
|
logging.info(f"Requesting summary of: {text}")
|
|
|
|
# Generate summary using Ollama
|
|
prompt = f"CONTENTS:\n\n{text}\n\n{question}"
|
|
response = client.generate(model=model,
|
|
system="You are given the contents of {url}. Answer the question about the contents",
|
|
prompt=prompt)
|
|
|
|
logging.info(response['response'])
|
|
|
|
return {
|
|
'source': 'summarizer-llm',
|
|
'content': response['response'],
|
|
'metadata': get_current_datetime()
|
|
}
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
return f"Error fetching the URL: {str(e)}"
|
|
except Exception as e:
|
|
return f"Error processing the website content: {str(e)}"
|
|
|
|
async def chat(history):
|
|
global client, model, irc_bot, system_log, tool_log, web_server
|
|
if not client:
|
|
return history
|
|
|
|
message = history[-1]
|
|
|
|
logging.info(f"chat('{message}')")
|
|
messages = system_log + history
|
|
response = client.chat(model=model, messages=messages, tools=tools)
|
|
tools_used = []
|
|
if 'tool_calls' in response['message']:
|
|
message = response['message']
|
|
tool_result, tools_used = await handle_tool_calls(message)
|
|
|
|
# Convert Message object to a proper dictionary format
|
|
message_dict = {
|
|
'role': message.get('role', 'assistant'),
|
|
'content': message.get('content', '')
|
|
}
|
|
# Add tool_calls if present, ensuring they're properly serialized
|
|
if 'tool_calls' in message:
|
|
message_dict['tool_calls'] = [
|
|
{'function': {'name': tc.function.name, 'arguments': tc.function.arguments}}
|
|
for tc in message['tool_calls']
|
|
]
|
|
|
|
messages.append(message_dict) # Add properly formatted dict instead of Message object
|
|
if isinstance(tool_result, list):
|
|
messages.extend(tool_result)
|
|
else:
|
|
messages.append(tool_result)
|
|
try:
|
|
response = client.chat(model=model, messages=messages)
|
|
except Exception:
|
|
logging.exception({ 'model': model, 'messages': messages })
|
|
|
|
tool_log.append({ 'call': message, 'tool_result': tool_result, 'tools_used': tools_used, 'response': response['message']['content']})
|
|
|
|
reply = response['message']['content']
|
|
if len(tools_used):
|
|
history += [{"role":"assistant", "content":reply, 'metadata': {"title": f"🛠️ Tool(s) used: {','.join(tools_used)}"}}]
|
|
else:
|
|
history += [{"role":"assistant", "content":reply}]
|
|
|
|
return history
|
|
|
|
# %%
|
|
|
|
# %%
|
|
|
|
class DynamicIRCBot(pydle.Client):
|
|
def __init__(self, nickname, channel, bot_admin, system_info, burst_limit = 5, rate_limit = 1.0, burst_reset_timeout = 10.0, **kwargs):
|
|
super().__init__(nickname, **kwargs)
|
|
self.histories = {}
|
|
self.channel = channel
|
|
self.bot_admin = bot_admin
|
|
self.system_info = system_info
|
|
# Message throttling
|
|
self.burst_limit = burst_limit
|
|
self.sent_burst = 0
|
|
self.rate_limit = rate_limit
|
|
self.burst_reset_timeout = burst_reset_timeout
|
|
self.sent_burst = 0 # Track messages sent in burst
|
|
self.last_message_time = None # Track last message time
|
|
self._message_queue = asyncio.Queue()
|
|
self._task = asyncio.create_task(self._send_from_queue())
|
|
self.processing = False
|
|
|
|
async def _send_from_queue(self):
|
|
"""Background task that sends queued messages with burst + rate limiting."""
|
|
while True:
|
|
target, message = await self._message_queue.get()
|
|
|
|
# If burst is still available, send immediately
|
|
if self.sent_burst < self.burst_limit:
|
|
self.sent_burst += 1
|
|
else:
|
|
await asyncio.sleep(self.rate_limit) # Apply rate limit
|
|
|
|
logging.debug(f"Sending {message} => {target} from queue")
|
|
await super().message(target, message) # Send message
|
|
self.last_message_time = asyncio.get_event_loop().time() # Update last message timestamp
|
|
|
|
# Start burst reset countdown after each message
|
|
asyncio.create_task(self._reset_burst_after_inactivity())
|
|
|
|
async def _reset_burst_after_inactivity(self):
|
|
"""Resets burst counter only if no new messages are sent within timeout."""
|
|
last_time = self.last_message_time
|
|
await asyncio.sleep(self.burst_reset_timeout) # Wait for inactivity period
|
|
|
|
# Only reset if no new messages were sent during the wait
|
|
if self.last_message_time == last_time:
|
|
self.sent_burst = 0
|
|
logging.info("Burst limit reset due to inactivity.")
|
|
|
|
async def message(self, target, message):
|
|
"""Splits a multi-line message and sends each line separately. If more than 10 lines, truncate and add a message."""
|
|
max_lines = 10
|
|
irc_lines = []
|
|
for line in message.splitlines():
|
|
lines = split_paragraph_with_hyphenation(line, line_length=300)
|
|
irc_lines.extend(lines)
|
|
|
|
# Send the first 'max_lines' non-empty lines
|
|
i=0
|
|
sent_lines=0
|
|
while i < len(irc_lines) and sent_lines < max_lines:
|
|
line = irc_lines[i].strip()
|
|
i+=1
|
|
if line != "":
|
|
sent_lines += 1
|
|
await self._message_queue.put((target, line))
|
|
|
|
# If there are more than 10 lines, add the truncation message
|
|
if len(irc_lines) > max_lines:
|
|
await self._message_queue.put((target, f"...and so on (message truncated to {max_lines} lines.)"))
|
|
|
|
async def on_connect(self):
|
|
await super().on_connect()
|
|
logging.info(f"CONNECT: {self.nickname}")
|
|
if self.channel:
|
|
await self.join(self.channel, self.nickname)
|
|
|
|
async def on_join(self, channel, user):
|
|
await super().on_join(channel, user)
|
|
logging.info(f"JOIN: {user} => {channel}")
|
|
|
|
async def on_part(self, channel, user):
|
|
await super().on_part(channel, user)
|
|
logging.info(f"PART: {channel} => {user}")
|
|
|
|
async def on_disconnect(self, expected):
|
|
self.logger.error("Disconnected. Reconnecting in 5 seconds...")
|
|
await asyncio.sleep(5)
|
|
try:
|
|
await self.connect(self.host, self.port, tls=self.tls)
|
|
except Exception as e:
|
|
self.logger.error(f"Reconnection failed: {e}")
|
|
|
|
async def on_message(self, target, source, message, session=None, local_user=None):
|
|
global system_log, tool_log, system_log, command_log, web_server
|
|
|
|
if not local_user:
|
|
await super().on_message(target, source, message)
|
|
|
|
message = message.strip()
|
|
logging.info(f"MESSAGE: {source} => {target}: {message}")
|
|
if source == self.nickname and not local_user:
|
|
return
|
|
|
|
if session == None:
|
|
session = target
|
|
|
|
if self.processing:
|
|
await self.message(target, f"I'm already processing a query.")
|
|
return
|
|
if session not in self.histories:
|
|
self.histories[session] = []
|
|
history = self.histories[session]
|
|
|
|
last_message = history[-1] if len(history) > 0 and history[-1]["role"] == "user" else None
|
|
try:
|
|
matches = re.match(r"^([^:]+)\s*:\s*(.*)$", message)
|
|
if matches:
|
|
user = matches.group(1).strip()
|
|
content = matches.group(2).strip()
|
|
else:
|
|
user = None
|
|
content = message
|
|
|
|
# If this message is not directed to the bot
|
|
if target != self.nickname and (not user or user != self.nickname):
|
|
logging.info(f"Message not directed to {self.nickname}")
|
|
# Add this message to the history either to the current 'user' context or create
|
|
# add a new message
|
|
if last_message:
|
|
logging.info(f"Modifying last USER context")
|
|
last_message['content'] += f"\n{source}: {content}"
|
|
else:
|
|
logging.info(f"Appending new USER context")
|
|
last_message = {
|
|
"role": "user",
|
|
"content": f"{source}: {content}"
|
|
}
|
|
history.append(last_message)
|
|
return
|
|
|
|
matches = re.match(r"^!([^\s]+)\s*(.*)?$", content)
|
|
if not matches:
|
|
logging.info(f"Non-command directed message to {self.nickname}: Invoking chat...")
|
|
self.processing = True
|
|
if web_server:
|
|
for session in web_server.sessions.values():
|
|
for socket in session['sockets']:
|
|
socket.send(json.dumps({"type": "processing", "value": self.processing}))
|
|
|
|
# Add this message to the history either to the current 'user' context or create
|
|
# add a new message
|
|
if last_message:
|
|
logging.info(f"Modifying last USER context")
|
|
last_message['content'] += f"\n{source}: {content}"
|
|
else:
|
|
logging.info(f"Appending new USER context")
|
|
last_message = {
|
|
"role": "user",
|
|
"content": f"{source}: {content}"
|
|
}
|
|
history.append(last_message)
|
|
history = await chat(history)
|
|
chat_response = history[-1]
|
|
await self.message(target, chat_response['content'])
|
|
|
|
self.processing = False
|
|
if web_server:
|
|
for session in web_server.sessions.values():
|
|
for socket in session['sockets']:
|
|
socket.send(json.dumps({"type": "processing", "value": self.processing}))
|
|
|
|
return
|
|
|
|
command = matches.group(1)
|
|
arguments = matches.group(2).strip()
|
|
logging.info(f"Command directed to {self.nickname}: command={command}, arguments={arguments}")
|
|
is_admin = source == self.nickname or source == self.bot_admin
|
|
match command:
|
|
case "help":
|
|
response = f"info, context, reset, system [prompt], server [address], join channel"
|
|
|
|
case "info":
|
|
response = str(self.system_info)
|
|
|
|
case "context":
|
|
system_log_size = total_json_length(system_log)
|
|
history_size = total_json_length(history)
|
|
tools_size = total_json_length(tools)
|
|
total_size = system_log_size + history_size + tools_size
|
|
response = f"\nsystem prompt: {system_log_size}"
|
|
response += f"\nhistory: {history_size} in {len(history)} entries."
|
|
response += f"\ntools: {tools_size} in {len(tools)} tools."
|
|
response += f"\ntotal context: {total_size}"
|
|
response += f"\ntotal tool calls: {len(tool_log)}"
|
|
|
|
case "reset":
|
|
system_log = [{"role": "system", "content": system_message}]
|
|
history.clear()
|
|
tool_log = []
|
|
command_log = []
|
|
response = 'All contexts reset'
|
|
|
|
case "system":
|
|
if arguments != "":
|
|
system_log = [{"role": "system", "content": arguments}]
|
|
response = 'System message updated.'
|
|
else:
|
|
lines = [line.strip() for line in system_log[0]['content'].split('\n') if line.strip()]
|
|
response = " ".join(lines)
|
|
|
|
case "server":
|
|
if not is_admin:
|
|
response = "You need to be admin to use this command."
|
|
else:
|
|
server = arguments.split(" ", 1)
|
|
if server[0] == "":
|
|
server = IRC_SERVER
|
|
else:
|
|
server = server[0]
|
|
try:
|
|
await self.connect(server, 6667, tls=False)
|
|
response="Connected to {server}"
|
|
except Exception:
|
|
response = f"Unable to connect to {server}"
|
|
logging.exception({ "error": f"Unable to process message {content}"})
|
|
|
|
case "join":
|
|
if not is_admin:
|
|
response = "You need to be admin to use this command."
|
|
else:
|
|
channel = arguments.strip()
|
|
if channel == "" or re.match(r"\s", channel):
|
|
response = "Usage: !join CHANNEL"
|
|
else:
|
|
if not re.match(r"^#", channel):
|
|
channel = f"#{channel}"
|
|
if self.channel and self.channel != channel:
|
|
await self.part(channel)
|
|
if channel:
|
|
await self.bot.join(channel)
|
|
self.channel = channel
|
|
response = f"Joined {channel}."
|
|
|
|
case _:
|
|
response = f"Unrecognized command: {command}"
|
|
|
|
await self.message(target, f"!{command}: {response}")
|
|
command_log.append({ 'source': source, 'command': f"{content}", 'response': response })
|
|
except Exception:
|
|
logging.exception({ "error": f"Unable to process message {content}"})
|
|
await self.message(target, f"I'm experiencing difficulties processing '{content}'")
|
|
|
|
# %%
|
|
last_history_len = 0
|
|
last_command_len = 0
|
|
css = """
|
|
body, html, #root {
|
|
background-color: #F0F0F0;
|
|
}
|
|
"""
|
|
async def create_ui():
|
|
global irc_bot
|
|
|
|
with gr.Blocks(css=css, fill_height=True, fill_width=True) as ui:
|
|
with gr.Row(scale=1):
|
|
with gr.Column(scale=1):
|
|
chatbot = gr.Chatbot(
|
|
# value=check_message_queue,
|
|
# every=1,
|
|
type="messages",
|
|
#height=600, # Set a fixed height
|
|
container=True, # Enable scrolling
|
|
elem_id="chatbot",
|
|
scale=1
|
|
)
|
|
entry = gr.Textbox(
|
|
label="Chat with our AI Assistant:",
|
|
container=False,
|
|
scale=0
|
|
)
|
|
with gr.Column(scale=1):
|
|
chat_history = gr.JSON(
|
|
system_log,
|
|
label="Chat History",
|
|
height=100, # Set height for JSON viewer,
|
|
scale=1
|
|
)
|
|
tool_history = gr.JSON(
|
|
tool_log,
|
|
label="Tool calls",
|
|
height=100, # Set height for JSON viewer
|
|
scale=1
|
|
)
|
|
command_history = gr.JSON(
|
|
command_log,
|
|
label="Command calls",
|
|
height=100, # Set height for JSON viewer
|
|
scale=1
|
|
)
|
|
with gr.Row(scale=0):
|
|
clear = gr.Button("Clear")
|
|
refresh = gr.Button("Sync with IRC")
|
|
|
|
async def do_entry(message):
|
|
if not irc_bot:
|
|
return gr.skip()
|
|
await irc_bot.message(irc_bot.channel, f"[console] {message}")
|
|
await irc_bot.on_message(irc_bot.channel, irc_bot.nickname, f"{irc_bot.nickname}: {message}", session="gradio", local_user="gradio")
|
|
return "", irc_bot.history
|
|
|
|
def do_clear():
|
|
if not irc_bot:
|
|
return gr.skip()
|
|
irc_bot.history = []
|
|
tool_log = []
|
|
command_log = []
|
|
return irc_bot.history, system_log, tool_log, command_log
|
|
|
|
def update_log(history):
|
|
if not irc_bot:
|
|
return gr.skip()
|
|
# This function updates the log after the chatbot responds
|
|
return system_log + irc_bot.history, tool_log, command_log
|
|
|
|
def get_history():
|
|
if not irc_bot:
|
|
return gr.skip()
|
|
return irc_bot.history, system_log + irc_bot.history, tool_log, command_log
|
|
|
|
entry.submit(
|
|
do_entry,
|
|
inputs=[entry],
|
|
outputs=[entry, chatbot]
|
|
).then(
|
|
update_log, # This new function updates the log after chatbot processing
|
|
inputs=chatbot,
|
|
outputs=[chat_history, tool_history, command_history]
|
|
)
|
|
|
|
refresh.click(get_history, inputs=None, outputs=[chatbot, chat_history, tool_history, command_history])
|
|
|
|
clear.click(do_clear, inputs=None, outputs=[chatbot, chat_history, tool_history, command_history], queue=False)
|
|
|
|
return ui
|
|
|
|
# %%
|
|
def is_valid_uuid(value):
|
|
try:
|
|
uuid_obj = uuid.UUID(value, version=4)
|
|
return str(uuid_obj) == value
|
|
except (ValueError, TypeError):
|
|
return False
|
|
|
|
# %%
|
|
class WebServer:
|
|
"""Web interface"""
|
|
|
|
def __init__(self, logging):
|
|
self.logging = logging
|
|
self.app = Flask(__name__, static_folder='/opt/airc/src/ketr-chat/build', static_url_path='')
|
|
self.sock = Sock(self.app)
|
|
self.sessions = {}
|
|
|
|
CORS(self.app, resources={r"/*": {"origins": "http://battle-linux.ketrenos.com:3000"}})
|
|
|
|
# Setup routes
|
|
self.setup_routes()
|
|
|
|
# Generate a unique session ID
|
|
def generate_session(self, socket, existing_uuid=None):
|
|
session = {
|
|
"id": existing_uuid if existing_uuid else str(uuid.uuid4()),
|
|
"system": system_message,
|
|
"users": [],
|
|
"history": [],
|
|
"sockets": []
|
|
}
|
|
logging.info(f"{session['id']} created and added to sessions.")
|
|
self.sessions[session['id']] = session
|
|
return session
|
|
|
|
def setup_routes(self):
|
|
"""Setup Flask routes"""
|
|
|
|
@self.sock.route('/api/ws/<session_id>')
|
|
def websocket(ws, session_id):
|
|
user = random_nick()
|
|
if not session_id in self.sessions:
|
|
self.generate_session(ws, session_id)
|
|
|
|
if ws not in self.sessions[session_id]['sockets']:
|
|
self.sessions[session_id]['sockets'].append(ws)
|
|
|
|
while True:
|
|
try:
|
|
data = ws.receive()
|
|
data = json.loads(data)
|
|
|
|
logging.info(f"Message received from {user}: {data['type']}")
|
|
|
|
if 'session' not in data or 'type' not in data:
|
|
ws.send(json.dumps({"type": "error", "error": f"Invalid request: {data}"}))
|
|
continue
|
|
|
|
if session_id != data['session']:
|
|
ws.send(json.dumps({"type": "error", "error": f"Invalid request: {data}"}))
|
|
|
|
self.sessions[session_id]['last_access'] = datetime.now()
|
|
|
|
if user not in self.sessions[session_id]['users']:
|
|
logging.info(f"Adding {user} to session {session_id}. Existing users: {self.sessions[session_id]['users']}")
|
|
self.sessions[session_id]['users'].append(user)
|
|
for socket in self.sessions[session_id]['sockets']:
|
|
socket.send(json.dumps({"type": "users", "update": self.sessions[session_id]['users']}))
|
|
|
|
match data['type']:
|
|
case "processing":
|
|
if irc_bot:
|
|
ws.send(json.dumps({"type": "processing", "value": irc_bot.processing}))
|
|
else:
|
|
ws.send(json.dumps({"type": "processing", "value": False}))
|
|
|
|
case "user-change":
|
|
if data['value'] in self.sessions[session_id]['users']:
|
|
ws.send(json.dumps({"type": "user", "update": user}))
|
|
else:
|
|
self.sessions[session_id]['users'] = [data['value'] if name == user else name for name in self.sessions[session_id]['users']]
|
|
user = data['value']
|
|
ws.send(json.dumps({"type": "user", "update": user}))
|
|
for socket in self.sessions[session_id]['sockets']:
|
|
socket.send(json.dumps({"type": "users", "update": self.sessions[session_id]['users']}))
|
|
|
|
case "user":
|
|
ws.send(json.dumps({"type": "user", "update": user}))
|
|
|
|
case "users":
|
|
ws.send(json.dumps({"type": "users", "update": self.sessions[session_id]['users']}))
|
|
|
|
case "history":
|
|
if not irc_bot:
|
|
ws.send(json.dumps({"type": "history", "update": []}))
|
|
else:
|
|
if session_id in irc_bot.histories:
|
|
ws.send(json.dumps({"type": "history", "update": irc_bot.histories[session_id]}))
|
|
else:
|
|
ws.send(json.dumps({"type": "history", "update": []}))
|
|
|
|
case _:
|
|
ws.send(json.dumps({"type": "error", "error": f"Invalid request type: {data['type']}"}))
|
|
except Exception as e:
|
|
logging.error(f"WebSocket error: {str(e)}")
|
|
if user in self.sessions[session_id]['users']:
|
|
self.sessions[session_id]['users'].remove(user)
|
|
if ws in self.sessions[session_id]['sockets']:
|
|
self.sessions[session_id]['sockets'].remove(ws)
|
|
for socket in self.sessions[session_id]['sockets']:
|
|
socket.send(json.dumps({"type": "users", "update": self.sessions[session_id]['users']}))
|
|
return
|
|
|
|
# Serve React app - This catches all routes not matched by API endpoints
|
|
@self.app.route('/')
|
|
def root():
|
|
# Generate a new unique session ID
|
|
session = self.generate_session(None)
|
|
# Redirect to the unique session path
|
|
self.logging.info(f"Redirecting non-session to {session['id']}")
|
|
return redirect(f"/{session['id']}")
|
|
|
|
# Basic endpoint for chat completions
|
|
@self.app.route('/api/chat/<session_id>', methods=['POST'])
|
|
async def chat(session_id):
|
|
if not irc_bot:
|
|
return jsonify({ "error": "Bot not initialized" }), 400
|
|
try:
|
|
data = request.get_json()
|
|
await irc_bot.on_message(irc_bot.channel, irc_bot.nickname, f"{irc_bot.nickname}: {data}", session=session_id, local_user="web")
|
|
if session_id in irc_bot.histories:
|
|
for socket in self.sessions[session_id]['sockets']:
|
|
socket.send(json.dumps({"type": "history", "update": irc_bot.histories[session_id]}))
|
|
return jsonify({ "success": "Message submitted to chat agent" }), 200
|
|
except Exception as e:
|
|
logging.exception(data)
|
|
return jsonify({
|
|
"error": "Invalid request"
|
|
}), 400
|
|
|
|
# Basic endpoint for chat completions
|
|
@self.app.route('/api/session', methods=['GET'])
|
|
async def create_session():
|
|
# Generate a new unique session ID
|
|
session = self.generate_session(None)
|
|
# Redirect to the unique session path
|
|
self.logging.info(f"Generating new session as {session['id']}")
|
|
return jsonify(session), 200
|
|
|
|
# Context requests
|
|
@self.app.route('/api/history', methods=['GET'])
|
|
def http_history():
|
|
if not irc_bot:
|
|
return jsonify({ "error": "Bot not initialized" }), 400
|
|
return jsonify(irc_bot.history), 200
|
|
|
|
@self.app.route('/api/system', methods=['GET'])
|
|
def http_system():
|
|
if not irc_bot:
|
|
return jsonify({ "error": "Bot not initialized" }), 400
|
|
return jsonify(system_log), 200
|
|
|
|
@self.app.route('/api/tools', methods=['GET'])
|
|
def http_tools():
|
|
if not irc_bot:
|
|
return jsonify({ "error": "Bot not initialized" }), 400
|
|
return jsonify(tool_log), 200
|
|
|
|
# Health check endpoint
|
|
@self.app.route('/api/health', methods=['GET'])
|
|
def health():
|
|
return jsonify({"status": "healthy"}), 200
|
|
|
|
# Session route - serve React app for a specific session
|
|
# @self.app.route('/<session_id>')
|
|
# def session_route(session_id):
|
|
# logging.info(f"{session_id}")
|
|
# # Validate if session_id is a valid UUID format (optional)
|
|
# try:
|
|
# uuid.UUID(session_id)
|
|
# # Here you could look up session data in a database if needed
|
|
# return send_from_directory(
|
|
# self.app.static_folder,
|
|
# 'index.html',
|
|
# mimetype='text/html'
|
|
# )
|
|
# except ValueError:
|
|
# # If not a valid UUID, it might be another path
|
|
# if os.path.exists(self.app.static_folder + '/' + session_id):
|
|
# return send_from_directory(self.app.static_folder, session_id)
|
|
# else:
|
|
# return send_from_directory(self.app.static_folder, 'index.html')
|
|
|
|
# Serve static files from the React build folder
|
|
@self.app.route('/<session_id>')
|
|
def serve_static(session_id):
|
|
logging.info(f"Serve request for {session_id}")
|
|
path = session_id
|
|
if os.path.exists(os.path.join(self.app.static_folder, path)):
|
|
# If the file exists, serve it with the correct MIME type
|
|
return send_from_directory(self.app.static_folder, path)
|
|
else:
|
|
# For nested paths like 'static/js/main.js'
|
|
parts = path.split('/')
|
|
session_id = parts[0]
|
|
rest_path = '/'.join(parts[1:])
|
|
if not rest_path:
|
|
rest_path = 'index.html'
|
|
if os.path.exists(os.path.join(self.app.static_folder, rest_path)):
|
|
self.logging.info(f"Serving '{rest_path}' for {session_id}")
|
|
return send_from_directory(self.app.static_folder, rest_path)
|
|
|
|
# Default to serving index.html
|
|
logging.info("Fall through to index.html")
|
|
return send_from_directory(self.app.static_folder, 'index.html')
|
|
|
|
def run(self, host='0.0.0.0', port=5000, debug=False, **kwargs):
|
|
"""Run the web server"""
|
|
# Load documents
|
|
self.app.run(host=host, port=port, debug=debug, **kwargs)
|
|
|
|
# %%
|
|
|
|
# Main function to run everything
|
|
async def main():
|
|
global irc_bot, client, model, web_server
|
|
# Parse command-line arguments
|
|
args = parse_args()
|
|
if not re.match(r"^#", args.irc_channel):
|
|
args.irc_channel = f"#{args.irc_channel}"
|
|
# Setup logging based on the provided level
|
|
setup_logging(args.level)
|
|
|
|
client = ollama.Client(host=args.ollama_server)
|
|
model = args.ollama_model
|
|
|
|
irc_bot = DynamicIRCBot(args.irc_nickname, args.irc_channel, args.irc_bot_admin, args)
|
|
await irc_bot.connect(args.irc_server, args.irc_port, tls=args.irc_use_tls)
|
|
|
|
if args.gradio_enable:
|
|
ui = await create_ui()
|
|
ui.launch(server_name=args.gradio_host, server_port=args.gradio_port, prevent_thread_lock=True, pwa=True)
|
|
logging.info(args)
|
|
|
|
if not args.web_disable:
|
|
web_server = WebServer(logging)
|
|
logging.info(f"Starting web server at http://{args.web_host}:{args.web_port}")
|
|
threading.Thread(target=lambda: web_server.run(host=args.web_host, port=args.web_port, debug=True, use_reloader=False)).start()
|
|
|
|
try:
|
|
await irc_bot.handle_forever()
|
|
except Exception as e:
|
|
logging.exception({ "error": "irc_bot.handle_forever() failed"})
|
|
|
|
# Run the main function using anyio
|
|
asyncio.run(main())
|