""" Rate limiting utilities for guest and authenticated users """ from __future__ import annotations from functools import wraps from datetime import datetime, timedelta, UTC from typing import Callable, Dict, Optional, Any from fastapi import Depends, HTTPException, Request from pydantic import BaseModel # type: ignore from database.manager import RedisDatabase from logger import logger from . dependencies import get_current_user_or_guest, get_database async def get_rate_limiter(database: RedisDatabase = Depends(get_database)) -> RateLimiter: """Dependency to get rate limiter instance""" return RateLimiter(database) class RateLimitConfig(BaseModel): """Rate limit configuration""" requests_per_minute: int requests_per_hour: int requests_per_day: int burst_limit: int # Maximum requests in a short burst burst_window_seconds: int = 60 # Window for burst detection class GuestRateLimitConfig(RateLimitConfig): """Rate limits for guest users - more restrictive""" requests_per_minute: int = 10 requests_per_hour: int = 100 requests_per_day: int = 500 burst_limit: int = 15 burst_window_seconds: int = 60 class AuthenticatedUserRateLimitConfig(RateLimitConfig): """Rate limits for authenticated users - more generous""" requests_per_minute: int = 60 requests_per_hour: int = 1000 requests_per_day: int = 10000 burst_limit: int = 100 burst_window_seconds: int = 60 class PremiumUserRateLimitConfig(RateLimitConfig): """Rate limits for premium/admin users - most generous""" requests_per_minute: int = 120 requests_per_hour: int = 5000 requests_per_day: int = 50000 burst_limit: int = 200 burst_window_seconds: int = 60 class RateLimitResult(BaseModel): """Result of rate limit check""" allowed: bool reason: Optional[str] = None retry_after_seconds: Optional[int] = None remaining_requests: Dict[str, int] = {} reset_times: Dict[str, datetime] = {} class RateLimiter: """Rate limiter using Redis for distributed rate limiting""" def __init__(self, database: RedisDatabase): self.database = database self.redis = database.redis # Rate limit configurations self.guest_config = GuestRateLimitConfig() self.user_config = AuthenticatedUserRateLimitConfig() self.premium_config = PremiumUserRateLimitConfig() def get_config_for_user(self, user_type: str, is_admin: bool = False) -> RateLimitConfig: """Get rate limit configuration based on user type""" if user_type == "guest": return self.guest_config elif is_admin: return self.premium_config else: return self.user_config async def check_rate_limit( self, user_id: str, user_type: str, is_admin: bool = False, endpoint: Optional[str] = None ) -> RateLimitResult: """ Check if user has exceeded rate limits Args: user_id: Unique identifier for the user (guest session ID or user ID) user_type: "guest", "candidate", or "employer" is_admin: Whether user has admin privileges endpoint: Optional endpoint-specific rate limiting Returns: RateLimitResult indicating if request is allowed """ config = self.get_config_for_user(user_type, is_admin) current_time = datetime.now(UTC) # Create Redis keys for different time windows base_key = f"rate_limit:{user_type}:{user_id}" keys = { "minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}", "hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}", "day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}", "burst": f"{base_key}:burst" } # Add endpoint-specific limiting if provided if endpoint: keys = {k: f"{v}:{endpoint}" for k, v in keys.items()} try: # Use Redis pipeline for atomic operations pipe = self.redis.pipeline() # Get current counts for key in keys.values(): pipe.get(key) results = await pipe.execute() current_counts = { "minute": int(results[0] or 0), "hour": int(results[1] or 0), "day": int(results[2] or 0), "burst": int(results[3] or 0) } # Check limits limits = { "minute": config.requests_per_minute, "hour": config.requests_per_hour, "day": config.requests_per_day, "burst": config.burst_limit } # Check each limit for window, current_count in current_counts.items(): limit = limits[window] if current_count >= limit: # Calculate retry after time if window == "minute": retry_after = 60 - current_time.second elif window == "hour": retry_after = 3600 - (current_time.minute * 60 + current_time.second) elif window == "day": retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second) else: # burst retry_after = config.burst_window_seconds logger.warning(f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}") return RateLimitResult( allowed=False, reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}", retry_after_seconds=retry_after, remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()}, reset_times=self._calculate_reset_times(current_time) ) # If we get here, request is allowed - increment counters pipe = self.redis.pipeline() # Increment minute counter (expires after 2 minutes) pipe.incr(keys["minute"]) pipe.expire(keys["minute"], 120) # Increment hour counter (expires after 2 hours) pipe.incr(keys["hour"]) pipe.expire(keys["hour"], 7200) # Increment day counter (expires after 2 days) pipe.incr(keys["day"]) pipe.expire(keys["day"], 172800) # Increment burst counter (expires after burst window) pipe.incr(keys["burst"]) pipe.expire(keys["burst"], config.burst_window_seconds) await pipe.execute() # Calculate remaining requests remaining = { k: max(0, limits[k] - (current_counts[k] + 1)) for k in current_counts.keys() } logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}") return RateLimitResult( allowed=True, remaining_requests=remaining, reset_times=self._calculate_reset_times(current_time) ) except Exception as e: logger.error(f"❌ Rate limit check failed for {user_id}: {e}") # Fail open - allow request if rate limiting system fails return RateLimitResult(allowed=True, reason="Rate limit check failed - allowing request") def _calculate_reset_times(self, current_time: datetime) -> Dict[str, datetime]: """Calculate when each rate limit window resets""" next_minute = current_time.replace(second=0, microsecond=0) + timedelta(minutes=1) next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1) next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1) return { "minute": next_minute, "hour": next_hour, "day": next_day } async def get_user_rate_limit_status( self, user_id: str, user_type: str, is_admin: bool = False ) -> Dict[str, Any]: """Get current rate limit status for a user""" config = self.get_config_for_user(user_type, is_admin) current_time = datetime.now(UTC) base_key = f"rate_limit:{user_type}:{user_id}" keys = { "minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}", "hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}", "day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}", "burst": f"{base_key}:burst" } try: pipe = self.redis.pipeline() for key in keys.values(): pipe.get(key) results = await pipe.execute() current_counts = { "minute": int(results[0] or 0), "hour": int(results[1] or 0), "day": int(results[2] or 0), "burst": int(results[3] or 0) } limits = { "minute": config.requests_per_minute, "hour": config.requests_per_hour, "day": config.requests_per_day, "burst": config.burst_limit } return { "user_id": user_id, "user_type": user_type, "is_admin": is_admin, "current_usage": current_counts, "limits": limits, "remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()}, "reset_times": self._calculate_reset_times(current_time), "config": config.model_dump() } except Exception as e: logger.error(f"❌ Failed to get rate limit status for {user_id}: {e}") return {"error": str(e)} async def reset_user_rate_limits(self, user_id: str, user_type: str) -> bool: """Reset all rate limits for a user (admin function)""" try: base_key = f"rate_limit:{user_type}:{user_id}" pattern = f"{base_key}:*" cursor = 0 deleted_count = 0 while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) if keys: await self.redis.delete(*keys) deleted_count += len(keys) if cursor == 0: break logger.info(f"🔄 Reset {deleted_count} rate limit keys for {user_type} {user_id}") return True except Exception as e: logger.error(f"❌ Failed to reset rate limits for {user_id}: {e}") return False # ============================ # Rate Limited Decorator # ============================ def rate_limited( guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120, endpoint_specific: bool = True ): """ Decorator to easily apply rate limiting to endpoints Args: guest_per_minute: Rate limit for guest users user_per_minute: Rate limit for authenticated users admin_per_minute: Rate limit for admin users endpoint_specific: Whether to apply endpoint-specific limits Usage: @rate_limited(guest_per_minute=5, user_per_minute=30) @api_router.post("/my-endpoint") async def my_endpoint( request: Request, current_user = Depends(get_current_user_or_guest), database: RedisDatabase = Depends(get_database) ): return {"message": "Rate limited endpoint"} """ def decorator(func: Callable) -> Callable: @wraps(func) async def wrapper(*args, **kwargs): # Extract dependencies from function signature import inspect inspect.signature(func) # Get request, current_user, and rate_limiter from kwargs or args request = None current_user = None rate_limiter = None # Try to find dependencies in kwargs first for param_name, param_value in kwargs.items(): if isinstance(param_value, Request): request = param_value elif hasattr(param_value, 'user_type'): # User-like object current_user = param_value elif isinstance(param_value, RateLimiter): rate_limiter = param_value # If not found in kwargs, check if they're provided via Depends if not rate_limiter: # Create rate limiter instance (this should ideally come from DI) database = get_database() rate_limiter = RateLimiter(database) # Apply rate limiting if we have the required components if request and current_user and rate_limiter: await apply_custom_rate_limiting( request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute ) # Call the original function return await func(*args, **kwargs) return wrapper return decorator async def apply_custom_rate_limiting( request: Request, current_user, rate_limiter: RateLimiter, guest_per_minute: int, user_per_minute: int, admin_per_minute: int ): """Apply custom rate limiting with specified limits""" try: # Determine user info user_id = current_user.id user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type) is_admin = getattr(current_user, 'is_admin', False) # Determine appropriate limit if is_admin: requests_per_minute = admin_per_minute elif user_type == "guest": requests_per_minute = guest_per_minute else: requests_per_minute = user_per_minute # Create custom rate limit key current_time = datetime.now(UTC) custom_key = f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}" # Check current usage current_count = int(await rate_limiter.redis.get(custom_key) or 0) if current_count >= requests_per_minute: logger.warning(f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}") raise HTTPException( status_code=429, detail={ "error": "Rate limit exceeded", "message": f"Custom rate limit exceeded: {current_count}/{requests_per_minute} requests per minute", "retryAfter": 60 - current_time.second, "userType": user_type, "endpoint": request.url.path }, headers={"Retry-After": str(60 - current_time.second)} ) # Increment counter pipe = rate_limiter.redis.pipeline() pipe.incr(custom_key) pipe.expire(custom_key, 120) # 2 minutes TTL await pipe.execute() logger.debug(f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}") except HTTPException: raise except Exception as e: logger.error(f"❌ Custom rate limiting error: {e}") # Fail open # ============================ # Alternative: FastAPI Dependency-Based Rate Limiting # ============================ def create_rate_limit_dependency( guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120 ): """ Create a FastAPI dependency for rate limiting Usage: rate_limit_5_30 = create_rate_limit_dependency(guest_per_minute=5, user_per_minute=30) @api_router.post("/my-endpoint") async def my_endpoint( rate_check = Depends(rate_limit_5_30), current_user = Depends(get_current_user_or_guest), database: RedisDatabase = Depends(get_database) ): return {"message": "Rate limited endpoint"} """ async def rate_limit_dependency( request: Request, current_user = Depends(get_current_user_or_guest), rate_limiter: RateLimiter = Depends(get_rate_limiter) ): await apply_custom_rate_limiting( request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute ) return True return rate_limit_dependency # ============================ # Rate Limiting Utilities # ============================ class EndpointRateLimiter: """Utility class for endpoint-specific rate limiting""" def __init__(self, rate_limiter: RateLimiter): self.rate_limiter = rate_limiter self.custom_limits = {} def set_endpoint_limits(self, endpoint: str, limits: dict): """Set custom limits for an endpoint""" self.custom_limits[endpoint] = limits async def check_endpoint_limit(self, request: Request, current_user) -> bool: """Check if request exceeds endpoint-specific limits""" endpoint = request.url.path if endpoint not in self.custom_limits: return True # No custom limits set limits = self.custom_limits[endpoint] user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type) if getattr(current_user, 'is_admin', False): user_type = "admin" limit = limits.get(user_type, limits.get("default", 60)) current_time = datetime.now(UTC) key = f"endpoint_limit:{endpoint}:{user_type}:{current_user.id}:minute:{current_time.strftime('%Y%m%d%H%M')}" current_count = int(await self.rate_limiter.redis.get(key) or 0) if current_count >= limit: raise HTTPException( status_code=429, detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}" ) # Increment counter await self.rate_limiter.redis.incr(key) await self.rate_limiter.redis.expire(key, 120) return True # Global endpoint rate limiter instance endpoint_rate_limiter = None def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limiter)) -> EndpointRateLimiter: """Get endpoint rate limiter instance""" global endpoint_rate_limiter if endpoint_rate_limiter is None: endpoint_rate_limiter = EndpointRateLimiter(rate_limiter) # Configure endpoint-specific limits endpoint_rate_limiter.set_endpoint_limits("/api/1.0/chat/sessions/*/messages/stream", { "guest": 5, "candidate": 30, "employer": 30, "admin": 100 }) endpoint_rate_limiter.set_endpoint_limits("/api/1.0/candidates/documents/upload", { "guest": 2, "candidate": 10, "employer": 10, "admin": 50 }) endpoint_rate_limiter.set_endpoint_limits("/api/1.0/jobs", { "guest": 1, "candidate": 5, "employer": 20, "admin": 50 }) return endpoint_rate_limiter