515 lines
		
	
	
		
			19 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			515 lines
		
	
	
		
			19 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| AI Provider Integration for Advanced Bot Management.
 | |
| 
 | |
| This module provides support for multiple AI providers including OpenAI, Anthropic,
 | |
| and local models for enhanced bot capabilities.
 | |
| """
 | |
| 
 | |
| import os
 | |
| import time
 | |
| import asyncio
 | |
| from abc import ABC, abstractmethod
 | |
| from typing import Dict, List, Optional, Any, AsyncIterator
 | |
| from enum import Enum
 | |
| from dataclasses import dataclass
 | |
| from pydantic import BaseModel, Field
 | |
| 
 | |
| from logger import logger
 | |
| 
 | |
| 
 | |
| class AIProviderType(str, Enum):
 | |
|     """Supported AI provider types."""
 | |
|     OPENAI = "openai"
 | |
|     ANTHROPIC = "anthropic"
 | |
|     LOCAL = "local"
 | |
|     CUSTOM = "custom"
 | |
| 
 | |
| 
 | |
| class MessageRole(str, Enum):
 | |
|     """Message roles in conversation."""
 | |
|     SYSTEM = "system"
 | |
|     USER = "user"
 | |
|     ASSISTANT = "assistant"
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class ConversationMessage:
 | |
|     """Individual message in a conversation."""
 | |
|     role: MessageRole
 | |
|     content: str
 | |
|     timestamp: float = None
 | |
|     metadata: Dict[str, Any] = None
 | |
|     
 | |
|     def __post_init__(self):
 | |
|         if self.timestamp is None:
 | |
|             self.timestamp = time.time()
 | |
|         if self.metadata is None:
 | |
|             self.metadata = {}
 | |
| 
 | |
| 
 | |
| class AIProviderConfig(BaseModel):
 | |
|     """Configuration for AI providers."""
 | |
|     provider_type: AIProviderType
 | |
|     api_key: Optional[str] = None
 | |
|     base_url: Optional[str] = None
 | |
|     model: str = "gpt-3.5-turbo"
 | |
|     max_tokens: int = 1000
 | |
|     temperature: float = 0.7
 | |
|     timeout: float = 30.0
 | |
|     retry_attempts: int = 3
 | |
|     retry_delay: float = 1.0
 | |
|     
 | |
|     # Advanced settings
 | |
|     top_p: Optional[float] = None
 | |
|     frequency_penalty: Optional[float] = None
 | |
|     presence_penalty: Optional[float] = None
 | |
|     stop_sequences: List[str] = Field(default_factory=list)
 | |
|     
 | |
|     class Config:
 | |
|         extra = "allow"  # Allow additional provider-specific configs
 | |
| 
 | |
| 
 | |
| class ConversationContext(BaseModel):
 | |
|     """Conversation context and memory management."""
 | |
|     session_id: str
 | |
|     bot_name: str
 | |
|     messages: List[ConversationMessage] = Field(default_factory=list)
 | |
|     created_at: float = Field(default_factory=time.time)
 | |
|     last_updated: float = Field(default_factory=time.time)
 | |
|     
 | |
|     # Context management
 | |
|     max_history: int = 50
 | |
|     context_window: int = 4000  # Token limit for context
 | |
|     personality_prompt: Optional[str] = None
 | |
|     
 | |
|     # Metadata
 | |
|     user_preferences: Dict[str, Any] = Field(default_factory=dict)
 | |
|     conversation_state: Dict[str, Any] = Field(default_factory=dict)
 | |
|     
 | |
|     def add_message(self, role: MessageRole, content: str, metadata: Dict[str, Any] = None):
 | |
|         """Add a message to the conversation."""
 | |
|         message = ConversationMessage(role=role, content=content, metadata=metadata or {})
 | |
|         self.messages.append(message)
 | |
|         self.last_updated = time.time()
 | |
|         
 | |
|         # Trim history if needed
 | |
|         if len(self.messages) > self.max_history:
 | |
|             # Keep system messages and recent messages
 | |
|             system_messages = [m for m in self.messages if m.role == MessageRole.SYSTEM]
 | |
|             recent_messages = [m for m in self.messages if m.role != MessageRole.SYSTEM][-self.max_history:]
 | |
|             self.messages = system_messages + recent_messages
 | |
|     
 | |
|     def get_context_messages(self) -> List[Dict[str, str]]:
 | |
|         """Get messages formatted for AI provider APIs."""
 | |
|         messages = []
 | |
|         for msg in self.messages:
 | |
|             messages.append({
 | |
|                 "role": msg.role.value,
 | |
|                 "content": msg.content
 | |
|             })
 | |
|         return messages
 | |
| 
 | |
| 
 | |
| class AIProvider(ABC):
 | |
|     """Abstract base class for AI providers."""
 | |
|     
 | |
|     def __init__(self, config: AIProviderConfig):
 | |
|         self.config = config
 | |
|         self.provider_type = config.provider_type
 | |
|     
 | |
|     @abstractmethod
 | |
|     async def generate_response(
 | |
|         self, 
 | |
|         context: ConversationContext, 
 | |
|         message: str
 | |
|     ) -> str:
 | |
|         """Generate a response to a message."""
 | |
|         pass
 | |
|     
 | |
|     @abstractmethod
 | |
|     async def stream_response(
 | |
|         self, 
 | |
|         context: ConversationContext, 
 | |
|         message: str
 | |
|     ) -> AsyncIterator[str]:
 | |
|         """Stream a response to a message."""
 | |
|         pass
 | |
|     
 | |
|     @abstractmethod
 | |
|     async def health_check(self) -> bool:
 | |
|         """Check if the provider is healthy and available."""
 | |
|         pass
 | |
| 
 | |
| 
 | |
| class OpenAIProvider(AIProvider):
 | |
|     """OpenAI provider implementation."""
 | |
|     
 | |
|     def __init__(self, config: AIProviderConfig):
 | |
|         super().__init__(config)
 | |
|         self._client = None
 | |
|     
 | |
|     def _get_client(self):
 | |
|         """Lazy initialization of OpenAI client."""
 | |
|         if self._client is None:
 | |
|             try:
 | |
|                 import openai
 | |
|                 self._client = openai.AsyncOpenAI(
 | |
|                     api_key=self.config.api_key or os.getenv("OPENAI_API_KEY"),
 | |
|                     base_url=self.config.base_url,
 | |
|                     timeout=self.config.timeout
 | |
|                 )
 | |
|             except ImportError:
 | |
|                 raise ImportError("OpenAI package not installed. Install with: pip install openai")
 | |
|         return self._client
 | |
|     
 | |
|     async def generate_response(self, context: ConversationContext, message: str) -> str:
 | |
|         """Generate response using OpenAI API."""
 | |
|         client = self._get_client()
 | |
|         
 | |
|         # Add user message to context
 | |
|         context.add_message(MessageRole.USER, message)
 | |
|         
 | |
|         messages = context.get_context_messages()
 | |
|         
 | |
|         for attempt in range(self.config.retry_attempts):
 | |
|             try:
 | |
|                 response = await client.chat.completions.create(
 | |
|                     model=self.config.model,
 | |
|                     messages=messages,
 | |
|                     max_tokens=self.config.max_tokens,
 | |
|                     temperature=self.config.temperature,
 | |
|                     top_p=self.config.top_p,
 | |
|                     frequency_penalty=self.config.frequency_penalty,
 | |
|                     presence_penalty=self.config.presence_penalty,
 | |
|                     stop=self.config.stop_sequences or None
 | |
|                 )
 | |
|                 
 | |
|                 response_text = response.choices[0].message.content
 | |
|                 context.add_message(MessageRole.ASSISTANT, response_text)
 | |
|                 return response_text
 | |
|                 
 | |
|             except Exception as e:
 | |
|                 logger.warning(f"OpenAI API attempt {attempt + 1} failed: {e}")
 | |
|                 if attempt < self.config.retry_attempts - 1:
 | |
|                     await asyncio.sleep(self.config.retry_delay * (2 ** attempt))
 | |
|                 else:
 | |
|                     raise
 | |
|     
 | |
|     async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]:
 | |
|         """Stream response using OpenAI API."""
 | |
|         client = self._get_client()
 | |
|         
 | |
|         context.add_message(MessageRole.USER, message)
 | |
|         messages = context.get_context_messages()
 | |
|         
 | |
|         try:
 | |
|             stream = await client.chat.completions.create(
 | |
|                 model=self.config.model,
 | |
|                 messages=messages,
 | |
|                 max_tokens=self.config.max_tokens,
 | |
|                 temperature=self.config.temperature,
 | |
|                 stream=True
 | |
|             )
 | |
|             
 | |
|             full_response = ""
 | |
|             async for chunk in stream:
 | |
|                 if chunk.choices[0].delta.content:
 | |
|                     content = chunk.choices[0].delta.content
 | |
|                     full_response += content
 | |
|                     yield content
 | |
|             
 | |
|             # Add complete response to context
 | |
|             context.add_message(MessageRole.ASSISTANT, full_response)
 | |
|             
 | |
|         except Exception as e:
 | |
|             logger.error(f"OpenAI streaming failed: {e}")
 | |
|             raise
 | |
|     
 | |
|     async def health_check(self) -> bool:
 | |
|         """Check OpenAI API health."""
 | |
|         try:
 | |
|             client = self._get_client()
 | |
|             await client.models.list()
 | |
|             return True
 | |
|         except Exception as e:
 | |
|             logger.warning(f"OpenAI health check failed: {e}")
 | |
|             return False
 | |
| 
 | |
| 
 | |
| class AnthropicProvider(AIProvider):
 | |
|     """Anthropic Claude provider implementation."""
 | |
|     
 | |
|     def __init__(self, config: AIProviderConfig):
 | |
|         super().__init__(config)
 | |
|         self._client = None
 | |
|     
 | |
|     def _get_client(self):
 | |
|         """Lazy initialization of Anthropic client."""
 | |
|         if self._client is None:
 | |
|             try:
 | |
|                 import anthropic
 | |
|                 self._client = anthropic.AsyncAnthropic(
 | |
|                     api_key=self.config.api_key or os.getenv("ANTHROPIC_API_KEY"),
 | |
|                     timeout=self.config.timeout
 | |
|                 )
 | |
|             except ImportError:
 | |
|                 raise ImportError("Anthropic package not installed. Install with: pip install anthropic")
 | |
|         return self._client
 | |
|     
 | |
|     async def generate_response(self, context: ConversationContext, message: str) -> str:
 | |
|         """Generate response using Anthropic API."""
 | |
|         client = self._get_client()
 | |
|         
 | |
|         context.add_message(MessageRole.USER, message)
 | |
|         
 | |
|         # Convert messages for Anthropic format
 | |
|         messages = []
 | |
|         system_prompt = None
 | |
|         
 | |
|         for msg in context.messages:
 | |
|             if msg.role == MessageRole.SYSTEM:
 | |
|                 system_prompt = msg.content
 | |
|             else:
 | |
|                 messages.append({
 | |
|                     "role": msg.role.value,
 | |
|                     "content": msg.content
 | |
|                 })
 | |
|         
 | |
|         for attempt in range(self.config.retry_attempts):
 | |
|             try:
 | |
|                 kwargs = {
 | |
|                     "model": self.config.model,
 | |
|                     "messages": messages,
 | |
|                     "max_tokens": self.config.max_tokens,
 | |
|                     "temperature": self.config.temperature,
 | |
|                 }
 | |
|                 
 | |
|                 if system_prompt:
 | |
|                     kwargs["system"] = system_prompt
 | |
|                 
 | |
|                 response = await client.messages.create(**kwargs)
 | |
|                 response_text = response.content[0].text
 | |
|                 context.add_message(MessageRole.ASSISTANT, response_text)
 | |
|                 return response_text
 | |
|                 
 | |
|             except Exception as e:
 | |
|                 logger.warning(f"Anthropic API attempt {attempt + 1} failed: {e}")
 | |
|                 if attempt < self.config.retry_attempts - 1:
 | |
|                     await asyncio.sleep(self.config.retry_delay * (2 ** attempt))
 | |
|                 else:
 | |
|                     raise
 | |
|     
 | |
|     async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]:
 | |
|         """Stream response using Anthropic API."""
 | |
|         client = self._get_client()
 | |
|         
 | |
|         context.add_message(MessageRole.USER, message)
 | |
|         messages = context.get_context_messages()
 | |
|         
 | |
|         try:
 | |
|             async with client.messages.stream(
 | |
|                 model=self.config.model,
 | |
|                 messages=messages,
 | |
|                 max_tokens=self.config.max_tokens,
 | |
|                 temperature=self.config.temperature
 | |
|             ) as stream:
 | |
|                 full_response = ""
 | |
|                 async for text in stream.text_stream:
 | |
|                     full_response += text
 | |
|                     yield text
 | |
|                 
 | |
|                 context.add_message(MessageRole.ASSISTANT, full_response)
 | |
|                 
 | |
|         except Exception as e:
 | |
|             logger.error(f"Anthropic streaming failed: {e}")
 | |
|             raise
 | |
|     
 | |
|     async def health_check(self) -> bool:
 | |
|         """Check Anthropic API health."""
 | |
|         try:
 | |
|             client = self._get_client()
 | |
|             # Simple test to verify API connectivity
 | |
|             await client.messages.create(
 | |
|                 model=self.config.model,
 | |
|                 messages=[{"role": "user", "content": "test"}],
 | |
|                 max_tokens=1
 | |
|             )
 | |
|             return True
 | |
|         except Exception as e:
 | |
|             logger.warning(f"Anthropic health check failed: {e}")
 | |
|             return False
 | |
| 
 | |
| 
 | |
| class LocalProvider(AIProvider):
 | |
|     """Local model provider (e.g., Ollama, llama.cpp)."""
 | |
|     
 | |
|     def __init__(self, config: AIProviderConfig):
 | |
|         super().__init__(config)
 | |
|         self.base_url = config.base_url or "http://localhost:11434"
 | |
|     
 | |
|     async def generate_response(self, context: ConversationContext, message: str) -> str:
 | |
|         """Generate response using local model API."""
 | |
|         context.add_message(MessageRole.USER, message)
 | |
|         
 | |
|         import aiohttp
 | |
|         
 | |
|         async with aiohttp.ClientSession() as session:
 | |
|             payload = {
 | |
|                 "model": self.config.model,
 | |
|                 "messages": context.get_context_messages(),
 | |
|                 "stream": False,
 | |
|                 "options": {
 | |
|                     "temperature": self.config.temperature,
 | |
|                     "num_predict": self.config.max_tokens
 | |
|                 }
 | |
|             }
 | |
|             
 | |
|             try:
 | |
|                 async with session.post(
 | |
|                     f"{self.base_url}/api/chat",
 | |
|                     json=payload,
 | |
|                     timeout=aiohttp.ClientTimeout(total=self.config.timeout)
 | |
|                 ) as resp:
 | |
|                     if resp.status == 200:
 | |
|                         result = await resp.json()
 | |
|                         response_text = result["message"]["content"]
 | |
|                         context.add_message(MessageRole.ASSISTANT, response_text)
 | |
|                         return response_text
 | |
|                     else:
 | |
|                         raise Exception(f"Local API returned status {resp.status}")
 | |
|                         
 | |
|             except Exception as e:
 | |
|                 logger.error(f"Local provider failed: {e}")
 | |
|                 raise
 | |
|     
 | |
|     async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]:
 | |
|         """Stream response using local model API."""
 | |
|         context.add_message(MessageRole.USER, message)
 | |
|         
 | |
|         import aiohttp
 | |
|         
 | |
|         async with aiohttp.ClientSession() as session:
 | |
|             payload = {
 | |
|                 "model": self.config.model,
 | |
|                 "messages": context.get_context_messages(),
 | |
|                 "stream": True
 | |
|             }
 | |
|             
 | |
|             try:
 | |
|                 async with session.post(
 | |
|                     f"{self.base_url}/api/chat",
 | |
|                     json=payload,
 | |
|                     timeout=aiohttp.ClientTimeout(total=self.config.timeout)
 | |
|                 ) as resp:
 | |
|                     if resp.status == 200:
 | |
|                         full_response = ""
 | |
|                         async for line in resp.content:
 | |
|                             if line:
 | |
|                                 import json
 | |
|                                 try:
 | |
|                                     data = json.loads(line.decode())
 | |
|                                     if "message" in data and "content" in data["message"]:
 | |
|                                         content = data["message"]["content"]
 | |
|                                         full_response += content
 | |
|                                         yield content
 | |
|                                 except json.JSONDecodeError:
 | |
|                                     continue
 | |
|                         
 | |
|                         context.add_message(MessageRole.ASSISTANT, full_response)
 | |
|                     else:
 | |
|                         raise Exception(f"Local API returned status {resp.status}")
 | |
|                         
 | |
|             except Exception as e:
 | |
|                 logger.error(f"Local provider streaming failed: {e}")
 | |
|                 raise
 | |
|     
 | |
|     async def health_check(self) -> bool:
 | |
|         """Check local model health."""
 | |
|         try:
 | |
|             import aiohttp
 | |
|             async with aiohttp.ClientSession() as session:
 | |
|                 async with session.get(
 | |
|                     f"{self.base_url}/api/tags",
 | |
|                     timeout=aiohttp.ClientTimeout(total=5)
 | |
|                 ) as resp:
 | |
|                     return resp.status == 200
 | |
|         except Exception as e:
 | |
|             logger.warning(f"Local provider health check failed: {e}")
 | |
|             return False
 | |
| 
 | |
| 
 | |
| class AIProviderManager:
 | |
|     """Manager for AI providers and configurations."""
 | |
|     
 | |
|     def __init__(self):
 | |
|         self.providers: Dict[str, AIProvider] = {}
 | |
|         self.default_configs = self._load_default_configs()
 | |
|     
 | |
|     def _load_default_configs(self) -> Dict[AIProviderType, AIProviderConfig]:
 | |
|         """Load default configurations for providers."""
 | |
|         return {
 | |
|             AIProviderType.OPENAI: AIProviderConfig(
 | |
|                 provider_type=AIProviderType.OPENAI,
 | |
|                 model=os.getenv("OPENAI_MODEL", "gpt-3.5-turbo"),
 | |
|                 max_tokens=int(os.getenv("OPENAI_MAX_TOKENS", "1000")),
 | |
|                 temperature=float(os.getenv("OPENAI_TEMPERATURE", "0.7"))
 | |
|             ),
 | |
|             AIProviderType.ANTHROPIC: AIProviderConfig(
 | |
|                 provider_type=AIProviderType.ANTHROPIC,
 | |
|                 model=os.getenv("ANTHROPIC_MODEL", "claude-3-sonnet-20240229"),
 | |
|                 max_tokens=int(os.getenv("ANTHROPIC_MAX_TOKENS", "1000")),
 | |
|                 temperature=float(os.getenv("ANTHROPIC_TEMPERATURE", "0.7"))
 | |
|             ),
 | |
|             AIProviderType.LOCAL: AIProviderConfig(
 | |
|                 provider_type=AIProviderType.LOCAL,
 | |
|                 base_url=os.getenv("LOCAL_MODEL_URL", "http://localhost:11434"),
 | |
|                 model=os.getenv("LOCAL_MODEL_NAME", "llama2"),
 | |
|                 max_tokens=int(os.getenv("LOCAL_MAX_TOKENS", "1000")),
 | |
|                 temperature=float(os.getenv("LOCAL_TEMPERATURE", "0.7"))
 | |
|             )
 | |
|         }
 | |
|     
 | |
|     def create_provider(self, provider_type: AIProviderType, config: Optional[AIProviderConfig] = None) -> AIProvider:
 | |
|         """Create an AI provider instance."""
 | |
|         if config is None:
 | |
|             config = self.default_configs.get(provider_type)
 | |
|             if config is None:
 | |
|                 raise ValueError(f"No default config for provider type: {provider_type}")
 | |
|         
 | |
|         if provider_type == AIProviderType.OPENAI:
 | |
|             return OpenAIProvider(config)
 | |
|         elif provider_type == AIProviderType.ANTHROPIC:
 | |
|             return AnthropicProvider(config)
 | |
|         elif provider_type == AIProviderType.LOCAL:
 | |
|             return LocalProvider(config)
 | |
|         else:
 | |
|             raise ValueError(f"Unsupported provider type: {provider_type}")
 | |
|     
 | |
|     def register_provider(self, name: str, provider: AIProvider):
 | |
|         """Register a provider instance."""
 | |
|         self.providers[name] = provider
 | |
|         logger.info(f"Registered AI provider: {name} ({provider.provider_type})")
 | |
|     
 | |
|     def get_provider(self, name: str) -> Optional[AIProvider]:
 | |
|         """Get a registered provider."""
 | |
|         return self.providers.get(name)
 | |
|     
 | |
|     def list_providers(self) -> List[str]:
 | |
|         """List all registered provider names."""
 | |
|         return list(self.providers.keys())
 | |
|     
 | |
|     async def health_check_all(self) -> Dict[str, bool]:
 | |
|         """Health check all registered providers."""
 | |
|         results = {}
 | |
|         for name, provider in self.providers.items():
 | |
|             try:
 | |
|                 results[name] = await provider.health_check()
 | |
|             except Exception as e:
 | |
|                 logger.error(f"Health check failed for provider {name}: {e}")
 | |
|                 results[name] = False
 | |
|         return results
 | |
| 
 | |
| 
 | |
| # Global provider manager instance
 | |
| ai_provider_manager = AIProviderManager()
 |