Add some simple RAG from RSS feeds
This commit is contained in:
parent
44d26dad9d
commit
3f749d4e78
@ -281,7 +281,7 @@ RUN { \
|
||||
echo ' --ServerApp.allow_origin=* \' ; \
|
||||
echo ' --ServerApp.base_url="/jupyter" \' ; \
|
||||
echo ' "${@}" \' ; \
|
||||
echo ' >> "/root/.cache/jupyter.log" 2>&1' ; \
|
||||
echo ' 2>&1 | tee -a "/root/.cache/jupyter.log"' ; \
|
||||
echo ' echo "jupyter notebook died ($?). Restarting."' ; \
|
||||
echo ' sleep 5' ; \
|
||||
echo 'done' ; \
|
||||
@ -393,4 +393,11 @@ RUN dpkg -i /opt/ze-monitor-*deb
|
||||
|
||||
WORKDIR /opt/airc
|
||||
|
||||
SHELL [ "/opt/airc/shell" ]
|
||||
|
||||
# Needed by src/model-server.py
|
||||
RUN pip install faiss-cpu sentence_transformers feedparser
|
||||
|
||||
SHELL [ "/bin/bash", "-c" ]
|
||||
|
||||
ENTRYPOINT [ "/entrypoint-airc.sh" ]
|
||||
|
@ -162,7 +162,6 @@ class AIRC(pydle.Client):
|
||||
# Extract and print just the assistant's message if available
|
||||
if "choices" in response and len(response["choices"]) > 0:
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
print(f"\nAssistant: {content}")
|
||||
|
||||
if content:
|
||||
logging.info(f'Sending: {content}')
|
||||
|
@ -15,6 +15,11 @@ import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
import feedparser
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="AI is Really Cool Server")
|
||||
@ -32,6 +37,100 @@ def setup_logging(level):
|
||||
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logging.info(f"Logging is set to {level} level.")
|
||||
|
||||
class Feed():
|
||||
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
|
||||
self.name = name
|
||||
self.url = url
|
||||
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
|
||||
self.last_poll = None
|
||||
self.articles = []
|
||||
self.max_articles = max_articles
|
||||
self.update()
|
||||
|
||||
def update(self):
|
||||
now = datetime.datetime.now()
|
||||
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
|
||||
logging.info(f"Updating {self.name}")
|
||||
feed = feedparser.parse(self.url)
|
||||
self.articles = []
|
||||
self.last_poll = now
|
||||
|
||||
content = ""
|
||||
if len(feed.entries) > 0:
|
||||
content += f"Source: {self.name}\n"
|
||||
for entry in feed.entries[:self.max_articles]:
|
||||
title = entry.get("title")
|
||||
if title:
|
||||
content += f"Title: {title}\n"
|
||||
link = entry.get("link")
|
||||
if link:
|
||||
content += f"Link: {link}\n"
|
||||
summary = entry.get("summary")
|
||||
if summary:
|
||||
content += f"Summary: {summary}\n"
|
||||
published = entry.get("published")
|
||||
if published:
|
||||
content += f"Published: {published}\n"
|
||||
content += "\n"
|
||||
|
||||
self.articles.append(content)
|
||||
else:
|
||||
logging.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
|
||||
return self.articles
|
||||
|
||||
|
||||
# News RSS Feeds
|
||||
rss_feeds = [
|
||||
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
|
||||
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
|
||||
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
|
||||
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
|
||||
Feed(name="Time", url="https://time.com/feed/"),
|
||||
Feed(name="Euronews", url="https://www.euronews.com/rss"),
|
||||
Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
|
||||
]
|
||||
|
||||
# Load an embedding model
|
||||
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
|
||||
# Collect news from all sources
|
||||
documents = []
|
||||
for feed in rss_feeds:
|
||||
documents.extend(feed.update())
|
||||
|
||||
# Step 2: Encode and store news articles into FAISS
|
||||
doc_vectors = np.array(embedding_model.encode(documents), dtype=np.float32)
|
||||
index = faiss.IndexFlatL2(doc_vectors.shape[1]) # Initialize FAISS index
|
||||
index.add(doc_vectors) # Store news vectors
|
||||
|
||||
print(f"Stored {len(doc_vectors)} documents in FAISS index.")
|
||||
|
||||
# Step 3: Retrieval function for user queries
|
||||
def retrieve_documents(query, top_k=2):
|
||||
"""Retrieve top-k most relevant news articles."""
|
||||
query_vector = np.array(embedding_model.encode([query]), dtype=np.float32)
|
||||
D, I = index.search(query_vector, top_k)
|
||||
retrieved_docs = [documents[i] for i in I[0]]
|
||||
return retrieved_docs
|
||||
|
||||
# Step 4: Format the RAG prompt
|
||||
def format_prompt(query, retrieved_docs):
|
||||
"""Format retrieved documents into a structured RAG prompt."""
|
||||
context_str = "\n".join(retrieved_docs)
|
||||
prompt = f"""You are an AI assistant with access to world news. Use the Retrieved Context to answer the user's question accurately if relevant, stating which Source provided the information.
|
||||
|
||||
## Retrieved Context:
|
||||
{context_str}
|
||||
|
||||
## User Query:
|
||||
{query}
|
||||
|
||||
## Response:
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
class Chat():
|
||||
def __init__(self, device_name):
|
||||
super().__init__()
|
||||
@ -53,12 +152,13 @@ class Chat():
|
||||
self.model = self.model.half().to(device_name)
|
||||
except Exception as e:
|
||||
logging.error(f"Loading error: {e}")
|
||||
raise Exception(e)
|
||||
|
||||
def remove_substring(self, string, substring):
|
||||
return string.replace(substring, "")
|
||||
|
||||
def generate_response(self, text):
|
||||
prompt = f"### System:\n{self.system_input}\n### User:\n{text}\n### Assistant:\n"
|
||||
prompt = text
|
||||
start = time.time()
|
||||
|
||||
with torch.autocast(self.device_name, dtype=torch.float16):
|
||||
@ -66,7 +166,7 @@ class Chat():
|
||||
prompt,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
max_length=1000, # Prevent 'Asking to truncate to max_length...'
|
||||
max_length=8000, # Prevent 'Asking to truncate to max_length...'
|
||||
padding=True, # Handles padding automatically
|
||||
truncation=True
|
||||
)
|
||||
@ -76,7 +176,7 @@ class Chat():
|
||||
outputs = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=1000,
|
||||
max_length=8000,
|
||||
num_return_sequences=1,
|
||||
pad_token_id=self.tokenizer.eos_token_id
|
||||
)
|
||||
@ -105,8 +205,29 @@ def chat_completions():
|
||||
max_tokens = data.get('max_tokens', 2048)
|
||||
|
||||
chat = app.config['chat']
|
||||
logging.info(f"Query: {messages}")
|
||||
response_content, _ = chat.generate_response(messages[-1]['content'])
|
||||
query = messages[-1]['content']
|
||||
|
||||
if re.match(r"^\s*(update|refresh) news\s*$", query, re.IGNORECASE):
|
||||
logging.info("New refresh requested")
|
||||
# Collect news from all sources
|
||||
documents = []
|
||||
for feed in rss_feeds:
|
||||
documents.extend(feed.update())
|
||||
# Step 2: Encode and store news articles into FAISS
|
||||
doc_vectors = np.array(embedding_model.encode(documents), dtype=np.float32)
|
||||
index = faiss.IndexFlatL2(doc_vectors.shape[1]) # Initialize FAISS index
|
||||
index.add(doc_vectors) # Store news vectors
|
||||
print(f"Stored {len(doc_vectors)} documents in FAISS index.")
|
||||
response_content = "News refresh requested."
|
||||
else:
|
||||
logging.info(f"Query: {query}")
|
||||
retrieved_docs = retrieve_documents(query)
|
||||
rag_prompt = format_prompt(query, retrieved_docs)
|
||||
logging.info(f"RAG prompt: {rag_prompt}")
|
||||
|
||||
# Get AI-generated response
|
||||
response_content, _ = chat.generate_response(rag_prompt)
|
||||
|
||||
logging.info(f"Response: {response_content}")
|
||||
# Format response in OpenAI-compatible structure
|
||||
response = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user