llm/ollama-context-proxy/ollama-context-proxy.py

420 lines
14 KiB
Python

#!/usr/bin/env python3
"""
Ollama Context Proxy - Single port with URL-based context routing + auto-sizing
Use URLs like: http://localhost:11434/proxy-context/4096/api/generate
Or auto-sizing: http://localhost:11434/proxy-context/auto/api/generate
"""
import asyncio
import json
import logging
import os
import re
import urllib.parse
from typing import Optional, Union
import aiohttp
from aiohttp import web, ClientSession
from aiohttp.web_response import StreamResponse
import argparse
import sys
class OllamaContextProxy:
def __init__(
self,
ollama_host: Optional[str] = None,
ollama_port: int = 11434,
proxy_port: int = 11434,
):
# Use OLLAMA_BASE_URL environment variable or construct from host/port
base_url = os.getenv("OLLAMA_BASE_URL")
if base_url:
self.ollama_base_url = base_url.rstrip("/")
else:
# Fall back to host/port construction
if ollama_host is None:
ollama_host = "localhost"
self.ollama_base_url = f"http://{ollama_host}:{ollama_port}"
self.proxy_port = proxy_port
self.session: Optional[ClientSession] = None
self.logger = logging.getLogger(__name__)
# Available context sizes (must be sorted ascending)
self.available_contexts = [2048, 4096, 8192, 16384, 32768]
# URL pattern to extract context size or 'auto'
self.context_pattern = re.compile(r"^/proxy-context/(auto|\d+)(/.*)?$")
async def start(self):
"""Initialize the HTTP session"""
self.session = ClientSession()
async def stop(self):
"""Cleanup HTTP session"""
if self.session:
await self.session.close()
def create_app(self) -> web.Application:
"""Create the main web application"""
app = web.Application()
app["proxy"] = self
# Add routes - capture everything under /proxy-context/
app.router.add_route(
"*",
r"/proxy-context/{context_spec:(auto|\d+)}{path:.*}",
self.proxy_handler,
)
# Optional: Add a health check endpoint
app.router.add_get("/", self.health_check)
app.router.add_get("/health", self.health_check)
return app
async def health_check(self, request: web.Request) -> web.Response:
"""Health check endpoint"""
return web.Response(
text="Ollama Context Proxy is running\n"
"Usage: /proxy-context/{context_size}/api/{endpoint}\n"
" /proxy-context/auto/api/{endpoint}\n"
"Examples:\n"
" Fixed: /proxy-context/4096/api/generate\n"
" Auto: /proxy-context/auto/api/generate\n"
f"Available contexts: {', '.join(map(str, self.available_contexts))}",
content_type="text/plain",
)
async def proxy_handler(self, request: web.Request) -> web.Response:
"""Handle all proxy requests with context size extraction or auto-detection"""
# Extract context spec and remaining path
context_spec = request.match_info["context_spec"]
remaining_path = request.match_info.get("path", "")
# Remove leading slash if present
if remaining_path.startswith("/"):
remaining_path = remaining_path[1:]
# Get request data first (needed for auto-sizing)
if request.content_type == "application/json":
try:
data = await request.json()
except json.JSONDecodeError:
data = await request.text()
else:
data = await request.read()
# Determine context size
if context_spec == "auto":
context_size = self._auto_determine_context_size(data, remaining_path)
else:
context_size = int(context_spec)
# Validate context size
if context_size not in self.available_contexts:
# Find the next larger available context
suitable_context = next(
(ctx for ctx in self.available_contexts if ctx >= context_size),
self.available_contexts[-1],
)
self.logger.warning(
f"Requested context {context_size} not available, using {suitable_context}"
)
context_size = suitable_context
# Build target URL
if not remaining_path:
target_url = self.ollama_base_url
else:
target_url = f"{self.ollama_base_url}/{remaining_path}"
self.logger.info(f"Routing to context {context_size} -> {target_url}")
# Inject context if needed
if self._should_inject_context(remaining_path) and isinstance(data, dict):
if "options" not in data:
data["options"] = {}
data["options"]["num_ctx"] = context_size
self.logger.info(f"Injected num_ctx={context_size} for {remaining_path}")
# Prepare headers (exclude hop-by-hop headers)
headers = {
key: value
for key, value in request.headers.items()
if key.lower() not in ["host", "connection", "upgrade"]
}
if not self.session:
raise RuntimeError("HTTP session not initialized")
try:
# Make request to Ollama
async with self.session.request(
method=request.method,
url=target_url,
data=json.dumps(data) if isinstance(data, dict) else data,
headers=headers,
params=request.query,
) as response:
# Handle streaming responses (for generate/chat endpoints)
if response.headers.get("content-type", "").startswith(
"application/x-ndjson"
):
return await self._handle_streaming_response(request, response)
else:
return await self._handle_regular_response(response)
except aiohttp.ClientError as e:
self.logger.error(f"Error proxying request to {target_url}: {e}")
return web.Response(
text=f"Proxy error: {str(e)}", status=502, content_type="text/plain"
)
def _auto_determine_context_size(
self, data: Union[dict, str, bytes], endpoint: str
) -> int:
"""Automatically determine the required context size based on request content"""
input_tokens = 0
max_tokens = 0
if isinstance(data, dict):
# Extract text content and max_tokens based on endpoint
if endpoint.startswith("api/generate"):
# Ollama generate endpoint
prompt = data.get("prompt", "")
input_tokens = self._estimate_tokens(prompt)
max_tokens = data.get("options", {}).get("num_predict", 0)
elif endpoint.startswith("api/chat"):
# Ollama chat endpoint
messages = data.get("messages", [])
total_text = ""
for msg in messages:
if isinstance(msg, dict) and "content" in msg:
total_text += str(msg["content"]) + " "
input_tokens = self._estimate_tokens(total_text)
max_tokens = data.get("options", {}).get("num_predict", 0)
elif endpoint.startswith("v1/chat/completions"):
# OpenAI-compatible chat endpoint
messages = data.get("messages", [])
total_text = ""
for msg in messages:
if isinstance(msg, dict) and "content" in msg:
total_text += str(msg["content"]) + " "
input_tokens = self._estimate_tokens(total_text)
max_tokens = data.get("max_tokens", 0)
elif endpoint.startswith("v1/completions"):
# OpenAI-compatible completions endpoint
prompt = data.get("prompt", "")
input_tokens = self._estimate_tokens(prompt)
max_tokens = data.get("max_tokens", 0)
elif isinstance(data, (str, bytes)):
# Fallback for non-JSON data
text = (
data if isinstance(data, str) else data.decode("utf-8", errors="ignore")
)
input_tokens = self._estimate_tokens(text)
# Calculate total tokens needed
system_overhead = 100 # Buffer for system prompts, formatting, etc.
response_buffer = max(max_tokens, 512) # Ensure space for response
safety_margin = 200 # Additional safety buffer
total_needed = input_tokens + response_buffer + system_overhead + safety_margin
# Find the smallest context that can accommodate the request
suitable_context = next(
(ctx for ctx in self.available_contexts if ctx >= total_needed),
self.available_contexts[-1], # Fall back to largest if none are big enough
)
self.logger.info(
f"Auto-sizing analysis: "
f"input_tokens={input_tokens}, "
f"max_tokens={max_tokens}, "
f"total_needed={total_needed}, "
f"selected_context={suitable_context}"
)
# Log warning if we're using the largest context and it might not be enough
if (
suitable_context == self.available_contexts[-1]
and total_needed > suitable_context
):
self.logger.warning(
f"Request may exceed largest available context! "
f"Needed: {total_needed}, Available: {suitable_context}"
)
return suitable_context
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count from text (rough approximation)"""
if not text:
return 0
# Rough estimation: ~4 characters per token for English
# This is a conservative estimate - actual tokenization varies by model
char_count = len(str(text))
estimated_tokens = max(1, char_count // 4)
self.logger.debug(
f"Token estimation: {char_count} chars -> ~{estimated_tokens} tokens"
)
return estimated_tokens
def _should_inject_context(self, path: str) -> bool:
"""Determine if we should inject context for this endpoint"""
# Inject context for endpoints that support the num_ctx parameter
context_endpoints = [
"api/generate",
"api/chat",
"v1/chat/completions",
"v1/completions",
]
return any(path.startswith(endpoint) for endpoint in context_endpoints)
async def _handle_streaming_response(
self, request: web.Request, response: aiohttp.ClientResponse
) -> StreamResponse:
"""Handle streaming responses (NDJSON)"""
stream_response = StreamResponse(
status=response.status,
headers={
key: value
for key, value in response.headers.items()
if key.lower() not in ["content-length", "transfer-encoding"]
},
)
await stream_response.prepare(request)
async for chunk in response.content.iter_any():
await stream_response.write(chunk)
await stream_response.write_eof()
return stream_response
async def _handle_regular_response(
self, response: aiohttp.ClientResponse
) -> web.Response:
"""Handle regular (non-streaming) responses"""
content = await response.read()
return web.Response(
body=content,
status=response.status,
headers={
key: value
for key, value in response.headers.items()
if key.lower() not in ["content-length", "transfer-encoding"]
},
)
async def main():
parser = argparse.ArgumentParser(
description="Ollama Context Proxy - URL-based routing with auto-sizing"
)
# Get default host from OLLAMA_BASE_URL if available
default_host = "localhost"
base_url = os.getenv("OLLAMA_BASE_URL")
if base_url:
# Extract host from base URL for backward compatibility with CLI args
parsed = urllib.parse.urlparse(base_url)
if parsed.hostname:
default_host = parsed.hostname
parser.add_argument(
"--ollama-host",
default=default_host,
help=f"Ollama server host (default: {default_host})",
)
parser.add_argument(
"--ollama-port",
type=int,
default=11434,
help="Ollama server port (default: 11434)",
)
parser.add_argument(
"--proxy-port",
type=int,
default=11435,
help="Proxy server port (default: 11435)",
)
parser.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Log level (default: INFO)",
)
args = parser.parse_args()
# Setup logging
logging.basicConfig(
level=getattr(logging, args.log_level),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
# Create proxy instance
proxy = OllamaContextProxy(args.ollama_host, args.ollama_port, args.proxy_port)
await proxy.start()
# Create and start the web application
app = proxy.create_app()
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "0.0.0.0", args.proxy_port)
await site.start()
logging.info(f"Ollama Context Proxy started on port {args.proxy_port}")
logging.info(f"Forwarding to Ollama at {proxy.ollama_base_url}")
logging.info(f"Available context sizes: {proxy.available_contexts}")
logging.info("Usage examples:")
logging.info(
f" Auto-size: http://localhost:{args.proxy_port}/proxy-context/auto"
)
logging.info(
f" 2K context: http://localhost:{args.proxy_port}/proxy-context/2048"
)
logging.info(
f" 4K context: http://localhost:{args.proxy_port}/proxy-context/4096"
)
logging.info(
f" 8K context: http://localhost:{args.proxy_port}/proxy-context/8192"
)
logging.info(
f" 16K context: http://localhost:{args.proxy_port}/proxy-context/16384"
)
logging.info(
f" 32K context: http://localhost:{args.proxy_port}/proxy-context/32768"
)
try:
# Keep running
while True:
await asyncio.sleep(1)
except KeyboardInterrupt:
logging.info("Shutting down...")
finally:
# Cleanup
await runner.cleanup()
await proxy.stop()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nShutdown complete.")
sys.exit(0)