backstory/src/airc.py

265 lines
10 KiB
Python

import asyncio
import argparse
import pydle
import logging
import os
import re
import time
import datetime
import asyncio
import json
import ollama
from typing import Dict, Any
import ollama
import chromadb
import feedparser
from bs4 import BeautifulSoup
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
MODEL_NAME = "deepseek-r1:7b"
def parse_args():
parser = argparse.ArgumentParser(description="AI is Really Cool")
parser.add_argument("--server", type=str, default="irc.libera.chat", help="IRC server address")
parser.add_argument("--port", type=int, default=6667, help="IRC server port")
parser.add_argument("--nickname", type=str, default="airc", help="Bot nickname")
parser.add_argument("--channel", type=str, default="#airc-test", help="Channel to join")
parser.add_argument("--ai-server", type=str, default="http://localhost:5000", help="OpenAI API endpoint")
parser.add_argument('--level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help='Set the logging level.')
return parser.parse_args()
def setup_logging(level):
numeric_level = getattr(logging, level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f"Invalid log level: {level}")
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
logging.info(f"Logging is set to {level} level.")
client = ollama.Client(host=OLLAMA_API_URL)
def extract_text_from_html_or_xml(content, is_xml=False):
# Parse the content
if is_xml:
soup = BeautifulSoup(content, 'xml') # Use 'xml' parser for XML content
else:
soup = BeautifulSoup(content, 'html.parser') # Default to 'html.parser' for HTML content
# Extract and return just the text
return soup.get_text()
class Feed():
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
self.name = name
self.url = url
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
self.last_poll = None
self.articles = []
self.max_articles = max_articles
self.update()
def update(self):
now = datetime.datetime.now()
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
logging.info(f"Updating {self.name}")
feed = feedparser.parse(self.url)
self.articles = []
self.last_poll = now
content = ""
if len(feed.entries) > 0:
content += f"Source: {self.name}\n"
for entry in feed.entries[:self.max_articles]:
title = entry.get("title")
if title:
content += f"Title: {title}\n"
link = entry.get("link")
if link:
content += f"Link: {link}\n"
summary = entry.get("summary")
if summary:
summary = extract_text_from_html_or_xml(summary, False)
content += f"Summary: {summary}\n"
published = entry.get("published")
if published:
content += f"Published: {published}\n"
content += "\n"
self.articles.append(content)
else:
logging.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
return self.articles
# News RSS Feeds
rss_feeds = [
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
Feed(name="Time", url="https://time.com/feed/"),
Feed(name="Euronews", url="https://www.euronews.com/rss"),
Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
]
documents = [
"Llamas like to eat penguins",
"Llamas are not vegetarians and have very efficient digestive systems",
"Llamas live to be about 120 years old, though some only live for 15 years and others live to be 90 years old",
]
import chromadb
# Initialize ChromaDB Client
db = chromadb.PersistentClient(path="/root/.cache/chroma.db")
# We want to save the collection to disk to analyze it offline, but we don't
# want to re-use it
collection = db.get_or_create_collection("docs")
# store each document in a vector embedding database
for i, feed in enumerate(rss_feeds):
# Use the client instance instead of the global ollama module
for j, article in enumerate(feed.articles):
response = client.embeddings(model="mxbai-embed-large", prompt=article)
embeddings = response["embedding"] # Note: it's "embedding", not "embeddings"
collection.add(
ids=[str(i)+str(j)],
embeddings=embeddings,
documents=[article]
)
class AIRC(pydle.Client):
def __init__(self, nick, channel, client, burst_limit = 5, rate_limit = 1.0, burst_reset_timeout = 10.0):
super().__init__(nick)
self.nick = nick
self.channel = channel
self.burst_limit = burst_limit
self.sent_burst = 0
self.rate_limit = rate_limit
self.burst_reset_timeout = burst_reset_timeout
self.sent_burst = 0 # Track messages sent in burst
self.last_message_time = None # Track last message time
self.system_input = "You are a critical assistant. Give concise and accurate answers in less than 120 characters."
self._message_queue = asyncio.Queue()
self._task = asyncio.create_task(self._send_from_queue())
self.client = client
self.queries = 0
self.processing = datetime.timedelta(minutes=0)
async def _send_from_queue(self):
"""Background task that sends queued messages with burst + rate limiting."""
while True:
target, message = await self._message_queue.get()
# If burst is still available, send immediately
if self.sent_burst < self.burst_limit:
self.sent_burst += 1
else:
await asyncio.sleep(self.rate_limit) # Apply rate limit
await super().message(target, message) # Send message
self.last_message_time = asyncio.get_event_loop().time() # Update last message timestamp
# Start burst reset countdown after each message
asyncio.create_task(self._reset_burst_after_inactivity())
async def _reset_burst_after_inactivity(self):
"""Resets burst counter only if no new messages are sent within timeout."""
last_time = self.last_message_time
await asyncio.sleep(self.burst_reset_timeout) # Wait for inactivity period
# Only reset if no new messages were sent during the wait
if self.last_message_time == last_time:
self.sent_burst = 0
logging.info("Burst limit reset due to inactivity.")
async def message(self, target, message):
"""Splits a multi-line message and sends each line separately. If more than 10 lines, truncate and add a message."""
lines = message.splitlines() # Splits on both '\n' and '\r\n'
# Process the first 10 lines
for line in lines[:10]:
if line.strip(): # Ignore empty lines
await self._message_queue.put((target, line))
# If there are more than 10 lines, add the truncation message
if len(lines) > 10:
await self._message_queue.put((target, "[additional content truncated]"))
async def on_connect(self):
logging.debug('on_connect')
await self.join(self.channel)
def remove_substring(self, string, substring):
return string.replace(substring, "")
def extract_nick_message(self, input_string):
# Pattern with capturing groups for nick and message
pattern = r"^\s*([^\s:]+?)\s*:\s*(.+?)$"
match = re.match(pattern, input_string)
if match:
nick = match.group(1) # First capturing group
message = match.group(2) # Second capturing group
return nick, message
return None, None # Return None for both if no match
async def on_message(self, target, source, message):
if source == self.nick:
return
nick, body = self.extract_nick_message(message)
if nick == self.nick:
content = None
if body == "stats":
content = f"{self.queries} queries handled in {self.processing}s"
else:
self.queries += 1
start = datetime.datetime.now()
query_text = body
query_response = client.embeddings(model="mxbai-embed-large", prompt=query_text)
query_embedding = query_response["embedding"] # Note: singular "embedding", not plural
# Then run the query with the correct structure
results = collection.query(
query_embeddings=[query_embedding], # Make sure this is a list containing the embedding
n_results=3
)
data = results['documents'][0][0]
logging.info(f"Data for {query_text}: {data}")
logging.info(f"From {results}")
output = client.generate(
model=MODEL_NAME,
system=f"Your are {self.nick}. In your response, make reference to this data if appropriate: {data}",
prompt=f"Respond to this prompt: {query_text}",
stream=False
)
end = datetime.datetime.now()
self.processing = self.processing + end - start
# Prune off the <think>...</think>
content = re.sub(r'^<think>.*?</think>', '', output['response'], flags=re.DOTALL).strip()
if content:
logging.info(f'Sending: {content}')
await self.message(target, f"{content}")
def remove_substring(string, substring):
return string.replace(substring, "")
async def main():
# Parse command-line arguments
args = parse_args()
# Setup logging based on the provided level
setup_logging(args.level)
bot = AIRC(args.nickname, args.channel, client)
await bot.connect(args.server, args.port, tls=False)
await bot.handle_forever()
if __name__ == "__main__":
asyncio.run(main())