466 lines
18 KiB
Python
466 lines
18 KiB
Python
# %%
|
|
# Imports [standard]
|
|
# Standard library modules (no try-except needed)
|
|
import argparse
|
|
import asyncio
|
|
import anyio
|
|
import json
|
|
import logging
|
|
import os
|
|
import queue
|
|
import re
|
|
import time
|
|
from datetime import datetime
|
|
|
|
def try_import(module_name, pip_name=None):
|
|
try:
|
|
__import__(module_name)
|
|
except ImportError:
|
|
print(f"Module '{module_name}' not found. Install it using:")
|
|
print(f" pip install {pip_name or module_name}")
|
|
|
|
# Third-party modules with import checks
|
|
try_import('gradio')
|
|
try_import('ollama')
|
|
try_import('openai')
|
|
try_import('pydle')
|
|
try_import('pytz')
|
|
try_import('requests')
|
|
try_import('yfinance', 'yfinance')
|
|
try_import('dotenv', 'python-dotenv')
|
|
try_import('geopy', 'geopy')
|
|
|
|
from dotenv import load_dotenv
|
|
from geopy.geocoders import Nominatim
|
|
import gradio as gr
|
|
import ollama
|
|
import openai
|
|
import pydle
|
|
import pytz
|
|
import requests
|
|
import yfinance as yf
|
|
|
|
# Local defined imports
|
|
from tools import (
|
|
get_weather_by_location,
|
|
get_current_datetime,
|
|
get_ticker_price,
|
|
tools
|
|
)
|
|
|
|
# %%
|
|
# Defaults
|
|
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
|
|
#MODEL_NAME = "deepseek-r1:7b"
|
|
MODEL_NAME = "llama3.2"
|
|
CHANNEL = "#airc-test"
|
|
NICK = "airc"
|
|
IRC_SERVER = "miniircd"
|
|
IRC_PORT = 6667
|
|
LOG_LEVEL="debug"
|
|
USE_TLS=False
|
|
GRADIO_HOST="0.0.0.0"
|
|
GRADIO_PORT=60673
|
|
GRADIO_ENABLE=False
|
|
BOT_ADMIN="james"
|
|
|
|
# %%
|
|
# Globals
|
|
system_message = f"""
|
|
You are a helpful information agent connected to the IRC network {IRC_SERVER}. Your name is {NICK}.
|
|
You are running { { 'model': MODEL_NAME, 'gpu': 'Intel Arc B580', 'cpu': 'Intel Core i9-14900KS', 'ram': '64G' } }.
|
|
You were launched on {get_current_datetime()}.
|
|
You have real time access to current stock trading values, the current date and time, and current weather information for locations in the United States.
|
|
If you use any real time access, do not mention your knowledge cutoff.
|
|
Give short, courteous answers, no more than 2-3 sentences, keeping the answer less than about 100 characters.
|
|
If you have to cut the answer short, ask the user if they want more information and provide it if they say Yes.
|
|
Always be accurate. If you don't know the answer, say so. Do not make up details.
|
|
|
|
You have tools to:
|
|
* get_current_datetime: Get current time and date.
|
|
* get_weather_by_location: Get-real time weather forecast.
|
|
* get_ticker_price: Get real-time value of a stock symbol.
|
|
|
|
Those are the only tools available.
|
|
"""
|
|
system_log = [{"role": "system", "content": system_message}]
|
|
history = []
|
|
tool_log = []
|
|
command_log = []
|
|
model = None
|
|
client = None
|
|
irc_bot = None
|
|
|
|
# %%
|
|
# Cmd line overrides
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
|
parser.add_argument("--irc-server", type=str, default=IRC_SERVER, help=f"IRC server address. default={IRC_SERVER}")
|
|
parser.add_argument("--irc-port", type=int, default=IRC_PORT, help=f"IRC server port. default={IRC_PORT}")
|
|
parser.add_argument("--irc-nickname", type=str, default=NICK, help=f"Bot nickname. default={NICK}")
|
|
parser.add_argument("--irc-channel", type=str, default=CHANNEL, help=f"Channel to join. default={CHANNEL}")
|
|
parser.add_argument("--irc-use-tls", type=bool, default=USE_TLS, help=f"Use TLS with --irc-server. default={USE_TLS}")
|
|
parser.add_argument("--ollama-server", type=str, default=OLLAMA_API_URL, help=f"Ollama API endpoint. default={OLLAMA_API_URL}")
|
|
parser.add_argument("--ollama-model", type=str, default=MODEL_NAME, help=f"LLM model to use. default={MODEL_NAME}")
|
|
parser.add_argument("--gradio-host", type=str, default=GRADIO_HOST, help=f"Host to launch gradio on. default={GRADIO_HOST} only if --gradio-enable is specified.")
|
|
parser.add_argument("--gradio-port", type=str, default=GRADIO_PORT, help=f"Port to launch gradio on. default={GRADIO_PORT} only if --gradio-enable is specified.")
|
|
parser.add_argument("--gradio-enable", action="store_true", default=GRADIO_ENABLE, help=f"If set to True, enable Gradio. default={GRADIO_ENABLE}")
|
|
parser.add_argument("--bot-admin", type=str, default=BOT_ADMIN, help=f"Nick that can send admin commands via IRC. default={BOT_ADMIN}")
|
|
parser.add_argument('--level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
|
default=LOG_LEVEL, help=f'Set the logging level. default={LOG_LEVEL}')
|
|
return parser.parse_args()
|
|
|
|
def setup_logging(level):
|
|
numeric_level = getattr(logging, level.upper(), None)
|
|
if not isinstance(numeric_level, int):
|
|
raise ValueError(f"Invalid log level: {level}")
|
|
|
|
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
|
|
|
|
logging.info(f"Logging is set to {level} level.")
|
|
|
|
# %%
|
|
def handle_tool_calls(message):
|
|
response = []
|
|
tools_used = []
|
|
for tool_call in message['tool_calls']:
|
|
arguments = tool_call['function']['arguments']
|
|
tool = tool_call['function']['name']
|
|
if tool == 'get_ticker_price':
|
|
ticker = arguments.get('ticker')
|
|
if not ticker:
|
|
ret = None
|
|
else:
|
|
ret = get_ticker_price(ticker)
|
|
tools_used.append(tool)
|
|
elif tool == 'get_current_datetime':
|
|
ret = get_current_datetime(arguments.get('timezone'))
|
|
tools_used.append(tool)
|
|
elif tool == 'get_weather_by_location':
|
|
ret = get_weather_by_location(arguments.get('city'), arguments.get('state'))
|
|
tools_used.append(tool)
|
|
else:
|
|
ret = None
|
|
response.append({
|
|
"role": "tool",
|
|
"content": str(ret),
|
|
"name": tool_call['function']['name']
|
|
})
|
|
if len(response) == 1:
|
|
return response[0], tools_used
|
|
else:
|
|
return response, tools_used
|
|
|
|
# %%
|
|
async def chat(history, is_irc=False):
|
|
global client, model, irc_bot, system_log, tool_log
|
|
if not client:
|
|
return history
|
|
|
|
message = history[-1]
|
|
|
|
logging.info(f"chat('{message}')")
|
|
messages = system_log + history
|
|
response = client.chat(model=model, messages=messages, tools=tools)
|
|
tools_used = []
|
|
if 'tool_calls' in response['message']:
|
|
message = response['message']
|
|
tool_result, tools_used = handle_tool_calls(message)
|
|
messages.append(message)
|
|
messages.append(tool_result)
|
|
try:
|
|
response = client.chat(model=model, messages=messages)
|
|
except Exception:
|
|
logging.exception({ 'model': model, 'messages': messages })
|
|
|
|
tool_log.append({ 'call': message, 'tool_result': tool_result, 'tools_used': tools_used, 'response': response['message']['content']})
|
|
|
|
reply = response['message']['content']
|
|
if len(tools_used):
|
|
history += [{"role":"assistant", "content":reply, 'metadata': {"title": f"🛠️ Tool(s) used: {','.join(tools_used)}"}}]
|
|
else:
|
|
history += [{"role":"assistant", "content":reply}]
|
|
|
|
return history
|
|
|
|
# %%
|
|
|
|
# %%
|
|
|
|
class DynamicIRCBot(pydle.Client):
|
|
def __init__(self, nickname, channel, bot_admin, system_info, **kwargs):
|
|
super().__init__(nickname, **kwargs)
|
|
self.history = []
|
|
self.channel = channel
|
|
self.bot_admin = bot_admin
|
|
self.system_info = system_info
|
|
|
|
async def on_connect(self):
|
|
await super().on_connect()
|
|
logging.info(f"CONNECT: {self.nickname}")
|
|
if self.channel:
|
|
await self.join(self.channel, self.nickname)
|
|
|
|
async def on_join(self, channel, user):
|
|
await super().on_join(channel, user)
|
|
logging.info(f"JOIN: {user} => {channel}")
|
|
|
|
async def on_part(self, channel, user):
|
|
await super().on_part(channel, user)
|
|
logging.info(f"PART: {channel} => {user}")
|
|
|
|
async def on_message(self, target, source, message, is_gradio=False):
|
|
global system_log, tool_log, system_log, command_log
|
|
|
|
if not is_gradio:
|
|
await super().on_message(target, source, message)
|
|
message = message.strip()
|
|
logging.info(f"MESSAGE: {source} => {target}: {message}")
|
|
if source == self.nickname and not is_gradio:
|
|
return
|
|
last_message = self.history[-1] if len(self.history) > 0 and self.history[-1]["role"] == "user" else None
|
|
try:
|
|
matches = re.match(r"^([^:]+)\s*:\s*(.*)$", message)
|
|
if matches:
|
|
user = matches.group(1).strip()
|
|
content = matches.group(2).strip()
|
|
else:
|
|
user = None
|
|
content = message
|
|
|
|
|
|
# If this message is not directed to the bot
|
|
if not user or user != self.nickname:
|
|
logging.info(f"Message not directed to {self.nickname}")
|
|
# Add this message to the history either to the current 'user' context or create
|
|
# add a new message
|
|
if last_message:
|
|
logging.info(f"Modifying last USER context")
|
|
last_message['content'] += f"\n{source}: {content}"
|
|
else:
|
|
logging.info(f"Appending new USER context")
|
|
last_message = {
|
|
"role": "user",
|
|
"content": f"{source}: {content}"
|
|
}
|
|
self.history.append(last_message)
|
|
return
|
|
|
|
matches = re.match(r"^!([^\s]+)\s*(.*)?$", content)
|
|
if not matches or (self.bot_admin and source != self.bot_admin and source != self.nickname):
|
|
logging.info(f"Non-command directed message to {self.nickname}: Invoking chat...")
|
|
# Add this message to the history either to the current 'user' context or create
|
|
# add a new message
|
|
if last_message:
|
|
logging.info(f"Modifying last USER context")
|
|
last_message['content'] += f"\n{source}: {content}"
|
|
else:
|
|
logging.info(f"Appending new USER context")
|
|
last_message = {
|
|
"role": "user",
|
|
"content": f"{source}: {content}"
|
|
}
|
|
self.history.append(last_message)
|
|
self.history = await chat(self.history, is_irc=True)
|
|
chat_response = self.history[-1]
|
|
lines = [line.strip() for line in chat_response['content'].split('\n') if line.strip()]
|
|
if len(lines) > 10:
|
|
lines = lines[:9] + ["...and so on. I'll send the rest to you in a PRIVMSG (wip) :)"]
|
|
await self.message(target, "\n".join(lines))
|
|
return
|
|
|
|
command = matches.group(1)
|
|
arguments = matches.group(2).strip()
|
|
logging.info(f"Command directed to {self.nickname}: command={command}, arguments={arguments}")
|
|
|
|
match command:
|
|
case "test":
|
|
response = f"test received: {arguments}"
|
|
|
|
case "info":
|
|
response = str(self.system_info)
|
|
|
|
case "context":
|
|
if len(self.history) > 1:
|
|
response = '"' + '","'.join(self.history[-1]['content'].split('\n')) + '"'
|
|
else:
|
|
response = "<no context>"
|
|
|
|
case "reset":
|
|
system_log = [{"role": "system", "content": system_message}]
|
|
history = []
|
|
tool_log = []
|
|
command_log = []
|
|
response = 'All contexts reset'
|
|
|
|
case "system":
|
|
if arguments != "":
|
|
system_log = [{"role": "system", "content": arguments}]
|
|
response = 'System message updated.'
|
|
else:
|
|
lines = [line.strip() for line in system_log[0]['content'].split('\n') if line.strip()]
|
|
response = " ".join(lines)
|
|
|
|
case "server":
|
|
server = arguments.split(" ", 1)
|
|
if len(server) == 1:
|
|
server = IRC_SERVER
|
|
else:
|
|
server = server[1]
|
|
try:
|
|
await self.connect(server, 6667, tls=False)
|
|
except Exception:
|
|
logging.exception({ "error": f"Unable to process message {content}"})
|
|
reponse = f"Unable to connect to {server}"
|
|
|
|
case "join":
|
|
channel = arguments.strip()
|
|
if channel == "" or re.match(r"\s", channel):
|
|
response = "Usage: !join CHANNEL"
|
|
else:
|
|
if not re.match(r"^#", channel):
|
|
channel = f"#{channel}"
|
|
if self.channel and self.channel != channel:
|
|
await self.part(channel)
|
|
if channel:
|
|
await self.bot.join(channel)
|
|
self.channel = channel
|
|
response = f"Joined {channel}."
|
|
|
|
case _:
|
|
response = f"Unrecognized command: {command}"
|
|
|
|
await self.message(target, f"!{command}: {response}")
|
|
command_log.append({ 'source': source, 'command': f"{content}", 'response': response })
|
|
except Exception:
|
|
logging.exception({ "error": f"Unable to process message {content}"})
|
|
await self.message(target, f"I'm experiencing difficulties processing '{content}'")
|
|
|
|
# %%
|
|
last_history_len = 0
|
|
last_command_len = 0
|
|
css = """
|
|
body, html, #root {
|
|
background-color: #F0F0F0;
|
|
}
|
|
"""
|
|
async def create_ui():
|
|
global irc_bot
|
|
|
|
with gr.Blocks(css=css, fill_height=True, fill_width=True) as ui:
|
|
with gr.Row(scale=1):
|
|
with gr.Column(scale=1):
|
|
chatbot = gr.Chatbot(
|
|
# value=check_message_queue,
|
|
# every=1,
|
|
type="messages",
|
|
#height=600, # Set a fixed height
|
|
container=True, # Enable scrolling
|
|
elem_id="chatbot",
|
|
scale=1
|
|
)
|
|
entry = gr.Textbox(
|
|
label="Chat with our AI Assistant:",
|
|
container=False,
|
|
scale=0
|
|
)
|
|
with gr.Column(scale=1):
|
|
chat_history = gr.JSON(
|
|
system_log,
|
|
label="Chat History",
|
|
height=100, # Set height for JSON viewer,
|
|
scale=1
|
|
)
|
|
tool_history = gr.JSON(
|
|
tool_log,
|
|
label="Tool calls",
|
|
height=100, # Set height for JSON viewer
|
|
scale=1
|
|
)
|
|
command_history = gr.JSON(
|
|
command_log,
|
|
label="Command calls",
|
|
height=100, # Set height for JSON viewer
|
|
scale=1
|
|
)
|
|
with gr.Row(scale=0):
|
|
clear = gr.Button("Clear")
|
|
timer = gr.Timer(1)
|
|
|
|
async def do_entry(message):
|
|
if not irc_bot:
|
|
return gr.skip()
|
|
await irc_bot.message(irc_bot.channel, f"[console] {message}")
|
|
await irc_bot.on_message(irc_bot.channel, irc_bot.nickname, f"{irc_bot.nickname}: {message}", is_gradio=True)
|
|
return "", irc_bot.history
|
|
|
|
def do_clear():
|
|
if not irc_bot:
|
|
return gr.skip()
|
|
irc_bot.history = []
|
|
tool_log = []
|
|
command_log = []
|
|
return irc_bot.history, system_log, tool_log, command_log
|
|
|
|
def update_log(history):
|
|
# This function updates the log after the chatbot responds
|
|
return system_log + history, tool_log, command_log
|
|
|
|
def check_history():
|
|
global last_history_len, last_command_len
|
|
if not irc_bot or last_history_len == len(irc_bot.history):
|
|
history = gr.skip()
|
|
else:
|
|
history = irc_bot.history
|
|
last_history_len = len(irc_bot.history)
|
|
if last_command_len == len(command_log):
|
|
commands = gr.skip()
|
|
else:
|
|
commands = command_log
|
|
last_command_len = len(command_log)
|
|
|
|
return history, commands
|
|
|
|
entry.submit(
|
|
do_entry,
|
|
inputs=[entry],
|
|
outputs=[entry, chatbot]
|
|
).then(
|
|
update_log, # This new function updates the log after chatbot processing
|
|
inputs=chatbot,
|
|
outputs=[chat_history, tool_history, command_history]
|
|
)
|
|
|
|
timer.tick(check_history, inputs=None, outputs=[chatbot, command_history])
|
|
|
|
clear.click(do_clear, inputs=None, outputs=[chatbot, chat_history, tool_history, command_history], queue=False)
|
|
|
|
return ui
|
|
|
|
# %%
|
|
|
|
# Main function to run everything
|
|
async def main():
|
|
global irc_bot, client, model
|
|
# Parse command-line arguments
|
|
args = parse_args()
|
|
|
|
# Setup logging based on the provided level
|
|
setup_logging(args.level)
|
|
|
|
client = ollama.Client(host=args.ollama_server)
|
|
model = args.ollama_model
|
|
|
|
irc_bot = DynamicIRCBot(args.irc_nickname, args.irc_channel, args.bot_admin, args)
|
|
await irc_bot.connect(args.irc_server, args.irc_port, tls=args.irc_use_tls)
|
|
|
|
if args.gradio_enable:
|
|
ui = await create_ui()
|
|
ui.launch(server_name=args.gradio_host, server_port=args.gradio_port, prevent_thread_lock=True, pwa=True)
|
|
logging.info(args)
|
|
|
|
await irc_bot.handle_forever()
|
|
|
|
# Run the main function using anyio
|
|
asyncio.run(main())
|