538 lines
20 KiB
Python
538 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Ollama Context Proxy - Single port with URL-based context routing + auto-sizing
|
|
Use URLs like: http://localhost:11434/proxy-context/4096/api/generate
|
|
Or auto-sizing: http://localhost:11434/proxy-context/auto/api/generate
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import urllib.parse
|
|
from typing import Optional, Union
|
|
import aiohttp
|
|
from aiohttp import web, ClientSession
|
|
from aiohttp.web_response import StreamResponse
|
|
import argparse
|
|
import sys
|
|
|
|
|
|
class OllamaContextProxy:
|
|
def __init__(
|
|
self,
|
|
ollama_host: Optional[str] = None,
|
|
ollama_port: int = 11434,
|
|
proxy_port: int = 11434,
|
|
):
|
|
# Use OLLAMA_BASE_URL environment variable or construct from host/port
|
|
base_url = os.getenv("OLLAMA_BASE_URL")
|
|
if base_url:
|
|
self.ollama_base_url = base_url.rstrip("/")
|
|
else:
|
|
# Fall back to host/port construction
|
|
if ollama_host is None:
|
|
ollama_host = "ollama"
|
|
self.ollama_base_url = f"http://{ollama_host}:{ollama_port}"
|
|
|
|
self.proxy_port = proxy_port
|
|
self.session: Optional[ClientSession] = None
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
# Available context sizes (must be sorted ascending)
|
|
self.available_contexts = [2048, 4096, 8192, 16384, 32768]
|
|
|
|
# URL pattern to extract context size or 'auto'
|
|
self.context_pattern = re.compile(r"^/proxy-context/(auto|\d+)(/.*)?$")
|
|
|
|
async def start(self):
|
|
"""Initialize the HTTP session"""
|
|
self.session = ClientSession()
|
|
|
|
async def stop(self):
|
|
"""Cleanup HTTP session"""
|
|
if self.session:
|
|
await self.session.close()
|
|
|
|
def create_app(self) -> web.Application:
|
|
"""Create the main web application"""
|
|
app = web.Application()
|
|
app["proxy"] = self
|
|
|
|
# Add routes - capture everything under /proxy-context/
|
|
app.router.add_route(
|
|
"*",
|
|
r"/proxy-context/{context_spec:(auto|\d+)}{path:.*}",
|
|
self.proxy_handler,
|
|
)
|
|
|
|
# Optional: Add a health check endpoint
|
|
app.router.add_get("/", self.health_check)
|
|
app.router.add_get("/health", self.health_check)
|
|
app.router.add_get("/debug/ollama", self.debug_ollama)
|
|
|
|
return app
|
|
|
|
async def debug_ollama(self, request: web.Request) -> web.Response:
|
|
"""Debug endpoint to test connectivity to Ollama"""
|
|
if not self.session:
|
|
return web.Response(
|
|
text="Error: HTTP session not initialized",
|
|
status=500,
|
|
content_type="text/plain",
|
|
)
|
|
|
|
test_url = f"{self.ollama_base_url}/api/tags"
|
|
try:
|
|
# Test basic connectivity to Ollama
|
|
self.logger.info(f"Testing Ollama connectivity to: {test_url}")
|
|
|
|
async with self.session.get(test_url) as response:
|
|
status = response.status
|
|
content_type = response.headers.get("content-type", "N/A")
|
|
body = await response.text()
|
|
|
|
return web.Response(
|
|
text=f"Ollama Debug Test\n"
|
|
f"=================\n"
|
|
f"Target URL: {test_url}\n"
|
|
f"Status: {status}\n"
|
|
f"Content-Type: {content_type}\n"
|
|
f"Body Length: {len(body)}\n"
|
|
f"Body Preview: {body[:500]}...\n"
|
|
f"\nProxy Base URL: {self.ollama_base_url}\n"
|
|
f"Available Contexts: {self.available_contexts}",
|
|
content_type="text/plain",
|
|
)
|
|
except Exception as e:
|
|
return web.Response(
|
|
text=f"Ollama Debug Test FAILED\n"
|
|
f"========================\n"
|
|
f"Error: {str(e)}\n"
|
|
f"Target URL: {test_url}\n"
|
|
f"Proxy Base URL: {self.ollama_base_url}",
|
|
status=502,
|
|
content_type="text/plain",
|
|
)
|
|
|
|
async def health_check(self, request: web.Request) -> web.Response:
|
|
"""Health check endpoint"""
|
|
return web.Response(
|
|
text="Ollama Context Proxy is running\n"
|
|
"Usage: /proxy-context/{context_size}/api/{endpoint}\n"
|
|
" /proxy-context/auto/api/{endpoint}\n"
|
|
"Examples:\n"
|
|
" Fixed: /proxy-context/4096/api/generate\n"
|
|
" Auto: /proxy-context/auto/api/generate\n"
|
|
f"Available contexts: {', '.join(map(str, self.available_contexts))}",
|
|
content_type="text/plain",
|
|
)
|
|
|
|
async def proxy_handler(
|
|
self, request: web.Request
|
|
) -> web.Response | web.StreamResponse:
|
|
"""Handle all proxy requests with context size extraction or auto-detection"""
|
|
|
|
# Extract context spec and remaining path
|
|
context_spec = request.match_info["context_spec"]
|
|
remaining_path = request.match_info.get("path", "")
|
|
|
|
# Remove leading slash if present
|
|
if remaining_path.startswith("/"):
|
|
remaining_path = remaining_path[1:]
|
|
|
|
# Get request data first (needed for auto-sizing) - read only once!
|
|
original_data = None
|
|
request_body = None
|
|
|
|
if request.content_type == "application/json":
|
|
try:
|
|
original_data = await request.json()
|
|
# Convert back to bytes for forwarding
|
|
request_body = json.dumps(original_data).encode("utf-8")
|
|
except json.JSONDecodeError as e:
|
|
self.logger.error(f"Failed to parse JSON: {e}")
|
|
request_body = await request.read()
|
|
original_data = request_body.decode("utf-8", errors="ignore")
|
|
else:
|
|
request_body = await request.read()
|
|
original_data = request_body
|
|
|
|
# Use original_data for analysis, request_body for forwarding
|
|
data_for_analysis = original_data if original_data is not None else {}
|
|
data_for_forwarding = request_body if request_body is not None else b""
|
|
|
|
# Determine context size
|
|
if context_spec == "auto":
|
|
context_size = self._auto_determine_context_size(
|
|
data_for_analysis, remaining_path
|
|
)
|
|
else:
|
|
context_size = int(context_spec)
|
|
|
|
# Validate context size
|
|
if context_size not in self.available_contexts:
|
|
# Find the next larger available context
|
|
suitable_context = next(
|
|
(ctx for ctx in self.available_contexts if ctx >= context_size),
|
|
self.available_contexts[-1],
|
|
)
|
|
self.logger.warning(
|
|
f"Requested context {context_size} not available, using {suitable_context}"
|
|
)
|
|
context_size = suitable_context
|
|
|
|
# Build target URL
|
|
if not remaining_path:
|
|
target_url = self.ollama_base_url
|
|
else:
|
|
target_url = f"{self.ollama_base_url}/{remaining_path}"
|
|
|
|
# Enhanced debugging
|
|
self.logger.info("=== REQUEST DEBUG ===")
|
|
self.logger.info(f"Original request path: {request.path}")
|
|
self.logger.info(f"Context spec: {context_spec}")
|
|
self.logger.info(f"Remaining path: '{remaining_path}'")
|
|
self.logger.info(f"Target URL: {target_url}")
|
|
self.logger.info(f"Request method: {request.method}")
|
|
self.logger.info(f"Request headers: {dict(request.headers)}")
|
|
self.logger.info(f"Request query params: {dict(request.query)}")
|
|
self.logger.info(f"Content type: {request.content_type}")
|
|
if isinstance(data_for_analysis, dict):
|
|
self.logger.info(f"Request data keys: {list(data_for_analysis.keys())}")
|
|
else:
|
|
data_len = (
|
|
len(data_for_analysis)
|
|
if hasattr(data_for_analysis, "__len__")
|
|
else "N/A"
|
|
)
|
|
self.logger.info(
|
|
f"Request data type: {type(data_for_analysis)}, length: {data_len}"
|
|
)
|
|
self.logger.info(f"Selected context size: {context_size}")
|
|
|
|
# Inject context if needed (modify the JSON data, not the raw bytes)
|
|
modified_data = False
|
|
if self._should_inject_context(remaining_path) and isinstance(
|
|
data_for_analysis, dict
|
|
):
|
|
if "options" not in data_for_analysis:
|
|
data_for_analysis["options"] = {}
|
|
data_for_analysis["options"]["num_ctx"] = context_size
|
|
self.logger.info(f"Injected num_ctx={context_size} for {remaining_path}")
|
|
# Re-encode the modified JSON
|
|
data_for_forwarding = json.dumps(data_for_analysis).encode("utf-8")
|
|
modified_data = True
|
|
|
|
# Prepare headers (exclude hop-by-hop headers)
|
|
headers = {
|
|
key: value
|
|
for key, value in request.headers.items()
|
|
if key.lower() not in ["host", "connection", "upgrade", "content-length"]
|
|
}
|
|
|
|
# Update Content-Length if we modified the data
|
|
if modified_data and isinstance(data_for_forwarding, bytes):
|
|
headers["Content-Length"] = str(len(data_for_forwarding))
|
|
|
|
# Debug the final data being sent
|
|
self.logger.debug(f"Final data being sent: {data_for_forwarding}")
|
|
self.logger.debug(f"Final headers: {headers}")
|
|
|
|
if not self.session:
|
|
raise RuntimeError("HTTP session not initialized")
|
|
try:
|
|
# Make request to Ollama
|
|
async with self.session.request(
|
|
method=request.method,
|
|
url=target_url,
|
|
data=data_for_forwarding,
|
|
headers=headers,
|
|
params=request.query,
|
|
) as response:
|
|
# Enhanced response debugging
|
|
self.logger.info("=== RESPONSE DEBUG ===")
|
|
self.logger.info(f"Response status: {response.status}")
|
|
self.logger.info(f"Response headers: {dict(response.headers)}")
|
|
self.logger.info(
|
|
f"Response content-type: {response.headers.get('content-type', 'N/A')}"
|
|
)
|
|
|
|
# Log response body for non-streaming 404s
|
|
if response.status == 404:
|
|
error_body = await response.text()
|
|
self.logger.error(f"404 Error body: {error_body}")
|
|
return web.Response(
|
|
text=f"Ollama 404 Error - URL: {target_url}\nError: {error_body}",
|
|
status=404,
|
|
content_type="text/plain",
|
|
)
|
|
|
|
# Handle streaming responses (for generate/chat endpoints)
|
|
if response.headers.get("content-type", "").startswith(
|
|
"application/x-ndjson"
|
|
):
|
|
return await self._handle_streaming_response(request, response)
|
|
else:
|
|
return await self._handle_regular_response(response)
|
|
|
|
except aiohttp.ClientError as e:
|
|
self.logger.error(f"Error proxying request to {target_url}: {e}")
|
|
return web.Response(
|
|
text=f"Proxy error: {str(e)}", status=502, content_type="text/plain"
|
|
)
|
|
|
|
def _auto_determine_context_size(
|
|
self, data: Union[dict, str, bytes], endpoint: str
|
|
) -> int:
|
|
"""Automatically determine the required context size based on request content"""
|
|
|
|
input_tokens = 0
|
|
max_tokens = 0
|
|
|
|
if isinstance(data, dict):
|
|
# Extract text content and max_tokens based on endpoint
|
|
if endpoint.startswith("api/generate"):
|
|
# Ollama generate endpoint
|
|
prompt = data.get("prompt", "")
|
|
input_tokens = self._estimate_tokens(prompt)
|
|
max_tokens = data.get("options", {}).get("num_predict", 0)
|
|
|
|
elif endpoint.startswith("api/chat"):
|
|
# Ollama chat endpoint
|
|
messages = data.get("messages", [])
|
|
total_text = ""
|
|
for msg in messages:
|
|
if isinstance(msg, dict) and "content" in msg:
|
|
total_text += str(msg["content"]) + " "
|
|
input_tokens = self._estimate_tokens(total_text)
|
|
max_tokens = data.get("options", {}).get("num_predict", 0)
|
|
|
|
elif endpoint.startswith("v1/chat/completions"):
|
|
# OpenAI-compatible chat endpoint
|
|
messages = data.get("messages", [])
|
|
total_text = ""
|
|
for msg in messages:
|
|
if isinstance(msg, dict) and "content" in msg:
|
|
total_text += str(msg["content"]) + " "
|
|
input_tokens = self._estimate_tokens(total_text)
|
|
max_tokens = data.get("max_tokens", 0)
|
|
|
|
elif endpoint.startswith("v1/completions"):
|
|
# OpenAI-compatible completions endpoint
|
|
prompt = data.get("prompt", "")
|
|
input_tokens = self._estimate_tokens(prompt)
|
|
max_tokens = data.get("max_tokens", 0)
|
|
|
|
elif isinstance(data, (str, bytes)):
|
|
# Fallback for non-JSON data
|
|
text = (
|
|
data if isinstance(data, str) else data.decode("utf-8", errors="ignore")
|
|
)
|
|
input_tokens = self._estimate_tokens(text)
|
|
|
|
# Calculate total tokens needed
|
|
system_overhead = 100 # Buffer for system prompts, formatting, etc.
|
|
response_buffer = max(max_tokens, 512) # Ensure space for response
|
|
safety_margin = 200 # Additional safety buffer
|
|
|
|
total_needed = input_tokens + response_buffer + system_overhead + safety_margin
|
|
|
|
# Find the smallest context that can accommodate the request
|
|
suitable_context = next(
|
|
(ctx for ctx in self.available_contexts if ctx >= total_needed),
|
|
self.available_contexts[-1], # Fall back to largest if none are big enough
|
|
)
|
|
|
|
self.logger.info(
|
|
f"Auto-sizing analysis: "
|
|
f"input_tokens={input_tokens}, "
|
|
f"max_tokens={max_tokens}, "
|
|
f"total_needed={total_needed}, "
|
|
f"selected_context={suitable_context}"
|
|
)
|
|
|
|
# Log warning if we're using the largest context and it might not be enough
|
|
if (
|
|
suitable_context == self.available_contexts[-1]
|
|
and total_needed > suitable_context
|
|
):
|
|
self.logger.warning(
|
|
f"Request may exceed largest available context! "
|
|
f"Needed: {total_needed}, Available: {suitable_context}"
|
|
)
|
|
|
|
return suitable_context
|
|
|
|
def _estimate_tokens(self, text: str) -> int:
|
|
"""Estimate token count from text (rough approximation)"""
|
|
if not text:
|
|
return 0
|
|
|
|
# Rough estimation: ~4 characters per token for English
|
|
# This is a conservative estimate - actual tokenization varies by model
|
|
char_count = len(str(text))
|
|
estimated_tokens = max(1, char_count // 4)
|
|
|
|
self.logger.debug(
|
|
f"Token estimation: {char_count} chars -> ~{estimated_tokens} tokens"
|
|
)
|
|
return estimated_tokens
|
|
|
|
def _should_inject_context(self, path: str) -> bool:
|
|
"""Determine if we should inject context for this endpoint"""
|
|
# Inject context for endpoints that support the num_ctx parameter
|
|
context_endpoints = [
|
|
"api/generate",
|
|
"api/chat",
|
|
"v1/chat/completions",
|
|
"v1/completions",
|
|
]
|
|
return any(path.startswith(endpoint) for endpoint in context_endpoints)
|
|
|
|
async def _handle_streaming_response(
|
|
self, request: web.Request, response: aiohttp.ClientResponse
|
|
) -> StreamResponse:
|
|
"""Handle streaming responses (NDJSON)"""
|
|
stream_response = StreamResponse(
|
|
status=response.status,
|
|
headers={
|
|
key: value
|
|
for key, value in response.headers.items()
|
|
if key.lower() not in ["content-length", "transfer-encoding"]
|
|
},
|
|
)
|
|
|
|
await stream_response.prepare(request)
|
|
|
|
async for chunk in response.content.iter_any():
|
|
await stream_response.write(chunk)
|
|
|
|
await stream_response.write_eof()
|
|
return stream_response
|
|
|
|
async def _handle_regular_response(
|
|
self, response: aiohttp.ClientResponse
|
|
) -> web.Response:
|
|
"""Handle regular (non-streaming) responses"""
|
|
content = await response.read()
|
|
|
|
return web.Response(
|
|
body=content,
|
|
status=response.status,
|
|
headers={
|
|
key: value
|
|
for key, value in response.headers.items()
|
|
if key.lower() not in ["content-length", "transfer-encoding"]
|
|
},
|
|
)
|
|
|
|
|
|
async def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Ollama Context Proxy - URL-based routing with auto-sizing"
|
|
)
|
|
|
|
# Get default host from OLLAMA_BASE_URL if available
|
|
default_host = "ollama" # Default to "ollama" for Docker environments
|
|
base_url = os.getenv("OLLAMA_BASE_URL")
|
|
if base_url:
|
|
# Extract host from base URL for backward compatibility with CLI args
|
|
parsed = urllib.parse.urlparse(base_url)
|
|
if parsed.hostname:
|
|
default_host = parsed.hostname
|
|
else:
|
|
# If no OLLAMA_BASE_URL, check if we're likely in a Docker environment
|
|
if os.path.exists("/.dockerenv"):
|
|
default_host = "ollama"
|
|
else:
|
|
default_host = "localhost"
|
|
|
|
parser.add_argument(
|
|
"--ollama-host",
|
|
default=default_host,
|
|
help=f"Ollama server host (default: {default_host})",
|
|
)
|
|
parser.add_argument(
|
|
"--ollama-port",
|
|
type=int,
|
|
default=11434,
|
|
help="Ollama server port (default: 11434)",
|
|
)
|
|
parser.add_argument(
|
|
"--proxy-port",
|
|
type=int,
|
|
default=11435,
|
|
help="Proxy server port (default: 11435)",
|
|
)
|
|
parser.add_argument(
|
|
"--log-level",
|
|
default="INFO",
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
help="Log level (default: INFO)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=getattr(logging, args.log_level),
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
|
|
# Create proxy instance
|
|
proxy = OllamaContextProxy(args.ollama_host, args.ollama_port, args.proxy_port)
|
|
await proxy.start()
|
|
|
|
# Create and start the web application
|
|
app = proxy.create_app()
|
|
runner = web.AppRunner(app)
|
|
await runner.setup()
|
|
|
|
site = web.TCPSite(runner, "0.0.0.0", args.proxy_port)
|
|
await site.start()
|
|
|
|
logging.info(f"Ollama Context Proxy started on port {args.proxy_port}")
|
|
logging.info(f"Forwarding to Ollama at {proxy.ollama_base_url}")
|
|
logging.info(f"Available context sizes: {proxy.available_contexts}")
|
|
logging.info("Usage examples:")
|
|
logging.info(
|
|
f" Auto-size: http://localhost:{args.proxy_port}/proxy-context/auto"
|
|
)
|
|
logging.info(
|
|
f" 2K context: http://localhost:{args.proxy_port}/proxy-context/2048"
|
|
)
|
|
logging.info(
|
|
f" 4K context: http://localhost:{args.proxy_port}/proxy-context/4096"
|
|
)
|
|
logging.info(
|
|
f" 8K context: http://localhost:{args.proxy_port}/proxy-context/8192"
|
|
)
|
|
logging.info(
|
|
f" 16K context: http://localhost:{args.proxy_port}/proxy-context/16384"
|
|
)
|
|
logging.info(
|
|
f" 32K context: http://localhost:{args.proxy_port}/proxy-context/32768"
|
|
)
|
|
|
|
try:
|
|
# Keep running
|
|
while True:
|
|
await asyncio.sleep(1)
|
|
|
|
except KeyboardInterrupt:
|
|
logging.info("Shutting down...")
|
|
finally:
|
|
# Cleanup
|
|
await runner.cleanup()
|
|
await proxy.stop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
print("\nShutdown complete.")
|
|
sys.exit(0)
|