from datetime import datetime, timedelta, timezone import json import logging from typing import Any, Dict, TYPE_CHECKING, List, Optional from .protocols import DatabaseProtocol from ..constants import KEY_PREFIXES logger = logging.getLogger(__name__) class AuthMixin(DatabaseProtocol): """Mixin for auth-related database operations""" async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]: """Find pending verification token by email address""" try: pattern = "email_verification:*" cursor = 0 email_lower = email.lower() while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) for key in keys: token_data = await self.redis.get(key) if token_data: verification_info = json.loads(token_data) if (verification_info.get("email", "").lower() == email_lower and not verification_info.get("verified", False)): # Extract token from key token = key.replace("email_verification:", "") verification_info["token"] = token return verification_info if cursor == 0: break return None except Exception as e: logger.error(f"โŒ Error finding verification token by email {email}: {e}") return None async def get_pending_verifications_count(self) -> int: """Get count of pending email verifications (admin function)""" try: pattern = "email_verification:*" cursor = 0 count = 0 while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) for key in keys: token_data = await self.redis.get(key) if token_data: verification_info = json.loads(token_data) if not verification_info.get("verified", False): count += 1 if cursor == 0: break return count except Exception as e: logger.error(f"โŒ Error counting pending verifications: {e}") return 0 async def cleanup_expired_verification_tokens(self) -> int: """Clean up expired verification tokens and return count of cleaned tokens""" try: pattern = "email_verification:*" cursor = 0 cleaned_count = 0 current_time = datetime.now(timezone.utc) while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) for key in keys: token_data = await self.redis.get(key) if token_data: verification_info = json.loads(token_data) expires_at = datetime.fromisoformat(verification_info.get("expires_at", "")) if current_time > expires_at: await self.redis.delete(key) cleaned_count += 1 logger.info(f"๐Ÿงน Cleaned expired verification token for {verification_info.get('email')}") if cursor == 0: break if cleaned_count > 0: logger.info(f"๐Ÿงน Cleaned up {cleaned_count} expired verification tokens") return cleaned_count except Exception as e: logger.error(f"โŒ Error cleaning up expired verification tokens: {e}") return 0 async def get_verification_attempts_count(self, email: str) -> int: """Get the number of verification emails sent for an email in the last 24 hours""" try: key = f"verification_attempts:{email.lower()}" data = await self.redis.get(key) if not data: return 0 attempts_data = json.loads(data) current_time = datetime.now(timezone.utc) window_start = current_time - timedelta(hours=24) # Filter out old attempts recent_attempts = [ attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start ] return len(recent_attempts) except Exception as e: logger.error(f"โŒ Error getting verification attempts count for {email}: {e}") return 0 async def record_verification_attempt(self, email: str) -> bool: """Record a verification email attempt""" try: key = f"verification_attempts:{email.lower()}" current_time = datetime.now(timezone.utc) # Get existing attempts data = await self.redis.get(key) attempts_data = json.loads(data) if data else [] # Add current attempt attempts_data.append(current_time.isoformat()) # Keep only last 24 hours of attempts window_start = current_time - timedelta(hours=24) recent_attempts = [ attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start ] # Store with 24 hour expiration await self.redis.setex( key, 24 * 60 * 60, # 24 hours json.dumps(recent_attempts) ) return True except Exception as e: logger.error(f"โŒ Error recording verification attempt for {email}: {e}") return False async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool: """Store email verification token with user data""" try: key = f"email_verification:{token}" verification_data = { "email": email.lower(), "user_type": user_type, "user_data": user_data, "expires_at": (datetime.now(timezone.utc) + timedelta(hours=24)).isoformat(), "created_at": datetime.now(timezone.utc).isoformat(), "verified": False } # Store with 24 hour expiration await self.redis.setex( key, 24 * 60 * 60, # 24 hours in seconds json.dumps(verification_data, default=str) ) logger.info(f"๐Ÿ“ง Stored email verification token for {email}") return True except Exception as e: logger.error(f"โŒ Error storing email verification token: {e}") return False async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]: """Retrieve email verification token data""" try: key = f"email_verification:{token}" data = await self.redis.get(key) if data: return json.loads(data) return None except Exception as e: logger.error(f"โŒ Error retrieving email verification token: {e}") return None async def mark_email_verified(self, token: str) -> bool: """Mark email verification token as used""" try: key = f"email_verification:{token}" token_data = await self.get_email_verification_token(token) if token_data: token_data["verified"] = True token_data["verified_at"] = datetime.now(timezone.utc).isoformat() await self.redis.setex( key, 24 * 60 * 60, # Keep for remaining TTL json.dumps(token_data, default=str) ) return True return False except Exception as e: logger.error(f"โŒ Error marking email verified: {e}") return False async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool: """Store MFA code for verification""" try: logger.info("๐Ÿ” Storing MFA code for email: %s", email ) key = f"mfa_code:{email.lower()}:{device_id}" mfa_data = { "code": code, "email": email.lower(), "device_id": device_id, "expires_at": (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(), "created_at": datetime.now(timezone.utc).isoformat(), "attempts": 0, "verified": False } # Store with 10 minute expiration await self.redis.setex( key, 10 * 60, # 10 minutes in seconds json.dumps(mfa_data, default=str) ) logger.info(f"๐Ÿ” Stored MFA code for {email}") return True except Exception as e: logger.error(f"โŒ Error storing MFA code: {e}") return False async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]: """Retrieve MFA code data""" try: key = f"mfa_code:{email.lower()}:{device_id}" data = await self.redis.get(key) if data: return json.loads(data) return None except Exception as e: logger.error(f"โŒ Error retrieving MFA code: {e}") return None async def increment_mfa_attempts(self, email: str, device_id: str) -> int: """Increment MFA verification attempts""" try: key = f"mfa_code:{email.lower()}:{device_id}" mfa_data = await self.get_mfa_code(email, device_id) if mfa_data: mfa_data["attempts"] += 1 await self.redis.setex( key, 10 * 60, # Keep original TTL json.dumps(mfa_data, default=str) ) return mfa_data["attempts"] return 0 except Exception as e: logger.error(f"โŒ Error incrementing MFA attempts: {e}") return 0 async def mark_mfa_verified(self, email: str, device_id: str) -> bool: """Mark MFA code as verified""" try: key = f"mfa_code:{email.lower()}:{device_id}" mfa_data = await self.get_mfa_code(email, device_id) if mfa_data: mfa_data["verified"] = True mfa_data["verified_at"] = datetime.now(timezone.utc).isoformat() await self.redis.setex( key, 10 * 60, # Keep for remaining TTL json.dumps(mfa_data, default=str) ) return True return False except Exception as e: logger.error(f"โŒ Error marking MFA verified: {e}") return False async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool: """Store authentication record for a user""" try: key = f"auth:{user_id}" await self.redis.set(key, json.dumps(auth_data, default=str)) logger.info(f"๐Ÿ” Stored authentication record for user {user_id}") return True except Exception as e: logger.error(f"โŒ Error storing authentication record for {user_id}: {e}") return False async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]: """Retrieve authentication record for a user""" try: key = f"auth:{user_id}" data = await self.redis.get(key) if data: return json.loads(data) return None except Exception as e: logger.error(f"โŒ Error retrieving authentication record for {user_id}: {e}") return None async def delete_authentication(self, user_id: str) -> bool: """Delete authentication record for a user""" try: key = f"auth:{user_id}" result = await self.redis.delete(key) logger.info(f"๐Ÿ” Deleted authentication record for user {user_id}") return result > 0 except Exception as e: logger.error(f"โŒ Error deleting authentication record for {user_id}: {e}") return False async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool: """Store refresh token for a user""" try: key = f"refresh_token:{token}" token_data = { "user_id": user_id, "expires_at": expires_at.isoformat(), "device": device_info.get("device", "unknown"), "ip_address": device_info.get("ip_address", "unknown"), "is_revoked": False, "created_at": datetime.now(timezone.utc).isoformat() } # Store with expiration ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds()) if ttl_seconds > 0: await self.redis.setex(key, ttl_seconds, json.dumps(token_data, default=str)) logger.info(f"๐Ÿ” Stored refresh token for user {user_id}") return True else: logger.warning(f"โš ๏ธ Attempted to store expired refresh token for user {user_id}") return False except Exception as e: logger.error(f"โŒ Error storing refresh token for {user_id}: {e}") return False async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]: """Retrieve refresh token data""" try: key = f"refresh_token:{token}" data = await self.redis.get(key) if data: return json.loads(data) return None except Exception as e: logger.error(f"โŒ Error retrieving refresh token: {e}") return None async def revoke_refresh_token(self, token: str) -> bool: """Revoke a refresh token""" try: key = f"refresh_token:{token}" token_data = await self.get_refresh_token(token) if token_data: token_data["is_revoked"] = True token_data["revoked_at"] = datetime.now(timezone.utc).isoformat() await self.redis.set(key, json.dumps(token_data, default=str)) logger.info(f"๐Ÿ” Revoked refresh token") return True return False except Exception as e: logger.error(f"โŒ Error revoking refresh token: {e}") return False async def revoke_all_user_tokens(self, user_id: str) -> bool: """Revoke all refresh tokens for a user""" try: # This requires scanning all refresh tokens - consider using a user token index for efficiency pattern = "refresh_token:*" cursor = 0 revoked_count = 0 while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) for key in keys: token_data = await self.redis.get(key) if token_data: token_info = json.loads(token_data) if token_info.get("user_id") == user_id and not token_info.get("is_revoked"): token_info["is_revoked"] = True token_info["revoked_at"] = datetime.now(timezone.utc).isoformat() await self.redis.set(key, json.dumps(token_info, default=str)) revoked_count += 1 if cursor == 0: break logger.info(f"๐Ÿ” Revoked {revoked_count} refresh tokens for user {user_id}") return True except Exception as e: logger.error(f"โŒ Error revoking all tokens for user {user_id}: {e}") return False # Password Reset Token Methods async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool: """Store password reset token""" try: key = f"password_reset:{token}" token_data = { "email": email.lower(), "expires_at": expires_at.isoformat(), "used": False, "created_at": datetime.now(timezone.utc).isoformat() } # Store with expiration ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds()) if ttl_seconds > 0: await self.redis.setex(key, ttl_seconds, json.dumps(token_data, default=str)) logger.info(f"๐Ÿ” Stored password reset token for {email}") return True else: logger.warning(f"โš ๏ธ Attempted to store expired password reset token for {email}") return False except Exception as e: logger.error(f"โŒ Error storing password reset token for {email}: {e}") return False async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]: """Retrieve password reset token data""" try: key = f"password_reset:{token}" data = await self.redis.get(key) if data: return json.loads(data) return None except Exception as e: logger.error(f"โŒ Error retrieving password reset token: {e}") return None async def mark_password_reset_token_used(self, token: str) -> bool: """Mark password reset token as used""" try: key = f"password_reset:{token}" token_data = await self.get_password_reset_token(token) if token_data: token_data["used"] = True token_data["used_at"] = datetime.now(timezone.utc).isoformat() await self.redis.set(key, json.dumps(token_data, default=str)) logger.info(f"๐Ÿ” Marked password reset token as used") return True return False except Exception as e: logger.error(f"โŒ Error marking password reset token as used: {e}") return False # User Activity and Security Logging async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool: """Log security events for audit purposes""" try: key = f"security_log:{user_id}:{datetime.now(timezone.utc).strftime('%Y-%m-%d')}" event_data = { "timestamp": datetime.now(timezone.utc).isoformat(), "user_id": user_id, "event_type": event_type, "details": details } # Add to list (latest events first) await self.redis.lpush(key, json.dumps(event_data, default=str))# type: ignore # Keep only last 100 events per day await self.redis.ltrim(key, 0, 99)# type: ignore # Set expiration for 30 days await self.redis.expire(key, 30 * 24 * 60 * 60) logger.info(f"๐Ÿ”’ Logged security event {event_type} for user {user_id}") return True except Exception as e: logger.error(f"โŒ Error logging security event for {user_id}: {e}") return False async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]: """Retrieve security log for a user""" try: events = [] for i in range(days): date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime('%Y-%m-%d') key = f"security_log:{user_id}:{date}" daily_events = await self.redis.lrange(key, 0, -1)# type: ignore for event_json in daily_events: events.append(json.loads(event_json)) # Sort by timestamp (most recent first) events.sort(key=lambda x: x["timestamp"], reverse=True) return events except Exception as e: logger.error(f"โŒ Error retrieving security log for {user_id}: {e}") return []