backstory/src/backend/utils/rate_limiter.py

526 lines
20 KiB
Python

"""
Rate limiting utilities for guest and authenticated users
"""
from __future__ import annotations
from functools import wraps
from datetime import datetime, timedelta, UTC
from typing import Callable, Dict, Optional, Any
from fastapi import Depends, HTTPException, Request
from pydantic import BaseModel # type: ignore
from database.manager import RedisDatabase
from logger import logger
from . dependencies import get_current_user_or_guest, get_database
async def get_rate_limiter(database: RedisDatabase = Depends(get_database)) -> RateLimiter:
"""Dependency to get rate limiter instance"""
return RateLimiter(database)
class RateLimitConfig(BaseModel):
"""Rate limit configuration"""
requests_per_minute: int
requests_per_hour: int
requests_per_day: int
burst_limit: int # Maximum requests in a short burst
burst_window_seconds: int = 60 # Window for burst detection
class GuestRateLimitConfig(RateLimitConfig):
"""Rate limits for guest users - more restrictive"""
requests_per_minute: int = 10
requests_per_hour: int = 100
requests_per_day: int = 500
burst_limit: int = 15
burst_window_seconds: int = 60
class AuthenticatedUserRateLimitConfig(RateLimitConfig):
"""Rate limits for authenticated users - more generous"""
requests_per_minute: int = 60
requests_per_hour: int = 1000
requests_per_day: int = 10000
burst_limit: int = 100
burst_window_seconds: int = 60
class PremiumUserRateLimitConfig(RateLimitConfig):
"""Rate limits for premium/admin users - most generous"""
requests_per_minute: int = 120
requests_per_hour: int = 5000
requests_per_day: int = 50000
burst_limit: int = 200
burst_window_seconds: int = 60
class RateLimitResult(BaseModel):
"""Result of rate limit check"""
allowed: bool
reason: Optional[str] = None
retry_after_seconds: Optional[int] = None
remaining_requests: Dict[str, int] = {}
reset_times: Dict[str, datetime] = {}
class RateLimiter:
"""Rate limiter using Redis for distributed rate limiting"""
def __init__(self, database: RedisDatabase):
self.database = database
self.redis = database.redis
# Rate limit configurations
self.guest_config = GuestRateLimitConfig()
self.user_config = AuthenticatedUserRateLimitConfig()
self.premium_config = PremiumUserRateLimitConfig()
def get_config_for_user(self, user_type: str, is_admin: bool = False) -> RateLimitConfig:
"""Get rate limit configuration based on user type"""
if user_type == "guest":
return self.guest_config
elif is_admin:
return self.premium_config
else:
return self.user_config
async def check_rate_limit(
self,
user_id: str,
user_type: str,
is_admin: bool = False,
endpoint: Optional[str] = None
) -> RateLimitResult:
"""
Check if user has exceeded rate limits
Args:
user_id: Unique identifier for the user (guest session ID or user ID)
user_type: "guest", "candidate", or "employer"
is_admin: Whether user has admin privileges
endpoint: Optional endpoint-specific rate limiting
Returns:
RateLimitResult indicating if request is allowed
"""
config = self.get_config_for_user(user_type, is_admin)
current_time = datetime.now(UTC)
# Create Redis keys for different time windows
base_key = f"rate_limit:{user_type}:{user_id}"
keys = {
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
"burst": f"{base_key}:burst"
}
# Add endpoint-specific limiting if provided
if endpoint:
keys = {k: f"{v}:{endpoint}" for k, v in keys.items()}
try:
# Use Redis pipeline for atomic operations
pipe = self.redis.pipeline()
# Get current counts
for key in keys.values():
pipe.get(key)
results = await pipe.execute()
current_counts = {
"minute": int(results[0] or 0),
"hour": int(results[1] or 0),
"day": int(results[2] or 0),
"burst": int(results[3] or 0)
}
# Check limits
limits = {
"minute": config.requests_per_minute,
"hour": config.requests_per_hour,
"day": config.requests_per_day,
"burst": config.burst_limit
}
# Check each limit
for window, current_count in current_counts.items():
limit = limits[window]
if current_count >= limit:
# Calculate retry after time
if window == "minute":
retry_after = 60 - current_time.second
elif window == "hour":
retry_after = 3600 - (current_time.minute * 60 + current_time.second)
elif window == "day":
retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second)
else: # burst
retry_after = config.burst_window_seconds
logger.warning(f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}")
return RateLimitResult(
allowed=False,
reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}",
retry_after_seconds=retry_after,
remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()},
reset_times=self._calculate_reset_times(current_time)
)
# If we get here, request is allowed - increment counters
pipe = self.redis.pipeline()
# Increment minute counter (expires after 2 minutes)
pipe.incr(keys["minute"])
pipe.expire(keys["minute"], 120)
# Increment hour counter (expires after 2 hours)
pipe.incr(keys["hour"])
pipe.expire(keys["hour"], 7200)
# Increment day counter (expires after 2 days)
pipe.incr(keys["day"])
pipe.expire(keys["day"], 172800)
# Increment burst counter (expires after burst window)
pipe.incr(keys["burst"])
pipe.expire(keys["burst"], config.burst_window_seconds)
await pipe.execute()
# Calculate remaining requests
remaining = {
k: max(0, limits[k] - (current_counts[k] + 1))
for k in current_counts.keys()
}
logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}")
return RateLimitResult(
allowed=True,
remaining_requests=remaining,
reset_times=self._calculate_reset_times(current_time)
)
except Exception as e:
logger.error(f"❌ Rate limit check failed for {user_id}: {e}")
# Fail open - allow request if rate limiting system fails
return RateLimitResult(allowed=True, reason="Rate limit check failed - allowing request")
def _calculate_reset_times(self, current_time: datetime) -> Dict[str, datetime]:
"""Calculate when each rate limit window resets"""
next_minute = current_time.replace(second=0, microsecond=0) + timedelta(minutes=1)
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
return {
"minute": next_minute,
"hour": next_hour,
"day": next_day
}
async def get_user_rate_limit_status(
self,
user_id: str,
user_type: str,
is_admin: bool = False
) -> Dict[str, Any]:
"""Get current rate limit status for a user"""
config = self.get_config_for_user(user_type, is_admin)
current_time = datetime.now(UTC)
base_key = f"rate_limit:{user_type}:{user_id}"
keys = {
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
"burst": f"{base_key}:burst"
}
try:
pipe = self.redis.pipeline()
for key in keys.values():
pipe.get(key)
results = await pipe.execute()
current_counts = {
"minute": int(results[0] or 0),
"hour": int(results[1] or 0),
"day": int(results[2] or 0),
"burst": int(results[3] or 0)
}
limits = {
"minute": config.requests_per_minute,
"hour": config.requests_per_hour,
"day": config.requests_per_day,
"burst": config.burst_limit
}
return {
"user_id": user_id,
"user_type": user_type,
"is_admin": is_admin,
"current_usage": current_counts,
"limits": limits,
"remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()},
"reset_times": self._calculate_reset_times(current_time),
"config": config.model_dump()
}
except Exception as e:
logger.error(f"❌ Failed to get rate limit status for {user_id}: {e}")
return {"error": str(e)}
async def reset_user_rate_limits(self, user_id: str, user_type: str) -> bool:
"""Reset all rate limits for a user (admin function)"""
try:
base_key = f"rate_limit:{user_type}:{user_id}"
pattern = f"{base_key}:*"
cursor = 0
deleted_count = 0
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
if keys:
await self.redis.delete(*keys)
deleted_count += len(keys)
if cursor == 0:
break
logger.info(f"🔄 Reset {deleted_count} rate limit keys for {user_type} {user_id}")
return True
except Exception as e:
logger.error(f"❌ Failed to reset rate limits for {user_id}: {e}")
return False
# ============================
# Rate Limited Decorator
# ============================
def rate_limited(
guest_per_minute: int = 10,
user_per_minute: int = 60,
admin_per_minute: int = 120,
endpoint_specific: bool = True
):
"""
Decorator to easily apply rate limiting to endpoints
Args:
guest_per_minute: Rate limit for guest users
user_per_minute: Rate limit for authenticated users
admin_per_minute: Rate limit for admin users
endpoint_specific: Whether to apply endpoint-specific limits
Usage:
@rate_limited(guest_per_minute=5, user_per_minute=30)
@api_router.post("/my-endpoint")
async def my_endpoint(
request: Request,
current_user = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
):
return {"message": "Rate limited endpoint"}
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract dependencies from function signature
import inspect
inspect.signature(func)
# Get request, current_user, and rate_limiter from kwargs or args
request = None
current_user = None
rate_limiter = None
# Try to find dependencies in kwargs first
for param_name, param_value in kwargs.items():
if isinstance(param_value, Request):
request = param_value
elif hasattr(param_value, 'user_type'): # User-like object
current_user = param_value
elif isinstance(param_value, RateLimiter):
rate_limiter = param_value
# If not found in kwargs, check if they're provided via Depends
if not rate_limiter:
# Create rate limiter instance (this should ideally come from DI)
database = get_database()
rate_limiter = RateLimiter(database)
# Apply rate limiting if we have the required components
if request and current_user and rate_limiter:
await apply_custom_rate_limiting(
request, current_user, rate_limiter,
guest_per_minute, user_per_minute, admin_per_minute
)
# Call the original function
return await func(*args, **kwargs)
return wrapper
return decorator
async def apply_custom_rate_limiting(
request: Request,
current_user,
rate_limiter: RateLimiter,
guest_per_minute: int,
user_per_minute: int,
admin_per_minute: int
):
"""Apply custom rate limiting with specified limits"""
try:
# Determine user info
user_id = current_user.id
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
is_admin = getattr(current_user, 'is_admin', False)
# Determine appropriate limit
if is_admin:
requests_per_minute = admin_per_minute
elif user_type == "guest":
requests_per_minute = guest_per_minute
else:
requests_per_minute = user_per_minute
# Create custom rate limit key
current_time = datetime.now(UTC)
custom_key = f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
# Check current usage
current_count = int(await rate_limiter.redis.get(custom_key) or 0)
if current_count >= requests_per_minute:
logger.warning(f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}")
raise HTTPException(
status_code=429,
detail={
"error": "Rate limit exceeded",
"message": f"Custom rate limit exceeded: {current_count}/{requests_per_minute} requests per minute",
"retryAfter": 60 - current_time.second,
"userType": user_type,
"endpoint": request.url.path
},
headers={"Retry-After": str(60 - current_time.second)}
)
# Increment counter
pipe = rate_limiter.redis.pipeline()
pipe.incr(custom_key)
pipe.expire(custom_key, 120) # 2 minutes TTL
await pipe.execute()
logger.debug(f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}")
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Custom rate limiting error: {e}")
# Fail open
# ============================
# Alternative: FastAPI Dependency-Based Rate Limiting
# ============================
def create_rate_limit_dependency(
guest_per_minute: int = 10,
user_per_minute: int = 60,
admin_per_minute: int = 120
):
"""
Create a FastAPI dependency for rate limiting
Usage:
rate_limit_5_30 = create_rate_limit_dependency(guest_per_minute=5, user_per_minute=30)
@api_router.post("/my-endpoint")
async def my_endpoint(
rate_check = Depends(rate_limit_5_30),
current_user = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
):
return {"message": "Rate limited endpoint"}
"""
async def rate_limit_dependency(
request: Request,
current_user = Depends(get_current_user_or_guest),
rate_limiter: RateLimiter = Depends(get_rate_limiter)
):
await apply_custom_rate_limiting(
request, current_user, rate_limiter,
guest_per_minute, user_per_minute, admin_per_minute
)
return True
return rate_limit_dependency
# ============================
# Rate Limiting Utilities
# ============================
class EndpointRateLimiter:
"""Utility class for endpoint-specific rate limiting"""
def __init__(self, rate_limiter: RateLimiter):
self.rate_limiter = rate_limiter
self.custom_limits = {}
def set_endpoint_limits(self, endpoint: str, limits: dict):
"""Set custom limits for an endpoint"""
self.custom_limits[endpoint] = limits
async def check_endpoint_limit(self, request: Request, current_user) -> bool:
"""Check if request exceeds endpoint-specific limits"""
endpoint = request.url.path
if endpoint not in self.custom_limits:
return True # No custom limits set
limits = self.custom_limits[endpoint]
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
if getattr(current_user, 'is_admin', False):
user_type = "admin"
limit = limits.get(user_type, limits.get("default", 60))
current_time = datetime.now(UTC)
key = f"endpoint_limit:{endpoint}:{user_type}:{current_user.id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
current_count = int(await self.rate_limiter.redis.get(key) or 0)
if current_count >= limit:
raise HTTPException(
status_code=429,
detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
)
# Increment counter
await self.rate_limiter.redis.incr(key)
await self.rate_limiter.redis.expire(key, 120)
return True
# Global endpoint rate limiter instance
endpoint_rate_limiter = None
def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limiter)) -> EndpointRateLimiter:
"""Get endpoint rate limiter instance"""
global endpoint_rate_limiter
if endpoint_rate_limiter is None:
endpoint_rate_limiter = EndpointRateLimiter(rate_limiter)
# Configure endpoint-specific limits
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/chat/sessions/*/messages/stream", {
"guest": 5, "candidate": 30, "employer": 30, "admin": 100
})
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/candidates/documents/upload", {
"guest": 2, "candidate": 10, "employer": 10, "admin": 50
})
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/jobs", {
"guest": 1, "candidate": 5, "employer": 20, "admin": 50
})
return endpoint_rate_limiter