514 lines
21 KiB
Python
514 lines
21 KiB
Python
from datetime import datetime, timedelta, timezone
|
|
import json
|
|
import logging
|
|
from typing import Any, Dict, TYPE_CHECKING, List, Optional
|
|
|
|
from .protocols import DatabaseProtocol
|
|
|
|
from ..constants import KEY_PREFIXES
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AuthMixin(DatabaseProtocol):
|
|
"""Mixin for auth-related database operations"""
|
|
|
|
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]:
|
|
"""Find pending verification token by email address"""
|
|
try:
|
|
pattern = "email_verification:*"
|
|
cursor = 0
|
|
email_lower = email.lower()
|
|
|
|
while True:
|
|
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
|
|
|
|
for key in keys:
|
|
token_data = await self.redis.get(key)
|
|
if token_data:
|
|
verification_info = json.loads(token_data)
|
|
if (verification_info.get("email", "").lower() == email_lower and
|
|
not verification_info.get("verified", False)):
|
|
# Extract token from key
|
|
token = key.replace("email_verification:", "")
|
|
verification_info["token"] = token
|
|
return verification_info
|
|
|
|
if cursor == 0:
|
|
break
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error finding verification token by email {email}: {e}")
|
|
return None
|
|
|
|
async def get_pending_verifications_count(self) -> int:
|
|
"""Get count of pending email verifications (admin function)"""
|
|
try:
|
|
pattern = "email_verification:*"
|
|
cursor = 0
|
|
count = 0
|
|
|
|
while True:
|
|
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
|
|
|
|
for key in keys:
|
|
token_data = await self.redis.get(key)
|
|
if token_data:
|
|
verification_info = json.loads(token_data)
|
|
if not verification_info.get("verified", False):
|
|
count += 1
|
|
|
|
if cursor == 0:
|
|
break
|
|
|
|
return count
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error counting pending verifications: {e}")
|
|
return 0
|
|
|
|
async def cleanup_expired_verification_tokens(self) -> int:
|
|
"""Clean up expired verification tokens and return count of cleaned tokens"""
|
|
try:
|
|
pattern = "email_verification:*"
|
|
cursor = 0
|
|
cleaned_count = 0
|
|
current_time = datetime.now(timezone.utc)
|
|
|
|
while True:
|
|
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
|
|
|
|
for key in keys:
|
|
token_data = await self.redis.get(key)
|
|
if token_data:
|
|
verification_info = json.loads(token_data)
|
|
expires_at = datetime.fromisoformat(verification_info.get("expires_at", ""))
|
|
|
|
if current_time > expires_at:
|
|
await self.redis.delete(key)
|
|
cleaned_count += 1
|
|
logger.info(f"🧹 Cleaned expired verification token for {verification_info.get('email')}")
|
|
|
|
if cursor == 0:
|
|
break
|
|
|
|
if cleaned_count > 0:
|
|
logger.info(f"🧹 Cleaned up {cleaned_count} expired verification tokens")
|
|
|
|
return cleaned_count
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error cleaning up expired verification tokens: {e}")
|
|
return 0
|
|
|
|
async def get_verification_attempts_count(self, email: str) -> int:
|
|
"""Get the number of verification emails sent for an email in the last 24 hours"""
|
|
try:
|
|
key = f"verification_attempts:{email.lower()}"
|
|
data = await self.redis.get(key)
|
|
|
|
if not data:
|
|
return 0
|
|
|
|
attempts_data = json.loads(data)
|
|
current_time = datetime.now(timezone.utc)
|
|
window_start = current_time - timedelta(hours=24)
|
|
|
|
# Filter out old attempts
|
|
recent_attempts = [
|
|
attempt for attempt in attempts_data
|
|
if datetime.fromisoformat(attempt) > window_start
|
|
]
|
|
|
|
return len(recent_attempts)
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error getting verification attempts count for {email}: {e}")
|
|
return 0
|
|
|
|
async def record_verification_attempt(self, email: str) -> bool:
|
|
"""Record a verification email attempt"""
|
|
try:
|
|
key = f"verification_attempts:{email.lower()}"
|
|
current_time = datetime.now(timezone.utc)
|
|
|
|
# Get existing attempts
|
|
data = await self.redis.get(key)
|
|
attempts_data = json.loads(data) if data else []
|
|
|
|
# Add current attempt
|
|
attempts_data.append(current_time.isoformat())
|
|
|
|
# Keep only last 24 hours of attempts
|
|
window_start = current_time - timedelta(hours=24)
|
|
recent_attempts = [
|
|
attempt for attempt in attempts_data
|
|
if datetime.fromisoformat(attempt) > window_start
|
|
]
|
|
|
|
# Store with 24 hour expiration
|
|
await self.redis.setex(
|
|
key,
|
|
24 * 60 * 60, # 24 hours
|
|
json.dumps(recent_attempts)
|
|
)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error recording verification attempt for {email}: {e}")
|
|
return False
|
|
|
|
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool:
|
|
"""Store email verification token with user data"""
|
|
try:
|
|
key = f"email_verification:{token}"
|
|
verification_data = {
|
|
"email": email.lower(),
|
|
"user_type": user_type,
|
|
"user_data": user_data,
|
|
"expires_at": (datetime.now(timezone.utc) + timedelta(hours=24)).isoformat(),
|
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
"verified": False
|
|
}
|
|
|
|
# Store with 24 hour expiration
|
|
await self.redis.setex(
|
|
key,
|
|
24 * 60 * 60, # 24 hours in seconds
|
|
json.dumps(verification_data, default=str)
|
|
)
|
|
|
|
logger.info(f"📧 Stored email verification token for {email}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"❌ Error storing email verification token: {e}")
|
|
return False
|
|
|
|
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]:
|
|
"""Retrieve email verification token data"""
|
|
try:
|
|
key = f"email_verification:{token}"
|
|
data = await self.redis.get(key)
|
|
if data:
|
|
return json.loads(data)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"❌ Error retrieving email verification token: {e}")
|
|
return None
|
|
|
|
async def mark_email_verified(self, token: str) -> bool:
|
|
"""Mark email verification token as used"""
|
|
try:
|
|
key = f"email_verification:{token}"
|
|
token_data = await self.get_email_verification_token(token)
|
|
if token_data:
|
|
token_data["verified"] = True
|
|
token_data["verified_at"] = datetime.now(timezone.utc).isoformat()
|
|
await self.redis.setex(
|
|
key,
|
|
24 * 60 * 60, # Keep for remaining TTL
|
|
json.dumps(token_data, default=str)
|
|
)
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"❌ Error marking email verified: {e}")
|
|
return False
|
|
|
|
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool:
|
|
"""Store MFA code for verification"""
|
|
try:
|
|
logger.info("🔐 Storing MFA code for email: %s", email )
|
|
key = f"mfa_code:{email.lower()}:{device_id}"
|
|
mfa_data = {
|
|
"code": code,
|
|
"email": email.lower(),
|
|
"device_id": device_id,
|
|
"expires_at": (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(),
|
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
"attempts": 0,
|
|
"verified": False
|
|
}
|
|
|
|
# Store with 10 minute expiration
|
|
await self.redis.setex(
|
|
key,
|
|
10 * 60, # 10 minutes in seconds
|
|
json.dumps(mfa_data, default=str)
|
|
)
|
|
|
|
logger.info(f"🔐 Stored MFA code for {email}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"❌ Error storing MFA code: {e}")
|
|
return False
|
|
|
|
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Retrieve MFA code data"""
|
|
try:
|
|
key = f"mfa_code:{email.lower()}:{device_id}"
|
|
data = await self.redis.get(key)
|
|
if data:
|
|
return json.loads(data)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"❌ Error retrieving MFA code: {e}")
|
|
return None
|
|
|
|
async def increment_mfa_attempts(self, email: str, device_id: str) -> int:
|
|
"""Increment MFA verification attempts"""
|
|
try:
|
|
key = f"mfa_code:{email.lower()}:{device_id}"
|
|
mfa_data = await self.get_mfa_code(email, device_id)
|
|
if mfa_data:
|
|
mfa_data["attempts"] += 1
|
|
await self.redis.setex(
|
|
key,
|
|
10 * 60, # Keep original TTL
|
|
json.dumps(mfa_data, default=str)
|
|
)
|
|
return mfa_data["attempts"]
|
|
return 0
|
|
except Exception as e:
|
|
logger.error(f"❌ Error incrementing MFA attempts: {e}")
|
|
return 0
|
|
|
|
async def mark_mfa_verified(self, email: str, device_id: str) -> bool:
|
|
"""Mark MFA code as verified"""
|
|
try:
|
|
key = f"mfa_code:{email.lower()}:{device_id}"
|
|
mfa_data = await self.get_mfa_code(email, device_id)
|
|
if mfa_data:
|
|
mfa_data["verified"] = True
|
|
mfa_data["verified_at"] = datetime.now(timezone.utc).isoformat()
|
|
await self.redis.setex(
|
|
key,
|
|
10 * 60, # Keep for remaining TTL
|
|
json.dumps(mfa_data, default=str)
|
|
)
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"❌ Error marking MFA verified: {e}")
|
|
return False
|
|
|
|
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool:
|
|
"""Store authentication record for a user"""
|
|
try:
|
|
key = f"auth:{user_id}"
|
|
await self.redis.set(key, json.dumps(auth_data, default=str))
|
|
logger.info(f"🔐 Stored authentication record for user {user_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"❌ Error storing authentication record for {user_id}: {e}")
|
|
return False
|
|
|
|
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Retrieve authentication record for a user"""
|
|
try:
|
|
key = f"auth:{user_id}"
|
|
data = await self.redis.get(key)
|
|
if data:
|
|
return json.loads(data)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"❌ Error retrieving authentication record for {user_id}: {e}")
|
|
return None
|
|
|
|
async def delete_authentication(self, user_id: str) -> bool:
|
|
"""Delete authentication record for a user"""
|
|
try:
|
|
key = f"auth:{user_id}"
|
|
result = await self.redis.delete(key)
|
|
logger.info(f"🔐 Deleted authentication record for user {user_id}")
|
|
return result > 0
|
|
except Exception as e:
|
|
logger.error(f"❌ Error deleting authentication record for {user_id}: {e}")
|
|
return False
|
|
|
|
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool:
|
|
"""Store refresh token for a user"""
|
|
try:
|
|
key = f"refresh_token:{token}"
|
|
token_data = {
|
|
"user_id": user_id,
|
|
"expires_at": expires_at.isoformat(),
|
|
"device": device_info.get("device", "unknown"),
|
|
"ip_address": device_info.get("ip_address", "unknown"),
|
|
"is_revoked": False,
|
|
"created_at": datetime.now(timezone.utc).isoformat()
|
|
}
|
|
|
|
# Store with expiration
|
|
ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
|
|
if ttl_seconds > 0:
|
|
await self.redis.setex(key, ttl_seconds, json.dumps(token_data, default=str))
|
|
logger.info(f"🔐 Stored refresh token for user {user_id}")
|
|
return True
|
|
else:
|
|
logger.warning(f"⚠️ Attempted to store expired refresh token for user {user_id}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"❌ Error storing refresh token for {user_id}: {e}")
|
|
return False
|
|
|
|
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]:
|
|
"""Retrieve refresh token data"""
|
|
try:
|
|
key = f"refresh_token:{token}"
|
|
data = await self.redis.get(key)
|
|
if data:
|
|
return json.loads(data)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"❌ Error retrieving refresh token: {e}")
|
|
return None
|
|
|
|
async def revoke_refresh_token(self, token: str) -> bool:
|
|
"""Revoke a refresh token"""
|
|
try:
|
|
key = f"refresh_token:{token}"
|
|
token_data = await self.get_refresh_token(token)
|
|
if token_data:
|
|
token_data["is_revoked"] = True
|
|
token_data["revoked_at"] = datetime.now(timezone.utc).isoformat()
|
|
await self.redis.set(key, json.dumps(token_data, default=str))
|
|
logger.info(f"🔐 Revoked refresh token")
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"❌ Error revoking refresh token: {e}")
|
|
return False
|
|
|
|
async def revoke_all_user_tokens(self, user_id: str) -> bool:
|
|
"""Revoke all refresh tokens for a user"""
|
|
try:
|
|
# This requires scanning all refresh tokens - consider using a user token index for efficiency
|
|
pattern = "refresh_token:*"
|
|
cursor = 0
|
|
revoked_count = 0
|
|
|
|
while True:
|
|
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
|
|
|
|
for key in keys:
|
|
token_data = await self.redis.get(key)
|
|
if token_data:
|
|
token_info = json.loads(token_data)
|
|
if token_info.get("user_id") == user_id and not token_info.get("is_revoked"):
|
|
token_info["is_revoked"] = True
|
|
token_info["revoked_at"] = datetime.now(timezone.utc).isoformat()
|
|
await self.redis.set(key, json.dumps(token_info, default=str))
|
|
revoked_count += 1
|
|
|
|
if cursor == 0:
|
|
break
|
|
|
|
logger.info(f"🔐 Revoked {revoked_count} refresh tokens for user {user_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"❌ Error revoking all tokens for user {user_id}: {e}")
|
|
return False
|
|
|
|
# Password Reset Token Methods
|
|
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool:
|
|
"""Store password reset token"""
|
|
try:
|
|
key = f"password_reset:{token}"
|
|
token_data = {
|
|
"email": email.lower(),
|
|
"expires_at": expires_at.isoformat(),
|
|
"used": False,
|
|
"created_at": datetime.now(timezone.utc).isoformat()
|
|
}
|
|
|
|
# Store with expiration
|
|
ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
|
|
if ttl_seconds > 0:
|
|
await self.redis.setex(key, ttl_seconds, json.dumps(token_data, default=str))
|
|
logger.info(f"🔐 Stored password reset token for {email}")
|
|
return True
|
|
else:
|
|
logger.warning(f"⚠️ Attempted to store expired password reset token for {email}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"❌ Error storing password reset token for {email}: {e}")
|
|
return False
|
|
|
|
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]:
|
|
"""Retrieve password reset token data"""
|
|
try:
|
|
key = f"password_reset:{token}"
|
|
data = await self.redis.get(key)
|
|
if data:
|
|
return json.loads(data)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"❌ Error retrieving password reset token: {e}")
|
|
return None
|
|
|
|
async def mark_password_reset_token_used(self, token: str) -> bool:
|
|
"""Mark password reset token as used"""
|
|
try:
|
|
key = f"password_reset:{token}"
|
|
token_data = await self.get_password_reset_token(token)
|
|
if token_data:
|
|
token_data["used"] = True
|
|
token_data["used_at"] = datetime.now(timezone.utc).isoformat()
|
|
await self.redis.set(key, json.dumps(token_data, default=str))
|
|
logger.info(f"🔐 Marked password reset token as used")
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"❌ Error marking password reset token as used: {e}")
|
|
return False
|
|
|
|
# User Activity and Security Logging
|
|
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool:
|
|
"""Log security events for audit purposes"""
|
|
try:
|
|
key = f"security_log:{user_id}:{datetime.now(timezone.utc).strftime('%Y-%m-%d')}"
|
|
event_data = {
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"user_id": user_id,
|
|
"event_type": event_type,
|
|
"details": details
|
|
}
|
|
|
|
# Add to list (latest events first)
|
|
await self.redis.lpush(key, json.dumps(event_data, default=str))# type: ignore
|
|
|
|
# Keep only last 100 events per day
|
|
await self.redis.ltrim(key, 0, 99)# type: ignore
|
|
|
|
# Set expiration for 30 days
|
|
await self.redis.expire(key, 30 * 24 * 60 * 60)
|
|
|
|
logger.info(f"🔒 Logged security event {event_type} for user {user_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"❌ Error logging security event for {user_id}: {e}")
|
|
return False
|
|
|
|
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]:
|
|
"""Retrieve security log for a user"""
|
|
try:
|
|
events = []
|
|
for i in range(days):
|
|
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime('%Y-%m-%d')
|
|
key = f"security_log:{user_id}:{date}"
|
|
|
|
daily_events = await self.redis.lrange(key, 0, -1)# type: ignore
|
|
for event_json in daily_events:
|
|
events.append(json.loads(event_json))
|
|
|
|
# Sort by timestamp (most recent first)
|
|
events.sort(key=lambda x: x["timestamp"], reverse=True)
|
|
return events
|
|
except Exception as e:
|
|
logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
|
|
return []
|
|
|