backstory/src/backend/utils/dependencies.py

207 lines
7.7 KiB
Python

"""
Shared dependencies for FastAPI routes
"""
from __future__ import annotations
import jwt
import os
from datetime import datetime, timedelta, timezone, UTC
from typing import Optional
from fastapi import HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from prometheus_client import CollectorRegistry
from prometheus_fastapi_instrumentator import Instrumentator
import defines
from database import RedisDatabase, redis_manager, DatabaseManager
from models import BaseUserWithType, Candidate, CandidateAI, Employer, Guest
from logger import logger
from background_tasks import BackgroundTaskManager
#from . rate_limiter import RateLimiter
# Security
security = HTTPBearer()
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "")
if JWT_SECRET_KEY == "":
raise ValueError("JWT_SECRET_KEY environment variable is not set")
ALGORITHM = "HS256"
background_task_manager: Optional[BackgroundTaskManager] = None
# Global database manager reference
db_manager = None
def set_db_manager(manager: DatabaseManager):
"""Set the global database manager reference"""
global db_manager
db_manager = manager
def get_database() -> RedisDatabase:
"""
Safe database dependency that checks for availability
Raises HTTP 503 if database is not available
"""
global db_manager
if db_manager is None:
logger.error("Database manager not initialized")
raise HTTPException(
status_code=503,
detail="Database not available - service starting up"
)
if db_manager.is_shutting_down:
logger.warning("Database is shutting down")
raise HTTPException(
status_code=503,
detail="Service is shutting down"
)
try:
return db_manager.get_database()
except RuntimeError as e:
logger.error(f"Database not available: {e}")
raise HTTPException(
status_code=503,
detail="Database connection not available"
)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(UTC) + timedelta(hours=24)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Enhanced token verification with guest session recovery"""
try:
if not db_manager:
raise HTTPException(status_code=500, detail="Database not initialized")
# First decode the token
payload = jwt.decode(credentials.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
token_type: str = payload.get("type", "access")
if user_id is None:
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
# Check if token is blacklisted
redis = redis_manager.get_client()
blacklist_key = f"blacklisted_token:{credentials.credentials}"
is_blacklisted = await redis.exists(blacklist_key)
if is_blacklisted:
logger.warning(f"🚫 Attempt to use blacklisted token for user {user_id}")
raise HTTPException(status_code=401, detail="Token has been revoked")
# For guest tokens, verify guest still exists and update activity
if token_type == "guest" or payload.get("type") == "guest":
database = db_manager.get_database()
guest_data = await database.get_guest(user_id)
if not guest_data:
logger.warning(f"🚫 Guest session not found for token: {user_id}")
raise HTTPException(status_code=401, detail="Guest session expired")
# Update guest activity
guest_data["last_activity"] = datetime.now(UTC).isoformat()
await database.set_guest(user_id, guest_data)
logger.debug(f"🔄 Guest activity updated: {user_id}")
return user_id
except jwt.PyJWTError as e:
logger.warning(f"⚠️ JWT decode error: {e}")
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Token verification error: {e}")
raise HTTPException(status_code=401, detail="Token verification failed")
async def get_current_user(
user_id: str = Depends(verify_token_with_blacklist),
database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType:
"""Get current user from database"""
try:
# Check candidates
candidate_data = await database.get_candidate(user_id)
if candidate_data:
if candidate_data.get("is_AI"):
from model_cast import cast_to_base_user_with_type
return cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data))
else:
from model_cast import cast_to_base_user_with_type
return cast_to_base_user_with_type(Candidate.model_validate(candidate_data))
# Check employers
employer = await database.get_employer(user_id)
if employer:
return Employer.model_validate(employer)
logger.warning(f"⚠️ User {user_id} not found in database")
raise HTTPException(status_code=404, detail="User not found")
except Exception as e:
logger.error(f"❌ Error getting current user: {e}")
raise HTTPException(status_code=404, detail="User not found")
async def get_current_user_or_guest(
user_id: str = Depends(verify_token_with_blacklist),
database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType:
"""Get current user (including guests) from database"""
try:
# Check candidates first
candidate_data = await database.get_candidate(user_id)
if candidate_data:
return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data)
# Check employers
employer_data = await database.get_employer(user_id)
if employer_data:
return Employer.model_validate(employer_data)
# Check guests
guest_data = await database.get_guest(user_id)
if guest_data:
return Guest.model_validate(guest_data)
logger.warning(f"⚠️ User {user_id} not found in database")
raise HTTPException(status_code=404, detail="User not found")
except Exception as e:
logger.error(f"❌ Error getting current user: {e}")
raise HTTPException(status_code=404, detail="User not found")
async def get_current_admin(
user_id: str = Depends(verify_token_with_blacklist),
database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType:
user = await get_current_user(user_id=user_id, database=database)
if isinstance(user, Candidate) and user.is_admin:
return user
elif isinstance(user, Employer) and user.is_admin:
return user
else:
logger.warning(f"⚠️ User {user_id} is not an admin")
raise HTTPException(status_code=403, detail="Admin access required")
prometheus_collector = CollectorRegistry()
# Keep the Instrumentator instance alive
instrumentator = Instrumentator(
should_group_status_codes=True,
should_ignore_untemplated=True,
should_group_untemplated=True,
excluded_handlers=[f"{defines.api_prefix}/metrics"],
registry=prometheus_collector
)