backstory/src/model-server.py

304 lines
11 KiB
Python

from flask import Flask, request, jsonify
import json
import asyncio
import argparse
import pydle
import torch
import logging
from ipex_llm.transformers import AutoModelForCausalLM
import transformers
import os
import re
import time
import datetime
import asyncio
import aiohttp
import json
from typing import Dict, Any
import feedparser
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from bs4 import BeautifulSoup
def parse_args():
parser = argparse.ArgumentParser(description="AI is Really Cool Server")
parser.add_argument("--device", type=int, default=0, help="Device # to use for inference. See --device-list")
#parser.add_argument("--device-list", help="List available devices")
parser.add_argument('--level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help='Set the logging level.')
return parser.parse_args()
def setup_logging(level):
numeric_level = getattr(logging, level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f"Invalid log level: {level}")
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
logging.info(f"Logging is set to {level} level.")
def extract_text_from_html_or_xml(content, is_xml=False):
# Parse the content
if is_xml:
soup = BeautifulSoup(content, 'xml') # Use 'xml' parser for XML content
else:
soup = BeautifulSoup(content, 'html.parser') # Default to 'html.parser' for HTML content
# Extract and return just the text
return soup.get_text()
class Feed():
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
self.name = name
self.url = url
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
self.last_poll = None
self.articles = []
self.max_articles = max_articles
self.update()
def update(self):
now = datetime.datetime.now()
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
logging.info(f"Updating {self.name}")
feed = feedparser.parse(self.url)
self.articles = []
self.last_poll = now
content = ""
if len(feed.entries) > 0:
content += f"Source: {self.name}\n"
for entry in feed.entries[:self.max_articles]:
title = entry.get("title")
if title:
content += f"Title: {title}\n"
link = entry.get("link")
if link:
content += f"Link: {link}\n"
summary = entry.get("summary")
if summary:
summary = extract_text_from_html_or_xml(summary, False)
content += f"Summary: {summary}\n"
published = entry.get("published")
if published:
content += f"Published: {published}\n"
content += "\n"
self.articles.append(content)
else:
logging.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
return self.articles
# News RSS Feeds
rss_feeds = [
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
Feed(name="Time", url="https://time.com/feed/"),
Feed(name="Euronews", url="https://www.euronews.com/rss"),
Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
]
# Load an embedding model
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# Collect news from all sources
documents = []
for feed in rss_feeds:
documents.extend(feed.articles)
# Step 2: Encode and store news articles into FAISS
doc_vectors = np.array(embedding_model.encode(documents), dtype=np.float32)
index = faiss.IndexFlatL2(doc_vectors.shape[1]) # Initialize FAISS index
index.add(doc_vectors) # Store news vectors
logging.info(f"Stored {len(doc_vectors)} documents in FAISS index.")
# Step 3: Retrieval function for user queries
def retrieve_documents(query, top_k=2):
"""Retrieve top-k most relevant news articles."""
query_vector = np.array(embedding_model.encode([query]), dtype=np.float32)
D, I = index.search(query_vector, top_k)
retrieved_docs = [documents[i] for i in I[0]]
return retrieved_docs
# Step 4: Format the RAG prompt
def format_prompt(query, retrieved_docs):
"""Format retrieved documents into a structured RAG prompt."""
context_str = "\n".join(retrieved_docs)
prompt = f"""You are an AI assistant with access to world news. Use the Retrieved Context to answer the user's question accurately if relevant, stating which Source provided the information.
## Retrieved Context:
{context_str}
## User Query:
{query}
## Response:
"""
return prompt
class Chat():
def __init__(self, device_name):
super().__init__()
self.device_name = device_name
self.system_input = "You are a critical assistant. Give concise and accurate answers in less than 120 characters."
self.context = None
self.model_path = 'Intel/neural-chat-7b-v3-3'
try:
logging.info(f"Loading tokenizer from: {self.model_path}")
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token # Set pad_token to eos_token if needed
self.model = AutoModelForCausalLM.from_pretrained(self.model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True,
use_cache=True)
self.model = self.model.half().to(device_name)
except Exception as e:
logging.error(f"Loading error: {e}")
raise Exception(e)
def remove_substring(self, string, substring):
return string.replace(substring, "")
def generate_response(self, text):
prompt = text
start = time.time()
with torch.autocast(self.device_name, dtype=torch.float16):
inputs = self.tokenizer.encode_plus(
prompt,
add_special_tokens=False,
return_tensors="pt",
max_length=7999, # Prevent 'Asking to truncate to max_length...'
padding=True, # Handles padding automatically
truncation=True
)
input_ids = inputs["input_ids"].to(self.device_name)
attention_mask = inputs["attention_mask"].to(self.device_name)
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=8000,
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id
)
final_outputs = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
final_outputs = self.remove_substring(final_outputs, prompt).strip()
end = time.time()
return final_outputs, datetime.timedelta(seconds=end - start)
app = Flask(__name__)
# Basic endpoint for chat completions
@app.route('/v1/chat/completions', methods=['POST'])
def chat_completions():
logging.info('/v1/chat/completions')
try:
# Get the JSON data from the request
data = request.get_json()
# Extract relevant fields from the request
model = data.get('model', 'default-model')
messages = data.get('messages', [])
temperature = data.get('temperature', 1.0)
max_tokens = data.get('max_tokens', 2048)
chat = app.config['chat']
query = messages[-1]['content']
if re.match(r"^\s*(update|refresh) news\s*$", query, re.IGNORECASE):
logging.info("New refresh requested")
# Collect news from all sources
documents = []
for feed in rss_feeds:
documents.extend(feed.update())
# Step 2: Encode and store news articles into FAISS
doc_vectors = np.array(embedding_model.encode(documents), dtype=np.float32)
index = faiss.IndexFlatL2(doc_vectors.shape[1]) # Initialize FAISS index
index.add(doc_vectors) # Store news vectors
logging.info(f"Stored {len(doc_vectors)} documents in FAISS index.")
response_content = "News refresh requested."
else:
logging.info(f"Query: {query}")
retrieved_docs = retrieve_documents(query)
rag_prompt = format_prompt(query, retrieved_docs)
logging.debug(f"RAG prompt: {rag_prompt}")
# Get AI-generated response
response_content, _ = chat.generate_response(rag_prompt)
logging.info(f"Response: {response_content}")
# Format response in OpenAI-compatible structure
response = {
"id": "chatcmpl-" + str(id(data)), # Simple unique ID
"object": "chat.completion",
"created": int(time.time()),
"model": chat.model_path,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response_content
},
"finish_reason": "stop"
}],
# "usage": {
# "prompt_tokens": len(str(messages).split()),
# "completion_tokens": len(response_content.split()),
# "total_tokens": len(str(messages).split()) + len(response_content.split())
# }
}
return jsonify(response)
except Exception as e:
logging.error(e)
return jsonify({
"error": {
"message": str(e),
"type": "invalid_request_error"
}
}), 400
# Health check endpoint
@app.route('/health', methods=['GET'])
def health():
return jsonify({"status": "healthy"}), 200
if __name__ == '__main__':
import time # Imported here for the timestamp
# Parse command-line arguments
args = parse_args()
# Setup logging based on the provided level
setup_logging(args.level)
if not torch.xpu.is_available():
logging.error("No XPU available.")
exit(1)
device_count = torch.xpu.device_count();
for i in range(device_count):
logging.info(f"Device {i}: {torch.xpu.get_device_name(i)} Total memory: {torch.xpu.get_device_properties(i).total_memory}")
device_name = 'xpu'
device = torch.device(device_name)
print(f"Using device: {device}")
# Set environment variables that might help with XPU stability
os.environ["SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS"] = "1"
app.config['chat'] = Chat(device_name)
app.run(host='0.0.0.0', port=5000, debug=True)