Add some simple RAG from RSS feeds

This commit is contained in:
James Ketr 2025-03-06 16:37:51 -08:00
parent 44d26dad9d
commit 3f749d4e78
3 changed files with 134 additions and 7 deletions

View File

@ -281,7 +281,7 @@ RUN { \
echo ' --ServerApp.allow_origin=* \' ; \
echo ' --ServerApp.base_url="/jupyter" \' ; \
echo ' "${@}" \' ; \
echo ' >> "/root/.cache/jupyter.log" 2>&1' ; \
echo ' 2>&1 | tee -a "/root/.cache/jupyter.log"' ; \
echo ' echo "jupyter notebook died ($?). Restarting."' ; \
echo ' sleep 5' ; \
echo 'done' ; \
@ -393,4 +393,11 @@ RUN dpkg -i /opt/ze-monitor-*deb
WORKDIR /opt/airc
SHELL [ "/opt/airc/shell" ]
# Needed by src/model-server.py
RUN pip install faiss-cpu sentence_transformers feedparser
SHELL [ "/bin/bash", "-c" ]
ENTRYPOINT [ "/entrypoint-airc.sh" ]

View File

@ -162,7 +162,6 @@ class AIRC(pydle.Client):
# Extract and print just the assistant's message if available
if "choices" in response and len(response["choices"]) > 0:
content = response["choices"][0]["message"]["content"]
print(f"\nAssistant: {content}")
if content:
logging.info(f'Sending: {content}')

View File

@ -15,6 +15,11 @@ import asyncio
import aiohttp
import json
from typing import Dict, Any
import feedparser
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
def parse_args():
parser = argparse.ArgumentParser(description="AI is Really Cool Server")
@ -32,6 +37,100 @@ def setup_logging(level):
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
logging.info(f"Logging is set to {level} level.")
class Feed():
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
self.name = name
self.url = url
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
self.last_poll = None
self.articles = []
self.max_articles = max_articles
self.update()
def update(self):
now = datetime.datetime.now()
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
logging.info(f"Updating {self.name}")
feed = feedparser.parse(self.url)
self.articles = []
self.last_poll = now
content = ""
if len(feed.entries) > 0:
content += f"Source: {self.name}\n"
for entry in feed.entries[:self.max_articles]:
title = entry.get("title")
if title:
content += f"Title: {title}\n"
link = entry.get("link")
if link:
content += f"Link: {link}\n"
summary = entry.get("summary")
if summary:
content += f"Summary: {summary}\n"
published = entry.get("published")
if published:
content += f"Published: {published}\n"
content += "\n"
self.articles.append(content)
else:
logging.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
return self.articles
# News RSS Feeds
rss_feeds = [
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
Feed(name="Time", url="https://time.com/feed/"),
Feed(name="Euronews", url="https://www.euronews.com/rss"),
Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
]
# Load an embedding model
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# Collect news from all sources
documents = []
for feed in rss_feeds:
documents.extend(feed.update())
# Step 2: Encode and store news articles into FAISS
doc_vectors = np.array(embedding_model.encode(documents), dtype=np.float32)
index = faiss.IndexFlatL2(doc_vectors.shape[1]) # Initialize FAISS index
index.add(doc_vectors) # Store news vectors
print(f"Stored {len(doc_vectors)} documents in FAISS index.")
# Step 3: Retrieval function for user queries
def retrieve_documents(query, top_k=2):
"""Retrieve top-k most relevant news articles."""
query_vector = np.array(embedding_model.encode([query]), dtype=np.float32)
D, I = index.search(query_vector, top_k)
retrieved_docs = [documents[i] for i in I[0]]
return retrieved_docs
# Step 4: Format the RAG prompt
def format_prompt(query, retrieved_docs):
"""Format retrieved documents into a structured RAG prompt."""
context_str = "\n".join(retrieved_docs)
prompt = f"""You are an AI assistant with access to world news. Use the Retrieved Context to answer the user's question accurately if relevant, stating which Source provided the information.
## Retrieved Context:
{context_str}
## User Query:
{query}
## Response:
"""
return prompt
class Chat():
def __init__(self, device_name):
super().__init__()
@ -53,12 +152,13 @@ class Chat():
self.model = self.model.half().to(device_name)
except Exception as e:
logging.error(f"Loading error: {e}")
raise Exception(e)
def remove_substring(self, string, substring):
return string.replace(substring, "")
def generate_response(self, text):
prompt = f"### System:\n{self.system_input}\n### User:\n{text}\n### Assistant:\n"
prompt = text
start = time.time()
with torch.autocast(self.device_name, dtype=torch.float16):
@ -66,7 +166,7 @@ class Chat():
prompt,
add_special_tokens=False,
return_tensors="pt",
max_length=1000, # Prevent 'Asking to truncate to max_length...'
max_length=8000, # Prevent 'Asking to truncate to max_length...'
padding=True, # Handles padding automatically
truncation=True
)
@ -76,7 +176,7 @@ class Chat():
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=1000,
max_length=8000,
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id
)
@ -105,8 +205,29 @@ def chat_completions():
max_tokens = data.get('max_tokens', 2048)
chat = app.config['chat']
logging.info(f"Query: {messages}")
response_content, _ = chat.generate_response(messages[-1]['content'])
query = messages[-1]['content']
if re.match(r"^\s*(update|refresh) news\s*$", query, re.IGNORECASE):
logging.info("New refresh requested")
# Collect news from all sources
documents = []
for feed in rss_feeds:
documents.extend(feed.update())
# Step 2: Encode and store news articles into FAISS
doc_vectors = np.array(embedding_model.encode(documents), dtype=np.float32)
index = faiss.IndexFlatL2(doc_vectors.shape[1]) # Initialize FAISS index
index.add(doc_vectors) # Store news vectors
print(f"Stored {len(doc_vectors)} documents in FAISS index.")
response_content = "News refresh requested."
else:
logging.info(f"Query: {query}")
retrieved_docs = retrieve_documents(query)
rag_prompt = format_prompt(query, retrieved_docs)
logging.info(f"RAG prompt: {rag_prompt}")
# Get AI-generated response
response_content, _ = chat.generate_response(rag_prompt)
logging.info(f"Response: {response_content}")
# Format response in OpenAI-compatible structure
response = {