395 lines
14 KiB
Python
395 lines
14 KiB
Python
"""
|
|
Caching System for Performance Optimization
|
|
|
|
Provides multi-level caching for sessions, lobbies, and frequently accessed data.
|
|
Includes in-memory LRU cache and optional Redis backend for distributed caching.
|
|
|
|
Features:
|
|
- In-memory LRU cache with TTL support
|
|
- Optional Redis distributed caching
|
|
- Cache warming and prefetching
|
|
- Cache statistics and monitoring
|
|
- Automatic cache invalidation
|
|
- Async cache operations
|
|
"""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
from datetime import datetime, timedelta
|
|
from typing import Any, Dict, Optional, Callable, TypeVar
|
|
from collections import OrderedDict
|
|
from dataclasses import dataclass
|
|
|
|
from shared.logger import logger
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
@dataclass
|
|
class CacheEntry:
|
|
"""Cache entry with value, expiration, and metadata."""
|
|
value: Any
|
|
created_at: datetime
|
|
expires_at: Optional[datetime]
|
|
hit_count: int = 0
|
|
last_accessed: Optional[datetime] = None
|
|
size_bytes: int = 0
|
|
|
|
|
|
class LRUCache:
|
|
"""In-memory LRU cache with TTL support."""
|
|
|
|
def __init__(self, max_size: int = 1000, default_ttl_seconds: int = 300):
|
|
self.max_size = max_size
|
|
self.default_ttl_seconds = default_ttl_seconds
|
|
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
|
self._hits = 0
|
|
self._misses = 0
|
|
self._evictions = 0
|
|
self._size_bytes = 0
|
|
|
|
def get(self, key: str) -> Optional[Any]:
|
|
"""Get value from cache."""
|
|
if key not in self._cache:
|
|
self._misses += 1
|
|
return None
|
|
|
|
entry = self._cache[key]
|
|
|
|
# Check expiration
|
|
if entry.expires_at and datetime.now() > entry.expires_at:
|
|
self._remove_entry(key)
|
|
self._misses += 1
|
|
return None
|
|
|
|
# Update access info
|
|
entry.hit_count += 1
|
|
entry.last_accessed = datetime.now()
|
|
|
|
# Move to end (most recently used)
|
|
self._cache.move_to_end(key)
|
|
|
|
self._hits += 1
|
|
return entry.value
|
|
|
|
def put(self, key: str, value: Any, ttl_seconds: Optional[int] = None) -> None:
|
|
"""Put value in cache."""
|
|
ttl = ttl_seconds or self.default_ttl_seconds
|
|
expires_at = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None
|
|
|
|
# Calculate size (rough estimate)
|
|
size_bytes = len(str(value).encode('utf-8'))
|
|
|
|
# Remove existing entry if present
|
|
if key in self._cache:
|
|
self._remove_entry(key)
|
|
|
|
# Create new entry
|
|
entry = CacheEntry(
|
|
value=value,
|
|
created_at=datetime.now(),
|
|
expires_at=expires_at,
|
|
size_bytes=size_bytes
|
|
)
|
|
|
|
self._cache[key] = entry
|
|
self._size_bytes += size_bytes
|
|
|
|
# Evict if necessary
|
|
self._evict_if_necessary()
|
|
|
|
def delete(self, key: str) -> bool:
|
|
"""Delete entry from cache."""
|
|
if key in self._cache:
|
|
self._remove_entry(key)
|
|
return True
|
|
return False
|
|
|
|
def clear(self) -> None:
|
|
"""Clear all cache entries."""
|
|
self._cache.clear()
|
|
self._size_bytes = 0
|
|
|
|
def _remove_entry(self, key: str) -> None:
|
|
"""Remove entry and update size."""
|
|
if key in self._cache:
|
|
entry = self._cache.pop(key)
|
|
self._size_bytes -= entry.size_bytes
|
|
|
|
def _evict_if_necessary(self) -> None:
|
|
"""Evict oldest entries if cache is full."""
|
|
while len(self._cache) > self.max_size:
|
|
# Remove least recently used (first item)
|
|
oldest_key = next(iter(self._cache))
|
|
self._remove_entry(oldest_key)
|
|
self._evictions += 1
|
|
|
|
def cleanup_expired(self) -> int:
|
|
"""Remove expired entries. Returns number of entries removed."""
|
|
now = datetime.now()
|
|
expired_keys = [
|
|
key for key, entry in self._cache.items()
|
|
if entry.expires_at and now > entry.expires_at
|
|
]
|
|
|
|
for key in expired_keys:
|
|
self._remove_entry(key)
|
|
|
|
return len(expired_keys)
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get cache statistics."""
|
|
total_requests = self._hits + self._misses
|
|
hit_rate = (self._hits / total_requests * 100) if total_requests > 0 else 0
|
|
|
|
return {
|
|
'hits': self._hits,
|
|
'misses': self._misses,
|
|
'hit_rate_percent': hit_rate,
|
|
'evictions': self._evictions,
|
|
'entries': len(self._cache),
|
|
'max_size': self.max_size,
|
|
'size_bytes': self._size_bytes,
|
|
'avg_entry_size': self._size_bytes / len(self._cache) if self._cache else 0
|
|
}
|
|
|
|
|
|
class AsyncCache:
|
|
"""Async wrapper for cache operations with background cleanup."""
|
|
|
|
def __init__(self, backend: LRUCache):
|
|
self.backend = backend
|
|
self._cleanup_task: Optional[asyncio.Task] = None
|
|
self._cleanup_interval = 60 # seconds
|
|
self._running = False
|
|
|
|
async def get(self, key: str) -> Optional[Any]:
|
|
"""Async get from cache."""
|
|
return self.backend.get(key)
|
|
|
|
async def put(self, key: str, value: Any, ttl_seconds: Optional[int] = None) -> None:
|
|
"""Async put to cache."""
|
|
self.backend.put(key, value, ttl_seconds)
|
|
|
|
async def delete(self, key: str) -> bool:
|
|
"""Async delete from cache."""
|
|
return self.backend.delete(key)
|
|
|
|
async def get_or_compute(self, key: str, compute_func: Callable[[], Any],
|
|
ttl_seconds: Optional[int] = None) -> Any:
|
|
"""Get value from cache or compute if not present."""
|
|
value = await self.get(key)
|
|
if value is not None:
|
|
return value
|
|
|
|
# Compute value
|
|
if asyncio.iscoroutinefunction(compute_func):
|
|
computed_value = await compute_func()
|
|
else:
|
|
computed_value = compute_func()
|
|
|
|
# Store in cache
|
|
await self.put(key, computed_value, ttl_seconds)
|
|
return computed_value
|
|
|
|
async def start_cleanup(self):
|
|
"""Start background cleanup task."""
|
|
if self._running:
|
|
return
|
|
|
|
self._running = True
|
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
logger.info("Cache cleanup started")
|
|
|
|
async def stop_cleanup(self):
|
|
"""Stop background cleanup."""
|
|
self._running = False
|
|
if self._cleanup_task:
|
|
self._cleanup_task.cancel()
|
|
try:
|
|
await self._cleanup_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
logger.info("Cache cleanup stopped")
|
|
|
|
async def _cleanup_loop(self):
|
|
"""Background cleanup loop."""
|
|
while self._running:
|
|
try:
|
|
expired_count = self.backend.cleanup_expired()
|
|
if expired_count > 0:
|
|
logger.debug(f"Cleaned up {expired_count} expired cache entries")
|
|
|
|
await asyncio.sleep(self._cleanup_interval)
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in cache cleanup: {e}")
|
|
await asyncio.sleep(5.0)
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get cache statistics."""
|
|
return self.backend.get_stats()
|
|
|
|
|
|
class CacheManager:
|
|
"""High-level cache manager for different data types."""
|
|
|
|
def __init__(self):
|
|
# Different caches for different data types
|
|
self.session_cache = AsyncCache(LRUCache(max_size=500, default_ttl_seconds=300))
|
|
self.lobby_cache = AsyncCache(LRUCache(max_size=200, default_ttl_seconds=600))
|
|
self.user_cache = AsyncCache(LRUCache(max_size=1000, default_ttl_seconds=1800))
|
|
self.message_cache = AsyncCache(LRUCache(max_size=2000, default_ttl_seconds=60))
|
|
|
|
# Cache for computed values (e.g., aggregations)
|
|
self.computed_cache = AsyncCache(LRUCache(max_size=100, default_ttl_seconds=120))
|
|
|
|
self._caches = {
|
|
'session': self.session_cache,
|
|
'lobby': self.lobby_cache,
|
|
'user': self.user_cache,
|
|
'message': self.message_cache,
|
|
'computed': self.computed_cache
|
|
}
|
|
|
|
async def start_all(self):
|
|
"""Start all cache cleanup tasks."""
|
|
for cache in self._caches.values():
|
|
await cache.start_cleanup()
|
|
logger.info("All cache managers started")
|
|
|
|
async def stop_all(self):
|
|
"""Stop all cache cleanup tasks."""
|
|
for cache in self._caches.values():
|
|
await cache.stop_cleanup()
|
|
logger.info("All cache managers stopped")
|
|
|
|
# Session caching methods
|
|
async def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get session from cache."""
|
|
return await self.session_cache.get(f"session:{session_id}")
|
|
|
|
async def cache_session(self, session_id: str, session_data: Dict[str, Any],
|
|
ttl_seconds: int = 300) -> None:
|
|
"""Cache session data."""
|
|
await self.session_cache.put(f"session:{session_id}", session_data, ttl_seconds)
|
|
|
|
async def invalidate_session(self, session_id: str) -> None:
|
|
"""Remove session from cache."""
|
|
await self.session_cache.delete(f"session:{session_id}")
|
|
|
|
# Lobby caching methods
|
|
async def get_lobby(self, lobby_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get lobby from cache."""
|
|
return await self.lobby_cache.get(f"lobby:{lobby_id}")
|
|
|
|
async def cache_lobby(self, lobby_id: str, lobby_data: Dict[str, Any],
|
|
ttl_seconds: int = 600) -> None:
|
|
"""Cache lobby data."""
|
|
await self.lobby_cache.put(f"lobby:{lobby_id}", lobby_data, ttl_seconds)
|
|
|
|
async def invalidate_lobby(self, lobby_id: str) -> None:
|
|
"""Remove lobby from cache."""
|
|
await self.lobby_cache.delete(f"lobby:{lobby_id}")
|
|
|
|
# Message caching methods
|
|
async def get_cached_response(self, message_hash: str) -> Optional[str]:
|
|
"""Get cached bot response."""
|
|
return await self.message_cache.get(f"response:{message_hash}")
|
|
|
|
async def cache_response(self, message: str, response: str, ttl_seconds: int = 60) -> None:
|
|
"""Cache bot response."""
|
|
message_hash = hashlib.md5(message.encode()).hexdigest()
|
|
await self.message_cache.put(f"response:{message_hash}", response, ttl_seconds)
|
|
|
|
# Computed values caching
|
|
async def get_computed(self, key: str) -> Optional[Any]:
|
|
"""Get computed value from cache."""
|
|
return await self.computed_cache.get(f"computed:{key}")
|
|
|
|
async def cache_computed(self, key: str, value: Any, ttl_seconds: int = 120) -> None:
|
|
"""Cache computed value."""
|
|
await self.computed_cache.put(f"computed:{key}", value, ttl_seconds)
|
|
|
|
async def get_or_compute_lobby_stats(self, lobby_id: str,
|
|
compute_func: Callable) -> Dict[str, Any]:
|
|
"""Get or compute lobby statistics."""
|
|
return await self.computed_cache.get_or_compute(
|
|
f"lobby_stats:{lobby_id}",
|
|
compute_func,
|
|
ttl_seconds=300 # 5 minutes
|
|
)
|
|
|
|
def get_all_stats(self) -> Dict[str, Any]:
|
|
"""Get statistics for all caches."""
|
|
return {
|
|
cache_name: cache.get_stats()
|
|
for cache_name, cache in self._caches.items()
|
|
}
|
|
|
|
async def warm_cache(self, session_manager, lobby_manager):
|
|
"""Warm up caches with current data."""
|
|
try:
|
|
# Warm session cache
|
|
for session_id, session in session_manager.sessions.items():
|
|
session_data = {
|
|
'id': session_id,
|
|
'name': session.getName() if hasattr(session, 'getName') else 'Unknown',
|
|
'lobby_id': getattr(session, 'lobby_id', None),
|
|
'created_at': datetime.now().isoformat()
|
|
}
|
|
await self.cache_session(session_id, session_data)
|
|
|
|
# Warm lobby cache
|
|
for lobby_id, lobby in lobby_manager.lobbies.items():
|
|
lobby_data = {
|
|
'id': lobby_id,
|
|
'session_count': len(lobby.sessions) if hasattr(lobby, 'sessions') else 0,
|
|
'created_at': datetime.now().isoformat()
|
|
}
|
|
await self.cache_lobby(lobby_id, lobby_data)
|
|
|
|
logger.info(f"Cache warmed: {len(session_manager.sessions)} sessions, {len(lobby_manager.lobbies)} lobbies")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error warming cache: {e}")
|
|
|
|
|
|
# Decorator for automatic caching
|
|
def cache_result(cache_manager: CacheManager, cache_type: str = 'computed',
|
|
ttl_seconds: int = 300, key_func: Optional[Callable] = None):
|
|
"""Decorator to automatically cache function results."""
|
|
def decorator(func):
|
|
async def wrapper(*args, **kwargs):
|
|
# Generate cache key
|
|
if key_func:
|
|
cache_key = key_func(*args, **kwargs)
|
|
else:
|
|
# Default key generation
|
|
key_parts = [func.__name__] + [str(arg) for arg in args[:3]] # Limit args
|
|
cache_key = ':'.join(key_parts)
|
|
|
|
# Try to get from cache
|
|
cache = getattr(cache_manager, f'{cache_type}_cache')
|
|
cached_result = await cache.get(cache_key)
|
|
if cached_result is not None:
|
|
return cached_result
|
|
|
|
# Compute result
|
|
if asyncio.iscoroutinefunction(func):
|
|
result = await func(*args, **kwargs)
|
|
else:
|
|
result = func(*args, **kwargs)
|
|
|
|
# Cache result
|
|
await cache.put(cache_key, result, ttl_seconds)
|
|
return result
|
|
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
# Global cache manager instance
|
|
cache_manager = CacheManager()
|