Compare commits
2 Commits
ed2b99e8b9
...
e5ac267935
Author | SHA1 | Date | |
---|---|---|---|
e5ac267935 | |||
cbd6ead5f3 |
@ -25,7 +25,7 @@ import hashlib
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from models import ActivityType, ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables
|
||||
import model_cast
|
||||
import helpers.model_cast as model_cast
|
||||
from logger import logger
|
||||
import defines
|
||||
import backstory_traceback as traceback
|
||||
|
@ -27,7 +27,6 @@ from names_dataset import NameDataset, NameWrapper # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables
|
||||
import model_cast
|
||||
from logger import logger
|
||||
import defines
|
||||
import backstory_traceback as traceback
|
||||
|
@ -20,7 +20,6 @@ import numpy as np # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from models import ApiActivityType, ApiMessage, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Job, JobRequirements, JobRequirementsMessage, Tunables
|
||||
import model_cast
|
||||
from logger import logger
|
||||
import defines
|
||||
import backstory_traceback as traceback
|
||||
|
@ -22,7 +22,6 @@ from .base import Agent, agent_registry, LLMMessage
|
||||
from models import (ApiMessage, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageRagSearch,
|
||||
ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions,
|
||||
ChatSenderType, ApiStatusType, EvidenceDetail, SkillAssessment, Tunables)
|
||||
import model_cast
|
||||
from logger import logger
|
||||
import defines
|
||||
import backstory_traceback as traceback
|
||||
|
@ -82,10 +82,8 @@ class BackgroundTaskManager:
|
||||
logger.info("Skipping rate limit cleanup - application shutting down")
|
||||
return 0
|
||||
|
||||
database = self.database_manager.get_database()
|
||||
|
||||
# Get Redis client safely (using the event loop safe method)
|
||||
from backend.database.manager import redis_manager
|
||||
from database.manager import redis_manager
|
||||
redis = await redis_manager.get_client()
|
||||
|
||||
# Clean up rate limit keys older than specified days
|
||||
@ -103,9 +101,13 @@ class BackgroundTaskManager:
|
||||
try:
|
||||
ttl = await redis.ttl(key)
|
||||
if ttl == -1: # No expiration set, check creation time
|
||||
# For simplicity, delete keys without TTL
|
||||
await redis.delete(key)
|
||||
deleted_count += 1
|
||||
creation_time = await redis.hget(key, "created_at") # type: ignore
|
||||
if creation_time:
|
||||
creation_time = datetime.fromisoformat(creation_time).replace(tzinfo=UTC)
|
||||
if creation_time < cutoff_time:
|
||||
# Key is older than cutoff, delete it
|
||||
await redis.delete(key)
|
||||
deleted_count += 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
@ -10,6 +10,8 @@ from .mixins.job import JobMixin
|
||||
from .mixins.skill import SkillMixin
|
||||
from .mixins.ai import AIMixin
|
||||
|
||||
# RedisDatabase is the main class that combines all mixins for a
|
||||
# comprehensive Redis database interface.
|
||||
class RedisDatabase(
|
||||
AIMixin,
|
||||
BaseMixin,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -11,8 +11,11 @@ from ..constants import KEY_PREFIXES
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UserMixin(DatabaseProtocol):
|
||||
"""Mixin for user and candidate operations"""
|
||||
"""Mixin for user operations"""
|
||||
|
||||
# ================
|
||||
# Guests
|
||||
# ================
|
||||
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None:
|
||||
"""Store guest data with enhanced persistence"""
|
||||
try:
|
||||
@ -211,6 +214,56 @@ class UserMixin(DatabaseProtocol):
|
||||
logger.error(f"❌ Error getting guest statistics: {e}")
|
||||
return {}
|
||||
|
||||
# ================
|
||||
# Users
|
||||
# ================
|
||||
async def get_user_by_username(self, username: str) -> Optional[Dict]:
|
||||
"""Get user by username specifically"""
|
||||
username_key = f"{KEY_PREFIXES['users']}{username.lower()}"
|
||||
data = await self.redis.get(username_key)
|
||||
return self._deserialize(data) if data else None
|
||||
|
||||
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
|
||||
"""Get the last time user's RAG data was updated (returns timezone-aware UTC)"""
|
||||
try:
|
||||
rag_update_key = f"user:{user_id}:rag_last_update"
|
||||
timestamp_str = await self.redis.get(rag_update_key)
|
||||
if timestamp_str:
|
||||
dt = datetime.fromisoformat(timestamp_str)
|
||||
# Ensure the datetime is timezone-aware (assume UTC if naive)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Convert to UTC if it's in a different timezone
|
||||
dt = dt.astimezone(timezone.utc)
|
||||
return dt
|
||||
logger.warning(f"⚠️ No RAG update time found for user {user_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting user RAG update time: {e}")
|
||||
return None
|
||||
|
||||
async def update_user_rag_timestamp(self, user_id: str) -> bool:
|
||||
"""Set the user's RAG data update time (stores as UTC ISO format)"""
|
||||
try:
|
||||
update_time = datetime.now(timezone.utc)
|
||||
|
||||
# Ensure we're storing UTC timezone-aware format
|
||||
if update_time.tzinfo is None:
|
||||
update_time = update_time.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
update_time = update_time.astimezone(timezone.utc)
|
||||
|
||||
rag_update_key = f"user:{user_id}:rag_last_update"
|
||||
# Store as ISO format with timezone info
|
||||
timestamp_str = update_time.isoformat() # This includes timezone
|
||||
await self.redis.set(rag_update_key, timestamp_str)
|
||||
logger.info(f"✅ User RAG update time set for user {user_id}: {timestamp_str}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error setting user RAG update time: {e}")
|
||||
return False
|
||||
|
||||
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool:
|
||||
"""Store user data with ID as key for direct lookup"""
|
||||
try:
|
||||
@ -281,7 +334,40 @@ class UserMixin(DatabaseProtocol):
|
||||
"""Delete user"""
|
||||
key = f"{KEY_PREFIXES['users']}{email}"
|
||||
await self.redis.delete(key)
|
||||
|
||||
|
||||
async def get_user(self, login: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user by email or username"""
|
||||
try:
|
||||
login = login.strip().lower()
|
||||
key = f"users:{login}"
|
||||
|
||||
data = await self.redis.get(key)
|
||||
if data:
|
||||
user_data = json.loads(data)
|
||||
logger.info(f"👤 Retrieved user data for {login}")
|
||||
return user_data
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving user {login}: {e}")
|
||||
return None
|
||||
|
||||
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool:
|
||||
"""Store user data by email or username"""
|
||||
try:
|
||||
login = login.strip().lower()
|
||||
key = f"users:{login}"
|
||||
|
||||
await self.redis.set(key, json.dumps(user_data, default=str))
|
||||
logger.info(f"👤 Stored user data for {login}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error storing user {login}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ================
|
||||
# Employers
|
||||
# ================
|
||||
async def get_employer(self, employer_id: str) -> Optional[Dict]:
|
||||
"""Get employer by ID"""
|
||||
key = f"{KEY_PREFIXES['employers']}{employer_id}"
|
||||
@ -319,36 +405,9 @@ class UserMixin(DatabaseProtocol):
|
||||
await self.redis.delete(key)
|
||||
|
||||
|
||||
async def get_user(self, login: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user by email or username"""
|
||||
try:
|
||||
login = login.strip().lower()
|
||||
key = f"users:{login}"
|
||||
|
||||
data = await self.redis.get(key)
|
||||
if data:
|
||||
user_data = json.loads(data)
|
||||
logger.info(f"👤 Retrieved user data for {login}")
|
||||
return user_data
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving user {login}: {e}")
|
||||
return None
|
||||
|
||||
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool:
|
||||
"""Store user data by email or username"""
|
||||
try:
|
||||
login = login.strip().lower()
|
||||
key = f"users:{login}"
|
||||
|
||||
await self.redis.set(key, json.dumps(user_data, default=str))
|
||||
logger.info(f"👤 Stored user data for {login}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error storing user {login}: {e}")
|
||||
return False
|
||||
|
||||
# Candidates operations
|
||||
# ================
|
||||
# Candidates
|
||||
# ================
|
||||
async def get_candidate(self, candidate_id: str) -> Optional[Dict]:
|
||||
"""Get candidate by ID"""
|
||||
key = f"{KEY_PREFIXES['candidates']}{candidate_id}"
|
||||
@ -723,13 +782,6 @@ class UserMixin(DatabaseProtocol):
|
||||
logger.error(f"❌ Critical error during batch candidate deletion: {e}")
|
||||
raise
|
||||
|
||||
# User Operations
|
||||
async def get_user_by_username(self, username: str) -> Optional[Dict]:
|
||||
"""Get user by username specifically"""
|
||||
username_key = f"{KEY_PREFIXES['users']}{username.lower()}"
|
||||
data = await self.redis.get(username_key)
|
||||
return self._deserialize(data) if data else None
|
||||
|
||||
async def find_candidate_by_username(self, username: str) -> Optional[Dict]:
|
||||
"""Find candidate by username"""
|
||||
all_candidates = await self.get_all_candidates()
|
||||
@ -741,7 +793,6 @@ class UserMixin(DatabaseProtocol):
|
||||
|
||||
return None
|
||||
|
||||
# Batch Operations
|
||||
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]:
|
||||
"""Get multiple candidates by their usernames efficiently"""
|
||||
all_candidates = await self.get_all_candidates()
|
||||
@ -787,6 +838,9 @@ class UserMixin(DatabaseProtocol):
|
||||
"recent_sessions": sessions[:5] # Last 5 sessions
|
||||
}
|
||||
|
||||
# ================
|
||||
# Viewers
|
||||
# ================
|
||||
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
|
||||
"""Get viewer by ID"""
|
||||
key = f"{KEY_PREFIXES['viewers']}{viewer_id}"
|
||||
@ -824,44 +878,3 @@ class UserMixin(DatabaseProtocol):
|
||||
key = f"{KEY_PREFIXES['viewers']}{viewer_id}"
|
||||
await self.redis.delete(key)
|
||||
|
||||
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
|
||||
"""Get the last time user's RAG data was updated (returns timezone-aware UTC)"""
|
||||
try:
|
||||
rag_update_key = f"user:{user_id}:rag_last_update"
|
||||
timestamp_str = await self.redis.get(rag_update_key)
|
||||
if timestamp_str:
|
||||
dt = datetime.fromisoformat(timestamp_str)
|
||||
# Ensure the datetime is timezone-aware (assume UTC if naive)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Convert to UTC if it's in a different timezone
|
||||
dt = dt.astimezone(timezone.utc)
|
||||
return dt
|
||||
logger.warning(f"⚠️ No RAG update time found for user {user_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting user RAG update time: {e}")
|
||||
return None
|
||||
|
||||
async def update_user_rag_timestamp(self, user_id: str) -> bool:
|
||||
"""Set the user's RAG data update time (stores as UTC ISO format)"""
|
||||
try:
|
||||
update_time = datetime.now(timezone.utc)
|
||||
|
||||
# Ensure we're storing UTC timezone-aware format
|
||||
if update_time.tzinfo is None:
|
||||
update_time = update_time.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
update_time = update_time.astimezone(timezone.utc)
|
||||
|
||||
rag_update_key = f"user:{user_id}:rag_last_update"
|
||||
# Store as ISO format with timezone info
|
||||
timestamp_str = update_time.isoformat() # This includes timezone
|
||||
await self.redis.set(rag_update_key, timestamp_str)
|
||||
logger.info(f"✅ User RAG update time set for user {user_id}: {timestamp_str}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error setting user RAG update time: {e}")
|
||||
return False
|
||||
|
||||
|
@ -60,7 +60,7 @@ from auth_utils import (
|
||||
sanitize_login_input,
|
||||
SecurityConfig
|
||||
)
|
||||
import model_cast
|
||||
import helpers.model_cast as model_cast
|
||||
import defines
|
||||
from logger import logger
|
||||
from database.manager import RedisDatabase, redis_manager, DatabaseManager
|
||||
|
@ -1,53 +0,0 @@
|
||||
from typing import Type, TypeVar
|
||||
from pydantic import BaseModel
|
||||
import copy
|
||||
|
||||
from models import Candidate, CandidateAI, Employer, Guest, BaseUserWithType
|
||||
|
||||
# Ensure all user models inherit from BaseUserWithType
|
||||
assert issubclass(Candidate, BaseUserWithType), "Candidate must inherit from BaseUserWithType"
|
||||
assert issubclass(CandidateAI, BaseUserWithType), "CandidateAI must inherit from BaseUserWithType"
|
||||
assert issubclass(Employer, BaseUserWithType), "Employer must inherit from BaseUserWithType"
|
||||
assert issubclass(Guest, BaseUserWithType), "Guest must inherit from BaseUserWithType"
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
def cast_to_model(model_cls: Type[T], source: BaseModel) -> T:
|
||||
data = {field: getattr(source, field) for field in model_cls.__fields__}
|
||||
return model_cls(**data)
|
||||
|
||||
def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T:
|
||||
data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__}
|
||||
return model_cls(**data)
|
||||
|
||||
def cast_to_base_user_with_type(user) -> BaseUserWithType:
|
||||
"""
|
||||
Casts a Candidate, CandidateAI, Employer, or Guest to BaseUserWithType.
|
||||
This is useful for FastAPI dependencies that expect a common user type.
|
||||
"""
|
||||
if isinstance(user, BaseUserWithType):
|
||||
return user
|
||||
# If it's a dict, try to detect type
|
||||
if isinstance(user, dict):
|
||||
user_type = user.get("user_type") or user.get("type")
|
||||
if user_type == "candidate":
|
||||
if user.get("is_AI"):
|
||||
return CandidateAI.model_validate(user)
|
||||
return Candidate.model_validate(user)
|
||||
elif user_type == "employer":
|
||||
return Employer.model_validate(user)
|
||||
elif user_type == "guest":
|
||||
return Guest.model_validate(user)
|
||||
else:
|
||||
raise ValueError(f"Unknown user_type: {user_type}")
|
||||
# If it's a model, check its type
|
||||
if hasattr(user, "user_type"):
|
||||
if getattr(user, "user_type", None) == "candidate":
|
||||
if getattr(user, "is_AI", False):
|
||||
return CandidateAI.model_validate(user.model_dump())
|
||||
return Candidate.model_validate(user.model_dump())
|
||||
elif getattr(user, "user_type", None) == "employer":
|
||||
return Employer.model_validate(user.model_dump())
|
||||
elif getattr(user, "user_type", None) == "guest":
|
||||
return Guest.model_validate(user.model_dump())
|
||||
raise TypeError(f"Cannot cast object of type {type(user)} to BaseUserWithType")
|
@ -134,11 +134,10 @@ async def get_current_user(
|
||||
# Check candidates
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
from helpers.model_cast import cast_to_base_user_with_type
|
||||
if candidate_data.get("is_AI"):
|
||||
from model_cast import cast_to_base_user_with_type
|
||||
return cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data))
|
||||
else:
|
||||
from model_cast import cast_to_base_user_with_type
|
||||
return cast_to_base_user_with_type(Candidate.model_validate(candidate_data))
|
||||
|
||||
# Check employers
|
||||
|
Loading…
x
Reference in New Issue
Block a user