Switching to ollama
This commit is contained in:
parent
5ca70b2933
commit
8027b5f8e3
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
.env
|
||||
cache/**
|
||||
jupyter/**
|
||||
ollama/**
|
377
Dockerfile
377
Dockerfile
@ -1,74 +1,3 @@
|
||||
FROM ubuntu:oracular AS pytorch-build
|
||||
|
||||
SHELL [ "/bin/bash", "-c" ]
|
||||
|
||||
# Instructions Dockerfied from:
|
||||
#
|
||||
# https://github.com/pytorch/pytorch
|
||||
#
|
||||
# and
|
||||
#
|
||||
# https://pytorch.org/docs/stable/notes/get_start_xpu.html
|
||||
# https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpu/2-6.html
|
||||
#
|
||||
#
|
||||
RUN apt-get update \
|
||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
gpg \
|
||||
wget \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||
|
||||
|
||||
# ipex only supports python 3.11, so use 3.11 instead of latest oracular (3.12)
|
||||
|
||||
RUN apt-get update \
|
||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
build-essential \
|
||||
ca-certificates \
|
||||
ccache \
|
||||
cmake \
|
||||
curl \
|
||||
git \
|
||||
gpg-agent \
|
||||
less \
|
||||
libbz2-dev \
|
||||
libffi-dev \
|
||||
libjpeg-dev \
|
||||
libpng-dev \
|
||||
libreadline-dev \
|
||||
libssl-dev \
|
||||
libsqlite3-dev \
|
||||
llvm \
|
||||
nano \
|
||||
wget \
|
||||
zlib1g-dev \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||
|
||||
# python3 \
|
||||
# python3-pip \
|
||||
# python3-venv \
|
||||
# python3-dev \
|
||||
|
||||
RUN /usr/sbin/update-ccache-symlinks
|
||||
RUN mkdir /opt/ccache && ccache --set-config=cache_dir=/opt/ccache
|
||||
|
||||
# Build Python in /opt/..., install it locally, then remove the build environment
|
||||
# collapsed to a single docker layer.
|
||||
WORKDIR /opt
|
||||
ENV PYTHON_VERSION=3.11.9
|
||||
|
||||
RUN wget -q -O - https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz | tar -xz \
|
||||
&& cd Python-${PYTHON_VERSION} \
|
||||
&& ./configure --prefix=/opt/python --enable-optimizations \
|
||||
&& make -j$(nproc) \
|
||||
&& make install \
|
||||
&& cd /opt \
|
||||
&& rm -rf Python-${PYTHON_VERSION}
|
||||
|
||||
WORKDIR /opt/pytorch
|
||||
|
||||
FROM ubuntu:oracular AS ze-monitor
|
||||
# From https://github.com/jketreno/ze-monitor
|
||||
RUN apt-get update \
|
||||
@ -100,10 +29,20 @@ RUN cmake .. \
|
||||
&& make \
|
||||
&& cpack
|
||||
|
||||
FROM pytorch-build AS pytorch
|
||||
FROM ubuntu:oracular AS airc
|
||||
|
||||
COPY --from=pytorch-build /opt/pytorch /opt/pytorch
|
||||
# Get a couple prerequisites
|
||||
RUN apt-get update \
|
||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
gpg \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-venv \
|
||||
wget \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||
|
||||
# Install Intel graphics runtimes
|
||||
RUN apt-get update \
|
||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y software-properties-common \
|
||||
&& add-apt-repository -y ppa:kobuk-team/intel-graphics \
|
||||
@ -117,74 +56,176 @@ RUN apt-get update \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||
|
||||
RUN update-alternatives --install /usr/bin/python3 python3 /opt/python/bin/python3.11 2
|
||||
WORKDIR /opt/airc
|
||||
|
||||
# When cache is enabled SYCL runtime will try to cache and reuse JIT-compiled binaries.
|
||||
ENV SYCL_CACHE_PERSISTENT=1
|
||||
|
||||
WORKDIR /opt/pytorch
|
||||
# Setup the ollama python virtual environment
|
||||
RUN python3 -m venv --system-site-packages /opt/airc/venv
|
||||
|
||||
# Setup the docker pip shell
|
||||
RUN { \
|
||||
echo '#!/bin/bash' ; \
|
||||
update-alternatives --set python3 /opt/python/bin/python3.11 ; \
|
||||
echo 'source /opt/pytorch/venv/bin/activate' ; \
|
||||
echo 'source /opt/airc/venv/bin/activate' ; \
|
||||
echo 'bash -c "${@}"' ; \
|
||||
} > /opt/pytorch/shell ; \
|
||||
chmod +x /opt/pytorch/shell
|
||||
} > /opt/airc/shell ; \
|
||||
chmod +x /opt/airc/shell
|
||||
|
||||
RUN python3 -m venv --system-site-packages /opt/pytorch/venv
|
||||
# Activate the pip environment on all shell calls
|
||||
SHELL [ "/opt/airc/shell" ]
|
||||
|
||||
SHELL [ "/opt/pytorch/shell" ]
|
||||
# Install ollama python module
|
||||
RUN pip3 install ollama
|
||||
# pydle does not work with newer asyncio due to coroutine
|
||||
# being deprecated. Patch to work.
|
||||
COPY /src/pydle.patch /opt/pydle.patch
|
||||
|
||||
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu
|
||||
RUN pip3 freeze > /opt/pytorch/requirements.txt
|
||||
RUN pip3 install pydle \
|
||||
&& patch -d /opt/airc/venv/lib/python3*/site-packages/pydle \
|
||||
-p1 < /opt/pydle.patch \
|
||||
&& rm /opt/pydle.patch
|
||||
|
||||
RUN pip install setuptools --upgrade
|
||||
RUN pip install ollama
|
||||
RUN pip install feedparser bs4 chromadb
|
||||
|
||||
SHELL [ "/bin/bash", "-c" ]
|
||||
|
||||
RUN apt-get update \
|
||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
libncurses6 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||
|
||||
COPY --from=ze-monitor /opt/ze-monitor/build/ze-monitor-*deb /opt/
|
||||
RUN dpkg -i /opt/ze-monitor-*deb && rm /opt/ze-monitor-*deb
|
||||
|
||||
COPY /src/ /opt/airc/src/
|
||||
|
||||
SHELL [ "/bin/bash", "-c" ]
|
||||
|
||||
RUN { \
|
||||
echo '#!/bin/bash' ; \
|
||||
echo 'echo "Container: pytorch"' ; \
|
||||
echo 'set -e' ; \
|
||||
echo 'echo "Setting pip environment to /opt/pytorch"' ; \
|
||||
echo 'source /opt/pytorch/venv/bin/activate'; \
|
||||
echo 'if [[ "${1}" == "" ]] || [[ "${1}" == "shell" ]]; then' ; \
|
||||
echo ' echo "Dropping to shell"' ; \
|
||||
echo ' /bin/bash -c "source /opt/pytorch/venv/bin/activate ; /bin/bash"' ; \
|
||||
echo 'else' ; \
|
||||
echo ' exec "${@}"' ; \
|
||||
echo 'fi' ; \
|
||||
echo '#!/bin/bash'; \
|
||||
echo 'echo "Container: airc"'; \
|
||||
echo 'set -e'; \
|
||||
echo 'echo "Setting pip environment to /opt/airc"'; \
|
||||
echo 'source /opt/airc/venv/bin/activate'; \
|
||||
echo ''; \
|
||||
echo 'if [[ "${1}" == "/bin/bash" ]] || [[ "${1}" =~ ^(/opt/airc/)?shell$ ]]; then'; \
|
||||
echo ' echo "Dropping to shell"'; \
|
||||
echo ' shift' ; \
|
||||
echo ' echo "Running: ${@}"' ; \
|
||||
echo ' if [[ "${1}" != "" ]]; then' ; \
|
||||
echo ' exec ${@}'; \
|
||||
echo ' else' ; \
|
||||
echo ' exec /bin/bash'; \
|
||||
echo ' fi' ; \
|
||||
echo 'else'; \
|
||||
echo ' echo "Launching AIRC chat server..."'; \
|
||||
echo ' python src/airc.py "${@}"' ; \
|
||||
echo 'fi'; \
|
||||
} > /entrypoint.sh \
|
||||
&& chmod +x /entrypoint.sh
|
||||
|
||||
ENTRYPOINT [ "/entrypoint.sh" ]
|
||||
|
||||
FROM pytorch AS ipex-llm-src
|
||||
FROM ubuntu:oracular AS ollama
|
||||
|
||||
# Build ipex-llm from source
|
||||
# Get a couple prerequisites
|
||||
RUN apt-get update \
|
||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
gpg \
|
||||
wget \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||
|
||||
RUN git clone --branch main --depth 1 https://github.com/intel/ipex-llm.git /opt/ipex-llm \
|
||||
&& cd /opt/ipex-llm \
|
||||
&& git fetch --depth 1 origin cb3c4b26ad058c156591816aa37eec4acfcbf765 \
|
||||
&& git checkout cb3c4b26ad058c156591816aa37eec4acfcbf765
|
||||
# Install Intel graphics runtimes
|
||||
RUN apt-get update \
|
||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y software-properties-common \
|
||||
&& add-apt-repository -y ppa:kobuk-team/intel-graphics \
|
||||
&& apt-get update \
|
||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
libze-intel-gpu1 \
|
||||
libze1 \
|
||||
intel-ocloc \
|
||||
intel-opencl-icd \
|
||||
xpu-smi \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||
|
||||
WORKDIR /opt/ipex-llm
|
||||
WORKDIR /opt/ollama
|
||||
|
||||
RUN python3 -m venv --system-site-packages /opt/ipex-llm/venv
|
||||
# Download the nightly ollama release from ipex-llm
|
||||
RUN wget -qO - https://github.com/intel/ipex-llm/releases/download/v2.2.0-nightly/ollama-0.5.4-ipex-llm-2.2.0b20250226-ubuntu.tgz | \
|
||||
tar --strip-components=1 -C . -xzv
|
||||
|
||||
# Install Python from Oracular (ollama works with 3.12)
|
||||
RUN apt-get update \
|
||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
gpg \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-venv \
|
||||
wget \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||
|
||||
# Setup the ollama python virtual environment
|
||||
RUN python3 -m venv --system-site-packages /opt/ollama/venv
|
||||
|
||||
# Setup the docker pip shell
|
||||
RUN { \
|
||||
echo '#!/bin/bash' ; \
|
||||
update-alternatives --set python3 /opt/python/bin/python3.11 ; \
|
||||
echo 'source /opt/ipex-llm/venv/bin/activate' ; \
|
||||
echo 'source /opt/ollama/venv/bin/activate' ; \
|
||||
echo 'bash -c "${@}"' ; \
|
||||
} > /opt/ipex-llm/shell ; \
|
||||
chmod +x /opt/ipex-llm/shell
|
||||
} > /opt/ollama/shell ; \
|
||||
chmod +x /opt/ollama/shell
|
||||
|
||||
SHELL [ "/opt/ipex-llm/shell" ]
|
||||
# Activate the pip environment on all shell calls
|
||||
SHELL [ "/opt/ollama/shell" ]
|
||||
|
||||
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu
|
||||
# Install ollama python module
|
||||
RUN pip3 install ollama
|
||||
|
||||
WORKDIR /opt/ipex-llm/python/llm
|
||||
RUN pip install requests wheel
|
||||
RUN python setup.py clean --all bdist_wheel --linux
|
||||
SHELL [ "/bin/bash", "-c" ]
|
||||
|
||||
RUN { \
|
||||
echo '#!/bin/bash'; \
|
||||
echo 'echo "Container: ollama"'; \
|
||||
echo 'set -e'; \
|
||||
echo 'echo "Setting pip environment to /opt/ollama"'; \
|
||||
echo 'source /opt/ollama/venv/bin/activate'; \
|
||||
echo 'export OLLAMA_NUM_GPU=999'; \
|
||||
echo 'export ZES_ENABLE_SYSMAN=1'; \
|
||||
echo 'export SYCL_CACHE_PERSISTENT=1'; \
|
||||
echo 'export OLLAMA_KEEP_ALIVE=-1'; \
|
||||
echo 'export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1'; \
|
||||
echo ''; \
|
||||
echo 'if [[ "${1}" == "/bin/bash" ]] || [[ "${1}" =~ ^(/opt/ollama/)?shell$ ]]; then'; \
|
||||
echo ' echo "Dropping to shell"'; \
|
||||
echo ' exec /bin/bash'; \
|
||||
echo 'else'; \
|
||||
echo ' echo "Launching Ollama server..."'; \
|
||||
echo ' exec ./ollama serve'; \
|
||||
echo 'fi'; \
|
||||
} > /entrypoint.sh \
|
||||
&& chmod +x /entrypoint.sh
|
||||
|
||||
RUN { \
|
||||
echo '#!/bin/bash'; \
|
||||
echo 'echo "Container: ollama"'; \
|
||||
echo 'set -e'; \
|
||||
echo 'echo "Setting pip environment to /opt/ollama"'; \
|
||||
echo 'source /opt/ollama/venv/bin/activate'; \
|
||||
echo './ollama pull mxbai-embed-large' ; \
|
||||
echo './ollama pull deepseek-r1:7b' ; \
|
||||
} > /fetch-models.sh \
|
||||
&& chmod +x /fetch-models.sh
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
VOLUME [" /root/.ollama" ]
|
||||
|
||||
ENTRYPOINT [ "/entrypoint.sh" ]
|
||||
|
||||
FROM airc AS jupyter
|
||||
|
||||
@ -236,112 +277,4 @@ RUN { \
|
||||
} > /entrypoint-jupyter.sh \
|
||||
&& chmod +x /entrypoint-jupyter.sh
|
||||
|
||||
ENTRYPOINT [ "/entrypoint-jupyter.sh" ]
|
||||
|
||||
FROM pytorch AS airc
|
||||
|
||||
RUN python3 -m venv --system-site-packages /opt/airc/venv
|
||||
|
||||
# Don't install the full oneapi essentials; just the ones that we seem to need
|
||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||
| gpg --dearmor -o /usr/share/keyrings/oneapi-archive-keyring.gpg \
|
||||
&& echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" \
|
||||
| tee /etc/apt/sources.list.d/oneAPI.list \
|
||||
&& apt-get update \
|
||||
&& DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
intel-oneapi-mkl-sycl-2025.0 \
|
||||
intel-oneapi-dnnl-2025.0 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||
|
||||
RUN { \
|
||||
echo '#!/bin/bash' ; \
|
||||
echo 'update-alternatives --set python3 /opt/python/bin/python3.11' ; \
|
||||
echo 'if [[ -e /opt/intel/oneapi/setvars.sh ]]; then source /opt/intel/oneapi/setvars.sh; fi' ; \
|
||||
echo 'source /opt/airc/venv/bin/activate' ; \
|
||||
echo 'if [[ "$1" == "" ]]; then bash -c; else bash -c "${@}"; fi' ; \
|
||||
} > /opt/airc/shell ; \
|
||||
chmod +x /opt/airc/shell
|
||||
|
||||
SHELL [ "/opt/airc/shell" ]
|
||||
|
||||
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu
|
||||
# Install ipex-llm built in ipex-llm-src
|
||||
COPY --from=ipex-llm-src /opt/ipex-llm/python/llm/dist/*.whl /opt/wheels/
|
||||
RUN for pkg in /opt/wheels/ipex_llm*.whl; do pip install $pkg; done
|
||||
|
||||
COPY src/ /opt/airc/src/
|
||||
|
||||
# pydle does not work with newer asyncio due to coroutine
|
||||
# being deprecated. Patch to work.
|
||||
RUN pip3 install pydle transformers sentencepiece accelerate \
|
||||
&& patch -d /opt/airc/venv/lib/python3*/site-packages/pydle \
|
||||
-p1 < /opt/airc/src/pydle.patch
|
||||
|
||||
# mistral fails with cache_position errors with transformers>4.40 (or at least it fails with the latest)
|
||||
# as well as MistralSpda* things missing
|
||||
RUN pip install "sentence_transformers<3.4.1" "transformers==4.40.0"
|
||||
|
||||
# To get xe_linear and other Xe methods
|
||||
RUN pip3 install 'bigdl-core-xe-all>=2.6.0b'
|
||||
|
||||
# trl.core doesn't have what is needed with the default 'pip install trl' version
|
||||
RUN pip install git+https://github.com/huggingface/trl.git@7630f877f91c556d9e5a3baa4b6e2894d90ff84c
|
||||
|
||||
# Needed by src/model-server.py
|
||||
RUN pip install flask
|
||||
|
||||
SHELL [ "/bin/bash", "-c" ]
|
||||
|
||||
RUN { \
|
||||
echo '#!/bin/bash' ; \
|
||||
echo 'set -e' ; \
|
||||
echo 'if [[ ! -e "/root/.cache/hub/token" ]]; then' ; \
|
||||
echo ' if [[ "${HF_ACCESS_TOKEN}" == "" ]]; then' ; \
|
||||
echo ' echo "Set your HF access token in .env as: HF_ACCESS_TOKEN=<token>" >&2' ; \
|
||||
echo ' exit 1' ; \
|
||||
echo ' else' ; \
|
||||
echo ' if [[ ! -d '/root/.cache/hub' ]]; then mkdir -p /root/.cache/hub; fi' ; \
|
||||
echo ' echo "${HF_ACCESS_TOKEN}" > /root/.cache/hub/token' ; \
|
||||
echo ' fi' ; \
|
||||
echo 'fi' ; \
|
||||
echo 'echo "Container: airc"' ; \
|
||||
echo 'echo "Setting pip environment to /opt/airc"' ; \
|
||||
echo 'if [[ -e /opt/intel/oneapi/setvars.sh ]]; then source /opt/intel/oneapi/setvars.sh; fi' ; \
|
||||
echo 'source /opt/airc/venv/bin/activate'; \
|
||||
echo 'if [[ "${1}" == "shell" ]] || [[ "${1}" == "/bin/bash" ]]; then' ; \
|
||||
echo ' echo "Dropping to shell"' ; \
|
||||
echo ' /bin/bash -c "source /opt/airc/venv/bin/activate ; /bin/bash"' ; \
|
||||
echo ' exit $?' ; \
|
||||
echo 'else' ; \
|
||||
echo ' while true; do' ; \
|
||||
echo ' echo "Launching model-server"' ; \
|
||||
echo ' python src/model-server.py \' ; \
|
||||
echo ' 2>&1 | tee -a "/root/.cache/model-server.log"'; \
|
||||
echo ' echo "model-server died ($?). Restarting."' ; \
|
||||
echo ' sleep 5' ; \
|
||||
echo ' done &' ; \
|
||||
echo ' while true; do' ; \
|
||||
echo ' echo "Launching airc"' ; \
|
||||
echo ' python src/airc.py "${@}" \' ; \
|
||||
echo ' 2>&1 | tee -a "/root/.cache/airc.log"' ; \
|
||||
echo ' echo "airc died ($?). Restarting."' ; \
|
||||
echo ' sleep 5' ; \
|
||||
echo ' done' ; \
|
||||
echo 'fi' ; \
|
||||
} > /entrypoint-airc.sh \
|
||||
&& chmod +x /entrypoint-airc.sh
|
||||
|
||||
COPY --from=ze-monitor /opt/ze-monitor/build/ze-monitor-*deb /opt/
|
||||
RUN dpkg -i /opt/ze-monitor-*deb
|
||||
|
||||
WORKDIR /opt/airc
|
||||
|
||||
SHELL [ "/opt/airc/shell" ]
|
||||
|
||||
# Needed by src/model-server.py
|
||||
RUN pip install faiss-cpu sentence_transformers feedparser bs4
|
||||
|
||||
SHELL [ "/bin/bash", "-c" ]
|
||||
|
||||
ENTRYPOINT [ "/entrypoint-airc.sh" ]
|
||||
ENTRYPOINT [ "/entrypoint-jupyter.sh" ]
|
@ -10,6 +10,10 @@ services:
|
||||
- .env
|
||||
devices:
|
||||
- /dev/dri:/dev/dri
|
||||
depends_on:
|
||||
- ollama
|
||||
networks:
|
||||
- internal
|
||||
volumes:
|
||||
- ./cache:/root/.cache
|
||||
- ./src:/opt/airc/src:rw
|
||||
@ -18,6 +22,33 @@ services:
|
||||
- CAP_PERFMON # Access to perf_events (vs. overloaded CAP_SYS_ADMIN)
|
||||
- CAP_SYS_PTRACE # PTRACE_MODE_READ_REALCREDS ptrace access mode check
|
||||
|
||||
ollama:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: ollama
|
||||
image: ollama
|
||||
restart: "no"
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- OLLAMA_HOST=0.0.0.0
|
||||
- ONEAPI_DEVICE_SELECTOR=level_zero:0
|
||||
devices:
|
||||
- /dev/dri:/dev/dri
|
||||
# ports:
|
||||
# - 11434:11434 # ollama serve port
|
||||
networks:
|
||||
- internal
|
||||
volumes:
|
||||
- ./cache:/root/.cache # Cache hub models and neo_compiler_cache
|
||||
- ./ollama:/root/.ollama # Cache the ollama models
|
||||
- ./src:/opt/airc/src:rw # Live mount src
|
||||
cap_add: # used for running ze-monitor within airc container
|
||||
- CAP_DAC_READ_SEARCH # Bypass all filesystem read access checks
|
||||
- CAP_PERFMON # Access to perf_events (vs. overloaded CAP_SYS_ADMIN)
|
||||
- CAP_SYS_PTRACE # PTRACE_MODE_READ_REALCREDS ptrace access mode check
|
||||
|
||||
jupyter:
|
||||
build:
|
||||
context: .
|
||||
@ -28,8 +59,17 @@ services:
|
||||
- .env
|
||||
devices:
|
||||
- /dev/dri:/dev/dri
|
||||
depends_on:
|
||||
- ollama
|
||||
ports:
|
||||
- 8888:8888 # Jupyter Notebook
|
||||
networks:
|
||||
- internal
|
||||
volumes:
|
||||
- ./jupyter:/opt/jupyter:rw
|
||||
- ./cache:/root/.cache
|
||||
|
||||
networks:
|
||||
internal:
|
||||
driver: bridge
|
||||
|
||||
|
189
src/airc.py
189
src/airc.py
@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import argparse
|
||||
import pydle
|
||||
import logging
|
||||
@ -9,7 +8,15 @@ import time
|
||||
import datetime
|
||||
import asyncio
|
||||
import json
|
||||
import ollama
|
||||
from typing import Dict, Any
|
||||
import ollama
|
||||
import chromadb
|
||||
import feedparser
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
|
||||
MODEL_NAME = "deepseek-r1:7b"
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
||||
@ -22,50 +29,6 @@ def parse_args():
|
||||
default='INFO', help='Set the logging level.')
|
||||
return parser.parse_args()
|
||||
|
||||
class AsyncOpenAIClient:
|
||||
def __init__(self, base_url: str = "http://localhost:5000"):
|
||||
logging.info(f"Using {base_url} as server")
|
||||
self.base_url = base_url
|
||||
self.session = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
async def chat_completion(self,
|
||||
messages: list,
|
||||
model: str = "my-model",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 100) -> Dict[str, Any]:
|
||||
"""
|
||||
Make an async chat completion request
|
||||
"""
|
||||
url = f"{self.base_url}/v1/chat/completions"
|
||||
|
||||
# Prepare the request payload
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
try:
|
||||
async with self.session.post(url, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"Request failed with status {response.status}: {error_text}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during request: {str(e)}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def setup_logging(level):
|
||||
numeric_level = getattr(logging, level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
@ -74,6 +37,100 @@ def setup_logging(level):
|
||||
logging.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logging.info(f"Logging is set to {level} level.")
|
||||
|
||||
|
||||
client = ollama.Client(host=OLLAMA_API_URL)
|
||||
|
||||
def extract_text_from_html_or_xml(content, is_xml=False):
|
||||
# Parse the content
|
||||
if is_xml:
|
||||
soup = BeautifulSoup(content, 'xml') # Use 'xml' parser for XML content
|
||||
else:
|
||||
soup = BeautifulSoup(content, 'html.parser') # Default to 'html.parser' for HTML content
|
||||
|
||||
# Extract and return just the text
|
||||
return soup.get_text()
|
||||
|
||||
class Feed():
|
||||
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
|
||||
self.name = name
|
||||
self.url = url
|
||||
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
|
||||
self.last_poll = None
|
||||
self.articles = []
|
||||
self.max_articles = max_articles
|
||||
self.update()
|
||||
|
||||
def update(self):
|
||||
now = datetime.datetime.now()
|
||||
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
|
||||
logging.info(f"Updating {self.name}")
|
||||
feed = feedparser.parse(self.url)
|
||||
self.articles = []
|
||||
self.last_poll = now
|
||||
|
||||
content = ""
|
||||
if len(feed.entries) > 0:
|
||||
content += f"Source: {self.name}\n"
|
||||
for entry in feed.entries[:self.max_articles]:
|
||||
title = entry.get("title")
|
||||
if title:
|
||||
content += f"Title: {title}\n"
|
||||
link = entry.get("link")
|
||||
if link:
|
||||
content += f"Link: {link}\n"
|
||||
summary = entry.get("summary")
|
||||
if summary:
|
||||
summary = extract_text_from_html_or_xml(summary, False)
|
||||
content += f"Summary: {summary}\n"
|
||||
published = entry.get("published")
|
||||
if published:
|
||||
content += f"Published: {published}\n"
|
||||
content += "\n"
|
||||
|
||||
self.articles.append(content)
|
||||
else:
|
||||
logging.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
|
||||
return self.articles
|
||||
|
||||
|
||||
# News RSS Feeds
|
||||
rss_feeds = [
|
||||
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
|
||||
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
|
||||
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
|
||||
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
|
||||
Feed(name="Time", url="https://time.com/feed/"),
|
||||
Feed(name="Euronews", url="https://www.euronews.com/rss"),
|
||||
Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
|
||||
]
|
||||
|
||||
documents = [
|
||||
"Llamas like to eat penguins",
|
||||
"Llamas are not vegetarians and have very efficient digestive systems",
|
||||
"Llamas live to be about 120 years old, though some only live for 15 years and others live to be 90 years old",
|
||||
]
|
||||
|
||||
import chromadb
|
||||
|
||||
# Initialize ChromaDB Client
|
||||
db = chromadb.PersistentClient(path="/root/.cache/chroma.db")
|
||||
|
||||
# We want to save the collection to disk to analyze it offline, but we don't
|
||||
# want to re-use it
|
||||
collection = db.get_or_create_collection("docs")
|
||||
|
||||
# store each document in a vector embedding database
|
||||
for i, feed in enumerate(rss_feeds):
|
||||
# Use the client instance instead of the global ollama module
|
||||
for j, article in enumerate(feed.articles):
|
||||
response = client.embeddings(model="mxbai-embed-large", prompt=article)
|
||||
embeddings = response["embedding"] # Note: it's "embedding", not "embeddings"
|
||||
collection.add(
|
||||
ids=[str(i)+str(j)],
|
||||
embeddings=embeddings,
|
||||
documents=[article]
|
||||
)
|
||||
|
||||
class AIRC(pydle.Client):
|
||||
def __init__(self, nick, channel, client, burst_limit = 5, rate_limit = 1.0, burst_reset_timeout = 10.0):
|
||||
super().__init__(nick)
|
||||
@ -89,6 +146,8 @@ class AIRC(pydle.Client):
|
||||
self._message_queue = asyncio.Queue()
|
||||
self._task = asyncio.create_task(self._send_from_queue())
|
||||
self.client = client
|
||||
self.queries = 0
|
||||
self.processing = datetime.timedelta(minutes=0)
|
||||
|
||||
async def _send_from_queue(self):
|
||||
"""Background task that sends queued messages with burst + rate limiting."""
|
||||
@ -157,18 +216,31 @@ class AIRC(pydle.Client):
|
||||
if body == "stats":
|
||||
content = f"{self.queries} queries handled in {self.processing}s"
|
||||
else:
|
||||
# Sample messages
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_input},
|
||||
{"role": "user", "content": body}
|
||||
]
|
||||
self.queries += 1
|
||||
start = datetime.datetime.now()
|
||||
query_text = body
|
||||
query_response = client.embeddings(model="mxbai-embed-large", prompt=query_text)
|
||||
query_embedding = query_response["embedding"] # Note: singular "embedding", not plural
|
||||
|
||||
# Make the request
|
||||
response = await self.client.chat_completion(messages)
|
||||
# Then run the query with the correct structure
|
||||
results = collection.query(
|
||||
query_embeddings=[query_embedding], # Make sure this is a list containing the embedding
|
||||
n_results=3
|
||||
)
|
||||
data = results['documents'][0][0]
|
||||
logging.info(f"Data for {query_text}: {data}")
|
||||
logging.info(f"From {results}")
|
||||
output = client.generate(
|
||||
model=MODEL_NAME,
|
||||
system=f"Your are {self.nick}. In your response, make reference to this data if appropriate: {data}",
|
||||
prompt=f"Respond to this prompt: {query_text}",
|
||||
stream=False
|
||||
)
|
||||
end = datetime.datetime.now()
|
||||
self.processing = self.processing + end - start
|
||||
|
||||
# Extract and print just the assistant's message if available
|
||||
if "choices" in response and len(response["choices"]) > 0:
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
# Prune off the <think>...</think>
|
||||
content = re.sub(r'^<think>.*?</think>', '', output['response'], flags=re.DOTALL).strip()
|
||||
|
||||
if content:
|
||||
logging.info(f'Sending: {content}')
|
||||
@ -184,10 +256,9 @@ async def main():
|
||||
# Setup logging based on the provided level
|
||||
setup_logging(args.level)
|
||||
|
||||
async with AsyncOpenAIClient(base_url=args.ai_server) as client:
|
||||
bot = AIRC(args.nickname, args.channel, client)
|
||||
await bot.connect(args.server, args.port, tls=False)
|
||||
await bot.handle_forever()
|
||||
bot = AIRC(args.nickname, args.channel, client)
|
||||
await bot.connect(args.server, args.port, tls=False)
|
||||
await bot.handle_forever()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
209
src/chat.py
Normal file
209
src/chat.py
Normal file
@ -0,0 +1,209 @@
|
||||
import logging as log
|
||||
import argparse
|
||||
import re
|
||||
import datetime
|
||||
import ollama
|
||||
import chromadb
|
||||
import feedparser
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
|
||||
MODEL_NAME = "deepseek-r1:7b"
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="AI is Really Cool")
|
||||
parser.add_argument("--nickname", type=str, default="airc", help="Bot nickname")
|
||||
parser.add_argument('--level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
default='INFO', help='Set the log level.')
|
||||
return parser.parse_args()
|
||||
|
||||
def setup_logging(level):
|
||||
numeric_level = getattr(log, level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
raise ValueError(f"Invalid log level: {level}")
|
||||
|
||||
log.basicConfig(level=numeric_level, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
log.info(f"Logging is set to {level} level.")
|
||||
|
||||
def extract_text_from_html_or_xml(content, is_xml=False):
|
||||
# Parse the content
|
||||
if is_xml:
|
||||
soup = BeautifulSoup(content, 'xml') # Use 'xml' parser for XML content
|
||||
else:
|
||||
soup = BeautifulSoup(content, 'html.parser') # Default to 'html.parser' for HTML content
|
||||
|
||||
# Extract and return just the text
|
||||
return soup.get_text()
|
||||
|
||||
class Feed():
|
||||
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
|
||||
self.name = name
|
||||
self.url = url
|
||||
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
|
||||
self.last_poll = None
|
||||
self.articles = []
|
||||
self.max_articles = max_articles
|
||||
self.update()
|
||||
|
||||
def update(self):
|
||||
now = datetime.datetime.now()
|
||||
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
|
||||
log.info(f"Updating {self.name}")
|
||||
feed = feedparser.parse(self.url)
|
||||
self.articles = []
|
||||
self.last_poll = now
|
||||
|
||||
content = ""
|
||||
if len(feed.entries) > 0:
|
||||
content += f"Source: {self.name}\n"
|
||||
for entry in feed.entries[:self.max_articles]:
|
||||
title = entry.get("title")
|
||||
if title:
|
||||
content += f"Title: {title}\n"
|
||||
link = entry.get("link")
|
||||
if link:
|
||||
content += f"Link: {link}\n"
|
||||
summary = entry.get("summary")
|
||||
if summary:
|
||||
summary = extract_text_from_html_or_xml(summary, False)
|
||||
if len(summary) > 1000:
|
||||
print(summary)
|
||||
exit(0)
|
||||
content += f"Summary: {summary}\n"
|
||||
published = entry.get("published")
|
||||
if published:
|
||||
content += f"Published: {published}\n"
|
||||
content += "\n"
|
||||
|
||||
self.articles.append(content)
|
||||
else:
|
||||
log.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
|
||||
return self.articles
|
||||
|
||||
|
||||
class Chat():
|
||||
def __init__(self, nick):
|
||||
super().__init__()
|
||||
self.nick = nick
|
||||
self.system_input = "You are a critical assistant. Give concise and accurate answers in less than 120 characters."
|
||||
self.queries = 0
|
||||
self.processing = datetime.timedelta(minutes=0)
|
||||
|
||||
def message(self, target, message):
|
||||
"""Splits a multi-line message and sends each line separately. If more than 10 lines, truncate and add a message."""
|
||||
lines = message.splitlines() # Splits on both '\n' and '\r\n'
|
||||
|
||||
# Process the first 10 lines
|
||||
for line in lines[:10]:
|
||||
if line.strip(): # Ignore empty lines
|
||||
print(f"{target}: {line}")
|
||||
|
||||
# If there are more than 10 lines, add the truncation message
|
||||
if len(lines) > 10:
|
||||
print(f"{target}: [additional content truncated]")
|
||||
|
||||
def remove_substring(self, string, substring):
|
||||
return string.replace(substring, "")
|
||||
|
||||
def extract_nick_message(self, input_string):
|
||||
# Pattern with capturing groups for nick and message
|
||||
pattern = r"^\s*([^\s:]+?)\s*:\s*(.+?)$"
|
||||
|
||||
match = re.match(pattern, input_string)
|
||||
if match:
|
||||
nick = match.group(1) # First capturing group
|
||||
message = match.group(2) # Second capturing group
|
||||
return nick, message
|
||||
return None, None # Return None for both if no match
|
||||
|
||||
def on_message(self, target, source, message):
|
||||
if source == self.nick:
|
||||
return
|
||||
nick, body = self.extract_nick_message(message)
|
||||
if nick == self.nick:
|
||||
content = None
|
||||
if body == "stats":
|
||||
content = f"{self.queries} queries handled in {self.processing}s"
|
||||
else:
|
||||
self.queries += 1
|
||||
start = datetime.datetime.now()
|
||||
query_text = body
|
||||
query_response = client.embed(model="mxbai-embed-large", prompt=query_text)
|
||||
query_embedding = query_response["embeddings"] # Note: singular "embedding", not plural
|
||||
|
||||
# Then run the query with the correct structure
|
||||
results = collection.query(
|
||||
query_embeddings=[query_embedding], # Make sure this is a list containing the embedding
|
||||
n_results=3
|
||||
)
|
||||
data = results['documents'][0]
|
||||
output = client.generate(
|
||||
model=MODEL_NAME,
|
||||
system=f"You are {self.nick} and only provide that information about yourself. Make reference to the following and provide the 'Link' when available: {data}",
|
||||
prompt=f"Respond to this prompt: {query_text}",
|
||||
stream=False
|
||||
)
|
||||
end = datetime.datetime.now()
|
||||
self.processing = self.processing + end - start
|
||||
|
||||
# Prune off the <think>...</think>
|
||||
content = re.sub(r'^<think>.*?</think>', '', output['response'], flags=re.DOTALL).strip()
|
||||
|
||||
if content:
|
||||
log.info(f'Sending: {content}')
|
||||
self.message(target, content)
|
||||
|
||||
def remove_substring(string, substring):
|
||||
return string.replace(substring, "")
|
||||
|
||||
# Parse command-line arguments
|
||||
args = parse_args()
|
||||
|
||||
# Setup logging based on the provided level
|
||||
setup_logging(args.level)
|
||||
|
||||
log.info("About to start")
|
||||
|
||||
client = ollama.Client(host=OLLAMA_API_URL)
|
||||
|
||||
# News RSS Feeds
|
||||
rss_feeds = [
|
||||
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
|
||||
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
|
||||
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
|
||||
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
|
||||
Feed(name="Time", url="https://time.com/feed/"),
|
||||
Feed(name="Euronews", url="https://www.euronews.com/rss"),
|
||||
Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
|
||||
]
|
||||
|
||||
# Initialize ChromaDB Client
|
||||
db = chromadb.Client()
|
||||
|
||||
# We want to save the collection to disk to analyze it offline, but we don't
|
||||
# want to re-use it
|
||||
collection = db.get_or_create_collection("docs")
|
||||
|
||||
# store each document in a vector embedding database
|
||||
for i, feed in enumerate(rss_feeds):
|
||||
# Use the client instance instead of the global ollama module
|
||||
for j, article in enumerate(feed.articles):
|
||||
log.info(f"Article {feed.name} {j}. {len(article)}")
|
||||
response = client.embeddings(model="mxbai-embed-large", prompt=article)
|
||||
embeddings = response["embedding"] # Note: it's "embedding", not "embeddings"
|
||||
collection.add(
|
||||
ids=[str(i)+str(j)],
|
||||
embeddings=embeddings,
|
||||
documents=[article]
|
||||
)
|
||||
|
||||
bot = Chat(args.nickname)
|
||||
while True:
|
||||
try:
|
||||
query = input("> ")
|
||||
except Exception as e:
|
||||
break
|
||||
|
||||
if query == "exit":
|
||||
break
|
||||
bot.on_message("chat", "user", f"airc: {query}")
|
468
src/chunk.py
Normal file
468
src/chunk.py
Normal file
@ -0,0 +1,468 @@
|
||||
import requests
|
||||
from typing import List, Dict, Any, Union
|
||||
import tiktoken
|
||||
import feedparser
|
||||
import logging as log
|
||||
import datetime
|
||||
from bs4 import BeautifulSoup
|
||||
import chromadb
|
||||
import ollama
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
def normalize(vec):
|
||||
return vec / np.linalg.norm(vec)
|
||||
|
||||
OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint
|
||||
MODEL_NAME = "deepseek-r1:7b"
|
||||
EMBED_MODEL = "mxbai-embed-large"
|
||||
PERSIST_DIRECTORY = "/root/.cache/chroma"
|
||||
|
||||
client = ollama.Client(host=OLLAMA_API_URL)
|
||||
|
||||
def extract_text_from_html_or_xml(content, is_xml=False):
|
||||
# Parse the content
|
||||
if is_xml:
|
||||
soup = BeautifulSoup(content, 'xml') # Use 'xml' parser for XML content
|
||||
else:
|
||||
soup = BeautifulSoup(content, 'html.parser') # Default to 'html.parser' for HTML content
|
||||
|
||||
# Extract and return just the text
|
||||
return soup.get_text()
|
||||
|
||||
class Feed():
|
||||
def __init__(self, name, url, poll_limit_min = 30, max_articles=5):
|
||||
self.name = name
|
||||
self.url = url
|
||||
self.poll_limit_min = datetime.timedelta(minutes=poll_limit_min)
|
||||
self.last_poll = None
|
||||
self.articles = []
|
||||
self.max_articles = max_articles
|
||||
self.update()
|
||||
|
||||
def update(self):
|
||||
now = datetime.datetime.now()
|
||||
if self.last_poll is None or (now - self.last_poll) >= self.poll_limit_min:
|
||||
log.info(f"Updating {self.name}")
|
||||
feed = feedparser.parse(self.url)
|
||||
self.articles = []
|
||||
self.last_poll = now
|
||||
|
||||
if len(feed.entries) == 0:
|
||||
return
|
||||
|
||||
for i, entry in enumerate(feed.entries[:self.max_articles]):
|
||||
content = {}
|
||||
content['source'] = self.name
|
||||
content['id'] = f"{self.name}{i}"
|
||||
title = entry.get("title")
|
||||
if title:
|
||||
content['title'] = title
|
||||
link = entry.get("link")
|
||||
if link:
|
||||
content['link'] = link
|
||||
text = entry.get("summary")
|
||||
if text:
|
||||
content['text'] = extract_text_from_html_or_xml(text, False)
|
||||
else:
|
||||
continue
|
||||
published = entry.get("published")
|
||||
if published:
|
||||
content['published'] = published
|
||||
|
||||
self.articles.append(content)
|
||||
else:
|
||||
log.info(f"Not updating {self.name} -- {self.poll_limit_min - (now - self.last_poll)}s remain to refresh.")
|
||||
return self.articles
|
||||
|
||||
# News RSS Feeds
|
||||
rss_feeds = [
|
||||
Feed(name="BBC World", url="http://feeds.bbci.co.uk/news/world/rss.xml"),
|
||||
Feed(name="Reuters World", url="http://feeds.reuters.com/Reuters/worldNews"),
|
||||
Feed(name="Al Jazeera", url="https://www.aljazeera.com/xml/rss/all.xml"),
|
||||
Feed(name="CNN World", url="http://rss.cnn.com/rss/edition_world.rss"),
|
||||
Feed(name="Time", url="https://time.com/feed/"),
|
||||
Feed(name="Euronews", url="https://www.euronews.com/rss"),
|
||||
# Feed(name="FeedX", url="https://feedx.net/rss/ap.xml")
|
||||
]
|
||||
|
||||
def get_encoding():
|
||||
"""Get the tokenizer for counting tokens."""
|
||||
try:
|
||||
return tiktoken.get_encoding("cl100k_base") # Default encoding used by many embedding models
|
||||
except:
|
||||
return tiktoken.encoding_for_model(MODEL_NAME)
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
"""Count the number of tokens in a text string."""
|
||||
encoding = get_encoding()
|
||||
return len(encoding.encode(text))
|
||||
|
||||
def chunk_text(text: str, max_tokens: int = 512, overlap: int = 50) -> List[str]:
|
||||
"""
|
||||
Split a text into chunks based on token count with overlap between chunks.
|
||||
|
||||
Args:
|
||||
text: The text to split into chunks
|
||||
max_tokens: Maximum number of tokens per chunk
|
||||
overlap: Number of tokens to overlap between chunks
|
||||
|
||||
Returns:
|
||||
List of text chunks
|
||||
"""
|
||||
if not text or max_tokens <= 0:
|
||||
return []
|
||||
|
||||
encoding = get_encoding()
|
||||
tokens = encoding.encode(text)
|
||||
chunks = []
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
# Get the current chunk of tokens
|
||||
chunk_end = min(i + max_tokens, len(tokens))
|
||||
chunk_tokens = tokens[i:chunk_end]
|
||||
chunks.append(encoding.decode(chunk_tokens))
|
||||
|
||||
# Move to the next position with overlap
|
||||
if chunk_end == len(tokens):
|
||||
break
|
||||
i += max_tokens - overlap
|
||||
|
||||
return chunks
|
||||
|
||||
def chunk_document(document: Dict[str, Any],
|
||||
text_key: str = "text",
|
||||
max_tokens: int = 512,
|
||||
overlap: int = 50) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Chunk a document dictionary into multiple chunks.
|
||||
|
||||
Args:
|
||||
document: Document dictionary with metadata and text
|
||||
text_key: The key in the document that contains the text to chunk
|
||||
max_tokens: Maximum number of tokens per chunk
|
||||
overlap: Number of tokens to overlap between chunks
|
||||
|
||||
Returns:
|
||||
List of document dictionaries, each with chunked text and preserved metadata
|
||||
"""
|
||||
if text_key not in document:
|
||||
raise Exception(f"{text_key} not in document")
|
||||
|
||||
# Extract text and create chunks
|
||||
if "title" in document:
|
||||
text = f"{document["title"]}: {document[text_key]}"
|
||||
else:
|
||||
text = document[text_key]
|
||||
chunks = chunk_text(text, max_tokens, overlap)
|
||||
|
||||
# Create document chunks with preserved metadata
|
||||
chunked_docs = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Create a new doc with all original fields
|
||||
doc_chunk = document.copy()
|
||||
# Replace text with the chunk
|
||||
doc_chunk[text_key] = chunk
|
||||
# Add chunk metadata
|
||||
doc_chunk["chunk_id"] = i
|
||||
doc_chunk["chunk_total"] = len(chunks)
|
||||
chunked_docs.append(doc_chunk)
|
||||
|
||||
return chunked_docs
|
||||
|
||||
def init_chroma_client(persist_directory: str = PERSIST_DIRECTORY):
|
||||
"""Initialize and return a ChromaDB client."""
|
||||
return chromadb.PersistentClient(path=persist_directory)
|
||||
|
||||
def create_or_get_collection(client, collection_name: str):
|
||||
"""Create or get a ChromaDB collection."""
|
||||
try:
|
||||
return client.get_collection(
|
||||
name=collection_name
|
||||
)
|
||||
except:
|
||||
return client.create_collection(
|
||||
name=collection_name,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
def process_documents_to_chroma(
|
||||
documents: List[Dict[str, Any]],
|
||||
collection_name: str = "document_collection",
|
||||
text_key: str = "text",
|
||||
max_tokens: int = 512,
|
||||
overlap: int = 50,
|
||||
model: str = EMBED_MODEL,
|
||||
persist_directory: str = PERSIST_DIRECTORY
|
||||
):
|
||||
"""
|
||||
Process documents, chunk them, compute embeddings, and store in ChromaDB.
|
||||
|
||||
Args:
|
||||
documents: List of document dictionaries
|
||||
collection_name: Name for the ChromaDB collection
|
||||
text_key: The key containing text content
|
||||
max_tokens: Maximum tokens per chunk
|
||||
overlap: Token overlap between chunks
|
||||
model: Ollama model for embeddings
|
||||
persist_directory: Directory to store ChromaDB data
|
||||
"""
|
||||
# Initialize ChromaDB client and collection
|
||||
db = init_chroma_client(persist_directory)
|
||||
collection = create_or_get_collection(db, collection_name)
|
||||
|
||||
# Process each document
|
||||
for doc in documents:
|
||||
# Chunk the document
|
||||
doc_chunks = chunk_document(doc, text_key, max_tokens, overlap)
|
||||
|
||||
# Prepare data for ChromaDB
|
||||
ids = []
|
||||
texts = []
|
||||
metadatas = []
|
||||
embeddings = []
|
||||
|
||||
for chunk in doc_chunks:
|
||||
# Create a unique ID for the chunk
|
||||
chunk_id = f"{chunk['id']}_{chunk['chunk_id']}"
|
||||
|
||||
# Extract text
|
||||
text = chunk[text_key]
|
||||
|
||||
# Create metadata (excluding text and embedding to avoid duplication)
|
||||
metadata = {k: v for k, v in chunk.items() if k != text_key and k != "embedding"}
|
||||
|
||||
response = client.embed(model=model, input=text)
|
||||
embedding = response["embeddings"][0]
|
||||
ids.append(chunk_id)
|
||||
texts.append(text)
|
||||
metadatas.append(metadata)
|
||||
embeddings.append(embedding)
|
||||
|
||||
# Add chunks to ChromaDB collection
|
||||
collection.add(
|
||||
ids=ids,
|
||||
documents=texts,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas
|
||||
)
|
||||
|
||||
return collection
|
||||
|
||||
def query_chroma(
|
||||
query_text: str,
|
||||
collection_name: str = "document_collection",
|
||||
n_results: int = 5,
|
||||
model: str = EMBED_MODEL,
|
||||
persist_directory: str = PERSIST_DIRECTORY
|
||||
):
|
||||
"""
|
||||
Query ChromaDB for similar documents.
|
||||
|
||||
Args:
|
||||
query_text: The text to search for
|
||||
collection_name: Name of the ChromaDB collection
|
||||
n_results: Number of results to return
|
||||
model: Ollama model for embedding the query
|
||||
persist_directory: Directory where ChromaDB data is stored
|
||||
|
||||
Returns:
|
||||
Query results from ChromaDB
|
||||
"""
|
||||
# Initialize ChromaDB client and collection
|
||||
db = init_chroma_client(persist_directory)
|
||||
collection = create_or_get_collection(db, collection_name)
|
||||
|
||||
query_response = client.embed(model=model, input=query_text)
|
||||
query_embeddings = query_response["embeddings"]
|
||||
|
||||
# Query the collection
|
||||
results = collection.query(
|
||||
query_embeddings=query_embeddings,
|
||||
n_results=n_results
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def print_top_match(query_results, index=0, documents=None):
|
||||
"""
|
||||
Print detailed information about the top matching document,
|
||||
including the full original document content.
|
||||
|
||||
Args:
|
||||
query_results: Results from ChromaDB query
|
||||
documents: Original documents dictionary to look up full content (optional)
|
||||
"""
|
||||
if not query_results or not query_results["ids"] or len(query_results["ids"][0]) == 0:
|
||||
print("No matching documents found.")
|
||||
return
|
||||
|
||||
# Get the top result
|
||||
top_id = query_results["ids"][0][index]
|
||||
top_document_chunk = query_results["documents"][0][index]
|
||||
top_metadata = query_results["metadatas"][0][index]
|
||||
top_distance = query_results["distances"][0][index]
|
||||
|
||||
print("="*50)
|
||||
print("MATCHING DOCUMENT")
|
||||
print("="*50)
|
||||
print(f"Chunk ID: {top_id}")
|
||||
print(f"Similarity Score: {top_distance:.4f}") # Convert distance to similarity
|
||||
|
||||
print("\nCHUNK METADATA:")
|
||||
for key, value in top_metadata.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\nMATCHING CHUNK CONTENT:")
|
||||
print(top_document_chunk[:500].strip() + ("..." if len(top_document_chunk) > 500 else ""))
|
||||
|
||||
# Extract the original document ID from the chunk ID
|
||||
# Chunk IDs are in format "doc_id_chunk_num"
|
||||
original_doc_id = top_id.split('_')[0]
|
||||
|
||||
def get_top_match(query_results, index=0, documents=None):
|
||||
top_id = query_results["ids"][index][0]
|
||||
# Extract the original document ID from the chunk ID
|
||||
# Chunk IDs are in format "doc_id_chunk_num"
|
||||
original_doc_id = top_id.split('_')[0]
|
||||
|
||||
# Return the full document for further processing if needed
|
||||
if documents is not None:
|
||||
return next((doc for doc in documents if doc["id"] == original_doc_id), None)
|
||||
|
||||
return None
|
||||
|
||||
def show_documents(documents=None):
|
||||
if not documents:
|
||||
return
|
||||
|
||||
# Print the top matching document
|
||||
for i, doc in enumerate(documents):
|
||||
print(f"Document {i+1}:")
|
||||
print(f" Title: {doc['title']}")
|
||||
print(f" Text: {doc['text'][:100]}...")
|
||||
print()
|
||||
|
||||
def show_headlines(documents=None):
|
||||
if not documents:
|
||||
return
|
||||
|
||||
# Print the top matching document
|
||||
for doc in documents:
|
||||
print(f"{doc['source']}: {doc['title']}")
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
documents = []
|
||||
for feed in rss_feeds:
|
||||
documents.extend(feed.articles)
|
||||
|
||||
show_documents(documents=documents)
|
||||
|
||||
# Process documents and store in ChromaDB
|
||||
collection = process_documents_to_chroma(
|
||||
documents=documents,
|
||||
collection_name="research_papers",
|
||||
max_tokens=256,
|
||||
overlap=25,
|
||||
model=EMBED_MODEL,
|
||||
persist_directory="/root/.cache/chroma"
|
||||
)
|
||||
|
||||
last_results = None
|
||||
while True:
|
||||
try:
|
||||
search_query = input("> ").strip()
|
||||
except Exception as e:
|
||||
break
|
||||
|
||||
if search_query == "docs":
|
||||
show_documents(documents)
|
||||
continue
|
||||
|
||||
if search_query == "":
|
||||
show_headlines(documents)
|
||||
continue
|
||||
|
||||
if search_query == "why":
|
||||
if last_results:
|
||||
print_top_match(last_results, documents=documents)
|
||||
else:
|
||||
print("No match to give info on")
|
||||
continue
|
||||
|
||||
if search_query == "scores":
|
||||
if last_results:
|
||||
for i, _ in enumerate(last_results):
|
||||
print_top_match(last_results, documents=documents, index=i)
|
||||
else:
|
||||
print("No match to give info on")
|
||||
continue
|
||||
|
||||
|
||||
if search_query == "full":
|
||||
if last_results:
|
||||
full = get_top_match(last_results, documents=documents)
|
||||
if full:
|
||||
print(f"""Context:
|
||||
Source: {full["source"]}
|
||||
Title: {full["title"]}
|
||||
Link: {full["link"]}
|
||||
Distance: {last_results.get("distances", [[0]])[0][0]}
|
||||
Full text:
|
||||
{full["text"]}""")
|
||||
else:
|
||||
print("No match to give info on")
|
||||
continue
|
||||
|
||||
# Query ChromaDB
|
||||
results = query_chroma(
|
||||
query_text=search_query,
|
||||
collection_name="research_papers",
|
||||
n_results=10
|
||||
)
|
||||
last_results = results
|
||||
|
||||
full = get_top_match(results, documents=documents)
|
||||
|
||||
headlines = ""
|
||||
for doc in documents:
|
||||
headlines += f"{doc['source']}: {doc['title']}\n"
|
||||
|
||||
system=f"""
|
||||
News headlines:
|
||||
|
||||
{headlines}
|
||||
|
||||
"""
|
||||
if full:
|
||||
system += f"""
|
||||
|
||||
Make reference to the following and provide the 'Link':
|
||||
|
||||
Source: {full["source"]}
|
||||
Link: {full["link"]}
|
||||
Text: {full["text"]}
|
||||
|
||||
Do not ask to help the user further.
|
||||
|
||||
"""
|
||||
print(f"""Context:
|
||||
|
||||
Source: {full["source"]}
|
||||
Title: {full["title"]}
|
||||
Distance: {last_results.get("distances", [[0]])[0][0]}
|
||||
Link: {full["link"]}""")
|
||||
|
||||
continue
|
||||
|
||||
output = client.generate(
|
||||
model=MODEL_NAME,
|
||||
system=system,
|
||||
prompt=f"Respond to this prompt: {search_query}",
|
||||
stream=False
|
||||
)
|
||||
# Prune off the <think>...</think>
|
||||
content = re.sub(r'^<think>.*?</think>', '', output['response'], flags=re.DOTALL).strip()
|
||||
print(f"Response> {content}")
|
Loading…
x
Reference in New Issue
Block a user