1199 lines
46 KiB
Python
1199 lines
46 KiB
Python
import redis.asyncio as redis # type: ignore
|
|
from typing import Optional, Dict, List, Optional, Any
|
|
import json
|
|
import logging
|
|
import os
|
|
from datetime import datetime, timezone, UTC, timedelta
|
|
import asyncio
|
|
from models import (
|
|
# User models
|
|
Candidate, Employer, BaseUser, Guest, Authentication, AuthResponse,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class _RedisManager:
|
|
def __init__(self):
|
|
self.redis: Optional[redis.Redis] = None
|
|
self.redis_url = os.getenv("REDIS_URL", "redis://redis:6379")
|
|
self._connection_pool: Optional[redis.ConnectionPool] = None
|
|
self._is_connected = False
|
|
|
|
async def connect(self):
|
|
"""Initialize Redis connection with connection pooling"""
|
|
if self._is_connected and self.redis:
|
|
logger.info("Redis already connected")
|
|
return
|
|
|
|
try:
|
|
# Create connection pool for better resource management
|
|
self._connection_pool = redis.ConnectionPool.from_url(
|
|
self.redis_url,
|
|
encoding="utf-8",
|
|
decode_responses=True,
|
|
max_connections=20,
|
|
retry_on_timeout=True,
|
|
socket_keepalive=True,
|
|
socket_keepalive_options={},
|
|
health_check_interval=30
|
|
)
|
|
|
|
self.redis = redis.Redis(
|
|
connection_pool=self._connection_pool
|
|
)
|
|
|
|
if not self.redis:
|
|
raise RuntimeError("Redis client not initialized")
|
|
|
|
# Test connection
|
|
await self.redis.ping()
|
|
self._is_connected = True
|
|
logger.info("Successfully connected to Redis")
|
|
|
|
# Log Redis info
|
|
info = await self.redis.info()
|
|
logger.info(f"Redis version: {info.get('redis_version', 'unknown')}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to Redis: {e}")
|
|
self._is_connected = False
|
|
self.redis = None
|
|
self._connection_pool = None
|
|
raise
|
|
|
|
async def disconnect(self):
|
|
"""Close Redis connection gracefully"""
|
|
if not self._is_connected:
|
|
logger.info("Redis already disconnected")
|
|
return
|
|
|
|
try:
|
|
if self.redis:
|
|
# Wait for any pending operations to complete
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Close the client
|
|
await self.redis.aclose()
|
|
logger.info("Redis client closed")
|
|
|
|
if self._connection_pool:
|
|
# Close the connection pool
|
|
await self._connection_pool.aclose()
|
|
logger.info("Redis connection pool closed")
|
|
|
|
self._is_connected = False
|
|
self.redis = None
|
|
self._connection_pool = None
|
|
|
|
logger.info("Successfully disconnected from Redis")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during Redis disconnect: {e}")
|
|
# Force cleanup even if there's an error
|
|
self._is_connected = False
|
|
self.redis = None
|
|
self._connection_pool = None
|
|
|
|
def get_client(self) -> redis.Redis:
|
|
"""Get Redis client instance"""
|
|
if not self._is_connected or not self.redis:
|
|
raise RuntimeError("Redis client not initialized or disconnected")
|
|
return self.redis
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
"""Check if Redis is connected"""
|
|
return self._is_connected and self.redis is not None
|
|
|
|
async def health_check(self) -> dict:
|
|
"""Perform health check on Redis connection"""
|
|
if not self.is_connected:
|
|
return {"status": "disconnected", "error": "Redis not connected"}
|
|
|
|
if not self.redis:
|
|
raise RuntimeError("Redis client not initialized")
|
|
|
|
try:
|
|
# Test basic operations
|
|
await self.redis.ping()
|
|
info = await self.redis.info()
|
|
|
|
return {
|
|
"status": "healthy",
|
|
"redis_version": info.get("redis_version", "unknown"),
|
|
"uptime_seconds": info.get("uptime_in_seconds", 0),
|
|
"connected_clients": info.get("connected_clients", 0),
|
|
"used_memory_human": info.get("used_memory_human", "unknown"),
|
|
"total_commands_processed": info.get("total_commands_processed", 0)
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Redis health check failed: {e}")
|
|
return {"status": "error", "error": str(e)}
|
|
|
|
async def force_save(self, background: bool = True) -> bool:
|
|
"""Force Redis to save data to disk"""
|
|
if not self.is_connected:
|
|
logger.warning("Cannot save: Redis not connected")
|
|
return False
|
|
|
|
try:
|
|
if not self.redis:
|
|
raise RuntimeError("Redis client not initialized")
|
|
|
|
if background:
|
|
# Non-blocking background save
|
|
await self.redis.bgsave()
|
|
logger.info("Background save initiated")
|
|
else:
|
|
# Blocking save
|
|
await self.redis.save()
|
|
logger.info("Synchronous save completed")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Redis save failed: {e}")
|
|
return False
|
|
|
|
async def get_info(self) -> Optional[dict]:
|
|
"""Get Redis server information"""
|
|
if not self.is_connected:
|
|
return None
|
|
|
|
try:
|
|
if not self.redis:
|
|
raise RuntimeError("Redis client not initialized")
|
|
return await self.redis.info()
|
|
except Exception as e:
|
|
logger.error(f"Failed to get Redis info: {e}")
|
|
return None
|
|
|
|
class RedisDatabase:
|
|
def __init__(self, redis: redis.Redis):
|
|
self.redis = redis
|
|
|
|
# Redis key prefixes for different data types
|
|
self.KEY_PREFIXES = {
|
|
'viewers': 'viewer:',
|
|
'candidates': 'candidate:',
|
|
'employers': 'employer:',
|
|
'jobs': 'job:',
|
|
'job_applications': 'job_application:',
|
|
'chat_sessions': 'chat_session:',
|
|
'chat_messages': 'chat_messages:', # This will store lists
|
|
'ai_parameters': 'ai_parameters:',
|
|
'users': 'user:',
|
|
}
|
|
|
|
def _serialize(self, data: Any) -> str:
|
|
"""Serialize data to JSON string for Redis storage"""
|
|
if data is None:
|
|
return ""
|
|
return json.dumps(data, default=str) # default=str handles datetime objects
|
|
|
|
def _deserialize(self, data: str) -> Any:
|
|
"""Deserialize JSON string from Redis"""
|
|
if not data:
|
|
return None
|
|
try:
|
|
return json.loads(data)
|
|
except json.JSONDecodeError:
|
|
logger.error(f"Failed to deserialize data: {data}")
|
|
return None
|
|
|
|
# Viewer operations
|
|
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
|
|
"""Get viewer by ID"""
|
|
key = f"{self.KEY_PREFIXES['viewers']}{viewer_id}"
|
|
data = await self.redis.get(key)
|
|
return self._deserialize(data) if data else None
|
|
|
|
async def set_viewer(self, viewer_id: str, viewer_data: Dict):
|
|
"""Set viewer data"""
|
|
key = f"{self.KEY_PREFIXES['viewers']}{viewer_id}"
|
|
await self.redis.set(key, self._serialize(viewer_data))
|
|
|
|
async def get_all_viewers(self) -> Dict[str, Any]:
|
|
"""Get all viewers"""
|
|
pattern = f"{self.KEY_PREFIXES['viewers']}*"
|
|
keys = await self.redis.keys(pattern)
|
|
|
|
if not keys:
|
|
return {}
|
|
|
|
# Use pipeline for efficiency
|
|
pipe = self.redis.pipeline()
|
|
for key in keys:
|
|
pipe.get(key)
|
|
values = await pipe.execute()
|
|
|
|
result = {}
|
|
for key, value in zip(keys, values):
|
|
viewer_id = key.replace(self.KEY_PREFIXES['viewers'], '')
|
|
result[viewer_id] = self._deserialize(value)
|
|
|
|
return result
|
|
|
|
async def delete_viewer(self, viewer_id: str):
|
|
"""Delete viewer"""
|
|
key = f"{self.KEY_PREFIXES['viewers']}{viewer_id}"
|
|
await self.redis.delete(key)
|
|
|
|
# Candidates operations
|
|
async def get_candidate(self, candidate_id: str) -> Optional[Dict]:
|
|
"""Get candidate by ID"""
|
|
key = f"{self.KEY_PREFIXES['candidates']}{candidate_id}"
|
|
data = await self.redis.get(key)
|
|
return self._deserialize(data) if data else None
|
|
|
|
async def set_candidate(self, candidate_id: str, candidate_data: Dict):
|
|
"""Set candidate data"""
|
|
key = f"{self.KEY_PREFIXES['candidates']}{candidate_id}"
|
|
await self.redis.set(key, self._serialize(candidate_data))
|
|
|
|
async def get_all_candidates(self) -> Dict[str, Any]:
|
|
"""Get all candidates"""
|
|
pattern = f"{self.KEY_PREFIXES['candidates']}*"
|
|
keys = await self.redis.keys(pattern)
|
|
|
|
if not keys:
|
|
return {}
|
|
|
|
# Use pipeline for efficiency
|
|
pipe = self.redis.pipeline()
|
|
for key in keys:
|
|
pipe.get(key)
|
|
values = await pipe.execute()
|
|
|
|
result = {}
|
|
for key, value in zip(keys, values):
|
|
candidate_id = key.replace(self.KEY_PREFIXES['candidates'], '')
|
|
result[candidate_id] = self._deserialize(value)
|
|
|
|
return result
|
|
|
|
async def delete_candidate(self, candidate_id: str):
|
|
"""Delete candidate"""
|
|
key = f"{self.KEY_PREFIXES['candidates']}{candidate_id}"
|
|
await self.redis.delete(key)
|
|
|
|
# Employers operations
|
|
async def get_employer(self, employer_id: str) -> Optional[Dict]:
|
|
"""Get employer by ID"""
|
|
key = f"{self.KEY_PREFIXES['employers']}{employer_id}"
|
|
data = await self.redis.get(key)
|
|
return self._deserialize(data) if data else None
|
|
|
|
async def set_employer(self, employer_id: str, employer_data: Dict):
|
|
"""Set employer data"""
|
|
key = f"{self.KEY_PREFIXES['employers']}{employer_id}"
|
|
await self.redis.set(key, self._serialize(employer_data))
|
|
|
|
async def get_all_employers(self) -> Dict[str, Any]:
|
|
"""Get all employers"""
|
|
pattern = f"{self.KEY_PREFIXES['employers']}*"
|
|
keys = await self.redis.keys(pattern)
|
|
|
|
if not keys:
|
|
return {}
|
|
|
|
pipe = self.redis.pipeline()
|
|
for key in keys:
|
|
pipe.get(key)
|
|
values = await pipe.execute()
|
|
|
|
result = {}
|
|
for key, value in zip(keys, values):
|
|
employer_id = key.replace(self.KEY_PREFIXES['employers'], '')
|
|
result[employer_id] = self._deserialize(value)
|
|
|
|
return result
|
|
|
|
async def delete_employer(self, employer_id: str):
|
|
"""Delete employer"""
|
|
key = f"{self.KEY_PREFIXES['employers']}{employer_id}"
|
|
await self.redis.delete(key)
|
|
|
|
# Jobs operations
|
|
async def get_job(self, job_id: str) -> Optional[Dict]:
|
|
"""Get job by ID"""
|
|
key = f"{self.KEY_PREFIXES['jobs']}{job_id}"
|
|
data = await self.redis.get(key)
|
|
return self._deserialize(data) if data else None
|
|
|
|
async def set_job(self, job_id: str, job_data: Dict):
|
|
"""Set job data"""
|
|
key = f"{self.KEY_PREFIXES['jobs']}{job_id}"
|
|
await self.redis.set(key, self._serialize(job_data))
|
|
|
|
async def get_all_jobs(self) -> Dict[str, Any]:
|
|
"""Get all jobs"""
|
|
pattern = f"{self.KEY_PREFIXES['jobs']}*"
|
|
keys = await self.redis.keys(pattern)
|
|
|
|
if not keys:
|
|
return {}
|
|
|
|
pipe = self.redis.pipeline()
|
|
for key in keys:
|
|
pipe.get(key)
|
|
values = await pipe.execute()
|
|
|
|
result = {}
|
|
for key, value in zip(keys, values):
|
|
job_id = key.replace(self.KEY_PREFIXES['jobs'], '')
|
|
result[job_id] = self._deserialize(value)
|
|
|
|
return result
|
|
|
|
async def delete_job(self, job_id: str):
|
|
"""Delete job"""
|
|
key = f"{self.KEY_PREFIXES['jobs']}{job_id}"
|
|
await self.redis.delete(key)
|
|
|
|
# Job Applications operations
|
|
async def get_job_application(self, application_id: str) -> Optional[Dict]:
|
|
"""Get job application by ID"""
|
|
key = f"{self.KEY_PREFIXES['job_applications']}{application_id}"
|
|
data = await self.redis.get(key)
|
|
return self._deserialize(data) if data else None
|
|
|
|
async def set_job_application(self, application_id: str, application_data: Dict):
|
|
"""Set job application data"""
|
|
key = f"{self.KEY_PREFIXES['job_applications']}{application_id}"
|
|
await self.redis.set(key, self._serialize(application_data))
|
|
|
|
async def get_all_job_applications(self) -> Dict[str, Any]:
|
|
"""Get all job applications"""
|
|
pattern = f"{self.KEY_PREFIXES['job_applications']}*"
|
|
keys = await self.redis.keys(pattern)
|
|
|
|
if not keys:
|
|
return {}
|
|
|
|
pipe = self.redis.pipeline()
|
|
for key in keys:
|
|
pipe.get(key)
|
|
values = await pipe.execute()
|
|
|
|
result = {}
|
|
for key, value in zip(keys, values):
|
|
app_id = key.replace(self.KEY_PREFIXES['job_applications'], '')
|
|
result[app_id] = self._deserialize(value)
|
|
|
|
return result
|
|
|
|
async def delete_job_application(self, application_id: str):
|
|
"""Delete job application"""
|
|
key = f"{self.KEY_PREFIXES['job_applications']}{application_id}"
|
|
await self.redis.delete(key)
|
|
|
|
# Chat Sessions operations
|
|
async def get_chat_session(self, session_id: str) -> Optional[Dict]:
|
|
"""Get chat session by ID"""
|
|
key = f"{self.KEY_PREFIXES['chat_sessions']}{session_id}"
|
|
data = await self.redis.get(key)
|
|
return self._deserialize(data) if data else None
|
|
|
|
async def set_chat_session(self, session_id: str, session_data: Dict):
|
|
"""Set chat session data"""
|
|
key = f"{self.KEY_PREFIXES['chat_sessions']}{session_id}"
|
|
await self.redis.set(key, self._serialize(session_data))
|
|
|
|
async def get_all_chat_sessions(self) -> Dict[str, Any]:
|
|
"""Get all chat sessions"""
|
|
pattern = f"{self.KEY_PREFIXES['chat_sessions']}*"
|
|
keys = await self.redis.keys(pattern)
|
|
|
|
if not keys:
|
|
return {}
|
|
|
|
pipe = self.redis.pipeline()
|
|
for key in keys:
|
|
pipe.get(key)
|
|
values = await pipe.execute()
|
|
|
|
result = {}
|
|
for key, value in zip(keys, values):
|
|
session_id = key.replace(self.KEY_PREFIXES['chat_sessions'], '')
|
|
result[session_id] = self._deserialize(value)
|
|
|
|
return result
|
|
|
|
async def delete_chat_session(self, session_id: str):
|
|
"""Delete chat session"""
|
|
key = f"{self.KEY_PREFIXES['chat_sessions']}{session_id}"
|
|
await self.redis.delete(key)
|
|
|
|
# Chat Messages operations (stored as lists)
|
|
async def get_chat_messages(self, session_id: str) -> List[Dict]:
|
|
"""Get chat messages for a session"""
|
|
key = f"{self.KEY_PREFIXES['chat_messages']}{session_id}"
|
|
messages = await self.redis.lrange(key, 0, -1)
|
|
return [self._deserialize(msg) for msg in messages if msg]
|
|
|
|
async def add_chat_message(self, session_id: str, message_data: Dict):
|
|
"""Add a chat message to a session"""
|
|
key = f"{self.KEY_PREFIXES['chat_messages']}{session_id}"
|
|
await self.redis.rpush(key, self._serialize(message_data))
|
|
|
|
async def set_chat_messages(self, session_id: str, messages: List[Dict]):
|
|
"""Set all chat messages for a session (replaces existing)"""
|
|
key = f"{self.KEY_PREFIXES['chat_messages']}{session_id}"
|
|
|
|
# Clear existing messages
|
|
await self.redis.delete(key)
|
|
|
|
# Add new messages
|
|
if messages:
|
|
serialized_messages = [self._serialize(msg) for msg in messages]
|
|
await self.redis.rpush(key, *serialized_messages)
|
|
|
|
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]:
|
|
"""Get all chat messages grouped by session"""
|
|
pattern = f"{self.KEY_PREFIXES['chat_messages']}*"
|
|
keys = await self.redis.keys(pattern)
|
|
|
|
if not keys:
|
|
return {}
|
|
|
|
result = {}
|
|
for key in keys:
|
|
session_id = key.replace(self.KEY_PREFIXES['chat_messages'], '')
|
|
messages = await self.redis.lrange(key, 0, -1)
|
|
result[session_id] = [self._deserialize(msg) for msg in messages if msg]
|
|
|
|
return result
|
|
|
|
async def delete_chat_messages(self, session_id: str):
|
|
"""Delete all chat messages for a session"""
|
|
key = f"{self.KEY_PREFIXES['chat_messages']}{session_id}"
|
|
await self.redis.delete(key)
|
|
|
|
# Enhanced Chat Session Methods
|
|
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]:
|
|
"""Get all chat sessions for a specific user"""
|
|
all_sessions = await self.get_all_chat_sessions()
|
|
user_sessions = []
|
|
|
|
for session_data in all_sessions.values():
|
|
if session_data.get("userId") == user_id or session_data.get("guestId") == user_id:
|
|
user_sessions.append(session_data)
|
|
|
|
# Sort by last activity (most recent first)
|
|
user_sessions.sort(key=lambda x: x.get("lastActivity", ""), reverse=True)
|
|
return user_sessions
|
|
|
|
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]:
|
|
"""Get all chat sessions related to a specific candidate"""
|
|
all_sessions = await self.get_all_chat_sessions()
|
|
candidate_sessions = []
|
|
|
|
for session_data in all_sessions.values():
|
|
context = session_data.get("context", {})
|
|
if (context.get("relatedEntityType") == "candidate" and
|
|
context.get("relatedEntityId") == candidate_id):
|
|
candidate_sessions.append(session_data)
|
|
|
|
# Sort by last activity (most recent first)
|
|
candidate_sessions.sort(key=lambda x: x.get("lastActivity", ""), reverse=True)
|
|
return candidate_sessions
|
|
|
|
async def update_chat_session_activity(self, session_id: str):
|
|
"""Update the last activity timestamp for a chat session"""
|
|
session_data = await self.get_chat_session(session_id)
|
|
if session_data:
|
|
session_data["lastActivity"] = datetime.now(UTC).isoformat()
|
|
await self.set_chat_session(session_id, session_data)
|
|
|
|
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]:
|
|
"""Get the most recent chat messages for a session"""
|
|
messages = await self.get_chat_messages(session_id)
|
|
# Return the last 'limit' messages
|
|
return messages[-limit:] if len(messages) > limit else messages
|
|
|
|
async def get_chat_message_count(self, session_id: str) -> int:
|
|
"""Get the total number of messages in a chat session"""
|
|
key = f"{self.KEY_PREFIXES['chat_messages']}{session_id}"
|
|
return await self.redis.llen(key)
|
|
|
|
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]:
|
|
"""Search for messages containing specific text in a session"""
|
|
messages = await self.get_chat_messages(session_id)
|
|
query_lower = query.lower()
|
|
|
|
matching_messages = []
|
|
for msg in messages:
|
|
content = msg.get("content", "").lower()
|
|
if query_lower in content:
|
|
matching_messages.append(msg)
|
|
|
|
return matching_messages
|
|
|
|
# Chat Session Management
|
|
async def archive_chat_session(self, session_id: str):
|
|
"""Archive a chat session"""
|
|
session_data = await self.get_chat_session(session_id)
|
|
if session_data:
|
|
session_data["isArchived"] = True
|
|
session_data["updatedAt"] = datetime.now(UTC).isoformat()
|
|
await self.set_chat_session(session_id, session_data)
|
|
|
|
async def delete_chat_session_completely(self, session_id: str):
|
|
"""Delete a chat session and all its messages"""
|
|
# Delete the session
|
|
await self.delete_chat_session(session_id)
|
|
# Delete all messages
|
|
await self.delete_chat_messages(session_id)
|
|
|
|
async def cleanup_old_chat_sessions(self, days_old: int = 90):
|
|
"""Archive or delete chat sessions older than specified days"""
|
|
cutoff_date = datetime.now(UTC) - timedelta(days=days_old)
|
|
cutoff_iso = cutoff_date.isoformat()
|
|
|
|
all_sessions = await self.get_all_chat_sessions()
|
|
archived_count = 0
|
|
|
|
for session_id, session_data in all_sessions.items():
|
|
last_activity = session_data.get("lastActivity", session_data.get("createdAt", ""))
|
|
|
|
if last_activity < cutoff_iso and not session_data.get("isArchived", False):
|
|
await self.archive_chat_session(session_id)
|
|
archived_count += 1
|
|
|
|
return archived_count
|
|
|
|
# Enhanced User Operations
|
|
async def get_user_by_username(self, username: str) -> Optional[Dict]:
|
|
"""Get user by username specifically"""
|
|
username_key = f"{self.KEY_PREFIXES['users']}{username.lower()}"
|
|
data = await self.redis.get(username_key)
|
|
return self._deserialize(data) if data else None
|
|
|
|
async def find_candidate_by_username(self, username: str) -> Optional[Dict]:
|
|
"""Find candidate by username"""
|
|
all_candidates = await self.get_all_candidates()
|
|
username_lower = username.lower()
|
|
|
|
for candidate_data in all_candidates.values():
|
|
if candidate_data.get("username", "").lower() == username_lower:
|
|
return candidate_data
|
|
|
|
return None
|
|
|
|
# Analytics and Reporting
|
|
async def get_chat_statistics(self) -> Dict[str, Any]:
|
|
"""Get comprehensive chat statistics"""
|
|
all_sessions = await self.get_all_chat_sessions()
|
|
all_messages = await self.get_all_chat_messages()
|
|
|
|
stats = {
|
|
"total_sessions": len(all_sessions),
|
|
"total_messages": sum(len(messages) for messages in all_messages.values()),
|
|
"active_sessions": 0,
|
|
"archived_sessions": 0,
|
|
"sessions_by_type": {},
|
|
"sessions_with_candidates": 0,
|
|
"average_messages_per_session": 0
|
|
}
|
|
|
|
# Analyze sessions
|
|
for session_data in all_sessions.values():
|
|
if session_data.get("isArchived", False):
|
|
stats["archived_sessions"] += 1
|
|
else:
|
|
stats["active_sessions"] += 1
|
|
|
|
# Count by type
|
|
context_type = session_data.get("context", {}).get("type", "unknown")
|
|
stats["sessions_by_type"][context_type] = stats["sessions_by_type"].get(context_type, 0) + 1
|
|
|
|
# Count sessions with candidate association
|
|
if session_data.get("context", {}).get("relatedEntityType") == "candidate":
|
|
stats["sessions_with_candidates"] += 1
|
|
|
|
# Calculate averages
|
|
if stats["total_sessions"] > 0:
|
|
stats["average_messages_per_session"] = stats["total_messages"] / stats["total_sessions"]
|
|
|
|
return stats
|
|
|
|
|
|
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]:
|
|
"""Get a summary of chat activity for a specific candidate"""
|
|
sessions = await self.get_chat_sessions_by_candidate(candidate_id)
|
|
|
|
if not sessions:
|
|
return {
|
|
"candidate_id": candidate_id,
|
|
"total_sessions": 0,
|
|
"total_messages": 0,
|
|
"first_chat": None,
|
|
"last_chat": None
|
|
}
|
|
|
|
total_messages = 0
|
|
for session in sessions:
|
|
session_id = session.get("id")
|
|
if session_id:
|
|
message_count = await self.get_chat_message_count(session_id)
|
|
total_messages += message_count
|
|
|
|
# Sort sessions by creation date
|
|
sessions_by_date = sorted(sessions, key=lambda x: x.get("createdAt", ""))
|
|
|
|
return {
|
|
"candidate_id": candidate_id,
|
|
"total_sessions": len(sessions),
|
|
"total_messages": total_messages,
|
|
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
|
|
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
|
|
"recent_sessions": sessions[:5] # Last 5 sessions
|
|
}
|
|
|
|
# Batch Operations
|
|
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]:
|
|
"""Get multiple candidates by their usernames efficiently"""
|
|
all_candidates = await self.get_all_candidates()
|
|
username_set = {username.lower() for username in usernames}
|
|
|
|
result = {}
|
|
for candidate_data in all_candidates.values():
|
|
candidate_username = candidate_data.get("username", "").lower()
|
|
if candidate_username in username_set:
|
|
result[candidate_username] = candidate_data
|
|
|
|
return result
|
|
|
|
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]):
|
|
"""Bulk update multiple chat sessions"""
|
|
pipe = self.redis.pipeline()
|
|
|
|
for session_id, updates in session_updates.items():
|
|
session_data = await self.get_chat_session(session_id)
|
|
if session_data:
|
|
session_data.update(updates)
|
|
session_data["updatedAt"] = datetime.now(UTC).isoformat()
|
|
key = f"{self.KEY_PREFIXES['chat_sessions']}{session_id}"
|
|
pipe.set(key, self._serialize(session_data))
|
|
|
|
await pipe.execute()
|
|
|
|
|
|
# AI Parameters operations
|
|
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
|
|
"""Get AI parameters by ID"""
|
|
key = f"{self.KEY_PREFIXES['ai_parameters']}{param_id}"
|
|
data = await self.redis.get(key)
|
|
return self._deserialize(data) if data else None
|
|
|
|
async def set_ai_parameters(self, param_id: str, param_data: Dict):
|
|
"""Set AI parameters data"""
|
|
key = f"{self.KEY_PREFIXES['ai_parameters']}{param_id}"
|
|
await self.redis.set(key, self._serialize(param_data))
|
|
|
|
async def get_all_ai_parameters(self) -> Dict[str, Any]:
|
|
"""Get all AI parameters"""
|
|
pattern = f"{self.KEY_PREFIXES['ai_parameters']}*"
|
|
keys = await self.redis.keys(pattern)
|
|
|
|
if not keys:
|
|
return {}
|
|
|
|
pipe = self.redis.pipeline()
|
|
for key in keys:
|
|
pipe.get(key)
|
|
values = await pipe.execute()
|
|
|
|
result = {}
|
|
for key, value in zip(keys, values):
|
|
param_id = key.replace(self.KEY_PREFIXES['ai_parameters'], '')
|
|
result[param_id] = self._deserialize(value)
|
|
|
|
return result
|
|
|
|
async def delete_ai_parameters(self, param_id: str):
|
|
"""Delete AI parameters"""
|
|
key = f"{self.KEY_PREFIXES['ai_parameters']}{param_id}"
|
|
await self.redis.delete(key)
|
|
|
|
|
|
async def get_all_users(self) -> Dict[str, Any]:
|
|
"""Get all users"""
|
|
pattern = f"{self.KEY_PREFIXES['users']}*"
|
|
keys = await self.redis.keys(pattern)
|
|
|
|
if not keys:
|
|
return {}
|
|
|
|
pipe = self.redis.pipeline()
|
|
for key in keys:
|
|
pipe.get(key)
|
|
values = await pipe.execute()
|
|
|
|
result = {}
|
|
for key, value in zip(keys, values):
|
|
email = key.replace(self.KEY_PREFIXES['users'], '')
|
|
result[email] = self._deserialize(value)
|
|
|
|
return result
|
|
|
|
async def delete_user(self, email: str):
|
|
"""Delete user"""
|
|
key = f"{self.KEY_PREFIXES['users']}{email}"
|
|
await self.redis.delete(key)
|
|
|
|
# Utility methods
|
|
async def clear_all_data(self):
|
|
"""Clear all data from Redis (use with caution!)"""
|
|
for prefix in self.KEY_PREFIXES.values():
|
|
pattern = f"{prefix}*"
|
|
keys = await self.redis.keys(pattern)
|
|
if keys:
|
|
await self.redis.delete(*keys)
|
|
|
|
async def get_stats(self) -> Dict[str, int]:
|
|
"""Get statistics about stored data"""
|
|
stats = {}
|
|
for data_type, prefix in self.KEY_PREFIXES.items():
|
|
pattern = f"{prefix}*"
|
|
keys = await self.redis.keys(pattern)
|
|
stats[data_type] = len(keys)
|
|
return stats
|
|
|
|
# Authentication Record Methods
|
|
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool:
|
|
"""Store authentication record for a user"""
|
|
try:
|
|
key = f"auth:{user_id}"
|
|
await self.redis.set(key, json.dumps(auth_data, default=str))
|
|
logger.debug(f"🔐 Stored authentication record for user {user_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"❌ Error storing authentication record for {user_id}: {e}")
|
|
return False
|
|
|
|
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Retrieve authentication record for a user"""
|
|
try:
|
|
key = f"auth:{user_id}"
|
|
data = await self.redis.get(key)
|
|
if data:
|
|
return json.loads(data)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"❌ Error retrieving authentication record for {user_id}: {e}")
|
|
return None
|
|
|
|
async def delete_authentication(self, user_id: str) -> bool:
|
|
"""Delete authentication record for a user"""
|
|
try:
|
|
key = f"auth:{user_id}"
|
|
result = await self.redis.delete(key)
|
|
logger.debug(f"🔐 Deleted authentication record for user {user_id}")
|
|
return result > 0
|
|
except Exception as e:
|
|
logger.error(f"❌ Error deleting authentication record for {user_id}: {e}")
|
|
return False
|
|
|
|
# Enhanced User Methods
|
|
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool:
|
|
"""Store user data with ID as key for direct lookup"""
|
|
try:
|
|
key = f"user_by_id:{user_id}"
|
|
await self.redis.set(key, json.dumps(user_data, default=str))
|
|
logger.debug(f"👤 Stored user data by ID for {user_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"❌ Error storing user by ID {user_id}: {e}")
|
|
return False
|
|
|
|
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Retrieve user data by ID"""
|
|
try:
|
|
key = f"user_by_id:{user_id}"
|
|
data = await self.redis.get(key)
|
|
if data:
|
|
return json.loads(data)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"❌ Error retrieving user by ID {user_id}: {e}")
|
|
return None
|
|
|
|
async def user_exists_by_email(self, email: str) -> bool:
|
|
"""Check if a user exists with the given email"""
|
|
try:
|
|
key = f"users:{email.lower()}"
|
|
exists = await self.redis.exists(key)
|
|
return exists > 0
|
|
except Exception as e:
|
|
logger.error(f"❌ Error checking user existence by email {email}: {e}")
|
|
return False
|
|
|
|
async def user_exists_by_username(self, username: str) -> bool:
|
|
"""Check if a user exists with the given username"""
|
|
try:
|
|
key = f"users:{username.lower()}"
|
|
exists = await self.redis.exists(key)
|
|
return exists > 0
|
|
except Exception as e:
|
|
logger.error(f"❌ Error checking user existence by username {username}: {e}")
|
|
return False
|
|
|
|
# Enhanced user lookup method to support both email and username
|
|
async def get_user(self, login: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get user by email or username
|
|
"""
|
|
try:
|
|
# Normalize the login string
|
|
login = login.strip().lower()
|
|
key = f"users:{login}"
|
|
|
|
data = await self.redis.get(key)
|
|
if data:
|
|
user_data = json.loads(data)
|
|
logger.debug(f"👤 Retrieved user data for {login}")
|
|
return user_data
|
|
|
|
logger.debug(f"👤 No user found for {login}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"❌ Error retrieving user {login}: {e}")
|
|
return None
|
|
|
|
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool:
|
|
"""
|
|
Enhanced method to store user data by email or username
|
|
Updated version of your existing method
|
|
"""
|
|
try:
|
|
# Normalize the login string
|
|
login = login.strip().lower()
|
|
key = f"users:{login}"
|
|
|
|
await self.redis.set(key, json.dumps(user_data, default=str))
|
|
logger.debug(f"👤 Stored user data for {login}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"❌ Error storing user {login}: {e}")
|
|
return False
|
|
|
|
# Token Management Methods
|
|
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool:
|
|
"""Store refresh token for a user"""
|
|
try:
|
|
key = f"refresh_token:{token}"
|
|
token_data = {
|
|
"user_id": user_id,
|
|
"expires_at": expires_at.isoformat(),
|
|
"device": device_info.get("device", "unknown"),
|
|
"ip_address": device_info.get("ip_address", "unknown"),
|
|
"is_revoked": False,
|
|
"created_at": datetime.now(timezone.utc).isoformat()
|
|
}
|
|
|
|
# Store with expiration
|
|
ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
|
|
if ttl_seconds > 0:
|
|
await self.redis.setex(key, ttl_seconds, json.dumps(token_data, default=str))
|
|
logger.debug(f"🔐 Stored refresh token for user {user_id}")
|
|
return True
|
|
else:
|
|
logger.warning(f"⚠️ Attempted to store expired refresh token for user {user_id}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"❌ Error storing refresh token for {user_id}: {e}")
|
|
return False
|
|
|
|
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]:
|
|
"""Retrieve refresh token data"""
|
|
try:
|
|
key = f"refresh_token:{token}"
|
|
data = await self.redis.get(key)
|
|
if data:
|
|
return json.loads(data)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"❌ Error retrieving refresh token: {e}")
|
|
return None
|
|
|
|
async def revoke_refresh_token(self, token: str) -> bool:
|
|
"""Revoke a refresh token"""
|
|
try:
|
|
key = f"refresh_token:{token}"
|
|
token_data = await self.get_refresh_token(token)
|
|
if token_data:
|
|
token_data["is_revoked"] = True
|
|
token_data["revoked_at"] = datetime.now(timezone.utc).isoformat()
|
|
await self.redis.set(key, json.dumps(token_data, default=str))
|
|
logger.debug(f"🔐 Revoked refresh token")
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"❌ Error revoking refresh token: {e}")
|
|
return False
|
|
|
|
async def revoke_all_user_tokens(self, user_id: str) -> bool:
|
|
"""Revoke all refresh tokens for a user"""
|
|
try:
|
|
# This requires scanning all refresh tokens - consider using a user token index for efficiency
|
|
pattern = "refresh_token:*"
|
|
cursor = 0
|
|
revoked_count = 0
|
|
|
|
while True:
|
|
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
|
|
|
|
for key in keys:
|
|
token_data = await self.redis.get(key)
|
|
if token_data:
|
|
token_info = json.loads(token_data)
|
|
if token_info.get("user_id") == user_id and not token_info.get("is_revoked"):
|
|
token_info["is_revoked"] = True
|
|
token_info["revoked_at"] = datetime.now(timezone.utc).isoformat()
|
|
await self.redis.set(key, json.dumps(token_info, default=str))
|
|
revoked_count += 1
|
|
|
|
if cursor == 0:
|
|
break
|
|
|
|
logger.info(f"🔐 Revoked {revoked_count} refresh tokens for user {user_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"❌ Error revoking all tokens for user {user_id}: {e}")
|
|
return False
|
|
|
|
# Password Reset Token Methods
|
|
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool:
|
|
"""Store password reset token"""
|
|
try:
|
|
key = f"password_reset:{token}"
|
|
token_data = {
|
|
"email": email.lower(),
|
|
"expires_at": expires_at.isoformat(),
|
|
"used": False,
|
|
"created_at": datetime.now(timezone.utc).isoformat()
|
|
}
|
|
|
|
# Store with expiration
|
|
ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
|
|
if ttl_seconds > 0:
|
|
await self.redis.setex(key, ttl_seconds, json.dumps(token_data, default=str))
|
|
logger.debug(f"🔐 Stored password reset token for {email}")
|
|
return True
|
|
else:
|
|
logger.warning(f"⚠️ Attempted to store expired password reset token for {email}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"❌ Error storing password reset token for {email}: {e}")
|
|
return False
|
|
|
|
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]:
|
|
"""Retrieve password reset token data"""
|
|
try:
|
|
key = f"password_reset:{token}"
|
|
data = await self.redis.get(key)
|
|
if data:
|
|
return json.loads(data)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"❌ Error retrieving password reset token: {e}")
|
|
return None
|
|
|
|
async def mark_password_reset_token_used(self, token: str) -> bool:
|
|
"""Mark password reset token as used"""
|
|
try:
|
|
key = f"password_reset:{token}"
|
|
token_data = await self.get_password_reset_token(token)
|
|
if token_data:
|
|
token_data["used"] = True
|
|
token_data["used_at"] = datetime.now(timezone.utc).isoformat()
|
|
await self.redis.set(key, json.dumps(token_data, default=str))
|
|
logger.debug(f"🔐 Marked password reset token as used")
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"❌ Error marking password reset token as used: {e}")
|
|
return False
|
|
|
|
# User Activity and Security Logging
|
|
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool:
|
|
"""Log security events for audit purposes"""
|
|
try:
|
|
key = f"security_log:{user_id}:{datetime.now(timezone.utc).strftime('%Y-%m-%d')}"
|
|
event_data = {
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"user_id": user_id,
|
|
"event_type": event_type,
|
|
"details": details
|
|
}
|
|
|
|
# Add to list (latest events first)
|
|
await self.redis.lpush(key, json.dumps(event_data, default=str))
|
|
|
|
# Keep only last 100 events per day
|
|
await self.redis.ltrim(key, 0, 99)
|
|
|
|
# Set expiration for 30 days
|
|
await self.redis.expire(key, 30 * 24 * 60 * 60)
|
|
|
|
logger.debug(f"🔒 Logged security event {event_type} for user {user_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"❌ Error logging security event for {user_id}: {e}")
|
|
return False
|
|
|
|
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]:
|
|
"""Retrieve security log for a user"""
|
|
try:
|
|
events = []
|
|
for i in range(days):
|
|
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime('%Y-%m-%d')
|
|
key = f"security_log:{user_id}:{date}"
|
|
|
|
daily_events = await self.redis.lrange(key, 0, -1)
|
|
for event_json in daily_events:
|
|
events.append(json.loads(event_json))
|
|
|
|
# Sort by timestamp (most recent first)
|
|
events.sort(key=lambda x: x["timestamp"], reverse=True)
|
|
return events
|
|
except Exception as e:
|
|
logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
|
|
return []
|
|
|
|
# Global Redis manager instance
|
|
redis_manager = _RedisManager()
|
|
|
|
class DatabaseManager:
|
|
"""Enhanced database manager with graceful shutdown capabilities"""
|
|
|
|
def __init__(self):
|
|
self.db: Optional[RedisDatabase] = None
|
|
self._shutdown_initiated = False
|
|
self._active_requests = 0
|
|
self._shutdown_timeout = int(os.getenv("SHUTDOWN_TIMEOUT", "30")) # seconds
|
|
self._backup_on_shutdown = os.getenv("BACKUP_ON_SHUTDOWN", "false").lower() == "true"
|
|
|
|
async def initialize(self):
|
|
"""Initialize database connection"""
|
|
try:
|
|
# Connect to Redis
|
|
await redis_manager.connect()
|
|
logger.info("Redis connection established")
|
|
|
|
# Create database instance
|
|
self.db = RedisDatabase(redis_manager.get_client())
|
|
|
|
# Test connection and log stats
|
|
if not redis_manager.redis:
|
|
raise RuntimeError("Redis client not initialized")
|
|
await redis_manager.redis.ping()
|
|
stats = await self.db.get_stats()
|
|
logger.info(f"Database initialized successfully. Stats: {stats}")
|
|
|
|
return self.db
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize database: {e}")
|
|
raise
|
|
|
|
async def backup_data(self) -> Optional[str]:
|
|
"""Create a backup of critical data before shutdown"""
|
|
if not self.db:
|
|
return None
|
|
|
|
try:
|
|
backup_data = {
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
"stats": await self.db.get_stats(),
|
|
"users": await self.db.get_all_users(),
|
|
# Add other critical data as needed
|
|
}
|
|
|
|
backup_filename = f"backup_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}.json"
|
|
|
|
# Save to local file (you might want to save to cloud storage instead)
|
|
with open(backup_filename, 'w') as f:
|
|
json.dump(backup_data, f, indent=2, default=str)
|
|
|
|
logger.info(f"Backup created: {backup_filename}")
|
|
return backup_filename
|
|
|
|
except Exception as e:
|
|
logger.error(f"Backup failed: {e}")
|
|
return None
|
|
|
|
async def graceful_shutdown(self):
|
|
"""Perform graceful shutdown with optional backup"""
|
|
self._shutdown_initiated = True
|
|
logger.info("Initiating graceful shutdown...")
|
|
|
|
# Wait for active requests to complete (with timeout)
|
|
wait_time = 0
|
|
while self._active_requests > 0 and wait_time < self._shutdown_timeout:
|
|
logger.info(f"Waiting for {self._active_requests} active requests to complete...")
|
|
await asyncio.sleep(1)
|
|
wait_time += 1
|
|
|
|
if self._active_requests > 0:
|
|
logger.warning(f"Shutdown timeout reached. {self._active_requests} requests may be interrupted.")
|
|
|
|
# Create backup if configured
|
|
if self._backup_on_shutdown:
|
|
backup_file = await self.backup_data()
|
|
if backup_file:
|
|
logger.info(f"Pre-shutdown backup completed: {backup_file}")
|
|
|
|
# Force Redis to save data to disk
|
|
try:
|
|
if redis_manager.redis:
|
|
# Try BGSAVE first (non-blocking)
|
|
try:
|
|
await redis_manager.redis.bgsave()
|
|
logger.info("Background save initiated")
|
|
|
|
# Wait a bit for background save to start
|
|
await asyncio.sleep(0.5)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Background save failed, trying synchronous save: {e}")
|
|
try:
|
|
# Fallback to synchronous save
|
|
await redis_manager.redis.save()
|
|
logger.info("Synchronous save completed")
|
|
except Exception as e2:
|
|
logger.warning(f"Synchronous save also failed (Redis persistence may be disabled): {e2}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during Redis save: {e}")
|
|
|
|
# Close Redis connection
|
|
try:
|
|
await redis_manager.disconnect()
|
|
logger.info("Redis connection closed successfully")
|
|
except Exception as e:
|
|
logger.error(f"Error closing Redis connection: {e}")
|
|
|
|
logger.info("Graceful shutdown completed")
|
|
|
|
def increment_requests(self):
|
|
"""Track active requests"""
|
|
self._active_requests += 1
|
|
|
|
def decrement_requests(self):
|
|
"""Track completed requests"""
|
|
self._active_requests = max(0, self._active_requests - 1)
|
|
|
|
@property
|
|
def is_shutting_down(self) -> bool:
|
|
"""Check if shutdown is in progress"""
|
|
return self._shutdown_initiated
|
|
|
|
def get_database(self) -> RedisDatabase:
|
|
"""Get database instance"""
|
|
if self.db is None:
|
|
raise RuntimeError("Database not initialized")
|
|
if self._shutdown_initiated:
|
|
raise RuntimeError("Application is shutting down")
|
|
return self.db
|
|
|