387 lines
13 KiB
Python
387 lines
13 KiB
Python
"""
|
|
Enhanced Error Handling and Recovery
|
|
|
|
This module provides robust error handling, recovery mechanisms, and resilience patterns
|
|
for the AI Voice Bot server. It includes custom exceptions, circuit breakers, retry logic,
|
|
and graceful degradation strategies.
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
import traceback
|
|
from enum import Enum
|
|
from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic
|
|
from functools import wraps
|
|
from dataclasses import dataclass
|
|
from fastapi import WebSocket
|
|
|
|
from logger import logger
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
class ErrorSeverity(Enum):
|
|
"""Error severity levels for categorization and handling"""
|
|
LOW = "low" # Minor issues, continue operation
|
|
MEDIUM = "medium" # Moderate issues, may need user notification
|
|
HIGH = "high" # Serious issues, may affect functionality
|
|
CRITICAL = "critical" # Critical issues, immediate attention required
|
|
|
|
|
|
class ErrorCategory(Enum):
|
|
"""Error categories for better classification"""
|
|
WEBSOCKET = "websocket"
|
|
WEBRTC = "webrtc"
|
|
SESSION = "session"
|
|
LOBBY = "lobby"
|
|
AUTH = "auth"
|
|
PERSISTENCE = "persistence"
|
|
NETWORK = "network"
|
|
VALIDATION = "validation"
|
|
SYSTEM = "system"
|
|
|
|
|
|
# Custom Exception Classes
|
|
class VoiceBotError(Exception):
|
|
"""Base exception for all voice bot related errors"""
|
|
def __init__(
|
|
self,
|
|
message: str,
|
|
category: ErrorCategory = ErrorCategory.SYSTEM,
|
|
severity: ErrorSeverity = ErrorSeverity.MEDIUM,
|
|
context: Optional[Dict[str, Any]] = None,
|
|
recoverable: bool = True
|
|
):
|
|
super().__init__(message)
|
|
self.category = category
|
|
self.severity = severity
|
|
self.context = context or {}
|
|
self.recoverable = recoverable
|
|
self.timestamp = time.time()
|
|
|
|
|
|
class WebSocketError(VoiceBotError):
|
|
"""WebSocket related errors"""
|
|
def __init__(self, message: str, **kwargs):
|
|
super().__init__(message, category=ErrorCategory.WEBSOCKET, **kwargs)
|
|
|
|
|
|
class WebRTCError(VoiceBotError):
|
|
"""WebRTC signaling and connection errors"""
|
|
def __init__(self, message: str, **kwargs):
|
|
super().__init__(message, category=ErrorCategory.WEBRTC, **kwargs)
|
|
|
|
|
|
class SessionError(VoiceBotError):
|
|
"""Session management errors"""
|
|
def __init__(self, message: str, **kwargs):
|
|
super().__init__(message, category=ErrorCategory.SESSION, **kwargs)
|
|
|
|
|
|
class LobbyError(VoiceBotError):
|
|
"""Lobby management errors"""
|
|
def __init__(self, message: str, **kwargs):
|
|
super().__init__(message, category=ErrorCategory.LOBBY, **kwargs)
|
|
|
|
|
|
class AuthError(VoiceBotError):
|
|
"""Authentication and authorization errors"""
|
|
def __init__(self, message: str, **kwargs):
|
|
super().__init__(message, category=ErrorCategory.AUTH, **kwargs)
|
|
|
|
|
|
class PersistenceError(VoiceBotError):
|
|
"""Data persistence and storage errors"""
|
|
def __init__(self, message: str, **kwargs):
|
|
super().__init__(message, category=ErrorCategory.PERSISTENCE, **kwargs)
|
|
|
|
|
|
class ValidationError(VoiceBotError):
|
|
"""Input validation errors"""
|
|
def __init__(self, message: str, **kwargs):
|
|
super().__init__(message, category=ErrorCategory.VALIDATION, **kwargs)
|
|
|
|
|
|
@dataclass
|
|
class CircuitBreakerState:
|
|
"""Circuit breaker state tracking"""
|
|
failures: int = 0
|
|
last_failure_time: float = 0
|
|
is_open: bool = False
|
|
last_success_time: float = 0
|
|
|
|
|
|
class CircuitBreaker:
|
|
"""Circuit breaker pattern implementation for preventing cascading failures"""
|
|
|
|
def __init__(
|
|
self,
|
|
failure_threshold: int = 5,
|
|
recovery_timeout: float = 60.0,
|
|
expected_exception: type = Exception
|
|
):
|
|
self.failure_threshold = failure_threshold
|
|
self.recovery_timeout = recovery_timeout
|
|
self.expected_exception = expected_exception
|
|
self.state = CircuitBreakerState()
|
|
|
|
def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs) -> T:
|
|
# Check if circuit is open
|
|
if self.state.is_open:
|
|
if time.time() - self.state.last_failure_time < self.recovery_timeout:
|
|
raise VoiceBotError(
|
|
f"Circuit breaker open for {func.__name__}",
|
|
severity=ErrorSeverity.HIGH,
|
|
recoverable=False
|
|
)
|
|
else:
|
|
# Try to close circuit (half-open state)
|
|
self.state.is_open = False
|
|
logger.info(f"Circuit breaker half-open for {func.__name__}")
|
|
|
|
try:
|
|
result = await func(*args, **kwargs)
|
|
# Success - reset failure count
|
|
self.state.failures = 0
|
|
self.state.last_success_time = time.time()
|
|
return result
|
|
|
|
except self.expected_exception as e:
|
|
self.state.failures += 1
|
|
self.state.last_failure_time = time.time()
|
|
|
|
if self.state.failures >= self.failure_threshold:
|
|
self.state.is_open = True
|
|
logger.error(
|
|
f"Circuit breaker opened for {func.__name__} "
|
|
f"after {self.state.failures} failures"
|
|
)
|
|
|
|
raise e
|
|
|
|
return wrapper
|
|
|
|
|
|
class RetryStrategy:
|
|
"""Retry strategy with exponential backoff"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_attempts: int = 3,
|
|
base_delay: float = 1.0,
|
|
max_delay: float = 60.0,
|
|
backoff_factor: float = 2.0,
|
|
retriable_exceptions: Optional[List[type]] = None
|
|
):
|
|
self.max_attempts = max_attempts
|
|
self.base_delay = base_delay
|
|
self.max_delay = max_delay
|
|
self.backoff_factor = backoff_factor
|
|
self.retriable_exceptions = retriable_exceptions or [Exception]
|
|
|
|
def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs) -> T:
|
|
last_exception = None
|
|
|
|
for attempt in range(1, self.max_attempts + 1):
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
|
|
except Exception as e:
|
|
last_exception = e
|
|
|
|
# Check if exception is retriable
|
|
if not any(isinstance(e, exc_type) for exc_type in self.retriable_exceptions):
|
|
raise e
|
|
|
|
if attempt == self.max_attempts:
|
|
logger.error(
|
|
f"Function {func.__name__} failed after {self.max_attempts} attempts"
|
|
)
|
|
raise e
|
|
|
|
# Calculate delay with exponential backoff
|
|
delay = min(
|
|
self.base_delay * (self.backoff_factor ** (attempt - 1)),
|
|
self.max_delay
|
|
)
|
|
|
|
logger.warning(
|
|
f"Function {func.__name__} failed (attempt {attempt}/{self.max_attempts}), "
|
|
f"retrying in {delay:.2f}s: {e}"
|
|
)
|
|
|
|
await asyncio.sleep(delay)
|
|
|
|
raise last_exception
|
|
|
|
return wrapper
|
|
|
|
|
|
class ErrorHandler:
|
|
"""Central error handler with context and recovery strategies"""
|
|
|
|
def __init__(self):
|
|
self.error_counts: Dict[str, int] = {}
|
|
self.error_history: List[VoiceBotError] = []
|
|
self.max_history = 100
|
|
|
|
async def handle_error(
|
|
self,
|
|
error: Exception,
|
|
context: Dict[str, Any],
|
|
websocket: Optional[WebSocket] = None,
|
|
session_id: Optional[str] = None,
|
|
recovery_action: Optional[Callable] = None
|
|
) -> bool:
|
|
"""
|
|
Handle an error with context and optional recovery.
|
|
|
|
Returns:
|
|
bool: True if error was handled successfully, False otherwise
|
|
"""
|
|
# Convert to VoiceBotError if needed
|
|
if not isinstance(error, VoiceBotError):
|
|
error = VoiceBotError(
|
|
str(error),
|
|
context=context,
|
|
severity=ErrorSeverity.MEDIUM
|
|
)
|
|
|
|
# Add context
|
|
error.context.update(context)
|
|
|
|
# Track error
|
|
error_key = f"{error.category.value}:{type(error).__name__}"
|
|
self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1
|
|
|
|
# Add to history (maintain size limit)
|
|
self.error_history.append(error)
|
|
if len(self.error_history) > self.max_history:
|
|
self.error_history = self.error_history[-self.max_history:]
|
|
|
|
# Log error with context
|
|
log_message = (
|
|
f"Error in {context.get('operation', 'unknown')}: {error} "
|
|
f"(Category: {error.category.value}, Severity: {error.severity.value})"
|
|
)
|
|
|
|
if error.severity in [ErrorSeverity.HIGH, ErrorSeverity.CRITICAL]:
|
|
logger.error(log_message)
|
|
logger.error(f"Error context: {error.context}")
|
|
if hasattr(error, '__traceback__'):
|
|
logger.error(f"Traceback: {traceback.format_tb(error.__traceback__)}")
|
|
else:
|
|
logger.warning(log_message)
|
|
|
|
# Notify client if WebSocket available
|
|
if websocket and error.severity != ErrorSeverity.LOW:
|
|
await self._notify_client_error(websocket, error, session_id)
|
|
|
|
# Attempt recovery if provided and error is recoverable
|
|
if recovery_action and error.recoverable:
|
|
try:
|
|
await recovery_action()
|
|
logger.info(f"Recovery action succeeded for {error_key}")
|
|
return True
|
|
except Exception as recovery_error:
|
|
logger.error(f"Recovery action failed for {error_key}: {recovery_error}")
|
|
|
|
return error.severity in [ErrorSeverity.LOW, ErrorSeverity.MEDIUM]
|
|
|
|
async def _notify_client_error(
|
|
self,
|
|
websocket: WebSocket,
|
|
error: VoiceBotError,
|
|
session_id: Optional[str]
|
|
):
|
|
"""Notify client about error via WebSocket"""
|
|
try:
|
|
error_message = {
|
|
"type": "error",
|
|
"data": {
|
|
"message": str(error),
|
|
"category": error.category.value,
|
|
"severity": error.severity.value,
|
|
"recoverable": error.recoverable,
|
|
"timestamp": error.timestamp
|
|
}
|
|
}
|
|
|
|
if session_id:
|
|
error_message["data"]["session_id"] = session_id
|
|
|
|
await websocket.send_json(error_message)
|
|
|
|
except Exception as notify_error:
|
|
logger.error(f"Failed to notify client of error: {notify_error}")
|
|
|
|
def get_error_statistics(self) -> Dict[str, Any]:
|
|
"""Get error statistics for monitoring"""
|
|
return {
|
|
"error_counts": dict(self.error_counts),
|
|
"total_errors": len(self.error_history),
|
|
"recent_errors": [
|
|
{
|
|
"message": str(err),
|
|
"category": err.category.value,
|
|
"severity": err.severity.value,
|
|
"timestamp": err.timestamp
|
|
}
|
|
for err in self.error_history[-10:] # Last 10 errors
|
|
]
|
|
}
|
|
|
|
|
|
# Global error handler instance
|
|
error_handler = ErrorHandler()
|
|
|
|
|
|
# Decorator shortcuts for common patterns
|
|
def with_websocket_error_handling(func: Callable[..., T]) -> Callable[..., T]:
|
|
"""Decorator for WebSocket operations with error handling"""
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs) -> T:
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
except Exception as e:
|
|
await error_handler.handle_error(
|
|
WebSocketError(f"WebSocket operation failed: {e}"),
|
|
context={"function": func.__name__, "args": str(args)[:200]}
|
|
)
|
|
raise
|
|
return wrapper
|
|
|
|
|
|
def with_webrtc_error_handling(func: Callable[..., T]) -> Callable[..., T]:
|
|
"""Decorator for WebRTC operations with error handling"""
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs) -> T:
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
except Exception as e:
|
|
await error_handler.handle_error(
|
|
WebRTCError(f"WebRTC operation failed: {e}"),
|
|
context={"function": func.__name__, "args": str(args)[:200]}
|
|
)
|
|
raise
|
|
return wrapper
|
|
|
|
|
|
def with_session_error_handling(func: Callable[..., T]) -> Callable[..., T]:
|
|
"""Decorator for Session operations with error handling"""
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs) -> T:
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
except Exception as e:
|
|
await error_handler.handle_error(
|
|
SessionError(f"Session operation failed: {e}"),
|
|
context={"function": func.__name__, "args": str(args)[:200]}
|
|
)
|
|
raise
|
|
return wrapper
|