ai-voicebot/server/core/session_manager.py

543 lines
20 KiB
Python

"""
Session management for the AI Voice Bot server.
This module handles session lifecycle, persistence, and cleanup operations.
Extracted from main.py to improve maintainability and separation of concerns.
"""
from __future__ import annotations
import json
import os
import time
import threading
import secrets
import asyncio
from typing import Optional, List, Dict, Any
from fastapi import WebSocket
from pydantic import ValidationError
# Import shared models
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from shared.models import SessionSaved, LobbySaved, SessionsPayload, NamePasswordRecord
from logger import logger
# Use try/except for importing events to handle both relative and absolute imports
try:
from ..models.events import event_bus, SessionDisconnected
except ImportError:
try:
from models.events import event_bus, SessionDisconnected
except ImportError:
# Create dummy event system for standalone testing
class DummyEventBus:
async def publish(self, event): pass
event_bus = DummyEventBus()
class SessionDisconnected: pass
class SessionConfig:
"""Configuration class for session management"""
ANONYMOUS_SESSION_TIMEOUT = int(
os.getenv("ANONYMOUS_SESSION_TIMEOUT", "60")
) # 1 minute
DISPLACED_SESSION_TIMEOUT = int(
os.getenv("DISPLACED_SESSION_TIMEOUT", "10800")
) # 3 hours
CLEANUP_INTERVAL = int(os.getenv("CLEANUP_INTERVAL", "300")) # 5 minutes
MAX_SESSIONS_PER_CLEANUP = int(
os.getenv("MAX_SESSIONS_PER_CLEANUP", "100")
) # Circuit breaker
SESSION_VALIDATION_INTERVAL = int(
os.getenv("SESSION_VALIDATION_INTERVAL", "1800")
) # 30 minutes
class Session:
"""Individual session representing a user or bot connection"""
def __init__(self, id: str, is_bot: bool = False, has_media: bool = True):
logger.info(
f"Instantiating new session {id} (bot: {is_bot}, media: {has_media})"
)
self.id = id
self.short = id[:8]
self.name = ""
self.lobbies: List[Any] = [] # List of lobby objects this session is in
self.lobby_peers: Dict[str, List[str]] = {} # lobby ID -> list of peer session IDs
self.ws: Optional[WebSocket] = None
self.created_at = time.time()
self.last_used = time.time()
self.displaced_at: Optional[float] = None # When name was taken over
self.is_bot = is_bot # Whether this session represents a bot
self.has_media = has_media # Whether this session provides audio/video streams
self.bot_run_id: Optional[str] = None # Bot run ID for tracking
self.bot_provider_id: Optional[str] = None # Bot provider ID
self.session_lock = threading.RLock() # Instance-level lock
def getName(self) -> str:
with self.session_lock:
return f"{self.short}:{self.name if self.name else '[ ---- ]'}"
def setName(self, name: str):
with self.session_lock:
old_name = self.name
self.name = name
self.update_last_used()
# Get lobby IDs for event
lobby_ids = [lobby.id for lobby in self.lobbies]
# Publish name change event (don't await here to avoid blocking)
from ..models.events import UserNameChanged
asyncio.create_task(event_bus.publish(UserNameChanged(
session_id=self.id,
old_name=old_name,
new_name=name,
lobby_ids=lobby_ids
)))
def update_last_used(self):
"""Update the last_used timestamp"""
with self.session_lock:
self.last_used = time.time()
def mark_displaced(self):
"""Mark this session as having its name taken over"""
with self.session_lock:
self.displaced_at = time.time()
async def join_lobby(self, lobby):
"""Join a lobby and update peers"""
with self.session_lock:
if lobby not in self.lobbies:
self.lobbies.append(lobby)
await lobby.addSession(self)
# Publish join event
from ..models.events import SessionJoinedLobby
await event_bus.publish(SessionJoinedLobby(
session_id=self.id,
lobby_id=lobby.id,
session_name=self.name or self.short
))
async def leave_lobby(self, lobby):
"""Leave a lobby and clean up peers"""
with self.session_lock:
if lobby in self.lobbies:
self.lobbies.remove(lobby)
if lobby.id in self.lobby_peers:
del self.lobby_peers[lobby.id]
await lobby.removeSession(self)
# Publish leave event
from ..models.events import SessionLeftLobby
await event_bus.publish(SessionLeftLobby(
session_id=self.id,
lobby_id=lobby.id,
session_name=self.name or self.short
))
def to_saved(self) -> SessionSaved:
"""Convert session to saved format for persistence"""
with self.session_lock:
lobbies_list: List[LobbySaved] = [
LobbySaved(
id=lobby.id, name=lobby.name, private=lobby.private
)
for lobby in self.lobbies
]
return SessionSaved(
id=self.id,
name=self.name or "",
lobbies=lobbies_list,
created_at=self.created_at,
last_used=self.last_used,
displaced_at=self.displaced_at,
is_bot=self.is_bot,
has_media=self.has_media,
bot_run_id=self.bot_run_id,
bot_provider_id=self.bot_provider_id,
)
class SessionManager:
"""Manages all sessions and their lifecycle"""
def __init__(self, save_file: str = "sessions.json"):
self._instances: List[Session] = []
self._save_file = save_file
self._loaded = False
self.lock = threading.RLock() # Thread safety for class-level operations
# Background task management
self.cleanup_task_running = False
self.cleanup_task: Optional[asyncio.Task] = None
self.validation_task_running = False
self.validation_task: Optional[asyncio.Task] = None
def create_session(self, session_id: Optional[str] = None, is_bot: bool = False, has_media: bool = True) -> Session:
"""Create a new session with given or generated ID"""
if not session_id:
session_id = secrets.token_hex(16)
session = Session(session_id, is_bot=is_bot, has_media=has_media)
with self.lock:
self._instances.append(session)
self.save()
return session
def get_session(self, session_id: str) -> Optional[Session]:
"""Get session by ID"""
if not self._loaded:
self.load()
logger.info(f"Loaded {len(self._instances)} sessions from disk...")
self._loaded = True
with self.lock:
for s in self._instances:
if s.id == session_id:
return s
return None
def get_session_by_name(self, name: str) -> Optional[Session]:
"""Get session by name"""
if not name:
return None
lname = name.lower()
with self.lock:
for s in self._instances:
with s.session_lock:
if s.name and s.name.lower() == lname:
return s
return None
def is_unique_name(self, name: str) -> bool:
"""Check if a name is unique across all sessions"""
if not name:
return False
with self.lock:
for s in self._instances:
with s.session_lock:
if s.name.lower() == name.lower():
return False
return True
def remove_session(self, session: Session):
"""Remove a session from the manager"""
with self.lock:
if session in self._instances:
self._instances.remove(session)
# Publish disconnect event
lobby_ids = [lobby.id for lobby in session.lobbies]
asyncio.create_task(event_bus.publish(SessionDisconnected(
session_id=session.id,
session_name=session.name or session.short,
lobby_ids=lobby_ids
)))
def save(self):
"""Save all sessions to disk"""
try:
with self.lock:
sessions_list: List[SessionSaved] = []
for s in self._instances:
sessions_list.append(s.to_saved())
# Note: We'll need to handle name_passwords separately or inject it
# For now, create empty dict - this will be handled by AuthManager
saved_pw: Dict[str, NamePasswordRecord] = {}
payload_model = SessionsPayload(
sessions=sessions_list, name_passwords=saved_pw
)
payload = payload_model.model_dump()
# Atomic write using temp file
temp_file = self._save_file + ".tmp"
with open(temp_file, "w") as f:
json.dump(payload, f, indent=2)
# Atomic rename
os.rename(temp_file, self._save_file)
logger.info(
f"Saved {len(sessions_list)} sessions to {self._save_file}"
)
except Exception as e:
logger.error(f"Failed to save sessions: {e}")
# Clean up temp file if it exists
try:
if os.path.exists(self._save_file + ".tmp"):
os.remove(self._save_file + ".tmp")
except Exception:
pass
def load(self):
"""Load sessions from disk"""
if not os.path.exists(self._save_file):
logger.info(f"No session save file found: {self._save_file}")
return
try:
with open(self._save_file, "r") as f:
raw = json.load(f)
except Exception as e:
logger.error(f"Failed to read session save file: {e}")
return
try:
payload = SessionsPayload.model_validate(raw)
except ValidationError as e:
logger.exception(f"Failed to validate sessions payload: {e}")
return
current_time = time.time()
sessions_loaded = 0
sessions_expired = 0
with self.lock:
for s_saved in payload.sessions:
# Check if this session should be expired during loading
created_at = getattr(s_saved, "created_at", time.time())
last_used = getattr(s_saved, "last_used", time.time())
displaced_at = getattr(s_saved, "displaced_at", None)
name = s_saved.name or ""
# Apply same removal criteria as cleanup_old_sessions
should_expire = self._should_remove_session_static(
name, None, created_at, last_used, displaced_at, current_time
)
if should_expire:
sessions_expired += 1
logger.info(f"Expiring session {s_saved.id[:8]}:{name} during load")
continue
session = Session(
s_saved.id,
is_bot=getattr(s_saved, "is_bot", False),
has_media=getattr(s_saved, "has_media", True),
)
session.name = name
session.created_at = created_at
session.last_used = last_used
session.displaced_at = displaced_at
session.is_bot = getattr(s_saved, "is_bot", False)
session.has_media = getattr(s_saved, "has_media", True)
session.bot_run_id = getattr(s_saved, "bot_run_id", None)
session.bot_provider_id = getattr(s_saved, "bot_provider_id", None)
# Note: Lobby restoration will be handled by LobbyManager
self._instances.append(session)
sessions_loaded += 1
logger.info(f"Loaded {sessions_loaded} sessions from {self._save_file}")
if sessions_expired > 0:
logger.info(f"Expired {sessions_expired} old sessions during load")
self.save()
@staticmethod
def _should_remove_session_static(
name: str,
ws: Optional[WebSocket],
created_at: float,
last_used: float,
displaced_at: Optional[float],
current_time: float,
) -> bool:
"""Static method to determine if a session should be removed"""
# Rule 1: Delete sessions with no active connection and no name that are older than threshold
if (
not ws
and not name
and current_time - created_at > SessionConfig.ANONYMOUS_SESSION_TIMEOUT
):
return True
# Rule 2: Delete inactive sessions that had their nick taken over and haven't been used recently
if (
not ws
and displaced_at is not None
and current_time - last_used > SessionConfig.DISPLACED_SESSION_TIMEOUT
):
return True
return False
def cleanup_old_sessions(self) -> int:
"""Clean up old/stale sessions and return count of removed sessions"""
current_time = time.time()
removed_count = 0
with self.lock:
sessions_to_remove = []
for session in self._instances:
with session.session_lock:
if self._should_remove_session_static(
session.name,
session.ws,
session.created_at,
session.last_used,
session.displaced_at,
current_time,
):
sessions_to_remove.append(session)
if len(sessions_to_remove) >= SessionConfig.MAX_SESSIONS_PER_CLEANUP:
break
# Remove sessions
for session in sessions_to_remove:
try:
# Clean up websocket if open
if session.ws:
asyncio.create_task(session.ws.close())
# Remove from lobbies (will be handled by lobby manager events)
for lobby in session.lobbies[:]:
asyncio.create_task(session.leave_lobby(lobby))
self._instances.remove(session)
removed_count += 1
logger.info(f"Cleaned up session {session.getName()}")
except Exception as e:
logger.warning(f"Error cleaning up session {session.getName()}: {e}")
if removed_count > 0:
self.save()
return removed_count
async def start_background_tasks(self):
"""Start background cleanup and validation tasks"""
logger.info("Starting session background tasks...")
self.cleanup_task_running = True
self.validation_task_running = True
self.cleanup_task = asyncio.create_task(self._periodic_cleanup())
self.validation_task = asyncio.create_task(self._periodic_validation())
logger.info("Session background tasks started")
async def stop_background_tasks(self):
"""Stop background tasks gracefully"""
logger.info("Shutting down session background tasks...")
self.cleanup_task_running = False
self.validation_task_running = False
# Cancel tasks
for task in [self.cleanup_task, self.validation_task]:
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Clean up all sessions gracefully
await self._cleanup_all_sessions()
logger.info("Session background tasks stopped")
async def _periodic_cleanup(self):
"""Background task to periodically clean up old sessions"""
cleanup_errors = 0
max_consecutive_errors = 5
while self.cleanup_task_running:
try:
removed_count = self.cleanup_old_sessions()
if removed_count > 0:
logger.info(f"Periodic cleanup removed {removed_count} old sessions")
cleanup_errors = 0 # Reset error counter on success
# Run cleanup at configured interval
await asyncio.sleep(SessionConfig.CLEANUP_INTERVAL)
except Exception as e:
cleanup_errors += 1
logger.error(
f"Error in session cleanup task (attempt {cleanup_errors}): {e}"
)
if cleanup_errors >= max_consecutive_errors:
logger.error(
f"Too many consecutive cleanup errors ({cleanup_errors}), stopping cleanup task"
)
break
# Exponential backoff on errors
await asyncio.sleep(min(60 * cleanup_errors, 300))
async def _periodic_validation(self):
"""Background task to periodically validate session integrity"""
while self.validation_task_running:
try:
issues = self.validate_session_integrity()
if issues:
logger.warning(f"Session integrity issues found: {len(issues)} issues")
for issue in issues[:10]: # Log first 10 issues
logger.warning(f"Integrity issue: {issue}")
await asyncio.sleep(SessionConfig.SESSION_VALIDATION_INTERVAL)
except Exception as e:
logger.error(f"Error in session validation task: {e}")
await asyncio.sleep(300) # Wait 5 minutes before retrying on error
def validate_session_integrity(self) -> List[str]:
"""Validate session integrity and return list of issues"""
issues = []
with self.lock:
for session in self._instances:
with session.session_lock:
# Check for sessions with invalid state
if not session.id:
issues.append(f"Session with empty ID: {session}")
if session.created_at > time.time():
issues.append(f"Session {session.getName()} has future creation time")
if session.last_used > time.time():
issues.append(f"Session {session.getName()} has future last_used time")
# Check for duplicate names
if session.name:
count = sum(1 for s in self._instances
if s.name and s.name.lower() == session.name.lower())
if count > 1:
issues.append(f"Duplicate name '{session.name}' found in {count} sessions")
return issues
async def _cleanup_all_sessions(self):
"""Clean up all sessions during shutdown"""
with self.lock:
for session in self._instances[:]:
try:
if session.ws:
await session.ws.close()
except Exception as e:
logger.warning(f"Error closing WebSocket for {session.getName()}: {e}")
logger.info("All sessions cleaned up")
def get_all_sessions(self) -> List[Session]:
"""Get all sessions (for admin/debugging purposes)"""
with self.lock:
return self._instances.copy()
def get_session_count(self) -> int:
"""Get total session count"""
with self.lock:
return len(self._instances)