Fixed database async usage with background tasks
This commit is contained in:
parent
dba497c854
commit
4edf11a62d
@ -1,28 +1,30 @@
|
||||
"""
|
||||
Background tasks for guest cleanup and system maintenance
|
||||
Fixed for event loop safety
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import schedule # type: ignore
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Dict, Any, Callable
|
||||
from logger import logger
|
||||
from database import DatabaseManager
|
||||
|
||||
class BackgroundTaskManager:
|
||||
"""Manages background tasks for the application"""
|
||||
"""Manages background tasks for the application using asyncio instead of threading"""
|
||||
|
||||
def __init__(self, database_manager: DatabaseManager):
|
||||
self.database_manager = database_manager
|
||||
self.running = False
|
||||
self.tasks = []
|
||||
self.scheduler_thread: Optional[threading.Thread] = None
|
||||
self.tasks: List[asyncio.Task] = []
|
||||
self.main_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
async def cleanup_inactive_guests(self, inactive_hours: int = 24):
|
||||
"""Clean up inactive guest sessions"""
|
||||
try:
|
||||
if self.database_manager.is_shutting_down:
|
||||
logger.info("Skipping guest cleanup - application shutting down")
|
||||
return 0
|
||||
|
||||
database = self.database_manager.get_database()
|
||||
cleaned_count = await database.cleanup_inactive_guests(inactive_hours)
|
||||
|
||||
@ -37,6 +39,10 @@ class BackgroundTaskManager:
|
||||
async def cleanup_expired_verification_tokens(self):
|
||||
"""Clean up expired email verification tokens"""
|
||||
try:
|
||||
if self.database_manager.is_shutting_down:
|
||||
logger.info("Skipping token cleanup - application shutting down")
|
||||
return 0
|
||||
|
||||
database = self.database_manager.get_database()
|
||||
cleaned_count = await database.cleanup_expired_verification_tokens()
|
||||
|
||||
@ -51,6 +57,10 @@ class BackgroundTaskManager:
|
||||
async def update_guest_statistics(self):
|
||||
"""Update guest usage statistics"""
|
||||
try:
|
||||
if self.database_manager.is_shutting_down:
|
||||
logger.info("Skipping stats update - application shutting down")
|
||||
return {}
|
||||
|
||||
database = self.database_manager.get_database()
|
||||
stats = await database.get_guest_statistics()
|
||||
|
||||
@ -68,8 +78,15 @@ class BackgroundTaskManager:
|
||||
async def cleanup_old_rate_limit_data(self, days_old: int = 7):
|
||||
"""Clean up old rate limiting data"""
|
||||
try:
|
||||
if self.database_manager.is_shutting_down:
|
||||
logger.info("Skipping rate limit cleanup - application shutting down")
|
||||
return 0
|
||||
|
||||
database = self.database_manager.get_database()
|
||||
redis = database.redis
|
||||
|
||||
# Get Redis client safely (using the event loop safe method)
|
||||
from database import redis_manager
|
||||
redis = await redis_manager.get_client()
|
||||
|
||||
# Clean up rate limit keys older than specified days
|
||||
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
|
||||
@ -103,73 +120,206 @@ class BackgroundTaskManager:
|
||||
logger.error(f"❌ Error cleaning up rate limit data: {e}")
|
||||
return 0
|
||||
|
||||
def schedule_periodic_tasks(self):
|
||||
"""Schedule periodic background tasks with safer intervals"""
|
||||
|
||||
# Guest cleanup - every 6 hours instead of every hour (less aggressive)
|
||||
schedule.every(6).hours.do(self._run_async_task, self.cleanup_inactive_guests, 48) # 48 hours instead of 24
|
||||
|
||||
# Verification token cleanup - every 12 hours
|
||||
schedule.every(12).hours.do(self._run_async_task, self.cleanup_expired_verification_tokens)
|
||||
|
||||
# Guest statistics update - every hour
|
||||
schedule.every().hour.do(self._run_async_task, self.update_guest_statistics)
|
||||
|
||||
# Rate limit data cleanup - daily at 3 AM
|
||||
schedule.every().day.at("03:00").do(self._run_async_task, self.cleanup_old_rate_limit_data, 7)
|
||||
|
||||
logger.info("📅 Background tasks scheduled with safer intervals")
|
||||
|
||||
def _run_async_task(self, coro_func, *args, **kwargs):
|
||||
"""Run an async task in the background"""
|
||||
async def cleanup_orphaned_data(self):
|
||||
"""Clean up orphaned database records"""
|
||||
try:
|
||||
# Create new event loop for this thread if needed
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
if self.database_manager.is_shutting_down:
|
||||
return 0
|
||||
|
||||
database = self.database_manager.get_database()
|
||||
|
||||
# Run the coroutine
|
||||
loop.run_until_complete(coro_func(*args, **kwargs))
|
||||
# Clean up orphaned job requirements
|
||||
orphaned_count = await database.cleanup_orphaned_job_requirements()
|
||||
|
||||
if orphaned_count > 0:
|
||||
logger.info(f"🧹 Cleaned up {orphaned_count} orphaned job requirements")
|
||||
|
||||
return orphaned_count
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error running background task {coro_func.__name__}: {e}")
|
||||
logger.error(f"❌ Error cleaning up orphaned data: {e}")
|
||||
return 0
|
||||
|
||||
def _scheduler_worker(self):
|
||||
"""Worker thread for running scheduled tasks"""
|
||||
async def _run_periodic_task(self, name: str, task_func: Callable, interval_seconds: int, *args, **kwargs):
|
||||
"""Run a periodic task safely in the same event loop"""
|
||||
logger.info(f"🔄 Starting periodic task: {name} (every {interval_seconds}s)")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
schedule.run_pending()
|
||||
time.sleep(60) # Check every minute
|
||||
# Verify we're still in the correct event loop
|
||||
current_loop = asyncio.get_running_loop()
|
||||
if current_loop != self.main_loop:
|
||||
logger.error(f"Task {name} detected event loop change! Stopping.")
|
||||
break
|
||||
|
||||
# Run the task
|
||||
await task_func(*args, **kwargs)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Periodic task {name} was cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in scheduler worker: {e}")
|
||||
time.sleep(60)
|
||||
logger.error(f"❌ Error in periodic task {name}: {e}")
|
||||
# Continue running despite errors
|
||||
|
||||
# Sleep with cancellation support
|
||||
try:
|
||||
await asyncio.sleep(interval_seconds)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Periodic task {name} cancelled during sleep")
|
||||
break
|
||||
|
||||
def start(self):
|
||||
"""Start the background task manager"""
|
||||
async def start(self):
|
||||
"""Start all background tasks in the current event loop"""
|
||||
if self.running:
|
||||
logger.warning("⚠️ Background task manager already running")
|
||||
return
|
||||
|
||||
# Store the current event loop
|
||||
self.main_loop = asyncio.get_running_loop()
|
||||
self.running = True
|
||||
self.schedule_periodic_tasks()
|
||||
|
||||
# Start scheduler thread
|
||||
self.scheduler_thread = threading.Thread(target=self._scheduler_worker, daemon=True)
|
||||
self.scheduler_thread.start()
|
||||
# Define periodic tasks with their intervals (in seconds)
|
||||
periodic_tasks = [
|
||||
# (name, function, interval_seconds, *args)
|
||||
("guest_cleanup", self.cleanup_inactive_guests, 6 * 3600, 48), # Every 6 hours, cleanup 48h old
|
||||
("token_cleanup", self.cleanup_expired_verification_tokens, 12 * 3600), # Every 12 hours
|
||||
("guest_stats", self.update_guest_statistics, 3600), # Every hour
|
||||
("rate_limit_cleanup", self.cleanup_old_rate_limit_data, 24 * 3600, 7), # Daily, cleanup 7 days old
|
||||
("orphaned_cleanup", self.cleanup_orphaned_data, 6 * 3600), # Every 6 hours
|
||||
]
|
||||
|
||||
logger.info("🚀 Background task manager started")
|
||||
# Create asyncio tasks for each periodic task
|
||||
for name, func, interval, *args in periodic_tasks:
|
||||
task = asyncio.create_task(
|
||||
self._run_periodic_task(name, func, interval, *args),
|
||||
name=f"background_{name}"
|
||||
)
|
||||
self.tasks.append(task)
|
||||
logger.info(f"📅 Scheduled background task: {name}")
|
||||
|
||||
# Run initial cleanup tasks immediately (but don't wait for them)
|
||||
asyncio.create_task(self._run_initial_cleanup(), name="initial_cleanup")
|
||||
|
||||
logger.info("🚀 Background task manager started with asyncio tasks")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the background task manager"""
|
||||
async def _run_initial_cleanup(self):
|
||||
"""Run some cleanup tasks immediately on startup"""
|
||||
try:
|
||||
logger.info("🧹 Running initial cleanup tasks...")
|
||||
|
||||
# Clean up expired tokens immediately
|
||||
await asyncio.sleep(5) # Give the app time to fully start
|
||||
await self.cleanup_expired_verification_tokens()
|
||||
|
||||
# Clean up very old inactive guests (7 days old)
|
||||
await self.cleanup_inactive_guests(inactive_hours=7 * 24)
|
||||
|
||||
# Update statistics
|
||||
await self.update_guest_statistics()
|
||||
|
||||
logger.info("✅ Initial cleanup tasks completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in initial cleanup: {e}")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop all background tasks gracefully"""
|
||||
logger.info("🛑 Stopping background task manager...")
|
||||
|
||||
self.running = False
|
||||
|
||||
if self.scheduler_thread and self.scheduler_thread.is_alive():
|
||||
self.scheduler_thread.join(timeout=5)
|
||||
# Cancel all running tasks
|
||||
for task in self.tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
# Clear scheduled tasks
|
||||
schedule.clear()
|
||||
# Wait for all tasks to complete with timeout
|
||||
if self.tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*self.tasks, return_exceptions=True),
|
||||
timeout=30.0
|
||||
)
|
||||
logger.info("✅ All background tasks stopped gracefully")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("⚠️ Some background tasks did not stop within timeout")
|
||||
|
||||
self.tasks.clear()
|
||||
self.main_loop = None
|
||||
|
||||
logger.info("🛑 Background task manager stopped")
|
||||
|
||||
async def get_task_status(self) -> Dict[str, Any]:
|
||||
"""Get status of all background tasks"""
|
||||
status = {
|
||||
"running": self.running,
|
||||
"main_loop_id": id(self.main_loop) if self.main_loop else None,
|
||||
"current_loop_id": None,
|
||||
"task_count": len(self.tasks),
|
||||
"tasks": []
|
||||
}
|
||||
|
||||
try:
|
||||
current_loop = asyncio.get_running_loop()
|
||||
status["current_loop_id"] = id(current_loop)
|
||||
status["loop_matches"] = (id(current_loop) == id(self.main_loop)) if self.main_loop else False
|
||||
except RuntimeError:
|
||||
status["current_loop_id"] = "no_running_loop"
|
||||
|
||||
for task in self.tasks:
|
||||
task_info = {
|
||||
"name": task.get_name(),
|
||||
"done": task.done(),
|
||||
"cancelled": task.cancelled(),
|
||||
}
|
||||
|
||||
if task.done() and not task.cancelled():
|
||||
try:
|
||||
task.result() # This will raise an exception if the task failed
|
||||
task_info["status"] = "completed"
|
||||
except Exception as e:
|
||||
task_info["status"] = "failed"
|
||||
task_info["error"] = str(e)
|
||||
elif task.cancelled():
|
||||
task_info["status"] = "cancelled"
|
||||
else:
|
||||
task_info["status"] = "running"
|
||||
|
||||
status["tasks"].append(task_info)
|
||||
|
||||
return status
|
||||
|
||||
async def force_run_task(self, task_name: str) -> Any:
|
||||
"""Manually trigger a specific background task"""
|
||||
task_map = {
|
||||
"guest_cleanup": self.cleanup_inactive_guests,
|
||||
"token_cleanup": self.cleanup_expired_verification_tokens,
|
||||
"guest_stats": self.update_guest_statistics,
|
||||
"rate_limit_cleanup": self.cleanup_old_rate_limit_data,
|
||||
"orphaned_cleanup": self.cleanup_orphaned_data,
|
||||
}
|
||||
|
||||
if task_name not in task_map:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(task_map.keys())}")
|
||||
|
||||
logger.info(f"🔧 Manually running task: {task_name}")
|
||||
result = await task_map[task_name]()
|
||||
logger.info(f"✅ Manual task {task_name} completed")
|
||||
return result
|
||||
|
||||
|
||||
# Usage in your main application
|
||||
async def setup_background_tasks(database_manager: DatabaseManager) -> BackgroundTaskManager:
|
||||
"""Setup and start background tasks"""
|
||||
task_manager = BackgroundTaskManager(database_manager)
|
||||
await task_manager.start()
|
||||
return task_manager
|
||||
|
||||
# For integration with your existing app startup
|
||||
async def initialize_with_background_tasks(database_manager: DatabaseManager):
|
||||
"""Initialize database and background tasks together"""
|
||||
# Start background tasks
|
||||
background_tasks = await setup_background_tasks(database_manager)
|
||||
|
||||
# Return both for your app to manage
|
||||
return database_manager, background_tasks
|
@ -2766,4 +2766,5 @@ class DatabaseManager:
|
||||
if self._shutdown_initiated:
|
||||
raise RuntimeError("Application is shutting down")
|
||||
return self.db
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import hashlib
|
||||
import time
|
||||
import traceback
|
||||
from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request, BackgroundTasks, File, UploadFile, Form# type: ignore
|
||||
from fastapi.middleware.cors import CORSMiddleware # type: ignore
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials# type: ignore
|
||||
@ -105,7 +106,8 @@ from models import (
|
||||
# ============================
|
||||
# Startup Event
|
||||
# ============================
|
||||
db_manager = DatabaseManager()
|
||||
db_manager = None
|
||||
background_task_manager = None
|
||||
|
||||
prev_int = signal.getsignal(signal.SIGINT)
|
||||
prev_term = signal.getsignal(signal.SIGTERM)
|
||||
@ -121,237 +123,38 @@ def signal_handler(signum, frame):
|
||||
# Global background task manager
|
||||
background_task_manager: Optional[BackgroundTaskManager] = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
global background_task_manager
|
||||
|
||||
logger.info("🚀 Starting Backstory API with enhanced background tasks")
|
||||
logger.info(f"📝 API Documentation available at: http://{defines.host}:{defines.port}{defines.api_prefix}/docs")
|
||||
logger.info("🔗 API endpoints prefixed with: /api/1.0")
|
||||
try:
|
||||
# Initialize database
|
||||
await db_manager.initialize()
|
||||
entities.entity_manager.initialize(prometheus_collector, database=db_manager.get_database())
|
||||
|
||||
# Initialize background task manager
|
||||
background_task_manager = BackgroundTaskManager(db_manager)
|
||||
background_task_manager.start()
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
logger.info("🚀 Application startup completed with background tasks")
|
||||
|
||||
yield # Application is running
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to start application: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Shutdown
|
||||
logger.info("Application shutdown requested")
|
||||
|
||||
# Stop background tasks first
|
||||
if background_task_manager:
|
||||
background_task_manager.stop()
|
||||
|
||||
await db_manager.graceful_shutdown()
|
||||
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title="Backstory API",
|
||||
description="FastAPI backend for Backstory platform with TypeScript frontend",
|
||||
version="1.0.0",
|
||||
docs_url=f"{defines.api_prefix}/docs",
|
||||
redoc_url=f"{defines.api_prefix}/redoc",
|
||||
openapi_url=f"{defines.api_prefix}/openapi.json",
|
||||
)
|
||||
|
||||
ssl_enabled = os.getenv("SSL_ENABLED", "true").lower() == "true"
|
||||
|
||||
if ssl_enabled:
|
||||
allow_origins = ["https://battle-linux.ketrenos.com:3000",
|
||||
"https://backstory-beta.ketrenos.com"]
|
||||
else:
|
||||
allow_origins = ["http://battle-linux.ketrenos.com:3000",
|
||||
"http://backstory-beta.ketrenos.com"]
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allow_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Security
|
||||
security = HTTPBearer()
|
||||
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "")
|
||||
if JWT_SECRET_KEY == "":
|
||||
raise ValueError("JWT_SECRET_KEY environment variable is not set")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
# ============================
|
||||
# Debug data type failures
|
||||
# ============================
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Validation error {request.method} {request.url.path}: {str(exc)}")
|
||||
return JSONResponse(
|
||||
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content=json.dumps({"detail": str(exc)}),
|
||||
)
|
||||
|
||||
# ============================
|
||||
# Authentication Utilities
|
||||
# ============================
|
||||
|
||||
# Request/Response Models
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(UTC) + timedelta(hours=24)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""Enhanced token verification with guest session recovery"""
|
||||
try:
|
||||
# First decode the token
|
||||
payload = jwt.decode(credentials.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: str = payload.get("sub")
|
||||
token_type: str = payload.get("type", "access")
|
||||
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
||||
|
||||
# Check if token is blacklisted
|
||||
redis = redis_manager.get_client()
|
||||
blacklist_key = f"blacklisted_token:{credentials.credentials}"
|
||||
|
||||
is_blacklisted = await redis.exists(blacklist_key)
|
||||
if is_blacklisted:
|
||||
logger.warning(f"🚫 Attempt to use blacklisted token for user {user_id}")
|
||||
raise HTTPException(status_code=401, detail="Token has been revoked")
|
||||
|
||||
# For guest tokens, verify guest still exists and update activity
|
||||
if token_type == "guest" or payload.get("type") == "guest":
|
||||
database = db_manager.get_database()
|
||||
guest_data = await database.get_guest(user_id)
|
||||
|
||||
if not guest_data:
|
||||
logger.warning(f"🚫 Guest session not found for token: {user_id}")
|
||||
raise HTTPException(status_code=401, detail="Guest session expired")
|
||||
|
||||
# Update guest activity
|
||||
guest_data["last_activity"] = datetime.now(UTC).isoformat()
|
||||
await database.set_guest(user_id, guest_data)
|
||||
logger.debug(f"🔄 Guest activity updated: {user_id}")
|
||||
|
||||
return user_id
|
||||
|
||||
except jwt.PyJWTError as e:
|
||||
logger.warning(f"⚠️ JWT decode error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Token verification error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Token verification failed")
|
||||
|
||||
async def get_current_user(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(lambda: db_manager.get_database())
|
||||
) -> BaseUserWithType:
|
||||
"""Get current user from database"""
|
||||
try:
|
||||
# Check candidates
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
# logger.info(f"🔑 Current user is candidate: {candidate['id']}")
|
||||
return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data) # type: ignore[return-value]
|
||||
# Check candidates
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
# logger.info(f"🔑 Current user is candidate: {candidate['id']}")
|
||||
if candidate_data.get("is_AI"):
|
||||
return model_cast.cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data))
|
||||
else:
|
||||
return model_cast.cast_to_base_user_with_type(Candidate.model_validate(candidate_data))
|
||||
# Check employers
|
||||
employer = await database.get_employer(user_id)
|
||||
if employer:
|
||||
# logger.info(f"🔑 Current user is employer: {employer['id']}")
|
||||
return Employer.model_validate(employer)
|
||||
|
||||
logger.warning(f"⚠️ User {user_id} not found in database")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting current user: {e}")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
async def get_current_user_or_guest(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(lambda: db_manager.get_database())
|
||||
) -> BaseUserWithType:
|
||||
"""Get current user (including guests) from database"""
|
||||
try:
|
||||
# Check candidates first
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data)
|
||||
|
||||
# Check employers
|
||||
employer_data = await database.get_employer(user_id)
|
||||
if employer_data:
|
||||
return Employer.model_validate(employer_data)
|
||||
|
||||
# Check guests
|
||||
guest_data = await database.get_guest(user_id)
|
||||
if guest_data:
|
||||
return Guest.model_validate(guest_data)
|
||||
|
||||
logger.warning(f"⚠️ User {user_id} not found in database")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting current user: {e}")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
async def get_current_admin(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(lambda: db_manager.get_database())
|
||||
) -> BaseUserWithType:
|
||||
user = await get_current_user(user_id=user_id, database=database)
|
||||
if isinstance(user, Candidate) and user.is_admin:
|
||||
return user
|
||||
elif isinstance(user, Employer) and user.is_admin:
|
||||
return user
|
||||
else:
|
||||
logger.warning(f"⚠️ User {user_id} is not an admin")
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
|
||||
# ============================
|
||||
# Helper Functions
|
||||
# ============================
|
||||
async def get_database() -> RedisDatabase:
|
||||
def get_database() -> RedisDatabase:
|
||||
"""
|
||||
FastAPI dependency to get database instance with shutdown protection
|
||||
Safe database dependency that checks for availability
|
||||
Raises HTTP 503 if database is not available
|
||||
"""
|
||||
return db_manager.get_database()
|
||||
global db_manager
|
||||
|
||||
if db_manager is None:
|
||||
logger.error("Database manager not initialized")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Database not available - service starting up"
|
||||
)
|
||||
|
||||
if db_manager.is_shutting_down:
|
||||
logger.warning("Database is shutting down")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Service is shutting down"
|
||||
)
|
||||
|
||||
try:
|
||||
return db_manager.get_database()
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Database not available: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Database connection not available"
|
||||
)
|
||||
|
||||
async def get_last_item(generator):
|
||||
last_item = None
|
||||
@ -524,6 +327,239 @@ def get_document_type_from_filename(filename: str) -> DocumentType:
|
||||
|
||||
return type_mapping.get(extension, DocumentType.TXT)
|
||||
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
global db_manager, background_task_manager
|
||||
|
||||
logger.info("🚀 Starting Backstory API with enhanced background tasks")
|
||||
logger.info(f"📝 API Documentation available at: http://{defines.host}:{defines.port}{defines.api_prefix}/docs")
|
||||
logger.info("🔗 API endpoints prefixed with: /api/1.0")
|
||||
try:
|
||||
# Initialize database
|
||||
db_manager = DatabaseManager()
|
||||
await db_manager.initialize()
|
||||
entities.entity_manager.initialize(prometheus_collector, database=db_manager.get_database())
|
||||
|
||||
# Initialize background task manager
|
||||
background_task_manager = BackgroundTaskManager(db_manager)
|
||||
await background_task_manager.start()
|
||||
|
||||
app.state.db_manager = db_manager
|
||||
app.state.background_task_manager = background_task_manager
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
logger.info("🚀 Application startup completed with background tasks")
|
||||
|
||||
yield # Application is running
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to start application: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Shutdown
|
||||
logger.info("Application shutdown requested")
|
||||
|
||||
# Stop background tasks first
|
||||
if background_task_manager:
|
||||
await background_task_manager.stop()
|
||||
|
||||
if db_manager:
|
||||
await db_manager.graceful_shutdown()
|
||||
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title="Backstory API",
|
||||
description="FastAPI backend for Backstory platform with TypeScript frontend",
|
||||
version="1.0.0",
|
||||
docs_url=f"{defines.api_prefix}/docs",
|
||||
redoc_url=f"{defines.api_prefix}/redoc",
|
||||
openapi_url=f"{defines.api_prefix}/openapi.json",
|
||||
)
|
||||
|
||||
ssl_enabled = os.getenv("SSL_ENABLED", "true").lower() == "true"
|
||||
|
||||
if ssl_enabled:
|
||||
allow_origins = ["https://battle-linux.ketrenos.com:3000",
|
||||
"https://backstory-beta.ketrenos.com"]
|
||||
else:
|
||||
allow_origins = ["http://battle-linux.ketrenos.com:3000",
|
||||
"http://backstory-beta.ketrenos.com"]
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allow_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Security
|
||||
security = HTTPBearer()
|
||||
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "")
|
||||
if JWT_SECRET_KEY == "":
|
||||
raise ValueError("JWT_SECRET_KEY environment variable is not set")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
# ============================
|
||||
# Debug data type failures
|
||||
# ============================
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Validation error {request.method} {request.url.path}: {str(exc)}")
|
||||
return JSONResponse(
|
||||
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content=json.dumps({"detail": str(exc)}),
|
||||
)
|
||||
|
||||
# ============================
|
||||
# Authentication Utilities
|
||||
# ============================
|
||||
|
||||
# Request/Response Models
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(UTC) + timedelta(hours=24)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""Enhanced token verification with guest session recovery"""
|
||||
try:
|
||||
if not db_manager:
|
||||
raise HTTPException(status_code=500, detail="Database not initialized")
|
||||
# First decode the token
|
||||
payload = jwt.decode(credentials.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: str = payload.get("sub")
|
||||
token_type: str = payload.get("type", "access")
|
||||
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
||||
|
||||
# Check if token is blacklisted
|
||||
redis = redis_manager.get_client()
|
||||
blacklist_key = f"blacklisted_token:{credentials.credentials}"
|
||||
|
||||
is_blacklisted = await redis.exists(blacklist_key)
|
||||
if is_blacklisted:
|
||||
logger.warning(f"🚫 Attempt to use blacklisted token for user {user_id}")
|
||||
raise HTTPException(status_code=401, detail="Token has been revoked")
|
||||
|
||||
# For guest tokens, verify guest still exists and update activity
|
||||
if token_type == "guest" or payload.get("type") == "guest":
|
||||
database = db_manager.get_database()
|
||||
guest_data = await database.get_guest(user_id)
|
||||
|
||||
if not guest_data:
|
||||
logger.warning(f"🚫 Guest session not found for token: {user_id}")
|
||||
raise HTTPException(status_code=401, detail="Guest session expired")
|
||||
|
||||
# Update guest activity
|
||||
guest_data["last_activity"] = datetime.now(UTC).isoformat()
|
||||
await database.set_guest(user_id, guest_data)
|
||||
logger.debug(f"🔄 Guest activity updated: {user_id}")
|
||||
|
||||
return user_id
|
||||
|
||||
except jwt.PyJWTError as e:
|
||||
logger.warning(f"⚠️ JWT decode error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Token verification error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Token verification failed")
|
||||
|
||||
async def get_current_user(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
) -> BaseUserWithType:
|
||||
"""Get current user from database"""
|
||||
try:
|
||||
# Check candidates
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
# logger.info(f"🔑 Current user is candidate: {candidate['id']}")
|
||||
return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data) # type: ignore[return-value]
|
||||
# Check candidates
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
# logger.info(f"🔑 Current user is candidate: {candidate['id']}")
|
||||
if candidate_data.get("is_AI"):
|
||||
return model_cast.cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data))
|
||||
else:
|
||||
return model_cast.cast_to_base_user_with_type(Candidate.model_validate(candidate_data))
|
||||
# Check employers
|
||||
employer = await database.get_employer(user_id)
|
||||
if employer:
|
||||
# logger.info(f"🔑 Current user is employer: {employer['id']}")
|
||||
return Employer.model_validate(employer)
|
||||
|
||||
logger.warning(f"⚠️ User {user_id} not found in database")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting current user: {e}")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
async def get_current_user_or_guest(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
) -> BaseUserWithType:
|
||||
"""Get current user (including guests) from database"""
|
||||
try:
|
||||
# Check candidates first
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data)
|
||||
|
||||
# Check employers
|
||||
employer_data = await database.get_employer(user_id)
|
||||
if employer_data:
|
||||
return Employer.model_validate(employer_data)
|
||||
|
||||
# Check guests
|
||||
guest_data = await database.get_guest(user_id)
|
||||
if guest_data:
|
||||
return Guest.model_validate(guest_data)
|
||||
|
||||
logger.warning(f"⚠️ User {user_id} not found in database")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting current user: {e}")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
async def get_current_admin(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
) -> BaseUserWithType:
|
||||
user = await get_current_user(user_id=user_id, database=database)
|
||||
if isinstance(user, Candidate) and user.is_admin:
|
||||
return user
|
||||
elif isinstance(user, Employer) and user.is_admin:
|
||||
return user
|
||||
else:
|
||||
logger.warning(f"⚠️ User {user_id} is not an admin")
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
|
||||
# ============================
|
||||
# Rate Limiting Dependencies
|
||||
# ============================
|
||||
@ -605,7 +641,7 @@ async def rate_limit_dependency(
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id = payload.get("sub")
|
||||
if user_id:
|
||||
database = db_manager.get_database()
|
||||
database = get_database()
|
||||
# Quick user lookup for rate limiting
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
@ -2481,8 +2517,8 @@ async def upload_candidate_profile(
|
||||
)
|
||||
)
|
||||
|
||||
# Check file size (limit to 2MB)
|
||||
max_size = 2 * 1024 * 1024 # 2MB
|
||||
# Check file size (limit to 5MB)
|
||||
max_size = 5 * 1024 * 1024 # 2MB
|
||||
file_content = await file.read()
|
||||
if len(file_content) > max_size:
|
||||
logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)")
|
||||
@ -2493,7 +2529,9 @@ async def upload_candidate_profile(
|
||||
|
||||
# Save file to disk as "profile.<extension>"
|
||||
_, extension = os.path.splitext(file.filename or "")
|
||||
file_path = os.path.join(defines.user_dir, candidate.username, f"profile{extension}")
|
||||
file_path = os.path.join(defines.user_dir, candidate.username)
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path, f"profile{extension}")
|
||||
|
||||
try:
|
||||
with open(file_path, "wb") as f:
|
||||
@ -4712,7 +4750,7 @@ def rate_limited(
|
||||
# If not found in kwargs, check if they're provided via Depends
|
||||
if not rate_limiter:
|
||||
# Create rate limiter instance (this should ideally come from DI)
|
||||
database = db_manager.get_database()
|
||||
database = get_database()
|
||||
rate_limiter = RateLimiter(database)
|
||||
|
||||
# Apply rate limiting if we have the required components
|
||||
@ -5727,10 +5765,11 @@ async def get_redis() -> redis.Redis:
|
||||
return redis_manager.get_client()
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
async def health_check(
|
||||
database = Depends(get_database),
|
||||
):
|
||||
"""Health check endpoint"""
|
||||
try:
|
||||
database = db_manager.get_database()
|
||||
if not redis_manager.redis:
|
||||
raise RuntimeError("Redis client not initialized")
|
||||
|
||||
@ -5756,8 +5795,8 @@ async def health_check():
|
||||
"memory_used": redis_info.get("used_memory_human", "unknown")
|
||||
},
|
||||
"application": {
|
||||
"active_requests": db_manager._active_requests,
|
||||
"shutting_down": db_manager.is_shutting_down
|
||||
"active_requests": database._active_requests,
|
||||
"shutting_down": database.is_shutting_down
|
||||
}
|
||||
}
|
||||
|
||||
@ -5901,33 +5940,226 @@ async def manual_rate_limit_cleanup(
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
|
||||
# ========================================
|
||||
# System Health and Maintenance Endpoints
|
||||
# ========================================
|
||||
|
||||
@api_router.get("/admin/system/health")
|
||||
async def get_system_health(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""Get comprehensive system health status (admin only)"""
|
||||
try:
|
||||
# Database health
|
||||
database_manager = getattr(request.app.state, 'database_manager', None)
|
||||
db_health = {"status": "unavailable", "healthy": False}
|
||||
|
||||
if database_manager:
|
||||
try:
|
||||
database = database_manager.get_database()
|
||||
from database import redis_manager
|
||||
redis_health = await redis_manager.health_check()
|
||||
db_health = {
|
||||
"status": redis_health.get("status", "unknown"),
|
||||
"healthy": redis_health.get("status") == "healthy",
|
||||
"details": redis_health
|
||||
}
|
||||
except Exception as e:
|
||||
db_health = {
|
||||
"status": "error",
|
||||
"healthy": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Background task health
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
task_health = {"status": "unavailable", "healthy": False}
|
||||
|
||||
if background_task_manager:
|
||||
try:
|
||||
task_status = await background_task_manager.get_task_status()
|
||||
running_tasks = len([t for t in task_status["tasks"] if t["status"] == "running"])
|
||||
failed_tasks = len([t for t in task_status["tasks"] if t["status"] == "failed"])
|
||||
|
||||
task_health = {
|
||||
"status": "healthy" if task_status["running"] and failed_tasks == 0 else "degraded",
|
||||
"healthy": task_status["running"] and failed_tasks == 0,
|
||||
"running_tasks": running_tasks,
|
||||
"failed_tasks": failed_tasks,
|
||||
"total_tasks": task_status["task_count"]
|
||||
}
|
||||
except Exception as e:
|
||||
task_health = {
|
||||
"status": "error",
|
||||
"healthy": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Overall health
|
||||
overall_healthy = db_health["healthy"] and task_health["healthy"]
|
||||
|
||||
return create_success_response({
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"overall_healthy": overall_healthy,
|
||||
"components": {
|
||||
"database": db_health,
|
||||
"background_tasks": task_health
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting system health: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("HEALTH_CHECK_ERROR", str(e))
|
||||
)
|
||||
|
||||
@api_router.post("/admin/maintenance/cleanup")
|
||||
async def run_maintenance_cleanup(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Run comprehensive maintenance cleanup (admin only)"""
|
||||
try:
|
||||
cleanup_results = {}
|
||||
|
||||
# Run various cleanup operations
|
||||
cleanup_operations = [
|
||||
("inactive_guests", lambda: database.cleanup_inactive_guests(72)), # 3 days
|
||||
("expired_tokens", lambda: database.cleanup_expired_verification_tokens()),
|
||||
("orphaned_job_requirements", lambda: database.cleanup_orphaned_job_requirements()),
|
||||
]
|
||||
|
||||
for operation_name, operation_func in cleanup_operations:
|
||||
try:
|
||||
result = await operation_func()
|
||||
cleanup_results[operation_name] = {
|
||||
"success": True,
|
||||
"cleaned_count": result,
|
||||
"message": f"Cleaned {result} items"
|
||||
}
|
||||
except Exception as e:
|
||||
cleanup_results[operation_name] = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"Failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Calculate totals
|
||||
total_cleaned = sum(
|
||||
result.get("cleaned_count", 0)
|
||||
for result in cleanup_results.values()
|
||||
if result.get("success", False)
|
||||
)
|
||||
|
||||
successful_operations = len([
|
||||
r for r in cleanup_results.values()
|
||||
if r.get("success", False)
|
||||
])
|
||||
|
||||
return create_success_response({
|
||||
"message": f"Maintenance cleanup completed. {total_cleaned} items cleaned across {successful_operations} operations.",
|
||||
"total_cleaned": total_cleaned,
|
||||
"successful_operations": successful_operations,
|
||||
"details": cleanup_results
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in maintenance cleanup: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
|
||||
# ========================================
|
||||
# Background Task Statistics
|
||||
# ========================================
|
||||
|
||||
@api_router.get("/admin/tasks/stats")
|
||||
async def get_task_statistics(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get background task execution statistics (admin only)"""
|
||||
try:
|
||||
# Get guest statistics
|
||||
guest_stats = await database.get_guest_statistics()
|
||||
|
||||
# Get background task manager status
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
task_manager_stats = {}
|
||||
|
||||
if background_task_manager:
|
||||
task_status = await background_task_manager.get_task_status()
|
||||
task_manager_stats = {
|
||||
"running": task_status["running"],
|
||||
"task_count": task_status["task_count"],
|
||||
"task_breakdown": {}
|
||||
}
|
||||
|
||||
# Count tasks by status
|
||||
for task in task_status["tasks"]:
|
||||
status = task["status"]
|
||||
task_manager_stats["task_breakdown"][status] = task_manager_stats["task_breakdown"].get(status, 0) + 1
|
||||
|
||||
return create_success_response({
|
||||
"guest_statistics": guest_stats,
|
||||
"task_manager": task_manager_stats,
|
||||
"timestamp": datetime.now(UTC).isoformat()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting task statistics: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATS_ERROR", str(e))
|
||||
)
|
||||
|
||||
# ========================================
|
||||
# Background Task Status Endpoints
|
||||
# ========================================
|
||||
|
||||
@api_router.get("/admin/tasks/status")
|
||||
async def get_background_task_status(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""Get background task manager status (admin only)"""
|
||||
try:
|
||||
global background_task_manager
|
||||
# Get background task manager from app state
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
|
||||
if not background_task_manager:
|
||||
return create_success_response({
|
||||
"running": False,
|
||||
"message": "Background task manager not initialized"
|
||||
"message": "Background task manager not initialized",
|
||||
"tasks": [],
|
||||
"task_count": 0
|
||||
})
|
||||
|
||||
# Get next scheduled run times
|
||||
next_runs = []
|
||||
for job in schedule.jobs:
|
||||
next_runs.append({
|
||||
"job": str(job.job_func),
|
||||
"next_run": job.next_run.isoformat() if job.next_run else None
|
||||
})
|
||||
# Get comprehensive task status using the new method
|
||||
task_status = await background_task_manager.get_task_status()
|
||||
|
||||
# Add additional system info
|
||||
system_info = {
|
||||
"uptime_seconds": None, # Could calculate from start time if stored
|
||||
"last_cleanup": None, # Could track last cleanup time
|
||||
}
|
||||
|
||||
# Format the response
|
||||
return create_success_response({
|
||||
"running": background_task_manager.running,
|
||||
"scheduler_thread_alive": background_task_manager.scheduler_thread.is_alive() if background_task_manager.scheduler_thread else False,
|
||||
"scheduled_jobs": len(schedule.jobs),
|
||||
"next_runs": next_runs
|
||||
"running": task_status["running"],
|
||||
"task_count": task_status["task_count"],
|
||||
"loop_status": {
|
||||
"main_loop_id": task_status["main_loop_id"],
|
||||
"current_loop_id": task_status["current_loop_id"],
|
||||
"loop_matches": task_status.get("loop_matches", False)
|
||||
},
|
||||
"tasks": task_status["tasks"],
|
||||
"system_info": system_info
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
@ -5937,6 +6169,162 @@ async def get_background_task_status(
|
||||
content=create_error_response("STATUS_ERROR", str(e))
|
||||
)
|
||||
|
||||
@api_router.post("/admin/tasks/run/{task_name}")
|
||||
async def run_background_task(
|
||||
task_name: str,
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""Manually trigger a specific background task (admin only)"""
|
||||
try:
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
|
||||
if not background_task_manager:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content=create_error_response(
|
||||
"MANAGER_UNAVAILABLE",
|
||||
"Background task manager not initialized"
|
||||
)
|
||||
)
|
||||
|
||||
# List of available tasks
|
||||
available_tasks = [
|
||||
"guest_cleanup",
|
||||
"token_cleanup",
|
||||
"guest_stats",
|
||||
"rate_limit_cleanup",
|
||||
"orphaned_cleanup"
|
||||
]
|
||||
|
||||
if task_name not in available_tasks:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response(
|
||||
"INVALID_TASK",
|
||||
f"Unknown task: {task_name}. Available: {available_tasks}"
|
||||
)
|
||||
)
|
||||
|
||||
# Run the task
|
||||
result = await background_task_manager.force_run_task(task_name)
|
||||
|
||||
return create_success_response({
|
||||
"task_name": task_name,
|
||||
"result": result,
|
||||
"message": f"Task {task_name} completed successfully"
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("INVALID_TASK", str(e))
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error running task {task_name}: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("TASK_EXECUTION_ERROR", str(e))
|
||||
)
|
||||
|
||||
@api_router.get("/admin/tasks/list")
|
||||
async def list_available_tasks(
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""List all available background tasks (admin only)"""
|
||||
try:
|
||||
tasks = [
|
||||
{
|
||||
"name": "guest_cleanup",
|
||||
"description": "Clean up inactive guest sessions",
|
||||
"interval": "6 hours",
|
||||
"parameters": ["inactive_hours (default: 48)"]
|
||||
},
|
||||
{
|
||||
"name": "token_cleanup",
|
||||
"description": "Clean up expired email verification tokens",
|
||||
"interval": "12 hours",
|
||||
"parameters": []
|
||||
},
|
||||
{
|
||||
"name": "guest_stats",
|
||||
"description": "Update guest usage statistics",
|
||||
"interval": "1 hour",
|
||||
"parameters": []
|
||||
},
|
||||
{
|
||||
"name": "rate_limit_cleanup",
|
||||
"description": "Clean up old rate limiting data",
|
||||
"interval": "24 hours",
|
||||
"parameters": ["days_old (default: 7)"]
|
||||
},
|
||||
{
|
||||
"name": "orphaned_cleanup",
|
||||
"description": "Clean up orphaned database records",
|
||||
"interval": "6 hours",
|
||||
"parameters": []
|
||||
}
|
||||
]
|
||||
|
||||
return create_success_response({
|
||||
"total_tasks": len(tasks),
|
||||
"tasks": tasks
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error listing tasks: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("LIST_ERROR", str(e))
|
||||
)
|
||||
|
||||
@api_router.post("/admin/tasks/restart")
|
||||
async def restart_background_tasks(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""Restart the background task manager (admin only)"""
|
||||
try:
|
||||
database_manager = getattr(request.app.state, 'database_manager', None)
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
|
||||
if not database_manager:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content=create_error_response(
|
||||
"DATABASE_UNAVAILABLE",
|
||||
"Database manager not available"
|
||||
)
|
||||
)
|
||||
|
||||
# Stop existing background tasks
|
||||
if background_task_manager:
|
||||
await background_task_manager.stop()
|
||||
logger.info("🛑 Stopped existing background task manager")
|
||||
|
||||
# Create and start new background task manager
|
||||
from background_tasks import BackgroundTaskManager
|
||||
new_background_task_manager = BackgroundTaskManager(database_manager)
|
||||
await new_background_task_manager.start()
|
||||
|
||||
# Update app state
|
||||
request.app.state.background_task_manager = new_background_task_manager
|
||||
|
||||
# Get status of new manager
|
||||
status = await new_background_task_manager.get_task_status()
|
||||
|
||||
return create_success_response({
|
||||
"message": "Background task manager restarted successfully",
|
||||
"new_status": status
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error restarting background tasks: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESTART_ERROR", str(e))
|
||||
)
|
||||
|
||||
|
||||
# ============================
|
||||
# Task Monitoring and Metrics
|
||||
@ -6043,6 +6431,8 @@ async def log_requests(request: Request, call_next):
|
||||
# ============================
|
||||
@app.middleware("http")
|
||||
async def track_requests(request, call_next):
|
||||
if not db_manager:
|
||||
raise RuntimeError("Database manager not initialized")
|
||||
"""Middleware to track active requests during shutdown"""
|
||||
if db_manager.is_shutting_down:
|
||||
return JSONResponse(status_code=503, content={"error": "Application is shutting down"})
|
||||
@ -6103,7 +6493,7 @@ async def root():
|
||||
async def periodic_verification_cleanup():
|
||||
"""Background task to periodically clean up expired verification tokens"""
|
||||
try:
|
||||
database = db_manager.get_database()
|
||||
database = get_database()
|
||||
cleaned_count = await database.cleanup_expired_verification_tokens()
|
||||
|
||||
if cleaned_count > 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user