621 lines
24 KiB
Python
621 lines
24 KiB
Python
# %%
|
|
# Imports [standard]
|
|
# Standard library modules (no try-except needed)
|
|
import argparse
|
|
import asyncio
|
|
import anyio
|
|
import json
|
|
import logging
|
|
import os
|
|
import queue
|
|
import re
|
|
import time
|
|
from datetime import datetime
|
|
import textwrap
|
|
|
|
def try_import(module_name, pip_name=None):
|
|
try:
|
|
__import__(module_name)
|
|
except ImportError:
|
|
print(f"Module '{module_name}' not found. Install it using:")
|
|
print(f" pip install {pip_name or module_name}")
|
|
|
|
# Third-party modules with import checks
|
|
try_import('gradio')
|
|
try_import('ollama')
|
|
try_import('openai')
|
|
try_import('pydle')
|
|
try_import('pytz')
|
|
try_import('requests')
|
|
try_import('yfinance', 'yfinance')
|
|
try_import('dotenv', 'python-dotenv')
|
|
try_import('geopy', 'geopy')
|
|
try_import('hyphen', 'PyHyphen')
|
|
|
|
from dotenv import load_dotenv
|
|
from geopy.geocoders import Nominatim
|
|
import gradio as gr
|
|
import ollama
|
|
import openai
|
|
import pydle
|
|
import pytz
|
|
import requests
|
|
import yfinance as yf
|
|
from hyphen import hyphenator
|
|
|
|
# Local defined imports
|
|
from tools import (
|
|
get_weather_by_location,
|
|
get_current_datetime,
|
|
get_ticker_price,
|
|
tools
|
|
)
|
|
|
|
# %%
|
|
# Defaults
|
|
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
|
|
#MODEL_NAME = "deepseek-r1:7b"
|
|
MODEL_NAME = "llama3.2"
|
|
CHANNEL = "#airc-test"
|
|
NICK = "airc"
|
|
IRC_SERVER = "miniircd"
|
|
IRC_PORT = 6667
|
|
LOG_LEVEL="debug"
|
|
USE_TLS=False
|
|
GRADIO_HOST="0.0.0.0"
|
|
GRADIO_PORT=60673
|
|
GRADIO_ENABLE=False
|
|
BOT_ADMIN="james"
|
|
|
|
# %%
|
|
# Globals
|
|
system_message = f"""
|
|
You are a helpful information agent connected to the IRC network {IRC_SERVER}. Your name is {NICK}.
|
|
You are running { { 'model': MODEL_NAME, 'gpu': 'Intel Arc B580', 'cpu': 'Intel Core i9-14900KS', 'ram': '64G' } }.
|
|
You were launched on {get_current_datetime()}.
|
|
You have real time access to current stock trading values, the current date and time, and current weather information for locations in the United States.
|
|
If you use any real time access, do not mention your knowledge cutoff.
|
|
Give short, courteous answers, no more than 2-3 sentences, keeping the answer less than about 100 characters.
|
|
If you have to cut the answer short, ask the user if they want more information and provide it if they say Yes.
|
|
Always be accurate. If you don't know the answer, say so. Do not make up details.
|
|
|
|
You have tools to:
|
|
* get_current_datetime: Get current time and date.
|
|
* get_weather_by_location: Get-real time weather forecast.
|
|
* get_ticker_price: Get real-time value of a stock symbol.
|
|
|
|
Those are the only tools available.
|
|
"""
|
|
system_log = [{"role": "system", "content": system_message}]
|
|
history = []
|
|
tool_log = []
|
|
command_log = []
|
|
model = None
|
|
client = None
|
|
irc_bot = None
|
|
|
|
# %%
|
|
# Cmd line overrides
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
|
parser.add_argument("--irc-server", type=str, default=IRC_SERVER, help=f"IRC server address. default={IRC_SERVER}")
|
|
parser.add_argument("--irc-port", type=int, default=IRC_PORT, help=f"IRC server port. default={IRC_PORT}")
|
|
parser.add_argument("--irc-nickname", type=str, default=NICK, help=f"Bot nickname. default={NICK}")
|
|
parser.add_argument("--irc-channel", type=str, default=CHANNEL, help=f"Channel to join. default={CHANNEL}")
|
|
parser.add_argument("--irc-use-tls", type=bool, default=USE_TLS, help=f"Use TLS with --irc-server. default={USE_TLS}")
|
|
parser.add_argument("--ollama-server", type=str, default=OLLAMA_API_URL, help=f"Ollama API endpoint. default={OLLAMA_API_URL}")
|
|
parser.add_argument("--ollama-model", type=str, default=MODEL_NAME, help=f"LLM model to use. default={MODEL_NAME}")
|
|
parser.add_argument("--gradio-host", type=str, default=GRADIO_HOST, help=f"Host to launch gradio on. default={GRADIO_HOST} only if --gradio-enable is specified.")
|
|
parser.add_argument("--gradio-port", type=str, default=GRADIO_PORT, help=f"Port to launch gradio on. default={GRADIO_PORT} only if --gradio-enable is specified.")
|
|
parser.add_argument("--gradio-enable", action="store_true", default=GRADIO_ENABLE, help=f"If set to True, enable Gradio. default={GRADIO_ENABLE}")
|
|
parser.add_argument("--bot-admin", type=str, default=BOT_ADMIN, help=f"Nick that can send admin commands via IRC. default={BOT_ADMIN}")
|
|
parser.add_argument('--level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
|
default=LOG_LEVEL, help=f'Set the logging level. default={LOG_LEVEL}')
|
|
return parser.parse_args()
|
|
|
|
def setup_logging(level):
|
|
numeric_level = getattr(logging, level.upper(), None)
|
|
if not isinstance(numeric_level, int):
|
|
raise ValueError(f"Invalid log level: {level}")
|
|
|
|
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
|
|
|
|
logging.info(f"Logging is set to {level} level.")
|
|
|
|
# %%
|
|
def split_paragraph_with_hyphenation(text, line_length=80, language='en_US'):
|
|
"""
|
|
Split a paragraph into multiple lines with proper hyphenation.
|
|
|
|
Args:
|
|
text (str): The text to split.
|
|
line_length (int): The maximum length of each line.
|
|
language (str): The language code for hyphenation rules.
|
|
|
|
Returns:
|
|
[str]: The text split into multiple lines with proper hyphenation.
|
|
"""
|
|
# Initialize the hyphenator for the specified language
|
|
h = hyphenator.Hyphenator(language)
|
|
|
|
# First attempt: try to wrap without hyphenation
|
|
lines = textwrap.wrap(text, width=line_length)
|
|
|
|
# If any lines are too long, we need to apply hyphenation
|
|
result_lines = []
|
|
|
|
for line in lines:
|
|
# If the line is already short enough, keep it as is
|
|
if len(line) <= line_length:
|
|
result_lines.append(line)
|
|
continue
|
|
|
|
# Otherwise, we need to hyphenate
|
|
words = line.split()
|
|
current_line = ""
|
|
|
|
for word in words:
|
|
# If adding the word doesn't exceed the limit, add it
|
|
if len(current_line) + len(word) + (1 if current_line else 0) <= line_length:
|
|
if current_line:
|
|
current_line += " "
|
|
current_line += word
|
|
# If the word itself is too long, hyphenate it
|
|
elif len(word) > line_length - len(current_line) - (1 if current_line else 0):
|
|
# If we already have content on the line, add it to results
|
|
if current_line:
|
|
result_lines.append(current_line)
|
|
current_line = ""
|
|
|
|
# Get hyphenation points for the word
|
|
hyphenated = h.syllables(word)
|
|
|
|
if not hyphenated:
|
|
# If no hyphenation points found, just add the word to a new line
|
|
result_lines.append(word)
|
|
continue
|
|
|
|
# Try to find a suitable hyphenation point
|
|
partial_word = ""
|
|
for syllable in hyphenated:
|
|
if len(partial_word) + len(syllable) + 1 > line_length:
|
|
# Add hyphen to the partial word and start a new line
|
|
if partial_word:
|
|
result_lines.append(partial_word + "-")
|
|
partial_word = syllable
|
|
else:
|
|
# If a single syllable is too long, just add it
|
|
result_lines.append(syllable)
|
|
else:
|
|
partial_word += syllable
|
|
|
|
# Don't forget the remaining part
|
|
if partial_word:
|
|
current_line = partial_word
|
|
|
|
else:
|
|
# Start a new line with this word
|
|
result_lines.append(current_line)
|
|
current_line = word
|
|
|
|
# Don't forget any remaining content
|
|
if current_line:
|
|
result_lines.append(current_line)
|
|
|
|
return result_lines
|
|
|
|
# %%
|
|
def handle_tool_calls(message):
|
|
response = []
|
|
tools_used = []
|
|
for tool_call in message['tool_calls']:
|
|
arguments = tool_call['function']['arguments']
|
|
tool = tool_call['function']['name']
|
|
if tool == 'get_ticker_price':
|
|
ticker = arguments.get('ticker')
|
|
if not ticker:
|
|
ret = None
|
|
else:
|
|
ret = get_ticker_price(ticker)
|
|
tools_used.append(tool)
|
|
elif tool == 'get_current_datetime':
|
|
ret = get_current_datetime(arguments.get('timezone'))
|
|
tools_used.append(tool)
|
|
elif tool == 'get_weather_by_location':
|
|
ret = get_weather_by_location(arguments.get('city'), arguments.get('state'))
|
|
tools_used.append(tool)
|
|
else:
|
|
ret = None
|
|
response.append({
|
|
"role": "tool",
|
|
"content": str(ret),
|
|
"name": tool_call['function']['name']
|
|
})
|
|
if len(response) == 1:
|
|
return response[0], tools_used
|
|
else:
|
|
return response, tools_used
|
|
|
|
# %%
|
|
async def chat(history, is_irc=False):
|
|
global client, model, irc_bot, system_log, tool_log
|
|
if not client:
|
|
return history
|
|
|
|
message = history[-1]
|
|
|
|
logging.info(f"chat('{message}')")
|
|
messages = system_log + history
|
|
response = client.chat(model=model, messages=messages, tools=tools)
|
|
tools_used = []
|
|
if 'tool_calls' in response['message']:
|
|
message = response['message']
|
|
# Convert Message object to a proper dictionary format
|
|
message_dict = {
|
|
'role': message.get('role', 'assistant'),
|
|
'content': message.get('content', '')
|
|
}
|
|
# Add tool_calls if present, ensuring they're properly serialized
|
|
if 'tool_calls' in message:
|
|
message_dict['tool_calls'] = [
|
|
{'function': {'name': tc.function.name, 'arguments': tc.function.arguments}}
|
|
for tc in message['tool_calls']
|
|
]
|
|
|
|
tool_result, tools_used = handle_tool_calls(message)
|
|
messages.append(message_dict) # Add properly formatted dict instead of Message object
|
|
messages.append(tool_result)
|
|
try:
|
|
response = client.chat(model=model, messages=messages)
|
|
except Exception:
|
|
logging.exception({ 'model': model, 'messages': messages })
|
|
|
|
tool_log.append({ 'call': message, 'tool_result': tool_result, 'tools_used': tools_used, 'response': response['message']['content']})
|
|
|
|
reply = response['message']['content']
|
|
if len(tools_used):
|
|
history += [{"role":"assistant", "content":reply, 'metadata': {"title": f"🛠️ Tool(s) used: {','.join(tools_used)}"}}]
|
|
else:
|
|
history += [{"role":"assistant", "content":reply}]
|
|
|
|
return history
|
|
|
|
# %%
|
|
|
|
# %%
|
|
|
|
class DynamicIRCBot(pydle.Client):
|
|
def __init__(self, nickname, channel, bot_admin, system_info, burst_limit = 5, rate_limit = 1.0, burst_reset_timeout = 10.0, **kwargs):
|
|
super().__init__(nickname, **kwargs)
|
|
self.history = []
|
|
self.channel = channel
|
|
self.bot_admin = bot_admin
|
|
self.system_info = system_info
|
|
# Message throttling
|
|
self.burst_limit = burst_limit
|
|
self.sent_burst = 0
|
|
self.rate_limit = rate_limit
|
|
self.burst_reset_timeout = burst_reset_timeout
|
|
self.sent_burst = 0 # Track messages sent in burst
|
|
self.last_message_time = None # Track last message time
|
|
self._message_queue = asyncio.Queue()
|
|
self._task = asyncio.create_task(self._send_from_queue())
|
|
|
|
async def _send_from_queue(self):
|
|
"""Background task that sends queued messages with burst + rate limiting."""
|
|
while True:
|
|
target, message = await self._message_queue.get()
|
|
|
|
# If burst is still available, send immediately
|
|
if self.sent_burst < self.burst_limit:
|
|
self.sent_burst += 1
|
|
else:
|
|
await asyncio.sleep(self.rate_limit) # Apply rate limit
|
|
|
|
logging.debug(f"Sending {message} => {target} from queue")
|
|
await super().message(target, message) # Send message
|
|
self.last_message_time = asyncio.get_event_loop().time() # Update last message timestamp
|
|
|
|
# Start burst reset countdown after each message
|
|
asyncio.create_task(self._reset_burst_after_inactivity())
|
|
|
|
async def _reset_burst_after_inactivity(self):
|
|
"""Resets burst counter only if no new messages are sent within timeout."""
|
|
last_time = self.last_message_time
|
|
await asyncio.sleep(self.burst_reset_timeout) # Wait for inactivity period
|
|
|
|
# Only reset if no new messages were sent during the wait
|
|
if self.last_message_time == last_time:
|
|
self.sent_burst = 0
|
|
logging.info("Burst limit reset due to inactivity.")
|
|
|
|
async def message(self, target, message):
|
|
"""Splits a multi-line message and sends each line separately. If more than 10 lines, truncate and add a message."""
|
|
max_lines = 10
|
|
irc_lines = []
|
|
for line in message.splitlines():
|
|
lines = split_paragraph_with_hyphenation(line, line_length=450)
|
|
irc_lines.extend(lines)
|
|
|
|
# Send the first 'max_lines' non-empty lines
|
|
i=0
|
|
sent_lines=0
|
|
while i < len(irc_lines) and sent_lines < max_lines:
|
|
line = irc_lines[i].strip()
|
|
i+=1
|
|
if line != "":
|
|
sent_lines += 1
|
|
await self._message_queue.put((target, line))
|
|
|
|
# If there are more than 10 lines, add the truncation message
|
|
if len(irc_lines) > max_lines:
|
|
await self._message_queue.put((target, f"...and so on (message truncated to {max_lines} lines.)"))
|
|
|
|
async def on_connect(self):
|
|
await super().on_connect()
|
|
logging.info(f"CONNECT: {self.nickname}")
|
|
if self.channel:
|
|
await self.join(self.channel, self.nickname)
|
|
|
|
async def on_join(self, channel, user):
|
|
await super().on_join(channel, user)
|
|
logging.info(f"JOIN: {user} => {channel}")
|
|
|
|
async def on_part(self, channel, user):
|
|
await super().on_part(channel, user)
|
|
logging.info(f"PART: {channel} => {user}")
|
|
|
|
async def on_message(self, target, source, message, is_gradio=False):
|
|
global system_log, tool_log, system_log, command_log
|
|
|
|
if not is_gradio:
|
|
await super().on_message(target, source, message)
|
|
message = message.strip()
|
|
logging.info(f"MESSAGE: {source} => {target}: {message}")
|
|
if source == self.nickname and not is_gradio:
|
|
return
|
|
last_message = self.history[-1] if len(self.history) > 0 and self.history[-1]["role"] == "user" else None
|
|
try:
|
|
matches = re.match(r"^([^:]+)\s*:\s*(.*)$", message)
|
|
if matches:
|
|
user = matches.group(1).strip()
|
|
content = matches.group(2).strip()
|
|
else:
|
|
user = None
|
|
content = message
|
|
|
|
|
|
# If this message is not directed to the bot
|
|
if not user or user != self.nickname:
|
|
logging.info(f"Message not directed to {self.nickname}")
|
|
# Add this message to the history either to the current 'user' context or create
|
|
# add a new message
|
|
if last_message:
|
|
logging.info(f"Modifying last USER context")
|
|
last_message['content'] += f"\n{source}: {content}"
|
|
else:
|
|
logging.info(f"Appending new USER context")
|
|
last_message = {
|
|
"role": "user",
|
|
"content": f"{source}: {content}"
|
|
}
|
|
self.history.append(last_message)
|
|
return
|
|
|
|
matches = re.match(r"^!([^\s]+)\s*(.*)?$", content)
|
|
if not matches or (self.bot_admin and source != self.bot_admin and source != self.nickname):
|
|
logging.info(f"Non-command directed message to {self.nickname}: Invoking chat...")
|
|
# Add this message to the history either to the current 'user' context or create
|
|
# add a new message
|
|
if last_message:
|
|
logging.info(f"Modifying last USER context")
|
|
last_message['content'] += f"\n{source}: {content}"
|
|
else:
|
|
logging.info(f"Appending new USER context")
|
|
last_message = {
|
|
"role": "user",
|
|
"content": f"{source}: {content}"
|
|
}
|
|
self.history.append(last_message)
|
|
self.history = await chat(self.history, is_irc=True)
|
|
chat_response = self.history[-1]
|
|
await self.message(target, chat_response['content'])
|
|
return
|
|
|
|
command = matches.group(1)
|
|
arguments = matches.group(2).strip()
|
|
logging.info(f"Command directed to {self.nickname}: command={command}, arguments={arguments}")
|
|
|
|
match command:
|
|
case "help":
|
|
response = f"info, context, reset, system [prompt], server [address], join channel"
|
|
|
|
case "info":
|
|
response = str(self.system_info)
|
|
|
|
case "context":
|
|
if len(self.history) > 1:
|
|
response = '"' + '","'.join(self.history[-1]['content'].split('\n')) + '"'
|
|
else:
|
|
response = "<no context>"
|
|
|
|
case "reset":
|
|
system_log = [{"role": "system", "content": system_message}]
|
|
history = []
|
|
tool_log = []
|
|
command_log = []
|
|
response = 'All contexts reset'
|
|
|
|
case "system":
|
|
if arguments != "":
|
|
system_log = [{"role": "system", "content": arguments}]
|
|
response = 'System message updated.'
|
|
else:
|
|
lines = [line.strip() for line in system_log[0]['content'].split('\n') if line.strip()]
|
|
response = " ".join(lines)
|
|
|
|
case "server":
|
|
server = arguments.split(" ", 1)
|
|
if server[0] == "":
|
|
server = IRC_SERVER
|
|
else:
|
|
server = server[0]
|
|
try:
|
|
await self.connect(server, 6667, tls=False)
|
|
response="Connected to {server}"
|
|
except Exception:
|
|
response = f"Unable to connect to {server}"
|
|
logging.exception({ "error": f"Unable to process message {content}"})
|
|
|
|
case "join":
|
|
channel = arguments.strip()
|
|
if channel == "" or re.match(r"\s", channel):
|
|
response = "Usage: !join CHANNEL"
|
|
else:
|
|
if not re.match(r"^#", channel):
|
|
channel = f"#{channel}"
|
|
if self.channel and self.channel != channel:
|
|
await self.part(channel)
|
|
if channel:
|
|
await self.bot.join(channel)
|
|
self.channel = channel
|
|
response = f"Joined {channel}."
|
|
|
|
case _:
|
|
response = f"Unrecognized command: {command}"
|
|
|
|
await self.message(target, f"!{command}: {response}")
|
|
command_log.append({ 'source': source, 'command': f"{content}", 'response': response })
|
|
except Exception:
|
|
logging.exception({ "error": f"Unable to process message {content}"})
|
|
await self.message(target, f"I'm experiencing difficulties processing '{content}'")
|
|
|
|
# %%
|
|
last_history_len = 0
|
|
last_command_len = 0
|
|
css = """
|
|
body, html, #root {
|
|
background-color: #F0F0F0;
|
|
}
|
|
"""
|
|
async def create_ui():
|
|
global irc_bot
|
|
|
|
with gr.Blocks(css=css, fill_height=True, fill_width=True) as ui:
|
|
with gr.Row(scale=1):
|
|
with gr.Column(scale=1):
|
|
chatbot = gr.Chatbot(
|
|
# value=check_message_queue,
|
|
# every=1,
|
|
type="messages",
|
|
#height=600, # Set a fixed height
|
|
container=True, # Enable scrolling
|
|
elem_id="chatbot",
|
|
scale=1
|
|
)
|
|
entry = gr.Textbox(
|
|
label="Chat with our AI Assistant:",
|
|
container=False,
|
|
scale=0
|
|
)
|
|
with gr.Column(scale=1):
|
|
chat_history = gr.JSON(
|
|
system_log,
|
|
label="Chat History",
|
|
height=100, # Set height for JSON viewer,
|
|
scale=1
|
|
)
|
|
tool_history = gr.JSON(
|
|
tool_log,
|
|
label="Tool calls",
|
|
height=100, # Set height for JSON viewer
|
|
scale=1
|
|
)
|
|
command_history = gr.JSON(
|
|
command_log,
|
|
label="Command calls",
|
|
height=100, # Set height for JSON viewer
|
|
scale=1
|
|
)
|
|
with gr.Row(scale=0):
|
|
clear = gr.Button("Clear")
|
|
timer = gr.Timer(1)
|
|
|
|
async def do_entry(message):
|
|
if not irc_bot:
|
|
return gr.skip()
|
|
await irc_bot.message(irc_bot.channel, f"[console] {message}")
|
|
await irc_bot.on_message(irc_bot.channel, irc_bot.nickname, f"{irc_bot.nickname}: {message}", is_gradio=True)
|
|
return "", irc_bot.history
|
|
|
|
def do_clear():
|
|
if not irc_bot:
|
|
return gr.skip()
|
|
irc_bot.history = []
|
|
tool_log = []
|
|
command_log = []
|
|
return irc_bot.history, system_log, tool_log, command_log
|
|
|
|
def update_log(history):
|
|
# This function updates the log after the chatbot responds
|
|
return system_log + history, tool_log, command_log
|
|
|
|
def check_history():
|
|
global last_history_len, last_command_len
|
|
if not irc_bot or last_history_len == len(irc_bot.history):
|
|
history = gr.skip()
|
|
else:
|
|
history = irc_bot.history
|
|
last_history_len = len(irc_bot.history)
|
|
if last_command_len == len(command_log):
|
|
commands = gr.skip()
|
|
else:
|
|
commands = command_log
|
|
last_command_len = len(command_log)
|
|
|
|
return history, commands
|
|
|
|
entry.submit(
|
|
do_entry,
|
|
inputs=[entry],
|
|
outputs=[entry, chatbot]
|
|
).then(
|
|
update_log, # This new function updates the log after chatbot processing
|
|
inputs=chatbot,
|
|
outputs=[chat_history, tool_history, command_history]
|
|
)
|
|
|
|
timer.tick(check_history, inputs=None, outputs=[chatbot, command_history])
|
|
|
|
clear.click(do_clear, inputs=None, outputs=[chatbot, chat_history, tool_history, command_history], queue=False)
|
|
|
|
return ui
|
|
|
|
# %%
|
|
|
|
# Main function to run everything
|
|
async def main():
|
|
global irc_bot, client, model
|
|
# Parse command-line arguments
|
|
args = parse_args()
|
|
if not re.match(r"^#", args.irc_channel):
|
|
args.irc_channel = f"#{args.irc_channel}"
|
|
# Setup logging based on the provided level
|
|
setup_logging(args.level)
|
|
|
|
client = ollama.Client(host=args.ollama_server)
|
|
model = args.ollama_model
|
|
|
|
irc_bot = DynamicIRCBot(args.irc_nickname, args.irc_channel, args.bot_admin, args)
|
|
await irc_bot.connect(args.irc_server, args.irc_port, tls=args.irc_use_tls)
|
|
|
|
if args.gradio_enable:
|
|
ui = await create_ui()
|
|
ui.launch(server_name=args.gradio_host, server_port=args.gradio_port, prevent_thread_lock=True, pwa=True)
|
|
logging.info(args)
|
|
|
|
await irc_bot.handle_forever()
|
|
|
|
# Run the main function using anyio
|
|
asyncio.run(main())
|