526 lines
20 KiB
Python
526 lines
20 KiB
Python
"""
|
|
Rate limiting utilities for guest and authenticated users
|
|
"""
|
|
from __future__ import annotations
|
|
from functools import wraps
|
|
from datetime import datetime, timedelta, UTC
|
|
from typing import Callable, Dict, Optional, Any
|
|
from fastapi import Depends, HTTPException, Request
|
|
from pydantic import BaseModel # type: ignore
|
|
from database.manager import RedisDatabase
|
|
from logger import logger
|
|
|
|
from . dependencies import get_current_user_or_guest, get_database
|
|
|
|
async def get_rate_limiter(database: RedisDatabase = Depends(get_database)) -> RateLimiter:
|
|
"""Dependency to get rate limiter instance"""
|
|
return RateLimiter(database)
|
|
|
|
class RateLimitConfig(BaseModel):
|
|
"""Rate limit configuration"""
|
|
requests_per_minute: int
|
|
requests_per_hour: int
|
|
requests_per_day: int
|
|
burst_limit: int # Maximum requests in a short burst
|
|
burst_window_seconds: int = 60 # Window for burst detection
|
|
|
|
class GuestRateLimitConfig(RateLimitConfig):
|
|
"""Rate limits for guest users - more restrictive"""
|
|
requests_per_minute: int = 10
|
|
requests_per_hour: int = 100
|
|
requests_per_day: int = 500
|
|
burst_limit: int = 15
|
|
burst_window_seconds: int = 60
|
|
|
|
class AuthenticatedUserRateLimitConfig(RateLimitConfig):
|
|
"""Rate limits for authenticated users - more generous"""
|
|
requests_per_minute: int = 60
|
|
requests_per_hour: int = 1000
|
|
requests_per_day: int = 10000
|
|
burst_limit: int = 100
|
|
burst_window_seconds: int = 60
|
|
|
|
class PremiumUserRateLimitConfig(RateLimitConfig):
|
|
"""Rate limits for premium/admin users - most generous"""
|
|
requests_per_minute: int = 120
|
|
requests_per_hour: int = 5000
|
|
requests_per_day: int = 50000
|
|
burst_limit: int = 200
|
|
burst_window_seconds: int = 60
|
|
|
|
class RateLimitResult(BaseModel):
|
|
"""Result of rate limit check"""
|
|
allowed: bool
|
|
reason: Optional[str] = None
|
|
retry_after_seconds: Optional[int] = None
|
|
remaining_requests: Dict[str, int] = {}
|
|
reset_times: Dict[str, datetime] = {}
|
|
|
|
class RateLimiter:
|
|
"""Rate limiter using Redis for distributed rate limiting"""
|
|
|
|
def __init__(self, database: RedisDatabase):
|
|
self.database = database
|
|
self.redis = database.redis
|
|
|
|
# Rate limit configurations
|
|
self.guest_config = GuestRateLimitConfig()
|
|
self.user_config = AuthenticatedUserRateLimitConfig()
|
|
self.premium_config = PremiumUserRateLimitConfig()
|
|
|
|
def get_config_for_user(self, user_type: str, is_admin: bool = False) -> RateLimitConfig:
|
|
"""Get rate limit configuration based on user type"""
|
|
if user_type == "guest":
|
|
return self.guest_config
|
|
elif is_admin:
|
|
return self.premium_config
|
|
else:
|
|
return self.user_config
|
|
|
|
async def check_rate_limit(
|
|
self,
|
|
user_id: str,
|
|
user_type: str,
|
|
is_admin: bool = False,
|
|
endpoint: Optional[str] = None
|
|
) -> RateLimitResult:
|
|
"""
|
|
Check if user has exceeded rate limits
|
|
|
|
Args:
|
|
user_id: Unique identifier for the user (guest session ID or user ID)
|
|
user_type: "guest", "candidate", or "employer"
|
|
is_admin: Whether user has admin privileges
|
|
endpoint: Optional endpoint-specific rate limiting
|
|
|
|
Returns:
|
|
RateLimitResult indicating if request is allowed
|
|
"""
|
|
config = self.get_config_for_user(user_type, is_admin)
|
|
current_time = datetime.now(UTC)
|
|
|
|
# Create Redis keys for different time windows
|
|
base_key = f"rate_limit:{user_type}:{user_id}"
|
|
keys = {
|
|
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
|
|
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
|
|
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
|
|
"burst": f"{base_key}:burst"
|
|
}
|
|
|
|
# Add endpoint-specific limiting if provided
|
|
if endpoint:
|
|
keys = {k: f"{v}:{endpoint}" for k, v in keys.items()}
|
|
|
|
try:
|
|
# Use Redis pipeline for atomic operations
|
|
pipe = self.redis.pipeline()
|
|
|
|
# Get current counts
|
|
for key in keys.values():
|
|
pipe.get(key)
|
|
|
|
results = await pipe.execute()
|
|
current_counts = {
|
|
"minute": int(results[0] or 0),
|
|
"hour": int(results[1] or 0),
|
|
"day": int(results[2] or 0),
|
|
"burst": int(results[3] or 0)
|
|
}
|
|
|
|
# Check limits
|
|
limits = {
|
|
"minute": config.requests_per_minute,
|
|
"hour": config.requests_per_hour,
|
|
"day": config.requests_per_day,
|
|
"burst": config.burst_limit
|
|
}
|
|
|
|
# Check each limit
|
|
for window, current_count in current_counts.items():
|
|
limit = limits[window]
|
|
if current_count >= limit:
|
|
# Calculate retry after time
|
|
if window == "minute":
|
|
retry_after = 60 - current_time.second
|
|
elif window == "hour":
|
|
retry_after = 3600 - (current_time.minute * 60 + current_time.second)
|
|
elif window == "day":
|
|
retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second)
|
|
else: # burst
|
|
retry_after = config.burst_window_seconds
|
|
|
|
logger.warning(f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}")
|
|
|
|
return RateLimitResult(
|
|
allowed=False,
|
|
reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}",
|
|
retry_after_seconds=retry_after,
|
|
remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()},
|
|
reset_times=self._calculate_reset_times(current_time)
|
|
)
|
|
|
|
# If we get here, request is allowed - increment counters
|
|
pipe = self.redis.pipeline()
|
|
|
|
# Increment minute counter (expires after 2 minutes)
|
|
pipe.incr(keys["minute"])
|
|
pipe.expire(keys["minute"], 120)
|
|
|
|
# Increment hour counter (expires after 2 hours)
|
|
pipe.incr(keys["hour"])
|
|
pipe.expire(keys["hour"], 7200)
|
|
|
|
# Increment day counter (expires after 2 days)
|
|
pipe.incr(keys["day"])
|
|
pipe.expire(keys["day"], 172800)
|
|
|
|
# Increment burst counter (expires after burst window)
|
|
pipe.incr(keys["burst"])
|
|
pipe.expire(keys["burst"], config.burst_window_seconds)
|
|
|
|
await pipe.execute()
|
|
|
|
# Calculate remaining requests
|
|
remaining = {
|
|
k: max(0, limits[k] - (current_counts[k] + 1))
|
|
for k in current_counts.keys()
|
|
}
|
|
|
|
logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}")
|
|
|
|
return RateLimitResult(
|
|
allowed=True,
|
|
remaining_requests=remaining,
|
|
reset_times=self._calculate_reset_times(current_time)
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Rate limit check failed for {user_id}: {e}")
|
|
# Fail open - allow request if rate limiting system fails
|
|
return RateLimitResult(allowed=True, reason="Rate limit check failed - allowing request")
|
|
|
|
def _calculate_reset_times(self, current_time: datetime) -> Dict[str, datetime]:
|
|
"""Calculate when each rate limit window resets"""
|
|
next_minute = current_time.replace(second=0, microsecond=0) + timedelta(minutes=1)
|
|
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
|
next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
|
|
|
|
return {
|
|
"minute": next_minute,
|
|
"hour": next_hour,
|
|
"day": next_day
|
|
}
|
|
|
|
async def get_user_rate_limit_status(
|
|
self,
|
|
user_id: str,
|
|
user_type: str,
|
|
is_admin: bool = False
|
|
) -> Dict[str, Any]:
|
|
"""Get current rate limit status for a user"""
|
|
config = self.get_config_for_user(user_type, is_admin)
|
|
current_time = datetime.now(UTC)
|
|
|
|
base_key = f"rate_limit:{user_type}:{user_id}"
|
|
keys = {
|
|
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
|
|
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
|
|
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
|
|
"burst": f"{base_key}:burst"
|
|
}
|
|
|
|
try:
|
|
pipe = self.redis.pipeline()
|
|
for key in keys.values():
|
|
pipe.get(key)
|
|
|
|
results = await pipe.execute()
|
|
current_counts = {
|
|
"minute": int(results[0] or 0),
|
|
"hour": int(results[1] or 0),
|
|
"day": int(results[2] or 0),
|
|
"burst": int(results[3] or 0)
|
|
}
|
|
|
|
limits = {
|
|
"minute": config.requests_per_minute,
|
|
"hour": config.requests_per_hour,
|
|
"day": config.requests_per_day,
|
|
"burst": config.burst_limit
|
|
}
|
|
|
|
return {
|
|
"user_id": user_id,
|
|
"user_type": user_type,
|
|
"is_admin": is_admin,
|
|
"current_usage": current_counts,
|
|
"limits": limits,
|
|
"remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()},
|
|
"reset_times": self._calculate_reset_times(current_time),
|
|
"config": config.model_dump()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Failed to get rate limit status for {user_id}: {e}")
|
|
return {"error": str(e)}
|
|
|
|
async def reset_user_rate_limits(self, user_id: str, user_type: str) -> bool:
|
|
"""Reset all rate limits for a user (admin function)"""
|
|
try:
|
|
base_key = f"rate_limit:{user_type}:{user_id}"
|
|
pattern = f"{base_key}:*"
|
|
|
|
cursor = 0
|
|
deleted_count = 0
|
|
|
|
while True:
|
|
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
|
|
if keys:
|
|
await self.redis.delete(*keys)
|
|
deleted_count += len(keys)
|
|
|
|
if cursor == 0:
|
|
break
|
|
|
|
logger.info(f"🔄 Reset {deleted_count} rate limit keys for {user_type} {user_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Failed to reset rate limits for {user_id}: {e}")
|
|
return False
|
|
|
|
|
|
# ============================
|
|
# Rate Limited Decorator
|
|
# ============================
|
|
|
|
def rate_limited(
|
|
guest_per_minute: int = 10,
|
|
user_per_minute: int = 60,
|
|
admin_per_minute: int = 120,
|
|
endpoint_specific: bool = True
|
|
):
|
|
"""
|
|
Decorator to easily apply rate limiting to endpoints
|
|
|
|
Args:
|
|
guest_per_minute: Rate limit for guest users
|
|
user_per_minute: Rate limit for authenticated users
|
|
admin_per_minute: Rate limit for admin users
|
|
endpoint_specific: Whether to apply endpoint-specific limits
|
|
|
|
Usage:
|
|
@rate_limited(guest_per_minute=5, user_per_minute=30)
|
|
@api_router.post("/my-endpoint")
|
|
async def my_endpoint(
|
|
request: Request,
|
|
current_user = Depends(get_current_user_or_guest),
|
|
database: RedisDatabase = Depends(get_database)
|
|
):
|
|
return {"message": "Rate limited endpoint"}
|
|
"""
|
|
def decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
# Extract dependencies from function signature
|
|
import inspect
|
|
inspect.signature(func)
|
|
|
|
# Get request, current_user, and rate_limiter from kwargs or args
|
|
request = None
|
|
current_user = None
|
|
rate_limiter = None
|
|
|
|
# Try to find dependencies in kwargs first
|
|
for param_name, param_value in kwargs.items():
|
|
if isinstance(param_value, Request):
|
|
request = param_value
|
|
elif hasattr(param_value, 'user_type'): # User-like object
|
|
current_user = param_value
|
|
elif isinstance(param_value, RateLimiter):
|
|
rate_limiter = param_value
|
|
|
|
# If not found in kwargs, check if they're provided via Depends
|
|
if not rate_limiter:
|
|
# Create rate limiter instance (this should ideally come from DI)
|
|
database = get_database()
|
|
rate_limiter = RateLimiter(database)
|
|
|
|
# Apply rate limiting if we have the required components
|
|
if request and current_user and rate_limiter:
|
|
await apply_custom_rate_limiting(
|
|
request, current_user, rate_limiter,
|
|
guest_per_minute, user_per_minute, admin_per_minute
|
|
)
|
|
|
|
# Call the original function
|
|
return await func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
return decorator
|
|
|
|
async def apply_custom_rate_limiting(
|
|
request: Request,
|
|
current_user,
|
|
rate_limiter: RateLimiter,
|
|
guest_per_minute: int,
|
|
user_per_minute: int,
|
|
admin_per_minute: int
|
|
):
|
|
"""Apply custom rate limiting with specified limits"""
|
|
try:
|
|
# Determine user info
|
|
user_id = current_user.id
|
|
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
|
|
is_admin = getattr(current_user, 'is_admin', False)
|
|
|
|
# Determine appropriate limit
|
|
if is_admin:
|
|
requests_per_minute = admin_per_minute
|
|
elif user_type == "guest":
|
|
requests_per_minute = guest_per_minute
|
|
else:
|
|
requests_per_minute = user_per_minute
|
|
|
|
# Create custom rate limit key
|
|
current_time = datetime.now(UTC)
|
|
custom_key = f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
|
|
|
|
# Check current usage
|
|
current_count = int(await rate_limiter.redis.get(custom_key) or 0)
|
|
|
|
if current_count >= requests_per_minute:
|
|
logger.warning(f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}")
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail={
|
|
"error": "Rate limit exceeded",
|
|
"message": f"Custom rate limit exceeded: {current_count}/{requests_per_minute} requests per minute",
|
|
"retryAfter": 60 - current_time.second,
|
|
"userType": user_type,
|
|
"endpoint": request.url.path
|
|
},
|
|
headers={"Retry-After": str(60 - current_time.second)}
|
|
)
|
|
|
|
# Increment counter
|
|
pipe = rate_limiter.redis.pipeline()
|
|
pipe.incr(custom_key)
|
|
pipe.expire(custom_key, 120) # 2 minutes TTL
|
|
await pipe.execute()
|
|
|
|
logger.debug(f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}")
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"❌ Custom rate limiting error: {e}")
|
|
# Fail open
|
|
|
|
# ============================
|
|
# Alternative: FastAPI Dependency-Based Rate Limiting
|
|
# ============================
|
|
|
|
def create_rate_limit_dependency(
|
|
guest_per_minute: int = 10,
|
|
user_per_minute: int = 60,
|
|
admin_per_minute: int = 120
|
|
):
|
|
"""
|
|
Create a FastAPI dependency for rate limiting
|
|
|
|
Usage:
|
|
rate_limit_5_30 = create_rate_limit_dependency(guest_per_minute=5, user_per_minute=30)
|
|
|
|
@api_router.post("/my-endpoint")
|
|
async def my_endpoint(
|
|
rate_check = Depends(rate_limit_5_30),
|
|
current_user = Depends(get_current_user_or_guest),
|
|
database: RedisDatabase = Depends(get_database)
|
|
):
|
|
return {"message": "Rate limited endpoint"}
|
|
"""
|
|
async def rate_limit_dependency(
|
|
request: Request,
|
|
current_user = Depends(get_current_user_or_guest),
|
|
rate_limiter: RateLimiter = Depends(get_rate_limiter)
|
|
):
|
|
await apply_custom_rate_limiting(
|
|
request, current_user, rate_limiter,
|
|
guest_per_minute, user_per_minute, admin_per_minute
|
|
)
|
|
return True
|
|
|
|
return rate_limit_dependency
|
|
|
|
# ============================
|
|
# Rate Limiting Utilities
|
|
# ============================
|
|
|
|
class EndpointRateLimiter:
|
|
"""Utility class for endpoint-specific rate limiting"""
|
|
|
|
def __init__(self, rate_limiter: RateLimiter):
|
|
self.rate_limiter = rate_limiter
|
|
self.custom_limits = {}
|
|
|
|
def set_endpoint_limits(self, endpoint: str, limits: dict):
|
|
"""Set custom limits for an endpoint"""
|
|
self.custom_limits[endpoint] = limits
|
|
|
|
async def check_endpoint_limit(self, request: Request, current_user) -> bool:
|
|
"""Check if request exceeds endpoint-specific limits"""
|
|
endpoint = request.url.path
|
|
|
|
if endpoint not in self.custom_limits:
|
|
return True # No custom limits set
|
|
|
|
limits = self.custom_limits[endpoint]
|
|
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
|
|
|
|
if getattr(current_user, 'is_admin', False):
|
|
user_type = "admin"
|
|
|
|
limit = limits.get(user_type, limits.get("default", 60))
|
|
|
|
current_time = datetime.now(UTC)
|
|
key = f"endpoint_limit:{endpoint}:{user_type}:{current_user.id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
|
|
|
|
current_count = int(await self.rate_limiter.redis.get(key) or 0)
|
|
|
|
if current_count >= limit:
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
|
|
)
|
|
|
|
# Increment counter
|
|
await self.rate_limiter.redis.incr(key)
|
|
await self.rate_limiter.redis.expire(key, 120)
|
|
|
|
return True
|
|
|
|
# Global endpoint rate limiter instance
|
|
endpoint_rate_limiter = None
|
|
|
|
def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limiter)) -> EndpointRateLimiter:
|
|
"""Get endpoint rate limiter instance"""
|
|
global endpoint_rate_limiter
|
|
if endpoint_rate_limiter is None:
|
|
endpoint_rate_limiter = EndpointRateLimiter(rate_limiter)
|
|
|
|
# Configure endpoint-specific limits
|
|
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/chat/sessions/*/messages/stream", {
|
|
"guest": 5, "candidate": 30, "employer": 30, "admin": 100
|
|
})
|
|
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/candidates/documents/upload", {
|
|
"guest": 2, "candidate": 10, "employer": 10, "admin": 50
|
|
})
|
|
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/jobs", {
|
|
"guest": 1, "candidate": 5, "employer": 20, "admin": 50
|
|
})
|
|
|
|
return endpoint_rate_limiter
|
|
|