532 lines
19 KiB
Python
532 lines
19 KiB
Python
"""
|
|
AI Provider Integration for Advanced Bot Management.
|
|
|
|
This module provides support for multiple AI providers including OpenAI, Anthropic,
|
|
and local models for enhanced bot capabilities.
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import asyncio
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Optional, Any, AsyncIterator
|
|
from enum import Enum
|
|
from dataclasses import dataclass
|
|
from pydantic import BaseModel, Field
|
|
|
|
from shared.logger import logger
|
|
|
|
|
|
class AIProviderType(str, Enum):
|
|
"""Supported AI provider types."""
|
|
OPENAI = "openai"
|
|
ANTHROPIC = "anthropic"
|
|
LOCAL = "local"
|
|
CUSTOM = "custom"
|
|
|
|
|
|
class MessageRole(str, Enum):
|
|
"""Message roles in conversation."""
|
|
SYSTEM = "system"
|
|
USER = "user"
|
|
ASSISTANT = "assistant"
|
|
|
|
|
|
@dataclass
|
|
class ConversationMessage:
|
|
"""Individual message in a conversation."""
|
|
role: MessageRole
|
|
content: str
|
|
timestamp: float = None
|
|
metadata: Dict[str, Any] = None
|
|
|
|
def __post_init__(self):
|
|
if self.timestamp is None:
|
|
self.timestamp = time.time()
|
|
if self.metadata is None:
|
|
self.metadata = {}
|
|
|
|
|
|
class AIProviderConfig(BaseModel):
|
|
"""Configuration for AI providers."""
|
|
provider_type: AIProviderType
|
|
api_key: Optional[str] = None
|
|
base_url: Optional[str] = None
|
|
model: str = "gpt-3.5-turbo"
|
|
max_tokens: int = 1000
|
|
temperature: float = 0.7
|
|
timeout: float = 30.0
|
|
retry_attempts: int = 3
|
|
retry_delay: float = 1.0
|
|
|
|
# Advanced settings
|
|
top_p: Optional[float] = None
|
|
frequency_penalty: Optional[float] = None
|
|
presence_penalty: Optional[float] = None
|
|
stop_sequences: List[str] = Field(default_factory=list)
|
|
|
|
class Config:
|
|
extra = "allow" # Allow additional provider-specific configs
|
|
|
|
|
|
class ConversationContext(BaseModel):
|
|
"""Conversation context and memory management."""
|
|
session_id: str
|
|
bot_name: str
|
|
messages: List[ConversationMessage] = Field(default_factory=list)
|
|
created_at: float = Field(default_factory=time.time)
|
|
last_updated: float = Field(default_factory=time.time)
|
|
|
|
# Context management
|
|
max_history: int = 50
|
|
context_window: int = 4000 # Token limit for context
|
|
personality_prompt: Optional[str] = None
|
|
|
|
# Metadata
|
|
user_preferences: Dict[str, Any] = Field(default_factory=dict)
|
|
conversation_state: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
def add_message(self, role: MessageRole, content: str, metadata: Dict[str, Any] = None):
|
|
"""Add a message to the conversation."""
|
|
message = ConversationMessage(role=role, content=content, metadata=metadata or {})
|
|
self.messages.append(message)
|
|
self.last_updated = time.time()
|
|
|
|
# Trim history if needed
|
|
if len(self.messages) > self.max_history:
|
|
# Keep system messages and recent messages
|
|
system_messages = [m for m in self.messages if m.role == MessageRole.SYSTEM]
|
|
recent_messages = [m for m in self.messages if m.role != MessageRole.SYSTEM][-self.max_history:]
|
|
self.messages = system_messages + recent_messages
|
|
|
|
def get_context_messages(self) -> List[Dict[str, str]]:
|
|
"""Get messages formatted for AI provider APIs."""
|
|
messages = []
|
|
for msg in self.messages:
|
|
messages.append({
|
|
"role": msg.role.value,
|
|
"content": msg.content
|
|
})
|
|
return messages
|
|
|
|
|
|
class AIProvider(ABC):
|
|
"""Abstract base class for AI providers."""
|
|
|
|
def __init__(self, config: AIProviderConfig):
|
|
self.config = config
|
|
self.provider_type = config.provider_type
|
|
|
|
@abstractmethod
|
|
async def generate_response(
|
|
self,
|
|
context: ConversationContext,
|
|
message: str
|
|
) -> str:
|
|
"""Generate a response to a message."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def stream_response(
|
|
self,
|
|
context: ConversationContext,
|
|
message: str
|
|
) -> AsyncIterator[str]:
|
|
"""Stream a response to a message."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def health_check(self) -> bool:
|
|
"""Check if the provider is healthy and available."""
|
|
pass
|
|
|
|
|
|
class OpenAIProvider(AIProvider):
|
|
"""OpenAI provider implementation."""
|
|
|
|
def __init__(self, config: AIProviderConfig):
|
|
super().__init__(config)
|
|
self._client = None
|
|
|
|
def _get_client(self):
|
|
"""Lazy initialization of OpenAI client."""
|
|
if self._client is None:
|
|
try:
|
|
import openai
|
|
self._client = openai.AsyncOpenAI(
|
|
api_key=self.config.api_key or os.getenv("OPENAI_API_KEY"),
|
|
base_url=self.config.base_url,
|
|
timeout=self.config.timeout
|
|
)
|
|
except ImportError:
|
|
raise ImportError("OpenAI package not installed. Install with: pip install openai")
|
|
return self._client
|
|
|
|
async def generate_response(self, context: ConversationContext, message: str) -> str:
|
|
"""Generate response using OpenAI API."""
|
|
client = self._get_client()
|
|
|
|
# Add user message to context
|
|
context.add_message(MessageRole.USER, message)
|
|
|
|
messages = context.get_context_messages()
|
|
|
|
for attempt in range(self.config.retry_attempts):
|
|
try:
|
|
response = await client.chat.completions.create(
|
|
model=self.config.model,
|
|
messages=messages,
|
|
max_tokens=self.config.max_tokens,
|
|
temperature=self.config.temperature,
|
|
top_p=self.config.top_p,
|
|
frequency_penalty=self.config.frequency_penalty,
|
|
presence_penalty=self.config.presence_penalty,
|
|
stop=self.config.stop_sequences or None
|
|
)
|
|
|
|
response_text = response.choices[0].message.content
|
|
context.add_message(MessageRole.ASSISTANT, response_text)
|
|
return response_text
|
|
|
|
except Exception as e:
|
|
logger.warning(f"OpenAI API attempt {attempt + 1} failed: {e}")
|
|
if attempt < self.config.retry_attempts - 1:
|
|
await asyncio.sleep(self.config.retry_delay * (2 ** attempt))
|
|
else:
|
|
raise
|
|
|
|
async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]:
|
|
"""Stream response using OpenAI API."""
|
|
client = self._get_client()
|
|
|
|
context.add_message(MessageRole.USER, message)
|
|
messages = context.get_context_messages()
|
|
|
|
try:
|
|
stream = await client.chat.completions.create(
|
|
model=self.config.model,
|
|
messages=messages,
|
|
max_tokens=self.config.max_tokens,
|
|
temperature=self.config.temperature,
|
|
stream=True
|
|
)
|
|
|
|
full_response = ""
|
|
async for chunk in stream:
|
|
if chunk.choices[0].delta.content:
|
|
content = chunk.choices[0].delta.content
|
|
full_response += content
|
|
yield content
|
|
|
|
# Add complete response to context
|
|
context.add_message(MessageRole.ASSISTANT, full_response)
|
|
|
|
except Exception as e:
|
|
logger.error(f"OpenAI streaming failed: {e}")
|
|
raise
|
|
|
|
async def health_check(self) -> bool:
|
|
"""Check OpenAI API health."""
|
|
try:
|
|
client = self._get_client()
|
|
await client.models.list()
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"OpenAI health check failed: {e}")
|
|
return False
|
|
|
|
|
|
class AnthropicProvider(AIProvider):
|
|
"""Anthropic Claude provider implementation."""
|
|
|
|
def __init__(self, config: AIProviderConfig):
|
|
super().__init__(config)
|
|
self._client = None
|
|
|
|
def _get_client(self):
|
|
"""Lazy initialization of Anthropic client."""
|
|
if self._client is None:
|
|
try:
|
|
import anthropic
|
|
self._client = anthropic.AsyncAnthropic(
|
|
api_key=self.config.api_key or os.getenv("ANTHROPIC_API_KEY"),
|
|
timeout=self.config.timeout
|
|
)
|
|
except ImportError:
|
|
raise ImportError("Anthropic package not installed. Install with: pip install anthropic")
|
|
return self._client
|
|
|
|
async def generate_response(self, context: ConversationContext, message: str) -> str:
|
|
"""Generate response using Anthropic API."""
|
|
client = self._get_client()
|
|
|
|
context.add_message(MessageRole.USER, message)
|
|
|
|
# Convert messages for Anthropic format
|
|
messages = []
|
|
system_prompt = None
|
|
|
|
for msg in context.messages:
|
|
if msg.role == MessageRole.SYSTEM:
|
|
system_prompt = msg.content
|
|
else:
|
|
messages.append({
|
|
"role": msg.role.value,
|
|
"content": msg.content
|
|
})
|
|
|
|
for attempt in range(self.config.retry_attempts):
|
|
try:
|
|
kwargs = {
|
|
"model": self.config.model,
|
|
"messages": messages,
|
|
"max_tokens": self.config.max_tokens,
|
|
"temperature": self.config.temperature,
|
|
}
|
|
|
|
if system_prompt:
|
|
kwargs["system"] = system_prompt
|
|
|
|
response = await client.messages.create(**kwargs)
|
|
response_text = response.content[0].text
|
|
context.add_message(MessageRole.ASSISTANT, response_text)
|
|
return response_text
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Anthropic API attempt {attempt + 1} failed: {e}")
|
|
if attempt < self.config.retry_attempts - 1:
|
|
await asyncio.sleep(self.config.retry_delay * (2 ** attempt))
|
|
else:
|
|
raise
|
|
|
|
async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]:
|
|
"""Stream response using Anthropic API."""
|
|
client = self._get_client()
|
|
|
|
context.add_message(MessageRole.USER, message)
|
|
|
|
# Convert messages for Anthropic format
|
|
messages = []
|
|
system_prompt = None
|
|
|
|
for msg in context.messages:
|
|
if msg.role == MessageRole.SYSTEM:
|
|
system_prompt = msg.content
|
|
else:
|
|
messages.append({
|
|
"role": msg.role.value,
|
|
"content": msg.content
|
|
})
|
|
|
|
try:
|
|
kwargs = {
|
|
"model": self.config.model,
|
|
"messages": messages,
|
|
"max_tokens": self.config.max_tokens,
|
|
"temperature": self.config.temperature,
|
|
}
|
|
|
|
if system_prompt:
|
|
kwargs["system"] = system_prompt
|
|
|
|
async with client.messages.stream(**kwargs) as stream:
|
|
full_response = ""
|
|
async for text in stream.text_stream:
|
|
full_response += text
|
|
yield text
|
|
|
|
context.add_message(MessageRole.ASSISTANT, full_response)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Anthropic streaming failed: {e}")
|
|
raise
|
|
|
|
async def health_check(self) -> bool:
|
|
"""Check Anthropic API health."""
|
|
try:
|
|
client = self._get_client()
|
|
# Simple test to verify API connectivity
|
|
await client.messages.create(
|
|
model=self.config.model,
|
|
messages=[{"role": "user", "content": "test"}],
|
|
max_tokens=1
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Anthropic health check failed: {e}")
|
|
return False
|
|
|
|
|
|
class LocalProvider(AIProvider):
|
|
"""Local model provider (e.g., Ollama, llama.cpp)."""
|
|
|
|
def __init__(self, config: AIProviderConfig):
|
|
super().__init__(config)
|
|
self.base_url = config.base_url or "http://localhost:11434"
|
|
|
|
async def generate_response(self, context: ConversationContext, message: str) -> str:
|
|
"""Generate response using local model API."""
|
|
context.add_message(MessageRole.USER, message)
|
|
|
|
import aiohttp
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
payload = {
|
|
"model": self.config.model,
|
|
"messages": context.get_context_messages(),
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": self.config.temperature,
|
|
"num_predict": self.config.max_tokens
|
|
}
|
|
}
|
|
|
|
try:
|
|
async with session.post(
|
|
f"{self.base_url}/api/chat",
|
|
json=payload,
|
|
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
|
|
) as resp:
|
|
if resp.status == 200:
|
|
result = await resp.json()
|
|
response_text = result["message"]["content"]
|
|
context.add_message(MessageRole.ASSISTANT, response_text)
|
|
return response_text
|
|
else:
|
|
raise Exception(f"Local API returned status {resp.status}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Local provider failed: {e}")
|
|
raise
|
|
|
|
async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]:
|
|
"""Stream response using local model API."""
|
|
context.add_message(MessageRole.USER, message)
|
|
|
|
import aiohttp
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
payload = {
|
|
"model": self.config.model,
|
|
"messages": context.get_context_messages(),
|
|
"stream": True
|
|
}
|
|
|
|
try:
|
|
async with session.post(
|
|
f"{self.base_url}/api/chat",
|
|
json=payload,
|
|
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
|
|
) as resp:
|
|
if resp.status == 200:
|
|
full_response = ""
|
|
async for line in resp.content:
|
|
if line:
|
|
import json
|
|
try:
|
|
data = json.loads(line.decode())
|
|
if "message" in data and "content" in data["message"]:
|
|
content = data["message"]["content"]
|
|
full_response += content
|
|
yield content
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
context.add_message(MessageRole.ASSISTANT, full_response)
|
|
else:
|
|
raise Exception(f"Local API returned status {resp.status}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Local provider streaming failed: {e}")
|
|
raise
|
|
|
|
async def health_check(self) -> bool:
|
|
"""Check local model health."""
|
|
try:
|
|
import aiohttp
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(
|
|
f"{self.base_url}/api/tags",
|
|
timeout=aiohttp.ClientTimeout(total=5)
|
|
) as resp:
|
|
return resp.status == 200
|
|
except Exception as e:
|
|
logger.warning(f"Local provider health check failed: {e}")
|
|
return False
|
|
|
|
|
|
class AIProviderManager:
|
|
"""Manager for AI providers and configurations."""
|
|
|
|
def __init__(self):
|
|
self.providers: Dict[str, AIProvider] = {}
|
|
self.default_configs = self._load_default_configs()
|
|
|
|
def _load_default_configs(self) -> Dict[AIProviderType, AIProviderConfig]:
|
|
"""Load default configurations for providers."""
|
|
return {
|
|
AIProviderType.OPENAI: AIProviderConfig(
|
|
provider_type=AIProviderType.OPENAI,
|
|
model=os.getenv("OPENAI_MODEL", "gpt-3.5-turbo"),
|
|
max_tokens=int(os.getenv("OPENAI_MAX_TOKENS", "1000")),
|
|
temperature=float(os.getenv("OPENAI_TEMPERATURE", "0.7"))
|
|
),
|
|
AIProviderType.ANTHROPIC: AIProviderConfig(
|
|
provider_type=AIProviderType.ANTHROPIC,
|
|
model=os.getenv("ANTHROPIC_MODEL", "claude-3-5-sonnet-20241022"),
|
|
max_tokens=int(os.getenv("ANTHROPIC_MAX_TOKENS", "1000")),
|
|
temperature=float(os.getenv("ANTHROPIC_TEMPERATURE", "0.7"))
|
|
),
|
|
AIProviderType.LOCAL: AIProviderConfig(
|
|
provider_type=AIProviderType.LOCAL,
|
|
base_url=os.getenv("LOCAL_MODEL_URL", "http://localhost:11434"),
|
|
model=os.getenv("LOCAL_MODEL_NAME", "llama2"),
|
|
max_tokens=int(os.getenv("LOCAL_MAX_TOKENS", "1000")),
|
|
temperature=float(os.getenv("LOCAL_TEMPERATURE", "0.7"))
|
|
)
|
|
}
|
|
|
|
def create_provider(self, provider_type: AIProviderType, config: Optional[AIProviderConfig] = None) -> AIProvider:
|
|
"""Create an AI provider instance."""
|
|
if config is None:
|
|
config = self.default_configs.get(provider_type)
|
|
if config is None:
|
|
raise ValueError(f"No default config for provider type: {provider_type}")
|
|
|
|
if provider_type == AIProviderType.OPENAI:
|
|
return OpenAIProvider(config)
|
|
elif provider_type == AIProviderType.ANTHROPIC:
|
|
return AnthropicProvider(config)
|
|
elif provider_type == AIProviderType.LOCAL:
|
|
return LocalProvider(config)
|
|
else:
|
|
raise ValueError(f"Unsupported provider type: {provider_type}")
|
|
|
|
def register_provider(self, name: str, provider: AIProvider):
|
|
"""Register a provider instance."""
|
|
self.providers[name] = provider
|
|
logger.info(f"Registered AI provider: {name} ({provider.provider_type})")
|
|
|
|
def get_provider(self, name: str) -> Optional[AIProvider]:
|
|
"""Get a registered provider."""
|
|
return self.providers.get(name)
|
|
|
|
def list_providers(self) -> List[str]:
|
|
"""List all registered provider names."""
|
|
return list(self.providers.keys())
|
|
|
|
async def health_check_all(self) -> Dict[str, bool]:
|
|
"""Health check all registered providers."""
|
|
results = {}
|
|
for name, provider in self.providers.items():
|
|
try:
|
|
results[name] = await provider.health_check()
|
|
except Exception as e:
|
|
logger.error(f"Health check failed for provider {name}: {e}")
|
|
results[name] = False
|
|
return results
|
|
|
|
|
|
# Global provider manager instance
|
|
ai_provider_manager = AIProviderManager()
|