207 lines
7.7 KiB
Python
207 lines
7.7 KiB
Python
"""
|
|
Shared dependencies for FastAPI routes
|
|
"""
|
|
from __future__ import annotations
|
|
import jwt
|
|
import os
|
|
from datetime import datetime, timedelta, timezone, UTC
|
|
from typing import Optional
|
|
|
|
from fastapi import HTTPException, Depends
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from prometheus_client import CollectorRegistry
|
|
from prometheus_fastapi_instrumentator import Instrumentator
|
|
|
|
import defines
|
|
|
|
from database import RedisDatabase, redis_manager, DatabaseManager
|
|
from models import BaseUserWithType, Candidate, CandidateAI, Employer, Guest
|
|
from logger import logger
|
|
from background_tasks import BackgroundTaskManager
|
|
|
|
#from . rate_limiter import RateLimiter
|
|
|
|
# Security
|
|
security = HTTPBearer()
|
|
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "")
|
|
if JWT_SECRET_KEY == "":
|
|
raise ValueError("JWT_SECRET_KEY environment variable is not set")
|
|
ALGORITHM = "HS256"
|
|
|
|
background_task_manager: Optional[BackgroundTaskManager] = None
|
|
|
|
# Global database manager reference
|
|
db_manager = None
|
|
|
|
def set_db_manager(manager: DatabaseManager):
|
|
"""Set the global database manager reference"""
|
|
global db_manager
|
|
db_manager = manager
|
|
|
|
def get_database() -> RedisDatabase:
|
|
"""
|
|
Safe database dependency that checks for availability
|
|
Raises HTTP 503 if database is not available
|
|
"""
|
|
global db_manager
|
|
|
|
if db_manager is None:
|
|
logger.error("Database manager not initialized")
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Database not available - service starting up"
|
|
)
|
|
|
|
if db_manager.is_shutting_down:
|
|
logger.warning("Database is shutting down")
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Service is shutting down"
|
|
)
|
|
|
|
try:
|
|
return db_manager.get_database()
|
|
except RuntimeError as e:
|
|
logger.error(f"Database not available: {e}")
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Database connection not available"
|
|
)
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.now(UTC) + expires_delta
|
|
else:
|
|
expire = datetime.now(UTC) + timedelta(hours=24)
|
|
to_encode.update({"exp": expire})
|
|
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
"""Enhanced token verification with guest session recovery"""
|
|
try:
|
|
if not db_manager:
|
|
raise HTTPException(status_code=500, detail="Database not initialized")
|
|
# First decode the token
|
|
payload = jwt.decode(credentials.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM])
|
|
user_id: str = payload.get("sub")
|
|
token_type: str = payload.get("type", "access")
|
|
|
|
if user_id is None:
|
|
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
|
|
|
# Check if token is blacklisted
|
|
redis = redis_manager.get_client()
|
|
blacklist_key = f"blacklisted_token:{credentials.credentials}"
|
|
|
|
is_blacklisted = await redis.exists(blacklist_key)
|
|
if is_blacklisted:
|
|
logger.warning(f"🚫 Attempt to use blacklisted token for user {user_id}")
|
|
raise HTTPException(status_code=401, detail="Token has been revoked")
|
|
|
|
# For guest tokens, verify guest still exists and update activity
|
|
if token_type == "guest" or payload.get("type") == "guest":
|
|
database = db_manager.get_database()
|
|
guest_data = await database.get_guest(user_id)
|
|
|
|
if not guest_data:
|
|
logger.warning(f"🚫 Guest session not found for token: {user_id}")
|
|
raise HTTPException(status_code=401, detail="Guest session expired")
|
|
|
|
# Update guest activity
|
|
guest_data["last_activity"] = datetime.now(UTC).isoformat()
|
|
await database.set_guest(user_id, guest_data)
|
|
logger.debug(f"🔄 Guest activity updated: {user_id}")
|
|
|
|
return user_id
|
|
|
|
except jwt.PyJWTError as e:
|
|
logger.warning(f"⚠️ JWT decode error: {e}")
|
|
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"❌ Token verification error: {e}")
|
|
raise HTTPException(status_code=401, detail="Token verification failed")
|
|
|
|
async def get_current_user(
|
|
user_id: str = Depends(verify_token_with_blacklist),
|
|
database: RedisDatabase = Depends(get_database)
|
|
) -> BaseUserWithType:
|
|
"""Get current user from database"""
|
|
try:
|
|
# Check candidates
|
|
candidate_data = await database.get_candidate(user_id)
|
|
if candidate_data:
|
|
if candidate_data.get("is_AI"):
|
|
from model_cast import cast_to_base_user_with_type
|
|
return cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data))
|
|
else:
|
|
from model_cast import cast_to_base_user_with_type
|
|
return cast_to_base_user_with_type(Candidate.model_validate(candidate_data))
|
|
|
|
# Check employers
|
|
employer = await database.get_employer(user_id)
|
|
if employer:
|
|
return Employer.model_validate(employer)
|
|
|
|
logger.warning(f"⚠️ User {user_id} not found in database")
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error getting current user: {e}")
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
async def get_current_user_or_guest(
|
|
user_id: str = Depends(verify_token_with_blacklist),
|
|
database: RedisDatabase = Depends(get_database)
|
|
) -> BaseUserWithType:
|
|
"""Get current user (including guests) from database"""
|
|
try:
|
|
# Check candidates first
|
|
candidate_data = await database.get_candidate(user_id)
|
|
if candidate_data:
|
|
return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data)
|
|
|
|
# Check employers
|
|
employer_data = await database.get_employer(user_id)
|
|
if employer_data:
|
|
return Employer.model_validate(employer_data)
|
|
|
|
# Check guests
|
|
guest_data = await database.get_guest(user_id)
|
|
if guest_data:
|
|
return Guest.model_validate(guest_data)
|
|
|
|
logger.warning(f"⚠️ User {user_id} not found in database")
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error getting current user: {e}")
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
async def get_current_admin(
|
|
user_id: str = Depends(verify_token_with_blacklist),
|
|
database: RedisDatabase = Depends(get_database)
|
|
) -> BaseUserWithType:
|
|
user = await get_current_user(user_id=user_id, database=database)
|
|
if isinstance(user, Candidate) and user.is_admin:
|
|
return user
|
|
elif isinstance(user, Employer) and user.is_admin:
|
|
return user
|
|
else:
|
|
logger.warning(f"⚠️ User {user_id} is not an admin")
|
|
raise HTTPException(status_code=403, detail="Admin access required")
|
|
|
|
prometheus_collector = CollectorRegistry()
|
|
|
|
# Keep the Instrumentator instance alive
|
|
instrumentator = Instrumentator(
|
|
should_group_status_codes=True,
|
|
should_ignore_untemplated=True,
|
|
should_group_untemplated=True,
|
|
excluded_handlers=[f"{defines.api_prefix}/metrics"],
|
|
registry=prometheus_collector
|
|
)
|