ai-voicebot/voicebot/ai_providers.py

515 lines
19 KiB
Python

"""
AI Provider Integration for Advanced Bot Management.
This module provides support for multiple AI providers including OpenAI, Anthropic,
and local models for enhanced bot capabilities.
"""
import os
import time
import asyncio
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any, AsyncIterator
from enum import Enum
from dataclasses import dataclass
from pydantic import BaseModel, Field
from logger import logger
class AIProviderType(str, Enum):
"""Supported AI provider types."""
OPENAI = "openai"
ANTHROPIC = "anthropic"
LOCAL = "local"
CUSTOM = "custom"
class MessageRole(str, Enum):
"""Message roles in conversation."""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
@dataclass
class ConversationMessage:
"""Individual message in a conversation."""
role: MessageRole
content: str
timestamp: float = None
metadata: Dict[str, Any] = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = time.time()
if self.metadata is None:
self.metadata = {}
class AIProviderConfig(BaseModel):
"""Configuration for AI providers."""
provider_type: AIProviderType
api_key: Optional[str] = None
base_url: Optional[str] = None
model: str = "gpt-3.5-turbo"
max_tokens: int = 1000
temperature: float = 0.7
timeout: float = 30.0
retry_attempts: int = 3
retry_delay: float = 1.0
# Advanced settings
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
stop_sequences: List[str] = Field(default_factory=list)
class Config:
extra = "allow" # Allow additional provider-specific configs
class ConversationContext(BaseModel):
"""Conversation context and memory management."""
session_id: str
bot_name: str
messages: List[ConversationMessage] = Field(default_factory=list)
created_at: float = Field(default_factory=time.time)
last_updated: float = Field(default_factory=time.time)
# Context management
max_history: int = 50
context_window: int = 4000 # Token limit for context
personality_prompt: Optional[str] = None
# Metadata
user_preferences: Dict[str, Any] = Field(default_factory=dict)
conversation_state: Dict[str, Any] = Field(default_factory=dict)
def add_message(self, role: MessageRole, content: str, metadata: Dict[str, Any] = None):
"""Add a message to the conversation."""
message = ConversationMessage(role=role, content=content, metadata=metadata or {})
self.messages.append(message)
self.last_updated = time.time()
# Trim history if needed
if len(self.messages) > self.max_history:
# Keep system messages and recent messages
system_messages = [m for m in self.messages if m.role == MessageRole.SYSTEM]
recent_messages = [m for m in self.messages if m.role != MessageRole.SYSTEM][-self.max_history:]
self.messages = system_messages + recent_messages
def get_context_messages(self) -> List[Dict[str, str]]:
"""Get messages formatted for AI provider APIs."""
messages = []
for msg in self.messages:
messages.append({
"role": msg.role.value,
"content": msg.content
})
return messages
class AIProvider(ABC):
"""Abstract base class for AI providers."""
def __init__(self, config: AIProviderConfig):
self.config = config
self.provider_type = config.provider_type
@abstractmethod
async def generate_response(
self,
context: ConversationContext,
message: str
) -> str:
"""Generate a response to a message."""
pass
@abstractmethod
async def stream_response(
self,
context: ConversationContext,
message: str
) -> AsyncIterator[str]:
"""Stream a response to a message."""
pass
@abstractmethod
async def health_check(self) -> bool:
"""Check if the provider is healthy and available."""
pass
class OpenAIProvider(AIProvider):
"""OpenAI provider implementation."""
def __init__(self, config: AIProviderConfig):
super().__init__(config)
self._client = None
def _get_client(self):
"""Lazy initialization of OpenAI client."""
if self._client is None:
try:
import openai
self._client = openai.AsyncOpenAI(
api_key=self.config.api_key or os.getenv("OPENAI_API_KEY"),
base_url=self.config.base_url,
timeout=self.config.timeout
)
except ImportError:
raise ImportError("OpenAI package not installed. Install with: pip install openai")
return self._client
async def generate_response(self, context: ConversationContext, message: str) -> str:
"""Generate response using OpenAI API."""
client = self._get_client()
# Add user message to context
context.add_message(MessageRole.USER, message)
messages = context.get_context_messages()
for attempt in range(self.config.retry_attempts):
try:
response = await client.chat.completions.create(
model=self.config.model,
messages=messages,
max_tokens=self.config.max_tokens,
temperature=self.config.temperature,
top_p=self.config.top_p,
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty,
stop=self.config.stop_sequences or None
)
response_text = response.choices[0].message.content
context.add_message(MessageRole.ASSISTANT, response_text)
return response_text
except Exception as e:
logger.warning(f"OpenAI API attempt {attempt + 1} failed: {e}")
if attempt < self.config.retry_attempts - 1:
await asyncio.sleep(self.config.retry_delay * (2 ** attempt))
else:
raise
async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]:
"""Stream response using OpenAI API."""
client = self._get_client()
context.add_message(MessageRole.USER, message)
messages = context.get_context_messages()
try:
stream = await client.chat.completions.create(
model=self.config.model,
messages=messages,
max_tokens=self.config.max_tokens,
temperature=self.config.temperature,
stream=True
)
full_response = ""
async for chunk in stream:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_response += content
yield content
# Add complete response to context
context.add_message(MessageRole.ASSISTANT, full_response)
except Exception as e:
logger.error(f"OpenAI streaming failed: {e}")
raise
async def health_check(self) -> bool:
"""Check OpenAI API health."""
try:
client = self._get_client()
await client.models.list()
return True
except Exception as e:
logger.warning(f"OpenAI health check failed: {e}")
return False
class AnthropicProvider(AIProvider):
"""Anthropic Claude provider implementation."""
def __init__(self, config: AIProviderConfig):
super().__init__(config)
self._client = None
def _get_client(self):
"""Lazy initialization of Anthropic client."""
if self._client is None:
try:
import anthropic
self._client = anthropic.AsyncAnthropic(
api_key=self.config.api_key or os.getenv("ANTHROPIC_API_KEY"),
timeout=self.config.timeout
)
except ImportError:
raise ImportError("Anthropic package not installed. Install with: pip install anthropic")
return self._client
async def generate_response(self, context: ConversationContext, message: str) -> str:
"""Generate response using Anthropic API."""
client = self._get_client()
context.add_message(MessageRole.USER, message)
# Convert messages for Anthropic format
messages = []
system_prompt = None
for msg in context.messages:
if msg.role == MessageRole.SYSTEM:
system_prompt = msg.content
else:
messages.append({
"role": msg.role.value,
"content": msg.content
})
for attempt in range(self.config.retry_attempts):
try:
kwargs = {
"model": self.config.model,
"messages": messages,
"max_tokens": self.config.max_tokens,
"temperature": self.config.temperature,
}
if system_prompt:
kwargs["system"] = system_prompt
response = await client.messages.create(**kwargs)
response_text = response.content[0].text
context.add_message(MessageRole.ASSISTANT, response_text)
return response_text
except Exception as e:
logger.warning(f"Anthropic API attempt {attempt + 1} failed: {e}")
if attempt < self.config.retry_attempts - 1:
await asyncio.sleep(self.config.retry_delay * (2 ** attempt))
else:
raise
async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]:
"""Stream response using Anthropic API."""
client = self._get_client()
context.add_message(MessageRole.USER, message)
messages = context.get_context_messages()
try:
async with client.messages.stream(
model=self.config.model,
messages=messages,
max_tokens=self.config.max_tokens,
temperature=self.config.temperature
) as stream:
full_response = ""
async for text in stream.text_stream:
full_response += text
yield text
context.add_message(MessageRole.ASSISTANT, full_response)
except Exception as e:
logger.error(f"Anthropic streaming failed: {e}")
raise
async def health_check(self) -> bool:
"""Check Anthropic API health."""
try:
client = self._get_client()
# Simple test to verify API connectivity
await client.messages.create(
model=self.config.model,
messages=[{"role": "user", "content": "test"}],
max_tokens=1
)
return True
except Exception as e:
logger.warning(f"Anthropic health check failed: {e}")
return False
class LocalProvider(AIProvider):
"""Local model provider (e.g., Ollama, llama.cpp)."""
def __init__(self, config: AIProviderConfig):
super().__init__(config)
self.base_url = config.base_url or "http://localhost:11434"
async def generate_response(self, context: ConversationContext, message: str) -> str:
"""Generate response using local model API."""
context.add_message(MessageRole.USER, message)
import aiohttp
async with aiohttp.ClientSession() as session:
payload = {
"model": self.config.model,
"messages": context.get_context_messages(),
"stream": False,
"options": {
"temperature": self.config.temperature,
"num_predict": self.config.max_tokens
}
}
try:
async with session.post(
f"{self.base_url}/api/chat",
json=payload,
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
) as resp:
if resp.status == 200:
result = await resp.json()
response_text = result["message"]["content"]
context.add_message(MessageRole.ASSISTANT, response_text)
return response_text
else:
raise Exception(f"Local API returned status {resp.status}")
except Exception as e:
logger.error(f"Local provider failed: {e}")
raise
async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]:
"""Stream response using local model API."""
context.add_message(MessageRole.USER, message)
import aiohttp
async with aiohttp.ClientSession() as session:
payload = {
"model": self.config.model,
"messages": context.get_context_messages(),
"stream": True
}
try:
async with session.post(
f"{self.base_url}/api/chat",
json=payload,
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
) as resp:
if resp.status == 200:
full_response = ""
async for line in resp.content:
if line:
import json
try:
data = json.loads(line.decode())
if "message" in data and "content" in data["message"]:
content = data["message"]["content"]
full_response += content
yield content
except json.JSONDecodeError:
continue
context.add_message(MessageRole.ASSISTANT, full_response)
else:
raise Exception(f"Local API returned status {resp.status}")
except Exception as e:
logger.error(f"Local provider streaming failed: {e}")
raise
async def health_check(self) -> bool:
"""Check local model health."""
try:
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/api/tags",
timeout=aiohttp.ClientTimeout(total=5)
) as resp:
return resp.status == 200
except Exception as e:
logger.warning(f"Local provider health check failed: {e}")
return False
class AIProviderManager:
"""Manager for AI providers and configurations."""
def __init__(self):
self.providers: Dict[str, AIProvider] = {}
self.default_configs = self._load_default_configs()
def _load_default_configs(self) -> Dict[AIProviderType, AIProviderConfig]:
"""Load default configurations for providers."""
return {
AIProviderType.OPENAI: AIProviderConfig(
provider_type=AIProviderType.OPENAI,
model=os.getenv("OPENAI_MODEL", "gpt-3.5-turbo"),
max_tokens=int(os.getenv("OPENAI_MAX_TOKENS", "1000")),
temperature=float(os.getenv("OPENAI_TEMPERATURE", "0.7"))
),
AIProviderType.ANTHROPIC: AIProviderConfig(
provider_type=AIProviderType.ANTHROPIC,
model=os.getenv("ANTHROPIC_MODEL", "claude-3-sonnet-20240229"),
max_tokens=int(os.getenv("ANTHROPIC_MAX_TOKENS", "1000")),
temperature=float(os.getenv("ANTHROPIC_TEMPERATURE", "0.7"))
),
AIProviderType.LOCAL: AIProviderConfig(
provider_type=AIProviderType.LOCAL,
base_url=os.getenv("LOCAL_MODEL_URL", "http://localhost:11434"),
model=os.getenv("LOCAL_MODEL_NAME", "llama2"),
max_tokens=int(os.getenv("LOCAL_MAX_TOKENS", "1000")),
temperature=float(os.getenv("LOCAL_TEMPERATURE", "0.7"))
)
}
def create_provider(self, provider_type: AIProviderType, config: Optional[AIProviderConfig] = None) -> AIProvider:
"""Create an AI provider instance."""
if config is None:
config = self.default_configs.get(provider_type)
if config is None:
raise ValueError(f"No default config for provider type: {provider_type}")
if provider_type == AIProviderType.OPENAI:
return OpenAIProvider(config)
elif provider_type == AIProviderType.ANTHROPIC:
return AnthropicProvider(config)
elif provider_type == AIProviderType.LOCAL:
return LocalProvider(config)
else:
raise ValueError(f"Unsupported provider type: {provider_type}")
def register_provider(self, name: str, provider: AIProvider):
"""Register a provider instance."""
self.providers[name] = provider
logger.info(f"Registered AI provider: {name} ({provider.provider_type})")
def get_provider(self, name: str) -> Optional[AIProvider]:
"""Get a registered provider."""
return self.providers.get(name)
def list_providers(self) -> List[str]:
"""List all registered provider names."""
return list(self.providers.keys())
async def health_check_all(self) -> Dict[str, bool]:
"""Health check all registered providers."""
results = {}
for name, provider in self.providers.items():
try:
results[name] = await provider.health_check()
except Exception as e:
logger.error(f"Health check failed for provider {name}: {e}")
results[name] = False
return results
# Global provider manager instance
ai_provider_manager = AIProviderManager()