514 lines
21 KiB
Python

from datetime import datetime, timedelta, timezone
import json
import logging
from typing import Any, Dict, TYPE_CHECKING, List, Optional
from .protocols import DatabaseProtocol
from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class AuthMixin(DatabaseProtocol):
"""Mixin for auth-related database operations"""
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]:
"""Find pending verification token by email address"""
try:
pattern = "email_verification:*"
cursor = 0
email_lower = email.lower()
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
verification_info = json.loads(token_data)
if (verification_info.get("email", "").lower() == email_lower and
not verification_info.get("verified", False)):
# Extract token from key
token = key.replace("email_verification:", "")
verification_info["token"] = token
return verification_info
if cursor == 0:
break
return None
except Exception as e:
logger.error(f"❌ Error finding verification token by email {email}: {e}")
return None
async def get_pending_verifications_count(self) -> int:
"""Get count of pending email verifications (admin function)"""
try:
pattern = "email_verification:*"
cursor = 0
count = 0
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
verification_info = json.loads(token_data)
if not verification_info.get("verified", False):
count += 1
if cursor == 0:
break
return count
except Exception as e:
logger.error(f"❌ Error counting pending verifications: {e}")
return 0
async def cleanup_expired_verification_tokens(self) -> int:
"""Clean up expired verification tokens and return count of cleaned tokens"""
try:
pattern = "email_verification:*"
cursor = 0
cleaned_count = 0
current_time = datetime.now(timezone.utc)
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
verification_info = json.loads(token_data)
expires_at = datetime.fromisoformat(verification_info.get("expires_at", ""))
if current_time > expires_at:
await self.redis.delete(key)
cleaned_count += 1
logger.info(f"🧹 Cleaned expired verification token for {verification_info.get('email')}")
if cursor == 0:
break
if cleaned_count > 0:
logger.info(f"🧹 Cleaned up {cleaned_count} expired verification tokens")
return cleaned_count
except Exception as e:
logger.error(f"❌ Error cleaning up expired verification tokens: {e}")
return 0
async def get_verification_attempts_count(self, email: str) -> int:
"""Get the number of verification emails sent for an email in the last 24 hours"""
try:
key = f"verification_attempts:{email.lower()}"
data = await self.redis.get(key)
if not data:
return 0
attempts_data = json.loads(data)
current_time = datetime.now(timezone.utc)
window_start = current_time - timedelta(hours=24)
# Filter out old attempts
recent_attempts = [
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > window_start
]
return len(recent_attempts)
except Exception as e:
logger.error(f"❌ Error getting verification attempts count for {email}: {e}")
return 0
async def record_verification_attempt(self, email: str) -> bool:
"""Record a verification email attempt"""
try:
key = f"verification_attempts:{email.lower()}"
current_time = datetime.now(timezone.utc)
# Get existing attempts
data = await self.redis.get(key)
attempts_data = json.loads(data) if data else []
# Add current attempt
attempts_data.append(current_time.isoformat())
# Keep only last 24 hours of attempts
window_start = current_time - timedelta(hours=24)
recent_attempts = [
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > window_start
]
# Store with 24 hour expiration
await self.redis.setex(
key,
24 * 60 * 60, # 24 hours
json.dumps(recent_attempts)
)
return True
except Exception as e:
logger.error(f"❌ Error recording verification attempt for {email}: {e}")
return False
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool:
"""Store email verification token with user data"""
try:
key = f"email_verification:{token}"
verification_data = {
"email": email.lower(),
"user_type": user_type,
"user_data": user_data,
"expires_at": (datetime.now(timezone.utc) + timedelta(hours=24)).isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(),
"verified": False
}
# Store with 24 hour expiration
await self.redis.setex(
key,
24 * 60 * 60, # 24 hours in seconds
json.dumps(verification_data, default=str)
)
logger.info(f"📧 Stored email verification token for {email}")
return True
except Exception as e:
logger.error(f"❌ Error storing email verification token: {e}")
return False
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Retrieve email verification token data"""
try:
key = f"email_verification:{token}"
data = await self.redis.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"❌ Error retrieving email verification token: {e}")
return None
async def mark_email_verified(self, token: str) -> bool:
"""Mark email verification token as used"""
try:
key = f"email_verification:{token}"
token_data = await self.get_email_verification_token(token)
if token_data:
token_data["verified"] = True
token_data["verified_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.setex(
key,
24 * 60 * 60, # Keep for remaining TTL
json.dumps(token_data, default=str)
)
return True
return False
except Exception as e:
logger.error(f"❌ Error marking email verified: {e}")
return False
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool:
"""Store MFA code for verification"""
try:
logger.info("🔐 Storing MFA code for email: %s", email )
key = f"mfa_code:{email.lower()}:{device_id}"
mfa_data = {
"code": code,
"email": email.lower(),
"device_id": device_id,
"expires_at": (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(),
"attempts": 0,
"verified": False
}
# Store with 10 minute expiration
await self.redis.setex(
key,
10 * 60, # 10 minutes in seconds
json.dumps(mfa_data, default=str)
)
logger.info(f"🔐 Stored MFA code for {email}")
return True
except Exception as e:
logger.error(f"❌ Error storing MFA code: {e}")
return False
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve MFA code data"""
try:
key = f"mfa_code:{email.lower()}:{device_id}"
data = await self.redis.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"❌ Error retrieving MFA code: {e}")
return None
async def increment_mfa_attempts(self, email: str, device_id: str) -> int:
"""Increment MFA verification attempts"""
try:
key = f"mfa_code:{email.lower()}:{device_id}"
mfa_data = await self.get_mfa_code(email, device_id)
if mfa_data:
mfa_data["attempts"] += 1
await self.redis.setex(
key,
10 * 60, # Keep original TTL
json.dumps(mfa_data, default=str)
)
return mfa_data["attempts"]
return 0
except Exception as e:
logger.error(f"❌ Error incrementing MFA attempts: {e}")
return 0
async def mark_mfa_verified(self, email: str, device_id: str) -> bool:
"""Mark MFA code as verified"""
try:
key = f"mfa_code:{email.lower()}:{device_id}"
mfa_data = await self.get_mfa_code(email, device_id)
if mfa_data:
mfa_data["verified"] = True
mfa_data["verified_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.setex(
key,
10 * 60, # Keep for remaining TTL
json.dumps(mfa_data, default=str)
)
return True
return False
except Exception as e:
logger.error(f"❌ Error marking MFA verified: {e}")
return False
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool:
"""Store authentication record for a user"""
try:
key = f"auth:{user_id}"
await self.redis.set(key, json.dumps(auth_data, default=str))
logger.info(f"🔐 Stored authentication record for user {user_id}")
return True
except Exception as e:
logger.error(f"❌ Error storing authentication record for {user_id}: {e}")
return False
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve authentication record for a user"""
try:
key = f"auth:{user_id}"
data = await self.redis.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"❌ Error retrieving authentication record for {user_id}: {e}")
return None
async def delete_authentication(self, user_id: str) -> bool:
"""Delete authentication record for a user"""
try:
key = f"auth:{user_id}"
result = await self.redis.delete(key)
logger.info(f"🔐 Deleted authentication record for user {user_id}")
return result > 0
except Exception as e:
logger.error(f"❌ Error deleting authentication record for {user_id}: {e}")
return False
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool:
"""Store refresh token for a user"""
try:
key = f"refresh_token:{token}"
token_data = {
"user_id": user_id,
"expires_at": expires_at.isoformat(),
"device": device_info.get("device", "unknown"),
"ip_address": device_info.get("ip_address", "unknown"),
"is_revoked": False,
"created_at": datetime.now(timezone.utc).isoformat()
}
# Store with expiration
ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
if ttl_seconds > 0:
await self.redis.setex(key, ttl_seconds, json.dumps(token_data, default=str))
logger.info(f"🔐 Stored refresh token for user {user_id}")
return True
else:
logger.warning(f"⚠️ Attempted to store expired refresh token for user {user_id}")
return False
except Exception as e:
logger.error(f"❌ Error storing refresh token for {user_id}: {e}")
return False
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Retrieve refresh token data"""
try:
key = f"refresh_token:{token}"
data = await self.redis.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"❌ Error retrieving refresh token: {e}")
return None
async def revoke_refresh_token(self, token: str) -> bool:
"""Revoke a refresh token"""
try:
key = f"refresh_token:{token}"
token_data = await self.get_refresh_token(token)
if token_data:
token_data["is_revoked"] = True
token_data["revoked_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.set(key, json.dumps(token_data, default=str))
logger.info(f"🔐 Revoked refresh token")
return True
return False
except Exception as e:
logger.error(f"❌ Error revoking refresh token: {e}")
return False
async def revoke_all_user_tokens(self, user_id: str) -> bool:
"""Revoke all refresh tokens for a user"""
try:
# This requires scanning all refresh tokens - consider using a user token index for efficiency
pattern = "refresh_token:*"
cursor = 0
revoked_count = 0
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
token_info = json.loads(token_data)
if token_info.get("user_id") == user_id and not token_info.get("is_revoked"):
token_info["is_revoked"] = True
token_info["revoked_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.set(key, json.dumps(token_info, default=str))
revoked_count += 1
if cursor == 0:
break
logger.info(f"🔐 Revoked {revoked_count} refresh tokens for user {user_id}")
return True
except Exception as e:
logger.error(f"❌ Error revoking all tokens for user {user_id}: {e}")
return False
# Password Reset Token Methods
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool:
"""Store password reset token"""
try:
key = f"password_reset:{token}"
token_data = {
"email": email.lower(),
"expires_at": expires_at.isoformat(),
"used": False,
"created_at": datetime.now(timezone.utc).isoformat()
}
# Store with expiration
ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
if ttl_seconds > 0:
await self.redis.setex(key, ttl_seconds, json.dumps(token_data, default=str))
logger.info(f"🔐 Stored password reset token for {email}")
return True
else:
logger.warning(f"⚠️ Attempted to store expired password reset token for {email}")
return False
except Exception as e:
logger.error(f"❌ Error storing password reset token for {email}: {e}")
return False
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Retrieve password reset token data"""
try:
key = f"password_reset:{token}"
data = await self.redis.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"❌ Error retrieving password reset token: {e}")
return None
async def mark_password_reset_token_used(self, token: str) -> bool:
"""Mark password reset token as used"""
try:
key = f"password_reset:{token}"
token_data = await self.get_password_reset_token(token)
if token_data:
token_data["used"] = True
token_data["used_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.set(key, json.dumps(token_data, default=str))
logger.info(f"🔐 Marked password reset token as used")
return True
return False
except Exception as e:
logger.error(f"❌ Error marking password reset token as used: {e}")
return False
# User Activity and Security Logging
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool:
"""Log security events for audit purposes"""
try:
key = f"security_log:{user_id}:{datetime.now(timezone.utc).strftime('%Y-%m-%d')}"
event_data = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"user_id": user_id,
"event_type": event_type,
"details": details
}
# Add to list (latest events first)
await self.redis.lpush(key, json.dumps(event_data, default=str))# type: ignore
# Keep only last 100 events per day
await self.redis.ltrim(key, 0, 99)# type: ignore
# Set expiration for 30 days
await self.redis.expire(key, 30 * 24 * 60 * 60)
logger.info(f"🔒 Logged security event {event_type} for user {user_id}")
return True
except Exception as e:
logger.error(f"❌ Error logging security event for {user_id}: {e}")
return False
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]:
"""Retrieve security log for a user"""
try:
events = []
for i in range(days):
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime('%Y-%m-%d')
key = f"security_log:{user_id}:{date}"
daily_events = await self.redis.lrange(key, 0, -1)# type: ignore
for event_json in daily_events:
events.append(json.loads(event_json))
# Sort by timestamp (most recent first)
events.sort(key=lambda x: x["timestamp"], reverse=True)
return events
except Exception as e:
logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
return []