395 lines
14 KiB
Python

"""
Caching System for Performance Optimization
Provides multi-level caching for sessions, lobbies, and frequently accessed data.
Includes in-memory LRU cache and optional Redis backend for distributed caching.
Features:
- In-memory LRU cache with TTL support
- Optional Redis distributed caching
- Cache warming and prefetching
- Cache statistics and monitoring
- Automatic cache invalidation
- Async cache operations
"""
import asyncio
import hashlib
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, Callable, TypeVar
from collections import OrderedDict
from dataclasses import dataclass
from logger import logger
T = TypeVar('T')
@dataclass
class CacheEntry:
"""Cache entry with value, expiration, and metadata."""
value: Any
created_at: datetime
expires_at: Optional[datetime]
hit_count: int = 0
last_accessed: Optional[datetime] = None
size_bytes: int = 0
class LRUCache:
"""In-memory LRU cache with TTL support."""
def __init__(self, max_size: int = 1000, default_ttl_seconds: int = 300):
self.max_size = max_size
self.default_ttl_seconds = default_ttl_seconds
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._hits = 0
self._misses = 0
self._evictions = 0
self._size_bytes = 0
def get(self, key: str) -> Optional[Any]:
"""Get value from cache."""
if key not in self._cache:
self._misses += 1
return None
entry = self._cache[key]
# Check expiration
if entry.expires_at and datetime.now() > entry.expires_at:
self._remove_entry(key)
self._misses += 1
return None
# Update access info
entry.hit_count += 1
entry.last_accessed = datetime.now()
# Move to end (most recently used)
self._cache.move_to_end(key)
self._hits += 1
return entry.value
def put(self, key: str, value: Any, ttl_seconds: Optional[int] = None) -> None:
"""Put value in cache."""
ttl = ttl_seconds or self.default_ttl_seconds
expires_at = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None
# Calculate size (rough estimate)
size_bytes = len(str(value).encode('utf-8'))
# Remove existing entry if present
if key in self._cache:
self._remove_entry(key)
# Create new entry
entry = CacheEntry(
value=value,
created_at=datetime.now(),
expires_at=expires_at,
size_bytes=size_bytes
)
self._cache[key] = entry
self._size_bytes += size_bytes
# Evict if necessary
self._evict_if_necessary()
def delete(self, key: str) -> bool:
"""Delete entry from cache."""
if key in self._cache:
self._remove_entry(key)
return True
return False
def clear(self) -> None:
"""Clear all cache entries."""
self._cache.clear()
self._size_bytes = 0
def _remove_entry(self, key: str) -> None:
"""Remove entry and update size."""
if key in self._cache:
entry = self._cache.pop(key)
self._size_bytes -= entry.size_bytes
def _evict_if_necessary(self) -> None:
"""Evict oldest entries if cache is full."""
while len(self._cache) > self.max_size:
# Remove least recently used (first item)
oldest_key = next(iter(self._cache))
self._remove_entry(oldest_key)
self._evictions += 1
def cleanup_expired(self) -> int:
"""Remove expired entries. Returns number of entries removed."""
now = datetime.now()
expired_keys = [
key for key, entry in self._cache.items()
if entry.expires_at and now > entry.expires_at
]
for key in expired_keys:
self._remove_entry(key)
return len(expired_keys)
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
total_requests = self._hits + self._misses
hit_rate = (self._hits / total_requests * 100) if total_requests > 0 else 0
return {
'hits': self._hits,
'misses': self._misses,
'hit_rate_percent': hit_rate,
'evictions': self._evictions,
'entries': len(self._cache),
'max_size': self.max_size,
'size_bytes': self._size_bytes,
'avg_entry_size': self._size_bytes / len(self._cache) if self._cache else 0
}
class AsyncCache:
"""Async wrapper for cache operations with background cleanup."""
def __init__(self, backend: LRUCache):
self.backend = backend
self._cleanup_task: Optional[asyncio.Task] = None
self._cleanup_interval = 60 # seconds
self._running = False
async def get(self, key: str) -> Optional[Any]:
"""Async get from cache."""
return self.backend.get(key)
async def put(self, key: str, value: Any, ttl_seconds: Optional[int] = None) -> None:
"""Async put to cache."""
self.backend.put(key, value, ttl_seconds)
async def delete(self, key: str) -> bool:
"""Async delete from cache."""
return self.backend.delete(key)
async def get_or_compute(self, key: str, compute_func: Callable[[], Any],
ttl_seconds: Optional[int] = None) -> Any:
"""Get value from cache or compute if not present."""
value = await self.get(key)
if value is not None:
return value
# Compute value
if asyncio.iscoroutinefunction(compute_func):
computed_value = await compute_func()
else:
computed_value = compute_func()
# Store in cache
await self.put(key, computed_value, ttl_seconds)
return computed_value
async def start_cleanup(self):
"""Start background cleanup task."""
if self._running:
return
self._running = True
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("Cache cleanup started")
async def stop_cleanup(self):
"""Stop background cleanup."""
self._running = False
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
logger.info("Cache cleanup stopped")
async def _cleanup_loop(self):
"""Background cleanup loop."""
while self._running:
try:
expired_count = self.backend.cleanup_expired()
if expired_count > 0:
logger.debug(f"Cleaned up {expired_count} expired cache entries")
await asyncio.sleep(self._cleanup_interval)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in cache cleanup: {e}")
await asyncio.sleep(5.0)
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
return self.backend.get_stats()
class CacheManager:
"""High-level cache manager for different data types."""
def __init__(self):
# Different caches for different data types
self.session_cache = AsyncCache(LRUCache(max_size=500, default_ttl_seconds=300))
self.lobby_cache = AsyncCache(LRUCache(max_size=200, default_ttl_seconds=600))
self.user_cache = AsyncCache(LRUCache(max_size=1000, default_ttl_seconds=1800))
self.message_cache = AsyncCache(LRUCache(max_size=2000, default_ttl_seconds=60))
# Cache for computed values (e.g., aggregations)
self.computed_cache = AsyncCache(LRUCache(max_size=100, default_ttl_seconds=120))
self._caches = {
'session': self.session_cache,
'lobby': self.lobby_cache,
'user': self.user_cache,
'message': self.message_cache,
'computed': self.computed_cache
}
async def start_all(self):
"""Start all cache cleanup tasks."""
for cache in self._caches.values():
await cache.start_cleanup()
logger.info("All cache managers started")
async def stop_all(self):
"""Stop all cache cleanup tasks."""
for cache in self._caches.values():
await cache.stop_cleanup()
logger.info("All cache managers stopped")
# Session caching methods
async def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get session from cache."""
return await self.session_cache.get(f"session:{session_id}")
async def cache_session(self, session_id: str, session_data: Dict[str, Any],
ttl_seconds: int = 300) -> None:
"""Cache session data."""
await self.session_cache.put(f"session:{session_id}", session_data, ttl_seconds)
async def invalidate_session(self, session_id: str) -> None:
"""Remove session from cache."""
await self.session_cache.delete(f"session:{session_id}")
# Lobby caching methods
async def get_lobby(self, lobby_id: str) -> Optional[Dict[str, Any]]:
"""Get lobby from cache."""
return await self.lobby_cache.get(f"lobby:{lobby_id}")
async def cache_lobby(self, lobby_id: str, lobby_data: Dict[str, Any],
ttl_seconds: int = 600) -> None:
"""Cache lobby data."""
await self.lobby_cache.put(f"lobby:{lobby_id}", lobby_data, ttl_seconds)
async def invalidate_lobby(self, lobby_id: str) -> None:
"""Remove lobby from cache."""
await self.lobby_cache.delete(f"lobby:{lobby_id}")
# Message caching methods
async def get_cached_response(self, message_hash: str) -> Optional[str]:
"""Get cached bot response."""
return await self.message_cache.get(f"response:{message_hash}")
async def cache_response(self, message: str, response: str, ttl_seconds: int = 60) -> None:
"""Cache bot response."""
message_hash = hashlib.md5(message.encode()).hexdigest()
await self.message_cache.put(f"response:{message_hash}", response, ttl_seconds)
# Computed values caching
async def get_computed(self, key: str) -> Optional[Any]:
"""Get computed value from cache."""
return await self.computed_cache.get(f"computed:{key}")
async def cache_computed(self, key: str, value: Any, ttl_seconds: int = 120) -> None:
"""Cache computed value."""
await self.computed_cache.put(f"computed:{key}", value, ttl_seconds)
async def get_or_compute_lobby_stats(self, lobby_id: str,
compute_func: Callable) -> Dict[str, Any]:
"""Get or compute lobby statistics."""
return await self.computed_cache.get_or_compute(
f"lobby_stats:{lobby_id}",
compute_func,
ttl_seconds=300 # 5 minutes
)
def get_all_stats(self) -> Dict[str, Any]:
"""Get statistics for all caches."""
return {
cache_name: cache.get_stats()
for cache_name, cache in self._caches.items()
}
async def warm_cache(self, session_manager, lobby_manager):
"""Warm up caches with current data."""
try:
# Warm session cache
for session_id, session in session_manager.sessions.items():
session_data = {
'id': session_id,
'name': session.getName() if hasattr(session, 'getName') else 'Unknown',
'lobby_id': getattr(session, 'lobby_id', None),
'created_at': datetime.now().isoformat()
}
await self.cache_session(session_id, session_data)
# Warm lobby cache
for lobby_id, lobby in lobby_manager.lobbies.items():
lobby_data = {
'id': lobby_id,
'session_count': len(lobby.sessions) if hasattr(lobby, 'sessions') else 0,
'created_at': datetime.now().isoformat()
}
await self.cache_lobby(lobby_id, lobby_data)
logger.info(f"Cache warmed: {len(session_manager.sessions)} sessions, {len(lobby_manager.lobbies)} lobbies")
except Exception as e:
logger.error(f"Error warming cache: {e}")
# Decorator for automatic caching
def cache_result(cache_manager: CacheManager, cache_type: str = 'computed',
ttl_seconds: int = 300, key_func: Optional[Callable] = None):
"""Decorator to automatically cache function results."""
def decorator(func):
async def wrapper(*args, **kwargs):
# Generate cache key
if key_func:
cache_key = key_func(*args, **kwargs)
else:
# Default key generation
key_parts = [func.__name__] + [str(arg) for arg in args[:3]] # Limit args
cache_key = ':'.join(key_parts)
# Try to get from cache
cache = getattr(cache_manager, f'{cache_type}_cache')
cached_result = await cache.get(cache_key)
if cached_result is not None:
return cached_result
# Compute result
if asyncio.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
# Cache result
await cache.put(cache_key, result, ttl_seconds)
return result
return wrapper
return decorator
# Global cache manager instance
cache_manager = CacheManager()