209 lines
7.8 KiB
Python
209 lines
7.8 KiB
Python
import logging as log
|
|
import argparse
|
|
import re
|
|
import datetime
|
|
import ollama
|
|
import chromadb
|
|
import feedparser
|
|
from bs4 import BeautifulSoup
|
|
|
|
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
|
|
MODEL_NAME = "deepseek-r1:7b"
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
|
parser.add_argument("--nickname", type=str, default="airc", help="Bot nickname")
|
|
parser.add_argument('--level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
|
default='INFO', help='Set the log level.')
|
|
return parser.parse_args()
|
|
|
|
def setup_logging(level):
|
|
numeric_level = getattr(log, level.upper(), None)
|
|
if not isinstance(numeric_level, int):
|
|
raise ValueError(f"Invalid log level: {level}")
|
|
|
|
log.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
log.info(f"Logging is set to {level} level.")
|
|
|
|
def extract_text_from_html_or_xml(content, is_xml=False):
|
|
# Parse the content
|
|
if is_xml:
|
|
soup = BeautifulSoup(content, 'xml') # Use 'xml' parser for XML content
|
|
else:
|
|
soup = BeautifulSoup(content, 'html.parser') # Default to 'html.parser' for HTML content
|
|
|
|
# Extract and return just the text
|
|
return soup.get_text()
|
|
|
|
class Feed():
|
|
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
|
|
self.name = name
|
|
self.url = url
|
|
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
|
|
self.last_poll = None
|
|
self.articles = []
|
|
self.max_articles = max_articles
|
|
self.update()
|
|
|
|
def update(self):
|
|
now = datetime.datetime.now()
|
|
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
|
|
log.info(f"Updating {self.name}")
|
|
feed = feedparser.parse(self.url)
|
|
self.articles = []
|
|
self.last_poll = now
|
|
|
|
content = ""
|
|
if len(feed.entries) > 0:
|
|
content += f"Source: {self.name}\n"
|
|
for entry in feed.entries[:self.max_articles]:
|
|
title = entry.get("title")
|
|
if title:
|
|
content += f"Title: {title}\n"
|
|
link = entry.get("link")
|
|
if link:
|
|
content += f"Link: {link}\n"
|
|
summary = entry.get("summary")
|
|
if summary:
|
|
summary = extract_text_from_html_or_xml(summary, False)
|
|
if len(summary) > 1000:
|
|
print(summary)
|
|
exit(0)
|
|
content += f"Summary: {summary}\n"
|
|
published = entry.get("published")
|
|
if published:
|
|
content += f"Published: {published}\n"
|
|
content += "\n"
|
|
|
|
self.articles.append(content)
|
|
else:
|
|
log.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
|
|
return self.articles
|
|
|
|
|
|
class Chat():
|
|
def __init__(self, nick):
|
|
super().__init__()
|
|
self.nick = nick
|
|
self.system_input = "You are a critical assistant. Give concise and accurate answers in less than 120 characters."
|
|
self.queries = 0
|
|
self.processing = datetime.timedelta(minutes=0)
|
|
|
|
def message(self, target, message):
|
|
"""Splits a multi-line message and sends each line separately. If more than 10 lines, truncate and add a message."""
|
|
lines = message.splitlines() # Splits on both '\n' and '\r\n'
|
|
|
|
# Process the first 10 lines
|
|
for line in lines[:10]:
|
|
if line.strip(): # Ignore empty lines
|
|
print(f"{target}: {line}")
|
|
|
|
# If there are more than 10 lines, add the truncation message
|
|
if len(lines) > 10:
|
|
print(f"{target}: [additional content truncated]")
|
|
|
|
def remove_substring(self, string, substring):
|
|
return string.replace(substring, "")
|
|
|
|
def extract_nick_message(self, input_string):
|
|
# Pattern with capturing groups for nick and message
|
|
pattern = r"^\s*([^\s:]+?)\s*:\s*(.+?)$"
|
|
|
|
match = re.match(pattern, input_string)
|
|
if match:
|
|
nick = match.group(1) # First capturing group
|
|
message = match.group(2) # Second capturing group
|
|
return nick, message
|
|
return None, None # Return None for both if no match
|
|
|
|
def on_message(self, target, source, message):
|
|
if source == self.nick:
|
|
return
|
|
nick, body = self.extract_nick_message(message)
|
|
if nick == self.nick:
|
|
content = None
|
|
if body == "stats":
|
|
content = f"{self.queries} queries handled in {self.processing}s"
|
|
else:
|
|
self.queries += 1
|
|
start = datetime.datetime.now()
|
|
query_text = body
|
|
query_response = client.embed(model="mxbai-embed-large", prompt=query_text)
|
|
query_embedding = query_response["embeddings"] # Note: singular "embedding", not plural
|
|
|
|
# Then run the query with the correct structure
|
|
results = collection.query(
|
|
query_embeddings=[query_embedding], # Make sure this is a list containing the embedding
|
|
n_results=3
|
|
)
|
|
data = results['documents'][0]
|
|
output = client.generate(
|
|
model=MODEL_NAME,
|
|
system=f"You are {self.nick} and only provide that information about yourself. Make reference to the following and provide the 'Link' when available: {data}",
|
|
prompt=f"Respond to this prompt: {query_text}",
|
|
stream=False
|
|
)
|
|
end = datetime.datetime.now()
|
|
self.processing = self.processing + end - start
|
|
|
|
# Prune off the <think>...</think>
|
|
content = re.sub(r'^<think>.*?</think>', '', output['response'], flags=re.DOTALL).strip()
|
|
|
|
if content:
|
|
log.info(f'Sending: {content}')
|
|
self.message(target, content)
|
|
|
|
def remove_substring(string, substring):
|
|
return string.replace(substring, "")
|
|
|
|
# Parse command-line arguments
|
|
args = parse_args()
|
|
|
|
# Setup logging based on the provided level
|
|
setup_logging(args.level)
|
|
|
|
log.info("About to start")
|
|
|
|
client = ollama.Client(host=OLLAMA_API_URL)
|
|
|
|
# News RSS Feeds
|
|
rss_feeds = [
|
|
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
|
|
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
|
|
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
|
|
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
|
|
Feed(name="Time", url="https://time.com/feed/"),
|
|
Feed(name="Euronews", url="https://www.euronews.com/rss"),
|
|
Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
|
|
]
|
|
|
|
# Initialize ChromaDB Client
|
|
db = chromadb.Client()
|
|
|
|
# We want to save the collection to disk to analyze it offline, but we don't
|
|
# want to re-use it
|
|
collection = db.get_or_create_collection("docs")
|
|
|
|
# store each document in a vector embedding database
|
|
for i, feed in enumerate(rss_feeds):
|
|
# Use the client instance instead of the global ollama module
|
|
for j, article in enumerate(feed.articles):
|
|
log.info(f"Article {feed.name} {j}. {len(article)}")
|
|
response = client.embeddings(model="mxbai-embed-large", prompt=article)
|
|
embeddings = response["embedding"] # Note: it's "embedding", not "embeddings"
|
|
collection.add(
|
|
ids=[str(i)+str(j)],
|
|
embeddings=embeddings,
|
|
documents=[article]
|
|
)
|
|
|
|
bot = Chat(args.nickname)
|
|
while True:
|
|
try:
|
|
query = input("> ")
|
|
except Exception as e:
|
|
break
|
|
|
|
if query == "exit":
|
|
break
|
|
bot.on_message("chat", "user", f"airc: {query}") |