Add some simple RAG from RSS feeds

This commit is contained in:
James Ketr 2025-03-06 16:37:51 -08:00
parent 44d26dad9d
commit 3f749d4e78
3 changed files with 134 additions and 7 deletions

View File

@ -281,7 +281,7 @@ RUN { \
echo ' --ServerApp.allow_origin=* \' ; \ echo ' --ServerApp.allow_origin=* \' ; \
echo ' --ServerApp.base_url="/jupyter" \' ; \ echo ' --ServerApp.base_url="/jupyter" \' ; \
echo ' "${@}" \' ; \ echo ' "${@}" \' ; \
echo ' >> "/root/.cache/jupyter.log" 2>&1' ; \ echo ' 2>&1 | tee -a "/root/.cache/jupyter.log"' ; \
echo ' echo "jupyter notebook died ($?). Restarting."' ; \ echo ' echo "jupyter notebook died ($?). Restarting."' ; \
echo ' sleep 5' ; \ echo ' sleep 5' ; \
echo 'done' ; \ echo 'done' ; \
@ -393,4 +393,11 @@ RUN dpkg -i /opt/ze-monitor-*deb
WORKDIR /opt/airc WORKDIR /opt/airc
SHELL [ "/opt/airc/shell" ]
# Needed by src/model-server.py
RUN pip install faiss-cpu sentence_transformers feedparser
SHELL [ "/bin/bash", "-c" ]
ENTRYPOINT [ "/entrypoint-airc.sh" ] ENTRYPOINT [ "/entrypoint-airc.sh" ]

View File

@ -162,7 +162,6 @@ class AIRC(pydle.Client):
# Extract and print just the assistant's message if available # Extract and print just the assistant's message if available
if "choices" in response and len(response["choices"]) > 0: if "choices" in response and len(response["choices"]) > 0:
content = response["choices"][0]["message"]["content"] content = response["choices"][0]["message"]["content"]
print(f"\nAssistant: {content}")
if content: if content:
logging.info(f'Sending: {content}') logging.info(f'Sending: {content}')

View File

@ -15,6 +15,11 @@ import asyncio
import aiohttp import aiohttp
import json import json
from typing import Dict, Any from typing import Dict, Any
import feedparser
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="AI is Really Cool Server") parser = argparse.ArgumentParser(description="AI is Really Cool Server")
@ -32,6 +37,100 @@ def setup_logging(level):
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
logging.info(f"Logging is set to {level} level.") logging.info(f"Logging is set to {level} level.")
class Feed():
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
self.name = name
self.url = url
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
self.last_poll = None
self.articles = []
self.max_articles = max_articles
self.update()
def update(self):
now = datetime.datetime.now()
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
logging.info(f"Updating {self.name}")
feed = feedparser.parse(self.url)
self.articles = []
self.last_poll = now
content = ""
if len(feed.entries) > 0:
content += f"Source: {self.name}\n"
for entry in feed.entries[:self.max_articles]:
title = entry.get("title")
if title:
content += f"Title: {title}\n"
link = entry.get("link")
if link:
content += f"Link: {link}\n"
summary = entry.get("summary")
if summary:
content += f"Summary: {summary}\n"
published = entry.get("published")
if published:
content += f"Published: {published}\n"
content += "\n"
self.articles.append(content)
else:
logging.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
return self.articles
# News RSS Feeds
rss_feeds = [
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
Feed(name="Time", url="https://time.com/feed/"),
Feed(name="Euronews", url="https://www.euronews.com/rss"),
Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
]
# Load an embedding model
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# Collect news from all sources
documents = []
for feed in rss_feeds:
documents.extend(feed.update())
# Step 2: Encode and store news articles into FAISS
doc_vectors = np.array(embedding_model.encode(documents), dtype=np.float32)
index = faiss.IndexFlatL2(doc_vectors.shape[1]) # Initialize FAISS index
index.add(doc_vectors) # Store news vectors
print(f"Stored {len(doc_vectors)} documents in FAISS index.")
# Step 3: Retrieval function for user queries
def retrieve_documents(query, top_k=2):
"""Retrieve top-k most relevant news articles."""
query_vector = np.array(embedding_model.encode([query]), dtype=np.float32)
D, I = index.search(query_vector, top_k)
retrieved_docs = [documents[i] for i in I[0]]
return retrieved_docs
# Step 4: Format the RAG prompt
def format_prompt(query, retrieved_docs):
"""Format retrieved documents into a structured RAG prompt."""
context_str = "\n".join(retrieved_docs)
prompt = f"""You are an AI assistant with access to world news. Use the Retrieved Context to answer the user's question accurately if relevant, stating which Source provided the information.
## Retrieved Context:
{context_str}
## User Query:
{query}
## Response:
"""
return prompt
class Chat(): class Chat():
def __init__(self, device_name): def __init__(self, device_name):
super().__init__() super().__init__()
@ -53,12 +152,13 @@ class Chat():
self.model = self.model.half().to(device_name) self.model = self.model.half().to(device_name)
except Exception as e: except Exception as e:
logging.error(f"Loading error: {e}") logging.error(f"Loading error: {e}")
raise Exception(e)
def remove_substring(self, string, substring): def remove_substring(self, string, substring):
return string.replace(substring, "") return string.replace(substring, "")
def generate_response(self, text): def generate_response(self, text):
prompt = f"### System:\n{self.system_input}\n### User:\n{text}\n### Assistant:\n" prompt = text
start = time.time() start = time.time()
with torch.autocast(self.device_name, dtype=torch.float16): with torch.autocast(self.device_name, dtype=torch.float16):
@ -66,7 +166,7 @@ class Chat():
prompt, prompt,
add_special_tokens=False, add_special_tokens=False,
return_tensors="pt", return_tensors="pt",
max_length=1000, # Prevent 'Asking to truncate to max_length...' max_length=8000, # Prevent 'Asking to truncate to max_length...'
padding=True, # Handles padding automatically padding=True, # Handles padding automatically
truncation=True truncation=True
) )
@ -76,7 +176,7 @@ class Chat():
outputs = self.model.generate( outputs = self.model.generate(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=1000, max_length=8000,
num_return_sequences=1, num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id pad_token_id=self.tokenizer.eos_token_id
) )
@ -105,8 +205,29 @@ def chat_completions():
max_tokens = data.get('max_tokens', 2048) max_tokens = data.get('max_tokens', 2048)
chat = app.config['chat'] chat = app.config['chat']
logging.info(f"Query: {messages}") query = messages[-1]['content']
response_content, _ = chat.generate_response(messages[-1]['content'])
if re.match(r"^\s*(update|refresh) news\s*$", query, re.IGNORECASE):
logging.info("New refresh requested")
# Collect news from all sources
documents = []
for feed in rss_feeds:
documents.extend(feed.update())
# Step 2: Encode and store news articles into FAISS
doc_vectors = np.array(embedding_model.encode(documents), dtype=np.float32)
index = faiss.IndexFlatL2(doc_vectors.shape[1]) # Initialize FAISS index
index.add(doc_vectors) # Store news vectors
print(f"Stored {len(doc_vectors)} documents in FAISS index.")
response_content = "News refresh requested."
else:
logging.info(f"Query: {query}")
retrieved_docs = retrieve_documents(query)
rag_prompt = format_prompt(query, retrieved_docs)
logging.info(f"RAG prompt: {rag_prompt}")
# Get AI-generated response
response_content, _ = chat.generate_response(rag_prompt)
logging.info(f"Response: {response_content}") logging.info(f"Response: {response_content}")
# Format response in OpenAI-compatible structure # Format response in OpenAI-compatible structure
response = { response = {