ai-voicebot/server/core/error_handling.py

387 lines
13 KiB
Python

"""
Enhanced Error Handling and Recovery
This module provides robust error handling, recovery mechanisms, and resilience patterns
for the AI Voice Bot server. It includes custom exceptions, circuit breakers, retry logic,
and graceful degradation strategies.
"""
import asyncio
import time
import traceback
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic
from functools import wraps
from dataclasses import dataclass
from fastapi import WebSocket
from logger import logger
T = TypeVar('T')
class ErrorSeverity(Enum):
"""Error severity levels for categorization and handling"""
LOW = "low" # Minor issues, continue operation
MEDIUM = "medium" # Moderate issues, may need user notification
HIGH = "high" # Serious issues, may affect functionality
CRITICAL = "critical" # Critical issues, immediate attention required
class ErrorCategory(Enum):
"""Error categories for better classification"""
WEBSOCKET = "websocket"
WEBRTC = "webrtc"
SESSION = "session"
LOBBY = "lobby"
AUTH = "auth"
PERSISTENCE = "persistence"
NETWORK = "network"
VALIDATION = "validation"
SYSTEM = "system"
# Custom Exception Classes
class VoiceBotError(Exception):
"""Base exception for all voice bot related errors"""
def __init__(
self,
message: str,
category: ErrorCategory = ErrorCategory.SYSTEM,
severity: ErrorSeverity = ErrorSeverity.MEDIUM,
context: Optional[Dict[str, Any]] = None,
recoverable: bool = True
):
super().__init__(message)
self.category = category
self.severity = severity
self.context = context or {}
self.recoverable = recoverable
self.timestamp = time.time()
class WebSocketError(VoiceBotError):
"""WebSocket related errors"""
def __init__(self, message: str, **kwargs):
super().__init__(message, category=ErrorCategory.WEBSOCKET, **kwargs)
class WebRTCError(VoiceBotError):
"""WebRTC signaling and connection errors"""
def __init__(self, message: str, **kwargs):
super().__init__(message, category=ErrorCategory.WEBRTC, **kwargs)
class SessionError(VoiceBotError):
"""Session management errors"""
def __init__(self, message: str, **kwargs):
super().__init__(message, category=ErrorCategory.SESSION, **kwargs)
class LobbyError(VoiceBotError):
"""Lobby management errors"""
def __init__(self, message: str, **kwargs):
super().__init__(message, category=ErrorCategory.LOBBY, **kwargs)
class AuthError(VoiceBotError):
"""Authentication and authorization errors"""
def __init__(self, message: str, **kwargs):
super().__init__(message, category=ErrorCategory.AUTH, **kwargs)
class PersistenceError(VoiceBotError):
"""Data persistence and storage errors"""
def __init__(self, message: str, **kwargs):
super().__init__(message, category=ErrorCategory.PERSISTENCE, **kwargs)
class ValidationError(VoiceBotError):
"""Input validation errors"""
def __init__(self, message: str, **kwargs):
super().__init__(message, category=ErrorCategory.VALIDATION, **kwargs)
@dataclass
class CircuitBreakerState:
"""Circuit breaker state tracking"""
failures: int = 0
last_failure_time: float = 0
is_open: bool = False
last_success_time: float = 0
class CircuitBreaker:
"""Circuit breaker pattern implementation for preventing cascading failures"""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: float = 60.0,
expected_exception: type = Exception
):
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exception = expected_exception
self.state = CircuitBreakerState()
def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
async def wrapper(*args, **kwargs) -> T:
# Check if circuit is open
if self.state.is_open:
if time.time() - self.state.last_failure_time < self.recovery_timeout:
raise VoiceBotError(
f"Circuit breaker open for {func.__name__}",
severity=ErrorSeverity.HIGH,
recoverable=False
)
else:
# Try to close circuit (half-open state)
self.state.is_open = False
logger.info(f"Circuit breaker half-open for {func.__name__}")
try:
result = await func(*args, **kwargs)
# Success - reset failure count
self.state.failures = 0
self.state.last_success_time = time.time()
return result
except self.expected_exception as e:
self.state.failures += 1
self.state.last_failure_time = time.time()
if self.state.failures >= self.failure_threshold:
self.state.is_open = True
logger.error(
f"Circuit breaker opened for {func.__name__} "
f"after {self.state.failures} failures"
)
raise e
return wrapper
class RetryStrategy:
"""Retry strategy with exponential backoff"""
def __init__(
self,
max_attempts: int = 3,
base_delay: float = 1.0,
max_delay: float = 60.0,
backoff_factor: float = 2.0,
retriable_exceptions: Optional[List[type]] = None
):
self.max_attempts = max_attempts
self.base_delay = base_delay
self.max_delay = max_delay
self.backoff_factor = backoff_factor
self.retriable_exceptions = retriable_exceptions or [Exception]
def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
async def wrapper(*args, **kwargs) -> T:
last_exception = None
for attempt in range(1, self.max_attempts + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
# Check if exception is retriable
if not any(isinstance(e, exc_type) for exc_type in self.retriable_exceptions):
raise e
if attempt == self.max_attempts:
logger.error(
f"Function {func.__name__} failed after {self.max_attempts} attempts"
)
raise e
# Calculate delay with exponential backoff
delay = min(
self.base_delay * (self.backoff_factor ** (attempt - 1)),
self.max_delay
)
logger.warning(
f"Function {func.__name__} failed (attempt {attempt}/{self.max_attempts}), "
f"retrying in {delay:.2f}s: {e}"
)
await asyncio.sleep(delay)
raise last_exception
return wrapper
class ErrorHandler:
"""Central error handler with context and recovery strategies"""
def __init__(self):
self.error_counts: Dict[str, int] = {}
self.error_history: List[VoiceBotError] = []
self.max_history = 100
async def handle_error(
self,
error: Exception,
context: Dict[str, Any],
websocket: Optional[WebSocket] = None,
session_id: Optional[str] = None,
recovery_action: Optional[Callable] = None
) -> bool:
"""
Handle an error with context and optional recovery.
Returns:
bool: True if error was handled successfully, False otherwise
"""
# Convert to VoiceBotError if needed
if not isinstance(error, VoiceBotError):
error = VoiceBotError(
str(error),
context=context,
severity=ErrorSeverity.MEDIUM
)
# Add context
error.context.update(context)
# Track error
error_key = f"{error.category.value}:{type(error).__name__}"
self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1
# Add to history (maintain size limit)
self.error_history.append(error)
if len(self.error_history) > self.max_history:
self.error_history = self.error_history[-self.max_history:]
# Log error with context
log_message = (
f"Error in {context.get('operation', 'unknown')}: {error} "
f"(Category: {error.category.value}, Severity: {error.severity.value})"
)
if error.severity in [ErrorSeverity.HIGH, ErrorSeverity.CRITICAL]:
logger.error(log_message)
logger.error(f"Error context: {error.context}")
if hasattr(error, '__traceback__'):
logger.error(f"Traceback: {traceback.format_tb(error.__traceback__)}")
else:
logger.warning(log_message)
# Notify client if WebSocket available
if websocket and error.severity != ErrorSeverity.LOW:
await self._notify_client_error(websocket, error, session_id)
# Attempt recovery if provided and error is recoverable
if recovery_action and error.recoverable:
try:
await recovery_action()
logger.info(f"Recovery action succeeded for {error_key}")
return True
except Exception as recovery_error:
logger.error(f"Recovery action failed for {error_key}: {recovery_error}")
return error.severity in [ErrorSeverity.LOW, ErrorSeverity.MEDIUM]
async def _notify_client_error(
self,
websocket: WebSocket,
error: VoiceBotError,
session_id: Optional[str]
):
"""Notify client about error via WebSocket"""
try:
error_message = {
"type": "error",
"data": {
"message": str(error),
"category": error.category.value,
"severity": error.severity.value,
"recoverable": error.recoverable,
"timestamp": error.timestamp
}
}
if session_id:
error_message["data"]["session_id"] = session_id
await websocket.send_json(error_message)
except Exception as notify_error:
logger.error(f"Failed to notify client of error: {notify_error}")
def get_error_statistics(self) -> Dict[str, Any]:
"""Get error statistics for monitoring"""
return {
"error_counts": dict(self.error_counts),
"total_errors": len(self.error_history),
"recent_errors": [
{
"message": str(err),
"category": err.category.value,
"severity": err.severity.value,
"timestamp": err.timestamp
}
for err in self.error_history[-10:] # Last 10 errors
]
}
# Global error handler instance
error_handler = ErrorHandler()
# Decorator shortcuts for common patterns
def with_websocket_error_handling(func: Callable[..., T]) -> Callable[..., T]:
"""Decorator for WebSocket operations with error handling"""
@wraps(func)
async def wrapper(*args, **kwargs) -> T:
try:
return await func(*args, **kwargs)
except Exception as e:
await error_handler.handle_error(
WebSocketError(f"WebSocket operation failed: {e}"),
context={"function": func.__name__, "args": str(args)[:200]}
)
raise
return wrapper
def with_webrtc_error_handling(func: Callable[..., T]) -> Callable[..., T]:
"""Decorator for WebRTC operations with error handling"""
@wraps(func)
async def wrapper(*args, **kwargs) -> T:
try:
return await func(*args, **kwargs)
except Exception as e:
await error_handler.handle_error(
WebRTCError(f"WebRTC operation failed: {e}"),
context={"function": func.__name__, "args": str(args)[:200]}
)
raise
return wrapper
def with_session_error_handling(func: Callable[..., T]) -> Callable[..., T]:
"""Decorator for Session operations with error handling"""
@wraps(func)
async def wrapper(*args, **kwargs) -> T:
try:
return await func(*args, **kwargs)
except Exception as e:
await error_handler.handle_error(
SessionError(f"Session operation failed: {e}"),
context={"function": func.__name__, "args": str(args)[:200]}
)
raise
return wrapper