417 lines
14 KiB
Python

import hashlib
import time
import traceback
from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request, BackgroundTasks, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from functools import wraps
from typing import Callable, Any, Optional
from utils.rate_limiter import RateLimiter, RateLimitResult
import schedule
import os
import shutil
from enum import Enum
import uuid
import defines
import pathlib
from markitdown import MarkItDown, StreamInfo
import io
import uvicorn
from typing import List, Optional, Dict, Any
from datetime import datetime, timedelta, UTC
import uuid
import jwt
import os
from contextlib import asynccontextmanager
import redis.asyncio as redis
import re
import asyncio
import signal
import json
import uuid
import logging
from datetime import datetime, timezone, timedelta
from typing import Dict, Any, Optional
from pydantic import BaseModel, EmailStr, field_validator, ValidationError
# Prometheus
from prometheus_client import Summary
from prometheus_fastapi_instrumentator import Instrumentator
from prometheus_client import CollectorRegistry, Counter
import secrets
import os
import backstory_traceback
from background_tasks import BackgroundTaskManager
# =============================
# Import custom modules
# =============================
from auth_utils import (
AuthenticationManager,
validate_password_strength,
sanitize_login_input,
SecurityConfig
)
import model_cast
import defines
from logger import logger
from database.manager import RedisDatabase, redis_manager, DatabaseManager
import entities
from email_service import VerificationEmailRateLimiter, email_service
from device_manager import DeviceManager
import agents
from entities.entity_manager import entity_manager
# =============================
# Import utilities
# =============================
from utils.dependencies import get_database, set_db_manager
from utils.responses import create_success_response, create_error_response
from utils.helpers import filter_and_paginate
# =============================
# Import route modules
# =============================
from routes import (
admin,
auth,
candidates,
chat,
employers,
jobs,
providers,
resumes,
system,
users,
)
# =============================
# Import Pydantic models
# =============================
from models import (
# API
MOCK_UUID, ApiActivityType, ChatMessageError, ChatMessageResume, ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, DocumentMessage, DocumentOptions, Job, JobRequirements, JobRequirementsMessage, LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
# User models
Candidate, Employer, BaseUserWithType, BaseUser, Guest, Authentication, AuthResponse, CandidateAI,
# Job models
JobApplication, ApplicationStatus,
# Chat models
ChatSession, ChatMessage, ChatContext, ChatQuery, ApiStatusType, ChatSenderType, ApiMessageType, ChatContextType,
ChatMessageRagSearch,
# Document models
Document, DocumentType, DocumentListResponse, DocumentUpdateRequest, DocumentContentResponse,
# Supporting models
Location, MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest, Resume, ResumeMessage, Skill, SkillAssessment, SystemInfo, UserType, WorkExperience, Education,
# Email
EmailVerificationRequest
)
# Initialize FastAPI app
# ============================
# Startup Event
# ============================
db_manager = None
background_task_manager = None
prev_int = signal.getsignal(signal.SIGINT)
prev_term = signal.getsignal(signal.SIGTERM)
def signal_handler(signum, frame):
logger.info(f"⚠️ Received signal {signum!r}, shutting down…")
# now call the old handler (it might raise KeyboardInterrupt or exit)
if signum == signal.SIGINT and callable(prev_int):
prev_int(signum, frame)
elif signum == signal.SIGTERM and callable(prev_term):
prev_term(signum, frame)
# Global background task manager
background_task_manager: Optional[BackgroundTaskManager] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
global db_manager, background_task_manager
logger.info("🚀 Starting Backstory API with enhanced background tasks")
logger.info(f"📝 API Documentation available at: http://{defines.host}:{defines.port}{defines.api_prefix}/docs")
logger.info("🔗 API endpoints prefixed with: /api/1.0")
try:
# Initialize database
db_manager = DatabaseManager()
await db_manager.initialize()
# Set the database manager in dependencies
set_db_manager(db_manager)
entity_manager.initialize(prometheus_collector=prometheus_collector, database=db_manager.get_database())
# Initialize background task manager
background_task_manager = BackgroundTaskManager(db_manager)
await background_task_manager.start()
app.state.db_manager = db_manager
app.state.background_task_manager = background_task_manager
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
logger.info("🚀 Application startup completed with background tasks")
yield # Application is running
except Exception as e:
logger.error(f"❌ Failed to start application: {e}")
raise
finally:
# Shutdown
logger.info("Application shutdown requested")
# Stop background tasks first
if background_task_manager:
await background_task_manager.stop()
if db_manager:
await db_manager.graceful_shutdown()
app = FastAPI(
lifespan=lifespan,
title="Backstory API",
description="FastAPI backend for Backstory platform with TypeScript frontend",
version="1.0.0",
docs_url=f"{defines.api_prefix}/docs",
redoc_url=f"{defines.api_prefix}/redoc",
openapi_url=f"{defines.api_prefix}/openapi.json",
)
ssl_enabled = os.getenv("SSL_ENABLED", "true").lower() == "true"
if ssl_enabled:
allow_origins = ["https://battle-linux.ketrenos.com:3000",
"https://backstory-beta.ketrenos.com"]
else:
allow_origins = ["http://battle-linux.ketrenos.com:3000",
"http://backstory-beta.ketrenos.com"]
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============================
# Debug data type failures
# ============================
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
import traceback
logger.error(traceback.format_exc())
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Validation error {request.method} {request.url.path}: {str(exc)}")
return JSONResponse(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
content=json.dumps({"detail": str(exc)}),
)
# ============================
# Create API router with prefix
# ============================
api_router = APIRouter(prefix="/api/1.0")
# ============================
# Include all route modules
# ============================
api_router.include_router(admin.router)
api_router.include_router(auth.router)
api_router.include_router(candidates.router)
api_router.include_router(chat.router)
api_router.include_router(employers.router)
api_router.include_router(jobs.router)
api_router.include_router(providers.router)
api_router.include_router(resumes.router)
api_router.include_router(system.router)
api_router.include_router(users.router)
# ============================
# Health Check and Info Endpoints
# ============================
@app.get("/health")
async def health_check(
database = Depends(get_database),
):
"""Health check endpoint"""
try:
if not redis_manager.redis:
raise RuntimeError("Redis client not initialized")
# Test Redis connection
await redis_manager.redis.ping()
# Get database stats
stats = await database.get_stats()
# Redis info
redis_info = await redis_manager.redis.info()
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"database": {
"status": "connected",
"stats": stats
},
"redis": {
"version": redis_info.get("redis_version", "unknown"),
"uptime": redis_info.get("uptime_in_seconds", 0),
"memory_used": redis_info.get("used_memory_human", "unknown")
}
}
except RuntimeError as e:
return {"status": "shutting_down", "message": str(e)}
except Exception as e:
logger.error(f"❌ Health check failed: {e}")
return {"status": "error", "message": str(e)}
# ============================
# Include Router in App
# ============================
# Include the API router
app.include_router(api_router)
# ============================
# Debug logging
# ============================
logger.info(f"Debug mode is {'enabled' if defines.debug else 'disabled'}")
@app.middleware("http")
async def log_requests(request: Request, call_next):
try:
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
logger.info(f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}")
response = await call_next(request)
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
if response.status_code < 200 or response.status_code >= 300:
logger.warning(f"⚠️ Response {request.method} {response.status_code}: Path: {request.url.path}")
return response
except Exception as e:
import traceback
logger.error(traceback.format_exc())
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Error processing request: {str(e)}, Path: {request.url.path}, Method: {request.method}")
return JSONResponse(status_code=400, content={"detail": "Invalid HTTP request"})
# ============================
# Request tracking middleware
# ============================
@app.middleware("http")
async def track_requests(request, call_next):
if not db_manager:
raise RuntimeError("Database manager not initialized")
"""Middleware to track active requests during shutdown"""
if db_manager.is_shutting_down:
return JSONResponse(status_code=503, content={"error": "Application is shutting down"})
db_manager.increment_requests()
try:
response = await call_next(request)
return response
finally:
db_manager.decrement_requests()
# ============================
# FastAPI Metrics
# ============================
prometheus_collector = CollectorRegistry()
# Keep the Instrumentator instance alive
instrumentator = Instrumentator(
should_group_status_codes=True,
should_ignore_untemplated=True,
should_group_untemplated=True,
excluded_handlers=[f"{defines.api_prefix}/metrics"],
registry=prometheus_collector
)
# Instrument the FastAPI app
instrumentator.instrument(app)
# Expose the /metrics endpoint
logger.info(f"Exposing Prometheus metrics at {defines.api_prefix}/metrics")
instrumentator.expose(app, endpoint=f"{defines.api_prefix}/metrics")
# ============================
# Static File Serving
# ============================
@app.get("/{path:path}")
async def serve_static(path: str, request: Request):
full_path = os.path.join(defines.static_content, path)
if os.path.exists(full_path) and os.path.isfile(full_path):
return FileResponse(full_path)
return FileResponse(os.path.join(defines.static_content, "index.html"))
# Root endpoint when no static files
@app.get("/", include_in_schema=False)
async def root():
"""Root endpoint with API information (when no static files)"""
return {
"message": "Backstory API",
"version": "1.0.0",
"api_prefix": defines.api_prefix,
"documentation": f"{defines.api_prefix}/docs",
"health": f"{defines.api_prefix}/health"
}
async def periodic_verification_cleanup():
"""Background task to periodically clean up expired verification tokens"""
try:
database = get_database()
cleaned_count = await database.cleanup_expired_verification_tokens()
if cleaned_count > 0:
logger.info(f"🧹 Periodic cleanup: removed {cleaned_count} expired verification tokens")
except Exception as e:
logger.error(f"❌ Error in periodic verification cleanup: {e}")
if __name__ == "__main__":
host = defines.host
port = defines.port
if ssl_enabled:
logger.info(f"Starting web server at https://{host}:{port}")
uvicorn.run(
app="main:app",
host=host,
port=port,
log_config=None,
ssl_keyfile=defines.key_path,
ssl_certfile=defines.cert_path,
reload=True,
reload_excludes=["src/cli/**"],
)
else:
logger.info(f"Starting web server at http://{host}:{port}")
uvicorn.run(app="main:app", host=host, port=port, log_config=None)