#!/usr/bin/env python3 """ Ollama Context Proxy - Single port with URL-based context routing + auto-sizing Use URLs like: http://localhost:11434/proxy-context/4096/api/generate Or auto-sizing: http://localhost:11434/proxy-context/auto/api/generate """ import asyncio import json import logging import os import re import urllib.parse from typing import Optional, Union import aiohttp from aiohttp import web, ClientSession from aiohttp.web_response import StreamResponse import argparse import sys class OllamaContextProxy: def __init__( self, ollama_host: Optional[str] = None, ollama_port: int = 11434, proxy_port: int = 11434, ): # Use OLLAMA_BASE_URL environment variable or construct from host/port base_url = os.getenv("OLLAMA_BASE_URL") if base_url: self.ollama_base_url = base_url.rstrip("/") else: # Fall back to host/port construction if ollama_host is None: ollama_host = "localhost" self.ollama_base_url = f"http://{ollama_host}:{ollama_port}" self.proxy_port = proxy_port self.session: Optional[ClientSession] = None self.logger = logging.getLogger(__name__) # Available context sizes (must be sorted ascending) self.available_contexts = [2048, 4096, 8192, 16384, 32768] # URL pattern to extract context size or 'auto' self.context_pattern = re.compile(r"^/proxy-context/(auto|\d+)(/.*)?$") async def start(self): """Initialize the HTTP session""" self.session = ClientSession() async def stop(self): """Cleanup HTTP session""" if self.session: await self.session.close() def create_app(self) -> web.Application: """Create the main web application""" app = web.Application() app["proxy"] = self # Add routes - capture everything under /proxy-context/ app.router.add_route( "*", r"/proxy-context/{context_spec:(auto|\d+)}{path:.*}", self.proxy_handler, ) # Optional: Add a health check endpoint app.router.add_get("/", self.health_check) app.router.add_get("/health", self.health_check) return app async def health_check(self, request: web.Request) -> web.Response: """Health check endpoint""" return web.Response( text="Ollama Context Proxy is running\n" "Usage: /proxy-context/{context_size}/api/{endpoint}\n" " /proxy-context/auto/api/{endpoint}\n" "Examples:\n" " Fixed: /proxy-context/4096/api/generate\n" " Auto: /proxy-context/auto/api/generate\n" f"Available contexts: {', '.join(map(str, self.available_contexts))}", content_type="text/plain", ) async def proxy_handler(self, request: web.Request) -> web.Response: """Handle all proxy requests with context size extraction or auto-detection""" # Extract context spec and remaining path context_spec = request.match_info["context_spec"] remaining_path = request.match_info.get("path", "") # Remove leading slash if present if remaining_path.startswith("/"): remaining_path = remaining_path[1:] # Get request data first (needed for auto-sizing) if request.content_type == "application/json": try: data = await request.json() except json.JSONDecodeError: data = await request.text() else: data = await request.read() # Determine context size if context_spec == "auto": context_size = self._auto_determine_context_size(data, remaining_path) else: context_size = int(context_spec) # Validate context size if context_size not in self.available_contexts: # Find the next larger available context suitable_context = next( (ctx for ctx in self.available_contexts if ctx >= context_size), self.available_contexts[-1], ) self.logger.warning( f"Requested context {context_size} not available, using {suitable_context}" ) context_size = suitable_context # Build target URL if not remaining_path: target_url = self.ollama_base_url else: target_url = f"{self.ollama_base_url}/{remaining_path}" self.logger.info(f"Routing to context {context_size} -> {target_url}") # Inject context if needed if self._should_inject_context(remaining_path) and isinstance(data, dict): if "options" not in data: data["options"] = {} data["options"]["num_ctx"] = context_size self.logger.info(f"Injected num_ctx={context_size} for {remaining_path}") # Prepare headers (exclude hop-by-hop headers) headers = { key: value for key, value in request.headers.items() if key.lower() not in ["host", "connection", "upgrade"] } if not self.session: raise RuntimeError("HTTP session not initialized") try: # Make request to Ollama async with self.session.request( method=request.method, url=target_url, data=json.dumps(data) if isinstance(data, dict) else data, headers=headers, params=request.query, ) as response: # Handle streaming responses (for generate/chat endpoints) if response.headers.get("content-type", "").startswith( "application/x-ndjson" ): return await self._handle_streaming_response(request, response) else: return await self._handle_regular_response(response) except aiohttp.ClientError as e: self.logger.error(f"Error proxying request to {target_url}: {e}") return web.Response( text=f"Proxy error: {str(e)}", status=502, content_type="text/plain" ) def _auto_determine_context_size( self, data: Union[dict, str, bytes], endpoint: str ) -> int: """Automatically determine the required context size based on request content""" input_tokens = 0 max_tokens = 0 if isinstance(data, dict): # Extract text content and max_tokens based on endpoint if endpoint.startswith("api/generate"): # Ollama generate endpoint prompt = data.get("prompt", "") input_tokens = self._estimate_tokens(prompt) max_tokens = data.get("options", {}).get("num_predict", 0) elif endpoint.startswith("api/chat"): # Ollama chat endpoint messages = data.get("messages", []) total_text = "" for msg in messages: if isinstance(msg, dict) and "content" in msg: total_text += str(msg["content"]) + " " input_tokens = self._estimate_tokens(total_text) max_tokens = data.get("options", {}).get("num_predict", 0) elif endpoint.startswith("v1/chat/completions"): # OpenAI-compatible chat endpoint messages = data.get("messages", []) total_text = "" for msg in messages: if isinstance(msg, dict) and "content" in msg: total_text += str(msg["content"]) + " " input_tokens = self._estimate_tokens(total_text) max_tokens = data.get("max_tokens", 0) elif endpoint.startswith("v1/completions"): # OpenAI-compatible completions endpoint prompt = data.get("prompt", "") input_tokens = self._estimate_tokens(prompt) max_tokens = data.get("max_tokens", 0) elif isinstance(data, (str, bytes)): # Fallback for non-JSON data text = ( data if isinstance(data, str) else data.decode("utf-8", errors="ignore") ) input_tokens = self._estimate_tokens(text) # Calculate total tokens needed system_overhead = 100 # Buffer for system prompts, formatting, etc. response_buffer = max(max_tokens, 512) # Ensure space for response safety_margin = 200 # Additional safety buffer total_needed = input_tokens + response_buffer + system_overhead + safety_margin # Find the smallest context that can accommodate the request suitable_context = next( (ctx for ctx in self.available_contexts if ctx >= total_needed), self.available_contexts[-1], # Fall back to largest if none are big enough ) self.logger.info( f"Auto-sizing analysis: " f"input_tokens={input_tokens}, " f"max_tokens={max_tokens}, " f"total_needed={total_needed}, " f"selected_context={suitable_context}" ) # Log warning if we're using the largest context and it might not be enough if ( suitable_context == self.available_contexts[-1] and total_needed > suitable_context ): self.logger.warning( f"Request may exceed largest available context! " f"Needed: {total_needed}, Available: {suitable_context}" ) return suitable_context def _estimate_tokens(self, text: str) -> int: """Estimate token count from text (rough approximation)""" if not text: return 0 # Rough estimation: ~4 characters per token for English # This is a conservative estimate - actual tokenization varies by model char_count = len(str(text)) estimated_tokens = max(1, char_count // 4) self.logger.debug( f"Token estimation: {char_count} chars -> ~{estimated_tokens} tokens" ) return estimated_tokens def _should_inject_context(self, path: str) -> bool: """Determine if we should inject context for this endpoint""" # Inject context for endpoints that support the num_ctx parameter context_endpoints = [ "api/generate", "api/chat", "v1/chat/completions", "v1/completions", ] return any(path.startswith(endpoint) for endpoint in context_endpoints) async def _handle_streaming_response( self, request: web.Request, response: aiohttp.ClientResponse ) -> StreamResponse: """Handle streaming responses (NDJSON)""" stream_response = StreamResponse( status=response.status, headers={ key: value for key, value in response.headers.items() if key.lower() not in ["content-length", "transfer-encoding"] }, ) await stream_response.prepare(request) async for chunk in response.content.iter_any(): await stream_response.write(chunk) await stream_response.write_eof() return stream_response async def _handle_regular_response( self, response: aiohttp.ClientResponse ) -> web.Response: """Handle regular (non-streaming) responses""" content = await response.read() return web.Response( body=content, status=response.status, headers={ key: value for key, value in response.headers.items() if key.lower() not in ["content-length", "transfer-encoding"] }, ) async def main(): parser = argparse.ArgumentParser( description="Ollama Context Proxy - URL-based routing with auto-sizing" ) # Get default host from OLLAMA_BASE_URL if available default_host = "localhost" base_url = os.getenv("OLLAMA_BASE_URL") if base_url: # Extract host from base URL for backward compatibility with CLI args parsed = urllib.parse.urlparse(base_url) if parsed.hostname: default_host = parsed.hostname parser.add_argument( "--ollama-host", default=default_host, help=f"Ollama server host (default: {default_host})", ) parser.add_argument( "--ollama-port", type=int, default=11434, help="Ollama server port (default: 11434)", ) parser.add_argument( "--proxy-port", type=int, default=11435, help="Proxy server port (default: 11435)", ) parser.add_argument( "--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Log level (default: INFO)", ) args = parser.parse_args() # Setup logging logging.basicConfig( level=getattr(logging, args.log_level), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) # Create proxy instance proxy = OllamaContextProxy(args.ollama_host, args.ollama_port, args.proxy_port) await proxy.start() # Create and start the web application app = proxy.create_app() runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, "0.0.0.0", args.proxy_port) await site.start() logging.info(f"Ollama Context Proxy started on port {args.proxy_port}") logging.info(f"Forwarding to Ollama at {proxy.ollama_base_url}") logging.info(f"Available context sizes: {proxy.available_contexts}") logging.info("Usage examples:") logging.info( f" Auto-size: http://localhost:{args.proxy_port}/proxy-context/auto" ) logging.info( f" 2K context: http://localhost:{args.proxy_port}/proxy-context/2048" ) logging.info( f" 4K context: http://localhost:{args.proxy_port}/proxy-context/4096" ) logging.info( f" 8K context: http://localhost:{args.proxy_port}/proxy-context/8192" ) logging.info( f" 16K context: http://localhost:{args.proxy_port}/proxy-context/16384" ) logging.info( f" 32K context: http://localhost:{args.proxy_port}/proxy-context/32768" ) try: # Keep running while True: await asyncio.sleep(1) except KeyboardInterrupt: logging.info("Shutting down...") finally: # Cleanup await runner.cleanup() await proxy.stop() if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: print("\nShutdown complete.") sys.exit(0)