""" Shared dependencies for FastAPI routes """ from __future__ import annotations import jwt import os from datetime import datetime, timedelta, timezone, UTC from typing import Optional from fastapi import HTTPException, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from prometheus_client import CollectorRegistry from prometheus_fastapi_instrumentator import Instrumentator import defines from database import RedisDatabase, redis_manager, DatabaseManager from models import BaseUserWithType, Candidate, CandidateAI, Employer, Guest from logger import logger from background_tasks import BackgroundTaskManager #from . rate_limiter import RateLimiter # Security security = HTTPBearer() JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "") if JWT_SECRET_KEY == "": raise ValueError("JWT_SECRET_KEY environment variable is not set") ALGORITHM = "HS256" background_task_manager: Optional[BackgroundTaskManager] = None # Global database manager reference db_manager = None def set_db_manager(manager: DatabaseManager): """Set the global database manager reference""" global db_manager db_manager = manager def get_database() -> RedisDatabase: """ Safe database dependency that checks for availability Raises HTTP 503 if database is not available """ global db_manager if db_manager is None: logger.error("Database manager not initialized") raise HTTPException( status_code=503, detail="Database not available - service starting up" ) if db_manager.is_shutting_down: logger.warning("Database is shutting down") raise HTTPException( status_code=503, detail="Service is shutting down" ) try: return db_manager.get_database() except RuntimeError as e: logger.error(f"Database not available: {e}") raise HTTPException( status_code=503, detail="Database connection not available" ) def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: expire = datetime.now(UTC) + expires_delta else: expire = datetime.now(UTC) + timedelta(hours=24) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)): """Enhanced token verification with guest session recovery""" try: if not db_manager: raise HTTPException(status_code=500, detail="Database not initialized") # First decode the token payload = jwt.decode(credentials.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM]) user_id: str = payload.get("sub") token_type: str = payload.get("type", "access") if user_id is None: raise HTTPException(status_code=401, detail="Invalid authentication credentials") # Check if token is blacklisted redis = redis_manager.get_client() blacklist_key = f"blacklisted_token:{credentials.credentials}" is_blacklisted = await redis.exists(blacklist_key) if is_blacklisted: logger.warning(f"🚫 Attempt to use blacklisted token for user {user_id}") raise HTTPException(status_code=401, detail="Token has been revoked") # For guest tokens, verify guest still exists and update activity if token_type == "guest" or payload.get("type") == "guest": database = db_manager.get_database() guest_data = await database.get_guest(user_id) if not guest_data: logger.warning(f"🚫 Guest session not found for token: {user_id}") raise HTTPException(status_code=401, detail="Guest session expired") # Update guest activity guest_data["last_activity"] = datetime.now(UTC).isoformat() await database.set_guest(user_id, guest_data) logger.debug(f"🔄 Guest activity updated: {user_id}") return user_id except jwt.PyJWTError as e: logger.warning(f"⚠️ JWT decode error: {e}") raise HTTPException(status_code=401, detail="Invalid authentication credentials") except HTTPException: raise except Exception as e: logger.error(f"❌ Token verification error: {e}") raise HTTPException(status_code=401, detail="Token verification failed") async def get_current_user( user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database) ) -> BaseUserWithType: """Get current user from database""" try: # Check candidates candidate_data = await database.get_candidate(user_id) if candidate_data: if candidate_data.get("is_AI"): from model_cast import cast_to_base_user_with_type return cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data)) else: from model_cast import cast_to_base_user_with_type return cast_to_base_user_with_type(Candidate.model_validate(candidate_data)) # Check employers employer = await database.get_employer(user_id) if employer: return Employer.model_validate(employer) logger.warning(f"⚠️ User {user_id} not found in database") raise HTTPException(status_code=404, detail="User not found") except Exception as e: logger.error(f"❌ Error getting current user: {e}") raise HTTPException(status_code=404, detail="User not found") async def get_current_user_or_guest( user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database) ) -> BaseUserWithType: """Get current user (including guests) from database""" try: # Check candidates first candidate_data = await database.get_candidate(user_id) if candidate_data: return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data) # Check employers employer_data = await database.get_employer(user_id) if employer_data: return Employer.model_validate(employer_data) # Check guests guest_data = await database.get_guest(user_id) if guest_data: return Guest.model_validate(guest_data) logger.warning(f"⚠️ User {user_id} not found in database") raise HTTPException(status_code=404, detail="User not found") except Exception as e: logger.error(f"❌ Error getting current user: {e}") raise HTTPException(status_code=404, detail="User not found") async def get_current_admin( user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database) ) -> BaseUserWithType: user = await get_current_user(user_id=user_id, database=database) if isinstance(user, Candidate) and user.is_admin: return user elif isinstance(user, Employer) and user.is_admin: return user else: logger.warning(f"⚠️ User {user_id} is not an admin") raise HTTPException(status_code=403, detail="Admin access required") prometheus_collector = CollectorRegistry() # Keep the Instrumentator instance alive instrumentator = Instrumentator( should_group_status_codes=True, should_ignore_untemplated=True, should_group_untemplated=True, excluded_handlers=[f"{defines.api_prefix}/metrics"], registry=prometheus_collector )