""" AI Provider Integration for Advanced Bot Management. This module provides support for multiple AI providers including OpenAI, Anthropic, and local models for enhanced bot capabilities. """ import os import time import asyncio from abc import ABC, abstractmethod from typing import Dict, List, Optional, Any, AsyncIterator from enum import Enum from dataclasses import dataclass from pydantic import BaseModel, Field from shared.logger import logger class AIProviderType(str, Enum): """Supported AI provider types.""" OPENAI = "openai" ANTHROPIC = "anthropic" LOCAL = "local" CUSTOM = "custom" class MessageRole(str, Enum): """Message roles in conversation.""" SYSTEM = "system" USER = "user" ASSISTANT = "assistant" @dataclass class ConversationMessage: """Individual message in a conversation.""" role: MessageRole content: str timestamp: float = None metadata: Dict[str, Any] = None def __post_init__(self): if self.timestamp is None: self.timestamp = time.time() if self.metadata is None: self.metadata = {} class AIProviderConfig(BaseModel): """Configuration for AI providers.""" provider_type: AIProviderType api_key: Optional[str] = None base_url: Optional[str] = None model: str = "gpt-3.5-turbo" max_tokens: int = 1000 temperature: float = 0.7 timeout: float = 30.0 retry_attempts: int = 3 retry_delay: float = 1.0 # Advanced settings top_p: Optional[float] = None frequency_penalty: Optional[float] = None presence_penalty: Optional[float] = None stop_sequences: List[str] = Field(default_factory=list) class Config: extra = "allow" # Allow additional provider-specific configs class ConversationContext(BaseModel): """Conversation context and memory management.""" session_id: str bot_name: str messages: List[ConversationMessage] = Field(default_factory=list) created_at: float = Field(default_factory=time.time) last_updated: float = Field(default_factory=time.time) # Context management max_history: int = 50 context_window: int = 4000 # Token limit for context personality_prompt: Optional[str] = None # Metadata user_preferences: Dict[str, Any] = Field(default_factory=dict) conversation_state: Dict[str, Any] = Field(default_factory=dict) def add_message(self, role: MessageRole, content: str, metadata: Dict[str, Any] = None): """Add a message to the conversation.""" message = ConversationMessage(role=role, content=content, metadata=metadata or {}) self.messages.append(message) self.last_updated = time.time() # Trim history if needed if len(self.messages) > self.max_history: # Keep system messages and recent messages system_messages = [m for m in self.messages if m.role == MessageRole.SYSTEM] recent_messages = [m for m in self.messages if m.role != MessageRole.SYSTEM][-self.max_history:] self.messages = system_messages + recent_messages def get_context_messages(self) -> List[Dict[str, str]]: """Get messages formatted for AI provider APIs.""" messages = [] for msg in self.messages: messages.append({ "role": msg.role.value, "content": msg.content }) return messages class AIProvider(ABC): """Abstract base class for AI providers.""" def __init__(self, config: AIProviderConfig): self.config = config self.provider_type = config.provider_type @abstractmethod async def generate_response( self, context: ConversationContext, message: str ) -> str: """Generate a response to a message.""" pass @abstractmethod async def stream_response( self, context: ConversationContext, message: str ) -> AsyncIterator[str]: """Stream a response to a message.""" pass @abstractmethod async def health_check(self) -> bool: """Check if the provider is healthy and available.""" pass class OpenAIProvider(AIProvider): """OpenAI provider implementation.""" def __init__(self, config: AIProviderConfig): super().__init__(config) self._client = None def _get_client(self): """Lazy initialization of OpenAI client.""" if self._client is None: try: import openai self._client = openai.AsyncOpenAI( api_key=self.config.api_key or os.getenv("OPENAI_API_KEY"), base_url=self.config.base_url, timeout=self.config.timeout ) except ImportError: raise ImportError("OpenAI package not installed. Install with: pip install openai") return self._client async def generate_response(self, context: ConversationContext, message: str) -> str: """Generate response using OpenAI API.""" client = self._get_client() # Add user message to context context.add_message(MessageRole.USER, message) messages = context.get_context_messages() for attempt in range(self.config.retry_attempts): try: response = await client.chat.completions.create( model=self.config.model, messages=messages, max_tokens=self.config.max_tokens, temperature=self.config.temperature, top_p=self.config.top_p, frequency_penalty=self.config.frequency_penalty, presence_penalty=self.config.presence_penalty, stop=self.config.stop_sequences or None ) response_text = response.choices[0].message.content context.add_message(MessageRole.ASSISTANT, response_text) return response_text except Exception as e: logger.warning(f"OpenAI API attempt {attempt + 1} failed: {e}") if attempt < self.config.retry_attempts - 1: await asyncio.sleep(self.config.retry_delay * (2 ** attempt)) else: raise async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]: """Stream response using OpenAI API.""" client = self._get_client() context.add_message(MessageRole.USER, message) messages = context.get_context_messages() try: stream = await client.chat.completions.create( model=self.config.model, messages=messages, max_tokens=self.config.max_tokens, temperature=self.config.temperature, stream=True ) full_response = "" async for chunk in stream: if chunk.choices[0].delta.content: content = chunk.choices[0].delta.content full_response += content yield content # Add complete response to context context.add_message(MessageRole.ASSISTANT, full_response) except Exception as e: logger.error(f"OpenAI streaming failed: {e}") raise async def health_check(self) -> bool: """Check OpenAI API health.""" try: client = self._get_client() await client.models.list() return True except Exception as e: logger.warning(f"OpenAI health check failed: {e}") return False class AnthropicProvider(AIProvider): """Anthropic Claude provider implementation.""" def __init__(self, config: AIProviderConfig): super().__init__(config) self._client = None def _get_client(self): """Lazy initialization of Anthropic client.""" if self._client is None: try: import anthropic self._client = anthropic.AsyncAnthropic( api_key=self.config.api_key or os.getenv("ANTHROPIC_API_KEY"), timeout=self.config.timeout ) except ImportError: raise ImportError("Anthropic package not installed. Install with: pip install anthropic") return self._client async def generate_response(self, context: ConversationContext, message: str) -> str: """Generate response using Anthropic API.""" client = self._get_client() context.add_message(MessageRole.USER, message) # Convert messages for Anthropic format messages = [] system_prompt = None for msg in context.messages: if msg.role == MessageRole.SYSTEM: system_prompt = msg.content else: messages.append({ "role": msg.role.value, "content": msg.content }) for attempt in range(self.config.retry_attempts): try: kwargs = { "model": self.config.model, "messages": messages, "max_tokens": self.config.max_tokens, "temperature": self.config.temperature, } if system_prompt: kwargs["system"] = system_prompt response = await client.messages.create(**kwargs) response_text = response.content[0].text context.add_message(MessageRole.ASSISTANT, response_text) return response_text except Exception as e: logger.warning(f"Anthropic API attempt {attempt + 1} failed: {e}") if attempt < self.config.retry_attempts - 1: await asyncio.sleep(self.config.retry_delay * (2 ** attempt)) else: raise async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]: """Stream response using Anthropic API.""" client = self._get_client() context.add_message(MessageRole.USER, message) # Convert messages for Anthropic format messages = [] system_prompt = None for msg in context.messages: if msg.role == MessageRole.SYSTEM: system_prompt = msg.content else: messages.append({ "role": msg.role.value, "content": msg.content }) try: kwargs = { "model": self.config.model, "messages": messages, "max_tokens": self.config.max_tokens, "temperature": self.config.temperature, } if system_prompt: kwargs["system"] = system_prompt async with client.messages.stream(**kwargs) as stream: full_response = "" async for text in stream.text_stream: full_response += text yield text context.add_message(MessageRole.ASSISTANT, full_response) except Exception as e: logger.error(f"Anthropic streaming failed: {e}") raise async def health_check(self) -> bool: """Check Anthropic API health.""" try: client = self._get_client() # Simple test to verify API connectivity await client.messages.create( model=self.config.model, messages=[{"role": "user", "content": "test"}], max_tokens=1 ) return True except Exception as e: logger.warning(f"Anthropic health check failed: {e}") return False class LocalProvider(AIProvider): """Local model provider (e.g., Ollama, llama.cpp).""" def __init__(self, config: AIProviderConfig): super().__init__(config) self.base_url = config.base_url or "http://localhost:11434" async def generate_response(self, context: ConversationContext, message: str) -> str: """Generate response using local model API.""" context.add_message(MessageRole.USER, message) import aiohttp async with aiohttp.ClientSession() as session: payload = { "model": self.config.model, "messages": context.get_context_messages(), "stream": False, "options": { "temperature": self.config.temperature, "num_predict": self.config.max_tokens } } logger.info(f"LocalProvider generate payload: {payload}") try: async with session.post( f"{self.base_url}/chat/completions", json=payload, timeout=aiohttp.ClientTimeout(total=self.config.timeout) ) as resp: if resp.status == 200: result = await resp.json() # Handle OpenAI-compatible format if "choices" in result and len(result["choices"]) > 0: choice = result["choices"][0] if "message" in choice and "content" in choice["message"]: response_text = choice["message"]["content"] elif "text" in choice: # Some APIs use "text" instead of "message.content" response_text = choice["text"] else: raise Exception(f"Unexpected response format: {result}") # Fallback to original format for backward compatibility elif "message" in result and "content" in result["message"]: response_text = result["message"]["content"] else: raise Exception(f"Unexpected response format: {result}") context.add_message(MessageRole.ASSISTANT, response_text) return response_text else: raise Exception(f"Local API returned status {resp.status}") except Exception as e: logger.error(f"Local provider failed: {e}") raise async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]: """Stream response using local model API.""" context.add_message(MessageRole.USER, message) import aiohttp async with aiohttp.ClientSession() as session: payload = { "model": self.config.model, "messages": context.get_context_messages(), "stream": True } logger.info(f"LocalProvider stream payload: {self.base_url} {payload}") try: async with session.post( f"{self.base_url}/chat/completions", json=payload, timeout=aiohttp.ClientTimeout(total=self.config.timeout) ) as resp: if resp.status == 200: full_response = "" import json async for line in resp.content: if line: try: # Decode the line line_str = line.decode('utf-8').strip() # Skip empty lines if not line_str: continue # Handle Server-Sent Events format if line_str.startswith('data: '): line_str = line_str[6:] # Remove 'data: ' prefix # Skip end-of-stream marker if line_str == '[DONE]': break data = json.loads(line_str) # Handle OpenAI-compatible format if "choices" in data and len(data["choices"]) > 0: choice = data["choices"][0] if "delta" in choice and "content" in choice["delta"]: content = choice["delta"]["content"] if content: # Only yield non-empty content full_response += content yield content except (UnicodeDecodeError, json.JSONDecodeError) as e: logger.debug(f"Skipping invalid line: {line} (error: {e})") continue except Exception as e: logger.warning(f"Unexpected error processing line: {e}") continue context.add_message(MessageRole.ASSISTANT, full_response) else: raise Exception(f"Local API returned status {resp.status}") except Exception as e: logger.error(f"Local provider streaming failed {self.base_url}: {e}") raise async def health_check(self) -> bool: """Check local model health.""" try: import aiohttp async with aiohttp.ClientSession() as session: async with session.get( f"{self.base_url}/api/tags", timeout=aiohttp.ClientTimeout(total=5) ) as resp: return resp.status == 200 except Exception as e: logger.warning(f"Local provider health check failed: {e}") return False class AIProviderManager: """Manager for AI providers and configurations.""" def __init__(self): self.providers: Dict[str, AIProvider] = {} self.default_configs = self._load_default_configs() def _load_default_configs(self) -> Dict[AIProviderType, AIProviderConfig]: """Load default configurations for providers.""" return { AIProviderType.OPENAI: AIProviderConfig( provider_type=AIProviderType.OPENAI, model=os.getenv("OPENAI_MODEL", "gpt-3.5-turbo"), max_tokens=int(os.getenv("OPENAI_MAX_TOKENS", "1000")), temperature=float(os.getenv("OPENAI_TEMPERATURE", "0.7")) ), AIProviderType.ANTHROPIC: AIProviderConfig( provider_type=AIProviderType.ANTHROPIC, model=os.getenv("ANTHROPIC_MODEL", "claude-3-5-sonnet-20241022"), max_tokens=int(os.getenv("ANTHROPIC_MAX_TOKENS", "1000")), temperature=float(os.getenv("ANTHROPIC_TEMPERATURE", "0.7")) ), AIProviderType.LOCAL: AIProviderConfig( provider_type=AIProviderType.LOCAL, base_url=os.getenv("OPENAI_BASE_URL", "http://localhost:11434"), model=os.getenv("OPENAI_MODEL", "llama2"), max_tokens=int(os.getenv("LOCAL_MAX_TOKENS", "1000")), temperature=float(os.getenv("LOCAL_TEMPERATURE", "0.7")) ) } def create_provider(self, provider_type: AIProviderType, config: Optional[AIProviderConfig] = None) -> AIProvider: """Create an AI provider instance.""" if config is None: config = self.default_configs.get(provider_type) if config is None: raise ValueError(f"No default config for provider type: {provider_type}") if provider_type == AIProviderType.OPENAI: return OpenAIProvider(config) elif provider_type == AIProviderType.ANTHROPIC: return AnthropicProvider(config) elif provider_type == AIProviderType.LOCAL: return LocalProvider(config) else: raise ValueError(f"Unsupported provider type: {provider_type}") def register_provider(self, name: str, provider: AIProvider): """Register a provider instance.""" self.providers[name] = provider logger.info(f"Registered AI provider: {name} ({provider.provider_type})") def get_provider(self, name: str) -> Optional[AIProvider]: """Get a registered provider.""" return self.providers.get(name) def list_providers(self) -> List[str]: """List all registered provider names.""" return list(self.providers.keys()) async def health_check_all(self) -> Dict[str, bool]: """Health check all registered providers.""" results = {} for name, provider in self.providers.items(): try: results[name] = await provider.health_check() except Exception as e: logger.error(f"Health check failed for provider {name}: {e}") results[name] = False return results # Global provider manager instance ai_provider_manager = AIProviderManager()