#!/usr/bin/env python3 """ Ollama Context Proxy - Single port with URL-based context routing + auto-sizing Use URLs like: http://localhost:11434/proxy-context/4096/api/generate Or auto-sizing: http://localhost:11434/proxy-context/auto/api/generate """ import asyncio import json import logging import os import re import urllib.parse from typing import Optional, Union import aiohttp from aiohttp import web, ClientSession from aiohttp.web_response import StreamResponse import argparse import sys class OllamaContextProxy: def __init__( self, ollama_host: Optional[str] = None, ollama_port: int = 11434, proxy_port: int = 11434, ): # Use OLLAMA_BASE_URL environment variable or construct from host/port base_url = os.getenv("OLLAMA_BASE_URL") if base_url: self.ollama_base_url = base_url.rstrip("/") else: # Fall back to host/port construction if ollama_host is None: ollama_host = "ollama" self.ollama_base_url = f"http://{ollama_host}:{ollama_port}" self.proxy_port = proxy_port self.session: Optional[ClientSession] = None self.logger = logging.getLogger(__name__) # Available context sizes (must be sorted ascending) self.available_contexts = [2048, 4096, 8192, 16384, 32768] # URL pattern to extract context size or 'auto' self.context_pattern = re.compile(r"^/proxy-context/(auto|\d+)(/.*)?$") async def start(self): """Initialize the HTTP session""" self.session = ClientSession() async def stop(self): """Cleanup HTTP session""" if self.session: await self.session.close() def create_app(self) -> web.Application: """Create the main web application""" app = web.Application() app["proxy"] = self # Add routes - capture everything under /proxy-context/ app.router.add_route( "*", r"/proxy-context/{context_spec:(auto|\d+)}{path:.*}", self.proxy_handler, ) # Optional: Add a health check endpoint app.router.add_get("/", self.health_check) app.router.add_get("/health", self.health_check) app.router.add_get("/debug/ollama", self.debug_ollama) return app async def debug_ollama(self, request: web.Request) -> web.Response: """Debug endpoint to test connectivity to Ollama""" if not self.session: return web.Response( text="Error: HTTP session not initialized", status=500, content_type="text/plain", ) test_url = f"{self.ollama_base_url}/api/tags" try: # Test basic connectivity to Ollama self.logger.info(f"Testing Ollama connectivity to: {test_url}") async with self.session.get(test_url) as response: status = response.status content_type = response.headers.get("content-type", "N/A") body = await response.text() return web.Response( text=f"Ollama Debug Test\n" f"=================\n" f"Target URL: {test_url}\n" f"Status: {status}\n" f"Content-Type: {content_type}\n" f"Body Length: {len(body)}\n" f"Body Preview: {body[:500]}...\n" f"\nProxy Base URL: {self.ollama_base_url}\n" f"Available Contexts: {self.available_contexts}", content_type="text/plain", ) except Exception as e: return web.Response( text=f"Ollama Debug Test FAILED\n" f"========================\n" f"Error: {str(e)}\n" f"Target URL: {test_url}\n" f"Proxy Base URL: {self.ollama_base_url}", status=502, content_type="text/plain", ) async def health_check(self, request: web.Request) -> web.Response: """Health check endpoint""" return web.Response( text="Ollama Context Proxy is running\n" "Usage: /proxy-context/{context_size}/api/{endpoint}\n" " /proxy-context/auto/api/{endpoint}\n" "Examples:\n" " Fixed: /proxy-context/4096/api/generate\n" " Auto: /proxy-context/auto/api/generate\n" f"Available contexts: {', '.join(map(str, self.available_contexts))}", content_type="text/plain", ) async def proxy_handler( self, request: web.Request ) -> web.Response | web.StreamResponse: """Handle all proxy requests with context size extraction or auto-detection""" # Extract context spec and remaining path context_spec = request.match_info["context_spec"] remaining_path = request.match_info.get("path", "") # Remove leading slash if present if remaining_path.startswith("/"): remaining_path = remaining_path[1:] # Get request data first (needed for auto-sizing) - read only once! original_data = None request_body = None if request.content_type == "application/json": try: original_data = await request.json() # Convert back to bytes for forwarding request_body = json.dumps(original_data).encode("utf-8") except json.JSONDecodeError as e: self.logger.error(f"Failed to parse JSON: {e}") request_body = await request.read() original_data = request_body.decode("utf-8", errors="ignore") else: request_body = await request.read() original_data = request_body # Use original_data for analysis, request_body for forwarding data_for_analysis = original_data if original_data is not None else {} data_for_forwarding = request_body if request_body is not None else b"" # Determine context size if context_spec == "auto": context_size = self._auto_determine_context_size( data_for_analysis, remaining_path ) else: context_size = int(context_spec) # Validate context size if context_size not in self.available_contexts: # Find the next larger available context suitable_context = next( (ctx for ctx in self.available_contexts if ctx >= context_size), self.available_contexts[-1], ) self.logger.warning( f"Requested context {context_size} not available, using {suitable_context}" ) context_size = suitable_context # Build target URL if not remaining_path: target_url = self.ollama_base_url else: target_url = f"{self.ollama_base_url}/{remaining_path}" # Enhanced debugging self.logger.info("=== REQUEST DEBUG ===") self.logger.info(f"Original request path: {request.path}") self.logger.info(f"Context spec: {context_spec}") self.logger.info(f"Remaining path: '{remaining_path}'") self.logger.info(f"Target URL: {target_url}") self.logger.info(f"Request method: {request.method}") self.logger.info(f"Request headers: {dict(request.headers)}") self.logger.info(f"Request query params: {dict(request.query)}") self.logger.info(f"Content type: {request.content_type}") if isinstance(data_for_analysis, dict): self.logger.info(f"Request data keys: {list(data_for_analysis.keys())}") else: data_len = ( len(data_for_analysis) if hasattr(data_for_analysis, "__len__") else "N/A" ) self.logger.info( f"Request data type: {type(data_for_analysis)}, length: {data_len}" ) self.logger.info(f"Selected context size: {context_size}") # Inject context if needed (modify the JSON data, not the raw bytes) modified_data = False if self._should_inject_context(remaining_path) and isinstance( data_for_analysis, dict ): if "options" not in data_for_analysis: data_for_analysis["options"] = {} data_for_analysis["options"]["num_ctx"] = context_size self.logger.info(f"Injected num_ctx={context_size} for {remaining_path}") # Re-encode the modified JSON data_for_forwarding = json.dumps(data_for_analysis).encode("utf-8") modified_data = True # Prepare headers (exclude hop-by-hop headers) headers = { key: value for key, value in request.headers.items() if key.lower() not in ["host", "connection", "upgrade", "content-length"] } # Update Content-Length if we modified the data if modified_data and isinstance(data_for_forwarding, bytes): headers["Content-Length"] = str(len(data_for_forwarding)) # Debug the final data being sent self.logger.debug(f"Final data being sent: {data_for_forwarding}") self.logger.debug(f"Final headers: {headers}") if not self.session: raise RuntimeError("HTTP session not initialized") try: # Make request to Ollama async with self.session.request( method=request.method, url=target_url, data=data_for_forwarding, headers=headers, params=request.query, ) as response: # Enhanced response debugging self.logger.info("=== RESPONSE DEBUG ===") self.logger.info(f"Response status: {response.status}") self.logger.info(f"Response headers: {dict(response.headers)}") self.logger.info( f"Response content-type: {response.headers.get('content-type', 'N/A')}" ) # Log response body for non-streaming 404s if response.status == 404: error_body = await response.text() self.logger.error(f"404 Error body: {error_body}") return web.Response( text=f"Ollama 404 Error - URL: {target_url}\nError: {error_body}", status=404, content_type="text/plain", ) # Handle streaming responses (for generate/chat endpoints) if response.headers.get("content-type", "").startswith( "application/x-ndjson" ): return await self._handle_streaming_response(request, response) else: return await self._handle_regular_response(response) except aiohttp.ClientError as e: self.logger.error(f"Error proxying request to {target_url}: {e}") return web.Response( text=f"Proxy error: {str(e)}", status=502, content_type="text/plain" ) def _auto_determine_context_size( self, data: Union[dict, str, bytes], endpoint: str ) -> int: """Automatically determine the required context size based on request content""" input_tokens = 0 max_tokens = 0 if isinstance(data, dict): # Extract text content and max_tokens based on endpoint if endpoint.startswith("api/generate"): # Ollama generate endpoint prompt = data.get("prompt", "") input_tokens = self._estimate_tokens(prompt) max_tokens = data.get("options", {}).get("num_predict", 0) elif endpoint.startswith("api/chat"): # Ollama chat endpoint messages = data.get("messages", []) total_text = "" for msg in messages: if isinstance(msg, dict) and "content" in msg: total_text += str(msg["content"]) + " " input_tokens = self._estimate_tokens(total_text) max_tokens = data.get("options", {}).get("num_predict", 0) elif endpoint.startswith("v1/chat/completions"): # OpenAI-compatible chat endpoint messages = data.get("messages", []) total_text = "" for msg in messages: if isinstance(msg, dict) and "content" in msg: total_text += str(msg["content"]) + " " input_tokens = self._estimate_tokens(total_text) max_tokens = data.get("max_tokens", 0) elif endpoint.startswith("v1/completions"): # OpenAI-compatible completions endpoint prompt = data.get("prompt", "") input_tokens = self._estimate_tokens(prompt) max_tokens = data.get("max_tokens", 0) elif isinstance(data, (str, bytes)): # Fallback for non-JSON data text = ( data if isinstance(data, str) else data.decode("utf-8", errors="ignore") ) input_tokens = self._estimate_tokens(text) # Calculate total tokens needed system_overhead = 100 # Buffer for system prompts, formatting, etc. response_buffer = max(max_tokens, 512) # Ensure space for response safety_margin = 200 # Additional safety buffer total_needed = input_tokens + response_buffer + system_overhead + safety_margin # Find the smallest context that can accommodate the request suitable_context = next( (ctx for ctx in self.available_contexts if ctx >= total_needed), self.available_contexts[-1], # Fall back to largest if none are big enough ) self.logger.info( f"Auto-sizing analysis: " f"input_tokens={input_tokens}, " f"max_tokens={max_tokens}, " f"total_needed={total_needed}, " f"selected_context={suitable_context}" ) # Log warning if we're using the largest context and it might not be enough if ( suitable_context == self.available_contexts[-1] and total_needed > suitable_context ): self.logger.warning( f"Request may exceed largest available context! " f"Needed: {total_needed}, Available: {suitable_context}" ) return suitable_context def _estimate_tokens(self, text: str) -> int: """Estimate token count from text (rough approximation)""" if not text: return 0 # Rough estimation: ~4 characters per token for English # This is a conservative estimate - actual tokenization varies by model char_count = len(str(text)) estimated_tokens = max(1, char_count // 4) self.logger.debug( f"Token estimation: {char_count} chars -> ~{estimated_tokens} tokens" ) return estimated_tokens def _should_inject_context(self, path: str) -> bool: """Determine if we should inject context for this endpoint""" # Inject context for endpoints that support the num_ctx parameter context_endpoints = [ "api/generate", "api/chat", "v1/chat/completions", "v1/completions", ] return any(path.startswith(endpoint) for endpoint in context_endpoints) async def _handle_streaming_response( self, request: web.Request, response: aiohttp.ClientResponse ) -> StreamResponse: """Handle streaming responses (NDJSON)""" stream_response = StreamResponse( status=response.status, headers={ key: value for key, value in response.headers.items() if key.lower() not in ["content-length", "transfer-encoding"] }, ) await stream_response.prepare(request) async for chunk in response.content.iter_any(): await stream_response.write(chunk) await stream_response.write_eof() return stream_response async def _handle_regular_response( self, response: aiohttp.ClientResponse ) -> web.Response: """Handle regular (non-streaming) responses""" content = await response.read() return web.Response( body=content, status=response.status, headers={ key: value for key, value in response.headers.items() if key.lower() not in ["content-length", "transfer-encoding"] }, ) async def main(): parser = argparse.ArgumentParser( description="Ollama Context Proxy - URL-based routing with auto-sizing" ) # Get default host from OLLAMA_BASE_URL if available default_host = "ollama" # Default to "ollama" for Docker environments base_url = os.getenv("OLLAMA_BASE_URL") if base_url: # Extract host from base URL for backward compatibility with CLI args parsed = urllib.parse.urlparse(base_url) if parsed.hostname: default_host = parsed.hostname else: # If no OLLAMA_BASE_URL, check if we're likely in a Docker environment if os.path.exists("/.dockerenv"): default_host = "ollama" else: default_host = "localhost" parser.add_argument( "--ollama-host", default=default_host, help=f"Ollama server host (default: {default_host})", ) parser.add_argument( "--ollama-port", type=int, default=11434, help="Ollama server port (default: 11434)", ) parser.add_argument( "--proxy-port", type=int, default=11435, help="Proxy server port (default: 11435)", ) parser.add_argument( "--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Log level (default: INFO)", ) args = parser.parse_args() # Setup logging logging.basicConfig( level=getattr(logging, args.log_level), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) # Create proxy instance proxy = OllamaContextProxy(args.ollama_host, args.ollama_port, args.proxy_port) await proxy.start() # Create and start the web application app = proxy.create_app() runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, "0.0.0.0", args.proxy_port) await site.start() logging.info(f"Ollama Context Proxy started on port {args.proxy_port}") logging.info(f"Forwarding to Ollama at {proxy.ollama_base_url}") logging.info(f"Available context sizes: {proxy.available_contexts}") logging.info("Usage examples:") logging.info( f" Auto-size: http://localhost:{args.proxy_port}/proxy-context/auto" ) logging.info( f" 2K context: http://localhost:{args.proxy_port}/proxy-context/2048" ) logging.info( f" 4K context: http://localhost:{args.proxy_port}/proxy-context/4096" ) logging.info( f" 8K context: http://localhost:{args.proxy_port}/proxy-context/8192" ) logging.info( f" 16K context: http://localhost:{args.proxy_port}/proxy-context/16384" ) logging.info( f" 32K context: http://localhost:{args.proxy_port}/proxy-context/32768" ) try: # Keep running while True: await asyncio.sleep(1) except KeyboardInterrupt: logging.info("Shutting down...") finally: # Cleanup await runner.cleanup() await proxy.stop() if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: print("\nShutdown complete.") sys.exit(0)