diff --git a/Dockerfile b/Dockerfile index 67cee3a..37e664c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -281,7 +281,7 @@ RUN { \ echo ' --ServerApp.allow_origin=* \' ; \ echo ' --ServerApp.base_url="/jupyter" \' ; \ echo ' "${@}" \' ; \ - echo ' >> "/root/.cache/jupyter.log" 2>&1' ; \ + echo ' 2>&1 | tee -a "/root/.cache/jupyter.log"' ; \ echo ' echo "jupyter notebook died ($?). Restarting."' ; \ echo ' sleep 5' ; \ echo 'done' ; \ @@ -393,4 +393,11 @@ RUN dpkg -i /opt/ze-monitor-*deb WORKDIR /opt/airc +SHELL [ "/opt/airc/shell" ] + +# Needed by src/model-server.py +RUN pip install faiss-cpu sentence_transformers feedparser + +SHELL [ "/bin/bash", "-c" ] + ENTRYPOINT [ "/entrypoint-airc.sh" ] diff --git a/src/airc.py b/src/airc.py index 17bda7f..5ab4b78 100644 --- a/src/airc.py +++ b/src/airc.py @@ -162,7 +162,6 @@ class AIRC(pydle.Client): # Extract and print just the assistant's message if available if "choices" in response and len(response["choices"]) > 0: content = response["choices"][0]["message"]["content"] - print(f"\nAssistant: {content}") if content: logging.info(f'Sending: {content}') diff --git a/src/model-server.py b/src/model-server.py index d7a35a9..21cf353 100644 --- a/src/model-server.py +++ b/src/model-server.py @@ -15,6 +15,11 @@ import asyncio import aiohttp import json from typing import Dict, Any +import feedparser +import faiss +import numpy as np +import torch +from sentence_transformers import SentenceTransformer def parse_args(): parser = argparse.ArgumentParser(description="AI is Really Cool Server") @@ -32,6 +37,100 @@ def setup_logging(level): logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s') logging.info(f"Logging is set to {level} level.") +class Feed(): + def __init__(self, name, url, poll_limit_min = 30, max_articles=5): + self.name = name + self.url = url + self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min) + self.last_poll = None + self.articles = [] + self.max_articles = max_articles + self.update() + + def update(self): + now = datetime.datetime.now() + if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min: + logging.info(f"Updating {self.name}") + feed = feedparser.parse(self.url) + self.articles = [] + self.last_poll = now + + content = "" + if len(feed.entries) > 0: + content += f"Source: {self.name}\n" + for entry in feed.entries[:self.max_articles]: + title = entry.get("title") + if title: + content += f"Title: {title}\n" + link = entry.get("link") + if link: + content += f"Link: {link}\n" + summary = entry.get("summary") + if summary: + content += f"Summary: {summary}\n" + published = entry.get("published") + if published: + content += f"Published: {published}\n" + content += "\n" + + self.articles.append(content) + else: + logging.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.") + return self.articles + + +# News RSS Feeds +rss_feeds = [ + Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"), + Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"), + Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"), + Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"), + Feed(name="Time", url="https://time.com/feed/"), + Feed(name="Euronews", url="https://www.euronews.com/rss"), + Feed(name="FeedX", url="https://feedx.net/rss/ap.xml") +] + +# Load an embedding model +embedding_model = SentenceTransformer("all-MiniLM-L6-v2") + + +# Collect news from all sources +documents = [] +for feed in rss_feeds: + documents.extend(feed.update()) + +# Step 2: Encode and store news articles into FAISS +doc_vectors = np.array(embedding_model.encode(documents), dtype=np.float32) +index = faiss.IndexFlatL2(doc_vectors.shape[1]) # Initialize FAISS index +index.add(doc_vectors) # Store news vectors + +print(f"Stored {len(doc_vectors)} documents in FAISS index.") + +# Step 3: Retrieval function for user queries +def retrieve_documents(query, top_k=2): + """Retrieve top-k most relevant news articles.""" + query_vector = np.array(embedding_model.encode([query]), dtype=np.float32) + D, I = index.search(query_vector, top_k) + retrieved_docs = [documents[i] for i in I[0]] + return retrieved_docs + +# Step 4: Format the RAG prompt +def format_prompt(query, retrieved_docs): + """Format retrieved documents into a structured RAG prompt.""" + context_str = "\n".join(retrieved_docs) + prompt = f"""You are an AI assistant with access to world news. Use the Retrieved Context to answer the user's question accurately if relevant, stating which Source provided the information. + +## Retrieved Context: +{context_str} + +## User Query: +{query} + +## Response: +""" + return prompt + + class Chat(): def __init__(self, device_name): super().__init__() @@ -53,12 +152,13 @@ class Chat(): self.model = self.model.half().to(device_name) except Exception as e: logging.error(f"Loading error: {e}") + raise Exception(e) def remove_substring(self, string, substring): return string.replace(substring, "") def generate_response(self, text): - prompt = f"### System:\n{self.system_input}\n### User:\n{text}\n### Assistant:\n" + prompt = text start = time.time() with torch.autocast(self.device_name, dtype=torch.float16): @@ -66,7 +166,7 @@ class Chat(): prompt, add_special_tokens=False, return_tensors="pt", - max_length=1000, # Prevent 'Asking to truncate to max_length...' + max_length=8000, # Prevent 'Asking to truncate to max_length...' padding=True, # Handles padding automatically truncation=True ) @@ -76,7 +176,7 @@ class Chat(): outputs = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, - max_length=1000, + max_length=8000, num_return_sequences=1, pad_token_id=self.tokenizer.eos_token_id ) @@ -105,8 +205,29 @@ def chat_completions(): max_tokens = data.get('max_tokens', 2048) chat = app.config['chat'] - logging.info(f"Query: {messages}") - response_content, _ = chat.generate_response(messages[-1]['content']) + query = messages[-1]['content'] + + if re.match(r"^\s*(update|refresh) news\s*$", query, re.IGNORECASE): + logging.info("New refresh requested") + # Collect news from all sources + documents = [] + for feed in rss_feeds: + documents.extend(feed.update()) + # Step 2: Encode and store news articles into FAISS + doc_vectors = np.array(embedding_model.encode(documents), dtype=np.float32) + index = faiss.IndexFlatL2(doc_vectors.shape[1]) # Initialize FAISS index + index.add(doc_vectors) # Store news vectors + print(f"Stored {len(doc_vectors)} documents in FAISS index.") + response_content = "News refresh requested." + else: + logging.info(f"Query: {query}") + retrieved_docs = retrieve_documents(query) + rag_prompt = format_prompt(query, retrieved_docs) + logging.info(f"RAG prompt: {rag_prompt}") + + # Get AI-generated response + response_content, _ = chat.generate_response(rag_prompt) + logging.info(f"Response: {response_content}") # Format response in OpenAI-compatible structure response = {