import asyncio import weakref from datetime import datetime, timedelta from typing import Dict, Optional from contextlib import asynccontextmanager from pydantic import BaseModel, Field # type: ignore from models import Candidate from agents.base import CandidateEntity from database.manager import RedisDatabase from prometheus_client import CollectorRegistry # type: ignore class EntityManager(BaseModel): """Manages lifecycle of CandidateEntity instances""" def __init__(self, default_ttl_minutes: int = 30): self._entities: Dict[str, CandidateEntity] = {} self._weak_refs: Dict[str, weakref.ReferenceType] = {} self._ttl_minutes = default_ttl_minutes self._cleanup_task: Optional[asyncio.Task] = None self._prometheus_collector: Optional[CollectorRegistry] = None self._database: Optional[RedisDatabase] = None async def start_cleanup_task(self): """Start background cleanup task""" if self._cleanup_task is None: self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) async def stop_cleanup_task(self): """Stop background cleanup task""" if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass self._cleanup_task = None def initialize( self, prometheus_collector: CollectorRegistry, database: RedisDatabase): """Initialize the EntityManager with Prometheus collector""" self._prometheus_collector = prometheus_collector self._database = database async def get_entity(self, candidate: "Candidate") -> CandidateEntity: """Get or create CandidateEntity with proper reference tracking""" # Check if entity exists and is still valid if id in self._entities: entity = self._entities[candidate.id] entity.last_accessed = datetime.now() entity.reference_count += 1 return entity if not self._prometheus_collector or not self._database: raise ValueError("EntityManager has not been initialized with required components.") entity = CandidateEntity(candidate=candidate) await entity.initialize( prometheus_collector=self._prometheus_collector, database=self._database) # Store with reference tracking self._entities[candidate.id] = entity self._weak_refs[candidate.id] = weakref.ref(entity, self._on_entity_deleted(candidate.id)) entity.reference_count = 1 entity.last_accessed = datetime.now() return entity async def remove_entity(self, candidate_id: str) -> bool: """ Immediately remove and cleanup a candidate entity from active persistence. This should be called when a candidate is being deleted from the system. Args: candidate_id: The ID of the candidate entity to remove Returns: bool: True if entity was found and removed, False if not found """ try: # Check if entity exists entity = self._entities.get(candidate_id) if not entity: print(f"Entity {candidate_id} not found in active persistence") return False # Remove from tracking dictionaries self._entities.pop(candidate_id, None) self._weak_refs.pop(candidate_id, None) # Cleanup the entity await entity.cleanup() print(f"Successfully removed entity {candidate_id} from active persistence") return True except Exception as e: print(f"Error removing entity {candidate_id}: {e}") return False def _on_entity_deleted(self, user_id: str): """Callback when entity is garbage collected""" def cleanup_callback(weak_ref): self._entities.pop(user_id, None) self._weak_refs.pop(user_id, None) print(f"Entity {user_id} garbage collected") return cleanup_callback async def release_entity(self, user_id: str): """Explicitly release reference to entity""" if user_id in self._entities: entity = self._entities[user_id] entity.reference_count = max(0, entity.reference_count - 1) entity.last_accessed = datetime.now() async def _periodic_cleanup(self): """Background task to clean up expired entities""" while True: try: await asyncio.sleep(60) # Check every minute await self._cleanup_expired_entities() except asyncio.CancelledError: break except Exception as e: print(f"Error in cleanup task: {e}") async def _cleanup_expired_entities(self): """Remove entities that have expired based on TTL and reference count""" current_time = datetime.now() expired_entities = [] for user_id, entity in list(self._entities.items()): time_since_access = current_time - entity.last_accessed # Remove if TTL exceeded and no active references if (time_since_access > timedelta(minutes=self._ttl_minutes) and entity.reference_count == 0): expired_entities.append(user_id) for user_id in expired_entities: entity = self._entities.pop(user_id, None) self._weak_refs.pop(user_id, None) if entity: await entity.cleanup() print(f"Cleaned up expired entity {user_id}") # Global entity manager instance entity_manager = EntityManager(default_ttl_minutes=30) @asynccontextmanager async def get_candidate_entity(candidate: Candidate): """Context manager for safe entity access with automatic reference management""" if not entity_manager._prometheus_collector: raise ValueError("EntityManager has not been initialized with a Prometheus collector.") entity = await entity_manager.get_entity(candidate=candidate) try: yield entity finally: await entity_manager.release_entity(candidate.id) EntityManager.model_rebuild()