2336 lines
94 KiB
Python

from pydantic import BaseModel, ConfigDict, Field
from redis.asyncio import (Redis, ConnectionPool)
from typing import Any, Optional, Dict, List, Optional, TypeGuard, Union
import json
import logging
import os
from datetime import datetime, timezone, UTC, timedelta
import asyncio
from models import (
# User models
Candidate, Employer, BaseUser, EvidenceDetail, Guest, Authentication, AuthResponse, SkillAssessment,
)
import backstory_traceback as traceback
from .constants import KEY_PREFIXES
from .core import RedisDatabase
logger = logging.getLogger(__name__)
class _RedisManager:
def __init__(self):
self.redis: Optional[Redis] = None
self.redis_url = os.getenv("REDIS_URL", "redis://redis:6379")
self.redis_db = int(os.getenv("REDIS_DB", "0"))
# Append database to URL if not already present
if not self.redis_url.endswith(f"/{self.redis_db}"):
self.redis_url = f"{self.redis_url}/{self.redis_db}"
self._connection_pool: Optional[ConnectionPool] = None
self._is_connected = False
async def connect(self):
"""Initialize Redis connection with connection pooling"""
if self._is_connected and self.redis:
logger.info("Redis already connected")
return
try:
# Create connection pool for better resource management
self._connection_pool = ConnectionPool.from_url(
self.redis_url,
encoding="utf-8",
decode_responses=True,
max_connections=20,
retry_on_timeout=True,
socket_keepalive=True,
socket_keepalive_options={},
health_check_interval=30
)
self.redis = Redis(
connection_pool=self._connection_pool
)
if not self.redis:
raise RuntimeError("Redis client not initialized")
# Test connection
await self.redis.ping()
self._is_connected = True
logger.info("Successfully connected to Redis")
# Log Redis info
info = await self.redis.info()
logger.info(f"Redis version: {info.get('redis_version', 'unknown')}")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
self._is_connected = False
self.redis = None
self._connection_pool = None
raise
async def disconnect(self):
"""Close Redis connection gracefully"""
if not self._is_connected:
logger.info("Redis already disconnected")
return
try:
if self.redis:
# Wait for any pending operations to complete
await asyncio.sleep(0.1)
# Close the client
await self.redis.aclose()
logger.info("Redis client closed")
if self._connection_pool:
# Close the connection pool
await self._connection_pool.aclose()
logger.info("Redis connection pool closed")
self._is_connected = False
self.redis = None
self._connection_pool = None
logger.info("Successfully disconnected from Redis")
except Exception as e:
logger.error(f"Error during Redis disconnect: {e}")
# Force cleanup even if there's an error
self._is_connected = False
self.redis = None
self._connection_pool = None
def get_client(self) -> Redis:
"""Get Redis client instance"""
if not self._is_connected or not self.redis:
raise RuntimeError("Redis client not initialized or disconnected")
return self.redis
@property
def is_connected(self) -> bool:
"""Check if Redis is connected"""
return self._is_connected and self.redis is not None
async def health_check(self) -> dict:
"""Perform health check on Redis connection"""
if not self.is_connected:
return {"status": "disconnected", "error": "Redis not connected"}
if not self.redis:
raise RuntimeError("Redis client not initialized")
try:
# Test basic operations
await self.redis.ping()
info = await self.redis.info()
return {
"status": "healthy",
"redis_version": info.get("redis_version", "unknown"),
"uptime_seconds": info.get("uptime_in_seconds", 0),
"connected_clients": info.get("connected_clients", 0),
"used_memory_human": info.get("used_memory_human", "unknown"),
"total_commands_processed": info.get("total_commands_processed", 0)
}
except Exception as e:
logger.error(f"Redis health check failed: {e}")
return {"status": "error", "error": str(e)}
async def force_save(self, background: bool = True) -> bool:
"""Force Redis to save data to disk"""
if not self.is_connected:
logger.warning("Cannot save: Redis not connected")
return False
try:
if not self.redis:
raise RuntimeError("Redis client not initialized")
if background:
# Non-blocking background save
await self.redis.bgsave()
logger.info("Background save initiated")
else:
# Blocking save
await self.redis.save()
logger.info("Synchronous save completed")
return True
except Exception as e:
logger.error(f"Redis save failed: {e}")
return False
async def get_info(self) -> Optional[dict]:
"""Get Redis server information"""
if not self.is_connected:
return None
try:
if not self.redis:
raise RuntimeError("Redis client not initialized")
return await self.redis.info()
except Exception as e:
logger.error(f"Failed to get Redis info: {e}")
return None
class RedisDatabase2:
def __init__(self, redis: Redis):
self.redis = redis
# Redis key prefixes for different data types
def _serialize(self, data: Any) -> str:
"""Serialize data to JSON string for Redis storage"""
if data is None:
return ""
return json.dumps(data, default=str) # default=str handles datetime objects
def _deserialize(self, data: str) -> Any:
"""Deserialize JSON string from Redis"""
if not data:
return None
try:
return json.loads(data)
except json.JSONDecodeError:
logger.error(traceback.format_exc())
logger.error(f"Failed to deserialize data: {data}")
return None
# Resume operations
async def set_resume(self, user_id: str, resume_data: Dict) -> bool:
"""Save a resume for a user"""
try:
# Generate resume_id if not present
if 'id' not in resume_data:
raise ValueError("Resume data must include an 'id' field")
resume_id = resume_data['id']
# Store the resume data
key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}"
await self.redis.set(key, self._serialize(resume_data))
# Add resume_id to user's resume list
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
await self.redis.rpush(user_resumes_key, resume_id) # type: ignore
logger.info(f"📄 Saved resume {resume_id} for user {user_id}")
return True
except Exception as e:
logger.error(f"❌ Error saving resume for user {user_id}: {e}")
return False
async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]:
"""Get a specific resume for a user"""
try:
key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}"
data = await self.redis.get(key)
if data:
resume_data = self._deserialize(data)
logger.info(f"📄 Retrieved resume {resume_id} for user {user_id}")
return resume_data
logger.info(f"📄 Resume {resume_id} not found for user {user_id}")
return None
except Exception as e:
logger.error(f"❌ Error retrieving resume {resume_id} for user {user_id}: {e}")
return None
async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]:
"""Get all resumes for a specific user"""
try:
# Get all resume IDs for this user
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore
if not resume_ids:
logger.info(f"📄 No resumes found for user {user_id}")
return []
# Get all resume data
resumes = []
pipe = self.redis.pipeline()
for resume_id in resume_ids:
pipe.get(f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}")
values = await pipe.execute()
for resume_id, value in zip(resume_ids, values):
if value:
resume_data = self._deserialize(value)
if resume_data:
resumes.append(resume_data)
else:
# Clean up orphaned resume ID
await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore
logger.warning(f"Removed orphaned resume ID {resume_id} for user {user_id}")
# Sort by created_at timestamp (most recent first)
resumes.sort(key=lambda x: x.get("created_at", ""), reverse=True)
logger.info(f"📄 Retrieved {len(resumes)} resumes for user {user_id}")
return resumes
except Exception as e:
logger.error(f"❌ Error retrieving resumes for user {user_id}: {e}")
return []
async def delete_resume(self, user_id: str, resume_id: str) -> bool:
"""Delete a specific resume for a user"""
try:
# Delete the resume data
key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}"
result = await self.redis.delete(key)
# Remove from user's resume list
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore
if result > 0:
logger.info(f"🗑️ Deleted resume {resume_id} for user {user_id}")
return True
else:
logger.warning(f"⚠️ Resume {resume_id} not found for user {user_id}")
return False
except Exception as e:
logger.error(f"❌ Error deleting resume {resume_id} for user {user_id}: {e}")
return False
async def delete_all_resumes_for_user(self, user_id: str) -> int:
"""Delete all resumes for a specific user and return count of deleted resumes"""
try:
# Get all resume IDs for this user
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore
if not resume_ids:
logger.info(f"📄 No resumes found for user {user_id}")
return 0
deleted_count = 0
# Use pipeline for efficient batch operations
pipe = self.redis.pipeline()
# Delete each resume
for resume_id in resume_ids:
pipe.delete(f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}")
deleted_count += 1
# Delete the user's resume list
pipe.delete(user_resumes_key)
# Execute all operations
await pipe.execute()
logger.info(f"🗑️ Successfully deleted {deleted_count} resumes for user {user_id}")
return deleted_count
except Exception as e:
logger.error(f"❌ Error deleting all resumes for user {user_id}: {e}")
raise
async def get_all_resumes(self) -> Dict[str, List[Dict]]:
"""Get all resumes grouped by user (admin function)"""
try:
pattern = f"{KEY_PREFIXES['resumes']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
# Group by user_id
user_resumes = {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
for key, value in zip(keys, values):
if value:
# Extract user_id from key format: resume:{user_id}:{resume_id}
key_parts = key.replace(KEY_PREFIXES['resumes'], '').split(':', 1)
if len(key_parts) >= 1:
user_id = key_parts[0]
resume_data = self._deserialize(value)
if resume_data:
if user_id not in user_resumes:
user_resumes[user_id] = []
user_resumes[user_id].append(resume_data)
# Sort each user's resumes by created_at
for user_id in user_resumes:
user_resumes[user_id].sort(key=lambda x: x.get("created_at", ""), reverse=True)
return user_resumes
except Exception as e:
logger.error(f"❌ Error retrieving all resumes: {e}")
return {}
async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]:
"""Search resumes for a user by content, job title, or candidate name"""
try:
all_resumes = await self.get_all_resumes_for_user(user_id)
query_lower = query.lower()
matching_resumes = []
for resume in all_resumes:
# Search in resume content, job_id, candidate_id, etc.
searchable_text = " ".join([
resume.get("resume", ""),
resume.get("job_id", ""),
resume.get("candidate_id", ""),
str(resume.get("created_at", ""))
]).lower()
if query_lower in searchable_text:
matching_resumes.append(resume)
logger.info(f"📄 Found {len(matching_resumes)} matching resumes for user {user_id}")
return matching_resumes
except Exception as e:
logger.error(f"❌ Error searching resumes for user {user_id}: {e}")
return []
async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]:
"""Get all resumes for a specific candidate created by a user"""
try:
all_resumes = await self.get_all_resumes_for_user(user_id)
candidate_resumes = [
resume for resume in all_resumes
if resume.get("candidate_id") == candidate_id
]
logger.info(f"📄 Found {len(candidate_resumes)} resumes for candidate {candidate_id} by user {user_id}")
return candidate_resumes
except Exception as e:
logger.error(f"❌ Error retrieving resumes for candidate {candidate_id} by user {user_id}: {e}")
return []
async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]:
"""Get all resumes for a specific job created by a user"""
try:
all_resumes = await self.get_all_resumes_for_user(user_id)
job_resumes = [
resume for resume in all_resumes
if resume.get("job_id") == job_id
]
logger.info(f"📄 Found {len(job_resumes)} resumes for job {job_id} by user {user_id}")
return job_resumes
except Exception as e:
logger.error(f"❌ Error retrieving resumes for job {job_id} by user {user_id}: {e}")
return []
async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]:
"""Get resume statistics for a user"""
try:
all_resumes = await self.get_all_resumes_for_user(user_id)
stats = {
"total_resumes": len(all_resumes),
"resumes_by_candidate": {},
"resumes_by_job": {},
"creation_timeline": {},
"recent_resumes": []
}
for resume in all_resumes:
# Count by candidate
candidate_id = resume.get("candidate_id", "unknown")
stats["resumes_by_candidate"][candidate_id] = stats["resumes_by_candidate"].get(candidate_id, 0) + 1
# Count by job
job_id = resume.get("job_id", "unknown")
stats["resumes_by_job"][job_id] = stats["resumes_by_job"].get(job_id, 0) + 1
# Timeline by date
created_at = resume.get("created_at")
if created_at:
try:
date_key = created_at[:10] # Extract date part
stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1
except (IndexError, TypeError):
pass
# Get recent resumes (last 5)
stats["recent_resumes"] = all_resumes[:5]
return stats
except Exception as e:
logger.error(f"❌ Error getting resume statistics for user {user_id}: {e}")
return {"total_resumes": 0, "resumes_by_candidate": {}, "resumes_by_job": {}, "creation_timeline": {}, "recent_resumes": []}
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]:
"""Update specific fields of a resume"""
try:
resume_data = await self.get_resume(user_id, resume_id)
if resume_data:
resume_data.update(updates)
resume_data["updated_at"] = datetime.now(UTC).isoformat()
key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}"
await self.redis.set(key, self._serialize(resume_data))
logger.info(f"📄 Updated resume {resume_id} for user {user_id}")
return resume_data
return None
except Exception as e:
logger.error(f"❌ Error updating resume {resume_id} for user {user_id}: {e}")
return None
# Document operations
async def get_document(self, document_id: str) -> Optional[Dict]:
"""Get document metadata by ID"""
key = f"document:{document_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_document(self, document_id: str, document_data: Dict):
"""Set document metadata"""
key = f"document:{document_id}"
await self.redis.set(key, self._serialize(document_data))
async def delete_document(self, document_id: str):
"""Delete document metadata"""
key = f"document:{document_id}"
await self.redis.delete(key)
async def delete_all_candidate_documents(self, candidate_id: str) -> int:
"""Delete all documents for a specific candidate and return count of deleted documents"""
try:
# Get all document IDs for this candidate
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
document_ids = await self.redis.lrange(key, 0, -1)# type: ignore
if not document_ids:
logger.info(f"No documents found for candidate {candidate_id}")
return 0
deleted_count = 0
# Use pipeline for efficient batch operations
pipe = self.redis.pipeline()
# Delete each document's metadata
for doc_id in document_ids:
pipe.delete(f"document:{doc_id}")
pipe.delete(f"{KEY_PREFIXES['job_requirements']}{doc_id}")
deleted_count += 1
# Delete the candidate's document list
pipe.delete(key)
# Execute all operations
await pipe.execute()
logger.info(f"Successfully deleted {deleted_count} documents for candidate {candidate_id}")
return deleted_count
except Exception as e:
logger.error(f"Error deleting all documents for candidate {candidate_id}: {e}")
raise
async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]:
"""Get cached skill match assessment"""
try:
json_str = await self.redis.get(cache_key)
if json_str:
json_data = json.loads(json_str)
skill_assessment = SkillAssessment.model_validate(json_data)
return skill_assessment
return None
except Exception as e:
logger.error(f"❌ Error getting cached skill match: {e}")
return None
async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None:
"""Cache skill match assessment"""
try:
# Cache for 1 hour by default
await self.redis.set(
cache_key,
json.dumps(assessment.model_dump(mode='json', by_alias=True), default=str) # Serialize with datetime handling
)
logger.info(f"💾 Skill match cached: {cache_key}")
except Exception as e:
logger.error(f"❌ Error caching skill match: {e}")
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
"""Get the last time user's RAG data was updated (returns timezone-aware UTC)"""
try:
rag_update_key = f"user:{user_id}:rag_last_update"
timestamp_str = await self.redis.get(rag_update_key)
if timestamp_str:
dt = datetime.fromisoformat(timestamp_str)
# Ensure the datetime is timezone-aware (assume UTC if naive)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
else:
# Convert to UTC if it's in a different timezone
dt = dt.astimezone(timezone.utc)
return dt
logger.warning(f"⚠️ No RAG update time found for user {user_id}")
return None
except Exception as e:
logger.error(f"❌ Error getting user RAG update time: {e}")
return None
async def update_user_rag_timestamp(self, user_id: str) -> bool:
"""Set the user's RAG data update time (stores as UTC ISO format)"""
try:
update_time = datetime.now(timezone.utc)
# Ensure we're storing UTC timezone-aware format
if update_time.tzinfo is None:
update_time = update_time.replace(tzinfo=timezone.utc)
else:
update_time = update_time.astimezone(timezone.utc)
rag_update_key = f"user:{user_id}:rag_last_update"
# Store as ISO format with timezone info
timestamp_str = update_time.isoformat() # This includes timezone
await self.redis.set(rag_update_key, timestamp_str)
logger.info(f"✅ User RAG update time set for user {user_id}: {timestamp_str}")
return True
except Exception as e:
logger.error(f"❌ Error setting user RAG update time: {e}")
return False
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int:
"""Invalidate all cached skill matches for a specific candidate"""
try:
pattern = f"skill_match:{candidate_id}:*"
keys = await self.redis.keys(pattern)
if keys:
return await self.redis.delete(*keys)
return 0
except Exception as e:
logger.error(f"Error invalidating candidate skill cache: {e}")
return 0
async def clear_all_skill_match_cache(self) -> int:
"""Clear all skill match cache (useful after major system updates)"""
try:
pattern = "skill_match:*"
keys = await self.redis.keys(pattern)
if keys:
return await self.redis.delete(*keys)
return 0
except Exception as e:
logger.error(f"Error clearing skill match cache: {e}")
return 0
async def invalidate_user_skill_cache(self, user_id: str) -> int:
"""Invalidate all cached skill matches when a user's RAG data is updated"""
try:
# This assumes all candidates belonging to this user need cache invalidation
# You might need to adjust the pattern based on how you associate candidates with users
pattern = f"skill_match:{user_id}:*"
keys = await self.redis.keys(pattern)
# Filter keys that belong to candidates owned by this user
# This would require additional logic to determine candidate ownership
# For now, you might want to clear all cache when any user's RAG data updates
# or implement a more sophisticated mapping
if keys:
return await self.redis.delete(*keys)
return 0
except Exception as e:
logger.error(f"Error invalidating user skill cache for user {user_id}: {e}")
return 0
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
"""Get all documents for a specific candidate"""
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
document_ids = await self.redis.lrange(key, 0, -1) # type: ignore
if not document_ids:
return []
# Get all document metadata
pipe = self.redis.pipeline()
for doc_id in document_ids:
pipe.get(f"document:{doc_id}")
values = await pipe.execute()
documents = []
for doc_id, value in zip(document_ids, values):
if value:
doc_data = self._deserialize(value)
if doc_data:
documents.append(doc_data)
else:
# Clean up orphaned document ID
await self.redis.lrem(key, 0, doc_id)# type: ignore
logger.warning(f"Removed orphaned document ID {doc_id} for candidate {candidate_id}")
return documents
async def add_document_to_candidate(self, candidate_id: str, document_id: str):
"""Add a document ID to a candidate's document list"""
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
await self.redis.rpush(key, document_id)# type: ignore
async def remove_document_from_candidate(self, candidate_id: str, document_id: str):
"""Remove a document ID from a candidate's document list"""
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
await self.redis.lrem(key, 0, document_id)# type: ignore
async def update_document(self, document_id: str, updates: Dict):
"""Update document metadata"""
document_data = await self.get_document(document_id)
if document_data:
document_data.update(updates)
await self.set_document(document_id, document_data)
return document_data
return None
async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]:
"""Get candidate documents filtered by RAG inclusion status"""
all_documents = await self.get_candidate_documents(candidate_id)
return [doc for doc in all_documents if doc.get("include_in_rag", False) == include_in_rag]
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool):
"""Bulk update RAG status for multiple documents"""
pipe = self.redis.pipeline()
for doc_id in document_ids:
doc_data = await self.get_document(doc_id)
if doc_data and doc_data.get("candidate_id") == candidate_id:
doc_data["include_in_rag"] = include_in_rag
doc_data["updatedAt"] = datetime.now(UTC).isoformat()
pipe.set(f"document:{doc_id}", self._serialize(doc_data))
await pipe.execute()
async def get_document_count_for_candidate(self, candidate_id: str) -> int:
"""Get total number of documents for a candidate"""
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
return await self.redis.llen(key)# type: ignore
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]:
"""Search documents by filename for a candidate"""
all_documents = await self.get_candidate_documents(candidate_id)
query_lower = query.lower()
return [
doc for doc in all_documents
if (query_lower in doc.get("filename", "").lower() or
query_lower in doc.get("originalName", "").lower())
]
async def get_job_requirements(self, document_id: str) -> Optional[Dict]:
"""Get cached job requirements analysis for a document"""
try:
key = f"{KEY_PREFIXES['job_requirements']}{document_id}"
data = await self.redis.get(key)
if data:
requirements_data = self._deserialize(data)
logger.info(f"📋 Retrieved cached job requirements for document {document_id}")
return requirements_data
logger.info(f"📋 No cached job requirements found for document {document_id}")
return None
except Exception as e:
logger.error(f"❌ Error retrieving job requirements for document {document_id}: {e}")
return None
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool:
"""Save job requirements analysis results for a document"""
try:
key = f"{KEY_PREFIXES['job_requirements']}{document_id}"
# Add metadata to the requirements
requirements_with_meta = {
**requirements,
"cached_at": datetime.now(UTC).isoformat(),
"document_id": document_id
}
await self.redis.set(key, self._serialize(requirements_with_meta))
# Optional: Set expiration (e.g., 30 days) to prevent indefinite storage
# await self.redis.expire(key, 30 * 24 * 60 * 60) # 30 days
logger.info(f"📋 Saved job requirements for document {document_id}")
return True
except Exception as e:
logger.error(f"❌ Error saving job requirements for document {document_id}: {e}")
return False
async def delete_job_requirements(self, document_id: str) -> bool:
"""Delete cached job requirements for a document"""
try:
key = f"{KEY_PREFIXES['job_requirements']}{document_id}"
result = await self.redis.delete(key)
if result > 0:
logger.info(f"📋 Deleted job requirements for document {document_id}")
return True
return False
except Exception as e:
logger.error(f"❌ Error deleting job requirements for document {document_id}: {e}")
return False
async def get_all_job_requirements(self) -> Dict[str, Any]:
"""Get all cached job requirements"""
try:
pattern = f"{KEY_PREFIXES['job_requirements']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
document_id = key.replace(KEY_PREFIXES['job_requirements'], '')
if value:
result[document_id] = self._deserialize(value)
return result
except Exception as e:
logger.error(f"❌ Error retrieving all job requirements: {e}")
return {}
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]:
"""Get all job requirements analysis for documents belonging to a candidate"""
try:
# Get all documents for the candidate
candidate_documents = await self.get_candidate_documents(candidate_id)
if not candidate_documents:
return []
# Get job requirements for each document
job_requirements = []
for doc in candidate_documents:
doc_id = doc.get("id")
if doc_id:
requirements = await self.get_job_requirements(doc_id)
if requirements:
# Add document metadata to requirements
requirements["document_filename"] = doc.get("filename")
requirements["document_original_name"] = doc.get("originalName")
job_requirements.append(requirements)
return job_requirements
except Exception as e:
logger.error(f"❌ Error retrieving job requirements for candidate {candidate_id}: {e}")
return []
async def invalidate_job_requirements_cache(self, document_id: str) -> bool:
"""Invalidate (delete) cached job requirements for a document"""
# This is an alias for delete_job_requirements for semantic clarity
return await self.delete_job_requirements(document_id)
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int:
"""Delete job requirements for multiple documents and return count of deleted items"""
try:
deleted_count = 0
pipe = self.redis.pipeline()
for doc_id in document_ids:
key = f"{KEY_PREFIXES['job_requirements']}{doc_id}"
pipe.delete(key)
deleted_count += 1
results = await pipe.execute()
actual_deleted = sum(1 for result in results if result > 0)
logger.info(f"📋 Bulk deleted job requirements for {actual_deleted}/{len(document_ids)} documents")
return actual_deleted
except Exception as e:
logger.error(f"❌ Error bulk deleting job requirements: {e}")
return 0
# Viewer operations
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
"""Get viewer by ID"""
key = f"{KEY_PREFIXES['viewers']}{viewer_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_viewer(self, viewer_id: str, viewer_data: Dict):
"""Set viewer data"""
key = f"{KEY_PREFIXES['viewers']}{viewer_id}"
await self.redis.set(key, self._serialize(viewer_data))
async def get_all_viewers(self) -> Dict[str, Any]:
"""Get all viewers"""
pattern = f"{KEY_PREFIXES['viewers']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
# Use pipeline for efficiency
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
viewer_id = key.replace(KEY_PREFIXES['viewers'], '')
result[viewer_id] = self._deserialize(value)
return result
async def delete_viewer(self, viewer_id: str):
"""Delete viewer"""
key = f"{KEY_PREFIXES['viewers']}{viewer_id}"
await self.redis.delete(key)
# Employers operations
async def get_employer(self, employer_id: str) -> Optional[Dict]:
"""Get employer by ID"""
key = f"{KEY_PREFIXES['employers']}{employer_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_employer(self, employer_id: str, employer_data: Dict):
"""Set employer data"""
key = f"{KEY_PREFIXES['employers']}{employer_id}"
await self.redis.set(key, self._serialize(employer_data))
async def get_all_employers(self) -> Dict[str, Any]:
"""Get all employers"""
pattern = f"{KEY_PREFIXES['employers']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
employer_id = key.replace(KEY_PREFIXES['employers'], '')
result[employer_id] = self._deserialize(value)
return result
async def delete_employer(self, employer_id: str):
"""Delete employer"""
key = f"{KEY_PREFIXES['employers']}{employer_id}"
await self.redis.delete(key)
# Jobs operations
async def get_job(self, job_id: str) -> Optional[Dict]:
"""Get job by ID"""
key = f"{KEY_PREFIXES['jobs']}{job_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_job(self, job_id: str, job_data: Dict):
"""Set job data"""
key = f"{KEY_PREFIXES['jobs']}{job_id}"
await self.redis.set(key, self._serialize(job_data))
async def get_all_jobs(self) -> Dict[str, Any]:
"""Get all jobs"""
pattern = f"{KEY_PREFIXES['jobs']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
job_id = key.replace(KEY_PREFIXES['jobs'], '')
result[job_id] = self._deserialize(value)
return result
async def delete_job(self, job_id: str):
"""Delete job"""
key = f"{KEY_PREFIXES['jobs']}{job_id}"
await self.redis.delete(key)
# MFA and Email Verification operations
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]:
"""Find pending verification token by email address"""
try:
pattern = "email_verification:*"
cursor = 0
email_lower = email.lower()
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
verification_info = json.loads(token_data)
if (verification_info.get("email", "").lower() == email_lower and
not verification_info.get("verified", False)):
# Extract token from key
token = key.replace("email_verification:", "")
verification_info["token"] = token
return verification_info
if cursor == 0:
break
return None
except Exception as e:
logger.error(f"❌ Error finding verification token by email {email}: {e}")
return None
async def get_pending_verifications_count(self) -> int:
"""Get count of pending email verifications (admin function)"""
try:
pattern = "email_verification:*"
cursor = 0
count = 0
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
verification_info = json.loads(token_data)
if not verification_info.get("verified", False):
count += 1
if cursor == 0:
break
return count
except Exception as e:
logger.error(f"❌ Error counting pending verifications: {e}")
return 0
async def cleanup_expired_verification_tokens(self) -> int:
"""Clean up expired verification tokens and return count of cleaned tokens"""
try:
pattern = "email_verification:*"
cursor = 0
cleaned_count = 0
current_time = datetime.now(timezone.utc)
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
verification_info = json.loads(token_data)
expires_at = datetime.fromisoformat(verification_info.get("expires_at", ""))
if current_time > expires_at:
await self.redis.delete(key)
cleaned_count += 1
logger.info(f"🧹 Cleaned expired verification token for {verification_info.get('email')}")
if cursor == 0:
break
if cleaned_count > 0:
logger.info(f"🧹 Cleaned up {cleaned_count} expired verification tokens")
return cleaned_count
except Exception as e:
logger.error(f"❌ Error cleaning up expired verification tokens: {e}")
return 0
async def get_verification_attempts_count(self, email: str) -> int:
"""Get the number of verification emails sent for an email in the last 24 hours"""
try:
key = f"verification_attempts:{email.lower()}"
data = await self.redis.get(key)
if not data:
return 0
attempts_data = json.loads(data)
current_time = datetime.now(timezone.utc)
window_start = current_time - timedelta(hours=24)
# Filter out old attempts
recent_attempts = [
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > window_start
]
return len(recent_attempts)
except Exception as e:
logger.error(f"❌ Error getting verification attempts count for {email}: {e}")
return 0
async def record_verification_attempt(self, email: str) -> bool:
"""Record a verification email attempt"""
try:
key = f"verification_attempts:{email.lower()}"
current_time = datetime.now(timezone.utc)
# Get existing attempts
data = await self.redis.get(key)
attempts_data = json.loads(data) if data else []
# Add current attempt
attempts_data.append(current_time.isoformat())
# Keep only last 24 hours of attempts
window_start = current_time - timedelta(hours=24)
recent_attempts = [
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > window_start
]
# Store with 24 hour expiration
await self.redis.setex(
key,
24 * 60 * 60, # 24 hours
json.dumps(recent_attempts)
)
return True
except Exception as e:
logger.error(f"❌ Error recording verification attempt for {email}: {e}")
return False
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool:
"""Store email verification token with user data"""
try:
key = f"email_verification:{token}"
verification_data = {
"email": email.lower(),
"user_type": user_type,
"user_data": user_data,
"expires_at": (datetime.now(timezone.utc) + timedelta(hours=24)).isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(),
"verified": False
}
# Store with 24 hour expiration
await self.redis.setex(
key,
24 * 60 * 60, # 24 hours in seconds
json.dumps(verification_data, default=str)
)
logger.info(f"📧 Stored email verification token for {email}")
return True
except Exception as e:
logger.error(f"❌ Error storing email verification token: {e}")
return False
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Retrieve email verification token data"""
try:
key = f"email_verification:{token}"
data = await self.redis.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"❌ Error retrieving email verification token: {e}")
return None
async def mark_email_verified(self, token: str) -> bool:
"""Mark email verification token as used"""
try:
key = f"email_verification:{token}"
token_data = await self.get_email_verification_token(token)
if token_data:
token_data["verified"] = True
token_data["verified_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.setex(
key,
24 * 60 * 60, # Keep for remaining TTL
json.dumps(token_data, default=str)
)
return True
return False
except Exception as e:
logger.error(f"❌ Error marking email verified: {e}")
return False
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool:
"""Store MFA code for verification"""
try:
logger.info("🔐 Storing MFA code for email: %s", email )
key = f"mfa_code:{email.lower()}:{device_id}"
mfa_data = {
"code": code,
"email": email.lower(),
"device_id": device_id,
"expires_at": (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(),
"attempts": 0,
"verified": False
}
# Store with 10 minute expiration
await self.redis.setex(
key,
10 * 60, # 10 minutes in seconds
json.dumps(mfa_data, default=str)
)
logger.info(f"🔐 Stored MFA code for {email}")
return True
except Exception as e:
logger.error(f"❌ Error storing MFA code: {e}")
return False
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve MFA code data"""
try:
key = f"mfa_code:{email.lower()}:{device_id}"
data = await self.redis.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"❌ Error retrieving MFA code: {e}")
return None
async def increment_mfa_attempts(self, email: str, device_id: str) -> int:
"""Increment MFA verification attempts"""
try:
key = f"mfa_code:{email.lower()}:{device_id}"
mfa_data = await self.get_mfa_code(email, device_id)
if mfa_data:
mfa_data["attempts"] += 1
await self.redis.setex(
key,
10 * 60, # Keep original TTL
json.dumps(mfa_data, default=str)
)
return mfa_data["attempts"]
return 0
except Exception as e:
logger.error(f"❌ Error incrementing MFA attempts: {e}")
return 0
async def mark_mfa_verified(self, email: str, device_id: str) -> bool:
"""Mark MFA code as verified"""
try:
key = f"mfa_code:{email.lower()}:{device_id}"
mfa_data = await self.get_mfa_code(email, device_id)
if mfa_data:
mfa_data["verified"] = True
mfa_data["verified_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.setex(
key,
10 * 60, # Keep for remaining TTL
json.dumps(mfa_data, default=str)
)
return True
return False
except Exception as e:
logger.error(f"❌ Error marking MFA verified: {e}")
return False
# Job Applications operations
async def get_job_application(self, application_id: str) -> Optional[Dict]:
"""Get job application by ID"""
key = f"{KEY_PREFIXES['job_applications']}{application_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_job_application(self, application_id: str, application_data: Dict):
"""Set job application data"""
key = f"{KEY_PREFIXES['job_applications']}{application_id}"
await self.redis.set(key, self._serialize(application_data))
async def get_all_job_applications(self) -> Dict[str, Any]:
"""Get all job applications"""
pattern = f"{KEY_PREFIXES['job_applications']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
app_id = key.replace(KEY_PREFIXES['job_applications'], '')
result[app_id] = self._deserialize(value)
return result
async def delete_job_application(self, application_id: str):
"""Delete job application"""
key = f"{KEY_PREFIXES['job_applications']}{application_id}"
await self.redis.delete(key)
# Chat Sessions operations
async def get_chat_session(self, session_id: str) -> Optional[Dict]:
"""Get chat session by ID"""
key = f"{KEY_PREFIXES['chat_sessions']}{session_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_chat_session(self, session_id: str, session_data: Dict):
"""Set chat session data"""
key = f"{KEY_PREFIXES['chat_sessions']}{session_id}"
await self.redis.set(key, self._serialize(session_data))
async def get_all_chat_sessions(self) -> Dict[str, Any]:
"""Get all chat sessions"""
pattern = f"{KEY_PREFIXES['chat_sessions']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
session_id = key.replace(KEY_PREFIXES['chat_sessions'], '')
result[session_id] = self._deserialize(value)
return result
async def delete_chat_session(self, session_id: str) -> bool:
'''Delete a chat session from Redis'''
try:
result = await self.redis.delete(f"chat_session:{session_id}")
return result > 0
except Exception as e:
logger.error(f"Error deleting chat session {session_id}: {e}")
raise
async def delete_chat_message(self, session_id: str, message_id: str) -> bool:
'''Delete a specific chat message from Redis'''
try:
# Remove from the session's message list
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
await self.redis.lrem(key, 0, message_id)# type: ignore
# Delete the message data itself
result = await self.redis.delete(f"chat_message:{message_id}")
return result > 0
except Exception as e:
logger.error(f"Error deleting chat message {message_id}: {e}")
raise
# Chat Messages operations (stored as lists)
async def get_chat_messages(self, session_id: str) -> List[Dict]:
"""Get chat messages for a session"""
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
messages = await self.redis.lrange(key, 0, -1)# type: ignore
return [self._deserialize(msg) for msg in messages if msg]
async def add_chat_message(self, session_id: str, message_data: Dict):
"""Add a chat message to a session"""
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
await self.redis.rpush(key, self._serialize(message_data))# type: ignore
async def set_chat_messages(self, session_id: str, messages: List[Dict]):
"""Set all chat messages for a session (replaces existing)"""
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
# Clear existing messages
await self.redis.delete(key)
# Add new messages
if messages:
serialized_messages = [self._serialize(msg) for msg in messages]
await self.redis.rpush(key, *serialized_messages)# type: ignore
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]:
"""Get all chat messages grouped by session"""
pattern = f"{KEY_PREFIXES['chat_messages']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
result = {}
for key in keys:
session_id = key.replace(KEY_PREFIXES['chat_messages'], '')
messages = await self.redis.lrange(key, 0, -1)# type: ignore
result[session_id] = [self._deserialize(msg) for msg in messages if msg]
return result
async def delete_chat_messages(self, session_id: str):
"""Delete all chat messages for a session"""
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
await self.redis.delete(key)
# Enhanced Chat Session Methods
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]:
"""Get all chat sessions for a specific user"""
all_sessions = await self.get_all_chat_sessions()
user_sessions = []
for session_data in all_sessions.values():
if session_data.get("userId") == user_id or session_data.get("guestId") == user_id:
user_sessions.append(session_data)
# Sort by last activity (most recent first)
user_sessions.sort(key=lambda x: x.get("lastActivity", ""), reverse=True)
return user_sessions
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]:
"""Get all chat sessions related to a specific candidate"""
all_sessions = await self.get_all_chat_sessions()
candidate_sessions = []
for session_data in all_sessions.values():
context = session_data.get("context", {})
if (context.get("relatedEntityType") == "candidate" and
context.get("relatedEntityId") == candidate_id):
candidate_sessions.append(session_data)
# Sort by last activity (most recent first)
candidate_sessions.sort(key=lambda x: x.get("lastActivity", ""), reverse=True)
return candidate_sessions
async def update_chat_session_activity(self, session_id: str):
"""Update the last activity timestamp for a chat session"""
session_data = await self.get_chat_session(session_id)
if session_data:
session_data["lastActivity"] = datetime.now(UTC).isoformat()
await self.set_chat_session(session_id, session_data)
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]:
"""Get the most recent chat messages for a session"""
messages = await self.get_chat_messages(session_id)
# Return the last 'limit' messages
return messages[-limit:] if len(messages) > limit else messages
async def get_chat_message_count(self, session_id: str) -> int:
"""Get the total number of messages in a chat session"""
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
return await self.redis.llen(key)# type: ignore
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]:
"""Search for messages containing specific text in a session"""
messages = await self.get_chat_messages(session_id)
query_lower = query.lower()
matching_messages = []
for msg in messages:
content = msg.get("content", "").lower()
if query_lower in content:
matching_messages.append(msg)
return matching_messages
# Chat Session Management
async def archive_chat_session(self, session_id: str):
"""Archive a chat session"""
session_data = await self.get_chat_session(session_id)
if session_data:
session_data["isArchived"] = True
session_data["updatedAt"] = datetime.now(UTC).isoformat()
await self.set_chat_session(session_id, session_data)
async def delete_chat_session_completely(self, session_id: str):
"""Delete a chat session and all its messages"""
# Delete the session
await self.delete_chat_session(session_id)
# Delete all messages
await self.delete_chat_messages(session_id)
async def cleanup_old_chat_sessions(self, days_old: int = 90):
"""Archive or delete chat sessions older than specified days"""
cutoff_date = datetime.now(UTC) - timedelta(days=days_old)
cutoff_iso = cutoff_date.isoformat()
all_sessions = await self.get_all_chat_sessions()
archived_count = 0
for session_id, session_data in all_sessions.items():
last_activity = session_data.get("lastActivity", session_data.get("createdAt", ""))
if last_activity < cutoff_iso and not session_data.get("isArchived", False):
await self.archive_chat_session(session_id)
archived_count += 1
return archived_count
# Analytics and Reporting
async def get_chat_statistics(self) -> Dict[str, Any]:
"""Get comprehensive chat statistics"""
all_sessions = await self.get_all_chat_sessions()
all_messages = await self.get_all_chat_messages()
stats = {
"total_sessions": len(all_sessions),
"total_messages": sum(len(messages) for messages in all_messages.values()),
"active_sessions": 0,
"archived_sessions": 0,
"sessions_by_type": {},
"sessions_with_candidates": 0,
"average_messages_per_session": 0
}
# Analyze sessions
for session_data in all_sessions.values():
if session_data.get("isArchived", False):
stats["archived_sessions"] += 1
else:
stats["active_sessions"] += 1
# Count by type
context_type = session_data.get("context", {}).get("type", "unknown")
stats["sessions_by_type"][context_type] = stats["sessions_by_type"].get(context_type, 0) + 1
# Count sessions with candidate association
if session_data.get("context", {}).get("relatedEntityType") == "candidate":
stats["sessions_with_candidates"] += 1
# Calculate averages
if stats["total_sessions"] > 0:
stats["average_messages_per_session"] = stats["total_messages"] / stats["total_sessions"]
return stats
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]:
"""Get a summary of chat activity for a specific candidate"""
sessions = await self.get_chat_sessions_by_candidate(candidate_id)
if not sessions:
return {
"candidate_id": candidate_id,
"total_sessions": 0,
"total_messages": 0,
"first_chat": None,
"last_chat": None
}
total_messages = 0
for session in sessions:
session_id = session.get("id")
if session_id:
message_count = await self.get_chat_message_count(session_id)
total_messages += message_count
# Sort sessions by creation date
sessions_by_date = sorted(sessions, key=lambda x: x.get("createdAt", ""))
return {
"candidate_id": candidate_id,
"total_sessions": len(sessions),
"total_messages": total_messages,
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
"recent_sessions": sessions[:5] # Last 5 sessions
}
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]):
"""Bulk update multiple chat sessions"""
pipe = self.redis.pipeline()
for session_id, updates in session_updates.items():
session_data = await self.get_chat_session(session_id)
if session_data:
session_data.update(updates)
session_data["updatedAt"] = datetime.now(UTC).isoformat()
key = f"{KEY_PREFIXES['chat_sessions']}{session_id}"
pipe.set(key, self._serialize(session_data))
await pipe.execute()
# AI Parameters operations
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
"""Get AI parameters by ID"""
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_ai_parameters(self, param_id: str, param_data: Dict):
"""Set AI parameters data"""
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
await self.redis.set(key, self._serialize(param_data))
async def get_all_ai_parameters(self) -> Dict[str, Any]:
"""Get all AI parameters"""
pattern = f"{KEY_PREFIXES['ai_parameters']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
param_id = key.replace(KEY_PREFIXES['ai_parameters'], '')
result[param_id] = self._deserialize(value)
return result
async def delete_ai_parameters(self, param_id: str):
"""Delete AI parameters"""
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
await self.redis.delete(key)
async def get_all_users(self) -> Dict[str, Any]:
"""Get all users"""
pattern = f"{KEY_PREFIXES['users']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
email = key.replace(KEY_PREFIXES['users'], '')
logger.info(f"🔍 Found user key: {key}, type: {type(value)}")
if type(value) == str:
result[email] = value
else:
result[email] = self._deserialize(value)
return result
async def delete_user(self, email: str):
"""Delete user"""
key = f"{KEY_PREFIXES['users']}{email}"
await self.redis.delete(key)
async def get_job_requirements_stats(self) -> Dict[str, Any]:
"""Get statistics about cached job requirements"""
try:
pattern = f"{KEY_PREFIXES['job_requirements']}*"
keys = await self.redis.keys(pattern)
stats = {
"total_cached_requirements": len(keys),
"cache_dates": {},
"documents_with_requirements": []
}
if keys:
# Get cache dates for analysis
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
for key, value in zip(keys, values):
if value:
requirements_data = self._deserialize(value)
if requirements_data:
document_id = key.replace(KEY_PREFIXES['job_requirements'], '')
stats["documents_with_requirements"].append(document_id)
# Track cache dates
cached_at = requirements_data.get("cached_at")
if cached_at:
cache_date = cached_at[:10] # Extract date part
stats["cache_dates"][cache_date] = stats["cache_dates"].get(cache_date, 0) + 1
return stats
except Exception as e:
logger.error(f"❌ Error getting job requirements stats: {e}")
return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []}
async def cleanup_orphaned_job_requirements(self) -> int:
"""Clean up job requirements for documents that no longer exist"""
try:
# Get all job requirements
all_requirements = await self.get_all_job_requirements()
if not all_requirements:
return 0
orphaned_count = 0
pipe = self.redis.pipeline()
for document_id in all_requirements.keys():
# Check if the document still exists
document_exists = await self.get_document(document_id)
if not document_exists:
# Document no longer exists, delete its job requirements
key = f"{KEY_PREFIXES['job_requirements']}{document_id}"
pipe.delete(key)
orphaned_count += 1
logger.info(f"📋 Queued orphaned job requirements for deletion: {document_id}")
if orphaned_count > 0:
await pipe.execute()
logger.info(f"🧹 Cleaned up {orphaned_count} orphaned job requirements")
return orphaned_count
except Exception as e:
logger.error(f"❌ Error cleaning up orphaned job requirements: {e}")
return 0
# Utility methods
# Authentication Record Methods
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool:
"""Store authentication record for a user"""
try:
key = f"auth:{user_id}"
await self.redis.set(key, json.dumps(auth_data, default=str))
logger.info(f"🔐 Stored authentication record for user {user_id}")
return True
except Exception as e:
logger.error(f"❌ Error storing authentication record for {user_id}: {e}")
return False
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve authentication record for a user"""
try:
key = f"auth:{user_id}"
data = await self.redis.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"❌ Error retrieving authentication record for {user_id}: {e}")
return None
async def delete_authentication(self, user_id: str) -> bool:
"""Delete authentication record for a user"""
try:
key = f"auth:{user_id}"
result = await self.redis.delete(key)
logger.info(f"🔐 Deleted authentication record for user {user_id}")
return result > 0
except Exception as e:
logger.error(f"❌ Error deleting authentication record for {user_id}: {e}")
return False
# Enhanced User Methods
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool:
"""Store user data with ID as key for direct lookup"""
try:
key = f"user_by_id:{user_id}"
await self.redis.set(key, json.dumps(user_data, default=str))
logger.info(f"👤 Stored user data by ID for {user_id}")
return True
except Exception as e:
logger.error(f"❌ Error storing user by ID {user_id}: {e}")
return False
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get user lookup data by user ID"""
try:
data = await self.redis.hget("user_lookup_by_id", user_id)# type: ignore
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"❌ Error getting user by ID {user_id}: {e}")
return None
async def user_exists_by_email(self, email: str) -> bool:
"""Check if a user exists with the given email"""
try:
key = f"users:{email.lower()}"
exists = await self.redis.exists(key)
return exists > 0
except Exception as e:
logger.error(f"❌ Error checking user existence by email {email}: {e}")
return False
async def user_exists_by_username(self, username: str) -> bool:
"""Check if a user exists with the given username"""
try:
key = f"users:{username.lower()}"
exists = await self.redis.exists(key)
return exists > 0
except Exception as e:
logger.error(f"❌ Error checking user existence by username {username}: {e}")
return False
# Enhanced user lookup method to support both email and username
async def get_user(self, login: str) -> Optional[Dict[str, Any]]:
"""
Get user by email or username
"""
try:
# Normalize the login string
login = login.strip().lower()
key = f"users:{login}"
data = await self.redis.get(key)
if data:
user_data = json.loads(data)
logger.info(f"👤 Retrieved user data for {login}")
return user_data
logger.info(f"👤 No user found for {login}")
return None
except Exception as e:
logger.error(f"❌ Error retrieving user {login}: {e}")
return None
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool:
"""
Enhanced method to store user data by email or username
Updated version of your existing method
"""
try:
# Normalize the login string
login = login.strip().lower()
key = f"users:{login}"
await self.redis.set(key, json.dumps(user_data, default=str))
logger.info(f"👤 Stored user data for {login}")
return True
except Exception as e:
logger.error(f"❌ Error storing user {login}: {e}")
return False
# Token Management Methods
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool:
"""Store refresh token for a user"""
try:
key = f"refresh_token:{token}"
token_data = {
"user_id": user_id,
"expires_at": expires_at.isoformat(),
"device": device_info.get("device", "unknown"),
"ip_address": device_info.get("ip_address", "unknown"),
"is_revoked": False,
"created_at": datetime.now(timezone.utc).isoformat()
}
# Store with expiration
ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
if ttl_seconds > 0:
await self.redis.setex(key, ttl_seconds, json.dumps(token_data, default=str))
logger.info(f"🔐 Stored refresh token for user {user_id}")
return True
else:
logger.warning(f"⚠️ Attempted to store expired refresh token for user {user_id}")
return False
except Exception as e:
logger.error(f"❌ Error storing refresh token for {user_id}: {e}")
return False
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Retrieve refresh token data"""
try:
key = f"refresh_token:{token}"
data = await self.redis.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"❌ Error retrieving refresh token: {e}")
return None
async def revoke_refresh_token(self, token: str) -> bool:
"""Revoke a refresh token"""
try:
key = f"refresh_token:{token}"
token_data = await self.get_refresh_token(token)
if token_data:
token_data["is_revoked"] = True
token_data["revoked_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.set(key, json.dumps(token_data, default=str))
logger.info(f"🔐 Revoked refresh token")
return True
return False
except Exception as e:
logger.error(f"❌ Error revoking refresh token: {e}")
return False
async def revoke_all_user_tokens(self, user_id: str) -> bool:
"""Revoke all refresh tokens for a user"""
try:
# This requires scanning all refresh tokens - consider using a user token index for efficiency
pattern = "refresh_token:*"
cursor = 0
revoked_count = 0
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
token_info = json.loads(token_data)
if token_info.get("user_id") == user_id and not token_info.get("is_revoked"):
token_info["is_revoked"] = True
token_info["revoked_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.set(key, json.dumps(token_info, default=str))
revoked_count += 1
if cursor == 0:
break
logger.info(f"🔐 Revoked {revoked_count} refresh tokens for user {user_id}")
return True
except Exception as e:
logger.error(f"❌ Error revoking all tokens for user {user_id}: {e}")
return False
# Password Reset Token Methods
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool:
"""Store password reset token"""
try:
key = f"password_reset:{token}"
token_data = {
"email": email.lower(),
"expires_at": expires_at.isoformat(),
"used": False,
"created_at": datetime.now(timezone.utc).isoformat()
}
# Store with expiration
ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
if ttl_seconds > 0:
await self.redis.setex(key, ttl_seconds, json.dumps(token_data, default=str))
logger.info(f"🔐 Stored password reset token for {email}")
return True
else:
logger.warning(f"⚠️ Attempted to store expired password reset token for {email}")
return False
except Exception as e:
logger.error(f"❌ Error storing password reset token for {email}: {e}")
return False
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Retrieve password reset token data"""
try:
key = f"password_reset:{token}"
data = await self.redis.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"❌ Error retrieving password reset token: {e}")
return None
async def mark_password_reset_token_used(self, token: str) -> bool:
"""Mark password reset token as used"""
try:
key = f"password_reset:{token}"
token_data = await self.get_password_reset_token(token)
if token_data:
token_data["used"] = True
token_data["used_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.set(key, json.dumps(token_data, default=str))
logger.info(f"🔐 Marked password reset token as used")
return True
return False
except Exception as e:
logger.error(f"❌ Error marking password reset token as used: {e}")
return False
# User Activity and Security Logging
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool:
"""Log security events for audit purposes"""
try:
key = f"security_log:{user_id}:{datetime.now(timezone.utc).strftime('%Y-%m-%d')}"
event_data = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"user_id": user_id,
"event_type": event_type,
"details": details
}
# Add to list (latest events first)
await self.redis.lpush(key, json.dumps(event_data, default=str))# type: ignore
# Keep only last 100 events per day
await self.redis.ltrim(key, 0, 99)# type: ignore
# Set expiration for 30 days
await self.redis.expire(key, 30 * 24 * 60 * 60)
logger.info(f"🔒 Logged security event {event_type} for user {user_id}")
return True
except Exception as e:
logger.error(f"❌ Error logging security event for {user_id}: {e}")
return False
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]:
"""Retrieve security log for a user"""
try:
events = []
for i in range(days):
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime('%Y-%m-%d')
key = f"security_log:{user_id}:{date}"
daily_events = await self.redis.lrange(key, 0, -1)# type: ignore
for event_json in daily_events:
events.append(json.loads(event_json))
# Sort by timestamp (most recent first)
events.sort(key=lambda x: x["timestamp"], reverse=True)
return events
except Exception as e:
logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
return []
# ============================
# Guest Management Methods
# ============================
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None:
"""Store guest data with enhanced persistence"""
try:
# Ensure last_activity is always set
guest_data["last_activity"] = datetime.now(UTC).isoformat()
# Store in Redis with both hash and individual key for redundancy
await self.redis.hset("guests", guest_id, json.dumps(guest_data))# type: ignore
# Also store with a longer TTL as backup
await self.redis.setex(
f"guest_backup:{guest_id}",
86400 * 7, # 7 days TTL
json.dumps(guest_data)
)
logger.info(f"💾 Guest stored with backup: {guest_id}")
except Exception as e:
logger.error(f"❌ Error storing guest {guest_id}: {e}")
raise
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]:
"""Get guest data with fallback to backup"""
try:
# Try primary storage first
data = await self.redis.hget("guests", guest_id)# type: ignore
if data:
guest_data = json.loads(data)
# Update last activity when accessed
guest_data["last_activity"] = datetime.now(UTC).isoformat()
await self.set_guest(guest_id, guest_data)
logger.info(f"🔍 Guest found in primary storage: {guest_id}")
return guest_data
# Fallback to backup storage
backup_data = await self.redis.get(f"guest_backup:{guest_id}")
if backup_data:
guest_data = json.loads(backup_data)
guest_data["last_activity"] = datetime.now(UTC).isoformat()
# Restore to primary storage
await self.set_guest(guest_id, guest_data)
logger.info(f"🔄 Guest restored from backup: {guest_id}")
return guest_data
logger.warning(f"⚠️ Guest not found: {guest_id}")
return None
except Exception as e:
logger.error(f"❌ Error getting guest {guest_id}: {e}")
return None
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get guest data by session ID"""
try:
all_guests = await self.get_all_guests()
for guest_data in all_guests.values():
if guest_data.get("session_id") == session_id:
return guest_data
return None
except Exception as e:
logger.error(f"❌ Error getting guest by session ID {session_id}: {e}")
return None
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
"""Get all guests"""
try:
data = await self.redis.hgetall("guests")# type: ignore
return {
guest_id: json.loads(guest_json)
for guest_id, guest_json in data.items()
}
except Exception as e:
logger.error(f"❌ Error getting all guests: {e}")
return {}
async def delete_guest(self, guest_id: str) -> bool:
"""Delete a guest"""
try:
result = await self.redis.hdel("guests", guest_id)# type: ignore
if result:
logger.info(f"🗑️ Guest deleted: {guest_id}")
return True
return False
except Exception as e:
logger.error(f"❌ Error deleting guest {guest_id}: {e}")
return False
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int:
"""Clean up inactive guest sessions with safety checks"""
try:
all_guests = await self.get_all_guests()
current_time = datetime.now(UTC)
cutoff_time = current_time - timedelta(hours=inactive_hours)
deleted_count = 0
preserved_count = 0
for guest_id, guest_data in all_guests.items():
try:
last_activity_str = guest_data.get("last_activity")
created_at_str = guest_data.get("created_at")
# Skip cleanup if guest is very new (less than 1 hour old)
if created_at_str:
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
if current_time - created_at < timedelta(hours=1):
preserved_count += 1
logger.info(f"🛡️ Preserving new guest: {guest_id}")
continue
# Check last activity
should_delete = False
if last_activity_str:
try:
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
if last_activity < cutoff_time:
should_delete = True
except ValueError:
# Invalid date format, but don't delete if guest is new
if not created_at_str:
should_delete = True
else:
# No last activity, but don't delete if guest is new
if not created_at_str:
should_delete = True
if should_delete:
await self.delete_guest(guest_id)
deleted_count += 1
else:
preserved_count += 1
except Exception as e:
logger.error(f"❌ Error processing guest {guest_id} for cleanup: {e}")
preserved_count += 1 # Preserve on error
if deleted_count > 0:
logger.info(f"🧹 Guest cleanup: removed {deleted_count}, preserved {preserved_count}")
return deleted_count
except Exception as e:
logger.error(f"❌ Error in guest cleanup: {e}")
return 0
async def get_guest_statistics(self) -> Dict[str, Any]:
"""Get guest usage statistics"""
try:
all_guests = await self.get_all_guests()
current_time = datetime.now(UTC)
stats = {
"total_guests": len(all_guests),
"active_last_hour": 0,
"active_last_day": 0,
"converted_guests": 0,
"by_ip": {},
"creation_timeline": {}
}
hour_ago = current_time - timedelta(hours=1)
day_ago = current_time - timedelta(days=1)
for guest_data in all_guests.values():
# Check activity
last_activity_str = guest_data.get("last_activity")
if last_activity_str:
try:
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
if last_activity > hour_ago:
stats["active_last_hour"] += 1
if last_activity > day_ago:
stats["active_last_day"] += 1
except ValueError:
pass
# Check conversions
if guest_data.get("converted_to_user_id"):
stats["converted_guests"] += 1
# IP tracking
ip = guest_data.get("ip_address", "unknown")
stats["by_ip"][ip] = stats["by_ip"].get(ip, 0) + 1
# Creation timeline
created_at_str = guest_data.get("created_at")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
date_key = created_at.strftime('%Y-%m-%d')
stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1
except ValueError:
pass
return stats
except Exception as e:
logger.error(f"❌ Error getting guest statistics: {e}")
return {}
# Global Redis manager instance
redis_manager = _RedisManager()
class DatabaseManager:
"""Enhanced database manager with graceful shutdown capabilities"""
def __init__(self):
self.db: Optional[RedisDatabase] = None
self._shutdown_initiated = False
self._active_requests = 0
self._shutdown_timeout = int(os.getenv("SHUTDOWN_TIMEOUT", "30")) # seconds
self._backup_on_shutdown = os.getenv("BACKUP_ON_SHUTDOWN", "false").lower() == "true"
async def initialize(self):
"""Initialize database connection"""
try:
# Connect to Redis
await redis_manager.connect()
logger.info("Redis connection established")
# Create database instance
self.db = RedisDatabase(redis_manager.get_client())
# Test connection and log stats
if not redis_manager.redis:
raise RuntimeError("Redis client not initialized")
await redis_manager.redis.ping()
stats = await self.db.get_stats()
logger.info(f"Database initialized successfully. Stats: {stats}")
return self.db
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
raise
async def backup_data(self) -> Optional[str]:
"""Create a backup of critical data before shutdown"""
if not self.db:
return None
try:
backup_data = {
"timestamp": datetime.now(UTC).isoformat(),
"stats": await self.db.get_stats(),
"users": await self.db.get_all_users(),
# Add other critical data as needed
}
backup_filename = f"backup_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}.json"
# Save to local file (you might want to save to cloud storage instead)
with open(backup_filename, 'w') as f:
json.dump(backup_data, f, indent=2, default=str)
logger.info(f"Backup created: {backup_filename}")
return backup_filename
except Exception as e:
logger.error(f"Backup failed: {e}")
return None
async def graceful_shutdown(self):
"""Perform graceful shutdown with optional backup"""
self._shutdown_initiated = True
logger.info("Initiating graceful shutdown...")
# Wait for active requests to complete (with timeout)
wait_time = 0
while self._active_requests > 0 and wait_time < self._shutdown_timeout:
logger.info(f"Waiting for {self._active_requests} active requests to complete...")
await asyncio.sleep(1)
wait_time += 1
if self._active_requests > 0:
logger.warning(f"Shutdown timeout reached. {self._active_requests} requests may be interrupted.")
# Create backup if configured
if self._backup_on_shutdown:
backup_file = await self.backup_data()
if backup_file:
logger.info(f"Pre-shutdown backup completed: {backup_file}")
# Force Redis to save data to disk
try:
if redis_manager.redis:
# Try BGSAVE first (non-blocking)
try:
await redis_manager.redis.bgsave()
logger.info("Background save initiated")
# Wait a bit for background save to start
await asyncio.sleep(0.5)
except Exception as e:
logger.warning(f"Background save failed, trying synchronous save: {e}")
try:
# Fallback to synchronous save
await redis_manager.redis.save()
logger.info("Synchronous save completed")
except Exception as e2:
logger.warning(f"Synchronous save also failed (Redis persistence may be disabled): {e2}")
except Exception as e:
logger.error(f"Error during Redis save: {e}")
# Close Redis connection
try:
await redis_manager.disconnect()
logger.info("Redis connection closed successfully")
except Exception as e:
logger.error(f"Error closing Redis connection: {e}")
logger.info("Graceful shutdown completed")
def increment_requests(self):
"""Track active requests"""
self._active_requests += 1
def decrement_requests(self):
"""Track completed requests"""
self._active_requests = max(0, self._active_requests - 1)
@property
def is_shutting_down(self) -> bool:
"""Check if shutdown is in progress"""
return self._shutdown_initiated
def get_database(self) -> RedisDatabase:
"""Get database instance"""
if self.db is None:
raise RuntimeError("Database not initialized")
if self._shutdown_initiated:
raise RuntimeError("Application is shutting down")
return self.db