304 lines
11 KiB
Python
304 lines
11 KiB
Python
from flask import Flask, request, jsonify
|
|
import json
|
|
import asyncio
|
|
import argparse
|
|
import pydle
|
|
import torch
|
|
import logging
|
|
from ipex_llm.transformers import AutoModelForCausalLM
|
|
import transformers
|
|
import os
|
|
import re
|
|
import time
|
|
import datetime
|
|
import asyncio
|
|
import aiohttp
|
|
import json
|
|
from typing import Dict, Any
|
|
import feedparser
|
|
import faiss
|
|
import numpy as np
|
|
import torch
|
|
from sentence_transformers import SentenceTransformer
|
|
from bs4 import BeautifulSoup
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="AI is Really Cool Server")
|
|
parser.add_argument("--device", type=int, default=0, help="Device # to use for inference. See --device-list")
|
|
#parser.add_argument("--device-list", help="List available devices")
|
|
parser.add_argument('--level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
|
default='INFO', help='Set the logging level.')
|
|
return parser.parse_args()
|
|
|
|
def setup_logging(level):
|
|
numeric_level = getattr(logging, level.upper(), None)
|
|
if not isinstance(numeric_level, int):
|
|
raise ValueError(f"Invalid log level: {level}")
|
|
|
|
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logging.info(f"Logging is set to {level} level.")
|
|
|
|
def extract_text_from_html_or_xml(content, is_xml=False):
|
|
# Parse the content
|
|
if is_xml:
|
|
soup = BeautifulSoup(content, 'xml') # Use 'xml' parser for XML content
|
|
else:
|
|
soup = BeautifulSoup(content, 'html.parser') # Default to 'html.parser' for HTML content
|
|
|
|
# Extract and return just the text
|
|
return soup.get_text()
|
|
|
|
class Feed():
|
|
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
|
|
self.name = name
|
|
self.url = url
|
|
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
|
|
self.last_poll = None
|
|
self.articles = []
|
|
self.max_articles = max_articles
|
|
self.update()
|
|
|
|
def update(self):
|
|
now = datetime.datetime.now()
|
|
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
|
|
logging.info(f"Updating {self.name}")
|
|
feed = feedparser.parse(self.url)
|
|
self.articles = []
|
|
self.last_poll = now
|
|
|
|
content = ""
|
|
if len(feed.entries) > 0:
|
|
content += f"Source: {self.name}\n"
|
|
for entry in feed.entries[:self.max_articles]:
|
|
title = entry.get("title")
|
|
if title:
|
|
content += f"Title: {title}\n"
|
|
link = entry.get("link")
|
|
if link:
|
|
content += f"Link: {link}\n"
|
|
summary = entry.get("summary")
|
|
if summary:
|
|
summary = extract_text_from_html_or_xml(summary, False)
|
|
content += f"Summary: {summary}\n"
|
|
published = entry.get("published")
|
|
if published:
|
|
content += f"Published: {published}\n"
|
|
content += "\n"
|
|
|
|
self.articles.append(content)
|
|
else:
|
|
logging.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
|
|
return self.articles
|
|
|
|
|
|
# News RSS Feeds
|
|
rss_feeds = [
|
|
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
|
|
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
|
|
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
|
|
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
|
|
Feed(name="Time", url="https://time.com/feed/"),
|
|
Feed(name="Euronews", url="https://www.euronews.com/rss"),
|
|
Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
|
|
]
|
|
|
|
# Load an embedding model
|
|
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
|
|
|
# Collect news from all sources
|
|
documents = []
|
|
for feed in rss_feeds:
|
|
documents.extend(feed.articles)
|
|
|
|
# Step 2: Encode and store news articles into FAISS
|
|
doc_vectors = np.array(embedding_model.encode(documents), dtype=np.float32)
|
|
index = faiss.IndexFlatL2(doc_vectors.shape[1]) # Initialize FAISS index
|
|
index.add(doc_vectors) # Store news vectors
|
|
|
|
logging.info(f"Stored {len(doc_vectors)} documents in FAISS index.")
|
|
|
|
# Step 3: Retrieval function for user queries
|
|
def retrieve_documents(query, top_k=2):
|
|
"""Retrieve top-k most relevant news articles."""
|
|
query_vector = np.array(embedding_model.encode([query]), dtype=np.float32)
|
|
D, I = index.search(query_vector, top_k)
|
|
retrieved_docs = [documents[i] for i in I[0]]
|
|
return retrieved_docs
|
|
|
|
# Step 4: Format the RAG prompt
|
|
def format_prompt(query, retrieved_docs):
|
|
"""Format retrieved documents into a structured RAG prompt."""
|
|
context_str = "\n".join(retrieved_docs)
|
|
prompt = f"""You are an AI assistant with access to world news. Use the Retrieved Context to answer the user's question accurately if relevant, stating which Source provided the information.
|
|
|
|
## Retrieved Context:
|
|
{context_str}
|
|
|
|
## User Query:
|
|
{query}
|
|
|
|
## Response:
|
|
"""
|
|
return prompt
|
|
|
|
|
|
class Chat():
|
|
def __init__(self, device_name):
|
|
super().__init__()
|
|
self.device_name = device_name
|
|
self.system_input = "You are a critical assistant. Give concise and accurate answers in less than 120 characters."
|
|
self.context = None
|
|
self.model_path = 'Intel/neural-chat-7b-v3-3'
|
|
try:
|
|
logging.info(f"Loading tokenizer from: {self.model_path}")
|
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
|
if self.tokenizer.pad_token is None:
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token # Set pad_token to eos_token if needed
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(self.model_path,
|
|
load_in_4bit=True,
|
|
optimize_model=True,
|
|
trust_remote_code=True,
|
|
use_cache=True)
|
|
self.model = self.model.half().to(device_name)
|
|
except Exception as e:
|
|
logging.error(f"Loading error: {e}")
|
|
raise Exception(e)
|
|
|
|
def remove_substring(self, string, substring):
|
|
return string.replace(substring, "")
|
|
|
|
def generate_response(self, text):
|
|
prompt = text
|
|
start = time.time()
|
|
|
|
with torch.autocast(self.device_name, dtype=torch.float16):
|
|
inputs = self.tokenizer.encode_plus(
|
|
prompt,
|
|
add_special_tokens=False,
|
|
return_tensors="pt",
|
|
max_length=7999, # Prevent 'Asking to truncate to max_length...'
|
|
padding=True, # Handles padding automatically
|
|
truncation=True
|
|
)
|
|
input_ids = inputs["input_ids"].to(self.device_name)
|
|
attention_mask = inputs["attention_mask"].to(self.device_name)
|
|
|
|
outputs = self.model.generate(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
max_length=8000,
|
|
num_return_sequences=1,
|
|
pad_token_id=self.tokenizer.eos_token_id
|
|
)
|
|
|
|
final_outputs = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
final_outputs = self.remove_substring(final_outputs, prompt).strip()
|
|
|
|
end = time.time()
|
|
|
|
return final_outputs, datetime.timedelta(seconds=end - start)
|
|
|
|
app = Flask(__name__)
|
|
|
|
# Basic endpoint for chat completions
|
|
@app.route('/v1/chat/completions', methods=['POST'])
|
|
def chat_completions():
|
|
logging.info('/v1/chat/completions')
|
|
try:
|
|
# Get the JSON data from the request
|
|
data = request.get_json()
|
|
|
|
# Extract relevant fields from the request
|
|
model = data.get('model', 'default-model')
|
|
messages = data.get('messages', [])
|
|
temperature = data.get('temperature', 1.0)
|
|
max_tokens = data.get('max_tokens', 2048)
|
|
|
|
chat = app.config['chat']
|
|
query = messages[-1]['content']
|
|
|
|
if re.match(r"^\s*(update|refresh) news\s*$", query, re.IGNORECASE):
|
|
logging.info("New refresh requested")
|
|
# Collect news from all sources
|
|
documents = []
|
|
for feed in rss_feeds:
|
|
documents.extend(feed.update())
|
|
# Step 2: Encode and store news articles into FAISS
|
|
doc_vectors = np.array(embedding_model.encode(documents), dtype=np.float32)
|
|
index = faiss.IndexFlatL2(doc_vectors.shape[1]) # Initialize FAISS index
|
|
index.add(doc_vectors) # Store news vectors
|
|
logging.info(f"Stored {len(doc_vectors)} documents in FAISS index.")
|
|
response_content = "News refresh requested."
|
|
else:
|
|
logging.info(f"Query: {query}")
|
|
retrieved_docs = retrieve_documents(query)
|
|
rag_prompt = format_prompt(query, retrieved_docs)
|
|
logging.debug(f"RAG prompt: {rag_prompt}")
|
|
|
|
# Get AI-generated response
|
|
response_content, _ = chat.generate_response(rag_prompt)
|
|
|
|
logging.info(f"Response: {response_content}")
|
|
# Format response in OpenAI-compatible structure
|
|
response = {
|
|
"id": "chatcmpl-" + str(id(data)), # Simple unique ID
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": chat.model_path,
|
|
"choices": [{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": response_content
|
|
},
|
|
"finish_reason": "stop"
|
|
}],
|
|
# "usage": {
|
|
# "prompt_tokens": len(str(messages).split()),
|
|
# "completion_tokens": len(response_content.split()),
|
|
# "total_tokens": len(str(messages).split()) + len(response_content.split())
|
|
# }
|
|
}
|
|
|
|
return jsonify(response)
|
|
|
|
except Exception as e:
|
|
logging.error(e)
|
|
return jsonify({
|
|
"error": {
|
|
"message": str(e),
|
|
"type": "invalid_request_error"
|
|
}
|
|
}), 400
|
|
|
|
# Health check endpoint
|
|
@app.route('/health', methods=['GET'])
|
|
def health():
|
|
return jsonify({"status": "healthy"}), 200
|
|
|
|
if __name__ == '__main__':
|
|
import time # Imported here for the timestamp
|
|
# Parse command-line arguments
|
|
args = parse_args()
|
|
|
|
# Setup logging based on the provided level
|
|
setup_logging(args.level)
|
|
|
|
if not torch.xpu.is_available():
|
|
logging.error("No XPU available.")
|
|
exit(1)
|
|
device_count = torch.xpu.device_count();
|
|
for i in range(device_count):
|
|
logging.info(f"Device {i}: {torch.xpu.get_device_name(i)} Total memory: {torch.xpu.get_device_properties(i).total_memory}")
|
|
device_name = 'xpu'
|
|
device = torch.device(device_name)
|
|
print(f"Using device: {device}")
|
|
|
|
# Set environment variables that might help with XPU stability
|
|
os.environ["SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS"] = "1"
|
|
|
|
app.config['chat'] = Chat(device_name)
|
|
|
|
app.run(host='0.0.0.0', port=5000, debug=True) |