ai-voicebot/voicebot/ai_providers.py
2025-09-15 13:35:38 -07:00

574 lines
22 KiB
Python

"""
AI Provider Integration for Advanced Bot Management.
This module provides support for multiple AI providers including OpenAI, Anthropic,
and local models for enhanced bot capabilities.
"""
import os
import time
import asyncio
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any, AsyncIterator
from enum import Enum
from dataclasses import dataclass
from pydantic import BaseModel, Field
from shared.logger import logger
class AIProviderType(str, Enum):
"""Supported AI provider types."""
OPENAI = "openai"
ANTHROPIC = "anthropic"
LOCAL = "local"
CUSTOM = "custom"
class MessageRole(str, Enum):
"""Message roles in conversation."""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
@dataclass
class ConversationMessage:
"""Individual message in a conversation."""
role: MessageRole
content: str
timestamp: float = None
metadata: Dict[str, Any] = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = time.time()
if self.metadata is None:
self.metadata = {}
class AIProviderConfig(BaseModel):
"""Configuration for AI providers."""
provider_type: AIProviderType
api_key: Optional[str] = None
base_url: Optional[str] = None
model: str = "gpt-3.5-turbo"
max_tokens: int = 1000
temperature: float = 0.7
timeout: float = 30.0
retry_attempts: int = 3
retry_delay: float = 1.0
# Advanced settings
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
stop_sequences: List[str] = Field(default_factory=list)
class Config:
extra = "allow" # Allow additional provider-specific configs
class ConversationContext(BaseModel):
"""Conversation context and memory management."""
session_id: str
bot_name: str
messages: List[ConversationMessage] = Field(default_factory=list)
created_at: float = Field(default_factory=time.time)
last_updated: float = Field(default_factory=time.time)
# Context management
max_history: int = 50
context_window: int = 4000 # Token limit for context
personality_prompt: Optional[str] = None
# Metadata
user_preferences: Dict[str, Any] = Field(default_factory=dict)
conversation_state: Dict[str, Any] = Field(default_factory=dict)
def add_message(self, role: MessageRole, content: str, metadata: Dict[str, Any] = None):
"""Add a message to the conversation."""
message = ConversationMessage(role=role, content=content, metadata=metadata or {})
self.messages.append(message)
self.last_updated = time.time()
# Trim history if needed
if len(self.messages) > self.max_history:
# Keep system messages and recent messages
system_messages = [m for m in self.messages if m.role == MessageRole.SYSTEM]
recent_messages = [m for m in self.messages if m.role != MessageRole.SYSTEM][-self.max_history:]
self.messages = system_messages + recent_messages
def get_context_messages(self) -> List[Dict[str, str]]:
"""Get messages formatted for AI provider APIs."""
messages = []
for msg in self.messages:
messages.append({
"role": msg.role.value,
"content": msg.content
})
return messages
class AIProvider(ABC):
"""Abstract base class for AI providers."""
def __init__(self, config: AIProviderConfig):
self.config = config
self.provider_type = config.provider_type
@abstractmethod
async def generate_response(
self,
context: ConversationContext,
message: str
) -> str:
"""Generate a response to a message."""
pass
@abstractmethod
async def stream_response(
self,
context: ConversationContext,
message: str
) -> AsyncIterator[str]:
"""Stream a response to a message."""
pass
@abstractmethod
async def health_check(self) -> bool:
"""Check if the provider is healthy and available."""
pass
class OpenAIProvider(AIProvider):
"""OpenAI provider implementation."""
def __init__(self, config: AIProviderConfig):
super().__init__(config)
self._client = None
def _get_client(self):
"""Lazy initialization of OpenAI client."""
if self._client is None:
try:
import openai
self._client = openai.AsyncOpenAI(
api_key=self.config.api_key or os.getenv("OPENAI_API_KEY"),
base_url=self.config.base_url,
timeout=self.config.timeout
)
except ImportError:
raise ImportError("OpenAI package not installed. Install with: pip install openai")
return self._client
async def generate_response(self, context: ConversationContext, message: str) -> str:
"""Generate response using OpenAI API."""
client = self._get_client()
# Add user message to context
context.add_message(MessageRole.USER, message)
messages = context.get_context_messages()
for attempt in range(self.config.retry_attempts):
try:
response = await client.chat.completions.create(
model=self.config.model,
messages=messages,
max_tokens=self.config.max_tokens,
temperature=self.config.temperature,
top_p=self.config.top_p,
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty,
stop=self.config.stop_sequences or None
)
response_text = response.choices[0].message.content
context.add_message(MessageRole.ASSISTANT, response_text)
return response_text
except Exception as e:
logger.warning(f"OpenAI API attempt {attempt + 1} failed: {e}")
if attempt < self.config.retry_attempts - 1:
await asyncio.sleep(self.config.retry_delay * (2 ** attempt))
else:
raise
async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]:
"""Stream response using OpenAI API."""
client = self._get_client()
context.add_message(MessageRole.USER, message)
messages = context.get_context_messages()
try:
stream = await client.chat.completions.create(
model=self.config.model,
messages=messages,
max_tokens=self.config.max_tokens,
temperature=self.config.temperature,
stream=True
)
full_response = ""
async for chunk in stream:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_response += content
yield content
# Add complete response to context
context.add_message(MessageRole.ASSISTANT, full_response)
except Exception as e:
logger.error(f"OpenAI streaming failed: {e}")
raise
async def health_check(self) -> bool:
"""Check OpenAI API health."""
try:
client = self._get_client()
await client.models.list()
return True
except Exception as e:
logger.warning(f"OpenAI health check failed: {e}")
return False
class AnthropicProvider(AIProvider):
"""Anthropic Claude provider implementation."""
def __init__(self, config: AIProviderConfig):
super().__init__(config)
self._client = None
def _get_client(self):
"""Lazy initialization of Anthropic client."""
if self._client is None:
try:
import anthropic
self._client = anthropic.AsyncAnthropic(
api_key=self.config.api_key or os.getenv("ANTHROPIC_API_KEY"),
timeout=self.config.timeout
)
except ImportError:
raise ImportError("Anthropic package not installed. Install with: pip install anthropic")
return self._client
async def generate_response(self, context: ConversationContext, message: str) -> str:
"""Generate response using Anthropic API."""
client = self._get_client()
context.add_message(MessageRole.USER, message)
# Convert messages for Anthropic format
messages = []
system_prompt = None
for msg in context.messages:
if msg.role == MessageRole.SYSTEM:
system_prompt = msg.content
else:
messages.append({
"role": msg.role.value,
"content": msg.content
})
for attempt in range(self.config.retry_attempts):
try:
kwargs = {
"model": self.config.model,
"messages": messages,
"max_tokens": self.config.max_tokens,
"temperature": self.config.temperature,
}
if system_prompt:
kwargs["system"] = system_prompt
response = await client.messages.create(**kwargs)
response_text = response.content[0].text
context.add_message(MessageRole.ASSISTANT, response_text)
return response_text
except Exception as e:
logger.warning(f"Anthropic API attempt {attempt + 1} failed: {e}")
if attempt < self.config.retry_attempts - 1:
await asyncio.sleep(self.config.retry_delay * (2 ** attempt))
else:
raise
async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]:
"""Stream response using Anthropic API."""
client = self._get_client()
context.add_message(MessageRole.USER, message)
# Convert messages for Anthropic format
messages = []
system_prompt = None
for msg in context.messages:
if msg.role == MessageRole.SYSTEM:
system_prompt = msg.content
else:
messages.append({
"role": msg.role.value,
"content": msg.content
})
try:
kwargs = {
"model": self.config.model,
"messages": messages,
"max_tokens": self.config.max_tokens,
"temperature": self.config.temperature,
}
if system_prompt:
kwargs["system"] = system_prompt
async with client.messages.stream(**kwargs) as stream:
full_response = ""
async for text in stream.text_stream:
full_response += text
yield text
context.add_message(MessageRole.ASSISTANT, full_response)
except Exception as e:
logger.error(f"Anthropic streaming failed: {e}")
raise
async def health_check(self) -> bool:
"""Check Anthropic API health."""
try:
client = self._get_client()
# Simple test to verify API connectivity
await client.messages.create(
model=self.config.model,
messages=[{"role": "user", "content": "test"}],
max_tokens=1
)
return True
except Exception as e:
logger.warning(f"Anthropic health check failed: {e}")
return False
class LocalProvider(AIProvider):
"""Local model provider (e.g., Ollama, llama.cpp)."""
def __init__(self, config: AIProviderConfig):
super().__init__(config)
self.base_url = config.base_url or "http://localhost:11434"
async def generate_response(self, context: ConversationContext, message: str) -> str:
"""Generate response using local model API."""
context.add_message(MessageRole.USER, message)
import aiohttp
async with aiohttp.ClientSession() as session:
payload = {
"model": self.config.model,
"messages": context.get_context_messages(),
"stream": False,
"options": {
"temperature": self.config.temperature,
"num_predict": self.config.max_tokens
}
}
logger.info(f"LocalProvider generate payload: {payload}")
try:
async with session.post(
f"{self.base_url}/chat/completions",
json=payload,
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
) as resp:
if resp.status == 200:
result = await resp.json()
# Handle OpenAI-compatible format
if "choices" in result and len(result["choices"]) > 0:
choice = result["choices"][0]
if "message" in choice and "content" in choice["message"]:
response_text = choice["message"]["content"]
elif "text" in choice:
# Some APIs use "text" instead of "message.content"
response_text = choice["text"]
else:
raise Exception(f"Unexpected response format: {result}")
# Fallback to original format for backward compatibility
elif "message" in result and "content" in result["message"]:
response_text = result["message"]["content"]
else:
raise Exception(f"Unexpected response format: {result}")
context.add_message(MessageRole.ASSISTANT, response_text)
return response_text
else:
raise Exception(f"Local API returned status {resp.status}")
except Exception as e:
logger.error(f"Local provider failed: {e}")
raise
async def stream_response(self, context: ConversationContext, message: str) -> AsyncIterator[str]:
"""Stream response using local model API."""
context.add_message(MessageRole.USER, message)
import aiohttp
async with aiohttp.ClientSession() as session:
payload = {
"model": self.config.model,
"messages": context.get_context_messages(),
"stream": True
}
logger.info(f"LocalProvider stream payload: {self.base_url} {payload}")
try:
async with session.post(
f"{self.base_url}/chat/completions",
json=payload,
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
) as resp:
if resp.status == 200:
full_response = ""
import json
async for line in resp.content:
if line:
try:
# Decode the line
line_str = line.decode('utf-8').strip()
# Skip empty lines
if not line_str:
continue
# Handle Server-Sent Events format
if line_str.startswith('data: '):
line_str = line_str[6:] # Remove 'data: ' prefix
# Skip end-of-stream marker
if line_str == '[DONE]':
break
data = json.loads(line_str)
# Handle OpenAI-compatible format
if "choices" in data and len(data["choices"]) > 0:
choice = data["choices"][0]
if "delta" in choice and "content" in choice["delta"]:
content = choice["delta"]["content"]
if content: # Only yield non-empty content
full_response += content
yield content
except (UnicodeDecodeError, json.JSONDecodeError) as e:
logger.debug(f"Skipping invalid line: {line} (error: {e})")
continue
except Exception as e:
logger.warning(f"Unexpected error processing line: {e}")
continue
context.add_message(MessageRole.ASSISTANT, full_response)
else:
raise Exception(f"Local API returned status {resp.status}")
except Exception as e:
logger.error(f"Local provider streaming failed {self.base_url}: {e}")
raise
async def health_check(self) -> bool:
"""Check local model health."""
try:
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/api/tags",
timeout=aiohttp.ClientTimeout(total=5)
) as resp:
return resp.status == 200
except Exception as e:
logger.warning(f"Local provider health check failed: {e}")
return False
class AIProviderManager:
"""Manager for AI providers and configurations."""
def __init__(self):
self.providers: Dict[str, AIProvider] = {}
self.default_configs = self._load_default_configs()
def _load_default_configs(self) -> Dict[AIProviderType, AIProviderConfig]:
"""Load default configurations for providers."""
return {
AIProviderType.OPENAI: AIProviderConfig(
provider_type=AIProviderType.OPENAI,
model=os.getenv("OPENAI_MODEL", "gpt-3.5-turbo"),
max_tokens=int(os.getenv("OPENAI_MAX_TOKENS", "1000")),
temperature=float(os.getenv("OPENAI_TEMPERATURE", "0.7"))
),
AIProviderType.ANTHROPIC: AIProviderConfig(
provider_type=AIProviderType.ANTHROPIC,
model=os.getenv("ANTHROPIC_MODEL", "claude-3-5-sonnet-20241022"),
max_tokens=int(os.getenv("ANTHROPIC_MAX_TOKENS", "1000")),
temperature=float(os.getenv("ANTHROPIC_TEMPERATURE", "0.7"))
),
AIProviderType.LOCAL: AIProviderConfig(
provider_type=AIProviderType.LOCAL,
base_url=os.getenv("OPENAI_BASE_URL", "http://localhost:11434"),
model=os.getenv("OPENAI_MODEL", "llama2"),
max_tokens=int(os.getenv("LOCAL_MAX_TOKENS", "1000")),
temperature=float(os.getenv("LOCAL_TEMPERATURE", "0.7"))
)
}
def create_provider(self, provider_type: AIProviderType, config: Optional[AIProviderConfig] = None) -> AIProvider:
"""Create an AI provider instance."""
if config is None:
config = self.default_configs.get(provider_type)
if config is None:
raise ValueError(f"No default config for provider type: {provider_type}")
if provider_type == AIProviderType.OPENAI:
return OpenAIProvider(config)
elif provider_type == AIProviderType.ANTHROPIC:
return AnthropicProvider(config)
elif provider_type == AIProviderType.LOCAL:
return LocalProvider(config)
else:
raise ValueError(f"Unsupported provider type: {provider_type}")
def register_provider(self, name: str, provider: AIProvider):
"""Register a provider instance."""
self.providers[name] = provider
logger.info(f"Registered AI provider: {name} ({provider.provider_type})")
def get_provider(self, name: str) -> Optional[AIProvider]:
"""Get a registered provider."""
return self.providers.get(name)
def list_providers(self) -> List[str]:
"""List all registered provider names."""
return list(self.providers.keys())
async def health_check_all(self) -> Dict[str, bool]:
"""Health check all registered providers."""
results = {}
for name, provider in self.providers.items():
try:
results[name] = await provider.health_check()
except Exception as e:
logger.error(f"Health check failed for provider {name}: {e}")
results[name] = False
return results
# Global provider manager instance
ai_provider_manager = AIProviderManager()