Fix auth issue

This commit is contained in:
James Ketr 2025-08-04 16:21:29 -07:00
parent 064868e96e
commit 234148f046

View File

@ -4,6 +4,7 @@ Authentication routes
import json import json
import jwt import jwt
import secrets import secrets
import traceback
import uuid import uuid
import os import os
from datetime import datetime, timedelta, timezone, UTC from datetime import datetime, timedelta, timezone, UTC
@ -11,7 +12,7 @@ from typing import Any, Dict
from fastapi import APIRouter, Depends, Body, Request, BackgroundTasks from fastapi import APIRouter, Depends, Body, Request, BackgroundTasks
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import BaseModel, EmailStr, ValidationError, field_validator from pydantic import BaseModel, EmailStr, ValidationError, field_validator, Field
import backstory_traceback as backstory_traceback import backstory_traceback as backstory_traceback
from utils.rate_limiter import RateLimiter from utils.rate_limiter import RateLimiter
@ -190,7 +191,6 @@ async def create_guest_session_enhanced(
except Exception as e: except Exception as e:
logger.error(f"❌ Guest session creation error: {e}") logger.error(f"❌ Guest session creation error: {e}")
import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return JSONResponse( return JSONResponse(
@ -443,43 +443,62 @@ async def logout_all_devices(current_user=Depends(get_current_admin), database:
return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ALL_ERROR", str(e))) return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ALL_ERROR", str(e)))
class RefreshTokenRequest(BaseModel):
refresh_token: str = Field(..., alias="refreshToken")
@router.post("/refresh") @router.post("/refresh")
async def refresh_token_endpoint( async def refresh_token_endpoint(request: RefreshTokenRequest, database: RedisDatabase = Depends(get_database)):
refresh_token: str = Body(..., alias="refreshToken"), database: RedisDatabase = Depends(get_database)
):
"""Refresh token endpoint""" """Refresh token endpoint"""
try: try:
# Verify refresh token # Verify refresh token
payload = jwt.decode(refresh_token, JWT_SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(request.refresh_token, JWT_SECRET_KEY, algorithms=[ALGORITHM])
user_id = payload.get("sub") user_id = payload.get("sub")
token_type = payload.get("type") token_type = payload.get("type")
if not user_id or token_type != "refresh": if not user_id or (token_type not in ["refresh", "refresh_guest"]):
return JSONResponse( return JSONResponse(
status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token") status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
) )
# Create new access token # Create new access token
access_token = create_access_token(data={"sub": user_id}) if token_type == "refresh_guest":
access_token = create_access_token(
data={"sub": user_id, "type": "guest"},
expires_delta=timedelta(hours=48), # Longer expiry for guests
)
else:
access_token = create_access_token(data={"sub": user_id})
# Get user # Get user
user = None user = None
candidate_data = await database.get_candidate(user_id) if token_type == "refresh_guest":
if candidate_data: guest_data = await database.get_guest(user_id)
user = Candidate.model_validate(candidate_data) if guest_data:
user = Guest.model_validate(guest_data)
else: else:
employer_data = await database.get_employer(user_id) candidate_data = await database.get_candidate(user_id)
if employer_data: if candidate_data:
user = Employer.model_validate(employer_data) user = Candidate.model_validate(candidate_data)
else:
employer_data = await database.get_employer(user_id)
if employer_data:
user = Employer.model_validate(employer_data)
if not user: if not user:
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found")) return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
# Set appropriate expiry time
if token_type == "refresh_guest":
expires_at = int((datetime.now(UTC) + timedelta(hours=48)).timestamp())
else:
expires_at = int((datetime.now(UTC) + timedelta(hours=24)).timestamp())
auth_response = AuthResponse( auth_response = AuthResponse(
access_token=access_token, access_token=access_token,
refresh_token=refresh_token, # Keep same refresh token refresh_token=request.refresh_token, # Keep same refresh token
user=user, user=user,
expires_at=int((datetime.now(UTC) + timedelta(hours=24)).timestamp()), expires_at=expires_at,
) )
return create_success_response(auth_response.model_dump(by_alias=True)) return create_success_response(auth_response.model_dump(by_alias=True))