265 lines
10 KiB
Python
265 lines
10 KiB
Python
import asyncio
|
|
import argparse
|
|
import pydle
|
|
import logging
|
|
import os
|
|
import re
|
|
import time
|
|
import datetime
|
|
import asyncio
|
|
import json
|
|
import ollama
|
|
from typing import Dict, Any
|
|
import ollama
|
|
import chromadb
|
|
import feedparser
|
|
from bs4 import BeautifulSoup
|
|
|
|
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
|
|
MODEL_NAME = "deepseek-r1:7b"
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
|
parser.add_argument("--server", type=str, default="irc.libera.chat", help="IRC server address")
|
|
parser.add_argument("--port", type=int, default=6667, help="IRC server port")
|
|
parser.add_argument("--nickname", type=str, default="airc", help="Bot nickname")
|
|
parser.add_argument("--channel", type=str, default="#airc-test", help="Channel to join")
|
|
parser.add_argument("--ai-server", type=str, default="http://localhost:5000", help="OpenAI API endpoint")
|
|
parser.add_argument('--level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
|
default='INFO', help='Set the logging level.')
|
|
return parser.parse_args()
|
|
|
|
def setup_logging(level):
|
|
numeric_level = getattr(logging, level.upper(), None)
|
|
if not isinstance(numeric_level, int):
|
|
raise ValueError(f"Invalid log level: {level}")
|
|
|
|
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logging.info(f"Logging is set to {level} level.")
|
|
|
|
|
|
client = ollama.Client(host=OLLAMA_API_URL)
|
|
|
|
def extract_text_from_html_or_xml(content, is_xml=False):
|
|
# Parse the content
|
|
if is_xml:
|
|
soup = BeautifulSoup(content, 'xml') # Use 'xml' parser for XML content
|
|
else:
|
|
soup = BeautifulSoup(content, 'html.parser') # Default to 'html.parser' for HTML content
|
|
|
|
# Extract and return just the text
|
|
return soup.get_text()
|
|
|
|
class Feed():
|
|
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
|
|
self.name = name
|
|
self.url = url
|
|
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
|
|
self.last_poll = None
|
|
self.articles = []
|
|
self.max_articles = max_articles
|
|
self.update()
|
|
|
|
def update(self):
|
|
now = datetime.datetime.now()
|
|
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
|
|
logging.info(f"Updating {self.name}")
|
|
feed = feedparser.parse(self.url)
|
|
self.articles = []
|
|
self.last_poll = now
|
|
|
|
content = ""
|
|
if len(feed.entries) > 0:
|
|
content += f"Source: {self.name}\n"
|
|
for entry in feed.entries[:self.max_articles]:
|
|
title = entry.get("title")
|
|
if title:
|
|
content += f"Title: {title}\n"
|
|
link = entry.get("link")
|
|
if link:
|
|
content += f"Link: {link}\n"
|
|
summary = entry.get("summary")
|
|
if summary:
|
|
summary = extract_text_from_html_or_xml(summary, False)
|
|
content += f"Summary: {summary}\n"
|
|
published = entry.get("published")
|
|
if published:
|
|
content += f"Published: {published}\n"
|
|
content += "\n"
|
|
|
|
self.articles.append(content)
|
|
else:
|
|
logging.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
|
|
return self.articles
|
|
|
|
|
|
# News RSS Feeds
|
|
rss_feeds = [
|
|
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
|
|
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
|
|
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
|
|
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
|
|
Feed(name="Time", url="https://time.com/feed/"),
|
|
Feed(name="Euronews", url="https://www.euronews.com/rss"),
|
|
Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
|
|
]
|
|
|
|
documents = [
|
|
"Llamas like to eat penguins",
|
|
"Llamas are not vegetarians and have very efficient digestive systems",
|
|
"Llamas live to be about 120 years old, though some only live for 15 years and others live to be 90 years old",
|
|
]
|
|
|
|
import chromadb
|
|
|
|
# Initialize ChromaDB Client
|
|
db = chromadb.PersistentClient(path="/root/.cache/chroma.db")
|
|
|
|
# We want to save the collection to disk to analyze it offline, but we don't
|
|
# want to re-use it
|
|
collection = db.get_or_create_collection("docs")
|
|
|
|
# store each document in a vector embedding database
|
|
for i, feed in enumerate(rss_feeds):
|
|
# Use the client instance instead of the global ollama module
|
|
for j, article in enumerate(feed.articles):
|
|
response = client.embeddings(model="mxbai-embed-large", prompt=article)
|
|
embeddings = response["embedding"] # Note: it's "embedding", not "embeddings"
|
|
collection.add(
|
|
ids=[str(i)+str(j)],
|
|
embeddings=embeddings,
|
|
documents=[article]
|
|
)
|
|
|
|
class AIRC(pydle.Client):
|
|
def __init__(self, nick, channel, client, burst_limit = 5, rate_limit = 1.0, burst_reset_timeout = 10.0):
|
|
super().__init__(nick)
|
|
self.nick = nick
|
|
self.channel = channel
|
|
self.burst_limit = burst_limit
|
|
self.sent_burst = 0
|
|
self.rate_limit = rate_limit
|
|
self.burst_reset_timeout = burst_reset_timeout
|
|
self.sent_burst = 0 # Track messages sent in burst
|
|
self.last_message_time = None # Track last message time
|
|
self.system_input = "You are a critical assistant. Give concise and accurate answers in less than 120 characters."
|
|
self._message_queue = asyncio.Queue()
|
|
self._task = asyncio.create_task(self._send_from_queue())
|
|
self.client = client
|
|
self.queries = 0
|
|
self.processing = datetime.timedelta(minutes=0)
|
|
|
|
async def _send_from_queue(self):
|
|
"""Background task that sends queued messages with burst + rate limiting."""
|
|
while True:
|
|
target, message = await self._message_queue.get()
|
|
|
|
# If burst is still available, send immediately
|
|
if self.sent_burst < self.burst_limit:
|
|
self.sent_burst += 1
|
|
else:
|
|
await asyncio.sleep(self.rate_limit) # Apply rate limit
|
|
|
|
await super().message(target, message) # Send message
|
|
self.last_message_time = asyncio.get_event_loop().time() # Update last message timestamp
|
|
|
|
# Start burst reset countdown after each message
|
|
asyncio.create_task(self._reset_burst_after_inactivity())
|
|
|
|
async def _reset_burst_after_inactivity(self):
|
|
"""Resets burst counter only if no new messages are sent within timeout."""
|
|
last_time = self.last_message_time
|
|
await asyncio.sleep(self.burst_reset_timeout) # Wait for inactivity period
|
|
|
|
# Only reset if no new messages were sent during the wait
|
|
if self.last_message_time == last_time:
|
|
self.sent_burst = 0
|
|
logging.info("Burst limit reset due to inactivity.")
|
|
|
|
async def message(self, target, message):
|
|
"""Splits a multi-line message and sends each line separately. If more than 10 lines, truncate and add a message."""
|
|
lines = message.splitlines() # Splits on both '\n' and '\r\n'
|
|
|
|
# Process the first 10 lines
|
|
for line in lines[:10]:
|
|
if line.strip(): # Ignore empty lines
|
|
await self._message_queue.put((target, line))
|
|
|
|
# If there are more than 10 lines, add the truncation message
|
|
if len(lines) > 10:
|
|
await self._message_queue.put((target, "[additional content truncated]"))
|
|
|
|
async def on_connect(self):
|
|
logging.debug('on_connect')
|
|
await self.join(self.channel)
|
|
|
|
def remove_substring(self, string, substring):
|
|
return string.replace(substring, "")
|
|
|
|
def extract_nick_message(self, input_string):
|
|
# Pattern with capturing groups for nick and message
|
|
pattern = r"^\s*([^\s:]+?)\s*:\s*(.+?)$"
|
|
|
|
match = re.match(pattern, input_string)
|
|
if match:
|
|
nick = match.group(1) # First capturing group
|
|
message = match.group(2) # Second capturing group
|
|
return nick, message
|
|
return None, None # Return None for both if no match
|
|
|
|
async def on_message(self, target, source, message):
|
|
if source == self.nick:
|
|
return
|
|
nick, body = self.extract_nick_message(message)
|
|
if nick == self.nick:
|
|
content = None
|
|
if body == "stats":
|
|
content = f"{self.queries} queries handled in {self.processing}s"
|
|
else:
|
|
self.queries += 1
|
|
start = datetime.datetime.now()
|
|
query_text = body
|
|
query_response = client.embeddings(model="mxbai-embed-large", prompt=query_text)
|
|
query_embedding = query_response["embedding"] # Note: singular "embedding", not plural
|
|
|
|
# Then run the query with the correct structure
|
|
results = collection.query(
|
|
query_embeddings=[query_embedding], # Make sure this is a list containing the embedding
|
|
n_results=3
|
|
)
|
|
data = results['documents'][0][0]
|
|
logging.info(f"Data for {query_text}: {data}")
|
|
logging.info(f"From {results}")
|
|
output = client.generate(
|
|
model=MODEL_NAME,
|
|
system=f"Your are {self.nick}. In your response, make reference to this data if appropriate: {data}",
|
|
prompt=f"Respond to this prompt: {query_text}",
|
|
stream=False
|
|
)
|
|
end = datetime.datetime.now()
|
|
self.processing = self.processing + end - start
|
|
|
|
# Prune off the <think>...</think>
|
|
content = re.sub(r'^<think>.*?</think>', '', output['response'], flags=re.DOTALL).strip()
|
|
|
|
if content:
|
|
logging.info(f'Sending: {content}')
|
|
await self.message(target, f"{content}")
|
|
|
|
def remove_substring(string, substring):
|
|
return string.replace(substring, "")
|
|
|
|
async def main():
|
|
# Parse command-line arguments
|
|
args = parse_args()
|
|
|
|
# Setup logging based on the provided level
|
|
setup_logging(args.level)
|
|
|
|
bot = AIRC(args.nickname, args.channel, client)
|
|
await bot.connect(args.server, args.port, tls=False)
|
|
await bot.handle_forever()
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|